Merge remote-tracking branch 'origin/main' into feat/glm-provider

Resolved conflicts in:
- Cargo.toml: kept both `ring` (JWT auth) and `prost` (protobuf) dependencies
- src/onboard/wizard.rs: accepted main branch version
- src/providers/mod.rs: accepted main branch version
- Cargo.lock: accepted main branch version

Note: The custom `glm::GlmProvider` from this PR was replaced with
main's OpenAiCompatibleProvider approach for GLM, which uses base URLs.
The main purpose of this PR is Windows daemon support via Task Scheduler.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
argenis de la rosa 2026-02-17 13:27:58 -05:00
commit 34af6a223a
269 changed files with 68574 additions and 2541 deletions

View file

@ -1,10 +1,15 @@
use crate::providers::traits::Provider;
use crate::providers::traits::{
ChatMessage, ChatRequest as ProviderChatRequest, ChatResponse as ProviderChatResponse,
Provider, ToolCall as ProviderToolCall,
};
use crate::tools::ToolSpec;
use async_trait::async_trait;
use reqwest::Client;
use serde::{Deserialize, Serialize};
pub struct AnthropicProvider {
api_key: Option<String>,
credential: Option<String>,
base_url: String,
client: Client,
}
@ -31,13 +36,91 @@ struct ChatResponse {
#[derive(Debug, Deserialize)]
struct ContentBlock {
text: String,
#[serde(rename = "type")]
kind: String,
#[serde(default)]
text: Option<String>,
}
#[derive(Debug, Serialize)]
struct NativeChatRequest {
model: String,
max_tokens: u32,
#[serde(skip_serializing_if = "Option::is_none")]
system: Option<String>,
messages: Vec<NativeMessage>,
temperature: f64,
#[serde(skip_serializing_if = "Option::is_none")]
tools: Option<Vec<NativeToolSpec>>,
}
#[derive(Debug, Serialize)]
struct NativeMessage {
role: String,
content: Vec<NativeContentOut>,
}
#[derive(Debug, Serialize)]
#[serde(tag = "type")]
enum NativeContentOut {
#[serde(rename = "text")]
Text { text: String },
#[serde(rename = "tool_use")]
ToolUse {
id: String,
name: String,
input: serde_json::Value,
},
#[serde(rename = "tool_result")]
ToolResult {
tool_use_id: String,
content: String,
},
}
#[derive(Debug, Serialize)]
struct NativeToolSpec {
name: String,
description: String,
input_schema: serde_json::Value,
}
#[derive(Debug, Deserialize)]
struct NativeChatResponse {
#[serde(default)]
content: Vec<NativeContentIn>,
}
#[derive(Debug, Deserialize)]
struct NativeContentIn {
#[serde(rename = "type")]
kind: String,
#[serde(default)]
text: Option<String>,
#[serde(default)]
id: Option<String>,
#[serde(default)]
name: Option<String>,
#[serde(default)]
input: Option<serde_json::Value>,
}
impl AnthropicProvider {
pub fn new(api_key: Option<&str>) -> Self {
pub fn new(credential: Option<&str>) -> Self {
Self::with_base_url(credential, None)
}
pub fn with_base_url(credential: Option<&str>, base_url: Option<&str>) -> Self {
let base_url = base_url
.map(|u| u.trim_end_matches('/'))
.unwrap_or("https://api.anthropic.com")
.to_string();
Self {
api_key: api_key.map(ToString::to_string),
credential: credential
.map(str::trim)
.filter(|k| !k.is_empty())
.map(ToString::to_string),
base_url,
client: Client::builder()
.timeout(std::time::Duration::from_secs(120))
.connect_timeout(std::time::Duration::from_secs(10))
@ -45,6 +128,192 @@ impl AnthropicProvider {
.unwrap_or_else(|_| Client::new()),
}
}
fn is_setup_token(token: &str) -> bool {
token.starts_with("sk-ant-oat01-")
}
fn apply_auth(
&self,
request: reqwest::RequestBuilder,
credential: &str,
) -> reqwest::RequestBuilder {
if Self::is_setup_token(credential) {
request
.header("Authorization", format!("Bearer {credential}"))
.header("anthropic-beta", "oauth-2025-04-20")
} else {
request.header("x-api-key", credential)
}
}
fn convert_tools(tools: Option<&[ToolSpec]>) -> Option<Vec<NativeToolSpec>> {
let items = tools?;
if items.is_empty() {
return None;
}
Some(
items
.iter()
.map(|tool| NativeToolSpec {
name: tool.name.clone(),
description: tool.description.clone(),
input_schema: tool.parameters.clone(),
})
.collect(),
)
}
fn parse_assistant_tool_call_message(content: &str) -> Option<Vec<NativeContentOut>> {
let value = serde_json::from_str::<serde_json::Value>(content).ok()?;
let tool_calls = value
.get("tool_calls")
.and_then(|v| serde_json::from_value::<Vec<ProviderToolCall>>(v.clone()).ok())?;
let mut blocks = Vec::new();
if let Some(text) = value
.get("content")
.and_then(serde_json::Value::as_str)
.map(str::trim)
.filter(|t| !t.is_empty())
{
blocks.push(NativeContentOut::Text {
text: text.to_string(),
});
}
for call in tool_calls {
let input = serde_json::from_str::<serde_json::Value>(&call.arguments)
.unwrap_or_else(|_| serde_json::Value::Object(serde_json::Map::new()));
blocks.push(NativeContentOut::ToolUse {
id: call.id,
name: call.name,
input,
});
}
Some(blocks)
}
fn parse_tool_result_message(content: &str) -> Option<NativeMessage> {
let value = serde_json::from_str::<serde_json::Value>(content).ok()?;
let tool_use_id = value
.get("tool_call_id")
.and_then(serde_json::Value::as_str)?
.to_string();
let result = value
.get("content")
.and_then(serde_json::Value::as_str)
.unwrap_or("")
.to_string();
Some(NativeMessage {
role: "user".to_string(),
content: vec![NativeContentOut::ToolResult {
tool_use_id,
content: result,
}],
})
}
fn convert_messages(messages: &[ChatMessage]) -> (Option<String>, Vec<NativeMessage>) {
let mut system_prompt = None;
let mut native_messages = Vec::new();
for msg in messages {
match msg.role.as_str() {
"system" => {
if system_prompt.is_none() {
system_prompt = Some(msg.content.clone());
}
}
"assistant" => {
if let Some(blocks) = Self::parse_assistant_tool_call_message(&msg.content) {
native_messages.push(NativeMessage {
role: "assistant".to_string(),
content: blocks,
});
} else {
native_messages.push(NativeMessage {
role: "assistant".to_string(),
content: vec![NativeContentOut::Text {
text: msg.content.clone(),
}],
});
}
}
"tool" => {
if let Some(tool_result) = Self::parse_tool_result_message(&msg.content) {
native_messages.push(tool_result);
} else {
native_messages.push(NativeMessage {
role: "user".to_string(),
content: vec![NativeContentOut::Text {
text: msg.content.clone(),
}],
});
}
}
_ => {
native_messages.push(NativeMessage {
role: "user".to_string(),
content: vec![NativeContentOut::Text {
text: msg.content.clone(),
}],
});
}
}
}
(system_prompt, native_messages)
}
fn parse_text_response(response: ChatResponse) -> anyhow::Result<String> {
response
.content
.into_iter()
.find(|c| c.kind == "text")
.and_then(|c| c.text)
.ok_or_else(|| anyhow::anyhow!("No response from Anthropic"))
}
fn parse_native_response(response: NativeChatResponse) -> ProviderChatResponse {
let mut text_parts = Vec::new();
let mut tool_calls = Vec::new();
for block in response.content {
match block.kind.as_str() {
"text" => {
if let Some(text) = block.text.map(|t| t.trim().to_string()) {
if !text.is_empty() {
text_parts.push(text);
}
}
}
"tool_use" => {
let name = block.name.unwrap_or_default();
if name.is_empty() {
continue;
}
let arguments = block
.input
.unwrap_or_else(|| serde_json::Value::Object(serde_json::Map::new()));
tool_calls.push(ProviderToolCall {
id: block.id.unwrap_or_else(|| uuid::Uuid::new_v4().to_string()),
name,
arguments: arguments.to_string(),
});
}
_ => {}
}
}
ProviderChatResponse {
text: if text_parts.is_empty() {
None
} else {
Some(text_parts.join("\n"))
},
tool_calls,
}
}
}
#[async_trait]
@ -56,8 +325,10 @@ impl Provider for AnthropicProvider {
model: &str,
temperature: f64,
) -> anyhow::Result<String> {
let api_key = self.api_key.as_ref().ok_or_else(|| {
anyhow::anyhow!("Anthropic API key not set. Set ANTHROPIC_API_KEY or edit config.toml.")
let credential = self.credential.as_ref().ok_or_else(|| {
anyhow::anyhow!(
"Anthropic credentials not set. Set ANTHROPIC_API_KEY or ANTHROPIC_OAUTH_TOKEN (setup-token)."
)
})?;
let request = ChatRequest {
@ -71,29 +342,65 @@ impl Provider for AnthropicProvider {
temperature,
};
let response = self
let mut request = self
.client
.post("https://api.anthropic.com/v1/messages")
.header("x-api-key", api_key)
.post(format!("{}/v1/messages", self.base_url))
.header("anthropic-version", "2023-06-01")
.header("content-type", "application/json")
.json(&request)
.send()
.await?;
.json(&request);
request = self.apply_auth(request, credential);
let response = request.send().await?;
if !response.status().is_success() {
let error = response.text().await?;
anyhow::bail!("Anthropic API error: {error}");
return Err(super::api_error("Anthropic", response).await);
}
let chat_response: ChatResponse = response.json().await?;
Self::parse_text_response(chat_response)
}
chat_response
.content
.into_iter()
.next()
.map(|c| c.text)
.ok_or_else(|| anyhow::anyhow!("No response from Anthropic"))
async fn chat(
&self,
request: ProviderChatRequest<'_>,
model: &str,
temperature: f64,
) -> anyhow::Result<ProviderChatResponse> {
let credential = self.credential.as_ref().ok_or_else(|| {
anyhow::anyhow!(
"Anthropic credentials not set. Set ANTHROPIC_API_KEY or ANTHROPIC_OAUTH_TOKEN (setup-token)."
)
})?;
let (system_prompt, messages) = Self::convert_messages(request.messages);
let native_request = NativeChatRequest {
model: model.to_string(),
max_tokens: 4096,
system: system_prompt,
messages,
temperature,
tools: Self::convert_tools(request.tools),
};
let req = self
.client
.post(format!("{}/v1/messages", self.base_url))
.header("anthropic-version", "2023-06-01")
.header("content-type", "application/json")
.json(&native_request);
let response = self.apply_auth(req, credential).send().await?;
if !response.status().is_success() {
return Err(super::api_error("Anthropic", response).await);
}
let native_response: NativeChatResponse = response.json().await?;
Ok(Self::parse_native_response(native_response))
}
fn supports_native_tools(&self) -> bool {
true
}
}
@ -103,22 +410,52 @@ mod tests {
#[test]
fn creates_with_key() {
let p = AnthropicProvider::new(Some("sk-ant-test123"));
assert!(p.api_key.is_some());
assert_eq!(p.api_key.as_deref(), Some("sk-ant-test123"));
let p = AnthropicProvider::new(Some("anthropic-test-credential"));
assert!(p.credential.is_some());
assert_eq!(p.credential.as_deref(), Some("anthropic-test-credential"));
assert_eq!(p.base_url, "https://api.anthropic.com");
}
#[test]
fn creates_without_key() {
let p = AnthropicProvider::new(None);
assert!(p.api_key.is_none());
assert!(p.credential.is_none());
assert_eq!(p.base_url, "https://api.anthropic.com");
}
#[test]
fn creates_with_empty_key() {
let p = AnthropicProvider::new(Some(""));
assert!(p.api_key.is_some());
assert_eq!(p.api_key.as_deref(), Some(""));
assert!(p.credential.is_none());
}
#[test]
fn creates_with_whitespace_key() {
let p = AnthropicProvider::new(Some(" anthropic-test-credential "));
assert!(p.credential.is_some());
assert_eq!(p.credential.as_deref(), Some("anthropic-test-credential"));
}
#[test]
fn creates_with_custom_base_url() {
let p = AnthropicProvider::with_base_url(
Some("anthropic-credential"),
Some("https://api.example.com"),
);
assert_eq!(p.base_url, "https://api.example.com");
assert_eq!(p.credential.as_deref(), Some("anthropic-credential"));
}
#[test]
fn custom_base_url_trims_trailing_slash() {
let p = AnthropicProvider::with_base_url(None, Some("https://api.example.com/"));
assert_eq!(p.base_url, "https://api.example.com");
}
#[test]
fn default_base_url_when_none_provided() {
let p = AnthropicProvider::with_base_url(None, None);
assert_eq!(p.base_url, "https://api.anthropic.com");
}
#[tokio::test]
@ -130,11 +467,67 @@ mod tests {
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(
err.contains("API key not set"),
err.contains("credentials not set"),
"Expected key error, got: {err}"
);
}
#[test]
fn setup_token_detection_works() {
assert!(AnthropicProvider::is_setup_token("sk-ant-oat01-abcdef"));
assert!(!AnthropicProvider::is_setup_token("sk-ant-api-key"));
}
#[test]
fn apply_auth_uses_bearer_and_beta_for_setup_tokens() {
let provider = AnthropicProvider::new(None);
let request = provider
.apply_auth(
provider.client.get("https://api.anthropic.com/v1/models"),
"sk-ant-oat01-test-token",
)
.build()
.expect("request should build");
assert_eq!(
request
.headers()
.get("authorization")
.and_then(|v| v.to_str().ok()),
Some("Bearer sk-ant-oat01-test-token")
);
assert_eq!(
request
.headers()
.get("anthropic-beta")
.and_then(|v| v.to_str().ok()),
Some("oauth-2025-04-20")
);
assert!(request.headers().get("x-api-key").is_none());
}
#[test]
fn apply_auth_uses_x_api_key_for_regular_tokens() {
let provider = AnthropicProvider::new(None);
let request = provider
.apply_auth(
provider.client.get("https://api.anthropic.com/v1/models"),
"sk-ant-api-key",
)
.build()
.expect("request should build");
assert_eq!(
request
.headers()
.get("x-api-key")
.and_then(|v| v.to_str().ok()),
Some("sk-ant-api-key")
);
assert!(request.headers().get("authorization").is_none());
assert!(request.headers().get("anthropic-beta").is_none());
}
#[tokio::test]
async fn chat_with_system_fails_without_key() {
let p = AnthropicProvider::new(None);
@ -186,7 +579,8 @@ mod tests {
let json = r#"{"content":[{"type":"text","text":"Hello there!"}]}"#;
let resp: ChatResponse = serde_json::from_str(json).unwrap();
assert_eq!(resp.content.len(), 1);
assert_eq!(resp.content[0].text, "Hello there!");
assert_eq!(resp.content[0].kind, "text");
assert_eq!(resp.content[0].text.as_deref(), Some("Hello there!"));
}
#[test]
@ -202,8 +596,8 @@ mod tests {
r#"{"content":[{"type":"text","text":"First"},{"type":"text","text":"Second"}]}"#;
let resp: ChatResponse = serde_json::from_str(json).unwrap();
assert_eq!(resp.content.len(), 2);
assert_eq!(resp.content[0].text, "First");
assert_eq!(resp.content[1].text, "Second");
assert_eq!(resp.content[0].text.as_deref(), Some("First"));
assert_eq!(resp.content[1].text.as_deref(), Some("Second"));
}
#[test]

File diff suppressed because it is too large Load diff

705
src/providers/copilot.rs Normal file
View file

@ -0,0 +1,705 @@
//! GitHub Copilot provider with OAuth device-flow authentication.
//!
//! Authenticates via GitHub's device code flow (same as VS Code Copilot),
//! then exchanges the OAuth token for short-lived Copilot API keys.
//! Tokens are cached to disk and auto-refreshed.
//!
//! **Note:** This uses VS Code's OAuth client ID (`Iv1.b507a08c87ecfe98`) and
//! editor headers. This is the same approach used by LiteLLM, Codex CLI,
//! and other third-party Copilot integrations. The Copilot token endpoint is
//! private; there is no public OAuth scope or app registration for it.
//! GitHub could change or revoke this at any time, which would break all
//! third-party integrations simultaneously.
use crate::providers::traits::{
ChatMessage, ChatRequest as ProviderChatRequest, ChatResponse as ProviderChatResponse,
Provider, ToolCall as ProviderToolCall,
};
use crate::tools::ToolSpec;
use async_trait::async_trait;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::Mutex;
use tracing::warn;
/// GitHub OAuth client ID for Copilot (VS Code extension).
const GITHUB_CLIENT_ID: &str = "Iv1.b507a08c87ecfe98";
const GITHUB_DEVICE_CODE_URL: &str = "https://github.com/login/device/code";
const GITHUB_ACCESS_TOKEN_URL: &str = "https://github.com/login/oauth/access_token";
const GITHUB_API_KEY_URL: &str = "https://api.github.com/copilot_internal/v2/token";
const DEFAULT_API: &str = "https://api.githubcopilot.com";
// ── Token types ──────────────────────────────────────────────────
#[derive(Debug, Deserialize)]
struct DeviceCodeResponse {
device_code: String,
user_code: String,
verification_uri: String,
#[serde(default = "default_interval")]
interval: u64,
#[serde(default = "default_expires_in")]
expires_in: u64,
}
fn default_interval() -> u64 {
5
}
fn default_expires_in() -> u64 {
900
}
#[derive(Debug, Deserialize)]
struct AccessTokenResponse {
access_token: Option<String>,
error: Option<String>,
}
#[derive(Debug, Serialize, Deserialize)]
struct ApiKeyInfo {
token: String,
expires_at: i64,
#[serde(default)]
endpoints: Option<ApiEndpoints>,
}
#[derive(Debug, Serialize, Deserialize)]
struct ApiEndpoints {
api: Option<String>,
}
struct CachedApiKey {
token: String,
api_endpoint: String,
expires_at: i64,
}
// ── Chat completions types ───────────────────────────────────────
#[derive(Debug, Serialize)]
struct ApiChatRequest {
model: String,
messages: Vec<ApiMessage>,
temperature: f64,
#[serde(skip_serializing_if = "Option::is_none")]
tools: Option<Vec<NativeToolSpec>>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_choice: Option<String>,
}
#[derive(Debug, Serialize)]
struct ApiMessage {
role: String,
#[serde(skip_serializing_if = "Option::is_none")]
content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_call_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_calls: Option<Vec<NativeToolCall>>,
}
#[derive(Debug, Serialize)]
struct NativeToolSpec {
#[serde(rename = "type")]
kind: String,
function: NativeToolFunctionSpec,
}
#[derive(Debug, Serialize)]
struct NativeToolFunctionSpec {
name: String,
description: String,
parameters: serde_json::Value,
}
#[derive(Debug, Serialize, Deserialize)]
struct NativeToolCall {
#[serde(skip_serializing_if = "Option::is_none")]
id: Option<String>,
#[serde(rename = "type", skip_serializing_if = "Option::is_none")]
kind: Option<String>,
function: NativeFunctionCall,
}
#[derive(Debug, Serialize, Deserialize)]
struct NativeFunctionCall {
name: String,
arguments: String,
}
#[derive(Debug, Deserialize)]
struct ApiChatResponse {
choices: Vec<Choice>,
}
#[derive(Debug, Deserialize)]
struct Choice {
message: ResponseMessage,
}
#[derive(Debug, Deserialize)]
struct ResponseMessage {
#[serde(default)]
content: Option<String>,
#[serde(default)]
tool_calls: Option<Vec<NativeToolCall>>,
}
// ── Provider ─────────────────────────────────────────────────────
/// GitHub Copilot provider with automatic OAuth and token refresh.
///
/// On first use, prompts the user to visit github.com/login/device.
/// Tokens are cached to `~/.config/zeroclaw/copilot/` and refreshed
/// automatically.
pub struct CopilotProvider {
github_token: Option<String>,
/// Mutex ensures only one caller refreshes tokens at a time,
/// preventing duplicate device flow prompts or redundant API calls.
refresh_lock: Arc<Mutex<Option<CachedApiKey>>>,
http: Client,
token_dir: PathBuf,
}
impl CopilotProvider {
pub fn new(github_token: Option<&str>) -> Self {
let token_dir = directories::ProjectDirs::from("", "", "zeroclaw")
.map(|dir| dir.config_dir().join("copilot"))
.unwrap_or_else(|| {
// Fall back to a user-specific temp directory to avoid
// shared-directory symlink attacks.
let user = std::env::var("USER")
.or_else(|_| std::env::var("USERNAME"))
.unwrap_or_else(|_| "unknown".to_string());
std::env::temp_dir().join(format!("zeroclaw-copilot-{user}"))
});
if let Err(err) = std::fs::create_dir_all(&token_dir) {
warn!(
"Failed to create Copilot token directory {:?}: {err}. Token caching is disabled.",
token_dir
);
} else {
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
if let Err(err) =
std::fs::set_permissions(&token_dir, std::fs::Permissions::from_mode(0o700))
{
warn!(
"Failed to set Copilot token directory permissions on {:?}: {err}",
token_dir
);
}
}
}
Self {
github_token: github_token
.filter(|token| !token.is_empty())
.map(String::from),
refresh_lock: Arc::new(Mutex::new(None)),
http: Client::builder()
.timeout(Duration::from_secs(120))
.connect_timeout(Duration::from_secs(10))
.build()
.unwrap_or_else(|_| Client::new()),
token_dir,
}
}
/// Required headers for Copilot API requests (editor identification).
const COPILOT_HEADERS: [(&str, &str); 4] = [
("Editor-Version", "vscode/1.85.1"),
("Editor-Plugin-Version", "copilot/1.155.0"),
("User-Agent", "GithubCopilot/1.155.0"),
("Accept", "application/json"),
];
fn convert_tools(tools: Option<&[ToolSpec]>) -> Option<Vec<NativeToolSpec>> {
tools.map(|items| {
items
.iter()
.map(|tool| NativeToolSpec {
kind: "function".to_string(),
function: NativeToolFunctionSpec {
name: tool.name.clone(),
description: tool.description.clone(),
parameters: tool.parameters.clone(),
},
})
.collect()
})
}
fn convert_messages(messages: &[ChatMessage]) -> Vec<ApiMessage> {
messages
.iter()
.map(|message| {
if message.role == "assistant" {
if let Ok(value) = serde_json::from_str::<serde_json::Value>(&message.content) {
if let Some(tool_calls_value) = value.get("tool_calls") {
if let Ok(parsed_calls) =
serde_json::from_value::<Vec<ProviderToolCall>>(tool_calls_value.clone())
{
let tool_calls = parsed_calls
.into_iter()
.map(|tool_call| NativeToolCall {
id: Some(tool_call.id),
kind: Some("function".to_string()),
function: NativeFunctionCall {
name: tool_call.name,
arguments: tool_call.arguments,
},
})
.collect::<Vec<_>>();
let content = value
.get("content")
.and_then(serde_json::Value::as_str)
.map(ToString::to_string);
return ApiMessage {
role: "assistant".to_string(),
content,
tool_call_id: None,
tool_calls: Some(tool_calls),
};
}
}
}
}
if message.role == "tool" {
if let Ok(value) = serde_json::from_str::<serde_json::Value>(&message.content) {
let tool_call_id = value
.get("tool_call_id")
.and_then(serde_json::Value::as_str)
.map(ToString::to_string);
let content = value
.get("content")
.and_then(serde_json::Value::as_str)
.map(ToString::to_string);
return ApiMessage {
role: "tool".to_string(),
content,
tool_call_id,
tool_calls: None,
};
}
}
ApiMessage {
role: message.role.clone(),
content: Some(message.content.clone()),
tool_call_id: None,
tool_calls: None,
}
})
.collect()
}
/// Send a chat completions request with required Copilot headers.
async fn send_chat_request(
&self,
messages: Vec<ApiMessage>,
tools: Option<&[ToolSpec]>,
model: &str,
temperature: f64,
) -> anyhow::Result<ProviderChatResponse> {
let (token, endpoint) = self.get_api_key().await?;
let url = format!("{}/chat/completions", endpoint.trim_end_matches('/'));
let native_tools = Self::convert_tools(tools);
let request = ApiChatRequest {
model: model.to_string(),
messages,
temperature,
tool_choice: native_tools.as_ref().map(|_| "auto".to_string()),
tools: native_tools,
};
let mut req = self
.http
.post(&url)
.header("Authorization", format!("Bearer {token}"))
.json(&request);
for (header, value) in &Self::COPILOT_HEADERS {
req = req.header(*header, *value);
}
let response = req.send().await?;
if !response.status().is_success() {
return Err(super::api_error("GitHub Copilot", response).await);
}
let api_response: ApiChatResponse = response.json().await?;
let choice = api_response
.choices
.into_iter()
.next()
.ok_or_else(|| anyhow::anyhow!("No response from GitHub Copilot"))?;
let tool_calls = choice
.message
.tool_calls
.unwrap_or_default()
.into_iter()
.map(|tool_call| ProviderToolCall {
id: tool_call
.id
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string()),
name: tool_call.function.name,
arguments: tool_call.function.arguments,
})
.collect();
Ok(ProviderChatResponse {
text: choice.message.content,
tool_calls,
})
}
/// Get a valid Copilot API key, refreshing or re-authenticating as needed.
/// Uses a Mutex to ensure only one caller refreshes at a time.
async fn get_api_key(&self) -> anyhow::Result<(String, String)> {
let mut cached = self.refresh_lock.lock().await;
if let Some(cached_key) = cached.as_ref() {
if chrono::Utc::now().timestamp() + 120 < cached_key.expires_at {
return Ok((cached_key.token.clone(), cached_key.api_endpoint.clone()));
}
}
if let Some(info) = self.load_api_key_from_disk().await {
if chrono::Utc::now().timestamp() + 120 < info.expires_at {
let endpoint = info
.endpoints
.as_ref()
.and_then(|e| e.api.clone())
.unwrap_or_else(|| DEFAULT_API.to_string());
let token = info.token;
*cached = Some(CachedApiKey {
token: token.clone(),
api_endpoint: endpoint.clone(),
expires_at: info.expires_at,
});
return Ok((token, endpoint));
}
}
let access_token = self.get_github_access_token().await?;
let api_key_info = self.exchange_for_api_key(&access_token).await?;
self.save_api_key_to_disk(&api_key_info).await;
let endpoint = api_key_info
.endpoints
.as_ref()
.and_then(|e| e.api.clone())
.unwrap_or_else(|| DEFAULT_API.to_string());
*cached = Some(CachedApiKey {
token: api_key_info.token.clone(),
api_endpoint: endpoint.clone(),
expires_at: api_key_info.expires_at,
});
Ok((api_key_info.token, endpoint))
}
/// Get a GitHub access token from config, cache, or device flow.
async fn get_github_access_token(&self) -> anyhow::Result<String> {
if let Some(token) = &self.github_token {
return Ok(token.clone());
}
let access_token_path = self.token_dir.join("access-token");
if let Ok(cached) = tokio::fs::read_to_string(&access_token_path).await {
let token = cached.trim();
if !token.is_empty() {
return Ok(token.to_string());
}
}
let token = self.device_code_login().await?;
write_file_secure(&access_token_path, &token).await;
Ok(token)
}
/// Run GitHub OAuth device code flow.
async fn device_code_login(&self) -> anyhow::Result<String> {
let response: DeviceCodeResponse = self
.http
.post(GITHUB_DEVICE_CODE_URL)
.header("Accept", "application/json")
.json(&serde_json::json!({
"client_id": GITHUB_CLIENT_ID,
"scope": "read:user"
}))
.send()
.await?
.error_for_status()?
.json()
.await?;
let mut poll_interval = Duration::from_secs(response.interval.max(5));
let expires_in = response.expires_in.max(1);
let expires_at = tokio::time::Instant::now() + Duration::from_secs(expires_in);
eprintln!(
"\nGitHub Copilot authentication is required.\n\
Visit: {}\n\
Code: {}\n\
Waiting for authorization...\n",
response.verification_uri, response.user_code
);
while tokio::time::Instant::now() < expires_at {
tokio::time::sleep(poll_interval).await;
let token_response: AccessTokenResponse = self
.http
.post(GITHUB_ACCESS_TOKEN_URL)
.header("Accept", "application/json")
.json(&serde_json::json!({
"client_id": GITHUB_CLIENT_ID,
"device_code": response.device_code,
"grant_type": "urn:ietf:params:oauth:grant-type:device_code"
}))
.send()
.await?
.json()
.await?;
if let Some(token) = token_response.access_token {
eprintln!("Authentication succeeded.\n");
return Ok(token);
}
match token_response.error.as_deref() {
Some("slow_down") => {
poll_interval += Duration::from_secs(5);
}
Some("authorization_pending") | None => {}
Some("expired_token") => {
anyhow::bail!("GitHub device authorization expired")
}
Some(error) => anyhow::bail!("GitHub auth failed: {error}"),
}
}
anyhow::bail!("Timed out waiting for GitHub authorization")
}
/// Exchange a GitHub access token for a Copilot API key.
async fn exchange_for_api_key(&self, access_token: &str) -> anyhow::Result<ApiKeyInfo> {
let mut request = self.http.get(GITHUB_API_KEY_URL);
for (header, value) in &Self::COPILOT_HEADERS {
request = request.header(*header, *value);
}
request = request.header("Authorization", format!("token {access_token}"));
let response = request.send().await?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
let sanitized = super::sanitize_api_error(&body);
if status.as_u16() == 401 || status.as_u16() == 403 {
let access_token_path = self.token_dir.join("access-token");
tokio::fs::remove_file(&access_token_path).await.ok();
}
anyhow::bail!(
"Failed to get Copilot API key ({status}): {sanitized}. \
Ensure your GitHub account has an active Copilot subscription."
);
}
let info: ApiKeyInfo = response.json().await?;
Ok(info)
}
async fn load_api_key_from_disk(&self) -> Option<ApiKeyInfo> {
let path = self.token_dir.join("api-key.json");
let data = tokio::fs::read_to_string(&path).await.ok()?;
serde_json::from_str(&data).ok()
}
async fn save_api_key_to_disk(&self, info: &ApiKeyInfo) {
let path = self.token_dir.join("api-key.json");
if let Ok(json) = serde_json::to_string_pretty(info) {
write_file_secure(&path, &json).await;
}
}
}
/// Write a file with 0600 permissions (owner read/write only).
/// Uses `spawn_blocking` to avoid blocking the async runtime.
async fn write_file_secure(path: &Path, content: &str) {
let path = path.to_path_buf();
let content = content.to_string();
let result = tokio::task::spawn_blocking(move || {
#[cfg(unix)]
{
use std::io::Write;
use std::os::unix::fs::{OpenOptionsExt, PermissionsExt};
let mut file = std::fs::OpenOptions::new()
.write(true)
.create(true)
.truncate(true)
.mode(0o600)
.open(&path)?;
file.write_all(content.as_bytes())?;
std::fs::set_permissions(&path, std::fs::Permissions::from_mode(0o600))?;
Ok::<(), std::io::Error>(())
}
#[cfg(not(unix))]
{
std::fs::write(&path, &content)?;
Ok::<(), std::io::Error>(())
}
})
.await;
match result {
Ok(Ok(())) => {}
Ok(Err(err)) => warn!("Failed to write secure file: {err}"),
Err(err) => warn!("Failed to spawn blocking write: {err}"),
}
}
#[async_trait]
impl Provider for CopilotProvider {
async fn chat_with_system(
&self,
system_prompt: Option<&str>,
message: &str,
model: &str,
temperature: f64,
) -> anyhow::Result<String> {
let mut messages = Vec::new();
if let Some(system) = system_prompt {
messages.push(ApiMessage {
role: "system".to_string(),
content: Some(system.to_string()),
tool_call_id: None,
tool_calls: None,
});
}
messages.push(ApiMessage {
role: "user".to_string(),
content: Some(message.to_string()),
tool_call_id: None,
tool_calls: None,
});
let response = self
.send_chat_request(messages, None, model, temperature)
.await?;
Ok(response.text.unwrap_or_default())
}
async fn chat_with_history(
&self,
messages: &[ChatMessage],
model: &str,
temperature: f64,
) -> anyhow::Result<String> {
let response = self
.send_chat_request(Self::convert_messages(messages), None, model, temperature)
.await?;
Ok(response.text.unwrap_or_default())
}
async fn chat(
&self,
request: ProviderChatRequest<'_>,
model: &str,
temperature: f64,
) -> anyhow::Result<ProviderChatResponse> {
self.send_chat_request(
Self::convert_messages(request.messages),
request.tools,
model,
temperature,
)
.await
}
fn supports_native_tools(&self) -> bool {
true
}
async fn warmup(&self) -> anyhow::Result<()> {
let _ = self.get_api_key().await?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn new_without_token() {
let provider = CopilotProvider::new(None);
assert!(provider.github_token.is_none());
}
#[test]
fn new_with_token() {
let provider = CopilotProvider::new(Some("ghp_test"));
assert_eq!(provider.github_token.as_deref(), Some("ghp_test"));
}
#[test]
fn empty_token_treated_as_none() {
let provider = CopilotProvider::new(Some(""));
assert!(provider.github_token.is_none());
}
#[tokio::test]
async fn cache_starts_empty() {
let provider = CopilotProvider::new(None);
let cached = provider.refresh_lock.lock().await;
assert!(cached.is_none());
}
#[test]
fn copilot_headers_include_required_fields() {
let headers = CopilotProvider::COPILOT_HEADERS;
assert!(headers
.iter()
.any(|(header, _)| *header == "Editor-Version"));
assert!(headers
.iter()
.any(|(header, _)| *header == "Editor-Plugin-Version"));
assert!(headers.iter().any(|(header, _)| *header == "User-Agent"));
}
#[test]
fn default_interval_and_expiry() {
assert_eq!(default_interval(), 5);
assert_eq!(default_expires_in(), 900);
}
#[test]
fn supports_native_tools() {
let provider = CopilotProvider::new(None);
assert!(provider.supports_native_tools());
}
}

560
src/providers/gemini.rs Normal file
View file

@ -0,0 +1,560 @@
//! Google Gemini provider with support for:
//! - Direct API key (`GEMINI_API_KEY` env var or config)
//! - Gemini CLI OAuth tokens (reuse existing ~/.gemini/ authentication)
//! - Google Cloud ADC (`GOOGLE_APPLICATION_CREDENTIALS`)
use crate::providers::traits::Provider;
use async_trait::async_trait;
use directories::UserDirs;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::path::PathBuf;
/// Gemini provider supporting multiple authentication methods.
pub struct GeminiProvider {
auth: Option<GeminiAuth>,
client: Client,
}
/// Resolved credential — the variant determines both the HTTP auth method
/// and the diagnostic label returned by `auth_source()`.
#[derive(Debug)]
enum GeminiAuth {
/// Explicit API key from config: sent as `?key=` query parameter.
ExplicitKey(String),
/// API key from `GEMINI_API_KEY` env var: sent as `?key=`.
EnvGeminiKey(String),
/// API key from `GOOGLE_API_KEY` env var: sent as `?key=`.
EnvGoogleKey(String),
/// OAuth access token from Gemini CLI: sent as `Authorization: Bearer`.
OAuthToken(String),
}
impl GeminiAuth {
/// Whether this credential is an API key (sent as `?key=` query param).
fn is_api_key(&self) -> bool {
matches!(
self,
GeminiAuth::ExplicitKey(_) | GeminiAuth::EnvGeminiKey(_) | GeminiAuth::EnvGoogleKey(_)
)
}
/// The raw credential string.
fn credential(&self) -> &str {
match self {
GeminiAuth::ExplicitKey(s)
| GeminiAuth::EnvGeminiKey(s)
| GeminiAuth::EnvGoogleKey(s)
| GeminiAuth::OAuthToken(s) => s,
}
}
}
// ══════════════════════════════════════════════════════════════════════════════
// API REQUEST/RESPONSE TYPES
// ══════════════════════════════════════════════════════════════════════════════
#[derive(Debug, Serialize)]
struct GenerateContentRequest {
contents: Vec<Content>,
#[serde(skip_serializing_if = "Option::is_none")]
system_instruction: Option<Content>,
#[serde(rename = "generationConfig")]
generation_config: GenerationConfig,
}
#[derive(Debug, Serialize)]
struct Content {
#[serde(skip_serializing_if = "Option::is_none")]
role: Option<String>,
parts: Vec<Part>,
}
#[derive(Debug, Serialize)]
struct Part {
text: String,
}
#[derive(Debug, Serialize)]
struct GenerationConfig {
temperature: f64,
#[serde(rename = "maxOutputTokens")]
max_output_tokens: u32,
}
#[derive(Debug, Deserialize)]
struct GenerateContentResponse {
candidates: Option<Vec<Candidate>>,
error: Option<ApiError>,
}
#[derive(Debug, Deserialize)]
struct Candidate {
content: CandidateContent,
}
#[derive(Debug, Deserialize)]
struct CandidateContent {
parts: Vec<ResponsePart>,
}
#[derive(Debug, Deserialize)]
struct ResponsePart {
text: Option<String>,
}
#[derive(Debug, Deserialize)]
struct ApiError {
message: String,
}
// ══════════════════════════════════════════════════════════════════════════════
// GEMINI CLI TOKEN STRUCTURES
// ══════════════════════════════════════════════════════════════════════════════
/// OAuth token stored by Gemini CLI in `~/.gemini/oauth_creds.json`
#[derive(Debug, Deserialize)]
struct GeminiCliOAuthCreds {
access_token: Option<String>,
expiry: Option<String>,
}
impl GeminiProvider {
/// Create a new Gemini provider.
///
/// Authentication priority:
/// 1. Explicit API key passed in
/// 2. `GEMINI_API_KEY` environment variable
/// 3. `GOOGLE_API_KEY` environment variable
/// 4. Gemini CLI OAuth tokens (`~/.gemini/oauth_creds.json`)
pub fn new(api_key: Option<&str>) -> Self {
let resolved_auth = api_key
.and_then(Self::normalize_non_empty)
.map(GeminiAuth::ExplicitKey)
.or_else(|| Self::load_non_empty_env("GEMINI_API_KEY").map(GeminiAuth::EnvGeminiKey))
.or_else(|| Self::load_non_empty_env("GOOGLE_API_KEY").map(GeminiAuth::EnvGoogleKey))
.or_else(|| Self::try_load_gemini_cli_token().map(GeminiAuth::OAuthToken));
Self {
auth: resolved_auth,
client: Client::builder()
.timeout(std::time::Duration::from_secs(120))
.connect_timeout(std::time::Duration::from_secs(10))
.build()
.unwrap_or_else(|_| Client::new()),
}
}
fn normalize_non_empty(value: &str) -> Option<String> {
let trimmed = value.trim();
if trimmed.is_empty() {
None
} else {
Some(trimmed.to_string())
}
}
fn load_non_empty_env(name: &str) -> Option<String> {
std::env::var(name)
.ok()
.and_then(|value| Self::normalize_non_empty(&value))
}
/// Try to load OAuth access token from Gemini CLI's cached credentials.
/// Location: `~/.gemini/oauth_creds.json`
fn try_load_gemini_cli_token() -> Option<String> {
let gemini_dir = Self::gemini_cli_dir()?;
let creds_path = gemini_dir.join("oauth_creds.json");
if !creds_path.exists() {
return None;
}
let content = std::fs::read_to_string(&creds_path).ok()?;
let creds: GeminiCliOAuthCreds = serde_json::from_str(&content).ok()?;
// Check if token is expired (basic check)
if let Some(ref expiry) = creds.expiry {
if let Ok(expiry_time) = chrono::DateTime::parse_from_rfc3339(expiry) {
if expiry_time < chrono::Utc::now() {
tracing::warn!("Gemini CLI OAuth token expired — re-run `gemini` to refresh");
return None;
}
}
}
creds
.access_token
.and_then(|token| Self::normalize_non_empty(&token))
}
/// Get the Gemini CLI config directory (~/.gemini)
fn gemini_cli_dir() -> Option<PathBuf> {
UserDirs::new().map(|u| u.home_dir().join(".gemini"))
}
/// Check if Gemini CLI is configured and has valid credentials
pub fn has_cli_credentials() -> bool {
Self::try_load_gemini_cli_token().is_some()
}
/// Check if any Gemini authentication is available
pub fn has_any_auth() -> bool {
Self::load_non_empty_env("GEMINI_API_KEY").is_some()
|| Self::load_non_empty_env("GOOGLE_API_KEY").is_some()
|| Self::has_cli_credentials()
}
/// Get authentication source description for diagnostics.
/// Uses the stored enum variant — no env var re-reading at call time.
pub fn auth_source(&self) -> &'static str {
match self.auth.as_ref() {
Some(GeminiAuth::ExplicitKey(_)) => "config",
Some(GeminiAuth::EnvGeminiKey(_)) => "GEMINI_API_KEY env var",
Some(GeminiAuth::EnvGoogleKey(_)) => "GOOGLE_API_KEY env var",
Some(GeminiAuth::OAuthToken(_)) => "Gemini CLI OAuth",
None => "none",
}
}
fn format_model_name(model: &str) -> String {
if model.starts_with("models/") {
model.to_string()
} else {
format!("models/{model}")
}
}
fn build_generate_content_url(model: &str, auth: &GeminiAuth) -> String {
let model_name = Self::format_model_name(model);
let base_url = format!(
"https://generativelanguage.googleapis.com/v1beta/{model_name}:generateContent"
);
if auth.is_api_key() {
format!("{base_url}?key={}", auth.credential())
} else {
base_url
}
}
fn build_generate_content_request(
&self,
auth: &GeminiAuth,
url: &str,
request: &GenerateContentRequest,
) -> reqwest::RequestBuilder {
let req = self.client.post(url).json(request);
match auth {
GeminiAuth::OAuthToken(token) => req.bearer_auth(token),
_ => req,
}
}
}
#[async_trait]
impl Provider for GeminiProvider {
async fn chat_with_system(
&self,
system_prompt: Option<&str>,
message: &str,
model: &str,
temperature: f64,
) -> anyhow::Result<String> {
let auth = self.auth.as_ref().ok_or_else(|| {
anyhow::anyhow!(
"Gemini API key not found. Options:\n\
1. Set GEMINI_API_KEY env var\n\
2. Run `gemini` CLI to authenticate (tokens will be reused)\n\
3. Get an API key from https://aistudio.google.com/app/apikey\n\
4. Run `zeroclaw onboard` to configure"
)
})?;
// Build request
let system_instruction = system_prompt.map(|sys| Content {
role: None,
parts: vec![Part {
text: sys.to_string(),
}],
});
let request = GenerateContentRequest {
contents: vec![Content {
role: Some("user".to_string()),
parts: vec![Part {
text: message.to_string(),
}],
}],
system_instruction,
generation_config: GenerationConfig {
temperature,
max_output_tokens: 8192,
},
};
let url = Self::build_generate_content_url(model, auth);
let response = self
.build_generate_content_request(auth, &url, &request)
.send()
.await?;
if !response.status().is_success() {
let status = response.status();
let error_text = response.text().await.unwrap_or_default();
anyhow::bail!("Gemini API error ({status}): {error_text}");
}
let result: GenerateContentResponse = response.json().await?;
// Check for API error in response body
if let Some(err) = result.error {
anyhow::bail!("Gemini API error: {}", err.message);
}
// Extract text from response
result
.candidates
.and_then(|c| c.into_iter().next())
.and_then(|c| c.content.parts.into_iter().next())
.and_then(|p| p.text)
.ok_or_else(|| anyhow::anyhow!("No response from Gemini"))
}
}
#[cfg(test)]
mod tests {
use super::*;
use reqwest::header::AUTHORIZATION;
#[test]
fn normalize_non_empty_trims_and_filters() {
assert_eq!(
GeminiProvider::normalize_non_empty(" value "),
Some("value".into())
);
assert_eq!(GeminiProvider::normalize_non_empty(""), None);
assert_eq!(GeminiProvider::normalize_non_empty(" \t\n"), None);
}
#[test]
fn provider_creates_without_key() {
let provider = GeminiProvider::new(None);
// May pick up env vars; just verify it doesn't panic
let _ = provider.auth_source();
}
#[test]
fn provider_creates_with_key() {
let provider = GeminiProvider::new(Some("test-api-key"));
assert!(matches!(
provider.auth,
Some(GeminiAuth::ExplicitKey(ref key)) if key == "test-api-key"
));
}
#[test]
fn provider_rejects_empty_key() {
let provider = GeminiProvider::new(Some(""));
assert!(!matches!(provider.auth, Some(GeminiAuth::ExplicitKey(_))));
}
#[test]
fn gemini_cli_dir_returns_path() {
let dir = GeminiProvider::gemini_cli_dir();
// Should return Some on systems with home dir
if UserDirs::new().is_some() {
assert!(dir.is_some());
assert!(dir.unwrap().ends_with(".gemini"));
}
}
#[test]
fn auth_source_explicit_key() {
let provider = GeminiProvider {
auth: Some(GeminiAuth::ExplicitKey("key".into())),
client: Client::new(),
};
assert_eq!(provider.auth_source(), "config");
}
#[test]
fn auth_source_none_without_credentials() {
let provider = GeminiProvider {
auth: None,
client: Client::new(),
};
assert_eq!(provider.auth_source(), "none");
}
#[test]
fn auth_source_oauth() {
let provider = GeminiProvider {
auth: Some(GeminiAuth::OAuthToken("ya29.mock".into())),
client: Client::new(),
};
assert_eq!(provider.auth_source(), "Gemini CLI OAuth");
}
#[test]
fn model_name_formatting() {
assert_eq!(
GeminiProvider::format_model_name("gemini-2.0-flash"),
"models/gemini-2.0-flash"
);
assert_eq!(
GeminiProvider::format_model_name("models/gemini-1.5-pro"),
"models/gemini-1.5-pro"
);
}
#[test]
fn api_key_url_includes_key_query_param() {
let auth = GeminiAuth::ExplicitKey("api-key-123".into());
let url = GeminiProvider::build_generate_content_url("gemini-2.0-flash", &auth);
assert!(url.contains(":generateContent?key=api-key-123"));
}
#[test]
fn oauth_url_omits_key_query_param() {
let auth = GeminiAuth::OAuthToken("ya29.test-token".into());
let url = GeminiProvider::build_generate_content_url("gemini-2.0-flash", &auth);
assert!(url.ends_with(":generateContent"));
assert!(!url.contains("?key="));
}
#[test]
fn oauth_request_uses_bearer_auth_header() {
let provider = GeminiProvider {
auth: Some(GeminiAuth::OAuthToken("ya29.mock-token".into())),
client: Client::new(),
};
let auth = GeminiAuth::OAuthToken("ya29.mock-token".into());
let url = GeminiProvider::build_generate_content_url("gemini-2.0-flash", &auth);
let body = GenerateContentRequest {
contents: vec![Content {
role: Some("user".into()),
parts: vec![Part {
text: "hello".into(),
}],
}],
system_instruction: None,
generation_config: GenerationConfig {
temperature: 0.7,
max_output_tokens: 8192,
},
};
let request = provider
.build_generate_content_request(&auth, &url, &body)
.build()
.unwrap();
assert_eq!(
request
.headers()
.get(AUTHORIZATION)
.and_then(|h| h.to_str().ok()),
Some("Bearer ya29.mock-token")
);
}
#[test]
fn api_key_request_does_not_set_bearer_header() {
let provider = GeminiProvider {
auth: Some(GeminiAuth::ExplicitKey("api-key-123".into())),
client: Client::new(),
};
let auth = GeminiAuth::ExplicitKey("api-key-123".into());
let url = GeminiProvider::build_generate_content_url("gemini-2.0-flash", &auth);
let body = GenerateContentRequest {
contents: vec![Content {
role: Some("user".into()),
parts: vec![Part {
text: "hello".into(),
}],
}],
system_instruction: None,
generation_config: GenerationConfig {
temperature: 0.7,
max_output_tokens: 8192,
},
};
let request = provider
.build_generate_content_request(&auth, &url, &body)
.build()
.unwrap();
assert!(request.headers().get(AUTHORIZATION).is_none());
}
#[test]
fn request_serialization() {
let request = GenerateContentRequest {
contents: vec![Content {
role: Some("user".to_string()),
parts: vec![Part {
text: "Hello".to_string(),
}],
}],
system_instruction: Some(Content {
role: None,
parts: vec![Part {
text: "You are helpful".to_string(),
}],
}),
generation_config: GenerationConfig {
temperature: 0.7,
max_output_tokens: 8192,
},
};
let json = serde_json::to_string(&request).unwrap();
assert!(json.contains("\"role\":\"user\""));
assert!(json.contains("\"text\":\"Hello\""));
assert!(json.contains("\"temperature\":0.7"));
assert!(json.contains("\"maxOutputTokens\":8192"));
}
#[test]
fn response_deserialization() {
let json = r#"{
"candidates": [{
"content": {
"parts": [{"text": "Hello there!"}]
}
}]
}"#;
let response: GenerateContentResponse = serde_json::from_str(json).unwrap();
assert!(response.candidates.is_some());
let text = response
.candidates
.unwrap()
.into_iter()
.next()
.unwrap()
.content
.parts
.into_iter()
.next()
.unwrap()
.text;
assert_eq!(text, Some("Hello there!".to_string()));
}
#[test]
fn error_response_deserialization() {
let json = r#"{
"error": {
"message": "Invalid API key"
}
}"#;
let response: GenerateContentResponse = serde_json::from_str(json).unwrap();
assert!(response.error.is_some());
assert_eq!(response.error.unwrap().message, "Invalid API key");
}
}

File diff suppressed because it is too large Load diff

View file

@ -5,9 +5,12 @@ use serde::{Deserialize, Serialize};
pub struct OllamaProvider {
base_url: String,
api_key: Option<String>,
client: Client,
}
// ─── Request Structures ───────────────────────────────────────────────────────
#[derive(Debug, Serialize)]
struct ChatRequest {
model: String,
@ -27,30 +30,231 @@ struct Options {
temperature: f64,
}
// ─── Response Structures ──────────────────────────────────────────────────────
#[derive(Debug, Deserialize)]
struct ChatResponse {
struct ApiChatResponse {
message: ResponseMessage,
}
#[derive(Debug, Deserialize)]
struct ResponseMessage {
#[serde(default)]
content: String,
#[serde(default)]
tool_calls: Vec<OllamaToolCall>,
/// Some models return a "thinking" field with internal reasoning
#[serde(default)]
thinking: Option<String>,
}
#[derive(Debug, Deserialize)]
struct OllamaToolCall {
id: Option<String>,
function: OllamaFunction,
}
#[derive(Debug, Deserialize)]
struct OllamaFunction {
name: String,
#[serde(default)]
arguments: serde_json::Value,
}
// ─── Implementation ───────────────────────────────────────────────────────────
impl OllamaProvider {
pub fn new(base_url: Option<&str>) -> Self {
pub fn new(base_url: Option<&str>, api_key: Option<&str>) -> Self {
let api_key = api_key.and_then(|value| {
let trimmed = value.trim();
(!trimmed.is_empty()).then(|| trimmed.to_string())
});
Self {
base_url: base_url
.unwrap_or("http://localhost:11434")
.trim_end_matches('/')
.to_string(),
api_key,
client: Client::builder()
.timeout(std::time::Duration::from_secs(300)) // Ollama runs locally, may be slow
.timeout(std::time::Duration::from_secs(300))
.connect_timeout(std::time::Duration::from_secs(10))
.build()
.unwrap_or_else(|_| Client::new()),
}
}
fn is_local_endpoint(&self) -> bool {
reqwest::Url::parse(&self.base_url)
.ok()
.and_then(|url| url.host_str().map(|host| host.to_string()))
.is_some_and(|host| matches!(host.as_str(), "localhost" | "127.0.0.1" | "::1"))
}
fn resolve_request_details(&self, model: &str) -> anyhow::Result<(String, bool)> {
let requests_cloud = model.ends_with(":cloud");
let normalized_model = model.strip_suffix(":cloud").unwrap_or(model).to_string();
if requests_cloud && self.is_local_endpoint() {
anyhow::bail!(
"Model '{}' requested cloud routing, but Ollama endpoint is local. Configure api_url with a remote Ollama endpoint.",
model
);
}
if requests_cloud && self.api_key.is_none() {
anyhow::bail!(
"Model '{}' requested cloud routing, but no API key is configured. Set OLLAMA_API_KEY or config api_key.",
model
);
}
let should_auth = self.api_key.is_some() && !self.is_local_endpoint();
Ok((normalized_model, should_auth))
}
/// Send a request to Ollama and get the parsed response
async fn send_request(
&self,
messages: Vec<Message>,
model: &str,
temperature: f64,
should_auth: bool,
) -> anyhow::Result<ApiChatResponse> {
let request = ChatRequest {
model: model.to_string(),
messages,
stream: false,
options: Options { temperature },
};
let url = format!("{}/api/chat", self.base_url);
tracing::debug!(
"Ollama request: url={} model={} message_count={} temperature={}",
url,
model,
request.messages.len(),
temperature
);
let mut request_builder = self.client.post(&url).json(&request);
if should_auth {
if let Some(key) = self.api_key.as_ref() {
request_builder = request_builder.bearer_auth(key);
}
}
let response = request_builder.send().await?;
let status = response.status();
tracing::debug!("Ollama response status: {}", status);
let body = response.bytes().await?;
tracing::debug!("Ollama response body length: {} bytes", body.len());
if !status.is_success() {
let raw = String::from_utf8_lossy(&body);
let sanitized = super::sanitize_api_error(&raw);
tracing::error!(
"Ollama error response: status={} body_excerpt={}",
status,
sanitized
);
anyhow::bail!(
"Ollama API error ({}): {}. Is Ollama running? (brew install ollama && ollama serve)",
status,
sanitized
);
}
let chat_response: ApiChatResponse = match serde_json::from_slice(&body) {
Ok(r) => r,
Err(e) => {
let raw = String::from_utf8_lossy(&body);
let sanitized = super::sanitize_api_error(&raw);
tracing::error!(
"Ollama response deserialization failed: {e}. body_excerpt={}",
sanitized
);
anyhow::bail!("Failed to parse Ollama response: {e}");
}
};
Ok(chat_response)
}
/// Convert Ollama tool calls to the JSON format expected by parse_tool_calls in loop_.rs
///
/// Handles quirky model behavior where tool calls are wrapped:
/// - `{"name": "tool_call", "arguments": {"name": "shell", "arguments": {...}}}`
/// - `{"name": "tool.shell", "arguments": {...}}`
fn format_tool_calls_for_loop(&self, tool_calls: &[OllamaToolCall]) -> String {
let formatted_calls: Vec<serde_json::Value> = tool_calls
.iter()
.map(|tc| {
let (tool_name, tool_args) = self.extract_tool_name_and_args(tc);
// Arguments must be a JSON string for parse_tool_calls compatibility
let args_str =
serde_json::to_string(&tool_args).unwrap_or_else(|_| "{}".to_string());
serde_json::json!({
"id": tc.id,
"type": "function",
"function": {
"name": tool_name,
"arguments": args_str
}
})
})
.collect();
serde_json::json!({
"content": "",
"tool_calls": formatted_calls
})
.to_string()
}
/// Extract the actual tool name and arguments from potentially nested structures
fn extract_tool_name_and_args(&self, tc: &OllamaToolCall) -> (String, serde_json::Value) {
let name = &tc.function.name;
let args = &tc.function.arguments;
// Pattern 1: Nested tool_call wrapper (various malformed versions)
// {"name": "tool_call", "arguments": {"name": "shell", "arguments": {"command": "date"}}}
// {"name": "tool_call><json", "arguments": {"name": "shell", ...}}
// {"name": "tool.call", "arguments": {"name": "shell", ...}}
if name == "tool_call"
|| name == "tool.call"
|| name.starts_with("tool_call>")
|| name.starts_with("tool_call<")
{
if let Some(nested_name) = args.get("name").and_then(|v| v.as_str()) {
let nested_args = args
.get("arguments")
.cloned()
.unwrap_or(serde_json::json!({}));
tracing::debug!(
"Unwrapped nested tool call: {} -> {} with args {:?}",
name,
nested_name,
nested_args
);
return (nested_name.to_string(), nested_args);
}
}
// Pattern 2: Prefixed tool name (tool.shell, tool.file_read, etc.)
if let Some(stripped) = name.strip_prefix("tool.") {
return (stripped.to_string(), args.clone());
}
// Pattern 3: Normal tool call
(name.clone(), args.clone())
}
}
#[async_trait]
@ -62,6 +266,8 @@ impl Provider for OllamaProvider {
model: &str,
temperature: f64,
) -> anyhow::Result<String> {
let (normalized_model, should_auth) = self.resolve_request_details(model)?;
let mut messages = Vec::new();
if let Some(sys) = system_prompt {
@ -76,115 +282,281 @@ impl Provider for OllamaProvider {
content: message.to_string(),
});
let request = ChatRequest {
model: model.to_string(),
messages,
stream: false,
options: Options { temperature },
};
let response = self
.send_request(messages, &normalized_model, temperature, should_auth)
.await?;
let url = format!("{}/api/chat", self.base_url);
let response = self.client.post(&url).json(&request).send().await?;
if !response.status().is_success() {
let error = response.text().await?;
anyhow::bail!(
"Ollama error: {error}. Is Ollama running? (brew install ollama && ollama serve)"
// If model returned tool calls, format them for loop_.rs's parse_tool_calls
if !response.message.tool_calls.is_empty() {
tracing::debug!(
"Ollama returned {} tool call(s), formatting for loop parser",
response.message.tool_calls.len()
);
return Ok(self.format_tool_calls_for_loop(&response.message.tool_calls));
}
let chat_response: ChatResponse = response.json().await?;
Ok(chat_response.message.content)
// Plain text response
let content = response.message.content;
// Handle edge case: model returned only "thinking" with no content or tool calls
if content.is_empty() {
if let Some(thinking) = &response.message.thinking {
tracing::warn!(
"Ollama returned empty content with only thinking: '{}'. Model may have stopped prematurely.",
if thinking.len() > 100 { &thinking[..100] } else { thinking }
);
return Ok(format!(
"I was thinking about this: {}... but I didn't complete my response. Could you try asking again?",
if thinking.len() > 200 { &thinking[..200] } else { thinking }
));
}
tracing::warn!("Ollama returned empty content with no tool calls");
}
Ok(content)
}
async fn chat_with_history(
&self,
messages: &[crate::providers::ChatMessage],
model: &str,
temperature: f64,
) -> anyhow::Result<String> {
let (normalized_model, should_auth) = self.resolve_request_details(model)?;
let api_messages: Vec<Message> = messages
.iter()
.map(|m| Message {
role: m.role.clone(),
content: m.content.clone(),
})
.collect();
let response = self
.send_request(api_messages, &normalized_model, temperature, should_auth)
.await?;
// If model returned tool calls, format them for loop_.rs's parse_tool_calls
if !response.message.tool_calls.is_empty() {
tracing::debug!(
"Ollama returned {} tool call(s), formatting for loop parser",
response.message.tool_calls.len()
);
return Ok(self.format_tool_calls_for_loop(&response.message.tool_calls));
}
// Plain text response
let content = response.message.content;
// Handle edge case: model returned only "thinking" with no content or tool calls
// This is a model quirk - it stopped after reasoning without producing output
if content.is_empty() {
if let Some(thinking) = &response.message.thinking {
tracing::warn!(
"Ollama returned empty content with only thinking: '{}'. Model may have stopped prematurely.",
if thinking.len() > 100 { &thinking[..100] } else { thinking }
);
// Return a message indicating the model's thought process but no action
return Ok(format!(
"I was thinking about this: {}... but I didn't complete my response. Could you try asking again?",
if thinking.len() > 200 { &thinking[..200] } else { thinking }
));
}
tracing::warn!("Ollama returned empty content with no tool calls");
}
Ok(content)
}
fn supports_native_tools(&self) -> bool {
// Return false since loop_.rs uses XML-style tool parsing via system prompt
// The model may return native tool_calls but we convert them to JSON format
// that parse_tool_calls() understands
false
}
}
// ─── Tests ────────────────────────────────────────────────────────────────────
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_url() {
let p = OllamaProvider::new(None);
let p = OllamaProvider::new(None, None);
assert_eq!(p.base_url, "http://localhost:11434");
}
#[test]
fn custom_url_trailing_slash() {
let p = OllamaProvider::new(Some("http://192.168.1.100:11434/"));
let p = OllamaProvider::new(Some("http://192.168.1.100:11434/"), None);
assert_eq!(p.base_url, "http://192.168.1.100:11434");
}
#[test]
fn custom_url_no_trailing_slash() {
let p = OllamaProvider::new(Some("http://myserver:11434"));
let p = OllamaProvider::new(Some("http://myserver:11434"), None);
assert_eq!(p.base_url, "http://myserver:11434");
}
#[test]
fn empty_url_uses_empty() {
let p = OllamaProvider::new(Some(""));
let p = OllamaProvider::new(Some(""), None);
assert_eq!(p.base_url, "");
}
#[test]
fn request_serializes_with_system() {
let req = ChatRequest {
model: "llama3".to_string(),
messages: vec![
Message {
role: "system".to_string(),
content: "You are ZeroClaw".to_string(),
},
Message {
role: "user".to_string(),
content: "hello".to_string(),
},
],
stream: false,
options: Options { temperature: 0.7 },
};
let json = serde_json::to_string(&req).unwrap();
assert!(json.contains("\"stream\":false"));
assert!(json.contains("llama3"));
assert!(json.contains("system"));
assert!(json.contains("\"temperature\":0.7"));
fn cloud_suffix_strips_model_name() {
let p = OllamaProvider::new(Some("https://ollama.com"), Some("ollama-key"));
let (model, should_auth) = p.resolve_request_details("qwen3:cloud").unwrap();
assert_eq!(model, "qwen3");
assert!(should_auth);
}
#[test]
fn request_serializes_without_system() {
let req = ChatRequest {
model: "mistral".to_string(),
messages: vec![Message {
role: "user".to_string(),
content: "test".to_string(),
}],
stream: false,
options: Options { temperature: 0.0 },
};
let json = serde_json::to_string(&req).unwrap();
assert!(!json.contains("\"role\":\"system\""));
assert!(json.contains("mistral"));
fn cloud_suffix_with_local_endpoint_errors() {
let p = OllamaProvider::new(None, Some("ollama-key"));
let error = p
.resolve_request_details("qwen3:cloud")
.expect_err("cloud suffix should fail on local endpoint");
assert!(error
.to_string()
.contains("requested cloud routing, but Ollama endpoint is local"));
}
#[test]
fn cloud_suffix_without_api_key_errors() {
let p = OllamaProvider::new(Some("https://ollama.com"), None);
let error = p
.resolve_request_details("qwen3:cloud")
.expect_err("cloud suffix should require API key");
assert!(error
.to_string()
.contains("requested cloud routing, but no API key is configured"));
}
#[test]
fn remote_endpoint_auth_enabled_when_key_present() {
let p = OllamaProvider::new(Some("https://ollama.com"), Some("ollama-key"));
let (_model, should_auth) = p.resolve_request_details("qwen3").unwrap();
assert!(should_auth);
}
#[test]
fn local_endpoint_auth_disabled_even_with_key() {
let p = OllamaProvider::new(None, Some("ollama-key"));
let (_model, should_auth) = p.resolve_request_details("llama3").unwrap();
assert!(!should_auth);
}
#[test]
fn response_deserializes() {
let json = r#"{"message":{"role":"assistant","content":"Hello from Ollama!"}}"#;
let resp: ChatResponse = serde_json::from_str(json).unwrap();
let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
assert_eq!(resp.message.content, "Hello from Ollama!");
}
#[test]
fn response_with_empty_content() {
let json = r#"{"message":{"role":"assistant","content":""}}"#;
let resp: ChatResponse = serde_json::from_str(json).unwrap();
let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
assert!(resp.message.content.is_empty());
}
#[test]
fn response_with_multiline() {
let json = r#"{"message":{"role":"assistant","content":"line1\nline2\nline3"}}"#;
let resp: ChatResponse = serde_json::from_str(json).unwrap();
assert!(resp.message.content.contains("line1"));
fn response_with_missing_content_defaults_to_empty() {
let json = r#"{"message":{"role":"assistant"}}"#;
let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
assert!(resp.message.content.is_empty());
}
#[test]
fn response_with_thinking_field_extracts_content() {
let json =
r#"{"message":{"role":"assistant","content":"hello","thinking":"internal reasoning"}}"#;
let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
assert_eq!(resp.message.content, "hello");
}
#[test]
fn response_with_tool_calls_parses_correctly() {
let json = r#"{"message":{"role":"assistant","content":"","tool_calls":[{"id":"call_123","function":{"name":"shell","arguments":{"command":"date"}}}]}}"#;
let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
assert!(resp.message.content.is_empty());
assert_eq!(resp.message.tool_calls.len(), 1);
assert_eq!(resp.message.tool_calls[0].function.name, "shell");
}
#[test]
fn extract_tool_name_handles_nested_tool_call() {
let provider = OllamaProvider::new(None, None);
let tc = OllamaToolCall {
id: Some("call_123".into()),
function: OllamaFunction {
name: "tool_call".into(),
arguments: serde_json::json!({
"name": "shell",
"arguments": {"command": "date"}
}),
},
};
let (name, args) = provider.extract_tool_name_and_args(&tc);
assert_eq!(name, "shell");
assert_eq!(args.get("command").unwrap(), "date");
}
#[test]
fn extract_tool_name_handles_prefixed_name() {
let provider = OllamaProvider::new(None, None);
let tc = OllamaToolCall {
id: Some("call_123".into()),
function: OllamaFunction {
name: "tool.shell".into(),
arguments: serde_json::json!({"command": "ls"}),
},
};
let (name, args) = provider.extract_tool_name_and_args(&tc);
assert_eq!(name, "shell");
assert_eq!(args.get("command").unwrap(), "ls");
}
#[test]
fn extract_tool_name_handles_normal_call() {
let provider = OllamaProvider::new(None, None);
let tc = OllamaToolCall {
id: Some("call_123".into()),
function: OllamaFunction {
name: "file_read".into(),
arguments: serde_json::json!({"path": "/tmp/test"}),
},
};
let (name, args) = provider.extract_tool_name_and_args(&tc);
assert_eq!(name, "file_read");
assert_eq!(args.get("path").unwrap(), "/tmp/test");
}
#[test]
fn format_tool_calls_produces_valid_json() {
let provider = OllamaProvider::new(None, None);
let tool_calls = vec![OllamaToolCall {
id: Some("call_abc".into()),
function: OllamaFunction {
name: "shell".into(),
arguments: serde_json::json!({"command": "date"}),
},
}];
let formatted = provider.format_tool_calls_for_loop(&tool_calls);
let parsed: serde_json::Value = serde_json::from_str(&formatted).unwrap();
assert!(parsed.get("tool_calls").is_some());
let calls = parsed.get("tool_calls").unwrap().as_array().unwrap();
assert_eq!(calls.len(), 1);
let func = calls[0].get("function").unwrap();
assert_eq!(func.get("name").unwrap(), "shell");
// arguments should be a string (JSON-encoded)
assert!(func.get("arguments").unwrap().is_string());
}
}

View file

@ -1,10 +1,14 @@
use crate::providers::traits::Provider;
use crate::providers::traits::{
ChatMessage, ChatRequest as ProviderChatRequest, ChatResponse as ProviderChatResponse,
Provider, ToolCall as ProviderToolCall,
};
use crate::tools::ToolSpec;
use async_trait::async_trait;
use reqwest::Client;
use serde::{Deserialize, Serialize};
pub struct OpenAiProvider {
api_key: Option<String>,
credential: Option<String>,
client: Client,
}
@ -36,10 +40,79 @@ struct ResponseMessage {
content: String,
}
#[derive(Debug, Serialize)]
struct NativeChatRequest {
model: String,
messages: Vec<NativeMessage>,
temperature: f64,
#[serde(skip_serializing_if = "Option::is_none")]
tools: Option<Vec<NativeToolSpec>>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_choice: Option<String>,
}
#[derive(Debug, Serialize)]
struct NativeMessage {
role: String,
#[serde(skip_serializing_if = "Option::is_none")]
content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_call_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_calls: Option<Vec<NativeToolCall>>,
}
#[derive(Debug, Serialize)]
struct NativeToolSpec {
#[serde(rename = "type")]
kind: String,
function: NativeToolFunctionSpec,
}
#[derive(Debug, Serialize)]
struct NativeToolFunctionSpec {
name: String,
description: String,
parameters: serde_json::Value,
}
#[derive(Debug, Serialize, Deserialize)]
struct NativeToolCall {
#[serde(skip_serializing_if = "Option::is_none")]
id: Option<String>,
#[serde(rename = "type", skip_serializing_if = "Option::is_none")]
kind: Option<String>,
function: NativeFunctionCall,
}
#[derive(Debug, Serialize, Deserialize)]
struct NativeFunctionCall {
name: String,
arguments: String,
}
#[derive(Debug, Deserialize)]
struct NativeChatResponse {
choices: Vec<NativeChoice>,
}
#[derive(Debug, Deserialize)]
struct NativeChoice {
message: NativeResponseMessage,
}
#[derive(Debug, Deserialize)]
struct NativeResponseMessage {
#[serde(default)]
content: Option<String>,
#[serde(default)]
tool_calls: Option<Vec<NativeToolCall>>,
}
impl OpenAiProvider {
pub fn new(api_key: Option<&str>) -> Self {
pub fn new(credential: Option<&str>) -> Self {
Self {
api_key: api_key.map(ToString::to_string),
credential: credential.map(ToString::to_string),
client: Client::builder()
.timeout(std::time::Duration::from_secs(120))
.connect_timeout(std::time::Duration::from_secs(10))
@ -47,6 +120,107 @@ impl OpenAiProvider {
.unwrap_or_else(|_| Client::new()),
}
}
fn convert_tools(tools: Option<&[ToolSpec]>) -> Option<Vec<NativeToolSpec>> {
tools.map(|items| {
items
.iter()
.map(|tool| NativeToolSpec {
kind: "function".to_string(),
function: NativeToolFunctionSpec {
name: tool.name.clone(),
description: tool.description.clone(),
parameters: tool.parameters.clone(),
},
})
.collect()
})
}
fn convert_messages(messages: &[ChatMessage]) -> Vec<NativeMessage> {
messages
.iter()
.map(|m| {
if m.role == "assistant" {
if let Ok(value) = serde_json::from_str::<serde_json::Value>(&m.content) {
if let Some(tool_calls_value) = value.get("tool_calls") {
if let Ok(parsed_calls) =
serde_json::from_value::<Vec<ProviderToolCall>>(
tool_calls_value.clone(),
)
{
let tool_calls = parsed_calls
.into_iter()
.map(|tc| NativeToolCall {
id: Some(tc.id),
kind: Some("function".to_string()),
function: NativeFunctionCall {
name: tc.name,
arguments: tc.arguments,
},
})
.collect::<Vec<_>>();
let content = value
.get("content")
.and_then(serde_json::Value::as_str)
.map(ToString::to_string);
return NativeMessage {
role: "assistant".to_string(),
content,
tool_call_id: None,
tool_calls: Some(tool_calls),
};
}
}
}
}
if m.role == "tool" {
if let Ok(value) = serde_json::from_str::<serde_json::Value>(&m.content) {
let tool_call_id = value
.get("tool_call_id")
.and_then(serde_json::Value::as_str)
.map(ToString::to_string);
let content = value
.get("content")
.and_then(serde_json::Value::as_str)
.map(ToString::to_string);
return NativeMessage {
role: "tool".to_string(),
content,
tool_call_id,
tool_calls: None,
};
}
}
NativeMessage {
role: m.role.clone(),
content: Some(m.content.clone()),
tool_call_id: None,
tool_calls: None,
}
})
.collect()
}
fn parse_native_response(message: NativeResponseMessage) -> ProviderChatResponse {
let tool_calls = message
.tool_calls
.unwrap_or_default()
.into_iter()
.map(|tc| ProviderToolCall {
id: tc.id.unwrap_or_else(|| uuid::Uuid::new_v4().to_string()),
name: tc.function.name,
arguments: tc.function.arguments,
})
.collect::<Vec<_>>();
ProviderChatResponse {
text: message.content,
tool_calls,
}
}
}
#[async_trait]
@ -58,7 +232,7 @@ impl Provider for OpenAiProvider {
model: &str,
temperature: f64,
) -> anyhow::Result<String> {
let api_key = self.api_key.as_ref().ok_or_else(|| {
let credential = self.credential.as_ref().ok_or_else(|| {
anyhow::anyhow!("OpenAI API key not set. Set OPENAI_API_KEY or edit config.toml.")
})?;
@ -85,14 +259,13 @@ impl Provider for OpenAiProvider {
let response = self
.client
.post("https://api.openai.com/v1/chat/completions")
.header("Authorization", format!("Bearer {api_key}"))
.header("Authorization", format!("Bearer {credential}"))
.json(&request)
.send()
.await?;
if !response.status().is_success() {
let error = response.text().await?;
anyhow::bail!("OpenAI API error: {error}");
return Err(super::api_error("OpenAI", response).await);
}
let chat_response: ChatResponse = response.json().await?;
@ -104,6 +277,51 @@ impl Provider for OpenAiProvider {
.map(|c| c.message.content)
.ok_or_else(|| anyhow::anyhow!("No response from OpenAI"))
}
async fn chat(
&self,
request: ProviderChatRequest<'_>,
model: &str,
temperature: f64,
) -> anyhow::Result<ProviderChatResponse> {
let credential = self.credential.as_ref().ok_or_else(|| {
anyhow::anyhow!("OpenAI API key not set. Set OPENAI_API_KEY or edit config.toml.")
})?;
let tools = Self::convert_tools(request.tools);
let native_request = NativeChatRequest {
model: model.to_string(),
messages: Self::convert_messages(request.messages),
temperature,
tool_choice: tools.as_ref().map(|_| "auto".to_string()),
tools,
};
let response = self
.client
.post("https://api.openai.com/v1/chat/completions")
.header("Authorization", format!("Bearer {credential}"))
.json(&native_request)
.send()
.await?;
if !response.status().is_success() {
return Err(super::api_error("OpenAI", response).await);
}
let native_response: NativeChatResponse = response.json().await?;
let message = native_response
.choices
.into_iter()
.next()
.map(|c| c.message)
.ok_or_else(|| anyhow::anyhow!("No response from OpenAI"))?;
Ok(Self::parse_native_response(message))
}
fn supports_native_tools(&self) -> bool {
true
}
}
#[cfg(test)]
@ -112,20 +330,20 @@ mod tests {
#[test]
fn creates_with_key() {
let p = OpenAiProvider::new(Some("sk-proj-abc123"));
assert_eq!(p.api_key.as_deref(), Some("sk-proj-abc123"));
let p = OpenAiProvider::new(Some("openai-test-credential"));
assert_eq!(p.credential.as_deref(), Some("openai-test-credential"));
}
#[test]
fn creates_without_key() {
let p = OpenAiProvider::new(None);
assert!(p.api_key.is_none());
assert!(p.credential.is_none());
}
#[test]
fn creates_with_empty_key() {
let p = OpenAiProvider::new(Some(""));
assert_eq!(p.api_key.as_deref(), Some(""));
assert_eq!(p.credential.as_deref(), Some(""));
}
#[tokio::test]

View file

@ -1,10 +1,14 @@
use crate::providers::traits::Provider;
use crate::providers::traits::{
ChatMessage, ChatRequest as ProviderChatRequest, ChatResponse as ProviderChatResponse,
Provider, ToolCall as ProviderToolCall,
};
use crate::tools::ToolSpec;
use async_trait::async_trait;
use reqwest::Client;
use serde::{Deserialize, Serialize};
pub struct OpenRouterProvider {
api_key: Option<String>,
credential: Option<String>,
client: Client,
}
@ -22,7 +26,7 @@ struct Message {
}
#[derive(Debug, Deserialize)]
struct ChatResponse {
struct ApiChatResponse {
choices: Vec<Choice>,
}
@ -36,10 +40,79 @@ struct ResponseMessage {
content: String,
}
#[derive(Debug, Serialize)]
struct NativeChatRequest {
model: String,
messages: Vec<NativeMessage>,
temperature: f64,
#[serde(skip_serializing_if = "Option::is_none")]
tools: Option<Vec<NativeToolSpec>>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_choice: Option<String>,
}
#[derive(Debug, Serialize)]
struct NativeMessage {
role: String,
#[serde(skip_serializing_if = "Option::is_none")]
content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_call_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_calls: Option<Vec<NativeToolCall>>,
}
#[derive(Debug, Serialize)]
struct NativeToolSpec {
#[serde(rename = "type")]
kind: String,
function: NativeToolFunctionSpec,
}
#[derive(Debug, Serialize)]
struct NativeToolFunctionSpec {
name: String,
description: String,
parameters: serde_json::Value,
}
#[derive(Debug, Serialize, Deserialize)]
struct NativeToolCall {
#[serde(skip_serializing_if = "Option::is_none")]
id: Option<String>,
#[serde(rename = "type", skip_serializing_if = "Option::is_none")]
kind: Option<String>,
function: NativeFunctionCall,
}
#[derive(Debug, Serialize, Deserialize)]
struct NativeFunctionCall {
name: String,
arguments: String,
}
#[derive(Debug, Deserialize)]
struct NativeChatResponse {
choices: Vec<NativeChoice>,
}
#[derive(Debug, Deserialize)]
struct NativeChoice {
message: NativeResponseMessage,
}
#[derive(Debug, Deserialize)]
struct NativeResponseMessage {
#[serde(default)]
content: Option<String>,
#[serde(default)]
tool_calls: Option<Vec<NativeToolCall>>,
}
impl OpenRouterProvider {
pub fn new(api_key: Option<&str>) -> Self {
pub fn new(credential: Option<&str>) -> Self {
Self {
api_key: api_key.map(ToString::to_string),
credential: credential.map(ToString::to_string),
client: Client::builder()
.timeout(std::time::Duration::from_secs(120))
.connect_timeout(std::time::Duration::from_secs(10))
@ -47,10 +120,129 @@ impl OpenRouterProvider {
.unwrap_or_else(|_| Client::new()),
}
}
fn convert_tools(tools: Option<&[ToolSpec]>) -> Option<Vec<NativeToolSpec>> {
let items = tools?;
if items.is_empty() {
return None;
}
Some(
items
.iter()
.map(|tool| NativeToolSpec {
kind: "function".to_string(),
function: NativeToolFunctionSpec {
name: tool.name.clone(),
description: tool.description.clone(),
parameters: tool.parameters.clone(),
},
})
.collect(),
)
}
fn convert_messages(messages: &[ChatMessage]) -> Vec<NativeMessage> {
messages
.iter()
.map(|m| {
if m.role == "assistant" {
if let Ok(value) = serde_json::from_str::<serde_json::Value>(&m.content) {
if let Some(tool_calls_value) = value.get("tool_calls") {
if let Ok(parsed_calls) =
serde_json::from_value::<Vec<ProviderToolCall>>(
tool_calls_value.clone(),
)
{
let tool_calls = parsed_calls
.into_iter()
.map(|tc| NativeToolCall {
id: Some(tc.id),
kind: Some("function".to_string()),
function: NativeFunctionCall {
name: tc.name,
arguments: tc.arguments,
},
})
.collect::<Vec<_>>();
let content = value
.get("content")
.and_then(serde_json::Value::as_str)
.map(ToString::to_string);
return NativeMessage {
role: "assistant".to_string(),
content,
tool_call_id: None,
tool_calls: Some(tool_calls),
};
}
}
}
}
if m.role == "tool" {
if let Ok(value) = serde_json::from_str::<serde_json::Value>(&m.content) {
let tool_call_id = value
.get("tool_call_id")
.and_then(serde_json::Value::as_str)
.map(ToString::to_string);
let content = value
.get("content")
.and_then(serde_json::Value::as_str)
.map(ToString::to_string);
return NativeMessage {
role: "tool".to_string(),
content,
tool_call_id,
tool_calls: None,
};
}
}
NativeMessage {
role: m.role.clone(),
content: Some(m.content.clone()),
tool_call_id: None,
tool_calls: None,
}
})
.collect()
}
fn parse_native_response(message: NativeResponseMessage) -> ProviderChatResponse {
let tool_calls = message
.tool_calls
.unwrap_or_default()
.into_iter()
.map(|tc| ProviderToolCall {
id: tc.id.unwrap_or_else(|| uuid::Uuid::new_v4().to_string()),
name: tc.function.name,
arguments: tc.function.arguments,
})
.collect::<Vec<_>>();
ProviderChatResponse {
text: message.content,
tool_calls,
}
}
}
#[async_trait]
impl Provider for OpenRouterProvider {
async fn warmup(&self) -> anyhow::Result<()> {
// Hit a lightweight endpoint to establish TLS + HTTP/2 connection pool.
// This prevents the first real chat request from timing out on cold start.
if let Some(credential) = self.credential.as_ref() {
self.client
.get("https://openrouter.ai/api/v1/auth/key")
.header("Authorization", format!("Bearer {credential}"))
.send()
.await?
.error_for_status()?;
}
Ok(())
}
async fn chat_with_system(
&self,
system_prompt: Option<&str>,
@ -58,7 +250,7 @@ impl Provider for OpenRouterProvider {
model: &str,
temperature: f64,
) -> anyhow::Result<String> {
let api_key = self.api_key.as_ref()
let credential = self.credential.as_ref()
.ok_or_else(|| anyhow::anyhow!("OpenRouter API key not set. Run `zeroclaw onboard` or set OPENROUTER_API_KEY env var."))?;
let mut messages = Vec::new();
@ -84,7 +276,7 @@ impl Provider for OpenRouterProvider {
let response = self
.client
.post("https://openrouter.ai/api/v1/chat/completions")
.header("Authorization", format!("Bearer {api_key}"))
.header("Authorization", format!("Bearer {credential}"))
.header(
"HTTP-Referer",
"https://github.com/theonlyhennygod/zeroclaw",
@ -95,11 +287,10 @@ impl Provider for OpenRouterProvider {
.await?;
if !response.status().is_success() {
let error = response.text().await?;
anyhow::bail!("OpenRouter API error: {error}");
return Err(super::api_error("OpenRouter", response).await);
}
let chat_response: ChatResponse = response.json().await?;
let chat_response: ApiChatResponse = response.json().await?;
chat_response
.choices
@ -108,4 +299,455 @@ impl Provider for OpenRouterProvider {
.map(|c| c.message.content)
.ok_or_else(|| anyhow::anyhow!("No response from OpenRouter"))
}
async fn chat_with_history(
&self,
messages: &[ChatMessage],
model: &str,
temperature: f64,
) -> anyhow::Result<String> {
let credential = self.credential.as_ref()
.ok_or_else(|| anyhow::anyhow!("OpenRouter API key not set. Run `zeroclaw onboard` or set OPENROUTER_API_KEY env var."))?;
let api_messages: Vec<Message> = messages
.iter()
.map(|m| Message {
role: m.role.clone(),
content: m.content.clone(),
})
.collect();
let request = ChatRequest {
model: model.to_string(),
messages: api_messages,
temperature,
};
let response = self
.client
.post("https://openrouter.ai/api/v1/chat/completions")
.header("Authorization", format!("Bearer {credential}"))
.header(
"HTTP-Referer",
"https://github.com/theonlyhennygod/zeroclaw",
)
.header("X-Title", "ZeroClaw")
.json(&request)
.send()
.await?;
if !response.status().is_success() {
return Err(super::api_error("OpenRouter", response).await);
}
let chat_response: ApiChatResponse = response.json().await?;
chat_response
.choices
.into_iter()
.next()
.map(|c| c.message.content)
.ok_or_else(|| anyhow::anyhow!("No response from OpenRouter"))
}
async fn chat(
&self,
request: ProviderChatRequest<'_>,
model: &str,
temperature: f64,
) -> anyhow::Result<ProviderChatResponse> {
let credential = self.credential.as_ref().ok_or_else(|| {
anyhow::anyhow!(
"OpenRouter API key not set. Run `zeroclaw onboard` or set OPENROUTER_API_KEY env var."
)
})?;
let tools = Self::convert_tools(request.tools);
let native_request = NativeChatRequest {
model: model.to_string(),
messages: Self::convert_messages(request.messages),
temperature,
tool_choice: tools.as_ref().map(|_| "auto".to_string()),
tools,
};
let response = self
.client
.post("https://openrouter.ai/api/v1/chat/completions")
.header("Authorization", format!("Bearer {credential}"))
.header(
"HTTP-Referer",
"https://github.com/theonlyhennygod/zeroclaw",
)
.header("X-Title", "ZeroClaw")
.json(&native_request)
.send()
.await?;
if !response.status().is_success() {
return Err(super::api_error("OpenRouter", response).await);
}
let native_response: NativeChatResponse = response.json().await?;
let message = native_response
.choices
.into_iter()
.next()
.map(|c| c.message)
.ok_or_else(|| anyhow::anyhow!("No response from OpenRouter"))?;
Ok(Self::parse_native_response(message))
}
fn supports_native_tools(&self) -> bool {
true
}
async fn chat_with_tools(
&self,
messages: &[ChatMessage],
tools: &[serde_json::Value],
model: &str,
temperature: f64,
) -> anyhow::Result<ProviderChatResponse> {
let credential = self.credential.as_ref().ok_or_else(|| {
anyhow::anyhow!(
"OpenRouter API key not set. Run `zeroclaw onboard` or set OPENROUTER_API_KEY env var."
)
})?;
// Convert tool JSON values to NativeToolSpec
let native_tools: Option<Vec<NativeToolSpec>> = if tools.is_empty() {
None
} else {
let specs: Vec<NativeToolSpec> = tools
.iter()
.filter_map(|t| {
let func = t.get("function")?;
Some(NativeToolSpec {
kind: "function".to_string(),
function: NativeToolFunctionSpec {
name: func.get("name")?.as_str()?.to_string(),
description: func
.get("description")
.and_then(|d| d.as_str())
.unwrap_or("")
.to_string(),
parameters: func
.get("parameters")
.cloned()
.unwrap_or(serde_json::json!({})),
},
})
})
.collect();
if specs.is_empty() {
None
} else {
Some(specs)
}
};
// Convert ChatMessage to NativeMessage, preserving structured assistant/tool entries
// when history contains native tool-call metadata.
let native_messages = Self::convert_messages(messages);
let native_request = NativeChatRequest {
model: model.to_string(),
messages: native_messages,
temperature,
tool_choice: native_tools.as_ref().map(|_| "auto".to_string()),
tools: native_tools,
};
let response = self
.client
.post("https://openrouter.ai/api/v1/chat/completions")
.header("Authorization", format!("Bearer {credential}"))
.header(
"HTTP-Referer",
"https://github.com/theonlyhennygod/zeroclaw",
)
.header("X-Title", "ZeroClaw")
.json(&native_request)
.send()
.await?;
if !response.status().is_success() {
return Err(super::api_error("OpenRouter", response).await);
}
let native_response: NativeChatResponse = response.json().await?;
let message = native_response
.choices
.into_iter()
.next()
.map(|c| c.message)
.ok_or_else(|| anyhow::anyhow!("No response from OpenRouter"))?;
Ok(Self::parse_native_response(message))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::providers::traits::{ChatMessage, Provider};
#[test]
fn creates_with_key() {
let provider = OpenRouterProvider::new(Some("openrouter-test-credential"));
assert_eq!(
provider.credential.as_deref(),
Some("openrouter-test-credential")
);
}
#[test]
fn creates_without_key() {
let provider = OpenRouterProvider::new(None);
assert!(provider.credential.is_none());
}
#[tokio::test]
async fn warmup_without_key_is_noop() {
let provider = OpenRouterProvider::new(None);
let result = provider.warmup().await;
assert!(result.is_ok());
}
#[tokio::test]
async fn chat_with_system_fails_without_key() {
let provider = OpenRouterProvider::new(None);
let result = provider
.chat_with_system(Some("system"), "hello", "openai/gpt-4o", 0.2)
.await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("API key not set"));
}
#[tokio::test]
async fn chat_with_history_fails_without_key() {
let provider = OpenRouterProvider::new(None);
let messages = vec![
ChatMessage {
role: "system".into(),
content: "be concise".into(),
},
ChatMessage {
role: "user".into(),
content: "hello".into(),
},
];
let result = provider
.chat_with_history(&messages, "anthropic/claude-sonnet-4", 0.7)
.await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("API key not set"));
}
#[test]
fn chat_request_serializes_with_system_and_user() {
let request = ChatRequest {
model: "anthropic/claude-sonnet-4".into(),
messages: vec![
Message {
role: "system".into(),
content: "You are helpful".into(),
},
Message {
role: "user".into(),
content: "Summarize this".into(),
},
],
temperature: 0.5,
};
let json = serde_json::to_string(&request).unwrap();
assert!(json.contains("anthropic/claude-sonnet-4"));
assert!(json.contains("\"role\":\"system\""));
assert!(json.contains("\"role\":\"user\""));
assert!(json.contains("\"temperature\":0.5"));
}
#[test]
fn chat_request_serializes_history_messages() {
let messages = [
ChatMessage {
role: "assistant".into(),
content: "Previous answer".into(),
},
ChatMessage {
role: "user".into(),
content: "Follow-up".into(),
},
];
let request = ChatRequest {
model: "google/gemini-2.5-pro".into(),
messages: messages
.iter()
.map(|msg| Message {
role: msg.role.clone(),
content: msg.content.clone(),
})
.collect(),
temperature: 0.0,
};
let json = serde_json::to_string(&request).unwrap();
assert!(json.contains("\"role\":\"assistant\""));
assert!(json.contains("\"role\":\"user\""));
assert!(json.contains("google/gemini-2.5-pro"));
}
#[test]
fn response_deserializes_single_choice() {
let json = r#"{"choices":[{"message":{"content":"Hi from OpenRouter"}}]}"#;
let response: ApiChatResponse = serde_json::from_str(json).unwrap();
assert_eq!(response.choices.len(), 1);
assert_eq!(response.choices[0].message.content, "Hi from OpenRouter");
}
#[test]
fn response_deserializes_empty_choices() {
let json = r#"{"choices":[]}"#;
let response: ApiChatResponse = serde_json::from_str(json).unwrap();
assert!(response.choices.is_empty());
}
#[tokio::test]
async fn chat_with_tools_fails_without_key() {
let provider = OpenRouterProvider::new(None);
let messages = vec![ChatMessage {
role: "user".into(),
content: "What is the date?".into(),
}];
let tools = vec![serde_json::json!({
"type": "function",
"function": {
"name": "shell",
"description": "Run a shell command",
"parameters": {"type": "object", "properties": {"command": {"type": "string"}}}
}
})];
let result = provider
.chat_with_tools(&messages, &tools, "deepseek/deepseek-chat", 0.5)
.await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("API key not set"));
}
#[test]
fn native_response_deserializes_with_tool_calls() {
let json = r#"{
"choices":[{
"message":{
"content":null,
"tool_calls":[
{"id":"call_123","type":"function","function":{"name":"get_price","arguments":"{\"symbol\":\"BTC\"}"}}
]
}
}]
}"#;
let response: NativeChatResponse = serde_json::from_str(json).unwrap();
assert_eq!(response.choices.len(), 1);
let message = &response.choices[0].message;
assert!(message.content.is_none());
let tool_calls = message.tool_calls.as_ref().unwrap();
assert_eq!(tool_calls.len(), 1);
assert_eq!(tool_calls[0].id.as_deref(), Some("call_123"));
assert_eq!(tool_calls[0].function.name, "get_price");
assert_eq!(tool_calls[0].function.arguments, "{\"symbol\":\"BTC\"}");
}
#[test]
fn native_response_deserializes_with_text_and_tool_calls() {
let json = r#"{
"choices":[{
"message":{
"content":"I'll get that for you.",
"tool_calls":[
{"id":"call_456","type":"function","function":{"name":"shell","arguments":"{\"command\":\"date\"}"}}
]
}
}]
}"#;
let response: NativeChatResponse = serde_json::from_str(json).unwrap();
assert_eq!(response.choices.len(), 1);
let message = &response.choices[0].message;
assert_eq!(message.content.as_deref(), Some("I'll get that for you."));
let tool_calls = message.tool_calls.as_ref().unwrap();
assert_eq!(tool_calls.len(), 1);
assert_eq!(tool_calls[0].function.name, "shell");
}
#[test]
fn parse_native_response_converts_to_chat_response() {
let message = NativeResponseMessage {
content: Some("Here you go.".into()),
tool_calls: Some(vec![NativeToolCall {
id: Some("call_789".into()),
kind: Some("function".into()),
function: NativeFunctionCall {
name: "file_read".into(),
arguments: r#"{"path":"test.txt"}"#.into(),
},
}]),
};
let response = OpenRouterProvider::parse_native_response(message);
assert_eq!(response.text.as_deref(), Some("Here you go."));
assert_eq!(response.tool_calls.len(), 1);
assert_eq!(response.tool_calls[0].id, "call_789");
assert_eq!(response.tool_calls[0].name, "file_read");
}
#[test]
fn convert_messages_parses_assistant_tool_call_payload() {
let messages = vec![ChatMessage {
role: "assistant".into(),
content: r#"{"content":"Using tool","tool_calls":[{"id":"call_abc","name":"shell","arguments":"{\"command\":\"pwd\"}"}]}"#
.into(),
}];
let converted = OpenRouterProvider::convert_messages(&messages);
assert_eq!(converted.len(), 1);
assert_eq!(converted[0].role, "assistant");
assert_eq!(converted[0].content.as_deref(), Some("Using tool"));
let tool_calls = converted[0].tool_calls.as_ref().unwrap();
assert_eq!(tool_calls.len(), 1);
assert_eq!(tool_calls[0].id.as_deref(), Some("call_abc"));
assert_eq!(tool_calls[0].function.name, "shell");
assert_eq!(tool_calls[0].function.arguments, r#"{"command":"pwd"}"#);
}
#[test]
fn convert_messages_parses_tool_result_payload() {
let messages = vec![ChatMessage {
role: "tool".into(),
content: r#"{"tool_call_id":"call_xyz","content":"done"}"#.into(),
}];
let converted = OpenRouterProvider::convert_messages(&messages);
assert_eq!(converted.len(), 1);
assert_eq!(converted[0].role, "tool");
assert_eq!(converted[0].tool_call_id.as_deref(), Some("call_xyz"));
assert_eq!(converted[0].content.as_deref(), Some("done"));
assert!(converted[0].tool_calls.is_none());
}
}

View file

@ -1,12 +1,85 @@
use super::traits::{ChatMessage, StreamChunk, StreamOptions, StreamResult};
use super::Provider;
use async_trait::async_trait;
use futures_util::{stream, StreamExt};
use std::collections::HashMap;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::time::Duration;
/// Provider wrapper with retry + fallback behavior.
/// Check if an error is non-retryable (client errors that won't resolve with retries).
fn is_non_retryable(err: &anyhow::Error) -> bool {
if let Some(reqwest_err) = err.downcast_ref::<reqwest::Error>() {
if let Some(status) = reqwest_err.status() {
let code = status.as_u16();
return status.is_client_error() && code != 429 && code != 408;
}
}
let msg = err.to_string();
for word in msg.split(|c: char| !c.is_ascii_digit()) {
if let Ok(code) = word.parse::<u16>() {
if (400..500).contains(&code) {
return code != 429 && code != 408;
}
}
}
false
}
/// Check if an error is a rate-limit (429) error.
fn is_rate_limited(err: &anyhow::Error) -> bool {
if let Some(reqwest_err) = err.downcast_ref::<reqwest::Error>() {
if let Some(status) = reqwest_err.status() {
return status.as_u16() == 429;
}
}
let msg = err.to_string();
msg.contains("429")
&& (msg.contains("Too Many") || msg.contains("rate") || msg.contains("limit"))
}
/// Try to extract a Retry-After value (in milliseconds) from an error message.
/// Looks for patterns like `Retry-After: 5` or `retry_after: 2.5` in the error string.
fn parse_retry_after_ms(err: &anyhow::Error) -> Option<u64> {
let msg = err.to_string();
let lower = msg.to_lowercase();
// Look for "retry-after: <number>" or "retry_after: <number>"
for prefix in &[
"retry-after:",
"retry_after:",
"retry-after ",
"retry_after ",
] {
if let Some(pos) = lower.find(prefix) {
let after = &msg[pos + prefix.len()..];
let num_str: String = after
.trim()
.chars()
.take_while(|c| c.is_ascii_digit() || *c == '.')
.collect();
if let Ok(secs) = num_str.parse::<f64>() {
if secs.is_finite() && secs >= 0.0 {
let millis = Duration::from_secs_f64(secs).as_millis();
if let Ok(value) = u64::try_from(millis) {
return Some(value);
}
}
}
}
}
None
}
/// Provider wrapper with retry, fallback, auth rotation, and model failover.
pub struct ReliableProvider {
providers: Vec<(String, Box<dyn Provider>)>,
max_retries: u32,
base_backoff_ms: u64,
/// Extra API keys for rotation (index tracks round-robin position).
api_keys: Vec<String>,
key_index: AtomicUsize,
/// Per-model fallback chains: model_name → [fallback_model_1, fallback_model_2, ...]
model_fallbacks: HashMap<String, Vec<String>>,
}
impl ReliableProvider {
@ -19,12 +92,65 @@ impl ReliableProvider {
providers,
max_retries,
base_backoff_ms: base_backoff_ms.max(50),
api_keys: Vec::new(),
key_index: AtomicUsize::new(0),
model_fallbacks: HashMap::new(),
}
}
/// Set additional API keys for round-robin rotation on rate-limit errors.
pub fn with_api_keys(mut self, keys: Vec<String>) -> Self {
self.api_keys = keys;
self
}
/// Set per-model fallback chains.
pub fn with_model_fallbacks(mut self, fallbacks: HashMap<String, Vec<String>>) -> Self {
self.model_fallbacks = fallbacks;
self
}
/// Build the list of models to try: [original, fallback1, fallback2, ...]
fn model_chain<'a>(&'a self, model: &'a str) -> Vec<&'a str> {
let mut chain = vec![model];
if let Some(fallbacks) = self.model_fallbacks.get(model) {
chain.extend(fallbacks.iter().map(|s| s.as_str()));
}
chain
}
/// Advance to the next API key and return it, or None if no extra keys configured.
fn rotate_key(&self) -> Option<&str> {
if self.api_keys.is_empty() {
return None;
}
let idx = self.key_index.fetch_add(1, Ordering::Relaxed) % self.api_keys.len();
Some(&self.api_keys[idx])
}
/// Compute backoff duration, respecting Retry-After if present.
fn compute_backoff(&self, base: u64, err: &anyhow::Error) -> u64 {
if let Some(retry_after) = parse_retry_after_ms(err) {
// Use Retry-After but cap at 30s to avoid indefinite waits
retry_after.min(30_000).max(base)
} else {
base
}
}
}
#[async_trait]
impl Provider for ReliableProvider {
async fn warmup(&self) -> anyhow::Result<()> {
for (name, provider) in &self.providers {
tracing::info!(provider = name, "Warming up provider connection pool");
if provider.warmup().await.is_err() {
tracing::warn!(provider = name, "Warmup failed (non-fatal)");
}
}
Ok(())
}
async fn chat_with_system(
&self,
system_prompt: Option<&str>,
@ -32,58 +158,278 @@ impl Provider for ReliableProvider {
model: &str,
temperature: f64,
) -> anyhow::Result<String> {
let models = self.model_chain(model);
let mut failures = Vec::new();
for (provider_name, provider) in &self.providers {
let mut backoff_ms = self.base_backoff_ms;
for current_model in &models {
for (provider_name, provider) in &self.providers {
let mut backoff_ms = self.base_backoff_ms;
for attempt in 0..=self.max_retries {
match provider
.chat_with_system(system_prompt, message, model, temperature)
.await
{
Ok(resp) => {
if attempt > 0 {
tracing::info!(
provider = provider_name,
attempt,
"Provider recovered after retries"
);
for attempt in 0..=self.max_retries {
match provider
.chat_with_system(system_prompt, message, current_model, temperature)
.await
{
Ok(resp) => {
if attempt > 0 || *current_model != model {
tracing::info!(
provider = provider_name,
model = *current_model,
attempt,
original_model = model,
"Provider recovered (failover/retry)"
);
}
return Ok(resp);
}
return Ok(resp);
}
Err(e) => {
failures.push(format!(
"{provider_name} attempt {}/{}: {e}",
attempt + 1,
self.max_retries + 1
));
Err(e) => {
let non_retryable = is_non_retryable(&e);
let rate_limited = is_rate_limited(&e);
if attempt < self.max_retries {
tracing::warn!(
provider = provider_name,
attempt = attempt + 1,
max_retries = self.max_retries,
"Provider call failed, retrying"
);
tokio::time::sleep(Duration::from_millis(backoff_ms)).await;
backoff_ms = (backoff_ms.saturating_mul(2)).min(10_000);
let failure_reason = if rate_limited {
"rate_limited"
} else if non_retryable {
"non_retryable"
} else {
"retryable"
};
failures.push(format!(
"{provider_name}/{current_model} attempt {}/{}: {failure_reason}",
attempt + 1,
self.max_retries + 1
));
// On rate-limit, try rotating API key
if rate_limited {
if let Some(new_key) = self.rotate_key() {
tracing::info!(
provider = provider_name,
"Rate limited, rotated API key (key ending ...{})",
&new_key[new_key.len().saturating_sub(4)..]
);
}
}
if non_retryable {
tracing::warn!(
provider = provider_name,
model = *current_model,
"Non-retryable error, moving on"
);
break;
}
if attempt < self.max_retries {
let wait = self.compute_backoff(backoff_ms, &e);
tracing::warn!(
provider = provider_name,
model = *current_model,
attempt = attempt + 1,
backoff_ms = wait,
"Provider call failed, retrying"
);
tokio::time::sleep(Duration::from_millis(wait)).await;
backoff_ms = (backoff_ms.saturating_mul(2)).min(10_000);
}
}
}
}
tracing::warn!(
provider = provider_name,
model = *current_model,
"Exhausted retries, trying next provider/model"
);
}
tracing::warn!(provider = provider_name, "Switching to fallback provider");
if *current_model != model {
tracing::warn!(
original_model = model,
fallback_model = *current_model,
"Model fallback exhausted all providers, trying next fallback model"
);
}
}
anyhow::bail!("All providers failed. Attempts:\n{}", failures.join("\n"))
anyhow::bail!(
"All providers/models failed. Attempts:\n{}",
failures.join("\n")
)
}
async fn chat_with_history(
&self,
messages: &[ChatMessage],
model: &str,
temperature: f64,
) -> anyhow::Result<String> {
let models = self.model_chain(model);
let mut failures = Vec::new();
for current_model in &models {
for (provider_name, provider) in &self.providers {
let mut backoff_ms = self.base_backoff_ms;
for attempt in 0..=self.max_retries {
match provider
.chat_with_history(messages, current_model, temperature)
.await
{
Ok(resp) => {
if attempt > 0 || *current_model != model {
tracing::info!(
provider = provider_name,
model = *current_model,
attempt,
original_model = model,
"Provider recovered (failover/retry)"
);
}
return Ok(resp);
}
Err(e) => {
let non_retryable = is_non_retryable(&e);
let rate_limited = is_rate_limited(&e);
let failure_reason = if rate_limited {
"rate_limited"
} else if non_retryable {
"non_retryable"
} else {
"retryable"
};
failures.push(format!(
"{provider_name}/{current_model} attempt {}/{}: {failure_reason}",
attempt + 1,
self.max_retries + 1
));
if rate_limited {
if let Some(new_key) = self.rotate_key() {
tracing::info!(
provider = provider_name,
"Rate limited, rotated API key (key ending ...{})",
&new_key[new_key.len().saturating_sub(4)..]
);
}
}
if non_retryable {
tracing::warn!(
provider = provider_name,
model = *current_model,
"Non-retryable error, moving on"
);
break;
}
if attempt < self.max_retries {
let wait = self.compute_backoff(backoff_ms, &e);
tracing::warn!(
provider = provider_name,
model = *current_model,
attempt = attempt + 1,
backoff_ms = wait,
"Provider call failed, retrying"
);
tokio::time::sleep(Duration::from_millis(wait)).await;
backoff_ms = (backoff_ms.saturating_mul(2)).min(10_000);
}
}
}
}
tracing::warn!(
provider = provider_name,
model = *current_model,
"Exhausted retries, trying next provider/model"
);
}
}
anyhow::bail!(
"All providers/models failed. Attempts:\n{}",
failures.join("\n")
)
}
fn supports_streaming(&self) -> bool {
self.providers.iter().any(|(_, p)| p.supports_streaming())
}
fn stream_chat_with_system(
&self,
system_prompt: Option<&str>,
message: &str,
model: &str,
temperature: f64,
options: StreamOptions,
) -> stream::BoxStream<'static, StreamResult<StreamChunk>> {
// Try each provider/model combination for streaming
// For streaming, we use the first provider that supports it and has streaming enabled
for (provider_name, provider) in &self.providers {
if !provider.supports_streaming() || !options.enabled {
continue;
}
// Clone provider data for the stream
let provider_clone = provider_name.clone();
// Try the first model in the chain for streaming
let current_model = match self.model_chain(model).first() {
Some(m) => m.to_string(),
None => model.to_string(),
};
// For streaming, we attempt once and propagate errors
// The caller can retry the entire request if needed
let stream = provider.stream_chat_with_system(
system_prompt,
message,
&current_model,
temperature,
options,
);
// Use a channel to bridge the stream with logging
let (tx, rx) = tokio::sync::mpsc::channel::<StreamResult<StreamChunk>>(100);
tokio::spawn(async move {
let mut stream = stream;
while let Some(chunk) = stream.next().await {
if let Err(ref e) = chunk {
tracing::warn!(
provider = provider_clone,
model = current_model,
"Streaming error: {e}"
);
}
if tx.send(chunk).await.is_err() {
break; // Receiver dropped
}
}
});
// Convert channel receiver to stream
return stream::unfold(rx, |mut rx| async move {
rx.recv().await.map(|chunk| (chunk, rx))
})
.boxed();
}
// No streaming support available
stream::once(async move {
Err(super::traits::StreamError::Provider(
"No provider supports streaming".to_string(),
))
})
.boxed()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
struct MockProvider {
@ -108,8 +454,49 @@ mod tests {
}
Ok(self.response.to_string())
}
async fn chat_with_history(
&self,
_messages: &[ChatMessage],
_model: &str,
_temperature: f64,
) -> anyhow::Result<String> {
let attempt = self.calls.fetch_add(1, Ordering::SeqCst) + 1;
if attempt <= self.fail_until_attempt {
anyhow::bail!(self.error);
}
Ok(self.response.to_string())
}
}
/// Mock that records which model was used for each call.
struct ModelAwareMock {
calls: Arc<AtomicUsize>,
models_seen: parking_lot::Mutex<Vec<String>>,
fail_models: Vec<&'static str>,
response: &'static str,
}
#[async_trait]
impl Provider for ModelAwareMock {
async fn chat_with_system(
&self,
_system_prompt: Option<&str>,
_message: &str,
model: &str,
_temperature: f64,
) -> anyhow::Result<String> {
self.calls.fetch_add(1, Ordering::SeqCst);
self.models_seen.lock().push(model.to_string());
if self.fail_models.contains(&model) {
anyhow::bail!("500 model {} unavailable", model);
}
Ok(self.response.to_string())
}
}
// ── Existing tests (preserved) ──
#[tokio::test]
async fn succeeds_without_retry() {
let calls = Arc::new(AtomicUsize::new(0));
@ -127,7 +514,7 @@ mod tests {
1,
);
let result = provider.chat("hello", "test", 0.0).await.unwrap();
let result = provider.simple_chat("hello", "test", 0.0).await.unwrap();
assert_eq!(result, "ok");
assert_eq!(calls.load(Ordering::SeqCst), 1);
}
@ -149,7 +536,7 @@ mod tests {
1,
);
let result = provider.chat("hello", "test", 0.0).await.unwrap();
let result = provider.simple_chat("hello", "test", 0.0).await.unwrap();
assert_eq!(result, "recovered");
assert_eq!(calls.load(Ordering::SeqCst), 2);
}
@ -184,7 +571,7 @@ mod tests {
1,
);
let result = provider.chat("hello", "test", 0.0).await.unwrap();
let result = provider.simple_chat("hello", "test", 0.0).await.unwrap();
assert_eq!(result, "from fallback");
assert_eq!(primary_calls.load(Ordering::SeqCst), 2);
assert_eq!(fallback_calls.load(Ordering::SeqCst), 1);
@ -218,12 +605,323 @@ mod tests {
);
let err = provider
.chat("hello", "test", 0.0)
.simple_chat("hello", "test", 0.0)
.await
.expect_err("all providers should fail");
let msg = err.to_string();
assert!(msg.contains("All providers failed"));
assert!(msg.contains("p1 attempt 1/1"));
assert!(msg.contains("p2 attempt 1/1"));
assert!(msg.contains("All providers/models failed"));
assert!(msg.contains("p1"));
assert!(msg.contains("p2"));
}
#[test]
fn non_retryable_detects_common_patterns() {
assert!(is_non_retryable(&anyhow::anyhow!("400 Bad Request")));
assert!(is_non_retryable(&anyhow::anyhow!("401 Unauthorized")));
assert!(is_non_retryable(&anyhow::anyhow!("403 Forbidden")));
assert!(is_non_retryable(&anyhow::anyhow!("404 Not Found")));
assert!(!is_non_retryable(&anyhow::anyhow!("429 Too Many Requests")));
assert!(!is_non_retryable(&anyhow::anyhow!("408 Request Timeout")));
assert!(!is_non_retryable(&anyhow::anyhow!(
"500 Internal Server Error"
)));
assert!(!is_non_retryable(&anyhow::anyhow!("502 Bad Gateway")));
assert!(!is_non_retryable(&anyhow::anyhow!("timeout")));
assert!(!is_non_retryable(&anyhow::anyhow!("connection reset")));
}
#[tokio::test]
async fn skips_retries_on_non_retryable_error() {
let primary_calls = Arc::new(AtomicUsize::new(0));
let fallback_calls = Arc::new(AtomicUsize::new(0));
let provider = ReliableProvider::new(
vec![
(
"primary".into(),
Box::new(MockProvider {
calls: Arc::clone(&primary_calls),
fail_until_attempt: usize::MAX,
response: "never",
error: "401 Unauthorized",
}),
),
(
"fallback".into(),
Box::new(MockProvider {
calls: Arc::clone(&fallback_calls),
fail_until_attempt: 0,
response: "from fallback",
error: "fallback err",
}),
),
],
3,
1,
);
let result = provider.simple_chat("hello", "test", 0.0).await.unwrap();
assert_eq!(result, "from fallback");
// Primary should have been called only once (no retries)
assert_eq!(primary_calls.load(Ordering::SeqCst), 1);
assert_eq!(fallback_calls.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn chat_with_history_retries_then_recovers() {
let calls = Arc::new(AtomicUsize::new(0));
let provider = ReliableProvider::new(
vec![(
"primary".into(),
Box::new(MockProvider {
calls: Arc::clone(&calls),
fail_until_attempt: 1,
response: "history ok",
error: "temporary",
}),
)],
2,
1,
);
let messages = vec![ChatMessage::system("system"), ChatMessage::user("hello")];
let result = provider
.chat_with_history(&messages, "test", 0.0)
.await
.unwrap();
assert_eq!(result, "history ok");
assert_eq!(calls.load(Ordering::SeqCst), 2);
}
#[tokio::test]
async fn chat_with_history_falls_back() {
let primary_calls = Arc::new(AtomicUsize::new(0));
let fallback_calls = Arc::new(AtomicUsize::new(0));
let provider = ReliableProvider::new(
vec![
(
"primary".into(),
Box::new(MockProvider {
calls: Arc::clone(&primary_calls),
fail_until_attempt: usize::MAX,
response: "never",
error: "primary down",
}),
),
(
"fallback".into(),
Box::new(MockProvider {
calls: Arc::clone(&fallback_calls),
fail_until_attempt: 0,
response: "fallback ok",
error: "fallback err",
}),
),
],
1,
1,
);
let messages = vec![ChatMessage::user("hello")];
let result = provider
.chat_with_history(&messages, "test", 0.0)
.await
.unwrap();
assert_eq!(result, "fallback ok");
assert_eq!(primary_calls.load(Ordering::SeqCst), 2);
assert_eq!(fallback_calls.load(Ordering::SeqCst), 1);
}
// ── New tests: model failover ──
#[tokio::test]
async fn model_failover_tries_fallback_model() {
let calls = Arc::new(AtomicUsize::new(0));
let mock = Arc::new(ModelAwareMock {
calls: Arc::clone(&calls),
models_seen: parking_lot::Mutex::new(Vec::new()),
fail_models: vec!["claude-opus"],
response: "ok from sonnet",
});
let mut fallbacks = HashMap::new();
fallbacks.insert("claude-opus".to_string(), vec!["claude-sonnet".to_string()]);
let provider = ReliableProvider::new(
vec![(
"anthropic".into(),
Box::new(mock.clone()) as Box<dyn Provider>,
)],
0, // no retries — force immediate model failover
1,
)
.with_model_fallbacks(fallbacks);
let result = provider
.simple_chat("hello", "claude-opus", 0.0)
.await
.unwrap();
assert_eq!(result, "ok from sonnet");
let seen = mock.models_seen.lock();
assert_eq!(seen.len(), 2);
assert_eq!(seen[0], "claude-opus");
assert_eq!(seen[1], "claude-sonnet");
}
#[tokio::test]
async fn model_failover_all_models_fail() {
let calls = Arc::new(AtomicUsize::new(0));
let mock = Arc::new(ModelAwareMock {
calls: Arc::clone(&calls),
models_seen: parking_lot::Mutex::new(Vec::new()),
fail_models: vec!["model-a", "model-b", "model-c"],
response: "never",
});
let mut fallbacks = HashMap::new();
fallbacks.insert(
"model-a".to_string(),
vec!["model-b".to_string(), "model-c".to_string()],
);
let provider = ReliableProvider::new(
vec![("p1".into(), Box::new(mock.clone()) as Box<dyn Provider>)],
0,
1,
)
.with_model_fallbacks(fallbacks);
let err = provider
.simple_chat("hello", "model-a", 0.0)
.await
.expect_err("all models should fail");
assert!(err.to_string().contains("All providers/models failed"));
let seen = mock.models_seen.lock();
assert_eq!(seen.len(), 3);
}
#[tokio::test]
async fn no_model_fallbacks_behaves_like_before() {
let calls = Arc::new(AtomicUsize::new(0));
let provider = ReliableProvider::new(
vec![(
"primary".into(),
Box::new(MockProvider {
calls: Arc::clone(&calls),
fail_until_attempt: 0,
response: "ok",
error: "boom",
}),
)],
2,
1,
);
// No model_fallbacks set — should work exactly as before
let result = provider.simple_chat("hello", "test", 0.0).await.unwrap();
assert_eq!(result, "ok");
assert_eq!(calls.load(Ordering::SeqCst), 1);
}
// ── New tests: auth rotation ──
#[tokio::test]
async fn auth_rotation_cycles_keys() {
let provider = ReliableProvider::new(
vec![(
"p".into(),
Box::new(MockProvider {
calls: Arc::new(AtomicUsize::new(0)),
fail_until_attempt: 0,
response: "ok",
error: "",
}),
)],
0,
1,
)
.with_api_keys(vec!["key-a".into(), "key-b".into(), "key-c".into()]);
// Rotate 5 times, verify round-robin
let keys: Vec<&str> = (0..5).map(|_| provider.rotate_key().unwrap()).collect();
assert_eq!(keys, vec!["key-a", "key-b", "key-c", "key-a", "key-b"]);
}
#[tokio::test]
async fn auth_rotation_returns_none_when_empty() {
let provider = ReliableProvider::new(vec![], 0, 1);
assert!(provider.rotate_key().is_none());
}
// ── New tests: Retry-After parsing ──
#[test]
fn parse_retry_after_integer() {
let err = anyhow::anyhow!("429 Too Many Requests, Retry-After: 5");
assert_eq!(parse_retry_after_ms(&err), Some(5000));
}
#[test]
fn parse_retry_after_float() {
let err = anyhow::anyhow!("Rate limited. retry_after: 2.5 seconds");
assert_eq!(parse_retry_after_ms(&err), Some(2500));
}
#[test]
fn parse_retry_after_missing() {
let err = anyhow::anyhow!("500 Internal Server Error");
assert_eq!(parse_retry_after_ms(&err), None);
}
#[test]
fn rate_limited_detection() {
assert!(is_rate_limited(&anyhow::anyhow!("429 Too Many Requests")));
assert!(is_rate_limited(&anyhow::anyhow!(
"HTTP 429 rate limit exceeded"
)));
assert!(!is_rate_limited(&anyhow::anyhow!("401 Unauthorized")));
assert!(!is_rate_limited(&anyhow::anyhow!(
"500 Internal Server Error"
)));
}
#[test]
fn compute_backoff_uses_retry_after() {
let provider = ReliableProvider::new(vec![], 0, 500);
let err = anyhow::anyhow!("429 Retry-After: 3");
assert_eq!(provider.compute_backoff(500, &err), 3000);
}
#[test]
fn compute_backoff_caps_at_30s() {
let provider = ReliableProvider::new(vec![], 0, 500);
let err = anyhow::anyhow!("429 Retry-After: 120");
assert_eq!(provider.compute_backoff(500, &err), 30_000);
}
#[test]
fn compute_backoff_falls_back_to_base() {
let provider = ReliableProvider::new(vec![], 0, 500);
let err = anyhow::anyhow!("500 Server Error");
assert_eq!(provider.compute_backoff(500, &err), 500);
}
// ── Arc<ModelAwareMock> Provider impl for test ──
#[async_trait]
impl Provider for Arc<ModelAwareMock> {
async fn chat_with_system(
&self,
system_prompt: Option<&str>,
message: &str,
model: &str,
temperature: f64,
) -> anyhow::Result<String> {
self.as_ref()
.chat_with_system(system_prompt, message, model, temperature)
.await
}
}
}

385
src/providers/router.rs Normal file
View file

@ -0,0 +1,385 @@
use super::traits::{ChatMessage, ChatRequest, ChatResponse};
use super::Provider;
use async_trait::async_trait;
use std::collections::HashMap;
/// A single route: maps a task hint to a provider + model combo.
#[derive(Debug, Clone)]
pub struct Route {
pub provider_name: String,
pub model: String,
}
/// Multi-model router — routes requests to different provider+model combos
/// based on a task hint encoded in the model parameter.
///
/// The model parameter can be:
/// - A regular model name (e.g. "anthropic/claude-sonnet-4") → uses default provider
/// - A hint-prefixed string (e.g. "hint:reasoning") → resolves via route table
///
/// This wraps multiple pre-created providers and selects the right one per request.
pub struct RouterProvider {
routes: HashMap<String, (usize, String)>, // hint → (provider_index, model)
providers: Vec<(String, Box<dyn Provider>)>,
default_index: usize,
default_model: String,
}
impl RouterProvider {
/// Create a new router with a default provider and optional routes.
///
/// `providers` is a list of (name, provider) pairs. The first one is the default.
/// `routes` maps hint names to Route structs containing provider_name and model.
pub fn new(
providers: Vec<(String, Box<dyn Provider>)>,
routes: Vec<(String, Route)>,
default_model: String,
) -> Self {
// Build provider name → index lookup
let name_to_index: HashMap<&str, usize> = providers
.iter()
.enumerate()
.map(|(i, (name, _))| (name.as_str(), i))
.collect();
// Resolve routes to provider indices
let resolved_routes: HashMap<String, (usize, String)> = routes
.into_iter()
.filter_map(|(hint, route)| {
let index = name_to_index.get(route.provider_name.as_str()).copied();
match index {
Some(i) => Some((hint, (i, route.model))),
None => {
tracing::warn!(
hint = hint,
provider = route.provider_name,
"Route references unknown provider, skipping"
);
None
}
}
})
.collect();
Self {
routes: resolved_routes,
providers,
default_index: 0,
default_model,
}
}
/// Resolve a model parameter to a (provider, actual_model) pair.
///
/// If the model starts with "hint:", look up the hint in the route table.
/// Otherwise, use the default provider with the given model name.
/// Resolve a model parameter to a (provider_index, actual_model) pair.
fn resolve(&self, model: &str) -> (usize, String) {
if let Some(hint) = model.strip_prefix("hint:") {
if let Some((idx, resolved_model)) = self.routes.get(hint) {
return (*idx, resolved_model.clone());
}
tracing::warn!(
hint = hint,
"Unknown route hint, falling back to default provider"
);
}
// Not a hint or hint not found — use default provider with the model as-is
(self.default_index, model.to_string())
}
}
#[async_trait]
impl Provider for RouterProvider {
async fn chat_with_system(
&self,
system_prompt: Option<&str>,
message: &str,
model: &str,
temperature: f64,
) -> anyhow::Result<String> {
let (provider_idx, resolved_model) = self.resolve(model);
let (provider_name, provider) = &self.providers[provider_idx];
tracing::info!(
provider = provider_name.as_str(),
model = resolved_model.as_str(),
"Router dispatching request"
);
provider
.chat_with_system(system_prompt, message, &resolved_model, temperature)
.await
}
async fn chat_with_history(
&self,
messages: &[ChatMessage],
model: &str,
temperature: f64,
) -> anyhow::Result<String> {
let (provider_idx, resolved_model) = self.resolve(model);
let (_, provider) = &self.providers[provider_idx];
provider
.chat_with_history(messages, &resolved_model, temperature)
.await
}
async fn chat(
&self,
request: ChatRequest<'_>,
model: &str,
temperature: f64,
) -> anyhow::Result<ChatResponse> {
let (provider_idx, resolved_model) = self.resolve(model);
let (_, provider) = &self.providers[provider_idx];
provider.chat(request, &resolved_model, temperature).await
}
fn supports_native_tools(&self) -> bool {
self.providers
.get(self.default_index)
.map(|(_, p)| p.supports_native_tools())
.unwrap_or(false)
}
async fn warmup(&self) -> anyhow::Result<()> {
for (name, provider) in &self.providers {
tracing::info!(provider = name, "Warming up routed provider");
if let Err(e) = provider.warmup().await {
tracing::warn!(provider = name, "Warmup failed (non-fatal): {e}");
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
struct MockProvider {
calls: Arc<AtomicUsize>,
response: &'static str,
last_model: parking_lot::Mutex<String>,
}
impl MockProvider {
fn new(response: &'static str) -> Self {
Self {
calls: Arc::new(AtomicUsize::new(0)),
response,
last_model: parking_lot::Mutex::new(String::new()),
}
}
fn call_count(&self) -> usize {
self.calls.load(Ordering::SeqCst)
}
fn last_model(&self) -> String {
self.last_model.lock().clone()
}
}
#[async_trait]
impl Provider for MockProvider {
async fn chat_with_system(
&self,
_system_prompt: Option<&str>,
_message: &str,
model: &str,
_temperature: f64,
) -> anyhow::Result<String> {
self.calls.fetch_add(1, Ordering::SeqCst);
*self.last_model.lock() = model.to_string();
Ok(self.response.to_string())
}
}
fn make_router(
providers: Vec<(&'static str, &'static str)>,
routes: Vec<(&str, &str, &str)>,
) -> (RouterProvider, Vec<Arc<MockProvider>>) {
let mocks: Vec<Arc<MockProvider>> = providers
.iter()
.map(|(_, response)| Arc::new(MockProvider::new(response)))
.collect();
let provider_list: Vec<(String, Box<dyn Provider>)> = providers
.iter()
.zip(mocks.iter())
.map(|((name, _), mock)| {
(
name.to_string(),
Box::new(Arc::clone(mock)) as Box<dyn Provider>,
)
})
.collect();
let route_list: Vec<(String, Route)> = routes
.iter()
.map(|(hint, provider_name, model)| {
(
hint.to_string(),
Route {
provider_name: provider_name.to_string(),
model: model.to_string(),
},
)
})
.collect();
let router = RouterProvider::new(provider_list, route_list, "default-model".to_string());
(router, mocks)
}
// Arc<MockProvider> should also be a Provider
#[async_trait]
impl Provider for Arc<MockProvider> {
async fn chat_with_system(
&self,
system_prompt: Option<&str>,
message: &str,
model: &str,
temperature: f64,
) -> anyhow::Result<String> {
self.as_ref()
.chat_with_system(system_prompt, message, model, temperature)
.await
}
}
#[tokio::test]
async fn routes_hint_to_correct_provider() {
let (router, mocks) = make_router(
vec![("fast", "fast-response"), ("smart", "smart-response")],
vec![
("fast", "fast", "llama-3-70b"),
("reasoning", "smart", "claude-opus"),
],
);
let result = router
.simple_chat("hello", "hint:reasoning", 0.5)
.await
.unwrap();
assert_eq!(result, "smart-response");
assert_eq!(mocks[1].call_count(), 1);
assert_eq!(mocks[1].last_model(), "claude-opus");
assert_eq!(mocks[0].call_count(), 0);
}
#[tokio::test]
async fn routes_fast_hint() {
let (router, mocks) = make_router(
vec![("fast", "fast-response"), ("smart", "smart-response")],
vec![("fast", "fast", "llama-3-70b")],
);
let result = router.simple_chat("hello", "hint:fast", 0.5).await.unwrap();
assert_eq!(result, "fast-response");
assert_eq!(mocks[0].call_count(), 1);
assert_eq!(mocks[0].last_model(), "llama-3-70b");
}
#[tokio::test]
async fn unknown_hint_falls_back_to_default() {
let (router, mocks) = make_router(
vec![("default", "default-response"), ("other", "other-response")],
vec![],
);
let result = router
.simple_chat("hello", "hint:nonexistent", 0.5)
.await
.unwrap();
assert_eq!(result, "default-response");
assert_eq!(mocks[0].call_count(), 1);
// Falls back to default with the hint as model name
assert_eq!(mocks[0].last_model(), "hint:nonexistent");
}
#[tokio::test]
async fn non_hint_model_uses_default_provider() {
let (router, mocks) = make_router(
vec![
("primary", "primary-response"),
("secondary", "secondary-response"),
],
vec![("code", "secondary", "codellama")],
);
let result = router
.simple_chat("hello", "anthropic/claude-sonnet-4-20250514", 0.5)
.await
.unwrap();
assert_eq!(result, "primary-response");
assert_eq!(mocks[0].call_count(), 1);
assert_eq!(mocks[0].last_model(), "anthropic/claude-sonnet-4-20250514");
}
#[test]
fn resolve_preserves_model_for_non_hints() {
let (router, _) = make_router(vec![("default", "ok")], vec![]);
let (idx, model) = router.resolve("gpt-4o");
assert_eq!(idx, 0);
assert_eq!(model, "gpt-4o");
}
#[test]
fn resolve_strips_hint_prefix() {
let (router, _) = make_router(
vec![("fast", "ok"), ("smart", "ok")],
vec![("reasoning", "smart", "claude-opus")],
);
let (idx, model) = router.resolve("hint:reasoning");
assert_eq!(idx, 1);
assert_eq!(model, "claude-opus");
}
#[test]
fn skips_routes_with_unknown_provider() {
let (router, _) = make_router(
vec![("default", "ok")],
vec![("broken", "nonexistent", "model")],
);
// Route should not exist
assert!(!router.routes.contains_key("broken"));
}
#[tokio::test]
async fn warmup_calls_all_providers() {
let (router, _) = make_router(vec![("a", "ok"), ("b", "ok")], vec![]);
// Warmup should not error
assert!(router.warmup().await.is_ok());
}
#[tokio::test]
async fn chat_with_system_passes_system_prompt() {
let mock = Arc::new(MockProvider::new("response"));
let router = RouterProvider::new(
vec![(
"default".into(),
Box::new(Arc::clone(&mock)) as Box<dyn Provider>,
)],
vec![],
"model".into(),
);
let result = router
.chat_with_system(Some("system"), "hello", "model", 0.5)
.await
.unwrap();
assert_eq!(result, "response");
assert_eq!(mock.call_count(), 1);
}
}

View file

@ -1,12 +1,269 @@
use crate::tools::ToolSpec;
use async_trait::async_trait;
use futures_util::{stream, StreamExt};
use serde::{Deserialize, Serialize};
use std::fmt::Write;
/// A single message in a conversation.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChatMessage {
pub role: String,
pub content: String,
}
impl ChatMessage {
pub fn system(content: impl Into<String>) -> Self {
Self {
role: "system".into(),
content: content.into(),
}
}
pub fn user(content: impl Into<String>) -> Self {
Self {
role: "user".into(),
content: content.into(),
}
}
pub fn assistant(content: impl Into<String>) -> Self {
Self {
role: "assistant".into(),
content: content.into(),
}
}
pub fn tool(content: impl Into<String>) -> Self {
Self {
role: "tool".into(),
content: content.into(),
}
}
}
/// A tool call requested by the LLM.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolCall {
pub id: String,
pub name: String,
pub arguments: String,
}
/// An LLM response that may contain text, tool calls, or both.
#[derive(Debug, Clone)]
pub struct ChatResponse {
/// Text content of the response (may be empty if only tool calls).
pub text: Option<String>,
/// Tool calls requested by the LLM.
pub tool_calls: Vec<ToolCall>,
}
impl ChatResponse {
/// True when the LLM wants to invoke at least one tool.
pub fn has_tool_calls(&self) -> bool {
!self.tool_calls.is_empty()
}
/// Convenience: return text content or empty string.
pub fn text_or_empty(&self) -> &str {
self.text.as_deref().unwrap_or("")
}
}
/// Request payload for provider chat calls.
#[derive(Debug, Clone, Copy)]
pub struct ChatRequest<'a> {
pub messages: &'a [ChatMessage],
pub tools: Option<&'a [ToolSpec]>,
}
/// A tool result to feed back to the LLM.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolResultMessage {
pub tool_call_id: String,
pub content: String,
}
/// A message in a multi-turn conversation, including tool interactions.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", content = "data")]
pub enum ConversationMessage {
/// Regular chat message (system, user, assistant).
Chat(ChatMessage),
/// Tool calls from the assistant (stored for history fidelity).
AssistantToolCalls {
text: Option<String>,
tool_calls: Vec<ToolCall>,
},
/// Results of tool executions, fed back to the LLM.
ToolResults(Vec<ToolResultMessage>),
}
/// A chunk of content from a streaming response.
#[derive(Debug, Clone)]
pub struct StreamChunk {
/// Text delta for this chunk.
pub delta: String,
/// Whether this is the final chunk.
pub is_final: bool,
/// Approximate token count for this chunk (estimated).
pub token_count: usize,
}
impl StreamChunk {
/// Create a new non-final chunk.
pub fn delta(text: impl Into<String>) -> Self {
Self {
delta: text.into(),
is_final: false,
token_count: 0,
}
}
/// Create a final chunk.
pub fn final_chunk() -> Self {
Self {
delta: String::new(),
is_final: true,
token_count: 0,
}
}
/// Create an error chunk.
pub fn error(message: impl Into<String>) -> Self {
Self {
delta: message.into(),
is_final: true,
token_count: 0,
}
}
/// Estimate tokens (rough approximation: ~4 chars per token).
pub fn with_token_estimate(mut self) -> Self {
self.token_count = self.delta.len().div_ceil(4);
self
}
}
/// Options for streaming chat requests.
#[derive(Debug, Clone, Copy, Default)]
pub struct StreamOptions {
/// Whether to enable streaming (default: true).
pub enabled: bool,
/// Whether to include token counts in chunks.
pub count_tokens: bool,
}
impl StreamOptions {
/// Create new streaming options with enabled flag.
pub fn new(enabled: bool) -> Self {
Self {
enabled,
count_tokens: false,
}
}
/// Enable token counting.
pub fn with_token_count(mut self) -> Self {
self.count_tokens = true;
self
}
}
/// Result type for streaming operations.
pub type StreamResult<T> = std::result::Result<T, StreamError>;
/// Errors that can occur during streaming.
#[derive(Debug, thiserror::Error)]
pub enum StreamError {
#[error("HTTP error: {0}")]
Http(reqwest::Error),
#[error("JSON parse error: {0}")]
Json(serde_json::Error),
#[error("Invalid SSE format: {0}")]
InvalidSse(String),
#[error("Provider error: {0}")]
Provider(String),
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
}
/// Provider capabilities declaration.
///
/// Describes what features a provider supports, enabling intelligent
/// adaptation of tool calling modes and request formatting.
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub struct ProviderCapabilities {
/// Whether the provider supports native tool calling via API primitives.
///
/// When `true`, the provider can convert tool definitions to API-native
/// formats (e.g., Gemini's functionDeclarations, Anthropic's input_schema).
///
/// When `false`, tools must be injected via system prompt as text.
pub native_tool_calling: bool,
}
/// Provider-specific tool payload formats.
///
/// Different LLM providers require different formats for tool definitions.
/// This enum encapsulates those variations, enabling providers to convert
/// from the unified `ToolSpec` format to their native API requirements.
#[derive(Debug, Clone)]
pub enum ToolsPayload {
/// Gemini API format (functionDeclarations).
Gemini {
function_declarations: Vec<serde_json::Value>,
},
/// Anthropic Messages API format (tools with input_schema).
Anthropic { tools: Vec<serde_json::Value> },
/// OpenAI Chat Completions API format (tools with function).
OpenAI { tools: Vec<serde_json::Value> },
/// Prompt-guided fallback (tools injected as text in system prompt).
PromptGuided { instructions: String },
}
#[async_trait]
pub trait Provider: Send + Sync {
async fn chat(&self, message: &str, model: &str, temperature: f64) -> anyhow::Result<String> {
/// Query provider capabilities.
///
/// Default implementation returns minimal capabilities (no native tool calling).
/// Providers should override this to declare their actual capabilities.
fn capabilities(&self) -> ProviderCapabilities {
ProviderCapabilities::default()
}
/// Convert tool specifications to provider-native format.
///
/// Default implementation returns `PromptGuided` payload, which injects
/// tool documentation into the system prompt as text. Providers with
/// native tool calling support should override this to return their
/// specific format (Gemini, Anthropic, OpenAI).
fn convert_tools(&self, tools: &[ToolSpec]) -> ToolsPayload {
ToolsPayload::PromptGuided {
instructions: build_tool_instructions_text(tools),
}
}
/// Simple one-shot chat (single user message, no explicit system prompt).
///
/// This is the preferred API for non-agentic direct interactions.
async fn simple_chat(
&self,
message: &str,
model: &str,
temperature: f64,
) -> anyhow::Result<String> {
self.chat_with_system(None, message, model, temperature)
.await
}
/// One-shot chat with optional system prompt.
///
/// Kept for compatibility and advanced one-shot prompting.
async fn chat_with_system(
&self,
system_prompt: Option<&str>,
@ -14,4 +271,605 @@ pub trait Provider: Send + Sync {
model: &str,
temperature: f64,
) -> anyhow::Result<String>;
/// Multi-turn conversation. Default implementation extracts the last user
/// message and delegates to `chat_with_system`.
async fn chat_with_history(
&self,
messages: &[ChatMessage],
model: &str,
temperature: f64,
) -> anyhow::Result<String> {
let system = messages
.iter()
.find(|m| m.role == "system")
.map(|m| m.content.as_str());
let last_user = messages
.iter()
.rfind(|m| m.role == "user")
.map(|m| m.content.as_str())
.unwrap_or("");
self.chat_with_system(system, last_user, model, temperature)
.await
}
/// Structured chat API for agent loop callers.
async fn chat(
&self,
request: ChatRequest<'_>,
model: &str,
temperature: f64,
) -> anyhow::Result<ChatResponse> {
// If tools are provided but provider doesn't support native tools,
// inject tool instructions into system prompt as fallback.
if let Some(tools) = request.tools {
if !tools.is_empty() && !self.supports_native_tools() {
let tool_instructions = match self.convert_tools(tools) {
ToolsPayload::PromptGuided { instructions } => instructions,
payload => {
anyhow::bail!(
"Provider returned non-prompt-guided tools payload ({payload:?}) while supports_native_tools() is false"
)
}
};
let mut modified_messages = request.messages.to_vec();
// Inject tool instructions into an existing system message.
// If none exists, prepend one to the conversation.
if let Some(system_message) =
modified_messages.iter_mut().find(|m| m.role == "system")
{
if !system_message.content.is_empty() {
system_message.content.push_str("\n\n");
}
system_message.content.push_str(&tool_instructions);
} else {
modified_messages.insert(0, ChatMessage::system(tool_instructions));
}
let text = self
.chat_with_history(&modified_messages, model, temperature)
.await?;
return Ok(ChatResponse {
text: Some(text),
tool_calls: Vec::new(),
});
}
}
let text = self
.chat_with_history(request.messages, model, temperature)
.await?;
Ok(ChatResponse {
text: Some(text),
tool_calls: Vec::new(),
})
}
/// Whether provider supports native tool calls over API.
fn supports_native_tools(&self) -> bool {
self.capabilities().native_tool_calling
}
/// Warm up the HTTP connection pool (TLS handshake, DNS, HTTP/2 setup).
/// Default implementation is a no-op; providers with HTTP clients should override.
async fn warmup(&self) -> anyhow::Result<()> {
Ok(())
}
/// Chat with tool definitions for native function calling support.
/// The default implementation falls back to chat_with_history and returns
/// an empty tool_calls vector (prompt-based tool use only).
async fn chat_with_tools(
&self,
messages: &[ChatMessage],
_tools: &[serde_json::Value],
model: &str,
temperature: f64,
) -> anyhow::Result<ChatResponse> {
let text = self.chat_with_history(messages, model, temperature).await?;
Ok(ChatResponse {
text: Some(text),
tool_calls: Vec::new(),
})
}
/// Whether provider supports streaming responses.
/// Default implementation returns false.
fn supports_streaming(&self) -> bool {
false
}
/// Streaming chat with optional system prompt.
/// Returns an async stream of text chunks.
/// Default implementation falls back to non-streaming chat.
fn stream_chat_with_system(
&self,
_system_prompt: Option<&str>,
_message: &str,
_model: &str,
_temperature: f64,
_options: StreamOptions,
) -> stream::BoxStream<'static, StreamResult<StreamChunk>> {
// Default: return an empty stream (not supported)
stream::empty().boxed()
}
/// Streaming chat with history.
/// Default implementation falls back to stream_chat_with_system with last user message.
fn stream_chat_with_history(
&self,
_messages: &[ChatMessage],
_model: &str,
_temperature: f64,
_options: StreamOptions,
) -> stream::BoxStream<'static, StreamResult<StreamChunk>> {
// For default implementation, we need to convert to owned strings
// This is a limitation of the default implementation
let provider_name = "unknown".to_string();
// Create a single empty chunk to indicate not supported
let chunk = StreamChunk::error(format!("{} does not support streaming", provider_name));
stream::once(async move { Ok(chunk) }).boxed()
}
}
/// Build tool instructions text for prompt-guided tool calling.
///
/// Generates a formatted text block describing available tools and how to
/// invoke them using XML-style tags. This is used as a fallback when the
/// provider doesn't support native tool calling.
pub fn build_tool_instructions_text(tools: &[ToolSpec]) -> String {
let mut instructions = String::new();
instructions.push_str("## Tool Use Protocol\n\n");
instructions.push_str("To use a tool, wrap a JSON object in <tool_call></tool_call> tags:\n\n");
instructions.push_str("<tool_call>\n");
instructions.push_str(r#"{"name": "tool_name", "arguments": {"param": "value"}}"#);
instructions.push_str("\n</tool_call>\n\n");
instructions.push_str("You may use multiple tool calls in a single response. ");
instructions.push_str("After tool execution, results appear in <tool_result> tags. ");
instructions
.push_str("Continue reasoning with the results until you can give a final answer.\n\n");
instructions.push_str("### Available Tools\n\n");
for tool in tools {
writeln!(&mut instructions, "**{}**: {}", tool.name, tool.description)
.expect("writing to String cannot fail");
let parameters =
serde_json::to_string(&tool.parameters).unwrap_or_else(|_| "{}".to_string());
writeln!(&mut instructions, "Parameters: `{parameters}`")
.expect("writing to String cannot fail");
instructions.push('\n');
}
instructions
}
#[cfg(test)]
mod tests {
use super::*;
struct CapabilityMockProvider;
#[async_trait]
impl Provider for CapabilityMockProvider {
fn capabilities(&self) -> ProviderCapabilities {
ProviderCapabilities {
native_tool_calling: true,
}
}
async fn chat_with_system(
&self,
_system_prompt: Option<&str>,
_message: &str,
_model: &str,
_temperature: f64,
) -> anyhow::Result<String> {
Ok("ok".into())
}
}
#[test]
fn chat_message_constructors() {
let sys = ChatMessage::system("Be helpful");
assert_eq!(sys.role, "system");
assert_eq!(sys.content, "Be helpful");
let user = ChatMessage::user("Hello");
assert_eq!(user.role, "user");
let asst = ChatMessage::assistant("Hi there");
assert_eq!(asst.role, "assistant");
let tool = ChatMessage::tool("{}");
assert_eq!(tool.role, "tool");
}
#[test]
fn chat_response_helpers() {
let empty = ChatResponse {
text: None,
tool_calls: vec![],
};
assert!(!empty.has_tool_calls());
assert_eq!(empty.text_or_empty(), "");
let with_tools = ChatResponse {
text: Some("Let me check".into()),
tool_calls: vec![ToolCall {
id: "1".into(),
name: "shell".into(),
arguments: "{}".into(),
}],
};
assert!(with_tools.has_tool_calls());
assert_eq!(with_tools.text_or_empty(), "Let me check");
}
#[test]
fn tool_call_serialization() {
let tc = ToolCall {
id: "call_123".into(),
name: "file_read".into(),
arguments: r#"{"path":"test.txt"}"#.into(),
};
let json = serde_json::to_string(&tc).unwrap();
assert!(json.contains("call_123"));
assert!(json.contains("file_read"));
}
#[test]
fn conversation_message_variants() {
let chat = ConversationMessage::Chat(ChatMessage::user("hi"));
let json = serde_json::to_string(&chat).unwrap();
assert!(json.contains("\"type\":\"Chat\""));
let tool_result = ConversationMessage::ToolResults(vec![ToolResultMessage {
tool_call_id: "1".into(),
content: "done".into(),
}]);
let json = serde_json::to_string(&tool_result).unwrap();
assert!(json.contains("\"type\":\"ToolResults\""));
}
#[test]
fn provider_capabilities_default() {
let caps = ProviderCapabilities::default();
assert!(!caps.native_tool_calling);
}
#[test]
fn provider_capabilities_equality() {
let caps1 = ProviderCapabilities {
native_tool_calling: true,
};
let caps2 = ProviderCapabilities {
native_tool_calling: true,
};
let caps3 = ProviderCapabilities {
native_tool_calling: false,
};
assert_eq!(caps1, caps2);
assert_ne!(caps1, caps3);
}
#[test]
fn supports_native_tools_reflects_capabilities_default_mapping() {
let provider = CapabilityMockProvider;
assert!(provider.supports_native_tools());
}
#[test]
fn tools_payload_variants() {
// Test Gemini variant
let gemini = ToolsPayload::Gemini {
function_declarations: vec![serde_json::json!({"name": "test"})],
};
assert!(matches!(gemini, ToolsPayload::Gemini { .. }));
// Test Anthropic variant
let anthropic = ToolsPayload::Anthropic {
tools: vec![serde_json::json!({"name": "test"})],
};
assert!(matches!(anthropic, ToolsPayload::Anthropic { .. }));
// Test OpenAI variant
let openai = ToolsPayload::OpenAI {
tools: vec![serde_json::json!({"type": "function"})],
};
assert!(matches!(openai, ToolsPayload::OpenAI { .. }));
// Test PromptGuided variant
let prompt_guided = ToolsPayload::PromptGuided {
instructions: "Use tools...".to_string(),
};
assert!(matches!(prompt_guided, ToolsPayload::PromptGuided { .. }));
}
#[test]
fn build_tool_instructions_text_format() {
let tools = vec![
ToolSpec {
name: "shell".to_string(),
description: "Execute commands".to_string(),
parameters: serde_json::json!({
"type": "object",
"properties": {
"command": {"type": "string"}
}
}),
},
ToolSpec {
name: "file_read".to_string(),
description: "Read files".to_string(),
parameters: serde_json::json!({
"type": "object",
"properties": {
"path": {"type": "string"}
}
}),
},
];
let instructions = build_tool_instructions_text(&tools);
// Check for protocol description
assert!(instructions.contains("Tool Use Protocol"));
assert!(instructions.contains("<tool_call>"));
assert!(instructions.contains("</tool_call>"));
// Check for tool listings
assert!(instructions.contains("**shell**"));
assert!(instructions.contains("Execute commands"));
assert!(instructions.contains("**file_read**"));
assert!(instructions.contains("Read files"));
// Check for parameters
assert!(instructions.contains("Parameters:"));
assert!(instructions.contains(r#""type":"object""#));
}
#[test]
fn build_tool_instructions_text_empty() {
let instructions = build_tool_instructions_text(&[]);
// Should still have protocol description
assert!(instructions.contains("Tool Use Protocol"));
// Should have empty tools section
assert!(instructions.contains("Available Tools"));
}
// Mock provider for testing.
struct MockProvider {
supports_native: bool,
}
#[async_trait]
impl Provider for MockProvider {
fn supports_native_tools(&self) -> bool {
self.supports_native
}
async fn chat_with_system(
&self,
_system: Option<&str>,
_message: &str,
_model: &str,
_temperature: f64,
) -> anyhow::Result<String> {
Ok("response".to_string())
}
}
#[test]
fn provider_convert_tools_default() {
let provider = MockProvider {
supports_native: false,
};
let tools = vec![ToolSpec {
name: "test_tool".to_string(),
description: "A test tool".to_string(),
parameters: serde_json::json!({"type": "object"}),
}];
let payload = provider.convert_tools(&tools);
// Default implementation should return PromptGuided.
assert!(matches!(payload, ToolsPayload::PromptGuided { .. }));
if let ToolsPayload::PromptGuided { instructions } = payload {
assert!(instructions.contains("test_tool"));
assert!(instructions.contains("A test tool"));
}
}
#[tokio::test]
async fn provider_chat_prompt_guided_fallback() {
let provider = MockProvider {
supports_native: false,
};
let tools = vec![ToolSpec {
name: "shell".to_string(),
description: "Run commands".to_string(),
parameters: serde_json::json!({"type": "object"}),
}];
let request = ChatRequest {
messages: &[ChatMessage::user("Hello")],
tools: Some(&tools),
};
let response = provider.chat(request, "model", 0.7).await.unwrap();
// Should return a response (default impl calls chat_with_history).
assert!(response.text.is_some());
}
#[tokio::test]
async fn provider_chat_without_tools() {
let provider = MockProvider {
supports_native: true,
};
let request = ChatRequest {
messages: &[ChatMessage::user("Hello")],
tools: None,
};
let response = provider.chat(request, "model", 0.7).await.unwrap();
// Should work normally without tools.
assert!(response.text.is_some());
}
// Provider that echoes the system prompt for assertions.
struct EchoSystemProvider {
supports_native: bool,
}
#[async_trait]
impl Provider for EchoSystemProvider {
fn supports_native_tools(&self) -> bool {
self.supports_native
}
async fn chat_with_system(
&self,
system: Option<&str>,
_message: &str,
_model: &str,
_temperature: f64,
) -> anyhow::Result<String> {
Ok(system.unwrap_or_default().to_string())
}
}
// Provider with custom prompt-guided conversion.
struct CustomConvertProvider;
#[async_trait]
impl Provider for CustomConvertProvider {
fn supports_native_tools(&self) -> bool {
false
}
fn convert_tools(&self, _tools: &[ToolSpec]) -> ToolsPayload {
ToolsPayload::PromptGuided {
instructions: "CUSTOM_TOOL_INSTRUCTIONS".to_string(),
}
}
async fn chat_with_system(
&self,
system: Option<&str>,
_message: &str,
_model: &str,
_temperature: f64,
) -> anyhow::Result<String> {
Ok(system.unwrap_or_default().to_string())
}
}
// Provider returning an invalid payload for non-native mode.
struct InvalidConvertProvider;
#[async_trait]
impl Provider for InvalidConvertProvider {
fn supports_native_tools(&self) -> bool {
false
}
fn convert_tools(&self, _tools: &[ToolSpec]) -> ToolsPayload {
ToolsPayload::OpenAI {
tools: vec![serde_json::json!({"type": "function"})],
}
}
async fn chat_with_system(
&self,
_system: Option<&str>,
_message: &str,
_model: &str,
_temperature: f64,
) -> anyhow::Result<String> {
Ok("should_not_reach".to_string())
}
}
#[tokio::test]
async fn provider_chat_prompt_guided_preserves_existing_system_not_first() {
let provider = EchoSystemProvider {
supports_native: false,
};
let tools = vec![ToolSpec {
name: "shell".to_string(),
description: "Run commands".to_string(),
parameters: serde_json::json!({"type": "object"}),
}];
let request = ChatRequest {
messages: &[
ChatMessage::user("Hello"),
ChatMessage::system("BASE_SYSTEM_PROMPT"),
],
tools: Some(&tools),
};
let response = provider.chat(request, "model", 0.7).await.unwrap();
let text = response.text.unwrap_or_default();
assert!(text.contains("BASE_SYSTEM_PROMPT"));
assert!(text.contains("Tool Use Protocol"));
}
#[tokio::test]
async fn provider_chat_prompt_guided_uses_convert_tools_override() {
let provider = CustomConvertProvider;
let tools = vec![ToolSpec {
name: "shell".to_string(),
description: "Run commands".to_string(),
parameters: serde_json::json!({"type": "object"}),
}];
let request = ChatRequest {
messages: &[ChatMessage::system("BASE"), ChatMessage::user("Hello")],
tools: Some(&tools),
};
let response = provider.chat(request, "model", 0.7).await.unwrap();
let text = response.text.unwrap_or_default();
assert!(text.contains("BASE"));
assert!(text.contains("CUSTOM_TOOL_INSTRUCTIONS"));
}
#[tokio::test]
async fn provider_chat_prompt_guided_rejects_non_prompt_payload() {
let provider = InvalidConvertProvider;
let tools = vec![ToolSpec {
name: "shell".to_string(),
description: "Run commands".to_string(),
parameters: serde_json::json!({"type": "object"}),
}];
let request = ChatRequest {
messages: &[ChatMessage::user("Hello")],
tools: Some(&tools),
};
let err = provider.chat(request, "model", 0.7).await.unwrap_err();
let message = err.to_string();
assert!(message.contains("non-prompt-guided"));
}
}