From 3b4a4de45769c60336b4cda294671594a8e711ac Mon Sep 17 00:00:00 2001 From: chumyin Date: Mon, 16 Feb 2026 13:04:10 +0800 Subject: [PATCH] refactor(provider): unify Provider responses with ChatResponse - Switch Provider trait methods to return structured ChatResponse - Map OpenAI-compatible tool_calls into shared ToolCall type - Update reliable/router wrappers and provider tests for new interface - Make agent loop prefer structured tool calls with text fallback parsing - Adapt gateway replies to structured responses with safe tool-call fallback --- src/agent/loop_.rs | 95 +++++++++++++++++++++++++++---- src/gateway/mod.rs | 35 ++++++++++-- src/providers/anthropic.rs | 16 +++--- src/providers/compatible.rs | 110 ++++++++++++++++++++++-------------- src/providers/gemini.rs | 5 +- src/providers/mod.rs | 2 +- src/providers/ollama.rs | 18 +++--- src/providers/openai.rs | 20 +++---- src/providers/openrouter.rs | 10 ++-- src/providers/reliable.rs | 26 ++++----- src/providers/router.rs | 22 ++++---- src/providers/traits.rs | 19 ++++++- 12 files changed, 260 insertions(+), 118 deletions(-) diff --git a/src/agent/loop_.rs b/src/agent/loop_.rs index a1aea97..45b37d2 100644 --- a/src/agent/loop_.rs +++ b/src/agent/loop_.rs @@ -1,7 +1,7 @@ use crate::config::Config; use crate::memory::{self, Memory, MemoryCategory}; use crate::observability::{self, Observer, ObserverEvent}; -use crate::providers::{self, ChatMessage, Provider}; +use crate::providers::{self, ChatMessage, Provider, ToolCall}; use crate::runtime; use crate::security::SecurityPolicy; use crate::tools::{self, Tool}; @@ -331,15 +331,71 @@ fn parse_tool_calls(response: &str) -> (String, Vec) { (text_parts.join("\n"), calls) } +fn parse_structured_tool_calls(tool_calls: &[ToolCall]) -> Vec { + tool_calls + .iter() + .map(|call| ParsedToolCall { + name: call.name.clone(), + arguments: serde_json::from_str::(&call.arguments) + .unwrap_or_else(|_| serde_json::Value::Object(serde_json::Map::new())), + }) + .collect() +} + +fn build_assistant_history_with_tool_calls(text: &str, tool_calls: &[ToolCall]) -> String { + let mut parts = Vec::new(); + + if !text.trim().is_empty() { + parts.push(text.trim().to_string()); + } + + for call in tool_calls { + let arguments = serde_json::from_str::(&call.arguments) + .unwrap_or_else(|_| serde_json::Value::String(call.arguments.clone())); + let payload = serde_json::json!({ + "id": call.id, + "name": call.name, + "arguments": arguments, + }); + parts.push(format!("\n{payload}\n")); + } + + parts.join("\n") +} + #[derive(Debug)] struct ParsedToolCall { name: String, arguments: serde_json::Value, } +/// Execute a single turn for channel runtime paths. +/// +/// Channels currently do not thread an explicit provider label into this call, +/// so we route through the full loop with a stable placeholder provider name. +pub(crate) async fn agent_turn( + provider: &dyn Provider, + history: &mut Vec, + tools_registry: &[Box], + observer: &dyn Observer, + model: &str, + temperature: f64, +) -> Result { + run_tool_call_loop( + provider, + history, + tools_registry, + observer, + "channel-runtime", + model, + temperature, + ) + .await +} + /// Execute a single turn of the agent loop: send messages, parse tool calls, /// execute tools, and loop until the LLM produces a final text response. -pub(crate) async fn agent_turn( +pub(crate) async fn run_tool_call_loop( provider: &dyn Provider, history: &mut Vec, tools_registry: &[Box], @@ -382,17 +438,36 @@ pub(crate) async fn agent_turn( } }; - let (text, tool_calls) = parse_tool_calls(&response); + let response_text = response.text.unwrap_or_default(); + let mut assistant_history_content = response_text.clone(); + let mut parsed_text = response_text.clone(); + let mut tool_calls = parse_structured_tool_calls(&response.tool_calls); + + if !response.tool_calls.is_empty() { + assistant_history_content = + build_assistant_history_with_tool_calls(&response_text, &response.tool_calls); + } + + if tool_calls.is_empty() { + let (fallback_text, fallback_calls) = parse_tool_calls(&response_text); + parsed_text = fallback_text; + tool_calls = fallback_calls; + } if tool_calls.is_empty() { // No tool calls — this is the final response - history.push(ChatMessage::assistant(&response)); - return Ok(if text.is_empty() { response } else { text }); + let final_text = if parsed_text.is_empty() { + response_text + } else { + parsed_text + }; + history.push(ChatMessage::assistant(&final_text)); + return Ok(final_text); } // Print any text the LLM produced alongside tool calls - if !text.is_empty() { - print!("{text}"); + if !parsed_text.is_empty() { + print!("{parsed_text}"); let _ = std::io::stdout().flush(); } @@ -438,7 +513,7 @@ pub(crate) async fn agent_turn( } // Add assistant message with tool calls + tool results to history - history.push(ChatMessage::assistant(&response)); + history.push(ChatMessage::assistant(&assistant_history_content)); history.push(ChatMessage::user(format!("[Tool results]\n{tool_results}"))); } @@ -639,7 +714,7 @@ pub async fn run( ChatMessage::user(&enriched), ]; - let response = agent_turn( + let response = run_tool_call_loop( provider.as_ref(), &mut history, &tools_registry, @@ -694,7 +769,7 @@ pub async fn run( history.push(ChatMessage::user(&enriched)); - let response = match agent_turn( + let response = match run_tool_call_loop( provider.as_ref(), &mut history, &tools_registry, diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index 11de562..2282e66 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -10,7 +10,7 @@ use crate::channels::{Channel, WhatsAppChannel}; use crate::config::Config; use crate::memory::{self, Memory, MemoryCategory}; -use crate::providers::{self, Provider}; +use crate::providers::{self, ChatResponse, Provider}; use crate::security::pairing::{constant_time_eq, is_public_bind, PairingGuard}; use crate::util::truncate_with_ellipsis; use anyhow::Result; @@ -45,6 +45,29 @@ fn whatsapp_memory_key(msg: &crate::channels::traits::ChannelMessage) -> String format!("whatsapp_{}_{}", msg.sender, msg.id) } +fn gateway_reply_from_response(response: ChatResponse) -> String { + let has_tool_calls = response.has_tool_calls(); + let tool_call_count = response.tool_calls.len(); + let mut reply = response.text.unwrap_or_default(); + + if has_tool_calls { + tracing::warn!( + tool_call_count, + "Provider requested tool calls in gateway mode; tool calls are not executed here" + ); + if reply.trim().is_empty() { + reply = "I need to use tools to answer that, but tool execution is not enabled for gateway requests yet." + .to_string(); + } + } + + if reply.trim().is_empty() { + reply = "Model returned an empty response.".to_string(); + } + + reply +} + #[derive(Debug)] struct SlidingWindowRateLimiter { limit_per_window: u32, @@ -497,7 +520,8 @@ async fn handle_webhook( .await { Ok(response) => { - let body = serde_json::json!({"response": response, "model": state.model}); + let reply = gateway_reply_from_response(response); + let body = serde_json::json!({"response": reply, "model": state.model}); (StatusCode::OK, Json(body)) } Err(e) => { @@ -651,8 +675,9 @@ async fn handle_whatsapp_message( .await { Ok(response) => { + let reply = gateway_reply_from_response(response); // Send reply via WhatsApp - if let Err(e) = wa.send(&response, &msg.sender).await { + if let Err(e) = wa.send(&reply, &msg.sender).await { tracing::error!("Failed to send WhatsApp reply: {e}"); } } @@ -822,9 +847,9 @@ mod tests { _message: &str, _model: &str, _temperature: f64, - ) -> anyhow::Result { + ) -> anyhow::Result { self.calls.fetch_add(1, Ordering::SeqCst); - Ok("ok".into()) + Ok(ChatResponse::with_text("ok")) } } diff --git a/src/providers/anthropic.rs b/src/providers/anthropic.rs index 3202a01..c3c7870 100644 --- a/src/providers/anthropic.rs +++ b/src/providers/anthropic.rs @@ -1,4 +1,4 @@ -use crate::providers::traits::Provider; +use crate::providers::traits::{ChatResponse as ProviderChatResponse, Provider}; use async_trait::async_trait; use reqwest::Client; use serde::{Deserialize, Serialize}; @@ -26,7 +26,7 @@ struct Message { } #[derive(Debug, Deserialize)] -struct ChatResponse { +struct ApiChatResponse { content: Vec, } @@ -72,7 +72,7 @@ impl Provider for AnthropicProvider { message: &str, model: &str, temperature: f64, - ) -> anyhow::Result { + ) -> anyhow::Result { let credential = self.credential.as_ref().ok_or_else(|| { anyhow::anyhow!( "Anthropic credentials not set. Set ANTHROPIC_API_KEY or ANTHROPIC_OAUTH_TOKEN (setup-token)." @@ -109,13 +109,13 @@ impl Provider for AnthropicProvider { return Err(super::api_error("Anthropic", response).await); } - let chat_response: ChatResponse = response.json().await?; + let chat_response: ApiChatResponse = response.json().await?; chat_response .content .into_iter() .next() - .map(|c| c.text) + .map(|c| ProviderChatResponse::with_text(c.text)) .ok_or_else(|| anyhow::anyhow!("No response from Anthropic")) } } @@ -241,7 +241,7 @@ mod tests { #[test] fn chat_response_deserializes() { let json = r#"{"content":[{"type":"text","text":"Hello there!"}]}"#; - let resp: ChatResponse = serde_json::from_str(json).unwrap(); + let resp: ApiChatResponse = serde_json::from_str(json).unwrap(); assert_eq!(resp.content.len(), 1); assert_eq!(resp.content[0].text, "Hello there!"); } @@ -249,7 +249,7 @@ mod tests { #[test] fn chat_response_empty_content() { let json = r#"{"content":[]}"#; - let resp: ChatResponse = serde_json::from_str(json).unwrap(); + let resp: ApiChatResponse = serde_json::from_str(json).unwrap(); assert!(resp.content.is_empty()); } @@ -257,7 +257,7 @@ mod tests { fn chat_response_multiple_blocks() { let json = r#"{"content":[{"type":"text","text":"First"},{"type":"text","text":"Second"}]}"#; - let resp: ChatResponse = serde_json::from_str(json).unwrap(); + let resp: ApiChatResponse = serde_json::from_str(json).unwrap(); assert_eq!(resp.content.len(), 2); assert_eq!(resp.content[0].text, "First"); assert_eq!(resp.content[1].text, "Second"); diff --git a/src/providers/compatible.rs b/src/providers/compatible.rs index 2312741..de7bff0 100644 --- a/src/providers/compatible.rs +++ b/src/providers/compatible.rs @@ -2,7 +2,7 @@ //! Most LLM APIs follow the same `/v1/chat/completions` format. //! This module provides a single implementation that works for all of them. -use crate::providers::traits::{ChatMessage, Provider}; +use crate::providers::traits::{ChatMessage, ChatResponse, Provider, ToolCall}; use async_trait::async_trait; use reqwest::Client; use serde::{Deserialize, Serialize}; @@ -135,11 +135,12 @@ struct ResponseMessage { #[serde(default)] content: Option, #[serde(default)] - tool_calls: Option>, + tool_calls: Option>, } #[derive(Debug, Deserialize, Serialize)] -struct ToolCall { +struct ApiToolCall { + id: Option, #[serde(rename = "type")] kind: Option, function: Option, @@ -225,6 +226,44 @@ fn extract_responses_text(response: ResponsesResponse) -> Option { None } +fn map_response_message(message: ResponseMessage) -> ChatResponse { + let text = first_nonempty(message.content.as_deref()); + let tool_calls = message + .tool_calls + .unwrap_or_default() + .into_iter() + .enumerate() + .filter_map(|(index, call)| map_api_tool_call(call, index)) + .collect(); + + ChatResponse { text, tool_calls } +} + +fn map_api_tool_call(call: ApiToolCall, index: usize) -> Option { + if call.kind.as_deref().is_some_and(|kind| kind != "function") { + return None; + } + + let function = call.function?; + let name = function + .name + .and_then(|value| first_nonempty(Some(value.as_str())))?; + let arguments = function + .arguments + .and_then(|value| first_nonempty(Some(value.as_str()))) + .unwrap_or_else(|| "{}".to_string()); + let id = call + .id + .and_then(|value| first_nonempty(Some(value.as_str()))) + .unwrap_or_else(|| format!("call_{}", index + 1)); + + Some(ToolCall { + id, + name, + arguments, + }) +} + impl OpenAiCompatibleProvider { fn apply_auth_header( &self, @@ -244,7 +283,7 @@ impl OpenAiCompatibleProvider { system_prompt: Option<&str>, message: &str, model: &str, - ) -> anyhow::Result { + ) -> anyhow::Result { let request = ResponsesRequest { model: model.to_string(), input: vec![ResponsesInput { @@ -270,6 +309,7 @@ impl OpenAiCompatibleProvider { let responses: ResponsesResponse = response.json().await?; extract_responses_text(responses) + .map(ChatResponse::with_text) .ok_or_else(|| anyhow::anyhow!("No response from {} Responses API", self.name)) } } @@ -282,7 +322,7 @@ impl Provider for OpenAiCompatibleProvider { message: &str, model: &str, temperature: f64, - ) -> anyhow::Result { + ) -> anyhow::Result { let api_key = self.api_key.as_ref().ok_or_else(|| { anyhow::anyhow!( "{} API key not set. Run `zeroclaw onboard` or set the appropriate env var.", @@ -339,27 +379,13 @@ impl Provider for OpenAiCompatibleProvider { let chat_response: ApiChatResponse = response.json().await?; - chat_response + let choice = chat_response .choices .into_iter() .next() - .map(|c| { - // If tool_calls are present, serialize the full message as JSON - // so parse_tool_calls can handle the OpenAI-style format - if c.message.tool_calls.is_some() - && c.message - .tool_calls - .as_ref() - .map_or(false, |t| !t.is_empty()) - { - serde_json::to_string(&c.message) - .unwrap_or_else(|_| c.message.content.unwrap_or_default()) - } else { - // No tool calls, return content as-is - c.message.content.unwrap_or_default() - } - }) - .ok_or_else(|| anyhow::anyhow!("No response from {}", self.name)) + .ok_or_else(|| anyhow::anyhow!("No response from {}", self.name))?; + + Ok(map_response_message(choice.message)) } async fn chat_with_history( @@ -367,7 +393,7 @@ impl Provider for OpenAiCompatibleProvider { messages: &[ChatMessage], model: &str, temperature: f64, - ) -> anyhow::Result { + ) -> anyhow::Result { let api_key = self.api_key.as_ref().ok_or_else(|| { anyhow::anyhow!( "{} API key not set. Run `zeroclaw onboard` or set the appropriate env var.", @@ -426,27 +452,13 @@ impl Provider for OpenAiCompatibleProvider { let chat_response: ApiChatResponse = response.json().await?; - chat_response + let choice = chat_response .choices .into_iter() .next() - .map(|c| { - // If tool_calls are present, serialize the full message as JSON - // so parse_tool_calls can handle the OpenAI-style format - if c.message.tool_calls.is_some() - && c.message - .tool_calls - .as_ref() - .map_or(false, |t| !t.is_empty()) - { - serde_json::to_string(&c.message) - .unwrap_or_else(|_| c.message.content.unwrap_or_default()) - } else { - // No tool calls, return content as-is - c.message.content.unwrap_or_default() - } - }) - .ok_or_else(|| anyhow::anyhow!("No response from {}", self.name)) + .ok_or_else(|| anyhow::anyhow!("No response from {}", self.name))?; + + Ok(map_response_message(choice.message)) } } @@ -530,6 +542,20 @@ mod tests { assert!(resp.choices.is_empty()); } + #[test] + fn response_with_tool_calls_maps_structured_data() { + let json = r#"{"choices":[{"message":{"content":"Running checks","tool_calls":[{"id":"call_1","type":"function","function":{"name":"shell","arguments":"{\"command\":\"pwd\"}"}}]}}]}"#; + let resp: ApiChatResponse = serde_json::from_str(json).unwrap(); + let choice = resp.choices.into_iter().next().unwrap(); + + let mapped = map_response_message(choice.message); + assert_eq!(mapped.text.as_deref(), Some("Running checks")); + assert_eq!(mapped.tool_calls.len(), 1); + assert_eq!(mapped.tool_calls[0].id, "call_1"); + assert_eq!(mapped.tool_calls[0].name, "shell"); + assert_eq!(mapped.tool_calls[0].arguments, r#"{"command":"pwd"}"#); + } + #[test] fn x_api_key_auth_style() { let p = OpenAiCompatibleProvider::new( diff --git a/src/providers/gemini.rs b/src/providers/gemini.rs index a988224..189daf0 100644 --- a/src/providers/gemini.rs +++ b/src/providers/gemini.rs @@ -3,7 +3,7 @@ //! - Gemini CLI OAuth tokens (reuse existing ~/.gemini/ authentication) //! - Google Cloud ADC (`GOOGLE_APPLICATION_CREDENTIALS`) -use crate::providers::traits::Provider; +use crate::providers::traits::{ChatResponse, Provider}; use async_trait::async_trait; use directories::UserDirs; use reqwest::Client; @@ -260,7 +260,7 @@ impl Provider for GeminiProvider { message: &str, model: &str, temperature: f64, - ) -> anyhow::Result { + ) -> anyhow::Result { let auth = self.auth.as_ref().ok_or_else(|| { anyhow::anyhow!( "Gemini API key not found. Options:\n\ @@ -319,6 +319,7 @@ impl Provider for GeminiProvider { .and_then(|c| c.into_iter().next()) .and_then(|c| c.content.parts.into_iter().next()) .and_then(|p| p.text) + .map(ChatResponse::with_text) .ok_or_else(|| anyhow::anyhow!("No response from Gemini")) } } diff --git a/src/providers/mod.rs b/src/providers/mod.rs index 4164fff..5911904 100644 --- a/src/providers/mod.rs +++ b/src/providers/mod.rs @@ -8,7 +8,7 @@ pub mod reliable; pub mod router; pub mod traits; -pub use traits::{ChatMessage, Provider}; +pub use traits::{ChatMessage, ChatResponse, Provider, ToolCall}; use compatible::{AuthStyle, OpenAiCompatibleProvider}; use reliable::ReliableProvider; diff --git a/src/providers/ollama.rs b/src/providers/ollama.rs index e3e08f2..481d0bf 100644 --- a/src/providers/ollama.rs +++ b/src/providers/ollama.rs @@ -1,4 +1,4 @@ -use crate::providers::traits::Provider; +use crate::providers::traits::{ChatResponse as ProviderChatResponse, Provider}; use async_trait::async_trait; use reqwest::Client; use serde::{Deserialize, Serialize}; @@ -28,7 +28,7 @@ struct Options { } #[derive(Debug, Deserialize)] -struct ChatResponse { +struct ApiChatResponse { message: ResponseMessage, } @@ -61,7 +61,7 @@ impl Provider for OllamaProvider { message: &str, model: &str, temperature: f64, - ) -> anyhow::Result { + ) -> anyhow::Result { let mut messages = Vec::new(); if let Some(sys) = system_prompt { @@ -92,8 +92,10 @@ impl Provider for OllamaProvider { anyhow::bail!("{err}. Is Ollama running? (brew install ollama && ollama serve)"); } - let chat_response: ChatResponse = response.json().await?; - Ok(chat_response.message.content) + let chat_response: ApiChatResponse = response.json().await?; + Ok(ProviderChatResponse::with_text( + chat_response.message.content, + )) } } @@ -168,21 +170,21 @@ mod tests { #[test] fn response_deserializes() { let json = r#"{"message":{"role":"assistant","content":"Hello from Ollama!"}}"#; - let resp: ChatResponse = serde_json::from_str(json).unwrap(); + let resp: ApiChatResponse = serde_json::from_str(json).unwrap(); assert_eq!(resp.message.content, "Hello from Ollama!"); } #[test] fn response_with_empty_content() { let json = r#"{"message":{"role":"assistant","content":""}}"#; - let resp: ChatResponse = serde_json::from_str(json).unwrap(); + let resp: ApiChatResponse = serde_json::from_str(json).unwrap(); assert!(resp.message.content.is_empty()); } #[test] fn response_with_multiline() { let json = r#"{"message":{"role":"assistant","content":"line1\nline2\nline3"}}"#; - let resp: ChatResponse = serde_json::from_str(json).unwrap(); + let resp: ApiChatResponse = serde_json::from_str(json).unwrap(); assert!(resp.message.content.contains("line1")); } } diff --git a/src/providers/openai.rs b/src/providers/openai.rs index f202073..6b8bbe5 100644 --- a/src/providers/openai.rs +++ b/src/providers/openai.rs @@ -1,4 +1,4 @@ -use crate::providers::traits::Provider; +use crate::providers::traits::{ChatResponse, Provider}; use async_trait::async_trait; use reqwest::Client; use serde::{Deserialize, Serialize}; @@ -22,7 +22,7 @@ struct Message { } #[derive(Debug, Deserialize)] -struct ChatResponse { +struct ApiChatResponse { choices: Vec, } @@ -57,7 +57,7 @@ impl Provider for OpenAiProvider { message: &str, model: &str, temperature: f64, - ) -> anyhow::Result { + ) -> anyhow::Result { let api_key = self.api_key.as_ref().ok_or_else(|| { anyhow::anyhow!("OpenAI API key not set. Set OPENAI_API_KEY or edit config.toml.") })?; @@ -94,13 +94,13 @@ impl Provider for OpenAiProvider { return Err(super::api_error("OpenAI", response).await); } - let chat_response: ChatResponse = response.json().await?; + let chat_response: ApiChatResponse = response.json().await?; chat_response .choices .into_iter() .next() - .map(|c| c.message.content) + .map(|c| ChatResponse::with_text(c.message.content)) .ok_or_else(|| anyhow::anyhow!("No response from OpenAI")) } } @@ -184,7 +184,7 @@ mod tests { #[test] fn response_deserializes_single_choice() { let json = r#"{"choices":[{"message":{"content":"Hi!"}}]}"#; - let resp: ChatResponse = serde_json::from_str(json).unwrap(); + let resp: ApiChatResponse = serde_json::from_str(json).unwrap(); assert_eq!(resp.choices.len(), 1); assert_eq!(resp.choices[0].message.content, "Hi!"); } @@ -192,14 +192,14 @@ mod tests { #[test] fn response_deserializes_empty_choices() { let json = r#"{"choices":[]}"#; - let resp: ChatResponse = serde_json::from_str(json).unwrap(); + let resp: ApiChatResponse = serde_json::from_str(json).unwrap(); assert!(resp.choices.is_empty()); } #[test] fn response_deserializes_multiple_choices() { let json = r#"{"choices":[{"message":{"content":"A"}},{"message":{"content":"B"}}]}"#; - let resp: ChatResponse = serde_json::from_str(json).unwrap(); + let resp: ApiChatResponse = serde_json::from_str(json).unwrap(); assert_eq!(resp.choices.len(), 2); assert_eq!(resp.choices[0].message.content, "A"); } @@ -207,7 +207,7 @@ mod tests { #[test] fn response_with_unicode() { let json = r#"{"choices":[{"message":{"content":"こんにちは 🦀"}}]}"#; - let resp: ChatResponse = serde_json::from_str(json).unwrap(); + let resp: ApiChatResponse = serde_json::from_str(json).unwrap(); assert_eq!(resp.choices[0].message.content, "こんにちは 🦀"); } @@ -215,7 +215,7 @@ mod tests { fn response_with_long_content() { let long = "x".repeat(100_000); let json = format!(r#"{{"choices":[{{"message":{{"content":"{long}"}}}}]}}"#); - let resp: ChatResponse = serde_json::from_str(&json).unwrap(); + let resp: ApiChatResponse = serde_json::from_str(&json).unwrap(); assert_eq!(resp.choices[0].message.content.len(), 100_000); } } diff --git a/src/providers/openrouter.rs b/src/providers/openrouter.rs index 6cb90e3..287dd88 100644 --- a/src/providers/openrouter.rs +++ b/src/providers/openrouter.rs @@ -1,4 +1,4 @@ -use crate::providers::traits::{ChatMessage, Provider}; +use crate::providers::traits::{ChatMessage, ChatResponse, Provider}; use async_trait::async_trait; use reqwest::Client; use serde::{Deserialize, Serialize}; @@ -71,7 +71,7 @@ impl Provider for OpenRouterProvider { message: &str, model: &str, temperature: f64, - ) -> anyhow::Result { + ) -> anyhow::Result { let api_key = self.api_key.as_ref() .ok_or_else(|| anyhow::anyhow!("OpenRouter API key not set. Run `zeroclaw onboard` or set OPENROUTER_API_KEY env var."))?; @@ -118,7 +118,7 @@ impl Provider for OpenRouterProvider { .choices .into_iter() .next() - .map(|c| c.message.content) + .map(|c| ChatResponse::with_text(c.message.content)) .ok_or_else(|| anyhow::anyhow!("No response from OpenRouter")) } @@ -127,7 +127,7 @@ impl Provider for OpenRouterProvider { messages: &[ChatMessage], model: &str, temperature: f64, - ) -> anyhow::Result { + ) -> anyhow::Result { let api_key = self.api_key.as_ref() .ok_or_else(|| anyhow::anyhow!("OpenRouter API key not set. Run `zeroclaw onboard` or set OPENROUTER_API_KEY env var."))?; @@ -168,7 +168,7 @@ impl Provider for OpenRouterProvider { .choices .into_iter() .next() - .map(|c| c.message.content) + .map(|c| ChatResponse::with_text(c.message.content)) .ok_or_else(|| anyhow::anyhow!("No response from OpenRouter")) } } diff --git a/src/providers/reliable.rs b/src/providers/reliable.rs index 366f013..12aaa62 100644 --- a/src/providers/reliable.rs +++ b/src/providers/reliable.rs @@ -1,4 +1,4 @@ -use super::traits::ChatMessage; +use super::traits::{ChatMessage, ChatResponse}; use super::Provider; use async_trait::async_trait; use std::time::Duration; @@ -66,7 +66,7 @@ impl Provider for ReliableProvider { message: &str, model: &str, temperature: f64, - ) -> anyhow::Result { + ) -> anyhow::Result { let mut failures = Vec::new(); for (provider_name, provider) in &self.providers { @@ -128,7 +128,7 @@ impl Provider for ReliableProvider { messages: &[ChatMessage], model: &str, temperature: f64, - ) -> anyhow::Result { + ) -> anyhow::Result { let mut failures = Vec::new(); for (provider_name, provider) in &self.providers { @@ -207,12 +207,12 @@ mod tests { _message: &str, _model: &str, _temperature: f64, - ) -> anyhow::Result { + ) -> anyhow::Result { let attempt = self.calls.fetch_add(1, Ordering::SeqCst) + 1; if attempt <= self.fail_until_attempt { anyhow::bail!(self.error); } - Ok(self.response.to_string()) + Ok(ChatResponse::with_text(self.response)) } async fn chat_with_history( @@ -220,12 +220,12 @@ mod tests { _messages: &[ChatMessage], _model: &str, _temperature: f64, - ) -> anyhow::Result { + ) -> anyhow::Result { let attempt = self.calls.fetch_add(1, Ordering::SeqCst) + 1; if attempt <= self.fail_until_attempt { anyhow::bail!(self.error); } - Ok(self.response.to_string()) + Ok(ChatResponse::with_text(self.response)) } } @@ -247,7 +247,7 @@ mod tests { ); let result = provider.chat("hello", "test", 0.0).await.unwrap(); - assert_eq!(result, "ok"); + assert_eq!(result.text_or_empty(), "ok"); assert_eq!(calls.load(Ordering::SeqCst), 1); } @@ -269,7 +269,7 @@ mod tests { ); let result = provider.chat("hello", "test", 0.0).await.unwrap(); - assert_eq!(result, "recovered"); + assert_eq!(result.text_or_empty(), "recovered"); assert_eq!(calls.load(Ordering::SeqCst), 2); } @@ -304,7 +304,7 @@ mod tests { ); let result = provider.chat("hello", "test", 0.0).await.unwrap(); - assert_eq!(result, "from fallback"); + assert_eq!(result.text_or_empty(), "from fallback"); assert_eq!(primary_calls.load(Ordering::SeqCst), 2); assert_eq!(fallback_calls.load(Ordering::SeqCst), 1); } @@ -401,7 +401,7 @@ mod tests { ); let result = provider.chat("hello", "test", 0.0).await.unwrap(); - assert_eq!(result, "from fallback"); + assert_eq!(result.text_or_empty(), "from fallback"); // Primary should have been called only once (no retries) assert_eq!(primary_calls.load(Ordering::SeqCst), 1); assert_eq!(fallback_calls.load(Ordering::SeqCst), 1); @@ -429,7 +429,7 @@ mod tests { .chat_with_history(&messages, "test", 0.0) .await .unwrap(); - assert_eq!(result, "history ok"); + assert_eq!(result.text_or_empty(), "history ok"); assert_eq!(calls.load(Ordering::SeqCst), 2); } @@ -468,7 +468,7 @@ mod tests { .chat_with_history(&messages, "test", 0.0) .await .unwrap(); - assert_eq!(result, "fallback ok"); + assert_eq!(result.text_or_empty(), "fallback ok"); assert_eq!(primary_calls.load(Ordering::SeqCst), 2); assert_eq!(fallback_calls.load(Ordering::SeqCst), 1); } diff --git a/src/providers/router.rs b/src/providers/router.rs index 4ee36f3..eb3101f 100644 --- a/src/providers/router.rs +++ b/src/providers/router.rs @@ -1,4 +1,4 @@ -use super::traits::ChatMessage; +use super::traits::{ChatMessage, ChatResponse}; use super::Provider; use async_trait::async_trait; use std::collections::HashMap; @@ -98,7 +98,7 @@ impl Provider for RouterProvider { message: &str, model: &str, temperature: f64, - ) -> anyhow::Result { + ) -> anyhow::Result { let (provider_idx, resolved_model) = self.resolve(model); let (provider_name, provider) = &self.providers[provider_idx]; @@ -118,7 +118,7 @@ impl Provider for RouterProvider { messages: &[ChatMessage], model: &str, temperature: f64, - ) -> anyhow::Result { + ) -> anyhow::Result { let (provider_idx, resolved_model) = self.resolve(model); let (_, provider) = &self.providers[provider_idx]; provider @@ -175,10 +175,10 @@ mod tests { _message: &str, model: &str, _temperature: f64, - ) -> anyhow::Result { + ) -> anyhow::Result { self.calls.fetch_add(1, Ordering::SeqCst); *self.last_model.lock().unwrap() = model.to_string(); - Ok(self.response.to_string()) + Ok(ChatResponse::with_text(self.response)) } } @@ -229,7 +229,7 @@ mod tests { message: &str, model: &str, temperature: f64, - ) -> anyhow::Result { + ) -> anyhow::Result { self.as_ref() .chat_with_system(system_prompt, message, model, temperature) .await @@ -247,7 +247,7 @@ mod tests { ); let result = router.chat("hello", "hint:reasoning", 0.5).await.unwrap(); - assert_eq!(result, "smart-response"); + assert_eq!(result.text_or_empty(), "smart-response"); assert_eq!(mocks[1].call_count(), 1); assert_eq!(mocks[1].last_model(), "claude-opus"); assert_eq!(mocks[0].call_count(), 0); @@ -261,7 +261,7 @@ mod tests { ); let result = router.chat("hello", "hint:fast", 0.5).await.unwrap(); - assert_eq!(result, "fast-response"); + assert_eq!(result.text_or_empty(), "fast-response"); assert_eq!(mocks[0].call_count(), 1); assert_eq!(mocks[0].last_model(), "llama-3-70b"); } @@ -274,7 +274,7 @@ mod tests { ); let result = router.chat("hello", "hint:nonexistent", 0.5).await.unwrap(); - assert_eq!(result, "default-response"); + assert_eq!(result.text_or_empty(), "default-response"); assert_eq!(mocks[0].call_count(), 1); // Falls back to default with the hint as model name assert_eq!(mocks[0].last_model(), "hint:nonexistent"); @@ -294,7 +294,7 @@ mod tests { .chat("hello", "anthropic/claude-sonnet-4-20250514", 0.5) .await .unwrap(); - assert_eq!(result, "primary-response"); + assert_eq!(result.text_or_empty(), "primary-response"); assert_eq!(mocks[0].call_count(), 1); assert_eq!(mocks[0].last_model(), "anthropic/claude-sonnet-4-20250514"); } @@ -355,7 +355,7 @@ mod tests { .chat_with_system(Some("system"), "hello", "model", 0.5) .await .unwrap(); - assert_eq!(result, "response"); + assert_eq!(result.text_or_empty(), "response"); assert_eq!(mock.call_count(), 1); } } diff --git a/src/providers/traits.rs b/src/providers/traits.rs index 84746ea..d1f8dd1 100644 --- a/src/providers/traits.rs +++ b/src/providers/traits.rs @@ -49,6 +49,14 @@ pub struct ChatResponse { } impl ChatResponse { + /// Convenience: construct a plain text response with no tool calls. + pub fn with_text(text: impl Into) -> Self { + Self { + text: Some(text.into()), + tool_calls: vec![], + } + } + /// True when the LLM wants to invoke at least one tool. pub fn has_tool_calls(&self) -> bool { !self.tool_calls.is_empty() @@ -84,7 +92,12 @@ pub enum ConversationMessage { #[async_trait] pub trait Provider: Send + Sync { - async fn chat(&self, message: &str, model: &str, temperature: f64) -> anyhow::Result { + async fn chat( + &self, + message: &str, + model: &str, + temperature: f64, + ) -> anyhow::Result { self.chat_with_system(None, message, model, temperature) .await } @@ -95,7 +108,7 @@ pub trait Provider: Send + Sync { message: &str, model: &str, temperature: f64, - ) -> anyhow::Result; + ) -> anyhow::Result; /// Multi-turn conversation. Default implementation extracts the last user /// message and delegates to `chat_with_system`. @@ -104,7 +117,7 @@ pub trait Provider: Send + Sync { messages: &[ChatMessage], model: &str, temperature: f64, - ) -> anyhow::Result { + ) -> anyhow::Result { let system = messages .iter() .find(|m| m.role == "system")