diff --git a/docs/channels-reference.md b/docs/channels-reference.md index 2ab904e..a25050c 100644 --- a/docs/channels-reference.md +++ b/docs/channels-reference.md @@ -51,6 +51,22 @@ Notes: - Model cache previews come from `zeroclaw models refresh --provider `. - These are runtime chat commands, not CLI subcommands. +## Inbound Image Marker Protocol + +ZeroClaw supports multimodal input through inline message markers: + +- Syntax: ``[IMAGE:]`` +- `` can be: + - Local file path + - Data URI (`data:image/...;base64,...`) + - Remote URL only when `[multimodal].allow_remote_fetch = true` + +Operational notes: + +- Marker parsing applies to user-role messages before provider calls. +- Provider capability is enforced at runtime: if the selected provider does not support vision, the request fails with a structured capability error (`capability=vision`). +- Linq webhook `media` parts with `image/*` MIME type are automatically converted to this marker format. + ## Channel Matrix --- @@ -349,4 +365,3 @@ If a specific channel task crashes or exits, the channel supervisor in `channels - `Channel message worker crashed:` These messages indicate automatic restart behavior is active, and you should inspect preceding logs for root cause. - diff --git a/docs/config-reference.md b/docs/config-reference.md index 2b8d87f..3635878 100644 --- a/docs/config-reference.md +++ b/docs/config-reference.md @@ -62,6 +62,24 @@ Notes: - `reasoning_enabled = true` explicitly requests reasoning for supported providers (`think: true` on `ollama`). - Unset keeps provider defaults. +## `[multimodal]` + +| Key | Default | Purpose | +|---|---|---| +| `max_images` | `4` | Maximum image markers accepted per request | +| `max_image_size_mb` | `5` | Per-image size limit before base64 encoding | +| `allow_remote_fetch` | `false` | Allow fetching `http(s)` image URLs from markers | + +Notes: + +- Runtime accepts image markers in user messages with syntax: ``[IMAGE:]``. +- Supported sources: + - Local file path (for example ``[IMAGE:/tmp/screenshot.png]``) +- Data URI (for example ``[IMAGE:data:image/png;base64,...]``) +- Remote URL only when `allow_remote_fetch = true` +- Allowed MIME types: `image/png`, `image/jpeg`, `image/webp`, `image/gif`, `image/bmp`. +- When the active provider does not support vision, requests fail with a structured capability error (`capability=vision`) instead of silently dropping images. + ## `[gateway]` | Key | Default | Purpose | diff --git a/docs/providers-reference.md b/docs/providers-reference.md index 790ee76..9fa2cca 100644 --- a/docs/providers-reference.md +++ b/docs/providers-reference.md @@ -56,6 +56,13 @@ credential is not reused for fallback providers. | `lmstudio` | `lm-studio` | Yes | (optional; local by default) | | `nvidia` | `nvidia-nim`, `build.nvidia.com` | No | `NVIDIA_API_KEY` | +### Ollama Vision Notes + +- Provider ID: `ollama` +- Vision input is supported through user message image markers: ``[IMAGE:]``. +- After multimodal normalization, ZeroClaw sends image payloads through Ollama's native `messages[].images` field. +- If a non-vision provider is selected, ZeroClaw returns a structured capability error instead of silently ignoring images. + ### Bedrock Notes - Provider ID: `bedrock` (alias: `aws-bedrock`) diff --git a/src/agent/loop_.rs b/src/agent/loop_.rs index 99b0ce7..c26f970 100644 --- a/src/agent/loop_.rs +++ b/src/agent/loop_.rs @@ -1,8 +1,11 @@ use crate::approval::{ApprovalManager, ApprovalRequest, ApprovalResponse}; use crate::config::Config; use crate::memory::{self, Memory, MemoryCategory}; +use crate::multimodal; use crate::observability::{self, Observer, ObserverEvent}; -use crate::providers::{self, ChatMessage, ChatRequest, Provider, ToolCall}; +use crate::providers::{ + self, ChatMessage, ChatRequest, Provider, ProviderCapabilityError, ToolCall, +}; use crate::runtime; use crate::security::SecurityPolicy; use crate::tools::{self, Tool}; @@ -826,6 +829,7 @@ pub(crate) async fn agent_turn( model: &str, temperature: f64, silent: bool, + multimodal_config: &crate::config::MultimodalConfig, max_tool_iterations: usize, ) -> Result { run_tool_call_loop( @@ -839,6 +843,7 @@ pub(crate) async fn agent_turn( silent, None, "channel", + multimodal_config, max_tool_iterations, None, ) @@ -859,6 +864,7 @@ pub(crate) async fn run_tool_call_loop( silent: bool, approval: Option<&ApprovalManager>, channel_name: &str, + multimodal_config: &crate::config::MultimodalConfig, max_tool_iterations: usize, on_delta: Option>, ) -> Result { @@ -873,6 +879,21 @@ pub(crate) async fn run_tool_call_loop( let use_native_tools = provider.supports_native_tools() && !tool_specs.is_empty(); for _iteration in 0..max_iterations { + let image_marker_count = multimodal::count_image_markers(history); + if image_marker_count > 0 && !provider.supports_vision() { + return Err(ProviderCapabilityError { + provider: provider_name.to_string(), + capability: "vision".to_string(), + message: format!( + "received {image_marker_count} image marker(s), but this provider does not support vision input" + ), + } + .into()); + } + + let prepared_messages = + multimodal::prepare_messages_for_provider(history, multimodal_config).await?; + observer.record_event(&ObserverEvent::LlmRequest { provider: provider_name.to_string(), model: model.to_string(), @@ -893,7 +914,7 @@ pub(crate) async fn run_tool_call_loop( match provider .chat( ChatRequest { - messages: history, + messages: &prepared_messages.messages, tools: request_tools, }, model, @@ -1404,6 +1425,7 @@ pub async fn run( false, Some(&approval_manager), "cli", + &config.multimodal, config.agent.max_tool_iterations, None, ) @@ -1530,6 +1552,7 @@ pub async fn run( false, Some(&approval_manager), "cli", + &config.multimodal, config.agent.max_tool_iterations, None, ) @@ -1757,6 +1780,7 @@ pub async fn process_message(config: Config, message: &str) -> Result { &model_name, config.default_temperature, true, + &config.multimodal, config.agent.max_tool_iterations, ) .await @@ -1765,6 +1789,10 @@ pub async fn process_message(config: Config, message: &str) -> Result { #[cfg(test)] mod tests { use super::*; + use async_trait::async_trait; + use base64::{engine::general_purpose::STANDARD, Engine as _}; + use std::sync::atomic::{AtomicUsize, Ordering}; + use std::sync::Arc; #[test] fn test_scrub_credentials() { @@ -1785,8 +1813,191 @@ mod tests { assert!(scrubbed.contains("public")); } use crate::memory::{Memory, MemoryCategory, SqliteMemory}; + use crate::observability::NoopObserver; + use crate::providers::traits::ProviderCapabilities; + use crate::providers::ChatResponse; use tempfile::TempDir; + struct NonVisionProvider { + calls: Arc, + } + + #[async_trait] + impl Provider for NonVisionProvider { + async fn chat_with_system( + &self, + _system_prompt: Option<&str>, + _message: &str, + _model: &str, + _temperature: f64, + ) -> anyhow::Result { + self.calls.fetch_add(1, Ordering::SeqCst); + Ok("ok".to_string()) + } + } + + struct VisionProvider { + calls: Arc, + } + + #[async_trait] + impl Provider for VisionProvider { + fn capabilities(&self) -> ProviderCapabilities { + ProviderCapabilities { + native_tool_calling: false, + vision: true, + } + } + + async fn chat_with_system( + &self, + _system_prompt: Option<&str>, + _message: &str, + _model: &str, + _temperature: f64, + ) -> anyhow::Result { + self.calls.fetch_add(1, Ordering::SeqCst); + Ok("ok".to_string()) + } + + async fn chat( + &self, + request: ChatRequest<'_>, + _model: &str, + _temperature: f64, + ) -> anyhow::Result { + self.calls.fetch_add(1, Ordering::SeqCst); + let marker_count = crate::multimodal::count_image_markers(request.messages); + if marker_count == 0 { + anyhow::bail!("expected image markers in request messages"); + } + + if request.tools.is_some() { + anyhow::bail!("no tools should be attached for this test"); + } + + Ok(ChatResponse { + text: Some("vision-ok".to_string()), + tool_calls: Vec::new(), + }) + } + } + + #[tokio::test] + async fn run_tool_call_loop_returns_structured_error_for_non_vision_provider() { + let calls = Arc::new(AtomicUsize::new(0)); + let provider = NonVisionProvider { + calls: Arc::clone(&calls), + }; + + let mut history = vec![ChatMessage::user( + "please inspect [IMAGE:data:image/png;base64,iVBORw0KGgo=]".to_string(), + )]; + let tools_registry: Vec> = Vec::new(); + let observer = NoopObserver; + + let err = run_tool_call_loop( + &provider, + &mut history, + &tools_registry, + &observer, + "mock-provider", + "mock-model", + 0.0, + true, + None, + "cli", + &crate::config::MultimodalConfig::default(), + 3, + None, + ) + .await + .expect_err("provider without vision support should fail"); + + assert!(err.to_string().contains("provider_capability_error")); + assert!(err.to_string().contains("capability=vision")); + assert_eq!(calls.load(Ordering::SeqCst), 0); + } + + #[tokio::test] + async fn run_tool_call_loop_rejects_oversized_image_payload() { + let calls = Arc::new(AtomicUsize::new(0)); + let provider = VisionProvider { + calls: Arc::clone(&calls), + }; + + let oversized_payload = STANDARD.encode(vec![0_u8; (1024 * 1024) + 1]); + let mut history = vec![ChatMessage::user(format!( + "[IMAGE:data:image/png;base64,{oversized_payload}]" + ))]; + + let tools_registry: Vec> = Vec::new(); + let observer = NoopObserver; + let multimodal = crate::config::MultimodalConfig { + max_images: 4, + max_image_size_mb: 1, + allow_remote_fetch: false, + }; + + let err = run_tool_call_loop( + &provider, + &mut history, + &tools_registry, + &observer, + "mock-provider", + "mock-model", + 0.0, + true, + None, + "cli", + &multimodal, + 3, + None, + ) + .await + .expect_err("oversized payload must fail"); + + assert!(err + .to_string() + .contains("multimodal image size limit exceeded")); + assert_eq!(calls.load(Ordering::SeqCst), 0); + } + + #[tokio::test] + async fn run_tool_call_loop_accepts_valid_multimodal_request_flow() { + let calls = Arc::new(AtomicUsize::new(0)); + let provider = VisionProvider { + calls: Arc::clone(&calls), + }; + + let mut history = vec![ChatMessage::user( + "Analyze this [IMAGE:data:image/png;base64,iVBORw0KGgo=]".to_string(), + )]; + let tools_registry: Vec> = Vec::new(); + let observer = NoopObserver; + + let result = run_tool_call_loop( + &provider, + &mut history, + &tools_registry, + &observer, + "mock-provider", + "mock-model", + 0.0, + true, + None, + "cli", + &crate::config::MultimodalConfig::default(), + 3, + None, + ) + .await + .expect("valid multimodal payload should pass"); + + assert_eq!(result, "vision-ok"); + assert_eq!(calls.load(Ordering::SeqCst), 1); + } + #[test] fn parse_tool_calls_extracts_single_call() { let response = r#"Let me check that. diff --git a/src/channels/linq.rs b/src/channels/linq.rs index d3361c9..228d43e 100644 --- a/src/channels/linq.rs +++ b/src/channels/linq.rs @@ -37,6 +37,28 @@ impl LinqChannel { &self.from_phone } + fn media_part_to_image_marker(part: &serde_json::Value) -> Option { + let source = part + .get("url") + .or_else(|| part.get("value")) + .and_then(|value| value.as_str()) + .map(str::trim) + .filter(|value| !value.is_empty())?; + + let mime_type = part + .get("mime_type") + .and_then(|value| value.as_str()) + .map(str::trim) + .unwrap_or_default() + .to_ascii_lowercase(); + + if !mime_type.starts_with("image/") { + return None; + } + + Some(format!("[IMAGE:{source}]")) + } + /// Parse an incoming webhook payload from Linq and extract messages. /// /// Linq webhook envelope: @@ -124,25 +146,36 @@ impl LinqChannel { return messages; }; - let text_parts: Vec<&str> = parts + let content_parts: Vec = parts .iter() .filter_map(|part| { let part_type = part.get("type").and_then(|t| t.as_str())?; - if part_type == "text" { - part.get("value").and_then(|v| v.as_str()) - } else { - // Skip media parts for now - tracing::debug!("Linq: skipping {part_type} part"); - None + match part_type { + "text" => part + .get("value") + .and_then(|v| v.as_str()) + .map(ToString::to_string), + "media" | "image" => { + if let Some(marker) = Self::media_part_to_image_marker(part) { + Some(marker) + } else { + tracing::debug!("Linq: skipping unsupported {part_type} part"); + None + } + } + _ => { + tracing::debug!("Linq: skipping {part_type} part"); + None + } } }) .collect(); - if text_parts.is_empty() { + if content_parts.is_empty() { return messages; } - let content = text_parts.join("\n"); + let content = content_parts.join("\n").trim().to_string(); if content.is_empty() { return messages; @@ -496,7 +529,7 @@ mod tests { } #[test] - fn linq_parse_media_only_skipped() { + fn linq_parse_media_only_translated_to_image_marker() { let ch = LinqChannel::new("tok".into(), "+15551234567".into(), vec!["*".into()]); let payload = serde_json::json!({ "event_type": "message.received", @@ -516,7 +549,32 @@ mod tests { }); let msgs = ch.parse_webhook_payload(&payload); - assert!(msgs.is_empty(), "Media-only messages should be skipped"); + assert_eq!(msgs.len(), 1); + assert_eq!(msgs[0].content, "[IMAGE:https://example.com/image.jpg]"); + } + + #[test] + fn linq_parse_media_non_image_still_skipped() { + let ch = LinqChannel::new("tok".into(), "+15551234567".into(), vec!["*".into()]); + let payload = serde_json::json!({ + "event_type": "message.received", + "data": { + "chat_id": "chat-789", + "from": "+1234567890", + "is_from_me": false, + "message": { + "id": "msg-abc", + "parts": [{ + "type": "media", + "url": "https://example.com/sound.mp3", + "mime_type": "audio/mpeg" + }] + } + } + }); + + let msgs = ch.parse_webhook_payload(&payload); + assert!(msgs.is_empty(), "Non-image media should still be skipped"); } #[test] diff --git a/src/channels/mod.rs b/src/channels/mod.rs index 023064a..c6a58af 100644 --- a/src/channels/mod.rs +++ b/src/channels/mod.rs @@ -139,6 +139,7 @@ struct ChannelRuntimeContext { provider_runtime_options: providers::ProviderRuntimeOptions, workspace_dir: Arc, message_timeout_secs: u64, + multimodal: crate::config::MultimodalConfig, } fn conversation_memory_key(msg: &traits::ChannelMessage) -> String { @@ -810,6 +811,7 @@ async fn process_channel_message(ctx: Arc, msg: traits::C true, None, msg.channel.as_str(), + &ctx.multimodal, ctx.max_tool_iterations, delta_tx, ), @@ -2062,6 +2064,7 @@ pub async fn start_channels(config: Config) -> Result<()> { provider_runtime_options, workspace_dir: Arc::new(config.workspace_dir.clone()), message_timeout_secs, + multimodal: config.multimodal.clone(), }); run_message_dispatch_loop(rx, runtime_ctx, max_in_flight_messages).await; @@ -2559,6 +2562,7 @@ mod tests { provider_runtime_options: providers::ProviderRuntimeOptions::default(), workspace_dir: Arc::new(std::env::temp_dir()), message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS, + multimodal: crate::config::MultimodalConfig::default(), }); process_channel_message( @@ -2613,6 +2617,7 @@ mod tests { provider_runtime_options: providers::ProviderRuntimeOptions::default(), workspace_dir: Arc::new(std::env::temp_dir()), message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS, + multimodal: crate::config::MultimodalConfig::default(), }); process_channel_message( @@ -2676,6 +2681,7 @@ mod tests { provider_runtime_options: providers::ProviderRuntimeOptions::default(), workspace_dir: Arc::new(std::env::temp_dir()), message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS, + multimodal: crate::config::MultimodalConfig::default(), }); process_channel_message( @@ -2760,6 +2766,7 @@ mod tests { provider_runtime_options: providers::ProviderRuntimeOptions::default(), workspace_dir: Arc::new(std::env::temp_dir()), message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS, + multimodal: crate::config::MultimodalConfig::default(), }); process_channel_message( @@ -2820,6 +2827,7 @@ mod tests { provider_runtime_options: providers::ProviderRuntimeOptions::default(), workspace_dir: Arc::new(std::env::temp_dir()), message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS, + multimodal: crate::config::MultimodalConfig::default(), }); process_channel_message( @@ -2875,6 +2883,7 @@ mod tests { provider_runtime_options: providers::ProviderRuntimeOptions::default(), workspace_dir: Arc::new(std::env::temp_dir()), message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS, + multimodal: crate::config::MultimodalConfig::default(), }); process_channel_message( @@ -2981,6 +2990,7 @@ mod tests { provider_runtime_options: providers::ProviderRuntimeOptions::default(), workspace_dir: Arc::new(std::env::temp_dir()), message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS, + multimodal: crate::config::MultimodalConfig::default(), }); let (tx, rx) = tokio::sync::mpsc::channel::(4); @@ -3054,6 +3064,7 @@ mod tests { provider_runtime_options: providers::ProviderRuntimeOptions::default(), workspace_dir: Arc::new(std::env::temp_dir()), message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS, + multimodal: crate::config::MultimodalConfig::default(), }); process_channel_message( @@ -3451,6 +3462,7 @@ mod tests { provider_runtime_options: providers::ProviderRuntimeOptions::default(), workspace_dir: Arc::new(std::env::temp_dir()), message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS, + multimodal: crate::config::MultimodalConfig::default(), }); process_channel_message( diff --git a/src/config/mod.rs b/src/config/mod.rs index 72fbbf0..6bc5d71 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -9,11 +9,11 @@ pub use schema::{ DelegateAgentConfig, DiscordConfig, DockerRuntimeConfig, EmbeddingRouteConfig, GatewayConfig, HardwareConfig, HardwareTransport, HeartbeatConfig, HttpRequestConfig, IMessageConfig, IdentityConfig, LarkConfig, MatrixConfig, MemoryConfig, ModelRouteConfig, ObservabilityConfig, - PeripheralBoardConfig, PeripheralsConfig, ProxyConfig, ProxyScope, QueryClassificationConfig, - ReliabilityConfig, ResourceLimitsConfig, RuntimeConfig, SandboxBackend, SandboxConfig, - SchedulerConfig, SecretsConfig, SecurityConfig, SlackConfig, StorageConfig, - StorageProviderConfig, StorageProviderSection, StreamMode, TelegramConfig, TunnelConfig, - WebSearchConfig, WebhookConfig, + MultimodalConfig, ObservabilityConfig, PeripheralBoardConfig, PeripheralsConfig, ProxyConfig, + ProxyScope, QueryClassificationConfig, ReliabilityConfig, ResourceLimitsConfig, RuntimeConfig, + SandboxBackend, SandboxConfig, SchedulerConfig, SecretsConfig, SecurityConfig, SlackConfig, + StorageConfig, StorageProviderConfig, StorageProviderSection, StreamMode, TelegramConfig, + TunnelConfig, WebSearchConfig, WebhookConfig, }; #[cfg(test)] diff --git a/src/config/schema.rs b/src/config/schema.rs index 6fd28d7..f1fbb4f 100644 --- a/src/config/schema.rs +++ b/src/config/schema.rs @@ -124,6 +124,9 @@ pub struct Config { #[serde(default)] pub http_request: HttpRequestConfig, + #[serde(default)] + pub multimodal: MultimodalConfig, + #[serde(default)] pub web_search: WebSearchConfig, @@ -284,6 +287,46 @@ impl Default for AgentConfig { } } +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] +pub struct MultimodalConfig { + /// Maximum number of image attachments accepted per request. + #[serde(default = "default_multimodal_max_images")] + pub max_images: usize, + /// Maximum image payload size in MiB before base64 encoding. + #[serde(default = "default_multimodal_max_image_size_mb")] + pub max_image_size_mb: usize, + /// Allow fetching remote image URLs (http/https). Disabled by default. + #[serde(default)] + pub allow_remote_fetch: bool, +} + +fn default_multimodal_max_images() -> usize { + 4 +} + +fn default_multimodal_max_image_size_mb() -> usize { + 5 +} + +impl MultimodalConfig { + /// Clamp configured values to safe runtime bounds. + pub fn effective_limits(&self) -> (usize, usize) { + let max_images = self.max_images.clamp(1, 16); + let max_image_size_mb = self.max_image_size_mb.clamp(1, 20); + (max_images, max_image_size_mb) + } +} + +impl Default for MultimodalConfig { + fn default() -> Self { + Self { + max_images: default_multimodal_max_images(), + max_image_size_mb: default_multimodal_max_image_size_mb(), + allow_remote_fetch: false, + } + } +} + // ── Identity (AIEOS / OpenClaw format) ────────────────────────── #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] @@ -2534,6 +2577,7 @@ impl Default for Config { secrets: SecretsConfig::default(), browser: BrowserConfig::default(), http_request: HttpRequestConfig::default(), + multimodal: MultimodalConfig::default(), web_search: WebSearchConfig::default(), proxy: ProxyConfig::default(), identity: IdentityConfig::default(), @@ -3502,6 +3546,7 @@ default_temperature = 0.7 secrets: SecretsConfig::default(), browser: BrowserConfig::default(), http_request: HttpRequestConfig::default(), + multimodal: MultimodalConfig::default(), web_search: WebSearchConfig::default(), proxy: ProxyConfig::default(), agent: AgentConfig::default(), @@ -3656,6 +3701,7 @@ tool_dispatcher = "xml" secrets: SecretsConfig::default(), browser: BrowserConfig::default(), http_request: HttpRequestConfig::default(), + multimodal: MultimodalConfig::default(), web_search: WebSearchConfig::default(), proxy: ProxyConfig::default(), agent: AgentConfig::default(), diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index db55c00..2f56909 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -10,7 +10,7 @@ use crate::channels::{Channel, LinqChannel, SendMessage, WhatsAppChannel}; use crate::config::Config; use crate::memory::{self, Memory, MemoryCategory}; -use crate::providers::{self, Provider}; +use crate::providers::{self, ChatMessage, Provider, ProviderCapabilityError}; use crate::runtime; use crate::security::pairing::{constant_time_eq, is_public_bind, PairingGuard}; use crate::security::SecurityPolicy; @@ -666,6 +666,52 @@ async fn persist_pairing_tokens(config: Arc>, pairing: &PairingGua Ok(()) } +async fn run_gateway_chat_with_multimodal( + state: &AppState, + provider_label: &str, + message: &str, +) -> anyhow::Result { + let user_messages = vec![ChatMessage::user(message)]; + let image_marker_count = crate::multimodal::count_image_markers(&user_messages); + if image_marker_count > 0 && !state.provider.supports_vision() { + return Err(ProviderCapabilityError { + provider: provider_label.to_string(), + capability: "vision".to_string(), + message: format!( + "received {image_marker_count} image marker(s), but this provider does not support vision input" + ), + } + .into()); + } + + // Keep webhook/gateway prompts aligned with channel behavior by injecting + // workspace-aware system context before model invocation. + let system_prompt = { + let config_guard = state.config.lock(); + crate::channels::build_system_prompt( + &config_guard.workspace_dir, + &state.model, + &[], // tools - empty for simple chat + &[], // skills + Some(&config_guard.identity), + None, // bootstrap_max_chars - use default + ) + }; + + let mut messages = Vec::with_capacity(1 + user_messages.len()); + messages.push(ChatMessage::system(system_prompt)); + messages.extend(user_messages); + + let multimodal_config = state.config.lock().multimodal.clone(); + let prepared = + crate::multimodal::prepare_messages_for_provider(&messages, &multimodal_config).await?; + + state + .provider + .chat_with_history(&prepared.messages, &state.model, state.temperature) + .await +} + /// Webhook request body #[derive(serde::Deserialize)] pub struct WebhookBody { @@ -787,30 +833,7 @@ async fn handle_webhook( messages_count: 1, }); - // Build system prompt with workspace context (IDENTITY.md, AGENTS.md, etc.) - let system_prompt = { - let config_guard = state.config.lock(); - crate::channels::build_system_prompt( - &config_guard.workspace_dir, - &state.model, - &[], // tools - empty for simple chat - &[], // skills - Some(&config_guard.identity), - None, // bootstrap_max_chars - use default - ) - }; - - // Call the LLM with separate system prompt - match state - .provider - .chat_with_system( - Some(&system_prompt), - message, - &state.model, - state.temperature, - ) - .await - { + match run_gateway_chat_with_multimodal(&state, &provider_label, message).await { Ok(response) => { let duration = started_at.elapsed(); state @@ -994,6 +1017,12 @@ async fn handle_whatsapp_message( } // Process each message + let provider_label = state + .config + .lock() + .default_provider + .clone() + .unwrap_or_else(|| "unknown".to_string()); for msg in &messages { tracing::info!( "WhatsApp message from {}: {}", @@ -1010,30 +1039,7 @@ async fn handle_whatsapp_message( .await; } - // Build system prompt with workspace context (IDENTITY.md, AGENTS.md, etc.) - let system_prompt = { - let config_guard = state.config.lock(); - crate::channels::build_system_prompt( - &config_guard.workspace_dir, - &state.model, - &[], // tools - empty for simple chat - &[], // skills - Some(&config_guard.identity), - None, // bootstrap_max_chars - use default - ) - }; - - // Call the LLM with separate system prompt - match state - .provider - .chat_with_system( - Some(&system_prompt), - &msg.content, - &state.model, - state.temperature, - ) - .await - { + match run_gateway_chat_with_multimodal(&state, &provider_label, &msg.content).await { Ok(response) => { // Send reply via WhatsApp if let Err(e) = wa @@ -1124,6 +1130,12 @@ async fn handle_linq_webhook( } // Process each message + let provider_label = state + .config + .lock() + .default_provider + .clone() + .unwrap_or_else(|| "unknown".to_string()); for msg in &messages { tracing::info!( "Linq message from {}: {}", @@ -1141,11 +1153,7 @@ async fn handle_linq_webhook( } // Call the LLM - match state - .provider - .simple_chat(&msg.content, &state.model, state.temperature) - .await - { + match run_gateway_chat_with_multimodal(&state, &provider_label, &msg.content).await { Ok(response) => { // Send reply via Linq if let Err(e) = linq diff --git a/src/lib.rs b/src/lib.rs index bb56298..600fa1d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -55,6 +55,7 @@ pub mod identity; pub mod integrations; pub mod memory; pub mod migration; +pub mod multimodal; pub mod observability; pub mod onboard; pub mod peripherals; diff --git a/src/main.rs b/src/main.rs index 8e21ae8..672e7e6 100644 --- a/src/main.rs +++ b/src/main.rs @@ -58,6 +58,7 @@ mod identity; mod integrations; mod memory; mod migration; +mod multimodal; mod observability; mod onboard; mod peripherals; diff --git a/src/multimodal.rs b/src/multimodal.rs new file mode 100644 index 0000000..bd15900 --- /dev/null +++ b/src/multimodal.rs @@ -0,0 +1,568 @@ +use crate::config::{build_runtime_proxy_client_with_timeouts, MultimodalConfig}; +use crate::providers::ChatMessage; +use base64::{engine::general_purpose::STANDARD, Engine as _}; +use reqwest::Client; +use std::path::Path; + +const IMAGE_MARKER_PREFIX: &str = "[IMAGE:"; +const ALLOWED_IMAGE_MIME_TYPES: &[&str] = &[ + "image/png", + "image/jpeg", + "image/webp", + "image/gif", + "image/bmp", +]; + +#[derive(Debug, Clone)] +pub struct PreparedMessages { + pub messages: Vec, + pub contains_images: bool, +} + +#[derive(Debug, thiserror::Error)] +pub enum MultimodalError { + #[error("multimodal image limit exceeded: max_images={max_images}, found={found}")] + TooManyImages { max_images: usize, found: usize }, + + #[error("multimodal image size limit exceeded for '{input}': {size_bytes} bytes > {max_bytes} bytes")] + ImageTooLarge { + input: String, + size_bytes: usize, + max_bytes: usize, + }, + + #[error("multimodal image MIME type is not allowed for '{input}': {mime}")] + UnsupportedMime { input: String, mime: String }, + + #[error("multimodal remote image fetch is disabled for '{input}'")] + RemoteFetchDisabled { input: String }, + + #[error("multimodal image source not found or unreadable: '{input}'")] + ImageSourceNotFound { input: String }, + + #[error("invalid multimodal image marker '{input}': {reason}")] + InvalidMarker { input: String, reason: String }, + + #[error("failed to download remote image '{input}': {reason}")] + RemoteFetchFailed { input: String, reason: String }, + + #[error("failed to read local image '{input}': {reason}")] + LocalReadFailed { input: String, reason: String }, +} + +pub fn parse_image_markers(content: &str) -> (String, Vec) { + let mut refs = Vec::new(); + let mut cleaned = String::with_capacity(content.len()); + let mut cursor = 0usize; + + while let Some(rel_start) = content[cursor..].find(IMAGE_MARKER_PREFIX) { + let start = cursor + rel_start; + cleaned.push_str(&content[cursor..start]); + + let marker_start = start + IMAGE_MARKER_PREFIX.len(); + let Some(rel_end) = content[marker_start..].find(']') else { + cleaned.push_str(&content[start..]); + cursor = content.len(); + break; + }; + + let end = marker_start + rel_end; + let candidate = content[marker_start..end].trim(); + + if candidate.is_empty() { + cleaned.push_str(&content[start..=end]); + } else { + refs.push(candidate.to_string()); + } + + cursor = end + 1; + } + + if cursor < content.len() { + cleaned.push_str(&content[cursor..]); + } + + (cleaned.trim().to_string(), refs) +} + +pub fn count_image_markers(messages: &[ChatMessage]) -> usize { + messages + .iter() + .filter(|m| m.role == "user") + .map(|m| parse_image_markers(&m.content).1.len()) + .sum() +} + +pub fn contains_image_markers(messages: &[ChatMessage]) -> bool { + count_image_markers(messages) > 0 +} + +pub fn extract_ollama_image_payload(image_ref: &str) -> Option { + if image_ref.starts_with("data:") { + let comma_idx = image_ref.find(',')?; + let (_, payload) = image_ref.split_at(comma_idx + 1); + let payload = payload.trim(); + if payload.is_empty() { + None + } else { + Some(payload.to_string()) + } + } else { + Some(image_ref.trim().to_string()).filter(|value| !value.is_empty()) + } +} + +pub async fn prepare_messages_for_provider( + messages: &[ChatMessage], + config: &MultimodalConfig, +) -> anyhow::Result { + let (max_images, max_image_size_mb) = config.effective_limits(); + let max_bytes = max_image_size_mb.saturating_mul(1024 * 1024); + + let found_images = count_image_markers(messages); + if found_images > max_images { + return Err(MultimodalError::TooManyImages { + max_images, + found: found_images, + } + .into()); + } + + if found_images == 0 { + return Ok(PreparedMessages { + messages: messages.to_vec(), + contains_images: false, + }); + } + + let remote_client = build_runtime_proxy_client_with_timeouts("provider.ollama", 30, 10); + + let mut normalized_messages = Vec::with_capacity(messages.len()); + for message in messages { + if message.role != "user" { + normalized_messages.push(message.clone()); + continue; + } + + let (cleaned_text, refs) = parse_image_markers(&message.content); + if refs.is_empty() { + normalized_messages.push(message.clone()); + continue; + } + + let mut normalized_refs = Vec::with_capacity(refs.len()); + for reference in refs { + let data_uri = + normalize_image_reference(&reference, config, max_bytes, &remote_client).await?; + normalized_refs.push(data_uri); + } + + let content = compose_multimodal_message(&cleaned_text, &normalized_refs); + normalized_messages.push(ChatMessage { + role: message.role.clone(), + content, + }); + } + + Ok(PreparedMessages { + messages: normalized_messages, + contains_images: true, + }) +} + +fn compose_multimodal_message(text: &str, data_uris: &[String]) -> String { + let mut content = String::new(); + let trimmed = text.trim(); + + if !trimmed.is_empty() { + content.push_str(trimmed); + content.push_str("\n\n"); + } + + for (index, data_uri) in data_uris.iter().enumerate() { + if index > 0 { + content.push('\n'); + } + content.push_str(IMAGE_MARKER_PREFIX); + content.push_str(data_uri); + content.push(']'); + } + + content +} + +async fn normalize_image_reference( + source: &str, + config: &MultimodalConfig, + max_bytes: usize, + remote_client: &Client, +) -> anyhow::Result { + if source.starts_with("data:") { + return normalize_data_uri(source, max_bytes); + } + + if source.starts_with("http://") || source.starts_with("https://") { + if !config.allow_remote_fetch { + return Err(MultimodalError::RemoteFetchDisabled { + input: source.to_string(), + } + .into()); + } + + return normalize_remote_image(source, max_bytes, remote_client).await; + } + + normalize_local_image(source, max_bytes).await +} + +fn normalize_data_uri(source: &str, max_bytes: usize) -> anyhow::Result { + let Some(comma_idx) = source.find(',') else { + return Err(MultimodalError::InvalidMarker { + input: source.to_string(), + reason: "expected data URI payload".to_string(), + } + .into()); + }; + + let header = &source[..comma_idx]; + let payload = source[comma_idx + 1..].trim(); + + if !header.contains(";base64") { + return Err(MultimodalError::InvalidMarker { + input: source.to_string(), + reason: "only base64 data URIs are supported".to_string(), + } + .into()); + } + + let mime = header + .trim_start_matches("data:") + .split(';') + .next() + .unwrap_or_default() + .trim() + .to_ascii_lowercase(); + + validate_mime(source, &mime)?; + + let decoded = STANDARD + .decode(payload) + .map_err(|error| MultimodalError::InvalidMarker { + input: source.to_string(), + reason: format!("invalid base64 payload: {error}"), + })?; + + validate_size(source, decoded.len(), max_bytes)?; + + Ok(format!("data:{mime};base64,{}", STANDARD.encode(decoded))) +} + +async fn normalize_remote_image( + source: &str, + max_bytes: usize, + remote_client: &Client, +) -> anyhow::Result { + let response = remote_client.get(source).send().await.map_err(|error| { + MultimodalError::RemoteFetchFailed { + input: source.to_string(), + reason: error.to_string(), + } + })?; + + let status = response.status(); + if !status.is_success() { + return Err(MultimodalError::RemoteFetchFailed { + input: source.to_string(), + reason: format!("HTTP {status}"), + } + .into()); + } + + if let Some(content_length) = response.content_length() { + let content_length = content_length as usize; + validate_size(source, content_length, max_bytes)?; + } + + let content_type = response + .headers() + .get(reqwest::header::CONTENT_TYPE) + .and_then(|value| value.to_str().ok()) + .map(ToString::to_string); + + let bytes = response + .bytes() + .await + .map_err(|error| MultimodalError::RemoteFetchFailed { + input: source.to_string(), + reason: error.to_string(), + })?; + + validate_size(source, bytes.len(), max_bytes)?; + + let mime = detect_mime(None, bytes.as_ref(), content_type.as_deref()).ok_or_else(|| { + MultimodalError::UnsupportedMime { + input: source.to_string(), + mime: "unknown".to_string(), + } + })?; + + validate_mime(source, &mime)?; + + Ok(format!("data:{mime};base64,{}", STANDARD.encode(bytes))) +} + +async fn normalize_local_image(source: &str, max_bytes: usize) -> anyhow::Result { + let path = Path::new(source); + if !path.exists() || !path.is_file() { + return Err(MultimodalError::ImageSourceNotFound { + input: source.to_string(), + } + .into()); + } + + let metadata = + tokio::fs::metadata(path) + .await + .map_err(|error| MultimodalError::LocalReadFailed { + input: source.to_string(), + reason: error.to_string(), + })?; + + validate_size(source, metadata.len() as usize, max_bytes)?; + + let bytes = tokio::fs::read(path) + .await + .map_err(|error| MultimodalError::LocalReadFailed { + input: source.to_string(), + reason: error.to_string(), + })?; + + validate_size(source, bytes.len(), max_bytes)?; + + let mime = + detect_mime(Some(path), &bytes, None).ok_or_else(|| MultimodalError::UnsupportedMime { + input: source.to_string(), + mime: "unknown".to_string(), + })?; + + validate_mime(source, &mime)?; + + Ok(format!("data:{mime};base64,{}", STANDARD.encode(bytes))) +} + +fn validate_size(source: &str, size_bytes: usize, max_bytes: usize) -> anyhow::Result<()> { + if size_bytes > max_bytes { + return Err(MultimodalError::ImageTooLarge { + input: source.to_string(), + size_bytes, + max_bytes, + } + .into()); + } + + Ok(()) +} + +fn validate_mime(source: &str, mime: &str) -> anyhow::Result<()> { + if ALLOWED_IMAGE_MIME_TYPES + .iter() + .any(|allowed| *allowed == mime) + { + return Ok(()); + } + + Err(MultimodalError::UnsupportedMime { + input: source.to_string(), + mime: mime.to_string(), + } + .into()) +} + +fn detect_mime( + path: Option<&Path>, + bytes: &[u8], + header_content_type: Option<&str>, +) -> Option { + if let Some(header_mime) = header_content_type.and_then(normalize_content_type) { + return Some(header_mime); + } + + if let Some(path) = path { + if let Some(ext) = path.extension().and_then(|value| value.to_str()) { + if let Some(mime) = mime_from_extension(ext) { + return Some(mime.to_string()); + } + } + } + + mime_from_magic(bytes).map(ToString::to_string) +} + +fn normalize_content_type(content_type: &str) -> Option { + let mime = content_type.split(';').next()?.trim().to_ascii_lowercase(); + if mime.is_empty() { + None + } else { + Some(mime) + } +} + +fn mime_from_extension(ext: &str) -> Option<&'static str> { + match ext.to_ascii_lowercase().as_str() { + "png" => Some("image/png"), + "jpg" | "jpeg" => Some("image/jpeg"), + "webp" => Some("image/webp"), + "gif" => Some("image/gif"), + "bmp" => Some("image/bmp"), + _ => None, + } +} + +fn mime_from_magic(bytes: &[u8]) -> Option<&'static str> { + if bytes.len() >= 8 && bytes.starts_with(&[0x89, b'P', b'N', b'G', b'\r', b'\n', 0x1a, b'\n']) { + return Some("image/png"); + } + + if bytes.len() >= 3 && bytes.starts_with(&[0xff, 0xd8, 0xff]) { + return Some("image/jpeg"); + } + + if bytes.len() >= 6 && (bytes.starts_with(b"GIF87a") || bytes.starts_with(b"GIF89a")) { + return Some("image/gif"); + } + + if bytes.len() >= 12 && bytes.starts_with(b"RIFF") && &bytes[8..12] == b"WEBP" { + return Some("image/webp"); + } + + if bytes.len() >= 2 && bytes.starts_with(b"BM") { + return Some("image/bmp"); + } + + None +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn parse_image_markers_extracts_multiple_markers() { + let input = "Check this [IMAGE:/tmp/a.png] and this [IMAGE:https://example.com/b.jpg]"; + let (cleaned, refs) = parse_image_markers(input); + + assert_eq!(cleaned, "Check this and this"); + assert_eq!(refs.len(), 2); + assert_eq!(refs[0], "/tmp/a.png"); + assert_eq!(refs[1], "https://example.com/b.jpg"); + } + + #[test] + fn parse_image_markers_keeps_invalid_empty_marker() { + let input = "hello [IMAGE:] world"; + let (cleaned, refs) = parse_image_markers(input); + + assert_eq!(cleaned, "hello [IMAGE:] world"); + assert!(refs.is_empty()); + } + + #[tokio::test] + async fn prepare_messages_normalizes_local_image_to_data_uri() { + let temp = tempfile::tempdir().unwrap(); + let image_path = temp.path().join("sample.png"); + + // Minimal PNG signature bytes are enough for MIME detection. + std::fs::write( + &image_path, + [0x89, b'P', b'N', b'G', b'\r', b'\n', 0x1a, b'\n'], + ) + .unwrap(); + + let messages = vec![ChatMessage::user(format!( + "Please inspect this screenshot [IMAGE:{}]", + image_path.display() + ))]; + + let prepared = prepare_messages_for_provider(&messages, &MultimodalConfig::default()) + .await + .unwrap(); + + assert!(prepared.contains_images); + assert_eq!(prepared.messages.len(), 1); + + let (cleaned, refs) = parse_image_markers(&prepared.messages[0].content); + assert_eq!(cleaned, "Please inspect this screenshot"); + assert_eq!(refs.len(), 1); + assert!(refs[0].starts_with("data:image/png;base64,")); + } + + #[tokio::test] + async fn prepare_messages_rejects_too_many_images() { + let messages = vec![ChatMessage::user( + "[IMAGE:/tmp/1.png]\n[IMAGE:/tmp/2.png]".to_string(), + )]; + + let config = MultimodalConfig { + max_images: 1, + max_image_size_mb: 5, + allow_remote_fetch: false, + }; + + let error = prepare_messages_for_provider(&messages, &config) + .await + .expect_err("should reject image count overflow"); + + assert!(error + .to_string() + .contains("multimodal image limit exceeded")); + } + + #[tokio::test] + async fn prepare_messages_rejects_remote_url_when_disabled() { + let messages = vec![ChatMessage::user( + "Look [IMAGE:https://example.com/img.png]".to_string(), + )]; + + let error = prepare_messages_for_provider(&messages, &MultimodalConfig::default()) + .await + .expect_err("should reject remote image URL when fetch is disabled"); + + assert!(error + .to_string() + .contains("multimodal remote image fetch is disabled")); + } + + #[tokio::test] + async fn prepare_messages_rejects_oversized_local_image() { + let temp = tempfile::tempdir().unwrap(); + let image_path = temp.path().join("big.png"); + + let bytes = vec![0u8; 1024 * 1024 + 1]; + std::fs::write(&image_path, bytes).unwrap(); + + let messages = vec![ChatMessage::user(format!( + "[IMAGE:{}]", + image_path.display() + ))]; + let config = MultimodalConfig { + max_images: 4, + max_image_size_mb: 1, + allow_remote_fetch: false, + }; + + let error = prepare_messages_for_provider(&messages, &config) + .await + .expect_err("should reject oversized local image"); + + assert!(error + .to_string() + .contains("multimodal image size limit exceeded")); + } + + #[test] + fn extract_ollama_image_payload_supports_data_uris() { + let payload = extract_ollama_image_payload("data:image/png;base64,abcd==") + .expect("payload should be extracted"); + assert_eq!(payload, "abcd=="); + } +} diff --git a/src/onboard/wizard.rs b/src/onboard/wizard.rs index db28394..952de31 100644 --- a/src/onboard/wizard.rs +++ b/src/onboard/wizard.rs @@ -173,6 +173,7 @@ pub async fn run_wizard() -> Result { secrets: secrets_config, browser: BrowserConfig::default(), http_request: crate::config::HttpRequestConfig::default(), + multimodal: crate::config::MultimodalConfig::default(), web_search: crate::config::WebSearchConfig::default(), proxy: crate::config::ProxyConfig::default(), identity: crate::config::IdentityConfig::default(), @@ -391,6 +392,7 @@ pub async fn run_quick_setup( secrets: SecretsConfig::default(), browser: BrowserConfig::default(), http_request: crate::config::HttpRequestConfig::default(), + multimodal: crate::config::MultimodalConfig::default(), web_search: crate::config::WebSearchConfig::default(), proxy: crate::config::ProxyConfig::default(), identity: crate::config::IdentityConfig::default(), diff --git a/src/providers/bedrock.rs b/src/providers/bedrock.rs index ba3ccf6..3807fc2 100644 --- a/src/providers/bedrock.rs +++ b/src/providers/bedrock.rs @@ -639,6 +639,7 @@ impl Provider for BedrockProvider { fn capabilities(&self) -> ProviderCapabilities { ProviderCapabilities { native_tool_calling: true, + vision: false, } } diff --git a/src/providers/compatible.rs b/src/providers/compatible.rs index 43716a1..615ac6d 100644 --- a/src/providers/compatible.rs +++ b/src/providers/compatible.rs @@ -898,6 +898,7 @@ impl Provider for OpenAiCompatibleProvider { fn capabilities(&self) -> crate::providers::traits::ProviderCapabilities { crate::providers::traits::ProviderCapabilities { native_tool_calling: true, + vision: false, } } diff --git a/src/providers/mod.rs b/src/providers/mod.rs index 107866c..aa453e5 100644 --- a/src/providers/mod.rs +++ b/src/providers/mod.rs @@ -13,8 +13,8 @@ pub mod traits; #[allow(unused_imports)] pub use traits::{ - ChatMessage, ChatRequest, ChatResponse, ConversationMessage, Provider, ToolCall, - ToolResultMessage, + ChatMessage, ChatRequest, ChatResponse, ConversationMessage, Provider, ProviderCapabilityError, + ToolCall, ToolResultMessage, }; use compatible::{AuthStyle, OpenAiCompatibleProvider}; diff --git a/src/providers/ollama.rs b/src/providers/ollama.rs index 91a45ad..0582872 100644 --- a/src/providers/ollama.rs +++ b/src/providers/ollama.rs @@ -1,4 +1,7 @@ -use crate::providers::traits::{ChatMessage, ChatResponse, Provider, ToolCall}; +use crate::multimodal; +use crate::providers::traits::{ + ChatMessage, ChatResponse, Provider, ProviderCapabilities, ToolCall, +}; use async_trait::async_trait; use reqwest::Client; use serde::{Deserialize, Serialize}; @@ -30,6 +33,8 @@ struct Message { #[serde(skip_serializing_if = "Option::is_none")] content: Option, #[serde(skip_serializing_if = "Option::is_none")] + images: Option>, + #[serde(skip_serializing_if = "Option::is_none")] tool_calls: Option>, #[serde(skip_serializing_if = "Option::is_none")] tool_name: Option, @@ -166,6 +171,31 @@ impl OllamaProvider { } } + fn convert_user_message_content(&self, content: &str) -> (Option, Option>) { + let (cleaned, image_refs) = multimodal::parse_image_markers(content); + if image_refs.is_empty() { + return (Some(content.to_string()), None); + } + + let images: Vec = image_refs + .iter() + .filter_map(|reference| multimodal::extract_ollama_image_payload(reference)) + .collect(); + + if images.is_empty() { + return (Some(content.to_string()), None); + } + + let cleaned = cleaned.trim(); + let content = if cleaned.is_empty() { + None + } else { + Some(cleaned.to_string()) + }; + + (content, Some(images)) + } + /// Convert internal chat history format to Ollama's native tool-call message schema. /// /// `run_tool_call_loop` stores native assistant/tool entries as JSON strings in @@ -205,6 +235,7 @@ impl OllamaProvider { return Message { role: "assistant".to_string(), content, + images: None, tool_calls: Some(outgoing_calls), tool_name: None, }; @@ -238,15 +269,28 @@ impl OllamaProvider { return Message { role: "tool".to_string(), content, + images: None, tool_calls: None, tool_name, }; } } + if message.role == "user" { + let (content, images) = self.convert_user_message_content(&message.content); + return Message { + role: "user".to_string(), + content, + images, + tool_calls: None, + tool_name: None, + }; + } + Message { role: message.role.clone(), content: Some(message.content.clone()), + images: None, tool_calls: None, tool_name: None, } @@ -398,6 +442,13 @@ impl OllamaProvider { #[async_trait] impl Provider for OllamaProvider { + fn capabilities(&self) -> ProviderCapabilities { + ProviderCapabilities { + native_tool_calling: true, + vision: true, + } + } + async fn chat_with_system( &self, system_prompt: Option<&str>, @@ -413,14 +464,17 @@ impl Provider for OllamaProvider { messages.push(Message { role: "system".to_string(), content: Some(sys.to_string()), + images: None, tool_calls: None, tool_name: None, }); } + let (user_content, user_images) = self.convert_user_message_content(message); messages.push(Message { role: "user".to_string(), - content: Some(message.to_string()), + content: user_content, + images: user_images, tool_calls: None, tool_name: None, }); @@ -862,4 +916,34 @@ mod tests { assert_eq!(converted[1].content.as_deref(), Some("ok")); assert!(converted[1].tool_calls.is_none()); } + + #[test] + fn convert_messages_extracts_images_from_user_marker() { + let provider = OllamaProvider::new(None, None); + let messages = vec![ChatMessage { + role: "user".into(), + content: "Inspect this screenshot [IMAGE:data:image/png;base64,abcd==]".into(), + }]; + + let converted = provider.convert_messages(&messages); + assert_eq!(converted.len(), 1); + assert_eq!(converted[0].role, "user"); + assert_eq!( + converted[0].content.as_deref(), + Some("Inspect this screenshot") + ); + let images = converted[0] + .images + .as_ref() + .expect("images should be present"); + assert_eq!(images, &vec!["abcd==".to_string()]); + } + + #[test] + fn capabilities_include_native_tools_and_vision() { + let provider = OllamaProvider::new(None, None); + let caps = ::capabilities(&provider); + assert!(caps.native_tool_calling); + assert!(caps.vision); + } } diff --git a/src/providers/reliable.rs b/src/providers/reliable.rs index bafe1bc..61812e7 100644 --- a/src/providers/reliable.rs +++ b/src/providers/reliable.rs @@ -511,6 +511,12 @@ impl Provider for ReliableProvider { .unwrap_or(false) } + fn supports_vision(&self) -> bool { + self.providers + .iter() + .any(|(_, provider)| provider.supports_vision()) + } + async fn chat_with_tools( &self, messages: &[ChatMessage], diff --git a/src/providers/router.rs b/src/providers/router.rs index 2d55869..b12bd52 100644 --- a/src/providers/router.rs +++ b/src/providers/router.rs @@ -158,6 +158,12 @@ impl Provider for RouterProvider { .unwrap_or(false) } + fn supports_vision(&self) -> bool { + self.providers + .iter() + .any(|(_, provider)| provider.supports_vision()) + } + async fn warmup(&self) -> anyhow::Result<()> { for (name, provider) in &self.providers { tracing::info!(provider = name, "Warming up routed provider"); diff --git a/src/providers/traits.rs b/src/providers/traits.rs index fe830ef..bfb3506 100644 --- a/src/providers/traits.rs +++ b/src/providers/traits.rs @@ -192,6 +192,15 @@ pub enum StreamError { Io(#[from] std::io::Error), } +/// Structured error returned when a requested capability is not supported. +#[derive(Debug, Clone, thiserror::Error)] +#[error("provider_capability_error provider={provider} capability={capability} message={message}")] +pub struct ProviderCapabilityError { + pub provider: String, + pub capability: String, + pub message: String, +} + /// Provider capabilities declaration. /// /// Describes what features a provider supports, enabling intelligent @@ -205,6 +214,8 @@ pub struct ProviderCapabilities { /// /// When `false`, tools must be injected via system prompt as text. pub native_tool_calling: bool, + /// Whether the provider supports vision / image inputs. + pub vision: bool, } /// Provider-specific tool payload formats. @@ -351,6 +362,11 @@ pub trait Provider: Send + Sync { self.capabilities().native_tool_calling } + /// Whether provider supports multimodal vision input. + fn supports_vision(&self) -> bool { + self.capabilities().vision + } + /// Warm up the HTTP connection pool (TLS handshake, DNS, HTTP/2 setup). /// Default implementation is a no-op; providers with HTTP clients should override. async fn warmup(&self) -> anyhow::Result<()> { @@ -458,6 +474,7 @@ mod tests { fn capabilities(&self) -> ProviderCapabilities { ProviderCapabilities { native_tool_calling: true, + vision: true, } } @@ -539,18 +556,22 @@ mod tests { fn provider_capabilities_default() { let caps = ProviderCapabilities::default(); assert!(!caps.native_tool_calling); + assert!(!caps.vision); } #[test] fn provider_capabilities_equality() { let caps1 = ProviderCapabilities { native_tool_calling: true, + vision: false, }; let caps2 = ProviderCapabilities { native_tool_calling: true, + vision: false, }; let caps3 = ProviderCapabilities { native_tool_calling: false, + vision: false, }; assert_eq!(caps1, caps2); @@ -563,6 +584,12 @@ mod tests { assert!(provider.supports_native_tools()); } + #[test] + fn supports_vision_reflects_capabilities_default_mapping() { + let provider = CapabilityMockProvider; + assert!(provider.supports_vision()); + } + #[test] fn tools_payload_variants() { // Test Gemini variant diff --git a/src/tools/proxy_config.rs b/src/tools/proxy_config.rs index a4d90d1..213a57e 100644 --- a/src/tools/proxy_config.rs +++ b/src/tools/proxy_config.rs @@ -235,7 +235,9 @@ impl ProxyConfigTool { } if args.get("enabled").is_none() && touched_proxy_url { - proxy.enabled = true; + // Keep auto-enable behavior when users provide a proxy URL, but + // auto-disable when all proxy URLs are cleared in the same update. + proxy.enabled = proxy.has_any_proxy_url(); } proxy.no_proxy = proxy.normalized_no_proxy();