feat: add screenshot and image_info vision tools

* feat: add screenshot and image_info vision tools

Add two new tools for visual capabilities:

- `screenshot`: captures screen using platform-native commands
  (screencapture on macOS, gnome-screenshot/scrot/import on Linux),
  returns file path + base64-encoded PNG data
- `image_info`: reads image metadata (format, dimensions, size) from
  header bytes without external deps, optionally returns base64 data
  for future multimodal provider support

Both tools are registered in the tool registry and agent system prompt.
Includes 24 inline tests covering format detection, dimension extraction,
schema validation, and execution edge cases.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* fix: resolve unused variable warning after rebase

Prefix unused `resolved_key` with underscore to suppress compiler
warning introduced by upstream changes. Update Cargo.lock.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* fix: address review comments on vision tools

Security fixes:
- Fix JPEG parser infinite loop on malformed zero-length segments
- Add workspace path restriction to ImageInfoTool (prevents arbitrary
  file exfiltration via include_base64)
- Quote paths in Linux screenshot shell commands to prevent injection
- Add autonomy-level check in ScreenshotTool::execute

Robustness:
- Add file size guard in read_and_encode before loading into memory
- Wire resolve_api_key through all provider match arms (was dead code)
- Gate screenshot_command_exists test on macOS/Linux only
- Infer MIME type from file extension instead of hardcoding image/png

Tests:
- Add JPEG dimension extraction test
- Add JPEG malformed zero-length segment test

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: argenis de la rosa <theonlyhennygod@gmail.com>
This commit is contained in:
Edvard Schøyen 2026-02-15 14:53:56 -05:00 committed by GitHub
parent 0f6648ceb1
commit 9b2f90018c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 837 additions and 25 deletions

1
Cargo.lock generated
View file

