feat(providers): add native tool calling for OpenAI-compatible providers

Implement chat_with_tools() on CompatibleProvider so OpenAI-compatible
endpoints (OpenRouter, local LLMs, etc.) can use structured tool calling
instead of prompt-injected tool descriptions.

Changes:
- CompatibleProvider: capabilities() reports native_tool_calling, new
  chat_with_tools() sends tools in API request and parses tool_calls
  from response, chat() bridges to chat_with_tools() when ToolSpecs
  are provided
- RouterProvider: chat_with_tools() delegation with model hint resolution
- loop_.rs: expose tools_to_openai_format as pub(crate), add
  tools_to_openai_format_from_specs for ToolSpec-based conversion

Adds 9 new tests and updates 1 existing test.
This commit is contained in:
Vernon Stinebaker 2026-02-18 17:15:02 +08:00 committed by Chummy
parent 6acec94666
commit 3b0133596c
3 changed files with 388 additions and 7 deletions

View file

@ -644,7 +644,8 @@ fn parse_tool_calls(response: &str) -> (String, Vec<ParsedToolCall>) {
remaining = &after_open[close_idx + close_tag.len()..]; remaining = &after_open[close_idx + close_tag.len()..];
} else { } else {
if let Some(json_end) = find_json_end(after_open) { if let Some(json_end) = find_json_end(after_open) {
if let Ok(value) = serde_json::from_str::<serde_json::Value>(&after_open[..json_end]) if let Ok(value) =
serde_json::from_str::<serde_json::Value>(&after_open[..json_end])
{ {
let parsed_calls = parse_tool_calls_from_json_value(&value); let parsed_calls = parse_tool_calls_from_json_value(&value);
if !parsed_calls.is_empty() { if !parsed_calls.is_empty() {

View file

@ -140,15 +140,35 @@ impl OpenAiCompatibleProvider {
format!("{normalized_base}/v1/responses") format!("{normalized_base}/v1/responses")
} }
} }
fn tool_specs_to_openai_format(tools: &[crate::tools::ToolSpec]) -> Vec<serde_json::Value> {
tools
.iter()
.map(|tool| {
serde_json::json!({
"type": "function",
"function": {
"name": tool.name,
"description": tool.description,
"parameters": tool.parameters
}
})
})
.collect()
}
} }
#[derive(Debug, Serialize)] #[derive(Debug, Serialize)]
struct ChatRequest { struct ApiChatRequest {
model: String, model: String,
messages: Vec<Message>, messages: Vec<Message>,
temperature: f64, temperature: f64,
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
stream: Option<bool>, stream: Option<bool>,
#[serde(skip_serializing_if = "Option::is_none")]
tools: Option<Vec<serde_json::Value>>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_choice: Option<String>,
} }
#[derive(Debug, Serialize)] #[derive(Debug, Serialize)]
@ -189,6 +209,13 @@ impl ResponseMessage {
_ => self.reasoning_content.clone().unwrap_or_default(), _ => self.reasoning_content.clone().unwrap_or_default(),
} }
} }
fn effective_content_optional(&self) -> Option<String> {
match &self.content {
Some(c) if !c.is_empty() => Some(c.clone()),
_ => self.reasoning_content.clone().filter(|c| !c.is_empty()),
}
}
} }
#[derive(Debug, Deserialize, Serialize)] #[derive(Debug, Deserialize, Serialize)]
@ -476,6 +503,12 @@ impl OpenAiCompatibleProvider {
#[async_trait] #[async_trait]
impl Provider for OpenAiCompatibleProvider { impl Provider for OpenAiCompatibleProvider {
fn capabilities(&self) -> crate::providers::traits::ProviderCapabilities {
crate::providers::traits::ProviderCapabilities {
native_tool_calling: true,
}
}
async fn chat_with_system( async fn chat_with_system(
&self, &self,
system_prompt: Option<&str>, system_prompt: Option<&str>,
@ -504,11 +537,13 @@ impl Provider for OpenAiCompatibleProvider {
content: message.to_string(), content: message.to_string(),
}); });
let request = ChatRequest { let request = ApiChatRequest {
model: model.to_string(), model: model.to_string(),
messages, messages,
temperature, temperature,
stream: Some(false), stream: Some(false),
tools: None,
tool_choice: None,
}; };
let url = self.chat_completions_url(); let url = self.chat_completions_url();
@ -584,11 +619,13 @@ impl Provider for OpenAiCompatibleProvider {
}) })
.collect(); .collect();
let request = ChatRequest { let request = ApiChatRequest {
model: model.to_string(), model: model.to_string(),
messages: api_messages, messages: api_messages,
temperature, temperature,
stream: Some(false), stream: Some(false),
tools: None,
tool_choice: None,
}; };
let url = self.chat_completions_url(); let url = self.chat_completions_url();
@ -651,18 +688,106 @@ impl Provider for OpenAiCompatibleProvider {
.ok_or_else(|| anyhow::anyhow!("No response from {}", self.name)) .ok_or_else(|| anyhow::anyhow!("No response from {}", self.name))
} }
async fn chat_with_tools(
&self,
messages: &[ChatMessage],
tools: &[serde_json::Value],
model: &str,
temperature: f64,
) -> anyhow::Result<ProviderChatResponse> {
let credential = self.credential.as_ref().ok_or_else(|| {
anyhow::anyhow!(
"{} API key not set. Run `zeroclaw onboard` or set the appropriate env var.",
self.name
)
})?;
let api_messages: Vec<Message> = messages
.iter()
.map(|m| Message {
role: m.role.clone(),
content: m.content.clone(),
})
.collect();
let request = ApiChatRequest {
model: model.to_string(),
messages: api_messages,
temperature,
stream: Some(false),
tools: if tools.is_empty() {
None
} else {
Some(tools.to_vec())
},
tool_choice: if tools.is_empty() {
None
} else {
Some("auto".to_string())
},
};
let url = self.chat_completions_url();
let response = self
.apply_auth_header(self.client.post(&url).json(&request), credential)
.send()
.await?;
if !response.status().is_success() {
return Err(super::api_error(&self.name, response).await);
}
let chat_response: ApiChatResponse = response.json().await?;
let choice = chat_response
.choices
.into_iter()
.next()
.ok_or_else(|| anyhow::anyhow!("No response from {}", self.name))?;
let text = choice.message.effective_content_optional();
let tool_calls = choice
.message
.tool_calls
.unwrap_or_default()
.into_iter()
.filter_map(|tc| {
let function = tc.function?;
let name = function.name?;
let arguments = function.arguments.unwrap_or_else(|| "{}".to_string());
Some(ProviderToolCall {
id: uuid::Uuid::new_v4().to_string(),
name,
arguments,
})
})
.collect::<Vec<_>>();
Ok(ProviderChatResponse { text, tool_calls })
}
async fn chat( async fn chat(
&self, &self,
request: ProviderChatRequest<'_>, request: ProviderChatRequest<'_>,
model: &str, model: &str,
temperature: f64, temperature: f64,
) -> anyhow::Result<ProviderChatResponse> { ) -> anyhow::Result<ProviderChatResponse> {
// If native tools are requested, delegate to chat_with_tools.
if let Some(tools) = request.tools {
if !tools.is_empty() && self.supports_native_tools() {
let native_tools = Self::tool_specs_to_openai_format(tools);
return self
.chat_with_tools(request.messages, &native_tools, model, temperature)
.await;
}
}
let text = self let text = self
.chat_with_history(request.messages, model, temperature) .chat_with_history(request.messages, model, temperature)
.await?; .await?;
// Backward compatible path: chat_with_history may serialize tool_calls JSON into content. // Backward compatible path: chat_with_history may serialize tool_calls JSON into content.
if let Ok(message) = serde_json::from_str::<ResponseMessage>(&text) { if let Ok(message) = serde_json::from_str::<ResponseMessage>(&text) {
let parsed_text = message.effective_content_optional();
let tool_calls = message let tool_calls = message
.tool_calls .tool_calls
.unwrap_or_default() .unwrap_or_default()
@ -680,7 +805,7 @@ impl Provider for OpenAiCompatibleProvider {
.collect::<Vec<_>>(); .collect::<Vec<_>>();
return Ok(ProviderChatResponse { return Ok(ProviderChatResponse {
text: message.content, text: parsed_text,
tool_calls, tool_calls,
}); });
} }
@ -733,11 +858,13 @@ impl Provider for OpenAiCompatibleProvider {
content: message.to_string(), content: message.to_string(),
}); });
let request = ChatRequest { let request = ApiChatRequest {
model: model.to_string(), model: model.to_string(),
messages, messages,
temperature, temperature,
stream: Some(options.enabled), stream: Some(options.enabled),
tools: None,
tool_choice: None,
}; };
let url = self.chat_completions_url(); let url = self.chat_completions_url();
@ -863,7 +990,7 @@ mod tests {
#[test] #[test]
fn request_serializes_correctly() { fn request_serializes_correctly() {
let req = ChatRequest { let req = ApiChatRequest {
model: "llama-3.3-70b".to_string(), model: "llama-3.3-70b".to_string(),
messages: vec![ messages: vec![
Message { Message {
@ -877,11 +1004,16 @@ mod tests {
], ],
temperature: 0.4, temperature: 0.4,
stream: Some(false), stream: Some(false),
tools: None,
tool_choice: None,
}; };
let json = serde_json::to_string(&req).unwrap(); let json = serde_json::to_string(&req).unwrap();
assert!(json.contains("llama-3.3-70b")); assert!(json.contains("llama-3.3-70b"));
assert!(json.contains("system")); assert!(json.contains("system"));
assert!(json.contains("user")); assert!(json.contains("user"));
// tools/tool_choice should be omitted when None
assert!(!json.contains("tools"));
assert!(!json.contains("tool_choice"));
} }
#[test] #[test]
@ -1176,6 +1308,181 @@ mod tests {
assert!(result.is_ok()); assert!(result.is_ok());
} }
// ══════════════════════════════════════════════════════════
// Native tool calling tests
// ══════════════════════════════════════════════════════════
#[test]
fn capabilities_reports_native_tool_calling() {
let p = make_provider("test", "https://example.com", None);
let caps = <OpenAiCompatibleProvider as Provider>::capabilities(&p);
assert!(caps.native_tool_calling);
}
#[test]
fn tool_specs_convert_to_openai_format() {
let specs = vec![crate::tools::ToolSpec {
name: "shell".to_string(),
description: "Run shell command".to_string(),
parameters: serde_json::json!({
"type": "object",
"properties": {"command": {"type": "string"}},
"required": ["command"]
}),
}];
let tools = OpenAiCompatibleProvider::tool_specs_to_openai_format(&specs);
assert_eq!(tools.len(), 1);
assert_eq!(tools[0]["type"], "function");
assert_eq!(tools[0]["function"]["name"], "shell");
assert_eq!(tools[0]["function"]["description"], "Run shell command");
assert_eq!(tools[0]["function"]["parameters"]["required"][0], "command");
}
#[test]
fn request_serializes_with_tools() {
let tools = vec![serde_json::json!({
"type": "function",
"function": {
"name": "get_weather",
"description": "Get weather for a location",
"parameters": {
"type": "object",
"properties": {
"location": {"type": "string"}
}
}
}
})];
let req = ApiChatRequest {
model: "test-model".to_string(),
messages: vec![Message {
role: "user".to_string(),
content: "What is the weather?".to_string(),
}],
temperature: 0.7,
stream: Some(false),
tools: Some(tools),
tool_choice: Some("auto".to_string()),
};
let json = serde_json::to_string(&req).unwrap();
assert!(json.contains("\"tools\""));
assert!(json.contains("get_weather"));
assert!(json.contains("\"tool_choice\":\"auto\""));
}
#[test]
fn response_with_tool_calls_deserializes() {
let json = r#"{
"choices": [{
"message": {
"content": null,
"tool_calls": [{
"type": "function",
"function": {
"name": "get_weather",
"arguments": "{\"location\":\"London\"}"
}
}]
}
}]
}"#;
let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
let msg = &resp.choices[0].message;
assert!(msg.content.is_none());
let tool_calls = msg.tool_calls.as_ref().unwrap();
assert_eq!(tool_calls.len(), 1);
assert_eq!(
tool_calls[0].function.as_ref().unwrap().name.as_deref(),
Some("get_weather")
);
assert_eq!(
tool_calls[0]
.function
.as_ref()
.unwrap()
.arguments
.as_deref(),
Some("{\"location\":\"London\"}")
);
}
#[test]
fn response_with_multiple_tool_calls() {
let json = r#"{
"choices": [{
"message": {
"content": "I'll check both.",
"tool_calls": [
{
"type": "function",
"function": {
"name": "get_weather",
"arguments": "{\"location\":\"London\"}"
}
},
{
"type": "function",
"function": {
"name": "get_time",
"arguments": "{\"timezone\":\"UTC\"}"
}
}
]
}
}]
}"#;
let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
let msg = &resp.choices[0].message;
assert_eq!(msg.content.as_deref(), Some("I'll check both."));
let tool_calls = msg.tool_calls.as_ref().unwrap();
assert_eq!(tool_calls.len(), 2);
assert_eq!(
tool_calls[0].function.as_ref().unwrap().name.as_deref(),
Some("get_weather")
);
assert_eq!(
tool_calls[1].function.as_ref().unwrap().name.as_deref(),
Some("get_time")
);
}
#[tokio::test]
async fn chat_with_tools_fails_without_key() {
let p = make_provider("TestProvider", "https://example.com", None);
let messages = vec![ChatMessage {
role: "user".to_string(),
content: "hello".to_string(),
}];
let tools = vec![serde_json::json!({
"type": "function",
"function": {
"name": "test_tool",
"description": "A test tool",
"parameters": {}
}
})];
let result = p.chat_with_tools(&messages, &tools, "model", 0.7).await;
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("TestProvider API key not set"));
}
#[test]
fn response_with_no_tool_calls_has_empty_vec() {
let json = r#"{"choices":[{"message":{"content":"Just text, no tools."}}]}"#;
let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
let msg = &resp.choices[0].message;
assert_eq!(msg.content.as_deref(), Some("Just text, no tools."));
assert!(msg.tool_calls.is_none());
}
// ---------------------------------------------------------- // ----------------------------------------------------------
// Reasoning model fallback tests (reasoning_content) // Reasoning model fallback tests (reasoning_content)
// ---------------------------------------------------------- // ----------------------------------------------------------

View file

@ -137,6 +137,20 @@ impl Provider for RouterProvider {
provider.chat(request, &resolved_model, temperature).await provider.chat(request, &resolved_model, temperature).await
} }
async fn chat_with_tools(
&self,
messages: &[ChatMessage],
tools: &[serde_json::Value],
model: &str,
temperature: f64,
) -> anyhow::Result<ChatResponse> {
let (provider_idx, resolved_model) = self.resolve(model);
let (_, provider) = &self.providers[provider_idx];
provider
.chat_with_tools(messages, tools, &resolved_model, temperature)
.await
}
fn supports_native_tools(&self) -> bool { fn supports_native_tools(&self) -> bool {
self.providers self.providers
.get(self.default_index) .get(self.default_index)
@ -382,4 +396,63 @@ mod tests {
assert_eq!(result, "response"); assert_eq!(result, "response");
assert_eq!(mock.call_count(), 1); assert_eq!(mock.call_count(), 1);
} }
#[tokio::test]
async fn chat_with_tools_delegates_to_resolved_provider() {
let mock = Arc::new(MockProvider::new("tool-response"));
let router = RouterProvider::new(
vec![(
"default".into(),
Box::new(Arc::clone(&mock)) as Box<dyn Provider>,
)],
vec![],
"model".into(),
);
let messages = vec![ChatMessage {
role: "user".to_string(),
content: "use tools".to_string(),
}];
let tools = vec![serde_json::json!({
"type": "function",
"function": {
"name": "shell",
"description": "Run shell command",
"parameters": {}
}
})];
// chat_with_tools should delegate through the router to the mock.
// MockProvider's default chat_with_tools calls chat_with_history -> chat_with_system.
let result = router
.chat_with_tools(&messages, &tools, "model", 0.7)
.await
.unwrap();
assert_eq!(result.text.as_deref(), Some("tool-response"));
assert_eq!(mock.call_count(), 1);
assert_eq!(mock.last_model(), "model");
}
#[tokio::test]
async fn chat_with_tools_routes_hint_correctly() {
let (router, mocks) = make_router(
vec![("fast", "fast-tool"), ("smart", "smart-tool")],
vec![("reasoning", "smart", "claude-opus")],
);
let messages = vec![ChatMessage {
role: "user".to_string(),
content: "reason about this".to_string(),
}];
let tools = vec![serde_json::json!({"type": "function", "function": {"name": "test"}})];
let result = router
.chat_with_tools(&messages, &tools, "hint:reasoning", 0.5)
.await
.unwrap();
assert_eq!(result.text.as_deref(), Some("smart-tool"));
assert_eq!(mocks[1].call_count(), 1);
assert_eq!(mocks[1].last_model(), "claude-opus");
assert_eq!(mocks[0].call_count(), 0);
}
} }