feat: add multimodal image marker support with Ollama vision

This commit is contained in:
Chummy 2026-02-19 20:24:56 +08:00
parent 63aacb09ff
commit dcd0bf641d
21 changed files with 1152 additions and 78 deletions

View file

@ -1,8 +1,11 @@
use crate::approval::{ApprovalManager, ApprovalRequest, ApprovalResponse};
use crate::config::Config;
use crate::memory::{self, Memory, MemoryCategory};
use crate::multimodal;
use crate::observability::{self, Observer, ObserverEvent};
use crate::providers::{self, ChatMessage, ChatRequest, Provider, ToolCall};
use crate::providers::{
self, ChatMessage, ChatRequest, Provider, ProviderCapabilityError, ToolCall,
};
use crate::runtime;
use crate::security::SecurityPolicy;
use crate::tools::{self, Tool};
@ -826,6 +829,7 @@ pub(crate) async fn agent_turn(
model: &str,
temperature: f64,
silent: bool,
multimodal_config: &crate::config::MultimodalConfig,
max_tool_iterations: usize,
) -> Result<String> {
run_tool_call_loop(
@ -839,6 +843,7 @@ pub(crate) async fn agent_turn(
silent,
None,
"channel",
multimodal_config,
max_tool_iterations,
None,
)
@ -859,6 +864,7 @@ pub(crate) async fn run_tool_call_loop(
silent: bool,
approval: Option<&ApprovalManager>,
channel_name: &str,
multimodal_config: &crate::config::MultimodalConfig,
max_tool_iterations: usize,
on_delta: Option<tokio::sync::mpsc::Sender<String>>,
) -> Result<String> {
@ -873,6 +879,21 @@ pub(crate) async fn run_tool_call_loop(
let use_native_tools = provider.supports_native_tools() && !tool_specs.is_empty();
for _iteration in 0..max_iterations {
let image_marker_count = multimodal::count_image_markers(history);
if image_marker_count > 0 && !provider.supports_vision() {
return Err(ProviderCapabilityError {
provider: provider_name.to_string(),
capability: "vision".to_string(),
message: format!(
"received {image_marker_count} image marker(s), but this provider does not support vision input"
),
}
.into());
}
let prepared_messages =
multimodal::prepare_messages_for_provider(history, multimodal_config).await?;
observer.record_event(&ObserverEvent::LlmRequest {
provider: provider_name.to_string(),
model: model.to_string(),
@ -893,7 +914,7 @@ pub(crate) async fn run_tool_call_loop(
match provider
.chat(
ChatRequest {
messages: history,
messages: &prepared_messages.messages,
tools: request_tools,
},
model,
@ -1404,6 +1425,7 @@ pub async fn run(
false,
Some(&approval_manager),
"cli",
&config.multimodal,
config.agent.max_tool_iterations,
None,
)
@ -1530,6 +1552,7 @@ pub async fn run(
false,
Some(&approval_manager),
"cli",
&config.multimodal,
config.agent.max_tool_iterations,
None,
)
@ -1757,6 +1780,7 @@ pub async fn process_message(config: Config, message: &str) -> Result<String> {
&model_name,
config.default_temperature,
true,
&config.multimodal,
config.agent.max_tool_iterations,
)
.await
@ -1765,6 +1789,10 @@ pub async fn process_message(config: Config, message: &str) -> Result<String> {
#[cfg(test)]
mod tests {
use super::*;
use async_trait::async_trait;
use base64::{engine::general_purpose::STANDARD, Engine as _};
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
#[test]
fn test_scrub_credentials() {
@ -1785,8 +1813,191 @@ mod tests {
assert!(scrubbed.contains("public"));
}
use crate::memory::{Memory, MemoryCategory, SqliteMemory};
use crate::observability::NoopObserver;
use crate::providers::traits::ProviderCapabilities;
use crate::providers::ChatResponse;
use tempfile::TempDir;
struct NonVisionProvider {
calls: Arc<AtomicUsize>,
}
#[async_trait]
impl Provider for NonVisionProvider {
async fn chat_with_system(
&self,
_system_prompt: Option<&str>,
_message: &str,
_model: &str,
_temperature: f64,
) -> anyhow::Result<String> {
self.calls.fetch_add(1, Ordering::SeqCst);
Ok("ok".to_string())
}
}
struct VisionProvider {
calls: Arc<AtomicUsize>,
}
#[async_trait]
impl Provider for VisionProvider {
fn capabilities(&self) -> ProviderCapabilities {
ProviderCapabilities {
native_tool_calling: false,
vision: true,
}
}
async fn chat_with_system(
&self,
_system_prompt: Option<&str>,
_message: &str,
_model: &str,
_temperature: f64,
) -> anyhow::Result<String> {
self.calls.fetch_add(1, Ordering::SeqCst);
Ok("ok".to_string())
}
async fn chat(
&self,
request: ChatRequest<'_>,
_model: &str,
_temperature: f64,
) -> anyhow::Result<ChatResponse> {
self.calls.fetch_add(1, Ordering::SeqCst);
let marker_count = crate::multimodal::count_image_markers(request.messages);
if marker_count == 0 {
anyhow::bail!("expected image markers in request messages");
}
if request.tools.is_some() {
anyhow::bail!("no tools should be attached for this test");
}
Ok(ChatResponse {
text: Some("vision-ok".to_string()),
tool_calls: Vec::new(),
})
}
}
#[tokio::test]
async fn run_tool_call_loop_returns_structured_error_for_non_vision_provider() {
let calls = Arc::new(AtomicUsize::new(0));
let provider = NonVisionProvider {
calls: Arc::clone(&calls),
};
let mut history = vec![ChatMessage::user(
"please inspect [IMAGE:data:image/png;base64,iVBORw0KGgo=]".to_string(),
)];
let tools_registry: Vec<Box<dyn Tool>> = Vec::new();
let observer = NoopObserver;
let err = run_tool_call_loop(
&provider,
&mut history,
&tools_registry,
&observer,
"mock-provider",
"mock-model",
0.0,
true,
None,
"cli",
&crate::config::MultimodalConfig::default(),
3,
None,
)
.await
.expect_err("provider without vision support should fail");
assert!(err.to_string().contains("provider_capability_error"));
assert!(err.to_string().contains("capability=vision"));
assert_eq!(calls.load(Ordering::SeqCst), 0);
}
#[tokio::test]
async fn run_tool_call_loop_rejects_oversized_image_payload() {
let calls = Arc::new(AtomicUsize::new(0));
let provider = VisionProvider {
calls: Arc::clone(&calls),
};
let oversized_payload = STANDARD.encode(vec![0_u8; (1024 * 1024) + 1]);
let mut history = vec![ChatMessage::user(format!(
"[IMAGE:data:image/png;base64,{oversized_payload}]"
))];
let tools_registry: Vec<Box<dyn Tool>> = Vec::new();
let observer = NoopObserver;
let multimodal = crate::config::MultimodalConfig {
max_images: 4,
max_image_size_mb: 1,
allow_remote_fetch: false,
};
let err = run_tool_call_loop(
&provider,
&mut history,
&tools_registry,
&observer,
"mock-provider",
"mock-model",
0.0,
true,
None,
"cli",
&multimodal,
3,
None,
)
.await
.expect_err("oversized payload must fail");
assert!(err
.to_string()
.contains("multimodal image size limit exceeded"));
assert_eq!(calls.load(Ordering::SeqCst), 0);
}
#[tokio::test]
async fn run_tool_call_loop_accepts_valid_multimodal_request_flow() {
let calls = Arc::new(AtomicUsize::new(0));
let provider = VisionProvider {
calls: Arc::clone(&calls),
};
let mut history = vec![ChatMessage::user(
"Analyze this [IMAGE:data:image/png;base64,iVBORw0KGgo=]".to_string(),
)];
let tools_registry: Vec<Box<dyn Tool>> = Vec::new();
let observer = NoopObserver;
let result = run_tool_call_loop(
&provider,
&mut history,
&tools_registry,
&observer,
"mock-provider",
"mock-model",
0.0,
true,
None,
"cli",
&crate::config::MultimodalConfig::default(),
3,
None,
)
.await
.expect("valid multimodal payload should pass");
assert_eq!(result, "vision-ok");
assert_eq!(calls.load(Ordering::SeqCst), 1);
}
#[test]
fn parse_tool_calls_extracts_single_call() {
let response = r#"Let me check that.

View file

@ -37,6 +37,28 @@ impl LinqChannel {
&self.from_phone
}
fn media_part_to_image_marker(part: &serde_json::Value) -> Option<String> {
let source = part
.get("url")
.or_else(|| part.get("value"))
.and_then(|value| value.as_str())
.map(str::trim)
.filter(|value| !value.is_empty())?;
let mime_type = part
.get("mime_type")
.and_then(|value| value.as_str())
.map(str::trim)
.unwrap_or_default()
.to_ascii_lowercase();
if !mime_type.starts_with("image/") {
return None;
}
Some(format!("[IMAGE:{source}]"))
}
/// Parse an incoming webhook payload from Linq and extract messages.
///
/// Linq webhook envelope:
@ -124,25 +146,36 @@ impl LinqChannel {
return messages;
};
let text_parts: Vec<&str> = parts
let content_parts: Vec<String> = parts
.iter()
.filter_map(|part| {
let part_type = part.get("type").and_then(|t| t.as_str())?;
if part_type == "text" {
part.get("value").and_then(|v| v.as_str())
} else {
// Skip media parts for now
tracing::debug!("Linq: skipping {part_type} part");
None
match part_type {
"text" => part
.get("value")
.and_then(|v| v.as_str())
.map(ToString::to_string),
"media" | "image" => {
if let Some(marker) = Self::media_part_to_image_marker(part) {
Some(marker)
} else {
tracing::debug!("Linq: skipping unsupported {part_type} part");
None
}
}
_ => {
tracing::debug!("Linq: skipping {part_type} part");
None
}
}
})
.collect();
if text_parts.is_empty() {
if content_parts.is_empty() {
return messages;
}
let content = text_parts.join("\n");
let content = content_parts.join("\n").trim().to_string();
if content.is_empty() {
return messages;
@ -496,7 +529,7 @@ mod tests {
}
#[test]
fn linq_parse_media_only_skipped() {
fn linq_parse_media_only_translated_to_image_marker() {
let ch = LinqChannel::new("tok".into(), "+15551234567".into(), vec!["*".into()]);
let payload = serde_json::json!({
"event_type": "message.received",
@ -516,7 +549,32 @@ mod tests {
});
let msgs = ch.parse_webhook_payload(&payload);
assert!(msgs.is_empty(), "Media-only messages should be skipped");
assert_eq!(msgs.len(), 1);
assert_eq!(msgs[0].content, "[IMAGE:https://example.com/image.jpg]");
}
#[test]
fn linq_parse_media_non_image_still_skipped() {
let ch = LinqChannel::new("tok".into(), "+15551234567".into(), vec!["*".into()]);
let payload = serde_json::json!({
"event_type": "message.received",
"data": {
"chat_id": "chat-789",
"from": "+1234567890",
"is_from_me": false,
"message": {
"id": "msg-abc",
"parts": [{
"type": "media",
"url": "https://example.com/sound.mp3",
"mime_type": "audio/mpeg"
}]
}
}
});
let msgs = ch.parse_webhook_payload(&payload);
assert!(msgs.is_empty(), "Non-image media should still be skipped");
}
#[test]

View file

@ -139,6 +139,7 @@ struct ChannelRuntimeContext {
provider_runtime_options: providers::ProviderRuntimeOptions,
workspace_dir: Arc<PathBuf>,
message_timeout_secs: u64,
multimodal: crate::config::MultimodalConfig,
}
fn conversation_memory_key(msg: &traits::ChannelMessage) -> String {
@ -810,6 +811,7 @@ async fn process_channel_message(ctx: Arc<ChannelRuntimeContext>, msg: traits::C
true,
None,
msg.channel.as_str(),
&ctx.multimodal,
ctx.max_tool_iterations,
delta_tx,
),
@ -2062,6 +2064,7 @@ pub async fn start_channels(config: Config) -> Result<()> {
provider_runtime_options,
workspace_dir: Arc::new(config.workspace_dir.clone()),
message_timeout_secs,
multimodal: config.multimodal.clone(),
});
run_message_dispatch_loop(rx, runtime_ctx, max_in_flight_messages).await;
@ -2559,6 +2562,7 @@ mod tests {
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
workspace_dir: Arc::new(std::env::temp_dir()),
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
multimodal: crate::config::MultimodalConfig::default(),
});
process_channel_message(
@ -2613,6 +2617,7 @@ mod tests {
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
workspace_dir: Arc::new(std::env::temp_dir()),
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
multimodal: crate::config::MultimodalConfig::default(),
});
process_channel_message(
@ -2676,6 +2681,7 @@ mod tests {
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
workspace_dir: Arc::new(std::env::temp_dir()),
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
multimodal: crate::config::MultimodalConfig::default(),
});
process_channel_message(
@ -2760,6 +2766,7 @@ mod tests {
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
workspace_dir: Arc::new(std::env::temp_dir()),
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
multimodal: crate::config::MultimodalConfig::default(),
});
process_channel_message(
@ -2820,6 +2827,7 @@ mod tests {
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
workspace_dir: Arc::new(std::env::temp_dir()),
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
multimodal: crate::config::MultimodalConfig::default(),
});
process_channel_message(
@ -2875,6 +2883,7 @@ mod tests {
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
workspace_dir: Arc::new(std::env::temp_dir()),
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
multimodal: crate::config::MultimodalConfig::default(),
});
process_channel_message(
@ -2981,6 +2990,7 @@ mod tests {
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
workspace_dir: Arc::new(std::env::temp_dir()),
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
multimodal: crate::config::MultimodalConfig::default(),
});
let (tx, rx) = tokio::sync::mpsc::channel::<traits::ChannelMessage>(4);
@ -3054,6 +3064,7 @@ mod tests {
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
workspace_dir: Arc::new(std::env::temp_dir()),
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
multimodal: crate::config::MultimodalConfig::default(),
});
process_channel_message(
@ -3451,6 +3462,7 @@ mod tests {
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
workspace_dir: Arc::new(std::env::temp_dir()),
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
multimodal: crate::config::MultimodalConfig::default(),
});
process_channel_message(

View file

@ -9,11 +9,11 @@ pub use schema::{
DelegateAgentConfig, DiscordConfig, DockerRuntimeConfig, EmbeddingRouteConfig, GatewayConfig,
HardwareConfig, HardwareTransport, HeartbeatConfig, HttpRequestConfig, IMessageConfig,
IdentityConfig, LarkConfig, MatrixConfig, MemoryConfig, ModelRouteConfig, ObservabilityConfig,
PeripheralBoardConfig, PeripheralsConfig, ProxyConfig, ProxyScope, QueryClassificationConfig,
ReliabilityConfig, ResourceLimitsConfig, RuntimeConfig, SandboxBackend, SandboxConfig,
SchedulerConfig, SecretsConfig, SecurityConfig, SlackConfig, StorageConfig,
StorageProviderConfig, StorageProviderSection, StreamMode, TelegramConfig, TunnelConfig,
WebSearchConfig, WebhookConfig,
MultimodalConfig, ObservabilityConfig, PeripheralBoardConfig, PeripheralsConfig, ProxyConfig,
ProxyScope, QueryClassificationConfig, ReliabilityConfig, ResourceLimitsConfig, RuntimeConfig,
SandboxBackend, SandboxConfig, SchedulerConfig, SecretsConfig, SecurityConfig, SlackConfig,
StorageConfig, StorageProviderConfig, StorageProviderSection, StreamMode, TelegramConfig,
TunnelConfig, WebSearchConfig, WebhookConfig,
};
#[cfg(test)]

View file

@ -124,6 +124,9 @@ pub struct Config {
#[serde(default)]
pub http_request: HttpRequestConfig,
#[serde(default)]
pub multimodal: MultimodalConfig,
#[serde(default)]
pub web_search: WebSearchConfig,
@ -284,6 +287,46 @@ impl Default for AgentConfig {
}
}
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
pub struct MultimodalConfig {
/// Maximum number of image attachments accepted per request.
#[serde(default = "default_multimodal_max_images")]
pub max_images: usize,
/// Maximum image payload size in MiB before base64 encoding.
#[serde(default = "default_multimodal_max_image_size_mb")]
pub max_image_size_mb: usize,
/// Allow fetching remote image URLs (http/https). Disabled by default.
#[serde(default)]
pub allow_remote_fetch: bool,
}
fn default_multimodal_max_images() -> usize {
4
}
fn default_multimodal_max_image_size_mb() -> usize {
5
}
impl MultimodalConfig {
/// Clamp configured values to safe runtime bounds.
pub fn effective_limits(&self) -> (usize, usize) {
let max_images = self.max_images.clamp(1, 16);
let max_image_size_mb = self.max_image_size_mb.clamp(1, 20);
(max_images, max_image_size_mb)
}
}
impl Default for MultimodalConfig {
fn default() -> Self {
Self {
max_images: default_multimodal_max_images(),
max_image_size_mb: default_multimodal_max_image_size_mb(),
allow_remote_fetch: false,
}
}
}
// ── Identity (AIEOS / OpenClaw format) ──────────────────────────
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
@ -2534,6 +2577,7 @@ impl Default for Config {
secrets: SecretsConfig::default(),
browser: BrowserConfig::default(),
http_request: HttpRequestConfig::default(),
multimodal: MultimodalConfig::default(),
web_search: WebSearchConfig::default(),
proxy: ProxyConfig::default(),
identity: IdentityConfig::default(),
@ -3502,6 +3546,7 @@ default_temperature = 0.7
secrets: SecretsConfig::default(),
browser: BrowserConfig::default(),
http_request: HttpRequestConfig::default(),
multimodal: MultimodalConfig::default(),
web_search: WebSearchConfig::default(),
proxy: ProxyConfig::default(),
agent: AgentConfig::default(),
@ -3656,6 +3701,7 @@ tool_dispatcher = "xml"
secrets: SecretsConfig::default(),
browser: BrowserConfig::default(),
http_request: HttpRequestConfig::default(),
multimodal: MultimodalConfig::default(),
web_search: WebSearchConfig::default(),
proxy: ProxyConfig::default(),
agent: AgentConfig::default(),

View file

@ -10,7 +10,7 @@
use crate::channels::{Channel, LinqChannel, SendMessage, WhatsAppChannel};
use crate::config::Config;
use crate::memory::{self, Memory, MemoryCategory};
use crate::providers::{self, Provider};
use crate::providers::{self, ChatMessage, Provider, ProviderCapabilityError};
use crate::runtime;
use crate::security::pairing::{constant_time_eq, is_public_bind, PairingGuard};
use crate::security::SecurityPolicy;
@ -666,6 +666,52 @@ async fn persist_pairing_tokens(config: Arc<Mutex<Config>>, pairing: &PairingGua
Ok(())
}
async fn run_gateway_chat_with_multimodal(
state: &AppState,
provider_label: &str,
message: &str,
) -> anyhow::Result<String> {
let user_messages = vec![ChatMessage::user(message)];
let image_marker_count = crate::multimodal::count_image_markers(&user_messages);
if image_marker_count > 0 && !state.provider.supports_vision() {
return Err(ProviderCapabilityError {
provider: provider_label.to_string(),
capability: "vision".to_string(),
message: format!(
"received {image_marker_count} image marker(s), but this provider does not support vision input"
),
}
.into());
}
// Keep webhook/gateway prompts aligned with channel behavior by injecting
// workspace-aware system context before model invocation.
let system_prompt = {
let config_guard = state.config.lock();
crate::channels::build_system_prompt(
&config_guard.workspace_dir,
&state.model,
&[], // tools - empty for simple chat
&[], // skills
Some(&config_guard.identity),
None, // bootstrap_max_chars - use default
)
};
let mut messages = Vec::with_capacity(1 + user_messages.len());
messages.push(ChatMessage::system(system_prompt));
messages.extend(user_messages);
let multimodal_config = state.config.lock().multimodal.clone();
let prepared =
crate::multimodal::prepare_messages_for_provider(&messages, &multimodal_config).await?;
state
.provider
.chat_with_history(&prepared.messages, &state.model, state.temperature)
.await
}
/// Webhook request body
#[derive(serde::Deserialize)]
pub struct WebhookBody {
@ -787,30 +833,7 @@ async fn handle_webhook(
messages_count: 1,
});
// Build system prompt with workspace context (IDENTITY.md, AGENTS.md, etc.)
let system_prompt = {
let config_guard = state.config.lock();
crate::channels::build_system_prompt(
&config_guard.workspace_dir,
&state.model,
&[], // tools - empty for simple chat
&[], // skills
Some(&config_guard.identity),
None, // bootstrap_max_chars - use default
)
};
// Call the LLM with separate system prompt
match state
.provider
.chat_with_system(
Some(&system_prompt),
message,
&state.model,
state.temperature,
)
.await
{
match run_gateway_chat_with_multimodal(&state, &provider_label, message).await {
Ok(response) => {
let duration = started_at.elapsed();
state
@ -994,6 +1017,12 @@ async fn handle_whatsapp_message(
}
// Process each message
let provider_label = state
.config
.lock()
.default_provider
.clone()
.unwrap_or_else(|| "unknown".to_string());
for msg in &messages {
tracing::info!(
"WhatsApp message from {}: {}",
@ -1010,30 +1039,7 @@ async fn handle_whatsapp_message(
.await;
}
// Build system prompt with workspace context (IDENTITY.md, AGENTS.md, etc.)
let system_prompt = {
let config_guard = state.config.lock();
crate::channels::build_system_prompt(
&config_guard.workspace_dir,
&state.model,
&[], // tools - empty for simple chat
&[], // skills
Some(&config_guard.identity),
None, // bootstrap_max_chars - use default
)
};
// Call the LLM with separate system prompt
match state
.provider
.chat_with_system(
Some(&system_prompt),
&msg.content,
&state.model,
state.temperature,
)
.await
{
match run_gateway_chat_with_multimodal(&state, &provider_label, &msg.content).await {
Ok(response) => {
// Send reply via WhatsApp
if let Err(e) = wa
@ -1124,6 +1130,12 @@ async fn handle_linq_webhook(
}
// Process each message
let provider_label = state
.config
.lock()
.default_provider
.clone()
.unwrap_or_else(|| "unknown".to_string());
for msg in &messages {
tracing::info!(
"Linq message from {}: {}",
@ -1141,11 +1153,7 @@ async fn handle_linq_webhook(
}
// Call the LLM
match state
.provider
.simple_chat(&msg.content, &state.model, state.temperature)
.await
{
match run_gateway_chat_with_multimodal(&state, &provider_label, &msg.content).await {
Ok(response) => {
// Send reply via Linq
if let Err(e) = linq

View file

@ -55,6 +55,7 @@ pub mod identity;
pub mod integrations;
pub mod memory;
pub mod migration;
pub mod multimodal;
pub mod observability;
pub mod onboard;
pub mod peripherals;

View file

@ -58,6 +58,7 @@ mod identity;
mod integrations;
mod memory;
mod migration;
mod multimodal;
mod observability;
mod onboard;
mod peripherals;

568
src/multimodal.rs Normal file
View file

@ -0,0 +1,568 @@
use crate::config::{build_runtime_proxy_client_with_timeouts, MultimodalConfig};
use crate::providers::ChatMessage;
use base64::{engine::general_purpose::STANDARD, Engine as _};
use reqwest::Client;
use std::path::Path;
const IMAGE_MARKER_PREFIX: &str = "[IMAGE:";
const ALLOWED_IMAGE_MIME_TYPES: &[&str] = &[
"image/png",
"image/jpeg",
"image/webp",
"image/gif",
"image/bmp",
];
#[derive(Debug, Clone)]
pub struct PreparedMessages {
pub messages: Vec<ChatMessage>,
pub contains_images: bool,
}
#[derive(Debug, thiserror::Error)]
pub enum MultimodalError {
#[error("multimodal image limit exceeded: max_images={max_images}, found={found}")]
TooManyImages { max_images: usize, found: usize },
#[error("multimodal image size limit exceeded for '{input}': {size_bytes} bytes > {max_bytes} bytes")]
ImageTooLarge {
input: String,
size_bytes: usize,
max_bytes: usize,
},
#[error("multimodal image MIME type is not allowed for '{input}': {mime}")]
UnsupportedMime { input: String, mime: String },
#[error("multimodal remote image fetch is disabled for '{input}'")]
RemoteFetchDisabled { input: String },
#[error("multimodal image source not found or unreadable: '{input}'")]
ImageSourceNotFound { input: String },
#[error("invalid multimodal image marker '{input}': {reason}")]
InvalidMarker { input: String, reason: String },
#[error("failed to download remote image '{input}': {reason}")]
RemoteFetchFailed { input: String, reason: String },
#[error("failed to read local image '{input}': {reason}")]
LocalReadFailed { input: String, reason: String },
}
pub fn parse_image_markers(content: &str) -> (String, Vec<String>) {
let mut refs = Vec::new();
let mut cleaned = String::with_capacity(content.len());
let mut cursor = 0usize;
while let Some(rel_start) = content[cursor..].find(IMAGE_MARKER_PREFIX) {
let start = cursor + rel_start;
cleaned.push_str(&content[cursor..start]);
let marker_start = start + IMAGE_MARKER_PREFIX.len();
let Some(rel_end) = content[marker_start..].find(']') else {
cleaned.push_str(&content[start..]);
cursor = content.len();
break;
};
let end = marker_start + rel_end;
let candidate = content[marker_start..end].trim();
if candidate.is_empty() {
cleaned.push_str(&content[start..=end]);
} else {
refs.push(candidate.to_string());
}
cursor = end + 1;
}
if cursor < content.len() {
cleaned.push_str(&content[cursor..]);
}
(cleaned.trim().to_string(), refs)
}
pub fn count_image_markers(messages: &[ChatMessage]) -> usize {
messages
.iter()
.filter(|m| m.role == "user")
.map(|m| parse_image_markers(&m.content).1.len())
.sum()
}
pub fn contains_image_markers(messages: &[ChatMessage]) -> bool {
count_image_markers(messages) > 0
}
pub fn extract_ollama_image_payload(image_ref: &str) -> Option<String> {
if image_ref.starts_with("data:") {
let comma_idx = image_ref.find(',')?;
let (_, payload) = image_ref.split_at(comma_idx + 1);
let payload = payload.trim();
if payload.is_empty() {
None
} else {
Some(payload.to_string())
}
} else {
Some(image_ref.trim().to_string()).filter(|value| !value.is_empty())
}
}
pub async fn prepare_messages_for_provider(
messages: &[ChatMessage],
config: &MultimodalConfig,
) -> anyhow::Result<PreparedMessages> {
let (max_images, max_image_size_mb) = config.effective_limits();
let max_bytes = max_image_size_mb.saturating_mul(1024 * 1024);
let found_images = count_image_markers(messages);
if found_images > max_images {
return Err(MultimodalError::TooManyImages {
max_images,
found: found_images,
}
.into());
}
if found_images == 0 {
return Ok(PreparedMessages {
messages: messages.to_vec(),
contains_images: false,
});
}
let remote_client = build_runtime_proxy_client_with_timeouts("provider.ollama", 30, 10);
let mut normalized_messages = Vec::with_capacity(messages.len());
for message in messages {
if message.role != "user" {
normalized_messages.push(message.clone());
continue;
}
let (cleaned_text, refs) = parse_image_markers(&message.content);
if refs.is_empty() {
normalized_messages.push(message.clone());
continue;
}
let mut normalized_refs = Vec::with_capacity(refs.len());
for reference in refs {
let data_uri =
normalize_image_reference(&reference, config, max_bytes, &remote_client).await?;
normalized_refs.push(data_uri);
}
let content = compose_multimodal_message(&cleaned_text, &normalized_refs);
normalized_messages.push(ChatMessage {
role: message.role.clone(),
content,
});
}
Ok(PreparedMessages {
messages: normalized_messages,
contains_images: true,
})
}
fn compose_multimodal_message(text: &str, data_uris: &[String]) -> String {
let mut content = String::new();
let trimmed = text.trim();
if !trimmed.is_empty() {
content.push_str(trimmed);
content.push_str("\n\n");
}
for (index, data_uri) in data_uris.iter().enumerate() {
if index > 0 {
content.push('\n');
}
content.push_str(IMAGE_MARKER_PREFIX);
content.push_str(data_uri);
content.push(']');
}
content
}
async fn normalize_image_reference(
source: &str,
config: &MultimodalConfig,
max_bytes: usize,
remote_client: &Client,
) -> anyhow::Result<String> {
if source.starts_with("data:") {
return normalize_data_uri(source, max_bytes);
}
if source.starts_with("http://") || source.starts_with("https://") {
if !config.allow_remote_fetch {
return Err(MultimodalError::RemoteFetchDisabled {
input: source.to_string(),
}
.into());
}
return normalize_remote_image(source, max_bytes, remote_client).await;
}
normalize_local_image(source, max_bytes).await
}
fn normalize_data_uri(source: &str, max_bytes: usize) -> anyhow::Result<String> {
let Some(comma_idx) = source.find(',') else {
return Err(MultimodalError::InvalidMarker {
input: source.to_string(),
reason: "expected data URI payload".to_string(),
}
.into());
};
let header = &source[..comma_idx];
let payload = source[comma_idx + 1..].trim();
if !header.contains(";base64") {
return Err(MultimodalError::InvalidMarker {
input: source.to_string(),
reason: "only base64 data URIs are supported".to_string(),
}
.into());
}
let mime = header
.trim_start_matches("data:")
.split(';')
.next()
.unwrap_or_default()
.trim()
.to_ascii_lowercase();
validate_mime(source, &mime)?;
let decoded = STANDARD
.decode(payload)
.map_err(|error| MultimodalError::InvalidMarker {
input: source.to_string(),
reason: format!("invalid base64 payload: {error}"),
})?;
validate_size(source, decoded.len(), max_bytes)?;
Ok(format!("data:{mime};base64,{}", STANDARD.encode(decoded)))
}
async fn normalize_remote_image(
source: &str,
max_bytes: usize,
remote_client: &Client,
) -> anyhow::Result<String> {
let response = remote_client.get(source).send().await.map_err(|error| {
MultimodalError::RemoteFetchFailed {
input: source.to_string(),
reason: error.to_string(),
}
})?;
let status = response.status();
if !status.is_success() {
return Err(MultimodalError::RemoteFetchFailed {
input: source.to_string(),
reason: format!("HTTP {status}"),
}
.into());
}
if let Some(content_length) = response.content_length() {
let content_length = content_length as usize;
validate_size(source, content_length, max_bytes)?;
}
let content_type = response
.headers()
.get(reqwest::header::CONTENT_TYPE)
.and_then(|value| value.to_str().ok())
.map(ToString::to_string);
let bytes = response
.bytes()
.await
.map_err(|error| MultimodalError::RemoteFetchFailed {
input: source.to_string(),
reason: error.to_string(),
})?;
validate_size(source, bytes.len(), max_bytes)?;
let mime = detect_mime(None, bytes.as_ref(), content_type.as_deref()).ok_or_else(|| {
MultimodalError::UnsupportedMime {
input: source.to_string(),
mime: "unknown".to_string(),
}
})?;
validate_mime(source, &mime)?;
Ok(format!("data:{mime};base64,{}", STANDARD.encode(bytes)))
}
async fn normalize_local_image(source: &str, max_bytes: usize) -> anyhow::Result<String> {
let path = Path::new(source);
if !path.exists() || !path.is_file() {
return Err(MultimodalError::ImageSourceNotFound {
input: source.to_string(),
}
.into());
}
let metadata =
tokio::fs::metadata(path)
.await
.map_err(|error| MultimodalError::LocalReadFailed {
input: source.to_string(),
reason: error.to_string(),
})?;
validate_size(source, metadata.len() as usize, max_bytes)?;
let bytes = tokio::fs::read(path)
.await
.map_err(|error| MultimodalError::LocalReadFailed {
input: source.to_string(),
reason: error.to_string(),
})?;
validate_size(source, bytes.len(), max_bytes)?;
let mime =
detect_mime(Some(path), &bytes, None).ok_or_else(|| MultimodalError::UnsupportedMime {
input: source.to_string(),
mime: "unknown".to_string(),
})?;
validate_mime(source, &mime)?;
Ok(format!("data:{mime};base64,{}", STANDARD.encode(bytes)))
}
fn validate_size(source: &str, size_bytes: usize, max_bytes: usize) -> anyhow::Result<()> {
if size_bytes > max_bytes {
return Err(MultimodalError::ImageTooLarge {
input: source.to_string(),
size_bytes,
max_bytes,
}
.into());
}
Ok(())
}
fn validate_mime(source: &str, mime: &str) -> anyhow::Result<()> {
if ALLOWED_IMAGE_MIME_TYPES
.iter()
.any(|allowed| *allowed == mime)
{
return Ok(());
}
Err(MultimodalError::UnsupportedMime {
input: source.to_string(),
mime: mime.to_string(),
}
.into())
}
fn detect_mime(
path: Option<&Path>,
bytes: &[u8],
header_content_type: Option<&str>,
) -> Option<String> {
if let Some(header_mime) = header_content_type.and_then(normalize_content_type) {
return Some(header_mime);
}
if let Some(path) = path {
if let Some(ext) = path.extension().and_then(|value| value.to_str()) {
if let Some(mime) = mime_from_extension(ext) {
return Some(mime.to_string());
}
}
}
mime_from_magic(bytes).map(ToString::to_string)
}
fn normalize_content_type(content_type: &str) -> Option<String> {
let mime = content_type.split(';').next()?.trim().to_ascii_lowercase();
if mime.is_empty() {
None
} else {
Some(mime)
}
}
fn mime_from_extension(ext: &str) -> Option<&'static str> {
match ext.to_ascii_lowercase().as_str() {
"png" => Some("image/png"),
"jpg" | "jpeg" => Some("image/jpeg"),
"webp" => Some("image/webp"),
"gif" => Some("image/gif"),
"bmp" => Some("image/bmp"),
_ => None,
}
}
fn mime_from_magic(bytes: &[u8]) -> Option<&'static str> {
if bytes.len() >= 8 && bytes.starts_with(&[0x89, b'P', b'N', b'G', b'\r', b'\n', 0x1a, b'\n']) {
return Some("image/png");
}
if bytes.len() >= 3 && bytes.starts_with(&[0xff, 0xd8, 0xff]) {
return Some("image/jpeg");
}
if bytes.len() >= 6 && (bytes.starts_with(b"GIF87a") || bytes.starts_with(b"GIF89a")) {
return Some("image/gif");
}
if bytes.len() >= 12 && bytes.starts_with(b"RIFF") && &bytes[8..12] == b"WEBP" {
return Some("image/webp");
}
if bytes.len() >= 2 && bytes.starts_with(b"BM") {
return Some("image/bmp");
}
None
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_image_markers_extracts_multiple_markers() {
let input = "Check this [IMAGE:/tmp/a.png] and this [IMAGE:https://example.com/b.jpg]";
let (cleaned, refs) = parse_image_markers(input);
assert_eq!(cleaned, "Check this and this");
assert_eq!(refs.len(), 2);
assert_eq!(refs[0], "/tmp/a.png");
assert_eq!(refs[1], "https://example.com/b.jpg");
}
#[test]
fn parse_image_markers_keeps_invalid_empty_marker() {
let input = "hello [IMAGE:] world";
let (cleaned, refs) = parse_image_markers(input);
assert_eq!(cleaned, "hello [IMAGE:] world");
assert!(refs.is_empty());
}
#[tokio::test]
async fn prepare_messages_normalizes_local_image_to_data_uri() {
let temp = tempfile::tempdir().unwrap();
let image_path = temp.path().join("sample.png");
// Minimal PNG signature bytes are enough for MIME detection.
std::fs::write(
&image_path,
[0x89, b'P', b'N', b'G', b'\r', b'\n', 0x1a, b'\n'],
)
.unwrap();
let messages = vec![ChatMessage::user(format!(
"Please inspect this screenshot [IMAGE:{}]",
image_path.display()
))];
let prepared = prepare_messages_for_provider(&messages, &MultimodalConfig::default())
.await
.unwrap();
assert!(prepared.contains_images);
assert_eq!(prepared.messages.len(), 1);
let (cleaned, refs) = parse_image_markers(&prepared.messages[0].content);
assert_eq!(cleaned, "Please inspect this screenshot");
assert_eq!(refs.len(), 1);
assert!(refs[0].starts_with("data:image/png;base64,"));
}
#[tokio::test]
async fn prepare_messages_rejects_too_many_images() {
let messages = vec![ChatMessage::user(
"[IMAGE:/tmp/1.png]\n[IMAGE:/tmp/2.png]".to_string(),
)];
let config = MultimodalConfig {
max_images: 1,
max_image_size_mb: 5,
allow_remote_fetch: false,
};
let error = prepare_messages_for_provider(&messages, &config)
.await
.expect_err("should reject image count overflow");
assert!(error
.to_string()
.contains("multimodal image limit exceeded"));
}
#[tokio::test]
async fn prepare_messages_rejects_remote_url_when_disabled() {
let messages = vec![ChatMessage::user(
"Look [IMAGE:https://example.com/img.png]".to_string(),
)];
let error = prepare_messages_for_provider(&messages, &MultimodalConfig::default())
.await
.expect_err("should reject remote image URL when fetch is disabled");
assert!(error
.to_string()
.contains("multimodal remote image fetch is disabled"));
}
#[tokio::test]
async fn prepare_messages_rejects_oversized_local_image() {
let temp = tempfile::tempdir().unwrap();
let image_path = temp.path().join("big.png");
let bytes = vec![0u8; 1024 * 1024 + 1];
std::fs::write(&image_path, bytes).unwrap();
let messages = vec![ChatMessage::user(format!(
"[IMAGE:{}]",
image_path.display()
))];
let config = MultimodalConfig {
max_images: 4,
max_image_size_mb: 1,
allow_remote_fetch: false,
};
let error = prepare_messages_for_provider(&messages, &config)
.await
.expect_err("should reject oversized local image");
assert!(error
.to_string()
.contains("multimodal image size limit exceeded"));
}
#[test]
fn extract_ollama_image_payload_supports_data_uris() {
let payload = extract_ollama_image_payload("data:image/png;base64,abcd==")
.expect("payload should be extracted");
assert_eq!(payload, "abcd==");
}
}

View file

@ -173,6 +173,7 @@ pub async fn run_wizard() -> Result<Config> {
secrets: secrets_config,
browser: BrowserConfig::default(),
http_request: crate::config::HttpRequestConfig::default(),
multimodal: crate::config::MultimodalConfig::default(),
web_search: crate::config::WebSearchConfig::default(),
proxy: crate::config::ProxyConfig::default(),
identity: crate::config::IdentityConfig::default(),
@ -391,6 +392,7 @@ pub async fn run_quick_setup(
secrets: SecretsConfig::default(),
browser: BrowserConfig::default(),
http_request: crate::config::HttpRequestConfig::default(),
multimodal: crate::config::MultimodalConfig::default(),
web_search: crate::config::WebSearchConfig::default(),
proxy: crate::config::ProxyConfig::default(),
identity: crate::config::IdentityConfig::default(),

View file

@ -639,6 +639,7 @@ impl Provider for BedrockProvider {
fn capabilities(&self) -> ProviderCapabilities {
ProviderCapabilities {
native_tool_calling: true,
vision: false,
}
}

View file

@ -898,6 +898,7 @@ impl Provider for OpenAiCompatibleProvider {
fn capabilities(&self) -> crate::providers::traits::ProviderCapabilities {
crate::providers::traits::ProviderCapabilities {
native_tool_calling: true,
vision: false,
}
}

View file

@ -13,8 +13,8 @@ pub mod traits;
#[allow(unused_imports)]
pub use traits::{
ChatMessage, ChatRequest, ChatResponse, ConversationMessage, Provider, ToolCall,
ToolResultMessage,
ChatMessage, ChatRequest, ChatResponse, ConversationMessage, Provider, ProviderCapabilityError,
ToolCall, ToolResultMessage,
};
use compatible::{AuthStyle, OpenAiCompatibleProvider};

View file

@ -1,4 +1,7 @@
use crate::providers::traits::{ChatMessage, ChatResponse, Provider, ToolCall};
use crate::multimodal;
use crate::providers::traits::{
ChatMessage, ChatResponse, Provider, ProviderCapabilities, ToolCall,
};
use async_trait::async_trait;
use reqwest::Client;
use serde::{Deserialize, Serialize};
@ -30,6 +33,8 @@ struct Message {
#[serde(skip_serializing_if = "Option::is_none")]
content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
images: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_calls: Option<Vec<OutgoingToolCall>>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_name: Option<String>,
@ -166,6 +171,31 @@ impl OllamaProvider {
}
}
fn convert_user_message_content(&self, content: &str) -> (Option<String>, Option<Vec<String>>) {
let (cleaned, image_refs) = multimodal::parse_image_markers(content);
if image_refs.is_empty() {
return (Some(content.to_string()), None);
}
let images: Vec<String> = image_refs
.iter()
.filter_map(|reference| multimodal::extract_ollama_image_payload(reference))
.collect();
if images.is_empty() {
return (Some(content.to_string()), None);
}
let cleaned = cleaned.trim();
let content = if cleaned.is_empty() {
None
} else {
Some(cleaned.to_string())
};
(content, Some(images))
}
/// Convert internal chat history format to Ollama's native tool-call message schema.
///
/// `run_tool_call_loop` stores native assistant/tool entries as JSON strings in
@ -205,6 +235,7 @@ impl OllamaProvider {
return Message {
role: "assistant".to_string(),
content,
images: None,
tool_calls: Some(outgoing_calls),
tool_name: None,
};
@ -238,15 +269,28 @@ impl OllamaProvider {
return Message {
role: "tool".to_string(),
content,
images: None,
tool_calls: None,
tool_name,
};
}
}
if message.role == "user" {
let (content, images) = self.convert_user_message_content(&message.content);
return Message {
role: "user".to_string(),
content,
images,
tool_calls: None,
tool_name: None,
};
}
Message {
role: message.role.clone(),
content: Some(message.content.clone()),
images: None,
tool_calls: None,
tool_name: None,
}
@ -398,6 +442,13 @@ impl OllamaProvider {
#[async_trait]
impl Provider for OllamaProvider {
fn capabilities(&self) -> ProviderCapabilities {
ProviderCapabilities {
native_tool_calling: true,
vision: true,
}
}
async fn chat_with_system(
&self,
system_prompt: Option<&str>,
@ -413,14 +464,17 @@ impl Provider for OllamaProvider {
messages.push(Message {
role: "system".to_string(),
content: Some(sys.to_string()),
images: None,
tool_calls: None,
tool_name: None,
});
}
let (user_content, user_images) = self.convert_user_message_content(message);
messages.push(Message {
role: "user".to_string(),
content: Some(message.to_string()),
content: user_content,
images: user_images,
tool_calls: None,
tool_name: None,
});
@ -862,4 +916,34 @@ mod tests {
assert_eq!(converted[1].content.as_deref(), Some("ok"));
assert!(converted[1].tool_calls.is_none());
}
#[test]
fn convert_messages_extracts_images_from_user_marker() {
let provider = OllamaProvider::new(None, None);
let messages = vec![ChatMessage {
role: "user".into(),
content: "Inspect this screenshot [IMAGE:data:image/png;base64,abcd==]".into(),
}];
let converted = provider.convert_messages(&messages);
assert_eq!(converted.len(), 1);
assert_eq!(converted[0].role, "user");
assert_eq!(
converted[0].content.as_deref(),
Some("Inspect this screenshot")
);
let images = converted[0]
.images
.as_ref()
.expect("images should be present");
assert_eq!(images, &vec!["abcd==".to_string()]);
}
#[test]
fn capabilities_include_native_tools_and_vision() {
let provider = OllamaProvider::new(None, None);
let caps = <OllamaProvider as Provider>::capabilities(&provider);
assert!(caps.native_tool_calling);
assert!(caps.vision);
}
}

View file

@ -511,6 +511,12 @@ impl Provider for ReliableProvider {
.unwrap_or(false)
}
fn supports_vision(&self) -> bool {
self.providers
.iter()
.any(|(_, provider)| provider.supports_vision())
}
async fn chat_with_tools(
&self,
messages: &[ChatMessage],

View file

@ -158,6 +158,12 @@ impl Provider for RouterProvider {
.unwrap_or(false)
}
fn supports_vision(&self) -> bool {
self.providers
.iter()
.any(|(_, provider)| provider.supports_vision())
}
async fn warmup(&self) -> anyhow::Result<()> {
for (name, provider) in &self.providers {
tracing::info!(provider = name, "Warming up routed provider");

View file

@ -192,6 +192,15 @@ pub enum StreamError {
Io(#[from] std::io::Error),
}
/// Structured error returned when a requested capability is not supported.
#[derive(Debug, Clone, thiserror::Error)]
#[error("provider_capability_error provider={provider} capability={capability} message={message}")]
pub struct ProviderCapabilityError {
pub provider: String,
pub capability: String,
pub message: String,
}
/// Provider capabilities declaration.
///
/// Describes what features a provider supports, enabling intelligent
@ -205,6 +214,8 @@ pub struct ProviderCapabilities {
///
/// When `false`, tools must be injected via system prompt as text.
pub native_tool_calling: bool,
/// Whether the provider supports vision / image inputs.
pub vision: bool,
}
/// Provider-specific tool payload formats.
@ -351,6 +362,11 @@ pub trait Provider: Send + Sync {
self.capabilities().native_tool_calling
}
/// Whether provider supports multimodal vision input.
fn supports_vision(&self) -> bool {
self.capabilities().vision
}
/// Warm up the HTTP connection pool (TLS handshake, DNS, HTTP/2 setup).
/// Default implementation is a no-op; providers with HTTP clients should override.
async fn warmup(&self) -> anyhow::Result<()> {
@ -458,6 +474,7 @@ mod tests {
fn capabilities(&self) -> ProviderCapabilities {
ProviderCapabilities {
native_tool_calling: true,
vision: true,
}
}
@ -539,18 +556,22 @@ mod tests {
fn provider_capabilities_default() {
let caps = ProviderCapabilities::default();
assert!(!caps.native_tool_calling);
assert!(!caps.vision);
}
#[test]
fn provider_capabilities_equality() {
let caps1 = ProviderCapabilities {
native_tool_calling: true,
vision: false,
};
let caps2 = ProviderCapabilities {
native_tool_calling: true,
vision: false,
};
let caps3 = ProviderCapabilities {
native_tool_calling: false,
vision: false,
};
assert_eq!(caps1, caps2);
@ -563,6 +584,12 @@ mod tests {
assert!(provider.supports_native_tools());
}
#[test]
fn supports_vision_reflects_capabilities_default_mapping() {
let provider = CapabilityMockProvider;
assert!(provider.supports_vision());
}
#[test]
fn tools_payload_variants() {
// Test Gemini variant

View file

@ -235,7 +235,9 @@ impl ProxyConfigTool {
}
if args.get("enabled").is_none() && touched_proxy_url {
proxy.enabled = true;
// Keep auto-enable behavior when users provide a proxy URL, but
// auto-disable when all proxy URLs are cleared in the same update.
proxy.enabled = proxy.has_any_proxy_url();
}
proxy.no_proxy = proxy.normalized_no_proxy();