diff --git a/Cargo.lock b/Cargo.lock index 614cbb6..f39c66f 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3191,6 +3191,7 @@ dependencies = [ "anyhow", "async-trait", "axum", + "base64", "chacha20poly1305", "chrono", "clap", diff --git a/Cargo.toml b/Cargo.toml index 45dfcaf..6ead2f0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -36,6 +36,9 @@ tracing-subscriber = { version = "0.3", default-features = false, features = ["f # Observability - Prometheus metrics prometheus = { version = "0.13", default-features = false } +# Base64 encoding (screenshots, image data) +base64 = "0.22" + # Error handling anyhow = "1.0" thiserror = "2.0" diff --git a/src/agent/loop_.rs b/src/agent/loop_.rs index 991905b..0d6b89d 100644 --- a/src/agent/loop_.rs +++ b/src/agent/loop_.rs @@ -321,6 +321,14 @@ pub async fn run( "Delete a memory entry. Use when: memory is incorrect/stale or explicitly requested for removal. Don't use when: impact is uncertain.", ), ]; + tool_descs.push(( + "screenshot", + "Capture a screenshot of the current screen. Returns file path and base64-encoded PNG. Use when: visual verification, UI inspection, debugging displays.", + )); + tool_descs.push(( + "image_info", + "Read image file metadata (format, dimensions, size) and optionally base64-encode it. Use when: inspecting images, preparing visual data for analysis.", + )); if config.browser.enabled { tool_descs.push(( "browser_open", diff --git a/src/providers/mod.rs b/src/providers/mod.rs index db65d63..1143374 100644 --- a/src/providers/mod.rs +++ b/src/providers/mod.rs @@ -154,25 +154,26 @@ fn resolve_api_key(name: &str, api_key: Option<&str>) -> Option { /// Factory: create the right provider from config #[allow(clippy::too_many_lines)] pub fn create_provider(name: &str, api_key: Option<&str>) -> anyhow::Result> { - let _resolved_key = resolve_api_key(name, api_key); + let resolved_key = resolve_api_key(name, api_key); + let key = resolved_key.as_deref(); match name { // ── Primary providers (custom implementations) ─────── - "openrouter" => Ok(Box::new(openrouter::OpenRouterProvider::new(api_key))), - "anthropic" => Ok(Box::new(anthropic::AnthropicProvider::new(api_key))), - "openai" => Ok(Box::new(openai::OpenAiProvider::new(api_key))), + "openrouter" => Ok(Box::new(openrouter::OpenRouterProvider::new(key))), + "anthropic" => Ok(Box::new(anthropic::AnthropicProvider::new(key))), + "openai" => Ok(Box::new(openai::OpenAiProvider::new(key))), // Ollama is a local service that doesn't use API keys. // The api_key parameter is ignored to avoid it being misinterpreted as a base_url. "ollama" => Ok(Box::new(ollama::OllamaProvider::new(None))), "gemini" | "google" | "google-gemini" => { - Ok(Box::new(gemini::GeminiProvider::new(api_key))) + Ok(Box::new(gemini::GeminiProvider::new(key))) } // ── OpenAI-compatible providers ────────────────────── "venice" => Ok(Box::new(OpenAiCompatibleProvider::new( - "Venice", "https://api.venice.ai", api_key, AuthStyle::Bearer, + "Venice", "https://api.venice.ai", key, AuthStyle::Bearer, ))), "vercel" | "vercel-ai" => Ok(Box::new(OpenAiCompatibleProvider::new( - "Vercel AI Gateway", "https://api.vercel.ai", api_key, AuthStyle::Bearer, + "Vercel AI Gateway", "https://api.vercel.ai", key, AuthStyle::Bearer, ))), "cloudflare" | "cloudflare-ai" => Ok(Box::new(OpenAiCompatibleProvider::new( "Cloudflare AI Gateway", @@ -181,22 +182,22 @@ pub fn create_provider(name: &str, api_key: Option<&str>) -> anyhow::Result Ok(Box::new(OpenAiCompatibleProvider::new( - "Moonshot", "https://api.moonshot.cn", api_key, AuthStyle::Bearer, + "Moonshot", "https://api.moonshot.cn", key, AuthStyle::Bearer, ))), "synthetic" => Ok(Box::new(OpenAiCompatibleProvider::new( - "Synthetic", "https://api.synthetic.com", api_key, AuthStyle::Bearer, + "Synthetic", "https://api.synthetic.com", key, AuthStyle::Bearer, ))), "opencode" | "opencode-zen" => Ok(Box::new(OpenAiCompatibleProvider::new( - "OpenCode Zen", "https://api.opencode.ai", api_key, AuthStyle::Bearer, + "OpenCode Zen", "https://api.opencode.ai", key, AuthStyle::Bearer, ))), "zai" | "z.ai" => Ok(Box::new(OpenAiCompatibleProvider::new( - "Z.AI", "https://api.z.ai/api/coding/paas/v4", api_key, AuthStyle::Bearer, + "Z.AI", "https://api.z.ai/api/coding/paas/v4", key, AuthStyle::Bearer, ))), "glm" | "zhipu" => Ok(Box::new(OpenAiCompatibleProvider::new( - "GLM", "https://open.bigmodel.cn/api/paas", api_key, AuthStyle::Bearer, + "GLM", "https://open.bigmodel.cn/api/paas", key, AuthStyle::Bearer, ))), "minimax" => Ok(Box::new(OpenAiCompatibleProvider::new( - "MiniMax", "https://api.minimax.chat", api_key, AuthStyle::Bearer, + "MiniMax", "https://api.minimax.chat", key, AuthStyle::Bearer, ))), "bedrock" | "aws-bedrock" => Ok(Box::new(OpenAiCompatibleProvider::new( "Amazon Bedrock", @@ -205,36 +206,36 @@ pub fn create_provider(name: &str, api_key: Option<&str>) -> anyhow::Result Ok(Box::new(OpenAiCompatibleProvider::new( - "Qianfan", "https://aip.baidubce.com", api_key, AuthStyle::Bearer, + "Qianfan", "https://aip.baidubce.com", key, AuthStyle::Bearer, ))), // ── Extended ecosystem (community favorites) ───────── "groq" => Ok(Box::new(OpenAiCompatibleProvider::new( - "Groq", "https://api.groq.com/openai", api_key, AuthStyle::Bearer, + "Groq", "https://api.groq.com/openai", key, AuthStyle::Bearer, ))), "mistral" => Ok(Box::new(OpenAiCompatibleProvider::new( - "Mistral", "https://api.mistral.ai", api_key, AuthStyle::Bearer, + "Mistral", "https://api.mistral.ai", key, AuthStyle::Bearer, ))), "xai" | "grok" => Ok(Box::new(OpenAiCompatibleProvider::new( - "xAI", "https://api.x.ai", api_key, AuthStyle::Bearer, + "xAI", "https://api.x.ai", key, AuthStyle::Bearer, ))), "deepseek" => Ok(Box::new(OpenAiCompatibleProvider::new( - "DeepSeek", "https://api.deepseek.com", api_key, AuthStyle::Bearer, + "DeepSeek", "https://api.deepseek.com", key, AuthStyle::Bearer, ))), "together" | "together-ai" => Ok(Box::new(OpenAiCompatibleProvider::new( - "Together AI", "https://api.together.xyz", api_key, AuthStyle::Bearer, + "Together AI", "https://api.together.xyz", key, AuthStyle::Bearer, ))), "fireworks" | "fireworks-ai" => Ok(Box::new(OpenAiCompatibleProvider::new( - "Fireworks AI", "https://api.fireworks.ai/inference", api_key, AuthStyle::Bearer, + "Fireworks AI", "https://api.fireworks.ai/inference", key, AuthStyle::Bearer, ))), "perplexity" => Ok(Box::new(OpenAiCompatibleProvider::new( - "Perplexity", "https://api.perplexity.ai", api_key, AuthStyle::Bearer, + "Perplexity", "https://api.perplexity.ai", key, AuthStyle::Bearer, ))), "cohere" => Ok(Box::new(OpenAiCompatibleProvider::new( - "Cohere", "https://api.cohere.com/compatibility", api_key, AuthStyle::Bearer, + "Cohere", "https://api.cohere.com/compatibility", key, AuthStyle::Bearer, ))), "copilot" | "github-copilot" => Ok(Box::new(OpenAiCompatibleProvider::new( - "GitHub Copilot", "https://api.githubcopilot.com", api_key, AuthStyle::Bearer, + "GitHub Copilot", "https://api.githubcopilot.com", key, AuthStyle::Bearer, ))), // ── Bring Your Own Provider (custom URL) ─────────── @@ -247,7 +248,7 @@ pub fn create_provider(name: &str, api_key: Option<&str>) -> anyhow::Result) -> anyhow::Result, +} + +impl ImageInfoTool { + pub fn new(security: Arc) -> Self { + Self { security } + } + + /// Detect image format from first few bytes (magic numbers). + fn detect_format(bytes: &[u8]) -> &'static str { + if bytes.len() < 4 { + return "unknown"; + } + if bytes.starts_with(b"\x89PNG") { + "png" + } else if bytes.starts_with(b"\xFF\xD8\xFF") { + "jpeg" + } else if bytes.starts_with(b"GIF8") { + "gif" + } else if bytes.starts_with(b"RIFF") && bytes.len() >= 12 && &bytes[8..12] == b"WEBP" { + "webp" + } else if bytes.starts_with(b"BM") { + "bmp" + } else { + "unknown" + } + } + + /// Try to extract dimensions from image header bytes. + /// Returns (width, height) if detectable. + fn extract_dimensions(bytes: &[u8], format: &str) -> Option<(u32, u32)> { + match format { + "png" => { + // PNG IHDR chunk: bytes 16-19 = width, 20-23 = height (big-endian) + if bytes.len() >= 24 { + let w = u32::from_be_bytes([bytes[16], bytes[17], bytes[18], bytes[19]]); + let h = u32::from_be_bytes([bytes[20], bytes[21], bytes[22], bytes[23]]); + Some((w, h)) + } else { + None + } + } + "gif" => { + // GIF: bytes 6-7 = width, 8-9 = height (little-endian) + if bytes.len() >= 10 { + let w = u32::from(u16::from_le_bytes([bytes[6], bytes[7]])); + let h = u32::from(u16::from_le_bytes([bytes[8], bytes[9]])); + Some((w, h)) + } else { + None + } + } + "bmp" => { + // BMP: bytes 18-21 = width, 22-25 = height (little-endian, signed) + if bytes.len() >= 26 { + let w = u32::from_le_bytes([bytes[18], bytes[19], bytes[20], bytes[21]]); + let h_raw = i32::from_le_bytes([bytes[22], bytes[23], bytes[24], bytes[25]]); + let h = h_raw.unsigned_abs(); + Some((w, h)) + } else { + None + } + } + "jpeg" => Self::jpeg_dimensions(bytes), + _ => None, + } + } + + /// Parse JPEG SOF markers to extract dimensions. + fn jpeg_dimensions(bytes: &[u8]) -> Option<(u32, u32)> { + let mut i = 2; // skip SOI marker + while i + 1 < bytes.len() { + if bytes[i] != 0xFF { + return None; + } + let marker = bytes[i + 1]; + i += 2; + + // SOF0..SOF3 markers contain dimensions + if (0xC0..=0xC3).contains(&marker) { + if i + 7 <= bytes.len() { + let h = u32::from(u16::from_be_bytes([bytes[i + 3], bytes[i + 4]])); + let w = u32::from(u16::from_be_bytes([bytes[i + 5], bytes[i + 6]])); + return Some((w, h)); + } + return None; + } + + // Skip this segment + if i + 1 < bytes.len() { + let seg_len = u16::from_be_bytes([bytes[i], bytes[i + 1]]) as usize; + if seg_len < 2 { + return None; // Malformed segment (valid segments have length >= 2) + } + i += seg_len; + } else { + return None; + } + } + None + } +} + +#[async_trait] +impl Tool for ImageInfoTool { + fn name(&self) -> &str { + "image_info" + } + + fn description(&self) -> &str { + "Read image file metadata (format, dimensions, size) and optionally return base64-encoded data." + } + + fn parameters_schema(&self) -> serde_json::Value { + json!({ + "type": "object", + "properties": { + "path": { + "type": "string", + "description": "Path to the image file (absolute or relative to workspace)" + }, + "include_base64": { + "type": "boolean", + "description": "Include base64-encoded image data in output (default: false)" + } + }, + "required": ["path"] + }) + } + + async fn execute(&self, args: serde_json::Value) -> anyhow::Result { + let path_str = args + .get("path") + .and_then(|v| v.as_str()) + .ok_or_else(|| anyhow::anyhow!("Missing 'path' parameter"))?; + + let include_base64 = args + .get("include_base64") + .and_then(serde_json::Value::as_bool) + .unwrap_or(false); + + let path = Path::new(path_str); + + // Restrict reads to workspace directory to prevent arbitrary file exfiltration + if !self.security.is_path_allowed(path_str) { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!("Path not allowed: {path_str} (must be within workspace)")), + }); + } + + if !path.exists() { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!("File not found: {path_str}")), + }); + } + + let metadata = tokio::fs::metadata(path) + .await + .map_err(|e| anyhow::anyhow!("Failed to read file metadata: {e}"))?; + + let file_size = metadata.len(); + + if file_size > MAX_IMAGE_BYTES { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!( + "Image too large: {file_size} bytes (max {MAX_IMAGE_BYTES} bytes)" + )), + }); + } + + let bytes = tokio::fs::read(path) + .await + .map_err(|e| anyhow::anyhow!("Failed to read image file: {e}"))?; + + let format = Self::detect_format(&bytes); + let dimensions = Self::extract_dimensions(&bytes, format); + + let mut output = format!("File: {path_str}\nFormat: {format}\nSize: {file_size} bytes"); + + if let Some((w, h)) = dimensions { + let _ = write!(output, "\nDimensions: {w}x{h}"); + } + + if include_base64 { + use base64::Engine; + let encoded = base64::engine::general_purpose::STANDARD.encode(&bytes); + let mime = match format { + "png" => "image/png", + "jpeg" => "image/jpeg", + "gif" => "image/gif", + "webp" => "image/webp", + "bmp" => "image/bmp", + _ => "application/octet-stream", + }; + let _ = write!(output, "\ndata:{mime};base64,{encoded}"); + } + + Ok(ToolResult { + success: true, + output, + error: None, + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::security::{AutonomyLevel, SecurityPolicy}; + + fn test_security() -> Arc { + Arc::new(SecurityPolicy { + autonomy: AutonomyLevel::Full, + workspace_dir: std::env::temp_dir(), + workspace_only: false, + forbidden_paths: vec![], + ..SecurityPolicy::default() + }) + } + + #[test] + fn image_info_tool_name() { + let tool = ImageInfoTool::new(test_security()); + assert_eq!(tool.name(), "image_info"); + } + + #[test] + fn image_info_tool_description() { + let tool = ImageInfoTool::new(test_security()); + assert!(!tool.description().is_empty()); + assert!(tool.description().contains("image")); + } + + #[test] + fn image_info_tool_schema() { + let tool = ImageInfoTool::new(test_security()); + let schema = tool.parameters_schema(); + assert!(schema["properties"]["path"].is_object()); + assert!(schema["properties"]["include_base64"].is_object()); + let required = schema["required"].as_array().unwrap(); + assert!(required.contains(&json!("path"))); + } + + #[test] + fn image_info_tool_spec() { + let tool = ImageInfoTool::new(test_security()); + let spec = tool.spec(); + assert_eq!(spec.name, "image_info"); + assert!(spec.parameters.is_object()); + } + + // ── Format detection ──────────────────────────────────────── + + #[test] + fn detect_png() { + let bytes = b"\x89PNG\r\n\x1a\n"; + assert_eq!(ImageInfoTool::detect_format(bytes), "png"); + } + + #[test] + fn detect_jpeg() { + let bytes = b"\xFF\xD8\xFF\xE0"; + assert_eq!(ImageInfoTool::detect_format(bytes), "jpeg"); + } + + #[test] + fn detect_gif() { + let bytes = b"GIF89a"; + assert_eq!(ImageInfoTool::detect_format(bytes), "gif"); + } + + #[test] + fn detect_webp() { + let bytes = b"RIFF\x00\x00\x00\x00WEBP"; + assert_eq!(ImageInfoTool::detect_format(bytes), "webp"); + } + + #[test] + fn detect_bmp() { + let bytes = b"BM\x00\x00"; + assert_eq!(ImageInfoTool::detect_format(bytes), "bmp"); + } + + #[test] + fn detect_unknown_short() { + let bytes = b"\x00\x01"; + assert_eq!(ImageInfoTool::detect_format(bytes), "unknown"); + } + + #[test] + fn detect_unknown_garbage() { + let bytes = b"this is not an image"; + assert_eq!(ImageInfoTool::detect_format(bytes), "unknown"); + } + + // ── Dimension extraction ──────────────────────────────────── + + #[test] + fn png_dimensions() { + // Minimal PNG IHDR: 8-byte signature + 4-byte length + 4-byte IHDR + 4-byte width + 4-byte height + let mut bytes = vec![ + 0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A, // PNG signature + 0x00, 0x00, 0x00, 0x0D, // IHDR length + 0x49, 0x48, 0x44, 0x52, // "IHDR" + 0x00, 0x00, 0x03, 0x20, // width: 800 + 0x00, 0x00, 0x02, 0x58, // height: 600 + ]; + bytes.extend_from_slice(&[0u8; 10]); // padding + let dims = ImageInfoTool::extract_dimensions(&bytes, "png"); + assert_eq!(dims, Some((800, 600))); + } + + #[test] + fn gif_dimensions() { + let bytes = [ + 0x47, 0x49, 0x46, 0x38, 0x39, 0x61, // GIF89a + 0x40, 0x01, // width: 320 (LE) + 0xF0, 0x00, // height: 240 (LE) + ]; + let dims = ImageInfoTool::extract_dimensions(&bytes, "gif"); + assert_eq!(dims, Some((320, 240))); + } + + #[test] + fn bmp_dimensions() { + let mut bytes = vec![0u8; 26]; + bytes[0] = b'B'; + bytes[1] = b'M'; + // width at offset 18 (LE): 1024 + bytes[18] = 0x00; + bytes[19] = 0x04; + bytes[20] = 0x00; + bytes[21] = 0x00; + // height at offset 22 (LE): 768 + bytes[22] = 0x00; + bytes[23] = 0x03; + bytes[24] = 0x00; + bytes[25] = 0x00; + let dims = ImageInfoTool::extract_dimensions(&bytes, "bmp"); + assert_eq!(dims, Some((1024, 768))); + } + + #[test] + fn jpeg_dimensions() { + // Minimal JPEG-like byte sequence with SOF0 marker + let mut bytes: Vec = vec![ + 0xFF, 0xD8, // SOI + 0xFF, 0xE0, // APP0 marker + 0x00, 0x10, // APP0 length = 16 + ]; + bytes.extend_from_slice(&[0u8; 14]); // APP0 payload + bytes.extend_from_slice(&[ + 0xFF, 0xC0, // SOF0 marker + 0x00, 0x11, // SOF0 length + 0x08, // precision + 0x01, 0xE0, // height: 480 + 0x02, 0x80, // width: 640 + ]); + let dims = ImageInfoTool::extract_dimensions(&bytes, "jpeg"); + assert_eq!(dims, Some((640, 480))); + } + + #[test] + fn jpeg_malformed_zero_length_segment() { + // Zero-length segment should return None instead of looping forever + let bytes: Vec = vec![ + 0xFF, 0xD8, // SOI + 0xFF, 0xE0, // APP0 marker + 0x00, 0x00, // length = 0 (malformed) + ]; + let dims = ImageInfoTool::extract_dimensions(&bytes, "jpeg"); + assert!(dims.is_none()); + } + + #[test] + fn unknown_format_no_dimensions() { + let bytes = b"random data here"; + let dims = ImageInfoTool::extract_dimensions(bytes, "unknown"); + assert!(dims.is_none()); + } + + // ── Execute tests ─────────────────────────────────────────── + + #[tokio::test] + async fn execute_missing_path() { + let tool = ImageInfoTool::new(test_security()); + let result = tool.execute(json!({})).await; + assert!(result.is_err()); + } + + #[tokio::test] + async fn execute_nonexistent_file() { + let tool = ImageInfoTool::new(test_security()); + let result = tool + .execute(json!({"path": "/tmp/nonexistent_image_xyz.png"})) + .await + .unwrap(); + assert!(!result.success); + assert!(result.error.as_ref().unwrap().contains("not found")); + } + + #[tokio::test] + async fn execute_real_file() { + // Create a minimal valid PNG + let dir = std::env::temp_dir().join("zeroclaw_image_info_test"); + let _ = std::fs::create_dir_all(&dir); + let png_path = dir.join("test.png"); + + // Minimal 1x1 red PNG (67 bytes) + let png_bytes: Vec = vec![ + 0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A, // signature + 0x00, 0x00, 0x00, 0x0D, // IHDR length + 0x49, 0x48, 0x44, 0x52, // IHDR + 0x00, 0x00, 0x00, 0x01, // width: 1 + 0x00, 0x00, 0x00, 0x01, // height: 1 + 0x08, 0x02, 0x00, 0x00, 0x00, // bit depth, color type, etc. + 0x90, 0x77, 0x53, 0xDE, // CRC + 0x00, 0x00, 0x00, 0x0C, // IDAT length + 0x49, 0x44, 0x41, 0x54, // IDAT + 0x08, 0xD7, 0x63, 0xF8, 0xCF, 0xC0, 0x00, 0x00, 0x00, 0x02, 0x00, 0x01, 0xE2, 0x21, + 0xBC, 0x33, // CRC + 0x00, 0x00, 0x00, 0x00, // IEND length + 0x49, 0x45, 0x4E, 0x44, // IEND + 0xAE, 0x42, 0x60, 0x82, // CRC + ]; + std::fs::write(&png_path, &png_bytes).unwrap(); + + let tool = ImageInfoTool::new(test_security()); + let result = tool + .execute(json!({"path": png_path.to_string_lossy()})) + .await + .unwrap(); + assert!(result.success); + assert!(result.output.contains("Format: png")); + assert!(result.output.contains("Dimensions: 1x1")); + assert!(!result.output.contains("data:")); + + // Clean up + let _ = std::fs::remove_dir_all(&dir); + } + + #[tokio::test] + async fn execute_with_base64() { + let dir = std::env::temp_dir().join("zeroclaw_image_info_b64"); + let _ = std::fs::create_dir_all(&dir); + let png_path = dir.join("test_b64.png"); + + // Minimal 1x1 PNG + let png_bytes: Vec = vec![ + 0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A, 0x00, 0x00, 0x00, 0x0D, 0x49, 0x48, + 0x44, 0x52, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x08, 0x02, 0x00, 0x00, + 0x00, 0x90, 0x77, 0x53, 0xDE, 0x00, 0x00, 0x00, 0x0C, 0x49, 0x44, 0x41, 0x54, 0x08, + 0xD7, 0x63, 0xF8, 0xCF, 0xC0, 0x00, 0x00, 0x00, 0x02, 0x00, 0x01, 0xE2, 0x21, 0xBC, + 0x33, 0x00, 0x00, 0x00, 0x00, 0x49, 0x45, 0x4E, 0x44, 0xAE, 0x42, 0x60, 0x82, + ]; + std::fs::write(&png_path, &png_bytes).unwrap(); + + let tool = ImageInfoTool::new(test_security()); + let result = tool + .execute(json!({"path": png_path.to_string_lossy(), "include_base64": true})) + .await + .unwrap(); + assert!(result.success); + assert!(result.output.contains("data:image/png;base64,")); + + let _ = std::fs::remove_dir_all(&dir); + } +} diff --git a/src/tools/mod.rs b/src/tools/mod.rs index 6f9891f..446c1ee 100644 --- a/src/tools/mod.rs +++ b/src/tools/mod.rs @@ -3,9 +3,11 @@ pub mod browser_open; pub mod composio; pub mod file_read; pub mod file_write; +pub mod image_info; pub mod memory_forget; pub mod memory_recall; pub mod memory_store; +pub mod screenshot; pub mod shell; pub mod traits; @@ -14,9 +16,11 @@ pub use browser_open::BrowserOpenTool; pub use composio::ComposioTool; pub use file_read::FileReadTool; pub use file_write::FileWriteTool; +pub use image_info::ImageInfoTool; pub use memory_forget::MemoryForgetTool; pub use memory_recall::MemoryRecallTool; pub use memory_store::MemoryStoreTool; +pub use screenshot::ScreenshotTool; pub use shell::ShellTool; pub use traits::Tool; #[allow(unused_imports)] @@ -91,6 +95,10 @@ pub fn all_tools_with_runtime( ))); } + // Vision tools are always available + tools.push(Box::new(ScreenshotTool::new(security.clone()))); + tools.push(Box::new(ImageInfoTool::new(security.clone()))); + if let Some(key) = composio_key { if !key.is_empty() { tools.push(Box::new(ComposioTool::new(key))); diff --git a/src/tools/screenshot.rs b/src/tools/screenshot.rs new file mode 100644 index 0000000..7581bc1 --- /dev/null +++ b/src/tools/screenshot.rs @@ -0,0 +1,300 @@ +use super::traits::{Tool, ToolResult}; +use crate::security::SecurityPolicy; +use async_trait::async_trait; +use serde_json::json; +use std::fmt::Write; +use std::path::PathBuf; +use std::sync::Arc; +use std::time::Duration; + +/// Maximum time to wait for a screenshot command to complete. +const SCREENSHOT_TIMEOUT_SECS: u64 = 15; +/// Maximum base64 payload size to return (2 MB of base64 ≈ 1.5 MB image). +const MAX_BASE64_BYTES: usize = 2_097_152; + +/// Tool for capturing screenshots using platform-native commands. +/// +/// macOS: `screencapture` +/// Linux: tries `gnome-screenshot`, `scrot`, `import` (`ImageMagick`) in order. +pub struct ScreenshotTool { + security: Arc, +} + +impl ScreenshotTool { + pub fn new(security: Arc) -> Self { + Self { security } + } + + /// Determine the screenshot command for the current platform. + fn screenshot_command(output_path: &str) -> Option> { + if cfg!(target_os = "macos") { + Some(vec![ + "screencapture".into(), + "-x".into(), // no sound + output_path.into(), + ]) + } else if cfg!(target_os = "linux") { + Some(vec![ + "sh".into(), + "-c".into(), + format!( + "if command -v gnome-screenshot >/dev/null 2>&1; then \ + gnome-screenshot -f '{output_path}'; \ + elif command -v scrot >/dev/null 2>&1; then \ + scrot '{output_path}'; \ + elif command -v import >/dev/null 2>&1; then \ + import -window root '{output_path}'; \ + else \ + echo 'NO_SCREENSHOT_TOOL' >&2; exit 1; \ + fi" + ), + ]) + } else { + None + } + } + + /// Execute the screenshot capture and return the result. + async fn capture(&self, args: serde_json::Value) -> anyhow::Result { + let timestamp = chrono::Utc::now().format("%Y%m%d_%H%M%S"); + let filename = args + .get("filename") + .and_then(|v| v.as_str()) + .map_or_else(|| format!("screenshot_{timestamp}.png"), String::from); + + // Sanitize filename to prevent path traversal + let safe_name = PathBuf::from(&filename).file_name().map_or_else( + || format!("screenshot_{timestamp}.png"), + |n| n.to_string_lossy().to_string(), + ); + + let output_path = self.security.workspace_dir.join(&safe_name); + let output_str = output_path.to_string_lossy().to_string(); + + let Some(mut cmd_args) = Self::screenshot_command(&output_str) else { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("Screenshot not supported on this platform".into()), + }); + }; + + // macOS region flags + if cfg!(target_os = "macos") { + if let Some(region) = args.get("region").and_then(|v| v.as_str()) { + match region { + "selection" => cmd_args.insert(1, "-s".into()), + "window" => cmd_args.insert(1, "-w".into()), + _ => {} // ignore unknown regions + } + } + } + + let program = cmd_args.remove(0); + let result = tokio::time::timeout( + Duration::from_secs(SCREENSHOT_TIMEOUT_SECS), + tokio::process::Command::new(&program) + .args(&cmd_args) + .output(), + ) + .await; + + match result { + Ok(Ok(output)) => { + if !output.status.success() { + let stderr = String::from_utf8_lossy(&output.stderr); + if stderr.contains("NO_SCREENSHOT_TOOL") { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some( + "No screenshot tool found. Install gnome-screenshot, scrot, or ImageMagick." + .into(), + ), + }); + } + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!("Screenshot command failed: {stderr}")), + }); + } + + Self::read_and_encode(&output_path).await + } + Ok(Err(e)) => Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!("Failed to execute screenshot command: {e}")), + }), + Err(_) => Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!( + "Screenshot timed out after {SCREENSHOT_TIMEOUT_SECS}s" + )), + }), + } + } + + /// Read the screenshot file and return base64-encoded result. + async fn read_and_encode(output_path: &std::path::Path) -> anyhow::Result { + // Check file size before reading to prevent OOM on large screenshots + const MAX_RAW_BYTES: u64 = 1_572_864; // ~1.5 MB (base64 expands ~33%) + if let Ok(meta) = tokio::fs::metadata(output_path).await { + if meta.len() > MAX_RAW_BYTES { + return Ok(ToolResult { + success: true, + output: format!( + "Screenshot saved to: {}\nSize: {} bytes (too large to base64-encode inline)", + output_path.display(), + meta.len(), + ), + error: None, + }); + } + } + + match tokio::fs::read(output_path).await { + Ok(bytes) => { + use base64::Engine; + let size = bytes.len(); + let mut encoded = base64::engine::general_purpose::STANDARD.encode(&bytes); + let truncated = if encoded.len() > MAX_BASE64_BYTES { + encoded.truncate(encoded.floor_char_boundary(MAX_BASE64_BYTES)); + true + } else { + false + }; + + let mut output_msg = format!( + "Screenshot saved to: {}\nSize: {size} bytes\nBase64 length: {}", + output_path.display(), + encoded.len(), + ); + if truncated { + output_msg.push_str(" (truncated)"); + } + let mime = match output_path.extension().and_then(|e| e.to_str()) { + Some("jpg" | "jpeg") => "image/jpeg", + Some("bmp") => "image/bmp", + Some("gif") => "image/gif", + Some("webp") => "image/webp", + _ => "image/png", + }; + let _ = write!(output_msg, "\ndata:{mime};base64,{encoded}"); + + Ok(ToolResult { + success: true, + output: output_msg, + error: None, + }) + } + Err(e) => Ok(ToolResult { + success: false, + output: format!("Screenshot saved to: {}", output_path.display()), + error: Some(format!("Failed to read screenshot file: {e}")), + }), + } + } +} + +#[async_trait] +impl Tool for ScreenshotTool { + fn name(&self) -> &str { + "screenshot" + } + + fn description(&self) -> &str { + "Capture a screenshot of the current screen. Returns the file path and base64-encoded PNG data." + } + + fn parameters_schema(&self) -> serde_json::Value { + json!({ + "type": "object", + "properties": { + "filename": { + "type": "string", + "description": "Optional filename (default: screenshot_.png). Saved in workspace." + }, + "region": { + "type": "string", + "description": "Optional region for macOS: 'selection' for interactive crop, 'window' for front window. Ignored on Linux." + } + } + }) + } + + async fn execute(&self, args: serde_json::Value) -> anyhow::Result { + if !self.security.can_act() { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("Action blocked: autonomy is read-only".into()), + }); + } + self.capture(args).await + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::security::{AutonomyLevel, SecurityPolicy}; + + fn test_security() -> Arc { + Arc::new(SecurityPolicy { + autonomy: AutonomyLevel::Full, + workspace_dir: std::env::temp_dir(), + ..SecurityPolicy::default() + }) + } + + #[test] + fn screenshot_tool_name() { + let tool = ScreenshotTool::new(test_security()); + assert_eq!(tool.name(), "screenshot"); + } + + #[test] + fn screenshot_tool_description() { + let tool = ScreenshotTool::new(test_security()); + assert!(!tool.description().is_empty()); + assert!(tool.description().contains("screenshot")); + } + + #[test] + fn screenshot_tool_schema() { + let tool = ScreenshotTool::new(test_security()); + let schema = tool.parameters_schema(); + assert!(schema["properties"]["filename"].is_object()); + assert!(schema["properties"]["region"].is_object()); + } + + #[test] + fn screenshot_tool_spec() { + let tool = ScreenshotTool::new(test_security()); + let spec = tool.spec(); + assert_eq!(spec.name, "screenshot"); + assert!(spec.parameters.is_object()); + } + + #[test] + #[cfg(any(target_os = "macos", target_os = "linux"))] + fn screenshot_command_exists() { + let cmd = ScreenshotTool::screenshot_command("/tmp/test.png"); + assert!(cmd.is_some()); + let args = cmd.unwrap(); + assert!(!args.is_empty()); + } + + #[test] + fn screenshot_command_contains_output_path() { + let cmd = ScreenshotTool::screenshot_command("/tmp/my_screenshot.png").unwrap(); + let joined = cmd.join(" "); + assert!( + joined.contains("/tmp/my_screenshot.png"), + "Command should contain the output path" + ); + } +}