From f322360248cd07f801c905ffd5b9f7c4e75567c6 Mon Sep 17 00:00:00 2001 From: Vernon Stinebaker Date: Tue, 17 Feb 2026 12:05:08 +0800 Subject: [PATCH] feat(providers): add native tool-call API support via chat_with_tools Add chat_with_tools() to the Provider trait with a default fallback to chat_with_history(). Implement native tool calling in OpenRouterProvider, reusing existing NativeChatRequest/NativeChatResponse structs. Wire the agent loop to use native tool calls when the provider supports them, falling back to XML-based parsing otherwise. Changes are purely additive to traits.rs and openrouter.rs. The only deletions (36 lines) are within run_tool_call_loop() in loop_.rs where the LLM call section was replaced with a branching if/else for native vs XML tool calling. Includes 5 new tests covering: - chat_with_tools error path (missing API key) - NativeChatResponse deserialization (tool calls only, mixed) - parse_native_response conversion to ChatResponse - tools_to_openai_format schema validation --- src/agent/loop_.rs | 163 ++++++++++++++++++++++++++------- src/providers/openrouter.rs | 178 ++++++++++++++++++++++++++++++++++++ src/providers/traits.rs | 17 ++++ 3 files changed, 325 insertions(+), 33 deletions(-) diff --git a/src/agent/loop_.rs b/src/agent/loop_.rs index 4495995..9a21395 100644 --- a/src/agent/loop_.rs +++ b/src/agent/loop_.rs @@ -27,6 +27,23 @@ const COMPACTION_MAX_SOURCE_CHARS: usize = 12_000; /// Max characters retained in stored compaction summary. const COMPACTION_MAX_SUMMARY_CHARS: usize = 2_000; +/// Convert a tool registry to OpenAI function-calling format for native tool support. +fn tools_to_openai_format(tools_registry: &[Box]) -> Vec { + tools_registry + .iter() + .map(|tool| { + serde_json::json!({ + "type": "function", + "function": { + "name": tool.name(), + "description": tool.description(), + "parameters": tool.parameters_schema() + } + }) + }) + .collect() +} + fn autosave_memory_key(prefix: &str) -> String { format!("{prefix}_{}", Uuid::new_v4()) } @@ -454,6 +471,14 @@ pub(crate) async fn run_tool_call_loop( temperature: f64, silent: bool, ) -> Result { + // Build native tool definitions once if the provider supports them. + let use_native_tools = provider.supports_native_tools() && !tools_registry.is_empty(); + let tool_definitions = if use_native_tools { + tools_to_openai_format(tools_registry) + } else { + Vec::new() + }; + for _iteration in 0..MAX_TOOL_ITERATIONS { observer.record_event(&ObserverEvent::LlmRequest { provider: provider_name.to_string(), @@ -462,49 +487,95 @@ pub(crate) async fn run_tool_call_loop( }); let llm_started_at = Instant::now(); - let response = match provider - .chat_with_history(history, model, temperature) - .await - { - Ok(resp) => { - observer.record_event(&ObserverEvent::LlmResponse { - provider: provider_name.to_string(), - model: model.to_string(), - duration: llm_started_at.elapsed(), - success: true, - error_message: None, - }); - resp + + // Choose between native tool-call API and prompt-based tool use. + let (response_text, parsed_text, tool_calls, assistant_history_content) = if use_native_tools { + match provider + .chat_with_tools(history, &tool_definitions, model, temperature) + .await + { + Ok(resp) => { + observer.record_event(&ObserverEvent::LlmResponse { + provider: provider_name.to_string(), + model: model.to_string(), + duration: llm_started_at.elapsed(), + success: true, + error_message: None, + }); + let response_text = resp.text_or_empty().to_string(); + let mut calls = parse_structured_tool_calls(&resp.tool_calls); + let mut parsed_text = String::new(); + + if calls.is_empty() { + let (fallback_text, fallback_calls) = parse_tool_calls(&response_text); + if !fallback_text.is_empty() { + parsed_text = fallback_text; + } + calls = fallback_calls; + } + + let assistant_history_content = if resp.tool_calls.is_empty() { + response_text.clone() + } else { + build_assistant_history_with_tool_calls(&response_text, &resp.tool_calls) + }; + + (response_text, parsed_text, calls, assistant_history_content) + } + Err(e) => { + observer.record_event(&ObserverEvent::LlmResponse { + provider: provider_name.to_string(), + model: model.to_string(), + duration: llm_started_at.elapsed(), + success: false, + error_message: Some(crate::providers::sanitize_api_error(&e.to_string())), + }); + return Err(e); + } } - Err(e) => { - observer.record_event(&ObserverEvent::LlmResponse { - provider: provider_name.to_string(), - model: model.to_string(), - duration: llm_started_at.elapsed(), - success: false, - error_message: Some(crate::providers::sanitize_api_error(&e.to_string())), - }); - return Err(e); + } else { + match provider.chat_with_history(history, model, temperature).await { + Ok(resp) => { + observer.record_event(&ObserverEvent::LlmResponse { + provider: provider_name.to_string(), + model: model.to_string(), + duration: llm_started_at.elapsed(), + success: true, + error_message: None, + }); + let response_text = resp; + let assistant_history_content = response_text.clone(); + let (parsed_text, calls) = parse_tool_calls(&response_text); + (response_text, parsed_text, calls, assistant_history_content) + } + Err(e) => { + observer.record_event(&ObserverEvent::LlmResponse { + provider: provider_name.to_string(), + model: model.to_string(), + duration: llm_started_at.elapsed(), + success: false, + error_message: Some(crate::providers::sanitize_api_error(&e.to_string())), + }); + return Err(e); + } } }; - let response_text = response; - let assistant_history_content = response_text.clone(); - let (parsed_text, tool_calls) = parse_tool_calls(&response_text); + let display_text = if parsed_text.is_empty() { + response_text.clone() + } else { + parsed_text + }; if tool_calls.is_empty() { // No tool calls — this is the final response history.push(ChatMessage::assistant(response_text.clone())); - return Ok(if parsed_text.is_empty() { - response_text - } else { - parsed_text - }); + return Ok(display_text); } // Print any text the LLM produced alongside tool calls (unless silent) - if !silent && !parsed_text.is_empty() { - print!("{parsed_text}"); + if !silent && !display_text.is_empty() { + print!("{display_text}"); let _ = std::io::stdout().flush(); } @@ -550,7 +621,7 @@ pub(crate) async fn run_tool_call_loop( } // Add assistant message with tool calls + tool results to history - history.push(ChatMessage::assistant(assistant_history_content.clone())); + history.push(ChatMessage::assistant(assistant_history_content)); history.push(ChatMessage::user(format!("[Tool results]\n{tool_results}"))); } @@ -1309,6 +1380,32 @@ I will now call the tool with this payload: assert!(instructions.contains("file_write")); } + #[test] + fn tools_to_openai_format_produces_valid_schema() { + use crate::security::SecurityPolicy; + let security = Arc::new(SecurityPolicy::from_config( + &crate::config::AutonomyConfig::default(), + std::path::Path::new("/tmp"), + )); + let tools = tools::default_tools(security); + let formatted = tools_to_openai_format(&tools); + + assert!(!formatted.is_empty()); + for tool_json in &formatted { + assert_eq!(tool_json["type"], "function"); + assert!(tool_json["function"]["name"].is_string()); + assert!(tool_json["function"]["description"].is_string()); + assert!(!tool_json["function"]["name"].as_str().unwrap().is_empty()); + } + // Verify known tools are present + let names: Vec<&str> = formatted + .iter() + .filter_map(|t| t["function"]["name"].as_str()) + .collect(); + assert!(names.contains(&"shell")); + assert!(names.contains(&"file_read")); + } + #[test] fn trim_history_preserves_system_prompt() { let mut history = vec![ChatMessage::system("system prompt")]; diff --git a/src/providers/openrouter.rs b/src/providers/openrouter.rs index 3a02e2d..8e84524 100644 --- a/src/providers/openrouter.rs +++ b/src/providers/openrouter.rs @@ -401,6 +401,90 @@ impl Provider for OpenRouterProvider { fn supports_native_tools(&self) -> bool { true } + + async fn chat_with_tools( + &self, + messages: &[ChatMessage], + tools: &[serde_json::Value], + model: &str, + temperature: f64, + ) -> anyhow::Result { + let api_key = self.api_key.as_ref().ok_or_else(|| { + anyhow::anyhow!( + "OpenRouter API key not set. Run `zeroclaw onboard` or set OPENROUTER_API_KEY env var." + ) + })?; + + // Convert tool JSON values to NativeToolSpec + let native_tools: Option> = if tools.is_empty() { + None + } else { + let specs: Vec = tools + .iter() + .filter_map(|t| { + let func = t.get("function")?; + Some(NativeToolSpec { + kind: "function".to_string(), + function: NativeToolFunctionSpec { + name: func.get("name")?.as_str()?.to_string(), + description: func + .get("description") + .and_then(|d| d.as_str()) + .unwrap_or("") + .to_string(), + parameters: func + .get("parameters") + .cloned() + .unwrap_or(serde_json::json!({})), + }, + }) + }) + .collect(); + if specs.is_empty() { + None + } else { + Some(specs) + } + }; + + // Convert ChatMessage to NativeMessage, preserving structured assistant/tool entries + // when history contains native tool-call metadata. + let native_messages = Self::convert_messages(messages); + + let native_request = NativeChatRequest { + model: model.to_string(), + messages: native_messages, + temperature, + tool_choice: native_tools.as_ref().map(|_| "auto".to_string()), + tools: native_tools, + }; + + let response = self + .client + .post("https://openrouter.ai/api/v1/chat/completions") + .header("Authorization", format!("Bearer {api_key}")) + .header( + "HTTP-Referer", + "https://github.com/theonlyhennygod/zeroclaw", + ) + .header("X-Title", "ZeroClaw") + .json(&native_request) + .send() + .await?; + + if !response.status().is_success() { + return Err(super::api_error("OpenRouter", response).await); + } + + let native_response: NativeChatResponse = response.json().await?; + let message = native_response + .choices + .into_iter() + .next() + .map(|c| c.message) + .ok_or_else(|| anyhow::anyhow!("No response from OpenRouter"))?; + Ok(Self::parse_native_response(message)) + } } #[cfg(test)] @@ -534,4 +618,98 @@ mod tests { assert!(response.choices.is_empty()); } + + #[tokio::test] + async fn chat_with_tools_fails_without_key() { + let provider = OpenRouterProvider::new(None); + let messages = vec![ChatMessage { + role: "user".into(), + content: "What is the date?".into(), + }]; + let tools = vec![serde_json::json!({ + "type": "function", + "function": { + "name": "shell", + "description": "Run a shell command", + "parameters": {"type": "object", "properties": {"command": {"type": "string"}}} + } + })]; + + let result = provider + .chat_with_tools(&messages, &tools, "deepseek/deepseek-chat", 0.5) + .await; + + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("API key not set")); + } + + #[test] + fn native_response_deserializes_with_tool_calls() { + let json = r#"{ + "choices":[{ + "message":{ + "content":null, + "tool_calls":[ + {"id":"call_123","type":"function","function":{"name":"get_price","arguments":"{\"symbol\":\"BTC\"}"}} + ] + } + }] + }"#; + + let response: NativeChatResponse = serde_json::from_str(json).unwrap(); + + assert_eq!(response.choices.len(), 1); + let message = &response.choices[0].message; + assert!(message.content.is_none()); + let tool_calls = message.tool_calls.as_ref().unwrap(); + assert_eq!(tool_calls.len(), 1); + assert_eq!(tool_calls[0].id.as_deref(), Some("call_123")); + assert_eq!(tool_calls[0].function.name, "get_price"); + assert_eq!(tool_calls[0].function.arguments, "{\"symbol\":\"BTC\"}"); + } + + #[test] + fn native_response_deserializes_with_text_and_tool_calls() { + let json = r#"{ + "choices":[{ + "message":{ + "content":"I'll get that for you.", + "tool_calls":[ + {"id":"call_456","type":"function","function":{"name":"shell","arguments":"{\"command\":\"date\"}"}} + ] + } + }] + }"#; + + let response: NativeChatResponse = serde_json::from_str(json).unwrap(); + + assert_eq!(response.choices.len(), 1); + let message = &response.choices[0].message; + assert_eq!(message.content.as_deref(), Some("I'll get that for you.")); + let tool_calls = message.tool_calls.as_ref().unwrap(); + assert_eq!(tool_calls.len(), 1); + assert_eq!(tool_calls[0].function.name, "shell"); + } + + #[test] + fn parse_native_response_converts_to_chat_response() { + let message = NativeResponseMessage { + content: Some("Here you go.".into()), + tool_calls: Some(vec![NativeToolCall { + id: Some("call_789".into()), + kind: Some("function".into()), + function: NativeFunctionCall { + name: "file_read".into(), + arguments: r#"{"path":"test.txt"}"#.into(), + }, + }]), + }; + + let response = OpenRouterProvider::parse_native_response(message); + + assert_eq!(response.text.as_deref(), Some("Here you go.")); + assert_eq!(response.tool_calls.len(), 1); + assert_eq!(response.tool_calls[0].id, "call_789"); + assert_eq!(response.tool_calls[0].name, "file_read"); + } } diff --git a/src/providers/traits.rs b/src/providers/traits.rs index 2117e57..7c61769 100644 --- a/src/providers/traits.rs +++ b/src/providers/traits.rs @@ -170,6 +170,23 @@ pub trait Provider: Send + Sync { async fn warmup(&self) -> anyhow::Result<()> { Ok(()) } + + /// Chat with tool definitions for native function calling support. + /// The default implementation falls back to chat_with_history and returns + /// an empty tool_calls vector (prompt-based tool use only). + async fn chat_with_tools( + &self, + messages: &[ChatMessage], + _tools: &[serde_json::Value], + model: &str, + temperature: f64, + ) -> anyhow::Result { + let text = self.chat_with_history(messages, model, temperature).await?; + Ok(ChatResponse { + text: Some(text), + tool_calls: Vec::new(), + }) + } } #[cfg(test)]