fix(providers): use Bearer auth for Gemini CLI OAuth tokens
* fix(providers): use Bearer auth for Gemini CLI OAuth tokens When credentials come from ~/.gemini/oauth_creds.json (Gemini CLI), send them as Authorization: Bearer header instead of ?key= query parameter. API keys from env vars or config continue using ?key=. Fixes #194 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * refactor(gemini): harden OAuth bearer auth flow and tests * fix(gemini): granular auth source tracking and review fixes Build on chumyin's auth model refactor with: - Expand GeminiAuth to 4 variants (ExplicitKey/EnvGeminiKey/EnvGoogleKey/ OAuthToken) so auth_source() uses stored discriminant without re-reading env vars at call time - Add is_api_key()/credential() helpers on the enum - Upgrade expired OAuth token log from debug to warn - Add tests: provider_rejects_empty_key, auth_source_explicit_key, auth_source_none_without_credentials Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * style: apply rustfmt to fix CI lint failures Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> Co-authored-by: root <root@instance-20220913-1738.vcn09131738.oraclevcn.com> Co-authored-by: argenis de la rosa <theonlyhennygod@gmail.com>
This commit is contained in:
parent
e057bf4128
commit
49bb20f961
15 changed files with 358 additions and 148 deletions
|
|
@ -21,10 +21,10 @@ pub use traits::Channel;
|
||||||
pub use whatsapp::WhatsAppChannel;
|
pub use whatsapp::WhatsAppChannel;
|
||||||
|
|
||||||
use crate::config::Config;
|
use crate::config::Config;
|
||||||
|
use crate::identity;
|
||||||
use crate::memory::{self, Memory};
|
use crate::memory::{self, Memory};
|
||||||
use crate::providers::{self, Provider};
|
use crate::providers::{self, Provider};
|
||||||
use crate::util::truncate_with_ellipsis;
|
use crate::util::truncate_with_ellipsis;
|
||||||
use crate::identity;
|
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::time::{Duration, Instant};
|
use std::time::{Duration, Instant};
|
||||||
|
|
@ -205,7 +205,9 @@ pub fn build_system_prompt(
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
// Log error but don't fail - fall back to OpenClaw
|
// Log error but don't fail - fall back to OpenClaw
|
||||||
eprintln!("Warning: Failed to load AIEOS identity: {e}. Using OpenClaw format.");
|
eprintln!(
|
||||||
|
"Warning: Failed to load AIEOS identity: {e}. Using OpenClaw format."
|
||||||
|
);
|
||||||
load_openclaw_bootstrap_files(&mut prompt, workspace_dir);
|
load_openclaw_bootstrap_files(&mut prompt, workspace_dir);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -534,7 +536,13 @@ pub async fn start_channels(config: Config) -> Result<()> {
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
let system_prompt = build_system_prompt(&workspace, &model, &tool_descs, &skills, Some(&config.identity));
|
let system_prompt = build_system_prompt(
|
||||||
|
&workspace,
|
||||||
|
&model,
|
||||||
|
&tool_descs,
|
||||||
|
&skills,
|
||||||
|
Some(&config.identity),
|
||||||
|
);
|
||||||
|
|
||||||
if !skills.is_empty() {
|
if !skills.is_empty() {
|
||||||
println!(
|
println!(
|
||||||
|
|
|
||||||
|
|
@ -1215,7 +1215,6 @@ default_temperature = 0.7
|
||||||
let _ = fs::remove_dir_all(&dir);
|
let _ = fs::remove_dir_all(&dir);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn config_save_atomic_cleanup() {
|
fn config_save_atomic_cleanup() {
|
||||||
let dir =
|
let dir =
|
||||||
|
|
@ -1920,7 +1919,7 @@ default_temperature = 0.7
|
||||||
fn env_override_temperature_out_of_range_ignored() {
|
fn env_override_temperature_out_of_range_ignored() {
|
||||||
// Clean up any leftover env vars from other tests
|
// Clean up any leftover env vars from other tests
|
||||||
std::env::remove_var("ZEROCLAW_TEMPERATURE");
|
std::env::remove_var("ZEROCLAW_TEMPERATURE");
|
||||||
|
|
||||||
let mut config = Config::default();
|
let mut config = Config::default();
|
||||||
let original_temp = config.default_temperature;
|
let original_temp = config.default_temperature;
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -183,8 +183,8 @@ pub fn load_aieos_identity(
|
||||||
|
|
||||||
// Fall back to aieos_inline
|
// Fall back to aieos_inline
|
||||||
if let Some(ref inline) = config.aieos_inline {
|
if let Some(ref inline) = config.aieos_inline {
|
||||||
let identity: AieosIdentity = serde_json::from_str(inline)
|
let identity: AieosIdentity =
|
||||||
.context("Failed to parse inline AIEOS JSON")?;
|
serde_json::from_str(inline).context("Failed to parse inline AIEOS JSON")?;
|
||||||
|
|
||||||
return Ok(Some(identity));
|
return Ok(Some(identity));
|
||||||
}
|
}
|
||||||
|
|
@ -544,10 +544,7 @@ mod tests {
|
||||||
|
|
||||||
// Check motivations
|
// Check motivations
|
||||||
let mot = identity.motivations.unwrap();
|
let mot = identity.motivations.unwrap();
|
||||||
assert_eq!(
|
assert_eq!(mot.core_drive.unwrap(), "Help users accomplish their goals");
|
||||||
mot.core_drive.unwrap(),
|
|
||||||
"Help users accomplish their goals"
|
|
||||||
);
|
|
||||||
|
|
||||||
// Check capabilities
|
// Check capabilities
|
||||||
let cap = identity.capabilities.unwrap();
|
let cap = identity.capabilities.unwrap();
|
||||||
|
|
|
||||||
|
|
@ -138,7 +138,11 @@ impl SqliteMemory {
|
||||||
// First 8 bytes → 16 hex chars, matching previous format length
|
// First 8 bytes → 16 hex chars, matching previous format length
|
||||||
format!(
|
format!(
|
||||||
"{:016x}",
|
"{:016x}",
|
||||||
u64::from_be_bytes(hash[..8].try_into().expect("SHA-256 always produces >= 8 bytes"))
|
u64::from_be_bytes(
|
||||||
|
hash[..8]
|
||||||
|
.try_into()
|
||||||
|
.expect("SHA-256 always produces >= 8 bytes")
|
||||||
|
)
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -51,7 +51,10 @@ pub trait Observer: Send + Sync + 'static {
|
||||||
fn name(&self) -> &str;
|
fn name(&self) -> &str;
|
||||||
|
|
||||||
/// Downcast to `Any` for backend-specific operations
|
/// Downcast to `Any` for backend-specific operations
|
||||||
fn as_any(&self) -> &dyn std::any::Any where Self: Sized {
|
fn as_any(&self) -> &dyn std::any::Any
|
||||||
|
where
|
||||||
|
Self: Sized,
|
||||||
|
{
|
||||||
self
|
self
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1734,9 +1734,8 @@ fn setup_channels() -> Result<ChannelsConfig> {
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let nickname: String = Input::new()
|
let nickname: String =
|
||||||
.with_prompt(" Bot nickname")
|
Input::new().with_prompt(" Bot nickname").interact_text()?;
|
||||||
.interact_text()?;
|
|
||||||
|
|
||||||
if nickname.trim().is_empty() {
|
if nickname.trim().is_empty() {
|
||||||
println!(" {} Skipped — nickname required", style("→").dim());
|
println!(" {} Skipped — nickname required", style("→").dim());
|
||||||
|
|
@ -1779,7 +1778,9 @@ fn setup_channels() -> Result<ChannelsConfig> {
|
||||||
};
|
};
|
||||||
|
|
||||||
if allowed_users.is_empty() {
|
if allowed_users.is_empty() {
|
||||||
print_bullet("⚠️ Empty allowlist — only you can interact. Add nicknames above.");
|
print_bullet(
|
||||||
|
"⚠️ Empty allowlist — only you can interact. Add nicknames above.",
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
println!();
|
println!();
|
||||||
|
|
|
||||||
|
|
@ -154,7 +154,8 @@ mod tests {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn creates_with_custom_base_url() {
|
fn creates_with_custom_base_url() {
|
||||||
let p = AnthropicProvider::with_base_url(Some("sk-ant-test"), Some("https://api.example.com"));
|
let p =
|
||||||
|
AnthropicProvider::with_base_url(Some("sk-ant-test"), Some("https://api.example.com"));
|
||||||
assert_eq!(p.base_url, "https://api.example.com");
|
assert_eq!(p.base_url, "https://api.example.com");
|
||||||
assert_eq!(p.credential.as_deref(), Some("sk-ant-test"));
|
assert_eq!(p.credential.as_deref(), Some("sk-ant-test"));
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -452,14 +452,20 @@ mod tests {
|
||||||
fn chat_completions_url_standard_openai() {
|
fn chat_completions_url_standard_openai() {
|
||||||
// Standard OpenAI-compatible providers get /chat/completions appended
|
// Standard OpenAI-compatible providers get /chat/completions appended
|
||||||
let p = make_provider("openai", "https://api.openai.com/v1", None);
|
let p = make_provider("openai", "https://api.openai.com/v1", None);
|
||||||
assert_eq!(p.chat_completions_url(), "https://api.openai.com/v1/chat/completions");
|
assert_eq!(
|
||||||
|
p.chat_completions_url(),
|
||||||
|
"https://api.openai.com/v1/chat/completions"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn chat_completions_url_trailing_slash() {
|
fn chat_completions_url_trailing_slash() {
|
||||||
// Trailing slash is stripped, then /chat/completions appended
|
// Trailing slash is stripped, then /chat/completions appended
|
||||||
let p = make_provider("test", "https://api.example.com/v1/", None);
|
let p = make_provider("test", "https://api.example.com/v1/", None);
|
||||||
assert_eq!(p.chat_completions_url(), "https://api.example.com/v1/chat/completions");
|
assert_eq!(
|
||||||
|
p.chat_completions_url(),
|
||||||
|
"https://api.example.com/v1/chat/completions"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
@ -515,14 +521,20 @@ mod tests {
|
||||||
fn chat_completions_url_without_v1() {
|
fn chat_completions_url_without_v1() {
|
||||||
// Provider configured without /v1 in base URL
|
// Provider configured without /v1 in base URL
|
||||||
let p = make_provider("test", "https://api.example.com", None);
|
let p = make_provider("test", "https://api.example.com", None);
|
||||||
assert_eq!(p.chat_completions_url(), "https://api.example.com/chat/completions");
|
assert_eq!(
|
||||||
|
p.chat_completions_url(),
|
||||||
|
"https://api.example.com/chat/completions"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn chat_completions_url_base_with_v1() {
|
fn chat_completions_url_base_with_v1() {
|
||||||
// Provider configured with /v1 in base URL
|
// Provider configured with /v1 in base URL
|
||||||
let p = make_provider("test", "https://api.example.com/v1", None);
|
let p = make_provider("test", "https://api.example.com/v1", None);
|
||||||
assert_eq!(p.chat_completions_url(), "https://api.example.com/v1/chat/completions");
|
assert_eq!(
|
||||||
|
p.chat_completions_url(),
|
||||||
|
"https://api.example.com/v1/chat/completions"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// ══════════════════════════════════════════════════════════
|
// ══════════════════════════════════════════════════════════
|
||||||
|
|
|
||||||
|
|
@ -12,10 +12,44 @@ use std::path::PathBuf;
|
||||||
|
|
||||||
/// Gemini provider supporting multiple authentication methods.
|
/// Gemini provider supporting multiple authentication methods.
|
||||||
pub struct GeminiProvider {
|
pub struct GeminiProvider {
|
||||||
api_key: Option<String>,
|
auth: Option<GeminiAuth>,
|
||||||
client: Client,
|
client: Client,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Resolved credential — the variant determines both the HTTP auth method
|
||||||
|
/// and the diagnostic label returned by `auth_source()`.
|
||||||
|
#[derive(Debug)]
|
||||||
|
enum GeminiAuth {
|
||||||
|
/// Explicit API key from config: sent as `?key=` query parameter.
|
||||||
|
ExplicitKey(String),
|
||||||
|
/// API key from `GEMINI_API_KEY` env var: sent as `?key=`.
|
||||||
|
EnvGeminiKey(String),
|
||||||
|
/// API key from `GOOGLE_API_KEY` env var: sent as `?key=`.
|
||||||
|
EnvGoogleKey(String),
|
||||||
|
/// OAuth access token from Gemini CLI: sent as `Authorization: Bearer`.
|
||||||
|
OAuthToken(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl GeminiAuth {
|
||||||
|
/// Whether this credential is an API key (sent as `?key=` query param).
|
||||||
|
fn is_api_key(&self) -> bool {
|
||||||
|
matches!(
|
||||||
|
self,
|
||||||
|
GeminiAuth::ExplicitKey(_) | GeminiAuth::EnvGeminiKey(_) | GeminiAuth::EnvGoogleKey(_)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// The raw credential string.
|
||||||
|
fn credential(&self) -> &str {
|
||||||
|
match self {
|
||||||
|
GeminiAuth::ExplicitKey(s)
|
||||||
|
| GeminiAuth::EnvGeminiKey(s)
|
||||||
|
| GeminiAuth::EnvGoogleKey(s)
|
||||||
|
| GeminiAuth::OAuthToken(s) => s,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ══════════════════════════════════════════════════════════════════════════════
|
// ══════════════════════════════════════════════════════════════════════════════
|
||||||
// API REQUEST/RESPONSE TYPES
|
// API REQUEST/RESPONSE TYPES
|
||||||
// ══════════════════════════════════════════════════════════════════════════════
|
// ══════════════════════════════════════════════════════════════════════════════
|
||||||
|
|
@ -82,17 +116,9 @@ struct ApiError {
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
struct GeminiCliOAuthCreds {
|
struct GeminiCliOAuthCreds {
|
||||||
access_token: Option<String>,
|
access_token: Option<String>,
|
||||||
refresh_token: Option<String>,
|
|
||||||
expiry: Option<String>,
|
expiry: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Settings stored by Gemini CLI in ~/.gemini/settings.json
|
|
||||||
#[derive(Debug, Deserialize)]
|
|
||||||
struct GeminiCliSettings {
|
|
||||||
#[serde(rename = "selectedAuthType")]
|
|
||||||
selected_auth_type: Option<String>,
|
|
||||||
}
|
|
||||||
|
|
||||||
impl GeminiProvider {
|
impl GeminiProvider {
|
||||||
/// Create a new Gemini provider.
|
/// Create a new Gemini provider.
|
||||||
///
|
///
|
||||||
|
|
@ -102,14 +128,15 @@ impl GeminiProvider {
|
||||||
/// 3. `GOOGLE_API_KEY` environment variable
|
/// 3. `GOOGLE_API_KEY` environment variable
|
||||||
/// 4. Gemini CLI OAuth tokens (`~/.gemini/oauth_creds.json`)
|
/// 4. Gemini CLI OAuth tokens (`~/.gemini/oauth_creds.json`)
|
||||||
pub fn new(api_key: Option<&str>) -> Self {
|
pub fn new(api_key: Option<&str>) -> Self {
|
||||||
let resolved_key = api_key
|
let resolved_auth = api_key
|
||||||
.map(String::from)
|
.and_then(Self::normalize_non_empty)
|
||||||
.or_else(|| std::env::var("GEMINI_API_KEY").ok())
|
.map(GeminiAuth::ExplicitKey)
|
||||||
.or_else(|| std::env::var("GOOGLE_API_KEY").ok())
|
.or_else(|| Self::load_non_empty_env("GEMINI_API_KEY").map(GeminiAuth::EnvGeminiKey))
|
||||||
.or_else(Self::try_load_gemini_cli_token);
|
.or_else(|| Self::load_non_empty_env("GOOGLE_API_KEY").map(GeminiAuth::EnvGoogleKey))
|
||||||
|
.or_else(|| Self::try_load_gemini_cli_token().map(GeminiAuth::OAuthToken));
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
api_key: resolved_key,
|
auth: resolved_auth,
|
||||||
client: Client::builder()
|
client: Client::builder()
|
||||||
.timeout(std::time::Duration::from_secs(120))
|
.timeout(std::time::Duration::from_secs(120))
|
||||||
.connect_timeout(std::time::Duration::from_secs(10))
|
.connect_timeout(std::time::Duration::from_secs(10))
|
||||||
|
|
@ -118,6 +145,21 @@ impl GeminiProvider {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn normalize_non_empty(value: &str) -> Option<String> {
|
||||||
|
let trimmed = value.trim();
|
||||||
|
if trimmed.is_empty() {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(trimmed.to_string())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn load_non_empty_env(name: &str) -> Option<String> {
|
||||||
|
std::env::var(name)
|
||||||
|
.ok()
|
||||||
|
.and_then(|value| Self::normalize_non_empty(&value))
|
||||||
|
}
|
||||||
|
|
||||||
/// Try to load OAuth access token from Gemini CLI's cached credentials.
|
/// Try to load OAuth access token from Gemini CLI's cached credentials.
|
||||||
/// Location: `~/.gemini/oauth_creds.json`
|
/// Location: `~/.gemini/oauth_creds.json`
|
||||||
fn try_load_gemini_cli_token() -> Option<String> {
|
fn try_load_gemini_cli_token() -> Option<String> {
|
||||||
|
|
@ -135,13 +177,15 @@ impl GeminiProvider {
|
||||||
if let Some(ref expiry) = creds.expiry {
|
if let Some(ref expiry) = creds.expiry {
|
||||||
if let Ok(expiry_time) = chrono::DateTime::parse_from_rfc3339(expiry) {
|
if let Ok(expiry_time) = chrono::DateTime::parse_from_rfc3339(expiry) {
|
||||||
if expiry_time < chrono::Utc::now() {
|
if expiry_time < chrono::Utc::now() {
|
||||||
tracing::debug!("Gemini CLI OAuth token expired, skipping");
|
tracing::warn!("Gemini CLI OAuth token expired — re-run `gemini` to refresh");
|
||||||
return None;
|
return None;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
creds.access_token
|
creds
|
||||||
|
.access_token
|
||||||
|
.and_then(|token| Self::normalize_non_empty(&token))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get the Gemini CLI config directory (~/.gemini)
|
/// Get the Gemini CLI config directory (~/.gemini)
|
||||||
|
|
@ -156,26 +200,55 @@ impl GeminiProvider {
|
||||||
|
|
||||||
/// Check if any Gemini authentication is available
|
/// Check if any Gemini authentication is available
|
||||||
pub fn has_any_auth() -> bool {
|
pub fn has_any_auth() -> bool {
|
||||||
std::env::var("GEMINI_API_KEY").is_ok()
|
Self::load_non_empty_env("GEMINI_API_KEY").is_some()
|
||||||
|| std::env::var("GOOGLE_API_KEY").is_ok()
|
|| Self::load_non_empty_env("GOOGLE_API_KEY").is_some()
|
||||||
|| Self::has_cli_credentials()
|
|| Self::has_cli_credentials()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get authentication source description for diagnostics
|
/// Get authentication source description for diagnostics.
|
||||||
|
/// Uses the stored enum variant — no env var re-reading at call time.
|
||||||
pub fn auth_source(&self) -> &'static str {
|
pub fn auth_source(&self) -> &'static str {
|
||||||
if self.api_key.is_none() {
|
match self.auth.as_ref() {
|
||||||
return "none";
|
Some(GeminiAuth::ExplicitKey(_)) => "config",
|
||||||
|
Some(GeminiAuth::EnvGeminiKey(_)) => "GEMINI_API_KEY env var",
|
||||||
|
Some(GeminiAuth::EnvGoogleKey(_)) => "GOOGLE_API_KEY env var",
|
||||||
|
Some(GeminiAuth::OAuthToken(_)) => "Gemini CLI OAuth",
|
||||||
|
None => "none",
|
||||||
}
|
}
|
||||||
if std::env::var("GEMINI_API_KEY").is_ok() {
|
}
|
||||||
return "GEMINI_API_KEY env var";
|
|
||||||
|
fn format_model_name(model: &str) -> String {
|
||||||
|
if model.starts_with("models/") {
|
||||||
|
model.to_string()
|
||||||
|
} else {
|
||||||
|
format!("models/{model}")
|
||||||
}
|
}
|
||||||
if std::env::var("GOOGLE_API_KEY").is_ok() {
|
}
|
||||||
return "GOOGLE_API_KEY env var";
|
|
||||||
|
fn build_generate_content_url(model: &str, auth: &GeminiAuth) -> String {
|
||||||
|
let model_name = Self::format_model_name(model);
|
||||||
|
let base_url = format!(
|
||||||
|
"https://generativelanguage.googleapis.com/v1beta/{model_name}:generateContent"
|
||||||
|
);
|
||||||
|
|
||||||
|
if auth.is_api_key() {
|
||||||
|
format!("{base_url}?key={}", auth.credential())
|
||||||
|
} else {
|
||||||
|
base_url
|
||||||
}
|
}
|
||||||
if Self::has_cli_credentials() {
|
}
|
||||||
return "Gemini CLI OAuth";
|
|
||||||
|
fn build_generate_content_request(
|
||||||
|
&self,
|
||||||
|
auth: &GeminiAuth,
|
||||||
|
url: &str,
|
||||||
|
request: &GenerateContentRequest,
|
||||||
|
) -> reqwest::RequestBuilder {
|
||||||
|
let req = self.client.post(url).json(request);
|
||||||
|
match auth {
|
||||||
|
GeminiAuth::OAuthToken(token) => req.bearer_auth(token),
|
||||||
|
_ => req,
|
||||||
}
|
}
|
||||||
"config"
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -188,7 +261,7 @@ impl Provider for GeminiProvider {
|
||||||
model: &str,
|
model: &str,
|
||||||
temperature: f64,
|
temperature: f64,
|
||||||
) -> anyhow::Result<String> {
|
) -> anyhow::Result<String> {
|
||||||
let api_key = self.api_key.as_ref().ok_or_else(|| {
|
let auth = self.auth.as_ref().ok_or_else(|| {
|
||||||
anyhow::anyhow!(
|
anyhow::anyhow!(
|
||||||
"Gemini API key not found. Options:\n\
|
"Gemini API key not found. Options:\n\
|
||||||
1. Set GEMINI_API_KEY env var\n\
|
1. Set GEMINI_API_KEY env var\n\
|
||||||
|
|
@ -220,19 +293,12 @@ impl Provider for GeminiProvider {
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
// Gemini API endpoint
|
let url = Self::build_generate_content_url(model, auth);
|
||||||
// Model format: gemini-2.0-flash, gemini-1.5-pro, etc.
|
|
||||||
let model_name = if model.starts_with("models/") {
|
|
||||||
model.to_string()
|
|
||||||
} else {
|
|
||||||
format!("models/{model}")
|
|
||||||
};
|
|
||||||
|
|
||||||
let url = format!(
|
let response = self
|
||||||
"https://generativelanguage.googleapis.com/v1beta/{model_name}:generateContent?key={api_key}"
|
.build_generate_content_request(auth, &url, &request)
|
||||||
);
|
.send()
|
||||||
|
.await?;
|
||||||
let response = self.client.post(&url).json(&request).send().await?;
|
|
||||||
|
|
||||||
if !response.status().is_success() {
|
if !response.status().is_success() {
|
||||||
let status = response.status();
|
let status = response.status();
|
||||||
|
|
@ -260,19 +326,38 @@ impl Provider for GeminiProvider {
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
use reqwest::header::AUTHORIZATION;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn normalize_non_empty_trims_and_filters() {
|
||||||
|
assert_eq!(
|
||||||
|
GeminiProvider::normalize_non_empty(" value "),
|
||||||
|
Some("value".into())
|
||||||
|
);
|
||||||
|
assert_eq!(GeminiProvider::normalize_non_empty(""), None);
|
||||||
|
assert_eq!(GeminiProvider::normalize_non_empty(" \t\n"), None);
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn provider_creates_without_key() {
|
fn provider_creates_without_key() {
|
||||||
let provider = GeminiProvider::new(None);
|
let provider = GeminiProvider::new(None);
|
||||||
// Should not panic, just have no key
|
// May pick up env vars; just verify it doesn't panic
|
||||||
assert!(provider.api_key.is_none() || provider.api_key.is_some());
|
let _ = provider.auth_source();
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn provider_creates_with_key() {
|
fn provider_creates_with_key() {
|
||||||
let provider = GeminiProvider::new(Some("test-api-key"));
|
let provider = GeminiProvider::new(Some("test-api-key"));
|
||||||
assert!(provider.api_key.is_some());
|
assert!(matches!(
|
||||||
assert_eq!(provider.api_key.as_deref(), Some("test-api-key"));
|
provider.auth,
|
||||||
|
Some(GeminiAuth::ExplicitKey(ref key)) if key == "test-api-key"
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn provider_rejects_empty_key() {
|
||||||
|
let provider = GeminiProvider::new(Some(""));
|
||||||
|
assert!(!matches!(provider.auth, Some(GeminiAuth::ExplicitKey(_))));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
@ -286,33 +371,123 @@ mod tests {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn auth_source_reports_correctly() {
|
fn auth_source_explicit_key() {
|
||||||
let provider = GeminiProvider::new(Some("explicit-key"));
|
let provider = GeminiProvider {
|
||||||
// With explicit key, should report "config" (unless CLI credentials exist)
|
auth: Some(GeminiAuth::ExplicitKey("key".into())),
|
||||||
let source = provider.auth_source();
|
client: Client::new(),
|
||||||
// Should be either "config" or "Gemini CLI OAuth" if CLI is configured
|
};
|
||||||
assert!(source == "config" || source == "Gemini CLI OAuth");
|
assert_eq!(provider.auth_source(), "config");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn auth_source_none_without_credentials() {
|
||||||
|
let provider = GeminiProvider {
|
||||||
|
auth: None,
|
||||||
|
client: Client::new(),
|
||||||
|
};
|
||||||
|
assert_eq!(provider.auth_source(), "none");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn auth_source_oauth() {
|
||||||
|
let provider = GeminiProvider {
|
||||||
|
auth: Some(GeminiAuth::OAuthToken("ya29.mock".into())),
|
||||||
|
client: Client::new(),
|
||||||
|
};
|
||||||
|
assert_eq!(provider.auth_source(), "Gemini CLI OAuth");
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn model_name_formatting() {
|
fn model_name_formatting() {
|
||||||
// Test that model names are formatted correctly
|
assert_eq!(
|
||||||
let model = "gemini-2.0-flash";
|
GeminiProvider::format_model_name("gemini-2.0-flash"),
|
||||||
let formatted = if model.starts_with("models/") {
|
"models/gemini-2.0-flash"
|
||||||
model.to_string()
|
);
|
||||||
} else {
|
assert_eq!(
|
||||||
format!("models/{model}")
|
GeminiProvider::format_model_name("models/gemini-1.5-pro"),
|
||||||
};
|
"models/gemini-1.5-pro"
|
||||||
assert_eq!(formatted, "models/gemini-2.0-flash");
|
);
|
||||||
|
}
|
||||||
|
|
||||||
// Already prefixed
|
#[test]
|
||||||
let model2 = "models/gemini-1.5-pro";
|
fn api_key_url_includes_key_query_param() {
|
||||||
let formatted2 = if model2.starts_with("models/") {
|
let auth = GeminiAuth::ExplicitKey("api-key-123".into());
|
||||||
model2.to_string()
|
let url = GeminiProvider::build_generate_content_url("gemini-2.0-flash", &auth);
|
||||||
} else {
|
assert!(url.contains(":generateContent?key=api-key-123"));
|
||||||
format!("models/{model2}")
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn oauth_url_omits_key_query_param() {
|
||||||
|
let auth = GeminiAuth::OAuthToken("ya29.test-token".into());
|
||||||
|
let url = GeminiProvider::build_generate_content_url("gemini-2.0-flash", &auth);
|
||||||
|
assert!(url.ends_with(":generateContent"));
|
||||||
|
assert!(!url.contains("?key="));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn oauth_request_uses_bearer_auth_header() {
|
||||||
|
let provider = GeminiProvider {
|
||||||
|
auth: Some(GeminiAuth::OAuthToken("ya29.mock-token".into())),
|
||||||
|
client: Client::new(),
|
||||||
};
|
};
|
||||||
assert_eq!(formatted2, "models/gemini-1.5-pro");
|
let auth = GeminiAuth::OAuthToken("ya29.mock-token".into());
|
||||||
|
let url = GeminiProvider::build_generate_content_url("gemini-2.0-flash", &auth);
|
||||||
|
let body = GenerateContentRequest {
|
||||||
|
contents: vec![Content {
|
||||||
|
role: Some("user".into()),
|
||||||
|
parts: vec![Part {
|
||||||
|
text: "hello".into(),
|
||||||
|
}],
|
||||||
|
}],
|
||||||
|
system_instruction: None,
|
||||||
|
generation_config: GenerationConfig {
|
||||||
|
temperature: 0.7,
|
||||||
|
max_output_tokens: 8192,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
let request = provider
|
||||||
|
.build_generate_content_request(&auth, &url, &body)
|
||||||
|
.build()
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
request
|
||||||
|
.headers()
|
||||||
|
.get(AUTHORIZATION)
|
||||||
|
.and_then(|h| h.to_str().ok()),
|
||||||
|
Some("Bearer ya29.mock-token")
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn api_key_request_does_not_set_bearer_header() {
|
||||||
|
let provider = GeminiProvider {
|
||||||
|
auth: Some(GeminiAuth::ExplicitKey("api-key-123".into())),
|
||||||
|
client: Client::new(),
|
||||||
|
};
|
||||||
|
let auth = GeminiAuth::ExplicitKey("api-key-123".into());
|
||||||
|
let url = GeminiProvider::build_generate_content_url("gemini-2.0-flash", &auth);
|
||||||
|
let body = GenerateContentRequest {
|
||||||
|
contents: vec![Content {
|
||||||
|
role: Some("user".into()),
|
||||||
|
parts: vec![Part {
|
||||||
|
text: "hello".into(),
|
||||||
|
}],
|
||||||
|
}],
|
||||||
|
system_instruction: None,
|
||||||
|
generation_config: GenerationConfig {
|
||||||
|
temperature: 0.7,
|
||||||
|
max_output_tokens: 8192,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
let request = provider
|
||||||
|
.build_generate_content_request(&auth, &url, &body)
|
||||||
|
.build()
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert!(request.headers().get(AUTHORIZATION).is_none());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
|
||||||
|
|
@ -281,9 +281,7 @@ mod tests {
|
||||||
"API error with 400 Bad Request"
|
"API error with 400 Bad Request"
|
||||||
)));
|
)));
|
||||||
// Retryable: 429 Too Many Requests
|
// Retryable: 429 Too Many Requests
|
||||||
assert!(!is_non_retryable(&anyhow::anyhow!(
|
assert!(!is_non_retryable(&anyhow::anyhow!("429 Too Many Requests")));
|
||||||
"429 Too Many Requests"
|
|
||||||
)));
|
|
||||||
// Retryable: 408 Request Timeout
|
// Retryable: 408 Request Timeout
|
||||||
assert!(!is_non_retryable(&anyhow::anyhow!("408 Request Timeout")));
|
assert!(!is_non_retryable(&anyhow::anyhow!("408 Request Timeout")));
|
||||||
// Retryable: 5xx server errors
|
// Retryable: 5xx server errors
|
||||||
|
|
|
||||||
|
|
@ -181,7 +181,10 @@ mod tests {
|
||||||
.iter()
|
.iter()
|
||||||
.zip(mocks.iter())
|
.zip(mocks.iter())
|
||||||
.map(|((name, _), mock)| {
|
.map(|((name, _), mock)| {
|
||||||
(name.to_string(), Box::new(Arc::clone(mock)) as Box<dyn Provider>)
|
(
|
||||||
|
name.to_string(),
|
||||||
|
Box::new(Arc::clone(mock)) as Box<dyn Provider>,
|
||||||
|
)
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
|
|
@ -198,11 +201,7 @@ mod tests {
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
let router = RouterProvider::new(
|
let router = RouterProvider::new(provider_list, route_list, "default-model".to_string());
|
||||||
provider_list,
|
|
||||||
route_list,
|
|
||||||
"default-model".to_string(),
|
|
||||||
);
|
|
||||||
|
|
||||||
(router, mocks)
|
(router, mocks)
|
||||||
}
|
}
|
||||||
|
|
@ -270,7 +269,10 @@ mod tests {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn non_hint_model_uses_default_provider() {
|
async fn non_hint_model_uses_default_provider() {
|
||||||
let (router, mocks) = make_router(
|
let (router, mocks) = make_router(
|
||||||
vec![("primary", "primary-response"), ("secondary", "secondary-response")],
|
vec![
|
||||||
|
("primary", "primary-response"),
|
||||||
|
("secondary", "secondary-response"),
|
||||||
|
],
|
||||||
vec![("code", "secondary", "codellama")],
|
vec![("code", "secondary", "codellama")],
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|
@ -285,10 +287,7 @@ mod tests {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn resolve_preserves_model_for_non_hints() {
|
fn resolve_preserves_model_for_non_hints() {
|
||||||
let (router, _) = make_router(
|
let (router, _) = make_router(vec![("default", "ok")], vec![]);
|
||||||
vec![("default", "ok")],
|
|
||||||
vec![],
|
|
||||||
);
|
|
||||||
|
|
||||||
let (idx, model) = router.resolve("gpt-4o");
|
let (idx, model) = router.resolve("gpt-4o");
|
||||||
assert_eq!(idx, 0);
|
assert_eq!(idx, 0);
|
||||||
|
|
@ -320,10 +319,7 @@ mod tests {
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn warmup_calls_all_providers() {
|
async fn warmup_calls_all_providers() {
|
||||||
let (router, _) = make_router(
|
let (router, _) = make_router(vec![("a", "ok"), ("b", "ok")], vec![]);
|
||||||
vec![("a", "ok"), ("b", "ok")],
|
|
||||||
vec![],
|
|
||||||
);
|
|
||||||
|
|
||||||
// Warmup should not error
|
// Warmup should not error
|
||||||
assert!(router.warmup().await.is_ok());
|
assert!(router.warmup().await.is_ok());
|
||||||
|
|
@ -333,7 +329,10 @@ mod tests {
|
||||||
async fn chat_with_system_passes_system_prompt() {
|
async fn chat_with_system_passes_system_prompt() {
|
||||||
let mock = Arc::new(MockProvider::new("response"));
|
let mock = Arc::new(MockProvider::new("response"));
|
||||||
let router = RouterProvider::new(
|
let router = RouterProvider::new(
|
||||||
vec![("default".into(), Box::new(Arc::clone(&mock)) as Box<dyn Provider>)],
|
vec![(
|
||||||
|
"default".into(),
|
||||||
|
Box::new(Arc::clone(&mock)) as Box<dyn Provider>,
|
||||||
|
)],
|
||||||
vec![],
|
vec![],
|
||||||
"model".into(),
|
"model".into(),
|
||||||
);
|
);
|
||||||
|
|
|
||||||
|
|
@ -74,11 +74,10 @@ const BAD_PATTERNS: &[&str] = &[
|
||||||
/// Check if `haystack` contains `word` as a whole word (bounded by non-alphanumeric chars).
|
/// Check if `haystack` contains `word` as a whole word (bounded by non-alphanumeric chars).
|
||||||
fn contains_word(haystack: &str, word: &str) -> bool {
|
fn contains_word(haystack: &str, word: &str) -> bool {
|
||||||
for (i, _) in haystack.match_indices(word) {
|
for (i, _) in haystack.match_indices(word) {
|
||||||
let before_ok = i == 0
|
let before_ok = i == 0 || !haystack.as_bytes()[i - 1].is_ascii_alphanumeric();
|
||||||
|| !haystack.as_bytes()[i - 1].is_ascii_alphanumeric();
|
|
||||||
let after = i + word.len();
|
let after = i + word.len();
|
||||||
let after_ok = after >= haystack.len()
|
let after_ok =
|
||||||
|| !haystack.as_bytes()[after].is_ascii_alphanumeric();
|
after >= haystack.len() || !haystack.as_bytes()[after].is_ascii_alphanumeric();
|
||||||
if before_ok && after_ok {
|
if before_ok && after_ok {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
@ -217,7 +216,11 @@ mod tests {
|
||||||
c.name = "malware-skill".into();
|
c.name = "malware-skill".into();
|
||||||
let res = eval.evaluate(c);
|
let res = eval.evaluate(c);
|
||||||
// 0.5 base + 0.3 license - 0.5 bad_pattern + 0.2 recency = 0.5
|
// 0.5 base + 0.3 license - 0.5 bad_pattern + 0.2 recency = 0.5
|
||||||
assert!(res.scores.security <= 0.5, "security: {}", res.scores.security);
|
assert!(
|
||||||
|
res.scores.security <= 0.5,
|
||||||
|
"security: {}",
|
||||||
|
res.scores.security
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
@ -245,7 +248,11 @@ mod tests {
|
||||||
c.description = "Tools for hackathons and lifehacks".into();
|
c.description = "Tools for hackathons and lifehacks".into();
|
||||||
let res = eval.evaluate(c);
|
let res = eval.evaluate(c);
|
||||||
// "hack" should NOT match "hackathon" or "lifehacks"
|
// "hack" should NOT match "hackathon" or "lifehacks"
|
||||||
assert!(res.scores.security >= 0.5, "security: {}", res.scores.security);
|
assert!(
|
||||||
|
res.scores.security >= 0.5,
|
||||||
|
"security: {}",
|
||||||
|
res.scores.security
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
@ -256,6 +263,10 @@ mod tests {
|
||||||
c.updated_at = None;
|
c.updated_at = None;
|
||||||
let res = eval.evaluate(c);
|
let res = eval.evaluate(c);
|
||||||
// 0.5 base + 0.0 license - 0.5 bad_pattern + 0.0 recency = 0.0
|
// 0.5 base + 0.0 license - 0.5 bad_pattern + 0.0 recency = 0.0
|
||||||
assert!(res.scores.security < 0.5, "security: {}", res.scores.security);
|
assert!(
|
||||||
|
res.scores.security < 0.5,
|
||||||
|
"security: {}",
|
||||||
|
res.scores.security
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -78,10 +78,7 @@ impl std::fmt::Debug for SkillForgeConfig {
|
||||||
.field("sources", &self.sources)
|
.field("sources", &self.sources)
|
||||||
.field("scan_interval_hours", &self.scan_interval_hours)
|
.field("scan_interval_hours", &self.scan_interval_hours)
|
||||||
.field("min_score", &self.min_score)
|
.field("min_score", &self.min_score)
|
||||||
.field(
|
.field("github_token", &self.github_token.as_ref().map(|_| "***"))
|
||||||
"github_token",
|
|
||||||
&self.github_token.as_ref().map(|_| "***"),
|
|
||||||
)
|
|
||||||
.field("output_dir", &self.output_dir)
|
.field("output_dir", &self.output_dir)
|
||||||
.finish()
|
.finish()
|
||||||
}
|
}
|
||||||
|
|
@ -155,7 +152,10 @@ impl SkillForge {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
ScoutSource::ClawHub | ScoutSource::HuggingFace => {
|
ScoutSource::ClawHub | ScoutSource::HuggingFace => {
|
||||||
info!(source = src.as_str(), "Source not yet implemented — skipping");
|
info!(
|
||||||
|
source = src.as_str(),
|
||||||
|
"Source not yet implemented — skipping"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -79,9 +79,7 @@ impl GitHubScout {
|
||||||
let mut headers = reqwest::header::HeaderMap::new();
|
let mut headers = reqwest::header::HeaderMap::new();
|
||||||
headers.insert(
|
headers.insert(
|
||||||
reqwest::header::ACCEPT,
|
reqwest::header::ACCEPT,
|
||||||
"application/vnd.github+json"
|
"application/vnd.github+json".parse().expect("valid header"),
|
||||||
.parse()
|
|
||||||
.expect("valid header"),
|
|
||||||
);
|
);
|
||||||
headers.insert(
|
headers.insert(
|
||||||
reqwest::header::USER_AGENT,
|
reqwest::header::USER_AGENT,
|
||||||
|
|
@ -101,10 +99,7 @@ impl GitHubScout {
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
client,
|
client,
|
||||||
queries: vec![
|
queries: vec!["zeroclaw skill".into(), "ai agent skill".into()],
|
||||||
"zeroclaw skill".into(),
|
|
||||||
"ai agent skill".into(),
|
|
||||||
],
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -143,10 +138,7 @@ impl GitHubScout {
|
||||||
.and_then(|v| v.as_str())
|
.and_then(|v| v.as_str())
|
||||||
.unwrap_or("unknown")
|
.unwrap_or("unknown")
|
||||||
.to_string();
|
.to_string();
|
||||||
let has_license = item
|
let has_license = item.get("license").map(|v| !v.is_null()).unwrap_or(false);
|
||||||
.get("license")
|
|
||||||
.map(|v| !v.is_null())
|
|
||||||
.unwrap_or(false);
|
|
||||||
|
|
||||||
Some(ScoutResult {
|
Some(ScoutResult {
|
||||||
name,
|
name,
|
||||||
|
|
@ -225,9 +217,7 @@ impl Scout for GitHubScout {
|
||||||
|
|
||||||
/// Minimal percent-encoding for query strings (space → +).
|
/// Minimal percent-encoding for query strings (space → +).
|
||||||
fn urlencoding(s: &str) -> String {
|
fn urlencoding(s: &str) -> String {
|
||||||
s.replace(' ', "+")
|
s.replace(' ', "+").replace('&', "%26").replace('#', "%23")
|
||||||
.replace('&', "%26")
|
|
||||||
.replace('#', "%23")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Deduplicate scout results by URL (keeps first occurrence).
|
/// Deduplicate scout results by URL (keeps first occurrence).
|
||||||
|
|
@ -246,13 +236,31 @@ mod tests {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn scout_source_from_str() {
|
fn scout_source_from_str() {
|
||||||
assert_eq!("github".parse::<ScoutSource>().unwrap(), ScoutSource::GitHub);
|
assert_eq!(
|
||||||
assert_eq!("GitHub".parse::<ScoutSource>().unwrap(), ScoutSource::GitHub);
|
"github".parse::<ScoutSource>().unwrap(),
|
||||||
assert_eq!("clawhub".parse::<ScoutSource>().unwrap(), ScoutSource::ClawHub);
|
ScoutSource::GitHub
|
||||||
assert_eq!("huggingface".parse::<ScoutSource>().unwrap(), ScoutSource::HuggingFace);
|
);
|
||||||
assert_eq!("hf".parse::<ScoutSource>().unwrap(), ScoutSource::HuggingFace);
|
assert_eq!(
|
||||||
|
"GitHub".parse::<ScoutSource>().unwrap(),
|
||||||
|
ScoutSource::GitHub
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
"clawhub".parse::<ScoutSource>().unwrap(),
|
||||||
|
ScoutSource::ClawHub
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
"huggingface".parse::<ScoutSource>().unwrap(),
|
||||||
|
ScoutSource::HuggingFace
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
"hf".parse::<ScoutSource>().unwrap(),
|
||||||
|
ScoutSource::HuggingFace
|
||||||
|
);
|
||||||
// unknown falls back to GitHub
|
// unknown falls back to GitHub
|
||||||
assert_eq!("unknown".parse::<ScoutSource>().unwrap(), ScoutSource::GitHub);
|
assert_eq!(
|
||||||
|
"unknown".parse::<ScoutSource>().unwrap(),
|
||||||
|
ScoutSource::GitHub
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
|
||||||
|
|
@ -793,20 +793,14 @@ mod tests {
|
||||||
#[test]
|
#[test]
|
||||||
fn extract_host_handles_ipv6() {
|
fn extract_host_handles_ipv6() {
|
||||||
// IPv6 with brackets (required for URLs with ports)
|
// IPv6 with brackets (required for URLs with ports)
|
||||||
assert_eq!(
|
assert_eq!(extract_host("https://[::1]/path").unwrap(), "[::1]");
|
||||||
extract_host("https://[::1]/path").unwrap(),
|
|
||||||
"[::1]"
|
|
||||||
);
|
|
||||||
// IPv6 with brackets and port
|
// IPv6 with brackets and port
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
extract_host("https://[2001:db8::1]:8080/path").unwrap(),
|
extract_host("https://[2001:db8::1]:8080/path").unwrap(),
|
||||||
"[2001:db8::1]"
|
"[2001:db8::1]"
|
||||||
);
|
);
|
||||||
// IPv6 with brackets, trailing slash
|
// IPv6 with brackets, trailing slash
|
||||||
assert_eq!(
|
assert_eq!(extract_host("https://[fe80::1]/").unwrap(), "[fe80::1]");
|
||||||
extract_host("https://[fe80::1]/").unwrap(),
|
|
||||||
"[fe80::1]"
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue