Merge branch 'main' into pr-484-clean

This commit is contained in:
Will Sarg 2026-02-17 08:54:24 -05:00 committed by GitHub
commit ee05d62ce4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
90 changed files with 6937 additions and 1403 deletions

View file

@ -106,17 +106,17 @@ struct NativeContentIn {
}
impl AnthropicProvider {
pub fn new(api_key: Option<&str>) -> Self {
Self::with_base_url(api_key, None)
pub fn new(credential: Option<&str>) -> Self {
Self::with_base_url(credential, None)
}
pub fn with_base_url(api_key: Option<&str>, base_url: Option<&str>) -> Self {
pub fn with_base_url(credential: Option<&str>, base_url: Option<&str>) -> Self {
let base_url = base_url
.map(|u| u.trim_end_matches('/'))
.unwrap_or("https://api.anthropic.com")
.to_string();
Self {
credential: api_key
credential: credential
.map(str::trim)
.filter(|k| !k.is_empty())
.map(ToString::to_string),
@ -410,9 +410,9 @@ mod tests {
#[test]
fn creates_with_key() {
let p = AnthropicProvider::new(Some("sk-ant-test123"));
let p = AnthropicProvider::new(Some("anthropic-test-credential"));
assert!(p.credential.is_some());
assert_eq!(p.credential.as_deref(), Some("sk-ant-test123"));
assert_eq!(p.credential.as_deref(), Some("anthropic-test-credential"));
assert_eq!(p.base_url, "https://api.anthropic.com");
}
@ -431,17 +431,19 @@ mod tests {
#[test]
fn creates_with_whitespace_key() {
let p = AnthropicProvider::new(Some(" sk-ant-test123 "));
let p = AnthropicProvider::new(Some(" anthropic-test-credential "));
assert!(p.credential.is_some());
assert_eq!(p.credential.as_deref(), Some("sk-ant-test123"));
assert_eq!(p.credential.as_deref(), Some("anthropic-test-credential"));
}
#[test]
fn creates_with_custom_base_url() {
let p =
AnthropicProvider::with_base_url(Some("sk-ant-test"), Some("https://api.example.com"));
let p = AnthropicProvider::with_base_url(
Some("anthropic-credential"),
Some("https://api.example.com"),
);
assert_eq!(p.base_url, "https://api.example.com");
assert_eq!(p.credential.as_deref(), Some("sk-ant-test"));
assert_eq!(p.credential.as_deref(), Some("anthropic-credential"));
}
#[test]

View file

@ -17,7 +17,7 @@ use serde::{Deserialize, Serialize};
pub struct OpenAiCompatibleProvider {
pub(crate) name: String,
pub(crate) base_url: String,
pub(crate) api_key: Option<String>,
pub(crate) credential: Option<String>,
pub(crate) auth_header: AuthStyle,
/// When false, do not fall back to /v1/responses on chat completions 404.
/// GLM/Zhipu does not support the responses API.
@ -37,11 +37,16 @@ pub enum AuthStyle {
}
impl OpenAiCompatibleProvider {
pub fn new(name: &str, base_url: &str, api_key: Option<&str>, auth_style: AuthStyle) -> Self {
pub fn new(
name: &str,
base_url: &str,
credential: Option<&str>,
auth_style: AuthStyle,
) -> Self {
Self {
name: name.to_string(),
base_url: base_url.trim_end_matches('/').to_string(),
api_key: api_key.map(ToString::to_string),
credential: credential.map(ToString::to_string),
auth_header: auth_style,
supports_responses_fallback: true,
client: Client::builder()
@ -57,13 +62,13 @@ impl OpenAiCompatibleProvider {
pub fn new_no_responses_fallback(
name: &str,
base_url: &str,
api_key: Option<&str>,
credential: Option<&str>,
auth_style: AuthStyle,
) -> Self {
Self {
name: name.to_string(),
base_url: base_url.trim_end_matches('/').to_string(),
api_key: api_key.map(ToString::to_string),
credential: credential.map(ToString::to_string),
auth_header: auth_style,
supports_responses_fallback: false,
client: Client::builder()
@ -405,18 +410,18 @@ impl OpenAiCompatibleProvider {
fn apply_auth_header(
&self,
req: reqwest::RequestBuilder,
api_key: &str,
credential: &str,
) -> reqwest::RequestBuilder {
match &self.auth_header {
AuthStyle::Bearer => req.header("Authorization", format!("Bearer {api_key}")),
AuthStyle::XApiKey => req.header("x-api-key", api_key),
AuthStyle::Custom(header) => req.header(header, api_key),
AuthStyle::Bearer => req.header("Authorization", format!("Bearer {credential}")),
AuthStyle::XApiKey => req.header("x-api-key", credential),
AuthStyle::Custom(header) => req.header(header, credential),
}
}
async fn chat_via_responses(
&self,
api_key: &str,
credential: &str,
system_prompt: Option<&str>,
message: &str,
model: &str,
@ -434,7 +439,7 @@ impl OpenAiCompatibleProvider {
let url = self.responses_url();
let response = self
.apply_auth_header(self.client.post(&url).json(&request), api_key)
.apply_auth_header(self.client.post(&url).json(&request), credential)
.send()
.await?;
@ -459,7 +464,7 @@ impl Provider for OpenAiCompatibleProvider {
model: &str,
temperature: f64,
) -> anyhow::Result<String> {
let api_key = self.api_key.as_ref().ok_or_else(|| {
let credential = self.credential.as_ref().ok_or_else(|| {
anyhow::anyhow!(
"{} API key not set. Run `zeroclaw onboard` or set the appropriate env var.",
self.name
@ -490,7 +495,7 @@ impl Provider for OpenAiCompatibleProvider {
let url = self.chat_completions_url();
let response = self
.apply_auth_header(self.client.post(&url).json(&request), api_key)
.apply_auth_header(self.client.post(&url).json(&request), credential)
.send()
.await?;
@ -501,7 +506,7 @@ impl Provider for OpenAiCompatibleProvider {
if status == reqwest::StatusCode::NOT_FOUND && self.supports_responses_fallback {
return self
.chat_via_responses(api_key, system_prompt, message, model)
.chat_via_responses(credential, system_prompt, message, model)
.await
.map_err(|responses_err| {
anyhow::anyhow!(
@ -545,7 +550,7 @@ impl Provider for OpenAiCompatibleProvider {
model: &str,
temperature: f64,
) -> anyhow::Result<String> {
let api_key = self.api_key.as_ref().ok_or_else(|| {
let credential = self.credential.as_ref().ok_or_else(|| {
anyhow::anyhow!(
"{} API key not set. Run `zeroclaw onboard` or set the appropriate env var.",
self.name
@ -569,7 +574,7 @@ impl Provider for OpenAiCompatibleProvider {
let url = self.chat_completions_url();
let response = self
.apply_auth_header(self.client.post(&url).json(&request), api_key)
.apply_auth_header(self.client.post(&url).json(&request), credential)
.send()
.await?;
@ -584,7 +589,7 @@ impl Provider for OpenAiCompatibleProvider {
if let Some(user_msg) = last_user {
return self
.chat_via_responses(
api_key,
credential,
system.map(|m| m.content.as_str()),
&user_msg.content,
model,
@ -791,16 +796,20 @@ mod tests {
#[test]
fn creates_with_key() {
let p = make_provider("venice", "https://api.venice.ai", Some("vn-key"));
let p = make_provider(
"venice",
"https://api.venice.ai",
Some("venice-test-credential"),
);
assert_eq!(p.name, "venice");
assert_eq!(p.base_url, "https://api.venice.ai");
assert_eq!(p.api_key.as_deref(), Some("vn-key"));
assert_eq!(p.credential.as_deref(), Some("venice-test-credential"));
}
#[test]
fn creates_without_key() {
let p = make_provider("test", "https://example.com", None);
assert!(p.api_key.is_none());
assert!(p.credential.is_none());
}
#[test]
@ -894,6 +903,7 @@ mod tests {
make_provider("Groq", "https://api.groq.com/openai", None),
make_provider("Mistral", "https://api.mistral.ai", None),
make_provider("xAI", "https://api.x.ai", None),
make_provider("Astrai", "https://as-trai.com/v1", None),
];
for p in providers {

705
src/providers/copilot.rs Normal file
View file

@ -0,0 +1,705 @@
//! GitHub Copilot provider with OAuth device-flow authentication.
//!
//! Authenticates via GitHub's device code flow (same as VS Code Copilot),
//! then exchanges the OAuth token for short-lived Copilot API keys.
//! Tokens are cached to disk and auto-refreshed.
//!
//! **Note:** This uses VS Code's OAuth client ID (`Iv1.b507a08c87ecfe98`) and
//! editor headers. This is the same approach used by LiteLLM, Codex CLI,
//! and other third-party Copilot integrations. The Copilot token endpoint is
//! private; there is no public OAuth scope or app registration for it.
//! GitHub could change or revoke this at any time, which would break all
//! third-party integrations simultaneously.
use crate::providers::traits::{
ChatMessage, ChatRequest as ProviderChatRequest, ChatResponse as ProviderChatResponse,
Provider, ToolCall as ProviderToolCall,
};
use crate::tools::ToolSpec;
use async_trait::async_trait;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::Mutex;
use tracing::warn;
/// GitHub OAuth client ID for Copilot (VS Code extension).
const GITHUB_CLIENT_ID: &str = "Iv1.b507a08c87ecfe98";
const GITHUB_DEVICE_CODE_URL: &str = "https://github.com/login/device/code";
const GITHUB_ACCESS_TOKEN_URL: &str = "https://github.com/login/oauth/access_token";
const GITHUB_API_KEY_URL: &str = "https://api.github.com/copilot_internal/v2/token";
const DEFAULT_API: &str = "https://api.githubcopilot.com";
// ── Token types ──────────────────────────────────────────────────
#[derive(Debug, Deserialize)]
struct DeviceCodeResponse {
device_code: String,
user_code: String,
verification_uri: String,
#[serde(default = "default_interval")]
interval: u64,
#[serde(default = "default_expires_in")]
expires_in: u64,
}
fn default_interval() -> u64 {
5
}
fn default_expires_in() -> u64 {
900
}
#[derive(Debug, Deserialize)]
struct AccessTokenResponse {
access_token: Option<String>,
error: Option<String>,
}
#[derive(Debug, Serialize, Deserialize)]
struct ApiKeyInfo {
token: String,
expires_at: i64,
#[serde(default)]
endpoints: Option<ApiEndpoints>,
}
#[derive(Debug, Serialize, Deserialize)]
struct ApiEndpoints {
api: Option<String>,
}
struct CachedApiKey {
token: String,
api_endpoint: String,
expires_at: i64,
}
// ── Chat completions types ───────────────────────────────────────
#[derive(Debug, Serialize)]
struct ApiChatRequest {
model: String,
messages: Vec<ApiMessage>,
temperature: f64,
#[serde(skip_serializing_if = "Option::is_none")]
tools: Option<Vec<NativeToolSpec>>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_choice: Option<String>,
}
#[derive(Debug, Serialize)]
struct ApiMessage {
role: String,
#[serde(skip_serializing_if = "Option::is_none")]
content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_call_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_calls: Option<Vec<NativeToolCall>>,
}
#[derive(Debug, Serialize)]
struct NativeToolSpec {
#[serde(rename = "type")]
kind: String,
function: NativeToolFunctionSpec,
}
#[derive(Debug, Serialize)]
struct NativeToolFunctionSpec {
name: String,
description: String,
parameters: serde_json::Value,
}
#[derive(Debug, Serialize, Deserialize)]
struct NativeToolCall {
#[serde(skip_serializing_if = "Option::is_none")]
id: Option<String>,
#[serde(rename = "type", skip_serializing_if = "Option::is_none")]
kind: Option<String>,
function: NativeFunctionCall,
}
#[derive(Debug, Serialize, Deserialize)]
struct NativeFunctionCall {
name: String,
arguments: String,
}
#[derive(Debug, Deserialize)]
struct ApiChatResponse {
choices: Vec<Choice>,
}
#[derive(Debug, Deserialize)]
struct Choice {
message: ResponseMessage,
}
#[derive(Debug, Deserialize)]
struct ResponseMessage {
#[serde(default)]
content: Option<String>,
#[serde(default)]
tool_calls: Option<Vec<NativeToolCall>>,
}
// ── Provider ─────────────────────────────────────────────────────
/// GitHub Copilot provider with automatic OAuth and token refresh.
///
/// On first use, prompts the user to visit github.com/login/device.
/// Tokens are cached to `~/.config/zeroclaw/copilot/` and refreshed
/// automatically.
pub struct CopilotProvider {
github_token: Option<String>,
/// Mutex ensures only one caller refreshes tokens at a time,
/// preventing duplicate device flow prompts or redundant API calls.
refresh_lock: Arc<Mutex<Option<CachedApiKey>>>,
http: Client,
token_dir: PathBuf,
}
impl CopilotProvider {
pub fn new(github_token: Option<&str>) -> Self {
let token_dir = directories::ProjectDirs::from("", "", "zeroclaw")
.map(|dir| dir.config_dir().join("copilot"))
.unwrap_or_else(|| {
// Fall back to a user-specific temp directory to avoid
// shared-directory symlink attacks.
let user = std::env::var("USER")
.or_else(|_| std::env::var("USERNAME"))
.unwrap_or_else(|_| "unknown".to_string());
std::env::temp_dir().join(format!("zeroclaw-copilot-{user}"))
});
if let Err(err) = std::fs::create_dir_all(&token_dir) {
warn!(
"Failed to create Copilot token directory {:?}: {err}. Token caching is disabled.",
token_dir
);
} else {
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
if let Err(err) =
std::fs::set_permissions(&token_dir, std::fs::Permissions::from_mode(0o700))
{
warn!(
"Failed to set Copilot token directory permissions on {:?}: {err}",
token_dir
);
}
}
}
Self {
github_token: github_token
.filter(|token| !token.is_empty())
.map(String::from),
refresh_lock: Arc::new(Mutex::new(None)),
http: Client::builder()
.timeout(Duration::from_secs(120))
.connect_timeout(Duration::from_secs(10))
.build()
.unwrap_or_else(|_| Client::new()),
token_dir,
}
}
/// Required headers for Copilot API requests (editor identification).
const COPILOT_HEADERS: [(&str, &str); 4] = [
("Editor-Version", "vscode/1.85.1"),
("Editor-Plugin-Version", "copilot/1.155.0"),
("User-Agent", "GithubCopilot/1.155.0"),
("Accept", "application/json"),
];
fn convert_tools(tools: Option<&[ToolSpec]>) -> Option<Vec<NativeToolSpec>> {
tools.map(|items| {
items
.iter()
.map(|tool| NativeToolSpec {
kind: "function".to_string(),
function: NativeToolFunctionSpec {
name: tool.name.clone(),
description: tool.description.clone(),
parameters: tool.parameters.clone(),
},
})
.collect()
})
}
fn convert_messages(messages: &[ChatMessage]) -> Vec<ApiMessage> {
messages
.iter()
.map(|message| {
if message.role == "assistant" {
if let Ok(value) = serde_json::from_str::<serde_json::Value>(&message.content) {
if let Some(tool_calls_value) = value.get("tool_calls") {
if let Ok(parsed_calls) =
serde_json::from_value::<Vec<ProviderToolCall>>(tool_calls_value.clone())
{
let tool_calls = parsed_calls
.into_iter()
.map(|tool_call| NativeToolCall {
id: Some(tool_call.id),
kind: Some("function".to_string()),
function: NativeFunctionCall {
name: tool_call.name,
arguments: tool_call.arguments,
},
})
.collect::<Vec<_>>();
let content = value
.get("content")
.and_then(serde_json::Value::as_str)
.map(ToString::to_string);
return ApiMessage {
role: "assistant".to_string(),
content,
tool_call_id: None,
tool_calls: Some(tool_calls),
};
}
}
}
}
if message.role == "tool" {
if let Ok(value) = serde_json::from_str::<serde_json::Value>(&message.content) {
let tool_call_id = value
.get("tool_call_id")
.and_then(serde_json::Value::as_str)
.map(ToString::to_string);
let content = value
.get("content")
.and_then(serde_json::Value::as_str)
.map(ToString::to_string);
return ApiMessage {
role: "tool".to_string(),
content,
tool_call_id,
tool_calls: None,
};
}
}
ApiMessage {
role: message.role.clone(),
content: Some(message.content.clone()),
tool_call_id: None,
tool_calls: None,
}
})
.collect()
}
/// Send a chat completions request with required Copilot headers.
async fn send_chat_request(
&self,
messages: Vec<ApiMessage>,
tools: Option<&[ToolSpec]>,
model: &str,
temperature: f64,
) -> anyhow::Result<ProviderChatResponse> {
let (token, endpoint) = self.get_api_key().await?;
let url = format!("{}/chat/completions", endpoint.trim_end_matches('/'));
let native_tools = Self::convert_tools(tools);
let request = ApiChatRequest {
model: model.to_string(),
messages,
temperature,
tool_choice: native_tools.as_ref().map(|_| "auto".to_string()),
tools: native_tools,
};
let mut req = self
.http
.post(&url)
.header("Authorization", format!("Bearer {token}"))
.json(&request);
for (header, value) in &Self::COPILOT_HEADERS {
req = req.header(*header, *value);
}
let response = req.send().await?;
if !response.status().is_success() {
return Err(super::api_error("GitHub Copilot", response).await);
}
let api_response: ApiChatResponse = response.json().await?;
let choice = api_response
.choices
.into_iter()
.next()
.ok_or_else(|| anyhow::anyhow!("No response from GitHub Copilot"))?;
let tool_calls = choice
.message
.tool_calls
.unwrap_or_default()
.into_iter()
.map(|tool_call| ProviderToolCall {
id: tool_call
.id
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string()),
name: tool_call.function.name,
arguments: tool_call.function.arguments,
})
.collect();
Ok(ProviderChatResponse {
text: choice.message.content,
tool_calls,
})
}
/// Get a valid Copilot API key, refreshing or re-authenticating as needed.
/// Uses a Mutex to ensure only one caller refreshes at a time.
async fn get_api_key(&self) -> anyhow::Result<(String, String)> {
let mut cached = self.refresh_lock.lock().await;
if let Some(cached_key) = cached.as_ref() {
if chrono::Utc::now().timestamp() + 120 < cached_key.expires_at {
return Ok((cached_key.token.clone(), cached_key.api_endpoint.clone()));
}
}
if let Some(info) = self.load_api_key_from_disk().await {
if chrono::Utc::now().timestamp() + 120 < info.expires_at {
let endpoint = info
.endpoints
.as_ref()
.and_then(|e| e.api.clone())
.unwrap_or_else(|| DEFAULT_API.to_string());
let token = info.token;
*cached = Some(CachedApiKey {
token: token.clone(),
api_endpoint: endpoint.clone(),
expires_at: info.expires_at,
});
return Ok((token, endpoint));
}
}
let access_token = self.get_github_access_token().await?;
let api_key_info = self.exchange_for_api_key(&access_token).await?;
self.save_api_key_to_disk(&api_key_info).await;
let endpoint = api_key_info
.endpoints
.as_ref()
.and_then(|e| e.api.clone())
.unwrap_or_else(|| DEFAULT_API.to_string());
*cached = Some(CachedApiKey {
token: api_key_info.token.clone(),
api_endpoint: endpoint.clone(),
expires_at: api_key_info.expires_at,
});
Ok((api_key_info.token, endpoint))
}
/// Get a GitHub access token from config, cache, or device flow.
async fn get_github_access_token(&self) -> anyhow::Result<String> {
if let Some(token) = &self.github_token {
return Ok(token.clone());
}
let access_token_path = self.token_dir.join("access-token");
if let Ok(cached) = tokio::fs::read_to_string(&access_token_path).await {
let token = cached.trim();
if !token.is_empty() {
return Ok(token.to_string());
}
}
let token = self.device_code_login().await?;
write_file_secure(&access_token_path, &token).await;
Ok(token)
}
/// Run GitHub OAuth device code flow.
async fn device_code_login(&self) -> anyhow::Result<String> {
let response: DeviceCodeResponse = self
.http
.post(GITHUB_DEVICE_CODE_URL)
.header("Accept", "application/json")
.json(&serde_json::json!({
"client_id": GITHUB_CLIENT_ID,
"scope": "read:user"
}))
.send()
.await?
.error_for_status()?
.json()
.await?;
let mut poll_interval = Duration::from_secs(response.interval.max(5));
let expires_in = response.expires_in.max(1);
let expires_at = tokio::time::Instant::now() + Duration::from_secs(expires_in);
eprintln!(
"\nGitHub Copilot authentication is required.\n\
Visit: {}\n\
Code: {}\n\
Waiting for authorization...\n",
response.verification_uri, response.user_code
);
while tokio::time::Instant::now() < expires_at {
tokio::time::sleep(poll_interval).await;
let token_response: AccessTokenResponse = self
.http
.post(GITHUB_ACCESS_TOKEN_URL)
.header("Accept", "application/json")
.json(&serde_json::json!({
"client_id": GITHUB_CLIENT_ID,
"device_code": response.device_code,
"grant_type": "urn:ietf:params:oauth:grant-type:device_code"
}))
.send()
.await?
.json()
.await?;
if let Some(token) = token_response.access_token {
eprintln!("Authentication succeeded.\n");
return Ok(token);
}
match token_response.error.as_deref() {
Some("slow_down") => {
poll_interval += Duration::from_secs(5);
}
Some("authorization_pending") | None => {}
Some("expired_token") => {
anyhow::bail!("GitHub device authorization expired")
}
Some(error) => anyhow::bail!("GitHub auth failed: {error}"),
}
}
anyhow::bail!("Timed out waiting for GitHub authorization")
}
/// Exchange a GitHub access token for a Copilot API key.
async fn exchange_for_api_key(&self, access_token: &str) -> anyhow::Result<ApiKeyInfo> {
let mut request = self.http.get(GITHUB_API_KEY_URL);
for (header, value) in &Self::COPILOT_HEADERS {
request = request.header(*header, *value);
}
request = request.header("Authorization", format!("token {access_token}"));
let response = request.send().await?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
let sanitized = super::sanitize_api_error(&body);
if status.as_u16() == 401 || status.as_u16() == 403 {
let access_token_path = self.token_dir.join("access-token");
tokio::fs::remove_file(&access_token_path).await.ok();
}
anyhow::bail!(
"Failed to get Copilot API key ({status}): {sanitized}. \
Ensure your GitHub account has an active Copilot subscription."
);
}
let info: ApiKeyInfo = response.json().await?;
Ok(info)
}
async fn load_api_key_from_disk(&self) -> Option<ApiKeyInfo> {
let path = self.token_dir.join("api-key.json");
let data = tokio::fs::read_to_string(&path).await.ok()?;
serde_json::from_str(&data).ok()
}
async fn save_api_key_to_disk(&self, info: &ApiKeyInfo) {
let path = self.token_dir.join("api-key.json");
if let Ok(json) = serde_json::to_string_pretty(info) {
write_file_secure(&path, &json).await;
}
}
}
/// Write a file with 0600 permissions (owner read/write only).
/// Uses `spawn_blocking` to avoid blocking the async runtime.
async fn write_file_secure(path: &Path, content: &str) {
let path = path.to_path_buf();
let content = content.to_string();
let result = tokio::task::spawn_blocking(move || {
#[cfg(unix)]
{
use std::io::Write;
use std::os::unix::fs::{OpenOptionsExt, PermissionsExt};
let mut file = std::fs::OpenOptions::new()
.write(true)
.create(true)
.truncate(true)
.mode(0o600)
.open(&path)?;
file.write_all(content.as_bytes())?;
std::fs::set_permissions(&path, std::fs::Permissions::from_mode(0o600))?;
Ok::<(), std::io::Error>(())
}
#[cfg(not(unix))]
{
std::fs::write(&path, &content)?;
Ok::<(), std::io::Error>(())
}
})
.await;
match result {
Ok(Ok(())) => {}
Ok(Err(err)) => warn!("Failed to write secure file: {err}"),
Err(err) => warn!("Failed to spawn blocking write: {err}"),
}
}
#[async_trait]
impl Provider for CopilotProvider {
async fn chat_with_system(
&self,
system_prompt: Option<&str>,
message: &str,
model: &str,
temperature: f64,
) -> anyhow::Result<String> {
let mut messages = Vec::new();
if let Some(system) = system_prompt {
messages.push(ApiMessage {
role: "system".to_string(),
content: Some(system.to_string()),
tool_call_id: None,
tool_calls: None,
});
}
messages.push(ApiMessage {
role: "user".to_string(),
content: Some(message.to_string()),
tool_call_id: None,
tool_calls: None,
});
let response = self
.send_chat_request(messages, None, model, temperature)
.await?;
Ok(response.text.unwrap_or_default())
}
async fn chat_with_history(
&self,
messages: &[ChatMessage],
model: &str,
temperature: f64,
) -> anyhow::Result<String> {
let response = self
.send_chat_request(Self::convert_messages(messages), None, model, temperature)
.await?;
Ok(response.text.unwrap_or_default())
}
async fn chat(
&self,
request: ProviderChatRequest<'_>,
model: &str,
temperature: f64,
) -> anyhow::Result<ProviderChatResponse> {
self.send_chat_request(
Self::convert_messages(request.messages),
request.tools,
model,
temperature,
)
.await
}
fn supports_native_tools(&self) -> bool {
true
}
async fn warmup(&self) -> anyhow::Result<()> {
let _ = self.get_api_key().await?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn new_without_token() {
let provider = CopilotProvider::new(None);
assert!(provider.github_token.is_none());
}
#[test]
fn new_with_token() {
let provider = CopilotProvider::new(Some("ghp_test"));
assert_eq!(provider.github_token.as_deref(), Some("ghp_test"));
}
#[test]
fn empty_token_treated_as_none() {
let provider = CopilotProvider::new(Some(""));
assert!(provider.github_token.is_none());
}
#[tokio::test]
async fn cache_starts_empty() {
let provider = CopilotProvider::new(None);
let cached = provider.refresh_lock.lock().await;
assert!(cached.is_none());
}
#[test]
fn copilot_headers_include_required_fields() {
let headers = CopilotProvider::COPILOT_HEADERS;
assert!(headers
.iter()
.any(|(header, _)| *header == "Editor-Version"));
assert!(headers
.iter()
.any(|(header, _)| *header == "Editor-Plugin-Version"));
assert!(headers.iter().any(|(header, _)| *header == "User-Agent"));
}
#[test]
fn default_interval_and_expiry() {
assert_eq!(default_interval(), 5);
assert_eq!(default_expires_in(), 900);
}
#[test]
fn supports_native_tools() {
let provider = CopilotProvider::new(None);
assert!(provider.supports_native_tools());
}
}

View file

@ -1,5 +1,6 @@
pub mod anthropic;
pub mod compatible;
pub mod copilot;
pub mod gemini;
pub mod ollama;
pub mod openai;
@ -37,9 +38,18 @@ fn token_end(input: &str, from: usize) -> usize {
/// Scrub known secret-like token prefixes from provider error strings.
///
/// Redacts tokens with prefixes like `sk-`, `xoxb-`, and `xoxp-`.
/// Redacts tokens with prefixes like `sk-`, `xoxb-`, `xoxp-`, `ghp_`, `gho_`,
/// `ghu_`, and `github_pat_`.
pub fn scrub_secret_patterns(input: &str) -> String {
const PREFIXES: [&str; 3] = ["sk-", "xoxb-", "xoxp-"];
const PREFIXES: [&str; 7] = [
"sk-",
"xoxb-",
"xoxp-",
"ghp_",
"gho_",
"ghu_",
"github_pat_",
];
let mut scrubbed = input.to_string();
@ -104,9 +114,12 @@ pub async fn api_error(provider: &str, response: reqwest::Response) -> anyhow::E
///
/// For Anthropic, the provider-specific env var is `ANTHROPIC_OAUTH_TOKEN` (for setup-tokens)
/// followed by `ANTHROPIC_API_KEY` (for regular API keys).
fn resolve_api_key(name: &str, api_key: Option<&str>) -> Option<String> {
if let Some(key) = api_key.map(str::trim).filter(|k| !k.is_empty()) {
return Some(key.to_string());
fn resolve_provider_credential(name: &str, credential_override: Option<&str>) -> Option<String> {
if let Some(raw_override) = credential_override {
let trimmed_override = raw_override.trim();
if !trimmed_override.is_empty() {
return Some(trimmed_override.to_owned());
}
}
let provider_env_candidates: Vec<&str> = match name {
@ -135,6 +148,7 @@ fn resolve_api_key(name: &str, api_key: Option<&str>) -> Option<String> {
"opencode" | "opencode-zen" => vec!["OPENCODE_API_KEY"],
"vercel" | "vercel-ai" => vec!["VERCEL_API_KEY"],
"cloudflare" | "cloudflare-ai" => vec!["CLOUDFLARE_API_KEY"],
"astrai" => vec!["ASTRAI_API_KEY"],
_ => vec![],
};
@ -182,19 +196,28 @@ fn parse_custom_provider_url(
}
}
/// Factory: create the right provider from config
#[allow(clippy::too_many_lines)]
/// Factory: create the right provider from config (without custom URL)
pub fn create_provider(name: &str, api_key: Option<&str>) -> anyhow::Result<Box<dyn Provider>> {
let resolved_key = resolve_api_key(name, api_key);
let key = resolved_key.as_deref();
create_provider_with_url(name, api_key, None)
}
/// Factory: create the right provider from config with optional custom base URL
#[allow(clippy::too_many_lines)]
pub fn create_provider_with_url(
name: &str,
api_key: Option<&str>,
api_url: Option<&str>,
) -> anyhow::Result<Box<dyn Provider>> {
let resolved_credential = resolve_provider_credential(name, api_key);
#[allow(clippy::option_as_ref_deref)]
let key = resolved_credential.as_ref().map(String::as_str);
match name {
// ── Primary providers (custom implementations) ───────
"openrouter" => Ok(Box::new(openrouter::OpenRouterProvider::new(key))),
"anthropic" => Ok(Box::new(anthropic::AnthropicProvider::new(key))),
"openai" => Ok(Box::new(openai::OpenAiProvider::new(key))),
// Ollama is a local service that doesn't use API keys.
// The api_key parameter is ignored to avoid it being misinterpreted as a base_url.
"ollama" => Ok(Box::new(ollama::OllamaProvider::new(None))),
// Ollama uses api_url for custom base URL (e.g. remote Ollama instance)
"ollama" => Ok(Box::new(ollama::OllamaProvider::new(api_url))),
"gemini" | "google" | "google-gemini" => {
Ok(Box::new(gemini::GeminiProvider::new(key)))
}
@ -257,7 +280,7 @@ pub fn create_provider(name: &str, api_key: Option<&str>) -> anyhow::Result<Box<
"Groq", "https://api.groq.com/openai", key, AuthStyle::Bearer,
))),
"mistral" => Ok(Box::new(OpenAiCompatibleProvider::new(
"Mistral", "https://api.mistral.ai", key, AuthStyle::Bearer,
"Mistral", "https://api.mistral.ai/v1", key, AuthStyle::Bearer,
))),
"xai" | "grok" => Ok(Box::new(OpenAiCompatibleProvider::new(
"xAI", "https://api.x.ai", key, AuthStyle::Bearer,
@ -277,11 +300,33 @@ pub fn create_provider(name: &str, api_key: Option<&str>) -> anyhow::Result<Box<
"cohere" => Ok(Box::new(OpenAiCompatibleProvider::new(
"Cohere", "https://api.cohere.com/compatibility", key, AuthStyle::Bearer,
))),
"copilot" | "github-copilot" => Ok(Box::new(OpenAiCompatibleProvider::new(
"GitHub Copilot", "https://api.githubcopilot.com", key, AuthStyle::Bearer,
))),
"nvidia" | "nvidia-nim" | "build.nvidia.com" => Ok(Box::new(OpenAiCompatibleProvider::new(
"NVIDIA NIM", "https://integrate.api.nvidia.com/v1", key, AuthStyle::Bearer,
"copilot" | "github-copilot" => {
Ok(Box::new(copilot::CopilotProvider::new(api_key)))
},
"lmstudio" | "lm-studio" => {
let lm_studio_key = api_key
.map(str::trim)
.filter(|value| !value.is_empty())
.unwrap_or("lm-studio");
Ok(Box::new(OpenAiCompatibleProvider::new(
"LM Studio",
"http://localhost:1234/v1",
Some(lm_studio_key),
AuthStyle::Bearer,
)))
}
"nvidia" | "nvidia-nim" | "build.nvidia.com" => Ok(Box::new(
OpenAiCompatibleProvider::new(
"NVIDIA NIM",
"https://integrate.api.nvidia.com/v1",
key,
AuthStyle::Bearer,
),
)),
// ── AI inference routers ─────────────────────────────
"astrai" => Ok(Box::new(OpenAiCompatibleProvider::new(
"Astrai", "https://as-trai.com/v1", key, AuthStyle::Bearer,
))),
// ── Bring Your Own Provider (custom URL) ───────────
@ -326,13 +371,14 @@ pub fn create_provider(name: &str, api_key: Option<&str>) -> anyhow::Result<Box<
pub fn create_resilient_provider(
primary_name: &str,
api_key: Option<&str>,
api_url: Option<&str>,
reliability: &crate::config::ReliabilityConfig,
) -> anyhow::Result<Box<dyn Provider>> {
let mut providers: Vec<(String, Box<dyn Provider>)> = Vec::new();
providers.push((
primary_name.to_string(),
create_provider(primary_name, api_key)?,
create_provider_with_url(primary_name, api_key, api_url)?,
));
for fallback in &reliability.fallback_providers {
@ -340,21 +386,13 @@ pub fn create_resilient_provider(
continue;
}
if api_key.is_some() && fallback != "ollama" {
tracing::warn!(
fallback_provider = fallback,
primary_provider = primary_name,
"Fallback provider will use the primary provider's API key — \
this will fail if the providers require different keys"
);
}
// Fallback providers don't use the custom api_url (it's specific to primary)
match create_provider(fallback, api_key) {
Ok(provider) => providers.push((fallback.clone(), provider)),
Err(e) => {
Err(_error) => {
tracing::warn!(
fallback_provider = fallback,
"Ignoring invalid fallback provider: {e}"
"Ignoring invalid fallback provider during initialization"
);
}
}
@ -377,12 +415,13 @@ pub fn create_resilient_provider(
pub fn create_routed_provider(
primary_name: &str,
api_key: Option<&str>,
api_url: Option<&str>,
reliability: &crate::config::ReliabilityConfig,
model_routes: &[crate::config::ModelRouteConfig],
default_model: &str,
) -> anyhow::Result<Box<dyn Provider>> {
if model_routes.is_empty() {
return create_resilient_provider(primary_name, api_key, reliability);
return create_resilient_provider(primary_name, api_key, api_url, reliability);
}
// Collect unique provider names needed
@ -396,12 +435,19 @@ pub fn create_routed_provider(
// Create each provider (with its own resilience wrapper)
let mut providers: Vec<(String, Box<dyn Provider>)> = Vec::new();
for name in &needed {
let key = model_routes
let routed_credential = model_routes
.iter()
.find(|r| &r.provider == name)
.and_then(|r| r.api_key.as_deref())
.or(api_key);
match create_resilient_provider(name, key, reliability) {
.and_then(|r| {
r.api_key.as_ref().and_then(|raw_key| {
let trimmed_key = raw_key.trim();
(!trimmed_key.is_empty()).then_some(trimmed_key)
})
});
let key = routed_credential.or(api_key);
// Only use api_url for the primary provider
let url = if name == primary_name { api_url } else { None };
match create_resilient_provider(name, key, url, reliability) {
Ok(provider) => providers.push((name.clone(), provider)),
Err(e) => {
if name == primary_name {
@ -409,7 +455,7 @@ pub fn create_routed_provider(
}
tracing::warn!(
provider = name.as_str(),
"Ignoring routed provider that failed to create: {e}"
"Ignoring routed provider that failed to initialize"
);
}
}
@ -441,27 +487,27 @@ mod tests {
use super::*;
#[test]
fn resolve_api_key_prefers_explicit_argument() {
let resolved = resolve_api_key("openrouter", Some(" explicit-key "));
assert_eq!(resolved.as_deref(), Some("explicit-key"));
fn resolve_provider_credential_prefers_explicit_argument() {
let resolved = resolve_provider_credential("openrouter", Some(" explicit-key "));
assert_eq!(resolved, Some("explicit-key".to_string()));
}
// ── Primary providers ────────────────────────────────────
#[test]
fn factory_openrouter() {
assert!(create_provider("openrouter", Some("sk-test")).is_ok());
assert!(create_provider("openrouter", Some("provider-test-credential")).is_ok());
assert!(create_provider("openrouter", None).is_ok());
}
#[test]
fn factory_anthropic() {
assert!(create_provider("anthropic", Some("sk-test")).is_ok());
assert!(create_provider("anthropic", Some("provider-test-credential")).is_ok());
}
#[test]
fn factory_openai() {
assert!(create_provider("openai", Some("sk-test")).is_ok());
assert!(create_provider("openai", Some("provider-test-credential")).is_ok());
}
#[test]
@ -556,6 +602,13 @@ mod tests {
assert!(create_provider("dashscope-us", Some("key")).is_ok());
}
#[test]
fn factory_lmstudio() {
assert!(create_provider("lmstudio", Some("key")).is_ok());
assert!(create_provider("lm-studio", Some("key")).is_ok());
assert!(create_provider("lmstudio", None).is_ok());
}
// ── Extended ecosystem ───────────────────────────────────
#[test]
@ -614,6 +667,13 @@ mod tests {
assert!(create_provider("build.nvidia.com", Some("nvapi-test")).is_ok());
}
// ── AI inference routers ─────────────────────────────────
#[test]
fn factory_astrai() {
assert!(create_provider("astrai", Some("sk-astrai-test")).is_ok());
}
// ── Custom / BYOP provider ─────────────────────────────
#[test]
@ -761,17 +821,33 @@ mod tests {
scheduler_retries: 2,
};
let provider = create_resilient_provider("openrouter", Some("sk-test"), &reliability);
let provider = create_resilient_provider(
"openrouter",
Some("provider-test-credential"),
None,
&reliability,
);
assert!(provider.is_ok());
}
#[test]
fn resilient_provider_errors_for_invalid_primary() {
let reliability = crate::config::ReliabilityConfig::default();
let provider = create_resilient_provider("totally-invalid", Some("sk-test"), &reliability);
let provider = create_resilient_provider(
"totally-invalid",
Some("provider-test-credential"),
None,
&reliability,
);
assert!(provider.is_err());
}
#[test]
fn ollama_with_custom_url() {
let provider = create_provider_with_url("ollama", None, Some("http://10.100.2.32:11434"));
assert!(provider.is_ok());
}
#[test]
fn factory_all_providers_create_successfully() {
let providers = [
@ -794,6 +870,7 @@ mod tests {
"qwen",
"qwen-intl",
"qwen-us",
"lmstudio",
"groq",
"mistral",
"xai",
@ -888,7 +965,7 @@ mod tests {
#[test]
fn sanitize_preserves_unicode_boundaries() {
let input = format!("{} sk-abcdef123", "こんにちは".repeat(80));
let input = format!("{} sk-abcdef123", "hello🙂".repeat(80));
let result = sanitize_api_error(&input);
assert!(std::str::from_utf8(result.as_bytes()).is_ok());
assert!(!result.contains("sk-abcdef123"));
@ -900,4 +977,32 @@ mod tests {
let result = sanitize_api_error(input);
assert_eq!(result, input);
}
#[test]
fn scrub_github_personal_access_token() {
let input = "auth failed with token ghp_abc123def456";
let result = scrub_secret_patterns(input);
assert_eq!(result, "auth failed with token [REDACTED]");
}
#[test]
fn scrub_github_oauth_token() {
let input = "Bearer gho_1234567890abcdef";
let result = scrub_secret_patterns(input);
assert_eq!(result, "Bearer [REDACTED]");
}
#[test]
fn scrub_github_user_token() {
let input = "token ghu_sessiontoken123";
let result = scrub_secret_patterns(input);
assert_eq!(result, "token [REDACTED]");
}
#[test]
fn scrub_github_fine_grained_pat() {
let input = "failed: github_pat_11AABBC_xyzzy789";
let result = scrub_secret_patterns(input);
assert_eq!(result, "failed: [REDACTED]");
}
}

View file

@ -8,6 +8,8 @@ pub struct OllamaProvider {
client: Client,
}
// ─── Request Structures ───────────────────────────────────────────────────────
#[derive(Debug, Serialize)]
struct ChatRequest {
model: String,
@ -27,6 +29,8 @@ struct Options {
temperature: f64,
}
// ─── Response Structures ──────────────────────────────────────────────────────
#[derive(Debug, Deserialize)]
struct ApiChatResponse {
message: ResponseMessage,
@ -34,9 +38,30 @@ struct ApiChatResponse {
#[derive(Debug, Deserialize)]
struct ResponseMessage {
#[serde(default)]
content: String,
#[serde(default)]
tool_calls: Vec<OllamaToolCall>,
/// Some models return a "thinking" field with internal reasoning
#[serde(default)]
thinking: Option<String>,
}
#[derive(Debug, Deserialize)]
struct OllamaToolCall {
id: Option<String>,
function: OllamaFunction,
}
#[derive(Debug, Deserialize)]
struct OllamaFunction {
name: String,
#[serde(default)]
arguments: serde_json::Value,
}
// ─── Implementation ───────────────────────────────────────────────────────────
impl OllamaProvider {
pub fn new(base_url: Option<&str>) -> Self {
Self {
@ -45,12 +70,145 @@ impl OllamaProvider {
.trim_end_matches('/')
.to_string(),
client: Client::builder()
.timeout(std::time::Duration::from_secs(300)) // Ollama runs locally, may be slow
.timeout(std::time::Duration::from_secs(300))
.connect_timeout(std::time::Duration::from_secs(10))
.build()
.unwrap_or_else(|_| Client::new()),
}
}
/// Send a request to Ollama and get the parsed response
async fn send_request(
&self,
messages: Vec<Message>,
model: &str,
temperature: f64,
) -> anyhow::Result<ApiChatResponse> {
let request = ChatRequest {
model: model.to_string(),
messages,
stream: false,
options: Options { temperature },
};
let url = format!("{}/api/chat", self.base_url);
tracing::debug!(
"Ollama request: url={} model={} message_count={} temperature={}",
url,
model,
request.messages.len(),
temperature
);
let response = self.client.post(&url).json(&request).send().await?;
let status = response.status();
tracing::debug!("Ollama response status: {}", status);
let body = response.bytes().await?;
tracing::debug!("Ollama response body length: {} bytes", body.len());
if !status.is_success() {
let raw = String::from_utf8_lossy(&body);
let sanitized = super::sanitize_api_error(&raw);
tracing::error!(
"Ollama error response: status={} body_excerpt={}",
status,
sanitized
);
anyhow::bail!(
"Ollama API error ({}): {}. Is Ollama running? (brew install ollama && ollama serve)",
status,
sanitized
);
}
let chat_response: ApiChatResponse = match serde_json::from_slice(&body) {
Ok(r) => r,
Err(e) => {
let raw = String::from_utf8_lossy(&body);
let sanitized = super::sanitize_api_error(&raw);
tracing::error!(
"Ollama response deserialization failed: {e}. body_excerpt={}",
sanitized
);
anyhow::bail!("Failed to parse Ollama response: {e}");
}
};
Ok(chat_response)
}
/// Convert Ollama tool calls to the JSON format expected by parse_tool_calls in loop_.rs
///
/// Handles quirky model behavior where tool calls are wrapped:
/// - `{"name": "tool_call", "arguments": {"name": "shell", "arguments": {...}}}`
/// - `{"name": "tool.shell", "arguments": {...}}`
fn format_tool_calls_for_loop(&self, tool_calls: &[OllamaToolCall]) -> String {
let formatted_calls: Vec<serde_json::Value> = tool_calls
.iter()
.map(|tc| {
let (tool_name, tool_args) = self.extract_tool_name_and_args(tc);
// Arguments must be a JSON string for parse_tool_calls compatibility
let args_str =
serde_json::to_string(&tool_args).unwrap_or_else(|_| "{}".to_string());
serde_json::json!({
"id": tc.id,
"type": "function",
"function": {
"name": tool_name,
"arguments": args_str
}
})
})
.collect();
serde_json::json!({
"content": "",
"tool_calls": formatted_calls
})
.to_string()
}
/// Extract the actual tool name and arguments from potentially nested structures
fn extract_tool_name_and_args(&self, tc: &OllamaToolCall) -> (String, serde_json::Value) {
let name = &tc.function.name;
let args = &tc.function.arguments;
// Pattern 1: Nested tool_call wrapper (various malformed versions)
// {"name": "tool_call", "arguments": {"name": "shell", "arguments": {"command": "date"}}}
// {"name": "tool_call><json", "arguments": {"name": "shell", ...}}
// {"name": "tool.call", "arguments": {"name": "shell", ...}}
if name == "tool_call"
|| name == "tool.call"
|| name.starts_with("tool_call>")
|| name.starts_with("tool_call<")
{
if let Some(nested_name) = args.get("name").and_then(|v| v.as_str()) {
let nested_args = args
.get("arguments")
.cloned()
.unwrap_or(serde_json::json!({}));
tracing::debug!(
"Unwrapped nested tool call: {} -> {} with args {:?}",
name,
nested_name,
nested_args
);
return (nested_name.to_string(), nested_args);
}
}
// Pattern 2: Prefixed tool name (tool.shell, tool.file_read, etc.)
if let Some(stripped) = name.strip_prefix("tool.") {
return (stripped.to_string(), args.clone());
}
// Pattern 3: Normal tool call
(name.clone(), args.clone())
}
}
#[async_trait]
@ -76,27 +234,96 @@ impl Provider for OllamaProvider {
content: message.to_string(),
});
let request = ChatRequest {
model: model.to_string(),
messages,
stream: false,
options: Options { temperature },
};
let response = self.send_request(messages, model, temperature).await?;
let url = format!("{}/api/chat", self.base_url);
let response = self.client.post(&url).json(&request).send().await?;
if !response.status().is_success() {
let err = super::api_error("Ollama", response).await;
anyhow::bail!("{err}. Is Ollama running? (brew install ollama && ollama serve)");
// If model returned tool calls, format them for loop_.rs's parse_tool_calls
if !response.message.tool_calls.is_empty() {
tracing::debug!(
"Ollama returned {} tool call(s), formatting for loop parser",
response.message.tool_calls.len()
);
return Ok(self.format_tool_calls_for_loop(&response.message.tool_calls));
}
let chat_response: ApiChatResponse = response.json().await?;
Ok(chat_response.message.content)
// Plain text response
let content = response.message.content;
// Handle edge case: model returned only "thinking" with no content or tool calls
if content.is_empty() {
if let Some(thinking) = &response.message.thinking {
tracing::warn!(
"Ollama returned empty content with only thinking: '{}'. Model may have stopped prematurely.",
if thinking.len() > 100 { &thinking[..100] } else { thinking }
);
return Ok(format!(
"I was thinking about this: {}... but I didn't complete my response. Could you try asking again?",
if thinking.len() > 200 { &thinking[..200] } else { thinking }
));
}
tracing::warn!("Ollama returned empty content with no tool calls");
}
Ok(content)
}
async fn chat_with_history(
&self,
messages: &[crate::providers::ChatMessage],
model: &str,
temperature: f64,
) -> anyhow::Result<String> {
let api_messages: Vec<Message> = messages
.iter()
.map(|m| Message {
role: m.role.clone(),
content: m.content.clone(),
})
.collect();
let response = self.send_request(api_messages, model, temperature).await?;
// If model returned tool calls, format them for loop_.rs's parse_tool_calls
if !response.message.tool_calls.is_empty() {
tracing::debug!(
"Ollama returned {} tool call(s), formatting for loop parser",
response.message.tool_calls.len()
);
return Ok(self.format_tool_calls_for_loop(&response.message.tool_calls));
}
// Plain text response
let content = response.message.content;
// Handle edge case: model returned only "thinking" with no content or tool calls
// This is a model quirk - it stopped after reasoning without producing output
if content.is_empty() {
if let Some(thinking) = &response.message.thinking {
tracing::warn!(
"Ollama returned empty content with only thinking: '{}'. Model may have stopped prematurely.",
if thinking.len() > 100 { &thinking[..100] } else { thinking }
);
// Return a message indicating the model's thought process but no action
return Ok(format!(
"I was thinking about this: {}... but I didn't complete my response. Could you try asking again?",
if thinking.len() > 200 { &thinking[..200] } else { thinking }
));
}
tracing::warn!("Ollama returned empty content with no tool calls");
}
Ok(content)
}
fn supports_native_tools(&self) -> bool {
// Return false since loop_.rs uses XML-style tool parsing via system prompt
// The model may return native tool_calls but we convert them to JSON format
// that parse_tool_calls() understands
false
}
}
// ─── Tests ────────────────────────────────────────────────────────────────────
#[cfg(test)]
mod tests {
use super::*;
@ -125,46 +352,6 @@ mod tests {
assert_eq!(p.base_url, "");
}
#[test]
fn request_serializes_with_system() {
let req = ChatRequest {
model: "llama3".to_string(),
messages: vec![
Message {
role: "system".to_string(),
content: "You are ZeroClaw".to_string(),
},
Message {
role: "user".to_string(),
content: "hello".to_string(),
},
],
stream: false,
options: Options { temperature: 0.7 },
};
let json = serde_json::to_string(&req).unwrap();
assert!(json.contains("\"stream\":false"));
assert!(json.contains("llama3"));
assert!(json.contains("system"));
assert!(json.contains("\"temperature\":0.7"));
}
#[test]
fn request_serializes_without_system() {
let req = ChatRequest {
model: "mistral".to_string(),
messages: vec![Message {
role: "user".to_string(),
content: "test".to_string(),
}],
stream: false,
options: Options { temperature: 0.0 },
};
let json = serde_json::to_string(&req).unwrap();
assert!(!json.contains("\"role\":\"system\""));
assert!(json.contains("mistral"));
}
#[test]
fn response_deserializes() {
let json = r#"{"message":{"role":"assistant","content":"Hello from Ollama!"}}"#;
@ -180,9 +367,98 @@ mod tests {
}
#[test]
fn response_with_multiline() {
let json = r#"{"message":{"role":"assistant","content":"line1\nline2\nline3"}}"#;
fn response_with_missing_content_defaults_to_empty() {
let json = r#"{"message":{"role":"assistant"}}"#;
let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
assert!(resp.message.content.contains("line1"));
assert!(resp.message.content.is_empty());
}
#[test]
fn response_with_thinking_field_extracts_content() {
let json =
r#"{"message":{"role":"assistant","content":"hello","thinking":"internal reasoning"}}"#;
let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
assert_eq!(resp.message.content, "hello");
}
#[test]
fn response_with_tool_calls_parses_correctly() {
let json = r#"{"message":{"role":"assistant","content":"","tool_calls":[{"id":"call_123","function":{"name":"shell","arguments":{"command":"date"}}}]}}"#;
let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
assert!(resp.message.content.is_empty());
assert_eq!(resp.message.tool_calls.len(), 1);
assert_eq!(resp.message.tool_calls[0].function.name, "shell");
}
#[test]
fn extract_tool_name_handles_nested_tool_call() {
let provider = OllamaProvider::new(None);
let tc = OllamaToolCall {
id: Some("call_123".into()),
function: OllamaFunction {
name: "tool_call".into(),
arguments: serde_json::json!({
"name": "shell",
"arguments": {"command": "date"}
}),
},
};
let (name, args) = provider.extract_tool_name_and_args(&tc);
assert_eq!(name, "shell");
assert_eq!(args.get("command").unwrap(), "date");
}
#[test]
fn extract_tool_name_handles_prefixed_name() {
let provider = OllamaProvider::new(None);
let tc = OllamaToolCall {
id: Some("call_123".into()),
function: OllamaFunction {
name: "tool.shell".into(),
arguments: serde_json::json!({"command": "ls"}),
},
};
let (name, args) = provider.extract_tool_name_and_args(&tc);
assert_eq!(name, "shell");
assert_eq!(args.get("command").unwrap(), "ls");
}
#[test]
fn extract_tool_name_handles_normal_call() {
let provider = OllamaProvider::new(None);
let tc = OllamaToolCall {
id: Some("call_123".into()),
function: OllamaFunction {
name: "file_read".into(),
arguments: serde_json::json!({"path": "/tmp/test"}),
},
};
let (name, args) = provider.extract_tool_name_and_args(&tc);
assert_eq!(name, "file_read");
assert_eq!(args.get("path").unwrap(), "/tmp/test");
}
#[test]
fn format_tool_calls_produces_valid_json() {
let provider = OllamaProvider::new(None);
let tool_calls = vec![OllamaToolCall {
id: Some("call_abc".into()),
function: OllamaFunction {
name: "shell".into(),
arguments: serde_json::json!({"command": "date"}),
},
}];
let formatted = provider.format_tool_calls_for_loop(&tool_calls);
let parsed: serde_json::Value = serde_json::from_str(&formatted).unwrap();
assert!(parsed.get("tool_calls").is_some());
let calls = parsed.get("tool_calls").unwrap().as_array().unwrap();
assert_eq!(calls.len(), 1);
let func = calls[0].get("function").unwrap();
assert_eq!(func.get("name").unwrap(), "shell");
// arguments should be a string (JSON-encoded)
assert!(func.get("arguments").unwrap().is_string());
}
}

View file

@ -8,7 +8,7 @@ use reqwest::Client;
use serde::{Deserialize, Serialize};
pub struct OpenAiProvider {
api_key: Option<String>,
credential: Option<String>,
client: Client,
}
@ -110,9 +110,9 @@ struct NativeResponseMessage {
}
impl OpenAiProvider {
pub fn new(api_key: Option<&str>) -> Self {
pub fn new(credential: Option<&str>) -> Self {
Self {
api_key: api_key.map(ToString::to_string),
credential: credential.map(ToString::to_string),
client: Client::builder()
.timeout(std::time::Duration::from_secs(120))
.connect_timeout(std::time::Duration::from_secs(10))
@ -232,7 +232,7 @@ impl Provider for OpenAiProvider {
model: &str,
temperature: f64,
) -> anyhow::Result<String> {
let api_key = self.api_key.as_ref().ok_or_else(|| {
let credential = self.credential.as_ref().ok_or_else(|| {
anyhow::anyhow!("OpenAI API key not set. Set OPENAI_API_KEY or edit config.toml.")
})?;
@ -259,7 +259,7 @@ impl Provider for OpenAiProvider {
let response = self
.client
.post("https://api.openai.com/v1/chat/completions")
.header("Authorization", format!("Bearer {api_key}"))
.header("Authorization", format!("Bearer {credential}"))
.json(&request)
.send()
.await?;
@ -284,7 +284,7 @@ impl Provider for OpenAiProvider {
model: &str,
temperature: f64,
) -> anyhow::Result<ProviderChatResponse> {
let api_key = self.api_key.as_ref().ok_or_else(|| {
let credential = self.credential.as_ref().ok_or_else(|| {
anyhow::anyhow!("OpenAI API key not set. Set OPENAI_API_KEY or edit config.toml.")
})?;
@ -300,7 +300,7 @@ impl Provider for OpenAiProvider {
let response = self
.client
.post("https://api.openai.com/v1/chat/completions")
.header("Authorization", format!("Bearer {api_key}"))
.header("Authorization", format!("Bearer {credential}"))
.json(&native_request)
.send()
.await?;
@ -330,20 +330,20 @@ mod tests {
#[test]
fn creates_with_key() {
let p = OpenAiProvider::new(Some("sk-proj-abc123"));
assert_eq!(p.api_key.as_deref(), Some("sk-proj-abc123"));
let p = OpenAiProvider::new(Some("openai-test-credential"));
assert_eq!(p.credential.as_deref(), Some("openai-test-credential"));
}
#[test]
fn creates_without_key() {
let p = OpenAiProvider::new(None);
assert!(p.api_key.is_none());
assert!(p.credential.is_none());
}
#[test]
fn creates_with_empty_key() {
let p = OpenAiProvider::new(Some(""));
assert_eq!(p.api_key.as_deref(), Some(""));
assert_eq!(p.credential.as_deref(), Some(""));
}
#[tokio::test]

View file

@ -8,7 +8,7 @@ use reqwest::Client;
use serde::{Deserialize, Serialize};
pub struct OpenRouterProvider {
api_key: Option<String>,
credential: Option<String>,
client: Client,
}
@ -110,9 +110,9 @@ struct NativeResponseMessage {
}
impl OpenRouterProvider {
pub fn new(api_key: Option<&str>) -> Self {
pub fn new(credential: Option<&str>) -> Self {
Self {
api_key: api_key.map(ToString::to_string),
credential: credential.map(ToString::to_string),
client: Client::builder()
.timeout(std::time::Duration::from_secs(120))
.connect_timeout(std::time::Duration::from_secs(10))
@ -232,10 +232,10 @@ impl Provider for OpenRouterProvider {
async fn warmup(&self) -> anyhow::Result<()> {
// Hit a lightweight endpoint to establish TLS + HTTP/2 connection pool.
// This prevents the first real chat request from timing out on cold start.
if let Some(api_key) = self.api_key.as_ref() {
if let Some(credential) = self.credential.as_ref() {
self.client
.get("https://openrouter.ai/api/v1/auth/key")
.header("Authorization", format!("Bearer {api_key}"))
.header("Authorization", format!("Bearer {credential}"))
.send()
.await?
.error_for_status()?;
@ -250,7 +250,7 @@ impl Provider for OpenRouterProvider {
model: &str,
temperature: f64,
) -> anyhow::Result<String> {
let api_key = self.api_key.as_ref()
let credential = self.credential.as_ref()
.ok_or_else(|| anyhow::anyhow!("OpenRouter API key not set. Run `zeroclaw onboard` or set OPENROUTER_API_KEY env var."))?;
let mut messages = Vec::new();
@ -276,7 +276,7 @@ impl Provider for OpenRouterProvider {
let response = self
.client
.post("https://openrouter.ai/api/v1/chat/completions")
.header("Authorization", format!("Bearer {api_key}"))
.header("Authorization", format!("Bearer {credential}"))
.header(
"HTTP-Referer",
"https://github.com/theonlyhennygod/zeroclaw",
@ -306,7 +306,7 @@ impl Provider for OpenRouterProvider {
model: &str,
temperature: f64,
) -> anyhow::Result<String> {
let api_key = self.api_key.as_ref()
let credential = self.credential.as_ref()
.ok_or_else(|| anyhow::anyhow!("OpenRouter API key not set. Run `zeroclaw onboard` or set OPENROUTER_API_KEY env var."))?;
let api_messages: Vec<Message> = messages
@ -326,7 +326,7 @@ impl Provider for OpenRouterProvider {
let response = self
.client
.post("https://openrouter.ai/api/v1/chat/completions")
.header("Authorization", format!("Bearer {api_key}"))
.header("Authorization", format!("Bearer {credential}"))
.header(
"HTTP-Referer",
"https://github.com/theonlyhennygod/zeroclaw",
@ -356,7 +356,7 @@ impl Provider for OpenRouterProvider {
model: &str,
temperature: f64,
) -> anyhow::Result<ProviderChatResponse> {
let api_key = self.api_key.as_ref().ok_or_else(|| {
let credential = self.credential.as_ref().ok_or_else(|| {
anyhow::anyhow!(
"OpenRouter API key not set. Run `zeroclaw onboard` or set OPENROUTER_API_KEY env var."
)
@ -374,7 +374,7 @@ impl Provider for OpenRouterProvider {
let response = self
.client
.post("https://openrouter.ai/api/v1/chat/completions")
.header("Authorization", format!("Bearer {api_key}"))
.header("Authorization", format!("Bearer {credential}"))
.header(
"HTTP-Referer",
"https://github.com/theonlyhennygod/zeroclaw",
@ -409,7 +409,7 @@ impl Provider for OpenRouterProvider {
model: &str,
temperature: f64,
) -> anyhow::Result<ProviderChatResponse> {
let api_key = self.api_key.as_ref().ok_or_else(|| {
let credential = self.credential.as_ref().ok_or_else(|| {
anyhow::anyhow!(
"OpenRouter API key not set. Run `zeroclaw onboard` or set OPENROUTER_API_KEY env var."
)
@ -462,7 +462,7 @@ impl Provider for OpenRouterProvider {
let response = self
.client
.post("https://openrouter.ai/api/v1/chat/completions")
.header("Authorization", format!("Bearer {api_key}"))
.header("Authorization", format!("Bearer {credential}"))
.header(
"HTTP-Referer",
"https://github.com/theonlyhennygod/zeroclaw",
@ -494,14 +494,17 @@ mod tests {
#[test]
fn creates_with_key() {
let provider = OpenRouterProvider::new(Some("sk-or-123"));
assert_eq!(provider.api_key.as_deref(), Some("sk-or-123"));
let provider = OpenRouterProvider::new(Some("openrouter-test-credential"));
assert_eq!(
provider.credential.as_deref(),
Some("openrouter-test-credential")
);
}
#[test]
fn creates_without_key() {
let provider = OpenRouterProvider::new(None);
assert!(provider.api_key.is_none());
assert!(provider.credential.is_none());
}
#[tokio::test]

View file

@ -144,8 +144,8 @@ impl Provider for ReliableProvider {
async fn warmup(&self) -> anyhow::Result<()> {
for (name, provider) in &self.providers {
tracing::info!(provider = name, "Warming up provider connection pool");
if let Err(e) = provider.warmup().await {
tracing::warn!(provider = name, "Warmup failed (non-fatal): {e}");
if provider.warmup().await.is_err() {
tracing::warn!(provider = name, "Warmup failed (non-fatal)");
}
}
Ok(())
@ -186,8 +186,15 @@ impl Provider for ReliableProvider {
let non_retryable = is_non_retryable(&e);
let rate_limited = is_rate_limited(&e);
let failure_reason = if rate_limited {
"rate_limited"
} else if non_retryable {
"non_retryable"
} else {
"retryable"
};
failures.push(format!(
"{provider_name}/{current_model} attempt {}/{}: {e}",
"{provider_name}/{current_model} attempt {}/{}: {failure_reason}",
attempt + 1,
self.max_retries + 1
));
@ -284,8 +291,15 @@ impl Provider for ReliableProvider {
let non_retryable = is_non_retryable(&e);
let rate_limited = is_rate_limited(&e);
let failure_reason = if rate_limited {
"rate_limited"
} else if non_retryable {
"non_retryable"
} else {
"retryable"
};
failures.push(format!(
"{provider_name}/{current_model} attempt {}/{}: {e}",
"{provider_name}/{current_model} attempt {}/{}: {failure_reason}",
attempt + 1,
self.max_retries + 1
));

View file

@ -193,6 +193,13 @@ pub enum StreamError {
#[async_trait]
pub trait Provider: Send + Sync {
/// Query provider capabilities.
///
/// Default implementation returns minimal capabilities (no native tool calling).
/// Providers should override this to declare their actual capabilities.
fn capabilities(&self) -> ProviderCapabilities {
ProviderCapabilities::default()
}
/// Simple one-shot chat (single user message, no explicit system prompt).
///
/// This is the preferred API for non-agentic direct interactions.
@ -256,7 +263,7 @@ pub trait Provider: Send + Sync {
/// Whether provider supports native tool calls over API.
fn supports_native_tools(&self) -> bool {
false
self.capabilities().native_tool_calling
}
/// Warm up the HTTP connection pool (TLS handshake, DNS, HTTP/2 setup).
@ -336,6 +343,27 @@ pub trait Provider: Send + Sync {
mod tests {
use super::*;
struct CapabilityMockProvider;
#[async_trait]
impl Provider for CapabilityMockProvider {
fn capabilities(&self) -> ProviderCapabilities {
ProviderCapabilities {
native_tool_calling: true,
}
}
async fn chat_with_system(
&self,
_system_prompt: Option<&str>,
_message: &str,
_model: &str,
_temperature: f64,
) -> anyhow::Result<String> {
Ok("ok".into())
}
}
#[test]
fn chat_message_constructors() {
let sys = ChatMessage::system("Be helpful");
@ -398,4 +426,32 @@ mod tests {
let json = serde_json::to_string(&tool_result).unwrap();
assert!(json.contains("\"type\":\"ToolResults\""));
}
#[test]
fn provider_capabilities_default() {
let caps = ProviderCapabilities::default();
assert!(!caps.native_tool_calling);
}
#[test]
fn provider_capabilities_equality() {
let caps1 = ProviderCapabilities {
native_tool_calling: true,
};
let caps2 = ProviderCapabilities {
native_tool_calling: true,
};
let caps3 = ProviderCapabilities {
native_tool_calling: false,
};
assert_eq!(caps1, caps2);
assert_ne!(caps1, caps3);
}
#[test]
fn supports_native_tools_reflects_capabilities_default_mapping() {
let provider = CapabilityMockProvider;
assert!(provider.supports_native_tools());
}
}