Merge remote-tracking branch 'origin/main' into feat/glm-provider
Resolved conflicts in: - Cargo.toml: kept both `ring` (JWT auth) and `prost` (protobuf) dependencies - src/onboard/wizard.rs: accepted main branch version - src/providers/mod.rs: accepted main branch version - Cargo.lock: accepted main branch version Note: The custom `glm::GlmProvider` from this PR was replaced with main's OpenAiCompatibleProvider approach for GLM, which uses base URLs. The main purpose of this PR is Windows daemon support via Task Scheduler. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
commit
34af6a223a
269 changed files with 68574 additions and 2541 deletions
|
|
@ -1,10 +1,15 @@
|
|||
use crate::providers::traits::Provider;
|
||||
use crate::providers::traits::{
|
||||
ChatMessage, ChatRequest as ProviderChatRequest, ChatResponse as ProviderChatResponse,
|
||||
Provider, ToolCall as ProviderToolCall,
|
||||
};
|
||||
use crate::tools::ToolSpec;
|
||||
use async_trait::async_trait;
|
||||
use reqwest::Client;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
pub struct AnthropicProvider {
|
||||
api_key: Option<String>,
|
||||
credential: Option<String>,
|
||||
base_url: String,
|
||||
client: Client,
|
||||
}
|
||||
|
||||
|
|
@ -31,13 +36,91 @@ struct ChatResponse {
|
|||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ContentBlock {
|
||||
text: String,
|
||||
#[serde(rename = "type")]
|
||||
kind: String,
|
||||
#[serde(default)]
|
||||
text: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct NativeChatRequest {
|
||||
model: String,
|
||||
max_tokens: u32,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
system: Option<String>,
|
||||
messages: Vec<NativeMessage>,
|
||||
temperature: f64,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
tools: Option<Vec<NativeToolSpec>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct NativeMessage {
|
||||
role: String,
|
||||
content: Vec<NativeContentOut>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
#[serde(tag = "type")]
|
||||
enum NativeContentOut {
|
||||
#[serde(rename = "text")]
|
||||
Text { text: String },
|
||||
#[serde(rename = "tool_use")]
|
||||
ToolUse {
|
||||
id: String,
|
||||
name: String,
|
||||
input: serde_json::Value,
|
||||
},
|
||||
#[serde(rename = "tool_result")]
|
||||
ToolResult {
|
||||
tool_use_id: String,
|
||||
content: String,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct NativeToolSpec {
|
||||
name: String,
|
||||
description: String,
|
||||
input_schema: serde_json::Value,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct NativeChatResponse {
|
||||
#[serde(default)]
|
||||
content: Vec<NativeContentIn>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct NativeContentIn {
|
||||
#[serde(rename = "type")]
|
||||
kind: String,
|
||||
#[serde(default)]
|
||||
text: Option<String>,
|
||||
#[serde(default)]
|
||||
id: Option<String>,
|
||||
#[serde(default)]
|
||||
name: Option<String>,
|
||||
#[serde(default)]
|
||||
input: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
impl AnthropicProvider {
|
||||
pub fn new(api_key: Option<&str>) -> Self {
|
||||
pub fn new(credential: Option<&str>) -> Self {
|
||||
Self::with_base_url(credential, None)
|
||||
}
|
||||
|
||||
pub fn with_base_url(credential: Option<&str>, base_url: Option<&str>) -> Self {
|
||||
let base_url = base_url
|
||||
.map(|u| u.trim_end_matches('/'))
|
||||
.unwrap_or("https://api.anthropic.com")
|
||||
.to_string();
|
||||
Self {
|
||||
api_key: api_key.map(ToString::to_string),
|
||||
credential: credential
|
||||
.map(str::trim)
|
||||
.filter(|k| !k.is_empty())
|
||||
.map(ToString::to_string),
|
||||
base_url,
|
||||
client: Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(120))
|
||||
.connect_timeout(std::time::Duration::from_secs(10))
|
||||
|
|
@ -45,6 +128,192 @@ impl AnthropicProvider {
|
|||
.unwrap_or_else(|_| Client::new()),
|
||||
}
|
||||
}
|
||||
|
||||
fn is_setup_token(token: &str) -> bool {
|
||||
token.starts_with("sk-ant-oat01-")
|
||||
}
|
||||
|
||||
fn apply_auth(
|
||||
&self,
|
||||
request: reqwest::RequestBuilder,
|
||||
credential: &str,
|
||||
) -> reqwest::RequestBuilder {
|
||||
if Self::is_setup_token(credential) {
|
||||
request
|
||||
.header("Authorization", format!("Bearer {credential}"))
|
||||
.header("anthropic-beta", "oauth-2025-04-20")
|
||||
} else {
|
||||
request.header("x-api-key", credential)
|
||||
}
|
||||
}
|
||||
|
||||
fn convert_tools(tools: Option<&[ToolSpec]>) -> Option<Vec<NativeToolSpec>> {
|
||||
let items = tools?;
|
||||
if items.is_empty() {
|
||||
return None;
|
||||
}
|
||||
Some(
|
||||
items
|
||||
.iter()
|
||||
.map(|tool| NativeToolSpec {
|
||||
name: tool.name.clone(),
|
||||
description: tool.description.clone(),
|
||||
input_schema: tool.parameters.clone(),
|
||||
})
|
||||
.collect(),
|
||||
)
|
||||
}
|
||||
|
||||
fn parse_assistant_tool_call_message(content: &str) -> Option<Vec<NativeContentOut>> {
|
||||
let value = serde_json::from_str::<serde_json::Value>(content).ok()?;
|
||||
let tool_calls = value
|
||||
.get("tool_calls")
|
||||
.and_then(|v| serde_json::from_value::<Vec<ProviderToolCall>>(v.clone()).ok())?;
|
||||
|
||||
let mut blocks = Vec::new();
|
||||
if let Some(text) = value
|
||||
.get("content")
|
||||
.and_then(serde_json::Value::as_str)
|
||||
.map(str::trim)
|
||||
.filter(|t| !t.is_empty())
|
||||
{
|
||||
blocks.push(NativeContentOut::Text {
|
||||
text: text.to_string(),
|
||||
});
|
||||
}
|
||||
for call in tool_calls {
|
||||
let input = serde_json::from_str::<serde_json::Value>(&call.arguments)
|
||||
.unwrap_or_else(|_| serde_json::Value::Object(serde_json::Map::new()));
|
||||
blocks.push(NativeContentOut::ToolUse {
|
||||
id: call.id,
|
||||
name: call.name,
|
||||
input,
|
||||
});
|
||||
}
|
||||
Some(blocks)
|
||||
}
|
||||
|
||||
fn parse_tool_result_message(content: &str) -> Option<NativeMessage> {
|
||||
let value = serde_json::from_str::<serde_json::Value>(content).ok()?;
|
||||
let tool_use_id = value
|
||||
.get("tool_call_id")
|
||||
.and_then(serde_json::Value::as_str)?
|
||||
.to_string();
|
||||
let result = value
|
||||
.get("content")
|
||||
.and_then(serde_json::Value::as_str)
|
||||
.unwrap_or("")
|
||||
.to_string();
|
||||
Some(NativeMessage {
|
||||
role: "user".to_string(),
|
||||
content: vec![NativeContentOut::ToolResult {
|
||||
tool_use_id,
|
||||
content: result,
|
||||
}],
|
||||
})
|
||||
}
|
||||
|
||||
fn convert_messages(messages: &[ChatMessage]) -> (Option<String>, Vec<NativeMessage>) {
|
||||
let mut system_prompt = None;
|
||||
let mut native_messages = Vec::new();
|
||||
|
||||
for msg in messages {
|
||||
match msg.role.as_str() {
|
||||
"system" => {
|
||||
if system_prompt.is_none() {
|
||||
system_prompt = Some(msg.content.clone());
|
||||
}
|
||||
}
|
||||
"assistant" => {
|
||||
if let Some(blocks) = Self::parse_assistant_tool_call_message(&msg.content) {
|
||||
native_messages.push(NativeMessage {
|
||||
role: "assistant".to_string(),
|
||||
content: blocks,
|
||||
});
|
||||
} else {
|
||||
native_messages.push(NativeMessage {
|
||||
role: "assistant".to_string(),
|
||||
content: vec![NativeContentOut::Text {
|
||||
text: msg.content.clone(),
|
||||
}],
|
||||
});
|
||||
}
|
||||
}
|
||||
"tool" => {
|
||||
if let Some(tool_result) = Self::parse_tool_result_message(&msg.content) {
|
||||
native_messages.push(tool_result);
|
||||
} else {
|
||||
native_messages.push(NativeMessage {
|
||||
role: "user".to_string(),
|
||||
content: vec![NativeContentOut::Text {
|
||||
text: msg.content.clone(),
|
||||
}],
|
||||
});
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
native_messages.push(NativeMessage {
|
||||
role: "user".to_string(),
|
||||
content: vec![NativeContentOut::Text {
|
||||
text: msg.content.clone(),
|
||||
}],
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
(system_prompt, native_messages)
|
||||
}
|
||||
|
||||
fn parse_text_response(response: ChatResponse) -> anyhow::Result<String> {
|
||||
response
|
||||
.content
|
||||
.into_iter()
|
||||
.find(|c| c.kind == "text")
|
||||
.and_then(|c| c.text)
|
||||
.ok_or_else(|| anyhow::anyhow!("No response from Anthropic"))
|
||||
}
|
||||
|
||||
fn parse_native_response(response: NativeChatResponse) -> ProviderChatResponse {
|
||||
let mut text_parts = Vec::new();
|
||||
let mut tool_calls = Vec::new();
|
||||
|
||||
for block in response.content {
|
||||
match block.kind.as_str() {
|
||||
"text" => {
|
||||
if let Some(text) = block.text.map(|t| t.trim().to_string()) {
|
||||
if !text.is_empty() {
|
||||
text_parts.push(text);
|
||||
}
|
||||
}
|
||||
}
|
||||
"tool_use" => {
|
||||
let name = block.name.unwrap_or_default();
|
||||
if name.is_empty() {
|
||||
continue;
|
||||
}
|
||||
let arguments = block
|
||||
.input
|
||||
.unwrap_or_else(|| serde_json::Value::Object(serde_json::Map::new()));
|
||||
tool_calls.push(ProviderToolCall {
|
||||
id: block.id.unwrap_or_else(|| uuid::Uuid::new_v4().to_string()),
|
||||
name,
|
||||
arguments: arguments.to_string(),
|
||||
});
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
|
||||
ProviderChatResponse {
|
||||
text: if text_parts.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(text_parts.join("\n"))
|
||||
},
|
||||
tool_calls,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
|
|
@ -56,8 +325,10 @@ impl Provider for AnthropicProvider {
|
|||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
let api_key = self.api_key.as_ref().ok_or_else(|| {
|
||||
anyhow::anyhow!("Anthropic API key not set. Set ANTHROPIC_API_KEY or edit config.toml.")
|
||||
let credential = self.credential.as_ref().ok_or_else(|| {
|
||||
anyhow::anyhow!(
|
||||
"Anthropic credentials not set. Set ANTHROPIC_API_KEY or ANTHROPIC_OAUTH_TOKEN (setup-token)."
|
||||
)
|
||||
})?;
|
||||
|
||||
let request = ChatRequest {
|
||||
|
|
@ -71,29 +342,65 @@ impl Provider for AnthropicProvider {
|
|||
temperature,
|
||||
};
|
||||
|
||||
let response = self
|
||||
let mut request = self
|
||||
.client
|
||||
.post("https://api.anthropic.com/v1/messages")
|
||||
.header("x-api-key", api_key)
|
||||
.post(format!("{}/v1/messages", self.base_url))
|
||||
.header("anthropic-version", "2023-06-01")
|
||||
.header("content-type", "application/json")
|
||||
.json(&request)
|
||||
.send()
|
||||
.await?;
|
||||
.json(&request);
|
||||
|
||||
request = self.apply_auth(request, credential);
|
||||
|
||||
let response = request.send().await?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let error = response.text().await?;
|
||||
anyhow::bail!("Anthropic API error: {error}");
|
||||
return Err(super::api_error("Anthropic", response).await);
|
||||
}
|
||||
|
||||
let chat_response: ChatResponse = response.json().await?;
|
||||
Self::parse_text_response(chat_response)
|
||||
}
|
||||
|
||||
chat_response
|
||||
.content
|
||||
.into_iter()
|
||||
.next()
|
||||
.map(|c| c.text)
|
||||
.ok_or_else(|| anyhow::anyhow!("No response from Anthropic"))
|
||||
async fn chat(
|
||||
&self,
|
||||
request: ProviderChatRequest<'_>,
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<ProviderChatResponse> {
|
||||
let credential = self.credential.as_ref().ok_or_else(|| {
|
||||
anyhow::anyhow!(
|
||||
"Anthropic credentials not set. Set ANTHROPIC_API_KEY or ANTHROPIC_OAUTH_TOKEN (setup-token)."
|
||||
)
|
||||
})?;
|
||||
|
||||
let (system_prompt, messages) = Self::convert_messages(request.messages);
|
||||
let native_request = NativeChatRequest {
|
||||
model: model.to_string(),
|
||||
max_tokens: 4096,
|
||||
system: system_prompt,
|
||||
messages,
|
||||
temperature,
|
||||
tools: Self::convert_tools(request.tools),
|
||||
};
|
||||
|
||||
let req = self
|
||||
.client
|
||||
.post(format!("{}/v1/messages", self.base_url))
|
||||
.header("anthropic-version", "2023-06-01")
|
||||
.header("content-type", "application/json")
|
||||
.json(&native_request);
|
||||
|
||||
let response = self.apply_auth(req, credential).send().await?;
|
||||
if !response.status().is_success() {
|
||||
return Err(super::api_error("Anthropic", response).await);
|
||||
}
|
||||
|
||||
let native_response: NativeChatResponse = response.json().await?;
|
||||
Ok(Self::parse_native_response(native_response))
|
||||
}
|
||||
|
||||
fn supports_native_tools(&self) -> bool {
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -103,22 +410,52 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn creates_with_key() {
|
||||
let p = AnthropicProvider::new(Some("sk-ant-test123"));
|
||||
assert!(p.api_key.is_some());
|
||||
assert_eq!(p.api_key.as_deref(), Some("sk-ant-test123"));
|
||||
let p = AnthropicProvider::new(Some("anthropic-test-credential"));
|
||||
assert!(p.credential.is_some());
|
||||
assert_eq!(p.credential.as_deref(), Some("anthropic-test-credential"));
|
||||
assert_eq!(p.base_url, "https://api.anthropic.com");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn creates_without_key() {
|
||||
let p = AnthropicProvider::new(None);
|
||||
assert!(p.api_key.is_none());
|
||||
assert!(p.credential.is_none());
|
||||
assert_eq!(p.base_url, "https://api.anthropic.com");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn creates_with_empty_key() {
|
||||
let p = AnthropicProvider::new(Some(""));
|
||||
assert!(p.api_key.is_some());
|
||||
assert_eq!(p.api_key.as_deref(), Some(""));
|
||||
assert!(p.credential.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn creates_with_whitespace_key() {
|
||||
let p = AnthropicProvider::new(Some(" anthropic-test-credential "));
|
||||
assert!(p.credential.is_some());
|
||||
assert_eq!(p.credential.as_deref(), Some("anthropic-test-credential"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn creates_with_custom_base_url() {
|
||||
let p = AnthropicProvider::with_base_url(
|
||||
Some("anthropic-credential"),
|
||||
Some("https://api.example.com"),
|
||||
);
|
||||
assert_eq!(p.base_url, "https://api.example.com");
|
||||
assert_eq!(p.credential.as_deref(), Some("anthropic-credential"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn custom_base_url_trims_trailing_slash() {
|
||||
let p = AnthropicProvider::with_base_url(None, Some("https://api.example.com/"));
|
||||
assert_eq!(p.base_url, "https://api.example.com");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn default_base_url_when_none_provided() {
|
||||
let p = AnthropicProvider::with_base_url(None, None);
|
||||
assert_eq!(p.base_url, "https://api.anthropic.com");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
|
@ -130,11 +467,67 @@ mod tests {
|
|||
assert!(result.is_err());
|
||||
let err = result.unwrap_err().to_string();
|
||||
assert!(
|
||||
err.contains("API key not set"),
|
||||
err.contains("credentials not set"),
|
||||
"Expected key error, got: {err}"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn setup_token_detection_works() {
|
||||
assert!(AnthropicProvider::is_setup_token("sk-ant-oat01-abcdef"));
|
||||
assert!(!AnthropicProvider::is_setup_token("sk-ant-api-key"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn apply_auth_uses_bearer_and_beta_for_setup_tokens() {
|
||||
let provider = AnthropicProvider::new(None);
|
||||
let request = provider
|
||||
.apply_auth(
|
||||
provider.client.get("https://api.anthropic.com/v1/models"),
|
||||
"sk-ant-oat01-test-token",
|
||||
)
|
||||
.build()
|
||||
.expect("request should build");
|
||||
|
||||
assert_eq!(
|
||||
request
|
||||
.headers()
|
||||
.get("authorization")
|
||||
.and_then(|v| v.to_str().ok()),
|
||||
Some("Bearer sk-ant-oat01-test-token")
|
||||
);
|
||||
assert_eq!(
|
||||
request
|
||||
.headers()
|
||||
.get("anthropic-beta")
|
||||
.and_then(|v| v.to_str().ok()),
|
||||
Some("oauth-2025-04-20")
|
||||
);
|
||||
assert!(request.headers().get("x-api-key").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn apply_auth_uses_x_api_key_for_regular_tokens() {
|
||||
let provider = AnthropicProvider::new(None);
|
||||
let request = provider
|
||||
.apply_auth(
|
||||
provider.client.get("https://api.anthropic.com/v1/models"),
|
||||
"sk-ant-api-key",
|
||||
)
|
||||
.build()
|
||||
.expect("request should build");
|
||||
|
||||
assert_eq!(
|
||||
request
|
||||
.headers()
|
||||
.get("x-api-key")
|
||||
.and_then(|v| v.to_str().ok()),
|
||||
Some("sk-ant-api-key")
|
||||
);
|
||||
assert!(request.headers().get("authorization").is_none());
|
||||
assert!(request.headers().get("anthropic-beta").is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn chat_with_system_fails_without_key() {
|
||||
let p = AnthropicProvider::new(None);
|
||||
|
|
@ -186,7 +579,8 @@ mod tests {
|
|||
let json = r#"{"content":[{"type":"text","text":"Hello there!"}]}"#;
|
||||
let resp: ChatResponse = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(resp.content.len(), 1);
|
||||
assert_eq!(resp.content[0].text, "Hello there!");
|
||||
assert_eq!(resp.content[0].kind, "text");
|
||||
assert_eq!(resp.content[0].text.as_deref(), Some("Hello there!"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -202,8 +596,8 @@ mod tests {
|
|||
r#"{"content":[{"type":"text","text":"First"},{"type":"text","text":"Second"}]}"#;
|
||||
let resp: ChatResponse = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(resp.content.len(), 2);
|
||||
assert_eq!(resp.content[0].text, "First");
|
||||
assert_eq!(resp.content[1].text, "Second");
|
||||
assert_eq!(resp.content[0].text.as_deref(), Some("First"));
|
||||
assert_eq!(resp.content[1].text.as_deref(), Some("Second"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
705
src/providers/copilot.rs
Normal file
705
src/providers/copilot.rs
Normal file
|
|
@ -0,0 +1,705 @@
|
|||
//! GitHub Copilot provider with OAuth device-flow authentication.
|
||||
//!
|
||||
//! Authenticates via GitHub's device code flow (same as VS Code Copilot),
|
||||
//! then exchanges the OAuth token for short-lived Copilot API keys.
|
||||
//! Tokens are cached to disk and auto-refreshed.
|
||||
//!
|
||||
//! **Note:** This uses VS Code's OAuth client ID (`Iv1.b507a08c87ecfe98`) and
|
||||
//! editor headers. This is the same approach used by LiteLLM, Codex CLI,
|
||||
//! and other third-party Copilot integrations. The Copilot token endpoint is
|
||||
//! private; there is no public OAuth scope or app registration for it.
|
||||
//! GitHub could change or revoke this at any time, which would break all
|
||||
//! third-party integrations simultaneously.
|
||||
|
||||
use crate::providers::traits::{
|
||||
ChatMessage, ChatRequest as ProviderChatRequest, ChatResponse as ProviderChatResponse,
|
||||
Provider, ToolCall as ProviderToolCall,
|
||||
};
|
||||
use crate::tools::ToolSpec;
|
||||
use async_trait::async_trait;
|
||||
use reqwest::Client;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::Mutex;
|
||||
use tracing::warn;
|
||||
|
||||
/// GitHub OAuth client ID for Copilot (VS Code extension).
|
||||
const GITHUB_CLIENT_ID: &str = "Iv1.b507a08c87ecfe98";
|
||||
const GITHUB_DEVICE_CODE_URL: &str = "https://github.com/login/device/code";
|
||||
const GITHUB_ACCESS_TOKEN_URL: &str = "https://github.com/login/oauth/access_token";
|
||||
const GITHUB_API_KEY_URL: &str = "https://api.github.com/copilot_internal/v2/token";
|
||||
const DEFAULT_API: &str = "https://api.githubcopilot.com";
|
||||
|
||||
// ── Token types ──────────────────────────────────────────────────
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct DeviceCodeResponse {
|
||||
device_code: String,
|
||||
user_code: String,
|
||||
verification_uri: String,
|
||||
#[serde(default = "default_interval")]
|
||||
interval: u64,
|
||||
#[serde(default = "default_expires_in")]
|
||||
expires_in: u64,
|
||||
}
|
||||
|
||||
fn default_interval() -> u64 {
|
||||
5
|
||||
}
|
||||
|
||||
fn default_expires_in() -> u64 {
|
||||
900
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct AccessTokenResponse {
|
||||
access_token: Option<String>,
|
||||
error: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct ApiKeyInfo {
|
||||
token: String,
|
||||
expires_at: i64,
|
||||
#[serde(default)]
|
||||
endpoints: Option<ApiEndpoints>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct ApiEndpoints {
|
||||
api: Option<String>,
|
||||
}
|
||||
|
||||
struct CachedApiKey {
|
||||
token: String,
|
||||
api_endpoint: String,
|
||||
expires_at: i64,
|
||||
}
|
||||
|
||||
// ── Chat completions types ───────────────────────────────────────
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct ApiChatRequest {
|
||||
model: String,
|
||||
messages: Vec<ApiMessage>,
|
||||
temperature: f64,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
tools: Option<Vec<NativeToolSpec>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
tool_choice: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct ApiMessage {
|
||||
role: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
content: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
tool_call_id: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
tool_calls: Option<Vec<NativeToolCall>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct NativeToolSpec {
|
||||
#[serde(rename = "type")]
|
||||
kind: String,
|
||||
function: NativeToolFunctionSpec,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct NativeToolFunctionSpec {
|
||||
name: String,
|
||||
description: String,
|
||||
parameters: serde_json::Value,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct NativeToolCall {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
id: Option<String>,
|
||||
#[serde(rename = "type", skip_serializing_if = "Option::is_none")]
|
||||
kind: Option<String>,
|
||||
function: NativeFunctionCall,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct NativeFunctionCall {
|
||||
name: String,
|
||||
arguments: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ApiChatResponse {
|
||||
choices: Vec<Choice>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct Choice {
|
||||
message: ResponseMessage,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ResponseMessage {
|
||||
#[serde(default)]
|
||||
content: Option<String>,
|
||||
#[serde(default)]
|
||||
tool_calls: Option<Vec<NativeToolCall>>,
|
||||
}
|
||||
|
||||
// ── Provider ─────────────────────────────────────────────────────
|
||||
|
||||
/// GitHub Copilot provider with automatic OAuth and token refresh.
|
||||
///
|
||||
/// On first use, prompts the user to visit github.com/login/device.
|
||||
/// Tokens are cached to `~/.config/zeroclaw/copilot/` and refreshed
|
||||
/// automatically.
|
||||
pub struct CopilotProvider {
|
||||
github_token: Option<String>,
|
||||
/// Mutex ensures only one caller refreshes tokens at a time,
|
||||
/// preventing duplicate device flow prompts or redundant API calls.
|
||||
refresh_lock: Arc<Mutex<Option<CachedApiKey>>>,
|
||||
http: Client,
|
||||
token_dir: PathBuf,
|
||||
}
|
||||
|
||||
impl CopilotProvider {
|
||||
pub fn new(github_token: Option<&str>) -> Self {
|
||||
let token_dir = directories::ProjectDirs::from("", "", "zeroclaw")
|
||||
.map(|dir| dir.config_dir().join("copilot"))
|
||||
.unwrap_or_else(|| {
|
||||
// Fall back to a user-specific temp directory to avoid
|
||||
// shared-directory symlink attacks.
|
||||
let user = std::env::var("USER")
|
||||
.or_else(|_| std::env::var("USERNAME"))
|
||||
.unwrap_or_else(|_| "unknown".to_string());
|
||||
std::env::temp_dir().join(format!("zeroclaw-copilot-{user}"))
|
||||
});
|
||||
|
||||
if let Err(err) = std::fs::create_dir_all(&token_dir) {
|
||||
warn!(
|
||||
"Failed to create Copilot token directory {:?}: {err}. Token caching is disabled.",
|
||||
token_dir
|
||||
);
|
||||
} else {
|
||||
#[cfg(unix)]
|
||||
{
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
|
||||
if let Err(err) =
|
||||
std::fs::set_permissions(&token_dir, std::fs::Permissions::from_mode(0o700))
|
||||
{
|
||||
warn!(
|
||||
"Failed to set Copilot token directory permissions on {:?}: {err}",
|
||||
token_dir
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Self {
|
||||
github_token: github_token
|
||||
.filter(|token| !token.is_empty())
|
||||
.map(String::from),
|
||||
refresh_lock: Arc::new(Mutex::new(None)),
|
||||
http: Client::builder()
|
||||
.timeout(Duration::from_secs(120))
|
||||
.connect_timeout(Duration::from_secs(10))
|
||||
.build()
|
||||
.unwrap_or_else(|_| Client::new()),
|
||||
token_dir,
|
||||
}
|
||||
}
|
||||
|
||||
/// Required headers for Copilot API requests (editor identification).
|
||||
const COPILOT_HEADERS: [(&str, &str); 4] = [
|
||||
("Editor-Version", "vscode/1.85.1"),
|
||||
("Editor-Plugin-Version", "copilot/1.155.0"),
|
||||
("User-Agent", "GithubCopilot/1.155.0"),
|
||||
("Accept", "application/json"),
|
||||
];
|
||||
|
||||
fn convert_tools(tools: Option<&[ToolSpec]>) -> Option<Vec<NativeToolSpec>> {
|
||||
tools.map(|items| {
|
||||
items
|
||||
.iter()
|
||||
.map(|tool| NativeToolSpec {
|
||||
kind: "function".to_string(),
|
||||
function: NativeToolFunctionSpec {
|
||||
name: tool.name.clone(),
|
||||
description: tool.description.clone(),
|
||||
parameters: tool.parameters.clone(),
|
||||
},
|
||||
})
|
||||
.collect()
|
||||
})
|
||||
}
|
||||
|
||||
fn convert_messages(messages: &[ChatMessage]) -> Vec<ApiMessage> {
|
||||
messages
|
||||
.iter()
|
||||
.map(|message| {
|
||||
if message.role == "assistant" {
|
||||
if let Ok(value) = serde_json::from_str::<serde_json::Value>(&message.content) {
|
||||
if let Some(tool_calls_value) = value.get("tool_calls") {
|
||||
if let Ok(parsed_calls) =
|
||||
serde_json::from_value::<Vec<ProviderToolCall>>(tool_calls_value.clone())
|
||||
{
|
||||
let tool_calls = parsed_calls
|
||||
.into_iter()
|
||||
.map(|tool_call| NativeToolCall {
|
||||
id: Some(tool_call.id),
|
||||
kind: Some("function".to_string()),
|
||||
function: NativeFunctionCall {
|
||||
name: tool_call.name,
|
||||
arguments: tool_call.arguments,
|
||||
},
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let content = value
|
||||
.get("content")
|
||||
.and_then(serde_json::Value::as_str)
|
||||
.map(ToString::to_string);
|
||||
|
||||
return ApiMessage {
|
||||
role: "assistant".to_string(),
|
||||
content,
|
||||
tool_call_id: None,
|
||||
tool_calls: Some(tool_calls),
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if message.role == "tool" {
|
||||
if let Ok(value) = serde_json::from_str::<serde_json::Value>(&message.content) {
|
||||
let tool_call_id = value
|
||||
.get("tool_call_id")
|
||||
.and_then(serde_json::Value::as_str)
|
||||
.map(ToString::to_string);
|
||||
let content = value
|
||||
.get("content")
|
||||
.and_then(serde_json::Value::as_str)
|
||||
.map(ToString::to_string);
|
||||
|
||||
return ApiMessage {
|
||||
role: "tool".to_string(),
|
||||
content,
|
||||
tool_call_id,
|
||||
tool_calls: None,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
ApiMessage {
|
||||
role: message.role.clone(),
|
||||
content: Some(message.content.clone()),
|
||||
tool_call_id: None,
|
||||
tool_calls: None,
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Send a chat completions request with required Copilot headers.
|
||||
async fn send_chat_request(
|
||||
&self,
|
||||
messages: Vec<ApiMessage>,
|
||||
tools: Option<&[ToolSpec]>,
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<ProviderChatResponse> {
|
||||
let (token, endpoint) = self.get_api_key().await?;
|
||||
let url = format!("{}/chat/completions", endpoint.trim_end_matches('/'));
|
||||
|
||||
let native_tools = Self::convert_tools(tools);
|
||||
let request = ApiChatRequest {
|
||||
model: model.to_string(),
|
||||
messages,
|
||||
temperature,
|
||||
tool_choice: native_tools.as_ref().map(|_| "auto".to_string()),
|
||||
tools: native_tools,
|
||||
};
|
||||
|
||||
let mut req = self
|
||||
.http
|
||||
.post(&url)
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.json(&request);
|
||||
|
||||
for (header, value) in &Self::COPILOT_HEADERS {
|
||||
req = req.header(*header, *value);
|
||||
}
|
||||
|
||||
let response = req.send().await?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
return Err(super::api_error("GitHub Copilot", response).await);
|
||||
}
|
||||
|
||||
let api_response: ApiChatResponse = response.json().await?;
|
||||
let choice = api_response
|
||||
.choices
|
||||
.into_iter()
|
||||
.next()
|
||||
.ok_or_else(|| anyhow::anyhow!("No response from GitHub Copilot"))?;
|
||||
|
||||
let tool_calls = choice
|
||||
.message
|
||||
.tool_calls
|
||||
.unwrap_or_default()
|
||||
.into_iter()
|
||||
.map(|tool_call| ProviderToolCall {
|
||||
id: tool_call
|
||||
.id
|
||||
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string()),
|
||||
name: tool_call.function.name,
|
||||
arguments: tool_call.function.arguments,
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(ProviderChatResponse {
|
||||
text: choice.message.content,
|
||||
tool_calls,
|
||||
})
|
||||
}
|
||||
|
||||
/// Get a valid Copilot API key, refreshing or re-authenticating as needed.
|
||||
/// Uses a Mutex to ensure only one caller refreshes at a time.
|
||||
async fn get_api_key(&self) -> anyhow::Result<(String, String)> {
|
||||
let mut cached = self.refresh_lock.lock().await;
|
||||
|
||||
if let Some(cached_key) = cached.as_ref() {
|
||||
if chrono::Utc::now().timestamp() + 120 < cached_key.expires_at {
|
||||
return Ok((cached_key.token.clone(), cached_key.api_endpoint.clone()));
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(info) = self.load_api_key_from_disk().await {
|
||||
if chrono::Utc::now().timestamp() + 120 < info.expires_at {
|
||||
let endpoint = info
|
||||
.endpoints
|
||||
.as_ref()
|
||||
.and_then(|e| e.api.clone())
|
||||
.unwrap_or_else(|| DEFAULT_API.to_string());
|
||||
let token = info.token;
|
||||
|
||||
*cached = Some(CachedApiKey {
|
||||
token: token.clone(),
|
||||
api_endpoint: endpoint.clone(),
|
||||
expires_at: info.expires_at,
|
||||
});
|
||||
return Ok((token, endpoint));
|
||||
}
|
||||
}
|
||||
|
||||
let access_token = self.get_github_access_token().await?;
|
||||
let api_key_info = self.exchange_for_api_key(&access_token).await?;
|
||||
self.save_api_key_to_disk(&api_key_info).await;
|
||||
|
||||
let endpoint = api_key_info
|
||||
.endpoints
|
||||
.as_ref()
|
||||
.and_then(|e| e.api.clone())
|
||||
.unwrap_or_else(|| DEFAULT_API.to_string());
|
||||
|
||||
*cached = Some(CachedApiKey {
|
||||
token: api_key_info.token.clone(),
|
||||
api_endpoint: endpoint.clone(),
|
||||
expires_at: api_key_info.expires_at,
|
||||
});
|
||||
|
||||
Ok((api_key_info.token, endpoint))
|
||||
}
|
||||
|
||||
/// Get a GitHub access token from config, cache, or device flow.
|
||||
async fn get_github_access_token(&self) -> anyhow::Result<String> {
|
||||
if let Some(token) = &self.github_token {
|
||||
return Ok(token.clone());
|
||||
}
|
||||
|
||||
let access_token_path = self.token_dir.join("access-token");
|
||||
if let Ok(cached) = tokio::fs::read_to_string(&access_token_path).await {
|
||||
let token = cached.trim();
|
||||
if !token.is_empty() {
|
||||
return Ok(token.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
let token = self.device_code_login().await?;
|
||||
write_file_secure(&access_token_path, &token).await;
|
||||
Ok(token)
|
||||
}
|
||||
|
||||
/// Run GitHub OAuth device code flow.
|
||||
async fn device_code_login(&self) -> anyhow::Result<String> {
|
||||
let response: DeviceCodeResponse = self
|
||||
.http
|
||||
.post(GITHUB_DEVICE_CODE_URL)
|
||||
.header("Accept", "application/json")
|
||||
.json(&serde_json::json!({
|
||||
"client_id": GITHUB_CLIENT_ID,
|
||||
"scope": "read:user"
|
||||
}))
|
||||
.send()
|
||||
.await?
|
||||
.error_for_status()?
|
||||
.json()
|
||||
.await?;
|
||||
|
||||
let mut poll_interval = Duration::from_secs(response.interval.max(5));
|
||||
let expires_in = response.expires_in.max(1);
|
||||
let expires_at = tokio::time::Instant::now() + Duration::from_secs(expires_in);
|
||||
|
||||
eprintln!(
|
||||
"\nGitHub Copilot authentication is required.\n\
|
||||
Visit: {}\n\
|
||||
Code: {}\n\
|
||||
Waiting for authorization...\n",
|
||||
response.verification_uri, response.user_code
|
||||
);
|
||||
|
||||
while tokio::time::Instant::now() < expires_at {
|
||||
tokio::time::sleep(poll_interval).await;
|
||||
|
||||
let token_response: AccessTokenResponse = self
|
||||
.http
|
||||
.post(GITHUB_ACCESS_TOKEN_URL)
|
||||
.header("Accept", "application/json")
|
||||
.json(&serde_json::json!({
|
||||
"client_id": GITHUB_CLIENT_ID,
|
||||
"device_code": response.device_code,
|
||||
"grant_type": "urn:ietf:params:oauth:grant-type:device_code"
|
||||
}))
|
||||
.send()
|
||||
.await?
|
||||
.json()
|
||||
.await?;
|
||||
|
||||
if let Some(token) = token_response.access_token {
|
||||
eprintln!("Authentication succeeded.\n");
|
||||
return Ok(token);
|
||||
}
|
||||
|
||||
match token_response.error.as_deref() {
|
||||
Some("slow_down") => {
|
||||
poll_interval += Duration::from_secs(5);
|
||||
}
|
||||
Some("authorization_pending") | None => {}
|
||||
Some("expired_token") => {
|
||||
anyhow::bail!("GitHub device authorization expired")
|
||||
}
|
||||
Some(error) => anyhow::bail!("GitHub auth failed: {error}"),
|
||||
}
|
||||
}
|
||||
|
||||
anyhow::bail!("Timed out waiting for GitHub authorization")
|
||||
}
|
||||
|
||||
/// Exchange a GitHub access token for a Copilot API key.
|
||||
async fn exchange_for_api_key(&self, access_token: &str) -> anyhow::Result<ApiKeyInfo> {
|
||||
let mut request = self.http.get(GITHUB_API_KEY_URL);
|
||||
for (header, value) in &Self::COPILOT_HEADERS {
|
||||
request = request.header(*header, *value);
|
||||
}
|
||||
request = request.header("Authorization", format!("token {access_token}"));
|
||||
|
||||
let response = request.send().await?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let body = response.text().await.unwrap_or_default();
|
||||
let sanitized = super::sanitize_api_error(&body);
|
||||
|
||||
if status.as_u16() == 401 || status.as_u16() == 403 {
|
||||
let access_token_path = self.token_dir.join("access-token");
|
||||
tokio::fs::remove_file(&access_token_path).await.ok();
|
||||
}
|
||||
|
||||
anyhow::bail!(
|
||||
"Failed to get Copilot API key ({status}): {sanitized}. \
|
||||
Ensure your GitHub account has an active Copilot subscription."
|
||||
);
|
||||
}
|
||||
|
||||
let info: ApiKeyInfo = response.json().await?;
|
||||
Ok(info)
|
||||
}
|
||||
|
||||
async fn load_api_key_from_disk(&self) -> Option<ApiKeyInfo> {
|
||||
let path = self.token_dir.join("api-key.json");
|
||||
let data = tokio::fs::read_to_string(&path).await.ok()?;
|
||||
serde_json::from_str(&data).ok()
|
||||
}
|
||||
|
||||
async fn save_api_key_to_disk(&self, info: &ApiKeyInfo) {
|
||||
let path = self.token_dir.join("api-key.json");
|
||||
if let Ok(json) = serde_json::to_string_pretty(info) {
|
||||
write_file_secure(&path, &json).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Write a file with 0600 permissions (owner read/write only).
|
||||
/// Uses `spawn_blocking` to avoid blocking the async runtime.
|
||||
async fn write_file_secure(path: &Path, content: &str) {
|
||||
let path = path.to_path_buf();
|
||||
let content = content.to_string();
|
||||
|
||||
let result = tokio::task::spawn_blocking(move || {
|
||||
#[cfg(unix)]
|
||||
{
|
||||
use std::io::Write;
|
||||
use std::os::unix::fs::{OpenOptionsExt, PermissionsExt};
|
||||
|
||||
let mut file = std::fs::OpenOptions::new()
|
||||
.write(true)
|
||||
.create(true)
|
||||
.truncate(true)
|
||||
.mode(0o600)
|
||||
.open(&path)?;
|
||||
file.write_all(content.as_bytes())?;
|
||||
|
||||
std::fs::set_permissions(&path, std::fs::Permissions::from_mode(0o600))?;
|
||||
Ok::<(), std::io::Error>(())
|
||||
}
|
||||
#[cfg(not(unix))]
|
||||
{
|
||||
std::fs::write(&path, &content)?;
|
||||
Ok::<(), std::io::Error>(())
|
||||
}
|
||||
})
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(Ok(())) => {}
|
||||
Ok(Err(err)) => warn!("Failed to write secure file: {err}"),
|
||||
Err(err) => warn!("Failed to spawn blocking write: {err}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Provider for CopilotProvider {
|
||||
async fn chat_with_system(
|
||||
&self,
|
||||
system_prompt: Option<&str>,
|
||||
message: &str,
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
let mut messages = Vec::new();
|
||||
if let Some(system) = system_prompt {
|
||||
messages.push(ApiMessage {
|
||||
role: "system".to_string(),
|
||||
content: Some(system.to_string()),
|
||||
tool_call_id: None,
|
||||
tool_calls: None,
|
||||
});
|
||||
}
|
||||
messages.push(ApiMessage {
|
||||
role: "user".to_string(),
|
||||
content: Some(message.to_string()),
|
||||
tool_call_id: None,
|
||||
tool_calls: None,
|
||||
});
|
||||
|
||||
let response = self
|
||||
.send_chat_request(messages, None, model, temperature)
|
||||
.await?;
|
||||
Ok(response.text.unwrap_or_default())
|
||||
}
|
||||
|
||||
async fn chat_with_history(
|
||||
&self,
|
||||
messages: &[ChatMessage],
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
let response = self
|
||||
.send_chat_request(Self::convert_messages(messages), None, model, temperature)
|
||||
.await?;
|
||||
Ok(response.text.unwrap_or_default())
|
||||
}
|
||||
|
||||
async fn chat(
|
||||
&self,
|
||||
request: ProviderChatRequest<'_>,
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<ProviderChatResponse> {
|
||||
self.send_chat_request(
|
||||
Self::convert_messages(request.messages),
|
||||
request.tools,
|
||||
model,
|
||||
temperature,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
fn supports_native_tools(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
async fn warmup(&self) -> anyhow::Result<()> {
|
||||
let _ = self.get_api_key().await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn new_without_token() {
|
||||
let provider = CopilotProvider::new(None);
|
||||
assert!(provider.github_token.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn new_with_token() {
|
||||
let provider = CopilotProvider::new(Some("ghp_test"));
|
||||
assert_eq!(provider.github_token.as_deref(), Some("ghp_test"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_token_treated_as_none() {
|
||||
let provider = CopilotProvider::new(Some(""));
|
||||
assert!(provider.github_token.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn cache_starts_empty() {
|
||||
let provider = CopilotProvider::new(None);
|
||||
let cached = provider.refresh_lock.lock().await;
|
||||
assert!(cached.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn copilot_headers_include_required_fields() {
|
||||
let headers = CopilotProvider::COPILOT_HEADERS;
|
||||
assert!(headers
|
||||
.iter()
|
||||
.any(|(header, _)| *header == "Editor-Version"));
|
||||
assert!(headers
|
||||
.iter()
|
||||
.any(|(header, _)| *header == "Editor-Plugin-Version"));
|
||||
assert!(headers.iter().any(|(header, _)| *header == "User-Agent"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn default_interval_and_expiry() {
|
||||
assert_eq!(default_interval(), 5);
|
||||
assert_eq!(default_expires_in(), 900);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn supports_native_tools() {
|
||||
let provider = CopilotProvider::new(None);
|
||||
assert!(provider.supports_native_tools());
|
||||
}
|
||||
}
|
||||
560
src/providers/gemini.rs
Normal file
560
src/providers/gemini.rs
Normal file
|
|
@ -0,0 +1,560 @@
|
|||
//! Google Gemini provider with support for:
|
||||
//! - Direct API key (`GEMINI_API_KEY` env var or config)
|
||||
//! - Gemini CLI OAuth tokens (reuse existing ~/.gemini/ authentication)
|
||||
//! - Google Cloud ADC (`GOOGLE_APPLICATION_CREDENTIALS`)
|
||||
|
||||
use crate::providers::traits::Provider;
|
||||
use async_trait::async_trait;
|
||||
use directories::UserDirs;
|
||||
use reqwest::Client;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::path::PathBuf;
|
||||
|
||||
/// Gemini provider supporting multiple authentication methods.
|
||||
pub struct GeminiProvider {
|
||||
auth: Option<GeminiAuth>,
|
||||
client: Client,
|
||||
}
|
||||
|
||||
/// Resolved credential — the variant determines both the HTTP auth method
|
||||
/// and the diagnostic label returned by `auth_source()`.
|
||||
#[derive(Debug)]
|
||||
enum GeminiAuth {
|
||||
/// Explicit API key from config: sent as `?key=` query parameter.
|
||||
ExplicitKey(String),
|
||||
/// API key from `GEMINI_API_KEY` env var: sent as `?key=`.
|
||||
EnvGeminiKey(String),
|
||||
/// API key from `GOOGLE_API_KEY` env var: sent as `?key=`.
|
||||
EnvGoogleKey(String),
|
||||
/// OAuth access token from Gemini CLI: sent as `Authorization: Bearer`.
|
||||
OAuthToken(String),
|
||||
}
|
||||
|
||||
impl GeminiAuth {
|
||||
/// Whether this credential is an API key (sent as `?key=` query param).
|
||||
fn is_api_key(&self) -> bool {
|
||||
matches!(
|
||||
self,
|
||||
GeminiAuth::ExplicitKey(_) | GeminiAuth::EnvGeminiKey(_) | GeminiAuth::EnvGoogleKey(_)
|
||||
)
|
||||
}
|
||||
|
||||
/// The raw credential string.
|
||||
fn credential(&self) -> &str {
|
||||
match self {
|
||||
GeminiAuth::ExplicitKey(s)
|
||||
| GeminiAuth::EnvGeminiKey(s)
|
||||
| GeminiAuth::EnvGoogleKey(s)
|
||||
| GeminiAuth::OAuthToken(s) => s,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ══════════════════════════════════════════════════════════════════════════════
|
||||
// API REQUEST/RESPONSE TYPES
|
||||
// ══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct GenerateContentRequest {
|
||||
contents: Vec<Content>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
system_instruction: Option<Content>,
|
||||
#[serde(rename = "generationConfig")]
|
||||
generation_config: GenerationConfig,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct Content {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
role: Option<String>,
|
||||
parts: Vec<Part>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct Part {
|
||||
text: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct GenerationConfig {
|
||||
temperature: f64,
|
||||
#[serde(rename = "maxOutputTokens")]
|
||||
max_output_tokens: u32,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct GenerateContentResponse {
|
||||
candidates: Option<Vec<Candidate>>,
|
||||
error: Option<ApiError>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct Candidate {
|
||||
content: CandidateContent,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct CandidateContent {
|
||||
parts: Vec<ResponsePart>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ResponsePart {
|
||||
text: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ApiError {
|
||||
message: String,
|
||||
}
|
||||
|
||||
// ══════════════════════════════════════════════════════════════════════════════
|
||||
// GEMINI CLI TOKEN STRUCTURES
|
||||
// ══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
/// OAuth token stored by Gemini CLI in `~/.gemini/oauth_creds.json`
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct GeminiCliOAuthCreds {
|
||||
access_token: Option<String>,
|
||||
expiry: Option<String>,
|
||||
}
|
||||
|
||||
impl GeminiProvider {
|
||||
/// Create a new Gemini provider.
|
||||
///
|
||||
/// Authentication priority:
|
||||
/// 1. Explicit API key passed in
|
||||
/// 2. `GEMINI_API_KEY` environment variable
|
||||
/// 3. `GOOGLE_API_KEY` environment variable
|
||||
/// 4. Gemini CLI OAuth tokens (`~/.gemini/oauth_creds.json`)
|
||||
pub fn new(api_key: Option<&str>) -> Self {
|
||||
let resolved_auth = api_key
|
||||
.and_then(Self::normalize_non_empty)
|
||||
.map(GeminiAuth::ExplicitKey)
|
||||
.or_else(|| Self::load_non_empty_env("GEMINI_API_KEY").map(GeminiAuth::EnvGeminiKey))
|
||||
.or_else(|| Self::load_non_empty_env("GOOGLE_API_KEY").map(GeminiAuth::EnvGoogleKey))
|
||||
.or_else(|| Self::try_load_gemini_cli_token().map(GeminiAuth::OAuthToken));
|
||||
|
||||
Self {
|
||||
auth: resolved_auth,
|
||||
client: Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(120))
|
||||
.connect_timeout(std::time::Duration::from_secs(10))
|
||||
.build()
|
||||
.unwrap_or_else(|_| Client::new()),
|
||||
}
|
||||
}
|
||||
|
||||
fn normalize_non_empty(value: &str) -> Option<String> {
|
||||
let trimmed = value.trim();
|
||||
if trimmed.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(trimmed.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
fn load_non_empty_env(name: &str) -> Option<String> {
|
||||
std::env::var(name)
|
||||
.ok()
|
||||
.and_then(|value| Self::normalize_non_empty(&value))
|
||||
}
|
||||
|
||||
/// Try to load OAuth access token from Gemini CLI's cached credentials.
|
||||
/// Location: `~/.gemini/oauth_creds.json`
|
||||
fn try_load_gemini_cli_token() -> Option<String> {
|
||||
let gemini_dir = Self::gemini_cli_dir()?;
|
||||
let creds_path = gemini_dir.join("oauth_creds.json");
|
||||
|
||||
if !creds_path.exists() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let content = std::fs::read_to_string(&creds_path).ok()?;
|
||||
let creds: GeminiCliOAuthCreds = serde_json::from_str(&content).ok()?;
|
||||
|
||||
// Check if token is expired (basic check)
|
||||
if let Some(ref expiry) = creds.expiry {
|
||||
if let Ok(expiry_time) = chrono::DateTime::parse_from_rfc3339(expiry) {
|
||||
if expiry_time < chrono::Utc::now() {
|
||||
tracing::warn!("Gemini CLI OAuth token expired — re-run `gemini` to refresh");
|
||||
return None;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
creds
|
||||
.access_token
|
||||
.and_then(|token| Self::normalize_non_empty(&token))
|
||||
}
|
||||
|
||||
/// Get the Gemini CLI config directory (~/.gemini)
|
||||
fn gemini_cli_dir() -> Option<PathBuf> {
|
||||
UserDirs::new().map(|u| u.home_dir().join(".gemini"))
|
||||
}
|
||||
|
||||
/// Check if Gemini CLI is configured and has valid credentials
|
||||
pub fn has_cli_credentials() -> bool {
|
||||
Self::try_load_gemini_cli_token().is_some()
|
||||
}
|
||||
|
||||
/// Check if any Gemini authentication is available
|
||||
pub fn has_any_auth() -> bool {
|
||||
Self::load_non_empty_env("GEMINI_API_KEY").is_some()
|
||||
|| Self::load_non_empty_env("GOOGLE_API_KEY").is_some()
|
||||
|| Self::has_cli_credentials()
|
||||
}
|
||||
|
||||
/// Get authentication source description for diagnostics.
|
||||
/// Uses the stored enum variant — no env var re-reading at call time.
|
||||
pub fn auth_source(&self) -> &'static str {
|
||||
match self.auth.as_ref() {
|
||||
Some(GeminiAuth::ExplicitKey(_)) => "config",
|
||||
Some(GeminiAuth::EnvGeminiKey(_)) => "GEMINI_API_KEY env var",
|
||||
Some(GeminiAuth::EnvGoogleKey(_)) => "GOOGLE_API_KEY env var",
|
||||
Some(GeminiAuth::OAuthToken(_)) => "Gemini CLI OAuth",
|
||||
None => "none",
|
||||
}
|
||||
}
|
||||
|
||||
fn format_model_name(model: &str) -> String {
|
||||
if model.starts_with("models/") {
|
||||
model.to_string()
|
||||
} else {
|
||||
format!("models/{model}")
|
||||
}
|
||||
}
|
||||
|
||||
fn build_generate_content_url(model: &str, auth: &GeminiAuth) -> String {
|
||||
let model_name = Self::format_model_name(model);
|
||||
let base_url = format!(
|
||||
"https://generativelanguage.googleapis.com/v1beta/{model_name}:generateContent"
|
||||
);
|
||||
|
||||
if auth.is_api_key() {
|
||||
format!("{base_url}?key={}", auth.credential())
|
||||
} else {
|
||||
base_url
|
||||
}
|
||||
}
|
||||
|
||||
fn build_generate_content_request(
|
||||
&self,
|
||||
auth: &GeminiAuth,
|
||||
url: &str,
|
||||
request: &GenerateContentRequest,
|
||||
) -> reqwest::RequestBuilder {
|
||||
let req = self.client.post(url).json(request);
|
||||
match auth {
|
||||
GeminiAuth::OAuthToken(token) => req.bearer_auth(token),
|
||||
_ => req,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Provider for GeminiProvider {
|
||||
async fn chat_with_system(
|
||||
&self,
|
||||
system_prompt: Option<&str>,
|
||||
message: &str,
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
let auth = self.auth.as_ref().ok_or_else(|| {
|
||||
anyhow::anyhow!(
|
||||
"Gemini API key not found. Options:\n\
|
||||
1. Set GEMINI_API_KEY env var\n\
|
||||
2. Run `gemini` CLI to authenticate (tokens will be reused)\n\
|
||||
3. Get an API key from https://aistudio.google.com/app/apikey\n\
|
||||
4. Run `zeroclaw onboard` to configure"
|
||||
)
|
||||
})?;
|
||||
|
||||
// Build request
|
||||
let system_instruction = system_prompt.map(|sys| Content {
|
||||
role: None,
|
||||
parts: vec![Part {
|
||||
text: sys.to_string(),
|
||||
}],
|
||||
});
|
||||
|
||||
let request = GenerateContentRequest {
|
||||
contents: vec![Content {
|
||||
role: Some("user".to_string()),
|
||||
parts: vec![Part {
|
||||
text: message.to_string(),
|
||||
}],
|
||||
}],
|
||||
system_instruction,
|
||||
generation_config: GenerationConfig {
|
||||
temperature,
|
||||
max_output_tokens: 8192,
|
||||
},
|
||||
};
|
||||
|
||||
let url = Self::build_generate_content_url(model, auth);
|
||||
|
||||
let response = self
|
||||
.build_generate_content_request(auth, &url, &request)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let error_text = response.text().await.unwrap_or_default();
|
||||
anyhow::bail!("Gemini API error ({status}): {error_text}");
|
||||
}
|
||||
|
||||
let result: GenerateContentResponse = response.json().await?;
|
||||
|
||||
// Check for API error in response body
|
||||
if let Some(err) = result.error {
|
||||
anyhow::bail!("Gemini API error: {}", err.message);
|
||||
}
|
||||
|
||||
// Extract text from response
|
||||
result
|
||||
.candidates
|
||||
.and_then(|c| c.into_iter().next())
|
||||
.and_then(|c| c.content.parts.into_iter().next())
|
||||
.and_then(|p| p.text)
|
||||
.ok_or_else(|| anyhow::anyhow!("No response from Gemini"))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use reqwest::header::AUTHORIZATION;
|
||||
|
||||
#[test]
|
||||
fn normalize_non_empty_trims_and_filters() {
|
||||
assert_eq!(
|
||||
GeminiProvider::normalize_non_empty(" value "),
|
||||
Some("value".into())
|
||||
);
|
||||
assert_eq!(GeminiProvider::normalize_non_empty(""), None);
|
||||
assert_eq!(GeminiProvider::normalize_non_empty(" \t\n"), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn provider_creates_without_key() {
|
||||
let provider = GeminiProvider::new(None);
|
||||
// May pick up env vars; just verify it doesn't panic
|
||||
let _ = provider.auth_source();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn provider_creates_with_key() {
|
||||
let provider = GeminiProvider::new(Some("test-api-key"));
|
||||
assert!(matches!(
|
||||
provider.auth,
|
||||
Some(GeminiAuth::ExplicitKey(ref key)) if key == "test-api-key"
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn provider_rejects_empty_key() {
|
||||
let provider = GeminiProvider::new(Some(""));
|
||||
assert!(!matches!(provider.auth, Some(GeminiAuth::ExplicitKey(_))));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn gemini_cli_dir_returns_path() {
|
||||
let dir = GeminiProvider::gemini_cli_dir();
|
||||
// Should return Some on systems with home dir
|
||||
if UserDirs::new().is_some() {
|
||||
assert!(dir.is_some());
|
||||
assert!(dir.unwrap().ends_with(".gemini"));
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn auth_source_explicit_key() {
|
||||
let provider = GeminiProvider {
|
||||
auth: Some(GeminiAuth::ExplicitKey("key".into())),
|
||||
client: Client::new(),
|
||||
};
|
||||
assert_eq!(provider.auth_source(), "config");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn auth_source_none_without_credentials() {
|
||||
let provider = GeminiProvider {
|
||||
auth: None,
|
||||
client: Client::new(),
|
||||
};
|
||||
assert_eq!(provider.auth_source(), "none");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn auth_source_oauth() {
|
||||
let provider = GeminiProvider {
|
||||
auth: Some(GeminiAuth::OAuthToken("ya29.mock".into())),
|
||||
client: Client::new(),
|
||||
};
|
||||
assert_eq!(provider.auth_source(), "Gemini CLI OAuth");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn model_name_formatting() {
|
||||
assert_eq!(
|
||||
GeminiProvider::format_model_name("gemini-2.0-flash"),
|
||||
"models/gemini-2.0-flash"
|
||||
);
|
||||
assert_eq!(
|
||||
GeminiProvider::format_model_name("models/gemini-1.5-pro"),
|
||||
"models/gemini-1.5-pro"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn api_key_url_includes_key_query_param() {
|
||||
let auth = GeminiAuth::ExplicitKey("api-key-123".into());
|
||||
let url = GeminiProvider::build_generate_content_url("gemini-2.0-flash", &auth);
|
||||
assert!(url.contains(":generateContent?key=api-key-123"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn oauth_url_omits_key_query_param() {
|
||||
let auth = GeminiAuth::OAuthToken("ya29.test-token".into());
|
||||
let url = GeminiProvider::build_generate_content_url("gemini-2.0-flash", &auth);
|
||||
assert!(url.ends_with(":generateContent"));
|
||||
assert!(!url.contains("?key="));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn oauth_request_uses_bearer_auth_header() {
|
||||
let provider = GeminiProvider {
|
||||
auth: Some(GeminiAuth::OAuthToken("ya29.mock-token".into())),
|
||||
client: Client::new(),
|
||||
};
|
||||
let auth = GeminiAuth::OAuthToken("ya29.mock-token".into());
|
||||
let url = GeminiProvider::build_generate_content_url("gemini-2.0-flash", &auth);
|
||||
let body = GenerateContentRequest {
|
||||
contents: vec![Content {
|
||||
role: Some("user".into()),
|
||||
parts: vec![Part {
|
||||
text: "hello".into(),
|
||||
}],
|
||||
}],
|
||||
system_instruction: None,
|
||||
generation_config: GenerationConfig {
|
||||
temperature: 0.7,
|
||||
max_output_tokens: 8192,
|
||||
},
|
||||
};
|
||||
|
||||
let request = provider
|
||||
.build_generate_content_request(&auth, &url, &body)
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
request
|
||||
.headers()
|
||||
.get(AUTHORIZATION)
|
||||
.and_then(|h| h.to_str().ok()),
|
||||
Some("Bearer ya29.mock-token")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn api_key_request_does_not_set_bearer_header() {
|
||||
let provider = GeminiProvider {
|
||||
auth: Some(GeminiAuth::ExplicitKey("api-key-123".into())),
|
||||
client: Client::new(),
|
||||
};
|
||||
let auth = GeminiAuth::ExplicitKey("api-key-123".into());
|
||||
let url = GeminiProvider::build_generate_content_url("gemini-2.0-flash", &auth);
|
||||
let body = GenerateContentRequest {
|
||||
contents: vec![Content {
|
||||
role: Some("user".into()),
|
||||
parts: vec![Part {
|
||||
text: "hello".into(),
|
||||
}],
|
||||
}],
|
||||
system_instruction: None,
|
||||
generation_config: GenerationConfig {
|
||||
temperature: 0.7,
|
||||
max_output_tokens: 8192,
|
||||
},
|
||||
};
|
||||
|
||||
let request = provider
|
||||
.build_generate_content_request(&auth, &url, &body)
|
||||
.build()
|
||||
.unwrap();
|
||||
|
||||
assert!(request.headers().get(AUTHORIZATION).is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn request_serialization() {
|
||||
let request = GenerateContentRequest {
|
||||
contents: vec![Content {
|
||||
role: Some("user".to_string()),
|
||||
parts: vec![Part {
|
||||
text: "Hello".to_string(),
|
||||
}],
|
||||
}],
|
||||
system_instruction: Some(Content {
|
||||
role: None,
|
||||
parts: vec![Part {
|
||||
text: "You are helpful".to_string(),
|
||||
}],
|
||||
}),
|
||||
generation_config: GenerationConfig {
|
||||
temperature: 0.7,
|
||||
max_output_tokens: 8192,
|
||||
},
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&request).unwrap();
|
||||
assert!(json.contains("\"role\":\"user\""));
|
||||
assert!(json.contains("\"text\":\"Hello\""));
|
||||
assert!(json.contains("\"temperature\":0.7"));
|
||||
assert!(json.contains("\"maxOutputTokens\":8192"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_deserialization() {
|
||||
let json = r#"{
|
||||
"candidates": [{
|
||||
"content": {
|
||||
"parts": [{"text": "Hello there!"}]
|
||||
}
|
||||
}]
|
||||
}"#;
|
||||
|
||||
let response: GenerateContentResponse = serde_json::from_str(json).unwrap();
|
||||
assert!(response.candidates.is_some());
|
||||
let text = response
|
||||
.candidates
|
||||
.unwrap()
|
||||
.into_iter()
|
||||
.next()
|
||||
.unwrap()
|
||||
.content
|
||||
.parts
|
||||
.into_iter()
|
||||
.next()
|
||||
.unwrap()
|
||||
.text;
|
||||
assert_eq!(text, Some("Hello there!".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn error_response_deserialization() {
|
||||
let json = r#"{
|
||||
"error": {
|
||||
"message": "Invalid API key"
|
||||
}
|
||||
}"#;
|
||||
|
||||
let response: GenerateContentResponse = serde_json::from_str(json).unwrap();
|
||||
assert!(response.error.is_some());
|
||||
assert_eq!(response.error.unwrap().message, "Invalid API key");
|
||||
}
|
||||
}
|
||||
1225
src/providers/mod.rs
1225
src/providers/mod.rs
File diff suppressed because it is too large
Load diff
|
|
@ -5,9 +5,12 @@ use serde::{Deserialize, Serialize};
|
|||
|
||||
pub struct OllamaProvider {
|
||||
base_url: String,
|
||||
api_key: Option<String>,
|
||||
client: Client,
|
||||
}
|
||||
|
||||
// ─── Request Structures ───────────────────────────────────────────────────────
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct ChatRequest {
|
||||
model: String,
|
||||
|
|
@ -27,30 +30,231 @@ struct Options {
|
|||
temperature: f64,
|
||||
}
|
||||
|
||||
// ─── Response Structures ──────────────────────────────────────────────────────
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ChatResponse {
|
||||
struct ApiChatResponse {
|
||||
message: ResponseMessage,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ResponseMessage {
|
||||
#[serde(default)]
|
||||
content: String,
|
||||
#[serde(default)]
|
||||
tool_calls: Vec<OllamaToolCall>,
|
||||
/// Some models return a "thinking" field with internal reasoning
|
||||
#[serde(default)]
|
||||
thinking: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct OllamaToolCall {
|
||||
id: Option<String>,
|
||||
function: OllamaFunction,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct OllamaFunction {
|
||||
name: String,
|
||||
#[serde(default)]
|
||||
arguments: serde_json::Value,
|
||||
}
|
||||
|
||||
// ─── Implementation ───────────────────────────────────────────────────────────
|
||||
|
||||
impl OllamaProvider {
|
||||
pub fn new(base_url: Option<&str>) -> Self {
|
||||
pub fn new(base_url: Option<&str>, api_key: Option<&str>) -> Self {
|
||||
let api_key = api_key.and_then(|value| {
|
||||
let trimmed = value.trim();
|
||||
(!trimmed.is_empty()).then(|| trimmed.to_string())
|
||||
});
|
||||
|
||||
Self {
|
||||
base_url: base_url
|
||||
.unwrap_or("http://localhost:11434")
|
||||
.trim_end_matches('/')
|
||||
.to_string(),
|
||||
api_key,
|
||||
client: Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(300)) // Ollama runs locally, may be slow
|
||||
.timeout(std::time::Duration::from_secs(300))
|
||||
.connect_timeout(std::time::Duration::from_secs(10))
|
||||
.build()
|
||||
.unwrap_or_else(|_| Client::new()),
|
||||
}
|
||||
}
|
||||
|
||||
fn is_local_endpoint(&self) -> bool {
|
||||
reqwest::Url::parse(&self.base_url)
|
||||
.ok()
|
||||
.and_then(|url| url.host_str().map(|host| host.to_string()))
|
||||
.is_some_and(|host| matches!(host.as_str(), "localhost" | "127.0.0.1" | "::1"))
|
||||
}
|
||||
|
||||
fn resolve_request_details(&self, model: &str) -> anyhow::Result<(String, bool)> {
|
||||
let requests_cloud = model.ends_with(":cloud");
|
||||
let normalized_model = model.strip_suffix(":cloud").unwrap_or(model).to_string();
|
||||
|
||||
if requests_cloud && self.is_local_endpoint() {
|
||||
anyhow::bail!(
|
||||
"Model '{}' requested cloud routing, but Ollama endpoint is local. Configure api_url with a remote Ollama endpoint.",
|
||||
model
|
||||
);
|
||||
}
|
||||
|
||||
if requests_cloud && self.api_key.is_none() {
|
||||
anyhow::bail!(
|
||||
"Model '{}' requested cloud routing, but no API key is configured. Set OLLAMA_API_KEY or config api_key.",
|
||||
model
|
||||
);
|
||||
}
|
||||
|
||||
let should_auth = self.api_key.is_some() && !self.is_local_endpoint();
|
||||
|
||||
Ok((normalized_model, should_auth))
|
||||
}
|
||||
|
||||
/// Send a request to Ollama and get the parsed response
|
||||
async fn send_request(
|
||||
&self,
|
||||
messages: Vec<Message>,
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
should_auth: bool,
|
||||
) -> anyhow::Result<ApiChatResponse> {
|
||||
let request = ChatRequest {
|
||||
model: model.to_string(),
|
||||
messages,
|
||||
stream: false,
|
||||
options: Options { temperature },
|
||||
};
|
||||
|
||||
let url = format!("{}/api/chat", self.base_url);
|
||||
|
||||
tracing::debug!(
|
||||
"Ollama request: url={} model={} message_count={} temperature={}",
|
||||
url,
|
||||
model,
|
||||
request.messages.len(),
|
||||
temperature
|
||||
);
|
||||
|
||||
let mut request_builder = self.client.post(&url).json(&request);
|
||||
|
||||
if should_auth {
|
||||
if let Some(key) = self.api_key.as_ref() {
|
||||
request_builder = request_builder.bearer_auth(key);
|
||||
}
|
||||
}
|
||||
|
||||
let response = request_builder.send().await?;
|
||||
let status = response.status();
|
||||
tracing::debug!("Ollama response status: {}", status);
|
||||
|
||||
let body = response.bytes().await?;
|
||||
tracing::debug!("Ollama response body length: {} bytes", body.len());
|
||||
|
||||
if !status.is_success() {
|
||||
let raw = String::from_utf8_lossy(&body);
|
||||
let sanitized = super::sanitize_api_error(&raw);
|
||||
tracing::error!(
|
||||
"Ollama error response: status={} body_excerpt={}",
|
||||
status,
|
||||
sanitized
|
||||
);
|
||||
anyhow::bail!(
|
||||
"Ollama API error ({}): {}. Is Ollama running? (brew install ollama && ollama serve)",
|
||||
status,
|
||||
sanitized
|
||||
);
|
||||
}
|
||||
|
||||
let chat_response: ApiChatResponse = match serde_json::from_slice(&body) {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
let raw = String::from_utf8_lossy(&body);
|
||||
let sanitized = super::sanitize_api_error(&raw);
|
||||
tracing::error!(
|
||||
"Ollama response deserialization failed: {e}. body_excerpt={}",
|
||||
sanitized
|
||||
);
|
||||
anyhow::bail!("Failed to parse Ollama response: {e}");
|
||||
}
|
||||
};
|
||||
|
||||
Ok(chat_response)
|
||||
}
|
||||
|
||||
/// Convert Ollama tool calls to the JSON format expected by parse_tool_calls in loop_.rs
|
||||
///
|
||||
/// Handles quirky model behavior where tool calls are wrapped:
|
||||
/// - `{"name": "tool_call", "arguments": {"name": "shell", "arguments": {...}}}`
|
||||
/// - `{"name": "tool.shell", "arguments": {...}}`
|
||||
fn format_tool_calls_for_loop(&self, tool_calls: &[OllamaToolCall]) -> String {
|
||||
let formatted_calls: Vec<serde_json::Value> = tool_calls
|
||||
.iter()
|
||||
.map(|tc| {
|
||||
let (tool_name, tool_args) = self.extract_tool_name_and_args(tc);
|
||||
|
||||
// Arguments must be a JSON string for parse_tool_calls compatibility
|
||||
let args_str =
|
||||
serde_json::to_string(&tool_args).unwrap_or_else(|_| "{}".to_string());
|
||||
|
||||
serde_json::json!({
|
||||
"id": tc.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_name,
|
||||
"arguments": args_str
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
serde_json::json!({
|
||||
"content": "",
|
||||
"tool_calls": formatted_calls
|
||||
})
|
||||
.to_string()
|
||||
}
|
||||
|
||||
/// Extract the actual tool name and arguments from potentially nested structures
|
||||
fn extract_tool_name_and_args(&self, tc: &OllamaToolCall) -> (String, serde_json::Value) {
|
||||
let name = &tc.function.name;
|
||||
let args = &tc.function.arguments;
|
||||
|
||||
// Pattern 1: Nested tool_call wrapper (various malformed versions)
|
||||
// {"name": "tool_call", "arguments": {"name": "shell", "arguments": {"command": "date"}}}
|
||||
// {"name": "tool_call><json", "arguments": {"name": "shell", ...}}
|
||||
// {"name": "tool.call", "arguments": {"name": "shell", ...}}
|
||||
if name == "tool_call"
|
||||
|| name == "tool.call"
|
||||
|| name.starts_with("tool_call>")
|
||||
|| name.starts_with("tool_call<")
|
||||
{
|
||||
if let Some(nested_name) = args.get("name").and_then(|v| v.as_str()) {
|
||||
let nested_args = args
|
||||
.get("arguments")
|
||||
.cloned()
|
||||
.unwrap_or(serde_json::json!({}));
|
||||
tracing::debug!(
|
||||
"Unwrapped nested tool call: {} -> {} with args {:?}",
|
||||
name,
|
||||
nested_name,
|
||||
nested_args
|
||||
);
|
||||
return (nested_name.to_string(), nested_args);
|
||||
}
|
||||
}
|
||||
|
||||
// Pattern 2: Prefixed tool name (tool.shell, tool.file_read, etc.)
|
||||
if let Some(stripped) = name.strip_prefix("tool.") {
|
||||
return (stripped.to_string(), args.clone());
|
||||
}
|
||||
|
||||
// Pattern 3: Normal tool call
|
||||
(name.clone(), args.clone())
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
|
|
@ -62,6 +266,8 @@ impl Provider for OllamaProvider {
|
|||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
let (normalized_model, should_auth) = self.resolve_request_details(model)?;
|
||||
|
||||
let mut messages = Vec::new();
|
||||
|
||||
if let Some(sys) = system_prompt {
|
||||
|
|
@ -76,115 +282,281 @@ impl Provider for OllamaProvider {
|
|||
content: message.to_string(),
|
||||
});
|
||||
|
||||
let request = ChatRequest {
|
||||
model: model.to_string(),
|
||||
messages,
|
||||
stream: false,
|
||||
options: Options { temperature },
|
||||
};
|
||||
let response = self
|
||||
.send_request(messages, &normalized_model, temperature, should_auth)
|
||||
.await?;
|
||||
|
||||
let url = format!("{}/api/chat", self.base_url);
|
||||
|
||||
let response = self.client.post(&url).json(&request).send().await?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let error = response.text().await?;
|
||||
anyhow::bail!(
|
||||
"Ollama error: {error}. Is Ollama running? (brew install ollama && ollama serve)"
|
||||
// If model returned tool calls, format them for loop_.rs's parse_tool_calls
|
||||
if !response.message.tool_calls.is_empty() {
|
||||
tracing::debug!(
|
||||
"Ollama returned {} tool call(s), formatting for loop parser",
|
||||
response.message.tool_calls.len()
|
||||
);
|
||||
return Ok(self.format_tool_calls_for_loop(&response.message.tool_calls));
|
||||
}
|
||||
|
||||
let chat_response: ChatResponse = response.json().await?;
|
||||
Ok(chat_response.message.content)
|
||||
// Plain text response
|
||||
let content = response.message.content;
|
||||
|
||||
// Handle edge case: model returned only "thinking" with no content or tool calls
|
||||
if content.is_empty() {
|
||||
if let Some(thinking) = &response.message.thinking {
|
||||
tracing::warn!(
|
||||
"Ollama returned empty content with only thinking: '{}'. Model may have stopped prematurely.",
|
||||
if thinking.len() > 100 { &thinking[..100] } else { thinking }
|
||||
);
|
||||
return Ok(format!(
|
||||
"I was thinking about this: {}... but I didn't complete my response. Could you try asking again?",
|
||||
if thinking.len() > 200 { &thinking[..200] } else { thinking }
|
||||
));
|
||||
}
|
||||
tracing::warn!("Ollama returned empty content with no tool calls");
|
||||
}
|
||||
|
||||
Ok(content)
|
||||
}
|
||||
|
||||
async fn chat_with_history(
|
||||
&self,
|
||||
messages: &[crate::providers::ChatMessage],
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
let (normalized_model, should_auth) = self.resolve_request_details(model)?;
|
||||
|
||||
let api_messages: Vec<Message> = messages
|
||||
.iter()
|
||||
.map(|m| Message {
|
||||
role: m.role.clone(),
|
||||
content: m.content.clone(),
|
||||
})
|
||||
.collect();
|
||||
|
||||
let response = self
|
||||
.send_request(api_messages, &normalized_model, temperature, should_auth)
|
||||
.await?;
|
||||
|
||||
// If model returned tool calls, format them for loop_.rs's parse_tool_calls
|
||||
if !response.message.tool_calls.is_empty() {
|
||||
tracing::debug!(
|
||||
"Ollama returned {} tool call(s), formatting for loop parser",
|
||||
response.message.tool_calls.len()
|
||||
);
|
||||
return Ok(self.format_tool_calls_for_loop(&response.message.tool_calls));
|
||||
}
|
||||
|
||||
// Plain text response
|
||||
let content = response.message.content;
|
||||
|
||||
// Handle edge case: model returned only "thinking" with no content or tool calls
|
||||
// This is a model quirk - it stopped after reasoning without producing output
|
||||
if content.is_empty() {
|
||||
if let Some(thinking) = &response.message.thinking {
|
||||
tracing::warn!(
|
||||
"Ollama returned empty content with only thinking: '{}'. Model may have stopped prematurely.",
|
||||
if thinking.len() > 100 { &thinking[..100] } else { thinking }
|
||||
);
|
||||
// Return a message indicating the model's thought process but no action
|
||||
return Ok(format!(
|
||||
"I was thinking about this: {}... but I didn't complete my response. Could you try asking again?",
|
||||
if thinking.len() > 200 { &thinking[..200] } else { thinking }
|
||||
));
|
||||
}
|
||||
tracing::warn!("Ollama returned empty content with no tool calls");
|
||||
}
|
||||
|
||||
Ok(content)
|
||||
}
|
||||
|
||||
fn supports_native_tools(&self) -> bool {
|
||||
// Return false since loop_.rs uses XML-style tool parsing via system prompt
|
||||
// The model may return native tool_calls but we convert them to JSON format
|
||||
// that parse_tool_calls() understands
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
// ─── Tests ────────────────────────────────────────────────────────────────────
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn default_url() {
|
||||
let p = OllamaProvider::new(None);
|
||||
let p = OllamaProvider::new(None, None);
|
||||
assert_eq!(p.base_url, "http://localhost:11434");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn custom_url_trailing_slash() {
|
||||
let p = OllamaProvider::new(Some("http://192.168.1.100:11434/"));
|
||||
let p = OllamaProvider::new(Some("http://192.168.1.100:11434/"), None);
|
||||
assert_eq!(p.base_url, "http://192.168.1.100:11434");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn custom_url_no_trailing_slash() {
|
||||
let p = OllamaProvider::new(Some("http://myserver:11434"));
|
||||
let p = OllamaProvider::new(Some("http://myserver:11434"), None);
|
||||
assert_eq!(p.base_url, "http://myserver:11434");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_url_uses_empty() {
|
||||
let p = OllamaProvider::new(Some(""));
|
||||
let p = OllamaProvider::new(Some(""), None);
|
||||
assert_eq!(p.base_url, "");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn request_serializes_with_system() {
|
||||
let req = ChatRequest {
|
||||
model: "llama3".to_string(),
|
||||
messages: vec![
|
||||
Message {
|
||||
role: "system".to_string(),
|
||||
content: "You are ZeroClaw".to_string(),
|
||||
},
|
||||
Message {
|
||||
role: "user".to_string(),
|
||||
content: "hello".to_string(),
|
||||
},
|
||||
],
|
||||
stream: false,
|
||||
options: Options { temperature: 0.7 },
|
||||
};
|
||||
let json = serde_json::to_string(&req).unwrap();
|
||||
assert!(json.contains("\"stream\":false"));
|
||||
assert!(json.contains("llama3"));
|
||||
assert!(json.contains("system"));
|
||||
assert!(json.contains("\"temperature\":0.7"));
|
||||
fn cloud_suffix_strips_model_name() {
|
||||
let p = OllamaProvider::new(Some("https://ollama.com"), Some("ollama-key"));
|
||||
let (model, should_auth) = p.resolve_request_details("qwen3:cloud").unwrap();
|
||||
assert_eq!(model, "qwen3");
|
||||
assert!(should_auth);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn request_serializes_without_system() {
|
||||
let req = ChatRequest {
|
||||
model: "mistral".to_string(),
|
||||
messages: vec![Message {
|
||||
role: "user".to_string(),
|
||||
content: "test".to_string(),
|
||||
}],
|
||||
stream: false,
|
||||
options: Options { temperature: 0.0 },
|
||||
};
|
||||
let json = serde_json::to_string(&req).unwrap();
|
||||
assert!(!json.contains("\"role\":\"system\""));
|
||||
assert!(json.contains("mistral"));
|
||||
fn cloud_suffix_with_local_endpoint_errors() {
|
||||
let p = OllamaProvider::new(None, Some("ollama-key"));
|
||||
let error = p
|
||||
.resolve_request_details("qwen3:cloud")
|
||||
.expect_err("cloud suffix should fail on local endpoint");
|
||||
assert!(error
|
||||
.to_string()
|
||||
.contains("requested cloud routing, but Ollama endpoint is local"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn cloud_suffix_without_api_key_errors() {
|
||||
let p = OllamaProvider::new(Some("https://ollama.com"), None);
|
||||
let error = p
|
||||
.resolve_request_details("qwen3:cloud")
|
||||
.expect_err("cloud suffix should require API key");
|
||||
assert!(error
|
||||
.to_string()
|
||||
.contains("requested cloud routing, but no API key is configured"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn remote_endpoint_auth_enabled_when_key_present() {
|
||||
let p = OllamaProvider::new(Some("https://ollama.com"), Some("ollama-key"));
|
||||
let (_model, should_auth) = p.resolve_request_details("qwen3").unwrap();
|
||||
assert!(should_auth);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn local_endpoint_auth_disabled_even_with_key() {
|
||||
let p = OllamaProvider::new(None, Some("ollama-key"));
|
||||
let (_model, should_auth) = p.resolve_request_details("llama3").unwrap();
|
||||
assert!(!should_auth);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_deserializes() {
|
||||
let json = r#"{"message":{"role":"assistant","content":"Hello from Ollama!"}}"#;
|
||||
let resp: ChatResponse = serde_json::from_str(json).unwrap();
|
||||
let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(resp.message.content, "Hello from Ollama!");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_with_empty_content() {
|
||||
let json = r#"{"message":{"role":"assistant","content":""}}"#;
|
||||
let resp: ChatResponse = serde_json::from_str(json).unwrap();
|
||||
let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
|
||||
assert!(resp.message.content.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_with_multiline() {
|
||||
let json = r#"{"message":{"role":"assistant","content":"line1\nline2\nline3"}}"#;
|
||||
let resp: ChatResponse = serde_json::from_str(json).unwrap();
|
||||
assert!(resp.message.content.contains("line1"));
|
||||
fn response_with_missing_content_defaults_to_empty() {
|
||||
let json = r#"{"message":{"role":"assistant"}}"#;
|
||||
let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
|
||||
assert!(resp.message.content.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_with_thinking_field_extracts_content() {
|
||||
let json =
|
||||
r#"{"message":{"role":"assistant","content":"hello","thinking":"internal reasoning"}}"#;
|
||||
let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(resp.message.content, "hello");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_with_tool_calls_parses_correctly() {
|
||||
let json = r#"{"message":{"role":"assistant","content":"","tool_calls":[{"id":"call_123","function":{"name":"shell","arguments":{"command":"date"}}}]}}"#;
|
||||
let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
|
||||
assert!(resp.message.content.is_empty());
|
||||
assert_eq!(resp.message.tool_calls.len(), 1);
|
||||
assert_eq!(resp.message.tool_calls[0].function.name, "shell");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_tool_name_handles_nested_tool_call() {
|
||||
let provider = OllamaProvider::new(None, None);
|
||||
let tc = OllamaToolCall {
|
||||
id: Some("call_123".into()),
|
||||
function: OllamaFunction {
|
||||
name: "tool_call".into(),
|
||||
arguments: serde_json::json!({
|
||||
"name": "shell",
|
||||
"arguments": {"command": "date"}
|
||||
}),
|
||||
},
|
||||
};
|
||||
let (name, args) = provider.extract_tool_name_and_args(&tc);
|
||||
assert_eq!(name, "shell");
|
||||
assert_eq!(args.get("command").unwrap(), "date");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_tool_name_handles_prefixed_name() {
|
||||
let provider = OllamaProvider::new(None, None);
|
||||
let tc = OllamaToolCall {
|
||||
id: Some("call_123".into()),
|
||||
function: OllamaFunction {
|
||||
name: "tool.shell".into(),
|
||||
arguments: serde_json::json!({"command": "ls"}),
|
||||
},
|
||||
};
|
||||
let (name, args) = provider.extract_tool_name_and_args(&tc);
|
||||
assert_eq!(name, "shell");
|
||||
assert_eq!(args.get("command").unwrap(), "ls");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_tool_name_handles_normal_call() {
|
||||
let provider = OllamaProvider::new(None, None);
|
||||
let tc = OllamaToolCall {
|
||||
id: Some("call_123".into()),
|
||||
function: OllamaFunction {
|
||||
name: "file_read".into(),
|
||||
arguments: serde_json::json!({"path": "/tmp/test"}),
|
||||
},
|
||||
};
|
||||
let (name, args) = provider.extract_tool_name_and_args(&tc);
|
||||
assert_eq!(name, "file_read");
|
||||
assert_eq!(args.get("path").unwrap(), "/tmp/test");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn format_tool_calls_produces_valid_json() {
|
||||
let provider = OllamaProvider::new(None, None);
|
||||
let tool_calls = vec![OllamaToolCall {
|
||||
id: Some("call_abc".into()),
|
||||
function: OllamaFunction {
|
||||
name: "shell".into(),
|
||||
arguments: serde_json::json!({"command": "date"}),
|
||||
},
|
||||
}];
|
||||
|
||||
let formatted = provider.format_tool_calls_for_loop(&tool_calls);
|
||||
let parsed: serde_json::Value = serde_json::from_str(&formatted).unwrap();
|
||||
|
||||
assert!(parsed.get("tool_calls").is_some());
|
||||
let calls = parsed.get("tool_calls").unwrap().as_array().unwrap();
|
||||
assert_eq!(calls.len(), 1);
|
||||
|
||||
let func = calls[0].get("function").unwrap();
|
||||
assert_eq!(func.get("name").unwrap(), "shell");
|
||||
// arguments should be a string (JSON-encoded)
|
||||
assert!(func.get("arguments").unwrap().is_string());
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,10 +1,14 @@
|
|||
use crate::providers::traits::Provider;
|
||||
use crate::providers::traits::{
|
||||
ChatMessage, ChatRequest as ProviderChatRequest, ChatResponse as ProviderChatResponse,
|
||||
Provider, ToolCall as ProviderToolCall,
|
||||
};
|
||||
use crate::tools::ToolSpec;
|
||||
use async_trait::async_trait;
|
||||
use reqwest::Client;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
pub struct OpenAiProvider {
|
||||
api_key: Option<String>,
|
||||
credential: Option<String>,
|
||||
client: Client,
|
||||
}
|
||||
|
||||
|
|
@ -36,10 +40,79 @@ struct ResponseMessage {
|
|||
content: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct NativeChatRequest {
|
||||
model: String,
|
||||
messages: Vec<NativeMessage>,
|
||||
temperature: f64,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
tools: Option<Vec<NativeToolSpec>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
tool_choice: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct NativeMessage {
|
||||
role: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
content: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
tool_call_id: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
tool_calls: Option<Vec<NativeToolCall>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct NativeToolSpec {
|
||||
#[serde(rename = "type")]
|
||||
kind: String,
|
||||
function: NativeToolFunctionSpec,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct NativeToolFunctionSpec {
|
||||
name: String,
|
||||
description: String,
|
||||
parameters: serde_json::Value,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct NativeToolCall {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
id: Option<String>,
|
||||
#[serde(rename = "type", skip_serializing_if = "Option::is_none")]
|
||||
kind: Option<String>,
|
||||
function: NativeFunctionCall,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct NativeFunctionCall {
|
||||
name: String,
|
||||
arguments: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct NativeChatResponse {
|
||||
choices: Vec<NativeChoice>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct NativeChoice {
|
||||
message: NativeResponseMessage,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct NativeResponseMessage {
|
||||
#[serde(default)]
|
||||
content: Option<String>,
|
||||
#[serde(default)]
|
||||
tool_calls: Option<Vec<NativeToolCall>>,
|
||||
}
|
||||
|
||||
impl OpenAiProvider {
|
||||
pub fn new(api_key: Option<&str>) -> Self {
|
||||
pub fn new(credential: Option<&str>) -> Self {
|
||||
Self {
|
||||
api_key: api_key.map(ToString::to_string),
|
||||
credential: credential.map(ToString::to_string),
|
||||
client: Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(120))
|
||||
.connect_timeout(std::time::Duration::from_secs(10))
|
||||
|
|
@ -47,6 +120,107 @@ impl OpenAiProvider {
|
|||
.unwrap_or_else(|_| Client::new()),
|
||||
}
|
||||
}
|
||||
|
||||
fn convert_tools(tools: Option<&[ToolSpec]>) -> Option<Vec<NativeToolSpec>> {
|
||||
tools.map(|items| {
|
||||
items
|
||||
.iter()
|
||||
.map(|tool| NativeToolSpec {
|
||||
kind: "function".to_string(),
|
||||
function: NativeToolFunctionSpec {
|
||||
name: tool.name.clone(),
|
||||
description: tool.description.clone(),
|
||||
parameters: tool.parameters.clone(),
|
||||
},
|
||||
})
|
||||
.collect()
|
||||
})
|
||||
}
|
||||
|
||||
fn convert_messages(messages: &[ChatMessage]) -> Vec<NativeMessage> {
|
||||
messages
|
||||
.iter()
|
||||
.map(|m| {
|
||||
if m.role == "assistant" {
|
||||
if let Ok(value) = serde_json::from_str::<serde_json::Value>(&m.content) {
|
||||
if let Some(tool_calls_value) = value.get("tool_calls") {
|
||||
if let Ok(parsed_calls) =
|
||||
serde_json::from_value::<Vec<ProviderToolCall>>(
|
||||
tool_calls_value.clone(),
|
||||
)
|
||||
{
|
||||
let tool_calls = parsed_calls
|
||||
.into_iter()
|
||||
.map(|tc| NativeToolCall {
|
||||
id: Some(tc.id),
|
||||
kind: Some("function".to_string()),
|
||||
function: NativeFunctionCall {
|
||||
name: tc.name,
|
||||
arguments: tc.arguments,
|
||||
},
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
let content = value
|
||||
.get("content")
|
||||
.and_then(serde_json::Value::as_str)
|
||||
.map(ToString::to_string);
|
||||
return NativeMessage {
|
||||
role: "assistant".to_string(),
|
||||
content,
|
||||
tool_call_id: None,
|
||||
tool_calls: Some(tool_calls),
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if m.role == "tool" {
|
||||
if let Ok(value) = serde_json::from_str::<serde_json::Value>(&m.content) {
|
||||
let tool_call_id = value
|
||||
.get("tool_call_id")
|
||||
.and_then(serde_json::Value::as_str)
|
||||
.map(ToString::to_string);
|
||||
let content = value
|
||||
.get("content")
|
||||
.and_then(serde_json::Value::as_str)
|
||||
.map(ToString::to_string);
|
||||
return NativeMessage {
|
||||
role: "tool".to_string(),
|
||||
content,
|
||||
tool_call_id,
|
||||
tool_calls: None,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
NativeMessage {
|
||||
role: m.role.clone(),
|
||||
content: Some(m.content.clone()),
|
||||
tool_call_id: None,
|
||||
tool_calls: None,
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn parse_native_response(message: NativeResponseMessage) -> ProviderChatResponse {
|
||||
let tool_calls = message
|
||||
.tool_calls
|
||||
.unwrap_or_default()
|
||||
.into_iter()
|
||||
.map(|tc| ProviderToolCall {
|
||||
id: tc.id.unwrap_or_else(|| uuid::Uuid::new_v4().to_string()),
|
||||
name: tc.function.name,
|
||||
arguments: tc.function.arguments,
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
ProviderChatResponse {
|
||||
text: message.content,
|
||||
tool_calls,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
|
|
@ -58,7 +232,7 @@ impl Provider for OpenAiProvider {
|
|||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
let api_key = self.api_key.as_ref().ok_or_else(|| {
|
||||
let credential = self.credential.as_ref().ok_or_else(|| {
|
||||
anyhow::anyhow!("OpenAI API key not set. Set OPENAI_API_KEY or edit config.toml.")
|
||||
})?;
|
||||
|
||||
|
|
@ -85,14 +259,13 @@ impl Provider for OpenAiProvider {
|
|||
let response = self
|
||||
.client
|
||||
.post("https://api.openai.com/v1/chat/completions")
|
||||
.header("Authorization", format!("Bearer {api_key}"))
|
||||
.header("Authorization", format!("Bearer {credential}"))
|
||||
.json(&request)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let error = response.text().await?;
|
||||
anyhow::bail!("OpenAI API error: {error}");
|
||||
return Err(super::api_error("OpenAI", response).await);
|
||||
}
|
||||
|
||||
let chat_response: ChatResponse = response.json().await?;
|
||||
|
|
@ -104,6 +277,51 @@ impl Provider for OpenAiProvider {
|
|||
.map(|c| c.message.content)
|
||||
.ok_or_else(|| anyhow::anyhow!("No response from OpenAI"))
|
||||
}
|
||||
|
||||
async fn chat(
|
||||
&self,
|
||||
request: ProviderChatRequest<'_>,
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<ProviderChatResponse> {
|
||||
let credential = self.credential.as_ref().ok_or_else(|| {
|
||||
anyhow::anyhow!("OpenAI API key not set. Set OPENAI_API_KEY or edit config.toml.")
|
||||
})?;
|
||||
|
||||
let tools = Self::convert_tools(request.tools);
|
||||
let native_request = NativeChatRequest {
|
||||
model: model.to_string(),
|
||||
messages: Self::convert_messages(request.messages),
|
||||
temperature,
|
||||
tool_choice: tools.as_ref().map(|_| "auto".to_string()),
|
||||
tools,
|
||||
};
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post("https://api.openai.com/v1/chat/completions")
|
||||
.header("Authorization", format!("Bearer {credential}"))
|
||||
.json(&native_request)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
return Err(super::api_error("OpenAI", response).await);
|
||||
}
|
||||
|
||||
let native_response: NativeChatResponse = response.json().await?;
|
||||
let message = native_response
|
||||
.choices
|
||||
.into_iter()
|
||||
.next()
|
||||
.map(|c| c.message)
|
||||
.ok_or_else(|| anyhow::anyhow!("No response from OpenAI"))?;
|
||||
Ok(Self::parse_native_response(message))
|
||||
}
|
||||
|
||||
fn supports_native_tools(&self) -> bool {
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
|
@ -112,20 +330,20 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn creates_with_key() {
|
||||
let p = OpenAiProvider::new(Some("sk-proj-abc123"));
|
||||
assert_eq!(p.api_key.as_deref(), Some("sk-proj-abc123"));
|
||||
let p = OpenAiProvider::new(Some("openai-test-credential"));
|
||||
assert_eq!(p.credential.as_deref(), Some("openai-test-credential"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn creates_without_key() {
|
||||
let p = OpenAiProvider::new(None);
|
||||
assert!(p.api_key.is_none());
|
||||
assert!(p.credential.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn creates_with_empty_key() {
|
||||
let p = OpenAiProvider::new(Some(""));
|
||||
assert_eq!(p.api_key.as_deref(), Some(""));
|
||||
assert_eq!(p.credential.as_deref(), Some(""));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
|
|
|||
|
|
@ -1,10 +1,14 @@
|
|||
use crate::providers::traits::Provider;
|
||||
use crate::providers::traits::{
|
||||
ChatMessage, ChatRequest as ProviderChatRequest, ChatResponse as ProviderChatResponse,
|
||||
Provider, ToolCall as ProviderToolCall,
|
||||
};
|
||||
use crate::tools::ToolSpec;
|
||||
use async_trait::async_trait;
|
||||
use reqwest::Client;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
pub struct OpenRouterProvider {
|
||||
api_key: Option<String>,
|
||||
credential: Option<String>,
|
||||
client: Client,
|
||||
}
|
||||
|
||||
|
|
@ -22,7 +26,7 @@ struct Message {
|
|||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ChatResponse {
|
||||
struct ApiChatResponse {
|
||||
choices: Vec<Choice>,
|
||||
}
|
||||
|
||||
|
|
@ -36,10 +40,79 @@ struct ResponseMessage {
|
|||
content: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct NativeChatRequest {
|
||||
model: String,
|
||||
messages: Vec<NativeMessage>,
|
||||
temperature: f64,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
tools: Option<Vec<NativeToolSpec>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
tool_choice: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct NativeMessage {
|
||||
role: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
content: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
tool_call_id: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
tool_calls: Option<Vec<NativeToolCall>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct NativeToolSpec {
|
||||
#[serde(rename = "type")]
|
||||
kind: String,
|
||||
function: NativeToolFunctionSpec,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct NativeToolFunctionSpec {
|
||||
name: String,
|
||||
description: String,
|
||||
parameters: serde_json::Value,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct NativeToolCall {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
id: Option<String>,
|
||||
#[serde(rename = "type", skip_serializing_if = "Option::is_none")]
|
||||
kind: Option<String>,
|
||||
function: NativeFunctionCall,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct NativeFunctionCall {
|
||||
name: String,
|
||||
arguments: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct NativeChatResponse {
|
||||
choices: Vec<NativeChoice>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct NativeChoice {
|
||||
message: NativeResponseMessage,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct NativeResponseMessage {
|
||||
#[serde(default)]
|
||||
content: Option<String>,
|
||||
#[serde(default)]
|
||||
tool_calls: Option<Vec<NativeToolCall>>,
|
||||
}
|
||||
|
||||
impl OpenRouterProvider {
|
||||
pub fn new(api_key: Option<&str>) -> Self {
|
||||
pub fn new(credential: Option<&str>) -> Self {
|
||||
Self {
|
||||
api_key: api_key.map(ToString::to_string),
|
||||
credential: credential.map(ToString::to_string),
|
||||
client: Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(120))
|
||||
.connect_timeout(std::time::Duration::from_secs(10))
|
||||
|
|
@ -47,10 +120,129 @@ impl OpenRouterProvider {
|
|||
.unwrap_or_else(|_| Client::new()),
|
||||
}
|
||||
}
|
||||
|
||||
fn convert_tools(tools: Option<&[ToolSpec]>) -> Option<Vec<NativeToolSpec>> {
|
||||
let items = tools?;
|
||||
if items.is_empty() {
|
||||
return None;
|
||||
}
|
||||
Some(
|
||||
items
|
||||
.iter()
|
||||
.map(|tool| NativeToolSpec {
|
||||
kind: "function".to_string(),
|
||||
function: NativeToolFunctionSpec {
|
||||
name: tool.name.clone(),
|
||||
description: tool.description.clone(),
|
||||
parameters: tool.parameters.clone(),
|
||||
},
|
||||
})
|
||||
.collect(),
|
||||
)
|
||||
}
|
||||
|
||||
fn convert_messages(messages: &[ChatMessage]) -> Vec<NativeMessage> {
|
||||
messages
|
||||
.iter()
|
||||
.map(|m| {
|
||||
if m.role == "assistant" {
|
||||
if let Ok(value) = serde_json::from_str::<serde_json::Value>(&m.content) {
|
||||
if let Some(tool_calls_value) = value.get("tool_calls") {
|
||||
if let Ok(parsed_calls) =
|
||||
serde_json::from_value::<Vec<ProviderToolCall>>(
|
||||
tool_calls_value.clone(),
|
||||
)
|
||||
{
|
||||
let tool_calls = parsed_calls
|
||||
.into_iter()
|
||||
.map(|tc| NativeToolCall {
|
||||
id: Some(tc.id),
|
||||
kind: Some("function".to_string()),
|
||||
function: NativeFunctionCall {
|
||||
name: tc.name,
|
||||
arguments: tc.arguments,
|
||||
},
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
let content = value
|
||||
.get("content")
|
||||
.and_then(serde_json::Value::as_str)
|
||||
.map(ToString::to_string);
|
||||
return NativeMessage {
|
||||
role: "assistant".to_string(),
|
||||
content,
|
||||
tool_call_id: None,
|
||||
tool_calls: Some(tool_calls),
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if m.role == "tool" {
|
||||
if let Ok(value) = serde_json::from_str::<serde_json::Value>(&m.content) {
|
||||
let tool_call_id = value
|
||||
.get("tool_call_id")
|
||||
.and_then(serde_json::Value::as_str)
|
||||
.map(ToString::to_string);
|
||||
let content = value
|
||||
.get("content")
|
||||
.and_then(serde_json::Value::as_str)
|
||||
.map(ToString::to_string);
|
||||
return NativeMessage {
|
||||
role: "tool".to_string(),
|
||||
content,
|
||||
tool_call_id,
|
||||
tool_calls: None,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
NativeMessage {
|
||||
role: m.role.clone(),
|
||||
content: Some(m.content.clone()),
|
||||
tool_call_id: None,
|
||||
tool_calls: None,
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn parse_native_response(message: NativeResponseMessage) -> ProviderChatResponse {
|
||||
let tool_calls = message
|
||||
.tool_calls
|
||||
.unwrap_or_default()
|
||||
.into_iter()
|
||||
.map(|tc| ProviderToolCall {
|
||||
id: tc.id.unwrap_or_else(|| uuid::Uuid::new_v4().to_string()),
|
||||
name: tc.function.name,
|
||||
arguments: tc.function.arguments,
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
ProviderChatResponse {
|
||||
text: message.content,
|
||||
tool_calls,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Provider for OpenRouterProvider {
|
||||
async fn warmup(&self) -> anyhow::Result<()> {
|
||||
// Hit a lightweight endpoint to establish TLS + HTTP/2 connection pool.
|
||||
// This prevents the first real chat request from timing out on cold start.
|
||||
if let Some(credential) = self.credential.as_ref() {
|
||||
self.client
|
||||
.get("https://openrouter.ai/api/v1/auth/key")
|
||||
.header("Authorization", format!("Bearer {credential}"))
|
||||
.send()
|
||||
.await?
|
||||
.error_for_status()?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn chat_with_system(
|
||||
&self,
|
||||
system_prompt: Option<&str>,
|
||||
|
|
@ -58,7 +250,7 @@ impl Provider for OpenRouterProvider {
|
|||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
let api_key = self.api_key.as_ref()
|
||||
let credential = self.credential.as_ref()
|
||||
.ok_or_else(|| anyhow::anyhow!("OpenRouter API key not set. Run `zeroclaw onboard` or set OPENROUTER_API_KEY env var."))?;
|
||||
|
||||
let mut messages = Vec::new();
|
||||
|
|
@ -84,7 +276,7 @@ impl Provider for OpenRouterProvider {
|
|||
let response = self
|
||||
.client
|
||||
.post("https://openrouter.ai/api/v1/chat/completions")
|
||||
.header("Authorization", format!("Bearer {api_key}"))
|
||||
.header("Authorization", format!("Bearer {credential}"))
|
||||
.header(
|
||||
"HTTP-Referer",
|
||||
"https://github.com/theonlyhennygod/zeroclaw",
|
||||
|
|
@ -95,11 +287,10 @@ impl Provider for OpenRouterProvider {
|
|||
.await?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let error = response.text().await?;
|
||||
anyhow::bail!("OpenRouter API error: {error}");
|
||||
return Err(super::api_error("OpenRouter", response).await);
|
||||
}
|
||||
|
||||
let chat_response: ChatResponse = response.json().await?;
|
||||
let chat_response: ApiChatResponse = response.json().await?;
|
||||
|
||||
chat_response
|
||||
.choices
|
||||
|
|
@ -108,4 +299,455 @@ impl Provider for OpenRouterProvider {
|
|||
.map(|c| c.message.content)
|
||||
.ok_or_else(|| anyhow::anyhow!("No response from OpenRouter"))
|
||||
}
|
||||
|
||||
async fn chat_with_history(
|
||||
&self,
|
||||
messages: &[ChatMessage],
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
let credential = self.credential.as_ref()
|
||||
.ok_or_else(|| anyhow::anyhow!("OpenRouter API key not set. Run `zeroclaw onboard` or set OPENROUTER_API_KEY env var."))?;
|
||||
|
||||
let api_messages: Vec<Message> = messages
|
||||
.iter()
|
||||
.map(|m| Message {
|
||||
role: m.role.clone(),
|
||||
content: m.content.clone(),
|
||||
})
|
||||
.collect();
|
||||
|
||||
let request = ChatRequest {
|
||||
model: model.to_string(),
|
||||
messages: api_messages,
|
||||
temperature,
|
||||
};
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post("https://openrouter.ai/api/v1/chat/completions")
|
||||
.header("Authorization", format!("Bearer {credential}"))
|
||||
.header(
|
||||
"HTTP-Referer",
|
||||
"https://github.com/theonlyhennygod/zeroclaw",
|
||||
)
|
||||
.header("X-Title", "ZeroClaw")
|
||||
.json(&request)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
return Err(super::api_error("OpenRouter", response).await);
|
||||
}
|
||||
|
||||
let chat_response: ApiChatResponse = response.json().await?;
|
||||
|
||||
chat_response
|
||||
.choices
|
||||
.into_iter()
|
||||
.next()
|
||||
.map(|c| c.message.content)
|
||||
.ok_or_else(|| anyhow::anyhow!("No response from OpenRouter"))
|
||||
}
|
||||
|
||||
async fn chat(
|
||||
&self,
|
||||
request: ProviderChatRequest<'_>,
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<ProviderChatResponse> {
|
||||
let credential = self.credential.as_ref().ok_or_else(|| {
|
||||
anyhow::anyhow!(
|
||||
"OpenRouter API key not set. Run `zeroclaw onboard` or set OPENROUTER_API_KEY env var."
|
||||
)
|
||||
})?;
|
||||
|
||||
let tools = Self::convert_tools(request.tools);
|
||||
let native_request = NativeChatRequest {
|
||||
model: model.to_string(),
|
||||
messages: Self::convert_messages(request.messages),
|
||||
temperature,
|
||||
tool_choice: tools.as_ref().map(|_| "auto".to_string()),
|
||||
tools,
|
||||
};
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post("https://openrouter.ai/api/v1/chat/completions")
|
||||
.header("Authorization", format!("Bearer {credential}"))
|
||||
.header(
|
||||
"HTTP-Referer",
|
||||
"https://github.com/theonlyhennygod/zeroclaw",
|
||||
)
|
||||
.header("X-Title", "ZeroClaw")
|
||||
.json(&native_request)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
return Err(super::api_error("OpenRouter", response).await);
|
||||
}
|
||||
|
||||
let native_response: NativeChatResponse = response.json().await?;
|
||||
let message = native_response
|
||||
.choices
|
||||
.into_iter()
|
||||
.next()
|
||||
.map(|c| c.message)
|
||||
.ok_or_else(|| anyhow::anyhow!("No response from OpenRouter"))?;
|
||||
Ok(Self::parse_native_response(message))
|
||||
}
|
||||
|
||||
fn supports_native_tools(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
async fn chat_with_tools(
|
||||
&self,
|
||||
messages: &[ChatMessage],
|
||||
tools: &[serde_json::Value],
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<ProviderChatResponse> {
|
||||
let credential = self.credential.as_ref().ok_or_else(|| {
|
||||
anyhow::anyhow!(
|
||||
"OpenRouter API key not set. Run `zeroclaw onboard` or set OPENROUTER_API_KEY env var."
|
||||
)
|
||||
})?;
|
||||
|
||||
// Convert tool JSON values to NativeToolSpec
|
||||
let native_tools: Option<Vec<NativeToolSpec>> = if tools.is_empty() {
|
||||
None
|
||||
} else {
|
||||
let specs: Vec<NativeToolSpec> = tools
|
||||
.iter()
|
||||
.filter_map(|t| {
|
||||
let func = t.get("function")?;
|
||||
Some(NativeToolSpec {
|
||||
kind: "function".to_string(),
|
||||
function: NativeToolFunctionSpec {
|
||||
name: func.get("name")?.as_str()?.to_string(),
|
||||
description: func
|
||||
.get("description")
|
||||
.and_then(|d| d.as_str())
|
||||
.unwrap_or("")
|
||||
.to_string(),
|
||||
parameters: func
|
||||
.get("parameters")
|
||||
.cloned()
|
||||
.unwrap_or(serde_json::json!({})),
|
||||
},
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
if specs.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(specs)
|
||||
}
|
||||
};
|
||||
|
||||
// Convert ChatMessage to NativeMessage, preserving structured assistant/tool entries
|
||||
// when history contains native tool-call metadata.
|
||||
let native_messages = Self::convert_messages(messages);
|
||||
|
||||
let native_request = NativeChatRequest {
|
||||
model: model.to_string(),
|
||||
messages: native_messages,
|
||||
temperature,
|
||||
tool_choice: native_tools.as_ref().map(|_| "auto".to_string()),
|
||||
tools: native_tools,
|
||||
};
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post("https://openrouter.ai/api/v1/chat/completions")
|
||||
.header("Authorization", format!("Bearer {credential}"))
|
||||
.header(
|
||||
"HTTP-Referer",
|
||||
"https://github.com/theonlyhennygod/zeroclaw",
|
||||
)
|
||||
.header("X-Title", "ZeroClaw")
|
||||
.json(&native_request)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
return Err(super::api_error("OpenRouter", response).await);
|
||||
}
|
||||
|
||||
let native_response: NativeChatResponse = response.json().await?;
|
||||
let message = native_response
|
||||
.choices
|
||||
.into_iter()
|
||||
.next()
|
||||
.map(|c| c.message)
|
||||
.ok_or_else(|| anyhow::anyhow!("No response from OpenRouter"))?;
|
||||
Ok(Self::parse_native_response(message))
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::providers::traits::{ChatMessage, Provider};
|
||||
|
||||
#[test]
|
||||
fn creates_with_key() {
|
||||
let provider = OpenRouterProvider::new(Some("openrouter-test-credential"));
|
||||
assert_eq!(
|
||||
provider.credential.as_deref(),
|
||||
Some("openrouter-test-credential")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn creates_without_key() {
|
||||
let provider = OpenRouterProvider::new(None);
|
||||
assert!(provider.credential.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn warmup_without_key_is_noop() {
|
||||
let provider = OpenRouterProvider::new(None);
|
||||
let result = provider.warmup().await;
|
||||
assert!(result.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn chat_with_system_fails_without_key() {
|
||||
let provider = OpenRouterProvider::new(None);
|
||||
let result = provider
|
||||
.chat_with_system(Some("system"), "hello", "openai/gpt-4o", 0.2)
|
||||
.await;
|
||||
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().to_string().contains("API key not set"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn chat_with_history_fails_without_key() {
|
||||
let provider = OpenRouterProvider::new(None);
|
||||
let messages = vec![
|
||||
ChatMessage {
|
||||
role: "system".into(),
|
||||
content: "be concise".into(),
|
||||
},
|
||||
ChatMessage {
|
||||
role: "user".into(),
|
||||
content: "hello".into(),
|
||||
},
|
||||
];
|
||||
|
||||
let result = provider
|
||||
.chat_with_history(&messages, "anthropic/claude-sonnet-4", 0.7)
|
||||
.await;
|
||||
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().to_string().contains("API key not set"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn chat_request_serializes_with_system_and_user() {
|
||||
let request = ChatRequest {
|
||||
model: "anthropic/claude-sonnet-4".into(),
|
||||
messages: vec![
|
||||
Message {
|
||||
role: "system".into(),
|
||||
content: "You are helpful".into(),
|
||||
},
|
||||
Message {
|
||||
role: "user".into(),
|
||||
content: "Summarize this".into(),
|
||||
},
|
||||
],
|
||||
temperature: 0.5,
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&request).unwrap();
|
||||
|
||||
assert!(json.contains("anthropic/claude-sonnet-4"));
|
||||
assert!(json.contains("\"role\":\"system\""));
|
||||
assert!(json.contains("\"role\":\"user\""));
|
||||
assert!(json.contains("\"temperature\":0.5"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn chat_request_serializes_history_messages() {
|
||||
let messages = [
|
||||
ChatMessage {
|
||||
role: "assistant".into(),
|
||||
content: "Previous answer".into(),
|
||||
},
|
||||
ChatMessage {
|
||||
role: "user".into(),
|
||||
content: "Follow-up".into(),
|
||||
},
|
||||
];
|
||||
|
||||
let request = ChatRequest {
|
||||
model: "google/gemini-2.5-pro".into(),
|
||||
messages: messages
|
||||
.iter()
|
||||
.map(|msg| Message {
|
||||
role: msg.role.clone(),
|
||||
content: msg.content.clone(),
|
||||
})
|
||||
.collect(),
|
||||
temperature: 0.0,
|
||||
};
|
||||
|
||||
let json = serde_json::to_string(&request).unwrap();
|
||||
assert!(json.contains("\"role\":\"assistant\""));
|
||||
assert!(json.contains("\"role\":\"user\""));
|
||||
assert!(json.contains("google/gemini-2.5-pro"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_deserializes_single_choice() {
|
||||
let json = r#"{"choices":[{"message":{"content":"Hi from OpenRouter"}}]}"#;
|
||||
|
||||
let response: ApiChatResponse = serde_json::from_str(json).unwrap();
|
||||
|
||||
assert_eq!(response.choices.len(), 1);
|
||||
assert_eq!(response.choices[0].message.content, "Hi from OpenRouter");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_deserializes_empty_choices() {
|
||||
let json = r#"{"choices":[]}"#;
|
||||
|
||||
let response: ApiChatResponse = serde_json::from_str(json).unwrap();
|
||||
|
||||
assert!(response.choices.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn chat_with_tools_fails_without_key() {
|
||||
let provider = OpenRouterProvider::new(None);
|
||||
let messages = vec![ChatMessage {
|
||||
role: "user".into(),
|
||||
content: "What is the date?".into(),
|
||||
}];
|
||||
let tools = vec![serde_json::json!({
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "shell",
|
||||
"description": "Run a shell command",
|
||||
"parameters": {"type": "object", "properties": {"command": {"type": "string"}}}
|
||||
}
|
||||
})];
|
||||
|
||||
let result = provider
|
||||
.chat_with_tools(&messages, &tools, "deepseek/deepseek-chat", 0.5)
|
||||
.await;
|
||||
|
||||
assert!(result.is_err());
|
||||
assert!(result.unwrap_err().to_string().contains("API key not set"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn native_response_deserializes_with_tool_calls() {
|
||||
let json = r#"{
|
||||
"choices":[{
|
||||
"message":{
|
||||
"content":null,
|
||||
"tool_calls":[
|
||||
{"id":"call_123","type":"function","function":{"name":"get_price","arguments":"{\"symbol\":\"BTC\"}"}}
|
||||
]
|
||||
}
|
||||
}]
|
||||
}"#;
|
||||
|
||||
let response: NativeChatResponse = serde_json::from_str(json).unwrap();
|
||||
|
||||
assert_eq!(response.choices.len(), 1);
|
||||
let message = &response.choices[0].message;
|
||||
assert!(message.content.is_none());
|
||||
let tool_calls = message.tool_calls.as_ref().unwrap();
|
||||
assert_eq!(tool_calls.len(), 1);
|
||||
assert_eq!(tool_calls[0].id.as_deref(), Some("call_123"));
|
||||
assert_eq!(tool_calls[0].function.name, "get_price");
|
||||
assert_eq!(tool_calls[0].function.arguments, "{\"symbol\":\"BTC\"}");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn native_response_deserializes_with_text_and_tool_calls() {
|
||||
let json = r#"{
|
||||
"choices":[{
|
||||
"message":{
|
||||
"content":"I'll get that for you.",
|
||||
"tool_calls":[
|
||||
{"id":"call_456","type":"function","function":{"name":"shell","arguments":"{\"command\":\"date\"}"}}
|
||||
]
|
||||
}
|
||||
}]
|
||||
}"#;
|
||||
|
||||
let response: NativeChatResponse = serde_json::from_str(json).unwrap();
|
||||
|
||||
assert_eq!(response.choices.len(), 1);
|
||||
let message = &response.choices[0].message;
|
||||
assert_eq!(message.content.as_deref(), Some("I'll get that for you."));
|
||||
let tool_calls = message.tool_calls.as_ref().unwrap();
|
||||
assert_eq!(tool_calls.len(), 1);
|
||||
assert_eq!(tool_calls[0].function.name, "shell");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_native_response_converts_to_chat_response() {
|
||||
let message = NativeResponseMessage {
|
||||
content: Some("Here you go.".into()),
|
||||
tool_calls: Some(vec![NativeToolCall {
|
||||
id: Some("call_789".into()),
|
||||
kind: Some("function".into()),
|
||||
function: NativeFunctionCall {
|
||||
name: "file_read".into(),
|
||||
arguments: r#"{"path":"test.txt"}"#.into(),
|
||||
},
|
||||
}]),
|
||||
};
|
||||
|
||||
let response = OpenRouterProvider::parse_native_response(message);
|
||||
|
||||
assert_eq!(response.text.as_deref(), Some("Here you go."));
|
||||
assert_eq!(response.tool_calls.len(), 1);
|
||||
assert_eq!(response.tool_calls[0].id, "call_789");
|
||||
assert_eq!(response.tool_calls[0].name, "file_read");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn convert_messages_parses_assistant_tool_call_payload() {
|
||||
let messages = vec![ChatMessage {
|
||||
role: "assistant".into(),
|
||||
content: r#"{"content":"Using tool","tool_calls":[{"id":"call_abc","name":"shell","arguments":"{\"command\":\"pwd\"}"}]}"#
|
||||
.into(),
|
||||
}];
|
||||
|
||||
let converted = OpenRouterProvider::convert_messages(&messages);
|
||||
assert_eq!(converted.len(), 1);
|
||||
assert_eq!(converted[0].role, "assistant");
|
||||
assert_eq!(converted[0].content.as_deref(), Some("Using tool"));
|
||||
|
||||
let tool_calls = converted[0].tool_calls.as_ref().unwrap();
|
||||
assert_eq!(tool_calls.len(), 1);
|
||||
assert_eq!(tool_calls[0].id.as_deref(), Some("call_abc"));
|
||||
assert_eq!(tool_calls[0].function.name, "shell");
|
||||
assert_eq!(tool_calls[0].function.arguments, r#"{"command":"pwd"}"#);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn convert_messages_parses_tool_result_payload() {
|
||||
let messages = vec![ChatMessage {
|
||||
role: "tool".into(),
|
||||
content: r#"{"tool_call_id":"call_xyz","content":"done"}"#.into(),
|
||||
}];
|
||||
|
||||
let converted = OpenRouterProvider::convert_messages(&messages);
|
||||
assert_eq!(converted.len(), 1);
|
||||
assert_eq!(converted[0].role, "tool");
|
||||
assert_eq!(converted[0].tool_call_id.as_deref(), Some("call_xyz"));
|
||||
assert_eq!(converted[0].content.as_deref(), Some("done"));
|
||||
assert!(converted[0].tool_calls.is_none());
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,12 +1,85 @@
|
|||
use super::traits::{ChatMessage, StreamChunk, StreamOptions, StreamResult};
|
||||
use super::Provider;
|
||||
use async_trait::async_trait;
|
||||
use futures_util::{stream, StreamExt};
|
||||
use std::collections::HashMap;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use std::time::Duration;
|
||||
|
||||
/// Provider wrapper with retry + fallback behavior.
|
||||
/// Check if an error is non-retryable (client errors that won't resolve with retries).
|
||||
fn is_non_retryable(err: &anyhow::Error) -> bool {
|
||||
if let Some(reqwest_err) = err.downcast_ref::<reqwest::Error>() {
|
||||
if let Some(status) = reqwest_err.status() {
|
||||
let code = status.as_u16();
|
||||
return status.is_client_error() && code != 429 && code != 408;
|
||||
}
|
||||
}
|
||||
let msg = err.to_string();
|
||||
for word in msg.split(|c: char| !c.is_ascii_digit()) {
|
||||
if let Ok(code) = word.parse::<u16>() {
|
||||
if (400..500).contains(&code) {
|
||||
return code != 429 && code != 408;
|
||||
}
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
/// Check if an error is a rate-limit (429) error.
|
||||
fn is_rate_limited(err: &anyhow::Error) -> bool {
|
||||
if let Some(reqwest_err) = err.downcast_ref::<reqwest::Error>() {
|
||||
if let Some(status) = reqwest_err.status() {
|
||||
return status.as_u16() == 429;
|
||||
}
|
||||
}
|
||||
let msg = err.to_string();
|
||||
msg.contains("429")
|
||||
&& (msg.contains("Too Many") || msg.contains("rate") || msg.contains("limit"))
|
||||
}
|
||||
|
||||
/// Try to extract a Retry-After value (in milliseconds) from an error message.
|
||||
/// Looks for patterns like `Retry-After: 5` or `retry_after: 2.5` in the error string.
|
||||
fn parse_retry_after_ms(err: &anyhow::Error) -> Option<u64> {
|
||||
let msg = err.to_string();
|
||||
let lower = msg.to_lowercase();
|
||||
|
||||
// Look for "retry-after: <number>" or "retry_after: <number>"
|
||||
for prefix in &[
|
||||
"retry-after:",
|
||||
"retry_after:",
|
||||
"retry-after ",
|
||||
"retry_after ",
|
||||
] {
|
||||
if let Some(pos) = lower.find(prefix) {
|
||||
let after = &msg[pos + prefix.len()..];
|
||||
let num_str: String = after
|
||||
.trim()
|
||||
.chars()
|
||||
.take_while(|c| c.is_ascii_digit() || *c == '.')
|
||||
.collect();
|
||||
if let Ok(secs) = num_str.parse::<f64>() {
|
||||
if secs.is_finite() && secs >= 0.0 {
|
||||
let millis = Duration::from_secs_f64(secs).as_millis();
|
||||
if let Ok(value) = u64::try_from(millis) {
|
||||
return Some(value);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
/// Provider wrapper with retry, fallback, auth rotation, and model failover.
|
||||
pub struct ReliableProvider {
|
||||
providers: Vec<(String, Box<dyn Provider>)>,
|
||||
max_retries: u32,
|
||||
base_backoff_ms: u64,
|
||||
/// Extra API keys for rotation (index tracks round-robin position).
|
||||
api_keys: Vec<String>,
|
||||
key_index: AtomicUsize,
|
||||
/// Per-model fallback chains: model_name → [fallback_model_1, fallback_model_2, ...]
|
||||
model_fallbacks: HashMap<String, Vec<String>>,
|
||||
}
|
||||
|
||||
impl ReliableProvider {
|
||||
|
|
@ -19,12 +92,65 @@ impl ReliableProvider {
|
|||
providers,
|
||||
max_retries,
|
||||
base_backoff_ms: base_backoff_ms.max(50),
|
||||
api_keys: Vec::new(),
|
||||
key_index: AtomicUsize::new(0),
|
||||
model_fallbacks: HashMap::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Set additional API keys for round-robin rotation on rate-limit errors.
|
||||
pub fn with_api_keys(mut self, keys: Vec<String>) -> Self {
|
||||
self.api_keys = keys;
|
||||
self
|
||||
}
|
||||
|
||||
/// Set per-model fallback chains.
|
||||
pub fn with_model_fallbacks(mut self, fallbacks: HashMap<String, Vec<String>>) -> Self {
|
||||
self.model_fallbacks = fallbacks;
|
||||
self
|
||||
}
|
||||
|
||||
/// Build the list of models to try: [original, fallback1, fallback2, ...]
|
||||
fn model_chain<'a>(&'a self, model: &'a str) -> Vec<&'a str> {
|
||||
let mut chain = vec![model];
|
||||
if let Some(fallbacks) = self.model_fallbacks.get(model) {
|
||||
chain.extend(fallbacks.iter().map(|s| s.as_str()));
|
||||
}
|
||||
chain
|
||||
}
|
||||
|
||||
/// Advance to the next API key and return it, or None if no extra keys configured.
|
||||
fn rotate_key(&self) -> Option<&str> {
|
||||
if self.api_keys.is_empty() {
|
||||
return None;
|
||||
}
|
||||
let idx = self.key_index.fetch_add(1, Ordering::Relaxed) % self.api_keys.len();
|
||||
Some(&self.api_keys[idx])
|
||||
}
|
||||
|
||||
/// Compute backoff duration, respecting Retry-After if present.
|
||||
fn compute_backoff(&self, base: u64, err: &anyhow::Error) -> u64 {
|
||||
if let Some(retry_after) = parse_retry_after_ms(err) {
|
||||
// Use Retry-After but cap at 30s to avoid indefinite waits
|
||||
retry_after.min(30_000).max(base)
|
||||
} else {
|
||||
base
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Provider for ReliableProvider {
|
||||
async fn warmup(&self) -> anyhow::Result<()> {
|
||||
for (name, provider) in &self.providers {
|
||||
tracing::info!(provider = name, "Warming up provider connection pool");
|
||||
if provider.warmup().await.is_err() {
|
||||
tracing::warn!(provider = name, "Warmup failed (non-fatal)");
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn chat_with_system(
|
||||
&self,
|
||||
system_prompt: Option<&str>,
|
||||
|
|
@ -32,58 +158,278 @@ impl Provider for ReliableProvider {
|
|||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
let models = self.model_chain(model);
|
||||
let mut failures = Vec::new();
|
||||
|
||||
for (provider_name, provider) in &self.providers {
|
||||
let mut backoff_ms = self.base_backoff_ms;
|
||||
for current_model in &models {
|
||||
for (provider_name, provider) in &self.providers {
|
||||
let mut backoff_ms = self.base_backoff_ms;
|
||||
|
||||
for attempt in 0..=self.max_retries {
|
||||
match provider
|
||||
.chat_with_system(system_prompt, message, model, temperature)
|
||||
.await
|
||||
{
|
||||
Ok(resp) => {
|
||||
if attempt > 0 {
|
||||
tracing::info!(
|
||||
provider = provider_name,
|
||||
attempt,
|
||||
"Provider recovered after retries"
|
||||
);
|
||||
for attempt in 0..=self.max_retries {
|
||||
match provider
|
||||
.chat_with_system(system_prompt, message, current_model, temperature)
|
||||
.await
|
||||
{
|
||||
Ok(resp) => {
|
||||
if attempt > 0 || *current_model != model {
|
||||
tracing::info!(
|
||||
provider = provider_name,
|
||||
model = *current_model,
|
||||
attempt,
|
||||
original_model = model,
|
||||
"Provider recovered (failover/retry)"
|
||||
);
|
||||
}
|
||||
return Ok(resp);
|
||||
}
|
||||
return Ok(resp);
|
||||
}
|
||||
Err(e) => {
|
||||
failures.push(format!(
|
||||
"{provider_name} attempt {}/{}: {e}",
|
||||
attempt + 1,
|
||||
self.max_retries + 1
|
||||
));
|
||||
Err(e) => {
|
||||
let non_retryable = is_non_retryable(&e);
|
||||
let rate_limited = is_rate_limited(&e);
|
||||
|
||||
if attempt < self.max_retries {
|
||||
tracing::warn!(
|
||||
provider = provider_name,
|
||||
attempt = attempt + 1,
|
||||
max_retries = self.max_retries,
|
||||
"Provider call failed, retrying"
|
||||
);
|
||||
tokio::time::sleep(Duration::from_millis(backoff_ms)).await;
|
||||
backoff_ms = (backoff_ms.saturating_mul(2)).min(10_000);
|
||||
let failure_reason = if rate_limited {
|
||||
"rate_limited"
|
||||
} else if non_retryable {
|
||||
"non_retryable"
|
||||
} else {
|
||||
"retryable"
|
||||
};
|
||||
failures.push(format!(
|
||||
"{provider_name}/{current_model} attempt {}/{}: {failure_reason}",
|
||||
attempt + 1,
|
||||
self.max_retries + 1
|
||||
));
|
||||
|
||||
// On rate-limit, try rotating API key
|
||||
if rate_limited {
|
||||
if let Some(new_key) = self.rotate_key() {
|
||||
tracing::info!(
|
||||
provider = provider_name,
|
||||
"Rate limited, rotated API key (key ending ...{})",
|
||||
&new_key[new_key.len().saturating_sub(4)..]
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
if non_retryable {
|
||||
tracing::warn!(
|
||||
provider = provider_name,
|
||||
model = *current_model,
|
||||
"Non-retryable error, moving on"
|
||||
);
|
||||
break;
|
||||
}
|
||||
|
||||
if attempt < self.max_retries {
|
||||
let wait = self.compute_backoff(backoff_ms, &e);
|
||||
tracing::warn!(
|
||||
provider = provider_name,
|
||||
model = *current_model,
|
||||
attempt = attempt + 1,
|
||||
backoff_ms = wait,
|
||||
"Provider call failed, retrying"
|
||||
);
|
||||
tokio::time::sleep(Duration::from_millis(wait)).await;
|
||||
backoff_ms = (backoff_ms.saturating_mul(2)).min(10_000);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tracing::warn!(
|
||||
provider = provider_name,
|
||||
model = *current_model,
|
||||
"Exhausted retries, trying next provider/model"
|
||||
);
|
||||
}
|
||||
|
||||
tracing::warn!(provider = provider_name, "Switching to fallback provider");
|
||||
if *current_model != model {
|
||||
tracing::warn!(
|
||||
original_model = model,
|
||||
fallback_model = *current_model,
|
||||
"Model fallback exhausted all providers, trying next fallback model"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
anyhow::bail!("All providers failed. Attempts:\n{}", failures.join("\n"))
|
||||
anyhow::bail!(
|
||||
"All providers/models failed. Attempts:\n{}",
|
||||
failures.join("\n")
|
||||
)
|
||||
}
|
||||
|
||||
async fn chat_with_history(
|
||||
&self,
|
||||
messages: &[ChatMessage],
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
let models = self.model_chain(model);
|
||||
let mut failures = Vec::new();
|
||||
|
||||
for current_model in &models {
|
||||
for (provider_name, provider) in &self.providers {
|
||||
let mut backoff_ms = self.base_backoff_ms;
|
||||
|
||||
for attempt in 0..=self.max_retries {
|
||||
match provider
|
||||
.chat_with_history(messages, current_model, temperature)
|
||||
.await
|
||||
{
|
||||
Ok(resp) => {
|
||||
if attempt > 0 || *current_model != model {
|
||||
tracing::info!(
|
||||
provider = provider_name,
|
||||
model = *current_model,
|
||||
attempt,
|
||||
original_model = model,
|
||||
"Provider recovered (failover/retry)"
|
||||
);
|
||||
}
|
||||
return Ok(resp);
|
||||
}
|
||||
Err(e) => {
|
||||
let non_retryable = is_non_retryable(&e);
|
||||
let rate_limited = is_rate_limited(&e);
|
||||
|
||||
let failure_reason = if rate_limited {
|
||||
"rate_limited"
|
||||
} else if non_retryable {
|
||||
"non_retryable"
|
||||
} else {
|
||||
"retryable"
|
||||
};
|
||||
failures.push(format!(
|
||||
"{provider_name}/{current_model} attempt {}/{}: {failure_reason}",
|
||||
attempt + 1,
|
||||
self.max_retries + 1
|
||||
));
|
||||
|
||||
if rate_limited {
|
||||
if let Some(new_key) = self.rotate_key() {
|
||||
tracing::info!(
|
||||
provider = provider_name,
|
||||
"Rate limited, rotated API key (key ending ...{})",
|
||||
&new_key[new_key.len().saturating_sub(4)..]
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
if non_retryable {
|
||||
tracing::warn!(
|
||||
provider = provider_name,
|
||||
model = *current_model,
|
||||
"Non-retryable error, moving on"
|
||||
);
|
||||
break;
|
||||
}
|
||||
|
||||
if attempt < self.max_retries {
|
||||
let wait = self.compute_backoff(backoff_ms, &e);
|
||||
tracing::warn!(
|
||||
provider = provider_name,
|
||||
model = *current_model,
|
||||
attempt = attempt + 1,
|
||||
backoff_ms = wait,
|
||||
"Provider call failed, retrying"
|
||||
);
|
||||
tokio::time::sleep(Duration::from_millis(wait)).await;
|
||||
backoff_ms = (backoff_ms.saturating_mul(2)).min(10_000);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
tracing::warn!(
|
||||
provider = provider_name,
|
||||
model = *current_model,
|
||||
"Exhausted retries, trying next provider/model"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
anyhow::bail!(
|
||||
"All providers/models failed. Attempts:\n{}",
|
||||
failures.join("\n")
|
||||
)
|
||||
}
|
||||
|
||||
fn supports_streaming(&self) -> bool {
|
||||
self.providers.iter().any(|(_, p)| p.supports_streaming())
|
||||
}
|
||||
|
||||
fn stream_chat_with_system(
|
||||
&self,
|
||||
system_prompt: Option<&str>,
|
||||
message: &str,
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
options: StreamOptions,
|
||||
) -> stream::BoxStream<'static, StreamResult<StreamChunk>> {
|
||||
// Try each provider/model combination for streaming
|
||||
// For streaming, we use the first provider that supports it and has streaming enabled
|
||||
for (provider_name, provider) in &self.providers {
|
||||
if !provider.supports_streaming() || !options.enabled {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Clone provider data for the stream
|
||||
let provider_clone = provider_name.clone();
|
||||
|
||||
// Try the first model in the chain for streaming
|
||||
let current_model = match self.model_chain(model).first() {
|
||||
Some(m) => m.to_string(),
|
||||
None => model.to_string(),
|
||||
};
|
||||
|
||||
// For streaming, we attempt once and propagate errors
|
||||
// The caller can retry the entire request if needed
|
||||
let stream = provider.stream_chat_with_system(
|
||||
system_prompt,
|
||||
message,
|
||||
¤t_model,
|
||||
temperature,
|
||||
options,
|
||||
);
|
||||
|
||||
// Use a channel to bridge the stream with logging
|
||||
let (tx, rx) = tokio::sync::mpsc::channel::<StreamResult<StreamChunk>>(100);
|
||||
|
||||
tokio::spawn(async move {
|
||||
let mut stream = stream;
|
||||
while let Some(chunk) = stream.next().await {
|
||||
if let Err(ref e) = chunk {
|
||||
tracing::warn!(
|
||||
provider = provider_clone,
|
||||
model = current_model,
|
||||
"Streaming error: {e}"
|
||||
);
|
||||
}
|
||||
if tx.send(chunk).await.is_err() {
|
||||
break; // Receiver dropped
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Convert channel receiver to stream
|
||||
return stream::unfold(rx, |mut rx| async move {
|
||||
rx.recv().await.map(|chunk| (chunk, rx))
|
||||
})
|
||||
.boxed();
|
||||
}
|
||||
|
||||
// No streaming support available
|
||||
stream::once(async move {
|
||||
Err(super::traits::StreamError::Provider(
|
||||
"No provider supports streaming".to_string(),
|
||||
))
|
||||
})
|
||||
.boxed()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use std::sync::Arc;
|
||||
|
||||
struct MockProvider {
|
||||
|
|
@ -108,8 +454,49 @@ mod tests {
|
|||
}
|
||||
Ok(self.response.to_string())
|
||||
}
|
||||
|
||||
async fn chat_with_history(
|
||||
&self,
|
||||
_messages: &[ChatMessage],
|
||||
_model: &str,
|
||||
_temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
let attempt = self.calls.fetch_add(1, Ordering::SeqCst) + 1;
|
||||
if attempt <= self.fail_until_attempt {
|
||||
anyhow::bail!(self.error);
|
||||
}
|
||||
Ok(self.response.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
/// Mock that records which model was used for each call.
|
||||
struct ModelAwareMock {
|
||||
calls: Arc<AtomicUsize>,
|
||||
models_seen: parking_lot::Mutex<Vec<String>>,
|
||||
fail_models: Vec<&'static str>,
|
||||
response: &'static str,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Provider for ModelAwareMock {
|
||||
async fn chat_with_system(
|
||||
&self,
|
||||
_system_prompt: Option<&str>,
|
||||
_message: &str,
|
||||
model: &str,
|
||||
_temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
self.calls.fetch_add(1, Ordering::SeqCst);
|
||||
self.models_seen.lock().push(model.to_string());
|
||||
if self.fail_models.contains(&model) {
|
||||
anyhow::bail!("500 model {} unavailable", model);
|
||||
}
|
||||
Ok(self.response.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
// ── Existing tests (preserved) ──
|
||||
|
||||
#[tokio::test]
|
||||
async fn succeeds_without_retry() {
|
||||
let calls = Arc::new(AtomicUsize::new(0));
|
||||
|
|
@ -127,7 +514,7 @@ mod tests {
|
|||
1,
|
||||
);
|
||||
|
||||
let result = provider.chat("hello", "test", 0.0).await.unwrap();
|
||||
let result = provider.simple_chat("hello", "test", 0.0).await.unwrap();
|
||||
assert_eq!(result, "ok");
|
||||
assert_eq!(calls.load(Ordering::SeqCst), 1);
|
||||
}
|
||||
|
|
@ -149,7 +536,7 @@ mod tests {
|
|||
1,
|
||||
);
|
||||
|
||||
let result = provider.chat("hello", "test", 0.0).await.unwrap();
|
||||
let result = provider.simple_chat("hello", "test", 0.0).await.unwrap();
|
||||
assert_eq!(result, "recovered");
|
||||
assert_eq!(calls.load(Ordering::SeqCst), 2);
|
||||
}
|
||||
|
|
@ -184,7 +571,7 @@ mod tests {
|
|||
1,
|
||||
);
|
||||
|
||||
let result = provider.chat("hello", "test", 0.0).await.unwrap();
|
||||
let result = provider.simple_chat("hello", "test", 0.0).await.unwrap();
|
||||
assert_eq!(result, "from fallback");
|
||||
assert_eq!(primary_calls.load(Ordering::SeqCst), 2);
|
||||
assert_eq!(fallback_calls.load(Ordering::SeqCst), 1);
|
||||
|
|
@ -218,12 +605,323 @@ mod tests {
|
|||
);
|
||||
|
||||
let err = provider
|
||||
.chat("hello", "test", 0.0)
|
||||
.simple_chat("hello", "test", 0.0)
|
||||
.await
|
||||
.expect_err("all providers should fail");
|
||||
let msg = err.to_string();
|
||||
assert!(msg.contains("All providers failed"));
|
||||
assert!(msg.contains("p1 attempt 1/1"));
|
||||
assert!(msg.contains("p2 attempt 1/1"));
|
||||
assert!(msg.contains("All providers/models failed"));
|
||||
assert!(msg.contains("p1"));
|
||||
assert!(msg.contains("p2"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn non_retryable_detects_common_patterns() {
|
||||
assert!(is_non_retryable(&anyhow::anyhow!("400 Bad Request")));
|
||||
assert!(is_non_retryable(&anyhow::anyhow!("401 Unauthorized")));
|
||||
assert!(is_non_retryable(&anyhow::anyhow!("403 Forbidden")));
|
||||
assert!(is_non_retryable(&anyhow::anyhow!("404 Not Found")));
|
||||
assert!(!is_non_retryable(&anyhow::anyhow!("429 Too Many Requests")));
|
||||
assert!(!is_non_retryable(&anyhow::anyhow!("408 Request Timeout")));
|
||||
assert!(!is_non_retryable(&anyhow::anyhow!(
|
||||
"500 Internal Server Error"
|
||||
)));
|
||||
assert!(!is_non_retryable(&anyhow::anyhow!("502 Bad Gateway")));
|
||||
assert!(!is_non_retryable(&anyhow::anyhow!("timeout")));
|
||||
assert!(!is_non_retryable(&anyhow::anyhow!("connection reset")));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn skips_retries_on_non_retryable_error() {
|
||||
let primary_calls = Arc::new(AtomicUsize::new(0));
|
||||
let fallback_calls = Arc::new(AtomicUsize::new(0));
|
||||
|
||||
let provider = ReliableProvider::new(
|
||||
vec![
|
||||
(
|
||||
"primary".into(),
|
||||
Box::new(MockProvider {
|
||||
calls: Arc::clone(&primary_calls),
|
||||
fail_until_attempt: usize::MAX,
|
||||
response: "never",
|
||||
error: "401 Unauthorized",
|
||||
}),
|
||||
),
|
||||
(
|
||||
"fallback".into(),
|
||||
Box::new(MockProvider {
|
||||
calls: Arc::clone(&fallback_calls),
|
||||
fail_until_attempt: 0,
|
||||
response: "from fallback",
|
||||
error: "fallback err",
|
||||
}),
|
||||
),
|
||||
],
|
||||
3,
|
||||
1,
|
||||
);
|
||||
|
||||
let result = provider.simple_chat("hello", "test", 0.0).await.unwrap();
|
||||
assert_eq!(result, "from fallback");
|
||||
// Primary should have been called only once (no retries)
|
||||
assert_eq!(primary_calls.load(Ordering::SeqCst), 1);
|
||||
assert_eq!(fallback_calls.load(Ordering::SeqCst), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn chat_with_history_retries_then_recovers() {
|
||||
let calls = Arc::new(AtomicUsize::new(0));
|
||||
let provider = ReliableProvider::new(
|
||||
vec![(
|
||||
"primary".into(),
|
||||
Box::new(MockProvider {
|
||||
calls: Arc::clone(&calls),
|
||||
fail_until_attempt: 1,
|
||||
response: "history ok",
|
||||
error: "temporary",
|
||||
}),
|
||||
)],
|
||||
2,
|
||||
1,
|
||||
);
|
||||
|
||||
let messages = vec![ChatMessage::system("system"), ChatMessage::user("hello")];
|
||||
let result = provider
|
||||
.chat_with_history(&messages, "test", 0.0)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(result, "history ok");
|
||||
assert_eq!(calls.load(Ordering::SeqCst), 2);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn chat_with_history_falls_back() {
|
||||
let primary_calls = Arc::new(AtomicUsize::new(0));
|
||||
let fallback_calls = Arc::new(AtomicUsize::new(0));
|
||||
|
||||
let provider = ReliableProvider::new(
|
||||
vec![
|
||||
(
|
||||
"primary".into(),
|
||||
Box::new(MockProvider {
|
||||
calls: Arc::clone(&primary_calls),
|
||||
fail_until_attempt: usize::MAX,
|
||||
response: "never",
|
||||
error: "primary down",
|
||||
}),
|
||||
),
|
||||
(
|
||||
"fallback".into(),
|
||||
Box::new(MockProvider {
|
||||
calls: Arc::clone(&fallback_calls),
|
||||
fail_until_attempt: 0,
|
||||
response: "fallback ok",
|
||||
error: "fallback err",
|
||||
}),
|
||||
),
|
||||
],
|
||||
1,
|
||||
1,
|
||||
);
|
||||
|
||||
let messages = vec![ChatMessage::user("hello")];
|
||||
let result = provider
|
||||
.chat_with_history(&messages, "test", 0.0)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(result, "fallback ok");
|
||||
assert_eq!(primary_calls.load(Ordering::SeqCst), 2);
|
||||
assert_eq!(fallback_calls.load(Ordering::SeqCst), 1);
|
||||
}
|
||||
|
||||
// ── New tests: model failover ──
|
||||
|
||||
#[tokio::test]
|
||||
async fn model_failover_tries_fallback_model() {
|
||||
let calls = Arc::new(AtomicUsize::new(0));
|
||||
let mock = Arc::new(ModelAwareMock {
|
||||
calls: Arc::clone(&calls),
|
||||
models_seen: parking_lot::Mutex::new(Vec::new()),
|
||||
fail_models: vec!["claude-opus"],
|
||||
response: "ok from sonnet",
|
||||
});
|
||||
|
||||
let mut fallbacks = HashMap::new();
|
||||
fallbacks.insert("claude-opus".to_string(), vec!["claude-sonnet".to_string()]);
|
||||
|
||||
let provider = ReliableProvider::new(
|
||||
vec![(
|
||||
"anthropic".into(),
|
||||
Box::new(mock.clone()) as Box<dyn Provider>,
|
||||
)],
|
||||
0, // no retries — force immediate model failover
|
||||
1,
|
||||
)
|
||||
.with_model_fallbacks(fallbacks);
|
||||
|
||||
let result = provider
|
||||
.simple_chat("hello", "claude-opus", 0.0)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(result, "ok from sonnet");
|
||||
|
||||
let seen = mock.models_seen.lock();
|
||||
assert_eq!(seen.len(), 2);
|
||||
assert_eq!(seen[0], "claude-opus");
|
||||
assert_eq!(seen[1], "claude-sonnet");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn model_failover_all_models_fail() {
|
||||
let calls = Arc::new(AtomicUsize::new(0));
|
||||
let mock = Arc::new(ModelAwareMock {
|
||||
calls: Arc::clone(&calls),
|
||||
models_seen: parking_lot::Mutex::new(Vec::new()),
|
||||
fail_models: vec!["model-a", "model-b", "model-c"],
|
||||
response: "never",
|
||||
});
|
||||
|
||||
let mut fallbacks = HashMap::new();
|
||||
fallbacks.insert(
|
||||
"model-a".to_string(),
|
||||
vec!["model-b".to_string(), "model-c".to_string()],
|
||||
);
|
||||
|
||||
let provider = ReliableProvider::new(
|
||||
vec![("p1".into(), Box::new(mock.clone()) as Box<dyn Provider>)],
|
||||
0,
|
||||
1,
|
||||
)
|
||||
.with_model_fallbacks(fallbacks);
|
||||
|
||||
let err = provider
|
||||
.simple_chat("hello", "model-a", 0.0)
|
||||
.await
|
||||
.expect_err("all models should fail");
|
||||
assert!(err.to_string().contains("All providers/models failed"));
|
||||
|
||||
let seen = mock.models_seen.lock();
|
||||
assert_eq!(seen.len(), 3);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn no_model_fallbacks_behaves_like_before() {
|
||||
let calls = Arc::new(AtomicUsize::new(0));
|
||||
let provider = ReliableProvider::new(
|
||||
vec![(
|
||||
"primary".into(),
|
||||
Box::new(MockProvider {
|
||||
calls: Arc::clone(&calls),
|
||||
fail_until_attempt: 0,
|
||||
response: "ok",
|
||||
error: "boom",
|
||||
}),
|
||||
)],
|
||||
2,
|
||||
1,
|
||||
);
|
||||
// No model_fallbacks set — should work exactly as before
|
||||
let result = provider.simple_chat("hello", "test", 0.0).await.unwrap();
|
||||
assert_eq!(result, "ok");
|
||||
assert_eq!(calls.load(Ordering::SeqCst), 1);
|
||||
}
|
||||
|
||||
// ── New tests: auth rotation ──
|
||||
|
||||
#[tokio::test]
|
||||
async fn auth_rotation_cycles_keys() {
|
||||
let provider = ReliableProvider::new(
|
||||
vec![(
|
||||
"p".into(),
|
||||
Box::new(MockProvider {
|
||||
calls: Arc::new(AtomicUsize::new(0)),
|
||||
fail_until_attempt: 0,
|
||||
response: "ok",
|
||||
error: "",
|
||||
}),
|
||||
)],
|
||||
0,
|
||||
1,
|
||||
)
|
||||
.with_api_keys(vec!["key-a".into(), "key-b".into(), "key-c".into()]);
|
||||
|
||||
// Rotate 5 times, verify round-robin
|
||||
let keys: Vec<&str> = (0..5).map(|_| provider.rotate_key().unwrap()).collect();
|
||||
assert_eq!(keys, vec!["key-a", "key-b", "key-c", "key-a", "key-b"]);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn auth_rotation_returns_none_when_empty() {
|
||||
let provider = ReliableProvider::new(vec![], 0, 1);
|
||||
assert!(provider.rotate_key().is_none());
|
||||
}
|
||||
|
||||
// ── New tests: Retry-After parsing ──
|
||||
|
||||
#[test]
|
||||
fn parse_retry_after_integer() {
|
||||
let err = anyhow::anyhow!("429 Too Many Requests, Retry-After: 5");
|
||||
assert_eq!(parse_retry_after_ms(&err), Some(5000));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_retry_after_float() {
|
||||
let err = anyhow::anyhow!("Rate limited. retry_after: 2.5 seconds");
|
||||
assert_eq!(parse_retry_after_ms(&err), Some(2500));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_retry_after_missing() {
|
||||
let err = anyhow::anyhow!("500 Internal Server Error");
|
||||
assert_eq!(parse_retry_after_ms(&err), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rate_limited_detection() {
|
||||
assert!(is_rate_limited(&anyhow::anyhow!("429 Too Many Requests")));
|
||||
assert!(is_rate_limited(&anyhow::anyhow!(
|
||||
"HTTP 429 rate limit exceeded"
|
||||
)));
|
||||
assert!(!is_rate_limited(&anyhow::anyhow!("401 Unauthorized")));
|
||||
assert!(!is_rate_limited(&anyhow::anyhow!(
|
||||
"500 Internal Server Error"
|
||||
)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn compute_backoff_uses_retry_after() {
|
||||
let provider = ReliableProvider::new(vec![], 0, 500);
|
||||
let err = anyhow::anyhow!("429 Retry-After: 3");
|
||||
assert_eq!(provider.compute_backoff(500, &err), 3000);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn compute_backoff_caps_at_30s() {
|
||||
let provider = ReliableProvider::new(vec![], 0, 500);
|
||||
let err = anyhow::anyhow!("429 Retry-After: 120");
|
||||
assert_eq!(provider.compute_backoff(500, &err), 30_000);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn compute_backoff_falls_back_to_base() {
|
||||
let provider = ReliableProvider::new(vec![], 0, 500);
|
||||
let err = anyhow::anyhow!("500 Server Error");
|
||||
assert_eq!(provider.compute_backoff(500, &err), 500);
|
||||
}
|
||||
|
||||
// ── Arc<ModelAwareMock> Provider impl for test ──
|
||||
|
||||
#[async_trait]
|
||||
impl Provider for Arc<ModelAwareMock> {
|
||||
async fn chat_with_system(
|
||||
&self,
|
||||
system_prompt: Option<&str>,
|
||||
message: &str,
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
self.as_ref()
|
||||
.chat_with_system(system_prompt, message, model, temperature)
|
||||
.await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
385
src/providers/router.rs
Normal file
385
src/providers/router.rs
Normal file
|
|
@ -0,0 +1,385 @@
|
|||
use super::traits::{ChatMessage, ChatRequest, ChatResponse};
|
||||
use super::Provider;
|
||||
use async_trait::async_trait;
|
||||
use std::collections::HashMap;
|
||||
|
||||
/// A single route: maps a task hint to a provider + model combo.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Route {
|
||||
pub provider_name: String,
|
||||
pub model: String,
|
||||
}
|
||||
|
||||
/// Multi-model router — routes requests to different provider+model combos
|
||||
/// based on a task hint encoded in the model parameter.
|
||||
///
|
||||
/// The model parameter can be:
|
||||
/// - A regular model name (e.g. "anthropic/claude-sonnet-4") → uses default provider
|
||||
/// - A hint-prefixed string (e.g. "hint:reasoning") → resolves via route table
|
||||
///
|
||||
/// This wraps multiple pre-created providers and selects the right one per request.
|
||||
pub struct RouterProvider {
|
||||
routes: HashMap<String, (usize, String)>, // hint → (provider_index, model)
|
||||
providers: Vec<(String, Box<dyn Provider>)>,
|
||||
default_index: usize,
|
||||
default_model: String,
|
||||
}
|
||||
|
||||
impl RouterProvider {
|
||||
/// Create a new router with a default provider and optional routes.
|
||||
///
|
||||
/// `providers` is a list of (name, provider) pairs. The first one is the default.
|
||||
/// `routes` maps hint names to Route structs containing provider_name and model.
|
||||
pub fn new(
|
||||
providers: Vec<(String, Box<dyn Provider>)>,
|
||||
routes: Vec<(String, Route)>,
|
||||
default_model: String,
|
||||
) -> Self {
|
||||
// Build provider name → index lookup
|
||||
let name_to_index: HashMap<&str, usize> = providers
|
||||
.iter()
|
||||
.enumerate()
|
||||
.map(|(i, (name, _))| (name.as_str(), i))
|
||||
.collect();
|
||||
|
||||
// Resolve routes to provider indices
|
||||
let resolved_routes: HashMap<String, (usize, String)> = routes
|
||||
.into_iter()
|
||||
.filter_map(|(hint, route)| {
|
||||
let index = name_to_index.get(route.provider_name.as_str()).copied();
|
||||
match index {
|
||||
Some(i) => Some((hint, (i, route.model))),
|
||||
None => {
|
||||
tracing::warn!(
|
||||
hint = hint,
|
||||
provider = route.provider_name,
|
||||
"Route references unknown provider, skipping"
|
||||
);
|
||||
None
|
||||
}
|
||||
}
|
||||
})
|
||||
.collect();
|
||||
|
||||
Self {
|
||||
routes: resolved_routes,
|
||||
providers,
|
||||
default_index: 0,
|
||||
default_model,
|
||||
}
|
||||
}
|
||||
|
||||
/// Resolve a model parameter to a (provider, actual_model) pair.
|
||||
///
|
||||
/// If the model starts with "hint:", look up the hint in the route table.
|
||||
/// Otherwise, use the default provider with the given model name.
|
||||
/// Resolve a model parameter to a (provider_index, actual_model) pair.
|
||||
fn resolve(&self, model: &str) -> (usize, String) {
|
||||
if let Some(hint) = model.strip_prefix("hint:") {
|
||||
if let Some((idx, resolved_model)) = self.routes.get(hint) {
|
||||
return (*idx, resolved_model.clone());
|
||||
}
|
||||
tracing::warn!(
|
||||
hint = hint,
|
||||
"Unknown route hint, falling back to default provider"
|
||||
);
|
||||
}
|
||||
|
||||
// Not a hint or hint not found — use default provider with the model as-is
|
||||
(self.default_index, model.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Provider for RouterProvider {
|
||||
async fn chat_with_system(
|
||||
&self,
|
||||
system_prompt: Option<&str>,
|
||||
message: &str,
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
let (provider_idx, resolved_model) = self.resolve(model);
|
||||
|
||||
let (provider_name, provider) = &self.providers[provider_idx];
|
||||
tracing::info!(
|
||||
provider = provider_name.as_str(),
|
||||
model = resolved_model.as_str(),
|
||||
"Router dispatching request"
|
||||
);
|
||||
|
||||
provider
|
||||
.chat_with_system(system_prompt, message, &resolved_model, temperature)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn chat_with_history(
|
||||
&self,
|
||||
messages: &[ChatMessage],
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
let (provider_idx, resolved_model) = self.resolve(model);
|
||||
let (_, provider) = &self.providers[provider_idx];
|
||||
provider
|
||||
.chat_with_history(messages, &resolved_model, temperature)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn chat(
|
||||
&self,
|
||||
request: ChatRequest<'_>,
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<ChatResponse> {
|
||||
let (provider_idx, resolved_model) = self.resolve(model);
|
||||
let (_, provider) = &self.providers[provider_idx];
|
||||
provider.chat(request, &resolved_model, temperature).await
|
||||
}
|
||||
|
||||
fn supports_native_tools(&self) -> bool {
|
||||
self.providers
|
||||
.get(self.default_index)
|
||||
.map(|(_, p)| p.supports_native_tools())
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
async fn warmup(&self) -> anyhow::Result<()> {
|
||||
for (name, provider) in &self.providers {
|
||||
tracing::info!(provider = name, "Warming up routed provider");
|
||||
if let Err(e) = provider.warmup().await {
|
||||
tracing::warn!(provider = name, "Warmup failed (non-fatal): {e}");
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
use std::sync::Arc;
|
||||
|
||||
struct MockProvider {
|
||||
calls: Arc<AtomicUsize>,
|
||||
response: &'static str,
|
||||
last_model: parking_lot::Mutex<String>,
|
||||
}
|
||||
|
||||
impl MockProvider {
|
||||
fn new(response: &'static str) -> Self {
|
||||
Self {
|
||||
calls: Arc::new(AtomicUsize::new(0)),
|
||||
response,
|
||||
last_model: parking_lot::Mutex::new(String::new()),
|
||||
}
|
||||
}
|
||||
|
||||
fn call_count(&self) -> usize {
|
||||
self.calls.load(Ordering::SeqCst)
|
||||
}
|
||||
|
||||
fn last_model(&self) -> String {
|
||||
self.last_model.lock().clone()
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Provider for MockProvider {
|
||||
async fn chat_with_system(
|
||||
&self,
|
||||
_system_prompt: Option<&str>,
|
||||
_message: &str,
|
||||
model: &str,
|
||||
_temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
self.calls.fetch_add(1, Ordering::SeqCst);
|
||||
*self.last_model.lock() = model.to_string();
|
||||
Ok(self.response.to_string())
|
||||
}
|
||||
}
|
||||
|
||||
fn make_router(
|
||||
providers: Vec<(&'static str, &'static str)>,
|
||||
routes: Vec<(&str, &str, &str)>,
|
||||
) -> (RouterProvider, Vec<Arc<MockProvider>>) {
|
||||
let mocks: Vec<Arc<MockProvider>> = providers
|
||||
.iter()
|
||||
.map(|(_, response)| Arc::new(MockProvider::new(response)))
|
||||
.collect();
|
||||
|
||||
let provider_list: Vec<(String, Box<dyn Provider>)> = providers
|
||||
.iter()
|
||||
.zip(mocks.iter())
|
||||
.map(|((name, _), mock)| {
|
||||
(
|
||||
name.to_string(),
|
||||
Box::new(Arc::clone(mock)) as Box<dyn Provider>,
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let route_list: Vec<(String, Route)> = routes
|
||||
.iter()
|
||||
.map(|(hint, provider_name, model)| {
|
||||
(
|
||||
hint.to_string(),
|
||||
Route {
|
||||
provider_name: provider_name.to_string(),
|
||||
model: model.to_string(),
|
||||
},
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
|
||||
let router = RouterProvider::new(provider_list, route_list, "default-model".to_string());
|
||||
|
||||
(router, mocks)
|
||||
}
|
||||
|
||||
// Arc<MockProvider> should also be a Provider
|
||||
#[async_trait]
|
||||
impl Provider for Arc<MockProvider> {
|
||||
async fn chat_with_system(
|
||||
&self,
|
||||
system_prompt: Option<&str>,
|
||||
message: &str,
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
self.as_ref()
|
||||
.chat_with_system(system_prompt, message, model, temperature)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn routes_hint_to_correct_provider() {
|
||||
let (router, mocks) = make_router(
|
||||
vec![("fast", "fast-response"), ("smart", "smart-response")],
|
||||
vec![
|
||||
("fast", "fast", "llama-3-70b"),
|
||||
("reasoning", "smart", "claude-opus"),
|
||||
],
|
||||
);
|
||||
|
||||
let result = router
|
||||
.simple_chat("hello", "hint:reasoning", 0.5)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(result, "smart-response");
|
||||
assert_eq!(mocks[1].call_count(), 1);
|
||||
assert_eq!(mocks[1].last_model(), "claude-opus");
|
||||
assert_eq!(mocks[0].call_count(), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn routes_fast_hint() {
|
||||
let (router, mocks) = make_router(
|
||||
vec![("fast", "fast-response"), ("smart", "smart-response")],
|
||||
vec![("fast", "fast", "llama-3-70b")],
|
||||
);
|
||||
|
||||
let result = router.simple_chat("hello", "hint:fast", 0.5).await.unwrap();
|
||||
assert_eq!(result, "fast-response");
|
||||
assert_eq!(mocks[0].call_count(), 1);
|
||||
assert_eq!(mocks[0].last_model(), "llama-3-70b");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn unknown_hint_falls_back_to_default() {
|
||||
let (router, mocks) = make_router(
|
||||
vec![("default", "default-response"), ("other", "other-response")],
|
||||
vec![],
|
||||
);
|
||||
|
||||
let result = router
|
||||
.simple_chat("hello", "hint:nonexistent", 0.5)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(result, "default-response");
|
||||
assert_eq!(mocks[0].call_count(), 1);
|
||||
// Falls back to default with the hint as model name
|
||||
assert_eq!(mocks[0].last_model(), "hint:nonexistent");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn non_hint_model_uses_default_provider() {
|
||||
let (router, mocks) = make_router(
|
||||
vec![
|
||||
("primary", "primary-response"),
|
||||
("secondary", "secondary-response"),
|
||||
],
|
||||
vec![("code", "secondary", "codellama")],
|
||||
);
|
||||
|
||||
let result = router
|
||||
.simple_chat("hello", "anthropic/claude-sonnet-4-20250514", 0.5)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(result, "primary-response");
|
||||
assert_eq!(mocks[0].call_count(), 1);
|
||||
assert_eq!(mocks[0].last_model(), "anthropic/claude-sonnet-4-20250514");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_preserves_model_for_non_hints() {
|
||||
let (router, _) = make_router(vec![("default", "ok")], vec![]);
|
||||
|
||||
let (idx, model) = router.resolve("gpt-4o");
|
||||
assert_eq!(idx, 0);
|
||||
assert_eq!(model, "gpt-4o");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_strips_hint_prefix() {
|
||||
let (router, _) = make_router(
|
||||
vec![("fast", "ok"), ("smart", "ok")],
|
||||
vec![("reasoning", "smart", "claude-opus")],
|
||||
);
|
||||
|
||||
let (idx, model) = router.resolve("hint:reasoning");
|
||||
assert_eq!(idx, 1);
|
||||
assert_eq!(model, "claude-opus");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn skips_routes_with_unknown_provider() {
|
||||
let (router, _) = make_router(
|
||||
vec![("default", "ok")],
|
||||
vec![("broken", "nonexistent", "model")],
|
||||
);
|
||||
|
||||
// Route should not exist
|
||||
assert!(!router.routes.contains_key("broken"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn warmup_calls_all_providers() {
|
||||
let (router, _) = make_router(vec![("a", "ok"), ("b", "ok")], vec![]);
|
||||
|
||||
// Warmup should not error
|
||||
assert!(router.warmup().await.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn chat_with_system_passes_system_prompt() {
|
||||
let mock = Arc::new(MockProvider::new("response"));
|
||||
let router = RouterProvider::new(
|
||||
vec![(
|
||||
"default".into(),
|
||||
Box::new(Arc::clone(&mock)) as Box<dyn Provider>,
|
||||
)],
|
||||
vec![],
|
||||
"model".into(),
|
||||
);
|
||||
|
||||
let result = router
|
||||
.chat_with_system(Some("system"), "hello", "model", 0.5)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(result, "response");
|
||||
assert_eq!(mock.call_count(), 1);
|
||||
}
|
||||
}
|
||||
|
|
@ -1,12 +1,269 @@
|
|||
use crate::tools::ToolSpec;
|
||||
use async_trait::async_trait;
|
||||
use futures_util::{stream, StreamExt};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::fmt::Write;
|
||||
|
||||
/// A single message in a conversation.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChatMessage {
|
||||
pub role: String,
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
impl ChatMessage {
|
||||
pub fn system(content: impl Into<String>) -> Self {
|
||||
Self {
|
||||
role: "system".into(),
|
||||
content: content.into(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn user(content: impl Into<String>) -> Self {
|
||||
Self {
|
||||
role: "user".into(),
|
||||
content: content.into(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn assistant(content: impl Into<String>) -> Self {
|
||||
Self {
|
||||
role: "assistant".into(),
|
||||
content: content.into(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn tool(content: impl Into<String>) -> Self {
|
||||
Self {
|
||||
role: "tool".into(),
|
||||
content: content.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A tool call requested by the LLM.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ToolCall {
|
||||
pub id: String,
|
||||
pub name: String,
|
||||
pub arguments: String,
|
||||
}
|
||||
|
||||
/// An LLM response that may contain text, tool calls, or both.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ChatResponse {
|
||||
/// Text content of the response (may be empty if only tool calls).
|
||||
pub text: Option<String>,
|
||||
/// Tool calls requested by the LLM.
|
||||
pub tool_calls: Vec<ToolCall>,
|
||||
}
|
||||
|
||||
impl ChatResponse {
|
||||
/// True when the LLM wants to invoke at least one tool.
|
||||
pub fn has_tool_calls(&self) -> bool {
|
||||
!self.tool_calls.is_empty()
|
||||
}
|
||||
|
||||
/// Convenience: return text content or empty string.
|
||||
pub fn text_or_empty(&self) -> &str {
|
||||
self.text.as_deref().unwrap_or("")
|
||||
}
|
||||
}
|
||||
|
||||
/// Request payload for provider chat calls.
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
pub struct ChatRequest<'a> {
|
||||
pub messages: &'a [ChatMessage],
|
||||
pub tools: Option<&'a [ToolSpec]>,
|
||||
}
|
||||
|
||||
/// A tool result to feed back to the LLM.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ToolResultMessage {
|
||||
pub tool_call_id: String,
|
||||
pub content: String,
|
||||
}
|
||||
|
||||
/// A message in a multi-turn conversation, including tool interactions.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
#[serde(tag = "type", content = "data")]
|
||||
pub enum ConversationMessage {
|
||||
/// Regular chat message (system, user, assistant).
|
||||
Chat(ChatMessage),
|
||||
/// Tool calls from the assistant (stored for history fidelity).
|
||||
AssistantToolCalls {
|
||||
text: Option<String>,
|
||||
tool_calls: Vec<ToolCall>,
|
||||
},
|
||||
/// Results of tool executions, fed back to the LLM.
|
||||
ToolResults(Vec<ToolResultMessage>),
|
||||
}
|
||||
|
||||
/// A chunk of content from a streaming response.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct StreamChunk {
|
||||
/// Text delta for this chunk.
|
||||
pub delta: String,
|
||||
/// Whether this is the final chunk.
|
||||
pub is_final: bool,
|
||||
/// Approximate token count for this chunk (estimated).
|
||||
pub token_count: usize,
|
||||
}
|
||||
|
||||
impl StreamChunk {
|
||||
/// Create a new non-final chunk.
|
||||
pub fn delta(text: impl Into<String>) -> Self {
|
||||
Self {
|
||||
delta: text.into(),
|
||||
is_final: false,
|
||||
token_count: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create a final chunk.
|
||||
pub fn final_chunk() -> Self {
|
||||
Self {
|
||||
delta: String::new(),
|
||||
is_final: true,
|
||||
token_count: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Create an error chunk.
|
||||
pub fn error(message: impl Into<String>) -> Self {
|
||||
Self {
|
||||
delta: message.into(),
|
||||
is_final: true,
|
||||
token_count: 0,
|
||||
}
|
||||
}
|
||||
|
||||
/// Estimate tokens (rough approximation: ~4 chars per token).
|
||||
pub fn with_token_estimate(mut self) -> Self {
|
||||
self.token_count = self.delta.len().div_ceil(4);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// Options for streaming chat requests.
|
||||
#[derive(Debug, Clone, Copy, Default)]
|
||||
pub struct StreamOptions {
|
||||
/// Whether to enable streaming (default: true).
|
||||
pub enabled: bool,
|
||||
/// Whether to include token counts in chunks.
|
||||
pub count_tokens: bool,
|
||||
}
|
||||
|
||||
impl StreamOptions {
|
||||
/// Create new streaming options with enabled flag.
|
||||
pub fn new(enabled: bool) -> Self {
|
||||
Self {
|
||||
enabled,
|
||||
count_tokens: false,
|
||||
}
|
||||
}
|
||||
|
||||
/// Enable token counting.
|
||||
pub fn with_token_count(mut self) -> Self {
|
||||
self.count_tokens = true;
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
/// Result type for streaming operations.
|
||||
pub type StreamResult<T> = std::result::Result<T, StreamError>;
|
||||
|
||||
/// Errors that can occur during streaming.
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum StreamError {
|
||||
#[error("HTTP error: {0}")]
|
||||
Http(reqwest::Error),
|
||||
|
||||
#[error("JSON parse error: {0}")]
|
||||
Json(serde_json::Error),
|
||||
|
||||
#[error("Invalid SSE format: {0}")]
|
||||
InvalidSse(String),
|
||||
|
||||
#[error("Provider error: {0}")]
|
||||
Provider(String),
|
||||
|
||||
#[error("IO error: {0}")]
|
||||
Io(#[from] std::io::Error),
|
||||
}
|
||||
|
||||
/// Provider capabilities declaration.
|
||||
///
|
||||
/// Describes what features a provider supports, enabling intelligent
|
||||
/// adaptation of tool calling modes and request formatting.
|
||||
#[derive(Debug, Clone, Default, PartialEq, Eq)]
|
||||
pub struct ProviderCapabilities {
|
||||
/// Whether the provider supports native tool calling via API primitives.
|
||||
///
|
||||
/// When `true`, the provider can convert tool definitions to API-native
|
||||
/// formats (e.g., Gemini's functionDeclarations, Anthropic's input_schema).
|
||||
///
|
||||
/// When `false`, tools must be injected via system prompt as text.
|
||||
pub native_tool_calling: bool,
|
||||
}
|
||||
|
||||
/// Provider-specific tool payload formats.
|
||||
///
|
||||
/// Different LLM providers require different formats for tool definitions.
|
||||
/// This enum encapsulates those variations, enabling providers to convert
|
||||
/// from the unified `ToolSpec` format to their native API requirements.
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum ToolsPayload {
|
||||
/// Gemini API format (functionDeclarations).
|
||||
Gemini {
|
||||
function_declarations: Vec<serde_json::Value>,
|
||||
},
|
||||
/// Anthropic Messages API format (tools with input_schema).
|
||||
Anthropic { tools: Vec<serde_json::Value> },
|
||||
/// OpenAI Chat Completions API format (tools with function).
|
||||
OpenAI { tools: Vec<serde_json::Value> },
|
||||
/// Prompt-guided fallback (tools injected as text in system prompt).
|
||||
PromptGuided { instructions: String },
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait Provider: Send + Sync {
|
||||
async fn chat(&self, message: &str, model: &str, temperature: f64) -> anyhow::Result<String> {
|
||||
/// Query provider capabilities.
|
||||
///
|
||||
/// Default implementation returns minimal capabilities (no native tool calling).
|
||||
/// Providers should override this to declare their actual capabilities.
|
||||
fn capabilities(&self) -> ProviderCapabilities {
|
||||
ProviderCapabilities::default()
|
||||
}
|
||||
|
||||
/// Convert tool specifications to provider-native format.
|
||||
///
|
||||
/// Default implementation returns `PromptGuided` payload, which injects
|
||||
/// tool documentation into the system prompt as text. Providers with
|
||||
/// native tool calling support should override this to return their
|
||||
/// specific format (Gemini, Anthropic, OpenAI).
|
||||
fn convert_tools(&self, tools: &[ToolSpec]) -> ToolsPayload {
|
||||
ToolsPayload::PromptGuided {
|
||||
instructions: build_tool_instructions_text(tools),
|
||||
}
|
||||
}
|
||||
|
||||
/// Simple one-shot chat (single user message, no explicit system prompt).
|
||||
///
|
||||
/// This is the preferred API for non-agentic direct interactions.
|
||||
async fn simple_chat(
|
||||
&self,
|
||||
message: &str,
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
self.chat_with_system(None, message, model, temperature)
|
||||
.await
|
||||
}
|
||||
|
||||
/// One-shot chat with optional system prompt.
|
||||
///
|
||||
/// Kept for compatibility and advanced one-shot prompting.
|
||||
async fn chat_with_system(
|
||||
&self,
|
||||
system_prompt: Option<&str>,
|
||||
|
|
@ -14,4 +271,605 @@ pub trait Provider: Send + Sync {
|
|||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<String>;
|
||||
|
||||
/// Multi-turn conversation. Default implementation extracts the last user
|
||||
/// message and delegates to `chat_with_system`.
|
||||
async fn chat_with_history(
|
||||
&self,
|
||||
messages: &[ChatMessage],
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
let system = messages
|
||||
.iter()
|
||||
.find(|m| m.role == "system")
|
||||
.map(|m| m.content.as_str());
|
||||
let last_user = messages
|
||||
.iter()
|
||||
.rfind(|m| m.role == "user")
|
||||
.map(|m| m.content.as_str())
|
||||
.unwrap_or("");
|
||||
self.chat_with_system(system, last_user, model, temperature)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Structured chat API for agent loop callers.
|
||||
async fn chat(
|
||||
&self,
|
||||
request: ChatRequest<'_>,
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<ChatResponse> {
|
||||
// If tools are provided but provider doesn't support native tools,
|
||||
// inject tool instructions into system prompt as fallback.
|
||||
if let Some(tools) = request.tools {
|
||||
if !tools.is_empty() && !self.supports_native_tools() {
|
||||
let tool_instructions = match self.convert_tools(tools) {
|
||||
ToolsPayload::PromptGuided { instructions } => instructions,
|
||||
payload => {
|
||||
anyhow::bail!(
|
||||
"Provider returned non-prompt-guided tools payload ({payload:?}) while supports_native_tools() is false"
|
||||
)
|
||||
}
|
||||
};
|
||||
let mut modified_messages = request.messages.to_vec();
|
||||
|
||||
// Inject tool instructions into an existing system message.
|
||||
// If none exists, prepend one to the conversation.
|
||||
if let Some(system_message) =
|
||||
modified_messages.iter_mut().find(|m| m.role == "system")
|
||||
{
|
||||
if !system_message.content.is_empty() {
|
||||
system_message.content.push_str("\n\n");
|
||||
}
|
||||
system_message.content.push_str(&tool_instructions);
|
||||
} else {
|
||||
modified_messages.insert(0, ChatMessage::system(tool_instructions));
|
||||
}
|
||||
|
||||
let text = self
|
||||
.chat_with_history(&modified_messages, model, temperature)
|
||||
.await?;
|
||||
return Ok(ChatResponse {
|
||||
text: Some(text),
|
||||
tool_calls: Vec::new(),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
let text = self
|
||||
.chat_with_history(request.messages, model, temperature)
|
||||
.await?;
|
||||
Ok(ChatResponse {
|
||||
text: Some(text),
|
||||
tool_calls: Vec::new(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Whether provider supports native tool calls over API.
|
||||
fn supports_native_tools(&self) -> bool {
|
||||
self.capabilities().native_tool_calling
|
||||
}
|
||||
|
||||
/// Warm up the HTTP connection pool (TLS handshake, DNS, HTTP/2 setup).
|
||||
/// Default implementation is a no-op; providers with HTTP clients should override.
|
||||
async fn warmup(&self) -> anyhow::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Chat with tool definitions for native function calling support.
|
||||
/// The default implementation falls back to chat_with_history and returns
|
||||
/// an empty tool_calls vector (prompt-based tool use only).
|
||||
async fn chat_with_tools(
|
||||
&self,
|
||||
messages: &[ChatMessage],
|
||||
_tools: &[serde_json::Value],
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<ChatResponse> {
|
||||
let text = self.chat_with_history(messages, model, temperature).await?;
|
||||
Ok(ChatResponse {
|
||||
text: Some(text),
|
||||
tool_calls: Vec::new(),
|
||||
})
|
||||
}
|
||||
|
||||
/// Whether provider supports streaming responses.
|
||||
/// Default implementation returns false.
|
||||
fn supports_streaming(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
/// Streaming chat with optional system prompt.
|
||||
/// Returns an async stream of text chunks.
|
||||
/// Default implementation falls back to non-streaming chat.
|
||||
fn stream_chat_with_system(
|
||||
&self,
|
||||
_system_prompt: Option<&str>,
|
||||
_message: &str,
|
||||
_model: &str,
|
||||
_temperature: f64,
|
||||
_options: StreamOptions,
|
||||
) -> stream::BoxStream<'static, StreamResult<StreamChunk>> {
|
||||
// Default: return an empty stream (not supported)
|
||||
stream::empty().boxed()
|
||||
}
|
||||
|
||||
/// Streaming chat with history.
|
||||
/// Default implementation falls back to stream_chat_with_system with last user message.
|
||||
fn stream_chat_with_history(
|
||||
&self,
|
||||
_messages: &[ChatMessage],
|
||||
_model: &str,
|
||||
_temperature: f64,
|
||||
_options: StreamOptions,
|
||||
) -> stream::BoxStream<'static, StreamResult<StreamChunk>> {
|
||||
// For default implementation, we need to convert to owned strings
|
||||
// This is a limitation of the default implementation
|
||||
let provider_name = "unknown".to_string();
|
||||
|
||||
// Create a single empty chunk to indicate not supported
|
||||
let chunk = StreamChunk::error(format!("{} does not support streaming", provider_name));
|
||||
stream::once(async move { Ok(chunk) }).boxed()
|
||||
}
|
||||
}
|
||||
|
||||
/// Build tool instructions text for prompt-guided tool calling.
|
||||
///
|
||||
/// Generates a formatted text block describing available tools and how to
|
||||
/// invoke them using XML-style tags. This is used as a fallback when the
|
||||
/// provider doesn't support native tool calling.
|
||||
pub fn build_tool_instructions_text(tools: &[ToolSpec]) -> String {
|
||||
let mut instructions = String::new();
|
||||
|
||||
instructions.push_str("## Tool Use Protocol\n\n");
|
||||
instructions.push_str("To use a tool, wrap a JSON object in <tool_call></tool_call> tags:\n\n");
|
||||
instructions.push_str("<tool_call>\n");
|
||||
instructions.push_str(r#"{"name": "tool_name", "arguments": {"param": "value"}}"#);
|
||||
instructions.push_str("\n</tool_call>\n\n");
|
||||
instructions.push_str("You may use multiple tool calls in a single response. ");
|
||||
instructions.push_str("After tool execution, results appear in <tool_result> tags. ");
|
||||
instructions
|
||||
.push_str("Continue reasoning with the results until you can give a final answer.\n\n");
|
||||
instructions.push_str("### Available Tools\n\n");
|
||||
|
||||
for tool in tools {
|
||||
writeln!(&mut instructions, "**{}**: {}", tool.name, tool.description)
|
||||
.expect("writing to String cannot fail");
|
||||
|
||||
let parameters =
|
||||
serde_json::to_string(&tool.parameters).unwrap_or_else(|_| "{}".to_string());
|
||||
writeln!(&mut instructions, "Parameters: `{parameters}`")
|
||||
.expect("writing to String cannot fail");
|
||||
instructions.push('\n');
|
||||
}
|
||||
|
||||
instructions
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
struct CapabilityMockProvider;
|
||||
|
||||
#[async_trait]
|
||||
impl Provider for CapabilityMockProvider {
|
||||
fn capabilities(&self) -> ProviderCapabilities {
|
||||
ProviderCapabilities {
|
||||
native_tool_calling: true,
|
||||
}
|
||||
}
|
||||
|
||||
async fn chat_with_system(
|
||||
&self,
|
||||
_system_prompt: Option<&str>,
|
||||
_message: &str,
|
||||
_model: &str,
|
||||
_temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
Ok("ok".into())
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn chat_message_constructors() {
|
||||
let sys = ChatMessage::system("Be helpful");
|
||||
assert_eq!(sys.role, "system");
|
||||
assert_eq!(sys.content, "Be helpful");
|
||||
|
||||
let user = ChatMessage::user("Hello");
|
||||
assert_eq!(user.role, "user");
|
||||
|
||||
let asst = ChatMessage::assistant("Hi there");
|
||||
assert_eq!(asst.role, "assistant");
|
||||
|
||||
let tool = ChatMessage::tool("{}");
|
||||
assert_eq!(tool.role, "tool");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn chat_response_helpers() {
|
||||
let empty = ChatResponse {
|
||||
text: None,
|
||||
tool_calls: vec![],
|
||||
};
|
||||
assert!(!empty.has_tool_calls());
|
||||
assert_eq!(empty.text_or_empty(), "");
|
||||
|
||||
let with_tools = ChatResponse {
|
||||
text: Some("Let me check".into()),
|
||||
tool_calls: vec![ToolCall {
|
||||
id: "1".into(),
|
||||
name: "shell".into(),
|
||||
arguments: "{}".into(),
|
||||
}],
|
||||
};
|
||||
assert!(with_tools.has_tool_calls());
|
||||
assert_eq!(with_tools.text_or_empty(), "Let me check");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tool_call_serialization() {
|
||||
let tc = ToolCall {
|
||||
id: "call_123".into(),
|
||||
name: "file_read".into(),
|
||||
arguments: r#"{"path":"test.txt"}"#.into(),
|
||||
};
|
||||
let json = serde_json::to_string(&tc).unwrap();
|
||||
assert!(json.contains("call_123"));
|
||||
assert!(json.contains("file_read"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn conversation_message_variants() {
|
||||
let chat = ConversationMessage::Chat(ChatMessage::user("hi"));
|
||||
let json = serde_json::to_string(&chat).unwrap();
|
||||
assert!(json.contains("\"type\":\"Chat\""));
|
||||
|
||||
let tool_result = ConversationMessage::ToolResults(vec![ToolResultMessage {
|
||||
tool_call_id: "1".into(),
|
||||
content: "done".into(),
|
||||
}]);
|
||||
let json = serde_json::to_string(&tool_result).unwrap();
|
||||
assert!(json.contains("\"type\":\"ToolResults\""));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn provider_capabilities_default() {
|
||||
let caps = ProviderCapabilities::default();
|
||||
assert!(!caps.native_tool_calling);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn provider_capabilities_equality() {
|
||||
let caps1 = ProviderCapabilities {
|
||||
native_tool_calling: true,
|
||||
};
|
||||
let caps2 = ProviderCapabilities {
|
||||
native_tool_calling: true,
|
||||
};
|
||||
let caps3 = ProviderCapabilities {
|
||||
native_tool_calling: false,
|
||||
};
|
||||
|
||||
assert_eq!(caps1, caps2);
|
||||
assert_ne!(caps1, caps3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn supports_native_tools_reflects_capabilities_default_mapping() {
|
||||
let provider = CapabilityMockProvider;
|
||||
assert!(provider.supports_native_tools());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tools_payload_variants() {
|
||||
// Test Gemini variant
|
||||
let gemini = ToolsPayload::Gemini {
|
||||
function_declarations: vec![serde_json::json!({"name": "test"})],
|
||||
};
|
||||
assert!(matches!(gemini, ToolsPayload::Gemini { .. }));
|
||||
|
||||
// Test Anthropic variant
|
||||
let anthropic = ToolsPayload::Anthropic {
|
||||
tools: vec![serde_json::json!({"name": "test"})],
|
||||
};
|
||||
assert!(matches!(anthropic, ToolsPayload::Anthropic { .. }));
|
||||
|
||||
// Test OpenAI variant
|
||||
let openai = ToolsPayload::OpenAI {
|
||||
tools: vec![serde_json::json!({"type": "function"})],
|
||||
};
|
||||
assert!(matches!(openai, ToolsPayload::OpenAI { .. }));
|
||||
|
||||
// Test PromptGuided variant
|
||||
let prompt_guided = ToolsPayload::PromptGuided {
|
||||
instructions: "Use tools...".to_string(),
|
||||
};
|
||||
assert!(matches!(prompt_guided, ToolsPayload::PromptGuided { .. }));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_tool_instructions_text_format() {
|
||||
let tools = vec![
|
||||
ToolSpec {
|
||||
name: "shell".to_string(),
|
||||
description: "Execute commands".to_string(),
|
||||
parameters: serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"command": {"type": "string"}
|
||||
}
|
||||
}),
|
||||
},
|
||||
ToolSpec {
|
||||
name: "file_read".to_string(),
|
||||
description: "Read files".to_string(),
|
||||
parameters: serde_json::json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"path": {"type": "string"}
|
||||
}
|
||||
}),
|
||||
},
|
||||
];
|
||||
|
||||
let instructions = build_tool_instructions_text(&tools);
|
||||
|
||||
// Check for protocol description
|
||||
assert!(instructions.contains("Tool Use Protocol"));
|
||||
assert!(instructions.contains("<tool_call>"));
|
||||
assert!(instructions.contains("</tool_call>"));
|
||||
|
||||
// Check for tool listings
|
||||
assert!(instructions.contains("**shell**"));
|
||||
assert!(instructions.contains("Execute commands"));
|
||||
assert!(instructions.contains("**file_read**"));
|
||||
assert!(instructions.contains("Read files"));
|
||||
|
||||
// Check for parameters
|
||||
assert!(instructions.contains("Parameters:"));
|
||||
assert!(instructions.contains(r#""type":"object""#));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_tool_instructions_text_empty() {
|
||||
let instructions = build_tool_instructions_text(&[]);
|
||||
|
||||
// Should still have protocol description
|
||||
assert!(instructions.contains("Tool Use Protocol"));
|
||||
|
||||
// Should have empty tools section
|
||||
assert!(instructions.contains("Available Tools"));
|
||||
}
|
||||
|
||||
// Mock provider for testing.
|
||||
struct MockProvider {
|
||||
supports_native: bool,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Provider for MockProvider {
|
||||
fn supports_native_tools(&self) -> bool {
|
||||
self.supports_native
|
||||
}
|
||||
|
||||
async fn chat_with_system(
|
||||
&self,
|
||||
_system: Option<&str>,
|
||||
_message: &str,
|
||||
_model: &str,
|
||||
_temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
Ok("response".to_string())
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn provider_convert_tools_default() {
|
||||
let provider = MockProvider {
|
||||
supports_native: false,
|
||||
};
|
||||
|
||||
let tools = vec![ToolSpec {
|
||||
name: "test_tool".to_string(),
|
||||
description: "A test tool".to_string(),
|
||||
parameters: serde_json::json!({"type": "object"}),
|
||||
}];
|
||||
|
||||
let payload = provider.convert_tools(&tools);
|
||||
|
||||
// Default implementation should return PromptGuided.
|
||||
assert!(matches!(payload, ToolsPayload::PromptGuided { .. }));
|
||||
|
||||
if let ToolsPayload::PromptGuided { instructions } = payload {
|
||||
assert!(instructions.contains("test_tool"));
|
||||
assert!(instructions.contains("A test tool"));
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn provider_chat_prompt_guided_fallback() {
|
||||
let provider = MockProvider {
|
||||
supports_native: false,
|
||||
};
|
||||
|
||||
let tools = vec![ToolSpec {
|
||||
name: "shell".to_string(),
|
||||
description: "Run commands".to_string(),
|
||||
parameters: serde_json::json!({"type": "object"}),
|
||||
}];
|
||||
|
||||
let request = ChatRequest {
|
||||
messages: &[ChatMessage::user("Hello")],
|
||||
tools: Some(&tools),
|
||||
};
|
||||
|
||||
let response = provider.chat(request, "model", 0.7).await.unwrap();
|
||||
|
||||
// Should return a response (default impl calls chat_with_history).
|
||||
assert!(response.text.is_some());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn provider_chat_without_tools() {
|
||||
let provider = MockProvider {
|
||||
supports_native: true,
|
||||
};
|
||||
|
||||
let request = ChatRequest {
|
||||
messages: &[ChatMessage::user("Hello")],
|
||||
tools: None,
|
||||
};
|
||||
|
||||
let response = provider.chat(request, "model", 0.7).await.unwrap();
|
||||
|
||||
// Should work normally without tools.
|
||||
assert!(response.text.is_some());
|
||||
}
|
||||
|
||||
// Provider that echoes the system prompt for assertions.
|
||||
struct EchoSystemProvider {
|
||||
supports_native: bool,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Provider for EchoSystemProvider {
|
||||
fn supports_native_tools(&self) -> bool {
|
||||
self.supports_native
|
||||
}
|
||||
|
||||
async fn chat_with_system(
|
||||
&self,
|
||||
system: Option<&str>,
|
||||
_message: &str,
|
||||
_model: &str,
|
||||
_temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
Ok(system.unwrap_or_default().to_string())
|
||||
}
|
||||
}
|
||||
|
||||
// Provider with custom prompt-guided conversion.
|
||||
struct CustomConvertProvider;
|
||||
|
||||
#[async_trait]
|
||||
impl Provider for CustomConvertProvider {
|
||||
fn supports_native_tools(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn convert_tools(&self, _tools: &[ToolSpec]) -> ToolsPayload {
|
||||
ToolsPayload::PromptGuided {
|
||||
instructions: "CUSTOM_TOOL_INSTRUCTIONS".to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
async fn chat_with_system(
|
||||
&self,
|
||||
system: Option<&str>,
|
||||
_message: &str,
|
||||
_model: &str,
|
||||
_temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
Ok(system.unwrap_or_default().to_string())
|
||||
}
|
||||
}
|
||||
|
||||
// Provider returning an invalid payload for non-native mode.
|
||||
struct InvalidConvertProvider;
|
||||
|
||||
#[async_trait]
|
||||
impl Provider for InvalidConvertProvider {
|
||||
fn supports_native_tools(&self) -> bool {
|
||||
false
|
||||
}
|
||||
|
||||
fn convert_tools(&self, _tools: &[ToolSpec]) -> ToolsPayload {
|
||||
ToolsPayload::OpenAI {
|
||||
tools: vec![serde_json::json!({"type": "function"})],
|
||||
}
|
||||
}
|
||||
|
||||
async fn chat_with_system(
|
||||
&self,
|
||||
_system: Option<&str>,
|
||||
_message: &str,
|
||||
_model: &str,
|
||||
_temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
Ok("should_not_reach".to_string())
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn provider_chat_prompt_guided_preserves_existing_system_not_first() {
|
||||
let provider = EchoSystemProvider {
|
||||
supports_native: false,
|
||||
};
|
||||
|
||||
let tools = vec![ToolSpec {
|
||||
name: "shell".to_string(),
|
||||
description: "Run commands".to_string(),
|
||||
parameters: serde_json::json!({"type": "object"}),
|
||||
}];
|
||||
|
||||
let request = ChatRequest {
|
||||
messages: &[
|
||||
ChatMessage::user("Hello"),
|
||||
ChatMessage::system("BASE_SYSTEM_PROMPT"),
|
||||
],
|
||||
tools: Some(&tools),
|
||||
};
|
||||
|
||||
let response = provider.chat(request, "model", 0.7).await.unwrap();
|
||||
let text = response.text.unwrap_or_default();
|
||||
|
||||
assert!(text.contains("BASE_SYSTEM_PROMPT"));
|
||||
assert!(text.contains("Tool Use Protocol"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn provider_chat_prompt_guided_uses_convert_tools_override() {
|
||||
let provider = CustomConvertProvider;
|
||||
|
||||
let tools = vec![ToolSpec {
|
||||
name: "shell".to_string(),
|
||||
description: "Run commands".to_string(),
|
||||
parameters: serde_json::json!({"type": "object"}),
|
||||
}];
|
||||
|
||||
let request = ChatRequest {
|
||||
messages: &[ChatMessage::system("BASE"), ChatMessage::user("Hello")],
|
||||
tools: Some(&tools),
|
||||
};
|
||||
|
||||
let response = provider.chat(request, "model", 0.7).await.unwrap();
|
||||
let text = response.text.unwrap_or_default();
|
||||
|
||||
assert!(text.contains("BASE"));
|
||||
assert!(text.contains("CUSTOM_TOOL_INSTRUCTIONS"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn provider_chat_prompt_guided_rejects_non_prompt_payload() {
|
||||
let provider = InvalidConvertProvider;
|
||||
|
||||
let tools = vec![ToolSpec {
|
||||
name: "shell".to_string(),
|
||||
description: "Run commands".to_string(),
|
||||
parameters: serde_json::json!({"type": "object"}),
|
||||
}];
|
||||
|
||||
let request = ChatRequest {
|
||||
messages: &[ChatMessage::user("Hello")],
|
||||
tools: Some(&tools),
|
||||
};
|
||||
|
||||
let err = provider.chat(request, "model", 0.7).await.unwrap_err();
|
||||
let message = err.to_string();
|
||||
|
||||
assert!(message.contains("non-prompt-guided"));
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue