feat: ollama tool calls
This commit is contained in:
parent
b828873426
commit
c4c1272580
1 changed files with 151 additions and 2 deletions
|
|
@ -36,6 +36,21 @@ struct ApiChatResponse {
|
||||||
struct ResponseMessage {
|
struct ResponseMessage {
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
content: String,
|
content: String,
|
||||||
|
#[serde(default)]
|
||||||
|
tool_calls: Vec<OllamaToolCall>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct OllamaToolCall {
|
||||||
|
id: Option<String>,
|
||||||
|
function: OllamaFunction,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct OllamaFunction {
|
||||||
|
name: String,
|
||||||
|
#[serde(default)]
|
||||||
|
arguments: serde_json::Value,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl OllamaProvider {
|
impl OllamaProvider {
|
||||||
|
|
@ -149,13 +164,127 @@ impl Provider for OllamaProvider {
|
||||||
}
|
}
|
||||||
);
|
);
|
||||||
|
|
||||||
if content.is_empty() {
|
if content.is_empty() && chat_response.message.tool_calls.is_empty() {
|
||||||
let raw = String::from_utf8_lossy(&body);
|
let raw = String::from_utf8_lossy(&body);
|
||||||
tracing::warn!("Ollama returned empty content. Raw response: {}", raw);
|
tracing::warn!("Ollama returned empty content with no tool calls. Raw response: {}", raw);
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(content)
|
Ok(content)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn supports_native_tools(&self) -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn chat(
|
||||||
|
&self,
|
||||||
|
request: crate::providers::ChatRequest<'_>,
|
||||||
|
model: &str,
|
||||||
|
temperature: f64,
|
||||||
|
) -> anyhow::Result<crate::providers::ChatResponse> {
|
||||||
|
let messages: Vec<Message> = request
|
||||||
|
.messages
|
||||||
|
.iter()
|
||||||
|
.map(|m| Message {
|
||||||
|
role: m.role.clone(),
|
||||||
|
content: m.content.clone(),
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let api_request = ChatRequest {
|
||||||
|
model: model.to_string(),
|
||||||
|
messages,
|
||||||
|
stream: false,
|
||||||
|
options: Options { temperature },
|
||||||
|
};
|
||||||
|
|
||||||
|
let url = format!("{}/api/chat", self.base_url);
|
||||||
|
|
||||||
|
tracing::debug!(
|
||||||
|
"Ollama chat request: url={} model={} message_count={} temperature={}",
|
||||||
|
url,
|
||||||
|
model,
|
||||||
|
api_request.messages.len(),
|
||||||
|
temperature
|
||||||
|
);
|
||||||
|
if tracing::enabled!(tracing::Level::TRACE) {
|
||||||
|
if let Ok(req_json) = serde_json::to_string(&api_request) {
|
||||||
|
tracing::trace!("Ollama chat request body: {}", req_json);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let response = self.client.post(&url).json(&api_request).send().await?;
|
||||||
|
let status = response.status();
|
||||||
|
tracing::debug!("Ollama chat response status: {}", status);
|
||||||
|
|
||||||
|
let body = response.bytes().await?;
|
||||||
|
tracing::debug!("Ollama chat response body length: {} bytes", body.len());
|
||||||
|
|
||||||
|
if tracing::enabled!(tracing::Level::TRACE) {
|
||||||
|
let raw = String::from_utf8_lossy(&body);
|
||||||
|
tracing::trace!(
|
||||||
|
"Ollama chat raw response: {}",
|
||||||
|
if raw.len() > 2000 { &raw[..2000] } else { &raw }
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if !status.is_success() {
|
||||||
|
let raw = String::from_utf8_lossy(&body);
|
||||||
|
tracing::error!("Ollama chat error response: status={} body={}", status, raw);
|
||||||
|
anyhow::bail!(
|
||||||
|
"Ollama API error ({}): {}. Is Ollama running? (brew install ollama && ollama serve)",
|
||||||
|
status,
|
||||||
|
if raw.len() > 200 { &raw[..200] } else { &raw }
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let chat_response: ApiChatResponse = match serde_json::from_slice(&body) {
|
||||||
|
Ok(r) => r,
|
||||||
|
Err(e) => {
|
||||||
|
let raw = String::from_utf8_lossy(&body);
|
||||||
|
tracing::error!(
|
||||||
|
"Ollama chat response deserialization failed: {e}. Raw body: {}",
|
||||||
|
if raw.len() > 500 { &raw[..500] } else { &raw }
|
||||||
|
);
|
||||||
|
anyhow::bail!("Failed to parse Ollama response: {e}");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let content = chat_response.message.content;
|
||||||
|
let tool_calls: Vec<crate::providers::ToolCall> = chat_response
|
||||||
|
.message
|
||||||
|
.tool_calls
|
||||||
|
.into_iter()
|
||||||
|
.enumerate()
|
||||||
|
.map(|(i, tc)| {
|
||||||
|
let args_str = match &tc.function.arguments {
|
||||||
|
serde_json::Value::String(s) => s.clone(),
|
||||||
|
other => other.to_string(),
|
||||||
|
};
|
||||||
|
crate::providers::ToolCall {
|
||||||
|
id: tc.id.unwrap_or_else(|| format!("call_{}", i)),
|
||||||
|
name: tc.function.name,
|
||||||
|
arguments: args_str,
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
tracing::debug!(
|
||||||
|
"Ollama chat response parsed: content_length={} tool_calls_count={}",
|
||||||
|
content.len(),
|
||||||
|
tool_calls.len()
|
||||||
|
);
|
||||||
|
|
||||||
|
if content.is_empty() && tool_calls.is_empty() {
|
||||||
|
let raw = String::from_utf8_lossy(&body);
|
||||||
|
tracing::warn!("Ollama returned empty content with no tool calls. Raw response: {}", raw);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(crate::providers::ChatResponse {
|
||||||
|
text: if content.is_empty() { None } else { Some(content) },
|
||||||
|
tool_calls,
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
|
@ -256,6 +385,26 @@ mod tests {
|
||||||
assert_eq!(resp.message.content, "hello");
|
assert_eq!(resp.message.content, "hello");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn response_with_tool_calls_parses_correctly() {
|
||||||
|
// Models may return tool_calls with empty content
|
||||||
|
let json = r#"{"message":{"role":"assistant","content":"","thinking":"some thinking","tool_calls":[{"id":"call_123","function":{"name":"shell","arguments":{"cmd":["ls","-la"]}}}]}}"#;
|
||||||
|
let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
|
||||||
|
assert!(resp.message.content.is_empty());
|
||||||
|
assert_eq!(resp.message.tool_calls.len(), 1);
|
||||||
|
assert_eq!(resp.message.tool_calls[0].function.name, "shell");
|
||||||
|
assert_eq!(resp.message.tool_calls[0].id, Some("call_123".to_string()));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn response_with_tool_calls_no_id() {
|
||||||
|
// Some models may not include an id field
|
||||||
|
let json = r#"{"message":{"role":"assistant","content":"","tool_calls":[{"function":{"name":"test_tool","arguments":{}}}]}}"#;
|
||||||
|
let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
|
||||||
|
assert_eq!(resp.message.tool_calls.len(), 1);
|
||||||
|
assert!(resp.message.tool_calls[0].id.is_none());
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn response_with_multiline() {
|
fn response_with_multiline() {
|
||||||
let json = r#"{"message":{"role":"assistant","content":"line1\nline2\nline3"}}"#;
|
let json = r#"{"message":{"role":"assistant","content":"line1\nline2\nline3"}}"#;
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue