diff --git a/src/agent/loop_.rs b/src/agent/loop_.rs index 941e4d0..7aecc9a 100644 --- a/src/agent/loop_.rs +++ b/src/agent/loop_.rs @@ -644,7 +644,8 @@ fn parse_tool_calls(response: &str) -> (String, Vec) { remaining = &after_open[close_idx + close_tag.len()..]; } else { if let Some(json_end) = find_json_end(after_open) { - if let Ok(value) = serde_json::from_str::(&after_open[..json_end]) + if let Ok(value) = + serde_json::from_str::(&after_open[..json_end]) { let parsed_calls = parse_tool_calls_from_json_value(&value); if !parsed_calls.is_empty() { diff --git a/src/providers/compatible.rs b/src/providers/compatible.rs index e24f961..9a8cc2a 100644 --- a/src/providers/compatible.rs +++ b/src/providers/compatible.rs @@ -140,15 +140,35 @@ impl OpenAiCompatibleProvider { format!("{normalized_base}/v1/responses") } } + + fn tool_specs_to_openai_format(tools: &[crate::tools::ToolSpec]) -> Vec { + tools + .iter() + .map(|tool| { + serde_json::json!({ + "type": "function", + "function": { + "name": tool.name, + "description": tool.description, + "parameters": tool.parameters + } + }) + }) + .collect() + } } #[derive(Debug, Serialize)] -struct ChatRequest { +struct ApiChatRequest { model: String, messages: Vec, temperature: f64, #[serde(skip_serializing_if = "Option::is_none")] stream: Option, + #[serde(skip_serializing_if = "Option::is_none")] + tools: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + tool_choice: Option, } #[derive(Debug, Serialize)] @@ -189,6 +209,13 @@ impl ResponseMessage { _ => self.reasoning_content.clone().unwrap_or_default(), } } + + fn effective_content_optional(&self) -> Option { + match &self.content { + Some(c) if !c.is_empty() => Some(c.clone()), + _ => self.reasoning_content.clone().filter(|c| !c.is_empty()), + } + } } #[derive(Debug, Deserialize, Serialize)] @@ -476,6 +503,12 @@ impl OpenAiCompatibleProvider { #[async_trait] impl Provider for OpenAiCompatibleProvider { + fn capabilities(&self) -> crate::providers::traits::ProviderCapabilities { + crate::providers::traits::ProviderCapabilities { + native_tool_calling: true, + } + } + async fn chat_with_system( &self, system_prompt: Option<&str>, @@ -504,11 +537,13 @@ impl Provider for OpenAiCompatibleProvider { content: message.to_string(), }); - let request = ChatRequest { + let request = ApiChatRequest { model: model.to_string(), messages, temperature, stream: Some(false), + tools: None, + tool_choice: None, }; let url = self.chat_completions_url(); @@ -584,11 +619,13 @@ impl Provider for OpenAiCompatibleProvider { }) .collect(); - let request = ChatRequest { + let request = ApiChatRequest { model: model.to_string(), messages: api_messages, temperature, stream: Some(false), + tools: None, + tool_choice: None, }; let url = self.chat_completions_url(); @@ -651,18 +688,106 @@ impl Provider for OpenAiCompatibleProvider { .ok_or_else(|| anyhow::anyhow!("No response from {}", self.name)) } + async fn chat_with_tools( + &self, + messages: &[ChatMessage], + tools: &[serde_json::Value], + model: &str, + temperature: f64, + ) -> anyhow::Result { + let credential = self.credential.as_ref().ok_or_else(|| { + anyhow::anyhow!( + "{} API key not set. Run `zeroclaw onboard` or set the appropriate env var.", + self.name + ) + })?; + + let api_messages: Vec = messages + .iter() + .map(|m| Message { + role: m.role.clone(), + content: m.content.clone(), + }) + .collect(); + + let request = ApiChatRequest { + model: model.to_string(), + messages: api_messages, + temperature, + stream: Some(false), + tools: if tools.is_empty() { + None + } else { + Some(tools.to_vec()) + }, + tool_choice: if tools.is_empty() { + None + } else { + Some("auto".to_string()) + }, + }; + + let url = self.chat_completions_url(); + let response = self + .apply_auth_header(self.client.post(&url).json(&request), credential) + .send() + .await?; + + if !response.status().is_success() { + return Err(super::api_error(&self.name, response).await); + } + + let chat_response: ApiChatResponse = response.json().await?; + let choice = chat_response + .choices + .into_iter() + .next() + .ok_or_else(|| anyhow::anyhow!("No response from {}", self.name))?; + + let text = choice.message.effective_content_optional(); + let tool_calls = choice + .message + .tool_calls + .unwrap_or_default() + .into_iter() + .filter_map(|tc| { + let function = tc.function?; + let name = function.name?; + let arguments = function.arguments.unwrap_or_else(|| "{}".to_string()); + Some(ProviderToolCall { + id: uuid::Uuid::new_v4().to_string(), + name, + arguments, + }) + }) + .collect::>(); + + Ok(ProviderChatResponse { text, tool_calls }) + } + async fn chat( &self, request: ProviderChatRequest<'_>, model: &str, temperature: f64, ) -> anyhow::Result { + // If native tools are requested, delegate to chat_with_tools. + if let Some(tools) = request.tools { + if !tools.is_empty() && self.supports_native_tools() { + let native_tools = Self::tool_specs_to_openai_format(tools); + return self + .chat_with_tools(request.messages, &native_tools, model, temperature) + .await; + } + } + let text = self .chat_with_history(request.messages, model, temperature) .await?; // Backward compatible path: chat_with_history may serialize tool_calls JSON into content. if let Ok(message) = serde_json::from_str::(&text) { + let parsed_text = message.effective_content_optional(); let tool_calls = message .tool_calls .unwrap_or_default() @@ -680,7 +805,7 @@ impl Provider for OpenAiCompatibleProvider { .collect::>(); return Ok(ProviderChatResponse { - text: message.content, + text: parsed_text, tool_calls, }); } @@ -733,11 +858,13 @@ impl Provider for OpenAiCompatibleProvider { content: message.to_string(), }); - let request = ChatRequest { + let request = ApiChatRequest { model: model.to_string(), messages, temperature, stream: Some(options.enabled), + tools: None, + tool_choice: None, }; let url = self.chat_completions_url(); @@ -863,7 +990,7 @@ mod tests { #[test] fn request_serializes_correctly() { - let req = ChatRequest { + let req = ApiChatRequest { model: "llama-3.3-70b".to_string(), messages: vec![ Message { @@ -877,11 +1004,16 @@ mod tests { ], temperature: 0.4, stream: Some(false), + tools: None, + tool_choice: None, }; let json = serde_json::to_string(&req).unwrap(); assert!(json.contains("llama-3.3-70b")); assert!(json.contains("system")); assert!(json.contains("user")); + // tools/tool_choice should be omitted when None + assert!(!json.contains("tools")); + assert!(!json.contains("tool_choice")); } #[test] @@ -1176,6 +1308,181 @@ mod tests { assert!(result.is_ok()); } + // ══════════════════════════════════════════════════════════ + // Native tool calling tests + // ══════════════════════════════════════════════════════════ + + #[test] + fn capabilities_reports_native_tool_calling() { + let p = make_provider("test", "https://example.com", None); + let caps = ::capabilities(&p); + assert!(caps.native_tool_calling); + } + + #[test] + fn tool_specs_convert_to_openai_format() { + let specs = vec![crate::tools::ToolSpec { + name: "shell".to_string(), + description: "Run shell command".to_string(), + parameters: serde_json::json!({ + "type": "object", + "properties": {"command": {"type": "string"}}, + "required": ["command"] + }), + }]; + + let tools = OpenAiCompatibleProvider::tool_specs_to_openai_format(&specs); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0]["type"], "function"); + assert_eq!(tools[0]["function"]["name"], "shell"); + assert_eq!(tools[0]["function"]["description"], "Run shell command"); + assert_eq!(tools[0]["function"]["parameters"]["required"][0], "command"); + } + + #[test] + fn request_serializes_with_tools() { + let tools = vec![serde_json::json!({ + "type": "function", + "function": { + "name": "get_weather", + "description": "Get weather for a location", + "parameters": { + "type": "object", + "properties": { + "location": {"type": "string"} + } + } + } + })]; + + let req = ApiChatRequest { + model: "test-model".to_string(), + messages: vec![Message { + role: "user".to_string(), + content: "What is the weather?".to_string(), + }], + temperature: 0.7, + stream: Some(false), + tools: Some(tools), + tool_choice: Some("auto".to_string()), + }; + let json = serde_json::to_string(&req).unwrap(); + assert!(json.contains("\"tools\"")); + assert!(json.contains("get_weather")); + assert!(json.contains("\"tool_choice\":\"auto\"")); + } + + #[test] + fn response_with_tool_calls_deserializes() { + let json = r#"{ + "choices": [{ + "message": { + "content": null, + "tool_calls": [{ + "type": "function", + "function": { + "name": "get_weather", + "arguments": "{\"location\":\"London\"}" + } + }] + } + }] + }"#; + + let resp: ApiChatResponse = serde_json::from_str(json).unwrap(); + let msg = &resp.choices[0].message; + assert!(msg.content.is_none()); + let tool_calls = msg.tool_calls.as_ref().unwrap(); + assert_eq!(tool_calls.len(), 1); + assert_eq!( + tool_calls[0].function.as_ref().unwrap().name.as_deref(), + Some("get_weather") + ); + assert_eq!( + tool_calls[0] + .function + .as_ref() + .unwrap() + .arguments + .as_deref(), + Some("{\"location\":\"London\"}") + ); + } + + #[test] + fn response_with_multiple_tool_calls() { + let json = r#"{ + "choices": [{ + "message": { + "content": "I'll check both.", + "tool_calls": [ + { + "type": "function", + "function": { + "name": "get_weather", + "arguments": "{\"location\":\"London\"}" + } + }, + { + "type": "function", + "function": { + "name": "get_time", + "arguments": "{\"timezone\":\"UTC\"}" + } + } + ] + } + }] + }"#; + + let resp: ApiChatResponse = serde_json::from_str(json).unwrap(); + let msg = &resp.choices[0].message; + assert_eq!(msg.content.as_deref(), Some("I'll check both.")); + let tool_calls = msg.tool_calls.as_ref().unwrap(); + assert_eq!(tool_calls.len(), 2); + assert_eq!( + tool_calls[0].function.as_ref().unwrap().name.as_deref(), + Some("get_weather") + ); + assert_eq!( + tool_calls[1].function.as_ref().unwrap().name.as_deref(), + Some("get_time") + ); + } + + #[tokio::test] + async fn chat_with_tools_fails_without_key() { + let p = make_provider("TestProvider", "https://example.com", None); + let messages = vec![ChatMessage { + role: "user".to_string(), + content: "hello".to_string(), + }]; + let tools = vec![serde_json::json!({ + "type": "function", + "function": { + "name": "test_tool", + "description": "A test tool", + "parameters": {} + } + })]; + + let result = p.chat_with_tools(&messages, &tools, "model", 0.7).await; + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("TestProvider API key not set")); + } + + #[test] + fn response_with_no_tool_calls_has_empty_vec() { + let json = r#"{"choices":[{"message":{"content":"Just text, no tools."}}]}"#; + let resp: ApiChatResponse = serde_json::from_str(json).unwrap(); + let msg = &resp.choices[0].message; + assert_eq!(msg.content.as_deref(), Some("Just text, no tools.")); + assert!(msg.tool_calls.is_none()); + } + // ---------------------------------------------------------- // Reasoning model fallback tests (reasoning_content) // ---------------------------------------------------------- diff --git a/src/providers/router.rs b/src/providers/router.rs index 78edde0..2d55869 100644 --- a/src/providers/router.rs +++ b/src/providers/router.rs @@ -137,6 +137,20 @@ impl Provider for RouterProvider { provider.chat(request, &resolved_model, temperature).await } + async fn chat_with_tools( + &self, + messages: &[ChatMessage], + tools: &[serde_json::Value], + model: &str, + temperature: f64, + ) -> anyhow::Result { + let (provider_idx, resolved_model) = self.resolve(model); + let (_, provider) = &self.providers[provider_idx]; + provider + .chat_with_tools(messages, tools, &resolved_model, temperature) + .await + } + fn supports_native_tools(&self) -> bool { self.providers .get(self.default_index) @@ -382,4 +396,63 @@ mod tests { assert_eq!(result, "response"); assert_eq!(mock.call_count(), 1); } + + #[tokio::test] + async fn chat_with_tools_delegates_to_resolved_provider() { + let mock = Arc::new(MockProvider::new("tool-response")); + let router = RouterProvider::new( + vec![( + "default".into(), + Box::new(Arc::clone(&mock)) as Box, + )], + vec![], + "model".into(), + ); + + let messages = vec![ChatMessage { + role: "user".to_string(), + content: "use tools".to_string(), + }]; + let tools = vec![serde_json::json!({ + "type": "function", + "function": { + "name": "shell", + "description": "Run shell command", + "parameters": {} + } + })]; + + // chat_with_tools should delegate through the router to the mock. + // MockProvider's default chat_with_tools calls chat_with_history -> chat_with_system. + let result = router + .chat_with_tools(&messages, &tools, "model", 0.7) + .await + .unwrap(); + assert_eq!(result.text.as_deref(), Some("tool-response")); + assert_eq!(mock.call_count(), 1); + assert_eq!(mock.last_model(), "model"); + } + + #[tokio::test] + async fn chat_with_tools_routes_hint_correctly() { + let (router, mocks) = make_router( + vec![("fast", "fast-tool"), ("smart", "smart-tool")], + vec![("reasoning", "smart", "claude-opus")], + ); + + let messages = vec![ChatMessage { + role: "user".to_string(), + content: "reason about this".to_string(), + }]; + let tools = vec![serde_json::json!({"type": "function", "function": {"name": "test"}})]; + + let result = router + .chat_with_tools(&messages, &tools, "hint:reasoning", 0.5) + .await + .unwrap(); + assert_eq!(result.text.as_deref(), Some("smart-tool")); + assert_eq!(mocks[1].call_count(), 1); + assert_eq!(mocks[1].last_model(), "claude-opus"); + assert_eq!(mocks[0].call_count(), 0); + } }