From b341fdb36892fb7e1f3cb3bf4e51d622553b2e3b Mon Sep 17 00:00:00 2001 From: mai1015 Date: Mon, 16 Feb 2026 00:40:43 -0500 Subject: [PATCH] feat: add agent structure and improve tooling for provider --- src/agent/agent.rs | 701 ++++++++++++++++++++++++++++++++++++ src/agent/dispatcher.rs | 312 ++++++++++++++++ src/agent/loop_.rs | 22 +- src/agent/memory_loader.rs | 118 ++++++ src/agent/mod.rs | 21 ++ src/agent/prompt.rs | 304 ++++++++++++++++ src/channels/mod.rs | 36 +- src/config/schema.rs | 67 ++++ src/gateway/mod.rs | 272 +++----------- src/onboard/wizard.rs | 2 + src/providers/anthropic.rs | 324 +++++++++++++++-- src/providers/compatible.rs | 155 ++++---- src/providers/gemini.rs | 5 +- src/providers/mod.rs | 5 +- src/providers/ollama.rs | 8 +- src/providers/openai.rs | 239 +++++++++++- src/providers/openrouter.rs | 238 +++++++++++- src/providers/reliable.rs | 42 +-- src/providers/router.rs | 54 ++- src/providers/traits.rs | 76 +++- src/tools/delegate.rs | 9 +- 21 files changed, 2567 insertions(+), 443 deletions(-) create mode 100644 src/agent/agent.rs create mode 100644 src/agent/dispatcher.rs create mode 100644 src/agent/memory_loader.rs create mode 100644 src/agent/prompt.rs diff --git a/src/agent/agent.rs b/src/agent/agent.rs new file mode 100644 index 0000000..8f9331e --- /dev/null +++ b/src/agent/agent.rs @@ -0,0 +1,701 @@ +use crate::agent::dispatcher::{ + NativeToolDispatcher, ParsedToolCall, ToolDispatcher, ToolExecutionResult, XmlToolDispatcher, +}; +use crate::agent::memory_loader::{DefaultMemoryLoader, MemoryLoader}; +use crate::agent::prompt::{PromptContext, SystemPromptBuilder}; +use crate::config::Config; +use crate::memory::{self, Memory, MemoryCategory}; +use crate::observability::{self, Observer, ObserverEvent}; +use crate::providers::{self, ChatMessage, ChatRequest, ConversationMessage, Provider}; +use crate::runtime; +use crate::security::SecurityPolicy; +use crate::tools::{self, Tool, ToolSpec}; +use crate::util::truncate_with_ellipsis; +use anyhow::Result; +use std::io::Write as IoWrite; +use std::sync::Arc; +use std::time::Instant; + +pub struct Agent { + provider: Box, + tools: Vec>, + tool_specs: Vec, + memory: Arc, + observer: Arc, + prompt_builder: SystemPromptBuilder, + tool_dispatcher: Box, + memory_loader: Box, + config: crate::config::AgentConfig, + model_name: String, + temperature: f64, + workspace_dir: std::path::PathBuf, + identity_config: crate::config::IdentityConfig, + skills: Vec, + auto_save: bool, + history: Vec, +} + +pub struct AgentBuilder { + provider: Option>, + tools: Option>>, + memory: Option>, + observer: Option>, + prompt_builder: Option, + tool_dispatcher: Option>, + memory_loader: Option>, + config: Option, + model_name: Option, + temperature: Option, + workspace_dir: Option, + identity_config: Option, + skills: Option>, + auto_save: Option, +} + +impl AgentBuilder { + pub fn new() -> Self { + Self { + provider: None, + tools: None, + memory: None, + observer: None, + prompt_builder: None, + tool_dispatcher: None, + memory_loader: None, + config: None, + model_name: None, + temperature: None, + workspace_dir: None, + identity_config: None, + skills: None, + auto_save: None, + } + } + + pub fn provider(mut self, provider: Box) -> Self { + self.provider = Some(provider); + self + } + + pub fn tools(mut self, tools: Vec>) -> Self { + self.tools = Some(tools); + self + } + + pub fn memory(mut self, memory: Arc) -> Self { + self.memory = Some(memory); + self + } + + pub fn observer(mut self, observer: Arc) -> Self { + self.observer = Some(observer); + self + } + + pub fn prompt_builder(mut self, prompt_builder: SystemPromptBuilder) -> Self { + self.prompt_builder = Some(prompt_builder); + self + } + + pub fn tool_dispatcher(mut self, tool_dispatcher: Box) -> Self { + self.tool_dispatcher = Some(tool_dispatcher); + self + } + + pub fn memory_loader(mut self, memory_loader: Box) -> Self { + self.memory_loader = Some(memory_loader); + self + } + + pub fn config(mut self, config: crate::config::AgentConfig) -> Self { + self.config = Some(config); + self + } + + pub fn model_name(mut self, model_name: String) -> Self { + self.model_name = Some(model_name); + self + } + + pub fn temperature(mut self, temperature: f64) -> Self { + self.temperature = Some(temperature); + self + } + + pub fn workspace_dir(mut self, workspace_dir: std::path::PathBuf) -> Self { + self.workspace_dir = Some(workspace_dir); + self + } + + pub fn identity_config(mut self, identity_config: crate::config::IdentityConfig) -> Self { + self.identity_config = Some(identity_config); + self + } + + pub fn skills(mut self, skills: Vec) -> Self { + self.skills = Some(skills); + self + } + + pub fn auto_save(mut self, auto_save: bool) -> Self { + self.auto_save = Some(auto_save); + self + } + + pub fn build(self) -> Result { + let tools = self + .tools + .ok_or_else(|| anyhow::anyhow!("tools are required"))?; + let tool_specs = tools.iter().map(|tool| tool.spec()).collect(); + + Ok(Agent { + provider: self + .provider + .ok_or_else(|| anyhow::anyhow!("provider is required"))?, + tools, + tool_specs, + memory: self + .memory + .ok_or_else(|| anyhow::anyhow!("memory is required"))?, + observer: self + .observer + .ok_or_else(|| anyhow::anyhow!("observer is required"))?, + prompt_builder: self + .prompt_builder + .unwrap_or_else(SystemPromptBuilder::with_defaults), + tool_dispatcher: self + .tool_dispatcher + .ok_or_else(|| anyhow::anyhow!("tool_dispatcher is required"))?, + memory_loader: self + .memory_loader + .unwrap_or_else(|| Box::new(DefaultMemoryLoader::default())), + config: self.config.unwrap_or_default(), + model_name: self + .model_name + .unwrap_or_else(|| "anthropic/claude-sonnet-4-20250514".into()), + temperature: self.temperature.unwrap_or(0.7), + workspace_dir: self + .workspace_dir + .unwrap_or_else(|| std::path::PathBuf::from(".")), + identity_config: self.identity_config.unwrap_or_default(), + skills: self.skills.unwrap_or_default(), + auto_save: self.auto_save.unwrap_or(false), + history: Vec::new(), + }) + } +} + +impl Agent { + pub fn builder() -> AgentBuilder { + AgentBuilder::new() + } + + pub fn history(&self) -> &[ConversationMessage] { + &self.history + } + + pub fn clear_history(&mut self) { + self.history.clear(); + } + + pub fn from_config(config: &Config) -> Result { + let observer: Arc = + Arc::from(observability::create_observer(&config.observability)); + let runtime: Arc = + Arc::from(runtime::create_runtime(&config.runtime)?); + let security = Arc::new(SecurityPolicy::from_config( + &config.autonomy, + &config.workspace_dir, + )); + + let memory: Arc = Arc::from(memory::create_memory( + &config.memory, + &config.workspace_dir, + config.api_key.as_deref(), + )?); + + let composio_key = if config.composio.enabled { + config.composio.api_key.as_deref() + } else { + None + }; + + let tools = tools::all_tools_with_runtime( + &security, + runtime, + memory.clone(), + composio_key, + &config.browser, + &config.http_request, + &config.workspace_dir, + &config.agents, + config.api_key.as_deref(), + ); + + let provider_name = config.default_provider.as_deref().unwrap_or("openrouter"); + + let model_name = config + .default_model + .as_deref() + .unwrap_or("anthropic/claude-sonnet-4-20250514") + .to_string(); + + let provider: Box = providers::create_routed_provider( + provider_name, + config.api_key.as_deref(), + &config.reliability, + &config.model_routes, + &model_name, + )?; + + let dispatcher_choice = config.agent.tool_dispatcher.as_str(); + let tool_dispatcher: Box = match dispatcher_choice { + "native" => Box::new(NativeToolDispatcher), + "xml" => Box::new(XmlToolDispatcher), + _ if provider.supports_native_tools() => Box::new(NativeToolDispatcher), + _ => Box::new(XmlToolDispatcher), + }; + + Agent::builder() + .provider(provider) + .tools(tools) + .memory(memory) + .observer(observer) + .tool_dispatcher(tool_dispatcher) + .memory_loader(Box::new(DefaultMemoryLoader::default())) + .prompt_builder(SystemPromptBuilder::with_defaults()) + .config(config.agent.clone()) + .model_name(model_name) + .temperature(config.default_temperature) + .workspace_dir(config.workspace_dir.clone()) + .identity_config(config.identity.clone()) + .skills(crate::skills::load_skills(&config.workspace_dir)) + .auto_save(config.memory.auto_save) + .build() + } + + fn trim_history(&mut self) { + let max = self.config.max_history_messages; + if self.history.len() <= max { + return; + } + + let mut system_messages = Vec::new(); + let mut other_messages = Vec::new(); + + for msg in self.history.drain(..) { + match &msg { + ConversationMessage::Chat(chat) if chat.role == "system" => { + system_messages.push(msg) + } + _ => other_messages.push(msg), + } + } + + if other_messages.len() > max { + let drop_count = other_messages.len() - max; + other_messages.drain(0..drop_count); + } + + self.history = system_messages; + self.history.extend(other_messages); + } + + fn build_system_prompt(&self) -> Result { + let instructions = self.tool_dispatcher.prompt_instructions(&self.tools); + let ctx = PromptContext { + workspace_dir: &self.workspace_dir, + model_name: &self.model_name, + tools: &self.tools, + skills: &self.skills, + identity_config: Some(&self.identity_config), + dispatcher_instructions: &instructions, + }; + self.prompt_builder.build(&ctx) + } + + async fn execute_tool_call(&self, call: &ParsedToolCall) -> ToolExecutionResult { + let start = Instant::now(); + + let result = if let Some(tool) = self.tools.iter().find(|t| t.name() == call.name) { + match tool.execute(call.arguments.clone()).await { + Ok(r) => { + self.observer.record_event(&ObserverEvent::ToolCall { + tool: call.name.clone(), + duration: start.elapsed(), + success: r.success, + }); + if r.success { + r.output + } else { + format!("Error: {}", r.error.unwrap_or(r.output)) + } + } + Err(e) => { + self.observer.record_event(&ObserverEvent::ToolCall { + tool: call.name.clone(), + duration: start.elapsed(), + success: false, + }); + format!("Error executing {}: {e}", call.name) + } + } + } else { + format!("Unknown tool: {}", call.name) + }; + + ToolExecutionResult { + name: call.name.clone(), + output: result, + success: true, + tool_call_id: call.tool_call_id.clone(), + } + } + + async fn execute_tools(&self, calls: &[ParsedToolCall]) -> Vec { + if !self.config.parallel_tools { + let mut results = Vec::with_capacity(calls.len()); + for call in calls { + results.push(self.execute_tool_call(call).await); + } + return results; + } + + let mut results = Vec::with_capacity(calls.len()); + for call in calls { + results.push(self.execute_tool_call(call).await); + } + results + } + + pub async fn turn(&mut self, user_message: &str) -> Result { + if self.history.is_empty() { + let system_prompt = self.build_system_prompt()?; + self.history + .push(ConversationMessage::Chat(ChatMessage::system( + system_prompt, + ))); + } + + if self.auto_save { + let _ = self + .memory + .store("user_msg", user_message, MemoryCategory::Conversation) + .await; + } + + let context = self + .memory_loader + .load_context(self.memory.as_ref(), user_message) + .await + .unwrap_or_default(); + + let enriched = if context.is_empty() { + user_message.to_string() + } else { + format!("{context}{user_message}") + }; + + self.history + .push(ConversationMessage::Chat(ChatMessage::user(enriched))); + + for _ in 0..self.config.max_tool_iterations { + let messages = self.tool_dispatcher.to_provider_messages(&self.history); + let response = match self + .provider + .chat( + ChatRequest { + messages: &messages, + tools: if self.tool_dispatcher.should_send_tool_specs() { + Some(&self.tool_specs) + } else { + None + }, + }, + &self.model_name, + self.temperature, + ) + .await + { + Ok(resp) => resp, + Err(err) => return Err(err), + }; + + let (text, calls) = self.tool_dispatcher.parse_response(&response); + if calls.is_empty() { + let final_text = if text.is_empty() { + response.text.unwrap_or_default() + } else { + text + }; + + self.history + .push(ConversationMessage::Chat(ChatMessage::assistant( + final_text.clone(), + ))); + self.trim_history(); + + if self.auto_save { + let summary = truncate_with_ellipsis(&final_text, 100); + let _ = self + .memory + .store("assistant_resp", &summary, MemoryCategory::Daily) + .await; + } + + return Ok(final_text); + } + + if !text.is_empty() { + self.history + .push(ConversationMessage::Chat(ChatMessage::assistant( + text.clone(), + ))); + print!("{text}"); + let _ = std::io::stdout().flush(); + } + + self.history.push(ConversationMessage::AssistantToolCalls { + text: response.text.clone(), + tool_calls: response.tool_calls.clone(), + }); + + let results = self.execute_tools(&calls).await; + let formatted = self.tool_dispatcher.format_results(&results); + self.history.push(formatted); + self.trim_history(); + } + + anyhow::bail!( + "Agent exceeded maximum tool iterations ({})", + self.config.max_tool_iterations + ) + } + + pub async fn run_single(&mut self, message: &str) -> Result { + self.turn(message).await + } + + pub async fn run_interactive(&mut self) -> Result<()> { + println!("๐Ÿฆ€ ZeroClaw Interactive Mode"); + println!("Type /quit to exit.\n"); + + let (tx, mut rx) = tokio::sync::mpsc::channel(32); + let cli = crate::channels::CliChannel::new(); + + let listen_handle = tokio::spawn(async move { + let _ = crate::channels::Channel::listen(&cli, tx).await; + }); + + while let Some(msg) = rx.recv().await { + let response = match self.turn(&msg.content).await { + Ok(resp) => resp, + Err(e) => { + eprintln!("\nError: {e}\n"); + continue; + } + }; + println!("\n{response}\n"); + } + + listen_handle.abort(); + Ok(()) + } +} + +pub async fn run( + config: Config, + message: Option, + provider_override: Option, + model_override: Option, + temperature: f64, +) -> Result<()> { + let start = Instant::now(); + + let mut effective_config = config; + if let Some(p) = provider_override { + effective_config.default_provider = Some(p); + } + if let Some(m) = model_override { + effective_config.default_model = Some(m); + } + effective_config.default_temperature = temperature; + + let mut agent = Agent::from_config(&effective_config)?; + + let provider_name = effective_config + .default_provider + .as_deref() + .unwrap_or("openrouter") + .to_string(); + let model_name = effective_config + .default_model + .as_deref() + .unwrap_or("anthropic/claude-sonnet-4-20250514") + .to_string(); + + agent.observer.record_event(&ObserverEvent::AgentStart { + provider: provider_name, + model: model_name, + }); + + if let Some(msg) = message { + let response = agent.run_single(&msg).await?; + println!("{response}"); + } else { + agent.run_interactive().await?; + } + + agent.observer.record_event(&ObserverEvent::AgentEnd { + duration: start.elapsed(), + tokens_used: None, + }); + + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use async_trait::async_trait; + use std::sync::Mutex; + + struct MockProvider { + responses: Mutex>, + } + + #[async_trait] + impl Provider for MockProvider { + async fn chat_with_system( + &self, + _system_prompt: Option<&str>, + _message: &str, + _model: &str, + _temperature: f64, + ) -> Result { + Ok("ok".into()) + } + + async fn chat( + &self, + _request: ChatRequest<'_>, + _model: &str, + _temperature: f64, + ) -> Result { + let mut guard = self.responses.lock().unwrap(); + if guard.is_empty() { + return Ok(crate::providers::ChatResponse { + text: Some("done".into()), + tool_calls: vec![], + }); + } + Ok(guard.remove(0)) + } + } + + struct MockTool; + + #[async_trait] + impl Tool for MockTool { + fn name(&self) -> &str { + "echo" + } + + fn description(&self) -> &str { + "echo" + } + + fn parameters_schema(&self) -> serde_json::Value { + serde_json::json!({"type": "object"}) + } + + async fn execute(&self, _args: serde_json::Value) -> Result { + Ok(crate::tools::ToolResult { + success: true, + output: "tool-out".into(), + error: None, + }) + } + } + + #[tokio::test] + async fn turn_without_tools_returns_text() { + let provider = Box::new(MockProvider { + responses: Mutex::new(vec![crate::providers::ChatResponse { + text: Some("hello".into()), + tool_calls: vec![], + }]), + }); + + let memory_cfg = crate::config::MemoryConfig { + backend: "none".into(), + ..crate::config::MemoryConfig::default() + }; + let mem: Arc = Arc::from( + crate::memory::create_memory(&memory_cfg, std::path::Path::new("/tmp"), None).unwrap(), + ); + + let observer: Arc = Arc::from(crate::observability::NoopObserver {}); + let mut agent = Agent::builder() + .provider(provider) + .tools(vec![Box::new(MockTool)]) + .memory(mem) + .observer(observer) + .tool_dispatcher(Box::new(XmlToolDispatcher)) + .workspace_dir(std::path::PathBuf::from("/tmp")) + .build() + .unwrap(); + + let response = agent.turn("hi").await.unwrap(); + assert_eq!(response, "hello"); + } + + #[tokio::test] + async fn turn_with_native_dispatcher_handles_tool_results_variant() { + let provider = Box::new(MockProvider { + responses: Mutex::new(vec![ + crate::providers::ChatResponse { + text: Some("".into()), + tool_calls: vec![crate::providers::ToolCall { + id: "tc1".into(), + name: "echo".into(), + arguments: "{}".into(), + }], + }, + crate::providers::ChatResponse { + text: Some("done".into()), + tool_calls: vec![], + }, + ]), + }); + + let memory_cfg = crate::config::MemoryConfig { + backend: "none".into(), + ..crate::config::MemoryConfig::default() + }; + let mem: Arc = Arc::from( + crate::memory::create_memory(&memory_cfg, std::path::Path::new("/tmp"), None).unwrap(), + ); + + let observer: Arc = Arc::from(crate::observability::NoopObserver {}); + let mut agent = Agent::builder() + .provider(provider) + .tools(vec![Box::new(MockTool)]) + .memory(mem) + .observer(observer) + .tool_dispatcher(Box::new(NativeToolDispatcher)) + .workspace_dir(std::path::PathBuf::from("/tmp")) + .build() + .unwrap(); + + let response = agent.turn("hi").await.unwrap(); + assert_eq!(response, "done"); + assert!(matches!( + agent + .history() + .iter() + .find(|msg| matches!(msg, ConversationMessage::ToolResults(_))), + Some(_) + )); + } +} diff --git a/src/agent/dispatcher.rs b/src/agent/dispatcher.rs new file mode 100644 index 0000000..673ec8c --- /dev/null +++ b/src/agent/dispatcher.rs @@ -0,0 +1,312 @@ +use crate::providers::{ChatMessage, ChatResponse, ConversationMessage, ToolResultMessage}; +use crate::tools::{Tool, ToolSpec}; +use serde_json::Value; +use std::fmt::Write; + +#[derive(Debug, Clone)] +pub struct ParsedToolCall { + pub name: String, + pub arguments: Value, + pub tool_call_id: Option, +} + +#[derive(Debug, Clone)] +pub struct ToolExecutionResult { + pub name: String, + pub output: String, + pub success: bool, + pub tool_call_id: Option, +} + +pub trait ToolDispatcher: Send + Sync { + fn parse_response(&self, response: &ChatResponse) -> (String, Vec); + fn format_results(&self, results: &[ToolExecutionResult]) -> ConversationMessage; + fn prompt_instructions(&self, tools: &[Box]) -> String; + fn to_provider_messages(&self, history: &[ConversationMessage]) -> Vec; + fn should_send_tool_specs(&self) -> bool; +} + +#[derive(Default)] +pub struct XmlToolDispatcher; + +impl XmlToolDispatcher { + fn parse_xml_tool_calls(response: &str) -> (String, Vec) { + let mut text_parts = Vec::new(); + let mut calls = Vec::new(); + let mut remaining = response; + + while let Some(start) = remaining.find("") { + let before = &remaining[..start]; + if !before.trim().is_empty() { + text_parts.push(before.trim().to_string()); + } + + if let Some(end) = remaining[start..].find("") { + let inner = &remaining[start + 11..start + end]; + match serde_json::from_str::(inner.trim()) { + Ok(parsed) => { + let name = parsed + .get("name") + .and_then(Value::as_str) + .unwrap_or("") + .to_string(); + if name.is_empty() { + remaining = &remaining[start + end + 12..]; + continue; + } + let arguments = parsed + .get("arguments") + .cloned() + .unwrap_or_else(|| Value::Object(serde_json::Map::new())); + calls.push(ParsedToolCall { + name, + arguments, + tool_call_id: None, + }); + } + Err(e) => { + tracing::warn!("Malformed JSON: {e}"); + } + } + remaining = &remaining[start + end + 12..]; + } else { + break; + } + } + + if !remaining.trim().is_empty() { + text_parts.push(remaining.trim().to_string()); + } + + (text_parts.join("\n"), calls) + } + + pub fn tool_specs(tools: &[Box]) -> Vec { + tools.iter().map(|tool| tool.spec()).collect() + } +} + +impl ToolDispatcher for XmlToolDispatcher { + fn parse_response(&self, response: &ChatResponse) -> (String, Vec) { + let text = response.text_or_empty(); + Self::parse_xml_tool_calls(text) + } + + fn format_results(&self, results: &[ToolExecutionResult]) -> ConversationMessage { + let mut content = String::new(); + for result in results { + let status = if result.success { "ok" } else { "error" }; + let _ = writeln!( + content, + "\n{}\n", + result.name, status, result.output + ); + } + ConversationMessage::Chat(ChatMessage::user(format!("[Tool results]\n{content}"))) + } + + fn prompt_instructions(&self, tools: &[Box]) -> String { + let mut instructions = String::new(); + instructions.push_str("## Tool Use Protocol\n\n"); + instructions + .push_str("To use a tool, wrap a JSON object in tags:\n\n"); + instructions.push_str( + "```\n\n{\"name\": \"tool_name\", \"arguments\": {\"param\": \"value\"}}\n\n```\n\n", + ); + instructions.push_str("### Available Tools\n\n"); + + for tool in tools { + let _ = writeln!( + instructions, + "- **{}**: {}\n Parameters: `{}`", + tool.name(), + tool.description(), + tool.parameters_schema() + ); + } + + instructions + } + + fn to_provider_messages(&self, history: &[ConversationMessage]) -> Vec { + history + .iter() + .flat_map(|msg| match msg { + ConversationMessage::Chat(chat) => vec![chat.clone()], + ConversationMessage::AssistantToolCalls { text, .. } => { + vec![ChatMessage::assistant(text.clone().unwrap_or_default())] + } + ConversationMessage::ToolResults(results) => { + let mut content = String::new(); + for result in results { + let _ = writeln!( + content, + "\n{}\n", + result.tool_call_id, result.content + ); + } + vec![ChatMessage::user(format!("[Tool results]\n{content}"))] + } + }) + .collect() + } + + fn should_send_tool_specs(&self) -> bool { + false + } +} + +pub struct NativeToolDispatcher; + +impl ToolDispatcher for NativeToolDispatcher { + fn parse_response(&self, response: &ChatResponse) -> (String, Vec) { + let text = response.text.clone().unwrap_or_default(); + let calls = response + .tool_calls + .iter() + .map(|tc| ParsedToolCall { + name: tc.name.clone(), + arguments: serde_json::from_str(&tc.arguments) + .unwrap_or_else(|_| Value::Object(serde_json::Map::new())), + tool_call_id: Some(tc.id.clone()), + }) + .collect(); + (text, calls) + } + + fn format_results(&self, results: &[ToolExecutionResult]) -> ConversationMessage { + let messages = results + .iter() + .map(|result| ToolResultMessage { + tool_call_id: result + .tool_call_id + .clone() + .unwrap_or_else(|| "unknown".to_string()), + content: result.output.clone(), + }) + .collect(); + ConversationMessage::ToolResults(messages) + } + + fn prompt_instructions(&self, _tools: &[Box]) -> String { + String::new() + } + + fn to_provider_messages(&self, history: &[ConversationMessage]) -> Vec { + history + .iter() + .flat_map(|msg| match msg { + ConversationMessage::Chat(chat) => vec![chat.clone()], + ConversationMessage::AssistantToolCalls { text, tool_calls } => { + let payload = serde_json::json!({ + "content": text, + "tool_calls": tool_calls, + }); + vec![ChatMessage::assistant(payload.to_string())] + } + ConversationMessage::ToolResults(results) => results + .iter() + .map(|result| { + ChatMessage::tool( + serde_json::json!({ + "tool_call_id": result.tool_call_id, + "content": result.content, + }) + .to_string(), + ) + }) + .collect(), + }) + .collect() + } + + fn should_send_tool_specs(&self) -> bool { + true + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn xml_dispatcher_parses_tool_calls() { + let response = ChatResponse { + text: Some( + "Checking\n{\"name\":\"shell\",\"arguments\":{\"command\":\"ls\"}}" + .into(), + ), + tool_calls: vec![], + }; + let dispatcher = XmlToolDispatcher; + let (_, calls) = dispatcher.parse_response(&response); + assert_eq!(calls.len(), 1); + assert_eq!(calls[0].name, "shell"); + } + + #[test] + fn native_dispatcher_roundtrip() { + let response = ChatResponse { + text: Some("ok".into()), + tool_calls: vec![crate::providers::ToolCall { + id: "tc1".into(), + name: "file_read".into(), + arguments: "{\"path\":\"a.txt\"}".into(), + }], + }; + let dispatcher = NativeToolDispatcher; + let (_, calls) = dispatcher.parse_response(&response); + assert_eq!(calls.len(), 1); + assert_eq!(calls[0].tool_call_id.as_deref(), Some("tc1")); + + let msg = dispatcher.format_results(&[ToolExecutionResult { + name: "file_read".into(), + output: "hello".into(), + success: true, + tool_call_id: Some("tc1".into()), + }]); + match msg { + ConversationMessage::ToolResults(results) => { + assert_eq!(results.len(), 1); + assert_eq!(results[0].tool_call_id, "tc1"); + } + _ => panic!("expected tool results"), + } + } + + #[test] + fn xml_format_results_contains_tool_result_tags() { + let dispatcher = XmlToolDispatcher; + let msg = dispatcher.format_results(&[ToolExecutionResult { + name: "shell".into(), + output: "ok".into(), + success: true, + tool_call_id: None, + }]); + let rendered = match msg { + ConversationMessage::Chat(chat) => chat.content, + _ => String::new(), + }; + assert!(rendered.contains(" { + assert_eq!(results.len(), 1); + assert_eq!(results[0].tool_call_id, "tc-1"); + } + _ => panic!("expected ToolResults variant"), + } + } +} diff --git a/src/agent/loop_.rs b/src/agent/loop_.rs index e7421ad..1888866 100644 --- a/src/agent/loop_.rs +++ b/src/agent/loop_.rs @@ -8,11 +8,10 @@ use crate::tools::{self, Tool}; use crate::util::truncate_with_ellipsis; use anyhow::Result; use std::fmt::Write; -use std::io::Write as IoWrite; +use std::io::Write as _; use std::sync::Arc; use std::time::Instant; use uuid::Uuid; - /// Maximum agentic tool-use iterations per user message to prevent runaway loops. const MAX_TOOL_ITERATIONS: usize = 10; @@ -113,7 +112,6 @@ async fn auto_compact_history( let summary_raw = provider .chat_with_system(Some(summarizer_system), &summarizer_user, model, 0.2) .await - .map(|resp| resp.text_or_empty().to_string()) .unwrap_or_else(|_| { // Fallback to deterministic local truncation when summarization fails. truncate_with_ellipsis(&transcript, COMPACTION_MAX_SUMMARY_CHARS) @@ -482,21 +480,11 @@ pub(crate) async fn run_tool_call_loop( } }; - let response_text = response.text.unwrap_or_default(); + let response_text = response; let mut assistant_history_content = response_text.clone(); - let mut parsed_text = response_text.clone(); - let mut tool_calls = parse_structured_tool_calls(&response.tool_calls); - - if !response.tool_calls.is_empty() { - assistant_history_content = - build_assistant_history_with_tool_calls(&response_text, &response.tool_calls); - } - - if tool_calls.is_empty() { - let (fallback_text, fallback_calls) = parse_tool_calls(&response_text); - parsed_text = fallback_text; - tool_calls = fallback_calls; - } + let (parsed_text, tool_calls) = parse_tool_calls(&response_text); + let mut parsed_text = parsed_text; + let mut tool_calls = tool_calls; if tool_calls.is_empty() { // No tool calls โ€” this is the final response diff --git a/src/agent/memory_loader.rs b/src/agent/memory_loader.rs new file mode 100644 index 0000000..f5733ec --- /dev/null +++ b/src/agent/memory_loader.rs @@ -0,0 +1,118 @@ +use crate::memory::Memory; +use async_trait::async_trait; +use std::fmt::Write; + +#[async_trait] +pub trait MemoryLoader: Send + Sync { + async fn load_context(&self, memory: &dyn Memory, user_message: &str) + -> anyhow::Result; +} + +pub struct DefaultMemoryLoader { + limit: usize, +} + +impl Default for DefaultMemoryLoader { + fn default() -> Self { + Self { limit: 5 } + } +} + +impl DefaultMemoryLoader { + pub fn new(limit: usize) -> Self { + Self { + limit: limit.max(1), + } + } +} + +#[async_trait] +impl MemoryLoader for DefaultMemoryLoader { + async fn load_context( + &self, + memory: &dyn Memory, + user_message: &str, + ) -> anyhow::Result { + let entries = memory.recall(user_message, self.limit).await?; + if entries.is_empty() { + return Ok(String::new()); + } + + let mut context = String::from("[Memory context]\n"); + for entry in entries { + let _ = writeln!(context, "- {}: {}", entry.key, entry.content); + } + context.push('\n'); + Ok(context) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::memory::{Memory, MemoryCategory, MemoryEntry}; + + struct MockMemory; + + #[async_trait] + impl Memory for MockMemory { + async fn store( + &self, + _key: &str, + _content: &str, + _category: MemoryCategory, + ) -> anyhow::Result<()> { + Ok(()) + } + + async fn recall(&self, _query: &str, limit: usize) -> anyhow::Result> { + if limit == 0 { + return Ok(vec![]); + } + Ok(vec![MemoryEntry { + id: "1".into(), + key: "k".into(), + content: "v".into(), + category: MemoryCategory::Conversation, + timestamp: "now".into(), + session_id: None, + score: None, + }]) + } + + async fn get(&self, _key: &str) -> anyhow::Result> { + Ok(None) + } + + async fn list( + &self, + _category: Option<&MemoryCategory>, + ) -> anyhow::Result> { + Ok(vec![]) + } + + async fn forget(&self, _key: &str) -> anyhow::Result { + Ok(true) + } + + async fn count(&self) -> anyhow::Result { + Ok(0) + } + + async fn health_check(&self) -> bool { + true + } + + fn name(&self) -> &str { + "mock" + } + } + + #[tokio::test] + async fn default_loader_formats_context() { + let loader = DefaultMemoryLoader::default(); + let context = loader.load_context(&MockMemory, "hello").await.unwrap(); + assert!(context.contains("[Memory context]")); + assert!(context.contains("- k: v")); + } +} diff --git a/src/agent/mod.rs b/src/agent/mod.rs index e3d7d16..63bf3f8 100644 --- a/src/agent/mod.rs +++ b/src/agent/mod.rs @@ -1,3 +1,24 @@ +pub mod agent; +pub mod dispatcher; pub mod loop_; +pub mod memory_loader; +pub mod prompt; +#[allow(unused_imports)] +pub use agent::{Agent, AgentBuilder}; pub use loop_::{process_message, run}; + +#[cfg(test)] +mod tests { + use super::*; + + fn assert_reexport_exists(_value: F) {} + + #[test] + fn run_function_is_reexported() { + assert_reexport_exists(run); + assert_reexport_exists(process_message); + assert_reexport_exists(loop_::run); + assert_reexport_exists(loop_::process_message); + } +} diff --git a/src/agent/prompt.rs b/src/agent/prompt.rs new file mode 100644 index 0000000..bdc426f --- /dev/null +++ b/src/agent/prompt.rs @@ -0,0 +1,304 @@ +use crate::config::IdentityConfig; +use crate::identity; +use crate::skills::Skill; +use crate::tools::Tool; +use anyhow::Result; +use chrono::Local; +use std::fmt::Write; +use std::path::Path; + +const BOOTSTRAP_MAX_CHARS: usize = 20_000; + +pub struct PromptContext<'a> { + pub workspace_dir: &'a Path, + pub model_name: &'a str, + pub tools: &'a [Box], + pub skills: &'a [Skill], + pub identity_config: Option<&'a IdentityConfig>, + pub dispatcher_instructions: &'a str, +} + +pub trait PromptSection: Send + Sync { + fn name(&self) -> &str; + fn build(&self, ctx: &PromptContext<'_>) -> Result; +} + +#[derive(Default)] +pub struct SystemPromptBuilder { + sections: Vec>, +} + +impl SystemPromptBuilder { + pub fn with_defaults() -> Self { + Self { + sections: vec![ + Box::new(IdentitySection), + Box::new(ToolsSection), + Box::new(SafetySection), + Box::new(SkillsSection), + Box::new(WorkspaceSection), + Box::new(DateTimeSection), + Box::new(RuntimeSection), + ], + } + } + + pub fn add_section(mut self, section: Box) -> Self { + self.sections.push(section); + self + } + + pub fn build(&self, ctx: &PromptContext<'_>) -> Result { + let mut output = String::new(); + for section in &self.sections { + let part = section.build(ctx)?; + if part.trim().is_empty() { + continue; + } + output.push_str(part.trim_end()); + output.push_str("\n\n"); + } + Ok(output) + } +} + +pub struct IdentitySection; +pub struct ToolsSection; +pub struct SafetySection; +pub struct SkillsSection; +pub struct WorkspaceSection; +pub struct RuntimeSection; +pub struct DateTimeSection; + +impl PromptSection for IdentitySection { + fn name(&self) -> &str { + "identity" + } + + fn build(&self, ctx: &PromptContext<'_>) -> Result { + let mut prompt = String::from("## Project Context\n\n"); + if let Some(config) = ctx.identity_config { + if identity::is_aieos_configured(config) { + if let Ok(Some(aieos)) = identity::load_aieos_identity(config, ctx.workspace_dir) { + let rendered = identity::aieos_to_system_prompt(&aieos); + if !rendered.is_empty() { + prompt.push_str(&rendered); + return Ok(prompt); + } + } + } + } + + prompt.push_str( + "The following workspace files define your identity, behavior, and context.\n\n", + ); + for file in [ + "AGENTS.md", + "SOUL.md", + "TOOLS.md", + "IDENTITY.md", + "USER.md", + "HEARTBEAT.md", + "BOOTSTRAP.md", + "MEMORY.md", + ] { + inject_workspace_file(&mut prompt, ctx.workspace_dir, file); + } + + Ok(prompt) + } +} + +impl PromptSection for ToolsSection { + fn name(&self) -> &str { + "tools" + } + + fn build(&self, ctx: &PromptContext<'_>) -> Result { + let mut out = String::from("## Tools\n\n"); + for tool in ctx.tools { + let _ = writeln!( + out, + "- **{}**: {}\n Parameters: `{}`", + tool.name(), + tool.description(), + tool.parameters_schema() + ); + } + if !ctx.dispatcher_instructions.is_empty() { + out.push('\n'); + out.push_str(ctx.dispatcher_instructions); + } + Ok(out) + } +} + +impl PromptSection for SafetySection { + fn name(&self) -> &str { + "safety" + } + + fn build(&self, _ctx: &PromptContext<'_>) -> Result { + Ok("## Safety\n\n- Do not exfiltrate private data.\n- Do not run destructive commands without asking.\n- Do not bypass oversight or approval mechanisms.\n- Prefer `trash` over `rm`.\n- When in doubt, ask before acting externally.".into()) + } +} + +impl PromptSection for SkillsSection { + fn name(&self) -> &str { + "skills" + } + + fn build(&self, ctx: &PromptContext<'_>) -> Result { + if ctx.skills.is_empty() { + return Ok(String::new()); + } + + let mut prompt = String::from("## Available Skills\n\n\n"); + for skill in ctx.skills { + let location = skill.location.clone().unwrap_or_else(|| { + ctx.workspace_dir + .join("skills") + .join(&skill.name) + .join("SKILL.md") + }); + let _ = writeln!( + prompt, + " \n {}\n {}\n {}\n ", + skill.name, + skill.description, + location.display() + ); + } + prompt.push_str(""); + Ok(prompt) + } +} + +impl PromptSection for WorkspaceSection { + fn name(&self) -> &str { + "workspace" + } + + fn build(&self, ctx: &PromptContext<'_>) -> Result { + Ok(format!( + "## Workspace\n\nWorking directory: `{}`", + ctx.workspace_dir.display() + )) + } +} + +impl PromptSection for RuntimeSection { + fn name(&self) -> &str { + "runtime" + } + + fn build(&self, ctx: &PromptContext<'_>) -> Result { + let host = + hostname::get().map_or_else(|_| "unknown".into(), |h| h.to_string_lossy().to_string()); + Ok(format!( + "## Runtime\n\nHost: {host} | OS: {} | Model: {}", + std::env::consts::OS, + ctx.model_name + )) + } +} + +impl PromptSection for DateTimeSection { + fn name(&self) -> &str { + "datetime" + } + + fn build(&self, _ctx: &PromptContext<'_>) -> Result { + let now = Local::now(); + Ok(format!( + "## Current Date & Time\n\nTimezone: {}", + now.format("%Z") + )) + } +} + +fn inject_workspace_file(prompt: &mut String, workspace_dir: &Path, filename: &str) { + let path = workspace_dir.join(filename); + match std::fs::read_to_string(&path) { + Ok(content) => { + let trimmed = content.trim(); + if trimmed.is_empty() { + return; + } + let _ = writeln!(prompt, "### {filename}\n"); + let truncated = if trimmed.chars().count() > BOOTSTRAP_MAX_CHARS { + trimmed + .char_indices() + .nth(BOOTSTRAP_MAX_CHARS) + .map(|(idx, _)| &trimmed[..idx]) + .unwrap_or(trimmed) + } else { + trimmed + }; + prompt.push_str(truncated); + if truncated.len() < trimmed.len() { + let _ = writeln!( + prompt, + "\n\n[... truncated at {BOOTSTRAP_MAX_CHARS} chars โ€” use `read` for full file]\n" + ); + } else { + prompt.push_str("\n\n"); + } + } + Err(_) => { + let _ = writeln!(prompt, "### {filename}\n\n[File not found: {filename}]\n"); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::tools::traits::Tool; + use async_trait::async_trait; + + struct TestTool; + + #[async_trait] + impl Tool for TestTool { + fn name(&self) -> &str { + "test_tool" + } + + fn description(&self) -> &str { + "tool desc" + } + + fn parameters_schema(&self) -> serde_json::Value { + serde_json::json!({"type": "object"}) + } + + async fn execute( + &self, + _args: serde_json::Value, + ) -> anyhow::Result { + Ok(crate::tools::ToolResult { + success: true, + output: "ok".into(), + error: None, + }) + } + } + + #[test] + fn prompt_builder_assembles_sections() { + let tools: Vec> = vec![Box::new(TestTool)]; + let ctx = PromptContext { + workspace_dir: Path::new("/tmp"), + model_name: "test-model", + tools: &tools, + skills: &[], + identity_config: None, + dispatcher_instructions: "instr", + }; + let prompt = SystemPromptBuilder::with_defaults().build(&ctx).unwrap(); + assert!(prompt.contains("## Tools")); + assert!(prompt.contains("test_tool")); + assert!(prompt.contains("instr")); + } +} diff --git a/src/channels/mod.rs b/src/channels/mod.rs index a3d8281..3c96f19 100644 --- a/src/channels/mod.rs +++ b/src/channels/mod.rs @@ -765,18 +765,16 @@ pub async fn start_channels(config: Config) -> Result<()> { &config.autonomy, &config.workspace_dir, )); - let model = config .default_model .clone() - .unwrap_or_else(|| "anthropic/claude-sonnet-4".into()); + .unwrap_or_else(|| "anthropic/claude-sonnet-4-20250514".into()); let temperature = config.default_temperature; let mem: Arc = Arc::from(memory::create_memory( &config.memory, &config.workspace_dir, config.api_key.as_deref(), )?); - let (composio_key, composio_entity_id) = if config.composio.enabled { ( config.composio.api_key.as_deref(), @@ -785,6 +783,8 @@ pub async fn start_channels(config: Config) -> Result<()> { } else { (None, None) }; + // Build system prompt from workspace identity files + skills + let workspace = config.workspace_dir.clone(); let tools_registry = Arc::new(tools::all_tools_with_runtime( &security, runtime, @@ -793,14 +793,12 @@ pub async fn start_channels(config: Config) -> Result<()> { composio_entity_id, &config.browser, &config.http_request, - &config.workspace_dir, + &workspace, &config.agents, config.api_key.as_deref(), &config, )); - // Build system prompt from workspace identity files + skills - let workspace = config.workspace_dir.clone(); let skills = crate::skills::load_skills(&workspace); // Collect tool descriptions for the prompt @@ -1112,23 +1110,19 @@ mod tests { message: &str, _model: &str, _temperature: f64, - ) -> anyhow::Result { + ) -> anyhow::Result { tokio::time::sleep(self.delay).await; - Ok(ChatResponse::with_text(format!("echo: {message}"))) + Ok(format!("echo: {message}")) } } struct ToolCallingProvider; - fn tool_call_payload() -> ChatResponse { - ChatResponse { - text: Some(String::new()), - tool_calls: vec![ToolCall { - id: "call_1".into(), - name: "mock_price".into(), - arguments: r#"{"symbol":"BTC"}"#.into(), - }], - } + fn tool_call_payload() -> String { + r#" +{"name":"mock_price","arguments":{"symbol":"BTC"}} +"# + .to_string() } #[async_trait::async_trait] @@ -1139,7 +1133,7 @@ mod tests { _message: &str, _model: &str, _temperature: f64, - ) -> anyhow::Result { + ) -> anyhow::Result { Ok(tool_call_payload()) } @@ -1148,14 +1142,12 @@ mod tests { messages: &[ChatMessage], _model: &str, _temperature: f64, - ) -> anyhow::Result { + ) -> anyhow::Result { let has_tool_results = messages .iter() .any(|msg| msg.role == "user" && msg.content.contains("[Tool results]")); if has_tool_results { - Ok(ChatResponse::with_text( - "BTC is currently around $65,000 based on latest tool output.", - )) + Ok("BTC is currently around $65,000 based on latest tool output.".to_string()) } else { Ok(tool_call_payload()) } diff --git a/src/config/schema.rs b/src/config/schema.rs index f615d13..5183b81 100644 --- a/src/config/schema.rs +++ b/src/config/schema.rs @@ -37,6 +37,9 @@ pub struct Config { #[serde(default)] pub scheduler: SchedulerConfig, + #[serde(default)] + pub agent: AgentConfig, + /// Model routing rules โ€” route `hint:` to specific provider+model combos. #[serde(default)] pub model_routes: Vec, @@ -209,6 +212,41 @@ impl Default for HardwareConfig { } } +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct AgentConfig { + #[serde(default = "default_agent_max_tool_iterations")] + pub max_tool_iterations: usize, + #[serde(default = "default_agent_max_history_messages")] + pub max_history_messages: usize, + #[serde(default)] + pub parallel_tools: bool, + #[serde(default = "default_agent_tool_dispatcher")] + pub tool_dispatcher: String, +} + +fn default_agent_max_tool_iterations() -> usize { + 10 +} + +fn default_agent_max_history_messages() -> usize { + 50 +} + +fn default_agent_tool_dispatcher() -> String { + "auto".into() +} + +impl Default for AgentConfig { + fn default() -> Self { + Self { + max_tool_iterations: default_agent_max_tool_iterations(), + max_history_messages: default_agent_max_history_messages(), + parallel_tools: false, + tool_dispatcher: default_agent_tool_dispatcher(), + } + } +} + // โ”€โ”€ Identity (AIEOS / OpenClaw format) โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ #[derive(Debug, Clone, Serialize, Deserialize)] @@ -1507,6 +1545,7 @@ impl Default for Config { runtime: RuntimeConfig::default(), reliability: ReliabilityConfig::default(), scheduler: SchedulerConfig::default(), + agent: AgentConfig::default(), model_routes: Vec::new(), heartbeat: HeartbeatConfig::default(), channels_config: ChannelsConfig::default(), @@ -1873,6 +1912,7 @@ mod tests { secrets: SecretsConfig::default(), browser: BrowserConfig::default(), http_request: HttpRequestConfig::default(), + agent: AgentConfig::default(), identity: IdentityConfig::default(), cost: CostConfig::default(), peripherals: PeripheralsConfig::default(), @@ -1922,6 +1962,32 @@ default_temperature = 0.7 assert_eq!(parsed.memory.conversation_retention_days, 30); } + #[test] + fn agent_config_defaults() { + let cfg = AgentConfig::default(); + assert_eq!(cfg.max_tool_iterations, 10); + assert_eq!(cfg.max_history_messages, 50); + assert!(!cfg.parallel_tools); + assert_eq!(cfg.tool_dispatcher, "auto"); + } + + #[test] + fn agent_config_deserializes() { + let raw = r#" +default_temperature = 0.7 +[agent] +max_tool_iterations = 20 +max_history_messages = 80 +parallel_tools = true +tool_dispatcher = "xml" +"#; + let parsed: Config = toml::from_str(raw).unwrap(); + assert_eq!(parsed.agent.max_tool_iterations, 20); + assert_eq!(parsed.agent.max_history_messages, 80); + assert!(parsed.agent.parallel_tools); + assert_eq!(parsed.agent.tool_dispatcher, "xml"); + } + #[test] fn config_save_and_load_tmpdir() { let dir = std::env::temp_dir().join("zeroclaw_test_config"); @@ -1951,6 +2017,7 @@ default_temperature = 0.7 secrets: SecretsConfig::default(), browser: BrowserConfig::default(), http_request: HttpRequestConfig::default(), + agent: AgentConfig::default(), identity: IdentityConfig::default(), cost: CostConfig::default(), peripherals: PeripheralsConfig::default(), diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index f9f5b6e..580fe4b 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -10,14 +10,8 @@ use crate::channels::{Channel, WhatsAppChannel}; use crate::config::Config; use crate::memory::{self, Memory, MemoryCategory}; -use crate::observability::{self, Observer}; -use crate::providers::{self, ChatMessage, Provider}; -use crate::runtime; -use crate::security::{ - pairing::{constant_time_eq, is_public_bind, PairingGuard}, - SecurityPolicy, -}; -use crate::tools::{self, Tool}; +use crate::providers::{self, Provider}; +use crate::security::pairing::{constant_time_eq, is_public_bind, PairingGuard}; use crate::util::truncate_with_ellipsis; use anyhow::Result; use axum::{ @@ -51,35 +45,6 @@ fn whatsapp_memory_key(msg: &crate::channels::traits::ChannelMessage) -> String format!("whatsapp_{}_{}", msg.sender, msg.id) } -fn normalize_gateway_reply(reply: String) -> String { - if reply.trim().is_empty() { - return "Model returned an empty response.".to_string(); - } - - reply -} - -async fn gateway_agent_reply(state: &AppState, message: &str) -> Result { - let mut history = vec![ - ChatMessage::system(state.system_prompt.as_str()), - ChatMessage::user(message), - ]; - - let reply = crate::agent::loop_::run_tool_call_loop( - state.provider.as_ref(), - &mut history, - state.tools_registry.as_ref(), - state.observer.as_ref(), - "gateway", - &state.model, - state.temperature, - true, // silent โ€” gateway responses go over HTTP - ) - .await?; - - Ok(normalize_gateway_reply(reply)) -} - /// How often the rate limiter sweeps stale IP entries from its map. const RATE_LIMITER_SWEEP_INTERVAL_SECS: u64 = 300; // 5 minutes @@ -207,9 +172,6 @@ fn client_key_from_headers(headers: &HeaderMap) -> String { #[derive(Clone)] pub struct AppState { pub provider: Arc, - pub observer: Arc, - pub tools_registry: Arc>>, - pub system_prompt: Arc, pub model: String, pub temperature: f64, pub mem: Arc, @@ -256,55 +218,7 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> { &config.workspace_dir, config.api_key.as_deref(), )?); - let observer: Arc = - Arc::from(observability::create_observer(&config.observability)); - let runtime: Arc = - Arc::from(runtime::create_runtime(&config.runtime)?); - let security = Arc::new(SecurityPolicy::from_config( - &config.autonomy, - &config.workspace_dir, - )); - let (composio_key, composio_entity_id) = if config.composio.enabled { - ( - config.composio.api_key.as_deref(), - Some(config.composio.entity_id.as_str()), - ) - } else { - (None, None) - }; - - let tools_registry = Arc::new(tools::all_tools_with_runtime( - &security, - runtime, - Arc::clone(&mem), - composio_key, - composio_entity_id, - &config.browser, - &config.http_request, - &config.workspace_dir, - &config.agents, - config.api_key.as_deref(), - &config, - )); - let skills = crate::skills::load_skills(&config.workspace_dir); - let tool_descs: Vec<(&str, &str)> = tools_registry - .iter() - .map(|tool| (tool.name(), tool.description())) - .collect(); - - let mut system_prompt = crate::channels::build_system_prompt( - &config.workspace_dir, - &model, - &tool_descs, - &skills, - Some(&config.identity), - None, // bootstrap_max_chars โ€” no compact context for gateway - ); - system_prompt.push_str(&crate::agent::loop_::build_tool_instructions( - tools_registry.as_ref(), - )); - let system_prompt = Arc::new(system_prompt); // Extract webhook secret for authentication let webhook_secret: Option> = config @@ -408,9 +322,6 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> { // Build shared state let state = AppState { provider, - observer, - tools_registry, - system_prompt, model, temperature, mem, @@ -594,9 +505,13 @@ async fn handle_webhook( .await; } - match gateway_agent_reply(&state, message).await { - Ok(reply) => { - let body = serde_json::json!({"response": reply, "model": state.model}); + match state + .provider + .simple_chat(message, &state.model, state.temperature) + .await + { + Ok(response) => { + let body = serde_json::json!({"response": response, "model": state.model}); (StatusCode::OK, Json(body)) } Err(e) => { @@ -744,10 +659,14 @@ async fn handle_whatsapp_message( } // Call the LLM - match gateway_agent_reply(&state, &msg.content).await { - Ok(reply) => { + match state + .provider + .simple_chat(&msg.content, &state.model, state.temperature) + .await + { + Ok(response) => { // Send reply via WhatsApp - if let Err(e) = wa.send(&reply, &msg.sender).await { + if let Err(e) = wa.send(&response, &msg.sender).await { tracing::error!("Failed to send WhatsApp reply: {e}"); } } @@ -966,9 +885,9 @@ mod tests { _message: &str, _model: &str, _temperature: f64, - ) -> anyhow::Result { + ) -> anyhow::Result { self.calls.fetch_add(1, Ordering::SeqCst); - Ok(crate::providers::ChatResponse::with_text("ok")) + Ok("ok".into()) } } @@ -1029,36 +948,25 @@ mod tests { } } - fn test_app_state( - provider: Arc, - memory: Arc, - auto_save: bool, - ) -> AppState { - AppState { - provider, - observer: Arc::new(crate::observability::NoopObserver), - tools_registry: Arc::new(Vec::new()), - system_prompt: Arc::new("test-system-prompt".into()), - model: "test-model".into(), - temperature: 0.0, - mem: memory, - auto_save, - webhook_secret: None, - pairing: Arc::new(PairingGuard::new(false, &[])), - rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100)), - idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300))), - whatsapp: None, - whatsapp_app_secret: None, - } - } - #[tokio::test] async fn webhook_idempotency_skips_duplicate_provider_calls() { let provider_impl = Arc::new(MockProvider::default()); let provider: Arc = provider_impl.clone(); let memory: Arc = Arc::new(MockMemory); - let state = test_app_state(provider, memory, false); + let state = AppState { + provider, + model: "test-model".into(), + temperature: 0.0, + mem: memory, + auto_save: false, + webhook_secret: None, + pairing: Arc::new(PairingGuard::new(false, &[])), + rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100)), + idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300))), + whatsapp: None, + whatsapp_app_secret: None, + }; let mut headers = HeaderMap::new(); headers.insert("X-Idempotency-Key", HeaderValue::from_static("abc-123")); @@ -1094,7 +1002,19 @@ mod tests { let tracking_impl = Arc::new(TrackingMemory::default()); let memory: Arc = tracking_impl.clone(); - let state = test_app_state(provider, memory, true); + let state = AppState { + provider, + model: "test-model".into(), + temperature: 0.0, + mem: memory, + auto_save: true, + webhook_secret: None, + pairing: Arc::new(PairingGuard::new(false, &[])), + rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100)), + idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300))), + whatsapp: None, + whatsapp_app_secret: None, + }; let headers = HeaderMap::new(); @@ -1126,110 +1046,6 @@ mod tests { assert_eq!(provider_impl.calls.load(Ordering::SeqCst), 2); } - #[derive(Default)] - struct StructuredToolCallProvider { - calls: AtomicUsize, - } - - #[async_trait] - impl Provider for StructuredToolCallProvider { - async fn chat_with_system( - &self, - _system_prompt: Option<&str>, - _message: &str, - _model: &str, - _temperature: f64, - ) -> anyhow::Result { - let turn = self.calls.fetch_add(1, Ordering::SeqCst); - - if turn == 0 { - return Ok(crate::providers::ChatResponse { - text: Some("Running tool...".into()), - tool_calls: vec![crate::providers::ToolCall { - id: "call_1".into(), - name: "mock_tool".into(), - arguments: r#"{"query":"gateway"}"#.into(), - }], - }); - } - - Ok(crate::providers::ChatResponse::with_text( - "Gateway tool result ready.", - )) - } - } - - struct MockTool { - calls: Arc, - } - - #[async_trait] - impl Tool for MockTool { - fn name(&self) -> &str { - "mock_tool" - } - - fn description(&self) -> &str { - "Mock tool for gateway tests" - } - - fn parameters_schema(&self) -> serde_json::Value { - serde_json::json!({ - "type": "object", - "properties": { - "query": {"type": "string"} - }, - "required": ["query"] - }) - } - - async fn execute( - &self, - args: serde_json::Value, - ) -> anyhow::Result { - self.calls.fetch_add(1, Ordering::SeqCst); - assert_eq!(args["query"], "gateway"); - - Ok(crate::tools::ToolResult { - success: true, - output: "ok".into(), - error: None, - }) - } - } - - #[tokio::test] - async fn webhook_executes_structured_tool_calls() { - let provider_impl = Arc::new(StructuredToolCallProvider::default()); - let provider: Arc = provider_impl.clone(); - let memory: Arc = Arc::new(MockMemory); - - let tool_calls = Arc::new(AtomicUsize::new(0)); - let tools: Vec> = vec![Box::new(MockTool { - calls: Arc::clone(&tool_calls), - })]; - - let mut state = test_app_state(provider, memory, false); - state.tools_registry = Arc::new(tools); - - let response = handle_webhook( - State(state), - HeaderMap::new(), - Ok(Json(WebhookBody { - message: "please use tool".into(), - })), - ) - .await - .into_response(); - - assert_eq!(response.status(), StatusCode::OK); - let payload = response.into_body().collect().await.unwrap().to_bytes(); - let parsed: serde_json::Value = serde_json::from_slice(&payload).unwrap(); - assert_eq!(parsed["response"], "Gateway tool result ready."); - assert_eq!(tool_calls.load(Ordering::SeqCst), 1); - assert_eq!(provider_impl.calls.load(Ordering::SeqCst), 2); - } - // โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ• // WhatsApp Signature Verification Tests (CWE-345 Prevention) // โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ• diff --git a/src/onboard/wizard.rs b/src/onboard/wizard.rs index 13ed3a8..2deee91 100644 --- a/src/onboard/wizard.rs +++ b/src/onboard/wizard.rs @@ -114,6 +114,7 @@ pub fn run_wizard() -> Result { runtime: RuntimeConfig::default(), reliability: crate::config::ReliabilityConfig::default(), scheduler: crate::config::schema::SchedulerConfig::default(), + agent: crate::config::schema::AgentConfig::default(), model_routes: Vec::new(), heartbeat: HeartbeatConfig::default(), channels_config, @@ -318,6 +319,7 @@ pub fn run_quick_setup( runtime: RuntimeConfig::default(), reliability: crate::config::ReliabilityConfig::default(), scheduler: crate::config::schema::SchedulerConfig::default(), + agent: crate::config::schema::AgentConfig::default(), model_routes: Vec::new(), heartbeat: HeartbeatConfig::default(), channels_config: ChannelsConfig::default(), diff --git a/src/providers/anthropic.rs b/src/providers/anthropic.rs index c3c7870..56efeb8 100644 --- a/src/providers/anthropic.rs +++ b/src/providers/anthropic.rs @@ -1,4 +1,8 @@ -use crate::providers::traits::{ChatResponse as ProviderChatResponse, Provider}; +use crate::providers::traits::{ + ChatMessage, ChatRequest as ProviderChatRequest, ChatResponse as ProviderChatResponse, + Provider, ToolCall as ProviderToolCall, +}; +use crate::tools::ToolSpec; use async_trait::async_trait; use reqwest::Client; use serde::{Deserialize, Serialize}; @@ -26,13 +30,76 @@ struct Message { } #[derive(Debug, Deserialize)] -struct ApiChatResponse { +struct ChatResponse { content: Vec, } #[derive(Debug, Deserialize)] struct ContentBlock { - text: String, + #[serde(rename = "type")] + kind: String, + #[serde(default)] + text: Option, +} + +#[derive(Debug, Serialize)] +struct NativeChatRequest { + model: String, + max_tokens: u32, + #[serde(skip_serializing_if = "Option::is_none")] + system: Option, + messages: Vec, + temperature: f64, + #[serde(skip_serializing_if = "Option::is_none")] + tools: Option>, +} + +#[derive(Debug, Serialize)] +struct NativeMessage { + role: String, + content: Vec, +} + +#[derive(Debug, Serialize)] +#[serde(tag = "type")] +enum NativeContentOut { + #[serde(rename = "text")] + Text { text: String }, + #[serde(rename = "tool_use")] + ToolUse { + id: String, + name: String, + input: serde_json::Value, + }, + #[serde(rename = "tool_result")] + ToolResult { tool_use_id: String, content: String }, +} + +#[derive(Debug, Serialize)] +struct NativeToolSpec { + name: String, + description: String, + input_schema: serde_json::Value, +} + +#[derive(Debug, Deserialize)] +struct NativeChatResponse { + #[serde(default)] + content: Vec, +} + +#[derive(Debug, Deserialize)] +struct NativeContentIn { + #[serde(rename = "type")] + kind: String, + #[serde(default)] + text: Option, + #[serde(default)] + id: Option, + #[serde(default)] + name: Option, + #[serde(default)] + input: Option, } impl AnthropicProvider { @@ -62,6 +129,186 @@ impl AnthropicProvider { fn is_setup_token(token: &str) -> bool { token.starts_with("sk-ant-oat01-") } + + fn apply_auth( + &self, + request: reqwest::RequestBuilder, + credential: &str, + ) -> reqwest::RequestBuilder { + if Self::is_setup_token(credential) { + request.header("Authorization", format!("Bearer {credential}")) + } else { + request.header("x-api-key", credential) + } + } + + fn convert_tools(tools: Option<&[ToolSpec]>) -> Option> { + let items = tools?; + if items.is_empty() { + return None; + } + Some( + items + .iter() + .map(|tool| NativeToolSpec { + name: tool.name.clone(), + description: tool.description.clone(), + input_schema: tool.parameters.clone(), + }) + .collect(), + ) + } + + fn parse_assistant_tool_call_message(content: &str) -> Option> { + let value = serde_json::from_str::(content).ok()?; + let tool_calls = value + .get("tool_calls") + .and_then(|v| serde_json::from_value::>(v.clone()).ok())?; + + let mut blocks = Vec::new(); + if let Some(text) = value + .get("content") + .and_then(serde_json::Value::as_str) + .map(str::trim) + .filter(|t| !t.is_empty()) + { + blocks.push(NativeContentOut::Text { + text: text.to_string(), + }); + } + for call in tool_calls { + let input = serde_json::from_str::(&call.arguments) + .unwrap_or_else(|_| serde_json::Value::Object(serde_json::Map::new())); + blocks.push(NativeContentOut::ToolUse { + id: call.id, + name: call.name, + input, + }); + } + Some(blocks) + } + + fn parse_tool_result_message(content: &str) -> Option { + let value = serde_json::from_str::(content).ok()?; + let tool_use_id = value + .get("tool_call_id") + .and_then(serde_json::Value::as_str)? + .to_string(); + let result = value + .get("content") + .and_then(serde_json::Value::as_str) + .unwrap_or("") + .to_string(); + Some(NativeMessage { + role: "user".to_string(), + content: vec![NativeContentOut::ToolResult { + tool_use_id, + content: result, + }], + }) + } + + fn convert_messages(messages: &[ChatMessage]) -> (Option, Vec) { + let mut system_prompt = None; + let mut native_messages = Vec::new(); + + for msg in messages { + match msg.role.as_str() { + "system" => { + if system_prompt.is_none() { + system_prompt = Some(msg.content.clone()); + } + } + "assistant" => { + if let Some(blocks) = Self::parse_assistant_tool_call_message(&msg.content) { + native_messages.push(NativeMessage { + role: "assistant".to_string(), + content: blocks, + }); + } else { + native_messages.push(NativeMessage { + role: "assistant".to_string(), + content: vec![NativeContentOut::Text { + text: msg.content.clone(), + }], + }); + } + } + "tool" => { + if let Some(tool_result) = Self::parse_tool_result_message(&msg.content) { + native_messages.push(tool_result); + } else { + native_messages.push(NativeMessage { + role: "user".to_string(), + content: vec![NativeContentOut::Text { + text: msg.content.clone(), + }], + }); + } + } + _ => { + native_messages.push(NativeMessage { + role: "user".to_string(), + content: vec![NativeContentOut::Text { + text: msg.content.clone(), + }], + }); + } + } + } + + (system_prompt, native_messages) + } + + fn parse_text_response(response: ChatResponse) -> anyhow::Result { + response + .content + .into_iter() + .find(|c| c.kind == "text") + .and_then(|c| c.text) + .ok_or_else(|| anyhow::anyhow!("No response from Anthropic")) + } + + fn parse_native_response(response: NativeChatResponse) -> ProviderChatResponse { + let mut text_parts = Vec::new(); + let mut tool_calls = Vec::new(); + + for block in response.content { + match block.kind.as_str() { + "text" => { + if let Some(text) = block.text.map(|t| t.trim().to_string()) { + if !text.is_empty() { + text_parts.push(text); + } + } + } + "tool_use" => { + let name = block.name.unwrap_or_default(); + if name.is_empty() { + continue; + } + let arguments = block + .input + .unwrap_or_else(|| serde_json::Value::Object(serde_json::Map::new())); + tool_calls.push(ProviderToolCall { + id: block.id.unwrap_or_else(|| uuid::Uuid::new_v4().to_string()), + name, + arguments: arguments.to_string(), + }); + } + _ => {} + } + } + + ProviderChatResponse { + text: if text_parts.is_empty() { + None + } else { + Some(text_parts.join("\n")) + }, + tool_calls, + } + } } #[async_trait] @@ -72,7 +319,7 @@ impl Provider for AnthropicProvider { message: &str, model: &str, temperature: f64, - ) -> anyhow::Result { + ) -> anyhow::Result { let credential = self.credential.as_ref().ok_or_else(|| { anyhow::anyhow!( "Anthropic credentials not set. Set ANTHROPIC_API_KEY or ANTHROPIC_OAUTH_TOKEN (setup-token)." @@ -97,11 +344,7 @@ impl Provider for AnthropicProvider { .header("content-type", "application/json") .json(&request); - if Self::is_setup_token(credential) { - request = request.header("Authorization", format!("Bearer {credential}")); - } else { - request = request.header("x-api-key", credential); - } + request = self.apply_auth(request, credential); let response = request.send().await?; @@ -109,14 +352,50 @@ impl Provider for AnthropicProvider { return Err(super::api_error("Anthropic", response).await); } - let chat_response: ApiChatResponse = response.json().await?; + let chat_response: ChatResponse = response.json().await?; + Self::parse_text_response(chat_response) + } - chat_response - .content - .into_iter() - .next() - .map(|c| ProviderChatResponse::with_text(c.text)) - .ok_or_else(|| anyhow::anyhow!("No response from Anthropic")) + async fn chat( + &self, + request: ProviderChatRequest<'_>, + model: &str, + temperature: f64, + ) -> anyhow::Result { + let credential = self.credential.as_ref().ok_or_else(|| { + anyhow::anyhow!( + "Anthropic credentials not set. Set ANTHROPIC_API_KEY or ANTHROPIC_OAUTH_TOKEN (setup-token)." + ) + })?; + + let (system_prompt, messages) = Self::convert_messages(request.messages); + let native_request = NativeChatRequest { + model: model.to_string(), + max_tokens: 4096, + system: system_prompt, + messages, + temperature, + tools: Self::convert_tools(request.tools), + }; + + let req = self + .client + .post(format!("{}/v1/messages", self.base_url)) + .header("anthropic-version", "2023-06-01") + .header("content-type", "application/json") + .json(&native_request); + + let response = self.apply_auth(req, credential).send().await?; + if !response.status().is_success() { + return Err(super::api_error("Anthropic", response).await); + } + + let native_response: NativeChatResponse = response.json().await?; + Ok(Self::parse_native_response(native_response)) + } + + fn supports_native_tools(&self) -> bool { + true } } @@ -241,15 +520,16 @@ mod tests { #[test] fn chat_response_deserializes() { let json = r#"{"content":[{"type":"text","text":"Hello there!"}]}"#; - let resp: ApiChatResponse = serde_json::from_str(json).unwrap(); + let resp: ChatResponse = serde_json::from_str(json).unwrap(); assert_eq!(resp.content.len(), 1); - assert_eq!(resp.content[0].text, "Hello there!"); + assert_eq!(resp.content[0].kind, "text"); + assert_eq!(resp.content[0].text.as_deref(), Some("Hello there!")); } #[test] fn chat_response_empty_content() { let json = r#"{"content":[]}"#; - let resp: ApiChatResponse = serde_json::from_str(json).unwrap(); + let resp: ChatResponse = serde_json::from_str(json).unwrap(); assert!(resp.content.is_empty()); } @@ -257,10 +537,10 @@ mod tests { fn chat_response_multiple_blocks() { let json = r#"{"content":[{"type":"text","text":"First"},{"type":"text","text":"Second"}]}"#; - let resp: ApiChatResponse = serde_json::from_str(json).unwrap(); + let resp: ChatResponse = serde_json::from_str(json).unwrap(); assert_eq!(resp.content.len(), 2); - assert_eq!(resp.content[0].text, "First"); - assert_eq!(resp.content[1].text, "Second"); + assert_eq!(resp.content[0].text.as_deref(), Some("First")); + assert_eq!(resp.content[1].text.as_deref(), Some("Second")); } #[test] diff --git a/src/providers/compatible.rs b/src/providers/compatible.rs index e9e39e1..a9942f0 100644 --- a/src/providers/compatible.rs +++ b/src/providers/compatible.rs @@ -2,7 +2,10 @@ //! Most LLM APIs follow the same `/v1/chat/completions` format. //! This module provides a single implementation that works for all of them. -use crate::providers::traits::{ChatMessage, ChatResponse, Provider, ToolCall}; +use crate::providers::traits::{ + ChatMessage, ChatRequest as ProviderChatRequest, ChatResponse as ProviderChatResponse, + Provider, ToolCall as ProviderToolCall, +}; use async_trait::async_trait; use reqwest::Client; use serde::{Deserialize, Serialize}; @@ -163,12 +166,11 @@ struct ResponseMessage { #[serde(default)] content: Option, #[serde(default)] - tool_calls: Option>, + tool_calls: Option>, } #[derive(Debug, Deserialize, Serialize)] -struct ApiToolCall { - id: Option, +struct ToolCall { #[serde(rename = "type")] kind: Option, function: Option, @@ -254,44 +256,6 @@ fn extract_responses_text(response: ResponsesResponse) -> Option { None } -fn map_response_message(message: ResponseMessage) -> ChatResponse { - let text = first_nonempty(message.content.as_deref()); - let tool_calls = message - .tool_calls - .unwrap_or_default() - .into_iter() - .enumerate() - .filter_map(|(index, call)| map_api_tool_call(call, index)) - .collect(); - - ChatResponse { text, tool_calls } -} - -fn map_api_tool_call(call: ApiToolCall, index: usize) -> Option { - if call.kind.as_deref().is_some_and(|kind| kind != "function") { - return None; - } - - let function = call.function?; - let name = function - .name - .and_then(|value| first_nonempty(Some(value.as_str())))?; - let arguments = function - .arguments - .and_then(|value| first_nonempty(Some(value.as_str()))) - .unwrap_or_else(|| "{}".to_string()); - let id = call - .id - .and_then(|value| first_nonempty(Some(value.as_str()))) - .unwrap_or_else(|| format!("call_{}", index + 1)); - - Some(ToolCall { - id, - name, - arguments, - }) -} - impl OpenAiCompatibleProvider { fn apply_auth_header( &self, @@ -311,7 +275,7 @@ impl OpenAiCompatibleProvider { system_prompt: Option<&str>, message: &str, model: &str, - ) -> anyhow::Result { + ) -> anyhow::Result { let request = ResponsesRequest { model: model.to_string(), input: vec![ResponsesInput { @@ -337,7 +301,6 @@ impl OpenAiCompatibleProvider { let responses: ResponsesResponse = response.json().await?; extract_responses_text(responses) - .map(ChatResponse::with_text) .ok_or_else(|| anyhow::anyhow!("No response from {} Responses API", self.name)) } } @@ -350,7 +313,7 @@ impl Provider for OpenAiCompatibleProvider { message: &str, model: &str, temperature: f64, - ) -> anyhow::Result { + ) -> anyhow::Result { let api_key = self.api_key.as_ref().ok_or_else(|| { anyhow::anyhow!( "{} API key not set. Run `zeroclaw onboard` or set the appropriate env var.", @@ -408,13 +371,27 @@ impl Provider for OpenAiCompatibleProvider { let chat_response: ApiChatResponse = response.json().await?; - let choice = chat_response + chat_response .choices .into_iter() .next() - .ok_or_else(|| anyhow::anyhow!("No response from {}", self.name))?; - - Ok(map_response_message(choice.message)) + .map(|c| { + // If tool_calls are present, serialize the full message as JSON + // so parse_tool_calls can handle the OpenAI-style format + if c.message.tool_calls.is_some() + && c.message + .tool_calls + .as_ref() + .map_or(false, |t| !t.is_empty()) + { + serde_json::to_string(&c.message) + .unwrap_or_else(|_| c.message.content.unwrap_or_default()) + } else { + // No tool calls, return content as-is + c.message.content.unwrap_or_default() + } + }) + .ok_or_else(|| anyhow::anyhow!("No response from {}", self.name)) } async fn chat_with_history( @@ -422,7 +399,7 @@ impl Provider for OpenAiCompatibleProvider { messages: &[ChatMessage], model: &str, temperature: f64, - ) -> anyhow::Result { + ) -> anyhow::Result { let api_key = self.api_key.as_ref().ok_or_else(|| { anyhow::anyhow!( "{} API key not set. Run `zeroclaw onboard` or set the appropriate env var.", @@ -482,13 +459,71 @@ impl Provider for OpenAiCompatibleProvider { let chat_response: ApiChatResponse = response.json().await?; - let choice = chat_response + chat_response .choices .into_iter() .next() - .ok_or_else(|| anyhow::anyhow!("No response from {}", self.name))?; + .map(|c| { + // If tool_calls are present, serialize the full message as JSON + // so parse_tool_calls can handle the OpenAI-style format + if c.message.tool_calls.is_some() + && c.message + .tool_calls + .as_ref() + .map_or(false, |t| !t.is_empty()) + { + serde_json::to_string(&c.message) + .unwrap_or_else(|_| c.message.content.unwrap_or_default()) + } else { + // No tool calls, return content as-is + c.message.content.unwrap_or_default() + } + }) + .ok_or_else(|| anyhow::anyhow!("No response from {}", self.name)) + } - Ok(map_response_message(choice.message)) + async fn chat( + &self, + request: ProviderChatRequest<'_>, + model: &str, + temperature: f64, + ) -> anyhow::Result { + let text = self + .chat_with_history(request.messages, model, temperature) + .await?; + + // Backward compatible path: chat_with_history may serialize tool_calls JSON into content. + if let Ok(message) = serde_json::from_str::(&text) { + let tool_calls = message + .tool_calls + .unwrap_or_default() + .into_iter() + .filter_map(|tc| { + let function = tc.function?; + let name = function.name?; + let arguments = function.arguments.unwrap_or_else(|| "{}".to_string()); + Some(ProviderToolCall { + id: uuid::Uuid::new_v4().to_string(), + name, + arguments, + }) + }) + .collect::>(); + + return Ok(ProviderChatResponse { + text: message.content, + tool_calls, + }); + } + + Ok(ProviderChatResponse { + text: Some(text), + tool_calls: vec![], + }) + } + + fn supports_native_tools(&self) -> bool { + true } } @@ -573,20 +608,6 @@ mod tests { assert!(resp.choices.is_empty()); } - #[test] - fn response_with_tool_calls_maps_structured_data() { - let json = r#"{"choices":[{"message":{"content":"Running checks","tool_calls":[{"id":"call_1","type":"function","function":{"name":"shell","arguments":"{\"command\":\"pwd\"}"}}]}}]}"#; - let resp: ApiChatResponse = serde_json::from_str(json).unwrap(); - let choice = resp.choices.into_iter().next().unwrap(); - - let mapped = map_response_message(choice.message); - assert_eq!(mapped.text.as_deref(), Some("Running checks")); - assert_eq!(mapped.tool_calls.len(), 1); - assert_eq!(mapped.tool_calls[0].id, "call_1"); - assert_eq!(mapped.tool_calls[0].name, "shell"); - assert_eq!(mapped.tool_calls[0].arguments, r#"{"command":"pwd"}"#); - } - #[test] fn x_api_key_auth_style() { let p = OpenAiCompatibleProvider::new( diff --git a/src/providers/gemini.rs b/src/providers/gemini.rs index 189daf0..a988224 100644 --- a/src/providers/gemini.rs +++ b/src/providers/gemini.rs @@ -3,7 +3,7 @@ //! - Gemini CLI OAuth tokens (reuse existing ~/.gemini/ authentication) //! - Google Cloud ADC (`GOOGLE_APPLICATION_CREDENTIALS`) -use crate::providers::traits::{ChatResponse, Provider}; +use crate::providers::traits::Provider; use async_trait::async_trait; use directories::UserDirs; use reqwest::Client; @@ -260,7 +260,7 @@ impl Provider for GeminiProvider { message: &str, model: &str, temperature: f64, - ) -> anyhow::Result { + ) -> anyhow::Result { let auth = self.auth.as_ref().ok_or_else(|| { anyhow::anyhow!( "Gemini API key not found. Options:\n\ @@ -319,7 +319,6 @@ impl Provider for GeminiProvider { .and_then(|c| c.into_iter().next()) .and_then(|c| c.content.parts.into_iter().next()) .and_then(|p| p.text) - .map(ChatResponse::with_text) .ok_or_else(|| anyhow::anyhow!("No response from Gemini")) } } diff --git a/src/providers/mod.rs b/src/providers/mod.rs index 713afe4..1ddaddc 100644 --- a/src/providers/mod.rs +++ b/src/providers/mod.rs @@ -9,7 +9,10 @@ pub mod router; pub mod traits; #[allow(unused_imports)] -pub use traits::{ChatMessage, ChatResponse, Provider, ToolCall}; +pub use traits::{ + ChatMessage, ChatRequest, ChatResponse, ConversationMessage, Provider, ToolCall, + ToolResultMessage, +}; use compatible::{AuthStyle, OpenAiCompatibleProvider}; use reliable::ReliableProvider; diff --git a/src/providers/ollama.rs b/src/providers/ollama.rs index 481d0bf..8ecfb5a 100644 --- a/src/providers/ollama.rs +++ b/src/providers/ollama.rs @@ -1,4 +1,4 @@ -use crate::providers::traits::{ChatResponse as ProviderChatResponse, Provider}; +use crate::providers::traits::Provider; use async_trait::async_trait; use reqwest::Client; use serde::{Deserialize, Serialize}; @@ -61,7 +61,7 @@ impl Provider for OllamaProvider { message: &str, model: &str, temperature: f64, - ) -> anyhow::Result { + ) -> anyhow::Result { let mut messages = Vec::new(); if let Some(sys) = system_prompt { @@ -93,9 +93,7 @@ impl Provider for OllamaProvider { } let chat_response: ApiChatResponse = response.json().await?; - Ok(ProviderChatResponse::with_text( - chat_response.message.content, - )) + Ok(chat_response.message.content) } } diff --git a/src/providers/openai.rs b/src/providers/openai.rs index 6b8bbe5..ef67678 100644 --- a/src/providers/openai.rs +++ b/src/providers/openai.rs @@ -1,4 +1,8 @@ -use crate::providers::traits::{ChatResponse, Provider}; +use crate::providers::traits::{ + ChatMessage, ChatRequest as ProviderChatRequest, ChatResponse as ProviderChatResponse, + Provider, ToolCall as ProviderToolCall, +}; +use crate::tools::ToolSpec; use async_trait::async_trait; use reqwest::Client; use serde::{Deserialize, Serialize}; @@ -22,7 +26,7 @@ struct Message { } #[derive(Debug, Deserialize)] -struct ApiChatResponse { +struct ChatResponse { choices: Vec, } @@ -36,6 +40,75 @@ struct ResponseMessage { content: String, } +#[derive(Debug, Serialize)] +struct NativeChatRequest { + model: String, + messages: Vec, + temperature: f64, + #[serde(skip_serializing_if = "Option::is_none")] + tools: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + tool_choice: Option, +} + +#[derive(Debug, Serialize)] +struct NativeMessage { + role: String, + #[serde(skip_serializing_if = "Option::is_none")] + content: Option, + #[serde(skip_serializing_if = "Option::is_none")] + tool_call_id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + tool_calls: Option>, +} + +#[derive(Debug, Serialize)] +struct NativeToolSpec { + #[serde(rename = "type")] + kind: String, + function: NativeToolFunctionSpec, +} + +#[derive(Debug, Serialize)] +struct NativeToolFunctionSpec { + name: String, + description: String, + parameters: serde_json::Value, +} + +#[derive(Debug, Serialize, Deserialize)] +struct NativeToolCall { + #[serde(skip_serializing_if = "Option::is_none")] + id: Option, + #[serde(rename = "type", skip_serializing_if = "Option::is_none")] + kind: Option, + function: NativeFunctionCall, +} + +#[derive(Debug, Serialize, Deserialize)] +struct NativeFunctionCall { + name: String, + arguments: String, +} + +#[derive(Debug, Deserialize)] +struct NativeChatResponse { + choices: Vec, +} + +#[derive(Debug, Deserialize)] +struct NativeChoice { + message: NativeResponseMessage, +} + +#[derive(Debug, Deserialize)] +struct NativeResponseMessage { + #[serde(default)] + content: Option, + #[serde(default)] + tool_calls: Option>, +} + impl OpenAiProvider { pub fn new(api_key: Option<&str>) -> Self { Self { @@ -47,6 +120,107 @@ impl OpenAiProvider { .unwrap_or_else(|_| Client::new()), } } + + fn convert_tools(tools: Option<&[ToolSpec]>) -> Option> { + tools.map(|items| { + items + .iter() + .map(|tool| NativeToolSpec { + kind: "function".to_string(), + function: NativeToolFunctionSpec { + name: tool.name.clone(), + description: tool.description.clone(), + parameters: tool.parameters.clone(), + }, + }) + .collect() + }) + } + + fn convert_messages(messages: &[ChatMessage]) -> Vec { + messages + .iter() + .map(|m| { + if m.role == "assistant" { + if let Ok(value) = serde_json::from_str::(&m.content) { + if let Some(tool_calls_value) = value.get("tool_calls") { + if let Ok(parsed_calls) = + serde_json::from_value::>( + tool_calls_value.clone(), + ) + { + let tool_calls = parsed_calls + .into_iter() + .map(|tc| NativeToolCall { + id: Some(tc.id), + kind: Some("function".to_string()), + function: NativeFunctionCall { + name: tc.name, + arguments: tc.arguments, + }, + }) + .collect::>(); + let content = value + .get("content") + .and_then(serde_json::Value::as_str) + .map(ToString::to_string); + return NativeMessage { + role: "assistant".to_string(), + content, + tool_call_id: None, + tool_calls: Some(tool_calls), + }; + } + } + } + } + + if m.role == "tool" { + if let Ok(value) = serde_json::from_str::(&m.content) { + let tool_call_id = value + .get("tool_call_id") + .and_then(serde_json::Value::as_str) + .map(ToString::to_string); + let content = value + .get("content") + .and_then(serde_json::Value::as_str) + .map(ToString::to_string); + return NativeMessage { + role: "tool".to_string(), + content, + tool_call_id, + tool_calls: None, + }; + } + } + + NativeMessage { + role: m.role.clone(), + content: Some(m.content.clone()), + tool_call_id: None, + tool_calls: None, + } + }) + .collect() + } + + fn parse_native_response(message: NativeResponseMessage) -> ProviderChatResponse { + let tool_calls = message + .tool_calls + .unwrap_or_default() + .into_iter() + .map(|tc| ProviderToolCall { + id: tc.id.unwrap_or_else(|| uuid::Uuid::new_v4().to_string()), + name: tc.function.name, + arguments: tc.function.arguments, + }) + .collect::>(); + + ProviderChatResponse { + text: message.content, + tool_calls, + } + } } #[async_trait] @@ -57,7 +231,7 @@ impl Provider for OpenAiProvider { message: &str, model: &str, temperature: f64, - ) -> anyhow::Result { + ) -> anyhow::Result { let api_key = self.api_key.as_ref().ok_or_else(|| { anyhow::anyhow!("OpenAI API key not set. Set OPENAI_API_KEY or edit config.toml.") })?; @@ -94,15 +268,60 @@ impl Provider for OpenAiProvider { return Err(super::api_error("OpenAI", response).await); } - let chat_response: ApiChatResponse = response.json().await?; + let chat_response: ChatResponse = response.json().await?; chat_response .choices .into_iter() .next() - .map(|c| ChatResponse::with_text(c.message.content)) + .map(|c| c.message.content) .ok_or_else(|| anyhow::anyhow!("No response from OpenAI")) } + + async fn chat( + &self, + request: ProviderChatRequest<'_>, + model: &str, + temperature: f64, + ) -> anyhow::Result { + let api_key = self.api_key.as_ref().ok_or_else(|| { + anyhow::anyhow!("OpenAI API key not set. Set OPENAI_API_KEY or edit config.toml.") + })?; + + let tools = Self::convert_tools(request.tools); + let native_request = NativeChatRequest { + model: model.to_string(), + messages: Self::convert_messages(request.messages), + temperature, + tool_choice: tools.as_ref().map(|_| "auto".to_string()), + tools, + }; + + let response = self + .client + .post("https://api.openai.com/v1/chat/completions") + .header("Authorization", format!("Bearer {api_key}")) + .json(&native_request) + .send() + .await?; + + if !response.status().is_success() { + return Err(super::api_error("OpenAI", response).await); + } + + let native_response: NativeChatResponse = response.json().await?; + let message = native_response + .choices + .into_iter() + .next() + .map(|c| c.message) + .ok_or_else(|| anyhow::anyhow!("No response from OpenAI"))?; + Ok(Self::parse_native_response(message)) + } + + fn supports_native_tools(&self) -> bool { + true + } } #[cfg(test)] @@ -184,7 +403,7 @@ mod tests { #[test] fn response_deserializes_single_choice() { let json = r#"{"choices":[{"message":{"content":"Hi!"}}]}"#; - let resp: ApiChatResponse = serde_json::from_str(json).unwrap(); + let resp: ChatResponse = serde_json::from_str(json).unwrap(); assert_eq!(resp.choices.len(), 1); assert_eq!(resp.choices[0].message.content, "Hi!"); } @@ -192,14 +411,14 @@ mod tests { #[test] fn response_deserializes_empty_choices() { let json = r#"{"choices":[]}"#; - let resp: ApiChatResponse = serde_json::from_str(json).unwrap(); + let resp: ChatResponse = serde_json::from_str(json).unwrap(); assert!(resp.choices.is_empty()); } #[test] fn response_deserializes_multiple_choices() { let json = r#"{"choices":[{"message":{"content":"A"}},{"message":{"content":"B"}}]}"#; - let resp: ApiChatResponse = serde_json::from_str(json).unwrap(); + let resp: ChatResponse = serde_json::from_str(json).unwrap(); assert_eq!(resp.choices.len(), 2); assert_eq!(resp.choices[0].message.content, "A"); } @@ -207,7 +426,7 @@ mod tests { #[test] fn response_with_unicode() { let json = r#"{"choices":[{"message":{"content":"ใ“ใ‚“ใซใกใฏ ๐Ÿฆ€"}}]}"#; - let resp: ApiChatResponse = serde_json::from_str(json).unwrap(); + let resp: ChatResponse = serde_json::from_str(json).unwrap(); assert_eq!(resp.choices[0].message.content, "ใ“ใ‚“ใซใกใฏ ๐Ÿฆ€"); } @@ -215,7 +434,7 @@ mod tests { fn response_with_long_content() { let long = "x".repeat(100_000); let json = format!(r#"{{"choices":[{{"message":{{"content":"{long}"}}}}]}}"#); - let resp: ApiChatResponse = serde_json::from_str(&json).unwrap(); + let resp: ChatResponse = serde_json::from_str(&json).unwrap(); assert_eq!(resp.choices[0].message.content.len(), 100_000); } } diff --git a/src/providers/openrouter.rs b/src/providers/openrouter.rs index 287dd88..5363651 100644 --- a/src/providers/openrouter.rs +++ b/src/providers/openrouter.rs @@ -1,4 +1,8 @@ -use crate::providers::traits::{ChatMessage, ChatResponse, Provider}; +use crate::providers::traits::{ + ChatMessage, ChatRequest as ProviderChatRequest, ChatResponse as ProviderChatResponse, + Provider, ToolCall as ProviderToolCall, +}; +use crate::tools::ToolSpec; use async_trait::async_trait; use reqwest::Client; use serde::{Deserialize, Serialize}; @@ -36,6 +40,75 @@ struct ResponseMessage { content: String, } +#[derive(Debug, Serialize)] +struct NativeChatRequest { + model: String, + messages: Vec, + temperature: f64, + #[serde(skip_serializing_if = "Option::is_none")] + tools: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + tool_choice: Option, +} + +#[derive(Debug, Serialize)] +struct NativeMessage { + role: String, + #[serde(skip_serializing_if = "Option::is_none")] + content: Option, + #[serde(skip_serializing_if = "Option::is_none")] + tool_call_id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + tool_calls: Option>, +} + +#[derive(Debug, Serialize)] +struct NativeToolSpec { + #[serde(rename = "type")] + kind: String, + function: NativeToolFunctionSpec, +} + +#[derive(Debug, Serialize)] +struct NativeToolFunctionSpec { + name: String, + description: String, + parameters: serde_json::Value, +} + +#[derive(Debug, Serialize, Deserialize)] +struct NativeToolCall { + #[serde(skip_serializing_if = "Option::is_none")] + id: Option, + #[serde(rename = "type", skip_serializing_if = "Option::is_none")] + kind: Option, + function: NativeFunctionCall, +} + +#[derive(Debug, Serialize, Deserialize)] +struct NativeFunctionCall { + name: String, + arguments: String, +} + +#[derive(Debug, Deserialize)] +struct NativeChatResponse { + choices: Vec, +} + +#[derive(Debug, Deserialize)] +struct NativeChoice { + message: NativeResponseMessage, +} + +#[derive(Debug, Deserialize)] +struct NativeResponseMessage { + #[serde(default)] + content: Option, + #[serde(default)] + tool_calls: Option>, +} + impl OpenRouterProvider { pub fn new(api_key: Option<&str>) -> Self { Self { @@ -47,6 +120,111 @@ impl OpenRouterProvider { .unwrap_or_else(|_| Client::new()), } } + + fn convert_tools(tools: Option<&[ToolSpec]>) -> Option> { + let items = tools?; + if items.is_empty() { + return None; + } + Some( + items + .iter() + .map(|tool| NativeToolSpec { + kind: "function".to_string(), + function: NativeToolFunctionSpec { + name: tool.name.clone(), + description: tool.description.clone(), + parameters: tool.parameters.clone(), + }, + }) + .collect(), + ) + } + + fn convert_messages(messages: &[ChatMessage]) -> Vec { + messages + .iter() + .map(|m| { + if m.role == "assistant" { + if let Ok(value) = serde_json::from_str::(&m.content) { + if let Some(tool_calls_value) = value.get("tool_calls") { + if let Ok(parsed_calls) = + serde_json::from_value::>( + tool_calls_value.clone(), + ) + { + let tool_calls = parsed_calls + .into_iter() + .map(|tc| NativeToolCall { + id: Some(tc.id), + kind: Some("function".to_string()), + function: NativeFunctionCall { + name: tc.name, + arguments: tc.arguments, + }, + }) + .collect::>(); + let content = value + .get("content") + .and_then(serde_json::Value::as_str) + .map(ToString::to_string); + return NativeMessage { + role: "assistant".to_string(), + content, + tool_call_id: None, + tool_calls: Some(tool_calls), + }; + } + } + } + } + + if m.role == "tool" { + if let Ok(value) = serde_json::from_str::(&m.content) { + let tool_call_id = value + .get("tool_call_id") + .and_then(serde_json::Value::as_str) + .map(ToString::to_string); + let content = value + .get("content") + .and_then(serde_json::Value::as_str) + .map(ToString::to_string); + return NativeMessage { + role: "tool".to_string(), + content, + tool_call_id, + tool_calls: None, + }; + } + } + + NativeMessage { + role: m.role.clone(), + content: Some(m.content.clone()), + tool_call_id: None, + tool_calls: None, + } + }) + .collect() + } + + fn parse_native_response(message: NativeResponseMessage) -> ProviderChatResponse { + let tool_calls = message + .tool_calls + .unwrap_or_default() + .into_iter() + .map(|tc| ProviderToolCall { + id: tc.id.unwrap_or_else(|| uuid::Uuid::new_v4().to_string()), + name: tc.function.name, + arguments: tc.function.arguments, + }) + .collect::>(); + + ProviderChatResponse { + text: message.content, + tool_calls, + } + } } #[async_trait] @@ -71,7 +249,7 @@ impl Provider for OpenRouterProvider { message: &str, model: &str, temperature: f64, - ) -> anyhow::Result { + ) -> anyhow::Result { let api_key = self.api_key.as_ref() .ok_or_else(|| anyhow::anyhow!("OpenRouter API key not set. Run `zeroclaw onboard` or set OPENROUTER_API_KEY env var."))?; @@ -118,7 +296,7 @@ impl Provider for OpenRouterProvider { .choices .into_iter() .next() - .map(|c| ChatResponse::with_text(c.message.content)) + .map(|c| c.message.content) .ok_or_else(|| anyhow::anyhow!("No response from OpenRouter")) } @@ -127,7 +305,7 @@ impl Provider for OpenRouterProvider { messages: &[ChatMessage], model: &str, temperature: f64, - ) -> anyhow::Result { + ) -> anyhow::Result { let api_key = self.api_key.as_ref() .ok_or_else(|| anyhow::anyhow!("OpenRouter API key not set. Run `zeroclaw onboard` or set OPENROUTER_API_KEY env var."))?; @@ -168,9 +346,59 @@ impl Provider for OpenRouterProvider { .choices .into_iter() .next() - .map(|c| ChatResponse::with_text(c.message.content)) + .map(|c| c.message.content) .ok_or_else(|| anyhow::anyhow!("No response from OpenRouter")) } + + async fn chat( + &self, + request: ProviderChatRequest<'_>, + model: &str, + temperature: f64, + ) -> anyhow::Result { + let api_key = self.api_key.as_ref().ok_or_else(|| anyhow::anyhow!( + "OpenRouter API key not set. Run `zeroclaw onboard` or set OPENROUTER_API_KEY env var." + ))?; + + let tools = Self::convert_tools(request.tools); + let native_request = NativeChatRequest { + model: model.to_string(), + messages: Self::convert_messages(request.messages), + temperature, + tool_choice: tools.as_ref().map(|_| "auto".to_string()), + tools, + }; + + let response = self + .client + .post("https://openrouter.ai/api/v1/chat/completions") + .header("Authorization", format!("Bearer {api_key}")) + .header( + "HTTP-Referer", + "https://github.com/theonlyhennygod/zeroclaw", + ) + .header("X-Title", "ZeroClaw") + .json(&native_request) + .send() + .await?; + + if !response.status().is_success() { + return Err(super::api_error("OpenRouter", response).await); + } + + let native_response: NativeChatResponse = response.json().await?; + let message = native_response + .choices + .into_iter() + .next() + .map(|c| c.message) + .ok_or_else(|| anyhow::anyhow!("No response from OpenRouter"))?; + Ok(Self::parse_native_response(message)) + } + + fn supports_native_tools(&self) -> bool { + true + } } #[cfg(test)] diff --git a/src/providers/reliable.rs b/src/providers/reliable.rs index 3494a41..9782ec4 100644 --- a/src/providers/reliable.rs +++ b/src/providers/reliable.rs @@ -1,4 +1,4 @@ -use super::traits::{ChatMessage, ChatResponse}; +use super::traits::ChatMessage; use super::Provider; use async_trait::async_trait; use std::collections::HashMap; @@ -156,7 +156,7 @@ impl Provider for ReliableProvider { message: &str, model: &str, temperature: f64, - ) -> anyhow::Result { + ) -> anyhow::Result { let models = self.model_chain(model); let mut failures = Vec::new(); @@ -254,7 +254,7 @@ impl Provider for ReliableProvider { messages: &[ChatMessage], model: &str, temperature: f64, - ) -> anyhow::Result { + ) -> anyhow::Result { let models = self.model_chain(model); let mut failures = Vec::new(); @@ -359,12 +359,12 @@ mod tests { _message: &str, _model: &str, _temperature: f64, - ) -> anyhow::Result { + ) -> anyhow::Result { let attempt = self.calls.fetch_add(1, Ordering::SeqCst) + 1; if attempt <= self.fail_until_attempt { anyhow::bail!(self.error); } - Ok(ChatResponse::with_text(self.response)) + Ok(self.response.to_string()) } async fn chat_with_history( @@ -372,12 +372,12 @@ mod tests { _messages: &[ChatMessage], _model: &str, _temperature: f64, - ) -> anyhow::Result { + ) -> anyhow::Result { let attempt = self.calls.fetch_add(1, Ordering::SeqCst) + 1; if attempt <= self.fail_until_attempt { anyhow::bail!(self.error); } - Ok(ChatResponse::with_text(self.response)) + Ok(self.response.to_string()) } } @@ -397,13 +397,13 @@ mod tests { _message: &str, model: &str, _temperature: f64, - ) -> anyhow::Result { + ) -> anyhow::Result { self.calls.fetch_add(1, Ordering::SeqCst); self.models_seen.lock().unwrap().push(model.to_string()); if self.fail_models.contains(&model) { anyhow::bail!("500 model {} unavailable", model); } - Ok(ChatResponse::with_text(self.response)) + Ok(self.response.to_string()) } } @@ -426,8 +426,8 @@ mod tests { 1, ); - let result = provider.chat("hello", "test", 0.0).await.unwrap(); - assert_eq!(result.text_or_empty(), "ok"); + let result = provider.simple_chat("hello", "test", 0.0).await.unwrap(); + assert_eq!(result, "ok"); assert_eq!(calls.load(Ordering::SeqCst), 1); } @@ -448,8 +448,8 @@ mod tests { 1, ); - let result = provider.chat("hello", "test", 0.0).await.unwrap(); - assert_eq!(result.text_or_empty(), "recovered"); + let result = provider.simple_chat("hello", "test", 0.0).await.unwrap(); + assert_eq!(result, "recovered"); assert_eq!(calls.load(Ordering::SeqCst), 2); } @@ -483,8 +483,8 @@ mod tests { 1, ); - let result = provider.chat("hello", "test", 0.0).await.unwrap(); - assert_eq!(result.text_or_empty(), "from fallback"); + let result = provider.simple_chat("hello", "test", 0.0).await.unwrap(); + assert_eq!(result, "from fallback"); assert_eq!(primary_calls.load(Ordering::SeqCst), 2); assert_eq!(fallback_calls.load(Ordering::SeqCst), 1); } @@ -517,7 +517,7 @@ mod tests { ); let err = provider - .chat("hello", "test", 0.0) + .simple_chat("hello", "test", 0.0) .await .expect_err("all providers should fail"); let msg = err.to_string(); @@ -572,8 +572,8 @@ mod tests { 1, ); - let result = provider.chat("hello", "test", 0.0).await.unwrap(); - assert_eq!(result.text_or_empty(), "from fallback"); + let result = provider.simple_chat("hello", "test", 0.0).await.unwrap(); + assert_eq!(result, "from fallback"); // Primary should have been called only once (no retries) assert_eq!(primary_calls.load(Ordering::SeqCst), 1); assert_eq!(fallback_calls.load(Ordering::SeqCst), 1); @@ -601,7 +601,7 @@ mod tests { .chat_with_history(&messages, "test", 0.0) .await .unwrap(); - assert_eq!(result.text_or_empty(), "history ok"); + assert_eq!(result, "history ok"); assert_eq!(calls.load(Ordering::SeqCst), 2); } @@ -640,7 +640,7 @@ mod tests { .chat_with_history(&messages, "test", 0.0) .await .unwrap(); - assert_eq!(result.text_or_empty(), "fallback ok"); + assert_eq!(result, "fallback ok"); assert_eq!(primary_calls.load(Ordering::SeqCst), 2); assert_eq!(fallback_calls.load(Ordering::SeqCst), 1); } @@ -827,7 +827,7 @@ mod tests { message: &str, model: &str, temperature: f64, - ) -> anyhow::Result { + ) -> anyhow::Result { self.as_ref() .chat_with_system(system_prompt, message, model, temperature) .await diff --git a/src/providers/router.rs b/src/providers/router.rs index eb3101f..ccbdffb 100644 --- a/src/providers/router.rs +++ b/src/providers/router.rs @@ -1,4 +1,4 @@ -use super::traits::{ChatMessage, ChatResponse}; +use super::traits::{ChatMessage, ChatRequest, ChatResponse}; use super::Provider; use async_trait::async_trait; use std::collections::HashMap; @@ -98,7 +98,7 @@ impl Provider for RouterProvider { message: &str, model: &str, temperature: f64, - ) -> anyhow::Result { + ) -> anyhow::Result { let (provider_idx, resolved_model) = self.resolve(model); let (provider_name, provider) = &self.providers[provider_idx]; @@ -118,7 +118,7 @@ impl Provider for RouterProvider { messages: &[ChatMessage], model: &str, temperature: f64, - ) -> anyhow::Result { + ) -> anyhow::Result { let (provider_idx, resolved_model) = self.resolve(model); let (_, provider) = &self.providers[provider_idx]; provider @@ -126,6 +126,24 @@ impl Provider for RouterProvider { .await } + async fn chat( + &self, + request: ChatRequest<'_>, + model: &str, + temperature: f64, + ) -> anyhow::Result { + let (provider_idx, resolved_model) = self.resolve(model); + let (_, provider) = &self.providers[provider_idx]; + provider.chat(request, &resolved_model, temperature).await + } + + fn supports_native_tools(&self) -> bool { + self.providers + .get(self.default_index) + .map(|(_, p)| p.supports_native_tools()) + .unwrap_or(false) + } + async fn warmup(&self) -> anyhow::Result<()> { for (name, provider) in &self.providers { tracing::info!(provider = name, "Warming up routed provider"); @@ -175,10 +193,10 @@ mod tests { _message: &str, model: &str, _temperature: f64, - ) -> anyhow::Result { + ) -> anyhow::Result { self.calls.fetch_add(1, Ordering::SeqCst); *self.last_model.lock().unwrap() = model.to_string(); - Ok(ChatResponse::with_text(self.response)) + Ok(self.response.to_string()) } } @@ -229,7 +247,7 @@ mod tests { message: &str, model: &str, temperature: f64, - ) -> anyhow::Result { + ) -> anyhow::Result { self.as_ref() .chat_with_system(system_prompt, message, model, temperature) .await @@ -246,8 +264,11 @@ mod tests { ], ); - let result = router.chat("hello", "hint:reasoning", 0.5).await.unwrap(); - assert_eq!(result.text_or_empty(), "smart-response"); + let result = router + .simple_chat("hello", "hint:reasoning", 0.5) + .await + .unwrap(); + assert_eq!(result, "smart-response"); assert_eq!(mocks[1].call_count(), 1); assert_eq!(mocks[1].last_model(), "claude-opus"); assert_eq!(mocks[0].call_count(), 0); @@ -260,8 +281,8 @@ mod tests { vec![("fast", "fast", "llama-3-70b")], ); - let result = router.chat("hello", "hint:fast", 0.5).await.unwrap(); - assert_eq!(result.text_or_empty(), "fast-response"); + let result = router.simple_chat("hello", "hint:fast", 0.5).await.unwrap(); + assert_eq!(result, "fast-response"); assert_eq!(mocks[0].call_count(), 1); assert_eq!(mocks[0].last_model(), "llama-3-70b"); } @@ -273,8 +294,11 @@ mod tests { vec![], ); - let result = router.chat("hello", "hint:nonexistent", 0.5).await.unwrap(); - assert_eq!(result.text_or_empty(), "default-response"); + let result = router + .simple_chat("hello", "hint:nonexistent", 0.5) + .await + .unwrap(); + assert_eq!(result, "default-response"); assert_eq!(mocks[0].call_count(), 1); // Falls back to default with the hint as model name assert_eq!(mocks[0].last_model(), "hint:nonexistent"); @@ -291,10 +315,10 @@ mod tests { ); let result = router - .chat("hello", "anthropic/claude-sonnet-4-20250514", 0.5) + .simple_chat("hello", "anthropic/claude-sonnet-4-20250514", 0.5) .await .unwrap(); - assert_eq!(result.text_or_empty(), "primary-response"); + assert_eq!(result, "primary-response"); assert_eq!(mocks[0].call_count(), 1); assert_eq!(mocks[0].last_model(), "anthropic/claude-sonnet-4-20250514"); } @@ -355,7 +379,7 @@ mod tests { .chat_with_system(Some("system"), "hello", "model", 0.5) .await .unwrap(); - assert_eq!(result.text_or_empty(), "response"); + assert_eq!(result, "response"); assert_eq!(mock.call_count(), 1); } } diff --git a/src/providers/traits.rs b/src/providers/traits.rs index d1f8dd1..fdbd5cc 100644 --- a/src/providers/traits.rs +++ b/src/providers/traits.rs @@ -1,3 +1,4 @@ +use crate::tools::ToolSpec; use async_trait::async_trait; use serde::{Deserialize, Serialize}; @@ -29,6 +30,13 @@ impl ChatMessage { content: content.into(), } } + + pub fn tool(content: impl Into) -> Self { + Self { + role: "tool".into(), + content: content.into(), + } + } } /// A tool call requested by the LLM. @@ -49,14 +57,6 @@ pub struct ChatResponse { } impl ChatResponse { - /// Convenience: construct a plain text response with no tool calls. - pub fn with_text(text: impl Into) -> Self { - Self { - text: Some(text.into()), - tool_calls: vec![], - } - } - /// True when the LLM wants to invoke at least one tool. pub fn has_tool_calls(&self) -> bool { !self.tool_calls.is_empty() @@ -68,6 +68,13 @@ impl ChatResponse { } } +/// Request payload for provider chat calls. +#[derive(Debug, Clone, Copy)] +pub struct ChatRequest<'a> { + pub messages: &'a [ChatMessage], + pub tools: Option<&'a [ToolSpec]>, +} + /// A tool result to feed back to the LLM. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ToolResultMessage { @@ -77,7 +84,7 @@ pub struct ToolResultMessage { /// A message in a multi-turn conversation, including tool interactions. #[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(tag = "type")] +#[serde(tag = "type", content = "data")] pub enum ConversationMessage { /// Regular chat message (system, user, assistant). Chat(ChatMessage), @@ -86,29 +93,34 @@ pub enum ConversationMessage { text: Option, tool_calls: Vec, }, - /// Result of a tool execution, fed back to the LLM. - ToolResult(ToolResultMessage), + /// Results of tool executions, fed back to the LLM. + ToolResults(Vec), } #[async_trait] pub trait Provider: Send + Sync { - async fn chat( + /// Simple one-shot chat (single user message, no explicit system prompt). + /// + /// This is the preferred API for non-agentic direct interactions. + async fn simple_chat( &self, message: &str, model: &str, temperature: f64, - ) -> anyhow::Result { - self.chat_with_system(None, message, model, temperature) - .await + ) -> anyhow::Result { + self.chat_with_system(None, message, model, temperature).await } + /// One-shot chat with optional system prompt. + /// + /// Kept for compatibility and advanced one-shot prompting. async fn chat_with_system( &self, system_prompt: Option<&str>, message: &str, model: &str, temperature: f64, - ) -> anyhow::Result; + ) -> anyhow::Result; /// Multi-turn conversation. Default implementation extracts the last user /// message and delegates to `chat_with_system`. @@ -117,7 +129,7 @@ pub trait Provider: Send + Sync { messages: &[ChatMessage], model: &str, temperature: f64, - ) -> anyhow::Result { + ) -> anyhow::Result { let system = messages .iter() .find(|m| m.role == "system") @@ -131,6 +143,27 @@ pub trait Provider: Send + Sync { .await } + /// Structured chat API for agent loop callers. + async fn chat( + &self, + request: ChatRequest<'_>, + model: &str, + temperature: f64, + ) -> anyhow::Result { + let text = self + .chat_with_history(request.messages, model, temperature) + .await?; + Ok(ChatResponse { + text: Some(text), + tool_calls: Vec::new(), + }) + } + + /// Whether provider supports native tool calls over API. + fn supports_native_tools(&self) -> bool { + false + } + /// Warm up the HTTP connection pool (TLS handshake, DNS, HTTP/2 setup). /// Default implementation is a no-op; providers with HTTP clients should override. async fn warmup(&self) -> anyhow::Result<()> { @@ -153,6 +186,9 @@ mod tests { let asst = ChatMessage::assistant("Hi there"); assert_eq!(asst.role, "assistant"); + + let tool = ChatMessage::tool("{}"); + assert_eq!(tool.role, "tool"); } #[test] @@ -194,11 +230,11 @@ mod tests { let json = serde_json::to_string(&chat).unwrap(); assert!(json.contains("\"type\":\"Chat\"")); - let tool_result = ConversationMessage::ToolResult(ToolResultMessage { + let tool_result = ConversationMessage::ToolResults(vec![ToolResultMessage { tool_call_id: "1".into(), content: "done".into(), - }); + }]); let json = serde_json::to_string(&tool_result).unwrap(); - assert!(json.contains("\"type\":\"ToolResult\"")); + assert!(json.contains("\"type\":\"ToolResults\"")); } } diff --git a/src/tools/delegate.rs b/src/tools/delegate.rs index f205a58..7f30b64 100644 --- a/src/tools/delegate.rs +++ b/src/tools/delegate.rs @@ -221,14 +221,9 @@ impl Tool for DelegateTool { match result { Ok(response) => { - let has_tool_calls = response.has_tool_calls(); - let mut rendered = response.text.unwrap_or_default(); + let mut rendered = response; if rendered.trim().is_empty() { - if has_tool_calls { - rendered = "[Tool-only response; no text content]".to_string(); - } else { - rendered = "[Empty response]".to_string(); - } + rendered = "[Empty response]".to_string(); } Ok(ToolResult {