feat: add multimodal image marker support with Ollama vision
This commit is contained in:
parent
63aacb09ff
commit
dcd0bf641d
21 changed files with 1152 additions and 78 deletions
|
|
@ -1,8 +1,11 @@
|
|||
use crate::approval::{ApprovalManager, ApprovalRequest, ApprovalResponse};
|
||||
use crate::config::Config;
|
||||
use crate::memory::{self, Memory, MemoryCategory};
|
||||
use crate::multimodal;
|
||||
use crate::observability::{self, Observer, ObserverEvent};
|
||||
use crate::providers::{self, ChatMessage, ChatRequest, Provider, ToolCall};
|
||||
use crate::providers::{
|
||||
self, ChatMessage, ChatRequest, Provider, ProviderCapabilityError, ToolCall,
|
||||
};
|
||||
use crate::runtime;
|
||||
use crate::security::SecurityPolicy;
|
||||
use crate::tools::{self, Tool};
|
||||
|
|
@ -826,6 +829,7 @@ pub(crate) async fn agent_turn(
|
|||
model: &str,
|
||||
temperature: f64,
|
||||
silent: bool,
|
||||
multimodal_config: &crate::config::MultimodalConfig,
|
||||
max_tool_iterations: usize,
|
||||
) -> Result<String> {
|
||||
run_tool_call_loop(
|
||||
|
|
@ -839,6 +843,7 @@ pub(crate) async fn agent_turn(
|
|||
silent,
|
||||
None,
|
||||
"channel",
|
||||
multimodal_config,
|
||||
max_tool_iterations,
|
||||
None,
|
||||
)
|
||||
|
|
@ -859,6 +864,7 @@ pub(crate) async fn run_tool_call_loop(
|
|||
silent: bool,
|
||||
approval: Option<&ApprovalManager>,
|
||||
channel_name: &str,
|
||||
multimodal_config: &crate::config::MultimodalConfig,
|
||||
max_tool_iterations: usize,
|
||||
on_delta: Option<tokio::sync::mpsc::Sender<String>>,
|
||||
) -> Result<String> {
|
||||
|
|
@ -873,6 +879,21 @@ pub(crate) async fn run_tool_call_loop(
|
|||
let use_native_tools = provider.supports_native_tools() && !tool_specs.is_empty();
|
||||
|
||||
for _iteration in 0..max_iterations {
|
||||
let image_marker_count = multimodal::count_image_markers(history);
|
||||
if image_marker_count > 0 && !provider.supports_vision() {
|
||||
return Err(ProviderCapabilityError {
|
||||
provider: provider_name.to_string(),
|
||||
capability: "vision".to_string(),
|
||||
message: format!(
|
||||
"received {image_marker_count} image marker(s), but this provider does not support vision input"
|
||||
),
|
||||
}
|
||||
.into());
|
||||
}
|
||||
|
||||
let prepared_messages =
|
||||
multimodal::prepare_messages_for_provider(history, multimodal_config).await?;
|
||||
|
||||
observer.record_event(&ObserverEvent::LlmRequest {
|
||||
provider: provider_name.to_string(),
|
||||
model: model.to_string(),
|
||||
|
|
@ -893,7 +914,7 @@ pub(crate) async fn run_tool_call_loop(
|
|||
match provider
|
||||
.chat(
|
||||
ChatRequest {
|
||||
messages: history,
|
||||
messages: &prepared_messages.messages,
|
||||
tools: request_tools,
|
||||
},
|
||||
model,
|
||||
|
|
@ -1404,6 +1425,7 @@ pub async fn run(
|
|||
false,
|
||||
Some(&approval_manager),
|
||||
"cli",
|
||||
&config.multimodal,
|
||||
config.agent.max_tool_iterations,
|
||||
None,
|
||||
)
|
||||
|
|
@ -1530,6 +1552,7 @@ pub async fn run(
|
|||
false,
|
||||
Some(&approval_manager),
|
||||
"cli",
|
||||
&config.multimodal,
|
||||
config.agent.max_tool_iterations,
|
||||
None,
|
||||
)
|
||||
|
|
@ -1757,6 +1780,7 @@ pub async fn process_message(config: Config, message: &str) -> Result<String> {
|
|||
&model_name,
|
||||
config.default_temperature,
|
||||
true,
|
||||
&config.multimodal,
|
||||
config.agent.max_tool_iterations,
|
||||
)
|
||||
.await
|
||||
|
|
@ -1765,6 +1789,10 @@ pub async fn process_message(config: Config, message: &str) -> Result<String> {
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use async_trait::async_trait;
|
||||
use base64::{engine::general_purpose::STANDARD, Engine as _};
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use std::sync::Arc;
|
||||
|
||||
#[test]
|
||||
fn test_scrub_credentials() {
|
||||
|
|
@ -1785,8 +1813,191 @@ mod tests {
|
|||
assert!(scrubbed.contains("public"));
|
||||
}
|
||||
use crate::memory::{Memory, MemoryCategory, SqliteMemory};
|
||||
use crate::observability::NoopObserver;
|
||||
use crate::providers::traits::ProviderCapabilities;
|
||||
use crate::providers::ChatResponse;
|
||||
use tempfile::TempDir;
|
||||
|
||||
struct NonVisionProvider {
|
||||
calls: Arc<AtomicUsize>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Provider for NonVisionProvider {
|
||||
async fn chat_with_system(
|
||||
&self,
|
||||
_system_prompt: Option<&str>,
|
||||
_message: &str,
|
||||
_model: &str,
|
||||
_temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
self.calls.fetch_add(1, Ordering::SeqCst);
|
||||
Ok("ok".to_string())
|
||||
}
|
||||
}
|
||||
|
||||
struct VisionProvider {
|
||||
calls: Arc<AtomicUsize>,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Provider for VisionProvider {
|
||||
fn capabilities(&self) -> ProviderCapabilities {
|
||||
ProviderCapabilities {
|
||||
native_tool_calling: false,
|
||||
vision: true,
|
||||
}
|
||||
}
|
||||
|
||||
async fn chat_with_system(
|
||||
&self,
|
||||
_system_prompt: Option<&str>,
|
||||
_message: &str,
|
||||
_model: &str,
|
||||
_temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
self.calls.fetch_add(1, Ordering::SeqCst);
|
||||
Ok("ok".to_string())
|
||||
}
|
||||
|
||||
async fn chat(
|
||||
&self,
|
||||
request: ChatRequest<'_>,
|
||||
_model: &str,
|
||||
_temperature: f64,
|
||||
) -> anyhow::Result<ChatResponse> {
|
||||
self.calls.fetch_add(1, Ordering::SeqCst);
|
||||
let marker_count = crate::multimodal::count_image_markers(request.messages);
|
||||
if marker_count == 0 {
|
||||
anyhow::bail!("expected image markers in request messages");
|
||||
}
|
||||
|
||||
if request.tools.is_some() {
|
||||
anyhow::bail!("no tools should be attached for this test");
|
||||
}
|
||||
|
||||
Ok(ChatResponse {
|
||||
text: Some("vision-ok".to_string()),
|
||||
tool_calls: Vec::new(),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn run_tool_call_loop_returns_structured_error_for_non_vision_provider() {
|
||||
let calls = Arc::new(AtomicUsize::new(0));
|
||||
let provider = NonVisionProvider {
|
||||
calls: Arc::clone(&calls),
|
||||
};
|
||||
|
||||
let mut history = vec![ChatMessage::user(
|
||||
"please inspect [IMAGE:data:image/png;base64,iVBORw0KGgo=]".to_string(),
|
||||
)];
|
||||
let tools_registry: Vec<Box<dyn Tool>> = Vec::new();
|
||||
let observer = NoopObserver;
|
||||
|
||||
let err = run_tool_call_loop(
|
||||
&provider,
|
||||
&mut history,
|
||||
&tools_registry,
|
||||
&observer,
|
||||
"mock-provider",
|
||||
"mock-model",
|
||||
0.0,
|
||||
true,
|
||||
None,
|
||||
"cli",
|
||||
&crate::config::MultimodalConfig::default(),
|
||||
3,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect_err("provider without vision support should fail");
|
||||
|
||||
assert!(err.to_string().contains("provider_capability_error"));
|
||||
assert!(err.to_string().contains("capability=vision"));
|
||||
assert_eq!(calls.load(Ordering::SeqCst), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn run_tool_call_loop_rejects_oversized_image_payload() {
|
||||
let calls = Arc::new(AtomicUsize::new(0));
|
||||
let provider = VisionProvider {
|
||||
calls: Arc::clone(&calls),
|
||||
};
|
||||
|
||||
let oversized_payload = STANDARD.encode(vec![0_u8; (1024 * 1024) + 1]);
|
||||
let mut history = vec![ChatMessage::user(format!(
|
||||
"[IMAGE:data:image/png;base64,{oversized_payload}]"
|
||||
))];
|
||||
|
||||
let tools_registry: Vec<Box<dyn Tool>> = Vec::new();
|
||||
let observer = NoopObserver;
|
||||
let multimodal = crate::config::MultimodalConfig {
|
||||
max_images: 4,
|
||||
max_image_size_mb: 1,
|
||||
allow_remote_fetch: false,
|
||||
};
|
||||
|
||||
let err = run_tool_call_loop(
|
||||
&provider,
|
||||
&mut history,
|
||||
&tools_registry,
|
||||
&observer,
|
||||
"mock-provider",
|
||||
"mock-model",
|
||||
0.0,
|
||||
true,
|
||||
None,
|
||||
"cli",
|
||||
&multimodal,
|
||||
3,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect_err("oversized payload must fail");
|
||||
|
||||
assert!(err
|
||||
.to_string()
|
||||
.contains("multimodal image size limit exceeded"));
|
||||
assert_eq!(calls.load(Ordering::SeqCst), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn run_tool_call_loop_accepts_valid_multimodal_request_flow() {
|
||||
let calls = Arc::new(AtomicUsize::new(0));
|
||||
let provider = VisionProvider {
|
||||
calls: Arc::clone(&calls),
|
||||
};
|
||||
|
||||
let mut history = vec![ChatMessage::user(
|
||||
"Analyze this [IMAGE:data:image/png;base64,iVBORw0KGgo=]".to_string(),
|
||||
)];
|
||||
let tools_registry: Vec<Box<dyn Tool>> = Vec::new();
|
||||
let observer = NoopObserver;
|
||||
|
||||
let result = run_tool_call_loop(
|
||||
&provider,
|
||||
&mut history,
|
||||
&tools_registry,
|
||||
&observer,
|
||||
"mock-provider",
|
||||
"mock-model",
|
||||
0.0,
|
||||
true,
|
||||
None,
|
||||
"cli",
|
||||
&crate::config::MultimodalConfig::default(),
|
||||
3,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.expect("valid multimodal payload should pass");
|
||||
|
||||
assert_eq!(result, "vision-ok");
|
||||
assert_eq!(calls.load(Ordering::SeqCst), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_tool_calls_extracts_single_call() {
|
||||
let response = r#"Let me check that.
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue