feat: add multi-turn conversation history and tool execution

* feat: add multi-turn conversation history and tool execution

Major enhancement to the agent loop:

**Multi-turn conversation:**
- Add `ChatMessage` type with system/user/assistant constructors
- Add `chat_with_history` method to Provider trait (default impl
  delegates to `chat_with_system` for backward compatibility)
- Implement native `chat_with_history` on OpenRouter, Compatible,
  Reliable, and Router providers to send full message history
- Interactive mode now maintains persistent history across turns

**Tool execution:**
- Agent loop now parses `<tool_call>` XML tags from LLM responses
- Executes tools from the registry and feeds results back as
  `<tool_result>` messages
- Agentic loop continues until LLM produces final text (no tool calls)
- MAX_TOOL_ITERATIONS (10) safety limit prevents runaway loops
- System prompt includes structured tool-use protocol with JSON schemas

**Types:**
- `ChatMessage`, `ChatResponse`, `ToolCall`, `ToolResultMessage`,
  `ConversationMessage` — full conversation modeling types

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* fix: address review comments on multi-turn + tool execution

- Add history sliding window (MAX_HISTORY_MESSAGES=50) to prevent
  unbounded conversation history growth in interactive mode
- Add 404→Responses API fallback in compatible.rs chat_with_history,
  matching chat_with_system behavior
- Use super::api_error() for error sanitization in compatible.rs
  instead of raw error body (prevents secret leakage)
- Add missing operational logs in reliable.rs chat_with_history:
  recovery, non-retryable, fallback switch warnings
- Add trim_history tests

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* fix: address second round of review comments

- Sanitize raw error text in compatible.rs chat_with_system using
  sanitize_api_error (prevents leaking secrets in error messages)
- Add chat_with_history to MockProvider in reliable.rs tests so
  the retry/fallback path is exercised end-to-end
- Add chat_with_history_retries_then_recovers and
  chat_with_history_falls_back tests
- Log warning on malformed <tool_call> JSON instead of silent drop
- Flush stdout after print! in agent_turn so output appears before
  tool execution on line-buffered terminals
- Make interactive mode resilient to transient errors (continue
  loop instead of terminating session)

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Edvard Schøyen 2026-02-15 14:43:02 -05:00 committed by GitHub
parent 92c42dc24d
commit 89b1ec6fa2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 829 additions and 21 deletions

View file

@ -1,16 +1,44 @@
use crate::config::Config; use crate::config::Config;
use crate::memory::{self, Memory, MemoryCategory}; use crate::memory::{self, Memory, MemoryCategory};
use crate::observability::{self, Observer, ObserverEvent}; use crate::observability::{self, Observer, ObserverEvent};
use crate::providers::{self, Provider}; use crate::providers::{self, ChatMessage, Provider};
use crate::runtime; use crate::runtime;
use crate::security::SecurityPolicy; use crate::security::SecurityPolicy;
use crate::tools; use crate::tools::{self, Tool};
use crate::util::truncate_with_ellipsis; use crate::util::truncate_with_ellipsis;
use anyhow::Result; use anyhow::Result;
use std::fmt::Write; use std::fmt::Write;
use std::io::Write as IoWrite;
use std::sync::Arc; use std::sync::Arc;
use std::time::Instant; use std::time::Instant;
/// Maximum agentic tool-use iterations per user message to prevent runaway loops.
const MAX_TOOL_ITERATIONS: usize = 10;
/// Maximum number of non-system messages to keep in history.
/// When exceeded, the oldest messages are dropped (system prompt is always preserved).
const MAX_HISTORY_MESSAGES: usize = 50;
/// Trim conversation history to prevent unbounded growth.
/// Preserves the system prompt (first message if role=system) and the most recent messages.
fn trim_history(history: &mut Vec<ChatMessage>) {
// Nothing to trim if within limit
let has_system = history.first().map_or(false, |m| m.role == "system");
let non_system_count = if has_system {
history.len() - 1
} else {
history.len()
};
if non_system_count <= MAX_HISTORY_MESSAGES {
return;
}
let start = if has_system { 1 } else { 0 };
let to_remove = non_system_count - MAX_HISTORY_MESSAGES;
history.drain(start..start + to_remove);
}
/// Build context preamble by searching memory for relevant entries /// Build context preamble by searching memory for relevant entries
async fn build_context(mem: &dyn Memory, user_msg: &str) -> String { async fn build_context(mem: &dyn Memory, user_msg: &str) -> String {
let mut context = String::new(); let mut context = String::new();
@ -29,6 +57,178 @@ async fn build_context(mem: &dyn Memory, user_msg: &str) -> String {
context context
} }
/// Find a tool by name in the registry.
fn find_tool<'a>(tools: &'a [Box<dyn Tool>], name: &str) -> Option<&'a dyn Tool> {
tools.iter().find(|t| t.name() == name).map(|t| t.as_ref())
}
/// Parse tool calls from an LLM response that uses XML-style function calling.
///
/// Expected format (common with system-prompt-guided tool use):
/// ```text
/// <tool_call>
/// {"name": "shell", "arguments": {"command": "ls"}}
/// </tool_call>
/// ```
///
/// Also supports JSON with `tool_calls` array from OpenAI-format responses.
fn parse_tool_calls(response: &str) -> (String, Vec<ParsedToolCall>) {
let mut text_parts = Vec::new();
let mut calls = Vec::new();
let mut remaining = response;
while let Some(start) = remaining.find("<tool_call>") {
// Everything before the tag is text
let before = &remaining[..start];
if !before.trim().is_empty() {
text_parts.push(before.trim().to_string());
}
if let Some(end) = remaining[start..].find("</tool_call>") {
let inner = &remaining[start + 11..start + end];
match serde_json::from_str::<serde_json::Value>(inner.trim()) {
Ok(parsed) => {
let name = parsed
.get("name")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
let arguments = parsed
.get("arguments")
.cloned()
.unwrap_or(serde_json::Value::Object(serde_json::Map::new()));
calls.push(ParsedToolCall { name, arguments });
}
Err(e) => {
tracing::warn!("Malformed <tool_call> JSON: {e}");
}
}
remaining = &remaining[start + end + 12..];
} else {
break;
}
}
// Remaining text after last tool call
if !remaining.trim().is_empty() {
text_parts.push(remaining.trim().to_string());
}
(text_parts.join("\n"), calls)
}
#[derive(Debug)]
struct ParsedToolCall {
name: String,
arguments: serde_json::Value,
}
/// Execute a single turn of the agent loop: send messages, parse tool calls,
/// execute tools, and loop until the LLM produces a final text response.
async fn agent_turn(
provider: &dyn Provider,
history: &mut Vec<ChatMessage>,
tools_registry: &[Box<dyn Tool>],
observer: &dyn Observer,
model: &str,
temperature: f64,
) -> Result<String> {
for _iteration in 0..MAX_TOOL_ITERATIONS {
let response = provider
.chat_with_history(history, model, temperature)
.await?;
let (text, tool_calls) = parse_tool_calls(&response);
if tool_calls.is_empty() {
// No tool calls — this is the final response
history.push(ChatMessage::assistant(&response));
return Ok(if text.is_empty() {
response
} else {
text
});
}
// Print any text the LLM produced alongside tool calls
if !text.is_empty() {
print!("{text}");
let _ = std::io::stdout().flush();
}
// Execute each tool call and build results
let mut tool_results = String::new();
for call in &tool_calls {
let start = Instant::now();
let result = if let Some(tool) = find_tool(tools_registry, &call.name) {
match tool.execute(call.arguments.clone()).await {
Ok(r) => {
observer.record_event(&ObserverEvent::ToolCall {
tool: call.name.clone(),
duration: start.elapsed(),
success: r.success,
});
if r.success {
r.output
} else {
format!("Error: {}", r.error.unwrap_or_else(|| r.output))
}
}
Err(e) => {
observer.record_event(&ObserverEvent::ToolCall {
tool: call.name.clone(),
duration: start.elapsed(),
success: false,
});
format!("Error executing {}: {e}", call.name)
}
}
} else {
format!("Unknown tool: {}", call.name)
};
let _ = writeln!(
tool_results,
"<tool_result name=\"{}\">\n{}\n</tool_result>",
call.name, result
);
}
// Add assistant message with tool calls + tool results to history
history.push(ChatMessage::assistant(&response));
history.push(ChatMessage::user(format!(
"[Tool results]\n{tool_results}"
)));
}
anyhow::bail!("Agent exceeded maximum tool iterations ({MAX_TOOL_ITERATIONS})")
}
/// Build the tool instruction block for the system prompt so the LLM knows
/// how to invoke tools.
fn build_tool_instructions(tools_registry: &[Box<dyn Tool>]) -> String {
let mut instructions = String::new();
instructions.push_str("\n## Tool Use Protocol\n\n");
instructions.push_str("To use a tool, wrap a JSON object in <tool_call></tool_call> tags:\n\n");
instructions.push_str("```\n<tool_call>\n{\"name\": \"tool_name\", \"arguments\": {\"param\": \"value\"}}\n</tool_call>\n```\n\n");
instructions.push_str("You may use multiple tool calls in a single response. ");
instructions.push_str("After tool execution, results appear in <tool_result> tags. ");
instructions.push_str("Continue reasoning with the results until you can give a final answer.\n\n");
instructions.push_str("### Available Tools\n\n");
for tool in tools_registry {
let _ = writeln!(
instructions,
"**{}**: {}\nParameters: `{}`\n",
tool.name(),
tool.description(),
tool.parameters_schema()
);
}
instructions
}
#[allow(clippy::too_many_lines)] #[allow(clippy::too_many_lines)]
pub async fn run( pub async fn run(
config: Config, config: Config,
@ -61,7 +261,7 @@ pub async fn run(
} else { } else {
None None
}; };
let _tools = tools::all_tools_with_runtime( let tools_registry = tools::all_tools_with_runtime(
&security, &security,
runtime, runtime,
mem.clone(), mem.clone(),
@ -133,7 +333,7 @@ pub async fn run(
"Execute actions on 1000+ apps via Composio (Gmail, Notion, GitHub, Slack, etc.). Use action='list' to discover, 'execute' to run, 'connect' to OAuth.", "Execute actions on 1000+ apps via Composio (Gmail, Notion, GitHub, Slack, etc.). Use action='list' to discover, 'execute' to run, 'connect' to OAuth.",
)); ));
} }
let system_prompt = crate::channels::build_system_prompt( let mut system_prompt = crate::channels::build_system_prompt(
&config.workspace_dir, &config.workspace_dir,
model_name, model_name,
&tool_descs, &tool_descs,
@ -141,6 +341,9 @@ pub async fn run(
Some(&config.identity), Some(&config.identity),
); );
// Append structured tool-use instructions with schemas
system_prompt.push_str(&build_tool_instructions(&tools_registry));
// ── Execute ────────────────────────────────────────────────── // ── Execute ──────────────────────────────────────────────────
let start = Instant::now(); let start = Instant::now();
@ -160,8 +363,19 @@ pub async fn run(
format!("{context}{msg}") format!("{context}{msg}")
}; };
let response = provider let mut history = vec![
.chat_with_system(Some(&system_prompt), &enriched, model_name, temperature) ChatMessage::system(&system_prompt),
ChatMessage::user(&enriched),
];
let response = agent_turn(
provider.as_ref(),
&mut history,
&tools_registry,
observer.as_ref(),
model_name,
temperature,
)
.await?; .await?;
println!("{response}"); println!("{response}");
@ -184,6 +398,9 @@ pub async fn run(
let _ = crate::channels::Channel::listen(&cli, tx).await; let _ = crate::channels::Channel::listen(&cli, tx).await;
}); });
// Persistent conversation history across turns
let mut history = vec![ChatMessage::system(&system_prompt)];
while let Some(msg) = rx.recv().await { while let Some(msg) = rx.recv().await {
// Auto-save conversation turns // Auto-save conversation turns
if config.memory.auto_save { if config.memory.auto_save {
@ -200,11 +417,29 @@ pub async fn run(
format!("{context}{}", msg.content) format!("{context}{}", msg.content)
}; };
let response = provider history.push(ChatMessage::user(&enriched));
.chat_with_system(Some(&system_prompt), &enriched, model_name, temperature)
.await?; let response = match agent_turn(
provider.as_ref(),
&mut history,
&tools_registry,
observer.as_ref(),
model_name,
temperature,
)
.await
{
Ok(resp) => resp,
Err(e) => {
eprintln!("\nError: {e}\n");
continue;
}
};
println!("\n{response}\n"); println!("\n{response}\n");
// Prevent unbounded history growth in long interactive sessions
trim_history(&mut history);
if config.memory.auto_save { if config.memory.auto_save {
let summary = truncate_with_ellipsis(&response, 100); let summary = truncate_with_ellipsis(&response, 100);
let _ = mem let _ = mem
@ -224,3 +459,126 @@ pub async fn run(
Ok(()) Ok(())
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_tool_calls_extracts_single_call() {
let response = r#"Let me check that.
<tool_call>
{"name": "shell", "arguments": {"command": "ls -la"}}
</tool_call>"#;
let (text, calls) = parse_tool_calls(response);
assert_eq!(text, "Let me check that.");
assert_eq!(calls.len(), 1);
assert_eq!(calls[0].name, "shell");
assert_eq!(
calls[0].arguments.get("command").unwrap().as_str().unwrap(),
"ls -la"
);
}
#[test]
fn parse_tool_calls_extracts_multiple_calls() {
let response = r#"<tool_call>
{"name": "file_read", "arguments": {"path": "a.txt"}}
</tool_call>
<tool_call>
{"name": "file_read", "arguments": {"path": "b.txt"}}
</tool_call>"#;
let (_, calls) = parse_tool_calls(response);
assert_eq!(calls.len(), 2);
assert_eq!(calls[0].name, "file_read");
assert_eq!(calls[1].name, "file_read");
}
#[test]
fn parse_tool_calls_returns_text_only_when_no_calls() {
let response = "Just a normal response with no tools.";
let (text, calls) = parse_tool_calls(response);
assert_eq!(text, "Just a normal response with no tools.");
assert!(calls.is_empty());
}
#[test]
fn parse_tool_calls_handles_malformed_json() {
let response = r#"<tool_call>
not valid json
</tool_call>
Some text after."#;
let (text, calls) = parse_tool_calls(response);
assert!(calls.is_empty());
assert!(text.contains("Some text after."));
}
#[test]
fn parse_tool_calls_text_before_and_after() {
let response = r#"Before text.
<tool_call>
{"name": "shell", "arguments": {"command": "echo hi"}}
</tool_call>
After text."#;
let (text, calls) = parse_tool_calls(response);
assert!(text.contains("Before text."));
assert!(text.contains("After text."));
assert_eq!(calls.len(), 1);
}
#[test]
fn build_tool_instructions_includes_all_tools() {
use crate::security::SecurityPolicy;
let security = Arc::new(SecurityPolicy::from_config(
&crate::config::AutonomyConfig::default(),
std::path::Path::new("/tmp"),
));
let tools = tools::default_tools(security);
let instructions = build_tool_instructions(&tools);
assert!(instructions.contains("## Tool Use Protocol"));
assert!(instructions.contains("<tool_call>"));
assert!(instructions.contains("shell"));
assert!(instructions.contains("file_read"));
assert!(instructions.contains("file_write"));
}
#[test]
fn trim_history_preserves_system_prompt() {
let mut history = vec![ChatMessage::system("system prompt")];
for i in 0..MAX_HISTORY_MESSAGES + 20 {
history.push(ChatMessage::user(format!("msg {i}")));
}
let original_len = history.len();
assert!(original_len > MAX_HISTORY_MESSAGES + 1);
trim_history(&mut history);
// System prompt preserved
assert_eq!(history[0].role, "system");
assert_eq!(history[0].content, "system prompt");
// Trimmed to limit
assert_eq!(history.len(), MAX_HISTORY_MESSAGES + 1); // +1 for system
// Most recent messages preserved
let last = &history[history.len() - 1];
assert_eq!(
last.content,
format!("msg {}", MAX_HISTORY_MESSAGES + 19)
);
}
#[test]
fn trim_history_noop_when_within_limit() {
let mut history = vec![
ChatMessage::system("sys"),
ChatMessage::user("hello"),
ChatMessage::assistant("hi"),
];
trim_history(&mut history);
assert_eq!(history.len(), 3);
}
}

View file

@ -2,7 +2,7 @@
//! Most LLM APIs follow the same `/v1/chat/completions` format. //! Most LLM APIs follow the same `/v1/chat/completions` format.
//! This module provides a single implementation that works for all of them. //! This module provides a single implementation that works for all of them.
use crate::providers::traits::Provider; use crate::providers::traits::{ChatMessage, Provider};
use async_trait::async_trait; use async_trait::async_trait;
use reqwest::Client; use reqwest::Client;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -81,7 +81,7 @@ struct Message {
} }
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
struct ChatResponse { struct ApiChatResponse {
choices: Vec<Choice>, choices: Vec<Choice>,
} }
@ -264,6 +264,7 @@ impl Provider for OpenAiCompatibleProvider {
if !response.status().is_success() { if !response.status().is_success() {
let status = response.status(); let status = response.status();
let error = response.text().await?; let error = response.text().await?;
let sanitized = super::sanitize_api_error(&error);
if status == reqwest::StatusCode::NOT_FOUND { if status == reqwest::StatusCode::NOT_FOUND {
return self return self
@ -271,16 +272,88 @@ impl Provider for OpenAiCompatibleProvider {
.await .await
.map_err(|responses_err| { .map_err(|responses_err| {
anyhow::anyhow!( anyhow::anyhow!(
"{} API error: {error} (chat completions unavailable; responses fallback failed: {responses_err})", "{} API error ({status}): {sanitized} (chat completions unavailable; responses fallback failed: {responses_err})",
self.name self.name
) )
}); });
} }
anyhow::bail!("{} API error: {error}", self.name); anyhow::bail!("{} API error ({status}): {sanitized}", self.name);
} }
let chat_response: ChatResponse = response.json().await?; let chat_response: ApiChatResponse = response.json().await?;
chat_response
.choices
.into_iter()
.next()
.map(|c| c.message.content)
.ok_or_else(|| anyhow::anyhow!("No response from {}", self.name))
}
async fn chat_with_history(
&self,
messages: &[ChatMessage],
model: &str,
temperature: f64,
) -> anyhow::Result<String> {
let api_key = self.api_key.as_ref().ok_or_else(|| {
anyhow::anyhow!(
"{} API key not set. Run `zeroclaw onboard` or set the appropriate env var.",
self.name
)
})?;
let api_messages: Vec<Message> = messages
.iter()
.map(|m| Message {
role: m.role.clone(),
content: m.content.clone(),
})
.collect();
let request = ChatRequest {
model: model.to_string(),
messages: api_messages,
temperature,
};
let url = self.chat_completions_url();
let response = self
.apply_auth_header(self.client.post(&url).json(&request), api_key)
.send()
.await?;
if !response.status().is_success() {
let status = response.status();
// Mirror chat_with_system: 404 may mean this provider uses the Responses API
if status == reqwest::StatusCode::NOT_FOUND {
// Extract system prompt and last user message for responses fallback
let system = messages.iter().find(|m| m.role == "system");
let last_user = messages.iter().rfind(|m| m.role == "user");
if let Some(user_msg) = last_user {
return self
.chat_via_responses(
api_key,
system.map(|m| m.content.as_str()),
&user_msg.content,
model,
)
.await
.map_err(|responses_err| {
anyhow::anyhow!(
"{} API error (chat completions unavailable; responses fallback failed: {responses_err})",
self.name
)
});
}
}
return Err(super::api_error(&self.name, response).await);
}
let chat_response: ApiChatResponse = response.json().await?;
chat_response chat_response
.choices .choices
@ -357,14 +430,14 @@ mod tests {
#[test] #[test]
fn response_deserializes() { fn response_deserializes() {
let json = r#"{"choices":[{"message":{"content":"Hello from Venice!"}}]}"#; let json = r#"{"choices":[{"message":{"content":"Hello from Venice!"}}]}"#;
let resp: ChatResponse = serde_json::from_str(json).unwrap(); let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
assert_eq!(resp.choices[0].message.content, "Hello from Venice!"); assert_eq!(resp.choices[0].message.content, "Hello from Venice!");
} }
#[test] #[test]
fn response_empty_choices() { fn response_empty_choices() {
let json = r#"{"choices":[]}"#; let json = r#"{"choices":[]}"#;
let resp: ChatResponse = serde_json::from_str(json).unwrap(); let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
assert!(resp.choices.is_empty()); assert!(resp.choices.is_empty());
} }

View file

@ -8,7 +8,7 @@ pub mod reliable;
pub mod router; pub mod router;
pub mod traits; pub mod traits;
pub use traits::Provider; pub use traits::{ChatMessage, Provider};
use compatible::{AuthStyle, OpenAiCompatibleProvider}; use compatible::{AuthStyle, OpenAiCompatibleProvider};
use reliable::ReliableProvider; use reliable::ReliableProvider;

View file

@ -1,4 +1,4 @@
use crate::providers::traits::Provider; use crate::providers::traits::{ChatMessage, Provider};
use async_trait::async_trait; use async_trait::async_trait;
use reqwest::Client; use reqwest::Client;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -22,7 +22,7 @@ struct Message {
} }
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
struct ChatResponse { struct ApiChatResponse {
choices: Vec<Choice>, choices: Vec<Choice>,
} }
@ -112,7 +112,57 @@ impl Provider for OpenRouterProvider {
return Err(super::api_error("OpenRouter", response).await); return Err(super::api_error("OpenRouter", response).await);
} }
let chat_response: ChatResponse = response.json().await?; let chat_response: ApiChatResponse = response.json().await?;
chat_response
.choices
.into_iter()
.next()
.map(|c| c.message.content)
.ok_or_else(|| anyhow::anyhow!("No response from OpenRouter"))
}
async fn chat_with_history(
&self,
messages: &[ChatMessage],
model: &str,
temperature: f64,
) -> anyhow::Result<String> {
let api_key = self.api_key.as_ref()
.ok_or_else(|| anyhow::anyhow!("OpenRouter API key not set. Run `zeroclaw onboard` or set OPENROUTER_API_KEY env var."))?;
let api_messages: Vec<Message> = messages
.iter()
.map(|m| Message {
role: m.role.clone(),
content: m.content.clone(),
})
.collect();
let request = ChatRequest {
model: model.to_string(),
messages: api_messages,
temperature,
};
let response = self
.client
.post("https://openrouter.ai/api/v1/chat/completions")
.header("Authorization", format!("Bearer {api_key}"))
.header(
"HTTP-Referer",
"https://github.com/theonlyhennygod/zeroclaw",
)
.header("X-Title", "ZeroClaw")
.json(&request)
.send()
.await?;
if !response.status().is_success() {
return Err(super::api_error("OpenRouter", response).await);
}
let chat_response: ApiChatResponse = response.json().await?;
chat_response chat_response
.choices .choices

View file

@ -1,3 +1,4 @@
use super::traits::ChatMessage;
use super::Provider; use super::Provider;
use async_trait::async_trait; use async_trait::async_trait;
use std::time::Duration; use std::time::Duration;
@ -121,6 +122,68 @@ impl Provider for ReliableProvider {
anyhow::bail!("All providers failed. Attempts:\n{}", failures.join("\n")) anyhow::bail!("All providers failed. Attempts:\n{}", failures.join("\n"))
} }
async fn chat_with_history(
&self,
messages: &[ChatMessage],
model: &str,
temperature: f64,
) -> anyhow::Result<String> {
let mut failures = Vec::new();
for (provider_name, provider) in &self.providers {
let mut backoff_ms = self.base_backoff_ms;
for attempt in 0..=self.max_retries {
match provider
.chat_with_history(messages, model, temperature)
.await
{
Ok(resp) => {
if attempt > 0 {
tracing::info!(
provider = provider_name,
attempt,
"Provider recovered after retries"
);
}
return Ok(resp);
}
Err(e) => {
let non_retryable = is_non_retryable(&e);
failures.push(format!(
"{provider_name} attempt {}/{}: {e}",
attempt + 1,
self.max_retries + 1
));
if non_retryable {
tracing::warn!(
provider = provider_name,
"Non-retryable error, switching provider"
);
break;
}
if attempt < self.max_retries {
tracing::warn!(
provider = provider_name,
attempt = attempt + 1,
max_retries = self.max_retries,
"Provider call failed, retrying"
);
tokio::time::sleep(Duration::from_millis(backoff_ms)).await;
backoff_ms = (backoff_ms.saturating_mul(2)).min(10_000);
}
}
}
}
tracing::warn!(provider = provider_name, "Switching to fallback provider");
}
anyhow::bail!("All providers failed. Attempts:\n{}", failures.join("\n"))
}
} }
#[cfg(test)] #[cfg(test)]
@ -151,6 +214,19 @@ mod tests {
} }
Ok(self.response.to_string()) Ok(self.response.to_string())
} }
async fn chat_with_history(
&self,
_messages: &[ChatMessage],
_model: &str,
_temperature: f64,
) -> anyhow::Result<String> {
let attempt = self.calls.fetch_add(1, Ordering::SeqCst) + 1;
if attempt <= self.fail_until_attempt {
anyhow::bail!(self.error);
}
Ok(self.response.to_string())
}
} }
#[tokio::test] #[tokio::test]
@ -330,4 +406,73 @@ mod tests {
assert_eq!(primary_calls.load(Ordering::SeqCst), 1); assert_eq!(primary_calls.load(Ordering::SeqCst), 1);
assert_eq!(fallback_calls.load(Ordering::SeqCst), 1); assert_eq!(fallback_calls.load(Ordering::SeqCst), 1);
} }
#[tokio::test]
async fn chat_with_history_retries_then_recovers() {
let calls = Arc::new(AtomicUsize::new(0));
let provider = ReliableProvider::new(
vec![(
"primary".into(),
Box::new(MockProvider {
calls: Arc::clone(&calls),
fail_until_attempt: 1,
response: "history ok",
error: "temporary",
}),
)],
2,
1,
);
let messages = vec![
ChatMessage::system("system"),
ChatMessage::user("hello"),
];
let result = provider
.chat_with_history(&messages, "test", 0.0)
.await
.unwrap();
assert_eq!(result, "history ok");
assert_eq!(calls.load(Ordering::SeqCst), 2);
}
#[tokio::test]
async fn chat_with_history_falls_back() {
let primary_calls = Arc::new(AtomicUsize::new(0));
let fallback_calls = Arc::new(AtomicUsize::new(0));
let provider = ReliableProvider::new(
vec![
(
"primary".into(),
Box::new(MockProvider {
calls: Arc::clone(&primary_calls),
fail_until_attempt: usize::MAX,
response: "never",
error: "primary down",
}),
),
(
"fallback".into(),
Box::new(MockProvider {
calls: Arc::clone(&fallback_calls),
fail_until_attempt: 0,
response: "fallback ok",
error: "fallback err",
}),
),
],
1,
1,
);
let messages = vec![ChatMessage::user("hello")];
let result = provider
.chat_with_history(&messages, "test", 0.0)
.await
.unwrap();
assert_eq!(result, "fallback ok");
assert_eq!(primary_calls.load(Ordering::SeqCst), 2);
assert_eq!(fallback_calls.load(Ordering::SeqCst), 1);
}
} }

View file

@ -1,3 +1,4 @@
use super::traits::ChatMessage;
use super::Provider; use super::Provider;
use async_trait::async_trait; use async_trait::async_trait;
use std::collections::HashMap; use std::collections::HashMap;
@ -112,6 +113,19 @@ impl Provider for RouterProvider {
.await .await
} }
async fn chat_with_history(
&self,
messages: &[ChatMessage],
model: &str,
temperature: f64,
) -> anyhow::Result<String> {
let (provider_idx, resolved_model) = self.resolve(model);
let (_, provider) = &self.providers[provider_idx];
provider
.chat_with_history(messages, &resolved_model, temperature)
.await
}
async fn warmup(&self) -> anyhow::Result<()> { async fn warmup(&self) -> anyhow::Result<()> {
for (name, provider) in &self.providers { for (name, provider) in &self.providers {
tracing::info!(provider = name, "Warming up routed provider"); tracing::info!(provider = name, "Warming up routed provider");

View file

@ -1,4 +1,86 @@
use async_trait::async_trait; use async_trait::async_trait;
use serde::{Deserialize, Serialize};
/// A single message in a conversation.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatMessage {
pub role: String,
pub content: String,
}
impl ChatMessage {
pub fn system(content: impl Into<String>) -> Self {
Self {
role: "system".into(),
content: content.into(),
}
}
pub fn user(content: impl Into<String>) -> Self {
Self {
role: "user".into(),
content: content.into(),
}
}
pub fn assistant(content: impl Into<String>) -> Self {
Self {
role: "assistant".into(),
content: content.into(),
}
}
}
/// A tool call requested by the LLM.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCall {
pub id: String,
pub name: String,
pub arguments: String,
}
/// An LLM response that may contain text, tool calls, or both.
#[derive(Debug, Clone)]
pub struct ChatResponse {
/// Text content of the response (may be empty if only tool calls).
pub text: Option<String>,
/// Tool calls requested by the LLM.
pub tool_calls: Vec<ToolCall>,
}
impl ChatResponse {
/// True when the LLM wants to invoke at least one tool.
pub fn has_tool_calls(&self) -> bool {
!self.tool_calls.is_empty()
}
/// Convenience: return text content or empty string.
pub fn text_or_empty(&self) -> &str {
self.text.as_deref().unwrap_or("")
}
}
/// A tool result to feed back to the LLM.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolResultMessage {
pub tool_call_id: String,
pub content: String,
}
/// A message in a multi-turn conversation, including tool interactions.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum ConversationMessage {
/// Regular chat message (system, user, assistant).
Chat(ChatMessage),
/// Tool calls from the assistant (stored for history fidelity).
AssistantToolCalls {
text: Option<String>,
tool_calls: Vec<ToolCall>,
},
/// Result of a tool execution, fed back to the LLM.
ToolResult(ToolResultMessage),
}
#[async_trait] #[async_trait]
pub trait Provider: Send + Sync { pub trait Provider: Send + Sync {
@ -15,9 +97,95 @@ pub trait Provider: Send + Sync {
temperature: f64, temperature: f64,
) -> anyhow::Result<String>; ) -> anyhow::Result<String>;
/// Multi-turn conversation. Default implementation extracts the last user
/// message and delegates to `chat_with_system`.
async fn chat_with_history(
&self,
messages: &[ChatMessage],
model: &str,
temperature: f64,
) -> anyhow::Result<String> {
let system = messages
.iter()
.find(|m| m.role == "system")
.map(|m| m.content.as_str());
let last_user = messages
.iter()
.rfind(|m| m.role == "user")
.map(|m| m.content.as_str())
.unwrap_or("");
self.chat_with_system(system, last_user, model, temperature)
.await
}
/// Warm up the HTTP connection pool (TLS handshake, DNS, HTTP/2 setup). /// Warm up the HTTP connection pool (TLS handshake, DNS, HTTP/2 setup).
/// Default implementation is a no-op; providers with HTTP clients should override. /// Default implementation is a no-op; providers with HTTP clients should override.
async fn warmup(&self) -> anyhow::Result<()> { async fn warmup(&self) -> anyhow::Result<()> {
Ok(()) Ok(())
} }
} }
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn chat_message_constructors() {
let sys = ChatMessage::system("Be helpful");
assert_eq!(sys.role, "system");
assert_eq!(sys.content, "Be helpful");
let user = ChatMessage::user("Hello");
assert_eq!(user.role, "user");
let asst = ChatMessage::assistant("Hi there");
assert_eq!(asst.role, "assistant");
}
#[test]
fn chat_response_helpers() {
let empty = ChatResponse {
text: None,
tool_calls: vec![],
};
assert!(!empty.has_tool_calls());
assert_eq!(empty.text_or_empty(), "");
let with_tools = ChatResponse {
text: Some("Let me check".into()),
tool_calls: vec![ToolCall {
id: "1".into(),
name: "shell".into(),
arguments: "{}".into(),
}],
};
assert!(with_tools.has_tool_calls());
assert_eq!(with_tools.text_or_empty(), "Let me check");
}
#[test]
fn tool_call_serialization() {
let tc = ToolCall {
id: "call_123".into(),
name: "file_read".into(),
arguments: r#"{"path":"test.txt"}"#.into(),
};
let json = serde_json::to_string(&tc).unwrap();
assert!(json.contains("call_123"));
assert!(json.contains("file_read"));
}
#[test]
fn conversation_message_variants() {
let chat = ConversationMessage::Chat(ChatMessage::user("hi"));
let json = serde_json::to_string(&chat).unwrap();
assert!(json.contains("\"type\":\"Chat\""));
let tool_result = ConversationMessage::ToolResult(ToolResultMessage {
tool_call_id: "1".into(),
content: "done".into(),
});
let json = serde_json::to_string(&tool_result).unwrap();
assert!(json.contains("\"type\":\"ToolResult\""));
}
}