fix(provider): enable native tool calling for OllamaProvider
This commit is contained in:
parent
d548caa5f3
commit
cd59dc65c4
1 changed files with 95 additions and 11 deletions
|
|
@ -1,4 +1,4 @@
|
|||
use crate::providers::traits::Provider;
|
||||
use crate::providers::traits::{ChatMessage, ChatResponse, Provider, ToolCall};
|
||||
use async_trait::async_trait;
|
||||
use reqwest::Client;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
|
@ -16,6 +16,8 @@ struct ChatRequest {
|
|||
messages: Vec<Message>,
|
||||
stream: bool,
|
||||
options: Options,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
tools: Option<Vec<serde_json::Value>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
|
|
@ -112,29 +114,33 @@ impl OllamaProvider {
|
|||
Ok((normalized_model, should_auth))
|
||||
}
|
||||
|
||||
/// Send a request to Ollama and get the parsed response
|
||||
/// Send a request to Ollama and get the parsed response.
|
||||
/// Pass `tools` to enable native function-calling for models that support it.
|
||||
async fn send_request(
|
||||
&self,
|
||||
messages: Vec<Message>,
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
should_auth: bool,
|
||||
tools: Option<&[serde_json::Value]>,
|
||||
) -> anyhow::Result<ApiChatResponse> {
|
||||
let request = ChatRequest {
|
||||
model: model.to_string(),
|
||||
messages,
|
||||
stream: false,
|
||||
options: Options { temperature },
|
||||
tools: tools.map(|t| t.to_vec()),
|
||||
};
|
||||
|
||||
let url = format!("{}/api/chat", self.base_url);
|
||||
|
||||
tracing::debug!(
|
||||
"Ollama request: url={} model={} message_count={} temperature={}",
|
||||
"Ollama request: url={} model={} message_count={} temperature={} tool_count={}",
|
||||
url,
|
||||
model,
|
||||
request.messages.len(),
|
||||
temperature
|
||||
temperature,
|
||||
request.tools.as_ref().map_or(0, |t| t.len()),
|
||||
);
|
||||
|
||||
let mut request_builder = self.http_client().post(&url).json(&request);
|
||||
|
|
@ -281,7 +287,7 @@ impl Provider for OllamaProvider {
|
|||
});
|
||||
|
||||
let response = self
|
||||
.send_request(messages, &normalized_model, temperature, should_auth)
|
||||
.send_request(messages, &normalized_model, temperature, should_auth, None)
|
||||
.await?;
|
||||
|
||||
// If model returned tool calls, format them for loop_.rs's parse_tool_calls
|
||||
|
|
@ -331,7 +337,7 @@ impl Provider for OllamaProvider {
|
|||
.collect();
|
||||
|
||||
let response = self
|
||||
.send_request(api_messages, &normalized_model, temperature, should_auth)
|
||||
.send_request(api_messages, &normalized_model, temperature, should_auth, None)
|
||||
.await?;
|
||||
|
||||
// If model returned tool calls, format them for loop_.rs's parse_tool_calls
|
||||
|
|
@ -343,6 +349,7 @@ impl Provider for OllamaProvider {
|
|||
return Ok(self.format_tool_calls_for_loop(&response.message.tool_calls));
|
||||
}
|
||||
|
||||
|
||||
// Plain text response
|
||||
let content = response.message.content;
|
||||
|
||||
|
|
@ -366,11 +373,88 @@ impl Provider for OllamaProvider {
|
|||
Ok(content)
|
||||
}
|
||||
|
||||
async fn chat_with_tools(
|
||||
&self,
|
||||
messages: &[ChatMessage],
|
||||
tools: &[serde_json::Value],
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<ChatResponse> {
|
||||
let (normalized_model, should_auth) = self.resolve_request_details(model)?;
|
||||
|
||||
let api_messages: Vec<Message> = messages
|
||||
.iter()
|
||||
.map(|m| Message {
|
||||
role: m.role.clone(),
|
||||
content: m.content.clone(),
|
||||
})
|
||||
.collect();
|
||||
|
||||
// Tools arrive pre-formatted in OpenAI/Ollama-compatible JSON from
|
||||
// tools_to_openai_format() in loop_.rs — pass them through directly.
|
||||
let tools_opt = if tools.is_empty() { None } else { Some(tools) };
|
||||
|
||||
let response = self
|
||||
.send_request(api_messages, &normalized_model, temperature, should_auth, tools_opt)
|
||||
.await?;
|
||||
|
||||
// Native tool calls returned by the model.
|
||||
if !response.message.tool_calls.is_empty() {
|
||||
let tool_calls: Vec<ToolCall> = response
|
||||
.message
|
||||
.tool_calls
|
||||
.iter()
|
||||
.map(|tc| {
|
||||
let (name, args) = self.extract_tool_name_and_args(tc);
|
||||
ToolCall {
|
||||
id: tc
|
||||
.id
|
||||
.clone()
|
||||
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string()),
|
||||
name,
|
||||
arguments: serde_json::to_string(&args)
|
||||
.unwrap_or_else(|_| "{}".to_string()),
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
let text = if response.message.content.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(response.message.content)
|
||||
};
|
||||
return Ok(ChatResponse { text, tool_calls });
|
||||
}
|
||||
|
||||
|
||||
// Plain text response.
|
||||
let content = response.message.content;
|
||||
if content.is_empty() {
|
||||
if let Some(thinking) = &response.message.thinking {
|
||||
tracing::warn!(
|
||||
"Ollama returned empty content with only thinking: '{}'. Model may have stopped prematurely.",
|
||||
if thinking.len() > 100 { &thinking[..100] } else { thinking }
|
||||
);
|
||||
return Ok(ChatResponse {
|
||||
text: Some(format!(
|
||||
"I was thinking about this: {}... but I didn't complete my response. Could you try asking again?",
|
||||
if thinking.len() > 200 { &thinking[..200] } else { thinking }
|
||||
)),
|
||||
tool_calls: vec![],
|
||||
});
|
||||
}
|
||||
tracing::warn!("Ollama returned empty content with no tool calls");
|
||||
}
|
||||
Ok(ChatResponse {
|
||||
text: Some(content),
|
||||
tool_calls: vec![],
|
||||
})
|
||||
}
|
||||
|
||||
fn supports_native_tools(&self) -> bool {
|
||||
// Return false since loop_.rs uses XML-style tool parsing via system prompt
|
||||
// The model may return native tool_calls but we convert them to JSON format
|
||||
// that parse_tool_calls() understands
|
||||
false
|
||||
// Ollama's /api/chat supports native function-calling for capable models
|
||||
// (qwen2.5, llama3.1, mistral-nemo, etc.). chat_with_tools() sends tool
|
||||
// definitions in the request and returns structured ToolCall objects.
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -557,4 +641,4 @@ mod tests {
|
|||
// arguments should be a string (JSON-encoded)
|
||||
assert!(func.get("arguments").unwrap().is_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue