feat: add multimodal image marker support with Ollama vision
This commit is contained in:
parent
63aacb09ff
commit
dcd0bf641d
21 changed files with 1152 additions and 78 deletions
568
src/multimodal.rs
Normal file
568
src/multimodal.rs
Normal file
|
|
@ -0,0 +1,568 @@
|
|||
use crate::config::{build_runtime_proxy_client_with_timeouts, MultimodalConfig};
|
||||
use crate::providers::ChatMessage;
|
||||
use base64::{engine::general_purpose::STANDARD, Engine as _};
|
||||
use reqwest::Client;
|
||||
use std::path::Path;
|
||||
|
||||
const IMAGE_MARKER_PREFIX: &str = "[IMAGE:";
|
||||
const ALLOWED_IMAGE_MIME_TYPES: &[&str] = &[
|
||||
"image/png",
|
||||
"image/jpeg",
|
||||
"image/webp",
|
||||
"image/gif",
|
||||
"image/bmp",
|
||||
];
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PreparedMessages {
|
||||
pub messages: Vec<ChatMessage>,
|
||||
pub contains_images: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum MultimodalError {
|
||||
#[error("multimodal image limit exceeded: max_images={max_images}, found={found}")]
|
||||
TooManyImages { max_images: usize, found: usize },
|
||||
|
||||
#[error("multimodal image size limit exceeded for '{input}': {size_bytes} bytes > {max_bytes} bytes")]
|
||||
ImageTooLarge {
|
||||
input: String,
|
||||
size_bytes: usize,
|
||||
max_bytes: usize,
|
||||
},
|
||||
|
||||
#[error("multimodal image MIME type is not allowed for '{input}': {mime}")]
|
||||
UnsupportedMime { input: String, mime: String },
|
||||
|
||||
#[error("multimodal remote image fetch is disabled for '{input}'")]
|
||||
RemoteFetchDisabled { input: String },
|
||||
|
||||
#[error("multimodal image source not found or unreadable: '{input}'")]
|
||||
ImageSourceNotFound { input: String },
|
||||
|
||||
#[error("invalid multimodal image marker '{input}': {reason}")]
|
||||
InvalidMarker { input: String, reason: String },
|
||||
|
||||
#[error("failed to download remote image '{input}': {reason}")]
|
||||
RemoteFetchFailed { input: String, reason: String },
|
||||
|
||||
#[error("failed to read local image '{input}': {reason}")]
|
||||
LocalReadFailed { input: String, reason: String },
|
||||
}
|
||||
|
||||
pub fn parse_image_markers(content: &str) -> (String, Vec<String>) {
|
||||
let mut refs = Vec::new();
|
||||
let mut cleaned = String::with_capacity(content.len());
|
||||
let mut cursor = 0usize;
|
||||
|
||||
while let Some(rel_start) = content[cursor..].find(IMAGE_MARKER_PREFIX) {
|
||||
let start = cursor + rel_start;
|
||||
cleaned.push_str(&content[cursor..start]);
|
||||
|
||||
let marker_start = start + IMAGE_MARKER_PREFIX.len();
|
||||
let Some(rel_end) = content[marker_start..].find(']') else {
|
||||
cleaned.push_str(&content[start..]);
|
||||
cursor = content.len();
|
||||
break;
|
||||
};
|
||||
|
||||
let end = marker_start + rel_end;
|
||||
let candidate = content[marker_start..end].trim();
|
||||
|
||||
if candidate.is_empty() {
|
||||
cleaned.push_str(&content[start..=end]);
|
||||
} else {
|
||||
refs.push(candidate.to_string());
|
||||
}
|
||||
|
||||
cursor = end + 1;
|
||||
}
|
||||
|
||||
if cursor < content.len() {
|
||||
cleaned.push_str(&content[cursor..]);
|
||||
}
|
||||
|
||||
(cleaned.trim().to_string(), refs)
|
||||
}
|
||||
|
||||
pub fn count_image_markers(messages: &[ChatMessage]) -> usize {
|
||||
messages
|
||||
.iter()
|
||||
.filter(|m| m.role == "user")
|
||||
.map(|m| parse_image_markers(&m.content).1.len())
|
||||
.sum()
|
||||
}
|
||||
|
||||
pub fn contains_image_markers(messages: &[ChatMessage]) -> bool {
|
||||
count_image_markers(messages) > 0
|
||||
}
|
||||
|
||||
pub fn extract_ollama_image_payload(image_ref: &str) -> Option<String> {
|
||||
if image_ref.starts_with("data:") {
|
||||
let comma_idx = image_ref.find(',')?;
|
||||
let (_, payload) = image_ref.split_at(comma_idx + 1);
|
||||
let payload = payload.trim();
|
||||
if payload.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(payload.to_string())
|
||||
}
|
||||
} else {
|
||||
Some(image_ref.trim().to_string()).filter(|value| !value.is_empty())
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn prepare_messages_for_provider(
|
||||
messages: &[ChatMessage],
|
||||
config: &MultimodalConfig,
|
||||
) -> anyhow::Result<PreparedMessages> {
|
||||
let (max_images, max_image_size_mb) = config.effective_limits();
|
||||
let max_bytes = max_image_size_mb.saturating_mul(1024 * 1024);
|
||||
|
||||
let found_images = count_image_markers(messages);
|
||||
if found_images > max_images {
|
||||
return Err(MultimodalError::TooManyImages {
|
||||
max_images,
|
||||
found: found_images,
|
||||
}
|
||||
.into());
|
||||
}
|
||||
|
||||
if found_images == 0 {
|
||||
return Ok(PreparedMessages {
|
||||
messages: messages.to_vec(),
|
||||
contains_images: false,
|
||||
});
|
||||
}
|
||||
|
||||
let remote_client = build_runtime_proxy_client_with_timeouts("provider.ollama", 30, 10);
|
||||
|
||||
let mut normalized_messages = Vec::with_capacity(messages.len());
|
||||
for message in messages {
|
||||
if message.role != "user" {
|
||||
normalized_messages.push(message.clone());
|
||||
continue;
|
||||
}
|
||||
|
||||
let (cleaned_text, refs) = parse_image_markers(&message.content);
|
||||
if refs.is_empty() {
|
||||
normalized_messages.push(message.clone());
|
||||
continue;
|
||||
}
|
||||
|
||||
let mut normalized_refs = Vec::with_capacity(refs.len());
|
||||
for reference in refs {
|
||||
let data_uri =
|
||||
normalize_image_reference(&reference, config, max_bytes, &remote_client).await?;
|
||||
normalized_refs.push(data_uri);
|
||||
}
|
||||
|
||||
let content = compose_multimodal_message(&cleaned_text, &normalized_refs);
|
||||
normalized_messages.push(ChatMessage {
|
||||
role: message.role.clone(),
|
||||
content,
|
||||
});
|
||||
}
|
||||
|
||||
Ok(PreparedMessages {
|
||||
messages: normalized_messages,
|
||||
contains_images: true,
|
||||
})
|
||||
}
|
||||
|
||||
fn compose_multimodal_message(text: &str, data_uris: &[String]) -> String {
|
||||
let mut content = String::new();
|
||||
let trimmed = text.trim();
|
||||
|
||||
if !trimmed.is_empty() {
|
||||
content.push_str(trimmed);
|
||||
content.push_str("\n\n");
|
||||
}
|
||||
|
||||
for (index, data_uri) in data_uris.iter().enumerate() {
|
||||
if index > 0 {
|
||||
content.push('\n');
|
||||
}
|
||||
content.push_str(IMAGE_MARKER_PREFIX);
|
||||
content.push_str(data_uri);
|
||||
content.push(']');
|
||||
}
|
||||
|
||||
content
|
||||
}
|
||||
|
||||
async fn normalize_image_reference(
|
||||
source: &str,
|
||||
config: &MultimodalConfig,
|
||||
max_bytes: usize,
|
||||
remote_client: &Client,
|
||||
) -> anyhow::Result<String> {
|
||||
if source.starts_with("data:") {
|
||||
return normalize_data_uri(source, max_bytes);
|
||||
}
|
||||
|
||||
if source.starts_with("http://") || source.starts_with("https://") {
|
||||
if !config.allow_remote_fetch {
|
||||
return Err(MultimodalError::RemoteFetchDisabled {
|
||||
input: source.to_string(),
|
||||
}
|
||||
.into());
|
||||
}
|
||||
|
||||
return normalize_remote_image(source, max_bytes, remote_client).await;
|
||||
}
|
||||
|
||||
normalize_local_image(source, max_bytes).await
|
||||
}
|
||||
|
||||
fn normalize_data_uri(source: &str, max_bytes: usize) -> anyhow::Result<String> {
|
||||
let Some(comma_idx) = source.find(',') else {
|
||||
return Err(MultimodalError::InvalidMarker {
|
||||
input: source.to_string(),
|
||||
reason: "expected data URI payload".to_string(),
|
||||
}
|
||||
.into());
|
||||
};
|
||||
|
||||
let header = &source[..comma_idx];
|
||||
let payload = source[comma_idx + 1..].trim();
|
||||
|
||||
if !header.contains(";base64") {
|
||||
return Err(MultimodalError::InvalidMarker {
|
||||
input: source.to_string(),
|
||||
reason: "only base64 data URIs are supported".to_string(),
|
||||
}
|
||||
.into());
|
||||
}
|
||||
|
||||
let mime = header
|
||||
.trim_start_matches("data:")
|
||||
.split(';')
|
||||
.next()
|
||||
.unwrap_or_default()
|
||||
.trim()
|
||||
.to_ascii_lowercase();
|
||||
|
||||
validate_mime(source, &mime)?;
|
||||
|
||||
let decoded = STANDARD
|
||||
.decode(payload)
|
||||
.map_err(|error| MultimodalError::InvalidMarker {
|
||||
input: source.to_string(),
|
||||
reason: format!("invalid base64 payload: {error}"),
|
||||
})?;
|
||||
|
||||
validate_size(source, decoded.len(), max_bytes)?;
|
||||
|
||||
Ok(format!("data:{mime};base64,{}", STANDARD.encode(decoded)))
|
||||
}
|
||||
|
||||
async fn normalize_remote_image(
|
||||
source: &str,
|
||||
max_bytes: usize,
|
||||
remote_client: &Client,
|
||||
) -> anyhow::Result<String> {
|
||||
let response = remote_client.get(source).send().await.map_err(|error| {
|
||||
MultimodalError::RemoteFetchFailed {
|
||||
input: source.to_string(),
|
||||
reason: error.to_string(),
|
||||
}
|
||||
})?;
|
||||
|
||||
let status = response.status();
|
||||
if !status.is_success() {
|
||||
return Err(MultimodalError::RemoteFetchFailed {
|
||||
input: source.to_string(),
|
||||
reason: format!("HTTP {status}"),
|
||||
}
|
||||
.into());
|
||||
}
|
||||
|
||||
if let Some(content_length) = response.content_length() {
|
||||
let content_length = content_length as usize;
|
||||
validate_size(source, content_length, max_bytes)?;
|
||||
}
|
||||
|
||||
let content_type = response
|
||||
.headers()
|
||||
.get(reqwest::header::CONTENT_TYPE)
|
||||
.and_then(|value| value.to_str().ok())
|
||||
.map(ToString::to_string);
|
||||
|
||||
let bytes = response
|
||||
.bytes()
|
||||
.await
|
||||
.map_err(|error| MultimodalError::RemoteFetchFailed {
|
||||
input: source.to_string(),
|
||||
reason: error.to_string(),
|
||||
})?;
|
||||
|
||||
validate_size(source, bytes.len(), max_bytes)?;
|
||||
|
||||
let mime = detect_mime(None, bytes.as_ref(), content_type.as_deref()).ok_or_else(|| {
|
||||
MultimodalError::UnsupportedMime {
|
||||
input: source.to_string(),
|
||||
mime: "unknown".to_string(),
|
||||
}
|
||||
})?;
|
||||
|
||||
validate_mime(source, &mime)?;
|
||||
|
||||
Ok(format!("data:{mime};base64,{}", STANDARD.encode(bytes)))
|
||||
}
|
||||
|
||||
async fn normalize_local_image(source: &str, max_bytes: usize) -> anyhow::Result<String> {
|
||||
let path = Path::new(source);
|
||||
if !path.exists() || !path.is_file() {
|
||||
return Err(MultimodalError::ImageSourceNotFound {
|
||||
input: source.to_string(),
|
||||
}
|
||||
.into());
|
||||
}
|
||||
|
||||
let metadata =
|
||||
tokio::fs::metadata(path)
|
||||
.await
|
||||
.map_err(|error| MultimodalError::LocalReadFailed {
|
||||
input: source.to_string(),
|
||||
reason: error.to_string(),
|
||||
})?;
|
||||
|
||||
validate_size(source, metadata.len() as usize, max_bytes)?;
|
||||
|
||||
let bytes = tokio::fs::read(path)
|
||||
.await
|
||||
.map_err(|error| MultimodalError::LocalReadFailed {
|
||||
input: source.to_string(),
|
||||
reason: error.to_string(),
|
||||
})?;
|
||||
|
||||
validate_size(source, bytes.len(), max_bytes)?;
|
||||
|
||||
let mime =
|
||||
detect_mime(Some(path), &bytes, None).ok_or_else(|| MultimodalError::UnsupportedMime {
|
||||
input: source.to_string(),
|
||||
mime: "unknown".to_string(),
|
||||
})?;
|
||||
|
||||
validate_mime(source, &mime)?;
|
||||
|
||||
Ok(format!("data:{mime};base64,{}", STANDARD.encode(bytes)))
|
||||
}
|
||||
|
||||
fn validate_size(source: &str, size_bytes: usize, max_bytes: usize) -> anyhow::Result<()> {
|
||||
if size_bytes > max_bytes {
|
||||
return Err(MultimodalError::ImageTooLarge {
|
||||
input: source.to_string(),
|
||||
size_bytes,
|
||||
max_bytes,
|
||||
}
|
||||
.into());
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn validate_mime(source: &str, mime: &str) -> anyhow::Result<()> {
|
||||
if ALLOWED_IMAGE_MIME_TYPES
|
||||
.iter()
|
||||
.any(|allowed| *allowed == mime)
|
||||
{
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
Err(MultimodalError::UnsupportedMime {
|
||||
input: source.to_string(),
|
||||
mime: mime.to_string(),
|
||||
}
|
||||
.into())
|
||||
}
|
||||
|
||||
fn detect_mime(
|
||||
path: Option<&Path>,
|
||||
bytes: &[u8],
|
||||
header_content_type: Option<&str>,
|
||||
) -> Option<String> {
|
||||
if let Some(header_mime) = header_content_type.and_then(normalize_content_type) {
|
||||
return Some(header_mime);
|
||||
}
|
||||
|
||||
if let Some(path) = path {
|
||||
if let Some(ext) = path.extension().and_then(|value| value.to_str()) {
|
||||
if let Some(mime) = mime_from_extension(ext) {
|
||||
return Some(mime.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
mime_from_magic(bytes).map(ToString::to_string)
|
||||
}
|
||||
|
||||
fn normalize_content_type(content_type: &str) -> Option<String> {
|
||||
let mime = content_type.split(';').next()?.trim().to_ascii_lowercase();
|
||||
if mime.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(mime)
|
||||
}
|
||||
}
|
||||
|
||||
fn mime_from_extension(ext: &str) -> Option<&'static str> {
|
||||
match ext.to_ascii_lowercase().as_str() {
|
||||
"png" => Some("image/png"),
|
||||
"jpg" | "jpeg" => Some("image/jpeg"),
|
||||
"webp" => Some("image/webp"),
|
||||
"gif" => Some("image/gif"),
|
||||
"bmp" => Some("image/bmp"),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn mime_from_magic(bytes: &[u8]) -> Option<&'static str> {
|
||||
if bytes.len() >= 8 && bytes.starts_with(&[0x89, b'P', b'N', b'G', b'\r', b'\n', 0x1a, b'\n']) {
|
||||
return Some("image/png");
|
||||
}
|
||||
|
||||
if bytes.len() >= 3 && bytes.starts_with(&[0xff, 0xd8, 0xff]) {
|
||||
return Some("image/jpeg");
|
||||
}
|
||||
|
||||
if bytes.len() >= 6 && (bytes.starts_with(b"GIF87a") || bytes.starts_with(b"GIF89a")) {
|
||||
return Some("image/gif");
|
||||
}
|
||||
|
||||
if bytes.len() >= 12 && bytes.starts_with(b"RIFF") && &bytes[8..12] == b"WEBP" {
|
||||
return Some("image/webp");
|
||||
}
|
||||
|
||||
if bytes.len() >= 2 && bytes.starts_with(b"BM") {
|
||||
return Some("image/bmp");
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn parse_image_markers_extracts_multiple_markers() {
|
||||
let input = "Check this [IMAGE:/tmp/a.png] and this [IMAGE:https://example.com/b.jpg]";
|
||||
let (cleaned, refs) = parse_image_markers(input);
|
||||
|
||||
assert_eq!(cleaned, "Check this and this");
|
||||
assert_eq!(refs.len(), 2);
|
||||
assert_eq!(refs[0], "/tmp/a.png");
|
||||
assert_eq!(refs[1], "https://example.com/b.jpg");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_image_markers_keeps_invalid_empty_marker() {
|
||||
let input = "hello [IMAGE:] world";
|
||||
let (cleaned, refs) = parse_image_markers(input);
|
||||
|
||||
assert_eq!(cleaned, "hello [IMAGE:] world");
|
||||
assert!(refs.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn prepare_messages_normalizes_local_image_to_data_uri() {
|
||||
let temp = tempfile::tempdir().unwrap();
|
||||
let image_path = temp.path().join("sample.png");
|
||||
|
||||
// Minimal PNG signature bytes are enough for MIME detection.
|
||||
std::fs::write(
|
||||
&image_path,
|
||||
[0x89, b'P', b'N', b'G', b'\r', b'\n', 0x1a, b'\n'],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let messages = vec![ChatMessage::user(format!(
|
||||
"Please inspect this screenshot [IMAGE:{}]",
|
||||
image_path.display()
|
||||
))];
|
||||
|
||||
let prepared = prepare_messages_for_provider(&messages, &MultimodalConfig::default())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(prepared.contains_images);
|
||||
assert_eq!(prepared.messages.len(), 1);
|
||||
|
||||
let (cleaned, refs) = parse_image_markers(&prepared.messages[0].content);
|
||||
assert_eq!(cleaned, "Please inspect this screenshot");
|
||||
assert_eq!(refs.len(), 1);
|
||||
assert!(refs[0].starts_with("data:image/png;base64,"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn prepare_messages_rejects_too_many_images() {
|
||||
let messages = vec![ChatMessage::user(
|
||||
"[IMAGE:/tmp/1.png]\n[IMAGE:/tmp/2.png]".to_string(),
|
||||
)];
|
||||
|
||||
let config = MultimodalConfig {
|
||||
max_images: 1,
|
||||
max_image_size_mb: 5,
|
||||
allow_remote_fetch: false,
|
||||
};
|
||||
|
||||
let error = prepare_messages_for_provider(&messages, &config)
|
||||
.await
|
||||
.expect_err("should reject image count overflow");
|
||||
|
||||
assert!(error
|
||||
.to_string()
|
||||
.contains("multimodal image limit exceeded"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn prepare_messages_rejects_remote_url_when_disabled() {
|
||||
let messages = vec![ChatMessage::user(
|
||||
"Look [IMAGE:https://example.com/img.png]".to_string(),
|
||||
)];
|
||||
|
||||
let error = prepare_messages_for_provider(&messages, &MultimodalConfig::default())
|
||||
.await
|
||||
.expect_err("should reject remote image URL when fetch is disabled");
|
||||
|
||||
assert!(error
|
||||
.to_string()
|
||||
.contains("multimodal remote image fetch is disabled"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn prepare_messages_rejects_oversized_local_image() {
|
||||
let temp = tempfile::tempdir().unwrap();
|
||||
let image_path = temp.path().join("big.png");
|
||||
|
||||
let bytes = vec![0u8; 1024 * 1024 + 1];
|
||||
std::fs::write(&image_path, bytes).unwrap();
|
||||
|
||||
let messages = vec![ChatMessage::user(format!(
|
||||
"[IMAGE:{}]",
|
||||
image_path.display()
|
||||
))];
|
||||
let config = MultimodalConfig {
|
||||
max_images: 4,
|
||||
max_image_size_mb: 1,
|
||||
allow_remote_fetch: false,
|
||||
};
|
||||
|
||||
let error = prepare_messages_for_provider(&messages, &config)
|
||||
.await
|
||||
.expect_err("should reject oversized local image");
|
||||
|
||||
assert!(error
|
||||
.to_string()
|
||||
.contains("multimodal image size limit exceeded"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_ollama_image_payload_supports_data_uris() {
|
||||
let payload = extract_ollama_image_payload("data:image/png;base64,abcd==")
|
||||
.expect("payload should be extracted");
|
||||
assert_eq!(payload, "abcd==");
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue