fix(provider): enable native tool calling for OllamaProvider

This commit is contained in:
reidliu41 2026-02-18 22:04:24 +08:00 committed by Chummy
parent d548caa5f3
commit cd59dc65c4

View file

@ -1,4 +1,4 @@
use crate::providers::traits::Provider; use crate::providers::traits::{ChatMessage, ChatResponse, Provider, ToolCall};
use async_trait::async_trait; use async_trait::async_trait;
use reqwest::Client; use reqwest::Client;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -16,6 +16,8 @@ struct ChatRequest {
messages: Vec<Message>, messages: Vec<Message>,
stream: bool, stream: bool,
options: Options, options: Options,
#[serde(skip_serializing_if = "Option::is_none")]
tools: Option<Vec<serde_json::Value>>,
} }
#[derive(Debug, Serialize)] #[derive(Debug, Serialize)]
@ -112,29 +114,33 @@ impl OllamaProvider {
Ok((normalized_model, should_auth)) Ok((normalized_model, should_auth))
} }
/// Send a request to Ollama and get the parsed response /// Send a request to Ollama and get the parsed response.
/// Pass `tools` to enable native function-calling for models that support it.
async fn send_request( async fn send_request(
&self, &self,
messages: Vec<Message>, messages: Vec<Message>,
model: &str, model: &str,
temperature: f64, temperature: f64,
should_auth: bool, should_auth: bool,
tools: Option<&[serde_json::Value]>,
) -> anyhow::Result<ApiChatResponse> { ) -> anyhow::Result<ApiChatResponse> {
let request = ChatRequest { let request = ChatRequest {
model: model.to_string(), model: model.to_string(),
messages, messages,
stream: false, stream: false,
options: Options { temperature }, options: Options { temperature },
tools: tools.map(|t| t.to_vec()),
}; };
let url = format!("{}/api/chat", self.base_url); let url = format!("{}/api/chat", self.base_url);
tracing::debug!( tracing::debug!(
"Ollama request: url={} model={} message_count={} temperature={}", "Ollama request: url={} model={} message_count={} temperature={} tool_count={}",
url, url,
model, model,
request.messages.len(), request.messages.len(),
temperature temperature,
request.tools.as_ref().map_or(0, |t| t.len()),
); );
let mut request_builder = self.http_client().post(&url).json(&request); let mut request_builder = self.http_client().post(&url).json(&request);
@ -281,7 +287,7 @@ impl Provider for OllamaProvider {
}); });
let response = self let response = self
.send_request(messages, &normalized_model, temperature, should_auth) .send_request(messages, &normalized_model, temperature, should_auth, None)
.await?; .await?;
// If model returned tool calls, format them for loop_.rs's parse_tool_calls // If model returned tool calls, format them for loop_.rs's parse_tool_calls
@ -331,7 +337,7 @@ impl Provider for OllamaProvider {
.collect(); .collect();
let response = self let response = self
.send_request(api_messages, &normalized_model, temperature, should_auth) .send_request(api_messages, &normalized_model, temperature, should_auth, None)
.await?; .await?;
// If model returned tool calls, format them for loop_.rs's parse_tool_calls // If model returned tool calls, format them for loop_.rs's parse_tool_calls
@ -343,6 +349,7 @@ impl Provider for OllamaProvider {
return Ok(self.format_tool_calls_for_loop(&response.message.tool_calls)); return Ok(self.format_tool_calls_for_loop(&response.message.tool_calls));
} }
// Plain text response // Plain text response
let content = response.message.content; let content = response.message.content;
@ -366,11 +373,88 @@ impl Provider for OllamaProvider {
Ok(content) Ok(content)
} }
async fn chat_with_tools(
&self,
messages: &[ChatMessage],
tools: &[serde_json::Value],
model: &str,
temperature: f64,
) -> anyhow::Result<ChatResponse> {
let (normalized_model, should_auth) = self.resolve_request_details(model)?;
let api_messages: Vec<Message> = messages
.iter()
.map(|m| Message {
role: m.role.clone(),
content: m.content.clone(),
})
.collect();
// Tools arrive pre-formatted in OpenAI/Ollama-compatible JSON from
// tools_to_openai_format() in loop_.rs — pass them through directly.
let tools_opt = if tools.is_empty() { None } else { Some(tools) };
let response = self
.send_request(api_messages, &normalized_model, temperature, should_auth, tools_opt)
.await?;
// Native tool calls returned by the model.
if !response.message.tool_calls.is_empty() {
let tool_calls: Vec<ToolCall> = response
.message
.tool_calls
.iter()
.map(|tc| {
let (name, args) = self.extract_tool_name_and_args(tc);
ToolCall {
id: tc
.id
.clone()
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string()),
name,
arguments: serde_json::to_string(&args)
.unwrap_or_else(|_| "{}".to_string()),
}
})
.collect();
let text = if response.message.content.is_empty() {
None
} else {
Some(response.message.content)
};
return Ok(ChatResponse { text, tool_calls });
}
// Plain text response.
let content = response.message.content;
if content.is_empty() {
if let Some(thinking) = &response.message.thinking {
tracing::warn!(
"Ollama returned empty content with only thinking: '{}'. Model may have stopped prematurely.",
if thinking.len() > 100 { &thinking[..100] } else { thinking }
);
return Ok(ChatResponse {
text: Some(format!(
"I was thinking about this: {}... but I didn't complete my response. Could you try asking again?",
if thinking.len() > 200 { &thinking[..200] } else { thinking }
)),
tool_calls: vec![],
});
}
tracing::warn!("Ollama returned empty content with no tool calls");
}
Ok(ChatResponse {
text: Some(content),
tool_calls: vec![],
})
}
fn supports_native_tools(&self) -> bool { fn supports_native_tools(&self) -> bool {
// Return false since loop_.rs uses XML-style tool parsing via system prompt // Ollama's /api/chat supports native function-calling for capable models
// The model may return native tool_calls but we convert them to JSON format // (qwen2.5, llama3.1, mistral-nemo, etc.). chat_with_tools() sends tool
// that parse_tool_calls() understands // definitions in the request and returns structured ToolCall objects.
false true
} }
} }
@ -557,4 +641,4 @@ mod tests {
// arguments should be a string (JSON-encoded) // arguments should be a string (JSON-encoded)
assert!(func.get("arguments").unwrap().is_string()); assert!(func.get("arguments").unwrap().is_string());
} }
} }