From 1c0d7bbcb87e83cc6eb79eae58b3f64f6fe381c3 Mon Sep 17 00:00:00 2001 From: Kieran Date: Mon, 16 Feb 2026 22:48:40 +0000 Subject: [PATCH] feat: ollama tools --- src/providers/ollama.rs | 428 ++++++++++++++++++++++------------------ 1 file changed, 241 insertions(+), 187 deletions(-) diff --git a/src/providers/ollama.rs b/src/providers/ollama.rs index 582fdfe..c7b008a 100644 --- a/src/providers/ollama.rs +++ b/src/providers/ollama.rs @@ -8,6 +8,8 @@ pub struct OllamaProvider { client: Client, } +// ─── Request Structures ─────────────────────────────────────────────────────── + #[derive(Debug, Serialize)] struct ChatRequest { model: String, @@ -27,6 +29,8 @@ struct Options { temperature: f64, } +// ─── Response Structures ────────────────────────────────────────────────────── + #[derive(Debug, Deserialize)] struct ApiChatResponse { message: ResponseMessage, @@ -38,6 +42,9 @@ struct ResponseMessage { content: String, #[serde(default)] tool_calls: Vec, + /// Some models return a "thinking" field with internal reasoning + #[serde(default)] + thinking: Option, } #[derive(Debug, Deserialize)] @@ -53,6 +60,8 @@ struct OllamaFunction { arguments: serde_json::Value, } +// ─── Implementation ─────────────────────────────────────────────────────────── + impl OllamaProvider { pub fn new(base_url: Option<&str>) -> Self { Self { @@ -61,37 +70,20 @@ impl OllamaProvider { .trim_end_matches('/') .to_string(), client: Client::builder() - .timeout(std::time::Duration::from_secs(300)) // Ollama runs locally, may be slow + .timeout(std::time::Duration::from_secs(300)) .connect_timeout(std::time::Duration::from_secs(10)) .build() .unwrap_or_else(|_| Client::new()), } } -} -#[async_trait] -impl Provider for OllamaProvider { - async fn chat_with_system( + /// Send a request to Ollama and get the parsed response + async fn send_request( &self, - system_prompt: Option<&str>, - message: &str, + messages: Vec, model: &str, temperature: f64, - ) -> anyhow::Result { - let mut messages = Vec::new(); - - if let Some(sys) = system_prompt { - messages.push(Message { - role: "system".to_string(), - content: sys.to_string(), - }); - } - - messages.push(Message { - role: "user".to_string(), - content: message.to_string(), - }); - + ) -> anyhow::Result { let request = ChatRequest { model: model.to_string(), messages, @@ -108,6 +100,7 @@ impl Provider for OllamaProvider { request.messages.len(), temperature ); + if tracing::enabled!(tracing::Level::TRACE) { if let Ok(req_json) = serde_json::to_string(&request) { tracing::trace!("Ollama request body: {}", req_json); @@ -118,11 +111,9 @@ impl Provider for OllamaProvider { let status = response.status(); tracing::debug!("Ollama response status: {}", status); - // Read raw body first to enable debugging if deserialization fails let body = response.bytes().await?; - let body_len = body.len(); + tracing::debug!("Ollama response body length: {} bytes", body.len()); - tracing::debug!("Ollama response body length: {} bytes", body_len); if tracing::enabled!(tracing::Level::TRACE) { let raw = String::from_utf8_lossy(&body); tracing::trace!( @@ -153,37 +144,140 @@ impl Provider for OllamaProvider { } }; - let content = chat_response.message.content; - tracing::debug!( - "Ollama response parsed: content_length={} content_preview='{}'", - content.len(), - if content.len() > 100 { - format!("{}...", &content[..100]) - } else { - content.clone() - } - ); + Ok(chat_response) + } - if content.is_empty() && chat_response.message.tool_calls.is_empty() { - let raw = String::from_utf8_lossy(&body); - tracing::warn!("Ollama returned empty content with no tool calls. Raw response: {}", raw); + /// Convert Ollama tool calls to the JSON format expected by parse_tool_calls in loop_.rs + /// + /// Handles quirky model behavior where tool calls are wrapped: + /// - `{"name": "tool_call", "arguments": {"name": "shell", "arguments": {...}}}` + /// - `{"name": "tool.shell", "arguments": {...}}` + fn format_tool_calls_for_loop(&self, tool_calls: &[OllamaToolCall]) -> String { + let formatted_calls: Vec = tool_calls + .iter() + .map(|tc| { + let (tool_name, tool_args) = self.extract_tool_name_and_args(tc); + + // Arguments must be a JSON string for parse_tool_calls compatibility + let args_str = serde_json::to_string(&tool_args) + .unwrap_or_else(|_| "{}".to_string()); + + serde_json::json!({ + "id": tc.id, + "type": "function", + "function": { + "name": tool_name, + "arguments": args_str + } + }) + }) + .collect(); + + serde_json::json!({ + "content": "", + "tool_calls": formatted_calls + }) + .to_string() + } + + /// Extract the actual tool name and arguments from potentially nested structures + fn extract_tool_name_and_args(&self, tc: &OllamaToolCall) -> (String, serde_json::Value) { + let name = &tc.function.name; + let args = &tc.function.arguments; + + // Pattern 1: Nested tool_call wrapper (various malformed versions) + // {"name": "tool_call", "arguments": {"name": "shell", "arguments": {"command": "date"}}} + // {"name": "tool_call>") + || name.starts_with("tool_call<") + { + if let Some(nested_name) = args.get("name").and_then(|v| v.as_str()) { + let nested_args = args.get("arguments").cloned().unwrap_or(serde_json::json!({})); + tracing::debug!( + "Unwrapped nested tool call: {} -> {} with args {:?}", + name, + nested_name, + nested_args + ); + return (nested_name.to_string(), nested_args); + } + } + + // Pattern 2: Prefixed tool name (tool.shell, tool.file_read, etc.) + if let Some(stripped) = name.strip_prefix("tool.") { + return (stripped.to_string(), args.clone()); + } + + // Pattern 3: Normal tool call + (name.clone(), args.clone()) + } +} + +#[async_trait] +impl Provider for OllamaProvider { + async fn chat_with_system( + &self, + system_prompt: Option<&str>, + message: &str, + model: &str, + temperature: f64, + ) -> anyhow::Result { + let mut messages = Vec::new(); + + if let Some(sys) = system_prompt { + messages.push(Message { + role: "system".to_string(), + content: sys.to_string(), + }); + } + + messages.push(Message { + role: "user".to_string(), + content: message.to_string(), + }); + + let response = self.send_request(messages, model, temperature).await?; + + // If model returned tool calls, format them for loop_.rs's parse_tool_calls + if !response.message.tool_calls.is_empty() { + tracing::debug!( + "Ollama returned {} tool call(s), formatting for loop parser", + response.message.tool_calls.len() + ); + return Ok(self.format_tool_calls_for_loop(&response.message.tool_calls)); + } + + // Plain text response + let content = response.message.content; + + // Handle edge case: model returned only "thinking" with no content or tool calls + if content.is_empty() { + if let Some(thinking) = &response.message.thinking { + tracing::warn!( + "Ollama returned empty content with only thinking: '{}'. Model may have stopped prematurely.", + if thinking.len() > 100 { &thinking[..100] } else { thinking } + ); + return Ok(format!( + "I was thinking about this: {}... but I didn't complete my response. Could you try asking again?", + if thinking.len() > 200 { &thinking[..200] } else { thinking } + )); + } + tracing::warn!("Ollama returned empty content with no tool calls"); } Ok(content) } - fn supports_native_tools(&self) -> bool { - true - } - - async fn chat( + async fn chat_with_history( &self, - request: crate::providers::ChatRequest<'_>, + messages: &[crate::providers::ChatMessage], model: &str, temperature: f64, - ) -> anyhow::Result { - let messages: Vec = request - .messages + ) -> anyhow::Result { + let api_messages: Vec = messages .iter() .map(|m| Message { role: m.role.clone(), @@ -191,102 +285,50 @@ impl Provider for OllamaProvider { }) .collect(); - let api_request = ChatRequest { - model: model.to_string(), - messages, - stream: false, - options: Options { temperature }, - }; + let response = self.send_request(api_messages, model, temperature).await?; - let url = format!("{}/api/chat", self.base_url); - - tracing::debug!( - "Ollama chat request: url={} model={} message_count={} temperature={}", - url, - model, - api_request.messages.len(), - temperature - ); - if tracing::enabled!(tracing::Level::TRACE) { - if let Ok(req_json) = serde_json::to_string(&api_request) { - tracing::trace!("Ollama chat request body: {}", req_json); - } - } - - let response = self.client.post(&url).json(&api_request).send().await?; - let status = response.status(); - tracing::debug!("Ollama chat response status: {}", status); - - let body = response.bytes().await?; - tracing::debug!("Ollama chat response body length: {} bytes", body.len()); - - if tracing::enabled!(tracing::Level::TRACE) { - let raw = String::from_utf8_lossy(&body); - tracing::trace!( - "Ollama chat raw response: {}", - if raw.len() > 2000 { &raw[..2000] } else { &raw } + // If model returned tool calls, format them for loop_.rs's parse_tool_calls + if !response.message.tool_calls.is_empty() { + tracing::debug!( + "Ollama returned {} tool call(s), formatting for loop parser", + response.message.tool_calls.len() ); + return Ok(self.format_tool_calls_for_loop(&response.message.tool_calls)); } - if !status.is_success() { - let raw = String::from_utf8_lossy(&body); - tracing::error!("Ollama chat error response: status={} body={}", status, raw); - anyhow::bail!( - "Ollama API error ({}): {}. Is Ollama running? (brew install ollama && ollama serve)", - status, - if raw.len() > 200 { &raw[..200] } else { &raw } - ); - } - - let chat_response: ApiChatResponse = match serde_json::from_slice(&body) { - Ok(r) => r, - Err(e) => { - let raw = String::from_utf8_lossy(&body); - tracing::error!( - "Ollama chat response deserialization failed: {e}. Raw body: {}", - if raw.len() > 500 { &raw[..500] } else { &raw } + // Plain text response + let content = response.message.content; + + // Handle edge case: model returned only "thinking" with no content or tool calls + // This is a model quirk - it stopped after reasoning without producing output + if content.is_empty() { + if let Some(thinking) = &response.message.thinking { + tracing::warn!( + "Ollama returned empty content with only thinking: '{}'. Model may have stopped prematurely.", + if thinking.len() > 100 { &thinking[..100] } else { thinking } ); - anyhow::bail!("Failed to parse Ollama response: {e}"); + // Return a message indicating the model's thought process but no action + return Ok(format!( + "I was thinking about this: {}... but I didn't complete my response. Could you try asking again?", + if thinking.len() > 200 { &thinking[..200] } else { thinking } + )); } - }; - - let content = chat_response.message.content; - let tool_calls: Vec = chat_response - .message - .tool_calls - .into_iter() - .enumerate() - .map(|(i, tc)| { - let args_str = match &tc.function.arguments { - serde_json::Value::String(s) => s.clone(), - other => other.to_string(), - }; - crate::providers::ToolCall { - id: tc.id.unwrap_or_else(|| format!("call_{}", i)), - name: tc.function.name, - arguments: args_str, - } - }) - .collect(); - - tracing::debug!( - "Ollama chat response parsed: content_length={} tool_calls_count={}", - content.len(), - tool_calls.len() - ); - - if content.is_empty() && tool_calls.is_empty() { - let raw = String::from_utf8_lossy(&body); - tracing::warn!("Ollama returned empty content with no tool calls. Raw response: {}", raw); + tracing::warn!("Ollama returned empty content with no tool calls"); } - Ok(crate::providers::ChatResponse { - text: if content.is_empty() { None } else { Some(content) }, - tool_calls, - }) + Ok(content) + } + + fn supports_native_tools(&self) -> bool { + // Return false since loop_.rs uses XML-style tool parsing via system prompt + // The model may return native tool_calls but we convert them to JSON format + // that parse_tool_calls() understands + false } } +// ─── Tests ──────────────────────────────────────────────────────────────────── + #[cfg(test)] mod tests { use super::*; @@ -315,46 +357,6 @@ mod tests { assert_eq!(p.base_url, ""); } - #[test] - fn request_serializes_with_system() { - let req = ChatRequest { - model: "llama3".to_string(), - messages: vec![ - Message { - role: "system".to_string(), - content: "You are ZeroClaw".to_string(), - }, - Message { - role: "user".to_string(), - content: "hello".to_string(), - }, - ], - stream: false, - options: Options { temperature: 0.7 }, - }; - let json = serde_json::to_string(&req).unwrap(); - assert!(json.contains("\"stream\":false")); - assert!(json.contains("llama3")); - assert!(json.contains("system")); - assert!(json.contains("\"temperature\":0.7")); - } - - #[test] - fn request_serializes_without_system() { - let req = ChatRequest { - model: "mistral".to_string(), - messages: vec![Message { - role: "user".to_string(), - content: "test".to_string(), - }], - stream: false, - options: Options { temperature: 0.0 }, - }; - let json = serde_json::to_string(&req).unwrap(); - assert!(!json.contains("\"role\":\"system\"")); - assert!(json.contains("mistral")); - } - #[test] fn response_deserializes() { let json = r#"{"message":{"role":"assistant","content":"Hello from Ollama!"}}"#; @@ -371,7 +373,6 @@ mod tests { #[test] fn response_with_missing_content_defaults_to_empty() { - // Some models/versions may omit content field entirely let json = r#"{"message":{"role":"assistant"}}"#; let resp: ApiChatResponse = serde_json::from_str(json).unwrap(); assert!(resp.message.content.is_empty()); @@ -379,7 +380,6 @@ mod tests { #[test] fn response_with_thinking_field_extracts_content() { - // Models with thinking capability return additional fields let json = r#"{"message":{"role":"assistant","content":"hello","thinking":"internal reasoning"}}"#; let resp: ApiChatResponse = serde_json::from_str(json).unwrap(); assert_eq!(resp.message.content, "hello"); @@ -387,28 +387,82 @@ mod tests { #[test] fn response_with_tool_calls_parses_correctly() { - // Models may return tool_calls with empty content - let json = r#"{"message":{"role":"assistant","content":"","thinking":"some thinking","tool_calls":[{"id":"call_123","function":{"name":"shell","arguments":{"cmd":["ls","-la"]}}}]}}"#; + let json = r#"{"message":{"role":"assistant","content":"","tool_calls":[{"id":"call_123","function":{"name":"shell","arguments":{"command":"date"}}}]}}"#; let resp: ApiChatResponse = serde_json::from_str(json).unwrap(); assert!(resp.message.content.is_empty()); assert_eq!(resp.message.tool_calls.len(), 1); assert_eq!(resp.message.tool_calls[0].function.name, "shell"); - assert_eq!(resp.message.tool_calls[0].id, Some("call_123".to_string())); } #[test] - fn response_with_tool_calls_no_id() { - // Some models may not include an id field - let json = r#"{"message":{"role":"assistant","content":"","tool_calls":[{"function":{"name":"test_tool","arguments":{}}}]}}"#; - let resp: ApiChatResponse = serde_json::from_str(json).unwrap(); - assert_eq!(resp.message.tool_calls.len(), 1); - assert!(resp.message.tool_calls[0].id.is_none()); + fn extract_tool_name_handles_nested_tool_call() { + let provider = OllamaProvider::new(None); + let tc = OllamaToolCall { + id: Some("call_123".into()), + function: OllamaFunction { + name: "tool_call".into(), + arguments: serde_json::json!({ + "name": "shell", + "arguments": {"command": "date"} + }), + }, + }; + let (name, args) = provider.extract_tool_name_and_args(&tc); + assert_eq!(name, "shell"); + assert_eq!(args.get("command").unwrap(), "date"); } #[test] - fn response_with_multiline() { - let json = r#"{"message":{"role":"assistant","content":"line1\nline2\nline3"}}"#; - let resp: ApiChatResponse = serde_json::from_str(json).unwrap(); - assert!(resp.message.content.contains("line1")); + fn extract_tool_name_handles_prefixed_name() { + let provider = OllamaProvider::new(None); + let tc = OllamaToolCall { + id: Some("call_123".into()), + function: OllamaFunction { + name: "tool.shell".into(), + arguments: serde_json::json!({"command": "ls"}), + }, + }; + let (name, args) = provider.extract_tool_name_and_args(&tc); + assert_eq!(name, "shell"); + assert_eq!(args.get("command").unwrap(), "ls"); + } + + #[test] + fn extract_tool_name_handles_normal_call() { + let provider = OllamaProvider::new(None); + let tc = OllamaToolCall { + id: Some("call_123".into()), + function: OllamaFunction { + name: "file_read".into(), + arguments: serde_json::json!({"path": "/tmp/test"}), + }, + }; + let (name, args) = provider.extract_tool_name_and_args(&tc); + assert_eq!(name, "file_read"); + assert_eq!(args.get("path").unwrap(), "/tmp/test"); + } + + #[test] + fn format_tool_calls_produces_valid_json() { + let provider = OllamaProvider::new(None); + let tool_calls = vec![OllamaToolCall { + id: Some("call_abc".into()), + function: OllamaFunction { + name: "shell".into(), + arguments: serde_json::json!({"command": "date"}), + }, + }]; + + let formatted = provider.format_tool_calls_for_loop(&tool_calls); + let parsed: serde_json::Value = serde_json::from_str(&formatted).unwrap(); + + assert!(parsed.get("tool_calls").is_some()); + let calls = parsed.get("tool_calls").unwrap().as_array().unwrap(); + assert_eq!(calls.len(), 1); + + let func = calls[0].get("function").unwrap(); + assert_eq!(func.get("name").unwrap(), "shell"); + // arguments should be a string (JSON-encoded) + assert!(func.get("arguments").unwrap().is_string()); } }