@ -3191,6 +3191,7 @@ dependencies = [
"anyhow", "anyhow",
"async-trait", "async-trait",
"axum", "axum",
"base64",
"chacha20poly1305", "chacha20poly1305",
"chrono", "chrono",
"clap", "clap",

View file

@ -36,6 +36,9 @@ tracing-subscriber = { version = "0.3", default-features = false, features = ["f
# Observability - Prometheus metrics # Observability - Prometheus metrics
prometheus = { version = "0.13", default-features = false } prometheus = { version = "0.13", default-features = false }
# Base64 encoding (screenshots, image data)
base64 = "0.22"
# Error handling # Error handling
anyhow = "1.0" anyhow = "1.0"
thiserror = "2.0" thiserror = "2.0"

View file

@ -321,6 +321,14 @@ pub async fn run(
"Delete a memory entry. Use when: memory is incorrect/stale or explicitly requested for removal. Don't use when: impact is uncertain.", "Delete a memory entry. Use when: memory is incorrect/stale or explicitly requested for removal. Don't use when: impact is uncertain.",
), ),
]; ];
tool_descs.push((
"screenshot",
"Capture a screenshot of the current screen. Returns file path and base64-encoded PNG. Use when: visual verification, UI inspection, debugging displays.",
));
tool_descs.push((
"image_info",
"Read image file metadata (format, dimensions, size) and optionally base64-encode it. Use when: inspecting images, preparing visual data for analysis.",
));
if config.browser.enabled { if config.browser.enabled {
tool_descs.push(( tool_descs.push((
"browser_open", "browser_open",

View file

@ -154,25 +154,26 @@ fn resolve_api_key(name: &str, api_key: Option<&str>) -> Option<String> {
/// Factory: create the right provider from config /// Factory: create the right provider from config
#[allow(clippy::too_many_lines)] #[allow(clippy::too_many_lines)]
pub fn create_provider(name: &str, api_key: Option<&str>) -> anyhow::Result<Box<dyn Provider>> { pub fn create_provider(name: &str, api_key: Option<&str>) -> anyhow::Result<Box<dyn Provider>> {
let _resolved_key = resolve_api_key(name, api_key); let resolved_key = resolve_api_key(name, api_key);
let key = resolved_key.as_deref();
match name { match name {
// ── Primary providers (custom implementations) ─────── // ── Primary providers (custom implementations) ───────
"openrouter" => Ok(Box::new(openrouter::OpenRouterProvider::new(api_key))), "openrouter" => Ok(Box::new(openrouter::OpenRouterProvider::new(key))),
"anthropic" => Ok(Box::new(anthropic::AnthropicProvider::new(api_key))), "anthropic" => Ok(Box::new(anthropic::AnthropicProvider::new(key))),
"openai" => Ok(Box::new(openai::OpenAiProvider::new(api_key))), "openai" => Ok(Box::new(openai::OpenAiProvider::new(key))),
// Ollama is a local service that doesn't use API keys. // Ollama is a local service that doesn't use API keys.
// The api_key parameter is ignored to avoid it being misinterpreted as a base_url. // The api_key parameter is ignored to avoid it being misinterpreted as a base_url.
"ollama" => Ok(Box::new(ollama::OllamaProvider::new(None))), "ollama" => Ok(Box::new(ollama::OllamaProvider::new(None))),
"gemini" | "google" | "google-gemini" => { "gemini" | "google" | "google-gemini" => {
Ok(Box::new(gemini::GeminiProvider::new(api_key))) Ok(Box::new(gemini::GeminiProvider::new(key)))
} }
// ── OpenAI-compatible providers ────────────────────── // ── OpenAI-compatible providers ──────────────────────
"venice" => Ok(Box::new(OpenAiCompatibleProvider::new( "venice" => Ok(Box::new(OpenAiCompatibleProvider::new(
"Venice", "https://api.venice.ai", api_key, AuthStyle::Bearer, "Venice", "https://api.venice.ai", key, AuthStyle::Bearer,
))), ))),
"vercel" | "vercel-ai" => Ok(Box::new(OpenAiCompatibleProvider::new( "vercel" | "vercel-ai" => Ok(Box::new(OpenAiCompatibleProvider::new(
"Vercel AI Gateway", "https://api.vercel.ai", api_key, AuthStyle::Bearer, "Vercel AI Gateway", "https://api.vercel.ai", key, AuthStyle::Bearer,
))), ))),
"cloudflare" | "cloudflare-ai" => Ok(Box::new(OpenAiCompatibleProvider::new( "cloudflare" | "cloudflare-ai" => Ok(Box::new(OpenAiCompatibleProvider::new(
"Cloudflare AI Gateway", "Cloudflare AI Gateway",
@ -181,22 +182,22 @@ pub fn create_provider(name: &str, api_key: Option<&str>) -> anyhow::Result<Box<
AuthStyle::Bearer, AuthStyle::Bearer,
))), ))),
"moonshot" | "kimi" => Ok(Box::new(OpenAiCompatibleProvider::new( "moonshot" | "kimi" => Ok(Box::new(OpenAiCompatibleProvider::new(
"Moonshot", "https://api.moonshot.cn", api_key, AuthStyle::Bearer, "Moonshot", "https://api.moonshot.cn", key, AuthStyle::Bearer,
))), ))),
"synthetic" => Ok(Box::new(OpenAiCompatibleProvider::new( "synthetic" => Ok(Box::new(OpenAiCompatibleProvider::new(
"Synthetic", "https://api.synthetic.com", api_key, AuthStyle::Bearer, "Synthetic", "https://api.synthetic.com", key, AuthStyle::Bearer,
))), ))),
"opencode" | "opencode-zen" => Ok(Box::new(OpenAiCompatibleProvider::new( "opencode" | "opencode-zen" => Ok(Box::new(OpenAiCompatibleProvider::new(
"OpenCode Zen", "https://api.opencode.ai", api_key, AuthStyle::Bearer, "OpenCode Zen", "https://api.opencode.ai", key, AuthStyle::Bearer,
))), ))),
"zai" | "z.ai" => Ok(Box::new(OpenAiCompatibleProvider::new( "zai" | "z.ai" => Ok(Box::new(OpenAiCompatibleProvider::new(
"Z.AI", "https://api.z.ai/api/coding/paas/v4", api_key, AuthStyle::Bearer, "Z.AI", "https://api.z.ai/api/coding/paas/v4", key, AuthStyle::Bearer,
))), ))),
"glm" | "zhipu" => Ok(Box::new(OpenAiCompatibleProvider::new( "glm" | "zhipu" => Ok(Box::new(OpenAiCompatibleProvider::new(
"GLM", "https://open.bigmodel.cn/api/paas", api_key, AuthStyle::Bearer, "GLM", "https://open.bigmodel.cn/api/paas", key, AuthStyle::Bearer,
))), ))),
"minimax" => Ok(Box::new(OpenAiCompatibleProvider::new( "minimax" => Ok(Box::new(OpenAiCompatibleProvider::new(
"MiniMax", "https://api.minimax.chat", api_key, AuthStyle::Bearer, "MiniMax", "https://api.minimax.chat", key, AuthStyle::Bearer,
))), ))),
"bedrock" | "aws-bedrock" => Ok(Box::new(OpenAiCompatibleProvider::new( "bedrock" | "aws-bedrock" => Ok(Box::new(OpenAiCompatibleProvider::new(
"Amazon Bedrock", "Amazon Bedrock",
@ -205,36 +206,36 @@ pub fn create_provider(name: &str, api_key: Option<&str>) -> anyhow::Result<Box<
AuthStyle::Bearer, AuthStyle::Bearer,
))), ))),
"qianfan" | "baidu" => Ok(Box::new(OpenAiCompatibleProvider::new( "qianfan" | "baidu" => Ok(Box::new(OpenAiCompatibleProvider::new(
"Qianfan", "https://aip.baidubce.com", api_key, AuthStyle::Bearer, "Qianfan", "https://aip.baidubce.com", key, AuthStyle::Bearer,
))), ))),
// ── Extended ecosystem (community favorites) ───────── // ── Extended ecosystem (community favorites) ─────────
"groq" => Ok(Box::new(OpenAiCompatibleProvider::new( "groq" => Ok(Box::new(OpenAiCompatibleProvider::new(
"Groq", "https://api.groq.com/openai", api_key, AuthStyle::Bearer, "Groq", "https://api.groq.com/openai", key, AuthStyle::Bearer,
))), ))),
"mistral" => Ok(Box::new(OpenAiCompatibleProvider::new( "mistral" => Ok(Box::new(OpenAiCompatibleProvider::new(
"Mistral", "https://api.mistral.ai", api_key, AuthStyle::Bearer, "Mistral", "https://api.mistral.ai", key, AuthStyle::Bearer,
))), ))),
"xai" | "grok" => Ok(Box::new(OpenAiCompatibleProvider::new( "xai" | "grok" => Ok(Box::new(OpenAiCompatibleProvider::new(
"xAI", "https://api.x.ai", api_key, AuthStyle::Bearer, "xAI", "https://api.x.ai", key, AuthStyle::Bearer,
))), ))),
"deepseek" => Ok(Box::new(OpenAiCompatibleProvider::new( "deepseek" => Ok(Box::new(OpenAiCompatibleProvider::new(
"DeepSeek", "https://api.deepseek.com", api_key, AuthStyle::Bearer, "DeepSeek", "https://api.deepseek.com", key, AuthStyle::Bearer,
))), ))),
"together" | "together-ai" => Ok(Box::new(OpenAiCompatibleProvider::new( "together" | "together-ai" => Ok(Box::new(OpenAiCompatibleProvider::new(
"Together AI", "https://api.together.xyz", api_key, AuthStyle::Bearer, "Together AI", "https://api.together.xyz", key, AuthStyle::Bearer,
))), ))),
"fireworks" | "fireworks-ai" => Ok(Box::new(OpenAiCompatibleProvider::new( "fireworks" | "fireworks-ai" => Ok(Box::new(OpenAiCompatibleProvider::new(
"Fireworks AI", "https://api.fireworks.ai/inference", api_key, AuthStyle::Bearer, "Fireworks AI", "https://api.fireworks.ai/inference", key, AuthStyle::Bearer,
))), ))),
"perplexity" => Ok(Box::new(OpenAiCompatibleProvider::new( "perplexity" => Ok(Box::new(OpenAiCompatibleProvider::new(
"Perplexity", "https://api.perplexity.ai", api_key, AuthStyle::Bearer, "Perplexity", "https://api.perplexity.ai", key, AuthStyle::Bearer,
))), ))),
"cohere" => Ok(Box::new(OpenAiCompatibleProvider::new( "cohere" => Ok(Box::new(OpenAiCompatibleProvider::new(
"Cohere", "https://api.cohere.com/compatibility", api_key, AuthStyle::Bearer, "Cohere", "https://api.cohere.com/compatibility", key, AuthStyle::Bearer,
))), ))),
"copilot" | "github-copilot" => Ok(Box::new(OpenAiCompatibleProvider::new( "copilot" | "github-copilot" => Ok(Box::new(OpenAiCompatibleProvider::new(
"GitHub Copilot", "https://api.githubcopilot.com", api_key, AuthStyle::Bearer, "GitHub Copilot", "https://api.githubcopilot.com", key, AuthStyle::Bearer,
))), ))),
// ── Bring Your Own Provider (custom URL) ─────────── // ── Bring Your Own Provider (custom URL) ───────────
@ -247,7 +248,7 @@ pub fn create_provider(name: &str, api_key: Option<&str>) -> anyhow::Result<Box<
Ok(Box::new(OpenAiCompatibleProvider::new( Ok(Box::new(OpenAiCompatibleProvider::new(
"Custom", "Custom",
base_url, base_url,
api_key, key,
AuthStyle::Bearer, AuthStyle::Bearer,
))) )))
} }
@ -260,7 +261,7 @@ pub fn create_provider(name: &str, api_key: Option<&str>) -> anyhow::Result<Box<
anyhow::bail!("Anthropic-custom provider requires a URL. Format: anthropic-custom:https://your-api.com"); anyhow::bail!("Anthropic-custom provider requires a URL. Format: anthropic-custom:https://your-api.com");
} }
Ok(Box::new(anthropic::AnthropicProvider::with_base_url( Ok(Box::new(anthropic::AnthropicProvider::with_base_url(
api_key, Some(base_url), key, Some(base_url),
))) )))
} }

491
src/tools/image_info.rs Normal file
View file

@ -0,0 +1,491 @@
use super::traits::{Tool, ToolResult};
use crate::security::SecurityPolicy;
use async_trait::async_trait;
use serde_json::json;
use std::fmt::Write;
use std::path::Path;
use std::sync::Arc;
/// Maximum file size we will read and base64-encode (5 MB).
const MAX_IMAGE_BYTES: u64 = 5_242_880;
/// Tool to read image metadata and optionally return base64-encoded data.
///
/// Since providers are currently text-only, this tool extracts what it can
/// (file size, format, dimensions from header bytes) and provides base64
/// data for future multimodal provider support.
pub struct ImageInfoTool {
security: Arc<SecurityPolicy>,
}
impl ImageInfoTool {
pub fn new(security: Arc<SecurityPolicy>) -> Self {
Self { security }
}
/// Detect image format from first few bytes (magic numbers).
fn detect_format(bytes: &[u8]) -> &'static str {
if bytes.len() < 4 {
return "unknown";
}
if bytes.starts_with(b"\x89PNG") {
"png"
} else if bytes.starts_with(b"\xFF\xD8\xFF") {
"jpeg"
} else if bytes.starts_with(b"GIF8") {
"gif"
} else if bytes.starts_with(b"RIFF") && bytes.len() >= 12 && &bytes[8..12] == b"WEBP" {
"webp"
} else if bytes.starts_with(b"BM") {
"bmp"
} else {
"unknown"
}
}
/// Try to extract dimensions from image header bytes.
/// Returns (width, height) if detectable.
fn extract_dimensions(bytes: &[u8], format: &str) -> Option<(u32, u32)> {
match format {
"png" => {
// PNG IHDR chunk: bytes 16-19 = width, 20-23 = height (big-endian)
if bytes.len() >= 24 {
let w = u32::from_be_bytes([bytes[16], bytes[17], bytes[18], bytes[19]]);
let h = u32::from_be_bytes([bytes[20], bytes[21], bytes[22], bytes[23]]);
Some((w, h))
} else {
None
}
}
"gif" => {
// GIF: bytes 6-7 = width, 8-9 = height (little-endian)
if bytes.len() >= 10 {
let w = u32::from(u16::from_le_bytes([bytes[6], bytes[7]]));
let h = u32::from(u16::from_le_bytes([bytes[8], bytes[9]]));
Some((w, h))
} else {
None
}
}
"bmp" => {
// BMP: bytes 18-21 = width, 22-25 = height (little-endian, signed)
if bytes.len() >= 26 {
let w = u32::from_le_bytes([bytes[18], bytes[19], bytes[20], bytes[21]]);
let h_raw = i32::from_le_bytes([bytes[22], bytes[23], bytes[24], bytes[25]]);
let h = h_raw.unsigned_abs();
Some((w, h))
} else {
None
}
}
"jpeg" => Self::jpeg_dimensions(bytes),
_ => None,
}
}
/// Parse JPEG SOF markers to extract dimensions.
fn jpeg_dimensions(bytes: &[u8]) -> Option<(u32, u32)> {
let mut i = 2; // skip SOI marker
while i + 1 < bytes.len() {
if bytes[i] != 0xFF {
return None;
}
let marker = bytes[i + 1];
i += 2;
// SOF0..SOF3 markers contain dimensions
if (0xC0..=0xC3).contains(&marker) {
if i + 7 <= bytes.len() {
let h = u32::from(u16::from_be_bytes([bytes[i + 3], bytes[i + 4]]));
let w = u32::from(u16::from_be_bytes([bytes[i + 5], bytes[i + 6]]));
return Some((w, h));
}
return None;
}
// Skip this segment
if i + 1 < bytes.len() {
let seg_len = u16::from_be_bytes([bytes[i], bytes[i + 1]]) as usize;
if seg_len < 2 {
return None; // Malformed segment (valid segments have length >= 2)
}
i += seg_len;
} else {
return None;
}
}
None
}
}
#[async_trait]
impl Tool for ImageInfoTool {
fn name(&self) -> &str {
"image_info"
}
fn description(&self) -> &str {
"Read image file metadata (format, dimensions, size) and optionally return base64-encoded data."
}
fn parameters_schema(&self) -> serde_json::Value {
json!({
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "Path to the image file (absolute or relative to workspace)"
},
"include_base64": {
"type": "boolean",
"description": "Include base64-encoded image data in output (default: false)"
}
},
"required": ["path"]
})
}
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
let path_str = args
.get("path")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing 'path' parameter"))?;
let include_base64 = args
.get("include_base64")
.and_then(serde_json::Value::as_bool)
.unwrap_or(false);
let path = Path::new(path_str);
// Restrict reads to workspace directory to prevent arbitrary file exfiltration
if !self.security.is_path_allowed(path_str) {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!("Path not allowed: {path_str} (must be within workspace)")),
});
}
if !path.exists() {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!("File not found: {path_str}")),
});
}
let metadata = tokio::fs::metadata(path)
.await
.map_err(|e| anyhow::anyhow!("Failed to read file metadata: {e}"))?;
let file_size = metadata.len();
if file_size > MAX_IMAGE_BYTES {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!(
"Image too large: {file_size} bytes (max {MAX_IMAGE_BYTES} bytes)"
)),
});
}
let bytes = tokio::fs::read(path)
.await
.map_err(|e| anyhow::anyhow!("Failed to read image file: {e}"))?;
let format = Self::detect_format(&bytes);
let dimensions = Self::extract_dimensions(&bytes, format);
let mut output = format!("File: {path_str}\nFormat: {format}\nSize: {file_size} bytes");
if let Some((w, h)) = dimensions {
let _ = write!(output, "\nDimensions: {w}x{h}");
}
if include_base64 {
use base64::Engine;
let encoded = base64::engine::general_purpose::STANDARD.encode(&bytes);
let mime = match format {
"png" => "image/png",
"jpeg" => "image/jpeg",
"gif" => "image/gif",
"webp" => "image/webp",
"bmp" => "image/bmp",
_ => "application/octet-stream",
};
let _ = write!(output, "\ndata:{mime};base64,{encoded}");
}
Ok(ToolResult {
success: true,
output,
error: None,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::security::{AutonomyLevel, SecurityPolicy};
fn test_security() -> Arc<SecurityPolicy> {
Arc::new(SecurityPolicy {
autonomy: AutonomyLevel::Full,
workspace_dir: std::env::temp_dir(),
workspace_only: false,
forbidden_paths: vec![],
..SecurityPolicy::default()
})
}
#[test]
fn image_info_tool_name() {
let tool = ImageInfoTool::new(test_security());
assert_eq!(tool.name(), "image_info");
}
#[test]
fn image_info_tool_description() {
let tool = ImageInfoTool::new(test_security());
assert!(!tool.description().is_empty());
assert!(tool.description().contains("image"));
}
#[test]
fn image_info_tool_schema() {
let tool = ImageInfoTool::new(test_security());
let schema = tool.parameters_schema();
assert!(schema["properties"]["path"].is_object());
assert!(schema["properties"]["include_base64"].is_object());
let required = schema["required"].as_array().unwrap();
assert!(required.contains(&json!("path")));
}
#[test]
fn image_info_tool_spec() {
let tool = ImageInfoTool::new(test_security());
let spec = tool.spec();
assert_eq!(spec.name, "image_info");
assert!(spec.parameters.is_object());
}
// ── Format detection ────────────────────────────────────────
#[test]
fn detect_png() {
let bytes = b"\x89PNG\r\n\x1a\n";
assert_eq!(ImageInfoTool::detect_format(bytes), "png");
}
#[test]
fn detect_jpeg() {
let bytes = b"\xFF\xD8\xFF\xE0";
assert_eq!(ImageInfoTool::detect_format(bytes), "jpeg");
}
#[test]
fn detect_gif() {
let bytes = b"GIF89a";
assert_eq!(ImageInfoTool::detect_format(bytes), "gif");
}
#[test]
fn detect_webp() {
let bytes = b"RIFF\x00\x00\x00\x00WEBP";
assert_eq!(ImageInfoTool::detect_format(bytes), "webp");
}
#[test]
fn detect_bmp() {
let bytes = b"BM\x00\x00";
assert_eq!(ImageInfoTool::detect_format(bytes), "bmp");
}
#[test]
fn detect_unknown_short() {
let bytes = b"\x00\x01";
assert_eq!(ImageInfoTool::detect_format(bytes), "unknown");
}
#[test]
fn detect_unknown_garbage() {
let bytes = b"this is not an image";
assert_eq!(ImageInfoTool::detect_format(bytes), "unknown");
}
// ── Dimension extraction ────────────────────────────────────
#[test]
fn png_dimensions() {
// Minimal PNG IHDR: 8-byte signature + 4-byte length + 4-byte IHDR + 4-byte width + 4-byte height
let mut bytes = vec![
0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A, // PNG signature
0x00, 0x00, 0x00, 0x0D, // IHDR length
0x49, 0x48, 0x44, 0x52, // "IHDR"
0x00, 0x00, 0x03, 0x20, // width: 800
0x00, 0x00, 0x02, 0x58, // height: 600
];
bytes.extend_from_slice(&[0u8; 10]); // padding
let dims = ImageInfoTool::extract_dimensions(&bytes, "png");
assert_eq!(dims, Some((800, 600)));
}
#[test]
fn gif_dimensions() {
let bytes = [
0x47, 0x49, 0x46, 0x38, 0x39, 0x61, // GIF89a
0x40, 0x01, // width: 320 (LE)
0xF0, 0x00, // height: 240 (LE)
];
let dims = ImageInfoTool::extract_dimensions(&bytes, "gif");
assert_eq!(dims, Some((320, 240)));
}
#[test]
fn bmp_dimensions() {
let mut bytes = vec![0u8; 26];
bytes[0] = b'B';
bytes[1] = b'M';
// width at offset 18 (LE): 1024
bytes[18] = 0x00;
bytes[19] = 0x04;
bytes[20] = 0x00;
bytes[21] = 0x00;
// height at offset 22 (LE): 768
bytes[22] = 0x00;
bytes[23] = 0x03;
bytes[24] = 0x00;
bytes[25] = 0x00;
let dims = ImageInfoTool::extract_dimensions(&bytes, "bmp");
assert_eq!(dims, Some((1024, 768)));
}
#[test]
fn jpeg_dimensions() {
// Minimal JPEG-like byte sequence with SOF0 marker
let mut bytes: Vec<u8> = vec![
0xFF, 0xD8, // SOI
0xFF, 0xE0, // APP0 marker
0x00, 0x10, // APP0 length = 16
];
bytes.extend_from_slice(&[0u8; 14]); // APP0 payload
bytes.extend_from_slice(&[
0xFF, 0xC0, // SOF0 marker
0x00, 0x11, // SOF0 length
0x08, // precision
0x01, 0xE0, // height: 480
0x02, 0x80, // width: 640
]);
let dims = ImageInfoTool::extract_dimensions(&bytes, "jpeg");
assert_eq!(dims, Some((640, 480)));
}
#[test]
fn jpeg_malformed_zero_length_segment() {
// Zero-length segment should return None instead of looping forever
let bytes: Vec<u8> = vec![
0xFF, 0xD8, // SOI
0xFF, 0xE0, // APP0 marker
0x00, 0x00, // length = 0 (malformed)
];
let dims = ImageInfoTool::extract_dimensions(&bytes, "jpeg");
assert!(dims.is_none());
}
#[test]
fn unknown_format_no_dimensions() {
let bytes = b"random data here";
let dims = ImageInfoTool::extract_dimensions(bytes, "unknown");
assert!(dims.is_none());
}
// ── Execute tests ───────────────────────────────────────────
#[tokio::test]
async fn execute_missing_path() {
let tool = ImageInfoTool::new(test_security());
let result = tool.execute(json!({})).await;
assert!(result.is_err());
}
#[tokio::test]
async fn execute_nonexistent_file() {
let tool = ImageInfoTool::new(test_security());
let result = tool
.execute(json!({"path": "/tmp/nonexistent_image_xyz.png"}))
.await
.unwrap();
assert!(!result.success);
assert!(result.error.as_ref().unwrap().contains("not found"));
}
#[tokio::test]
async fn execute_real_file() {
// Create a minimal valid PNG
let dir = std::env::temp_dir().join("zeroclaw_image_info_test");
let _ = std::fs::create_dir_all(&dir);
let png_path = dir.join("test.png");
// Minimal 1x1 red PNG (67 bytes)
let png_bytes: Vec<u8> = vec![
0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A, // signature
0x00, 0x00, 0x00, 0x0D, // IHDR length
0x49, 0x48, 0x44, 0x52, // IHDR
0x00, 0x00, 0x00, 0x01, // width: 1
0x00, 0x00, 0x00, 0x01, // height: 1
0x08, 0x02, 0x00, 0x00, 0x00, // bit depth, color type, etc.
0x90, 0x77, 0x53, 0xDE, // CRC
0x00, 0x00, 0x00, 0x0C, // IDAT length
0x49, 0x44, 0x41, 0x54, // IDAT
0x08, 0xD7, 0x63, 0xF8, 0xCF, 0xC0, 0x00, 0x00, 0x00, 0x02, 0x00, 0x01, 0xE2, 0x21,
0xBC, 0x33, // CRC
0x00, 0x00, 0x00, 0x00, // IEND length
0x49, 0x45, 0x4E, 0x44, // IEND
0xAE, 0x42, 0x60, 0x82, // CRC
];
std::fs::write(&png_path, &png_bytes).unwrap();
let tool = ImageInfoTool::new(test_security());
let result = tool
.execute(json!({"path": png_path.to_string_lossy()}))
.await
.unwrap();
assert!(result.success);
assert!(result.output.contains("Format: png"));
assert!(result.output.contains("Dimensions: 1x1"));
assert!(!result.output.contains("data:"));
// Clean up
let _ = std::fs::remove_dir_all(&dir);
}
#[tokio::test]
async fn execute_with_base64() {
let dir = std::env::temp_dir().join("zeroclaw_image_info_b64");
let _ = std::fs::create_dir_all(&dir);
let png_path = dir.join("test_b64.png");
// Minimal 1x1 PNG
let png_bytes: Vec<u8> = vec![
0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A, 0x00, 0x00, 0x00, 0x0D, 0x49, 0x48,
0x44, 0x52, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x08, 0x02, 0x00, 0x00,
0x00, 0x90, 0x77, 0x53, 0xDE, 0x00, 0x00, 0x00, 0x0C, 0x49, 0x44, 0x41, 0x54, 0x08,
0xD7, 0x63, 0xF8, 0xCF, 0xC0, 0x00, 0x00, 0x00, 0x02, 0x00, 0x01, 0xE2, 0x21, 0xBC,
0x33, 0x00, 0x00, 0x00, 0x00, 0x49, 0x45, 0x4E, 0x44, 0xAE, 0x42, 0x60, 0x82,
];
std::fs::write(&png_path, &png_bytes).unwrap();
let tool = ImageInfoTool::new(test_security());
let result = tool
.execute(json!({"path": png_path.to_string_lossy(), "include_base64": true}))
.await
.unwrap();
assert!(result.success);
assert!(result.output.contains("data:image/png;base64,"));
let _ = std::fs::remove_dir_all(&dir);
}
}

View file

@ -3,9 +3,11 @@ pub mod browser_open;
pub mod composio; pub mod composio;
pub mod file_read; pub mod file_read;
pub mod file_write; pub mod file_write;
pub mod image_info;
pub mod memory_forget; pub mod memory_forget;
pub mod memory_recall; pub mod memory_recall;
pub mod memory_store; pub mod memory_store;
pub mod screenshot;
pub mod shell; pub mod shell;
pub mod traits; pub mod traits;
@ -14,9 +16,11 @@ pub use browser_open::BrowserOpenTool;
pub use composio::ComposioTool; pub use composio::ComposioTool;
pub use file_read::FileReadTool; pub use file_read::FileReadTool;
pub use file_write::FileWriteTool; pub use file_write::FileWriteTool;
pub use image_info::ImageInfoTool;
pub use memory_forget::MemoryForgetTool; pub use memory_forget::MemoryForgetTool;
pub use memory_recall::MemoryRecallTool; pub use memory_recall::MemoryRecallTool;
pub use memory_store::MemoryStoreTool; pub use memory_store::MemoryStoreTool;
pub use screenshot::ScreenshotTool;
pub use shell::ShellTool; pub use shell::ShellTool;
pub use traits::Tool; pub use traits::Tool;
#[allow(unused_imports)] #[allow(unused_imports)]
@ -91,6 +95,10 @@ pub fn all_tools_with_runtime(
))); )));
} }
// Vision tools are always available
tools.push(Box::new(ScreenshotTool::new(security.clone())));
tools.push(Box::new(ImageInfoTool::new(security.clone())));
if let Some(key) = composio_key { if let Some(key) = composio_key {
if !key.is_empty() { if !key.is_empty() {
tools.push(Box::new(ComposioTool::new(key))); tools.push(Box::new(ComposioTool::new(key)));

300
src/tools/screenshot.rs Normal file
View file

@ -0,0 +1,300 @@
use super::traits::{Tool, ToolResult};
use crate::security::SecurityPolicy;
use async_trait::async_trait;
use serde_json::json;
use std::fmt::Write;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration;
/// Maximum time to wait for a screenshot command to complete.
const SCREENSHOT_TIMEOUT_SECS: u64 = 15;
/// Maximum base64 payload size to return (2 MB of base64 ≈ 1.5 MB image).
const MAX_BASE64_BYTES: usize = 2_097_152;
/// Tool for capturing screenshots using platform-native commands.
///
/// macOS: `screencapture`
/// Linux: tries `gnome-screenshot`, `scrot`, `import` (`ImageMagick`) in order.
pub struct ScreenshotTool {
security: Arc<SecurityPolicy>,
}
impl ScreenshotTool {
pub fn new(security: Arc<SecurityPolicy>) -> Self {
Self { security }
}
/// Determine the screenshot command for the current platform.
fn screenshot_command(output_path: &str) -> Option<Vec<String>> {
if cfg!(target_os = "macos") {
Some(vec![
"screencapture".into(),
"-x".into(), // no sound
output_path.into(),
])
} else if cfg!(target_os = "linux") {
Some(vec![
"sh".into(),
"-c".into(),
format!(
"if command -v gnome-screenshot >/dev/null 2>&1; then \
gnome-screenshot -f '{output_path}'; \
elif command -v scrot >/dev/null 2>&1; then \
scrot '{output_path}'; \
elif command -v import >/dev/null 2>&1; then \
import -window root '{output_path}'; \
else \
echo 'NO_SCREENSHOT_TOOL' >&2; exit 1; \
fi"
),
])
} else {
None
}
}
/// Execute the screenshot capture and return the result.
async fn capture(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
let timestamp = chrono::Utc::now().format("%Y%m%d_%H%M%S");
let filename = args
.get("filename")
.and_then(|v| v.as_str())
.map_or_else(|| format!("screenshot_{timestamp}.png"), String::from);
// Sanitize filename to prevent path traversal
let safe_name = PathBuf::from(&filename).file_name().map_or_else(
|| format!("screenshot_{timestamp}.png"),
|n| n.to_string_lossy().to_string(),
);
let output_path = self.security.workspace_dir.join(&safe_name);
let output_str = output_path.to_string_lossy().to_string();
let Some(mut cmd_args) = Self::screenshot_command(&output_str) else {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some("Screenshot not supported on this platform".into()),
});
};
// macOS region flags
if cfg!(target_os = "macos") {
if let Some(region) = args.get("region").and_then(|v| v.as_str()) {
match region {
"selection" => cmd_args.insert(1, "-s".into()),
"window" => cmd_args.insert(1, "-w".into()),
_ => {} // ignore unknown regions
}
}
}
let program = cmd_args.remove(0);
let result = tokio::time::timeout(
Duration::from_secs(SCREENSHOT_TIMEOUT_SECS),
tokio::process::Command::new(&program)
.args(&cmd_args)
.output(),
)
.await;
match result {
Ok(Ok(output)) => {
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
if stderr.contains("NO_SCREENSHOT_TOOL") {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(
"No screenshot tool found. Install gnome-screenshot, scrot, or ImageMagick."
.into(),
),
});
}
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!("Screenshot command failed: {stderr}")),
});
}
Self::read_and_encode(&output_path).await
}
Ok(Err(e)) => Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!("Failed to execute screenshot command: {e}")),
}),
Err(_) => Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!(
"Screenshot timed out after {SCREENSHOT_TIMEOUT_SECS}s"
)),
}),
}
}
/// Read the screenshot file and return base64-encoded result.
async fn read_and_encode(output_path: &std::path::Path) -> anyhow::Result<ToolResult> {
// Check file size before reading to prevent OOM on large screenshots
const MAX_RAW_BYTES: u64 = 1_572_864; // ~1.5 MB (base64 expands ~33%)
if let Ok(meta) = tokio::fs::metadata(output_path).await {
if meta.len() > MAX_RAW_BYTES {
return Ok(ToolResult {
success: true,
output: format!(
"Screenshot saved to: {}\nSize: {} bytes (too large to base64-encode inline)",
output_path.display(),
meta.len(),
),
error: None,
});
}
}
match tokio::fs::read(output_path).await {
Ok(bytes) => {
use base64::Engine;
let size = bytes.len();
let mut encoded = base64::engine::general_purpose::STANDARD.encode(&bytes);
let truncated = if encoded.len() > MAX_BASE64_BYTES {
encoded.truncate(encoded.floor_char_boundary(MAX_BASE64_BYTES));
true
} else {
false
};
let mut output_msg = format!(
"Screenshot saved to: {}\nSize: {size} bytes\nBase64 length: {}",
output_path.display(),
encoded.len(),
);
if truncated {
output_msg.push_str(" (truncated)");
}
let mime = match output_path.extension().and_then(|e| e.to_str()) {
Some("jpg" | "jpeg") => "image/jpeg",
Some("bmp") => "image/bmp",
Some("gif") => "image/gif",
Some("webp") => "image/webp",
_ => "image/png",
};
let _ = write!(output_msg, "\ndata:{mime};base64,{encoded}");
Ok(ToolResult {
success: true,
output: output_msg,
error: None,
})
}
Err(e) => Ok(ToolResult {
success: false,
output: format!("Screenshot saved to: {}", output_path.display()),
error: Some(format!("Failed to read screenshot file: {e}")),
}),
}
}
}
#[async_trait]
impl Tool for ScreenshotTool {
fn name(&self) -> &str {
"screenshot"
}
fn description(&self) -> &str {
"Capture a screenshot of the current screen. Returns the file path and base64-encoded PNG data."
}
fn parameters_schema(&self) -> serde_json::Value {
json!({
"type": "object",
"properties": {
"filename": {
"type": "string",
"description": "Optional filename (default: screenshot_<timestamp>.png). Saved in workspace."
},
"region": {
"type": "string",
"description": "Optional region for macOS: 'selection' for interactive crop, 'window' for front window. Ignored on Linux."
}
}
})
}
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
if !self.security.can_act() {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some("Action blocked: autonomy is read-only".into()),
});
}
self.capture(args).await
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::security::{AutonomyLevel, SecurityPolicy};
fn test_security() -> Arc<SecurityPolicy> {
Arc::new(SecurityPolicy {
autonomy: AutonomyLevel::Full,
workspace_dir: std::env::temp_dir(),
..SecurityPolicy::default()
})
}
#[test]
fn screenshot_tool_name() {
let tool = ScreenshotTool::new(test_security());
assert_eq!(tool.name(), "screenshot");
}
#[test]
fn screenshot_tool_description() {
let tool = ScreenshotTool::new(test_security());
assert!(!tool.description().is_empty());
assert!(tool.description().contains("screenshot"));
}
#[test]
fn screenshot_tool_schema() {
let tool = ScreenshotTool::new(test_security());
let schema = tool.parameters_schema();
assert!(schema["properties"]["filename"].is_object());
assert!(schema["properties"]["region"].is_object());
}
#[test]
fn screenshot_tool_spec() {
let tool = ScreenshotTool::new(test_security());
let spec = tool.spec();
assert_eq!(spec.name, "screenshot");
assert!(spec.parameters.is_object());
}
#[test]
#[cfg(any(target_os = "macos", target_os = "linux"))]
fn screenshot_command_exists() {
let cmd = ScreenshotTool::screenshot_command("/tmp/test.png");
assert!(cmd.is_some());
let args = cmd.unwrap();
assert!(!args.is_empty());
}
#[test]
fn screenshot_command_contains_output_path() {
let cmd = ScreenshotTool::screenshot_command("/tmp/my_screenshot.png").unwrap();
let joined = cmd.join(" ");
assert!(
joined.contains("/tmp/my_screenshot.png"),
"Command should contain the output path"
);
}
}