From 9b2f90018cc0f579258066cea69bd84589614af1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Edvard=20Sch=C3=B8yen?= <99178202+ecschoye@users.noreply.github.com> Date: Sun, 15 Feb 2026 14:53:56 -0500 Subject: [PATCH] feat: add screenshot and image_info vision tools * feat: add screenshot and image_info vision tools Add two new tools for visual capabilities: - `screenshot`: captures screen using platform-native commands (screencapture on macOS, gnome-screenshot/scrot/import on Linux), returns file path + base64-encoded PNG data - `image_info`: reads image metadata (format, dimensions, size) from header bytes without external deps, optionally returns base64 data for future multimodal provider support Both tools are registered in the tool registry and agent system prompt. Includes 24 inline tests covering format detection, dimension extraction, schema validation, and execution edge cases. Co-Authored-By: Claude Opus 4.6 * fix: resolve unused variable warning after rebase Prefix unused `resolved_key` with underscore to suppress compiler warning introduced by upstream changes. Update Cargo.lock. Co-Authored-By: Claude Opus 4.6 * fix: address review comments on vision tools Security fixes: - Fix JPEG parser infinite loop on malformed zero-length segments - Add workspace path restriction to ImageInfoTool (prevents arbitrary file exfiltration via include_base64) - Quote paths in Linux screenshot shell commands to prevent injection - Add autonomy-level check in ScreenshotTool::execute Robustness: - Add file size guard in read_and_encode before loading into memory - Wire resolve_api_key through all provider match arms (was dead code) - Gate screenshot_command_exists test on macOS/Linux only - Infer MIME type from file extension instead of hardcoding image/png Tests: - Add JPEG dimension extraction test - Add JPEG malformed zero-length segment test Co-Authored-By: Claude Opus 4.6 --------- Co-authored-by: Claude Opus 4.6 Co-authored-by: argenis de la rosa --- Cargo.lock | 1 + Cargo.toml | 3 + src/agent/loop_.rs | 8 + src/providers/mod.rs | 51 +++-- src/tools/image_info.rs | 491 ++++++++++++++++++++++++++++++++++++++++ src/tools/mod.rs | 8 + src/tools/screenshot.rs | 300 ++++++++++++++++++++++++ 7 files changed, 837 insertions(+), 25 deletions(-) create mode 100644 src/tools/image_info.rs create mode 100644 src/tools/screenshot.rs diff --git a/Cargo.lock b/Cargo.lock index 614cbb6..f39c66f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3191,6 +3191,7 @@ dependencies = [ "anyhow", "async-trait", "axum", + "base64", "chacha20poly1305", "chrono", "clap", diff --git a/Cargo.toml b/Cargo.toml index 45dfcaf..6ead2f0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -36,6 +36,9 @@ tracing-subscriber = { version = "0.3", default-features = false, features = ["f # Observability - Prometheus metrics prometheus = { version = "0.13", default-features = false } +# Base64 encoding (screenshots, image data) +base64 = "0.22" + # Error handling anyhow = "1.0" thiserror = "2.0" diff --git a/src/agent/loop_.rs b/src/agent/loop_.rs index 991905b..0d6b89d 100644 --- a/src/agent/loop_.rs +++ b/src/agent/loop_.rs @@ -321,6 +321,14 @@ pub async fn run( "Delete a memory entry. Use when: memory is incorrect/stale or explicitly requested for removal. Don't use when: impact is uncertain.", ), ]; + tool_descs.push(( + "screenshot", + "Capture a screenshot of the current screen. Returns file path and base64-encoded PNG. Use when: visual verification, UI inspection, debugging displays.", + )); + tool_descs.push(( + "image_info", + "Read image file metadata (format, dimensions, size) and optionally base64-encode it. Use when: inspecting images, preparing visual data for analysis.", + )); if config.browser.enabled { tool_descs.push(( "browser_open", diff --git a/src/providers/mod.rs b/src/providers/mod.rs index db65d63..1143374 100644 --- a/src/providers/mod.rs +++ b/src/providers/mod.rs @@ -154,25 +154,26 @@ fn resolve_api_key(name: &str, api_key: Option<&str>) -> Option { /// Factory: create the right provider from config #[allow(clippy::too_many_lines)] pub fn create_provider(name: &str, api_key: Option<&str>) -> anyhow::Result> { - let _resolved_key = resolve_api_key(name, api_key); + let resolved_key = resolve_api_key(name, api_key); + let key = resolved_key.as_deref(); match name { // ── Primary providers (custom implementations) ─────── - "openrouter" => Ok(Box::new(openrouter::OpenRouterProvider::new(api_key))), - "anthropic" => Ok(Box::new(anthropic::AnthropicProvider::new(api_key))), - "openai" => Ok(Box::new(openai::OpenAiProvider::new(api_key))), + "openrouter" => Ok(Box::new(openrouter::OpenRouterProvider::new(key))), + "anthropic" => Ok(Box::new(anthropic::AnthropicProvider::new(key))), + "openai" => Ok(Box::new(openai::OpenAiProvider::new(key))), // Ollama is a local service that doesn't use API keys. // The api_key parameter is ignored to avoid it being misinterpreted as a base_url. "ollama" => Ok(Box::new(ollama::OllamaProvider::new(None))), "gemini" | "google" | "google-gemini" => { - Ok(Box::new(gemini::GeminiProvider::new(api_key))) + Ok(Box::new(gemini::GeminiProvider::new(key))) } // ── OpenAI-compatible providers ────────────────────── "venice" => Ok(Box::new(OpenAiCompatibleProvider::new( - "Venice", "https://api.venice.ai", api_key, AuthStyle::Bearer, + "Venice", "https://api.venice.ai", key, AuthStyle::Bearer, ))), "vercel" | "vercel-ai" => Ok(Box::new(OpenAiCompatibleProvider::new( - "Vercel AI Gateway", "https://api.vercel.ai", api_key, AuthStyle::Bearer, + "Vercel AI Gateway", "https://api.vercel.ai", key, AuthStyle::Bearer, ))), "cloudflare" | "cloudflare-ai" => Ok(Box::new(OpenAiCompatibleProvider::new( "Cloudflare AI Gateway", @@ -181,22 +182,22 @@ pub fn create_provider(name: &str, api_key: Option<&str>) -> anyhow::Result Ok(Box::new(OpenAiCompatibleProvider::new( - "Moonshot", "https://api.moonshot.cn", api_key, AuthStyle::Bearer, + "Moonshot", "https://api.moonshot.cn", key, AuthStyle::Bearer, ))), "synthetic" => Ok(Box::new(OpenAiCompatibleProvider::new( - "Synthetic", "https://api.synthetic.com", api_key, AuthStyle::Bearer, + "Synthetic", "https://api.synthetic.com", key, AuthStyle::Bearer, ))), "opencode" | "opencode-zen" => Ok(Box::new(OpenAiCompatibleProvider::new( - "OpenCode Zen", "https://api.opencode.ai", api_key, AuthStyle::Bearer, + "OpenCode Zen", "https://api.opencode.ai", key, AuthStyle::Bearer, ))), "zai" | "z.ai" => Ok(Box::new(OpenAiCompatibleProvider::new( - "Z.AI", "https://api.z.ai/api/coding/paas/v4", api_key, AuthStyle::Bearer, + "Z.AI", "https://api.z.ai/api/coding/paas/v4", key, AuthStyle::Bearer, ))), "glm" | "zhipu" => Ok(Box::new(OpenAiCompatibleProvider::new( - "GLM", "https://open.bigmodel.cn/api/paas", api_key, AuthStyle::Bearer, + "GLM", "https://open.bigmodel.cn/api/paas", key, AuthStyle::Bearer, ))), "minimax" => Ok(Box::new(OpenAiCompatibleProvider::new( - "MiniMax", "https://api.minimax.chat", api_key, AuthStyle::Bearer, + "MiniMax", "https://api.minimax.chat", key, AuthStyle::Bearer, ))), "bedrock" | "aws-bedrock" => Ok(Box::new(OpenAiCompatibleProvider::new( "Amazon Bedrock", @@ -205,36 +206,36 @@ pub fn create_provider(name: &str, api_key: Option<&str>) -> anyhow::Result Ok(Box::new(OpenAiCompatibleProvider::new( - "Qianfan", "https://aip.baidubce.com", api_key, AuthStyle::Bearer, + "Qianfan", "https://aip.baidubce.com", key, AuthStyle::Bearer, ))), // ── Extended ecosystem (community favorites) ───────── "groq" => Ok(Box::new(OpenAiCompatibleProvider::new( - "Groq", "https://api.groq.com/openai", api_key, AuthStyle::Bearer, + "Groq", "https://api.groq.com/openai", key, AuthStyle::Bearer, ))), "mistral" => Ok(Box::new(OpenAiCompatibleProvider::new( - "Mistral", "https://api.mistral.ai", api_key, AuthStyle::Bearer, + "Mistral", "https://api.mistral.ai", key, AuthStyle::Bearer, ))), "xai" | "grok" => Ok(Box::new(OpenAiCompatibleProvider::new( - "xAI", "https://api.x.ai", api_key, AuthStyle::Bearer, + "xAI", "https://api.x.ai", key, AuthStyle::Bearer, ))), "deepseek" => Ok(Box::new(OpenAiCompatibleProvider::new( - "DeepSeek", "https://api.deepseek.com", api_key, AuthStyle::Bearer, + "DeepSeek", "https://api.deepseek.com", key, AuthStyle::Bearer, ))), "together" | "together-ai" => Ok(Box::new(OpenAiCompatibleProvider::new( - "Together AI", "https://api.together.xyz", api_key, AuthStyle::Bearer, + "Together AI", "https://api.together.xyz", key, AuthStyle::Bearer, ))), "fireworks" | "fireworks-ai" => Ok(Box::new(OpenAiCompatibleProvider::new( - "Fireworks AI", "https://api.fireworks.ai/inference", api_key, AuthStyle::Bearer, + "Fireworks AI", "https://api.fireworks.ai/inference", key, AuthStyle::Bearer, ))), "perplexity" => Ok(Box::new(OpenAiCompatibleProvider::new( - "Perplexity", "https://api.perplexity.ai", api_key, AuthStyle::Bearer, + "Perplexity", "https://api.perplexity.ai", key, AuthStyle::Bearer, ))), "cohere" => Ok(Box::new(OpenAiCompatibleProvider::new( - "Cohere", "https://api.cohere.com/compatibility", api_key, AuthStyle::Bearer, + "Cohere", "https://api.cohere.com/compatibility", key, AuthStyle::Bearer, ))), "copilot" | "github-copilot" => Ok(Box::new(OpenAiCompatibleProvider::new( - "GitHub Copilot", "https://api.githubcopilot.com", api_key, AuthStyle::Bearer, + "GitHub Copilot", "https://api.githubcopilot.com", key, AuthStyle::Bearer, ))), // ── Bring Your Own Provider (custom URL) ─────────── @@ -247,7 +248,7 @@ pub fn create_provider(name: &str, api_key: Option<&str>) -> anyhow::Result) -> anyhow::Result, +} + +impl ImageInfoTool { + pub fn new(security: Arc) -> Self { + Self { security } + } + + /// Detect image format from first few bytes (magic numbers). + fn detect_format(bytes: &[u8]) -> &'static str { + if bytes.len() < 4 { + return "unknown"; + } + if bytes.starts_with(b"\x89PNG") { + "png" + } else if bytes.starts_with(b"\xFF\xD8\xFF") { + "jpeg" + } else if bytes.starts_with(b"GIF8") { + "gif" + } else if bytes.starts_with(b"RIFF") && bytes.len() >= 12 && &bytes[8..12] == b"WEBP" { + "webp" + } else if bytes.starts_with(b"BM") { + "bmp" + } else { + "unknown" + } + } + + /// Try to extract dimensions from image header bytes. + /// Returns (width, height) if detectable. + fn extract_dimensions(bytes: &[u8], format: &str) -> Option<(u32, u32)> { + match format { + "png" => { + // PNG IHDR chunk: bytes 16-19 = width, 20-23 = height (big-endian) + if bytes.len() >= 24 { + let w = u32::from_be_bytes([bytes[16], bytes[17], bytes[18], bytes[19]]); + let h = u32::from_be_bytes([bytes[20], bytes[21], bytes[22], bytes[23]]); + Some((w, h)) + } else { + None + } + } + "gif" => { + // GIF: bytes 6-7 = width, 8-9 = height (little-endian) + if bytes.len() >= 10 { + let w = u32::from(u16::from_le_bytes([bytes[6], bytes[7]])); + let h = u32::from(u16::from_le_bytes([bytes[8], bytes[9]])); + Some((w, h)) + } else { + None + } + } + "bmp" => { + // BMP: bytes 18-21 = width, 22-25 = height (little-endian, signed) + if bytes.len() >= 26 { + let w = u32::from_le_bytes([bytes[18], bytes[19], bytes[20], bytes[21]]); + let h_raw = i32::from_le_bytes([bytes[22], bytes[23], bytes[24], bytes[25]]); + let h = h_raw.unsigned_abs(); + Some((w, h)) + } else { + None + } + } + "jpeg" => Self::jpeg_dimensions(bytes), + _ => None, + } + } + + /// Parse JPEG SOF markers to extract dimensions. + fn jpeg_dimensions(bytes: &[u8]) -> Option<(u32, u32)> { + let mut i = 2; // skip SOI marker + while i + 1 < bytes.len() { + if bytes[i] != 0xFF { + return None; + } + let marker = bytes[i + 1]; + i += 2; + + // SOF0..SOF3 markers contain dimensions + if (0xC0..=0xC3).contains(&marker) { + if i + 7 <= bytes.len() { + let h = u32::from(u16::from_be_bytes([bytes[i + 3], bytes[i + 4]])); + let w = u32::from(u16::from_be_bytes([bytes[i + 5], bytes[i + 6]])); + return Some((w, h)); + } + return None; + } + + // Skip this segment + if i + 1 < bytes.len() { + let seg_len = u16::from_be_bytes([bytes[i], bytes[i + 1]]) as usize; + if seg_len < 2 { + return None; // Malformed segment (valid segments have length >= 2) + } + i += seg_len; + } else { + return None; + } + } + None + } +} + +#[async_trait] +impl Tool for ImageInfoTool { + fn name(&self) -> &str { + "image_info" + } + + fn description(&self) -> &str { + "Read image file metadata (format, dimensions, size) and optionally return base64-encoded data." + } + + fn parameters_schema(&self) -> serde_json::Value { + json!({ + "type": "object", + "properties": { + "path": { + "type": "string", + "description": "Path to the image file (absolute or relative to workspace)" + }, + "include_base64": { + "type": "boolean", + "description": "Include base64-encoded image data in output (default: false)" + } + }, + "required": ["path"] + }) + } + + async fn execute(&self, args: serde_json::Value) -> anyhow::Result { + let path_str = args + .get("path") + .and_then(|v| v.as_str()) + .ok_or_else(|| anyhow::anyhow!("Missing 'path' parameter"))?; + + let include_base64 = args + .get("include_base64") + .and_then(serde_json::Value::as_bool) + .unwrap_or(false); + + let path = Path::new(path_str); + + // Restrict reads to workspace directory to prevent arbitrary file exfiltration + if !self.security.is_path_allowed(path_str) { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!("Path not allowed: {path_str} (must be within workspace)")), + }); + } + + if !path.exists() { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!("File not found: {path_str}")), + }); + } + + let metadata = tokio::fs::metadata(path) + .await + .map_err(|e| anyhow::anyhow!("Failed to read file metadata: {e}"))?; + + let file_size = metadata.len(); + + if file_size > MAX_IMAGE_BYTES { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!( + "Image too large: {file_size} bytes (max {MAX_IMAGE_BYTES} bytes)" + )), + }); + } + + let bytes = tokio::fs::read(path) + .await + .map_err(|e| anyhow::anyhow!("Failed to read image file: {e}"))?; + + let format = Self::detect_format(&bytes); + let dimensions = Self::extract_dimensions(&bytes, format); + + let mut output = format!("File: {path_str}\nFormat: {format}\nSize: {file_size} bytes"); + + if let Some((w, h)) = dimensions { + let _ = write!(output, "\nDimensions: {w}x{h}"); + } + + if include_base64 { + use base64::Engine; + let encoded = base64::engine::general_purpose::STANDARD.encode(&bytes); + let mime = match format { + "png" => "image/png", + "jpeg" => "image/jpeg", + "gif" => "image/gif", + "webp" => "image/webp", + "bmp" => "image/bmp", + _ => "application/octet-stream", + }; + let _ = write!(output, "\ndata:{mime};base64,{encoded}"); + } + + Ok(ToolResult { + success: true, + output, + error: None, + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::security::{AutonomyLevel, SecurityPolicy}; + + fn test_security() -> Arc { + Arc::new(SecurityPolicy { + autonomy: AutonomyLevel::Full, + workspace_dir: std::env::temp_dir(), + workspace_only: false, + forbidden_paths: vec![], + ..SecurityPolicy::default() + }) + } + + #[test] + fn image_info_tool_name() { + let tool = ImageInfoTool::new(test_security()); + assert_eq!(tool.name(), "image_info"); + } + + #[test] + fn image_info_tool_description() { + let tool = ImageInfoTool::new(test_security()); + assert!(!tool.description().is_empty()); + assert!(tool.description().contains("image")); + } + + #[test] + fn image_info_tool_schema() { + let tool = ImageInfoTool::new(test_security()); + let schema = tool.parameters_schema(); + assert!(schema["properties"]["path"].is_object()); + assert!(schema["properties"]["include_base64"].is_object()); + let required = schema["required"].as_array().unwrap(); + assert!(required.contains(&json!("path"))); + } + + #[test] + fn image_info_tool_spec() { + let tool = ImageInfoTool::new(test_security()); + let spec = tool.spec(); + assert_eq!(spec.name, "image_info"); + assert!(spec.parameters.is_object()); + } + + // ── Format detection ──────────────────────────────────────── + + #[test] + fn detect_png() { + let bytes = b"\x89PNG\r\n\x1a\n"; + assert_eq!(ImageInfoTool::detect_format(bytes), "png"); + } + + #[test] + fn detect_jpeg() { + let bytes = b"\xFF\xD8\xFF\xE0"; + assert_eq!(ImageInfoTool::detect_format(bytes), "jpeg"); + } + + #[test] + fn detect_gif() { + let bytes = b"GIF89a"; + assert_eq!(ImageInfoTool::detect_format(bytes), "gif"); + } + + #[test] + fn detect_webp() { + let bytes = b"RIFF\x00\x00\x00\x00WEBP"; + assert_eq!(ImageInfoTool::detect_format(bytes), "webp"); + } + + #[test] + fn detect_bmp() { + let bytes = b"BM\x00\x00"; + assert_eq!(ImageInfoTool::detect_format(bytes), "bmp"); + } + + #[test] + fn detect_unknown_short() { + let bytes = b"\x00\x01"; + assert_eq!(ImageInfoTool::detect_format(bytes), "unknown"); + } + + #[test] + fn detect_unknown_garbage() { + let bytes = b"this is not an image"; + assert_eq!(ImageInfoTool::detect_format(bytes), "unknown"); + } + + // ── Dimension extraction ──────────────────────────────────── + + #[test] + fn png_dimensions() { + // Minimal PNG IHDR: 8-byte signature + 4-byte length + 4-byte IHDR + 4-byte width + 4-byte height + let mut bytes = vec![ + 0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A, // PNG signature + 0x00, 0x00, 0x00, 0x0D, // IHDR length + 0x49, 0x48, 0x44, 0x52, // "IHDR" + 0x00, 0x00, 0x03, 0x20, // width: 800 + 0x00, 0x00, 0x02, 0x58, // height: 600 + ]; + bytes.extend_from_slice(&[0u8; 10]); // padding + let dims = ImageInfoTool::extract_dimensions(&bytes, "png"); + assert_eq!(dims, Some((800, 600))); + } + + #[test] + fn gif_dimensions() { + let bytes = [ + 0x47, 0x49, 0x46, 0x38, 0x39, 0x61, // GIF89a + 0x40, 0x01, // width: 320 (LE) + 0xF0, 0x00, // height: 240 (LE) + ]; + let dims = ImageInfoTool::extract_dimensions(&bytes, "gif"); + assert_eq!(dims, Some((320, 240))); + } + + #[test] + fn bmp_dimensions() { + let mut bytes = vec![0u8; 26]; + bytes[0] = b'B'; + bytes[1] = b'M'; + // width at offset 18 (LE): 1024 + bytes[18] = 0x00; + bytes[19] = 0x04; + bytes[20] = 0x00; + bytes[21] = 0x00; + // height at offset 22 (LE): 768 + bytes[22] = 0x00; + bytes[23] = 0x03; + bytes[24] = 0x00; + bytes[25] = 0x00; + let dims = ImageInfoTool::extract_dimensions(&bytes, "bmp"); + assert_eq!(dims, Some((1024, 768))); + } + + #[test] + fn jpeg_dimensions() { + // Minimal JPEG-like byte sequence with SOF0 marker + let mut bytes: Vec = vec![ + 0xFF, 0xD8, // SOI + 0xFF, 0xE0, // APP0 marker + 0x00, 0x10, // APP0 length = 16 + ]; + bytes.extend_from_slice(&[0u8; 14]); // APP0 payload + bytes.extend_from_slice(&[ + 0xFF, 0xC0, // SOF0 marker + 0x00, 0x11, // SOF0 length + 0x08, // precision + 0x01, 0xE0, // height: 480 + 0x02, 0x80, // width: 640 + ]); + let dims = ImageInfoTool::extract_dimensions(&bytes, "jpeg"); + assert_eq!(dims, Some((640, 480))); + } + + #[test] + fn jpeg_malformed_zero_length_segment() { + // Zero-length segment should return None instead of looping forever + let bytes: Vec = vec![ + 0xFF, 0xD8, // SOI + 0xFF, 0xE0, // APP0 marker + 0x00, 0x00, // length = 0 (malformed) + ]; + let dims = ImageInfoTool::extract_dimensions(&bytes, "jpeg"); + assert!(dims.is_none()); + } + + #[test] + fn unknown_format_no_dimensions() { + let bytes = b"random data here"; + let dims = ImageInfoTool::extract_dimensions(bytes, "unknown"); + assert!(dims.is_none()); + } + + // ── Execute tests ─────────────────────────────────────────── + + #[tokio::test] + async fn execute_missing_path() { + let tool = ImageInfoTool::new(test_security()); + let result = tool.execute(json!({})).await; + assert!(result.is_err()); + } + + #[tokio::test] + async fn execute_nonexistent_file() { + let tool = ImageInfoTool::new(test_security()); + let result = tool + .execute(json!({"path": "/tmp/nonexistent_image_xyz.png"})) + .await + .unwrap(); + assert!(!result.success); + assert!(result.error.as_ref().unwrap().contains("not found")); + } + + #[tokio::test] + async fn execute_real_file() { + // Create a minimal valid PNG + let dir = std::env::temp_dir().join("zeroclaw_image_info_test"); + let _ = std::fs::create_dir_all(&dir); + let png_path = dir.join("test.png"); + + // Minimal 1x1 red PNG (67 bytes) + let png_bytes: Vec = vec![ + 0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A, // signature + 0x00, 0x00, 0x00, 0x0D, // IHDR length + 0x49, 0x48, 0x44, 0x52, // IHDR + 0x00, 0x00, 0x00, 0x01, // width: 1 + 0x00, 0x00, 0x00, 0x01, // height: 1 + 0x08, 0x02, 0x00, 0x00, 0x00, // bit depth, color type, etc. + 0x90, 0x77, 0x53, 0xDE, // CRC + 0x00, 0x00, 0x00, 0x0C, // IDAT length + 0x49, 0x44, 0x41, 0x54, // IDAT + 0x08, 0xD7, 0x63, 0xF8, 0xCF, 0xC0, 0x00, 0x00, 0x00, 0x02, 0x00, 0x01, 0xE2, 0x21, + 0xBC, 0x33, // CRC + 0x00, 0x00, 0x00, 0x00, // IEND length + 0x49, 0x45, 0x4E, 0x44, // IEND + 0xAE, 0x42, 0x60, 0x82, // CRC + ]; + std::fs::write(&png_path, &png_bytes).unwrap(); + + let tool = ImageInfoTool::new(test_security()); + let result = tool + .execute(json!({"path": png_path.to_string_lossy()})) + .await + .unwrap(); + assert!(result.success); + assert!(result.output.contains("Format: png")); + assert!(result.output.contains("Dimensions: 1x1")); + assert!(!result.output.contains("data:")); + + // Clean up + let _ = std::fs::remove_dir_all(&dir); + } + + #[tokio::test] + async fn execute_with_base64() { + let dir = std::env::temp_dir().join("zeroclaw_image_info_b64"); + let _ = std::fs::create_dir_all(&dir); + let png_path = dir.join("test_b64.png"); + + // Minimal 1x1 PNG + let png_bytes: Vec = vec![ + 0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A, 0x00, 0x00, 0x00, 0x0D, 0x49, 0x48, + 0x44, 0x52, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x08, 0x02, 0x00, 0x00, + 0x00, 0x90, 0x77, 0x53, 0xDE, 0x00, 0x00, 0x00, 0x0C, 0x49, 0x44, 0x41, 0x54, 0x08, + 0xD7, 0x63, 0xF8, 0xCF, 0xC0, 0x00, 0x00, 0x00, 0x02, 0x00, 0x01, 0xE2, 0x21, 0xBC, + 0x33, 0x00, 0x00, 0x00, 0x00, 0x49, 0x45, 0x4E, 0x44, 0xAE, 0x42, 0x60, 0x82, + ]; + std::fs::write(&png_path, &png_bytes).unwrap(); + + let tool = ImageInfoTool::new(test_security()); + let result = tool + .execute(json!({"path": png_path.to_string_lossy(), "include_base64": true})) + .await + .unwrap(); + assert!(result.success); + assert!(result.output.contains("data:image/png;base64,")); + + let _ = std::fs::remove_dir_all(&dir); + } +} diff --git a/src/tools/mod.rs b/src/tools/mod.rs index 6f9891f..446c1ee 100644 --- a/src/tools/mod.rs +++ b/src/tools/mod.rs @@ -3,9 +3,11 @@ pub mod browser_open; pub mod composio; pub mod file_read; pub mod file_write; +pub mod image_info; pub mod memory_forget; pub mod memory_recall; pub mod memory_store; +pub mod screenshot; pub mod shell; pub mod traits; @@ -14,9 +16,11 @@ pub use browser_open::BrowserOpenTool; pub use composio::ComposioTool; pub use file_read::FileReadTool; pub use file_write::FileWriteTool; +pub use image_info::ImageInfoTool; pub use memory_forget::MemoryForgetTool; pub use memory_recall::MemoryRecallTool; pub use memory_store::MemoryStoreTool; +pub use screenshot::ScreenshotTool; pub use shell::ShellTool; pub use traits::Tool; #[allow(unused_imports)] @@ -91,6 +95,10 @@ pub fn all_tools_with_runtime( ))); } + // Vision tools are always available + tools.push(Box::new(ScreenshotTool::new(security.clone()))); + tools.push(Box::new(ImageInfoTool::new(security.clone()))); + if let Some(key) = composio_key { if !key.is_empty() { tools.push(Box::new(ComposioTool::new(key))); diff --git a/src/tools/screenshot.rs b/src/tools/screenshot.rs new file mode 100644 index 0000000..7581bc1 --- /dev/null +++ b/src/tools/screenshot.rs @@ -0,0 +1,300 @@ +use super::traits::{Tool, ToolResult}; +use crate::security::SecurityPolicy; +use async_trait::async_trait; +use serde_json::json; +use std::fmt::Write; +use std::path::PathBuf; +use std::sync::Arc; +use std::time::Duration; + +/// Maximum time to wait for a screenshot command to complete. +const SCREENSHOT_TIMEOUT_SECS: u64 = 15; +/// Maximum base64 payload size to return (2 MB of base64 ≈ 1.5 MB image). +const MAX_BASE64_BYTES: usize = 2_097_152; + +/// Tool for capturing screenshots using platform-native commands. +/// +/// macOS: `screencapture` +/// Linux: tries `gnome-screenshot`, `scrot`, `import` (`ImageMagick`) in order. +pub struct ScreenshotTool { + security: Arc, +} + +impl ScreenshotTool { + pub fn new(security: Arc) -> Self { + Self { security } + } + + /// Determine the screenshot command for the current platform. + fn screenshot_command(output_path: &str) -> Option> { + if cfg!(target_os = "macos") { + Some(vec![ + "screencapture".into(), + "-x".into(), // no sound + output_path.into(), + ]) + } else if cfg!(target_os = "linux") { + Some(vec![ + "sh".into(), + "-c".into(), + format!( + "if command -v gnome-screenshot >/dev/null 2>&1; then \ + gnome-screenshot -f '{output_path}'; \ + elif command -v scrot >/dev/null 2>&1; then \ + scrot '{output_path}'; \ + elif command -v import >/dev/null 2>&1; then \ + import -window root '{output_path}'; \ + else \ + echo 'NO_SCREENSHOT_TOOL' >&2; exit 1; \ + fi" + ), + ]) + } else { + None + } + } + + /// Execute the screenshot capture and return the result. + async fn capture(&self, args: serde_json::Value) -> anyhow::Result { + let timestamp = chrono::Utc::now().format("%Y%m%d_%H%M%S"); + let filename = args + .get("filename") + .and_then(|v| v.as_str()) + .map_or_else(|| format!("screenshot_{timestamp}.png"), String::from); + + // Sanitize filename to prevent path traversal + let safe_name = PathBuf::from(&filename).file_name().map_or_else( + || format!("screenshot_{timestamp}.png"), + |n| n.to_string_lossy().to_string(), + ); + + let output_path = self.security.workspace_dir.join(&safe_name); + let output_str = output_path.to_string_lossy().to_string(); + + let Some(mut cmd_args) = Self::screenshot_command(&output_str) else { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("Screenshot not supported on this platform".into()), + }); + }; + + // macOS region flags + if cfg!(target_os = "macos") { + if let Some(region) = args.get("region").and_then(|v| v.as_str()) { + match region { + "selection" => cmd_args.insert(1, "-s".into()), + "window" => cmd_args.insert(1, "-w".into()), + _ => {} // ignore unknown regions + } + } + } + + let program = cmd_args.remove(0); + let result = tokio::time::timeout( + Duration::from_secs(SCREENSHOT_TIMEOUT_SECS), + tokio::process::Command::new(&program) + .args(&cmd_args) + .output(), + ) + .await; + + match result { + Ok(Ok(output)) => { + if !output.status.success() { + let stderr = String::from_utf8_lossy(&output.stderr); + if stderr.contains("NO_SCREENSHOT_TOOL") { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some( + "No screenshot tool found. Install gnome-screenshot, scrot, or ImageMagick." + .into(), + ), + }); + } + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!("Screenshot command failed: {stderr}")), + }); + } + + Self::read_and_encode(&output_path).await + } + Ok(Err(e)) => Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!("Failed to execute screenshot command: {e}")), + }), + Err(_) => Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!( + "Screenshot timed out after {SCREENSHOT_TIMEOUT_SECS}s" + )), + }), + } + } + + /// Read the screenshot file and return base64-encoded result. + async fn read_and_encode(output_path: &std::path::Path) -> anyhow::Result { + // Check file size before reading to prevent OOM on large screenshots + const MAX_RAW_BYTES: u64 = 1_572_864; // ~1.5 MB (base64 expands ~33%) + if let Ok(meta) = tokio::fs::metadata(output_path).await { + if meta.len() > MAX_RAW_BYTES { + return Ok(ToolResult { + success: true, + output: format!( + "Screenshot saved to: {}\nSize: {} bytes (too large to base64-encode inline)", + output_path.display(), + meta.len(), + ), + error: None, + }); + } + } + + match tokio::fs::read(output_path).await { + Ok(bytes) => { + use base64::Engine; + let size = bytes.len(); + let mut encoded = base64::engine::general_purpose::STANDARD.encode(&bytes); + let truncated = if encoded.len() > MAX_BASE64_BYTES { + encoded.truncate(encoded.floor_char_boundary(MAX_BASE64_BYTES)); + true + } else { + false + }; + + let mut output_msg = format!( + "Screenshot saved to: {}\nSize: {size} bytes\nBase64 length: {}", + output_path.display(), + encoded.len(), + ); + if truncated { + output_msg.push_str(" (truncated)"); + } + let mime = match output_path.extension().and_then(|e| e.to_str()) { + Some("jpg" | "jpeg") => "image/jpeg", + Some("bmp") => "image/bmp", + Some("gif") => "image/gif", + Some("webp") => "image/webp", + _ => "image/png", + }; + let _ = write!(output_msg, "\ndata:{mime};base64,{encoded}"); + + Ok(ToolResult { + success: true, + output: output_msg, + error: None, + }) + } + Err(e) => Ok(ToolResult { + success: false, + output: format!("Screenshot saved to: {}", output_path.display()), + error: Some(format!("Failed to read screenshot file: {e}")), + }), + } + } +} + +#[async_trait] +impl Tool for ScreenshotTool { + fn name(&self) -> &str { + "screenshot" + } + + fn description(&self) -> &str { + "Capture a screenshot of the current screen. Returns the file path and base64-encoded PNG data." + } + + fn parameters_schema(&self) -> serde_json::Value { + json!({ + "type": "object", + "properties": { + "filename": { + "type": "string", + "description": "Optional filename (default: screenshot_.png). Saved in workspace." + }, + "region": { + "type": "string", + "description": "Optional region for macOS: 'selection' for interactive crop, 'window' for front window. Ignored on Linux." + } + } + }) + } + + async fn execute(&self, args: serde_json::Value) -> anyhow::Result { + if !self.security.can_act() { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("Action blocked: autonomy is read-only".into()), + }); + } + self.capture(args).await + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::security::{AutonomyLevel, SecurityPolicy}; + + fn test_security() -> Arc { + Arc::new(SecurityPolicy { + autonomy: AutonomyLevel::Full, + workspace_dir: std::env::temp_dir(), + ..SecurityPolicy::default() + }) + } + + #[test] + fn screenshot_tool_name() { + let tool = ScreenshotTool::new(test_security()); + assert_eq!(tool.name(), "screenshot"); + } + + #[test] + fn screenshot_tool_description() { + let tool = ScreenshotTool::new(test_security()); + assert!(!tool.description().is_empty()); + assert!(tool.description().contains("screenshot")); + } + + #[test] + fn screenshot_tool_schema() { + let tool = ScreenshotTool::new(test_security()); + let schema = tool.parameters_schema(); + assert!(schema["properties"]["filename"].is_object()); + assert!(schema["properties"]["region"].is_object()); + } + + #[test] + fn screenshot_tool_spec() { + let tool = ScreenshotTool::new(test_security()); + let spec = tool.spec(); + assert_eq!(spec.name, "screenshot"); + assert!(spec.parameters.is_object()); + } + + #[test] + #[cfg(any(target_os = "macos", target_os = "linux"))] + fn screenshot_command_exists() { + let cmd = ScreenshotTool::screenshot_command("/tmp/test.png"); + assert!(cmd.is_some()); + let args = cmd.unwrap(); + assert!(!args.is_empty()); + } + + #[test] + fn screenshot_command_contains_output_path() { + let cmd = ScreenshotTool::screenshot_command("/tmp/my_screenshot.png").unwrap(); + let joined = cmd.join(" "); + assert!( + joined.contains("/tmp/my_screenshot.png"), + "Command should contain the output path" + ); + } +}