feat: add agent structure and improve tooling for provider
This commit is contained in:
parent
e2c966d31e
commit
b341fdb368
21 changed files with 2567 additions and 443 deletions
|
|
@ -2,7 +2,10 @@
|
|||
//! Most LLM APIs follow the same `/v1/chat/completions` format.
|
||||
//! This module provides a single implementation that works for all of them.
|
||||
|
||||
use crate::providers::traits::{ChatMessage, ChatResponse, Provider, ToolCall};
|
||||
use crate::providers::traits::{
|
||||
ChatMessage, ChatRequest as ProviderChatRequest, ChatResponse as ProviderChatResponse,
|
||||
Provider, ToolCall as ProviderToolCall,
|
||||
};
|
||||
use async_trait::async_trait;
|
||||
use reqwest::Client;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
|
@ -163,12 +166,11 @@ struct ResponseMessage {
|
|||
#[serde(default)]
|
||||
content: Option<String>,
|
||||
#[serde(default)]
|
||||
tool_calls: Option<Vec<ApiToolCall>>,
|
||||
tool_calls: Option<Vec<ToolCall>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, Serialize)]
|
||||
struct ApiToolCall {
|
||||
id: Option<String>,
|
||||
struct ToolCall {
|
||||
#[serde(rename = "type")]
|
||||
kind: Option<String>,
|
||||
function: Option<Function>,
|
||||
|
|
@ -254,44 +256,6 @@ fn extract_responses_text(response: ResponsesResponse) -> Option<String> {
|
|||
None
|
||||
}
|
||||
|
||||
fn map_response_message(message: ResponseMessage) -> ChatResponse {
|
||||
let text = first_nonempty(message.content.as_deref());
|
||||
let tool_calls = message
|
||||
.tool_calls
|
||||
.unwrap_or_default()
|
||||
.into_iter()
|
||||
.enumerate()
|
||||
.filter_map(|(index, call)| map_api_tool_call(call, index))
|
||||
.collect();
|
||||
|
||||
ChatResponse { text, tool_calls }
|
||||
}
|
||||
|
||||
fn map_api_tool_call(call: ApiToolCall, index: usize) -> Option<ToolCall> {
|
||||
if call.kind.as_deref().is_some_and(|kind| kind != "function") {
|
||||
return None;
|
||||
}
|
||||
|
||||
let function = call.function?;
|
||||
let name = function
|
||||
.name
|
||||
.and_then(|value| first_nonempty(Some(value.as_str())))?;
|
||||
let arguments = function
|
||||
.arguments
|
||||
.and_then(|value| first_nonempty(Some(value.as_str())))
|
||||
.unwrap_or_else(|| "{}".to_string());
|
||||
let id = call
|
||||
.id
|
||||
.and_then(|value| first_nonempty(Some(value.as_str())))
|
||||
.unwrap_or_else(|| format!("call_{}", index + 1));
|
||||
|
||||
Some(ToolCall {
|
||||
id,
|
||||
name,
|
||||
arguments,
|
||||
})
|
||||
}
|
||||
|
||||
impl OpenAiCompatibleProvider {
|
||||
fn apply_auth_header(
|
||||
&self,
|
||||
|
|
@ -311,7 +275,7 @@ impl OpenAiCompatibleProvider {
|
|||
system_prompt: Option<&str>,
|
||||
message: &str,
|
||||
model: &str,
|
||||
) -> anyhow::Result<ChatResponse> {
|
||||
) -> anyhow::Result<String> {
|
||||
let request = ResponsesRequest {
|
||||
model: model.to_string(),
|
||||
input: vec![ResponsesInput {
|
||||
|
|
@ -337,7 +301,6 @@ impl OpenAiCompatibleProvider {
|
|||
let responses: ResponsesResponse = response.json().await?;
|
||||
|
||||
extract_responses_text(responses)
|
||||
.map(ChatResponse::with_text)
|
||||
.ok_or_else(|| anyhow::anyhow!("No response from {} Responses API", self.name))
|
||||
}
|
||||
}
|
||||
|
|
@ -350,7 +313,7 @@ impl Provider for OpenAiCompatibleProvider {
|
|||
message: &str,
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<ChatResponse> {
|
||||
) -> anyhow::Result<String> {
|
||||
let api_key = self.api_key.as_ref().ok_or_else(|| {
|
||||
anyhow::anyhow!(
|
||||
"{} API key not set. Run `zeroclaw onboard` or set the appropriate env var.",
|
||||
|
|
@ -408,13 +371,27 @@ impl Provider for OpenAiCompatibleProvider {
|
|||
|
||||
let chat_response: ApiChatResponse = response.json().await?;
|
||||
|
||||
let choice = chat_response
|
||||
chat_response
|
||||
.choices
|
||||
.into_iter()
|
||||
.next()
|
||||
.ok_or_else(|| anyhow::anyhow!("No response from {}", self.name))?;
|
||||
|
||||
Ok(map_response_message(choice.message))
|
||||
.map(|c| {
|
||||
// If tool_calls are present, serialize the full message as JSON
|
||||
// so parse_tool_calls can handle the OpenAI-style format
|
||||
if c.message.tool_calls.is_some()
|
||||
&& c.message
|
||||
.tool_calls
|
||||
.as_ref()
|
||||
.map_or(false, |t| !t.is_empty())
|
||||
{
|
||||
serde_json::to_string(&c.message)
|
||||
.unwrap_or_else(|_| c.message.content.unwrap_or_default())
|
||||
} else {
|
||||
// No tool calls, return content as-is
|
||||
c.message.content.unwrap_or_default()
|
||||
}
|
||||
})
|
||||
.ok_or_else(|| anyhow::anyhow!("No response from {}", self.name))
|
||||
}
|
||||
|
||||
async fn chat_with_history(
|
||||
|
|
@ -422,7 +399,7 @@ impl Provider for OpenAiCompatibleProvider {
|
|||
messages: &[ChatMessage],
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<ChatResponse> {
|
||||
) -> anyhow::Result<String> {
|
||||
let api_key = self.api_key.as_ref().ok_or_else(|| {
|
||||
anyhow::anyhow!(
|
||||
"{} API key not set. Run `zeroclaw onboard` or set the appropriate env var.",
|
||||
|
|
@ -482,13 +459,71 @@ impl Provider for OpenAiCompatibleProvider {
|
|||
|
||||
let chat_response: ApiChatResponse = response.json().await?;
|
||||
|
||||
let choice = chat_response
|
||||
chat_response
|
||||
.choices
|
||||
.into_iter()
|
||||
.next()
|
||||
.ok_or_else(|| anyhow::anyhow!("No response from {}", self.name))?;
|
||||
.map(|c| {
|
||||
// If tool_calls are present, serialize the full message as JSON
|
||||
// so parse_tool_calls can handle the OpenAI-style format
|
||||
if c.message.tool_calls.is_some()
|
||||
&& c.message
|
||||
.tool_calls
|
||||
.as_ref()
|
||||
.map_or(false, |t| !t.is_empty())
|
||||
{
|
||||
serde_json::to_string(&c.message)
|
||||
.unwrap_or_else(|_| c.message.content.unwrap_or_default())
|
||||
} else {
|
||||
// No tool calls, return content as-is
|
||||
c.message.content.unwrap_or_default()
|
||||
}
|
||||
})
|
||||
.ok_or_else(|| anyhow::anyhow!("No response from {}", self.name))
|
||||
}
|
||||
|
||||
Ok(map_response_message(choice.message))
|
||||
async fn chat(
|
||||
&self,
|
||||
request: ProviderChatRequest<'_>,
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<ProviderChatResponse> {
|
||||
let text = self
|
||||
.chat_with_history(request.messages, model, temperature)
|
||||
.await?;
|
||||
|
||||
// Backward compatible path: chat_with_history may serialize tool_calls JSON into content.
|
||||
if let Ok(message) = serde_json::from_str::<ResponseMessage>(&text) {
|
||||
let tool_calls = message
|
||||
.tool_calls
|
||||
.unwrap_or_default()
|
||||
.into_iter()
|
||||
.filter_map(|tc| {
|
||||
let function = tc.function?;
|
||||
let name = function.name?;
|
||||
let arguments = function.arguments.unwrap_or_else(|| "{}".to_string());
|
||||
Some(ProviderToolCall {
|
||||
id: uuid::Uuid::new_v4().to_string(),
|
||||
name,
|
||||
arguments,
|
||||
})
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
return Ok(ProviderChatResponse {
|
||||
text: message.content,
|
||||
tool_calls,
|
||||
});
|
||||
}
|
||||
|
||||
Ok(ProviderChatResponse {
|
||||
text: Some(text),
|
||||
tool_calls: vec![],
|
||||
})
|
||||
}
|
||||
|
||||
fn supports_native_tools(&self) -> bool {
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -573,20 +608,6 @@ mod tests {
|
|||
assert!(resp.choices.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_with_tool_calls_maps_structured_data() {
|
||||
let json = r#"{"choices":[{"message":{"content":"Running checks","tool_calls":[{"id":"call_1","type":"function","function":{"name":"shell","arguments":"{\"command\":\"pwd\"}"}}]}}]}"#;
|
||||
let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
|
||||
let choice = resp.choices.into_iter().next().unwrap();
|
||||
|
||||
let mapped = map_response_message(choice.message);
|
||||
assert_eq!(mapped.text.as_deref(), Some("Running checks"));
|
||||
assert_eq!(mapped.tool_calls.len(), 1);
|
||||
assert_eq!(mapped.tool_calls[0].id, "call_1");
|
||||
assert_eq!(mapped.tool_calls[0].name, "shell");
|
||||
assert_eq!(mapped.tool_calls[0].arguments, r#"{"command":"pwd"}"#);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn x_api_key_auth_style() {
|
||||
let p = OpenAiCompatibleProvider::new(
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue