feat: ollama tools
This commit is contained in:
parent
808450c48e
commit
1c0d7bbcb8
1 changed files with 241 additions and 187 deletions
|
|
@ -8,6 +8,8 @@ pub struct OllamaProvider {
|
|||
client: Client,
|
||||
}
|
||||
|
||||
// ─── Request Structures ───────────────────────────────────────────────────────
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct ChatRequest {
|
||||
model: String,
|
||||
|
|
@ -27,6 +29,8 @@ struct Options {
|
|||
temperature: f64,
|
||||
}
|
||||
|
||||
// ─── Response Structures ──────────────────────────────────────────────────────
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ApiChatResponse {
|
||||
message: ResponseMessage,
|
||||
|
|
@ -38,6 +42,9 @@ struct ResponseMessage {
|
|||
content: String,
|
||||
#[serde(default)]
|
||||
tool_calls: Vec<OllamaToolCall>,
|
||||
/// Some models return a "thinking" field with internal reasoning
|
||||
#[serde(default)]
|
||||
thinking: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
|
|
@ -53,6 +60,8 @@ struct OllamaFunction {
|
|||
arguments: serde_json::Value,
|
||||
}
|
||||
|
||||
// ─── Implementation ───────────────────────────────────────────────────────────
|
||||
|
||||
impl OllamaProvider {
|
||||
pub fn new(base_url: Option<&str>) -> Self {
|
||||
Self {
|
||||
|
|
@ -61,37 +70,20 @@ impl OllamaProvider {
|
|||
.trim_end_matches('/')
|
||||
.to_string(),
|
||||
client: Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(300)) // Ollama runs locally, may be slow
|
||||
.timeout(std::time::Duration::from_secs(300))
|
||||
.connect_timeout(std::time::Duration::from_secs(10))
|
||||
.build()
|
||||
.unwrap_or_else(|_| Client::new()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Provider for OllamaProvider {
|
||||
async fn chat_with_system(
|
||||
/// Send a request to Ollama and get the parsed response
|
||||
async fn send_request(
|
||||
&self,
|
||||
system_prompt: Option<&str>,
|
||||
message: &str,
|
||||
messages: Vec<Message>,
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
let mut messages = Vec::new();
|
||||
|
||||
if let Some(sys) = system_prompt {
|
||||
messages.push(Message {
|
||||
role: "system".to_string(),
|
||||
content: sys.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
messages.push(Message {
|
||||
role: "user".to_string(),
|
||||
content: message.to_string(),
|
||||
});
|
||||
|
||||
) -> anyhow::Result<ApiChatResponse> {
|
||||
let request = ChatRequest {
|
||||
model: model.to_string(),
|
||||
messages,
|
||||
|
|
@ -108,6 +100,7 @@ impl Provider for OllamaProvider {
|
|||
request.messages.len(),
|
||||
temperature
|
||||
);
|
||||
|
||||
if tracing::enabled!(tracing::Level::TRACE) {
|
||||
if let Ok(req_json) = serde_json::to_string(&request) {
|
||||
tracing::trace!("Ollama request body: {}", req_json);
|
||||
|
|
@ -118,11 +111,9 @@ impl Provider for OllamaProvider {
|
|||
let status = response.status();
|
||||
tracing::debug!("Ollama response status: {}", status);
|
||||
|
||||
// Read raw body first to enable debugging if deserialization fails
|
||||
let body = response.bytes().await?;
|
||||
let body_len = body.len();
|
||||
tracing::debug!("Ollama response body length: {} bytes", body.len());
|
||||
|
||||
tracing::debug!("Ollama response body length: {} bytes", body_len);
|
||||
if tracing::enabled!(tracing::Level::TRACE) {
|
||||
let raw = String::from_utf8_lossy(&body);
|
||||
tracing::trace!(
|
||||
|
|
@ -153,37 +144,140 @@ impl Provider for OllamaProvider {
|
|||
}
|
||||
};
|
||||
|
||||
let content = chat_response.message.content;
|
||||
tracing::debug!(
|
||||
"Ollama response parsed: content_length={} content_preview='{}'",
|
||||
content.len(),
|
||||
if content.len() > 100 {
|
||||
format!("{}...", &content[..100])
|
||||
} else {
|
||||
content.clone()
|
||||
Ok(chat_response)
|
||||
}
|
||||
);
|
||||
|
||||
if content.is_empty() && chat_response.message.tool_calls.is_empty() {
|
||||
let raw = String::from_utf8_lossy(&body);
|
||||
tracing::warn!("Ollama returned empty content with no tool calls. Raw response: {}", raw);
|
||||
/// Convert Ollama tool calls to the JSON format expected by parse_tool_calls in loop_.rs
|
||||
///
|
||||
/// Handles quirky model behavior where tool calls are wrapped:
|
||||
/// - `{"name": "tool_call", "arguments": {"name": "shell", "arguments": {...}}}`
|
||||
/// - `{"name": "tool.shell", "arguments": {...}}`
|
||||
fn format_tool_calls_for_loop(&self, tool_calls: &[OllamaToolCall]) -> String {
|
||||
let formatted_calls: Vec<serde_json::Value> = tool_calls
|
||||
.iter()
|
||||
.map(|tc| {
|
||||
let (tool_name, tool_args) = self.extract_tool_name_and_args(tc);
|
||||
|
||||
// Arguments must be a JSON string for parse_tool_calls compatibility
|
||||
let args_str = serde_json::to_string(&tool_args)
|
||||
.unwrap_or_else(|_| "{}".to_string());
|
||||
|
||||
serde_json::json!({
|
||||
"id": tc.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_name,
|
||||
"arguments": args_str
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
serde_json::json!({
|
||||
"content": "",
|
||||
"tool_calls": formatted_calls
|
||||
})
|
||||
.to_string()
|
||||
}
|
||||
|
||||
/// Extract the actual tool name and arguments from potentially nested structures
|
||||
fn extract_tool_name_and_args(&self, tc: &OllamaToolCall) -> (String, serde_json::Value) {
|
||||
let name = &tc.function.name;
|
||||
let args = &tc.function.arguments;
|
||||
|
||||
// Pattern 1: Nested tool_call wrapper (various malformed versions)
|
||||
// {"name": "tool_call", "arguments": {"name": "shell", "arguments": {"command": "date"}}}
|
||||
// {"name": "tool_call><json", "arguments": {"name": "shell", ...}}
|
||||
// {"name": "tool.call", "arguments": {"name": "shell", ...}}
|
||||
if name == "tool_call"
|
||||
|| name == "tool.call"
|
||||
|| name.starts_with("tool_call>")
|
||||
|| name.starts_with("tool_call<")
|
||||
{
|
||||
if let Some(nested_name) = args.get("name").and_then(|v| v.as_str()) {
|
||||
let nested_args = args.get("arguments").cloned().unwrap_or(serde_json::json!({}));
|
||||
tracing::debug!(
|
||||
"Unwrapped nested tool call: {} -> {} with args {:?}",
|
||||
name,
|
||||
nested_name,
|
||||
nested_args
|
||||
);
|
||||
return (nested_name.to_string(), nested_args);
|
||||
}
|
||||
}
|
||||
|
||||
// Pattern 2: Prefixed tool name (tool.shell, tool.file_read, etc.)
|
||||
if let Some(stripped) = name.strip_prefix("tool.") {
|
||||
return (stripped.to_string(), args.clone());
|
||||
}
|
||||
|
||||
// Pattern 3: Normal tool call
|
||||
(name.clone(), args.clone())
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Provider for OllamaProvider {
|
||||
async fn chat_with_system(
|
||||
&self,
|
||||
system_prompt: Option<&str>,
|
||||
message: &str,
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
let mut messages = Vec::new();
|
||||
|
||||
if let Some(sys) = system_prompt {
|
||||
messages.push(Message {
|
||||
role: "system".to_string(),
|
||||
content: sys.to_string(),
|
||||
});
|
||||
}
|
||||
|
||||
messages.push(Message {
|
||||
role: "user".to_string(),
|
||||
content: message.to_string(),
|
||||
});
|
||||
|
||||
let response = self.send_request(messages, model, temperature).await?;
|
||||
|
||||
// If model returned tool calls, format them for loop_.rs's parse_tool_calls
|
||||
if !response.message.tool_calls.is_empty() {
|
||||
tracing::debug!(
|
||||
"Ollama returned {} tool call(s), formatting for loop parser",
|
||||
response.message.tool_calls.len()
|
||||
);
|
||||
return Ok(self.format_tool_calls_for_loop(&response.message.tool_calls));
|
||||
}
|
||||
|
||||
// Plain text response
|
||||
let content = response.message.content;
|
||||
|
||||
// Handle edge case: model returned only "thinking" with no content or tool calls
|
||||
if content.is_empty() {
|
||||
if let Some(thinking) = &response.message.thinking {
|
||||
tracing::warn!(
|
||||
"Ollama returned empty content with only thinking: '{}'. Model may have stopped prematurely.",
|
||||
if thinking.len() > 100 { &thinking[..100] } else { thinking }
|
||||
);
|
||||
return Ok(format!(
|
||||
"I was thinking about this: {}... but I didn't complete my response. Could you try asking again?",
|
||||
if thinking.len() > 200 { &thinking[..200] } else { thinking }
|
||||
));
|
||||
}
|
||||
tracing::warn!("Ollama returned empty content with no tool calls");
|
||||
}
|
||||
|
||||
Ok(content)
|
||||
}
|
||||
|
||||
fn supports_native_tools(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
async fn chat(
|
||||
async fn chat_with_history(
|
||||
&self,
|
||||
request: crate::providers::ChatRequest<'_>,
|
||||
messages: &[crate::providers::ChatMessage],
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<crate::providers::ChatResponse> {
|
||||
let messages: Vec<Message> = request
|
||||
.messages
|
||||
) -> anyhow::Result<String> {
|
||||
let api_messages: Vec<Message> = messages
|
||||
.iter()
|
||||
.map(|m| Message {
|
||||
role: m.role.clone(),
|
||||
|
|
@ -191,101 +285,49 @@ impl Provider for OllamaProvider {
|
|||
})
|
||||
.collect();
|
||||
|
||||
let api_request = ChatRequest {
|
||||
model: model.to_string(),
|
||||
messages,
|
||||
stream: false,
|
||||
options: Options { temperature },
|
||||
};
|
||||
|
||||
let url = format!("{}/api/chat", self.base_url);
|
||||
let response = self.send_request(api_messages, model, temperature).await?;
|
||||
|
||||
// If model returned tool calls, format them for loop_.rs's parse_tool_calls
|
||||
if !response.message.tool_calls.is_empty() {
|
||||
tracing::debug!(
|
||||
"Ollama chat request: url={} model={} message_count={} temperature={}",
|
||||
url,
|
||||
model,
|
||||
api_request.messages.len(),
|
||||
temperature
|
||||
"Ollama returned {} tool call(s), formatting for loop parser",
|
||||
response.message.tool_calls.len()
|
||||
);
|
||||
if tracing::enabled!(tracing::Level::TRACE) {
|
||||
if let Ok(req_json) = serde_json::to_string(&api_request) {
|
||||
tracing::trace!("Ollama chat request body: {}", req_json);
|
||||
}
|
||||
return Ok(self.format_tool_calls_for_loop(&response.message.tool_calls));
|
||||
}
|
||||
|
||||
let response = self.client.post(&url).json(&api_request).send().await?;
|
||||
let status = response.status();
|
||||
tracing::debug!("Ollama chat response status: {}", status);
|
||||
// Plain text response
|
||||
let content = response.message.content;
|
||||
|
||||
let body = response.bytes().await?;
|
||||
tracing::debug!("Ollama chat response body length: {} bytes", body.len());
|
||||
|
||||
if tracing::enabled!(tracing::Level::TRACE) {
|
||||
let raw = String::from_utf8_lossy(&body);
|
||||
tracing::trace!(
|
||||
"Ollama chat raw response: {}",
|
||||
if raw.len() > 2000 { &raw[..2000] } else { &raw }
|
||||
// Handle edge case: model returned only "thinking" with no content or tool calls
|
||||
// This is a model quirk - it stopped after reasoning without producing output
|
||||
if content.is_empty() {
|
||||
if let Some(thinking) = &response.message.thinking {
|
||||
tracing::warn!(
|
||||
"Ollama returned empty content with only thinking: '{}'. Model may have stopped prematurely.",
|
||||
if thinking.len() > 100 { &thinking[..100] } else { thinking }
|
||||
);
|
||||
// Return a message indicating the model's thought process but no action
|
||||
return Ok(format!(
|
||||
"I was thinking about this: {}... but I didn't complete my response. Could you try asking again?",
|
||||
if thinking.len() > 200 { &thinking[..200] } else { thinking }
|
||||
));
|
||||
}
|
||||
tracing::warn!("Ollama returned empty content with no tool calls");
|
||||
}
|
||||
|
||||
if !status.is_success() {
|
||||
let raw = String::from_utf8_lossy(&body);
|
||||
tracing::error!("Ollama chat error response: status={} body={}", status, raw);
|
||||
anyhow::bail!(
|
||||
"Ollama API error ({}): {}. Is Ollama running? (brew install ollama && ollama serve)",
|
||||
status,
|
||||
if raw.len() > 200 { &raw[..200] } else { &raw }
|
||||
);
|
||||
Ok(content)
|
||||
}
|
||||
|
||||
let chat_response: ApiChatResponse = match serde_json::from_slice(&body) {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
let raw = String::from_utf8_lossy(&body);
|
||||
tracing::error!(
|
||||
"Ollama chat response deserialization failed: {e}. Raw body: {}",
|
||||
if raw.len() > 500 { &raw[..500] } else { &raw }
|
||||
);
|
||||
anyhow::bail!("Failed to parse Ollama response: {e}");
|
||||
fn supports_native_tools(&self) -> bool {
|
||||
// Return false since loop_.rs uses XML-style tool parsing via system prompt
|
||||
// The model may return native tool_calls but we convert them to JSON format
|
||||
// that parse_tool_calls() understands
|
||||
false
|
||||
}
|
||||
};
|
||||
|
||||
let content = chat_response.message.content;
|
||||
let tool_calls: Vec<crate::providers::ToolCall> = chat_response
|
||||
.message
|
||||
.tool_calls
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.map(|(i, tc)| {
|
||||
let args_str = match &tc.function.arguments {
|
||||
serde_json::Value::String(s) => s.clone(),
|
||||
other => other.to_string(),
|
||||
};
|
||||
crate::providers::ToolCall {
|
||||
id: tc.id.unwrap_or_else(|| format!("call_{}", i)),
|
||||
name: tc.function.name,
|
||||
arguments: args_str,
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
tracing::debug!(
|
||||
"Ollama chat response parsed: content_length={} tool_calls_count={}",
|
||||
content.len(),
|
||||
tool_calls.len()
|
||||
);
|
||||
|
||||
if content.is_empty() && tool_calls.is_empty() {
|
||||
let raw = String::from_utf8_lossy(&body);
|
||||
tracing::warn!("Ollama returned empty content with no tool calls. Raw response: {}", raw);
|
||||
}
|
||||
|
||||
Ok(crate::providers::ChatResponse {
|
||||
text: if content.is_empty() { None } else { Some(content) },
|
||||
tool_calls,
|
||||
})
|
||||
}
|
||||
}
|
||||
// ─── Tests ────────────────────────────────────────────────────────────────────
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
|
|
@ -315,46 +357,6 @@ mod tests {
|
|||
assert_eq!(p.base_url, "");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn request_serializes_with_system() {
|
||||
let req = ChatRequest {
|
||||
model: "llama3".to_string(),
|
||||
messages: vec![
|
||||
Message {
|
||||
role: "system".to_string(),
|
||||
content: "You are ZeroClaw".to_string(),
|
||||
},
|
||||
Message {
|
||||
role: "user".to_string(),
|
||||
content: "hello".to_string(),
|
||||
},
|
||||
],
|
||||
stream: false,
|
||||
options: Options { temperature: 0.7 },
|
||||
};
|
||||
let json = serde_json::to_string(&req).unwrap();
|
||||
assert!(json.contains("\"stream\":false"));
|
||||
assert!(json.contains("llama3"));
|
||||
assert!(json.contains("system"));
|
||||
assert!(json.contains("\"temperature\":0.7"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn request_serializes_without_system() {
|
||||
let req = ChatRequest {
|
||||
model: "mistral".to_string(),
|
||||
messages: vec![Message {
|
||||
role: "user".to_string(),
|
||||
content: "test".to_string(),
|
||||
}],
|
||||
stream: false,
|
||||
options: Options { temperature: 0.0 },
|
||||
};
|
||||
let json = serde_json::to_string(&req).unwrap();
|
||||
assert!(!json.contains("\"role\":\"system\""));
|
||||
assert!(json.contains("mistral"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_deserializes() {
|
||||
let json = r#"{"message":{"role":"assistant","content":"Hello from Ollama!"}}"#;
|
||||
|
|
@ -371,7 +373,6 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn response_with_missing_content_defaults_to_empty() {
|
||||
// Some models/versions may omit content field entirely
|
||||
let json = r#"{"message":{"role":"assistant"}}"#;
|
||||
let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
|
||||
assert!(resp.message.content.is_empty());
|
||||
|
|
@ -379,7 +380,6 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn response_with_thinking_field_extracts_content() {
|
||||
// Models with thinking capability return additional fields
|
||||
let json = r#"{"message":{"role":"assistant","content":"hello","thinking":"internal reasoning"}}"#;
|
||||
let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(resp.message.content, "hello");
|
||||
|
|
@ -387,28 +387,82 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn response_with_tool_calls_parses_correctly() {
|
||||
// Models may return tool_calls with empty content
|
||||
let json = r#"{"message":{"role":"assistant","content":"","thinking":"some thinking","tool_calls":[{"id":"call_123","function":{"name":"shell","arguments":{"cmd":["ls","-la"]}}}]}}"#;
|
||||
let json = r#"{"message":{"role":"assistant","content":"","tool_calls":[{"id":"call_123","function":{"name":"shell","arguments":{"command":"date"}}}]}}"#;
|
||||
let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
|
||||
assert!(resp.message.content.is_empty());
|
||||
assert_eq!(resp.message.tool_calls.len(), 1);
|
||||
assert_eq!(resp.message.tool_calls[0].function.name, "shell");
|
||||
assert_eq!(resp.message.tool_calls[0].id, Some("call_123".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_with_tool_calls_no_id() {
|
||||
// Some models may not include an id field
|
||||
let json = r#"{"message":{"role":"assistant","content":"","tool_calls":[{"function":{"name":"test_tool","arguments":{}}}]}}"#;
|
||||
let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(resp.message.tool_calls.len(), 1);
|
||||
assert!(resp.message.tool_calls[0].id.is_none());
|
||||
fn extract_tool_name_handles_nested_tool_call() {
|
||||
let provider = OllamaProvider::new(None);
|
||||
let tc = OllamaToolCall {
|
||||
id: Some("call_123".into()),
|
||||
function: OllamaFunction {
|
||||
name: "tool_call".into(),
|
||||
arguments: serde_json::json!({
|
||||
"name": "shell",
|
||||
"arguments": {"command": "date"}
|
||||
}),
|
||||
},
|
||||
};
|
||||
let (name, args) = provider.extract_tool_name_and_args(&tc);
|
||||
assert_eq!(name, "shell");
|
||||
assert_eq!(args.get("command").unwrap(), "date");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_with_multiline() {
|
||||
let json = r#"{"message":{"role":"assistant","content":"line1\nline2\nline3"}}"#;
|
||||
let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
|
||||
assert!(resp.message.content.contains("line1"));
|
||||
fn extract_tool_name_handles_prefixed_name() {
|
||||
let provider = OllamaProvider::new(None);
|
||||
let tc = OllamaToolCall {
|
||||
id: Some("call_123".into()),
|
||||
function: OllamaFunction {
|
||||
name: "tool.shell".into(),
|
||||
arguments: serde_json::json!({"command": "ls"}),
|
||||
},
|
||||
};
|
||||
let (name, args) = provider.extract_tool_name_and_args(&tc);
|
||||
assert_eq!(name, "shell");
|
||||
assert_eq!(args.get("command").unwrap(), "ls");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_tool_name_handles_normal_call() {
|
||||
let provider = OllamaProvider::new(None);
|
||||
let tc = OllamaToolCall {
|
||||
id: Some("call_123".into()),
|
||||
function: OllamaFunction {
|
||||
name: "file_read".into(),
|
||||
arguments: serde_json::json!({"path": "/tmp/test"}),
|
||||
},
|
||||
};
|
||||
let (name, args) = provider.extract_tool_name_and_args(&tc);
|
||||
assert_eq!(name, "file_read");
|
||||
assert_eq!(args.get("path").unwrap(), "/tmp/test");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn format_tool_calls_produces_valid_json() {
|
||||
let provider = OllamaProvider::new(None);
|
||||
let tool_calls = vec![OllamaToolCall {
|
||||
id: Some("call_abc".into()),
|
||||
function: OllamaFunction {
|
||||
name: "shell".into(),
|
||||
arguments: serde_json::json!({"command": "date"}),
|
||||
},
|
||||
}];
|
||||
|
||||
let formatted = provider.format_tool_calls_for_loop(&tool_calls);
|
||||
let parsed: serde_json::Value = serde_json::from_str(&formatted).unwrap();
|
||||
|
||||
assert!(parsed.get("tool_calls").is_some());
|
||||
let calls = parsed.get("tool_calls").unwrap().as_array().unwrap();
|
||||
assert_eq!(calls.len(), 1);
|
||||
|
||||
let func = calls[0].get("function").unwrap();
|
||||
assert_eq!(func.get("name").unwrap(), "shell");
|
||||
// arguments should be a string (JSON-encoded)
|
||||
assert!(func.get("arguments").unwrap().is_string());
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue