568 lines
16 KiB
Rust
568 lines
16 KiB
Rust
use crate::config::{build_runtime_proxy_client_with_timeouts, MultimodalConfig};
|
|
use crate::providers::ChatMessage;
|
|
use base64::{engine::general_purpose::STANDARD, Engine as _};
|
|
use reqwest::Client;
|
|
use std::path::Path;
|
|
|
|
const IMAGE_MARKER_PREFIX: &str = "[IMAGE:";
|
|
const ALLOWED_IMAGE_MIME_TYPES: &[&str] = &[
|
|
"image/png",
|
|
"image/jpeg",
|
|
"image/webp",
|
|
"image/gif",
|
|
"image/bmp",
|
|
];
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub struct PreparedMessages {
|
|
pub messages: Vec<ChatMessage>,
|
|
pub contains_images: bool,
|
|
}
|
|
|
|
#[derive(Debug, thiserror::Error)]
|
|
pub enum MultimodalError {
|
|
#[error("multimodal image limit exceeded: max_images={max_images}, found={found}")]
|
|
TooManyImages { max_images: usize, found: usize },
|
|
|
|
#[error("multimodal image size limit exceeded for '{input}': {size_bytes} bytes > {max_bytes} bytes")]
|
|
ImageTooLarge {
|
|
input: String,
|
|
size_bytes: usize,
|
|
max_bytes: usize,
|
|
},
|
|
|
|
#[error("multimodal image MIME type is not allowed for '{input}': {mime}")]
|
|
UnsupportedMime { input: String, mime: String },
|
|
|
|
#[error("multimodal remote image fetch is disabled for '{input}'")]
|
|
RemoteFetchDisabled { input: String },
|
|
|
|
#[error("multimodal image source not found or unreadable: '{input}'")]
|
|
ImageSourceNotFound { input: String },
|
|
|
|
#[error("invalid multimodal image marker '{input}': {reason}")]
|
|
InvalidMarker { input: String, reason: String },
|
|
|
|
#[error("failed to download remote image '{input}': {reason}")]
|
|
RemoteFetchFailed { input: String, reason: String },
|
|
|
|
#[error("failed to read local image '{input}': {reason}")]
|
|
LocalReadFailed { input: String, reason: String },
|
|
}
|
|
|
|
pub fn parse_image_markers(content: &str) -> (String, Vec<String>) {
|
|
let mut refs = Vec::new();
|
|
let mut cleaned = String::with_capacity(content.len());
|
|
let mut cursor = 0usize;
|
|
|
|
while let Some(rel_start) = content[cursor..].find(IMAGE_MARKER_PREFIX) {
|
|
let start = cursor + rel_start;
|
|
cleaned.push_str(&content[cursor..start]);
|
|
|
|
let marker_start = start + IMAGE_MARKER_PREFIX.len();
|
|
let Some(rel_end) = content[marker_start..].find(']') else {
|
|
cleaned.push_str(&content[start..]);
|
|
cursor = content.len();
|
|
break;
|
|
};
|
|
|
|
let end = marker_start + rel_end;
|
|
let candidate = content[marker_start..end].trim();
|
|
|
|
if candidate.is_empty() {
|
|
cleaned.push_str(&content[start..=end]);
|
|
} else {
|
|
refs.push(candidate.to_string());
|
|
}
|
|
|
|
cursor = end + 1;
|
|
}
|
|
|
|
if cursor < content.len() {
|
|
cleaned.push_str(&content[cursor..]);
|
|
}
|
|
|
|
(cleaned.trim().to_string(), refs)
|
|
}
|
|
|
|
pub fn count_image_markers(messages: &[ChatMessage]) -> usize {
|
|
messages
|
|
.iter()
|
|
.filter(|m| m.role == "user")
|
|
.map(|m| parse_image_markers(&m.content).1.len())
|
|
.sum()
|
|
}
|
|
|
|
pub fn contains_image_markers(messages: &[ChatMessage]) -> bool {
|
|
count_image_markers(messages) > 0
|
|
}
|
|
|
|
pub fn extract_ollama_image_payload(image_ref: &str) -> Option<String> {
|
|
if image_ref.starts_with("data:") {
|
|
let comma_idx = image_ref.find(',')?;
|
|
let (_, payload) = image_ref.split_at(comma_idx + 1);
|
|
let payload = payload.trim();
|
|
if payload.is_empty() {
|
|
None
|
|
} else {
|
|
Some(payload.to_string())
|
|
}
|
|
} else {
|
|
Some(image_ref.trim().to_string()).filter(|value| !value.is_empty())
|
|
}
|
|
}
|
|
|
|
pub async fn prepare_messages_for_provider(
|
|
messages: &[ChatMessage],
|
|
config: &MultimodalConfig,
|
|
) -> anyhow::Result<PreparedMessages> {
|
|
let (max_images, max_image_size_mb) = config.effective_limits();
|
|
let max_bytes = max_image_size_mb.saturating_mul(1024 * 1024);
|
|
|
|
let found_images = count_image_markers(messages);
|
|
if found_images > max_images {
|
|
return Err(MultimodalError::TooManyImages {
|
|
max_images,
|
|
found: found_images,
|
|
}
|
|
.into());
|
|
}
|
|
|
|
if found_images == 0 {
|
|
return Ok(PreparedMessages {
|
|
messages: messages.to_vec(),
|
|
contains_images: false,
|
|
});
|
|
}
|
|
|
|
let remote_client = build_runtime_proxy_client_with_timeouts("provider.ollama", 30, 10);
|
|
|
|
let mut normalized_messages = Vec::with_capacity(messages.len());
|
|
for message in messages {
|
|
if message.role != "user" {
|
|
normalized_messages.push(message.clone());
|
|
continue;
|
|
}
|
|
|
|
let (cleaned_text, refs) = parse_image_markers(&message.content);
|
|
if refs.is_empty() {
|
|
normalized_messages.push(message.clone());
|
|
continue;
|
|
}
|
|
|
|
let mut normalized_refs = Vec::with_capacity(refs.len());
|
|
for reference in refs {
|
|
let data_uri =
|
|
normalize_image_reference(&reference, config, max_bytes, &remote_client).await?;
|
|
normalized_refs.push(data_uri);
|
|
}
|
|
|
|
let content = compose_multimodal_message(&cleaned_text, &normalized_refs);
|
|
normalized_messages.push(ChatMessage {
|
|
role: message.role.clone(),
|
|
content,
|
|
});
|
|
}
|
|
|
|
Ok(PreparedMessages {
|
|
messages: normalized_messages,
|
|
contains_images: true,
|
|
})
|
|
}
|
|
|
|
fn compose_multimodal_message(text: &str, data_uris: &[String]) -> String {
|
|
let mut content = String::new();
|
|
let trimmed = text.trim();
|
|
|
|
if !trimmed.is_empty() {
|
|
content.push_str(trimmed);
|
|
content.push_str("\n\n");
|
|
}
|
|
|
|
for (index, data_uri) in data_uris.iter().enumerate() {
|
|
if index > 0 {
|
|
content.push('\n');
|
|
}
|
|
content.push_str(IMAGE_MARKER_PREFIX);
|
|
content.push_str(data_uri);
|
|
content.push(']');
|
|
}
|
|
|
|
content
|
|
}
|
|
|
|
async fn normalize_image_reference(
|
|
source: &str,
|
|
config: &MultimodalConfig,
|
|
max_bytes: usize,
|
|
remote_client: &Client,
|
|
) -> anyhow::Result<String> {
|
|
if source.starts_with("data:") {
|
|
return normalize_data_uri(source, max_bytes);
|
|
}
|
|
|
|
if source.starts_with("http://") || source.starts_with("https://") {
|
|
if !config.allow_remote_fetch {
|
|
return Err(MultimodalError::RemoteFetchDisabled {
|
|
input: source.to_string(),
|
|
}
|
|
.into());
|
|
}
|
|
|
|
return normalize_remote_image(source, max_bytes, remote_client).await;
|
|
}
|
|
|
|
normalize_local_image(source, max_bytes).await
|
|
}
|
|
|
|
fn normalize_data_uri(source: &str, max_bytes: usize) -> anyhow::Result<String> {
|
|
let Some(comma_idx) = source.find(',') else {
|
|
return Err(MultimodalError::InvalidMarker {
|
|
input: source.to_string(),
|
|
reason: "expected data URI payload".to_string(),
|
|
}
|
|
.into());
|
|
};
|
|
|
|
let header = &source[..comma_idx];
|
|
let payload = source[comma_idx + 1..].trim();
|
|
|
|
if !header.contains(";base64") {
|
|
return Err(MultimodalError::InvalidMarker {
|
|
input: source.to_string(),
|
|
reason: "only base64 data URIs are supported".to_string(),
|
|
}
|
|
.into());
|
|
}
|
|
|
|
let mime = header
|
|
.trim_start_matches("data:")
|
|
.split(';')
|
|
.next()
|
|
.unwrap_or_default()
|
|
.trim()
|
|
.to_ascii_lowercase();
|
|
|
|
validate_mime(source, &mime)?;
|
|
|
|
let decoded = STANDARD
|
|
.decode(payload)
|
|
.map_err(|error| MultimodalError::InvalidMarker {
|
|
input: source.to_string(),
|
|
reason: format!("invalid base64 payload: {error}"),
|
|
})?;
|
|
|
|
validate_size(source, decoded.len(), max_bytes)?;
|
|
|
|
Ok(format!("data:{mime};base64,{}", STANDARD.encode(decoded)))
|
|
}
|
|
|
|
async fn normalize_remote_image(
|
|
source: &str,
|
|
max_bytes: usize,
|
|
remote_client: &Client,
|
|
) -> anyhow::Result<String> {
|
|
let response = remote_client.get(source).send().await.map_err(|error| {
|
|
MultimodalError::RemoteFetchFailed {
|
|
input: source.to_string(),
|
|
reason: error.to_string(),
|
|
}
|
|
})?;
|
|
|
|
let status = response.status();
|
|
if !status.is_success() {
|
|
return Err(MultimodalError::RemoteFetchFailed {
|
|
input: source.to_string(),
|
|
reason: format!("HTTP {status}"),
|
|
}
|
|
.into());
|
|
}
|
|
|
|
if let Some(content_length) = response.content_length() {
|
|
let content_length = content_length as usize;
|
|
validate_size(source, content_length, max_bytes)?;
|
|
}
|
|
|
|
let content_type = response
|
|
.headers()
|
|
.get(reqwest::header::CONTENT_TYPE)
|
|
.and_then(|value| value.to_str().ok())
|
|
.map(ToString::to_string);
|
|
|
|
let bytes = response
|
|
.bytes()
|
|
.await
|
|
.map_err(|error| MultimodalError::RemoteFetchFailed {
|
|
input: source.to_string(),
|
|
reason: error.to_string(),
|
|
})?;
|
|
|
|
validate_size(source, bytes.len(), max_bytes)?;
|
|
|
|
let mime = detect_mime(None, bytes.as_ref(), content_type.as_deref()).ok_or_else(|| {
|
|
MultimodalError::UnsupportedMime {
|
|
input: source.to_string(),
|
|
mime: "unknown".to_string(),
|
|
}
|
|
})?;
|
|
|
|
validate_mime(source, &mime)?;
|
|
|
|
Ok(format!("data:{mime};base64,{}", STANDARD.encode(bytes)))
|
|
}
|
|
|
|
async fn normalize_local_image(source: &str, max_bytes: usize) -> anyhow::Result<String> {
|
|
let path = Path::new(source);
|
|
if !path.exists() || !path.is_file() {
|
|
return Err(MultimodalError::ImageSourceNotFound {
|
|
input: source.to_string(),
|
|
}
|
|
.into());
|
|
}
|
|
|
|
let metadata =
|
|
tokio::fs::metadata(path)
|
|
.await
|
|
.map_err(|error| MultimodalError::LocalReadFailed {
|
|
input: source.to_string(),
|
|
reason: error.to_string(),
|
|
})?;
|
|
|
|
validate_size(source, metadata.len() as usize, max_bytes)?;
|
|
|
|
let bytes = tokio::fs::read(path)
|
|
.await
|
|
.map_err(|error| MultimodalError::LocalReadFailed {
|
|
input: source.to_string(),
|
|
reason: error.to_string(),
|
|
})?;
|
|
|
|
validate_size(source, bytes.len(), max_bytes)?;
|
|
|
|
let mime =
|
|
detect_mime(Some(path), &bytes, None).ok_or_else(|| MultimodalError::UnsupportedMime {
|
|
input: source.to_string(),
|
|
mime: "unknown".to_string(),
|
|
})?;
|
|
|
|
validate_mime(source, &mime)?;
|
|
|
|
Ok(format!("data:{mime};base64,{}", STANDARD.encode(bytes)))
|
|
}
|
|
|
|
fn validate_size(source: &str, size_bytes: usize, max_bytes: usize) -> anyhow::Result<()> {
|
|
if size_bytes > max_bytes {
|
|
return Err(MultimodalError::ImageTooLarge {
|
|
input: source.to_string(),
|
|
size_bytes,
|
|
max_bytes,
|
|
}
|
|
.into());
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
fn validate_mime(source: &str, mime: &str) -> anyhow::Result<()> {
|
|
if ALLOWED_IMAGE_MIME_TYPES
|
|
.iter()
|
|
.any(|allowed| *allowed == mime)
|
|
{
|
|
return Ok(());
|
|
}
|
|
|
|
Err(MultimodalError::UnsupportedMime {
|
|
input: source.to_string(),
|
|
mime: mime.to_string(),
|
|
}
|
|
.into())
|
|
}
|
|
|
|
fn detect_mime(
|
|
path: Option<&Path>,
|
|
bytes: &[u8],
|
|
header_content_type: Option<&str>,
|
|
) -> Option<String> {
|
|
if let Some(header_mime) = header_content_type.and_then(normalize_content_type) {
|
|
return Some(header_mime);
|
|
}
|
|
|
|
if let Some(path) = path {
|
|
if let Some(ext) = path.extension().and_then(|value| value.to_str()) {
|
|
if let Some(mime) = mime_from_extension(ext) {
|
|
return Some(mime.to_string());
|
|
}
|
|
}
|
|
}
|
|
|
|
mime_from_magic(bytes).map(ToString::to_string)
|
|
}
|
|
|
|
fn normalize_content_type(content_type: &str) -> Option<String> {
|
|
let mime = content_type.split(';').next()?.trim().to_ascii_lowercase();
|
|
if mime.is_empty() {
|
|
None
|
|
} else {
|
|
Some(mime)
|
|
}
|
|
}
|
|
|
|
fn mime_from_extension(ext: &str) -> Option<&'static str> {
|
|
match ext.to_ascii_lowercase().as_str() {
|
|
"png" => Some("image/png"),
|
|
"jpg" | "jpeg" => Some("image/jpeg"),
|
|
"webp" => Some("image/webp"),
|
|
"gif" => Some("image/gif"),
|
|
"bmp" => Some("image/bmp"),
|
|
_ => None,
|
|
}
|
|
}
|
|
|
|
fn mime_from_magic(bytes: &[u8]) -> Option<&'static str> {
|
|
if bytes.len() >= 8 && bytes.starts_with(&[0x89, b'P', b'N', b'G', b'\r', b'\n', 0x1a, b'\n']) {
|
|
return Some("image/png");
|
|
}
|
|
|
|
if bytes.len() >= 3 && bytes.starts_with(&[0xff, 0xd8, 0xff]) {
|
|
return Some("image/jpeg");
|
|
}
|
|
|
|
if bytes.len() >= 6 && (bytes.starts_with(b"GIF87a") || bytes.starts_with(b"GIF89a")) {
|
|
return Some("image/gif");
|
|
}
|
|
|
|
if bytes.len() >= 12 && bytes.starts_with(b"RIFF") && &bytes[8..12] == b"WEBP" {
|
|
return Some("image/webp");
|
|
}
|
|
|
|
if bytes.len() >= 2 && bytes.starts_with(b"BM") {
|
|
return Some("image/bmp");
|
|
}
|
|
|
|
None
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn parse_image_markers_extracts_multiple_markers() {
|
|
let input = "Check this [IMAGE:/tmp/a.png] and this [IMAGE:https://example.com/b.jpg]";
|
|
let (cleaned, refs) = parse_image_markers(input);
|
|
|
|
assert_eq!(cleaned, "Check this and this");
|
|
assert_eq!(refs.len(), 2);
|
|
assert_eq!(refs[0], "/tmp/a.png");
|
|
assert_eq!(refs[1], "https://example.com/b.jpg");
|
|
}
|
|
|
|
#[test]
|
|
fn parse_image_markers_keeps_invalid_empty_marker() {
|
|
let input = "hello [IMAGE:] world";
|
|
let (cleaned, refs) = parse_image_markers(input);
|
|
|
|
assert_eq!(cleaned, "hello [IMAGE:] world");
|
|
assert!(refs.is_empty());
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn prepare_messages_normalizes_local_image_to_data_uri() {
|
|
let temp = tempfile::tempdir().unwrap();
|
|
let image_path = temp.path().join("sample.png");
|
|
|
|
// Minimal PNG signature bytes are enough for MIME detection.
|
|
std::fs::write(
|
|
&image_path,
|
|
[0x89, b'P', b'N', b'G', b'\r', b'\n', 0x1a, b'\n'],
|
|
)
|
|
.unwrap();
|
|
|
|
let messages = vec![ChatMessage::user(format!(
|
|
"Please inspect this screenshot [IMAGE:{}]",
|
|
image_path.display()
|
|
))];
|
|
|
|
let prepared = prepare_messages_for_provider(&messages, &MultimodalConfig::default())
|
|
.await
|
|
.unwrap();
|
|
|
|
assert!(prepared.contains_images);
|
|
assert_eq!(prepared.messages.len(), 1);
|
|
|
|
let (cleaned, refs) = parse_image_markers(&prepared.messages[0].content);
|
|
assert_eq!(cleaned, "Please inspect this screenshot");
|
|
assert_eq!(refs.len(), 1);
|
|
assert!(refs[0].starts_with("data:image/png;base64,"));
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn prepare_messages_rejects_too_many_images() {
|
|
let messages = vec![ChatMessage::user(
|
|
"[IMAGE:/tmp/1.png]\n[IMAGE:/tmp/2.png]".to_string(),
|
|
)];
|
|
|
|
let config = MultimodalConfig {
|
|
max_images: 1,
|
|
max_image_size_mb: 5,
|
|
allow_remote_fetch: false,
|
|
};
|
|
|
|
let error = prepare_messages_for_provider(&messages, &config)
|
|
.await
|
|
.expect_err("should reject image count overflow");
|
|
|
|
assert!(error
|
|
.to_string()
|
|
.contains("multimodal image limit exceeded"));
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn prepare_messages_rejects_remote_url_when_disabled() {
|
|
let messages = vec![ChatMessage::user(
|
|
"Look [IMAGE:https://example.com/img.png]".to_string(),
|
|
)];
|
|
|
|
let error = prepare_messages_for_provider(&messages, &MultimodalConfig::default())
|
|
.await
|
|
.expect_err("should reject remote image URL when fetch is disabled");
|
|
|
|
assert!(error
|
|
.to_string()
|
|
.contains("multimodal remote image fetch is disabled"));
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn prepare_messages_rejects_oversized_local_image() {
|
|
let temp = tempfile::tempdir().unwrap();
|
|
let image_path = temp.path().join("big.png");
|
|
|
|
let bytes = vec![0u8; 1024 * 1024 + 1];
|
|
std::fs::write(&image_path, bytes).unwrap();
|
|
|
|
let messages = vec![ChatMessage::user(format!(
|
|
"[IMAGE:{}]",
|
|
image_path.display()
|
|
))];
|
|
let config = MultimodalConfig {
|
|
max_images: 4,
|
|
max_image_size_mb: 1,
|
|
allow_remote_fetch: false,
|
|
};
|
|
|
|
let error = prepare_messages_for_provider(&messages, &config)
|
|
.await
|
|
.expect_err("should reject oversized local image");
|
|
|
|
assert!(error
|
|
.to_string()
|
|
.contains("multimodal image size limit exceeded"));
|
|
}
|
|
|
|
#[test]
|
|
fn extract_ollama_image_payload_supports_data_uris() {
|
|
let payload = extract_ollama_image_payload("data:image/png;base64,abcd==")
|
|
.expect("payload should be extracted");
|
|
assert_eq!(payload, "abcd==");
|
|
}
|
|
}
|