diff --git a/src/agent/loop_.rs b/src/agent/loop_.rs index 54b88f4..991905b 100644 --- a/src/agent/loop_.rs +++ b/src/agent/loop_.rs @@ -1,16 +1,44 @@ use crate::config::Config; use crate::memory::{self, Memory, MemoryCategory}; use crate::observability::{self, Observer, ObserverEvent}; -use crate::providers::{self, Provider}; +use crate::providers::{self, ChatMessage, Provider}; use crate::runtime; use crate::security::SecurityPolicy; -use crate::tools; +use crate::tools::{self, Tool}; use crate::util::truncate_with_ellipsis; use anyhow::Result; use std::fmt::Write; +use std::io::Write as IoWrite; use std::sync::Arc; use std::time::Instant; +/// Maximum agentic tool-use iterations per user message to prevent runaway loops. +const MAX_TOOL_ITERATIONS: usize = 10; + +/// Maximum number of non-system messages to keep in history. +/// When exceeded, the oldest messages are dropped (system prompt is always preserved). +const MAX_HISTORY_MESSAGES: usize = 50; + +/// Trim conversation history to prevent unbounded growth. +/// Preserves the system prompt (first message if role=system) and the most recent messages. +fn trim_history(history: &mut Vec) { + // Nothing to trim if within limit + let has_system = history.first().map_or(false, |m| m.role == "system"); + let non_system_count = if has_system { + history.len() - 1 + } else { + history.len() + }; + + if non_system_count <= MAX_HISTORY_MESSAGES { + return; + } + + let start = if has_system { 1 } else { 0 }; + let to_remove = non_system_count - MAX_HISTORY_MESSAGES; + history.drain(start..start + to_remove); +} + /// Build context preamble by searching memory for relevant entries async fn build_context(mem: &dyn Memory, user_msg: &str) -> String { let mut context = String::new(); @@ -29,6 +57,178 @@ async fn build_context(mem: &dyn Memory, user_msg: &str) -> String { context } +/// Find a tool by name in the registry. +fn find_tool<'a>(tools: &'a [Box], name: &str) -> Option<&'a dyn Tool> { + tools.iter().find(|t| t.name() == name).map(|t| t.as_ref()) +} + +/// Parse tool calls from an LLM response that uses XML-style function calling. +/// +/// Expected format (common with system-prompt-guided tool use): +/// ```text +/// +/// {"name": "shell", "arguments": {"command": "ls"}} +/// +/// ``` +/// +/// Also supports JSON with `tool_calls` array from OpenAI-format responses. +fn parse_tool_calls(response: &str) -> (String, Vec) { + let mut text_parts = Vec::new(); + let mut calls = Vec::new(); + let mut remaining = response; + + while let Some(start) = remaining.find("") { + // Everything before the tag is text + let before = &remaining[..start]; + if !before.trim().is_empty() { + text_parts.push(before.trim().to_string()); + } + + if let Some(end) = remaining[start..].find("") { + let inner = &remaining[start + 11..start + end]; + match serde_json::from_str::(inner.trim()) { + Ok(parsed) => { + let name = parsed + .get("name") + .and_then(|v| v.as_str()) + .unwrap_or("") + .to_string(); + let arguments = parsed + .get("arguments") + .cloned() + .unwrap_or(serde_json::Value::Object(serde_json::Map::new())); + calls.push(ParsedToolCall { name, arguments }); + } + Err(e) => { + tracing::warn!("Malformed JSON: {e}"); + } + } + remaining = &remaining[start + end + 12..]; + } else { + break; + } + } + + // Remaining text after last tool call + if !remaining.trim().is_empty() { + text_parts.push(remaining.trim().to_string()); + } + + (text_parts.join("\n"), calls) +} + +#[derive(Debug)] +struct ParsedToolCall { + name: String, + arguments: serde_json::Value, +} + +/// Execute a single turn of the agent loop: send messages, parse tool calls, +/// execute tools, and loop until the LLM produces a final text response. +async fn agent_turn( + provider: &dyn Provider, + history: &mut Vec, + tools_registry: &[Box], + observer: &dyn Observer, + model: &str, + temperature: f64, +) -> Result { + for _iteration in 0..MAX_TOOL_ITERATIONS { + let response = provider + .chat_with_history(history, model, temperature) + .await?; + + let (text, tool_calls) = parse_tool_calls(&response); + + if tool_calls.is_empty() { + // No tool calls — this is the final response + history.push(ChatMessage::assistant(&response)); + return Ok(if text.is_empty() { + response + } else { + text + }); + } + + // Print any text the LLM produced alongside tool calls + if !text.is_empty() { + print!("{text}"); + let _ = std::io::stdout().flush(); + } + + // Execute each tool call and build results + let mut tool_results = String::new(); + for call in &tool_calls { + let start = Instant::now(); + let result = if let Some(tool) = find_tool(tools_registry, &call.name) { + match tool.execute(call.arguments.clone()).await { + Ok(r) => { + observer.record_event(&ObserverEvent::ToolCall { + tool: call.name.clone(), + duration: start.elapsed(), + success: r.success, + }); + if r.success { + r.output + } else { + format!("Error: {}", r.error.unwrap_or_else(|| r.output)) + } + } + Err(e) => { + observer.record_event(&ObserverEvent::ToolCall { + tool: call.name.clone(), + duration: start.elapsed(), + success: false, + }); + format!("Error executing {}: {e}", call.name) + } + } + } else { + format!("Unknown tool: {}", call.name) + }; + + let _ = writeln!( + tool_results, + "\n{}\n", + call.name, result + ); + } + + // Add assistant message with tool calls + tool results to history + history.push(ChatMessage::assistant(&response)); + history.push(ChatMessage::user(format!( + "[Tool results]\n{tool_results}" + ))); + } + + anyhow::bail!("Agent exceeded maximum tool iterations ({MAX_TOOL_ITERATIONS})") +} + +/// Build the tool instruction block for the system prompt so the LLM knows +/// how to invoke tools. +fn build_tool_instructions(tools_registry: &[Box]) -> String { + let mut instructions = String::new(); + instructions.push_str("\n## Tool Use Protocol\n\n"); + instructions.push_str("To use a tool, wrap a JSON object in tags:\n\n"); + instructions.push_str("```\n\n{\"name\": \"tool_name\", \"arguments\": {\"param\": \"value\"}}\n\n```\n\n"); + instructions.push_str("You may use multiple tool calls in a single response. "); + instructions.push_str("After tool execution, results appear in tags. "); + instructions.push_str("Continue reasoning with the results until you can give a final answer.\n\n"); + instructions.push_str("### Available Tools\n\n"); + + for tool in tools_registry { + let _ = writeln!( + instructions, + "**{}**: {}\nParameters: `{}`\n", + tool.name(), + tool.description(), + tool.parameters_schema() + ); + } + + instructions +} + #[allow(clippy::too_many_lines)] pub async fn run( config: Config, @@ -61,7 +261,7 @@ pub async fn run( } else { None }; - let _tools = tools::all_tools_with_runtime( + let tools_registry = tools::all_tools_with_runtime( &security, runtime, mem.clone(), @@ -133,7 +333,7 @@ pub async fn run( "Execute actions on 1000+ apps via Composio (Gmail, Notion, GitHub, Slack, etc.). Use action='list' to discover, 'execute' to run, 'connect' to OAuth.", )); } - let system_prompt = crate::channels::build_system_prompt( + let mut system_prompt = crate::channels::build_system_prompt( &config.workspace_dir, model_name, &tool_descs, @@ -141,6 +341,9 @@ pub async fn run( Some(&config.identity), ); + // Append structured tool-use instructions with schemas + system_prompt.push_str(&build_tool_instructions(&tools_registry)); + // ── Execute ────────────────────────────────────────────────── let start = Instant::now(); @@ -160,9 +363,20 @@ pub async fn run( format!("{context}{msg}") }; - let response = provider - .chat_with_system(Some(&system_prompt), &enriched, model_name, temperature) - .await?; + let mut history = vec![ + ChatMessage::system(&system_prompt), + ChatMessage::user(&enriched), + ]; + + let response = agent_turn( + provider.as_ref(), + &mut history, + &tools_registry, + observer.as_ref(), + model_name, + temperature, + ) + .await?; println!("{response}"); // Auto-save assistant response to daily log @@ -184,6 +398,9 @@ pub async fn run( let _ = crate::channels::Channel::listen(&cli, tx).await; }); + // Persistent conversation history across turns + let mut history = vec![ChatMessage::system(&system_prompt)]; + while let Some(msg) = rx.recv().await { // Auto-save conversation turns if config.memory.auto_save { @@ -200,11 +417,29 @@ pub async fn run( format!("{context}{}", msg.content) }; - let response = provider - .chat_with_system(Some(&system_prompt), &enriched, model_name, temperature) - .await?; + history.push(ChatMessage::user(&enriched)); + + let response = match agent_turn( + provider.as_ref(), + &mut history, + &tools_registry, + observer.as_ref(), + model_name, + temperature, + ) + .await + { + Ok(resp) => resp, + Err(e) => { + eprintln!("\nError: {e}\n"); + continue; + } + }; println!("\n{response}\n"); + // Prevent unbounded history growth in long interactive sessions + trim_history(&mut history); + if config.memory.auto_save { let summary = truncate_with_ellipsis(&response, 100); let _ = mem @@ -224,3 +459,126 @@ pub async fn run( Ok(()) } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn parse_tool_calls_extracts_single_call() { + let response = r#"Let me check that. + +{"name": "shell", "arguments": {"command": "ls -la"}} +"#; + + let (text, calls) = parse_tool_calls(response); + assert_eq!(text, "Let me check that."); + assert_eq!(calls.len(), 1); + assert_eq!(calls[0].name, "shell"); + assert_eq!( + calls[0].arguments.get("command").unwrap().as_str().unwrap(), + "ls -la" + ); + } + + #[test] + fn parse_tool_calls_extracts_multiple_calls() { + let response = r#" +{"name": "file_read", "arguments": {"path": "a.txt"}} + + +{"name": "file_read", "arguments": {"path": "b.txt"}} +"#; + + let (_, calls) = parse_tool_calls(response); + assert_eq!(calls.len(), 2); + assert_eq!(calls[0].name, "file_read"); + assert_eq!(calls[1].name, "file_read"); + } + + #[test] + fn parse_tool_calls_returns_text_only_when_no_calls() { + let response = "Just a normal response with no tools."; + let (text, calls) = parse_tool_calls(response); + assert_eq!(text, "Just a normal response with no tools."); + assert!(calls.is_empty()); + } + + #[test] + fn parse_tool_calls_handles_malformed_json() { + let response = r#" +not valid json + +Some text after."#; + + let (text, calls) = parse_tool_calls(response); + assert!(calls.is_empty()); + assert!(text.contains("Some text after.")); + } + + #[test] + fn parse_tool_calls_text_before_and_after() { + let response = r#"Before text. + +{"name": "shell", "arguments": {"command": "echo hi"}} + +After text."#; + + let (text, calls) = parse_tool_calls(response); + assert!(text.contains("Before text.")); + assert!(text.contains("After text.")); + assert_eq!(calls.len(), 1); + } + + #[test] + fn build_tool_instructions_includes_all_tools() { + use crate::security::SecurityPolicy; + let security = Arc::new(SecurityPolicy::from_config( + &crate::config::AutonomyConfig::default(), + std::path::Path::new("/tmp"), + )); + let tools = tools::default_tools(security); + let instructions = build_tool_instructions(&tools); + + assert!(instructions.contains("## Tool Use Protocol")); + assert!(instructions.contains("")); + assert!(instructions.contains("shell")); + assert!(instructions.contains("file_read")); + assert!(instructions.contains("file_write")); + } + + #[test] + fn trim_history_preserves_system_prompt() { + let mut history = vec![ChatMessage::system("system prompt")]; + for i in 0..MAX_HISTORY_MESSAGES + 20 { + history.push(ChatMessage::user(format!("msg {i}"))); + } + let original_len = history.len(); + assert!(original_len > MAX_HISTORY_MESSAGES + 1); + + trim_history(&mut history); + + // System prompt preserved + assert_eq!(history[0].role, "system"); + assert_eq!(history[0].content, "system prompt"); + // Trimmed to limit + assert_eq!(history.len(), MAX_HISTORY_MESSAGES + 1); // +1 for system + // Most recent messages preserved + let last = &history[history.len() - 1]; + assert_eq!( + last.content, + format!("msg {}", MAX_HISTORY_MESSAGES + 19) + ); + } + + #[test] + fn trim_history_noop_when_within_limit() { + let mut history = vec![ + ChatMessage::system("sys"), + ChatMessage::user("hello"), + ChatMessage::assistant("hi"), + ]; + trim_history(&mut history); + assert_eq!(history.len(), 3); + } +} diff --git a/src/providers/compatible.rs b/src/providers/compatible.rs index 7c2eeec..5c1348c 100644 --- a/src/providers/compatible.rs +++ b/src/providers/compatible.rs @@ -2,7 +2,7 @@ //! Most LLM APIs follow the same `/v1/chat/completions` format. //! This module provides a single implementation that works for all of them. -use crate::providers::traits::Provider; +use crate::providers::traits::{ChatMessage, Provider}; use async_trait::async_trait; use reqwest::Client; use serde::{Deserialize, Serialize}; @@ -81,7 +81,7 @@ struct Message { } #[derive(Debug, Deserialize)] -struct ChatResponse { +struct ApiChatResponse { choices: Vec, } @@ -264,6 +264,7 @@ impl Provider for OpenAiCompatibleProvider { if !response.status().is_success() { let status = response.status(); let error = response.text().await?; + let sanitized = super::sanitize_api_error(&error); if status == reqwest::StatusCode::NOT_FOUND { return self @@ -271,16 +272,88 @@ impl Provider for OpenAiCompatibleProvider { .await .map_err(|responses_err| { anyhow::anyhow!( - "{} API error: {error} (chat completions unavailable; responses fallback failed: {responses_err})", + "{} API error ({status}): {sanitized} (chat completions unavailable; responses fallback failed: {responses_err})", self.name ) }); } - anyhow::bail!("{} API error: {error}", self.name); + anyhow::bail!("{} API error ({status}): {sanitized}", self.name); } - let chat_response: ChatResponse = response.json().await?; + let chat_response: ApiChatResponse = response.json().await?; + + chat_response + .choices + .into_iter() + .next() + .map(|c| c.message.content) + .ok_or_else(|| anyhow::anyhow!("No response from {}", self.name)) + } + + async fn chat_with_history( + &self, + messages: &[ChatMessage], + model: &str, + temperature: f64, + ) -> anyhow::Result { + let api_key = self.api_key.as_ref().ok_or_else(|| { + anyhow::anyhow!( + "{} API key not set. Run `zeroclaw onboard` or set the appropriate env var.", + self.name + ) + })?; + + let api_messages: Vec = messages + .iter() + .map(|m| Message { + role: m.role.clone(), + content: m.content.clone(), + }) + .collect(); + + let request = ChatRequest { + model: model.to_string(), + messages: api_messages, + temperature, + }; + + let url = self.chat_completions_url(); + let response = self + .apply_auth_header(self.client.post(&url).json(&request), api_key) + .send() + .await?; + + if !response.status().is_success() { + let status = response.status(); + + // Mirror chat_with_system: 404 may mean this provider uses the Responses API + if status == reqwest::StatusCode::NOT_FOUND { + // Extract system prompt and last user message for responses fallback + let system = messages.iter().find(|m| m.role == "system"); + let last_user = messages.iter().rfind(|m| m.role == "user"); + if let Some(user_msg) = last_user { + return self + .chat_via_responses( + api_key, + system.map(|m| m.content.as_str()), + &user_msg.content, + model, + ) + .await + .map_err(|responses_err| { + anyhow::anyhow!( + "{} API error (chat completions unavailable; responses fallback failed: {responses_err})", + self.name + ) + }); + } + } + + return Err(super::api_error(&self.name, response).await); + } + + let chat_response: ApiChatResponse = response.json().await?; chat_response .choices @@ -357,14 +430,14 @@ mod tests { #[test] fn response_deserializes() { let json = r#"{"choices":[{"message":{"content":"Hello from Venice!"}}]}"#; - let resp: ChatResponse = serde_json::from_str(json).unwrap(); + let resp: ApiChatResponse = serde_json::from_str(json).unwrap(); assert_eq!(resp.choices[0].message.content, "Hello from Venice!"); } #[test] fn response_empty_choices() { let json = r#"{"choices":[]}"#; - let resp: ChatResponse = serde_json::from_str(json).unwrap(); + let resp: ApiChatResponse = serde_json::from_str(json).unwrap(); assert!(resp.choices.is_empty()); } diff --git a/src/providers/mod.rs b/src/providers/mod.rs index 1ff85b7..db65d63 100644 --- a/src/providers/mod.rs +++ b/src/providers/mod.rs @@ -8,7 +8,7 @@ pub mod reliable; pub mod router; pub mod traits; -pub use traits::Provider; +pub use traits::{ChatMessage, Provider}; use compatible::{AuthStyle, OpenAiCompatibleProvider}; use reliable::ReliableProvider; diff --git a/src/providers/openrouter.rs b/src/providers/openrouter.rs index a760eaf..51aefcc 100644 --- a/src/providers/openrouter.rs +++ b/src/providers/openrouter.rs @@ -1,4 +1,4 @@ -use crate::providers::traits::Provider; +use crate::providers::traits::{ChatMessage, Provider}; use async_trait::async_trait; use reqwest::Client; use serde::{Deserialize, Serialize}; @@ -22,7 +22,7 @@ struct Message { } #[derive(Debug, Deserialize)] -struct ChatResponse { +struct ApiChatResponse { choices: Vec, } @@ -112,7 +112,57 @@ impl Provider for OpenRouterProvider { return Err(super::api_error("OpenRouter", response).await); } - let chat_response: ChatResponse = response.json().await?; + let chat_response: ApiChatResponse = response.json().await?; + + chat_response + .choices + .into_iter() + .next() + .map(|c| c.message.content) + .ok_or_else(|| anyhow::anyhow!("No response from OpenRouter")) + } + + async fn chat_with_history( + &self, + messages: &[ChatMessage], + model: &str, + temperature: f64, + ) -> anyhow::Result { + let api_key = self.api_key.as_ref() + .ok_or_else(|| anyhow::anyhow!("OpenRouter API key not set. Run `zeroclaw onboard` or set OPENROUTER_API_KEY env var."))?; + + let api_messages: Vec = messages + .iter() + .map(|m| Message { + role: m.role.clone(), + content: m.content.clone(), + }) + .collect(); + + let request = ChatRequest { + model: model.to_string(), + messages: api_messages, + temperature, + }; + + let response = self + .client + .post("https://openrouter.ai/api/v1/chat/completions") + .header("Authorization", format!("Bearer {api_key}")) + .header( + "HTTP-Referer", + "https://github.com/theonlyhennygod/zeroclaw", + ) + .header("X-Title", "ZeroClaw") + .json(&request) + .send() + .await?; + + if !response.status().is_success() { + return Err(super::api_error("OpenRouter", response).await); + } + + let chat_response: ApiChatResponse = response.json().await?; chat_response .choices diff --git a/src/providers/reliable.rs b/src/providers/reliable.rs index 921eeef..2b3cd96 100644 --- a/src/providers/reliable.rs +++ b/src/providers/reliable.rs @@ -1,3 +1,4 @@ +use super::traits::ChatMessage; use super::Provider; use async_trait::async_trait; use std::time::Duration; @@ -121,6 +122,68 @@ impl Provider for ReliableProvider { anyhow::bail!("All providers failed. Attempts:\n{}", failures.join("\n")) } + + async fn chat_with_history( + &self, + messages: &[ChatMessage], + model: &str, + temperature: f64, + ) -> anyhow::Result { + let mut failures = Vec::new(); + + for (provider_name, provider) in &self.providers { + let mut backoff_ms = self.base_backoff_ms; + + for attempt in 0..=self.max_retries { + match provider + .chat_with_history(messages, model, temperature) + .await + { + Ok(resp) => { + if attempt > 0 { + tracing::info!( + provider = provider_name, + attempt, + "Provider recovered after retries" + ); + } + return Ok(resp); + } + Err(e) => { + let non_retryable = is_non_retryable(&e); + failures.push(format!( + "{provider_name} attempt {}/{}: {e}", + attempt + 1, + self.max_retries + 1 + )); + + if non_retryable { + tracing::warn!( + provider = provider_name, + "Non-retryable error, switching provider" + ); + break; + } + + if attempt < self.max_retries { + tracing::warn!( + provider = provider_name, + attempt = attempt + 1, + max_retries = self.max_retries, + "Provider call failed, retrying" + ); + tokio::time::sleep(Duration::from_millis(backoff_ms)).await; + backoff_ms = (backoff_ms.saturating_mul(2)).min(10_000); + } + } + } + } + + tracing::warn!(provider = provider_name, "Switching to fallback provider"); + } + + anyhow::bail!("All providers failed. Attempts:\n{}", failures.join("\n")) + } } #[cfg(test)] @@ -151,6 +214,19 @@ mod tests { } Ok(self.response.to_string()) } + + async fn chat_with_history( + &self, + _messages: &[ChatMessage], + _model: &str, + _temperature: f64, + ) -> anyhow::Result { + let attempt = self.calls.fetch_add(1, Ordering::SeqCst) + 1; + if attempt <= self.fail_until_attempt { + anyhow::bail!(self.error); + } + Ok(self.response.to_string()) + } } #[tokio::test] @@ -330,4 +406,73 @@ mod tests { assert_eq!(primary_calls.load(Ordering::SeqCst), 1); assert_eq!(fallback_calls.load(Ordering::SeqCst), 1); } + + #[tokio::test] + async fn chat_with_history_retries_then_recovers() { + let calls = Arc::new(AtomicUsize::new(0)); + let provider = ReliableProvider::new( + vec![( + "primary".into(), + Box::new(MockProvider { + calls: Arc::clone(&calls), + fail_until_attempt: 1, + response: "history ok", + error: "temporary", + }), + )], + 2, + 1, + ); + + let messages = vec![ + ChatMessage::system("system"), + ChatMessage::user("hello"), + ]; + let result = provider + .chat_with_history(&messages, "test", 0.0) + .await + .unwrap(); + assert_eq!(result, "history ok"); + assert_eq!(calls.load(Ordering::SeqCst), 2); + } + + #[tokio::test] + async fn chat_with_history_falls_back() { + let primary_calls = Arc::new(AtomicUsize::new(0)); + let fallback_calls = Arc::new(AtomicUsize::new(0)); + + let provider = ReliableProvider::new( + vec![ + ( + "primary".into(), + Box::new(MockProvider { + calls: Arc::clone(&primary_calls), + fail_until_attempt: usize::MAX, + response: "never", + error: "primary down", + }), + ), + ( + "fallback".into(), + Box::new(MockProvider { + calls: Arc::clone(&fallback_calls), + fail_until_attempt: 0, + response: "fallback ok", + error: "fallback err", + }), + ), + ], + 1, + 1, + ); + + let messages = vec![ChatMessage::user("hello")]; + let result = provider + .chat_with_history(&messages, "test", 0.0) + .await + .unwrap(); + assert_eq!(result, "fallback ok"); + assert_eq!(primary_calls.load(Ordering::SeqCst), 2); + assert_eq!(fallback_calls.load(Ordering::SeqCst), 1); + } } diff --git a/src/providers/router.rs b/src/providers/router.rs index 2085276..2fec083 100644 --- a/src/providers/router.rs +++ b/src/providers/router.rs @@ -1,3 +1,4 @@ +use super::traits::ChatMessage; use super::Provider; use async_trait::async_trait; use std::collections::HashMap; @@ -112,6 +113,19 @@ impl Provider for RouterProvider { .await } + async fn chat_with_history( + &self, + messages: &[ChatMessage], + model: &str, + temperature: f64, + ) -> anyhow::Result { + let (provider_idx, resolved_model) = self.resolve(model); + let (_, provider) = &self.providers[provider_idx]; + provider + .chat_with_history(messages, &resolved_model, temperature) + .await + } + async fn warmup(&self) -> anyhow::Result<()> { for (name, provider) in &self.providers { tracing::info!(provider = name, "Warming up routed provider"); diff --git a/src/providers/traits.rs b/src/providers/traits.rs index ff9adad..84746ea 100644 --- a/src/providers/traits.rs +++ b/src/providers/traits.rs @@ -1,4 +1,86 @@ use async_trait::async_trait; +use serde::{Deserialize, Serialize}; + +/// A single message in a conversation. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ChatMessage { + pub role: String, + pub content: String, +} + +impl ChatMessage { + pub fn system(content: impl Into) -> Self { + Self { + role: "system".into(), + content: content.into(), + } + } + + pub fn user(content: impl Into) -> Self { + Self { + role: "user".into(), + content: content.into(), + } + } + + pub fn assistant(content: impl Into) -> Self { + Self { + role: "assistant".into(), + content: content.into(), + } + } +} + +/// A tool call requested by the LLM. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolCall { + pub id: String, + pub name: String, + pub arguments: String, +} + +/// An LLM response that may contain text, tool calls, or both. +#[derive(Debug, Clone)] +pub struct ChatResponse { + /// Text content of the response (may be empty if only tool calls). + pub text: Option, + /// Tool calls requested by the LLM. + pub tool_calls: Vec, +} + +impl ChatResponse { + /// True when the LLM wants to invoke at least one tool. + pub fn has_tool_calls(&self) -> bool { + !self.tool_calls.is_empty() + } + + /// Convenience: return text content or empty string. + pub fn text_or_empty(&self) -> &str { + self.text.as_deref().unwrap_or("") + } +} + +/// A tool result to feed back to the LLM. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolResultMessage { + pub tool_call_id: String, + pub content: String, +} + +/// A message in a multi-turn conversation, including tool interactions. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type")] +pub enum ConversationMessage { + /// Regular chat message (system, user, assistant). + Chat(ChatMessage), + /// Tool calls from the assistant (stored for history fidelity). + AssistantToolCalls { + text: Option, + tool_calls: Vec, + }, + /// Result of a tool execution, fed back to the LLM. + ToolResult(ToolResultMessage), +} #[async_trait] pub trait Provider: Send + Sync { @@ -15,9 +97,95 @@ pub trait Provider: Send + Sync { temperature: f64, ) -> anyhow::Result; + /// Multi-turn conversation. Default implementation extracts the last user + /// message and delegates to `chat_with_system`. + async fn chat_with_history( + &self, + messages: &[ChatMessage], + model: &str, + temperature: f64, + ) -> anyhow::Result { + let system = messages + .iter() + .find(|m| m.role == "system") + .map(|m| m.content.as_str()); + let last_user = messages + .iter() + .rfind(|m| m.role == "user") + .map(|m| m.content.as_str()) + .unwrap_or(""); + self.chat_with_system(system, last_user, model, temperature) + .await + } + /// Warm up the HTTP connection pool (TLS handshake, DNS, HTTP/2 setup). /// Default implementation is a no-op; providers with HTTP clients should override. async fn warmup(&self) -> anyhow::Result<()> { Ok(()) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn chat_message_constructors() { + let sys = ChatMessage::system("Be helpful"); + assert_eq!(sys.role, "system"); + assert_eq!(sys.content, "Be helpful"); + + let user = ChatMessage::user("Hello"); + assert_eq!(user.role, "user"); + + let asst = ChatMessage::assistant("Hi there"); + assert_eq!(asst.role, "assistant"); + } + + #[test] + fn chat_response_helpers() { + let empty = ChatResponse { + text: None, + tool_calls: vec![], + }; + assert!(!empty.has_tool_calls()); + assert_eq!(empty.text_or_empty(), ""); + + let with_tools = ChatResponse { + text: Some("Let me check".into()), + tool_calls: vec![ToolCall { + id: "1".into(), + name: "shell".into(), + arguments: "{}".into(), + }], + }; + assert!(with_tools.has_tool_calls()); + assert_eq!(with_tools.text_or_empty(), "Let me check"); + } + + #[test] + fn tool_call_serialization() { + let tc = ToolCall { + id: "call_123".into(), + name: "file_read".into(), + arguments: r#"{"path":"test.txt"}"#.into(), + }; + let json = serde_json::to_string(&tc).unwrap(); + assert!(json.contains("call_123")); + assert!(json.contains("file_read")); + } + + #[test] + fn conversation_message_variants() { + let chat = ConversationMessage::Chat(ChatMessage::user("hi")); + let json = serde_json::to_string(&chat).unwrap(); + assert!(json.contains("\"type\":\"Chat\"")); + + let tool_result = ConversationMessage::ToolResult(ToolResultMessage { + tool_call_id: "1".into(), + content: "done".into(), + }); + let json = serde_json::to_string(&tool_result).unwrap(); + assert!(json.contains("\"type\":\"ToolResult\"")); + } +}