Merge branch 'main' into pr-484-clean
This commit is contained in:
commit
ee05d62ce4
90 changed files with 6937 additions and 1403 deletions
|
|
@ -106,17 +106,17 @@ struct NativeContentIn {
|
|||
}
|
||||
|
||||
impl AnthropicProvider {
|
||||
pub fn new(api_key: Option<&str>) -> Self {
|
||||
Self::with_base_url(api_key, None)
|
||||
pub fn new(credential: Option<&str>) -> Self {
|
||||
Self::with_base_url(credential, None)
|
||||
}
|
||||
|
||||
pub fn with_base_url(api_key: Option<&str>, base_url: Option<&str>) -> Self {
|
||||
pub fn with_base_url(credential: Option<&str>, base_url: Option<&str>) -> Self {
|
||||
let base_url = base_url
|
||||
.map(|u| u.trim_end_matches('/'))
|
||||
.unwrap_or("https://api.anthropic.com")
|
||||
.to_string();
|
||||
Self {
|
||||
credential: api_key
|
||||
credential: credential
|
||||
.map(str::trim)
|
||||
.filter(|k| !k.is_empty())
|
||||
.map(ToString::to_string),
|
||||
|
|
@ -410,9 +410,9 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn creates_with_key() {
|
||||
let p = AnthropicProvider::new(Some("sk-ant-test123"));
|
||||
let p = AnthropicProvider::new(Some("anthropic-test-credential"));
|
||||
assert!(p.credential.is_some());
|
||||
assert_eq!(p.credential.as_deref(), Some("sk-ant-test123"));
|
||||
assert_eq!(p.credential.as_deref(), Some("anthropic-test-credential"));
|
||||
assert_eq!(p.base_url, "https://api.anthropic.com");
|
||||
}
|
||||
|
||||
|
|
@ -431,17 +431,19 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn creates_with_whitespace_key() {
|
||||
let p = AnthropicProvider::new(Some(" sk-ant-test123 "));
|
||||
let p = AnthropicProvider::new(Some(" anthropic-test-credential "));
|
||||
assert!(p.credential.is_some());
|
||||
assert_eq!(p.credential.as_deref(), Some("sk-ant-test123"));
|
||||
assert_eq!(p.credential.as_deref(), Some("anthropic-test-credential"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn creates_with_custom_base_url() {
|
||||
let p =
|
||||
AnthropicProvider::with_base_url(Some("sk-ant-test"), Some("https://api.example.com"));
|
||||
let p = AnthropicProvider::with_base_url(
|
||||
Some("anthropic-credential"),
|
||||
Some("https://api.example.com"),
|
||||
);
|
||||
assert_eq!(p.base_url, "https://api.example.com");
|
||||
assert_eq!(p.credential.as_deref(), Some("sk-ant-test"));
|
||||
assert_eq!(p.credential.as_deref(), Some("anthropic-credential"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ use serde::{Deserialize, Serialize};
|
|||
pub struct OpenAiCompatibleProvider {
|
||||
pub(crate) name: String,
|
||||
pub(crate) base_url: String,
|
||||
pub(crate) api_key: Option<String>,
|
||||
pub(crate) credential: Option<String>,
|
||||
pub(crate) auth_header: AuthStyle,
|
||||
/// When false, do not fall back to /v1/responses on chat completions 404.
|
||||
/// GLM/Zhipu does not support the responses API.
|
||||
|
|
@ -37,11 +37,16 @@ pub enum AuthStyle {
|
|||
}
|
||||
|
||||
impl OpenAiCompatibleProvider {
|
||||
pub fn new(name: &str, base_url: &str, api_key: Option<&str>, auth_style: AuthStyle) -> Self {
|
||||
pub fn new(
|
||||
name: &str,
|
||||
base_url: &str,
|
||||
credential: Option<&str>,
|
||||
auth_style: AuthStyle,
|
||||
) -> Self {
|
||||
Self {
|
||||
name: name.to_string(),
|
||||
base_url: base_url.trim_end_matches('/').to_string(),
|
||||
api_key: api_key.map(ToString::to_string),
|
||||
credential: credential.map(ToString::to_string),
|
||||
auth_header: auth_style,
|
||||
supports_responses_fallback: true,
|
||||
client: Client::builder()
|
||||
|
|
@ -57,13 +62,13 @@ impl OpenAiCompatibleProvider {
|
|||
pub fn new_no_responses_fallback(
|
||||
name: &str,
|
||||
base_url: &str,
|
||||
api_key: Option<&str>,
|
||||
credential: Option<&str>,
|
||||
auth_style: AuthStyle,
|
||||
) -> Self {
|
||||
Self {
|
||||
name: name.to_string(),
|
||||
base_url: base_url.trim_end_matches('/').to_string(),
|
||||
api_key: api_key.map(ToString::to_string),
|
||||
credential: credential.map(ToString::to_string),
|
||||
auth_header: auth_style,
|
||||
supports_responses_fallback: false,
|
||||
client: Client::builder()
|
||||
|
|
@ -405,18 +410,18 @@ impl OpenAiCompatibleProvider {
|
|||
fn apply_auth_header(
|
||||
&self,
|
||||
req: reqwest::RequestBuilder,
|
||||
api_key: &str,
|
||||
credential: &str,
|
||||
) -> reqwest::RequestBuilder {
|
||||
match &self.auth_header {
|
||||
AuthStyle::Bearer => req.header("Authorization", format!("Bearer {api_key}")),
|
||||
AuthStyle::XApiKey => req.header("x-api-key", api_key),
|
||||
AuthStyle::Custom(header) => req.header(header, api_key),
|
||||
AuthStyle::Bearer => req.header("Authorization", format!("Bearer {credential}")),
|
||||
AuthStyle::XApiKey => req.header("x-api-key", credential),
|
||||
AuthStyle::Custom(header) => req.header(header, credential),
|
||||
}
|
||||
}
|
||||
|
||||
async fn chat_via_responses(
|
||||
&self,
|
||||
api_key: &str,
|
||||
credential: &str,
|
||||
system_prompt: Option<&str>,
|
||||
message: &str,
|
||||
model: &str,
|
||||
|
|
@ -434,7 +439,7 @@ impl OpenAiCompatibleProvider {
|
|||
let url = self.responses_url();
|
||||
|
||||
let response = self
|
||||
.apply_auth_header(self.client.post(&url).json(&request), api_key)
|
||||
.apply_auth_header(self.client.post(&url).json(&request), credential)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
|
|
@ -459,7 +464,7 @@ impl Provider for OpenAiCompatibleProvider {
|
|||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
let api_key = self.api_key.as_ref().ok_or_else(|| {
|
||||
let credential = self.credential.as_ref().ok_or_else(|| {
|
||||
anyhow::anyhow!(
|
||||
"{} API key not set. Run `zeroclaw onboard` or set the appropriate env var.",
|
||||
self.name
|
||||
|
|
@ -490,7 +495,7 @@ impl Provider for OpenAiCompatibleProvider {
|
|||
let url = self.chat_completions_url();
|
||||
|
||||
let response = self
|
||||
.apply_auth_header(self.client.post(&url).json(&request), api_key)
|
||||
.apply_auth_header(self.client.post(&url).json(&request), credential)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
|
|
@ -501,7 +506,7 @@ impl Provider for OpenAiCompatibleProvider {
|
|||
|
||||
if status == reqwest::StatusCode::NOT_FOUND && self.supports_responses_fallback {
|
||||
return self
|
||||
.chat_via_responses(api_key, system_prompt, message, model)
|
||||
.chat_via_responses(credential, system_prompt, message, model)
|
||||
.await
|
||||
.map_err(|responses_err| {
|
||||
anyhow::anyhow!(
|
||||
|
|
@ -545,7 +550,7 @@ impl Provider for OpenAiCompatibleProvider {
|
|||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
let api_key = self.api_key.as_ref().ok_or_else(|| {
|
||||
let credential = self.credential.as_ref().ok_or_else(|| {
|
||||
anyhow::anyhow!(
|
||||
"{} API key not set. Run `zeroclaw onboard` or set the appropriate env var.",
|
||||
self.name
|
||||
|
|
@ -569,7 +574,7 @@ impl Provider for OpenAiCompatibleProvider {
|
|||
|
||||
let url = self.chat_completions_url();
|
||||
let response = self
|
||||
.apply_auth_header(self.client.post(&url).json(&request), api_key)
|
||||
.apply_auth_header(self.client.post(&url).json(&request), credential)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
|
|
@ -584,7 +589,7 @@ impl Provider for OpenAiCompatibleProvider {
|
|||
if let Some(user_msg) = last_user {
|
||||
return self
|
||||
.chat_via_responses(
|
||||
api_key,
|
||||
credential,
|
||||
system.map(|m| m.content.as_str()),
|
||||
&user_msg.content,
|
||||
model,
|
||||
|
|
@ -791,16 +796,20 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn creates_with_key() {
|
||||
let p = make_provider("venice", "https://api.venice.ai", Some("vn-key"));
|
||||
let p = make_provider(
|
||||
"venice",
|
||||
"https://api.venice.ai",
|
||||
Some("venice-test-credential"),
|
||||
);
|
||||
assert_eq!(p.name, "venice");
|
||||
assert_eq!(p.base_url, "https://api.venice.ai");
|
||||
assert_eq!(p.api_key.as_deref(), Some("vn-key"));
|
||||
assert_eq!(p.credential.as_deref(), Some("venice-test-credential"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn creates_without_key() {
|
||||
let p = make_provider("test", "https://example.com", None);
|
||||
assert!(p.api_key.is_none());
|
||||
assert!(p.credential.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -894,6 +903,7 @@ mod tests {
|
|||
make_provider("Groq", "https://api.groq.com/openai", None),
|
||||
make_provider("Mistral", "https://api.mistral.ai", None),
|
||||
make_provider("xAI", "https://api.x.ai", None),
|
||||
make_provider("Astrai", "https://as-trai.com/v1", None),
|
||||
];
|
||||
|
||||
for p in providers {
|
||||
|
|
|
|||
705
src/providers/copilot.rs
Normal file
705
src/providers/copilot.rs
Normal file
|
|
@ -0,0 +1,705 @@
|
|||
//! GitHub Copilot provider with OAuth device-flow authentication.
|
||||
//!
|
||||
//! Authenticates via GitHub's device code flow (same as VS Code Copilot),
|
||||
//! then exchanges the OAuth token for short-lived Copilot API keys.
|
||||
//! Tokens are cached to disk and auto-refreshed.
|
||||
//!
|
||||
//! **Note:** This uses VS Code's OAuth client ID (`Iv1.b507a08c87ecfe98`) and
|
||||
//! editor headers. This is the same approach used by LiteLLM, Codex CLI,
|
||||
//! and other third-party Copilot integrations. The Copilot token endpoint is
|
||||
//! private; there is no public OAuth scope or app registration for it.
|
||||
//! GitHub could change or revoke this at any time, which would break all
|
||||
//! third-party integrations simultaneously.
|
||||
|
||||
use crate::providers::traits::{
|
||||
ChatMessage, ChatRequest as ProviderChatRequest, ChatResponse as ProviderChatResponse,
|
||||
Provider, ToolCall as ProviderToolCall,
|
||||
};
|
||||
use crate::tools::ToolSpec;
|
||||
use async_trait::async_trait;
|
||||
use reqwest::Client;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::Mutex;
|
||||
use tracing::warn;
|
||||
|
||||
/// GitHub OAuth client ID for Copilot (VS Code extension).
|
||||
const GITHUB_CLIENT_ID: &str = "Iv1.b507a08c87ecfe98";
|
||||
const GITHUB_DEVICE_CODE_URL: &str = "https://github.com/login/device/code";
|
||||
const GITHUB_ACCESS_TOKEN_URL: &str = "https://github.com/login/oauth/access_token";
|
||||
const GITHUB_API_KEY_URL: &str = "https://api.github.com/copilot_internal/v2/token";
|
||||
const DEFAULT_API: &str = "https://api.githubcopilot.com";
|
||||
|
||||
// ── Token types ──────────────────────────────────────────────────
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct DeviceCodeResponse {
|
||||
device_code: String,
|
||||
user_code: String,
|
||||
verification_uri: String,
|
||||
#[serde(default = "default_interval")]
|
||||
interval: u64,
|
||||
#[serde(default = "default_expires_in")]
|
||||
expires_in: u64,
|
||||
}
|
||||
|
||||
fn default_interval() -> u64 {
|
||||
5
|
||||
}
|
||||
|
||||
fn default_expires_in() -> u64 {
|
||||
900
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct AccessTokenResponse {
|
||||
access_token: Option<String>,
|
||||
error: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct ApiKeyInfo {
|
||||
token: String,
|
||||
expires_at: i64,
|
||||
#[serde(default)]
|
||||
endpoints: Option<ApiEndpoints>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct ApiEndpoints {
|
||||
api: Option<String>,
|
||||
}
|
||||
|
||||
struct CachedApiKey {
|
||||
token: String,
|
||||
api_endpoint: String,
|
||||
expires_at: i64,
|
||||
}
|
||||
|
||||
// ── Chat completions types ───────────────────────────────────────
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct ApiChatRequest {
|
||||
model: String,
|
||||
messages: Vec<ApiMessage>,
|
||||
temperature: f64,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
tools: Option<Vec<NativeToolSpec>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
tool_choice: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct ApiMessage {
|
||||
role: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
content: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
tool_call_id: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
tool_calls: Option<Vec<NativeToolCall>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct NativeToolSpec {
|
||||
#[serde(rename = "type")]
|
||||
kind: String,
|
||||
function: NativeToolFunctionSpec,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct NativeToolFunctionSpec {
|
||||
name: String,
|
||||
description: String,
|
||||
parameters: serde_json::Value,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct NativeToolCall {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
id: Option<String>,
|
||||
#[serde(rename = "type", skip_serializing_if = "Option::is_none")]
|
||||
kind: Option<String>,
|
||||
function: NativeFunctionCall,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct NativeFunctionCall {
|
||||
name: String,
|
||||
arguments: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ApiChatResponse {
|
||||
choices: Vec<Choice>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct Choice {
|
||||
message: ResponseMessage,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ResponseMessage {
|
||||
#[serde(default)]
|
||||
content: Option<String>,
|
||||
#[serde(default)]
|
||||
tool_calls: Option<Vec<NativeToolCall>>,
|
||||
}
|
||||
|
||||
// ── Provider ─────────────────────────────────────────────────────
|
||||
|
||||
/// GitHub Copilot provider with automatic OAuth and token refresh.
|
||||
///
|
||||
/// On first use, prompts the user to visit github.com/login/device.
|
||||
/// Tokens are cached to `~/.config/zeroclaw/copilot/` and refreshed
|
||||
/// automatically.
|
||||
pub struct CopilotProvider {
|
||||
github_token: Option<String>,
|
||||
/// Mutex ensures only one caller refreshes tokens at a time,
|
||||
/// preventing duplicate device flow prompts or redundant API calls.
|
||||
refresh_lock: Arc<Mutex<Option<CachedApiKey>>>,
|
||||
http: Client,
|
||||
token_dir: PathBuf,
|
||||
}
|
||||
|
||||
impl CopilotProvider {
|
||||
pub fn new(github_token: Option<&str>) -> Self {
|
||||
let token_dir = directories::ProjectDirs::from("", "", "zeroclaw")
|
||||
.map(|dir| dir.config_dir().join("copilot"))
|
||||
.unwrap_or_else(|| {
|
||||
// Fall back to a user-specific temp directory to avoid
|
||||
// shared-directory symlink attacks.
|
||||
let user = std::env::var("USER")
|
||||
.or_else(|_| std::env::var("USERNAME"))
|
||||
.unwrap_or_else(|_| "unknown".to_string());
|
||||
std::env::temp_dir().join(format!("zeroclaw-copilot-{user}"))
|
||||
});
|
||||
|
||||
if let Err(err) = std::fs::create_dir_all(&token_dir) {
|
||||
warn!(
|
||||
"Failed to create Copilot token directory {:?}: {err}. Token caching is disabled.",
|
||||
token_dir
|
||||
);
|
||||
} else {
|
||||
#[cfg(unix)]
|
||||
{
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
|
||||
if let Err(err) =
|
||||
std::fs::set_permissions(&token_dir, std::fs::Permissions::from_mode(0o700))
|
||||
{
|
||||
warn!(
|
||||
"Failed to set Copilot token directory permissions on {:?}: {err}",
|
||||
token_dir
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Self {
|
||||
github_token: github_token
|
||||
.filter(|token| !token.is_empty())
|
||||
.map(String::from),
|
||||
refresh_lock: Arc::new(Mutex::new(None)),
|
||||
http: Client::builder()
|
||||
.timeout(Duration::from_secs(120))
|
||||
.connect_timeout(Duration::from_secs(10))
|
||||
.build()
|
||||
.unwrap_or_else(|_| Client::new()),
|
||||
token_dir,
|
||||
}
|
||||
}
|
||||
|
||||
/// Required headers for Copilot API requests (editor identification).
|
||||
const COPILOT_HEADERS: [(&str, &str); 4] = [
|
||||
("Editor-Version", "vscode/1.85.1"),
|
||||
("Editor-Plugin-Version", "copilot/1.155.0"),
|
||||
("User-Agent", "GithubCopilot/1.155.0"),
|
||||
("Accept", "application/json"),
|
||||
];
|
||||
|
||||
fn convert_tools(tools: Option<&[ToolSpec]>) -> Option<Vec<NativeToolSpec>> {
|
||||
tools.map(|items| {
|
||||
items
|
||||
.iter()
|
||||
.map(|tool| NativeToolSpec {
|
||||
kind: "function".to_string(),
|
||||
function: NativeToolFunctionSpec {
|
||||
name: tool.name.clone(),
|
||||
description: tool.description.clone(),
|
||||
parameters: tool.parameters.clone(),
|
||||
},
|
||||
})
|
||||
.collect()
|
||||
})
|
||||
}
|
||||
|
||||
fn convert_messages(messages: &[ChatMessage]) -> Vec<ApiMessage> {
|
||||
messages
|
||||
.iter()
|
||||
.map(|message| {
|
||||
if message.role == "assistant" {
|
||||
if let Ok(value) = serde_json::from_str::<serde_json::Value>(&message.content) {
|
||||
if let Some(tool_calls_value) = value.get("tool_calls") {
|
||||
if let Ok(parsed_calls) =
|
||||
serde_json::from_value::<Vec<ProviderToolCall>>(tool_calls_value.clone())
|
||||
{
|
||||
let tool_calls = parsed_calls
|
||||
.into_iter()
|
||||
.map(|tool_call| NativeToolCall {
|
||||
id: Some(tool_call.id),
|
||||
kind: Some("function".to_string()),
|
||||
function: NativeFunctionCall {
|
||||
name: tool_call.name,
|
||||
arguments: tool_call.arguments,
|
||||
},
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let content = value
|
||||
.get("content")
|
||||
.and_then(serde_json::Value::as_str)
|
||||
.map(ToString::to_string);
|
||||
|
||||
return ApiMessage {
|
||||
role: "assistant".to_string(),
|
||||
content,
|
||||
tool_call_id: None,
|
||||
tool_calls: Some(tool_calls),
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if message.role == "tool" {
|
||||
if let Ok(value) = serde_json::from_str::<serde_json::Value>(&message.content) {
|
||||
let tool_call_id = value
|
||||
.get("tool_call_id")
|
||||
.and_then(serde_json::Value::as_str)
|
||||
.map(ToString::to_string);
|
||||
let content = value
|
||||
.get("content")
|
||||
.and_then(serde_json::Value::as_str)
|
||||
.map(ToString::to_string);
|
||||
|
||||
return ApiMessage {
|
||||
role: "tool".to_string(),
|
||||
content,
|
||||
tool_call_id,
|
||||
tool_calls: None,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
ApiMessage {
|
||||
role: message.role.clone(),
|
||||
content: Some(message.content.clone()),
|
||||
tool_call_id: None,
|
||||
tool_calls: None,
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Send a chat completions request with required Copilot headers.
|
||||
async fn send_chat_request(
|
||||
&self,
|
||||
messages: Vec<ApiMessage>,
|
||||
tools: Option<&[ToolSpec]>,
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<ProviderChatResponse> {
|
||||
let (token, endpoint) = self.get_api_key().await?;
|
||||
let url = format!("{}/chat/completions", endpoint.trim_end_matches('/'));
|
||||
|
||||
let native_tools = Self::convert_tools(tools);
|
||||
let request = ApiChatRequest {
|
||||
model: model.to_string(),
|
||||
messages,
|
||||
temperature,
|
||||
tool_choice: native_tools.as_ref().map(|_| "auto".to_string()),
|
||||
tools: native_tools,
|
||||
};
|
||||
|
||||
let mut req = self
|
||||
.http
|
||||
.post(&url)
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.json(&request);
|
||||
|
||||
for (header, value) in &Self::COPILOT_HEADERS {
|
||||
req = req.header(*header, *value);
|
||||
}
|
||||
|
||||
let response = req.send().await?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
return Err(super::api_error("GitHub Copilot", response).await);
|
||||
}
|
||||
|
||||
let api_response: ApiChatResponse = response.json().await?;
|
||||
let choice = api_response
|
||||
.choices
|
||||
.into_iter()
|
||||
.next()
|
||||
.ok_or_else(|| anyhow::anyhow!("No response from GitHub Copilot"))?;
|
||||
|
||||
let tool_calls = choice
|
||||
.message
|
||||
.tool_calls
|
||||
.unwrap_or_default()
|
||||
.into_iter()
|
||||
.map(|tool_call| ProviderToolCall {
|
||||
id: tool_call
|
||||
.id
|
||||
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string()),
|
||||
name: tool_call.function.name,
|
||||
arguments: tool_call.function.arguments,
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(ProviderChatResponse {
|
||||
text: choice.message.content,
|
||||
tool_calls,
|
||||
})
|
||||
}
|
||||
|
||||
/// Get a valid Copilot API key, refreshing or re-authenticating as needed.
|
||||
/// Uses a Mutex to ensure only one caller refreshes at a time.
|
||||
async fn get_api_key(&self) -> anyhow::Result<(String, String)> {
|
||||
let mut cached = self.refresh_lock.lock().await;
|
||||
|
||||
if let Some(cached_key) = cached.as_ref() {
|
||||
if chrono::Utc::now().timestamp() + 120 < cached_key.expires_at {
|
||||
return Ok((cached_key.token.clone(), cached_key.api_endpoint.clone()));
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(info) = self.load_api_key_from_disk().await {
|
||||
if chrono::Utc::now().timestamp() + 120 < info.expires_at {
|
||||
let endpoint = info
|
||||
.endpoints
|
||||
.as_ref()
|
||||
.and_then(|e| e.api.clone())
|
||||
.unwrap_or_else(|| DEFAULT_API.to_string());
|
||||
let token = info.token;
|
||||
|
||||
*cached = Some(CachedApiKey {
|
||||
token: token.clone(),
|
||||
api_endpoint: endpoint.clone(),
|
||||
expires_at: info.expires_at,
|
||||
});
|
||||
return Ok((token, endpoint));
|
||||
}
|
||||
}
|
||||
|
||||
let access_token = self.get_github_access_token().await?;
|
||||
let api_key_info = self.exchange_for_api_key(&access_token).await?;
|
||||
self.save_api_key_to_disk(&api_key_info).await;
|
||||
|
||||
let endpoint = api_key_info
|
||||
.endpoints
|
||||
.as_ref()
|
||||
.and_then(|e| e.api.clone())
|
||||
.unwrap_or_else(|| DEFAULT_API.to_string());
|
||||
|
||||
*cached = Some(CachedApiKey {
|
||||
token: api_key_info.token.clone(),
|
||||
api_endpoint: endpoint.clone(),
|
||||
expires_at: api_key_info.expires_at,
|
||||
});
|
||||
|
||||
Ok((api_key_info.token, endpoint))
|
||||
}
|
||||
|
||||
/// Get a GitHub access token from config, cache, or device flow.
|
||||
async fn get_github_access_token(&self) -> anyhow::Result<String> {
|
||||
if let Some(token) = &self.github_token {
|
||||
return Ok(token.clone());
|
||||
}
|
||||
|
||||
let access_token_path = self.token_dir.join("access-token");
|
||||
if let Ok(cached) = tokio::fs::read_to_string(&access_token_path).await {
|
||||
let token = cached.trim();
|
||||
if !token.is_empty() {
|
||||
return Ok(token.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
let token = self.device_code_login().await?;
|
||||
write_file_secure(&access_token_path, &token).await;
|
||||
Ok(token)
|
||||
}
|
||||
|
||||
/// Run GitHub OAuth device code flow.
|
||||
async fn device_code_login(&self) -> anyhow::Result<String> {
|
||||
let response: DeviceCodeResponse = self
|
||||
.http
|
||||
.post(GITHUB_DEVICE_CODE_URL)
|
||||
.header("Accept", "application/json")
|
||||
.json(&serde_json::json!({
|
||||
"client_id": GITHUB_CLIENT_ID,
|
||||
"scope": "read:user"
|
||||
}))
|
||||
.send()
|
||||
.await?
|
||||
.error_for_status()?
|
||||
.json()
|
||||
.await?;
|
||||
|
||||
let mut poll_interval = Duration::from_secs(response.interval.max(5));
|
||||
let expires_in = response.expires_in.max(1);
|
||||
let expires_at = tokio::time::Instant::now() + Duration::from_secs(expires_in);
|
||||
|
||||
eprintln!(
|
||||
"\nGitHub Copilot authentication is required.\n\
|
||||
Visit: {}\n\
|
||||
Code: {}\n\
|
||||
Waiting for authorization...\n",
|
||||
response.verification_uri, response.user_code
|
||||
);
|
||||
|
||||
while tokio::time::Instant::now() < expires_at {
|
||||
tokio::time::sleep(poll_interval).await;
|
||||
|
||||
let token_response: AccessTokenResponse = self
|
||||
.http
|
||||
.post(GITHUB_ACCESS_TOKEN_URL)
|
||||
.header("Accept", "application/json")
|
||||
.json(&serde_json::json!({
|
||||
"client_id": GITHUB_CLIENT_ID,
|
||||
"device_code": response.device_code,
|
||||
"grant_type": "urn:ietf:params:oauth:grant-type:device_code"
|
||||
}))
|
||||
.send()
|
||||
.await?
|
||||
.json()
|
||||
.await?;
|
||||
|
||||
if let Some(token) = token_response.access_token {
|
||||
eprintln!("Authentication succeeded.\n");
|
||||
return Ok(token);
|
||||
}
|
||||
|
||||
match token_response.error.as_deref() {
|
||||
Some("slow_down") => {
|
||||
poll_interval += Duration::from_secs(5);
|
||||
}
|
||||
Some("authorization_pending") | None => {}
|
||||
Some("expired_token") => {
|
||||
anyhow::bail!("GitHub device authorization expired")
|
||||
}
|
||||
Some(error) => anyhow::bail!("GitHub auth failed: {error}"),
|
||||
}
|
||||
}
|
||||
|
||||
anyhow::bail!("Timed out waiting for GitHub authorization")
|
||||
}
|
||||
|
||||
/// Exchange a GitHub access token for a Copilot API key.
|
||||
async fn exchange_for_api_key(&self, access_token: &str) -> anyhow::Result<ApiKeyInfo> {
|
||||
let mut request = self.http.get(GITHUB_API_KEY_URL);
|
||||
for (header, value) in &Self::COPILOT_HEADERS {
|
||||
request = request.header(*header, *value);
|
||||
}
|
||||
request = request.header("Authorization", format!("token {access_token}"));
|
||||
|
||||
let response = request.send().await?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let body = response.text().await.unwrap_or_default();
|
||||
let sanitized = super::sanitize_api_error(&body);
|
||||
|
||||
if status.as_u16() == 401 || status.as_u16() == 403 {
|
||||
let access_token_path = self.token_dir.join("access-token");
|
||||
tokio::fs::remove_file(&access_token_path).await.ok();
|
||||
}
|
||||
|
||||
anyhow::bail!(
|
||||
"Failed to get Copilot API key ({status}): {sanitized}. \
|
||||
Ensure your GitHub account has an active Copilot subscription."
|
||||
);
|
||||
}
|
||||
|
||||
let info: ApiKeyInfo = response.json().await?;
|
||||
Ok(info)
|
||||
}
|
||||
|
||||
async fn load_api_key_from_disk(&self) -> Option<ApiKeyInfo> {
|
||||
let path = self.token_dir.join("api-key.json");
|
||||
let data = tokio::fs::read_to_string(&path).await.ok()?;
|
||||
serde_json::from_str(&data).ok()
|
||||
}
|
||||
|
||||
async fn save_api_key_to_disk(&self, info: &ApiKeyInfo) {
|
||||
let path = self.token_dir.join("api-key.json");
|
||||
if let Ok(json) = serde_json::to_string_pretty(info) {
|
||||
write_file_secure(&path, &json).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Write a file with 0600 permissions (owner read/write only).
|
||||
/// Uses `spawn_blocking` to avoid blocking the async runtime.
|
||||
async fn write_file_secure(path: &Path, content: &str) {
|
||||
let path = path.to_path_buf();
|
||||
let content = content.to_string();
|
||||
|
||||
let result = tokio::task::spawn_blocking(move || {
|
||||
#[cfg(unix)]
|
||||
{
|
||||
use std::io::Write;
|
||||
use std::os::unix::fs::{OpenOptionsExt, PermissionsExt};
|
||||
|
||||
let mut file = std::fs::OpenOptions::new()
|
||||
.write(true)
|
||||
.create(true)
|
||||
.truncate(true)
|
||||
.mode(0o600)
|
||||
.open(&path)?;
|
||||
file.write_all(content.as_bytes())?;
|
||||
|
||||
std::fs::set_permissions(&path, std::fs::Permissions::from_mode(0o600))?;
|
||||
Ok::<(), std::io::Error>(())
|
||||
}
|
||||
#[cfg(not(unix))]
|
||||
{
|
||||
std::fs::write(&path, &content)?;
|
||||
Ok::<(), std::io::Error>(())
|
||||
}
|
||||
})
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(Ok(())) => {}
|
||||
Ok(Err(err)) => warn!("Failed to write secure file: {err}"),
|
||||
Err(err) => warn!("Failed to spawn blocking write: {err}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Provider for CopilotProvider {
|
||||
async fn chat_with_system(
|
||||
&self,
|
||||
system_prompt: Option<&str>,
|
||||
message: &str,
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
let mut messages = Vec::new();
|
||||
if let Some(system) = system_prompt {
|
||||
messages.push(ApiMessage {
|
||||
role: "system".to_string(),
|
||||
content: Some(system.to_string()),
|
||||
tool_call_id: None,
|
||||
tool_calls: None,
|
||||
});
|
||||
}
|
||||
messages.push(ApiMessage {
|
||||
role: "user".to_string(),
|
||||
content: Some(message.to_string()),
|
||||
tool_call_id: None,
|
||||
tool_calls: None,
|
||||
});
|
||||
|
||||
let response = self
|
||||
.send_chat_request(messages, None, model, temperature)
|
||||
.await?;
|
||||
Ok(response.text.unwrap_or_default())
|
||||
}
|
||||
|
||||
async fn chat_with_history(
|
||||
&self,
|
||||
messages: &[ChatMessage],
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
let response = self
|
||||
.send_chat_request(Self::convert_messages(messages), None, model, temperature)
|
||||
.await?;
|
||||
Ok(response.text.unwrap_or_default())
|
||||
}
|
||||
|
||||
async fn chat(
|
||||
&self,
|
||||
request: ProviderChatRequest<'_>,
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<ProviderChatResponse> {
|
||||
self.send_chat_request(
|
||||
Self::convert_messages(request.messages),
|
||||
request.tools,
|
||||
model,
|
||||
temperature,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
fn supports_native_tools(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
async fn warmup(&self) -> anyhow::Result<()> {
|
||||
let _ = self.get_api_key().await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn new_without_token() {
|
||||
let provider = CopilotProvider::new(None);
|
||||
assert!(provider.github_token.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn new_with_token() {
|
||||
let provider = CopilotProvider::new(Some("ghp_test"));
|
||||
assert_eq!(provider.github_token.as_deref(), Some("ghp_test"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_token_treated_as_none() {
|
||||
let provider = CopilotProvider::new(Some(""));
|
||||
assert!(provider.github_token.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn cache_starts_empty() {
|
||||
let provider = CopilotProvider::new(None);
|
||||
let cached = provider.refresh_lock.lock().await;
|
||||
assert!(cached.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn copilot_headers_include_required_fields() {
|
||||
let headers = CopilotProvider::COPILOT_HEADERS;
|
||||
assert!(headers
|
||||
.iter()
|
||||
.any(|(header, _)| *header == "Editor-Version"));
|
||||
assert!(headers
|
||||
.iter()
|
||||
.any(|(header, _)| *header == "Editor-Plugin-Version"));
|
||||
assert!(headers.iter().any(|(header, _)| *header == "User-Agent"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn default_interval_and_expiry() {
|
||||
assert_eq!(default_interval(), 5);
|
||||
assert_eq!(default_expires_in(), 900);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn supports_native_tools() {
|
||||
let provider = CopilotProvider::new(None);
|
||||
assert!(provider.supports_native_tools());
|
||||
}
|
||||
}
|
||||
|
|
@ -1,5 +1,6 @@
|
|||
pub mod anthropic;
|
||||
pub mod compatible;
|
||||
pub mod copilot;
|
||||
pub mod gemini;
|
||||
pub mod ollama;
|
||||
pub mod openai;
|
||||
|
|
@ -37,9 +38,18 @@ fn token_end(input: &str, from: usize) -> usize {
|
|||
|
||||
/// Scrub known secret-like token prefixes from provider error strings.
|
||||
///
|
||||
/// Redacts tokens with prefixes like `sk-`, `xoxb-`, and `xoxp-`.
|
||||
/// Redacts tokens with prefixes like `sk-`, `xoxb-`, `xoxp-`, `ghp_`, `gho_`,
|
||||
/// `ghu_`, and `github_pat_`.
|
||||
pub fn scrub_secret_patterns(input: &str) -> String {
|
||||
const PREFIXES: [&str; 3] = ["sk-", "xoxb-", "xoxp-"];
|
||||
const PREFIXES: [&str; 7] = [
|
||||
"sk-",
|
||||
"xoxb-",
|
||||
"xoxp-",
|
||||
"ghp_",
|
||||
"gho_",
|
||||
"ghu_",
|
||||
"github_pat_",
|
||||
];
|
||||
|
||||
let mut scrubbed = input.to_string();
|
||||
|
||||
|
|
@ -104,9 +114,12 @@ pub async fn api_error(provider: &str, response: reqwest::Response) -> anyhow::E
|
|||
///
|
||||
/// For Anthropic, the provider-specific env var is `ANTHROPIC_OAUTH_TOKEN` (for setup-tokens)
|
||||
/// followed by `ANTHROPIC_API_KEY` (for regular API keys).
|
||||
fn resolve_api_key(name: &str, api_key: Option<&str>) -> Option<String> {
|
||||
if let Some(key) = api_key.map(str::trim).filter(|k| !k.is_empty()) {
|
||||
return Some(key.to_string());
|
||||
fn resolve_provider_credential(name: &str, credential_override: Option<&str>) -> Option<String> {
|
||||
if let Some(raw_override) = credential_override {
|
||||
let trimmed_override = raw_override.trim();
|
||||
if !trimmed_override.is_empty() {
|
||||
return Some(trimmed_override.to_owned());
|
||||
}
|
||||
}
|
||||
|
||||
let provider_env_candidates: Vec<&str> = match name {
|
||||
|
|
@ -135,6 +148,7 @@ fn resolve_api_key(name: &str, api_key: Option<&str>) -> Option<String> {
|
|||
"opencode" | "opencode-zen" => vec!["OPENCODE_API_KEY"],
|
||||
"vercel" | "vercel-ai" => vec!["VERCEL_API_KEY"],
|
||||
"cloudflare" | "cloudflare-ai" => vec!["CLOUDFLARE_API_KEY"],
|
||||
"astrai" => vec!["ASTRAI_API_KEY"],
|
||||
_ => vec![],
|
||||
};
|
||||
|
||||
|
|
@ -182,19 +196,28 @@ fn parse_custom_provider_url(
|
|||
}
|
||||
}
|
||||
|
||||
/// Factory: create the right provider from config
|
||||
#[allow(clippy::too_many_lines)]
|
||||
/// Factory: create the right provider from config (without custom URL)
|
||||
pub fn create_provider(name: &str, api_key: Option<&str>) -> anyhow::Result<Box<dyn Provider>> {
|
||||
let resolved_key = resolve_api_key(name, api_key);
|
||||
let key = resolved_key.as_deref();
|
||||
create_provider_with_url(name, api_key, None)
|
||||
}
|
||||
|
||||
/// Factory: create the right provider from config with optional custom base URL
|
||||
#[allow(clippy::too_many_lines)]
|
||||
pub fn create_provider_with_url(
|
||||
name: &str,
|
||||
api_key: Option<&str>,
|
||||
api_url: Option<&str>,
|
||||
) -> anyhow::Result<Box<dyn Provider>> {
|
||||
let resolved_credential = resolve_provider_credential(name, api_key);
|
||||
#[allow(clippy::option_as_ref_deref)]
|
||||
let key = resolved_credential.as_ref().map(String::as_str);
|
||||
match name {
|
||||
// ── Primary providers (custom implementations) ───────
|
||||
"openrouter" => Ok(Box::new(openrouter::OpenRouterProvider::new(key))),
|
||||
"anthropic" => Ok(Box::new(anthropic::AnthropicProvider::new(key))),
|
||||
"openai" => Ok(Box::new(openai::OpenAiProvider::new(key))),
|
||||
// Ollama is a local service that doesn't use API keys.
|
||||
// The api_key parameter is ignored to avoid it being misinterpreted as a base_url.
|
||||
"ollama" => Ok(Box::new(ollama::OllamaProvider::new(None))),
|
||||
// Ollama uses api_url for custom base URL (e.g. remote Ollama instance)
|
||||
"ollama" => Ok(Box::new(ollama::OllamaProvider::new(api_url))),
|
||||
"gemini" | "google" | "google-gemini" => {
|
||||
Ok(Box::new(gemini::GeminiProvider::new(key)))
|
||||
}
|
||||
|
|
@ -257,7 +280,7 @@ pub fn create_provider(name: &str, api_key: Option<&str>) -> anyhow::Result<Box<
|
|||
"Groq", "https://api.groq.com/openai", key, AuthStyle::Bearer,
|
||||
))),
|
||||
"mistral" => Ok(Box::new(OpenAiCompatibleProvider::new(
|
||||
"Mistral", "https://api.mistral.ai", key, AuthStyle::Bearer,
|
||||
"Mistral", "https://api.mistral.ai/v1", key, AuthStyle::Bearer,
|
||||
))),
|
||||
"xai" | "grok" => Ok(Box::new(OpenAiCompatibleProvider::new(
|
||||
"xAI", "https://api.x.ai", key, AuthStyle::Bearer,
|
||||
|
|
@ -277,11 +300,33 @@ pub fn create_provider(name: &str, api_key: Option<&str>) -> anyhow::Result<Box<
|
|||
"cohere" => Ok(Box::new(OpenAiCompatibleProvider::new(
|
||||
"Cohere", "https://api.cohere.com/compatibility", key, AuthStyle::Bearer,
|
||||
))),
|
||||
"copilot" | "github-copilot" => Ok(Box::new(OpenAiCompatibleProvider::new(
|
||||
"GitHub Copilot", "https://api.githubcopilot.com", key, AuthStyle::Bearer,
|
||||
))),
|
||||
"nvidia" | "nvidia-nim" | "build.nvidia.com" => Ok(Box::new(OpenAiCompatibleProvider::new(
|
||||
"NVIDIA NIM", "https://integrate.api.nvidia.com/v1", key, AuthStyle::Bearer,
|
||||
"copilot" | "github-copilot" => {
|
||||
Ok(Box::new(copilot::CopilotProvider::new(api_key)))
|
||||
},
|
||||
"lmstudio" | "lm-studio" => {
|
||||
let lm_studio_key = api_key
|
||||
.map(str::trim)
|
||||
.filter(|value| !value.is_empty())
|
||||
.unwrap_or("lm-studio");
|
||||
Ok(Box::new(OpenAiCompatibleProvider::new(
|
||||
"LM Studio",
|
||||
"http://localhost:1234/v1",
|
||||
Some(lm_studio_key),
|
||||
AuthStyle::Bearer,
|
||||
)))
|
||||
}
|
||||
"nvidia" | "nvidia-nim" | "build.nvidia.com" => Ok(Box::new(
|
||||
OpenAiCompatibleProvider::new(
|
||||
"NVIDIA NIM",
|
||||
"https://integrate.api.nvidia.com/v1",
|
||||
key,
|
||||
AuthStyle::Bearer,
|
||||
),
|
||||
)),
|
||||
|
||||
// ── AI inference routers ─────────────────────────────
|
||||
"astrai" => Ok(Box::new(OpenAiCompatibleProvider::new(
|
||||
"Astrai", "https://as-trai.com/v1", key, AuthStyle::Bearer,
|
||||
))),
|
||||
|
||||
// ── Bring Your Own Provider (custom URL) ───────────
|
||||
|
|
@ -326,13 +371,14 @@ pub fn create_provider(name: &str, api_key: Option<&str>) -> anyhow::Result<Box<
|
|||
pub fn create_resilient_provider(
|
||||
primary_name: &str,
|
||||
api_key: Option<&str>,
|
||||
api_url: Option<&str>,
|
||||
reliability: &crate::config::ReliabilityConfig,
|
||||
) -> anyhow::Result<Box<dyn Provider>> {
|
||||
let mut providers: Vec<(String, Box<dyn Provider>)> = Vec::new();
|
||||
|
||||
providers.push((
|
||||
primary_name.to_string(),
|
||||
create_provider(primary_name, api_key)?,
|
||||
create_provider_with_url(primary_name, api_key, api_url)?,
|
||||
));
|
||||
|
||||
for fallback in &reliability.fallback_providers {
|
||||
|
|
@ -340,21 +386,13 @@ pub fn create_resilient_provider(
|
|||
continue;
|
||||
}
|
||||
|
||||
if api_key.is_some() && fallback != "ollama" {
|
||||
tracing::warn!(
|
||||
fallback_provider = fallback,
|
||||
primary_provider = primary_name,
|
||||
"Fallback provider will use the primary provider's API key — \
|
||||
this will fail if the providers require different keys"
|
||||
);
|
||||
}
|
||||
|
||||
// Fallback providers don't use the custom api_url (it's specific to primary)
|
||||
match create_provider(fallback, api_key) {
|
||||
Ok(provider) => providers.push((fallback.clone(), provider)),
|
||||
Err(e) => {
|
||||
Err(_error) => {
|
||||
tracing::warn!(
|
||||
fallback_provider = fallback,
|
||||
"Ignoring invalid fallback provider: {e}"
|
||||
"Ignoring invalid fallback provider during initialization"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
@ -377,12 +415,13 @@ pub fn create_resilient_provider(
|
|||
pub fn create_routed_provider(
|
||||
primary_name: &str,
|
||||
api_key: Option<&str>,
|
||||
api_url: Option<&str>,
|
||||
reliability: &crate::config::ReliabilityConfig,
|
||||
model_routes: &[crate::config::ModelRouteConfig],
|
||||
default_model: &str,
|
||||
) -> anyhow::Result<Box<dyn Provider>> {
|
||||
if model_routes.is_empty() {
|
||||
return create_resilient_provider(primary_name, api_key, reliability);
|
||||
return create_resilient_provider(primary_name, api_key, api_url, reliability);
|
||||
}
|
||||
|
||||
// Collect unique provider names needed
|
||||
|
|
@ -396,12 +435,19 @@ pub fn create_routed_provider(
|
|||
// Create each provider (with its own resilience wrapper)
|
||||
let mut providers: Vec<(String, Box<dyn Provider>)> = Vec::new();
|
||||
for name in &needed {
|
||||
let key = model_routes
|
||||
let routed_credential = model_routes
|
||||
.iter()
|
||||
.find(|r| &r.provider == name)
|
||||
.and_then(|r| r.api_key.as_deref())
|
||||
.or(api_key);
|
||||
match create_resilient_provider(name, key, reliability) {
|
||||
.and_then(|r| {
|
||||
r.api_key.as_ref().and_then(|raw_key| {
|
||||
let trimmed_key = raw_key.trim();
|
||||
(!trimmed_key.is_empty()).then_some(trimmed_key)
|
||||
})
|
||||
});
|
||||
let key = routed_credential.or(api_key);
|
||||
// Only use api_url for the primary provider
|
||||
let url = if name == primary_name { api_url } else { None };
|
||||
match create_resilient_provider(name, key, url, reliability) {
|
||||
Ok(provider) => providers.push((name.clone(), provider)),
|
||||
Err(e) => {
|
||||
if name == primary_name {
|
||||
|
|
@ -409,7 +455,7 @@ pub fn create_routed_provider(
|
|||
}
|
||||
tracing::warn!(
|
||||
provider = name.as_str(),
|
||||
"Ignoring routed provider that failed to create: {e}"
|
||||
"Ignoring routed provider that failed to initialize"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
@ -441,27 +487,27 @@ mod tests {
|
|||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn resolve_api_key_prefers_explicit_argument() {
|
||||
let resolved = resolve_api_key("openrouter", Some(" explicit-key "));
|
||||
assert_eq!(resolved.as_deref(), Some("explicit-key"));
|
||||
fn resolve_provider_credential_prefers_explicit_argument() {
|
||||
let resolved = resolve_provider_credential("openrouter", Some(" explicit-key "));
|
||||
assert_eq!(resolved, Some("explicit-key".to_string()));
|
||||
}
|
||||
|
||||
// ── Primary providers ────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn factory_openrouter() {
|
||||
assert!(create_provider("openrouter", Some("sk-test")).is_ok());
|
||||
assert!(create_provider("openrouter", Some("provider-test-credential")).is_ok());
|
||||
assert!(create_provider("openrouter", None).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn factory_anthropic() {
|
||||
assert!(create_provider("anthropic", Some("sk-test")).is_ok());
|
||||
assert!(create_provider("anthropic", Some("provider-test-credential")).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn factory_openai() {
|
||||
assert!(create_provider("openai", Some("sk-test")).is_ok());
|
||||
assert!(create_provider("openai", Some("provider-test-credential")).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -556,6 +602,13 @@ mod tests {
|
|||
assert!(create_provider("dashscope-us", Some("key")).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn factory_lmstudio() {
|
||||
assert!(create_provider("lmstudio", Some("key")).is_ok());
|
||||
assert!(create_provider("lm-studio", Some("key")).is_ok());
|
||||
assert!(create_provider("lmstudio", None).is_ok());
|
||||
}
|
||||
|
||||
// ── Extended ecosystem ───────────────────────────────────
|
||||
|
||||
#[test]
|
||||
|
|
@ -614,6 +667,13 @@ mod tests {
|
|||
assert!(create_provider("build.nvidia.com", Some("nvapi-test")).is_ok());
|
||||
}
|
||||
|
||||
// ── AI inference routers ─────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn factory_astrai() {
|
||||
assert!(create_provider("astrai", Some("sk-astrai-test")).is_ok());
|
||||
}
|
||||
|
||||
// ── Custom / BYOP provider ─────────────────────────────
|
||||
|
||||
#[test]
|
||||
|
|
@ -761,17 +821,33 @@ mod tests {
|
|||
scheduler_retries: 2,
|
||||
};
|
||||
|
||||
let provider = create_resilient_provider("openrouter", Some("sk-test"), &reliability);
|
||||
let provider = create_resilient_provider(
|
||||
"openrouter",
|
||||
Some("provider-test-credential"),
|
||||
None,
|
||||
&reliability,
|
||||
);
|
||||
assert!(provider.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resilient_provider_errors_for_invalid_primary() {
|
||||
let reliability = crate::config::ReliabilityConfig::default();
|
||||
let provider = create_resilient_provider("totally-invalid", Some("sk-test"), &reliability);
|
||||
let provider = create_resilient_provider(
|
||||
"totally-invalid",
|
||||
Some("provider-test-credential"),
|
||||
None,
|
||||
&reliability,
|
||||
);
|
||||
assert!(provider.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ollama_with_custom_url() {
|
||||
let provider = create_provider_with_url("ollama", None, Some("http://10.100.2.32:11434"));
|
||||
assert!(provider.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn factory_all_providers_create_successfully() {
|
||||
let providers = [
|
||||
|
|
@ -794,6 +870,7 @@ mod tests {
|
|||
"qwen",
|
||||
"qwen-intl",
|
||||
"qwen-us",
|
||||
"lmstudio",
|
||||
"groq",
|
||||
"mistral",
|
||||
"xai",
|
||||
|
|
@ -888,7 +965,7 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn sanitize_preserves_unicode_boundaries() {
|
||||
let input = format!("{} sk-abcdef123", "こんにちは".repeat(80));
|
||||
let input = format!("{} sk-abcdef123", "hello🙂".repeat(80));
|
||||
let result = sanitize_api_error(&input);
|
||||
assert!(std::str::from_utf8(result.as_bytes()).is_ok());
|
||||
assert!(!result.contains("sk-abcdef123"));
|
||||
|
|
@ -900,4 +977,32 @@ mod tests {
|
|||
let result = sanitize_api_error(input);
|
||||
assert_eq!(result, input);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn scrub_github_personal_access_token() {
|
||||
let input = "auth failed with token ghp_abc123def456";
|
||||
let result = scrub_secret_patterns(input);
|
||||
assert_eq!(result, "auth failed with token [REDACTED]");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn scrub_github_oauth_token() {
|
||||
let input = "Bearer gho_1234567890abcdef";
|
||||
let result = scrub_secret_patterns(input);
|
||||
assert_eq!(result, "Bearer [REDACTED]");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn scrub_github_user_token() {
|
||||
let input = "token ghu_sessiontoken123";
|
||||
let result = scrub_secret_patterns(input);
|
||||
assert_eq!(result, "token [REDACTED]");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn scrub_github_fine_grained_pat() {
|
||||
let input = "failed: github_pat_11AABBC_xyzzy789";
|
||||
let result = scrub_secret_patterns(input);
|
||||
assert_eq!(result, "failed: [REDACTED]");
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8,6 +8,8 @@ pub struct OllamaProvider {
|
|||
client: Client,
|
||||
}
|
||||
|
||||
// ─── Request Structures ───────────────────────────────────────────────────────
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct ChatRequest {
|
||||
model: String,
|
||||
|
|
@ -27,6 +29,8 @@ struct Options {
|
|||
temperature: f64,
|
||||
}
|
||||
|
||||
// ─── Response Structures ──────────────────────────────────────────────────────
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ApiChatResponse {
|
||||
message: ResponseMessage,
|
||||
|
|
@ -34,9 +38,30 @@ struct ApiChatResponse {
|
|||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ResponseMessage {
|
||||
#[serde(default)]
|
||||
content: String,
|
||||
#[serde(default)]
|
||||
tool_calls: Vec<OllamaToolCall>,
|
||||
/// Some models return a "thinking" field with internal reasoning
|
||||
#[serde(default)]
|
||||
thinking: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct OllamaToolCall {
|
||||
id: Option<String>,
|
||||
function: OllamaFunction,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct OllamaFunction {
|
||||
name: String,
|
||||
#[serde(default)]
|
||||
arguments: serde_json::Value,
|
||||
}
|
||||
|
||||
// ─── Implementation ───────────────────────────────────────────────────────────
|
||||
|
||||
impl OllamaProvider {
|
||||
pub fn new(base_url: Option<&str>) -> Self {
|
||||
Self {
|
||||
|
|
@ -45,12 +70,145 @@ impl OllamaProvider {
|
|||
.trim_end_matches('/')
|
||||
.to_string(),
|
||||
client: Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(300)) // Ollama runs locally, may be slow
|
||||
.timeout(std::time::Duration::from_secs(300))
|
||||
.connect_timeout(std::time::Duration::from_secs(10))
|
||||
.build()
|
||||
.unwrap_or_else(|_| Client::new()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Send a request to Ollama and get the parsed response
|
||||
async fn send_request(
|
||||
&self,
|
||||
messages: Vec<Message>,
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<ApiChatResponse> {
|
||||
let request = ChatRequest {
|
||||
model: model.to_string(),
|
||||
messages,
|
||||
stream: false,
|
||||
options: Options { temperature },
|
||||
};
|
||||
|
||||
let url = format!("{}/api/chat", self.base_url);
|
||||
|
||||
tracing::debug!(
|
||||
"Ollama request: url={} model={} message_count={} temperature={}",
|
||||
url,
|
||||
model,
|
||||
request.messages.len(),
|
||||
temperature
|
||||
);
|
||||
|
||||
let response = self.client.post(&url).json(&request).send().await?;
|
||||
let status = response.status();
|
||||
tracing::debug!("Ollama response status: {}", status);
|
||||
|
||||
let body = response.bytes().await?;
|
||||
tracing::debug!("Ollama response body length: {} bytes", body.len());
|
||||
|
||||
if !status.is_success() {
|
||||
let raw = String::from_utf8_lossy(&body);
|
||||
let sanitized = super::sanitize_api_error(&raw);
|
||||
tracing::error!(
|
||||
"Ollama error response: status={} body_excerpt={}",
|
||||
status,
|
||||
sanitized
|
||||
);
|
||||
anyhow::bail!(
|
||||
"Ollama API error ({}): {}. Is Ollama running? (brew install ollama && ollama serve)",
|
||||
status,
|
||||
sanitized
|
||||
);
|
||||
}
|
||||
|
||||
let chat_response: ApiChatResponse = match serde_json::from_slice(&body) {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
let raw = String::from_utf8_lossy(&body);
|
||||
let sanitized = super::sanitize_api_error(&raw);
|
||||
tracing::error!(
|
||||
"Ollama response deserialization failed: {e}. body_excerpt={}",
|
||||
sanitized
|
||||
);
|
||||
anyhow::bail!("Failed to parse Ollama response: {e}");
|
||||
}
|
||||
};
|
||||
|
||||
Ok(chat_response)
|
||||
}
|
||||
|
||||
/// Convert Ollama tool calls to the JSON format expected by parse_tool_calls in loop_.rs
|
||||
///
|
||||
/// Handles quirky model behavior where tool calls are wrapped:
|
||||
/// - `{"name": "tool_call", "arguments": {"name": "shell", "arguments": {...}}}`
|
||||
/// - `{"name": "tool.shell", "arguments": {...}}`
|
||||
fn format_tool_calls_for_loop(&self, tool_calls: &[OllamaToolCall]) -> String {
|
||||
let formatted_calls: Vec<serde_json::Value> = tool_calls
|
||||
.iter()
|
||||
.map(|tc| {
|
||||
let (tool_name, tool_args) = self.extract_tool_name_and_args(tc);
|
||||
|
||||
// Arguments must be a JSON string for parse_tool_calls compatibility
|
||||
let args_str =
|
||||
serde_json::to_string(&tool_args).unwrap_or_else(|_| "{}".to_string());
|
||||
|
||||
serde_json::json!({
|
||||
"id": tc.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_name,
|
||||
"arguments": args_str
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
serde_json::json!({
|
||||
"content": "",
|
||||
"tool_calls": formatted_calls
|
||||
})
|
||||
.to_string()
|
||||
}
|
||||
|
||||
/// Extract the actual tool name and arguments from potentially nested structures
|
||||
fn extract_tool_name_and_args(&self, tc: &OllamaToolCall) -> (String, serde_json::Value) {
|
||||
let name = &tc.function.name;
|
||||
let args = &tc.function.arguments;
|
||||
|
||||
// Pattern 1: Nested tool_call wrapper (various malformed versions)
|
||||
// {"name": "tool_call", "arguments": {"name": "shell", "arguments": {"command": "date"}}}
|
||||
// {"name": "tool_call><json", "arguments": {"name": "shell", ...}}
|
||||
// {"name": "tool.call", "arguments": {"name": "shell", ...}}
|
||||
if name == "tool_call"
|
||||
|| name == "tool.call"
|
||||
|| name.starts_with("tool_call>")
|
||||
|| name.starts_with("tool_call<")
|
||||
{
|
||||
if let Some(nested_name) = args.get("name").and_then(|v| v.as_str()) {
|
||||
let nested_args = args
|
||||
.get("arguments")
|
||||
.cloned()
|
||||
.unwrap_or(serde_json::json!({}));
|
||||
tracing::debug!(
|
||||
"Unwrapped nested tool call: {} -> {} with args {:?}",
|
||||
name,
|
||||
nested_name,
|
||||
nested_args
|
||||
);
|
||||
return (nested_name.to_string(), nested_args);
|
||||
}
|
||||
}
|
||||
|
||||
// Pattern 2: Prefixed tool name (tool.shell, tool.file_read, etc.)
|
||||
if let Some(stripped) = name.strip_prefix("tool.") {
|
||||
return (stripped.to_string(), args.clone());
|
||||
}
|
||||
|
||||
// Pattern 3: Normal tool call
|
||||
(name.clone(), args.clone())
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
|
|
@ -76,27 +234,96 @@ impl Provider for OllamaProvider {
|
|||
content: message.to_string(),
|
||||
});
|
||||
|
||||
let request = ChatRequest {
|
||||
model: model.to_string(),
|
||||
messages,
|
||||
stream: false,
|
||||
options: Options { temperature },
|
||||
};
|
||||
let response = self.send_request(messages, model, temperature).await?;
|
||||
|
||||
let url = format!("{}/api/chat", self.base_url);
|
||||
|
||||
let response = self.client.post(&url).json(&request).send().await?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let err = super::api_error("Ollama", response).await;
|
||||
anyhow::bail!("{err}. Is Ollama running? (brew install ollama && ollama serve)");
|
||||
// If model returned tool calls, format them for loop_.rs's parse_tool_calls
|
||||
if !response.message.tool_calls.is_empty() {
|
||||
tracing::debug!(
|
||||
"Ollama returned {} tool call(s), formatting for loop parser",
|
||||
response.message.tool_calls.len()
|
||||
);
|
||||
return Ok(self.format_tool_calls_for_loop(&response.message.tool_calls));
|
||||
}
|
||||
|
||||
let chat_response: ApiChatResponse = response.json().await?;
|
||||
Ok(chat_response.message.content)
|
||||
// Plain text response
|
||||
let content = response.message.content;
|
||||
|
||||
// Handle edge case: model returned only "thinking" with no content or tool calls
|
||||
if content.is_empty() {
|
||||
if let Some(thinking) = &response.message.thinking {
|
||||
tracing::warn!(
|
||||
"Ollama returned empty content with only thinking: '{}'. Model may have stopped prematurely.",
|
||||
if thinking.len() > 100 { &thinking[..100] } else { thinking }
|
||||
);
|
||||
return Ok(format!(
|
||||
"I was thinking about this: {}... but I didn't complete my response. Could you try asking again?",
|
||||
if thinking.len() > 200 { &thinking[..200] } else { thinking }
|
||||
));
|
||||
}
|
||||
tracing::warn!("Ollama returned empty content with no tool calls");
|
||||
}
|
||||
|
||||
Ok(content)
|
||||
}
|
||||
|
||||
async fn chat_with_history(
|
||||
&self,
|
||||
messages: &[crate::providers::ChatMessage],
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
let api_messages: Vec<Message> = messages
|
||||
.iter()
|
||||
.map(|m| Message {
|
||||
role: m.role.clone(),
|
||||
content: m.content.clone(),
|
||||
})
|
||||
.collect();
|
||||
|
||||
let response = self.send_request(api_messages, model, temperature).await?;
|
||||
|
||||
// If model returned tool calls, format them for loop_.rs's parse_tool_calls
|
||||
if !response.message.tool_calls.is_empty() {
|
||||
tracing::debug!(
|
||||
"Ollama returned {} tool call(s), formatting for loop parser",
|
||||
response.message.tool_calls.len()
|
||||
);
|
||||
return Ok(self.format_tool_calls_for_loop(&response.message.tool_calls));
|
||||
}
|
||||
|
||||
// Plain text response
|
||||
let content = response.message.content;
|
||||
|
||||
// Handle edge case: model returned only "thinking" with no content or tool calls
|
||||
// This is a model quirk - it stopped after reasoning without producing output
|
||||
if content.is_empty() {
|
||||
if let Some(thinking) = &response.message.thinking {
|
||||
tracing::warn!(
|
||||
"Ollama returned empty content with only thinking: '{}'. Model may have stopped prematurely.",
|
||||
if thinking.len() > 100 { &thinking[..100] } else { thinking }
|
||||
);
|
||||
// Return a message indicating the model's thought process but no action
|
||||
return Ok(format!(
|
||||
"I was thinking about this: {}... but I didn't complete my response. Could you try asking again?",
|
||||
if thinking.len() > 200 { &thinking[..200] } else { thinking }
|
||||
));
|
||||
}
|
||||
tracing::warn!("Ollama returned empty content with no tool calls");
|
||||
}
|
||||
|
||||
Ok(content)
|
||||
}
|
||||
|
||||
fn supports_native_tools(&self) -> bool {
|
||||
// Return false since loop_.rs uses XML-style tool parsing via system prompt
|
||||
// The model may return native tool_calls but we convert them to JSON format
|
||||
// that parse_tool_calls() understands
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
// ─── Tests ────────────────────────────────────────────────────────────────────
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
|
@ -125,46 +352,6 @@ mod tests {
|
|||
assert_eq!(p.base_url, "");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn request_serializes_with_system() {
|
||||
let req = ChatRequest {
|
||||
model: "llama3".to_string(),
|
||||
messages: vec![
|
||||
Message {
|
||||
role: "system".to_string(),
|
||||
content: "You are ZeroClaw".to_string(),
|
||||
},
|
||||
Message {
|
||||
role: "user".to_string(),
|
||||
content: "hello".to_string(),
|
||||
},
|
||||
],
|
||||
stream: false,
|
||||
options: Options { temperature: 0.7 },
|
||||
};
|
||||
let json = serde_json::to_string(&req).unwrap();
|
||||
assert!(json.contains("\"stream\":false"));
|
||||
assert!(json.contains("llama3"));
|
||||
assert!(json.contains("system"));
|
||||
assert!(json.contains("\"temperature\":0.7"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn request_serializes_without_system() {
|
||||
let req = ChatRequest {
|
||||
model: "mistral".to_string(),
|
||||
messages: vec![Message {
|
||||
role: "user".to_string(),
|
||||
content: "test".to_string(),
|
||||
}],
|
||||
stream: false,
|
||||
options: Options { temperature: 0.0 },
|
||||
};
|
||||
let json = serde_json::to_string(&req).unwrap();
|
||||
assert!(!json.contains("\"role\":\"system\""));
|
||||
assert!(json.contains("mistral"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_deserializes() {
|
||||
let json = r#"{"message":{"role":"assistant","content":"Hello from Ollama!"}}"#;
|
||||
|
|
@ -180,9 +367,98 @@ mod tests {
|
|||
}
|
||||
|
||||
#[test]
|
||||
fn response_with_multiline() {
|
||||
let json = r#"{"message":{"role":"assistant","content":"line1\nline2\nline3"}}"#;
|
||||
fn response_with_missing_content_defaults_to_empty() {
|
||||
let json = r#"{"message":{"role":"assistant"}}"#;
|
||||
let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
|
||||
assert!(resp.message.content.contains("line1"));
|
||||
assert!(resp.message.content.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_with_thinking_field_extracts_content() {
|
||||
let json =
|
||||
r#"{"message":{"role":"assistant","content":"hello","thinking":"internal reasoning"}}"#;
|
||||
let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(resp.message.content, "hello");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_with_tool_calls_parses_correctly() {
|
||||
let json = r#"{"message":{"role":"assistant","content":"","tool_calls":[{"id":"call_123","function":{"name":"shell","arguments":{"command":"date"}}}]}}"#;
|
||||
let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
|
||||
assert!(resp.message.content.is_empty());
|
||||
assert_eq!(resp.message.tool_calls.len(), 1);
|
||||
assert_eq!(resp.message.tool_calls[0].function.name, "shell");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_tool_name_handles_nested_tool_call() {
|
||||
let provider = OllamaProvider::new(None);
|
||||
let tc = OllamaToolCall {
|
||||
id: Some("call_123".into()),
|
||||
function: OllamaFunction {
|
||||
name: "tool_call".into(),
|
||||
arguments: serde_json::json!({
|
||||
"name": "shell",
|
||||
"arguments": {"command": "date"}
|
||||
}),
|
||||
},
|
||||
};
|
||||
let (name, args) = provider.extract_tool_name_and_args(&tc);
|
||||
assert_eq!(name, "shell");
|
||||
assert_eq!(args.get("command").unwrap(), "date");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_tool_name_handles_prefixed_name() {
|
||||
let provider = OllamaProvider::new(None);
|
||||
let tc = OllamaToolCall {
|
||||
id: Some("call_123".into()),
|
||||
function: OllamaFunction {
|
||||
name: "tool.shell".into(),
|
||||
arguments: serde_json::json!({"command": "ls"}),
|
||||
},
|
||||
};
|
||||
let (name, args) = provider.extract_tool_name_and_args(&tc);
|
||||
assert_eq!(name, "shell");
|
||||
assert_eq!(args.get("command").unwrap(), "ls");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_tool_name_handles_normal_call() {
|
||||
let provider = OllamaProvider::new(None);
|
||||
let tc = OllamaToolCall {
|
||||
id: Some("call_123".into()),
|
||||
function: OllamaFunction {
|
||||
name: "file_read".into(),
|
||||
arguments: serde_json::json!({"path": "/tmp/test"}),
|
||||
},
|
||||
};
|
||||
let (name, args) = provider.extract_tool_name_and_args(&tc);
|
||||
assert_eq!(name, "file_read");
|
||||
assert_eq!(args.get("path").unwrap(), "/tmp/test");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn format_tool_calls_produces_valid_json() {
|
||||
let provider = OllamaProvider::new(None);
|
||||
let tool_calls = vec![OllamaToolCall {
|
||||
id: Some("call_abc".into()),
|
||||
function: OllamaFunction {
|
||||
name: "shell".into(),
|
||||
arguments: serde_json::json!({"command": "date"}),
|
||||
},
|
||||
}];
|
||||
|
||||
let formatted = provider.format_tool_calls_for_loop(&tool_calls);
|
||||
let parsed: serde_json::Value = serde_json::from_str(&formatted).unwrap();
|
||||
|
||||
assert!(parsed.get("tool_calls").is_some());
|
||||
let calls = parsed.get("tool_calls").unwrap().as_array().unwrap();
|
||||
assert_eq!(calls.len(), 1);
|
||||
|
||||
let func = calls[0].get("function").unwrap();
|
||||
assert_eq!(func.get("name").unwrap(), "shell");
|
||||
// arguments should be a string (JSON-encoded)
|
||||
assert!(func.get("arguments").unwrap().is_string());
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ use reqwest::Client;
|
|||
use serde::{Deserialize, Serialize};
|
||||
|
||||
pub struct OpenAiProvider {
|
||||
api_key: Option<String>,
|
||||
credential: Option<String>,
|
||||
client: Client,
|
||||
}
|
||||
|
||||
|
|
@ -110,9 +110,9 @@ struct NativeResponseMessage {
|
|||
}
|
||||
|
||||
impl OpenAiProvider {
|
||||
pub fn new(api_key: Option<&str>) -> Self {
|
||||
pub fn new(credential: Option<&str>) -> Self {
|
||||
Self {
|
||||
api_key: api_key.map(ToString::to_string),
|
||||
credential: credential.map(ToString::to_string),
|
||||
client: Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(120))
|
||||
.connect_timeout(std::time::Duration::from_secs(10))
|
||||
|
|
@ -232,7 +232,7 @@ impl Provider for OpenAiProvider {
|
|||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
let api_key = self.api_key.as_ref().ok_or_else(|| {
|
||||
let credential = self.credential.as_ref().ok_or_else(|| {
|
||||
anyhow::anyhow!("OpenAI API key not set. Set OPENAI_API_KEY or edit config.toml.")
|
||||
})?;
|
||||
|
||||
|
|
@ -259,7 +259,7 @@ impl Provider for OpenAiProvider {
|
|||
let response = self
|
||||
.client
|
||||
.post("https://api.openai.com/v1/chat/completions")
|
||||
.header("Authorization", format!("Bearer {api_key}"))
|
||||
.header("Authorization", format!("Bearer {credential}"))
|
||||
.json(&request)
|
||||
.send()
|
||||
.await?;
|
||||
|
|
@ -284,7 +284,7 @@ impl Provider for OpenAiProvider {
|
|||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<ProviderChatResponse> {
|
||||
let api_key = self.api_key.as_ref().ok_or_else(|| {
|
||||
let credential = self.credential.as_ref().ok_or_else(|| {
|
||||
anyhow::anyhow!("OpenAI API key not set. Set OPENAI_API_KEY or edit config.toml.")
|
||||
})?;
|
||||
|
||||
|
|
@ -300,7 +300,7 @@ impl Provider for OpenAiProvider {
|
|||
let response = self
|
||||
.client
|
||||
.post("https://api.openai.com/v1/chat/completions")
|
||||
.header("Authorization", format!("Bearer {api_key}"))
|
||||
.header("Authorization", format!("Bearer {credential}"))
|
||||
.json(&native_request)
|
||||
.send()
|
||||
.await?;
|
||||
|
|
@ -330,20 +330,20 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn creates_with_key() {
|
||||
let p = OpenAiProvider::new(Some("sk-proj-abc123"));
|
||||
assert_eq!(p.api_key.as_deref(), Some("sk-proj-abc123"));
|
||||
let p = OpenAiProvider::new(Some("openai-test-credential"));
|
||||
assert_eq!(p.credential.as_deref(), Some("openai-test-credential"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn creates_without_key() {
|
||||
let p = OpenAiProvider::new(None);
|
||||
assert!(p.api_key.is_none());
|
||||
assert!(p.credential.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn creates_with_empty_key() {
|
||||
let p = OpenAiProvider::new(Some(""));
|
||||
assert_eq!(p.api_key.as_deref(), Some(""));
|
||||
assert_eq!(p.credential.as_deref(), Some(""));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ use reqwest::Client;
|
|||
use serde::{Deserialize, Serialize};
|
||||
|
||||
pub struct OpenRouterProvider {
|
||||
api_key: Option<String>,
|
||||
credential: Option<String>,
|
||||
client: Client,
|
||||
}
|
||||
|
||||
|
|
@ -110,9 +110,9 @@ struct NativeResponseMessage {
|
|||
}
|
||||
|
||||
impl OpenRouterProvider {
|
||||
pub fn new(api_key: Option<&str>) -> Self {
|
||||
pub fn new(credential: Option<&str>) -> Self {
|
||||
Self {
|
||||
api_key: api_key.map(ToString::to_string),
|
||||
credential: credential.map(ToString::to_string),
|
||||
client: Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(120))
|
||||
.connect_timeout(std::time::Duration::from_secs(10))
|
||||
|
|
@ -232,10 +232,10 @@ impl Provider for OpenRouterProvider {
|
|||
async fn warmup(&self) -> anyhow::Result<()> {
|
||||
// Hit a lightweight endpoint to establish TLS + HTTP/2 connection pool.
|
||||
// This prevents the first real chat request from timing out on cold start.
|
||||
if let Some(api_key) = self.api_key.as_ref() {
|
||||
if let Some(credential) = self.credential.as_ref() {
|
||||
self.client
|
||||
.get("https://openrouter.ai/api/v1/auth/key")
|
||||
.header("Authorization", format!("Bearer {api_key}"))
|
||||
.header("Authorization", format!("Bearer {credential}"))
|
||||
.send()
|
||||
.await?
|
||||
.error_for_status()?;
|
||||
|
|
@ -250,7 +250,7 @@ impl Provider for OpenRouterProvider {
|
|||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
let api_key = self.api_key.as_ref()
|
||||
let credential = self.credential.as_ref()
|
||||
.ok_or_else(|| anyhow::anyhow!("OpenRouter API key not set. Run `zeroclaw onboard` or set OPENROUTER_API_KEY env var."))?;
|
||||
|
||||
let mut messages = Vec::new();
|
||||
|
|
@ -276,7 +276,7 @@ impl Provider for OpenRouterProvider {
|
|||
let response = self
|
||||
.client
|
||||
.post("https://openrouter.ai/api/v1/chat/completions")
|
||||
.header("Authorization", format!("Bearer {api_key}"))
|
||||
.header("Authorization", format!("Bearer {credential}"))
|
||||
.header(
|
||||
"HTTP-Referer",
|
||||
"https://github.com/theonlyhennygod/zeroclaw",
|
||||
|
|
@ -306,7 +306,7 @@ impl Provider for OpenRouterProvider {
|
|||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
let api_key = self.api_key.as_ref()
|
||||
let credential = self.credential.as_ref()
|
||||
.ok_or_else(|| anyhow::anyhow!("OpenRouter API key not set. Run `zeroclaw onboard` or set OPENROUTER_API_KEY env var."))?;
|
||||
|
||||
let api_messages: Vec<Message> = messages
|
||||
|
|
@ -326,7 +326,7 @@ impl Provider for OpenRouterProvider {
|
|||
let response = self
|
||||
.client
|
||||
.post("https://openrouter.ai/api/v1/chat/completions")
|
||||
.header("Authorization", format!("Bearer {api_key}"))
|
||||
.header("Authorization", format!("Bearer {credential}"))
|
||||
.header(
|
||||
"HTTP-Referer",
|
||||
"https://github.com/theonlyhennygod/zeroclaw",
|
||||
|
|
@ -356,7 +356,7 @@ impl Provider for OpenRouterProvider {
|
|||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<ProviderChatResponse> {
|
||||
let api_key = self.api_key.as_ref().ok_or_else(|| {
|
||||
let credential = self.credential.as_ref().ok_or_else(|| {
|
||||
anyhow::anyhow!(
|
||||
"OpenRouter API key not set. Run `zeroclaw onboard` or set OPENROUTER_API_KEY env var."
|
||||
)
|
||||
|
|
@ -374,7 +374,7 @@ impl Provider for OpenRouterProvider {
|
|||
let response = self
|
||||
.client
|
||||
.post("https://openrouter.ai/api/v1/chat/completions")
|
||||
.header("Authorization", format!("Bearer {api_key}"))
|
||||
.header("Authorization", format!("Bearer {credential}"))
|
||||
.header(
|
||||
"HTTP-Referer",
|
||||
"https://github.com/theonlyhennygod/zeroclaw",
|
||||
|
|
@ -409,7 +409,7 @@ impl Provider for OpenRouterProvider {
|
|||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<ProviderChatResponse> {
|
||||
let api_key = self.api_key.as_ref().ok_or_else(|| {
|
||||
let credential = self.credential.as_ref().ok_or_else(|| {
|
||||
anyhow::anyhow!(
|
||||
"OpenRouter API key not set. Run `zeroclaw onboard` or set OPENROUTER_API_KEY env var."
|
||||
)
|
||||
|
|
@ -462,7 +462,7 @@ impl Provider for OpenRouterProvider {
|
|||
let response = self
|
||||
.client
|
||||
.post("https://openrouter.ai/api/v1/chat/completions")
|
||||
.header("Authorization", format!("Bearer {api_key}"))
|
||||
.header("Authorization", format!("Bearer {credential}"))
|
||||
.header(
|
||||
"HTTP-Referer",
|
||||
"https://github.com/theonlyhennygod/zeroclaw",
|
||||
|
|
@ -494,14 +494,17 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn creates_with_key() {
|
||||
let provider = OpenRouterProvider::new(Some("sk-or-123"));
|
||||
assert_eq!(provider.api_key.as_deref(), Some("sk-or-123"));
|
||||
let provider = OpenRouterProvider::new(Some("openrouter-test-credential"));
|
||||
assert_eq!(
|
||||
provider.credential.as_deref(),
|
||||
Some("openrouter-test-credential")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn creates_without_key() {
|
||||
let provider = OpenRouterProvider::new(None);
|
||||
assert!(provider.api_key.is_none());
|
||||
assert!(provider.credential.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
|
|
|||
|
|
@ -144,8 +144,8 @@ impl Provider for ReliableProvider {
|
|||
async fn warmup(&self) -> anyhow::Result<()> {
|
||||
for (name, provider) in &self.providers {
|
||||
tracing::info!(provider = name, "Warming up provider connection pool");
|
||||
if let Err(e) = provider.warmup().await {
|
||||
tracing::warn!(provider = name, "Warmup failed (non-fatal): {e}");
|
||||
if provider.warmup().await.is_err() {
|
||||
tracing::warn!(provider = name, "Warmup failed (non-fatal)");
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
|
|
@ -186,8 +186,15 @@ impl Provider for ReliableProvider {
|
|||
let non_retryable = is_non_retryable(&e);
|
||||
let rate_limited = is_rate_limited(&e);
|
||||
|
||||
let failure_reason = if rate_limited {
|
||||
"rate_limited"
|
||||
} else if non_retryable {
|
||||
"non_retryable"
|
||||
} else {
|
||||
"retryable"
|
||||
};
|
||||
failures.push(format!(
|
||||
"{provider_name}/{current_model} attempt {}/{}: {e}",
|
||||
"{provider_name}/{current_model} attempt {}/{}: {failure_reason}",
|
||||
attempt + 1,
|
||||
self.max_retries + 1
|
||||
));
|
||||
|
|
@ -284,8 +291,15 @@ impl Provider for ReliableProvider {
|
|||
let non_retryable = is_non_retryable(&e);
|
||||
let rate_limited = is_rate_limited(&e);
|
||||
|
||||
let failure_reason = if rate_limited {
|
||||
"rate_limited"
|
||||
} else if non_retryable {
|
||||
"non_retryable"
|
||||
} else {
|
||||
"retryable"
|
||||
};
|
||||
failures.push(format!(
|
||||
"{provider_name}/{current_model} attempt {}/{}: {e}",
|
||||
"{provider_name}/{current_model} attempt {}/{}: {failure_reason}",
|
||||
attempt + 1,
|
||||
self.max_retries + 1
|
||||
));
|
||||
|
|
|
|||
|
|
@ -193,6 +193,13 @@ pub enum StreamError {
|
|||
|
||||
#[async_trait]
|
||||
pub trait Provider: Send + Sync {
|
||||
/// Query provider capabilities.
|
||||
///
|
||||
/// Default implementation returns minimal capabilities (no native tool calling).
|
||||
/// Providers should override this to declare their actual capabilities.
|
||||
fn capabilities(&self) -> ProviderCapabilities {
|
||||
ProviderCapabilities::default()
|
||||
}
|
||||
/// Simple one-shot chat (single user message, no explicit system prompt).
|
||||
///
|
||||
/// This is the preferred API for non-agentic direct interactions.
|
||||
|
|
@ -256,7 +263,7 @@ pub trait Provider: Send + Sync {
|
|||
|
||||
/// Whether provider supports native tool calls over API.
|
||||
fn supports_native_tools(&self) -> bool {
|
||||
false
|
||||
self.capabilities().native_tool_calling
|
||||
}
|
||||
|
||||
/// Warm up the HTTP connection pool (TLS handshake, DNS, HTTP/2 setup).
|
||||
|
|
@ -336,6 +343,27 @@ pub trait Provider: Send + Sync {
|
|||
mod tests {
|
||||
use super::*;
|
||||
|
||||
struct CapabilityMockProvider;
|
||||
|
||||
#[async_trait]
|
||||
impl Provider for CapabilityMockProvider {
|
||||
fn capabilities(&self) -> ProviderCapabilities {
|
||||
ProviderCapabilities {
|
||||
native_tool_calling: true,
|
||||
}
|
||||
}
|
||||
|
||||
async fn chat_with_system(
|
||||
&self,
|
||||
_system_prompt: Option<&str>,
|
||||
_message: &str,
|
||||
_model: &str,
|
||||
_temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
Ok("ok".into())
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn chat_message_constructors() {
|
||||
let sys = ChatMessage::system("Be helpful");
|
||||
|
|
@ -398,4 +426,32 @@ mod tests {
|
|||
let json = serde_json::to_string(&tool_result).unwrap();
|
||||
assert!(json.contains("\"type\":\"ToolResults\""));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn provider_capabilities_default() {
|
||||
let caps = ProviderCapabilities::default();
|
||||
assert!(!caps.native_tool_calling);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn provider_capabilities_equality() {
|
||||
let caps1 = ProviderCapabilities {
|
||||
native_tool_calling: true,
|
||||
};
|
||||
let caps2 = ProviderCapabilities {
|
||||
native_tool_calling: true,
|
||||
};
|
||||
let caps3 = ProviderCapabilities {
|
||||
native_tool_calling: false,
|
||||
};
|
||||
|
||||
assert_eq!(caps1, caps2);
|
||||
assert_ne!(caps1, caps3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn supports_native_tools_reflects_capabilities_default_mapping() {
|
||||
let provider = CapabilityMockProvider;
|
||||
assert!(provider.supports_native_tools());
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue