use crate::providers::{ChatMessage, ChatResponse, ConversationMessage, ToolResultMessage}; use crate::tools::{Tool, ToolSpec}; use serde_json::Value; use std::fmt::Write; #[derive(Debug, Clone)] pub struct ParsedToolCall { pub name: String, pub arguments: Value, pub tool_call_id: Option, } #[derive(Debug, Clone)] pub struct ToolExecutionResult { pub name: String, pub output: String, pub success: bool, pub tool_call_id: Option, } pub trait ToolDispatcher: Send + Sync { fn parse_response(&self, response: &ChatResponse) -> (String, Vec); fn format_results(&self, results: &[ToolExecutionResult]) -> ConversationMessage; fn prompt_instructions(&self, tools: &[Box]) -> String; fn to_provider_messages(&self, history: &[ConversationMessage]) -> Vec; fn should_send_tool_specs(&self) -> bool; } #[derive(Default)] pub struct XmlToolDispatcher; impl XmlToolDispatcher { fn parse_xml_tool_calls(response: &str) -> (String, Vec) { let mut text_parts = Vec::new(); let mut calls = Vec::new(); let mut remaining = response; while let Some(start) = remaining.find("") { let before = &remaining[..start]; if !before.trim().is_empty() { text_parts.push(before.trim().to_string()); } if let Some(end) = remaining[start..].find("") { let inner = &remaining[start + 11..start + end]; match serde_json::from_str::(inner.trim()) { Ok(parsed) => { let name = parsed .get("name") .and_then(Value::as_str) .unwrap_or("") .to_string(); if name.is_empty() { remaining = &remaining[start + end + 12..]; continue; } let arguments = parsed .get("arguments") .cloned() .unwrap_or_else(|| Value::Object(serde_json::Map::new())); calls.push(ParsedToolCall { name, arguments, tool_call_id: None, }); } Err(e) => { tracing::warn!("Malformed JSON: {e}"); } } remaining = &remaining[start + end + 12..]; } else { break; } } if !remaining.trim().is_empty() { text_parts.push(remaining.trim().to_string()); } (text_parts.join("\n"), calls) } pub fn tool_specs(tools: &[Box]) -> Vec { tools.iter().map(|tool| tool.spec()).collect() } } impl ToolDispatcher for XmlToolDispatcher { fn parse_response(&self, response: &ChatResponse) -> (String, Vec) { let text = response.text_or_empty(); Self::parse_xml_tool_calls(text) } fn format_results(&self, results: &[ToolExecutionResult]) -> ConversationMessage { let mut content = String::new(); for result in results { let status = if result.success { "ok" } else { "error" }; let _ = writeln!( content, "\n{}\n", result.name, status, result.output ); } ConversationMessage::Chat(ChatMessage::user(format!("[Tool results]\n{content}"))) } fn prompt_instructions(&self, tools: &[Box]) -> String { let mut instructions = String::new(); instructions.push_str("## Tool Use Protocol\n\n"); instructions .push_str("To use a tool, wrap a JSON object in tags:\n\n"); instructions.push_str( "```\n\n{\"name\": \"tool_name\", \"arguments\": {\"param\": \"value\"}}\n\n```\n\n", ); instructions.push_str("### Available Tools\n\n"); for tool in tools { let _ = writeln!( instructions, "- **{}**: {}\n Parameters: `{}`", tool.name(), tool.description(), tool.parameters_schema() ); } instructions } fn to_provider_messages(&self, history: &[ConversationMessage]) -> Vec { history .iter() .flat_map(|msg| match msg { ConversationMessage::Chat(chat) => vec![chat.clone()], ConversationMessage::AssistantToolCalls { text, .. } => { vec![ChatMessage::assistant(text.clone().unwrap_or_default())] } ConversationMessage::ToolResults(results) => { let mut content = String::new(); for result in results { let _ = writeln!( content, "\n{}\n", result.tool_call_id, result.content ); } vec![ChatMessage::user(format!("[Tool results]\n{content}"))] } }) .collect() } fn should_send_tool_specs(&self) -> bool { false } } pub struct NativeToolDispatcher; impl ToolDispatcher for NativeToolDispatcher { fn parse_response(&self, response: &ChatResponse) -> (String, Vec) { let text = response.text.clone().unwrap_or_default(); let calls = response .tool_calls .iter() .map(|tc| ParsedToolCall { name: tc.name.clone(), arguments: serde_json::from_str(&tc.arguments) .unwrap_or_else(|_| Value::Object(serde_json::Map::new())), tool_call_id: Some(tc.id.clone()), }) .collect(); (text, calls) } fn format_results(&self, results: &[ToolExecutionResult]) -> ConversationMessage { let messages = results .iter() .map(|result| ToolResultMessage { tool_call_id: result .tool_call_id .clone() .unwrap_or_else(|| "unknown".to_string()), content: result.output.clone(), }) .collect(); ConversationMessage::ToolResults(messages) } fn prompt_instructions(&self, _tools: &[Box]) -> String { String::new() } fn to_provider_messages(&self, history: &[ConversationMessage]) -> Vec { history .iter() .flat_map(|msg| match msg { ConversationMessage::Chat(chat) => vec![chat.clone()], ConversationMessage::AssistantToolCalls { text, tool_calls } => { let payload = serde_json::json!({ "content": text, "tool_calls": tool_calls, }); vec![ChatMessage::assistant(payload.to_string())] } ConversationMessage::ToolResults(results) => results .iter() .map(|result| { ChatMessage::tool( serde_json::json!({ "tool_call_id": result.tool_call_id, "content": result.content, }) .to_string(), ) }) .collect(), }) .collect() } fn should_send_tool_specs(&self) -> bool { true } } #[cfg(test)] mod tests { use super::*; #[test] fn xml_dispatcher_parses_tool_calls() { let response = ChatResponse { text: Some( "Checking\n{\"name\":\"shell\",\"arguments\":{\"command\":\"ls\"}}" .into(), ), tool_calls: vec![], }; let dispatcher = XmlToolDispatcher; let (_, calls) = dispatcher.parse_response(&response); assert_eq!(calls.len(), 1); assert_eq!(calls[0].name, "shell"); } #[test] fn native_dispatcher_roundtrip() { let response = ChatResponse { text: Some("ok".into()), tool_calls: vec![crate::providers::ToolCall { id: "tc1".into(), name: "file_read".into(), arguments: "{\"path\":\"a.txt\"}".into(), }], }; let dispatcher = NativeToolDispatcher; let (_, calls) = dispatcher.parse_response(&response); assert_eq!(calls.len(), 1); assert_eq!(calls[0].tool_call_id.as_deref(), Some("tc1")); let msg = dispatcher.format_results(&[ToolExecutionResult { name: "file_read".into(), output: "hello".into(), success: true, tool_call_id: Some("tc1".into()), }]); match msg { ConversationMessage::ToolResults(results) => { assert_eq!(results.len(), 1); assert_eq!(results[0].tool_call_id, "tc1"); } _ => panic!("expected tool results"), } } #[test] fn xml_format_results_contains_tool_result_tags() { let dispatcher = XmlToolDispatcher; let msg = dispatcher.format_results(&[ToolExecutionResult { name: "shell".into(), output: "ok".into(), success: true, tool_call_id: None, }]); let rendered = match msg { ConversationMessage::Chat(chat) => chat.content, _ => String::new(), }; assert!(rendered.contains(" { assert_eq!(results.len(), 1); assert_eq!(results[0].tool_call_id, "tc-1"); } _ => panic!("expected ToolResults variant"), } } }