From 49bb20f961613eaf78423badbb5af09deed9f901 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Edvard=20Sch=C3=B8yen?= <99178202+ecschoye@users.noreply.github.com> Date: Sun, 15 Feb 2026 14:32:33 -0500 Subject: [PATCH] fix(providers): use Bearer auth for Gemini CLI OAuth tokens * fix(providers): use Bearer auth for Gemini CLI OAuth tokens When credentials come from ~/.gemini/oauth_creds.json (Gemini CLI), send them as Authorization: Bearer header instead of ?key= query parameter. API keys from env vars or config continue using ?key=. Fixes #194 Co-Authored-By: Claude Opus 4.6 * refactor(gemini): harden OAuth bearer auth flow and tests * fix(gemini): granular auth source tracking and review fixes Build on chumyin's auth model refactor with: - Expand GeminiAuth to 4 variants (ExplicitKey/EnvGeminiKey/EnvGoogleKey/ OAuthToken) so auth_source() uses stored discriminant without re-reading env vars at call time - Add is_api_key()/credential() helpers on the enum - Upgrade expired OAuth token log from debug to warn - Add tests: provider_rejects_empty_key, auth_source_explicit_key, auth_source_none_without_credentials Co-Authored-By: Claude Opus 4.6 * style: apply rustfmt to fix CI lint failures Co-Authored-By: Claude Opus 4.6 --------- Co-authored-by: Claude Opus 4.6 Co-authored-by: root Co-authored-by: argenis de la rosa --- src/channels/mod.rs | 14 +- src/config/schema.rs | 3 +- src/identity.rs | 9 +- src/memory/sqlite.rs | 6 +- src/observability/traits.rs | 5 +- src/onboard/wizard.rs | 9 +- src/providers/anthropic.rs | 3 +- src/providers/compatible.rs | 20 ++- src/providers/gemini.rs | 309 ++++++++++++++++++++++++++++-------- src/providers/reliable.rs | 4 +- src/providers/router.rs | 31 ++-- src/skillforge/evaluate.rs | 25 ++- src/skillforge/mod.rs | 10 +- src/skillforge/scout.rs | 48 +++--- src/tools/browser.rs | 10 +- 15 files changed, 358 insertions(+), 148 deletions(-) diff --git a/src/channels/mod.rs b/src/channels/mod.rs index 4d5a7b8..313398e 100644 --- a/src/channels/mod.rs +++ b/src/channels/mod.rs @@ -21,10 +21,10 @@ pub use traits::Channel; pub use whatsapp::WhatsAppChannel; use crate::config::Config; +use crate::identity; use crate::memory::{self, Memory}; use crate::providers::{self, Provider}; use crate::util::truncate_with_ellipsis; -use crate::identity; use anyhow::Result; use std::sync::Arc; use std::time::{Duration, Instant}; @@ -205,7 +205,9 @@ pub fn build_system_prompt( } Err(e) => { // Log error but don't fail - fall back to OpenClaw - eprintln!("Warning: Failed to load AIEOS identity: {e}. Using OpenClaw format."); + eprintln!( + "Warning: Failed to load AIEOS identity: {e}. Using OpenClaw format." + ); load_openclaw_bootstrap_files(&mut prompt, workspace_dir); } } @@ -534,7 +536,13 @@ pub async fn start_channels(config: Config) -> Result<()> { )); } - let system_prompt = build_system_prompt(&workspace, &model, &tool_descs, &skills, Some(&config.identity)); + let system_prompt = build_system_prompt( + &workspace, + &model, + &tool_descs, + &skills, + Some(&config.identity), + ); if !skills.is_empty() { println!( diff --git a/src/config/schema.rs b/src/config/schema.rs index a866880..84496ab 100644 --- a/src/config/schema.rs +++ b/src/config/schema.rs @@ -1215,7 +1215,6 @@ default_temperature = 0.7 let _ = fs::remove_dir_all(&dir); } - #[test] fn config_save_atomic_cleanup() { let dir = @@ -1920,7 +1919,7 @@ default_temperature = 0.7 fn env_override_temperature_out_of_range_ignored() { // Clean up any leftover env vars from other tests std::env::remove_var("ZEROCLAW_TEMPERATURE"); - + let mut config = Config::default(); let original_temp = config.default_temperature; diff --git a/src/identity.rs b/src/identity.rs index 45fe630..4217f4a 100644 --- a/src/identity.rs +++ b/src/identity.rs @@ -183,8 +183,8 @@ pub fn load_aieos_identity( // Fall back to aieos_inline if let Some(ref inline) = config.aieos_inline { - let identity: AieosIdentity = serde_json::from_str(inline) - .context("Failed to parse inline AIEOS JSON")?; + let identity: AieosIdentity = + serde_json::from_str(inline).context("Failed to parse inline AIEOS JSON")?; return Ok(Some(identity)); } @@ -544,10 +544,7 @@ mod tests { // Check motivations let mot = identity.motivations.unwrap(); - assert_eq!( - mot.core_drive.unwrap(), - "Help users accomplish their goals" - ); + assert_eq!(mot.core_drive.unwrap(), "Help users accomplish their goals"); // Check capabilities let cap = identity.capabilities.unwrap(); diff --git a/src/memory/sqlite.rs b/src/memory/sqlite.rs index b56f337..73abff5 100644 --- a/src/memory/sqlite.rs +++ b/src/memory/sqlite.rs @@ -138,7 +138,11 @@ impl SqliteMemory { // First 8 bytes → 16 hex chars, matching previous format length format!( "{:016x}", - u64::from_be_bytes(hash[..8].try_into().expect("SHA-256 always produces >= 8 bytes")) + u64::from_be_bytes( + hash[..8] + .try_into() + .expect("SHA-256 always produces >= 8 bytes") + ) ) } diff --git a/src/observability/traits.rs b/src/observability/traits.rs index 3a2c5ae..08ac2ea 100644 --- a/src/observability/traits.rs +++ b/src/observability/traits.rs @@ -51,7 +51,10 @@ pub trait Observer: Send + Sync + 'static { fn name(&self) -> &str; /// Downcast to `Any` for backend-specific operations - fn as_any(&self) -> &dyn std::any::Any where Self: Sized { + fn as_any(&self) -> &dyn std::any::Any + where + Self: Sized, + { self } } diff --git a/src/onboard/wizard.rs b/src/onboard/wizard.rs index ec95aa3..75e253e 100644 --- a/src/onboard/wizard.rs +++ b/src/onboard/wizard.rs @@ -1734,9 +1734,8 @@ fn setup_channels() -> Result { } }; - let nickname: String = Input::new() - .with_prompt(" Bot nickname") - .interact_text()?; + let nickname: String = + Input::new().with_prompt(" Bot nickname").interact_text()?; if nickname.trim().is_empty() { println!(" {} Skipped — nickname required", style("→").dim()); @@ -1779,7 +1778,9 @@ fn setup_channels() -> Result { }; if allowed_users.is_empty() { - print_bullet("⚠️ Empty allowlist — only you can interact. Add nicknames above."); + print_bullet( + "⚠️ Empty allowlist — only you can interact. Add nicknames above.", + ); } println!(); diff --git a/src/providers/anthropic.rs b/src/providers/anthropic.rs index c81bac0..3202a01 100644 --- a/src/providers/anthropic.rs +++ b/src/providers/anthropic.rs @@ -154,7 +154,8 @@ mod tests { #[test] fn creates_with_custom_base_url() { - let p = AnthropicProvider::with_base_url(Some("sk-ant-test"), Some("https://api.example.com")); + let p = + AnthropicProvider::with_base_url(Some("sk-ant-test"), Some("https://api.example.com")); assert_eq!(p.base_url, "https://api.example.com"); assert_eq!(p.credential.as_deref(), Some("sk-ant-test")); } diff --git a/src/providers/compatible.rs b/src/providers/compatible.rs index 4d8f868..7c2eeec 100644 --- a/src/providers/compatible.rs +++ b/src/providers/compatible.rs @@ -452,14 +452,20 @@ mod tests { fn chat_completions_url_standard_openai() { // Standard OpenAI-compatible providers get /chat/completions appended let p = make_provider("openai", "https://api.openai.com/v1", None); - assert_eq!(p.chat_completions_url(), "https://api.openai.com/v1/chat/completions"); + assert_eq!( + p.chat_completions_url(), + "https://api.openai.com/v1/chat/completions" + ); } #[test] fn chat_completions_url_trailing_slash() { // Trailing slash is stripped, then /chat/completions appended let p = make_provider("test", "https://api.example.com/v1/", None); - assert_eq!(p.chat_completions_url(), "https://api.example.com/v1/chat/completions"); + assert_eq!( + p.chat_completions_url(), + "https://api.example.com/v1/chat/completions" + ); } #[test] @@ -515,14 +521,20 @@ mod tests { fn chat_completions_url_without_v1() { // Provider configured without /v1 in base URL let p = make_provider("test", "https://api.example.com", None); - assert_eq!(p.chat_completions_url(), "https://api.example.com/chat/completions"); + assert_eq!( + p.chat_completions_url(), + "https://api.example.com/chat/completions" + ); } #[test] fn chat_completions_url_base_with_v1() { // Provider configured with /v1 in base URL let p = make_provider("test", "https://api.example.com/v1", None); - assert_eq!(p.chat_completions_url(), "https://api.example.com/v1/chat/completions"); + assert_eq!( + p.chat_completions_url(), + "https://api.example.com/v1/chat/completions" + ); } // ══════════════════════════════════════════════════════════ diff --git a/src/providers/gemini.rs b/src/providers/gemini.rs index 1b64af0..a988224 100644 --- a/src/providers/gemini.rs +++ b/src/providers/gemini.rs @@ -12,10 +12,44 @@ use std::path::PathBuf; /// Gemini provider supporting multiple authentication methods. pub struct GeminiProvider { - api_key: Option, + auth: Option, client: Client, } +/// Resolved credential — the variant determines both the HTTP auth method +/// and the diagnostic label returned by `auth_source()`. +#[derive(Debug)] +enum GeminiAuth { + /// Explicit API key from config: sent as `?key=` query parameter. + ExplicitKey(String), + /// API key from `GEMINI_API_KEY` env var: sent as `?key=`. + EnvGeminiKey(String), + /// API key from `GOOGLE_API_KEY` env var: sent as `?key=`. + EnvGoogleKey(String), + /// OAuth access token from Gemini CLI: sent as `Authorization: Bearer`. + OAuthToken(String), +} + +impl GeminiAuth { + /// Whether this credential is an API key (sent as `?key=` query param). + fn is_api_key(&self) -> bool { + matches!( + self, + GeminiAuth::ExplicitKey(_) | GeminiAuth::EnvGeminiKey(_) | GeminiAuth::EnvGoogleKey(_) + ) + } + + /// The raw credential string. + fn credential(&self) -> &str { + match self { + GeminiAuth::ExplicitKey(s) + | GeminiAuth::EnvGeminiKey(s) + | GeminiAuth::EnvGoogleKey(s) + | GeminiAuth::OAuthToken(s) => s, + } + } +} + // ══════════════════════════════════════════════════════════════════════════════ // API REQUEST/RESPONSE TYPES // ══════════════════════════════════════════════════════════════════════════════ @@ -82,17 +116,9 @@ struct ApiError { #[derive(Debug, Deserialize)] struct GeminiCliOAuthCreds { access_token: Option, - refresh_token: Option, expiry: Option, } -/// Settings stored by Gemini CLI in ~/.gemini/settings.json -#[derive(Debug, Deserialize)] -struct GeminiCliSettings { - #[serde(rename = "selectedAuthType")] - selected_auth_type: Option, -} - impl GeminiProvider { /// Create a new Gemini provider. /// @@ -102,14 +128,15 @@ impl GeminiProvider { /// 3. `GOOGLE_API_KEY` environment variable /// 4. Gemini CLI OAuth tokens (`~/.gemini/oauth_creds.json`) pub fn new(api_key: Option<&str>) -> Self { - let resolved_key = api_key - .map(String::from) - .or_else(|| std::env::var("GEMINI_API_KEY").ok()) - .or_else(|| std::env::var("GOOGLE_API_KEY").ok()) - .or_else(Self::try_load_gemini_cli_token); + let resolved_auth = api_key + .and_then(Self::normalize_non_empty) + .map(GeminiAuth::ExplicitKey) + .or_else(|| Self::load_non_empty_env("GEMINI_API_KEY").map(GeminiAuth::EnvGeminiKey)) + .or_else(|| Self::load_non_empty_env("GOOGLE_API_KEY").map(GeminiAuth::EnvGoogleKey)) + .or_else(|| Self::try_load_gemini_cli_token().map(GeminiAuth::OAuthToken)); Self { - api_key: resolved_key, + auth: resolved_auth, client: Client::builder() .timeout(std::time::Duration::from_secs(120)) .connect_timeout(std::time::Duration::from_secs(10)) @@ -118,6 +145,21 @@ impl GeminiProvider { } } + fn normalize_non_empty(value: &str) -> Option { + let trimmed = value.trim(); + if trimmed.is_empty() { + None + } else { + Some(trimmed.to_string()) + } + } + + fn load_non_empty_env(name: &str) -> Option { + std::env::var(name) + .ok() + .and_then(|value| Self::normalize_non_empty(&value)) + } + /// Try to load OAuth access token from Gemini CLI's cached credentials. /// Location: `~/.gemini/oauth_creds.json` fn try_load_gemini_cli_token() -> Option { @@ -135,13 +177,15 @@ impl GeminiProvider { if let Some(ref expiry) = creds.expiry { if let Ok(expiry_time) = chrono::DateTime::parse_from_rfc3339(expiry) { if expiry_time < chrono::Utc::now() { - tracing::debug!("Gemini CLI OAuth token expired, skipping"); + tracing::warn!("Gemini CLI OAuth token expired — re-run `gemini` to refresh"); return None; } } } - creds.access_token + creds + .access_token + .and_then(|token| Self::normalize_non_empty(&token)) } /// Get the Gemini CLI config directory (~/.gemini) @@ -156,26 +200,55 @@ impl GeminiProvider { /// Check if any Gemini authentication is available pub fn has_any_auth() -> bool { - std::env::var("GEMINI_API_KEY").is_ok() - || std::env::var("GOOGLE_API_KEY").is_ok() + Self::load_non_empty_env("GEMINI_API_KEY").is_some() + || Self::load_non_empty_env("GOOGLE_API_KEY").is_some() || Self::has_cli_credentials() } - /// Get authentication source description for diagnostics + /// Get authentication source description for diagnostics. + /// Uses the stored enum variant — no env var re-reading at call time. pub fn auth_source(&self) -> &'static str { - if self.api_key.is_none() { - return "none"; + match self.auth.as_ref() { + Some(GeminiAuth::ExplicitKey(_)) => "config", + Some(GeminiAuth::EnvGeminiKey(_)) => "GEMINI_API_KEY env var", + Some(GeminiAuth::EnvGoogleKey(_)) => "GOOGLE_API_KEY env var", + Some(GeminiAuth::OAuthToken(_)) => "Gemini CLI OAuth", + None => "none", } - if std::env::var("GEMINI_API_KEY").is_ok() { - return "GEMINI_API_KEY env var"; + } + + fn format_model_name(model: &str) -> String { + if model.starts_with("models/") { + model.to_string() + } else { + format!("models/{model}") } - if std::env::var("GOOGLE_API_KEY").is_ok() { - return "GOOGLE_API_KEY env var"; + } + + fn build_generate_content_url(model: &str, auth: &GeminiAuth) -> String { + let model_name = Self::format_model_name(model); + let base_url = format!( + "https://generativelanguage.googleapis.com/v1beta/{model_name}:generateContent" + ); + + if auth.is_api_key() { + format!("{base_url}?key={}", auth.credential()) + } else { + base_url } - if Self::has_cli_credentials() { - return "Gemini CLI OAuth"; + } + + fn build_generate_content_request( + &self, + auth: &GeminiAuth, + url: &str, + request: &GenerateContentRequest, + ) -> reqwest::RequestBuilder { + let req = self.client.post(url).json(request); + match auth { + GeminiAuth::OAuthToken(token) => req.bearer_auth(token), + _ => req, } - "config" } } @@ -188,7 +261,7 @@ impl Provider for GeminiProvider { model: &str, temperature: f64, ) -> anyhow::Result { - let api_key = self.api_key.as_ref().ok_or_else(|| { + let auth = self.auth.as_ref().ok_or_else(|| { anyhow::anyhow!( "Gemini API key not found. Options:\n\ 1. Set GEMINI_API_KEY env var\n\ @@ -220,19 +293,12 @@ impl Provider for GeminiProvider { }, }; - // Gemini API endpoint - // Model format: gemini-2.0-flash, gemini-1.5-pro, etc. - let model_name = if model.starts_with("models/") { - model.to_string() - } else { - format!("models/{model}") - }; + let url = Self::build_generate_content_url(model, auth); - let url = format!( - "https://generativelanguage.googleapis.com/v1beta/{model_name}:generateContent?key={api_key}" - ); - - let response = self.client.post(&url).json(&request).send().await?; + let response = self + .build_generate_content_request(auth, &url, &request) + .send() + .await?; if !response.status().is_success() { let status = response.status(); @@ -260,19 +326,38 @@ impl Provider for GeminiProvider { #[cfg(test)] mod tests { use super::*; + use reqwest::header::AUTHORIZATION; + + #[test] + fn normalize_non_empty_trims_and_filters() { + assert_eq!( + GeminiProvider::normalize_non_empty(" value "), + Some("value".into()) + ); + assert_eq!(GeminiProvider::normalize_non_empty(""), None); + assert_eq!(GeminiProvider::normalize_non_empty(" \t\n"), None); + } #[test] fn provider_creates_without_key() { let provider = GeminiProvider::new(None); - // Should not panic, just have no key - assert!(provider.api_key.is_none() || provider.api_key.is_some()); + // May pick up env vars; just verify it doesn't panic + let _ = provider.auth_source(); } #[test] fn provider_creates_with_key() { let provider = GeminiProvider::new(Some("test-api-key")); - assert!(provider.api_key.is_some()); - assert_eq!(provider.api_key.as_deref(), Some("test-api-key")); + assert!(matches!( + provider.auth, + Some(GeminiAuth::ExplicitKey(ref key)) if key == "test-api-key" + )); + } + + #[test] + fn provider_rejects_empty_key() { + let provider = GeminiProvider::new(Some("")); + assert!(!matches!(provider.auth, Some(GeminiAuth::ExplicitKey(_)))); } #[test] @@ -286,33 +371,123 @@ mod tests { } #[test] - fn auth_source_reports_correctly() { - let provider = GeminiProvider::new(Some("explicit-key")); - // With explicit key, should report "config" (unless CLI credentials exist) - let source = provider.auth_source(); - // Should be either "config" or "Gemini CLI OAuth" if CLI is configured - assert!(source == "config" || source == "Gemini CLI OAuth"); + fn auth_source_explicit_key() { + let provider = GeminiProvider { + auth: Some(GeminiAuth::ExplicitKey("key".into())), + client: Client::new(), + }; + assert_eq!(provider.auth_source(), "config"); + } + + #[test] + fn auth_source_none_without_credentials() { + let provider = GeminiProvider { + auth: None, + client: Client::new(), + }; + assert_eq!(provider.auth_source(), "none"); + } + + #[test] + fn auth_source_oauth() { + let provider = GeminiProvider { + auth: Some(GeminiAuth::OAuthToken("ya29.mock".into())), + client: Client::new(), + }; + assert_eq!(provider.auth_source(), "Gemini CLI OAuth"); } #[test] fn model_name_formatting() { - // Test that model names are formatted correctly - let model = "gemini-2.0-flash"; - let formatted = if model.starts_with("models/") { - model.to_string() - } else { - format!("models/{model}") - }; - assert_eq!(formatted, "models/gemini-2.0-flash"); + assert_eq!( + GeminiProvider::format_model_name("gemini-2.0-flash"), + "models/gemini-2.0-flash" + ); + assert_eq!( + GeminiProvider::format_model_name("models/gemini-1.5-pro"), + "models/gemini-1.5-pro" + ); + } - // Already prefixed - let model2 = "models/gemini-1.5-pro"; - let formatted2 = if model2.starts_with("models/") { - model2.to_string() - } else { - format!("models/{model2}") + #[test] + fn api_key_url_includes_key_query_param() { + let auth = GeminiAuth::ExplicitKey("api-key-123".into()); + let url = GeminiProvider::build_generate_content_url("gemini-2.0-flash", &auth); + assert!(url.contains(":generateContent?key=api-key-123")); + } + + #[test] + fn oauth_url_omits_key_query_param() { + let auth = GeminiAuth::OAuthToken("ya29.test-token".into()); + let url = GeminiProvider::build_generate_content_url("gemini-2.0-flash", &auth); + assert!(url.ends_with(":generateContent")); + assert!(!url.contains("?key=")); + } + + #[test] + fn oauth_request_uses_bearer_auth_header() { + let provider = GeminiProvider { + auth: Some(GeminiAuth::OAuthToken("ya29.mock-token".into())), + client: Client::new(), }; - assert_eq!(formatted2, "models/gemini-1.5-pro"); + let auth = GeminiAuth::OAuthToken("ya29.mock-token".into()); + let url = GeminiProvider::build_generate_content_url("gemini-2.0-flash", &auth); + let body = GenerateContentRequest { + contents: vec![Content { + role: Some("user".into()), + parts: vec![Part { + text: "hello".into(), + }], + }], + system_instruction: None, + generation_config: GenerationConfig { + temperature: 0.7, + max_output_tokens: 8192, + }, + }; + + let request = provider + .build_generate_content_request(&auth, &url, &body) + .build() + .unwrap(); + + assert_eq!( + request + .headers() + .get(AUTHORIZATION) + .and_then(|h| h.to_str().ok()), + Some("Bearer ya29.mock-token") + ); + } + + #[test] + fn api_key_request_does_not_set_bearer_header() { + let provider = GeminiProvider { + auth: Some(GeminiAuth::ExplicitKey("api-key-123".into())), + client: Client::new(), + }; + let auth = GeminiAuth::ExplicitKey("api-key-123".into()); + let url = GeminiProvider::build_generate_content_url("gemini-2.0-flash", &auth); + let body = GenerateContentRequest { + contents: vec![Content { + role: Some("user".into()), + parts: vec![Part { + text: "hello".into(), + }], + }], + system_instruction: None, + generation_config: GenerationConfig { + temperature: 0.7, + max_output_tokens: 8192, + }, + }; + + let request = provider + .build_generate_content_request(&auth, &url, &body) + .build() + .unwrap(); + + assert!(request.headers().get(AUTHORIZATION).is_none()); } #[test] diff --git a/src/providers/reliable.rs b/src/providers/reliable.rs index 791f13d..921eeef 100644 --- a/src/providers/reliable.rs +++ b/src/providers/reliable.rs @@ -281,9 +281,7 @@ mod tests { "API error with 400 Bad Request" ))); // Retryable: 429 Too Many Requests - assert!(!is_non_retryable(&anyhow::anyhow!( - "429 Too Many Requests" - ))); + assert!(!is_non_retryable(&anyhow::anyhow!("429 Too Many Requests"))); // Retryable: 408 Request Timeout assert!(!is_non_retryable(&anyhow::anyhow!("408 Request Timeout"))); // Retryable: 5xx server errors diff --git a/src/providers/router.rs b/src/providers/router.rs index 52dab47..2085276 100644 --- a/src/providers/router.rs +++ b/src/providers/router.rs @@ -181,7 +181,10 @@ mod tests { .iter() .zip(mocks.iter()) .map(|((name, _), mock)| { - (name.to_string(), Box::new(Arc::clone(mock)) as Box) + ( + name.to_string(), + Box::new(Arc::clone(mock)) as Box, + ) }) .collect(); @@ -198,11 +201,7 @@ mod tests { }) .collect(); - let router = RouterProvider::new( - provider_list, - route_list, - "default-model".to_string(), - ); + let router = RouterProvider::new(provider_list, route_list, "default-model".to_string()); (router, mocks) } @@ -270,7 +269,10 @@ mod tests { #[tokio::test] async fn non_hint_model_uses_default_provider() { let (router, mocks) = make_router( - vec![("primary", "primary-response"), ("secondary", "secondary-response")], + vec![ + ("primary", "primary-response"), + ("secondary", "secondary-response"), + ], vec![("code", "secondary", "codellama")], ); @@ -285,10 +287,7 @@ mod tests { #[test] fn resolve_preserves_model_for_non_hints() { - let (router, _) = make_router( - vec![("default", "ok")], - vec![], - ); + let (router, _) = make_router(vec![("default", "ok")], vec![]); let (idx, model) = router.resolve("gpt-4o"); assert_eq!(idx, 0); @@ -320,10 +319,7 @@ mod tests { #[tokio::test] async fn warmup_calls_all_providers() { - let (router, _) = make_router( - vec![("a", "ok"), ("b", "ok")], - vec![], - ); + let (router, _) = make_router(vec![("a", "ok"), ("b", "ok")], vec![]); // Warmup should not error assert!(router.warmup().await.is_ok()); @@ -333,7 +329,10 @@ mod tests { async fn chat_with_system_passes_system_prompt() { let mock = Arc::new(MockProvider::new("response")); let router = RouterProvider::new( - vec![("default".into(), Box::new(Arc::clone(&mock)) as Box)], + vec![( + "default".into(), + Box::new(Arc::clone(&mock)) as Box, + )], vec![], "model".into(), ); diff --git a/src/skillforge/evaluate.rs b/src/skillforge/evaluate.rs index e9971ec..bdefd59 100644 --- a/src/skillforge/evaluate.rs +++ b/src/skillforge/evaluate.rs @@ -74,11 +74,10 @@ const BAD_PATTERNS: &[&str] = &[ /// Check if `haystack` contains `word` as a whole word (bounded by non-alphanumeric chars). fn contains_word(haystack: &str, word: &str) -> bool { for (i, _) in haystack.match_indices(word) { - let before_ok = i == 0 - || !haystack.as_bytes()[i - 1].is_ascii_alphanumeric(); + let before_ok = i == 0 || !haystack.as_bytes()[i - 1].is_ascii_alphanumeric(); let after = i + word.len(); - let after_ok = after >= haystack.len() - || !haystack.as_bytes()[after].is_ascii_alphanumeric(); + let after_ok = + after >= haystack.len() || !haystack.as_bytes()[after].is_ascii_alphanumeric(); if before_ok && after_ok { return true; } @@ -217,7 +216,11 @@ mod tests { c.name = "malware-skill".into(); let res = eval.evaluate(c); // 0.5 base + 0.3 license - 0.5 bad_pattern + 0.2 recency = 0.5 - assert!(res.scores.security <= 0.5, "security: {}", res.scores.security); + assert!( + res.scores.security <= 0.5, + "security: {}", + res.scores.security + ); } #[test] @@ -245,7 +248,11 @@ mod tests { c.description = "Tools for hackathons and lifehacks".into(); let res = eval.evaluate(c); // "hack" should NOT match "hackathon" or "lifehacks" - assert!(res.scores.security >= 0.5, "security: {}", res.scores.security); + assert!( + res.scores.security >= 0.5, + "security: {}", + res.scores.security + ); } #[test] @@ -256,6 +263,10 @@ mod tests { c.updated_at = None; let res = eval.evaluate(c); // 0.5 base + 0.0 license - 0.5 bad_pattern + 0.0 recency = 0.0 - assert!(res.scores.security < 0.5, "security: {}", res.scores.security); + assert!( + res.scores.security < 0.5, + "security: {}", + res.scores.security + ); } } diff --git a/src/skillforge/mod.rs b/src/skillforge/mod.rs index d16b8dc..17c2336 100644 --- a/src/skillforge/mod.rs +++ b/src/skillforge/mod.rs @@ -78,10 +78,7 @@ impl std::fmt::Debug for SkillForgeConfig { .field("sources", &self.sources) .field("scan_interval_hours", &self.scan_interval_hours) .field("min_score", &self.min_score) - .field( - "github_token", - &self.github_token.as_ref().map(|_| "***"), - ) + .field("github_token", &self.github_token.as_ref().map(|_| "***")) .field("output_dir", &self.output_dir) .finish() } @@ -155,7 +152,10 @@ impl SkillForge { } } ScoutSource::ClawHub | ScoutSource::HuggingFace => { - info!(source = src.as_str(), "Source not yet implemented — skipping"); + info!( + source = src.as_str(), + "Source not yet implemented — skipping" + ); } } } diff --git a/src/skillforge/scout.rs b/src/skillforge/scout.rs index df3a4a8..1ad8af4 100644 --- a/src/skillforge/scout.rs +++ b/src/skillforge/scout.rs @@ -79,9 +79,7 @@ impl GitHubScout { let mut headers = reqwest::header::HeaderMap::new(); headers.insert( reqwest::header::ACCEPT, - "application/vnd.github+json" - .parse() - .expect("valid header"), + "application/vnd.github+json".parse().expect("valid header"), ); headers.insert( reqwest::header::USER_AGENT, @@ -101,10 +99,7 @@ impl GitHubScout { Self { client, - queries: vec![ - "zeroclaw skill".into(), - "ai agent skill".into(), - ], + queries: vec!["zeroclaw skill".into(), "ai agent skill".into()], } } @@ -143,10 +138,7 @@ impl GitHubScout { .and_then(|v| v.as_str()) .unwrap_or("unknown") .to_string(); - let has_license = item - .get("license") - .map(|v| !v.is_null()) - .unwrap_or(false); + let has_license = item.get("license").map(|v| !v.is_null()).unwrap_or(false); Some(ScoutResult { name, @@ -225,9 +217,7 @@ impl Scout for GitHubScout { /// Minimal percent-encoding for query strings (space → +). fn urlencoding(s: &str) -> String { - s.replace(' ', "+") - .replace('&', "%26") - .replace('#', "%23") + s.replace(' ', "+").replace('&', "%26").replace('#', "%23") } /// Deduplicate scout results by URL (keeps first occurrence). @@ -246,13 +236,31 @@ mod tests { #[test] fn scout_source_from_str() { - assert_eq!("github".parse::().unwrap(), ScoutSource::GitHub); - assert_eq!("GitHub".parse::().unwrap(), ScoutSource::GitHub); - assert_eq!("clawhub".parse::().unwrap(), ScoutSource::ClawHub); - assert_eq!("huggingface".parse::().unwrap(), ScoutSource::HuggingFace); - assert_eq!("hf".parse::().unwrap(), ScoutSource::HuggingFace); + assert_eq!( + "github".parse::().unwrap(), + ScoutSource::GitHub + ); + assert_eq!( + "GitHub".parse::().unwrap(), + ScoutSource::GitHub + ); + assert_eq!( + "clawhub".parse::().unwrap(), + ScoutSource::ClawHub + ); + assert_eq!( + "huggingface".parse::().unwrap(), + ScoutSource::HuggingFace + ); + assert_eq!( + "hf".parse::().unwrap(), + ScoutSource::HuggingFace + ); // unknown falls back to GitHub - assert_eq!("unknown".parse::().unwrap(), ScoutSource::GitHub); + assert_eq!( + "unknown".parse::().unwrap(), + ScoutSource::GitHub + ); } #[test] diff --git a/src/tools/browser.rs b/src/tools/browser.rs index b3709f6..006a9ef 100644 --- a/src/tools/browser.rs +++ b/src/tools/browser.rs @@ -793,20 +793,14 @@ mod tests { #[test] fn extract_host_handles_ipv6() { // IPv6 with brackets (required for URLs with ports) - assert_eq!( - extract_host("https://[::1]/path").unwrap(), - "[::1]" - ); + assert_eq!(extract_host("https://[::1]/path").unwrap(), "[::1]"); // IPv6 with brackets and port assert_eq!( extract_host("https://[2001:db8::1]:8080/path").unwrap(), "[2001:db8::1]" ); // IPv6 with brackets, trailing slash - assert_eq!( - extract_host("https://[fe80::1]/").unwrap(), - "[fe80::1]" - ); + assert_eq!(extract_host("https://[fe80::1]/").unwrap(), "[fe80::1]"); } #[test]