fix(provider): enable native tool calling for OllamaProvider
This commit is contained in:
parent
d548caa5f3
commit
cd59dc65c4
1 changed files with 95 additions and 11 deletions
|
|
@ -1,4 +1,4 @@
|
||||||
use crate::providers::traits::Provider;
|
use crate::providers::traits::{ChatMessage, ChatResponse, Provider, ToolCall};
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use reqwest::Client;
|
use reqwest::Client;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
@ -16,6 +16,8 @@ struct ChatRequest {
|
||||||
messages: Vec<Message>,
|
messages: Vec<Message>,
|
||||||
stream: bool,
|
stream: bool,
|
||||||
options: Options,
|
options: Options,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
tools: Option<Vec<serde_json::Value>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Serialize)]
|
||||||
|
|
@ -112,29 +114,33 @@ impl OllamaProvider {
|
||||||
Ok((normalized_model, should_auth))
|
Ok((normalized_model, should_auth))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Send a request to Ollama and get the parsed response
|
/// Send a request to Ollama and get the parsed response.
|
||||||
|
/// Pass `tools` to enable native function-calling for models that support it.
|
||||||
async fn send_request(
|
async fn send_request(
|
||||||
&self,
|
&self,
|
||||||
messages: Vec<Message>,
|
messages: Vec<Message>,
|
||||||
model: &str,
|
model: &str,
|
||||||
temperature: f64,
|
temperature: f64,
|
||||||
should_auth: bool,
|
should_auth: bool,
|
||||||
|
tools: Option<&[serde_json::Value]>,
|
||||||
) -> anyhow::Result<ApiChatResponse> {
|
) -> anyhow::Result<ApiChatResponse> {
|
||||||
let request = ChatRequest {
|
let request = ChatRequest {
|
||||||
model: model.to_string(),
|
model: model.to_string(),
|
||||||
messages,
|
messages,
|
||||||
stream: false,
|
stream: false,
|
||||||
options: Options { temperature },
|
options: Options { temperature },
|
||||||
|
tools: tools.map(|t| t.to_vec()),
|
||||||
};
|
};
|
||||||
|
|
||||||
let url = format!("{}/api/chat", self.base_url);
|
let url = format!("{}/api/chat", self.base_url);
|
||||||
|
|
||||||
tracing::debug!(
|
tracing::debug!(
|
||||||
"Ollama request: url={} model={} message_count={} temperature={}",
|
"Ollama request: url={} model={} message_count={} temperature={} tool_count={}",
|
||||||
url,
|
url,
|
||||||
model,
|
model,
|
||||||
request.messages.len(),
|
request.messages.len(),
|
||||||
temperature
|
temperature,
|
||||||
|
request.tools.as_ref().map_or(0, |t| t.len()),
|
||||||
);
|
);
|
||||||
|
|
||||||
let mut request_builder = self.http_client().post(&url).json(&request);
|
let mut request_builder = self.http_client().post(&url).json(&request);
|
||||||
|
|
@ -281,7 +287,7 @@ impl Provider for OllamaProvider {
|
||||||
});
|
});
|
||||||
|
|
||||||
let response = self
|
let response = self
|
||||||
.send_request(messages, &normalized_model, temperature, should_auth)
|
.send_request(messages, &normalized_model, temperature, should_auth, None)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
// If model returned tool calls, format them for loop_.rs's parse_tool_calls
|
// If model returned tool calls, format them for loop_.rs's parse_tool_calls
|
||||||
|
|
@ -331,7 +337,7 @@ impl Provider for OllamaProvider {
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
let response = self
|
let response = self
|
||||||
.send_request(api_messages, &normalized_model, temperature, should_auth)
|
.send_request(api_messages, &normalized_model, temperature, should_auth, None)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
// If model returned tool calls, format them for loop_.rs's parse_tool_calls
|
// If model returned tool calls, format them for loop_.rs's parse_tool_calls
|
||||||
|
|
@ -343,6 +349,7 @@ impl Provider for OllamaProvider {
|
||||||
return Ok(self.format_tool_calls_for_loop(&response.message.tool_calls));
|
return Ok(self.format_tool_calls_for_loop(&response.message.tool_calls));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// Plain text response
|
// Plain text response
|
||||||
let content = response.message.content;
|
let content = response.message.content;
|
||||||
|
|
||||||
|
|
@ -366,11 +373,88 @@ impl Provider for OllamaProvider {
|
||||||
Ok(content)
|
Ok(content)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async fn chat_with_tools(
|
||||||
|
&self,
|
||||||
|
messages: &[ChatMessage],
|
||||||
|
tools: &[serde_json::Value],
|
||||||
|
model: &str,
|
||||||
|
temperature: f64,
|
||||||
|
) -> anyhow::Result<ChatResponse> {
|
||||||
|
let (normalized_model, should_auth) = self.resolve_request_details(model)?;
|
||||||
|
|
||||||
|
let api_messages: Vec<Message> = messages
|
||||||
|
.iter()
|
||||||
|
.map(|m| Message {
|
||||||
|
role: m.role.clone(),
|
||||||
|
content: m.content.clone(),
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
// Tools arrive pre-formatted in OpenAI/Ollama-compatible JSON from
|
||||||
|
// tools_to_openai_format() in loop_.rs — pass them through directly.
|
||||||
|
let tools_opt = if tools.is_empty() { None } else { Some(tools) };
|
||||||
|
|
||||||
|
let response = self
|
||||||
|
.send_request(api_messages, &normalized_model, temperature, should_auth, tools_opt)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
// Native tool calls returned by the model.
|
||||||
|
if !response.message.tool_calls.is_empty() {
|
||||||
|
let tool_calls: Vec<ToolCall> = response
|
||||||
|
.message
|
||||||
|
.tool_calls
|
||||||
|
.iter()
|
||||||
|
.map(|tc| {
|
||||||
|
let (name, args) = self.extract_tool_name_and_args(tc);
|
||||||
|
ToolCall {
|
||||||
|
id: tc
|
||||||
|
.id
|
||||||
|
.clone()
|
||||||
|
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string()),
|
||||||
|
name,
|
||||||
|
arguments: serde_json::to_string(&args)
|
||||||
|
.unwrap_or_else(|_| "{}".to_string()),
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
let text = if response.message.content.is_empty() {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(response.message.content)
|
||||||
|
};
|
||||||
|
return Ok(ChatResponse { text, tool_calls });
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
// Plain text response.
|
||||||
|
let content = response.message.content;
|
||||||
|
if content.is_empty() {
|
||||||
|
if let Some(thinking) = &response.message.thinking {
|
||||||
|
tracing::warn!(
|
||||||
|
"Ollama returned empty content with only thinking: '{}'. Model may have stopped prematurely.",
|
||||||
|
if thinking.len() > 100 { &thinking[..100] } else { thinking }
|
||||||
|
);
|
||||||
|
return Ok(ChatResponse {
|
||||||
|
text: Some(format!(
|
||||||
|
"I was thinking about this: {}... but I didn't complete my response. Could you try asking again?",
|
||||||
|
if thinking.len() > 200 { &thinking[..200] } else { thinking }
|
||||||
|
)),
|
||||||
|
tool_calls: vec![],
|
||||||
|
});
|
||||||
|
}
|
||||||
|
tracing::warn!("Ollama returned empty content with no tool calls");
|
||||||
|
}
|
||||||
|
Ok(ChatResponse {
|
||||||
|
text: Some(content),
|
||||||
|
tool_calls: vec![],
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
fn supports_native_tools(&self) -> bool {
|
fn supports_native_tools(&self) -> bool {
|
||||||
// Return false since loop_.rs uses XML-style tool parsing via system prompt
|
// Ollama's /api/chat supports native function-calling for capable models
|
||||||
// The model may return native tool_calls but we convert them to JSON format
|
// (qwen2.5, llama3.1, mistral-nemo, etc.). chat_with_tools() sends tool
|
||||||
// that parse_tool_calls() understands
|
// definitions in the request and returns structured ToolCall objects.
|
||||||
false
|
true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -557,4 +641,4 @@ mod tests {
|
||||||
// arguments should be a string (JSON-encoded)
|
// arguments should be a string (JSON-encoded)
|
||||||
assert!(func.get("arguments").unwrap().is_string());
|
assert!(func.get("arguments").unwrap().is_string());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Loading…
Add table
Add a link
Reference in a new issue