fix(provider): preserve native Ollama tool history structure
This commit is contained in:
parent
cd59dc65c4
commit
cf476a81c1
1 changed files with 181 additions and 22 deletions
|
|
@ -2,6 +2,7 @@ use crate::providers::traits::{ChatMessage, ChatResponse, Provider, ToolCall};
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use reqwest::Client;
|
use reqwest::Client;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
pub struct OllamaProvider {
|
pub struct OllamaProvider {
|
||||||
base_url: String,
|
base_url: String,
|
||||||
|
|
@ -23,7 +24,25 @@ struct ChatRequest {
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Serialize)]
|
||||||
struct Message {
|
struct Message {
|
||||||
role: String,
|
role: String,
|
||||||
content: String,
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
content: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
tool_calls: Option<Vec<OutgoingToolCall>>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
tool_name: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
struct OutgoingToolCall {
|
||||||
|
#[serde(rename = "type")]
|
||||||
|
kind: String,
|
||||||
|
function: OutgoingFunction,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
struct OutgoingFunction {
|
||||||
|
name: String,
|
||||||
|
arguments: serde_json::Value,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Serialize)]
|
||||||
|
|
@ -114,6 +133,98 @@ impl OllamaProvider {
|
||||||
Ok((normalized_model, should_auth))
|
Ok((normalized_model, should_auth))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn parse_tool_arguments(arguments: &str) -> serde_json::Value {
|
||||||
|
serde_json::from_str(arguments).unwrap_or_else(|_| serde_json::json!({}))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Convert internal chat history format to Ollama's native tool-call message schema.
|
||||||
|
///
|
||||||
|
/// `run_tool_call_loop` stores native assistant/tool entries as JSON strings in
|
||||||
|
/// `ChatMessage.content`. We decode those payloads here so follow-up requests send
|
||||||
|
/// structured `assistant.tool_calls` and `tool.tool_name`, as expected by Ollama.
|
||||||
|
fn convert_messages(&self, messages: &[ChatMessage]) -> Vec<Message> {
|
||||||
|
let mut tool_name_by_id: HashMap<String, String> = HashMap::new();
|
||||||
|
|
||||||
|
messages
|
||||||
|
.iter()
|
||||||
|
.map(|message| {
|
||||||
|
if message.role == "assistant" {
|
||||||
|
if let Ok(value) = serde_json::from_str::<serde_json::Value>(&message.content) {
|
||||||
|
if let Some(tool_calls_value) = value.get("tool_calls") {
|
||||||
|
if let Ok(parsed_calls) =
|
||||||
|
serde_json::from_value::<Vec<ToolCall>>(tool_calls_value.clone())
|
||||||
|
{
|
||||||
|
let outgoing_calls: Vec<OutgoingToolCall> = parsed_calls
|
||||||
|
.into_iter()
|
||||||
|
.map(|call| {
|
||||||
|
tool_name_by_id.insert(call.id.clone(), call.name.clone());
|
||||||
|
OutgoingToolCall {
|
||||||
|
kind: "function".to_string(),
|
||||||
|
function: OutgoingFunction {
|
||||||
|
name: call.name,
|
||||||
|
arguments: Self::parse_tool_arguments(
|
||||||
|
&call.arguments,
|
||||||
|
),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
let content = value
|
||||||
|
.get("content")
|
||||||
|
.and_then(serde_json::Value::as_str)
|
||||||
|
.map(ToString::to_string);
|
||||||
|
return Message {
|
||||||
|
role: "assistant".to_string(),
|
||||||
|
content,
|
||||||
|
tool_calls: Some(outgoing_calls),
|
||||||
|
tool_name: None,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if message.role == "tool" {
|
||||||
|
if let Ok(value) = serde_json::from_str::<serde_json::Value>(&message.content) {
|
||||||
|
let tool_name = value
|
||||||
|
.get("tool_name")
|
||||||
|
.and_then(serde_json::Value::as_str)
|
||||||
|
.map(ToString::to_string)
|
||||||
|
.or_else(|| {
|
||||||
|
value
|
||||||
|
.get("tool_call_id")
|
||||||
|
.and_then(serde_json::Value::as_str)
|
||||||
|
.and_then(|id| tool_name_by_id.get(id))
|
||||||
|
.cloned()
|
||||||
|
});
|
||||||
|
let content = value
|
||||||
|
.get("content")
|
||||||
|
.and_then(serde_json::Value::as_str)
|
||||||
|
.map(ToString::to_string)
|
||||||
|
.or_else(|| {
|
||||||
|
(!message.content.trim().is_empty())
|
||||||
|
.then_some(message.content.clone())
|
||||||
|
});
|
||||||
|
|
||||||
|
return Message {
|
||||||
|
role: "tool".to_string(),
|
||||||
|
content,
|
||||||
|
tool_calls: None,
|
||||||
|
tool_name,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Message {
|
||||||
|
role: message.role.clone(),
|
||||||
|
content: Some(message.content.clone()),
|
||||||
|
tool_calls: None,
|
||||||
|
tool_name: None,
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
/// Send a request to Ollama and get the parsed response.
|
/// Send a request to Ollama and get the parsed response.
|
||||||
/// Pass `tools` to enable native function-calling for models that support it.
|
/// Pass `tools` to enable native function-calling for models that support it.
|
||||||
async fn send_request(
|
async fn send_request(
|
||||||
|
|
@ -277,13 +388,17 @@ impl Provider for OllamaProvider {
|
||||||
if let Some(sys) = system_prompt {
|
if let Some(sys) = system_prompt {
|
||||||
messages.push(Message {
|
messages.push(Message {
|
||||||
role: "system".to_string(),
|
role: "system".to_string(),
|
||||||
content: sys.to_string(),
|
content: Some(sys.to_string()),
|
||||||
|
tool_calls: None,
|
||||||
|
tool_name: None,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
messages.push(Message {
|
messages.push(Message {
|
||||||
role: "user".to_string(),
|
role: "user".to_string(),
|
||||||
content: message.to_string(),
|
content: Some(message.to_string()),
|
||||||
|
tool_calls: None,
|
||||||
|
tool_name: None,
|
||||||
});
|
});
|
||||||
|
|
||||||
let response = self
|
let response = self
|
||||||
|
|
@ -328,16 +443,16 @@ impl Provider for OllamaProvider {
|
||||||
) -> anyhow::Result<String> {
|
) -> anyhow::Result<String> {
|
||||||
let (normalized_model, should_auth) = self.resolve_request_details(model)?;
|
let (normalized_model, should_auth) = self.resolve_request_details(model)?;
|
||||||
|
|
||||||
let api_messages: Vec<Message> = messages
|
let api_messages = self.convert_messages(messages);
|
||||||
.iter()
|
|
||||||
.map(|m| Message {
|
|
||||||
role: m.role.clone(),
|
|
||||||
content: m.content.clone(),
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
let response = self
|
let response = self
|
||||||
.send_request(api_messages, &normalized_model, temperature, should_auth, None)
|
.send_request(
|
||||||
|
api_messages,
|
||||||
|
&normalized_model,
|
||||||
|
temperature,
|
||||||
|
should_auth,
|
||||||
|
None,
|
||||||
|
)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
// If model returned tool calls, format them for loop_.rs's parse_tool_calls
|
// If model returned tool calls, format them for loop_.rs's parse_tool_calls
|
||||||
|
|
@ -349,7 +464,6 @@ impl Provider for OllamaProvider {
|
||||||
return Ok(self.format_tool_calls_for_loop(&response.message.tool_calls));
|
return Ok(self.format_tool_calls_for_loop(&response.message.tool_calls));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// Plain text response
|
// Plain text response
|
||||||
let content = response.message.content;
|
let content = response.message.content;
|
||||||
|
|
||||||
|
|
@ -382,20 +496,20 @@ impl Provider for OllamaProvider {
|
||||||
) -> anyhow::Result<ChatResponse> {
|
) -> anyhow::Result<ChatResponse> {
|
||||||
let (normalized_model, should_auth) = self.resolve_request_details(model)?;
|
let (normalized_model, should_auth) = self.resolve_request_details(model)?;
|
||||||
|
|
||||||
let api_messages: Vec<Message> = messages
|
let api_messages = self.convert_messages(messages);
|
||||||
.iter()
|
|
||||||
.map(|m| Message {
|
|
||||||
role: m.role.clone(),
|
|
||||||
content: m.content.clone(),
|
|
||||||
})
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
// Tools arrive pre-formatted in OpenAI/Ollama-compatible JSON from
|
// Tools arrive pre-formatted in OpenAI/Ollama-compatible JSON from
|
||||||
// tools_to_openai_format() in loop_.rs — pass them through directly.
|
// tools_to_openai_format() in loop_.rs — pass them through directly.
|
||||||
let tools_opt = if tools.is_empty() { None } else { Some(tools) };
|
let tools_opt = if tools.is_empty() { None } else { Some(tools) };
|
||||||
|
|
||||||
let response = self
|
let response = self
|
||||||
.send_request(api_messages, &normalized_model, temperature, should_auth, tools_opt)
|
.send_request(
|
||||||
|
api_messages,
|
||||||
|
&normalized_model,
|
||||||
|
temperature,
|
||||||
|
should_auth,
|
||||||
|
tools_opt,
|
||||||
|
)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
// Native tool calls returned by the model.
|
// Native tool calls returned by the model.
|
||||||
|
|
@ -425,7 +539,6 @@ impl Provider for OllamaProvider {
|
||||||
return Ok(ChatResponse { text, tool_calls });
|
return Ok(ChatResponse { text, tool_calls });
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
// Plain text response.
|
// Plain text response.
|
||||||
let content = response.message.content;
|
let content = response.message.content;
|
||||||
if content.is_empty() {
|
if content.is_empty() {
|
||||||
|
|
@ -641,4 +754,50 @@ mod tests {
|
||||||
// arguments should be a string (JSON-encoded)
|
// arguments should be a string (JSON-encoded)
|
||||||
assert!(func.get("arguments").unwrap().is_string());
|
assert!(func.get("arguments").unwrap().is_string());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn convert_messages_parses_native_assistant_tool_calls() {
|
||||||
|
let provider = OllamaProvider::new(None, None);
|
||||||
|
let messages = vec![ChatMessage {
|
||||||
|
role: "assistant".into(),
|
||||||
|
content: r#"{"content":null,"tool_calls":[{"id":"call_1","name":"shell","arguments":"{\"command\":\"ls\"}"}]}"#.into(),
|
||||||
|
}];
|
||||||
|
|
||||||
|
let converted = provider.convert_messages(&messages);
|
||||||
|
|
||||||
|
assert_eq!(converted.len(), 1);
|
||||||
|
assert_eq!(converted[0].role, "assistant");
|
||||||
|
assert!(converted[0].content.is_none());
|
||||||
|
let calls = converted[0]
|
||||||
|
.tool_calls
|
||||||
|
.as_ref()
|
||||||
|
.expect("tool calls expected");
|
||||||
|
assert_eq!(calls.len(), 1);
|
||||||
|
assert_eq!(calls[0].kind, "function");
|
||||||
|
assert_eq!(calls[0].function.name, "shell");
|
||||||
|
assert_eq!(calls[0].function.arguments.get("command").unwrap(), "ls");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn convert_messages_maps_tool_result_call_id_to_tool_name() {
|
||||||
|
let provider = OllamaProvider::new(None, None);
|
||||||
|
let messages = vec![
|
||||||
|
ChatMessage {
|
||||||
|
role: "assistant".into(),
|
||||||
|
content: r#"{"content":null,"tool_calls":[{"id":"call_7","name":"file_read","arguments":"{\"path\":\"README.md\"}"}]}"#.into(),
|
||||||
|
},
|
||||||
|
ChatMessage {
|
||||||
|
role: "tool".into(),
|
||||||
|
content: r#"{"tool_call_id":"call_7","content":"ok"}"#.into(),
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
let converted = provider.convert_messages(&messages);
|
||||||
|
|
||||||
|
assert_eq!(converted.len(), 2);
|
||||||
|
assert_eq!(converted[1].role, "tool");
|
||||||
|
assert_eq!(converted[1].tool_name.as_deref(), Some("file_read"));
|
||||||
|
assert_eq!(converted[1].content.as_deref(), Some("ok"));
|
||||||
|
assert!(converted[1].tool_calls.is_none());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
Loading…
Add table
Add a link
Reference in a new issue