feat(providers): add native tool calling for OpenAI-compatible providers
Implement chat_with_tools() on CompatibleProvider so OpenAI-compatible endpoints (OpenRouter, local LLMs, etc.) can use structured tool calling instead of prompt-injected tool descriptions. Changes: - CompatibleProvider: capabilities() reports native_tool_calling, new chat_with_tools() sends tools in API request and parses tool_calls from response, chat() bridges to chat_with_tools() when ToolSpecs are provided - RouterProvider: chat_with_tools() delegation with model hint resolution - loop_.rs: expose tools_to_openai_format as pub(crate), add tools_to_openai_format_from_specs for ToolSpec-based conversion Adds 9 new tests and updates 1 existing test.
This commit is contained in:
parent
6acec94666
commit
3b0133596c
3 changed files with 388 additions and 7 deletions
|
|
@ -644,7 +644,8 @@ fn parse_tool_calls(response: &str) -> (String, Vec<ParsedToolCall>) {
|
||||||
remaining = &after_open[close_idx + close_tag.len()..];
|
remaining = &after_open[close_idx + close_tag.len()..];
|
||||||
} else {
|
} else {
|
||||||
if let Some(json_end) = find_json_end(after_open) {
|
if let Some(json_end) = find_json_end(after_open) {
|
||||||
if let Ok(value) = serde_json::from_str::<serde_json::Value>(&after_open[..json_end])
|
if let Ok(value) =
|
||||||
|
serde_json::from_str::<serde_json::Value>(&after_open[..json_end])
|
||||||
{
|
{
|
||||||
let parsed_calls = parse_tool_calls_from_json_value(&value);
|
let parsed_calls = parse_tool_calls_from_json_value(&value);
|
||||||
if !parsed_calls.is_empty() {
|
if !parsed_calls.is_empty() {
|
||||||
|
|
|
||||||
|
|
@ -140,15 +140,35 @@ impl OpenAiCompatibleProvider {
|
||||||
format!("{normalized_base}/v1/responses")
|
format!("{normalized_base}/v1/responses")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn tool_specs_to_openai_format(tools: &[crate::tools::ToolSpec]) -> Vec<serde_json::Value> {
|
||||||
|
tools
|
||||||
|
.iter()
|
||||||
|
.map(|tool| {
|
||||||
|
serde_json::json!({
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": tool.name,
|
||||||
|
"description": tool.description,
|
||||||
|
"parameters": tool.parameters
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Serialize)]
|
||||||
struct ChatRequest {
|
struct ApiChatRequest {
|
||||||
model: String,
|
model: String,
|
||||||
messages: Vec<Message>,
|
messages: Vec<Message>,
|
||||||
temperature: f64,
|
temperature: f64,
|
||||||
#[serde(skip_serializing_if = "Option::is_none")]
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
stream: Option<bool>,
|
stream: Option<bool>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
tools: Option<Vec<serde_json::Value>>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
tool_choice: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Serialize)]
|
||||||
|
|
@ -189,6 +209,13 @@ impl ResponseMessage {
|
||||||
_ => self.reasoning_content.clone().unwrap_or_default(),
|
_ => self.reasoning_content.clone().unwrap_or_default(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn effective_content_optional(&self) -> Option<String> {
|
||||||
|
match &self.content {
|
||||||
|
Some(c) if !c.is_empty() => Some(c.clone()),
|
||||||
|
_ => self.reasoning_content.clone().filter(|c| !c.is_empty()),
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Deserialize, Serialize)]
|
#[derive(Debug, Deserialize, Serialize)]
|
||||||
|
|
@ -476,6 +503,12 @@ impl OpenAiCompatibleProvider {
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl Provider for OpenAiCompatibleProvider {
|
impl Provider for OpenAiCompatibleProvider {
|
||||||
|
fn capabilities(&self) -> crate::providers::traits::ProviderCapabilities {
|
||||||
|
crate::providers::traits::ProviderCapabilities {
|
||||||
|
native_tool_calling: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
async fn chat_with_system(
|
async fn chat_with_system(
|
||||||
&self,
|
&self,
|
||||||
system_prompt: Option<&str>,
|
system_prompt: Option<&str>,
|
||||||
|
|
@ -504,11 +537,13 @@ impl Provider for OpenAiCompatibleProvider {
|
||||||
content: message.to_string(),
|
content: message.to_string(),
|
||||||
});
|
});
|
||||||
|
|
||||||
let request = ChatRequest {
|
let request = ApiChatRequest {
|
||||||
model: model.to_string(),
|
model: model.to_string(),
|
||||||
messages,
|
messages,
|
||||||
temperature,
|
temperature,
|
||||||
stream: Some(false),
|
stream: Some(false),
|
||||||
|
tools: None,
|
||||||
|
tool_choice: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
let url = self.chat_completions_url();
|
let url = self.chat_completions_url();
|
||||||
|
|
@ -584,11 +619,13 @@ impl Provider for OpenAiCompatibleProvider {
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
let request = ChatRequest {
|
let request = ApiChatRequest {
|
||||||
model: model.to_string(),
|
model: model.to_string(),
|
||||||
messages: api_messages,
|
messages: api_messages,
|
||||||
temperature,
|
temperature,
|
||||||
stream: Some(false),
|
stream: Some(false),
|
||||||
|
tools: None,
|
||||||
|
tool_choice: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
let url = self.chat_completions_url();
|
let url = self.chat_completions_url();
|
||||||
|
|
@ -651,18 +688,106 @@ impl Provider for OpenAiCompatibleProvider {
|
||||||
.ok_or_else(|| anyhow::anyhow!("No response from {}", self.name))
|
.ok_or_else(|| anyhow::anyhow!("No response from {}", self.name))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn chat_with_tools(
|
||||||
|
&self,
|
||||||
|
messages: &[ChatMessage],
|
||||||
|
tools: &[serde_json::Value],
|
||||||
|
model: &str,
|
||||||
|
temperature: f64,
|
||||||
|
) -> anyhow::Result<ProviderChatResponse> {
|
||||||
|
let credential = self.credential.as_ref().ok_or_else(|| {
|
||||||
|
anyhow::anyhow!(
|
||||||
|
"{} API key not set. Run `zeroclaw onboard` or set the appropriate env var.",
|
||||||
|
self.name
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let api_messages: Vec<Message> = messages
|
||||||
|
.iter()
|
||||||
|
.map(|m| Message {
|
||||||
|
role: m.role.clone(),
|
||||||
|
content: m.content.clone(),
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let request = ApiChatRequest {
|
||||||
|
model: model.to_string(),
|
||||||
|
messages: api_messages,
|
||||||
|
temperature,
|
||||||
|
stream: Some(false),
|
||||||
|
tools: if tools.is_empty() {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(tools.to_vec())
|
||||||
|
},
|
||||||
|
tool_choice: if tools.is_empty() {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some("auto".to_string())
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
let url = self.chat_completions_url();
|
||||||
|
let response = self
|
||||||
|
.apply_auth_header(self.client.post(&url).json(&request), credential)
|
||||||
|
.send()
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
if !response.status().is_success() {
|
||||||
|
return Err(super::api_error(&self.name, response).await);
|
||||||
|
}
|
||||||
|
|
||||||
|
let chat_response: ApiChatResponse = response.json().await?;
|
||||||
|
let choice = chat_response
|
||||||
|
.choices
|
||||||
|
.into_iter()
|
||||||
|
.next()
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("No response from {}", self.name))?;
|
||||||
|
|
||||||
|
let text = choice.message.effective_content_optional();
|
||||||
|
let tool_calls = choice
|
||||||
|
.message
|
||||||
|
.tool_calls
|
||||||
|
.unwrap_or_default()
|
||||||
|
.into_iter()
|
||||||
|
.filter_map(|tc| {
|
||||||
|
let function = tc.function?;
|
||||||
|
let name = function.name?;
|
||||||
|
let arguments = function.arguments.unwrap_or_else(|| "{}".to_string());
|
||||||
|
Some(ProviderToolCall {
|
||||||
|
id: uuid::Uuid::new_v4().to_string(),
|
||||||
|
name,
|
||||||
|
arguments,
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
Ok(ProviderChatResponse { text, tool_calls })
|
||||||
|
}
|
||||||
|
|
||||||
async fn chat(
|
async fn chat(
|
||||||
&self,
|
&self,
|
||||||
request: ProviderChatRequest<'_>,
|
request: ProviderChatRequest<'_>,
|
||||||
model: &str,
|
model: &str,
|
||||||
temperature: f64,
|
temperature: f64,
|
||||||
) -> anyhow::Result<ProviderChatResponse> {
|
) -> anyhow::Result<ProviderChatResponse> {
|
||||||
|
// If native tools are requested, delegate to chat_with_tools.
|
||||||
|
if let Some(tools) = request.tools {
|
||||||
|
if !tools.is_empty() && self.supports_native_tools() {
|
||||||
|
let native_tools = Self::tool_specs_to_openai_format(tools);
|
||||||
|
return self
|
||||||
|
.chat_with_tools(request.messages, &native_tools, model, temperature)
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
let text = self
|
let text = self
|
||||||
.chat_with_history(request.messages, model, temperature)
|
.chat_with_history(request.messages, model, temperature)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
// Backward compatible path: chat_with_history may serialize tool_calls JSON into content.
|
// Backward compatible path: chat_with_history may serialize tool_calls JSON into content.
|
||||||
if let Ok(message) = serde_json::from_str::<ResponseMessage>(&text) {
|
if let Ok(message) = serde_json::from_str::<ResponseMessage>(&text) {
|
||||||
|
let parsed_text = message.effective_content_optional();
|
||||||
let tool_calls = message
|
let tool_calls = message
|
||||||
.tool_calls
|
.tool_calls
|
||||||
.unwrap_or_default()
|
.unwrap_or_default()
|
||||||
|
|
@ -680,7 +805,7 @@ impl Provider for OpenAiCompatibleProvider {
|
||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
return Ok(ProviderChatResponse {
|
return Ok(ProviderChatResponse {
|
||||||
text: message.content,
|
text: parsed_text,
|
||||||
tool_calls,
|
tool_calls,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
@ -733,11 +858,13 @@ impl Provider for OpenAiCompatibleProvider {
|
||||||
content: message.to_string(),
|
content: message.to_string(),
|
||||||
});
|
});
|
||||||
|
|
||||||
let request = ChatRequest {
|
let request = ApiChatRequest {
|
||||||
model: model.to_string(),
|
model: model.to_string(),
|
||||||
messages,
|
messages,
|
||||||
temperature,
|
temperature,
|
||||||
stream: Some(options.enabled),
|
stream: Some(options.enabled),
|
||||||
|
tools: None,
|
||||||
|
tool_choice: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
let url = self.chat_completions_url();
|
let url = self.chat_completions_url();
|
||||||
|
|
@ -863,7 +990,7 @@ mod tests {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn request_serializes_correctly() {
|
fn request_serializes_correctly() {
|
||||||
let req = ChatRequest {
|
let req = ApiChatRequest {
|
||||||
model: "llama-3.3-70b".to_string(),
|
model: "llama-3.3-70b".to_string(),
|
||||||
messages: vec![
|
messages: vec![
|
||||||
Message {
|
Message {
|
||||||
|
|
@ -877,11 +1004,16 @@ mod tests {
|
||||||
],
|
],
|
||||||
temperature: 0.4,
|
temperature: 0.4,
|
||||||
stream: Some(false),
|
stream: Some(false),
|
||||||
|
tools: None,
|
||||||
|
tool_choice: None,
|
||||||
};
|
};
|
||||||
let json = serde_json::to_string(&req).unwrap();
|
let json = serde_json::to_string(&req).unwrap();
|
||||||
assert!(json.contains("llama-3.3-70b"));
|
assert!(json.contains("llama-3.3-70b"));
|
||||||
assert!(json.contains("system"));
|
assert!(json.contains("system"));
|
||||||
assert!(json.contains("user"));
|
assert!(json.contains("user"));
|
||||||
|
// tools/tool_choice should be omitted when None
|
||||||
|
assert!(!json.contains("tools"));
|
||||||
|
assert!(!json.contains("tool_choice"));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
@ -1176,6 +1308,181 @@ mod tests {
|
||||||
assert!(result.is_ok());
|
assert!(result.is_ok());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ══════════════════════════════════════════════════════════
|
||||||
|
// Native tool calling tests
|
||||||
|
// ══════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn capabilities_reports_native_tool_calling() {
|
||||||
|
let p = make_provider("test", "https://example.com", None);
|
||||||
|
let caps = <OpenAiCompatibleProvider as Provider>::capabilities(&p);
|
||||||
|
assert!(caps.native_tool_calling);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn tool_specs_convert_to_openai_format() {
|
||||||
|
let specs = vec![crate::tools::ToolSpec {
|
||||||
|
name: "shell".to_string(),
|
||||||
|
description: "Run shell command".to_string(),
|
||||||
|
parameters: serde_json::json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"command": {"type": "string"}},
|
||||||
|
"required": ["command"]
|
||||||
|
}),
|
||||||
|
}];
|
||||||
|
|
||||||
|
let tools = OpenAiCompatibleProvider::tool_specs_to_openai_format(&specs);
|
||||||
|
assert_eq!(tools.len(), 1);
|
||||||
|
assert_eq!(tools[0]["type"], "function");
|
||||||
|
assert_eq!(tools[0]["function"]["name"], "shell");
|
||||||
|
assert_eq!(tools[0]["function"]["description"], "Run shell command");
|
||||||
|
assert_eq!(tools[0]["function"]["parameters"]["required"][0], "command");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn request_serializes_with_tools() {
|
||||||
|
let tools = vec![serde_json::json!({
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_weather",
|
||||||
|
"description": "Get weather for a location",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"location": {"type": "string"}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})];
|
||||||
|
|
||||||
|
let req = ApiChatRequest {
|
||||||
|
model: "test-model".to_string(),
|
||||||
|
messages: vec![Message {
|
||||||
|
role: "user".to_string(),
|
||||||
|
content: "What is the weather?".to_string(),
|
||||||
|
}],
|
||||||
|
temperature: 0.7,
|
||||||
|
stream: Some(false),
|
||||||
|
tools: Some(tools),
|
||||||
|
tool_choice: Some("auto".to_string()),
|
||||||
|
};
|
||||||
|
let json = serde_json::to_string(&req).unwrap();
|
||||||
|
assert!(json.contains("\"tools\""));
|
||||||
|
assert!(json.contains("get_weather"));
|
||||||
|
assert!(json.contains("\"tool_choice\":\"auto\""));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn response_with_tool_calls_deserializes() {
|
||||||
|
let json = r#"{
|
||||||
|
"choices": [{
|
||||||
|
"message": {
|
||||||
|
"content": null,
|
||||||
|
"tool_calls": [{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_weather",
|
||||||
|
"arguments": "{\"location\":\"London\"}"
|
||||||
|
}
|
||||||
|
}]
|
||||||
|
}
|
||||||
|
}]
|
||||||
|
}"#;
|
||||||
|
|
||||||
|
let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
|
||||||
|
let msg = &resp.choices[0].message;
|
||||||
|
assert!(msg.content.is_none());
|
||||||
|
let tool_calls = msg.tool_calls.as_ref().unwrap();
|
||||||
|
assert_eq!(tool_calls.len(), 1);
|
||||||
|
assert_eq!(
|
||||||
|
tool_calls[0].function.as_ref().unwrap().name.as_deref(),
|
||||||
|
Some("get_weather")
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
tool_calls[0]
|
||||||
|
.function
|
||||||
|
.as_ref()
|
||||||
|
.unwrap()
|
||||||
|
.arguments
|
||||||
|
.as_deref(),
|
||||||
|
Some("{\"location\":\"London\"}")
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn response_with_multiple_tool_calls() {
|
||||||
|
let json = r#"{
|
||||||
|
"choices": [{
|
||||||
|
"message": {
|
||||||
|
"content": "I'll check both.",
|
||||||
|
"tool_calls": [
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_weather",
|
||||||
|
"arguments": "{\"location\":\"London\"}"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "get_time",
|
||||||
|
"arguments": "{\"timezone\":\"UTC\"}"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}]
|
||||||
|
}"#;
|
||||||
|
|
||||||
|
let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
|
||||||
|
let msg = &resp.choices[0].message;
|
||||||
|
assert_eq!(msg.content.as_deref(), Some("I'll check both."));
|
||||||
|
let tool_calls = msg.tool_calls.as_ref().unwrap();
|
||||||
|
assert_eq!(tool_calls.len(), 2);
|
||||||
|
assert_eq!(
|
||||||
|
tool_calls[0].function.as_ref().unwrap().name.as_deref(),
|
||||||
|
Some("get_weather")
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
tool_calls[1].function.as_ref().unwrap().name.as_deref(),
|
||||||
|
Some("get_time")
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn chat_with_tools_fails_without_key() {
|
||||||
|
let p = make_provider("TestProvider", "https://example.com", None);
|
||||||
|
let messages = vec![ChatMessage {
|
||||||
|
role: "user".to_string(),
|
||||||
|
content: "hello".to_string(),
|
||||||
|
}];
|
||||||
|
let tools = vec![serde_json::json!({
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "test_tool",
|
||||||
|
"description": "A test tool",
|
||||||
|
"parameters": {}
|
||||||
|
}
|
||||||
|
})];
|
||||||
|
|
||||||
|
let result = p.chat_with_tools(&messages, &tools, "model", 0.7).await;
|
||||||
|
assert!(result.is_err());
|
||||||
|
assert!(result
|
||||||
|
.unwrap_err()
|
||||||
|
.to_string()
|
||||||
|
.contains("TestProvider API key not set"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn response_with_no_tool_calls_has_empty_vec() {
|
||||||
|
let json = r#"{"choices":[{"message":{"content":"Just text, no tools."}}]}"#;
|
||||||
|
let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
|
||||||
|
let msg = &resp.choices[0].message;
|
||||||
|
assert_eq!(msg.content.as_deref(), Some("Just text, no tools."));
|
||||||
|
assert!(msg.tool_calls.is_none());
|
||||||
|
}
|
||||||
|
|
||||||
// ----------------------------------------------------------
|
// ----------------------------------------------------------
|
||||||
// Reasoning model fallback tests (reasoning_content)
|
// Reasoning model fallback tests (reasoning_content)
|
||||||
// ----------------------------------------------------------
|
// ----------------------------------------------------------
|
||||||
|
|
|
||||||
|
|
@ -137,6 +137,20 @@ impl Provider for RouterProvider {
|
||||||
provider.chat(request, &resolved_model, temperature).await
|
provider.chat(request, &resolved_model, temperature).await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn chat_with_tools(
|
||||||
|
&self,
|
||||||
|
messages: &[ChatMessage],
|
||||||
|
tools: &[serde_json::Value],
|
||||||
|
model: &str,
|
||||||
|
temperature: f64,
|
||||||
|
) -> anyhow::Result<ChatResponse> {
|
||||||
|
let (provider_idx, resolved_model) = self.resolve(model);
|
||||||
|
let (_, provider) = &self.providers[provider_idx];
|
||||||
|
provider
|
||||||
|
.chat_with_tools(messages, tools, &resolved_model, temperature)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
|
||||||
fn supports_native_tools(&self) -> bool {
|
fn supports_native_tools(&self) -> bool {
|
||||||
self.providers
|
self.providers
|
||||||
.get(self.default_index)
|
.get(self.default_index)
|
||||||
|
|
@ -382,4 +396,63 @@ mod tests {
|
||||||
assert_eq!(result, "response");
|
assert_eq!(result, "response");
|
||||||
assert_eq!(mock.call_count(), 1);
|
assert_eq!(mock.call_count(), 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn chat_with_tools_delegates_to_resolved_provider() {
|
||||||
|
let mock = Arc::new(MockProvider::new("tool-response"));
|
||||||
|
let router = RouterProvider::new(
|
||||||
|
vec![(
|
||||||
|
"default".into(),
|
||||||
|
Box::new(Arc::clone(&mock)) as Box<dyn Provider>,
|
||||||
|
)],
|
||||||
|
vec![],
|
||||||
|
"model".into(),
|
||||||
|
);
|
||||||
|
|
||||||
|
let messages = vec![ChatMessage {
|
||||||
|
role: "user".to_string(),
|
||||||
|
content: "use tools".to_string(),
|
||||||
|
}];
|
||||||
|
let tools = vec![serde_json::json!({
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "shell",
|
||||||
|
"description": "Run shell command",
|
||||||
|
"parameters": {}
|
||||||
|
}
|
||||||
|
})];
|
||||||
|
|
||||||
|
// chat_with_tools should delegate through the router to the mock.
|
||||||
|
// MockProvider's default chat_with_tools calls chat_with_history -> chat_with_system.
|
||||||
|
let result = router
|
||||||
|
.chat_with_tools(&messages, &tools, "model", 0.7)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(result.text.as_deref(), Some("tool-response"));
|
||||||
|
assert_eq!(mock.call_count(), 1);
|
||||||
|
assert_eq!(mock.last_model(), "model");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn chat_with_tools_routes_hint_correctly() {
|
||||||
|
let (router, mocks) = make_router(
|
||||||
|
vec![("fast", "fast-tool"), ("smart", "smart-tool")],
|
||||||
|
vec![("reasoning", "smart", "claude-opus")],
|
||||||
|
);
|
||||||
|
|
||||||
|
let messages = vec![ChatMessage {
|
||||||
|
role: "user".to_string(),
|
||||||
|
content: "reason about this".to_string(),
|
||||||
|
}];
|
||||||
|
let tools = vec![serde_json::json!({"type": "function", "function": {"name": "test"}})];
|
||||||
|
|
||||||
|
let result = router
|
||||||
|
.chat_with_tools(&messages, &tools, "hint:reasoning", 0.5)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(result.text.as_deref(), Some("smart-tool"));
|
||||||
|
assert_eq!(mocks[1].call_count(), 1);
|
||||||
|
assert_eq!(mocks[1].last_model(), "claude-opus");
|
||||||
|
assert_eq!(mocks[0].call_count(), 0);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue