feat: add multi-turn conversation history and tool execution
* feat: add multi-turn conversation history and tool execution Major enhancement to the agent loop: **Multi-turn conversation:** - Add `ChatMessage` type with system/user/assistant constructors - Add `chat_with_history` method to Provider trait (default impl delegates to `chat_with_system` for backward compatibility) - Implement native `chat_with_history` on OpenRouter, Compatible, Reliable, and Router providers to send full message history - Interactive mode now maintains persistent history across turns **Tool execution:** - Agent loop now parses `<tool_call>` XML tags from LLM responses - Executes tools from the registry and feeds results back as `<tool_result>` messages - Agentic loop continues until LLM produces final text (no tool calls) - MAX_TOOL_ITERATIONS (10) safety limit prevents runaway loops - System prompt includes structured tool-use protocol with JSON schemas **Types:** - `ChatMessage`, `ChatResponse`, `ToolCall`, `ToolResultMessage`, `ConversationMessage` — full conversation modeling types Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix: address review comments on multi-turn + tool execution - Add history sliding window (MAX_HISTORY_MESSAGES=50) to prevent unbounded conversation history growth in interactive mode - Add 404→Responses API fallback in compatible.rs chat_with_history, matching chat_with_system behavior - Use super::api_error() for error sanitization in compatible.rs instead of raw error body (prevents secret leakage) - Add missing operational logs in reliable.rs chat_with_history: recovery, non-retryable, fallback switch warnings - Add trim_history tests Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix: address second round of review comments - Sanitize raw error text in compatible.rs chat_with_system using sanitize_api_error (prevents leaking secrets in error messages) - Add chat_with_history to MockProvider in reliable.rs tests so the retry/fallback path is exercised end-to-end - Add chat_with_history_retries_then_recovers and chat_with_history_falls_back tests - Log warning on malformed <tool_call> JSON instead of silent drop - Flush stdout after print! in agent_turn so output appears before tool execution on line-buffered terminals - Make interactive mode resilient to transient errors (continue loop instead of terminating session) Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
92c42dc24d
commit
89b1ec6fa2
7 changed files with 829 additions and 21 deletions
|
|
@ -1,16 +1,44 @@
|
||||||
use crate::config::Config;
|
use crate::config::Config;
|
||||||
use crate::memory::{self, Memory, MemoryCategory};
|
use crate::memory::{self, Memory, MemoryCategory};
|
||||||
use crate::observability::{self, Observer, ObserverEvent};
|
use crate::observability::{self, Observer, ObserverEvent};
|
||||||
use crate::providers::{self, Provider};
|
use crate::providers::{self, ChatMessage, Provider};
|
||||||
use crate::runtime;
|
use crate::runtime;
|
||||||
use crate::security::SecurityPolicy;
|
use crate::security::SecurityPolicy;
|
||||||
use crate::tools;
|
use crate::tools::{self, Tool};
|
||||||
use crate::util::truncate_with_ellipsis;
|
use crate::util::truncate_with_ellipsis;
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use std::fmt::Write;
|
use std::fmt::Write;
|
||||||
|
use std::io::Write as IoWrite;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::time::Instant;
|
use std::time::Instant;
|
||||||
|
|
||||||
|
/// Maximum agentic tool-use iterations per user message to prevent runaway loops.
|
||||||
|
const MAX_TOOL_ITERATIONS: usize = 10;
|
||||||
|
|
||||||
|
/// Maximum number of non-system messages to keep in history.
|
||||||
|
/// When exceeded, the oldest messages are dropped (system prompt is always preserved).
|
||||||
|
const MAX_HISTORY_MESSAGES: usize = 50;
|
||||||
|
|
||||||
|
/// Trim conversation history to prevent unbounded growth.
|
||||||
|
/// Preserves the system prompt (first message if role=system) and the most recent messages.
|
||||||
|
fn trim_history(history: &mut Vec<ChatMessage>) {
|
||||||
|
// Nothing to trim if within limit
|
||||||
|
let has_system = history.first().map_or(false, |m| m.role == "system");
|
||||||
|
let non_system_count = if has_system {
|
||||||
|
history.len() - 1
|
||||||
|
} else {
|
||||||
|
history.len()
|
||||||
|
};
|
||||||
|
|
||||||
|
if non_system_count <= MAX_HISTORY_MESSAGES {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
let start = if has_system { 1 } else { 0 };
|
||||||
|
let to_remove = non_system_count - MAX_HISTORY_MESSAGES;
|
||||||
|
history.drain(start..start + to_remove);
|
||||||
|
}
|
||||||
|
|
||||||
/// Build context preamble by searching memory for relevant entries
|
/// Build context preamble by searching memory for relevant entries
|
||||||
async fn build_context(mem: &dyn Memory, user_msg: &str) -> String {
|
async fn build_context(mem: &dyn Memory, user_msg: &str) -> String {
|
||||||
let mut context = String::new();
|
let mut context = String::new();
|
||||||
|
|
@ -29,6 +57,178 @@ async fn build_context(mem: &dyn Memory, user_msg: &str) -> String {
|
||||||
context
|
context
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Find a tool by name in the registry.
|
||||||
|
fn find_tool<'a>(tools: &'a [Box<dyn Tool>], name: &str) -> Option<&'a dyn Tool> {
|
||||||
|
tools.iter().find(|t| t.name() == name).map(|t| t.as_ref())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Parse tool calls from an LLM response that uses XML-style function calling.
|
||||||
|
///
|
||||||
|
/// Expected format (common with system-prompt-guided tool use):
|
||||||
|
/// ```text
|
||||||
|
/// <tool_call>
|
||||||
|
/// {"name": "shell", "arguments": {"command": "ls"}}
|
||||||
|
/// </tool_call>
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// Also supports JSON with `tool_calls` array from OpenAI-format responses.
|
||||||
|
fn parse_tool_calls(response: &str) -> (String, Vec<ParsedToolCall>) {
|
||||||
|
let mut text_parts = Vec::new();
|
||||||
|
let mut calls = Vec::new();
|
||||||
|
let mut remaining = response;
|
||||||
|
|
||||||
|
while let Some(start) = remaining.find("<tool_call>") {
|
||||||
|
// Everything before the tag is text
|
||||||
|
let before = &remaining[..start];
|
||||||
|
if !before.trim().is_empty() {
|
||||||
|
text_parts.push(before.trim().to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(end) = remaining[start..].find("</tool_call>") {
|
||||||
|
let inner = &remaining[start + 11..start + end];
|
||||||
|
match serde_json::from_str::<serde_json::Value>(inner.trim()) {
|
||||||
|
Ok(parsed) => {
|
||||||
|
let name = parsed
|
||||||
|
.get("name")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.unwrap_or("")
|
||||||
|
.to_string();
|
||||||
|
let arguments = parsed
|
||||||
|
.get("arguments")
|
||||||
|
.cloned()
|
||||||
|
.unwrap_or(serde_json::Value::Object(serde_json::Map::new()));
|
||||||
|
calls.push(ParsedToolCall { name, arguments });
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
tracing::warn!("Malformed <tool_call> JSON: {e}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
remaining = &remaining[start + end + 12..];
|
||||||
|
} else {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Remaining text after last tool call
|
||||||
|
if !remaining.trim().is_empty() {
|
||||||
|
text_parts.push(remaining.trim().to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
(text_parts.join("\n"), calls)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct ParsedToolCall {
|
||||||
|
name: String,
|
||||||
|
arguments: serde_json::Value,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Execute a single turn of the agent loop: send messages, parse tool calls,
|
||||||
|
/// execute tools, and loop until the LLM produces a final text response.
|
||||||
|
async fn agent_turn(
|
||||||
|
provider: &dyn Provider,
|
||||||
|
history: &mut Vec<ChatMessage>,
|
||||||
|
tools_registry: &[Box<dyn Tool>],
|
||||||
|
observer: &dyn Observer,
|
||||||
|
model: &str,
|
||||||
|
temperature: f64,
|
||||||
|
) -> Result<String> {
|
||||||
|
for _iteration in 0..MAX_TOOL_ITERATIONS {
|
||||||
|
let response = provider
|
||||||
|
.chat_with_history(history, model, temperature)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let (text, tool_calls) = parse_tool_calls(&response);
|
||||||
|
|
||||||
|
if tool_calls.is_empty() {
|
||||||
|
// No tool calls — this is the final response
|
||||||
|
history.push(ChatMessage::assistant(&response));
|
||||||
|
return Ok(if text.is_empty() {
|
||||||
|
response
|
||||||
|
} else {
|
||||||
|
text
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Print any text the LLM produced alongside tool calls
|
||||||
|
if !text.is_empty() {
|
||||||
|
print!("{text}");
|
||||||
|
let _ = std::io::stdout().flush();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute each tool call and build results
|
||||||
|
let mut tool_results = String::new();
|
||||||
|
for call in &tool_calls {
|
||||||
|
let start = Instant::now();
|
||||||
|
let result = if let Some(tool) = find_tool(tools_registry, &call.name) {
|
||||||
|
match tool.execute(call.arguments.clone()).await {
|
||||||
|
Ok(r) => {
|
||||||
|
observer.record_event(&ObserverEvent::ToolCall {
|
||||||
|
tool: call.name.clone(),
|
||||||
|
duration: start.elapsed(),
|
||||||
|
success: r.success,
|
||||||
|
});
|
||||||
|
if r.success {
|
||||||
|
r.output
|
||||||
|
} else {
|
||||||
|
format!("Error: {}", r.error.unwrap_or_else(|| r.output))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
observer.record_event(&ObserverEvent::ToolCall {
|
||||||
|
tool: call.name.clone(),
|
||||||
|
duration: start.elapsed(),
|
||||||
|
success: false,
|
||||||
|
});
|
||||||
|
format!("Error executing {}: {e}", call.name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
format!("Unknown tool: {}", call.name)
|
||||||
|
};
|
||||||
|
|
||||||
|
let _ = writeln!(
|
||||||
|
tool_results,
|
||||||
|
"<tool_result name=\"{}\">\n{}\n</tool_result>",
|
||||||
|
call.name, result
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add assistant message with tool calls + tool results to history
|
||||||
|
history.push(ChatMessage::assistant(&response));
|
||||||
|
history.push(ChatMessage::user(format!(
|
||||||
|
"[Tool results]\n{tool_results}"
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
|
||||||
|
anyhow::bail!("Agent exceeded maximum tool iterations ({MAX_TOOL_ITERATIONS})")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Build the tool instruction block for the system prompt so the LLM knows
|
||||||
|
/// how to invoke tools.
|
||||||
|
fn build_tool_instructions(tools_registry: &[Box<dyn Tool>]) -> String {
|
||||||
|
let mut instructions = String::new();
|
||||||
|
instructions.push_str("\n## Tool Use Protocol\n\n");
|
||||||
|
instructions.push_str("To use a tool, wrap a JSON object in <tool_call></tool_call> tags:\n\n");
|
||||||
|
instructions.push_str("```\n<tool_call>\n{\"name\": \"tool_name\", \"arguments\": {\"param\": \"value\"}}\n</tool_call>\n```\n\n");
|
||||||
|
instructions.push_str("You may use multiple tool calls in a single response. ");
|
||||||
|
instructions.push_str("After tool execution, results appear in <tool_result> tags. ");
|
||||||
|
instructions.push_str("Continue reasoning with the results until you can give a final answer.\n\n");
|
||||||
|
instructions.push_str("### Available Tools\n\n");
|
||||||
|
|
||||||
|
for tool in tools_registry {
|
||||||
|
let _ = writeln!(
|
||||||
|
instructions,
|
||||||
|
"**{}**: {}\nParameters: `{}`\n",
|
||||||
|
tool.name(),
|
||||||
|
tool.description(),
|
||||||
|
tool.parameters_schema()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
instructions
|
||||||
|
}
|
||||||
|
|
||||||
#[allow(clippy::too_many_lines)]
|
#[allow(clippy::too_many_lines)]
|
||||||
pub async fn run(
|
pub async fn run(
|
||||||
config: Config,
|
config: Config,
|
||||||
|
|
@ -61,7 +261,7 @@ pub async fn run(
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
};
|
};
|
||||||
let _tools = tools::all_tools_with_runtime(
|
let tools_registry = tools::all_tools_with_runtime(
|
||||||
&security,
|
&security,
|
||||||
runtime,
|
runtime,
|
||||||
mem.clone(),
|
mem.clone(),
|
||||||
|
|
@ -133,7 +333,7 @@ pub async fn run(
|
||||||
"Execute actions on 1000+ apps via Composio (Gmail, Notion, GitHub, Slack, etc.). Use action='list' to discover, 'execute' to run, 'connect' to OAuth.",
|
"Execute actions on 1000+ apps via Composio (Gmail, Notion, GitHub, Slack, etc.). Use action='list' to discover, 'execute' to run, 'connect' to OAuth.",
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
let system_prompt = crate::channels::build_system_prompt(
|
let mut system_prompt = crate::channels::build_system_prompt(
|
||||||
&config.workspace_dir,
|
&config.workspace_dir,
|
||||||
model_name,
|
model_name,
|
||||||
&tool_descs,
|
&tool_descs,
|
||||||
|
|
@ -141,6 +341,9 @@ pub async fn run(
|
||||||
Some(&config.identity),
|
Some(&config.identity),
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// Append structured tool-use instructions with schemas
|
||||||
|
system_prompt.push_str(&build_tool_instructions(&tools_registry));
|
||||||
|
|
||||||
// ── Execute ──────────────────────────────────────────────────
|
// ── Execute ──────────────────────────────────────────────────
|
||||||
let start = Instant::now();
|
let start = Instant::now();
|
||||||
|
|
||||||
|
|
@ -160,9 +363,20 @@ pub async fn run(
|
||||||
format!("{context}{msg}")
|
format!("{context}{msg}")
|
||||||
};
|
};
|
||||||
|
|
||||||
let response = provider
|
let mut history = vec![
|
||||||
.chat_with_system(Some(&system_prompt), &enriched, model_name, temperature)
|
ChatMessage::system(&system_prompt),
|
||||||
.await?;
|
ChatMessage::user(&enriched),
|
||||||
|
];
|
||||||
|
|
||||||
|
let response = agent_turn(
|
||||||
|
provider.as_ref(),
|
||||||
|
&mut history,
|
||||||
|
&tools_registry,
|
||||||
|
observer.as_ref(),
|
||||||
|
model_name,
|
||||||
|
temperature,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
println!("{response}");
|
println!("{response}");
|
||||||
|
|
||||||
// Auto-save assistant response to daily log
|
// Auto-save assistant response to daily log
|
||||||
|
|
@ -184,6 +398,9 @@ pub async fn run(
|
||||||
let _ = crate::channels::Channel::listen(&cli, tx).await;
|
let _ = crate::channels::Channel::listen(&cli, tx).await;
|
||||||
});
|
});
|
||||||
|
|
||||||
|
// Persistent conversation history across turns
|
||||||
|
let mut history = vec![ChatMessage::system(&system_prompt)];
|
||||||
|
|
||||||
while let Some(msg) = rx.recv().await {
|
while let Some(msg) = rx.recv().await {
|
||||||
// Auto-save conversation turns
|
// Auto-save conversation turns
|
||||||
if config.memory.auto_save {
|
if config.memory.auto_save {
|
||||||
|
|
@ -200,11 +417,29 @@ pub async fn run(
|
||||||
format!("{context}{}", msg.content)
|
format!("{context}{}", msg.content)
|
||||||
};
|
};
|
||||||
|
|
||||||
let response = provider
|
history.push(ChatMessage::user(&enriched));
|
||||||
.chat_with_system(Some(&system_prompt), &enriched, model_name, temperature)
|
|
||||||
.await?;
|
let response = match agent_turn(
|
||||||
|
provider.as_ref(),
|
||||||
|
&mut history,
|
||||||
|
&tools_registry,
|
||||||
|
observer.as_ref(),
|
||||||
|
model_name,
|
||||||
|
temperature,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(resp) => resp,
|
||||||
|
Err(e) => {
|
||||||
|
eprintln!("\nError: {e}\n");
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
};
|
||||||
println!("\n{response}\n");
|
println!("\n{response}\n");
|
||||||
|
|
||||||
|
// Prevent unbounded history growth in long interactive sessions
|
||||||
|
trim_history(&mut history);
|
||||||
|
|
||||||
if config.memory.auto_save {
|
if config.memory.auto_save {
|
||||||
let summary = truncate_with_ellipsis(&response, 100);
|
let summary = truncate_with_ellipsis(&response, 100);
|
||||||
let _ = mem
|
let _ = mem
|
||||||
|
|
@ -224,3 +459,126 @@ pub async fn run(
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_tool_calls_extracts_single_call() {
|
||||||
|
let response = r#"Let me check that.
|
||||||
|
<tool_call>
|
||||||
|
{"name": "shell", "arguments": {"command": "ls -la"}}
|
||||||
|
</tool_call>"#;
|
||||||
|
|
||||||
|
let (text, calls) = parse_tool_calls(response);
|
||||||
|
assert_eq!(text, "Let me check that.");
|
||||||
|
assert_eq!(calls.len(), 1);
|
||||||
|
assert_eq!(calls[0].name, "shell");
|
||||||
|
assert_eq!(
|
||||||
|
calls[0].arguments.get("command").unwrap().as_str().unwrap(),
|
||||||
|
"ls -la"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_tool_calls_extracts_multiple_calls() {
|
||||||
|
let response = r#"<tool_call>
|
||||||
|
{"name": "file_read", "arguments": {"path": "a.txt"}}
|
||||||
|
</tool_call>
|
||||||
|
<tool_call>
|
||||||
|
{"name": "file_read", "arguments": {"path": "b.txt"}}
|
||||||
|
</tool_call>"#;
|
||||||
|
|
||||||
|
let (_, calls) = parse_tool_calls(response);
|
||||||
|
assert_eq!(calls.len(), 2);
|
||||||
|
assert_eq!(calls[0].name, "file_read");
|
||||||
|
assert_eq!(calls[1].name, "file_read");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_tool_calls_returns_text_only_when_no_calls() {
|
||||||
|
let response = "Just a normal response with no tools.";
|
||||||
|
let (text, calls) = parse_tool_calls(response);
|
||||||
|
assert_eq!(text, "Just a normal response with no tools.");
|
||||||
|
assert!(calls.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_tool_calls_handles_malformed_json() {
|
||||||
|
let response = r#"<tool_call>
|
||||||
|
not valid json
|
||||||
|
</tool_call>
|
||||||
|
Some text after."#;
|
||||||
|
|
||||||
|
let (text, calls) = parse_tool_calls(response);
|
||||||
|
assert!(calls.is_empty());
|
||||||
|
assert!(text.contains("Some text after."));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_tool_calls_text_before_and_after() {
|
||||||
|
let response = r#"Before text.
|
||||||
|
<tool_call>
|
||||||
|
{"name": "shell", "arguments": {"command": "echo hi"}}
|
||||||
|
</tool_call>
|
||||||
|
After text."#;
|
||||||
|
|
||||||
|
let (text, calls) = parse_tool_calls(response);
|
||||||
|
assert!(text.contains("Before text."));
|
||||||
|
assert!(text.contains("After text."));
|
||||||
|
assert_eq!(calls.len(), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn build_tool_instructions_includes_all_tools() {
|
||||||
|
use crate::security::SecurityPolicy;
|
||||||
|
let security = Arc::new(SecurityPolicy::from_config(
|
||||||
|
&crate::config::AutonomyConfig::default(),
|
||||||
|
std::path::Path::new("/tmp"),
|
||||||
|
));
|
||||||
|
let tools = tools::default_tools(security);
|
||||||
|
let instructions = build_tool_instructions(&tools);
|
||||||
|
|
||||||
|
assert!(instructions.contains("## Tool Use Protocol"));
|
||||||
|
assert!(instructions.contains("<tool_call>"));
|
||||||
|
assert!(instructions.contains("shell"));
|
||||||
|
assert!(instructions.contains("file_read"));
|
||||||
|
assert!(instructions.contains("file_write"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn trim_history_preserves_system_prompt() {
|
||||||
|
let mut history = vec![ChatMessage::system("system prompt")];
|
||||||
|
for i in 0..MAX_HISTORY_MESSAGES + 20 {
|
||||||
|
history.push(ChatMessage::user(format!("msg {i}")));
|
||||||
|
}
|
||||||
|
let original_len = history.len();
|
||||||
|
assert!(original_len > MAX_HISTORY_MESSAGES + 1);
|
||||||
|
|
||||||
|
trim_history(&mut history);
|
||||||
|
|
||||||
|
// System prompt preserved
|
||||||
|
assert_eq!(history[0].role, "system");
|
||||||
|
assert_eq!(history[0].content, "system prompt");
|
||||||
|
// Trimmed to limit
|
||||||
|
assert_eq!(history.len(), MAX_HISTORY_MESSAGES + 1); // +1 for system
|
||||||
|
// Most recent messages preserved
|
||||||
|
let last = &history[history.len() - 1];
|
||||||
|
assert_eq!(
|
||||||
|
last.content,
|
||||||
|
format!("msg {}", MAX_HISTORY_MESSAGES + 19)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn trim_history_noop_when_within_limit() {
|
||||||
|
let mut history = vec![
|
||||||
|
ChatMessage::system("sys"),
|
||||||
|
ChatMessage::user("hello"),
|
||||||
|
ChatMessage::assistant("hi"),
|
||||||
|
];
|
||||||
|
trim_history(&mut history);
|
||||||
|
assert_eq!(history.len(), 3);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@
|
||||||
//! Most LLM APIs follow the same `/v1/chat/completions` format.
|
//! Most LLM APIs follow the same `/v1/chat/completions` format.
|
||||||
//! This module provides a single implementation that works for all of them.
|
//! This module provides a single implementation that works for all of them.
|
||||||
|
|
||||||
use crate::providers::traits::Provider;
|
use crate::providers::traits::{ChatMessage, Provider};
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use reqwest::Client;
|
use reqwest::Client;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
@ -81,7 +81,7 @@ struct Message {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
struct ChatResponse {
|
struct ApiChatResponse {
|
||||||
choices: Vec<Choice>,
|
choices: Vec<Choice>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -264,6 +264,7 @@ impl Provider for OpenAiCompatibleProvider {
|
||||||
if !response.status().is_success() {
|
if !response.status().is_success() {
|
||||||
let status = response.status();
|
let status = response.status();
|
||||||
let error = response.text().await?;
|
let error = response.text().await?;
|
||||||
|
let sanitized = super::sanitize_api_error(&error);
|
||||||
|
|
||||||
if status == reqwest::StatusCode::NOT_FOUND {
|
if status == reqwest::StatusCode::NOT_FOUND {
|
||||||
return self
|
return self
|
||||||
|
|
@ -271,16 +272,88 @@ impl Provider for OpenAiCompatibleProvider {
|
||||||
.await
|
.await
|
||||||
.map_err(|responses_err| {
|
.map_err(|responses_err| {
|
||||||
anyhow::anyhow!(
|
anyhow::anyhow!(
|
||||||
"{} API error: {error} (chat completions unavailable; responses fallback failed: {responses_err})",
|
"{} API error ({status}): {sanitized} (chat completions unavailable; responses fallback failed: {responses_err})",
|
||||||
self.name
|
self.name
|
||||||
)
|
)
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
anyhow::bail!("{} API error: {error}", self.name);
|
anyhow::bail!("{} API error ({status}): {sanitized}", self.name);
|
||||||
}
|
}
|
||||||
|
|
||||||
let chat_response: ChatResponse = response.json().await?;
|
let chat_response: ApiChatResponse = response.json().await?;
|
||||||
|
|
||||||
|
chat_response
|
||||||
|
.choices
|
||||||
|
.into_iter()
|
||||||
|
.next()
|
||||||
|
.map(|c| c.message.content)
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("No response from {}", self.name))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn chat_with_history(
|
||||||
|
&self,
|
||||||
|
messages: &[ChatMessage],
|
||||||
|
model: &str,
|
||||||
|
temperature: f64,
|
||||||
|
) -> anyhow::Result<String> {
|
||||||
|
let api_key = self.api_key.as_ref().ok_or_else(|| {
|
||||||
|
anyhow::anyhow!(
|
||||||
|
"{} API key not set. Run `zeroclaw onboard` or set the appropriate env var.",
|
||||||
|
self.name
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let api_messages: Vec<Message> = messages
|
||||||
|
.iter()
|
||||||
|
.map(|m| Message {
|
||||||
|
role: m.role.clone(),
|
||||||
|
content: m.content.clone(),
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let request = ChatRequest {
|
||||||
|
model: model.to_string(),
|
||||||
|
messages: api_messages,
|
||||||
|
temperature,
|
||||||
|
};
|
||||||
|
|
||||||
|
let url = self.chat_completions_url();
|
||||||
|
let response = self
|
||||||
|
.apply_auth_header(self.client.post(&url).json(&request), api_key)
|
||||||
|
.send()
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
if !response.status().is_success() {
|
||||||
|
let status = response.status();
|
||||||
|
|
||||||
|
// Mirror chat_with_system: 404 may mean this provider uses the Responses API
|
||||||
|
if status == reqwest::StatusCode::NOT_FOUND {
|
||||||
|
// Extract system prompt and last user message for responses fallback
|
||||||
|
let system = messages.iter().find(|m| m.role == "system");
|
||||||
|
let last_user = messages.iter().rfind(|m| m.role == "user");
|
||||||
|
if let Some(user_msg) = last_user {
|
||||||
|
return self
|
||||||
|
.chat_via_responses(
|
||||||
|
api_key,
|
||||||
|
system.map(|m| m.content.as_str()),
|
||||||
|
&user_msg.content,
|
||||||
|
model,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.map_err(|responses_err| {
|
||||||
|
anyhow::anyhow!(
|
||||||
|
"{} API error (chat completions unavailable; responses fallback failed: {responses_err})",
|
||||||
|
self.name
|
||||||
|
)
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return Err(super::api_error(&self.name, response).await);
|
||||||
|
}
|
||||||
|
|
||||||
|
let chat_response: ApiChatResponse = response.json().await?;
|
||||||
|
|
||||||
chat_response
|
chat_response
|
||||||
.choices
|
.choices
|
||||||
|
|
@ -357,14 +430,14 @@ mod tests {
|
||||||
#[test]
|
#[test]
|
||||||
fn response_deserializes() {
|
fn response_deserializes() {
|
||||||
let json = r#"{"choices":[{"message":{"content":"Hello from Venice!"}}]}"#;
|
let json = r#"{"choices":[{"message":{"content":"Hello from Venice!"}}]}"#;
|
||||||
let resp: ChatResponse = serde_json::from_str(json).unwrap();
|
let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
|
||||||
assert_eq!(resp.choices[0].message.content, "Hello from Venice!");
|
assert_eq!(resp.choices[0].message.content, "Hello from Venice!");
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn response_empty_choices() {
|
fn response_empty_choices() {
|
||||||
let json = r#"{"choices":[]}"#;
|
let json = r#"{"choices":[]}"#;
|
||||||
let resp: ChatResponse = serde_json::from_str(json).unwrap();
|
let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
|
||||||
assert!(resp.choices.is_empty());
|
assert!(resp.choices.is_empty());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@ pub mod reliable;
|
||||||
pub mod router;
|
pub mod router;
|
||||||
pub mod traits;
|
pub mod traits;
|
||||||
|
|
||||||
pub use traits::Provider;
|
pub use traits::{ChatMessage, Provider};
|
||||||
|
|
||||||
use compatible::{AuthStyle, OpenAiCompatibleProvider};
|
use compatible::{AuthStyle, OpenAiCompatibleProvider};
|
||||||
use reliable::ReliableProvider;
|
use reliable::ReliableProvider;
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
use crate::providers::traits::Provider;
|
use crate::providers::traits::{ChatMessage, Provider};
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use reqwest::Client;
|
use reqwest::Client;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
@ -22,7 +22,7 @@ struct Message {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
struct ChatResponse {
|
struct ApiChatResponse {
|
||||||
choices: Vec<Choice>,
|
choices: Vec<Choice>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -112,7 +112,57 @@ impl Provider for OpenRouterProvider {
|
||||||
return Err(super::api_error("OpenRouter", response).await);
|
return Err(super::api_error("OpenRouter", response).await);
|
||||||
}
|
}
|
||||||
|
|
||||||
let chat_response: ChatResponse = response.json().await?;
|
let chat_response: ApiChatResponse = response.json().await?;
|
||||||
|
|
||||||
|
chat_response
|
||||||
|
.choices
|
||||||
|
.into_iter()
|
||||||
|
.next()
|
||||||
|
.map(|c| c.message.content)
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("No response from OpenRouter"))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn chat_with_history(
|
||||||
|
&self,
|
||||||
|
messages: &[ChatMessage],
|
||||||
|
model: &str,
|
||||||
|
temperature: f64,
|
||||||
|
) -> anyhow::Result<String> {
|
||||||
|
let api_key = self.api_key.as_ref()
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("OpenRouter API key not set. Run `zeroclaw onboard` or set OPENROUTER_API_KEY env var."))?;
|
||||||
|
|
||||||
|
let api_messages: Vec<Message> = messages
|
||||||
|
.iter()
|
||||||
|
.map(|m| Message {
|
||||||
|
role: m.role.clone(),
|
||||||
|
content: m.content.clone(),
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let request = ChatRequest {
|
||||||
|
model: model.to_string(),
|
||||||
|
messages: api_messages,
|
||||||
|
temperature,
|
||||||
|
};
|
||||||
|
|
||||||
|
let response = self
|
||||||
|
.client
|
||||||
|
.post("https://openrouter.ai/api/v1/chat/completions")
|
||||||
|
.header("Authorization", format!("Bearer {api_key}"))
|
||||||
|
.header(
|
||||||
|
"HTTP-Referer",
|
||||||
|
"https://github.com/theonlyhennygod/zeroclaw",
|
||||||
|
)
|
||||||
|
.header("X-Title", "ZeroClaw")
|
||||||
|
.json(&request)
|
||||||
|
.send()
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
if !response.status().is_success() {
|
||||||
|
return Err(super::api_error("OpenRouter", response).await);
|
||||||
|
}
|
||||||
|
|
||||||
|
let chat_response: ApiChatResponse = response.json().await?;
|
||||||
|
|
||||||
chat_response
|
chat_response
|
||||||
.choices
|
.choices
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
use super::traits::ChatMessage;
|
||||||
use super::Provider;
|
use super::Provider;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
|
|
@ -121,6 +122,68 @@ impl Provider for ReliableProvider {
|
||||||
|
|
||||||
anyhow::bail!("All providers failed. Attempts:\n{}", failures.join("\n"))
|
anyhow::bail!("All providers failed. Attempts:\n{}", failures.join("\n"))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn chat_with_history(
|
||||||
|
&self,
|
||||||
|
messages: &[ChatMessage],
|
||||||
|
model: &str,
|
||||||
|
temperature: f64,
|
||||||
|
) -> anyhow::Result<String> {
|
||||||
|
let mut failures = Vec::new();
|
||||||
|
|
||||||
|
for (provider_name, provider) in &self.providers {
|
||||||
|
let mut backoff_ms = self.base_backoff_ms;
|
||||||
|
|
||||||
|
for attempt in 0..=self.max_retries {
|
||||||
|
match provider
|
||||||
|
.chat_with_history(messages, model, temperature)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(resp) => {
|
||||||
|
if attempt > 0 {
|
||||||
|
tracing::info!(
|
||||||
|
provider = provider_name,
|
||||||
|
attempt,
|
||||||
|
"Provider recovered after retries"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
return Ok(resp);
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
let non_retryable = is_non_retryable(&e);
|
||||||
|
failures.push(format!(
|
||||||
|
"{provider_name} attempt {}/{}: {e}",
|
||||||
|
attempt + 1,
|
||||||
|
self.max_retries + 1
|
||||||
|
));
|
||||||
|
|
||||||
|
if non_retryable {
|
||||||
|
tracing::warn!(
|
||||||
|
provider = provider_name,
|
||||||
|
"Non-retryable error, switching provider"
|
||||||
|
);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
if attempt < self.max_retries {
|
||||||
|
tracing::warn!(
|
||||||
|
provider = provider_name,
|
||||||
|
attempt = attempt + 1,
|
||||||
|
max_retries = self.max_retries,
|
||||||
|
"Provider call failed, retrying"
|
||||||
|
);
|
||||||
|
tokio::time::sleep(Duration::from_millis(backoff_ms)).await;
|
||||||
|
backoff_ms = (backoff_ms.saturating_mul(2)).min(10_000);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tracing::warn!(provider = provider_name, "Switching to fallback provider");
|
||||||
|
}
|
||||||
|
|
||||||
|
anyhow::bail!("All providers failed. Attempts:\n{}", failures.join("\n"))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
|
@ -151,6 +214,19 @@ mod tests {
|
||||||
}
|
}
|
||||||
Ok(self.response.to_string())
|
Ok(self.response.to_string())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn chat_with_history(
|
||||||
|
&self,
|
||||||
|
_messages: &[ChatMessage],
|
||||||
|
_model: &str,
|
||||||
|
_temperature: f64,
|
||||||
|
) -> anyhow::Result<String> {
|
||||||
|
let attempt = self.calls.fetch_add(1, Ordering::SeqCst) + 1;
|
||||||
|
if attempt <= self.fail_until_attempt {
|
||||||
|
anyhow::bail!(self.error);
|
||||||
|
}
|
||||||
|
Ok(self.response.to_string())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
|
|
@ -330,4 +406,73 @@ mod tests {
|
||||||
assert_eq!(primary_calls.load(Ordering::SeqCst), 1);
|
assert_eq!(primary_calls.load(Ordering::SeqCst), 1);
|
||||||
assert_eq!(fallback_calls.load(Ordering::SeqCst), 1);
|
assert_eq!(fallback_calls.load(Ordering::SeqCst), 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn chat_with_history_retries_then_recovers() {
|
||||||
|
let calls = Arc::new(AtomicUsize::new(0));
|
||||||
|
let provider = ReliableProvider::new(
|
||||||
|
vec![(
|
||||||
|
"primary".into(),
|
||||||
|
Box::new(MockProvider {
|
||||||
|
calls: Arc::clone(&calls),
|
||||||
|
fail_until_attempt: 1,
|
||||||
|
response: "history ok",
|
||||||
|
error: "temporary",
|
||||||
|
}),
|
||||||
|
)],
|
||||||
|
2,
|
||||||
|
1,
|
||||||
|
);
|
||||||
|
|
||||||
|
let messages = vec![
|
||||||
|
ChatMessage::system("system"),
|
||||||
|
ChatMessage::user("hello"),
|
||||||
|
];
|
||||||
|
let result = provider
|
||||||
|
.chat_with_history(&messages, "test", 0.0)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(result, "history ok");
|
||||||
|
assert_eq!(calls.load(Ordering::SeqCst), 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn chat_with_history_falls_back() {
|
||||||
|
let primary_calls = Arc::new(AtomicUsize::new(0));
|
||||||
|
let fallback_calls = Arc::new(AtomicUsize::new(0));
|
||||||
|
|
||||||
|
let provider = ReliableProvider::new(
|
||||||
|
vec![
|
||||||
|
(
|
||||||
|
"primary".into(),
|
||||||
|
Box::new(MockProvider {
|
||||||
|
calls: Arc::clone(&primary_calls),
|
||||||
|
fail_until_attempt: usize::MAX,
|
||||||
|
response: "never",
|
||||||
|
error: "primary down",
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"fallback".into(),
|
||||||
|
Box::new(MockProvider {
|
||||||
|
calls: Arc::clone(&fallback_calls),
|
||||||
|
fail_until_attempt: 0,
|
||||||
|
response: "fallback ok",
|
||||||
|
error: "fallback err",
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
1,
|
||||||
|
1,
|
||||||
|
);
|
||||||
|
|
||||||
|
let messages = vec![ChatMessage::user("hello")];
|
||||||
|
let result = provider
|
||||||
|
.chat_with_history(&messages, "test", 0.0)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(result, "fallback ok");
|
||||||
|
assert_eq!(primary_calls.load(Ordering::SeqCst), 2);
|
||||||
|
assert_eq!(fallback_calls.load(Ordering::SeqCst), 1);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
use super::traits::ChatMessage;
|
||||||
use super::Provider;
|
use super::Provider;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
@ -112,6 +113,19 @@ impl Provider for RouterProvider {
|
||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn chat_with_history(
|
||||||
|
&self,
|
||||||
|
messages: &[ChatMessage],
|
||||||
|
model: &str,
|
||||||
|
temperature: f64,
|
||||||
|
) -> anyhow::Result<String> {
|
||||||
|
let (provider_idx, resolved_model) = self.resolve(model);
|
||||||
|
let (_, provider) = &self.providers[provider_idx];
|
||||||
|
provider
|
||||||
|
.chat_with_history(messages, &resolved_model, temperature)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
|
||||||
async fn warmup(&self) -> anyhow::Result<()> {
|
async fn warmup(&self) -> anyhow::Result<()> {
|
||||||
for (name, provider) in &self.providers {
|
for (name, provider) in &self.providers {
|
||||||
tracing::info!(provider = name, "Warming up routed provider");
|
tracing::info!(provider = name, "Warming up routed provider");
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,86 @@
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
/// A single message in a conversation.
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct ChatMessage {
|
||||||
|
pub role: String,
|
||||||
|
pub content: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ChatMessage {
|
||||||
|
pub fn system(content: impl Into<String>) -> Self {
|
||||||
|
Self {
|
||||||
|
role: "system".into(),
|
||||||
|
content: content.into(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn user(content: impl Into<String>) -> Self {
|
||||||
|
Self {
|
||||||
|
role: "user".into(),
|
||||||
|
content: content.into(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn assistant(content: impl Into<String>) -> Self {
|
||||||
|
Self {
|
||||||
|
role: "assistant".into(),
|
||||||
|
content: content.into(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A tool call requested by the LLM.
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct ToolCall {
|
||||||
|
pub id: String,
|
||||||
|
pub name: String,
|
||||||
|
pub arguments: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// An LLM response that may contain text, tool calls, or both.
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct ChatResponse {
|
||||||
|
/// Text content of the response (may be empty if only tool calls).
|
||||||
|
pub text: Option<String>,
|
||||||
|
/// Tool calls requested by the LLM.
|
||||||
|
pub tool_calls: Vec<ToolCall>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ChatResponse {
|
||||||
|
/// True when the LLM wants to invoke at least one tool.
|
||||||
|
pub fn has_tool_calls(&self) -> bool {
|
||||||
|
!self.tool_calls.is_empty()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Convenience: return text content or empty string.
|
||||||
|
pub fn text_or_empty(&self) -> &str {
|
||||||
|
self.text.as_deref().unwrap_or("")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A tool result to feed back to the LLM.
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct ToolResultMessage {
|
||||||
|
pub tool_call_id: String,
|
||||||
|
pub content: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A message in a multi-turn conversation, including tool interactions.
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
#[serde(tag = "type")]
|
||||||
|
pub enum ConversationMessage {
|
||||||
|
/// Regular chat message (system, user, assistant).
|
||||||
|
Chat(ChatMessage),
|
||||||
|
/// Tool calls from the assistant (stored for history fidelity).
|
||||||
|
AssistantToolCalls {
|
||||||
|
text: Option<String>,
|
||||||
|
tool_calls: Vec<ToolCall>,
|
||||||
|
},
|
||||||
|
/// Result of a tool execution, fed back to the LLM.
|
||||||
|
ToolResult(ToolResultMessage),
|
||||||
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
pub trait Provider: Send + Sync {
|
pub trait Provider: Send + Sync {
|
||||||
|
|
@ -15,9 +97,95 @@ pub trait Provider: Send + Sync {
|
||||||
temperature: f64,
|
temperature: f64,
|
||||||
) -> anyhow::Result<String>;
|
) -> anyhow::Result<String>;
|
||||||
|
|
||||||
|
/// Multi-turn conversation. Default implementation extracts the last user
|
||||||
|
/// message and delegates to `chat_with_system`.
|
||||||
|
async fn chat_with_history(
|
||||||
|
&self,
|
||||||
|
messages: &[ChatMessage],
|
||||||
|
model: &str,
|
||||||
|
temperature: f64,
|
||||||
|
) -> anyhow::Result<String> {
|
||||||
|
let system = messages
|
||||||
|
.iter()
|
||||||
|
.find(|m| m.role == "system")
|
||||||
|
.map(|m| m.content.as_str());
|
||||||
|
let last_user = messages
|
||||||
|
.iter()
|
||||||
|
.rfind(|m| m.role == "user")
|
||||||
|
.map(|m| m.content.as_str())
|
||||||
|
.unwrap_or("");
|
||||||
|
self.chat_with_system(system, last_user, model, temperature)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
|
||||||
/// Warm up the HTTP connection pool (TLS handshake, DNS, HTTP/2 setup).
|
/// Warm up the HTTP connection pool (TLS handshake, DNS, HTTP/2 setup).
|
||||||
/// Default implementation is a no-op; providers with HTTP clients should override.
|
/// Default implementation is a no-op; providers with HTTP clients should override.
|
||||||
async fn warmup(&self) -> anyhow::Result<()> {
|
async fn warmup(&self) -> anyhow::Result<()> {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn chat_message_constructors() {
|
||||||
|
let sys = ChatMessage::system("Be helpful");
|
||||||
|
assert_eq!(sys.role, "system");
|
||||||
|
assert_eq!(sys.content, "Be helpful");
|
||||||
|
|
||||||
|
let user = ChatMessage::user("Hello");
|
||||||
|
assert_eq!(user.role, "user");
|
||||||
|
|
||||||
|
let asst = ChatMessage::assistant("Hi there");
|
||||||
|
assert_eq!(asst.role, "assistant");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn chat_response_helpers() {
|
||||||
|
let empty = ChatResponse {
|
||||||
|
text: None,
|
||||||
|
tool_calls: vec![],
|
||||||
|
};
|
||||||
|
assert!(!empty.has_tool_calls());
|
||||||
|
assert_eq!(empty.text_or_empty(), "");
|
||||||
|
|
||||||
|
let with_tools = ChatResponse {
|
||||||
|
text: Some("Let me check".into()),
|
||||||
|
tool_calls: vec![ToolCall {
|
||||||
|
id: "1".into(),
|
||||||
|
name: "shell".into(),
|
||||||
|
arguments: "{}".into(),
|
||||||
|
}],
|
||||||
|
};
|
||||||
|
assert!(with_tools.has_tool_calls());
|
||||||
|
assert_eq!(with_tools.text_or_empty(), "Let me check");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn tool_call_serialization() {
|
||||||
|
let tc = ToolCall {
|
||||||
|
id: "call_123".into(),
|
||||||
|
name: "file_read".into(),
|
||||||
|
arguments: r#"{"path":"test.txt"}"#.into(),
|
||||||
|
};
|
||||||
|
let json = serde_json::to_string(&tc).unwrap();
|
||||||
|
assert!(json.contains("call_123"));
|
||||||
|
assert!(json.contains("file_read"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn conversation_message_variants() {
|
||||||
|
let chat = ConversationMessage::Chat(ChatMessage::user("hi"));
|
||||||
|
let json = serde_json::to_string(&chat).unwrap();
|
||||||
|
assert!(json.contains("\"type\":\"Chat\""));
|
||||||
|
|
||||||
|
let tool_result = ConversationMessage::ToolResult(ToolResultMessage {
|
||||||
|
tool_call_id: "1".into(),
|
||||||
|
content: "done".into(),
|
||||||
|
});
|
||||||
|
let json = serde_json::to_string(&tool_result).unwrap();
|
||||||
|
assert!(json.contains("\"type\":\"ToolResult\""));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue