fix(provider): preserve native Ollama tool history structure

This commit is contained in:
Chummy 2026-02-19 14:28:14 +08:00
parent cd59dc65c4
commit cf476a81c1

View file

@ -2,6 +2,7 @@ use crate::providers::traits::{ChatMessage, ChatResponse, Provider, ToolCall};
use async_trait::async_trait;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
pub struct OllamaProvider {
base_url: String,
@ -23,7 +24,25 @@ struct ChatRequest {
#[derive(Debug, Serialize)]
struct Message {
role: String,
content: String,
#[serde(skip_serializing_if = "Option::is_none")]
content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_calls: Option<Vec<OutgoingToolCall>>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_name: Option<String>,
}
#[derive(Debug, Serialize)]
struct OutgoingToolCall {
#[serde(rename = "type")]
kind: String,
function: OutgoingFunction,
}
#[derive(Debug, Serialize)]
struct OutgoingFunction {
name: String,
arguments: serde_json::Value,
}
#[derive(Debug, Serialize)]
@ -114,6 +133,98 @@ impl OllamaProvider {
Ok((normalized_model, should_auth))
}
fn parse_tool_arguments(arguments: &str) -> serde_json::Value {
serde_json::from_str(arguments).unwrap_or_else(|_| serde_json::json!({}))
}
/// Convert internal chat history format to Ollama's native tool-call message schema.
///
/// `run_tool_call_loop` stores native assistant/tool entries as JSON strings in
/// `ChatMessage.content`. We decode those payloads here so follow-up requests send
/// structured `assistant.tool_calls` and `tool.tool_name`, as expected by Ollama.
fn convert_messages(&self, messages: &[ChatMessage]) -> Vec<Message> {
let mut tool_name_by_id: HashMap<String, String> = HashMap::new();
messages
.iter()
.map(|message| {
if message.role == "assistant" {
if let Ok(value) = serde_json::from_str::<serde_json::Value>(&message.content) {
if let Some(tool_calls_value) = value.get("tool_calls") {
if let Ok(parsed_calls) =
serde_json::from_value::<Vec<ToolCall>>(tool_calls_value.clone())
{
let outgoing_calls: Vec<OutgoingToolCall> = parsed_calls
.into_iter()
.map(|call| {
tool_name_by_id.insert(call.id.clone(), call.name.clone());
OutgoingToolCall {
kind: "function".to_string(),
function: OutgoingFunction {
name: call.name,
arguments: Self::parse_tool_arguments(
&call.arguments,
),
},
}
})
.collect();
let content = value
.get("content")
.and_then(serde_json::Value::as_str)
.map(ToString::to_string);
return Message {
role: "assistant".to_string(),
content,
tool_calls: Some(outgoing_calls),
tool_name: None,
};
}
}
}
}
if message.role == "tool" {
if let Ok(value) = serde_json::from_str::<serde_json::Value>(&message.content) {
let tool_name = value
.get("tool_name")
.and_then(serde_json::Value::as_str)
.map(ToString::to_string)
.or_else(|| {
value
.get("tool_call_id")
.and_then(serde_json::Value::as_str)
.and_then(|id| tool_name_by_id.get(id))
.cloned()
});
let content = value
.get("content")
.and_then(serde_json::Value::as_str)
.map(ToString::to_string)
.or_else(|| {
(!message.content.trim().is_empty())
.then_some(message.content.clone())
});
return Message {
role: "tool".to_string(),
content,
tool_calls: None,
tool_name,
};
}
}
Message {
role: message.role.clone(),
content: Some(message.content.clone()),
tool_calls: None,
tool_name: None,
}
})
.collect()
}
/// Send a request to Ollama and get the parsed response.
/// Pass `tools` to enable native function-calling for models that support it.
async fn send_request(
@ -277,13 +388,17 @@ impl Provider for OllamaProvider {
if let Some(sys) = system_prompt {
messages.push(Message {
role: "system".to_string(),
content: sys.to_string(),
content: Some(sys.to_string()),
tool_calls: None,
tool_name: None,
});
}
messages.push(Message {
role: "user".to_string(),
content: message.to_string(),
content: Some(message.to_string()),
tool_calls: None,
tool_name: None,
});
let response = self
@ -328,16 +443,16 @@ impl Provider for OllamaProvider {
) -> anyhow::Result<String> {
let (normalized_model, should_auth) = self.resolve_request_details(model)?;
let api_messages: Vec<Message> = messages
.iter()
.map(|m| Message {
role: m.role.clone(),
content: m.content.clone(),
})
.collect();
let api_messages = self.convert_messages(messages);
let response = self
.send_request(api_messages, &normalized_model, temperature, should_auth, None)
.send_request(
api_messages,
&normalized_model,
temperature,
should_auth,
None,
)
.await?;
// If model returned tool calls, format them for loop_.rs's parse_tool_calls
@ -349,7 +464,6 @@ impl Provider for OllamaProvider {
return Ok(self.format_tool_calls_for_loop(&response.message.tool_calls));
}
// Plain text response
let content = response.message.content;
@ -382,20 +496,20 @@ impl Provider for OllamaProvider {
) -> anyhow::Result<ChatResponse> {
let (normalized_model, should_auth) = self.resolve_request_details(model)?;
let api_messages: Vec<Message> = messages
.iter()
.map(|m| Message {
role: m.role.clone(),
content: m.content.clone(),
})
.collect();
let api_messages = self.convert_messages(messages);
// Tools arrive pre-formatted in OpenAI/Ollama-compatible JSON from
// tools_to_openai_format() in loop_.rs — pass them through directly.
let tools_opt = if tools.is_empty() { None } else { Some(tools) };
let response = self
.send_request(api_messages, &normalized_model, temperature, should_auth, tools_opt)
.send_request(
api_messages,
&normalized_model,
temperature,
should_auth,
tools_opt,
)
.await?;
// Native tool calls returned by the model.
@ -425,7 +539,6 @@ impl Provider for OllamaProvider {
return Ok(ChatResponse { text, tool_calls });
}
// Plain text response.
let content = response.message.content;
if content.is_empty() {
@ -641,4 +754,50 @@ mod tests {
// arguments should be a string (JSON-encoded)
assert!(func.get("arguments").unwrap().is_string());
}
}
#[test]
fn convert_messages_parses_native_assistant_tool_calls() {
let provider = OllamaProvider::new(None, None);
let messages = vec![ChatMessage {
role: "assistant".into(),
content: r#"{"content":null,"tool_calls":[{"id":"call_1","name":"shell","arguments":"{\"command\":\"ls\"}"}]}"#.into(),
}];
let converted = provider.convert_messages(&messages);
assert_eq!(converted.len(), 1);
assert_eq!(converted[0].role, "assistant");
assert!(converted[0].content.is_none());
let calls = converted[0]
.tool_calls
.as_ref()
.expect("tool calls expected");
assert_eq!(calls.len(), 1);
assert_eq!(calls[0].kind, "function");
assert_eq!(calls[0].function.name, "shell");
assert_eq!(calls[0].function.arguments.get("command").unwrap(), "ls");
}
#[test]
fn convert_messages_maps_tool_result_call_id_to_tool_name() {
let provider = OllamaProvider::new(None, None);
let messages = vec![
ChatMessage {
role: "assistant".into(),
content: r#"{"content":null,"tool_calls":[{"id":"call_7","name":"file_read","arguments":"{\"path\":\"README.md\"}"}]}"#.into(),
},
ChatMessage {
role: "tool".into(),
content: r#"{"tool_call_id":"call_7","content":"ok"}"#.into(),
},
];
let converted = provider.convert_messages(&messages);
assert_eq!(converted.len(), 2);
assert_eq!(converted[1].role, "tool");
assert_eq!(converted[1].tool_name.as_deref(), Some("file_read"));
assert_eq!(converted[1].content.as_deref(), Some("ok"));
assert!(converted[1].tool_calls.is_none());
}
}