312 lines
11 KiB
Rust
312 lines
11 KiB
Rust
use crate::providers::{ChatMessage, ChatResponse, ConversationMessage, ToolResultMessage};
|
|
use crate::tools::{Tool, ToolSpec};
|
|
use serde_json::Value;
|
|
use std::fmt::Write;
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub struct ParsedToolCall {
|
|
pub name: String,
|
|
pub arguments: Value,
|
|
pub tool_call_id: Option<String>,
|
|
}
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub struct ToolExecutionResult {
|
|
pub name: String,
|
|
pub output: String,
|
|
pub success: bool,
|
|
pub tool_call_id: Option<String>,
|
|
}
|
|
|
|
pub trait ToolDispatcher: Send + Sync {
|
|
fn parse_response(&self, response: &ChatResponse) -> (String, Vec<ParsedToolCall>);
|
|
fn format_results(&self, results: &[ToolExecutionResult]) -> ConversationMessage;
|
|
fn prompt_instructions(&self, tools: &[Box<dyn Tool>]) -> String;
|
|
fn to_provider_messages(&self, history: &[ConversationMessage]) -> Vec<ChatMessage>;
|
|
fn should_send_tool_specs(&self) -> bool;
|
|
}
|
|
|
|
#[derive(Default)]
|
|
pub struct XmlToolDispatcher;
|
|
|
|
impl XmlToolDispatcher {
|
|
fn parse_xml_tool_calls(response: &str) -> (String, Vec<ParsedToolCall>) {
|
|
let mut text_parts = Vec::new();
|
|
let mut calls = Vec::new();
|
|
let mut remaining = response;
|
|
|
|
while let Some(start) = remaining.find("<tool_call>") {
|
|
let before = &remaining[..start];
|
|
if !before.trim().is_empty() {
|
|
text_parts.push(before.trim().to_string());
|
|
}
|
|
|
|
if let Some(end) = remaining[start..].find("</tool_call>") {
|
|
let inner = &remaining[start + 11..start + end];
|
|
match serde_json::from_str::<Value>(inner.trim()) {
|
|
Ok(parsed) => {
|
|
let name = parsed
|
|
.get("name")
|
|
.and_then(Value::as_str)
|
|
.unwrap_or("")
|
|
.to_string();
|
|
if name.is_empty() {
|
|
remaining = &remaining[start + end + 12..];
|
|
continue;
|
|
}
|
|
let arguments = parsed
|
|
.get("arguments")
|
|
.cloned()
|
|
.unwrap_or_else(|| Value::Object(serde_json::Map::new()));
|
|
calls.push(ParsedToolCall {
|
|
name,
|
|
arguments,
|
|
tool_call_id: None,
|
|
});
|
|
}
|
|
Err(e) => {
|
|
tracing::warn!("Malformed <tool_call> JSON: {e}");
|
|
}
|
|
}
|
|
remaining = &remaining[start + end + 12..];
|
|
} else {
|
|
break;
|
|
}
|
|
}
|
|
|
|
if !remaining.trim().is_empty() {
|
|
text_parts.push(remaining.trim().to_string());
|
|
}
|
|
|
|
(text_parts.join("\n"), calls)
|
|
}
|
|
|
|
pub fn tool_specs(tools: &[Box<dyn Tool>]) -> Vec<ToolSpec> {
|
|
tools.iter().map(|tool| tool.spec()).collect()
|
|
}
|
|
}
|
|
|
|
impl ToolDispatcher for XmlToolDispatcher {
|
|
fn parse_response(&self, response: &ChatResponse) -> (String, Vec<ParsedToolCall>) {
|
|
let text = response.text_or_empty();
|
|
Self::parse_xml_tool_calls(text)
|
|
}
|
|
|
|
fn format_results(&self, results: &[ToolExecutionResult]) -> ConversationMessage {
|
|
let mut content = String::new();
|
|
for result in results {
|
|
let status = if result.success { "ok" } else { "error" };
|
|
let _ = writeln!(
|
|
content,
|
|
"<tool_result name=\"{}\" status=\"{}\">\n{}\n</tool_result>",
|
|
result.name, status, result.output
|
|
);
|
|
}
|
|
ConversationMessage::Chat(ChatMessage::user(format!("[Tool results]\n{content}")))
|
|
}
|
|
|
|
fn prompt_instructions(&self, tools: &[Box<dyn Tool>]) -> String {
|
|
let mut instructions = String::new();
|
|
instructions.push_str("## Tool Use Protocol\n\n");
|
|
instructions
|
|
.push_str("To use a tool, wrap a JSON object in <tool_call></tool_call> tags:\n\n");
|
|
instructions.push_str(
|
|
"```\n<tool_call>\n{\"name\": \"tool_name\", \"arguments\": {\"param\": \"value\"}}\n</tool_call>\n```\n\n",
|
|
);
|
|
instructions.push_str("### Available Tools\n\n");
|
|
|
|
for tool in tools {
|
|
let _ = writeln!(
|
|
instructions,
|
|
"- **{}**: {}\n Parameters: `{}`",
|
|
tool.name(),
|
|
tool.description(),
|
|
tool.parameters_schema()
|
|
);
|
|
}
|
|
|
|
instructions
|
|
}
|
|
|
|
fn to_provider_messages(&self, history: &[ConversationMessage]) -> Vec<ChatMessage> {
|
|
history
|
|
.iter()
|
|
.flat_map(|msg| match msg {
|
|
ConversationMessage::Chat(chat) => vec![chat.clone()],
|
|
ConversationMessage::AssistantToolCalls { text, .. } => {
|
|
vec![ChatMessage::assistant(text.clone().unwrap_or_default())]
|
|
}
|
|
ConversationMessage::ToolResults(results) => {
|
|
let mut content = String::new();
|
|
for result in results {
|
|
let _ = writeln!(
|
|
content,
|
|
"<tool_result id=\"{}\">\n{}\n</tool_result>",
|
|
result.tool_call_id, result.content
|
|
);
|
|
}
|
|
vec![ChatMessage::user(format!("[Tool results]\n{content}"))]
|
|
}
|
|
})
|
|
.collect()
|
|
}
|
|
|
|
fn should_send_tool_specs(&self) -> bool {
|
|
false
|
|
}
|
|
}
|
|
|
|
pub struct NativeToolDispatcher;
|
|
|
|
impl ToolDispatcher for NativeToolDispatcher {
|
|
fn parse_response(&self, response: &ChatResponse) -> (String, Vec<ParsedToolCall>) {
|
|
let text = response.text.clone().unwrap_or_default();
|
|
let calls = response
|
|
.tool_calls
|
|
.iter()
|
|
.map(|tc| ParsedToolCall {
|
|
name: tc.name.clone(),
|
|
arguments: serde_json::from_str(&tc.arguments)
|
|
.unwrap_or_else(|_| Value::Object(serde_json::Map::new())),
|
|
tool_call_id: Some(tc.id.clone()),
|
|
})
|
|
.collect();
|
|
(text, calls)
|
|
}
|
|
|
|
fn format_results(&self, results: &[ToolExecutionResult]) -> ConversationMessage {
|
|
let messages = results
|
|
.iter()
|
|
.map(|result| ToolResultMessage {
|
|
tool_call_id: result
|
|
.tool_call_id
|
|
.clone()
|
|
.unwrap_or_else(|| "unknown".to_string()),
|
|
content: result.output.clone(),
|
|
})
|
|
.collect();
|
|
ConversationMessage::ToolResults(messages)
|
|
}
|
|
|
|
fn prompt_instructions(&self, _tools: &[Box<dyn Tool>]) -> String {
|
|
String::new()
|
|
}
|
|
|
|
fn to_provider_messages(&self, history: &[ConversationMessage]) -> Vec<ChatMessage> {
|
|
history
|
|
.iter()
|
|
.flat_map(|msg| match msg {
|
|
ConversationMessage::Chat(chat) => vec![chat.clone()],
|
|
ConversationMessage::AssistantToolCalls { text, tool_calls } => {
|
|
let payload = serde_json::json!({
|
|
"content": text,
|
|
"tool_calls": tool_calls,
|
|
});
|
|
vec![ChatMessage::assistant(payload.to_string())]
|
|
}
|
|
ConversationMessage::ToolResults(results) => results
|
|
.iter()
|
|
.map(|result| {
|
|
ChatMessage::tool(
|
|
serde_json::json!({
|
|
"tool_call_id": result.tool_call_id,
|
|
"content": result.content,
|
|
})
|
|
.to_string(),
|
|
)
|
|
})
|
|
.collect(),
|
|
})
|
|
.collect()
|
|
}
|
|
|
|
fn should_send_tool_specs(&self) -> bool {
|
|
true
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn xml_dispatcher_parses_tool_calls() {
|
|
let response = ChatResponse {
|
|
text: Some(
|
|
"Checking\n<tool_call>{\"name\":\"shell\",\"arguments\":{\"command\":\"ls\"}}</tool_call>"
|
|
.into(),
|
|
),
|
|
tool_calls: vec![],
|
|
};
|
|
let dispatcher = XmlToolDispatcher;
|
|
let (_, calls) = dispatcher.parse_response(&response);
|
|
assert_eq!(calls.len(), 1);
|
|
assert_eq!(calls[0].name, "shell");
|
|
}
|
|
|
|
#[test]
|
|
fn native_dispatcher_roundtrip() {
|
|
let response = ChatResponse {
|
|
text: Some("ok".into()),
|
|
tool_calls: vec![crate::providers::ToolCall {
|
|
id: "tc1".into(),
|
|
name: "file_read".into(),
|
|
arguments: "{\"path\":\"a.txt\"}".into(),
|
|
}],
|
|
};
|
|
let dispatcher = NativeToolDispatcher;
|
|
let (_, calls) = dispatcher.parse_response(&response);
|
|
assert_eq!(calls.len(), 1);
|
|
assert_eq!(calls[0].tool_call_id.as_deref(), Some("tc1"));
|
|
|
|
let msg = dispatcher.format_results(&[ToolExecutionResult {
|
|
name: "file_read".into(),
|
|
output: "hello".into(),
|
|
success: true,
|
|
tool_call_id: Some("tc1".into()),
|
|
}]);
|
|
match msg {
|
|
ConversationMessage::ToolResults(results) => {
|
|
assert_eq!(results.len(), 1);
|
|
assert_eq!(results[0].tool_call_id, "tc1");
|
|
}
|
|
_ => panic!("expected tool results"),
|
|
}
|
|
}
|
|
|
|
#[test]
|
|
fn xml_format_results_contains_tool_result_tags() {
|
|
let dispatcher = XmlToolDispatcher;
|
|
let msg = dispatcher.format_results(&[ToolExecutionResult {
|
|
name: "shell".into(),
|
|
output: "ok".into(),
|
|
success: true,
|
|
tool_call_id: None,
|
|
}]);
|
|
let rendered = match msg {
|
|
ConversationMessage::Chat(chat) => chat.content,
|
|
_ => String::new(),
|
|
};
|
|
assert!(rendered.contains("<tool_result"));
|
|
assert!(rendered.contains("shell"));
|
|
}
|
|
|
|
#[test]
|
|
fn native_format_results_keeps_tool_call_id() {
|
|
let dispatcher = NativeToolDispatcher;
|
|
let msg = dispatcher.format_results(&[ToolExecutionResult {
|
|
name: "shell".into(),
|
|
output: "ok".into(),
|
|
success: true,
|
|
tool_call_id: Some("tc-1".into()),
|
|
}]);
|
|
|
|
match msg {
|
|
ConversationMessage::ToolResults(results) => {
|
|
assert_eq!(results.len(), 1);
|
|
assert_eq!(results[0].tool_call_id, "tc-1");
|
|
}
|
|
_ => panic!("expected ToolResults variant"),
|
|
}
|
|
}
|
|
}
|