diff --git a/src/providers/ollama.rs b/src/providers/ollama.rs index b902935..6a43ad2 100644 --- a/src/providers/ollama.rs +++ b/src/providers/ollama.rs @@ -2,6 +2,7 @@ use crate::providers::traits::{ChatMessage, ChatResponse, Provider, ToolCall}; use async_trait::async_trait; use reqwest::Client; use serde::{Deserialize, Serialize}; +use std::collections::HashMap; pub struct OllamaProvider { base_url: String, @@ -23,7 +24,25 @@ struct ChatRequest { #[derive(Debug, Serialize)] struct Message { role: String, - content: String, + #[serde(skip_serializing_if = "Option::is_none")] + content: Option, + #[serde(skip_serializing_if = "Option::is_none")] + tool_calls: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + tool_name: Option, +} + +#[derive(Debug, Serialize)] +struct OutgoingToolCall { + #[serde(rename = "type")] + kind: String, + function: OutgoingFunction, +} + +#[derive(Debug, Serialize)] +struct OutgoingFunction { + name: String, + arguments: serde_json::Value, } #[derive(Debug, Serialize)] @@ -114,6 +133,98 @@ impl OllamaProvider { Ok((normalized_model, should_auth)) } + fn parse_tool_arguments(arguments: &str) -> serde_json::Value { + serde_json::from_str(arguments).unwrap_or_else(|_| serde_json::json!({})) + } + + /// Convert internal chat history format to Ollama's native tool-call message schema. + /// + /// `run_tool_call_loop` stores native assistant/tool entries as JSON strings in + /// `ChatMessage.content`. We decode those payloads here so follow-up requests send + /// structured `assistant.tool_calls` and `tool.tool_name`, as expected by Ollama. + fn convert_messages(&self, messages: &[ChatMessage]) -> Vec { + let mut tool_name_by_id: HashMap = HashMap::new(); + + messages + .iter() + .map(|message| { + if message.role == "assistant" { + if let Ok(value) = serde_json::from_str::(&message.content) { + if let Some(tool_calls_value) = value.get("tool_calls") { + if let Ok(parsed_calls) = + serde_json::from_value::>(tool_calls_value.clone()) + { + let outgoing_calls: Vec = parsed_calls + .into_iter() + .map(|call| { + tool_name_by_id.insert(call.id.clone(), call.name.clone()); + OutgoingToolCall { + kind: "function".to_string(), + function: OutgoingFunction { + name: call.name, + arguments: Self::parse_tool_arguments( + &call.arguments, + ), + }, + } + }) + .collect(); + let content = value + .get("content") + .and_then(serde_json::Value::as_str) + .map(ToString::to_string); + return Message { + role: "assistant".to_string(), + content, + tool_calls: Some(outgoing_calls), + tool_name: None, + }; + } + } + } + } + + if message.role == "tool" { + if let Ok(value) = serde_json::from_str::(&message.content) { + let tool_name = value + .get("tool_name") + .and_then(serde_json::Value::as_str) + .map(ToString::to_string) + .or_else(|| { + value + .get("tool_call_id") + .and_then(serde_json::Value::as_str) + .and_then(|id| tool_name_by_id.get(id)) + .cloned() + }); + let content = value + .get("content") + .and_then(serde_json::Value::as_str) + .map(ToString::to_string) + .or_else(|| { + (!message.content.trim().is_empty()) + .then_some(message.content.clone()) + }); + + return Message { + role: "tool".to_string(), + content, + tool_calls: None, + tool_name, + }; + } + } + + Message { + role: message.role.clone(), + content: Some(message.content.clone()), + tool_calls: None, + tool_name: None, + } + }) + .collect() + } + /// Send a request to Ollama and get the parsed response. /// Pass `tools` to enable native function-calling for models that support it. async fn send_request( @@ -277,13 +388,17 @@ impl Provider for OllamaProvider { if let Some(sys) = system_prompt { messages.push(Message { role: "system".to_string(), - content: sys.to_string(), + content: Some(sys.to_string()), + tool_calls: None, + tool_name: None, }); } messages.push(Message { role: "user".to_string(), - content: message.to_string(), + content: Some(message.to_string()), + tool_calls: None, + tool_name: None, }); let response = self @@ -328,16 +443,16 @@ impl Provider for OllamaProvider { ) -> anyhow::Result { let (normalized_model, should_auth) = self.resolve_request_details(model)?; - let api_messages: Vec = messages - .iter() - .map(|m| Message { - role: m.role.clone(), - content: m.content.clone(), - }) - .collect(); + let api_messages = self.convert_messages(messages); let response = self - .send_request(api_messages, &normalized_model, temperature, should_auth, None) + .send_request( + api_messages, + &normalized_model, + temperature, + should_auth, + None, + ) .await?; // If model returned tool calls, format them for loop_.rs's parse_tool_calls @@ -349,7 +464,6 @@ impl Provider for OllamaProvider { return Ok(self.format_tool_calls_for_loop(&response.message.tool_calls)); } - // Plain text response let content = response.message.content; @@ -382,20 +496,20 @@ impl Provider for OllamaProvider { ) -> anyhow::Result { let (normalized_model, should_auth) = self.resolve_request_details(model)?; - let api_messages: Vec = messages - .iter() - .map(|m| Message { - role: m.role.clone(), - content: m.content.clone(), - }) - .collect(); + let api_messages = self.convert_messages(messages); // Tools arrive pre-formatted in OpenAI/Ollama-compatible JSON from // tools_to_openai_format() in loop_.rs — pass them through directly. let tools_opt = if tools.is_empty() { None } else { Some(tools) }; let response = self - .send_request(api_messages, &normalized_model, temperature, should_auth, tools_opt) + .send_request( + api_messages, + &normalized_model, + temperature, + should_auth, + tools_opt, + ) .await?; // Native tool calls returned by the model. @@ -425,7 +539,6 @@ impl Provider for OllamaProvider { return Ok(ChatResponse { text, tool_calls }); } - // Plain text response. let content = response.message.content; if content.is_empty() { @@ -641,4 +754,50 @@ mod tests { // arguments should be a string (JSON-encoded) assert!(func.get("arguments").unwrap().is_string()); } -} \ No newline at end of file + + #[test] + fn convert_messages_parses_native_assistant_tool_calls() { + let provider = OllamaProvider::new(None, None); + let messages = vec![ChatMessage { + role: "assistant".into(), + content: r#"{"content":null,"tool_calls":[{"id":"call_1","name":"shell","arguments":"{\"command\":\"ls\"}"}]}"#.into(), + }]; + + let converted = provider.convert_messages(&messages); + + assert_eq!(converted.len(), 1); + assert_eq!(converted[0].role, "assistant"); + assert!(converted[0].content.is_none()); + let calls = converted[0] + .tool_calls + .as_ref() + .expect("tool calls expected"); + assert_eq!(calls.len(), 1); + assert_eq!(calls[0].kind, "function"); + assert_eq!(calls[0].function.name, "shell"); + assert_eq!(calls[0].function.arguments.get("command").unwrap(), "ls"); + } + + #[test] + fn convert_messages_maps_tool_result_call_id_to_tool_name() { + let provider = OllamaProvider::new(None, None); + let messages = vec![ + ChatMessage { + role: "assistant".into(), + content: r#"{"content":null,"tool_calls":[{"id":"call_7","name":"file_read","arguments":"{\"path\":\"README.md\"}"}]}"#.into(), + }, + ChatMessage { + role: "tool".into(), + content: r#"{"tool_call_id":"call_7","content":"ok"}"#.into(), + }, + ]; + + let converted = provider.convert_messages(&messages); + + assert_eq!(converted.len(), 2); + assert_eq!(converted[1].role, "tool"); + assert_eq!(converted[1].tool_name.as_deref(), Some("file_read")); + assert_eq!(converted[1].content.as_deref(), Some("ok")); + assert!(converted[1].tool_calls.is_none()); + } +}