From 3b4a4de45769c60336b4cda294671594a8e711ac Mon Sep 17 00:00:00 2001 From: chumyin Date: Mon, 16 Feb 2026 13:04:10 +0800 Subject: [PATCH 1/4] refactor(provider): unify Provider responses with ChatResponse - Switch Provider trait methods to return structured ChatResponse - Map OpenAI-compatible tool_calls into shared ToolCall type - Update reliable/router wrappers and provider tests for new interface - Make agent loop prefer structured tool calls with text fallback parsing - Adapt gateway replies to structured responses with safe tool-call fallback --- src/agent/loop_.rs | 95 +++++++++++++++++++++++++++---- src/gateway/mod.rs | 35 ++++++++++-- src/providers/anthropic.rs | 16 +++--- src/providers/compatible.rs | 110 ++++++++++++++++++++++-------------- src/providers/gemini.rs | 5 +- src/providers/mod.rs | 2 +- src/providers/ollama.rs | 18 +++--- src/providers/openai.rs | 20 +++---- src/providers/openrouter.rs | 10 ++-- src/providers/reliable.rs | 26 ++++----- src/providers/router.rs | 22 ++++---- src/providers/traits.rs | 19 ++++++- 12 files changed, 260 insertions(+), 118 deletions(-) diff --git a/src/agent/loop_.rs b/src/agent/loop_.rs index a1aea97..45b37d2 100644 --- a/src/agent/loop_.rs +++ b/src/agent/loop_.rs @@ -1,7 +1,7 @@ use crate::config::Config; use crate::memory::{self, Memory, MemoryCategory}; use crate::observability::{self, Observer, ObserverEvent}; -use crate::providers::{self, ChatMessage, Provider}; +use crate::providers::{self, ChatMessage, Provider, ToolCall}; use crate::runtime; use crate::security::SecurityPolicy; use crate::tools::{self, Tool}; @@ -331,15 +331,71 @@ fn parse_tool_calls(response: &str) -> (String, Vec) { (text_parts.join("\n"), calls) } +fn parse_structured_tool_calls(tool_calls: &[ToolCall]) -> Vec { + tool_calls + .iter() + .map(|call| ParsedToolCall { + name: call.name.clone(), + arguments: serde_json::from_str::(&call.arguments) + .unwrap_or_else(|_| serde_json::Value::Object(serde_json::Map::new())), + }) + .collect() +} + +fn build_assistant_history_with_tool_calls(text: &str, tool_calls: &[ToolCall]) -> String { + let mut parts = Vec::new(); + + if !text.trim().is_empty() { + parts.push(text.trim().to_string()); + } + + for call in tool_calls { + let arguments = serde_json::from_str::(&call.arguments) + .unwrap_or_else(|_| serde_json::Value::String(call.arguments.clone())); + let payload = serde_json::json!({ + "id": call.id, + "name": call.name, + "arguments": arguments, + }); + parts.push(format!("\n{payload}\n")); + } + + parts.join("\n") +} + #[derive(Debug)] struct ParsedToolCall { name: String, arguments: serde_json::Value, } +/// Execute a single turn for channel runtime paths. +/// +/// Channels currently do not thread an explicit provider label into this call, +/// so we route through the full loop with a stable placeholder provider name. +pub(crate) async fn agent_turn( + provider: &dyn Provider, + history: &mut Vec, + tools_registry: &[Box], + observer: &dyn Observer, + model: &str, + temperature: f64, +) -> Result { + run_tool_call_loop( + provider, + history, + tools_registry, + observer, + "channel-runtime", + model, + temperature, + ) + .await +} + /// Execute a single turn of the agent loop: send messages, parse tool calls, /// execute tools, and loop until the LLM produces a final text response. -pub(crate) async fn agent_turn( +pub(crate) async fn run_tool_call_loop( provider: &dyn Provider, history: &mut Vec, tools_registry: &[Box], @@ -382,17 +438,36 @@ pub(crate) async fn agent_turn( } }; - let (text, tool_calls) = parse_tool_calls(&response); + let response_text = response.text.unwrap_or_default(); + let mut assistant_history_content = response_text.clone(); + let mut parsed_text = response_text.clone(); + let mut tool_calls = parse_structured_tool_calls(&response.tool_calls); + + if !response.tool_calls.is_empty() { + assistant_history_content = + build_assistant_history_with_tool_calls(&response_text, &response.tool_calls); + } + + if tool_calls.is_empty() { + let (fallback_text, fallback_calls) = parse_tool_calls(&response_text); + parsed_text = fallback_text; + tool_calls = fallback_calls; + } if tool_calls.is_empty() { // No tool calls — this is the final response - history.push(ChatMessage::assistant(&response)); - return Ok(if text.is_empty() { response } else { text }); + let final_text = if parsed_text.is_empty() { + response_text + } else { + parsed_text + }; + history.push(ChatMessage::assistant(&final_text)); + return Ok(final_text); } // Print any text the LLM produced alongside tool calls - if !text.is_empty() { - print!("{text}"); + if !parsed_text.is_empty() { + print!("{parsed_text}"); let _ = std::io::stdout().flush(); } @@ -438,7 +513,7 @@ pub(crate) async fn agent_turn( } // Add assistant message with tool calls + tool results to history - history.push(ChatMessage::assistant(&response)); + history.push(ChatMessage::assistant(&assistant_history_content)); history.push(ChatMessage::user(format!("[Tool results]\n{tool_results}"))); } @@ -639,7 +714,7 @@ pub async fn run( ChatMessage::user(&enriched), ]; - let response = agent_turn( + let response = run_tool_call_loop( provider.as_ref(), &mut history, &tools_registry, @@ -694,7 +769,7 @@ pub async fn run( history.push(ChatMessage::user(&enriched)); - let response = match agent_turn( + let response = match run_tool_call_loop( provider.as_ref(), &mut history, &tools_registry, diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index 11de562..2282e66 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -10,7 +10,7 @@ use crate::channels::{Channel, WhatsAppChannel}; use crate::config::Config; use crate::memory::{self, Memory, MemoryCategory}; -use crate::providers::{self, Provider}; +use crate::providers::{self, ChatResponse, Provider}; use crate::security::pairing::{constant_time_eq, is_public_bind, PairingGuard}; use crate::util::truncate_with_ellipsis; use anyhow::Result; @@ -45,6 +45,29 @@ fn whatsapp_memory_key(msg: &crate::channels::traits::ChannelMessage) -> String format!("whatsapp_{}_{}", msg.sender, msg.id) } +fn gateway_reply_from_response(response: ChatResponse) -> String { + let has_tool_calls = response.has_tool_calls(); + let tool_call_count = response.tool_calls.len(); + let mut reply = response.text.unwrap_or_default(); + + if has_tool_calls { + tracing::warn!( + tool_call_count, + "Provider requested tool calls in gateway mode; tool calls are not executed here" + ); + if reply.trim().is_empty() { + reply = "I need to use tools to answer that, but tool execution is not enabled for gateway requests yet." + .to_string(); + } + } + + if reply.trim().is_empty() { + reply = "Model returned an empty response.".to_string(); + } + + reply +} + #[derive(Debug)] struct SlidingWindowRateLimiter { limit_per_window: u32, @@ -497,7 +520,8 @@ async fn handle_webhook( .await { Ok(response) => { - let body = serde_json::json!({"response": response, "model": state.model}); + let reply = gateway_reply_from_response(response); + let body = serde_json::json!({"response": reply, "model": state.model}); (StatusCode::OK, Json(body)) } Err(e) => { @@ -651,8 +675,9 @@ async fn handle_whatsapp_message( .await { Ok(response) => { + let reply = gateway_reply_from_response(response); // Send reply via WhatsApp - if let Err(e) = wa.send(&response, &msg.sender).await { + if let Err(e) = wa.send(&reply, &msg.sender).await { tracing::error!("Failed to send WhatsApp reply: {e}"); } } @@ -822,9 +847,9 @@ mod tests { _message: &str, _model: &str, _temperature: f64, - ) -> anyhow::Result { + ) -> anyhow::Result { self.calls.fetch_add(1, Ordering::SeqCst); - Ok("ok".into()) + Ok(ChatResponse::with_text("ok")) } } diff --git a/src/providers/anthropic.rs b/src/providers/anthropic.rs index 3202a01..c3c7870 100644 --- a/src/providers/anthropic.rs +++ b/src/providers/anthropic.rs @@ -1,4 +1,4 @@ -use crate::providers::traits::Provider; +use crate::providers::traits::{ChatResponse as ProviderChatResponse, Provider}; use async_trait::async_trait; use reqwest::Client; use serde::{Deserialize, Serialize}; @@ -26,7 +26,7 @@ struct Message { } #[derive(Debug, Deserialize)] -struct ChatResponse { +struct ApiChatResponse { content: Vec, } @@ -72,7 +72,7 @@ impl Provider for AnthropicProvider { message: &str, model: &str, temperature: f64, - ) -> anyhow::Result { + ) -> anyhow::Result { let credential = self.credential.as_ref().ok_or_else(|| { anyhow::anyhow!( "Anthropic credentials not set. Set ANTHROPIC_API_KEY or ANTHROPIC_OAUTH_TOKEN (setup-token)." @@ -109,13 +109,13 @@ impl Provider for AnthropicProvider { return Err(super::api_error("Anthropic", response).await); } - let chat_response: ChatResponse = response.json().await?; + let chat_response: ApiChatResponse = response.json().await?; chat_response .content .into_iter() .next() - .map(|c| c.text) + .map(|c| ProviderChatResponse::with_text(c.text)) .ok_or_else(|| anyhow::anyhow!("No response from Anthropic")) } } @@ -241,7 +241,7 @@ mod tests { #[test] fn chat_response_deserializes() { let json = r#"{"content":[{"type":"text","text":"Hello there!"}]}"#; - let resp: ChatResponse = serde_json::from_str(json).unwrap(); + let resp: ApiChatResponse = serde_json::from_str(json).unwrap(); assert_eq!(resp.content.len(), 1); assert_eq!(resp.content[0].text, "Hello there!"); } @@ -249,7 +249,7 @@ mod tests { #[test] fn chat_response_empty_content() { let json = r#"{"content":[]}"#; - let resp: ChatResponse = serde_json::from_str(json).unwrap(); + let resp: ApiChatResponse = serde_json::from_str(json).unwrap(); assert!(resp.content.is_empty()); } @@ -257,7 +257,7 @@ mod tests { fn chat_response_multiple_blocks() { let json = r#"{"content":[{"type":"text","text":"First"},{"type":"text","text":"Second"}]}"#; - let resp: ChatResponse = serde_json::from_str(json).unwrap(); + let resp: ApiChatResponse = serde_json::from_str(json).unwrap(); assert_eq!(resp.content.len(), 2); assert_eq!(resp.content[0].text, "First"); assert_eq!(resp.content[1].text, "Second"); diff --git a/src/providers/compatible.rs b/src/providers/compatible.rs index 2312741..de7bff0 100644 --- a/src/providers/compatible.rs +++ b/src/providers/compatible.rs @@ -2,7 +2,7 @@ //! Most LLM APIs follow the same `/v1/chat/completions` format. //! This module provides a single implementation that works for all of them. -use crate::providers::traits::{ChatMessage, Provider}; +use crate::providers::traits::{ChatMessage, ChatResponse, Provider, ToolCall}; use async_trait::async_trait; use reqwest::Client; use serde::{Deserialize, Serialize}; @@ -135,11 +135,12 @@ struct ResponseMessage { #[serde(default)] content: Option, #[serde(default)] - tool_calls: Option>, + tool_calls: Option>, } #[derive(Debug, Deserialize, Serialize)] -struct ToolCall { +struct ApiToolCall { + id: Option, #[serde(rename = "type")] kind: Option, function: Option, @@ -225,6 +226,44 @@ fn extract_responses_text(response: ResponsesResponse) -> Option { None } +fn map_response_message(message: ResponseMessage) -> ChatResponse { + let text = first_nonempty(message.content.as_deref()); + let tool_calls = message + .tool_calls + .unwrap_or_default() + .into_iter() + .enumerate() + .filter_map(|(index, call)| map_api_tool_call(call, index)) + .collect(); + + ChatResponse { text, tool_calls } +} + +fn map_api_tool_call(call: ApiToolCall, index: usize) -> Option { + if call.kind.as_deref().is_some_and(|kind| kind != "function") { + return None; + } + + let function = call.function?; + let name = function + .name + .and_then(|value| first_nonempty(Some(value.as_str())))?; + let arguments = function + .arguments + .and_then(|value| first_nonempty(Some(value.as_str()))) + .unwrap_or_else(|| "{}".to_string()); + let id = call + .id + .and_then(|value| first_nonempty(Some(value.as_str()))) + .unwrap_or_else(|| format!("call_{}", index + 1)); + + Some(ToolCall { + id, + name, + arguments, + }) +} + impl OpenAiCompatibleProvider { fn apply_auth_header( &self, @@ -244,7 +283,7 @@ impl OpenAiCompatibleProvider { system_prompt: Option<&str>, message: &str, model: &str, - ) -> anyhow::Result { + ) -> anyhow::Result { let request = ResponsesRequest { model: model.to_string(), input: vec![ResponsesInput { @@ -270,6 +309,7 @@ impl OpenAiCompatibleProvider { let responses: ResponsesResponse = response.json().await?; extract_responses_text(responses) + .map(ChatResponse::with_text) .ok_or_else(|| anyhow::anyhow!("No response from {} Responses API", self.name)) } } @@ -282,7 +322,7 @@ impl Provider for OpenAiCompatibleProvider { message: &str, model: &str, temperature: f64, - ) -> anyhow::Result { + ) -> anyhow::Result { let api_key = self.api_key.as_ref().ok_or_else(|| { anyhow::anyhow!( "{} API key not set. Run `zeroclaw onboard` or set the appropriate env var.", @@ -339,27 +379,13 @@ impl Provider for OpenAiCompatibleProvider { let chat_response: ApiChatResponse = response.json().await?; - chat_response + let choice = chat_response .choices .into_iter() .next() - .map(|c| { - // If tool_calls are present, serialize the full message as JSON - // so parse_tool_calls can handle the OpenAI-style format - if c.message.tool_calls.is_some() - && c.message - .tool_calls - .as_ref() - .map_or(false, |t| !t.is_empty()) - { - serde_json::to_string(&c.message) - .unwrap_or_else(|_| c.message.content.unwrap_or_default()) - } else { - // No tool calls, return content as-is - c.message.content.unwrap_or_default() - } - }) - .ok_or_else(|| anyhow::anyhow!("No response from {}", self.name)) + .ok_or_else(|| anyhow::anyhow!("No response from {}", self.name))?; + + Ok(map_response_message(choice.message)) } async fn chat_with_history( @@ -367,7 +393,7 @@ impl Provider for OpenAiCompatibleProvider { messages: &[ChatMessage], model: &str, temperature: f64, - ) -> anyhow::Result { + ) -> anyhow::Result { let api_key = self.api_key.as_ref().ok_or_else(|| { anyhow::anyhow!( "{} API key not set. Run `zeroclaw onboard` or set the appropriate env var.", @@ -426,27 +452,13 @@ impl Provider for OpenAiCompatibleProvider { let chat_response: ApiChatResponse = response.json().await?; - chat_response + let choice = chat_response .choices .into_iter() .next() - .map(|c| { - // If tool_calls are present, serialize the full message as JSON - // so parse_tool_calls can handle the OpenAI-style format - if c.message.tool_calls.is_some() - && c.message - .tool_calls - .as_ref() - .map_or(false, |t| !t.is_empty()) - { - serde_json::to_string(&c.message) - .unwrap_or_else(|_| c.message.content.unwrap_or_default()) - } else { - // No tool calls, return content as-is - c.message.content.unwrap_or_default() - } - }) - .ok_or_else(|| anyhow::anyhow!("No response from {}", self.name)) + .ok_or_else(|| anyhow::anyhow!("No response from {}", self.name))?; + + Ok(map_response_message(choice.message)) } } @@ -530,6 +542,20 @@ mod tests { assert!(resp.choices.is_empty()); } + #[test] + fn response_with_tool_calls_maps_structured_data() { + let json = r#"{"choices":[{"message":{"content":"Running checks","tool_calls":[{"id":"call_1","type":"function","function":{"name":"shell","arguments":"{\"command\":\"pwd\"}"}}]}}]}"#; + let resp: ApiChatResponse = serde_json::from_str(json).unwrap(); + let choice = resp.choices.into_iter().next().unwrap(); + + let mapped = map_response_message(choice.message); + assert_eq!(mapped.text.as_deref(), Some("Running checks")); + assert_eq!(mapped.tool_calls.len(), 1); + assert_eq!(mapped.tool_calls[0].id, "call_1"); + assert_eq!(mapped.tool_calls[0].name, "shell"); + assert_eq!(mapped.tool_calls[0].arguments, r#"{"command":"pwd"}"#); + } + #[test] fn x_api_key_auth_style() { let p = OpenAiCompatibleProvider::new( diff --git a/src/providers/gemini.rs b/src/providers/gemini.rs index a988224..189daf0 100644 --- a/src/providers/gemini.rs +++ b/src/providers/gemini.rs @@ -3,7 +3,7 @@ //! - Gemini CLI OAuth tokens (reuse existing ~/.gemini/ authentication) //! - Google Cloud ADC (`GOOGLE_APPLICATION_CREDENTIALS`) -use crate::providers::traits::Provider; +use crate::providers::traits::{ChatResponse, Provider}; use async_trait::async_trait; use directories::UserDirs; use reqwest::Client; @@ -260,7 +260,7 @@ impl Provider for GeminiProvider { message: &str, model: &str, temperature: f64, - ) -> anyhow::Result { + ) -> anyhow::Result { let auth = self.auth.as_ref().ok_or_else(|| { anyhow::anyhow!( "Gemini API key not found. Options:\n\ @@ -319,6 +319,7 @@ impl Provider for GeminiProvider { .and_then(|c| c.into_iter().next()) .and_then(|c| c.content.parts.into_iter().next()) .and_then(|p| p.text) + .map(ChatResponse::with_text) .ok_or_else(|| anyhow::anyhow!("No response from Gemini")) } } diff --git a/src/providers/mod.rs b/src/providers/mod.rs index 4164fff..5911904 100644 --- a/src/providers/mod.rs +++ b/src/providers/mod.rs @@ -8,7 +8,7 @@ pub mod reliable; pub mod router; pub mod traits; -pub use traits::{ChatMessage, Provider}; +pub use traits::{ChatMessage, ChatResponse, Provider, ToolCall}; use compatible::{AuthStyle, OpenAiCompatibleProvider}; use reliable::ReliableProvider; diff --git a/src/providers/ollama.rs b/src/providers/ollama.rs index e3e08f2..481d0bf 100644 --- a/src/providers/ollama.rs +++ b/src/providers/ollama.rs @@ -1,4 +1,4 @@ -use crate::providers::traits::Provider; +use crate::providers::traits::{ChatResponse as ProviderChatResponse, Provider}; use async_trait::async_trait; use reqwest::Client; use serde::{Deserialize, Serialize}; @@ -28,7 +28,7 @@ struct Options { } #[derive(Debug, Deserialize)] -struct ChatResponse { +struct ApiChatResponse { message: ResponseMessage, } @@ -61,7 +61,7 @@ impl Provider for OllamaProvider { message: &str, model: &str, temperature: f64, - ) -> anyhow::Result { + ) -> anyhow::Result { let mut messages = Vec::new(); if let Some(sys) = system_prompt { @@ -92,8 +92,10 @@ impl Provider for OllamaProvider { anyhow::bail!("{err}. Is Ollama running? (brew install ollama && ollama serve)"); } - let chat_response: ChatResponse = response.json().await?; - Ok(chat_response.message.content) + let chat_response: ApiChatResponse = response.json().await?; + Ok(ProviderChatResponse::with_text( + chat_response.message.content, + )) } } @@ -168,21 +170,21 @@ mod tests { #[test] fn response_deserializes() { let json = r#"{"message":{"role":"assistant","content":"Hello from Ollama!"}}"#; - let resp: ChatResponse = serde_json::from_str(json).unwrap(); + let resp: ApiChatResponse = serde_json::from_str(json).unwrap(); assert_eq!(resp.message.content, "Hello from Ollama!"); } #[test] fn response_with_empty_content() { let json = r#"{"message":{"role":"assistant","content":""}}"#; - let resp: ChatResponse = serde_json::from_str(json).unwrap(); + let resp: ApiChatResponse = serde_json::from_str(json).unwrap(); assert!(resp.message.content.is_empty()); } #[test] fn response_with_multiline() { let json = r#"{"message":{"role":"assistant","content":"line1\nline2\nline3"}}"#; - let resp: ChatResponse = serde_json::from_str(json).unwrap(); + let resp: ApiChatResponse = serde_json::from_str(json).unwrap(); assert!(resp.message.content.contains("line1")); } } diff --git a/src/providers/openai.rs b/src/providers/openai.rs index f202073..6b8bbe5 100644 --- a/src/providers/openai.rs +++ b/src/providers/openai.rs @@ -1,4 +1,4 @@ -use crate::providers::traits::Provider; +use crate::providers::traits::{ChatResponse, Provider}; use async_trait::async_trait; use reqwest::Client; use serde::{Deserialize, Serialize}; @@ -22,7 +22,7 @@ struct Message { } #[derive(Debug, Deserialize)] -struct ChatResponse { +struct ApiChatResponse { choices: Vec, } @@ -57,7 +57,7 @@ impl Provider for OpenAiProvider { message: &str, model: &str, temperature: f64, - ) -> anyhow::Result { + ) -> anyhow::Result { let api_key = self.api_key.as_ref().ok_or_else(|| { anyhow::anyhow!("OpenAI API key not set. Set OPENAI_API_KEY or edit config.toml.") })?; @@ -94,13 +94,13 @@ impl Provider for OpenAiProvider { return Err(super::api_error("OpenAI", response).await); } - let chat_response: ChatResponse = response.json().await?; + let chat_response: ApiChatResponse = response.json().await?; chat_response .choices .into_iter() .next() - .map(|c| c.message.content) + .map(|c| ChatResponse::with_text(c.message.content)) .ok_or_else(|| anyhow::anyhow!("No response from OpenAI")) } } @@ -184,7 +184,7 @@ mod tests { #[test] fn response_deserializes_single_choice() { let json = r#"{"choices":[{"message":{"content":"Hi!"}}]}"#; - let resp: ChatResponse = serde_json::from_str(json).unwrap(); + let resp: ApiChatResponse = serde_json::from_str(json).unwrap(); assert_eq!(resp.choices.len(), 1); assert_eq!(resp.choices[0].message.content, "Hi!"); } @@ -192,14 +192,14 @@ mod tests { #[test] fn response_deserializes_empty_choices() { let json = r#"{"choices":[]}"#; - let resp: ChatResponse = serde_json::from_str(json).unwrap(); + let resp: ApiChatResponse = serde_json::from_str(json).unwrap(); assert!(resp.choices.is_empty()); } #[test] fn response_deserializes_multiple_choices() { let json = r#"{"choices":[{"message":{"content":"A"}},{"message":{"content":"B"}}]}"#; - let resp: ChatResponse = serde_json::from_str(json).unwrap(); + let resp: ApiChatResponse = serde_json::from_str(json).unwrap(); assert_eq!(resp.choices.len(), 2); assert_eq!(resp.choices[0].message.content, "A"); } @@ -207,7 +207,7 @@ mod tests { #[test] fn response_with_unicode() { let json = r#"{"choices":[{"message":{"content":"こんにちは 🦀"}}]}"#; - let resp: ChatResponse = serde_json::from_str(json).unwrap(); + let resp: ApiChatResponse = serde_json::from_str(json).unwrap(); assert_eq!(resp.choices[0].message.content, "こんにちは 🦀"); } @@ -215,7 +215,7 @@ mod tests { fn response_with_long_content() { let long = "x".repeat(100_000); let json = format!(r#"{{"choices":[{{"message":{{"content":"{long}"}}}}]}}"#); - let resp: ChatResponse = serde_json::from_str(&json).unwrap(); + let resp: ApiChatResponse = serde_json::from_str(&json).unwrap(); assert_eq!(resp.choices[0].message.content.len(), 100_000); } } diff --git a/src/providers/openrouter.rs b/src/providers/openrouter.rs index 6cb90e3..287dd88 100644 --- a/src/providers/openrouter.rs +++ b/src/providers/openrouter.rs @@ -1,4 +1,4 @@ -use crate::providers::traits::{ChatMessage, Provider}; +use crate::providers::traits::{ChatMessage, ChatResponse, Provider}; use async_trait::async_trait; use reqwest::Client; use serde::{Deserialize, Serialize}; @@ -71,7 +71,7 @@ impl Provider for OpenRouterProvider { message: &str, model: &str, temperature: f64, - ) -> anyhow::Result { + ) -> anyhow::Result { let api_key = self.api_key.as_ref() .ok_or_else(|| anyhow::anyhow!("OpenRouter API key not set. Run `zeroclaw onboard` or set OPENROUTER_API_KEY env var."))?; @@ -118,7 +118,7 @@ impl Provider for OpenRouterProvider { .choices .into_iter() .next() - .map(|c| c.message.content) + .map(|c| ChatResponse::with_text(c.message.content)) .ok_or_else(|| anyhow::anyhow!("No response from OpenRouter")) } @@ -127,7 +127,7 @@ impl Provider for OpenRouterProvider { messages: &[ChatMessage], model: &str, temperature: f64, - ) -> anyhow::Result { + ) -> anyhow::Result { let api_key = self.api_key.as_ref() .ok_or_else(|| anyhow::anyhow!("OpenRouter API key not set. Run `zeroclaw onboard` or set OPENROUTER_API_KEY env var."))?; @@ -168,7 +168,7 @@ impl Provider for OpenRouterProvider { .choices .into_iter() .next() - .map(|c| c.message.content) + .map(|c| ChatResponse::with_text(c.message.content)) .ok_or_else(|| anyhow::anyhow!("No response from OpenRouter")) } } diff --git a/src/providers/reliable.rs b/src/providers/reliable.rs index 366f013..12aaa62 100644 --- a/src/providers/reliable.rs +++ b/src/providers/reliable.rs @@ -1,4 +1,4 @@ -use super::traits::ChatMessage; +use super::traits::{ChatMessage, ChatResponse}; use super::Provider; use async_trait::async_trait; use std::time::Duration; @@ -66,7 +66,7 @@ impl Provider for ReliableProvider { message: &str, model: &str, temperature: f64, - ) -> anyhow::Result { + ) -> anyhow::Result { let mut failures = Vec::new(); for (provider_name, provider) in &self.providers { @@ -128,7 +128,7 @@ impl Provider for ReliableProvider { messages: &[ChatMessage], model: &str, temperature: f64, - ) -> anyhow::Result { + ) -> anyhow::Result { let mut failures = Vec::new(); for (provider_name, provider) in &self.providers { @@ -207,12 +207,12 @@ mod tests { _message: &str, _model: &str, _temperature: f64, - ) -> anyhow::Result { + ) -> anyhow::Result { let attempt = self.calls.fetch_add(1, Ordering::SeqCst) + 1; if attempt <= self.fail_until_attempt { anyhow::bail!(self.error); } - Ok(self.response.to_string()) + Ok(ChatResponse::with_text(self.response)) } async fn chat_with_history( @@ -220,12 +220,12 @@ mod tests { _messages: &[ChatMessage], _model: &str, _temperature: f64, - ) -> anyhow::Result { + ) -> anyhow::Result { let attempt = self.calls.fetch_add(1, Ordering::SeqCst) + 1; if attempt <= self.fail_until_attempt { anyhow::bail!(self.error); } - Ok(self.response.to_string()) + Ok(ChatResponse::with_text(self.response)) } } @@ -247,7 +247,7 @@ mod tests { ); let result = provider.chat("hello", "test", 0.0).await.unwrap(); - assert_eq!(result, "ok"); + assert_eq!(result.text_or_empty(), "ok"); assert_eq!(calls.load(Ordering::SeqCst), 1); } @@ -269,7 +269,7 @@ mod tests { ); let result = provider.chat("hello", "test", 0.0).await.unwrap(); - assert_eq!(result, "recovered"); + assert_eq!(result.text_or_empty(), "recovered"); assert_eq!(calls.load(Ordering::SeqCst), 2); } @@ -304,7 +304,7 @@ mod tests { ); let result = provider.chat("hello", "test", 0.0).await.unwrap(); - assert_eq!(result, "from fallback"); + assert_eq!(result.text_or_empty(), "from fallback"); assert_eq!(primary_calls.load(Ordering::SeqCst), 2); assert_eq!(fallback_calls.load(Ordering::SeqCst), 1); } @@ -401,7 +401,7 @@ mod tests { ); let result = provider.chat("hello", "test", 0.0).await.unwrap(); - assert_eq!(result, "from fallback"); + assert_eq!(result.text_or_empty(), "from fallback"); // Primary should have been called only once (no retries) assert_eq!(primary_calls.load(Ordering::SeqCst), 1); assert_eq!(fallback_calls.load(Ordering::SeqCst), 1); @@ -429,7 +429,7 @@ mod tests { .chat_with_history(&messages, "test", 0.0) .await .unwrap(); - assert_eq!(result, "history ok"); + assert_eq!(result.text_or_empty(), "history ok"); assert_eq!(calls.load(Ordering::SeqCst), 2); } @@ -468,7 +468,7 @@ mod tests { .chat_with_history(&messages, "test", 0.0) .await .unwrap(); - assert_eq!(result, "fallback ok"); + assert_eq!(result.text_or_empty(), "fallback ok"); assert_eq!(primary_calls.load(Ordering::SeqCst), 2); assert_eq!(fallback_calls.load(Ordering::SeqCst), 1); } diff --git a/src/providers/router.rs b/src/providers/router.rs index 4ee36f3..eb3101f 100644 --- a/src/providers/router.rs +++ b/src/providers/router.rs @@ -1,4 +1,4 @@ -use super::traits::ChatMessage; +use super::traits::{ChatMessage, ChatResponse}; use super::Provider; use async_trait::async_trait; use std::collections::HashMap; @@ -98,7 +98,7 @@ impl Provider for RouterProvider { message: &str, model: &str, temperature: f64, - ) -> anyhow::Result { + ) -> anyhow::Result { let (provider_idx, resolved_model) = self.resolve(model); let (provider_name, provider) = &self.providers[provider_idx]; @@ -118,7 +118,7 @@ impl Provider for RouterProvider { messages: &[ChatMessage], model: &str, temperature: f64, - ) -> anyhow::Result { + ) -> anyhow::Result { let (provider_idx, resolved_model) = self.resolve(model); let (_, provider) = &self.providers[provider_idx]; provider @@ -175,10 +175,10 @@ mod tests { _message: &str, model: &str, _temperature: f64, - ) -> anyhow::Result { + ) -> anyhow::Result { self.calls.fetch_add(1, Ordering::SeqCst); *self.last_model.lock().unwrap() = model.to_string(); - Ok(self.response.to_string()) + Ok(ChatResponse::with_text(self.response)) } } @@ -229,7 +229,7 @@ mod tests { message: &str, model: &str, temperature: f64, - ) -> anyhow::Result { + ) -> anyhow::Result { self.as_ref() .chat_with_system(system_prompt, message, model, temperature) .await @@ -247,7 +247,7 @@ mod tests { ); let result = router.chat("hello", "hint:reasoning", 0.5).await.unwrap(); - assert_eq!(result, "smart-response"); + assert_eq!(result.text_or_empty(), "smart-response"); assert_eq!(mocks[1].call_count(), 1); assert_eq!(mocks[1].last_model(), "claude-opus"); assert_eq!(mocks[0].call_count(), 0); @@ -261,7 +261,7 @@ mod tests { ); let result = router.chat("hello", "hint:fast", 0.5).await.unwrap(); - assert_eq!(result, "fast-response"); + assert_eq!(result.text_or_empty(), "fast-response"); assert_eq!(mocks[0].call_count(), 1); assert_eq!(mocks[0].last_model(), "llama-3-70b"); } @@ -274,7 +274,7 @@ mod tests { ); let result = router.chat("hello", "hint:nonexistent", 0.5).await.unwrap(); - assert_eq!(result, "default-response"); + assert_eq!(result.text_or_empty(), "default-response"); assert_eq!(mocks[0].call_count(), 1); // Falls back to default with the hint as model name assert_eq!(mocks[0].last_model(), "hint:nonexistent"); @@ -294,7 +294,7 @@ mod tests { .chat("hello", "anthropic/claude-sonnet-4-20250514", 0.5) .await .unwrap(); - assert_eq!(result, "primary-response"); + assert_eq!(result.text_or_empty(), "primary-response"); assert_eq!(mocks[0].call_count(), 1); assert_eq!(mocks[0].last_model(), "anthropic/claude-sonnet-4-20250514"); } @@ -355,7 +355,7 @@ mod tests { .chat_with_system(Some("system"), "hello", "model", 0.5) .await .unwrap(); - assert_eq!(result, "response"); + assert_eq!(result.text_or_empty(), "response"); assert_eq!(mock.call_count(), 1); } } diff --git a/src/providers/traits.rs b/src/providers/traits.rs index 84746ea..d1f8dd1 100644 --- a/src/providers/traits.rs +++ b/src/providers/traits.rs @@ -49,6 +49,14 @@ pub struct ChatResponse { } impl ChatResponse { + /// Convenience: construct a plain text response with no tool calls. + pub fn with_text(text: impl Into) -> Self { + Self { + text: Some(text.into()), + tool_calls: vec![], + } + } + /// True when the LLM wants to invoke at least one tool. pub fn has_tool_calls(&self) -> bool { !self.tool_calls.is_empty() @@ -84,7 +92,12 @@ pub enum ConversationMessage { #[async_trait] pub trait Provider: Send + Sync { - async fn chat(&self, message: &str, model: &str, temperature: f64) -> anyhow::Result { + async fn chat( + &self, + message: &str, + model: &str, + temperature: f64, + ) -> anyhow::Result { self.chat_with_system(None, message, model, temperature) .await } @@ -95,7 +108,7 @@ pub trait Provider: Send + Sync { message: &str, model: &str, temperature: f64, - ) -> anyhow::Result; + ) -> anyhow::Result; /// Multi-turn conversation. Default implementation extracts the last user /// message and delegates to `chat_with_system`. @@ -104,7 +117,7 @@ pub trait Provider: Send + Sync { messages: &[ChatMessage], model: &str, temperature: f64, - ) -> anyhow::Result { + ) -> anyhow::Result { let system = messages .iter() .find(|m| m.role == "system") From 34306e32d8d4f76f75ae9c39024cdabc857feddc Mon Sep 17 00:00:00 2001 From: chumyin Date: Mon, 16 Feb 2026 13:17:23 +0800 Subject: [PATCH 2/4] fix(provider): complete ChatResponse integration across runtime surfaces --- src/gateway/mod.rs | 264 +++++++++++++++++++++++++++++++++--------- src/providers/mod.rs | 1 + src/tools/delegate.rs | 30 +++-- 3 files changed, 229 insertions(+), 66 deletions(-) diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index 2282e66..acf62a4 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -10,8 +10,14 @@ use crate::channels::{Channel, WhatsAppChannel}; use crate::config::Config; use crate::memory::{self, Memory, MemoryCategory}; -use crate::providers::{self, ChatResponse, Provider}; -use crate::security::pairing::{constant_time_eq, is_public_bind, PairingGuard}; +use crate::observability::{self, Observer}; +use crate::providers::{self, ChatMessage, Provider}; +use crate::runtime; +use crate::security::{ + pairing::{constant_time_eq, is_public_bind, PairingGuard}, + SecurityPolicy, +}; +use crate::tools::{self, Tool}; use crate::util::truncate_with_ellipsis; use anyhow::Result; use axum::{ @@ -45,29 +51,33 @@ fn whatsapp_memory_key(msg: &crate::channels::traits::ChannelMessage) -> String format!("whatsapp_{}_{}", msg.sender, msg.id) } -fn gateway_reply_from_response(response: ChatResponse) -> String { - let has_tool_calls = response.has_tool_calls(); - let tool_call_count = response.tool_calls.len(); - let mut reply = response.text.unwrap_or_default(); - - if has_tool_calls { - tracing::warn!( - tool_call_count, - "Provider requested tool calls in gateway mode; tool calls are not executed here" - ); - if reply.trim().is_empty() { - reply = "I need to use tools to answer that, but tool execution is not enabled for gateway requests yet." - .to_string(); - } - } - +fn normalize_gateway_reply(reply: String) -> String { if reply.trim().is_empty() { - reply = "Model returned an empty response.".to_string(); + return "Model returned an empty response.".to_string(); } reply } +async fn gateway_agent_reply(state: &AppState, message: &str) -> Result { + let mut history = vec![ + ChatMessage::system(state.system_prompt.as_str()), + ChatMessage::user(message), + ]; + + let reply = crate::agent::loop_::run_tool_call_loop( + state.provider.as_ref(), + &mut history, + state.tools_registry.as_ref(), + state.observer.as_ref(), + &state.model, + state.temperature, + ) + .await?; + + Ok(normalize_gateway_reply(reply)) +} + #[derive(Debug)] struct SlidingWindowRateLimiter { limit_per_window: u32, @@ -182,6 +192,9 @@ fn client_key_from_headers(headers: &HeaderMap) -> String { #[derive(Clone)] pub struct AppState { pub provider: Arc, + pub observer: Arc, + pub tools_registry: Arc>>, + pub system_prompt: Arc, pub model: String, pub temperature: f64, pub mem: Arc, @@ -228,6 +241,47 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> { &config.workspace_dir, config.api_key.as_deref(), )?); + let observer: Arc = + Arc::from(observability::create_observer(&config.observability)); + let runtime: Arc = + Arc::from(runtime::create_runtime(&config.runtime)?); + let security = Arc::new(SecurityPolicy::from_config( + &config.autonomy, + &config.workspace_dir, + )); + + let composio_key = if config.composio.enabled { + config.composio.api_key.as_deref() + } else { + None + }; + + let tools_registry = Arc::new(tools::all_tools_with_runtime( + &security, + runtime, + Arc::clone(&mem), + composio_key, + &config.browser, + &config.agents, + config.api_key.as_deref(), + )); + let skills = crate::skills::load_skills(&config.workspace_dir); + let tool_descs: Vec<(&str, &str)> = tools_registry + .iter() + .map(|tool| (tool.name(), tool.description())) + .collect(); + + let mut system_prompt = crate::channels::build_system_prompt( + &config.workspace_dir, + &model, + &tool_descs, + &skills, + Some(&config.identity), + ); + system_prompt.push_str(&crate::agent::loop_::build_tool_instructions( + tools_registry.as_ref(), + )); + let system_prompt = Arc::new(system_prompt); // Extract webhook secret for authentication let webhook_secret: Option> = config @@ -331,6 +385,9 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> { // Build shared state let state = AppState { provider, + observer, + tools_registry, + system_prompt, model, temperature, mem, @@ -514,13 +571,8 @@ async fn handle_webhook( .await; } - match state - .provider - .chat(message, &state.model, state.temperature) - .await - { - Ok(response) => { - let reply = gateway_reply_from_response(response); + match gateway_agent_reply(&state, message).await { + Ok(reply) => { let body = serde_json::json!({"response": reply, "model": state.model}); (StatusCode::OK, Json(body)) } @@ -669,13 +721,8 @@ async fn handle_whatsapp_message( } // Call the LLM - match state - .provider - .chat(&msg.content, &state.model, state.temperature) - .await - { - Ok(response) => { - let reply = gateway_reply_from_response(response); + match gateway_agent_reply(&state, &msg.content).await { + Ok(reply) => { // Send reply via WhatsApp if let Err(e) = wa.send(&reply, &msg.sender).await { tracing::error!("Failed to send WhatsApp reply: {e}"); @@ -847,9 +894,9 @@ mod tests { _message: &str, _model: &str, _temperature: f64, - ) -> anyhow::Result { + ) -> anyhow::Result { self.calls.fetch_add(1, Ordering::SeqCst); - Ok(ChatResponse::with_text("ok")) + Ok(crate::providers::ChatResponse::with_text("ok")) } } @@ -910,25 +957,36 @@ mod tests { } } - #[tokio::test] - async fn webhook_idempotency_skips_duplicate_provider_calls() { - let provider_impl = Arc::new(MockProvider::default()); - let provider: Arc = provider_impl.clone(); - let memory: Arc = Arc::new(MockMemory); - - let state = AppState { + fn test_app_state( + provider: Arc, + memory: Arc, + auto_save: bool, + ) -> AppState { + AppState { provider, + observer: Arc::new(crate::observability::NoopObserver), + tools_registry: Arc::new(Vec::new()), + system_prompt: Arc::new("test-system-prompt".into()), model: "test-model".into(), temperature: 0.0, mem: memory, - auto_save: false, + auto_save, webhook_secret: None, pairing: Arc::new(PairingGuard::new(false, &[])), rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100)), idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300))), whatsapp: None, whatsapp_app_secret: None, - }; + } + } + + #[tokio::test] + async fn webhook_idempotency_skips_duplicate_provider_calls() { + let provider_impl = Arc::new(MockProvider::default()); + let provider: Arc = provider_impl.clone(); + let memory: Arc = Arc::new(MockMemory); + + let state = test_app_state(provider, memory, false); let mut headers = HeaderMap::new(); headers.insert("X-Idempotency-Key", HeaderValue::from_static("abc-123")); @@ -964,19 +1022,7 @@ mod tests { let tracking_impl = Arc::new(TrackingMemory::default()); let memory: Arc = tracking_impl.clone(); - let state = AppState { - provider, - model: "test-model".into(), - temperature: 0.0, - mem: memory, - auto_save: true, - webhook_secret: None, - pairing: Arc::new(PairingGuard::new(false, &[])), - rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100)), - idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300))), - whatsapp: None, - whatsapp_app_secret: None, - }; + let state = test_app_state(provider, memory, true); let headers = HeaderMap::new(); @@ -1008,6 +1054,110 @@ mod tests { assert_eq!(provider_impl.calls.load(Ordering::SeqCst), 2); } + #[derive(Default)] + struct StructuredToolCallProvider { + calls: AtomicUsize, + } + + #[async_trait] + impl Provider for StructuredToolCallProvider { + async fn chat_with_system( + &self, + _system_prompt: Option<&str>, + _message: &str, + _model: &str, + _temperature: f64, + ) -> anyhow::Result { + let turn = self.calls.fetch_add(1, Ordering::SeqCst); + + if turn == 0 { + return Ok(crate::providers::ChatResponse { + text: Some("Running tool...".into()), + tool_calls: vec![crate::providers::ToolCall { + id: "call_1".into(), + name: "mock_tool".into(), + arguments: r#"{"query":"gateway"}"#.into(), + }], + }); + } + + Ok(crate::providers::ChatResponse::with_text( + "Gateway tool result ready.", + )) + } + } + + struct MockTool { + calls: Arc, + } + + #[async_trait] + impl Tool for MockTool { + fn name(&self) -> &str { + "mock_tool" + } + + fn description(&self) -> &str { + "Mock tool for gateway tests" + } + + fn parameters_schema(&self) -> serde_json::Value { + serde_json::json!({ + "type": "object", + "properties": { + "query": {"type": "string"} + }, + "required": ["query"] + }) + } + + async fn execute( + &self, + args: serde_json::Value, + ) -> anyhow::Result { + self.calls.fetch_add(1, Ordering::SeqCst); + assert_eq!(args["query"], "gateway"); + + Ok(crate::tools::ToolResult { + success: true, + output: "ok".into(), + error: None, + }) + } + } + + #[tokio::test] + async fn webhook_executes_structured_tool_calls() { + let provider_impl = Arc::new(StructuredToolCallProvider::default()); + let provider: Arc = provider_impl.clone(); + let memory: Arc = Arc::new(MockMemory); + + let tool_calls = Arc::new(AtomicUsize::new(0)); + let tools: Vec> = vec![Box::new(MockTool { + calls: Arc::clone(&tool_calls), + })]; + + let mut state = test_app_state(provider, memory, false); + state.tools_registry = Arc::new(tools); + + let response = handle_webhook( + State(state), + HeaderMap::new(), + Ok(Json(WebhookBody { + message: "please use tool".into(), + })), + ) + .await + .into_response(); + + assert_eq!(response.status(), StatusCode::OK); + let payload = response.into_body().collect().await.unwrap().to_bytes(); + let parsed: serde_json::Value = serde_json::from_slice(&payload).unwrap(); + assert_eq!(parsed["response"], "Gateway tool result ready."); + assert_eq!(tool_calls.load(Ordering::SeqCst), 1); + assert_eq!(provider_impl.calls.load(Ordering::SeqCst), 2); + } + // ══════════════════════════════════════════════════════════ // WhatsApp Signature Verification Tests (CWE-345 Prevention) // ══════════════════════════════════════════════════════════ diff --git a/src/providers/mod.rs b/src/providers/mod.rs index 5911904..7c30650 100644 --- a/src/providers/mod.rs +++ b/src/providers/mod.rs @@ -8,6 +8,7 @@ pub mod reliable; pub mod router; pub mod traits; +#[allow(unused_imports)] pub use traits::{ChatMessage, ChatResponse, Provider, ToolCall}; use compatible::{AuthStyle, OpenAiCompatibleProvider}; diff --git a/src/tools/delegate.rs b/src/tools/delegate.rs index c2660a4..f205a58 100644 --- a/src/tools/delegate.rs +++ b/src/tools/delegate.rs @@ -220,15 +220,27 @@ impl Tool for DelegateTool { }; match result { - Ok(response) => Ok(ToolResult { - success: true, - output: format!( - "[Agent '{agent_name}' ({provider}/{model})]\n{response}", - provider = agent_config.provider, - model = agent_config.model - ), - error: None, - }), + Ok(response) => { + let has_tool_calls = response.has_tool_calls(); + let mut rendered = response.text.unwrap_or_default(); + if rendered.trim().is_empty() { + if has_tool_calls { + rendered = "[Tool-only response; no text content]".to_string(); + } else { + rendered = "[Empty response]".to_string(); + } + } + + Ok(ToolResult { + success: true, + output: format!( + "[Agent '{agent_name}' ({provider}/{model})]\n{rendered}", + provider = agent_config.provider, + model = agent_config.model + ), + error: None, + }) + } Err(e) => Ok(ToolResult { success: false, output: String::new(), From 2d6ec2fb71a4ad162e505d7c58676519b4f6da03 Mon Sep 17 00:00:00 2001 From: chumyin Date: Mon, 16 Feb 2026 19:33:04 +0800 Subject: [PATCH 3/4] fix(rebase): resolve PR #266 conflicts against latest main --- src/agent/loop_.rs | 1 + src/channels/mod.rs | 37 +++++++-------- src/channels/telegram.rs | 5 +- src/daemon/mod.rs | 3 +- src/gateway/mod.rs | 3 ++ src/tools/git_operations.rs | 95 +++++++++++++++++++++++++++---------- src/tools/mod.rs | 49 +++++++++++++++++-- 7 files changed, 142 insertions(+), 51 deletions(-) diff --git a/src/agent/loop_.rs b/src/agent/loop_.rs index 45b37d2..4698032 100644 --- a/src/agent/loop_.rs +++ b/src/agent/loop_.rs @@ -113,6 +113,7 @@ async fn auto_compact_history( let summary_raw = provider .chat_with_system(Some(summarizer_system), &summarizer_user, model, 0.2) .await + .map(|resp| resp.text_or_empty().to_string()) .unwrap_or_else(|_| { // Fallback to deterministic local truncation when summarization fails. truncate_with_ellipsis(&transcript, COMPACTION_MAX_SUMMARY_CHARS) diff --git a/src/channels/mod.rs b/src/channels/mod.rs index f0399da..aa1fc6b 100644 --- a/src/channels/mod.rs +++ b/src/channels/mod.rs @@ -721,6 +721,7 @@ pub async fn start_channels(config: Config) -> Result<()> { composio_key, &config.browser, &config.http_request, + &config.workspace_dir, &config.agents, config.api_key.as_deref(), )); @@ -951,7 +952,7 @@ mod tests { use super::*; use crate::memory::{Memory, MemoryCategory, SqliteMemory}; use crate::observability::NoopObserver; - use crate::providers::{ChatMessage, Provider}; + use crate::providers::{ChatMessage, ChatResponse, Provider, ToolCall}; use crate::tools::{Tool, ToolResult}; use std::collections::HashMap; use std::sync::atomic::{AtomicUsize, Ordering}; @@ -1018,27 +1019,23 @@ mod tests { message: &str, _model: &str, _temperature: f64, - ) -> anyhow::Result { + ) -> anyhow::Result { tokio::time::sleep(self.delay).await; - Ok(format!("echo: {message}")) + Ok(ChatResponse::with_text(format!("echo: {message}"))) } } struct ToolCallingProvider; - fn tool_call_payload() -> String { - serde_json::json!({ - "content": "", - "tool_calls": [{ - "id": "call_1", - "type": "function", - "function": { - "name": "mock_price", - "arguments": "{\"symbol\":\"BTC\"}" - } - }] - }) - .to_string() + fn tool_call_payload() -> ChatResponse { + ChatResponse { + text: Some(String::new()), + tool_calls: vec![ToolCall { + id: "call_1".into(), + name: "mock_price".into(), + arguments: r#"{"symbol":"BTC"}"#.into(), + }], + } } #[async_trait::async_trait] @@ -1049,7 +1046,7 @@ mod tests { _message: &str, _model: &str, _temperature: f64, - ) -> anyhow::Result { + ) -> anyhow::Result { Ok(tool_call_payload()) } @@ -1058,12 +1055,14 @@ mod tests { messages: &[ChatMessage], _model: &str, _temperature: f64, - ) -> anyhow::Result { + ) -> anyhow::Result { let has_tool_results = messages .iter() .any(|msg| msg.role == "user" && msg.content.contains("[Tool results]")); if has_tool_results { - Ok("BTC is currently around $65,000 based on latest tool output.".to_string()) + Ok(ChatResponse::with_text( + "BTC is currently around $65,000 based on latest tool output.", + )) } else { Ok(tool_call_payload()) } diff --git a/src/channels/telegram.rs b/src/channels/telegram.rs index 40193fe..5b1435c 100644 --- a/src/channels/telegram.rs +++ b/src/channels/telegram.rs @@ -32,7 +32,10 @@ fn split_message_for_telegram(message: &str) -> Vec { pos + 1 } else { // Try space as fallback - search_area.rfind(' ').unwrap_or(TELEGRAM_MAX_MESSAGE_LENGTH) + 1 + search_area + .rfind(' ') + .unwrap_or(TELEGRAM_MAX_MESSAGE_LENGTH) + + 1 } } else if let Some(pos) = search_area.rfind(' ') { pos + 1 diff --git a/src/daemon/mod.rs b/src/daemon/mod.rs index af3b861..f1bc4a1 100644 --- a/src/daemon/mod.rs +++ b/src/daemon/mod.rs @@ -193,7 +193,8 @@ async fn run_heartbeat_worker(config: Config) -> Result<()> { for task in tasks { let prompt = format!("[Heartbeat Task] {task}"); let temp = config.default_temperature; - if let Err(e) = crate::agent::run(config.clone(), Some(prompt), None, None, temp).await + if let Err(e) = + crate::agent::run(config.clone(), Some(prompt), None, None, temp, false).await { crate::health::mark_component_error("heartbeat", e.to_string()); tracing::warn!("Heartbeat task failed: {e}"); diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index acf62a4..8eaa57c 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -70,6 +70,7 @@ async fn gateway_agent_reply(state: &AppState, message: &str) -> Result &mut history, state.tools_registry.as_ref(), state.observer.as_ref(), + "gateway", &state.model, state.temperature, ) @@ -262,6 +263,8 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> { Arc::clone(&mem), composio_key, &config.browser, + &config.http_request, + &config.workspace_dir, &config.agents, config.api_key.as_deref(), )); diff --git a/src/tools/git_operations.rs b/src/tools/git_operations.rs index 774115b..bf4e62c 100644 --- a/src/tools/git_operations.rs +++ b/src/tools/git_operations.rs @@ -14,7 +14,10 @@ pub struct GitOperationsTool { impl GitOperationsTool { pub fn new(security: Arc, workspace_dir: std::path::PathBuf) -> Self { - Self { security, workspace_dir } + Self { + security, + workspace_dir, + } } /// Sanitize git arguments to prevent injection attacks @@ -48,7 +51,10 @@ impl GitOperationsTool { /// Check if an operation is read-only fn is_read_only(&self, operation: &str) -> bool { - matches!(operation, "status" | "diff" | "log" | "show" | "branch" | "rev-parse") + matches!( + operation, + "status" | "diff" | "log" | "show" | "branch" | "rev-parse" + ) } async fn run_git_command(&self, args: &[&str]) -> anyhow::Result { @@ -67,7 +73,9 @@ impl GitOperationsTool { } async fn git_status(&self, _args: serde_json::Value) -> anyhow::Result { - let output = self.run_git_command(&["status", "--porcelain=2", "--branch"]).await?; + let output = self + .run_git_command(&["status", "--porcelain=2", "--branch"]) + .await?; // Parse git status output into structured format let mut result = serde_json::Map::new(); @@ -105,7 +113,10 @@ impl GitOperationsTool { result.insert("staged".to_string(), json!(staged)); result.insert("unstaged".to_string(), json!(unstaged)); result.insert("untracked".to_string(), json!(untracked)); - result.insert("clean".to_string(), json!(staged.is_empty() && unstaged.is_empty() && untracked.is_empty())); + result.insert( + "clean".to_string(), + json!(staged.is_empty() && unstaged.is_empty() && untracked.is_empty()), + ); Ok(ToolResult { success: true, @@ -116,7 +127,10 @@ impl GitOperationsTool { async fn git_diff(&self, args: serde_json::Value) -> anyhow::Result { let files = args.get("files").and_then(|v| v.as_str()).unwrap_or("."); - let cached = args.get("cached").and_then(|v| v.as_bool()).unwrap_or(false); + let cached = args + .get("cached") + .and_then(|v| v.as_bool()) + .unwrap_or(false); let mut git_args = vec!["diff", "--unified=3"]; if cached { @@ -191,12 +205,14 @@ impl GitOperationsTool { let limit = args.get("limit").and_then(|v| v.as_u64()).unwrap_or(10) as usize; let limit_str = limit.to_string(); - let output = self.run_git_command(&[ - "log", - &format!("-{limit_str}"), - "--pretty=format:%H|%an|%ae|%ad|%s", - "--date=iso", - ]).await?; + let output = self + .run_git_command(&[ + "log", + &format!("-{limit_str}"), + "--pretty=format:%H|%an|%ae|%ad|%s", + "--date=iso", + ]) + .await?; let mut commits = Vec::new(); @@ -215,13 +231,16 @@ impl GitOperationsTool { Ok(ToolResult { success: true, - output: serde_json::to_string_pretty(&json!({ "commits": commits })).unwrap_or_default(), + output: serde_json::to_string_pretty(&json!({ "commits": commits })) + .unwrap_or_default(), error: None, }) } async fn git_branch(&self, _args: serde_json::Value) -> anyhow::Result { - let output = self.run_git_command(&["branch", "--format=%(refname:short)|%(HEAD)"]).await?; + let output = self + .run_git_command(&["branch", "--format=%(refname:short)|%(HEAD)"]) + .await?; let mut branches = Vec::new(); let mut current = String::new(); @@ -244,18 +263,21 @@ impl GitOperationsTool { output: serde_json::to_string_pretty(&json!({ "current": current, "branches": branches - })).unwrap_or_default(), + })) + .unwrap_or_default(), error: None, }) } async fn git_commit(&self, args: serde_json::Value) -> anyhow::Result { - let message = args.get("message") + let message = args + .get("message") .and_then(|v| v.as_str()) .ok_or_else(|| anyhow::anyhow!("Missing 'message' parameter"))?; // Sanitize commit message - let sanitized = message.lines() + let sanitized = message + .lines() .map(|l| l.trim()) .filter(|l| !l.is_empty()) .collect::>() @@ -289,7 +311,8 @@ impl GitOperationsTool { } async fn git_add(&self, args: serde_json::Value) -> anyhow::Result { - let paths = args.get("paths") + let paths = args + .get("paths") .and_then(|v| v.as_str()) .ok_or_else(|| anyhow::anyhow!("Missing 'paths' parameter"))?; @@ -310,7 +333,8 @@ impl GitOperationsTool { } async fn git_checkout(&self, args: serde_json::Value) -> anyhow::Result { - let branch = args.get("branch") + let branch = args + .get("branch") .and_then(|v| v.as_str()) .ok_or_else(|| anyhow::anyhow!("Missing 'branch' parameter"))?; @@ -345,15 +369,22 @@ impl GitOperationsTool { } async fn git_stash(&self, args: serde_json::Value) -> anyhow::Result { - let action = args.get("action").and_then(|v| v.as_str()).unwrap_or("push"); + let action = args + .get("action") + .and_then(|v| v.as_str()) + .unwrap_or("push"); let output = match action { - "push" | "save" => self.run_git_command(&["stash", "push", "-m", "auto-stash"]).await, + "push" | "save" => { + self.run_git_command(&["stash", "push", "-m", "auto-stash"]) + .await + } "pop" => self.run_git_command(&["stash", "pop"]).await, "list" => self.run_git_command(&["stash", "list"]).await, "drop" => { let index = args.get("index").and_then(|v| v.as_u64()).unwrap_or(0) as i32; - self.run_git_command(&["stash", "drop", &format!("stash@{{{index}}}")]).await + self.run_git_command(&["stash", "drop", &format!("stash@{{{index}}}")]) + .await } _ => anyhow::bail!("Unknown stash action: {action}. Use: push, pop, list, drop"), }; @@ -470,7 +501,9 @@ impl Tool for GitOperationsTool { return Ok(ToolResult { success: false, output: String::new(), - error: Some("Action blocked: git write operations require higher autonomy level".into()), + error: Some( + "Action blocked: git write operations require higher autonomy level".into(), + ), }); } @@ -606,7 +639,11 @@ mod tests { .unwrap(); assert!(!result.success); // can_act() returns false for ReadOnly, so we get the "higher autonomy level" message - assert!(result.error.as_deref().unwrap_or("").contains("higher autonomy")); + assert!(result + .error + .as_deref() + .unwrap_or("") + .contains("higher autonomy")); } #[tokio::test] @@ -632,7 +669,11 @@ mod tests { let result = tool.execute(json!({})).await.unwrap(); assert!(!result.success); - assert!(result.error.as_deref().unwrap_or("").contains("Missing 'operation'")); + assert!(result + .error + .as_deref() + .unwrap_or("") + .contains("Missing 'operation'")); } #[tokio::test] @@ -649,6 +690,10 @@ mod tests { let result = tool.execute(json!({"operation": "push"})).await.unwrap(); assert!(!result.success); - assert!(result.error.as_deref().unwrap_or("").contains("Unknown operation")); + assert!(result + .error + .as_deref() + .unwrap_or("") + .contains("Unknown operation")); } } diff --git a/src/tools/mod.rs b/src/tools/mod.rs index 95660b3..22e8d1a 100644 --- a/src/tools/mod.rs +++ b/src/tools/mod.rs @@ -101,7 +101,10 @@ pub fn all_tools_with_runtime( Box::new(MemoryStoreTool::new(memory.clone())), Box::new(MemoryRecallTool::new(memory.clone())), Box::new(MemoryForgetTool::new(memory)), - Box::new(GitOperationsTool::new(security.clone(), workspace_dir.to_path_buf())), + Box::new(GitOperationsTool::new( + security.clone(), + workspace_dir.to_path_buf(), + )), ]; if browser_config.enabled { @@ -184,7 +187,16 @@ mod tests { }; let http = crate::config::HttpRequestConfig::default(); - let tools = all_tools(&security, mem, None, &browser, &http, tmp.path(), &HashMap::new(), None); + let tools = all_tools( + &security, + mem, + None, + &browser, + &http, + tmp.path(), + &HashMap::new(), + None, + ); let names: Vec<&str> = tools.iter().map(|t| t.name()).collect(); assert!(!names.contains(&"browser_open")); } @@ -208,7 +220,16 @@ mod tests { }; let http = crate::config::HttpRequestConfig::default(); - let tools = all_tools(&security, mem, None, &browser, &http, tmp.path(), &HashMap::new(), None); + let tools = all_tools( + &security, + mem, + None, + &browser, + &http, + tmp.path(), + &HashMap::new(), + None, + ); let names: Vec<&str> = tools.iter().map(|t| t.name()).collect(); assert!(names.contains(&"browser_open")); } @@ -334,7 +355,16 @@ mod tests { }, ); - let tools = all_tools(&security, mem, None, &browser, &http, tmp.path(), &agents, Some("sk-test")); + let tools = all_tools( + &security, + mem, + None, + &browser, + &http, + tmp.path(), + &agents, + Some("sk-test"), + ); let names: Vec<&str> = tools.iter().map(|t| t.name()).collect(); assert!(names.contains(&"delegate")); } @@ -353,7 +383,16 @@ mod tests { let browser = BrowserConfig::default(); let http = crate::config::HttpRequestConfig::default(); - let tools = all_tools(&security, mem, None, &browser, &http, tmp.path(), &HashMap::new(), None); + let tools = all_tools( + &security, + mem, + None, + &browser, + &http, + tmp.path(), + &HashMap::new(), + None, + ); let names: Vec<&str> = tools.iter().map(|t| t.name()).collect(); assert!(!names.contains(&"delegate")); } From dedb465377f8e848cac527359dea5e3141ccf909 Mon Sep 17 00:00:00 2001 From: chumyin Date: Mon, 16 Feb 2026 19:36:39 +0800 Subject: [PATCH 4/4] test(telegram): ensure newline split case exceeds max length --- src/channels/telegram.rs | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/channels/telegram.rs b/src/channels/telegram.rs index 5b1435c..ea90e79 100644 --- a/src/channels/telegram.rs +++ b/src/channels/telegram.rs @@ -919,7 +919,8 @@ mod tests { #[test] fn telegram_split_at_newline() { - let text_block = "Line of text\n".repeat(TELEGRAM_MAX_MESSAGE_LENGTH / 13); + let line = "Line of text\n"; + let text_block = line.repeat(TELEGRAM_MAX_MESSAGE_LENGTH / line.len() + 1); let chunks = split_message_for_telegram(&text_block); assert!(chunks.len() >= 2); for chunk in chunks {