fix(providers): use Bearer auth for Gemini CLI OAuth tokens

* fix(providers): use Bearer auth for Gemini CLI OAuth tokens

When credentials come from ~/.gemini/oauth_creds.json (Gemini CLI),
send them as Authorization: Bearer header instead of ?key= query
parameter. API keys from env vars or config continue using ?key=.

Fixes #194

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* refactor(gemini): harden OAuth bearer auth flow and tests

* fix(gemini): granular auth source tracking and review fixes

Build on chumyin's auth model refactor with:
- Expand GeminiAuth to 4 variants (ExplicitKey/EnvGeminiKey/EnvGoogleKey/
  OAuthToken) so auth_source() uses stored discriminant without re-reading
  env vars at call time
- Add is_api_key()/credential() helpers on the enum
- Upgrade expired OAuth token log from debug to warn
- Add tests: provider_rejects_empty_key, auth_source_explicit_key,
  auth_source_none_without_credentials

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

* style: apply rustfmt to fix CI lint failures

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>

---------

Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
Co-authored-by: root <root@instance-20220913-1738.vcn09131738.oraclevcn.com>
Co-authored-by: argenis de la rosa <theonlyhennygod@gmail.com>
This commit is contained in:
Edvard Schøyen 2026-02-15 14:32:33 -05:00 committed by GitHub
parent e057bf4128
commit 49bb20f961
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
15 changed files with 358 additions and 148 deletions

View file

@ -154,7 +154,8 @@ mod tests {
#[test]
fn creates_with_custom_base_url() {
let p = AnthropicProvider::with_base_url(Some("sk-ant-test"), Some("https://api.example.com"));
let p =
AnthropicProvider::with_base_url(Some("sk-ant-test"), Some("https://api.example.com"));
assert_eq!(p.base_url, "https://api.example.com");
assert_eq!(p.credential.as_deref(), Some("sk-ant-test"));
}

View file

@ -452,14 +452,20 @@ mod tests {
fn chat_completions_url_standard_openai() {
// Standard OpenAI-compatible providers get /chat/completions appended
let p = make_provider("openai", "https://api.openai.com/v1", None);
assert_eq!(p.chat_completions_url(), "https://api.openai.com/v1/chat/completions");
assert_eq!(
p.chat_completions_url(),
"https://api.openai.com/v1/chat/completions"
);
}
#[test]
fn chat_completions_url_trailing_slash() {
// Trailing slash is stripped, then /chat/completions appended
let p = make_provider("test", "https://api.example.com/v1/", None);
assert_eq!(p.chat_completions_url(), "https://api.example.com/v1/chat/completions");
assert_eq!(
p.chat_completions_url(),
"https://api.example.com/v1/chat/completions"
);
}
#[test]
@ -515,14 +521,20 @@ mod tests {
fn chat_completions_url_without_v1() {
// Provider configured without /v1 in base URL
let p = make_provider("test", "https://api.example.com", None);
assert_eq!(p.chat_completions_url(), "https://api.example.com/chat/completions");
assert_eq!(
p.chat_completions_url(),
"https://api.example.com/chat/completions"
);
}
#[test]
fn chat_completions_url_base_with_v1() {
// Provider configured with /v1 in base URL
let p = make_provider("test", "https://api.example.com/v1", None);
assert_eq!(p.chat_completions_url(), "https://api.example.com/v1/chat/completions");
assert_eq!(
p.chat_completions_url(),
"https://api.example.com/v1/chat/completions"
);
}
// ══════════════════════════════════════════════════════════

View file

@ -12,10 +12,44 @@ use std::path::PathBuf;
/// Gemini provider supporting multiple authentication methods.
pub struct GeminiProvider {
api_key: Option<String>,
auth: Option<GeminiAuth>,
client: Client,
}
/// Resolved credential — the variant determines both the HTTP auth method
/// and the diagnostic label returned by `auth_source()`.
#[derive(Debug)]
enum GeminiAuth {
/// Explicit API key from config: sent as `?key=` query parameter.
ExplicitKey(String),
/// API key from `GEMINI_API_KEY` env var: sent as `?key=`.
EnvGeminiKey(String),
/// API key from `GOOGLE_API_KEY` env var: sent as `?key=`.
EnvGoogleKey(String),
/// OAuth access token from Gemini CLI: sent as `Authorization: Bearer`.
OAuthToken(String),
}
impl GeminiAuth {
/// Whether this credential is an API key (sent as `?key=` query param).
fn is_api_key(&self) -> bool {
matches!(
self,
GeminiAuth::ExplicitKey(_) | GeminiAuth::EnvGeminiKey(_) | GeminiAuth::EnvGoogleKey(_)
)
}
/// The raw credential string.
fn credential(&self) -> &str {
match self {
GeminiAuth::ExplicitKey(s)
| GeminiAuth::EnvGeminiKey(s)
| GeminiAuth::EnvGoogleKey(s)
| GeminiAuth::OAuthToken(s) => s,
}
}
}
// ══════════════════════════════════════════════════════════════════════════════
// API REQUEST/RESPONSE TYPES
// ══════════════════════════════════════════════════════════════════════════════
@ -82,17 +116,9 @@ struct ApiError {
#[derive(Debug, Deserialize)]
struct GeminiCliOAuthCreds {
access_token: Option<String>,
refresh_token: Option<String>,
expiry: Option<String>,
}
/// Settings stored by Gemini CLI in ~/.gemini/settings.json
#[derive(Debug, Deserialize)]
struct GeminiCliSettings {
#[serde(rename = "selectedAuthType")]
selected_auth_type: Option<String>,
}
impl GeminiProvider {
/// Create a new Gemini provider.
///
@ -102,14 +128,15 @@ impl GeminiProvider {
/// 3. `GOOGLE_API_KEY` environment variable
/// 4. Gemini CLI OAuth tokens (`~/.gemini/oauth_creds.json`)
pub fn new(api_key: Option<&str>) -> Self {
let resolved_key = api_key
.map(String::from)
.or_else(|| std::env::var("GEMINI_API_KEY").ok())
.or_else(|| std::env::var("GOOGLE_API_KEY").ok())
.or_else(Self::try_load_gemini_cli_token);
let resolved_auth = api_key
.and_then(Self::normalize_non_empty)
.map(GeminiAuth::ExplicitKey)
.or_else(|| Self::load_non_empty_env("GEMINI_API_KEY").map(GeminiAuth::EnvGeminiKey))
.or_else(|| Self::load_non_empty_env("GOOGLE_API_KEY").map(GeminiAuth::EnvGoogleKey))
.or_else(|| Self::try_load_gemini_cli_token().map(GeminiAuth::OAuthToken));
Self {
api_key: resolved_key,
auth: resolved_auth,
client: Client::builder()
.timeout(std::time::Duration::from_secs(120))
.connect_timeout(std::time::Duration::from_secs(10))
@ -118,6 +145,21 @@ impl GeminiProvider {
}
}
fn normalize_non_empty(value: &str) -> Option<String> {
let trimmed = value.trim();
if trimmed.is_empty() {
None
} else {
Some(trimmed.to_string())
}
}
fn load_non_empty_env(name: &str) -> Option<String> {
std::env::var(name)
.ok()
.and_then(|value| Self::normalize_non_empty(&value))
}
/// Try to load OAuth access token from Gemini CLI's cached credentials.
/// Location: `~/.gemini/oauth_creds.json`
fn try_load_gemini_cli_token() -> Option<String> {
@ -135,13 +177,15 @@ impl GeminiProvider {
if let Some(ref expiry) = creds.expiry {
if let Ok(expiry_time) = chrono::DateTime::parse_from_rfc3339(expiry) {
if expiry_time < chrono::Utc::now() {
tracing::debug!("Gemini CLI OAuth token expired, skipping");
tracing::warn!("Gemini CLI OAuth token expired — re-run `gemini` to refresh");
return None;
}
}
}
creds.access_token
creds
.access_token
.and_then(|token| Self::normalize_non_empty(&token))
}
/// Get the Gemini CLI config directory (~/.gemini)
@ -156,26 +200,55 @@ impl GeminiProvider {
/// Check if any Gemini authentication is available
pub fn has_any_auth() -> bool {
std::env::var("GEMINI_API_KEY").is_ok()
|| std::env::var("GOOGLE_API_KEY").is_ok()
Self::load_non_empty_env("GEMINI_API_KEY").is_some()
|| Self::load_non_empty_env("GOOGLE_API_KEY").is_some()
|| Self::has_cli_credentials()
}
/// Get authentication source description for diagnostics
/// Get authentication source description for diagnostics.
/// Uses the stored enum variant — no env var re-reading at call time.
pub fn auth_source(&self) -> &'static str {
if self.api_key.is_none() {
return "none";
match self.auth.as_ref() {
Some(GeminiAuth::ExplicitKey(_)) => "config",
Some(GeminiAuth::EnvGeminiKey(_)) => "GEMINI_API_KEY env var",
Some(GeminiAuth::EnvGoogleKey(_)) => "GOOGLE_API_KEY env var",
Some(GeminiAuth::OAuthToken(_)) => "Gemini CLI OAuth",
None => "none",
}
if std::env::var("GEMINI_API_KEY").is_ok() {
return "GEMINI_API_KEY env var";
}
fn format_model_name(model: &str) -> String {
if model.starts_with("models/") {
model.to_string()
} else {
format!("models/{model}")
}
if std::env::var("GOOGLE_API_KEY").is_ok() {
return "GOOGLE_API_KEY env var";
}
fn build_generate_content_url(model: &str, auth: &GeminiAuth) -> String {
let model_name = Self::format_model_name(model);
let base_url = format!(
"https://generativelanguage.googleapis.com/v1beta/{model_name}:generateContent"
);
if auth.is_api_key() {
format!("{base_url}?key={}", auth.credential())
} else {
base_url
}
if Self::has_cli_credentials() {
return "Gemini CLI OAuth";
}
fn build_generate_content_request(
&self,
auth: &GeminiAuth,
url: &str,
request: &GenerateContentRequest,
) -> reqwest::RequestBuilder {
let req = self.client.post(url).json(request);
match auth {
GeminiAuth::OAuthToken(token) => req.bearer_auth(token),
_ => req,
}
"config"
}
}
@ -188,7 +261,7 @@ impl Provider for GeminiProvider {
model: &str,
temperature: f64,
) -> anyhow::Result<String> {
let api_key = self.api_key.as_ref().ok_or_else(|| {
let auth = self.auth.as_ref().ok_or_else(|| {
anyhow::anyhow!(
"Gemini API key not found. Options:\n\
1. Set GEMINI_API_KEY env var\n\
@ -220,19 +293,12 @@ impl Provider for GeminiProvider {
},
};
// Gemini API endpoint
// Model format: gemini-2.0-flash, gemini-1.5-pro, etc.
let model_name = if model.starts_with("models/") {
model.to_string()
} else {
format!("models/{model}")
};
let url = Self::build_generate_content_url(model, auth);
let url = format!(
"https://generativelanguage.googleapis.com/v1beta/{model_name}:generateContent?key={api_key}"
);
let response = self.client.post(&url).json(&request).send().await?;
let response = self
.build_generate_content_request(auth, &url, &request)
.send()
.await?;
if !response.status().is_success() {
let status = response.status();
@ -260,19 +326,38 @@ impl Provider for GeminiProvider {
#[cfg(test)]
mod tests {
use super::*;
use reqwest::header::AUTHORIZATION;
#[test]
fn normalize_non_empty_trims_and_filters() {
assert_eq!(
GeminiProvider::normalize_non_empty(" value "),
Some("value".into())
);
assert_eq!(GeminiProvider::normalize_non_empty(""), None);
assert_eq!(GeminiProvider::normalize_non_empty(" \t\n"), None);
}
#[test]
fn provider_creates_without_key() {
let provider = GeminiProvider::new(None);
// Should not panic, just have no key
assert!(provider.api_key.is_none() || provider.api_key.is_some());
// May pick up env vars; just verify it doesn't panic
let _ = provider.auth_source();
}
#[test]
fn provider_creates_with_key() {
let provider = GeminiProvider::new(Some("test-api-key"));
assert!(provider.api_key.is_some());
assert_eq!(provider.api_key.as_deref(), Some("test-api-key"));
assert!(matches!(
provider.auth,
Some(GeminiAuth::ExplicitKey(ref key)) if key == "test-api-key"
));
}
#[test]
fn provider_rejects_empty_key() {
let provider = GeminiProvider::new(Some(""));
assert!(!matches!(provider.auth, Some(GeminiAuth::ExplicitKey(_))));
}
#[test]
@ -286,33 +371,123 @@ mod tests {
}
#[test]
fn auth_source_reports_correctly() {
let provider = GeminiProvider::new(Some("explicit-key"));
// With explicit key, should report "config" (unless CLI credentials exist)
let source = provider.auth_source();
// Should be either "config" or "Gemini CLI OAuth" if CLI is configured
assert!(source == "config" || source == "Gemini CLI OAuth");
fn auth_source_explicit_key() {
let provider = GeminiProvider {
auth: Some(GeminiAuth::ExplicitKey("key".into())),
client: Client::new(),
};
assert_eq!(provider.auth_source(), "config");
}
#[test]
fn auth_source_none_without_credentials() {
let provider = GeminiProvider {
auth: None,
client: Client::new(),
};
assert_eq!(provider.auth_source(), "none");
}
#[test]
fn auth_source_oauth() {
let provider = GeminiProvider {
auth: Some(GeminiAuth::OAuthToken("ya29.mock".into())),
client: Client::new(),
};
assert_eq!(provider.auth_source(), "Gemini CLI OAuth");
}
#[test]
fn model_name_formatting() {
// Test that model names are formatted correctly
let model = "gemini-2.0-flash";
let formatted = if model.starts_with("models/") {
model.to_string()
} else {
format!("models/{model}")
};
assert_eq!(formatted, "models/gemini-2.0-flash");
assert_eq!(
GeminiProvider::format_model_name("gemini-2.0-flash"),
"models/gemini-2.0-flash"
);
assert_eq!(
GeminiProvider::format_model_name("models/gemini-1.5-pro"),
"models/gemini-1.5-pro"
);
}
// Already prefixed
let model2 = "models/gemini-1.5-pro";
let formatted2 = if model2.starts_with("models/") {
model2.to_string()
} else {
format!("models/{model2}")
#[test]
fn api_key_url_includes_key_query_param() {
let auth = GeminiAuth::ExplicitKey("api-key-123".into());
let url = GeminiProvider::build_generate_content_url("gemini-2.0-flash", &auth);
assert!(url.contains(":generateContent?key=api-key-123"));
}
#[test]
fn oauth_url_omits_key_query_param() {
let auth = GeminiAuth::OAuthToken("ya29.test-token".into());
let url = GeminiProvider::build_generate_content_url("gemini-2.0-flash", &auth);
assert!(url.ends_with(":generateContent"));
assert!(!url.contains("?key="));
}
#[test]
fn oauth_request_uses_bearer_auth_header() {
let provider = GeminiProvider {
auth: Some(GeminiAuth::OAuthToken("ya29.mock-token".into())),
client: Client::new(),
};
assert_eq!(formatted2, "models/gemini-1.5-pro");
let auth = GeminiAuth::OAuthToken("ya29.mock-token".into());
let url = GeminiProvider::build_generate_content_url("gemini-2.0-flash", &auth);
let body = GenerateContentRequest {
contents: vec![Content {
role: Some("user".into()),
parts: vec![Part {
text: "hello".into(),
}],
}],
system_instruction: None,
generation_config: GenerationConfig {
temperature: 0.7,
max_output_tokens: 8192,
},
};
let request = provider
.build_generate_content_request(&auth, &url, &body)
.build()
.unwrap();
assert_eq!(
request
.headers()
.get(AUTHORIZATION)
.and_then(|h| h.to_str().ok()),
Some("Bearer ya29.mock-token")
);
}
#[test]
fn api_key_request_does_not_set_bearer_header() {
let provider = GeminiProvider {
auth: Some(GeminiAuth::ExplicitKey("api-key-123".into())),
client: Client::new(),
};
let auth = GeminiAuth::ExplicitKey("api-key-123".into());
let url = GeminiProvider::build_generate_content_url("gemini-2.0-flash", &auth);
let body = GenerateContentRequest {
contents: vec![Content {
role: Some("user".into()),
parts: vec![Part {
text: "hello".into(),
}],
}],
system_instruction: None,
generation_config: GenerationConfig {
temperature: 0.7,
max_output_tokens: 8192,
},
};
let request = provider
.build_generate_content_request(&auth, &url, &body)
.build()
.unwrap();
assert!(request.headers().get(AUTHORIZATION).is_none());
}
#[test]

View file

@ -281,9 +281,7 @@ mod tests {
"API error with 400 Bad Request"
)));
// Retryable: 429 Too Many Requests
assert!(!is_non_retryable(&anyhow::anyhow!(
"429 Too Many Requests"
)));
assert!(!is_non_retryable(&anyhow::anyhow!("429 Too Many Requests")));
// Retryable: 408 Request Timeout
assert!(!is_non_retryable(&anyhow::anyhow!("408 Request Timeout")));
// Retryable: 5xx server errors

View file

@ -181,7 +181,10 @@ mod tests {
.iter()
.zip(mocks.iter())
.map(|((name, _), mock)| {
(name.to_string(), Box::new(Arc::clone(mock)) as Box<dyn Provider>)
(
name.to_string(),
Box::new(Arc::clone(mock)) as Box<dyn Provider>,
)
})
.collect();
@ -198,11 +201,7 @@ mod tests {
})
.collect();
let router = RouterProvider::new(
provider_list,
route_list,
"default-model".to_string(),
);
let router = RouterProvider::new(provider_list, route_list, "default-model".to_string());
(router, mocks)
}
@ -270,7 +269,10 @@ mod tests {
#[tokio::test]
async fn non_hint_model_uses_default_provider() {
let (router, mocks) = make_router(
vec![("primary", "primary-response"), ("secondary", "secondary-response")],
vec![
("primary", "primary-response"),
("secondary", "secondary-response"),
],
vec![("code", "secondary", "codellama")],
);
@ -285,10 +287,7 @@ mod tests {
#[test]
fn resolve_preserves_model_for_non_hints() {
let (router, _) = make_router(
vec![("default", "ok")],
vec![],
);
let (router, _) = make_router(vec![("default", "ok")], vec![]);
let (idx, model) = router.resolve("gpt-4o");
assert_eq!(idx, 0);
@ -320,10 +319,7 @@ mod tests {
#[tokio::test]
async fn warmup_calls_all_providers() {
let (router, _) = make_router(
vec![("a", "ok"), ("b", "ok")],
vec![],
);
let (router, _) = make_router(vec![("a", "ok"), ("b", "ok")], vec![]);
// Warmup should not error
assert!(router.warmup().await.is_ok());
@ -333,7 +329,10 @@ mod tests {
async fn chat_with_system_passes_system_prompt() {
let mock = Arc::new(MockProvider::new("response"));
let router = RouterProvider::new(
vec![("default".into(), Box::new(Arc::clone(&mock)) as Box<dyn Provider>)],
vec![(
"default".into(),
Box::new(Arc::clone(&mock)) as Box<dyn Provider>,
)],
vec![],
"model".into(),
);