feat(auth): add subscription auth profiles and codex/claude flows

This commit is contained in:
Codex 2026-02-15 19:02:41 +03:00 committed by Chummy
parent 6d8725c9e6
commit 007368d586
13 changed files with 1981 additions and 12 deletions

View file

@ -60,6 +60,7 @@ chacha20poly1305 = "0.10"
hmac = "0.12"
sha2 = "0.10"
hex = "0.4"
base64 = "0.22"
# CSPRNG for secure token generation
rand = "0.9"
@ -169,7 +170,6 @@ strip = true
panic = "abort"
[dev-dependencies]
tokio-test = "0.4"
tempfile = "3.14"
criterion = { version = "0.5", features = ["async_tokio"] }

View file

@ -164,6 +164,7 @@ zeroclaw daemon
# Check status
zeroclaw status
zeroclaw auth status
# Run system diagnostics
zeroclaw doctor
@ -188,6 +189,51 @@ zeroclaw migrate openclaw
> **Dev fallback (no global install):** prefix commands with `cargo run --release --` (example: `cargo run --release -- status`).
## Subscription Auth (OpenAI Codex / Claude Code)
ZeroClaw now supports subscription-native auth profiles (multi-account, encrypted at rest).
- Store file: `~/.zeroclaw/auth-profiles.json`
- Encryption key: `~/.zeroclaw/.secret_key`
- Profile id format: `<provider>:<profile_name>` (example: `openai-codex:work`)
OpenAI Codex OAuth (ChatGPT subscription):
```bash
# Recommended on servers/headless
zeroclaw auth login --provider openai-codex --device-code
# Browser/callback flow with paste fallback
zeroclaw auth login --provider openai-codex --profile default
zeroclaw auth paste-redirect --provider openai-codex --profile default
# Check / refresh / switch profile
zeroclaw auth status
zeroclaw auth refresh --provider openai-codex --profile default
zeroclaw auth use --provider openai-codex --profile work
```
Claude Code / Anthropic setup-token:
```bash
# Paste subscription/setup token (Authorization header mode)
zeroclaw auth paste-token --provider anthropic --profile default --auth-kind authorization
# Alias command
zeroclaw auth setup-token --provider anthropic --profile default
```
Run the agent with subscription auth:
```bash
zeroclaw agent --provider openai-codex -m "hello"
zeroclaw agent --provider openai-codex --auth-profile openai-codex:work -m "hello"
# Anthropic supports both API key and auth token env vars:
# ANTHROPIC_AUTH_TOKEN, ANTHROPIC_OAUTH_TOKEN, ANTHROPIC_API_KEY
zeroclaw agent --provider anthropic -m "hello"
```
## Architecture
Every subsystem is a **trait** — swap implementations with a config change, zero code changes.

View file

@ -0,0 +1,88 @@
use serde::{Deserialize, Serialize};
/// How Anthropic credentials should be sent.
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "kebab-case")]
pub enum AnthropicAuthKind {
/// Standard Anthropic API key via `x-api-key`.
ApiKey,
/// Subscription / setup token via `Authorization: Bearer ...`.
Authorization,
}
impl AnthropicAuthKind {
pub fn as_metadata_value(self) -> &'static str {
match self {
Self::ApiKey => "api-key",
Self::Authorization => "authorization",
}
}
pub fn from_metadata_value(value: &str) -> Option<Self> {
match value.trim().to_ascii_lowercase().as_str() {
"api-key" | "x-api-key" | "apikey" => Some(Self::ApiKey),
"authorization" | "bearer" | "auth-token" | "oauth" => {
Some(Self::Authorization)
}
_ => None,
}
}
}
/// Detect auth kind with explicit override support.
pub fn detect_auth_kind(token: &str, explicit: Option<&str>) -> AnthropicAuthKind {
if let Some(kind) = explicit.and_then(AnthropicAuthKind::from_metadata_value) {
return kind;
}
let trimmed = token.trim();
// JWT-like shape strongly suggests bearer token mode.
if trimmed.matches('.').count() >= 2 {
return AnthropicAuthKind::Authorization;
}
// Anthropic platform keys commonly start with this prefix.
if trimmed.starts_with("sk-ant-api") {
return AnthropicAuthKind::ApiKey;
}
// Default to API key for backward compatibility unless explicitly configured.
AnthropicAuthKind::ApiKey
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_kind_from_metadata() {
assert_eq!(
AnthropicAuthKind::from_metadata_value("authorization"),
Some(AnthropicAuthKind::Authorization)
);
assert_eq!(
AnthropicAuthKind::from_metadata_value("x-api-key"),
Some(AnthropicAuthKind::ApiKey)
);
assert_eq!(AnthropicAuthKind::from_metadata_value("nope"), None);
}
#[test]
fn detect_prefers_override() {
let kind = detect_auth_kind("sk-ant-api-123", Some("authorization"));
assert_eq!(kind, AnthropicAuthKind::Authorization);
}
#[test]
fn detect_jwt_like_as_authorization() {
let kind = detect_auth_kind("aaa.bbb.ccc", None);
assert_eq!(kind, AnthropicAuthKind::Authorization);
}
#[test]
fn detect_default_for_api_prefix() {
let kind = detect_auth_kind("sk-ant-api-123", None);
assert_eq!(kind, AnthropicAuthKind::ApiKey);
}
}

377
src/auth/mod.rs Normal file
View file

@ -0,0 +1,377 @@
pub mod anthropic_token;
pub mod openai_oauth;
pub mod profiles;
use crate::auth::openai_oauth::refresh_access_token;
use crate::auth::profiles::{profile_id, AuthProfile, AuthProfileKind, AuthProfilesData, AuthProfilesStore};
use crate::config::Config;
use anyhow::Result;
use std::collections::HashMap;
use std::path::{Path, PathBuf};
use std::sync::{Arc, Mutex, OnceLock};
use std::time::{Duration, Instant};
const OPENAI_CODEX_PROVIDER: &str = "openai-codex";
const ANTHROPIC_PROVIDER: &str = "anthropic";
const DEFAULT_PROFILE_NAME: &str = "default";
const OPENAI_REFRESH_SKEW_SECS: u64 = 90;
const OPENAI_REFRESH_FAILURE_BACKOFF_SECS: u64 = 10;
static REFRESH_BACKOFFS: OnceLock<Mutex<HashMap<String, Instant>>> = OnceLock::new();
#[derive(Clone)]
pub struct AuthService {
store: AuthProfilesStore,
client: reqwest::Client,
}
impl AuthService {
pub fn from_config(config: &Config) -> Self {
let state_dir = state_dir_from_config(config);
Self::new(&state_dir, config.secrets.encrypt)
}
pub fn new(state_dir: &Path, encrypt_secrets: bool) -> Self {
Self {
store: AuthProfilesStore::new(state_dir, encrypt_secrets),
client: reqwest::Client::new(),
}
}
pub fn load_profiles(&self) -> Result<AuthProfilesData> {
self.store.load()
}
pub fn store_openai_tokens(
&self,
profile_name: &str,
token_set: crate::auth::profiles::TokenSet,
account_id: Option<String>,
set_active: bool,
) -> Result<AuthProfile> {
let mut profile = AuthProfile::new_oauth(OPENAI_CODEX_PROVIDER, profile_name, token_set);
profile.account_id = account_id;
self.store.upsert_profile(profile.clone(), set_active)?;
Ok(profile)
}
pub fn store_provider_token(
&self,
provider: &str,
profile_name: &str,
token: &str,
metadata: HashMap<String, String>,
set_active: bool,
) -> Result<AuthProfile> {
let mut profile = AuthProfile::new_token(provider, profile_name, token.to_string());
profile.metadata.extend(metadata);
self.store.upsert_profile(profile.clone(), set_active)?;
Ok(profile)
}
pub fn set_active_profile(&self, provider: &str, requested_profile: &str) -> Result<String> {
let provider = normalize_provider(provider)?;
let data = self.store.load()?;
let profile_id = resolve_requested_profile_id(&provider, requested_profile);
let profile = data
.profiles
.get(&profile_id)
.ok_or_else(|| anyhow::anyhow!("Auth profile not found: {profile_id}"))?;
if profile.provider != provider {
anyhow::bail!(
"Profile {profile_id} belongs to provider {}, not {}",
profile.provider,
provider
);
}
self.store.set_active_profile(&provider, &profile_id)?;
Ok(profile_id)
}
pub fn remove_profile(&self, provider: &str, requested_profile: &str) -> Result<bool> {
let provider = normalize_provider(provider)?;
let profile_id = resolve_requested_profile_id(&provider, requested_profile);
self.store.remove_profile(&profile_id)
}
pub fn get_profile(
&self,
provider: &str,
profile_override: Option<&str>,
) -> Result<Option<AuthProfile>> {
let provider = normalize_provider(provider)?;
let data = self.store.load()?;
let Some(profile_id) = select_profile_id(&data, &provider, profile_override) else {
return Ok(None);
};
Ok(data.profiles.get(&profile_id).cloned())
}
pub fn get_provider_bearer_token(
&self,
provider: &str,
profile_override: Option<&str>,
) -> Result<Option<String>> {
let profile = self.get_profile(provider, profile_override)?;
let Some(profile) = profile else {
return Ok(None);
};
let token = match profile.kind {
AuthProfileKind::Token => profile.token,
AuthProfileKind::OAuth => profile.token_set.map(|t| t.access_token),
};
Ok(token.filter(|t| !t.trim().is_empty()))
}
pub async fn get_valid_openai_access_token(
&self,
profile_override: Option<&str>,
) -> Result<Option<String>> {
let data = self.store.load()?;
let Some(profile_id) = select_profile_id(&data, OPENAI_CODEX_PROVIDER, profile_override)
else {
return Ok(None);
};
let Some(profile) = data.profiles.get(&profile_id) else {
return Ok(None);
};
let Some(token_set) = profile.token_set.as_ref() else {
anyhow::bail!("OpenAI Codex auth profile is not OAuth-based: {profile_id}");
};
if !token_set.is_expiring_within(Duration::from_secs(OPENAI_REFRESH_SKEW_SECS)) {
return Ok(Some(token_set.access_token.clone()));
}
let Some(refresh_token) = token_set.refresh_token.clone() else {
return Ok(Some(token_set.access_token.clone()));
};
let refresh_lock = refresh_lock_for_profile(&profile_id);
let _guard = refresh_lock.lock().await;
// Re-load after waiting for lock to avoid duplicate refreshes.
let data = self.store.load()?;
let Some(latest_profile) = data.profiles.get(&profile_id) else {
return Ok(None);
};
let Some(latest_tokens) = latest_profile.token_set.as_ref() else {
anyhow::bail!("OpenAI Codex auth profile is missing token set: {profile_id}");
};
if !latest_tokens.is_expiring_within(Duration::from_secs(OPENAI_REFRESH_SKEW_SECS)) {
return Ok(Some(latest_tokens.access_token.clone()));
}
let refresh_token = latest_tokens
.refresh_token
.clone()
.unwrap_or(refresh_token);
if let Some(remaining) = refresh_backoff_remaining(&profile_id) {
anyhow::bail!(
"OpenAI token refresh is in backoff for {remaining}s due to previous failures"
);
}
let mut refreshed = match refresh_access_token(&self.client, &refresh_token).await {
Ok(tokens) => {
clear_refresh_backoff(&profile_id);
tokens
}
Err(err) => {
set_refresh_backoff(
&profile_id,
Duration::from_secs(OPENAI_REFRESH_FAILURE_BACKOFF_SECS),
);
return Err(err);
}
};
if refreshed.refresh_token.is_none() {
refreshed
.refresh_token
.clone_from(&latest_tokens.refresh_token);
}
let refreshed_clone = refreshed.clone();
let account_id = openai_oauth::extract_account_id_from_jwt(&refreshed.access_token)
.or_else(|| latest_profile.account_id.clone());
let updated = self.store.update_profile(&profile_id, |profile| {
profile.kind = AuthProfileKind::OAuth;
profile.token_set = Some(refreshed_clone.clone());
profile.account_id.clone_from(&account_id);
Ok(())
})?;
Ok(updated.token_set.map(|t| t.access_token))
}
}
pub fn normalize_provider(provider: &str) -> Result<String> {
let normalized = provider.trim().to_ascii_lowercase();
match normalized.as_str() {
"openai-codex" | "openai_codex" | "codex" => Ok(OPENAI_CODEX_PROVIDER.to_string()),
"anthropic" | "claude" | "claude-code" => Ok(ANTHROPIC_PROVIDER.to_string()),
other if !other.is_empty() => Ok(other.to_string()),
_ => anyhow::bail!("Provider name cannot be empty"),
}
}
pub fn state_dir_from_config(config: &Config) -> PathBuf {
config
.config_path
.parent()
.map_or_else(|| PathBuf::from("."), PathBuf::from)
}
pub fn default_profile_id(provider: &str) -> String {
profile_id(provider, DEFAULT_PROFILE_NAME)
}
fn resolve_requested_profile_id(provider: &str, requested: &str) -> String {
if requested.contains(':') {
requested.to_string()
} else {
profile_id(provider, requested)
}
}
pub fn select_profile_id(
data: &AuthProfilesData,
provider: &str,
profile_override: Option<&str>,
) -> Option<String> {
if let Some(override_profile) = profile_override {
let requested = resolve_requested_profile_id(provider, override_profile);
if data.profiles.contains_key(&requested) {
return Some(requested);
}
return None;
}
if let Some(active) = data.active_profiles.get(provider) {
if data.profiles.contains_key(active) {
return Some(active.clone());
}
}
let default = default_profile_id(provider);
if data.profiles.contains_key(&default) {
return Some(default);
}
data.profiles
.iter()
.find_map(|(id, profile)| (profile.provider == provider).then(|| id.clone()))
}
fn refresh_lock_for_profile(profile_id: &str) -> Arc<tokio::sync::Mutex<()>> {
static LOCKS: OnceLock<Mutex<HashMap<String, Arc<tokio::sync::Mutex<()>>>>> = OnceLock::new();
let table = LOCKS.get_or_init(|| Mutex::new(HashMap::new()));
let mut guard = table.lock().expect("refresh lock table poisoned");
guard
.entry(profile_id.to_string())
.or_insert_with(|| Arc::new(tokio::sync::Mutex::new(())))
.clone()
}
fn refresh_backoff_remaining(profile_id: &str) -> Option<u64> {
let map = REFRESH_BACKOFFS.get_or_init(|| Mutex::new(HashMap::new()));
let mut guard = map.lock().ok()?;
let now = Instant::now();
let deadline = guard.get(profile_id).copied()?;
if deadline <= now {
guard.remove(profile_id);
return None;
}
Some((deadline - now).as_secs().max(1))
}
fn set_refresh_backoff(profile_id: &str, duration: Duration) {
let map = REFRESH_BACKOFFS.get_or_init(|| Mutex::new(HashMap::new()));
if let Ok(mut guard) = map.lock() {
guard.insert(profile_id.to_string(), Instant::now() + duration);
}
}
fn clear_refresh_backoff(profile_id: &str) {
let map = REFRESH_BACKOFFS.get_or_init(|| Mutex::new(HashMap::new()));
if let Ok(mut guard) = map.lock() {
guard.remove(profile_id);
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::auth::profiles::{AuthProfile, AuthProfileKind};
#[test]
fn normalize_provider_aliases() {
assert_eq!(normalize_provider("codex").unwrap(), "openai-codex");
assert_eq!(normalize_provider("claude").unwrap(), "anthropic");
assert_eq!(normalize_provider("openai").unwrap(), "openai");
}
#[test]
fn select_profile_prefers_override_then_active_then_default() {
let mut data = AuthProfilesData::default();
let id_active = profile_id("openai-codex", "work");
let id_default = profile_id("openai-codex", "default");
data.profiles.insert(
id_default.clone(),
AuthProfile {
id: id_default.clone(),
provider: "openai-codex".into(),
profile_name: "default".into(),
kind: AuthProfileKind::Token,
account_id: None,
workspace_id: None,
token_set: None,
token: Some("x".into()),
metadata: std::collections::BTreeMap::default(),
created_at: chrono::Utc::now(),
updated_at: chrono::Utc::now(),
},
);
data.profiles.insert(
id_active.clone(),
AuthProfile {
id: id_active.clone(),
provider: "openai-codex".into(),
profile_name: "work".into(),
kind: AuthProfileKind::Token,
account_id: None,
workspace_id: None,
token_set: None,
token: Some("y".into()),
metadata: std::collections::BTreeMap::default(),
created_at: chrono::Utc::now(),
updated_at: chrono::Utc::now(),
},
);
data.active_profiles
.insert("openai-codex".into(), id_active.clone());
assert_eq!(
select_profile_id(&data, "openai-codex", Some("default")),
Some(id_default)
);
assert_eq!(
select_profile_id(&data, "openai-codex", None),
Some(id_active)
);
}
}

491
src/auth/openai_oauth.rs Normal file
View file

@ -0,0 +1,491 @@
use crate::auth::profiles::TokenSet;
use anyhow::{Context, Result};
use base64::Engine;
use chrono::Utc;
use reqwest::Client;
use serde::Deserialize;
use sha2::{Digest, Sha256};
use std::collections::BTreeMap;
use std::time::{Duration, Instant};
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpListener;
pub const OPENAI_OAUTH_CLIENT_ID: &str = "app_EMoamEEZ73f0CkXaXp7hrann";
pub const OPENAI_OAUTH_AUTHORIZE_URL: &str = "https://auth.openai.com/oauth/authorize";
pub const OPENAI_OAUTH_TOKEN_URL: &str = "https://auth.openai.com/oauth/token";
pub const OPENAI_OAUTH_DEVICE_CODE_URL: &str = "https://auth.openai.com/oauth/device/code";
pub const OPENAI_OAUTH_REDIRECT_URI: &str = "http://127.0.0.1:1455/auth/callback";
#[derive(Debug, Clone)]
pub struct PkceState {
pub code_verifier: String,
pub code_challenge: String,
pub state: String,
}
#[derive(Debug, Clone)]
pub struct DeviceCodeStart {
pub device_code: String,
pub user_code: String,
pub verification_uri: String,
pub verification_uri_complete: Option<String>,
pub expires_in: u64,
pub interval: u64,
pub message: Option<String>,
}
#[derive(Debug, Deserialize)]
struct TokenResponse {
access_token: String,
#[serde(default)]
refresh_token: Option<String>,
#[serde(default)]
id_token: Option<String>,
#[serde(default)]
expires_in: Option<i64>,
#[serde(default)]
token_type: Option<String>,
#[serde(default)]
scope: Option<String>,
}
#[derive(Debug, Deserialize)]
struct DeviceCodeResponse {
device_code: String,
user_code: String,
verification_uri: String,
#[serde(default)]
verification_uri_complete: Option<String>,
expires_in: u64,
#[serde(default)]
interval: Option<u64>,
#[serde(default)]
message: Option<String>,
}
#[derive(Debug, Deserialize)]
struct OAuthErrorResponse {
error: String,
#[serde(default)]
error_description: Option<String>,
}
pub fn generate_pkce_state() -> PkceState {
let code_verifier = random_base64url(64);
let digest = Sha256::digest(code_verifier.as_bytes());
let code_challenge = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(digest);
PkceState {
code_verifier,
code_challenge,
state: random_base64url(24),
}
}
pub fn build_authorize_url(pkce: &PkceState) -> String {
let mut params = BTreeMap::new();
params.insert("response_type", "code");
params.insert("client_id", OPENAI_OAUTH_CLIENT_ID);
params.insert("redirect_uri", OPENAI_OAUTH_REDIRECT_URI);
params.insert("scope", "openid profile email offline_access");
params.insert("code_challenge", pkce.code_challenge.as_str());
params.insert("code_challenge_method", "S256");
params.insert("state", pkce.state.as_str());
params.insert("codex_cli_simplified_flow", "true");
params.insert("id_token_add_organizations", "true");
let mut encoded: Vec<String> = Vec::with_capacity(params.len());
for (k, v) in params {
encoded.push(format!(
"{}={}",
url_encode(k),
url_encode(v)
));
}
format!(
"{OPENAI_OAUTH_AUTHORIZE_URL}?{}",
encoded.join("&")
)
}
pub async fn exchange_code_for_tokens(client: &Client, code: &str, pkce: &PkceState) -> Result<TokenSet> {
let form = [
("grant_type", "authorization_code"),
("code", code),
("client_id", OPENAI_OAUTH_CLIENT_ID),
("redirect_uri", OPENAI_OAUTH_REDIRECT_URI),
("code_verifier", pkce.code_verifier.as_str()),
];
let response = client
.post(OPENAI_OAUTH_TOKEN_URL)
.form(&form)
.send()
.await
.context("Failed to exchange OpenAI OAuth authorization code")?;
parse_token_response(response).await
}
pub async fn refresh_access_token(client: &Client, refresh_token: &str) -> Result<TokenSet> {
let form = [
("grant_type", "refresh_token"),
("refresh_token", refresh_token),
("client_id", OPENAI_OAUTH_CLIENT_ID),
];
let response = client
.post(OPENAI_OAUTH_TOKEN_URL)
.form(&form)
.send()
.await
.context("Failed to refresh OpenAI OAuth token")?;
parse_token_response(response).await
}
pub async fn start_device_code_flow(client: &Client) -> Result<DeviceCodeStart> {
let form = [
("client_id", OPENAI_OAUTH_CLIENT_ID),
("scope", "openid profile email offline_access"),
];
let response = client
.post(OPENAI_OAUTH_DEVICE_CODE_URL)
.form(&form)
.send()
.await
.context("Failed to start OpenAI OAuth device-code flow")?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
anyhow::bail!("OpenAI device-code start failed ({status}): {body}");
}
let parsed: DeviceCodeResponse = response
.json()
.await
.context("Failed to parse OpenAI device-code response")?;
Ok(DeviceCodeStart {
device_code: parsed.device_code,
user_code: parsed.user_code,
verification_uri: parsed.verification_uri,
verification_uri_complete: parsed.verification_uri_complete,
expires_in: parsed.expires_in,
interval: parsed.interval.unwrap_or(5).max(1),
message: parsed.message,
})
}
pub async fn poll_device_code_tokens(client: &Client, device: &DeviceCodeStart) -> Result<TokenSet> {
let started = Instant::now();
let mut interval_secs = device.interval.max(1);
loop {
if started.elapsed() > Duration::from_secs(device.expires_in) {
anyhow::bail!("Device-code flow timed out before authorization completed");
}
tokio::time::sleep(Duration::from_secs(interval_secs)).await;
let form = [
("grant_type", "urn:ietf:params:oauth:grant-type:device_code"),
("device_code", device.device_code.as_str()),
("client_id", OPENAI_OAUTH_CLIENT_ID),
];
let response = client
.post(OPENAI_OAUTH_TOKEN_URL)
.form(&form)
.send()
.await
.context("Failed polling OpenAI device-code token endpoint")?;
if response.status().is_success() {
return parse_token_response(response).await;
}
let status = response.status();
let text = response.text().await.unwrap_or_default();
if let Ok(err) = serde_json::from_str::<OAuthErrorResponse>(&text) {
match err.error.as_str() {
"authorization_pending" => {
continue;
}
"slow_down" => {
interval_secs = interval_secs.saturating_add(5);
continue;
}
"access_denied" => {
anyhow::bail!("OpenAI device-code authorization was denied")
}
"expired_token" => {
anyhow::bail!("OpenAI device-code expired")
}
_ => {
anyhow::bail!(
"OpenAI device-code polling failed ({status}): {}",
err.error_description.unwrap_or(err.error)
)
}
}
}
anyhow::bail!("OpenAI device-code polling failed ({status}): {text}");
}
}
pub async fn receive_loopback_code(expected_state: &str, timeout: Duration) -> Result<String> {
let listener = TcpListener::bind("127.0.0.1:1455")
.await
.context("Failed to bind callback listener at 127.0.0.1:1455")?;
let accepted = tokio::time::timeout(timeout, listener.accept())
.await
.context("Timed out waiting for browser callback")?
.context("Failed to accept callback connection")?;
let (mut stream, _) = accepted;
let mut buffer = vec![0_u8; 8192];
let bytes_read = stream
.read(&mut buffer)
.await
.context("Failed to read callback request")?;
let request = String::from_utf8_lossy(&buffer[..bytes_read]);
let first_line = request
.lines()
.next()
.ok_or_else(|| anyhow::anyhow!("Malformed callback request"))?;
let path = first_line
.split_whitespace()
.nth(1)
.ok_or_else(|| anyhow::anyhow!("Callback request missing path"))?;
let code = parse_code_from_redirect(path, Some(expected_state))?;
let body = "<html><body><h2>ZeroClaw login complete</h2><p>You can close this tab.</p></body></html>";
let response = format!(
"HTTP/1.1 200 OK\r\nContent-Type: text/html; charset=utf-8\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{}",
body.len(),
body
);
let _ = stream.write_all(response.as_bytes()).await;
Ok(code)
}
pub fn parse_code_from_redirect(input: &str, expected_state: Option<&str>) -> Result<String> {
let trimmed = input.trim();
if !trimmed.contains("code=") {
if trimmed.is_empty() {
anyhow::bail!("No OAuth code provided");
}
return Ok(trimmed.to_string());
}
let query = if let Some((_, right)) = trimmed.split_once('?') {
right
} else {
trimmed
};
let params = parse_query_params(query);
if let Some(err) = params.get("error") {
let desc = params
.get("error_description")
.cloned()
.unwrap_or_else(|| "OAuth authorization failed".to_string());
anyhow::bail!("OpenAI OAuth error: {err} ({desc})");
}
if let Some(expected_state) = expected_state {
let got = params
.get("state")
.ok_or_else(|| anyhow::anyhow!("Missing OAuth state in callback"))?;
if got != expected_state {
anyhow::bail!("OAuth state mismatch");
}
}
params
.get("code")
.cloned()
.ok_or_else(|| anyhow::anyhow!("Missing OAuth code in callback"))
}
pub fn extract_account_id_from_jwt(token: &str) -> Option<String> {
let payload = token.split('.').nth(1)?;
let decoded = base64::engine::general_purpose::URL_SAFE_NO_PAD
.decode(payload)
.ok()?;
let claims: serde_json::Value = serde_json::from_slice(&decoded).ok()?;
for key in [
"account_id",
"accountId",
"acct",
"sub",
"https://api.openai.com/account_id",
] {
if let Some(value) = claims.get(key).and_then(|v| v.as_str()) {
if !value.trim().is_empty() {
return Some(value.to_string());
}
}
}
None
}
async fn parse_token_response(response: reqwest::Response) -> Result<TokenSet> {
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
anyhow::bail!("OpenAI OAuth token request failed ({status}): {body}");
}
let token: TokenResponse = response
.json()
.await
.context("Failed to parse OpenAI token response")?;
let expires_at = token.expires_in.and_then(|seconds| {
if seconds <= 0 {
None
} else {
Some(Utc::now() + chrono::Duration::seconds(seconds))
}
});
Ok(TokenSet {
access_token: token.access_token,
refresh_token: token.refresh_token,
id_token: token.id_token,
expires_at,
token_type: token.token_type,
scope: token.scope,
})
}
fn parse_query_params(input: &str) -> BTreeMap<String, String> {
let mut out = BTreeMap::new();
for pair in input.split('&') {
if pair.is_empty() {
continue;
}
let (key, value) = match pair.split_once('=') {
Some((k, v)) => (k, v),
None => (pair, ""),
};
out.insert(url_decode(key), url_decode(value));
}
out
}
fn random_base64url(byte_len: usize) -> String {
use chacha20poly1305::aead::{rand_core::RngCore, OsRng};
let mut bytes = vec![0_u8; byte_len];
OsRng.fill_bytes(&mut bytes);
base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(bytes)
}
fn url_encode(input: &str) -> String {
input
.bytes()
.map(|b| match b {
b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
(b as char).to_string()
}
_ => format!("%{b:02X}"),
})
.collect::<String>()
}
fn url_decode(input: &str) -> String {
let bytes = input.as_bytes();
let mut out = Vec::with_capacity(bytes.len());
let mut i = 0;
while i < bytes.len() {
match bytes[i] {
b'%' if i + 2 < bytes.len() => {
let hi = bytes[i + 1] as char;
let lo = bytes[i + 2] as char;
if let (Some(h), Some(l)) = (hi.to_digit(16), lo.to_digit(16)) {
if let Ok(value) = u8::try_from(h * 16 + l) {
out.push(value);
i += 3;
continue;
}
}
out.push(bytes[i]);
i += 1;
}
b'+' => {
out.push(b' ');
i += 1;
}
b => {
out.push(b);
i += 1;
}
}
}
String::from_utf8_lossy(&out).to_string()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn pkce_generation_is_valid() {
let pkce = generate_pkce_state();
assert!(pkce.code_verifier.len() >= 43);
assert!(!pkce.code_challenge.is_empty());
assert!(!pkce.state.is_empty());
}
#[test]
fn parse_redirect_url_extracts_code() {
let code = parse_code_from_redirect(
"http://127.0.0.1:1455/auth/callback?code=abc123&state=xyz",
Some("xyz"),
)
.unwrap();
assert_eq!(code, "abc123");
}
#[test]
fn parse_redirect_accepts_raw_code() {
let code = parse_code_from_redirect("raw-code", None).unwrap();
assert_eq!(code, "raw-code");
}
#[test]
fn parse_redirect_rejects_state_mismatch() {
let err = parse_code_from_redirect("/auth/callback?code=x&state=a", Some("b")).unwrap_err();
assert!(err.to_string().contains("state mismatch"));
}
#[test]
fn extract_account_id_from_jwt_payload() {
let header = base64::engine::general_purpose::URL_SAFE_NO_PAD.encode("{}");
let payload = base64::engine::general_purpose::URL_SAFE_NO_PAD
.encode("{\"account_id\":\"acct_123\"}");
let token = format!("{header}.{payload}.sig");
let account = extract_account_id_from_jwt(&token);
assert_eq!(account.as_deref(), Some("acct_123"));
}
}

678
src/auth/profiles.rs Normal file
View file

@ -0,0 +1,678 @@
use crate::security::SecretStore;
use anyhow::{Context, Result};
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::collections::BTreeMap;
use std::fs::{self, OpenOptions};
use std::io::Write;
use std::path::{Path, PathBuf};
use std::thread;
use std::time::Duration;
const CURRENT_SCHEMA_VERSION: u32 = 1;
const PROFILES_FILENAME: &str = "auth-profiles.json";
const LOCK_FILENAME: &str = "auth-profiles.lock";
const LOCK_WAIT_MS: u64 = 50;
const LOCK_TIMEOUT_MS: u64 = 10_000;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "kebab-case")]
pub enum AuthProfileKind {
OAuth,
Token,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TokenSet {
pub access_token: String,
#[serde(default)]
pub refresh_token: Option<String>,
#[serde(default)]
pub id_token: Option<String>,
#[serde(default)]
pub expires_at: Option<DateTime<Utc>>,
#[serde(default)]
pub token_type: Option<String>,
#[serde(default)]
pub scope: Option<String>,
}
impl TokenSet {
pub fn is_expiring_within(&self, skew: Duration) -> bool {
match self.expires_at {
Some(expires_at) => {
let now_plus_skew = Utc::now() + chrono::Duration::from_std(skew).unwrap_or_default();
expires_at <= now_plus_skew
}
None => false,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AuthProfile {
pub id: String,
pub provider: String,
pub profile_name: String,
pub kind: AuthProfileKind,
#[serde(default)]
pub account_id: Option<String>,
#[serde(default)]
pub workspace_id: Option<String>,
#[serde(default)]
pub token_set: Option<TokenSet>,
#[serde(default)]
pub token: Option<String>,
#[serde(default)]
pub metadata: BTreeMap<String, String>,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
}
impl AuthProfile {
pub fn new_oauth(provider: &str, profile_name: &str, token_set: TokenSet) -> Self {
let now = Utc::now();
let id = profile_id(provider, profile_name);
Self {
id,
provider: provider.to_string(),
profile_name: profile_name.to_string(),
kind: AuthProfileKind::OAuth,
account_id: None,
workspace_id: None,
token_set: Some(token_set),
token: None,
metadata: BTreeMap::new(),
created_at: now,
updated_at: now,
}
}
pub fn new_token(provider: &str, profile_name: &str, token: String) -> Self {
let now = Utc::now();
let id = profile_id(provider, profile_name);
Self {
id,
provider: provider.to_string(),
profile_name: profile_name.to_string(),
kind: AuthProfileKind::Token,
account_id: None,
workspace_id: None,
token_set: None,
token: Some(token),
metadata: BTreeMap::new(),
created_at: now,
updated_at: now,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AuthProfilesData {
pub schema_version: u32,
pub updated_at: DateTime<Utc>,
pub active_profiles: BTreeMap<String, String>,
pub profiles: BTreeMap<String, AuthProfile>,
}
impl Default for AuthProfilesData {
fn default() -> Self {
Self {
schema_version: CURRENT_SCHEMA_VERSION,
updated_at: Utc::now(),
active_profiles: BTreeMap::new(),
profiles: BTreeMap::new(),
}
}
}
#[derive(Debug, Clone)]
pub struct AuthProfilesStore {
path: PathBuf,
lock_path: PathBuf,
secret_store: SecretStore,
}
impl AuthProfilesStore {
pub fn new(state_dir: &Path, encrypt_secrets: bool) -> Self {
Self {
path: state_dir.join(PROFILES_FILENAME),
lock_path: state_dir.join(LOCK_FILENAME),
secret_store: SecretStore::new(state_dir, encrypt_secrets),
}
}
pub fn path(&self) -> &Path {
&self.path
}
pub fn load(&self) -> Result<AuthProfilesData> {
let _lock = self.acquire_lock()?;
self.load_locked()
}
pub fn upsert_profile(&self, mut profile: AuthProfile, set_active: bool) -> Result<()> {
let _lock = self.acquire_lock()?;
let mut data = self.load_locked()?;
profile.updated_at = Utc::now();
if let Some(existing) = data.profiles.get(&profile.id) {
profile.created_at = existing.created_at;
}
if set_active {
data.active_profiles
.insert(profile.provider.clone(), profile.id.clone());
}
data.profiles.insert(profile.id.clone(), profile);
data.updated_at = Utc::now();
self.save_locked(&data)
}
pub fn remove_profile(&self, profile_id: &str) -> Result<bool> {
let _lock = self.acquire_lock()?;
let mut data = self.load_locked()?;
let removed = data.profiles.remove(profile_id).is_some();
if !removed {
return Ok(false);
}
data.active_profiles.retain(|_, active| active != profile_id);
data.updated_at = Utc::now();
self.save_locked(&data)?;
Ok(true)
}
pub fn set_active_profile(&self, provider: &str, profile_id: &str) -> Result<()> {
let _lock = self.acquire_lock()?;
let mut data = self.load_locked()?;
if !data.profiles.contains_key(profile_id) {
anyhow::bail!("Auth profile not found: {profile_id}");
}
data.active_profiles
.insert(provider.to_string(), profile_id.to_string());
data.updated_at = Utc::now();
self.save_locked(&data)
}
pub fn clear_active_profile(&self, provider: &str) -> Result<()> {
let _lock = self.acquire_lock()?;
let mut data = self.load_locked()?;
data.active_profiles.remove(provider);
data.updated_at = Utc::now();
self.save_locked(&data)
}
pub fn update_profile<F>(&self, profile_id: &str, mut updater: F) -> Result<AuthProfile>
where
F: FnMut(&mut AuthProfile) -> Result<()>,
{
let _lock = self.acquire_lock()?;
let mut data = self.load_locked()?;
let profile = data
.profiles
.get_mut(profile_id)
.ok_or_else(|| anyhow::anyhow!("Auth profile not found: {profile_id}"))?;
updater(profile)?;
profile.updated_at = Utc::now();
let updated_profile = profile.clone();
data.updated_at = Utc::now();
self.save_locked(&data)?;
Ok(updated_profile)
}
fn load_locked(&self) -> Result<AuthProfilesData> {
let mut persisted = self.read_persisted_locked()?;
let mut migrated = false;
let mut profiles = BTreeMap::new();
for (id, p) in &mut persisted.profiles {
let (access_token, access_migrated) = self.decrypt_optional(p.access_token.as_deref())?;
let (refresh_token, refresh_migrated) =
self.decrypt_optional(p.refresh_token.as_deref())?;
let (id_token, id_migrated) = self.decrypt_optional(p.id_token.as_deref())?;
let (token, token_migrated) = self.decrypt_optional(p.token.as_deref())?;
if let Some(value) = access_migrated {
p.access_token = Some(value);
migrated = true;
}
if let Some(value) = refresh_migrated {
p.refresh_token = Some(value);
migrated = true;
}
if let Some(value) = id_migrated {
p.id_token = Some(value);
migrated = true;
}
if let Some(value) = token_migrated {
p.token = Some(value);
migrated = true;
}
let kind = parse_profile_kind(&p.kind)?;
let token_set = match kind {
AuthProfileKind::OAuth => {
let access = access_token.ok_or_else(|| {
anyhow::anyhow!("OAuth profile missing access_token: {id}")
})?;
Some(TokenSet {
access_token: access,
refresh_token,
id_token,
expires_at: parse_optional_datetime(p.expires_at.as_deref())?,
token_type: p.token_type.clone(),
scope: p.scope.clone(),
})
}
AuthProfileKind::Token => None,
};
profiles.insert(
id.clone(),
AuthProfile {
id: id.clone(),
provider: p.provider.clone(),
profile_name: p.profile_name.clone(),
kind,
account_id: p.account_id.clone(),
workspace_id: p.workspace_id.clone(),
token_set,
token,
metadata: p.metadata.clone(),
created_at: parse_datetime_with_fallback(&p.created_at),
updated_at: parse_datetime_with_fallback(&p.updated_at),
},
);
}
if migrated {
self.write_persisted_locked(&persisted)?;
}
Ok(AuthProfilesData {
schema_version: persisted.schema_version,
updated_at: parse_datetime_with_fallback(&persisted.updated_at),
active_profiles: persisted.active_profiles,
profiles,
})
}
fn save_locked(&self, data: &AuthProfilesData) -> Result<()> {
let mut persisted = PersistedAuthProfiles {
schema_version: CURRENT_SCHEMA_VERSION,
updated_at: data.updated_at.to_rfc3339(),
active_profiles: data.active_profiles.clone(),
profiles: BTreeMap::new(),
};
for (id, profile) in &data.profiles {
let (access_token, refresh_token, id_token, expires_at, token_type, scope) =
match (&profile.kind, &profile.token_set) {
(AuthProfileKind::OAuth, Some(token_set)) => (
self.encrypt_optional(Some(&token_set.access_token))?,
self.encrypt_optional(token_set.refresh_token.as_deref())?,
self.encrypt_optional(token_set.id_token.as_deref())?,
token_set.expires_at.as_ref().map(DateTime::to_rfc3339),
token_set.token_type.clone(),
token_set.scope.clone(),
),
_ => (None, None, None, None, None, None),
};
let token = self.encrypt_optional(profile.token.as_deref())?;
persisted.profiles.insert(
id.clone(),
PersistedAuthProfile {
provider: profile.provider.clone(),
profile_name: profile.profile_name.clone(),
kind: profile_kind_to_string(profile.kind).to_string(),
account_id: profile.account_id.clone(),
workspace_id: profile.workspace_id.clone(),
access_token,
refresh_token,
id_token,
token,
expires_at,
token_type,
scope,
metadata: profile.metadata.clone(),
created_at: profile.created_at.to_rfc3339(),
updated_at: profile.updated_at.to_rfc3339(),
},
);
}
self.write_persisted_locked(&persisted)
}
fn read_persisted_locked(&self) -> Result<PersistedAuthProfiles> {
if !self.path.exists() {
return Ok(PersistedAuthProfiles::default());
}
let bytes = fs::read(&self.path).with_context(|| {
format!(
"Failed to read auth profile store at {}",
self.path.display()
)
})?;
if bytes.is_empty() {
return Ok(PersistedAuthProfiles::default());
}
let mut persisted: PersistedAuthProfiles = serde_json::from_slice(&bytes).with_context(|| {
format!(
"Failed to parse auth profile store at {}",
self.path.display()
)
})?;
if persisted.schema_version == 0 {
persisted.schema_version = CURRENT_SCHEMA_VERSION;
}
if persisted.schema_version > CURRENT_SCHEMA_VERSION {
anyhow::bail!(
"Unsupported auth profile schema version {} (max supported: {})",
persisted.schema_version,
CURRENT_SCHEMA_VERSION
);
}
Ok(persisted)
}
fn write_persisted_locked(&self, persisted: &PersistedAuthProfiles) -> Result<()> {
if let Some(parent) = self.path.parent() {
fs::create_dir_all(parent).with_context(|| {
format!(
"Failed to create auth profile directory at {}",
parent.display()
)
})?;
}
let json = serde_json::to_vec_pretty(persisted).context("Failed to serialize auth profiles")?;
let tmp_name = format!(
"{}.tmp.{}.{}",
PROFILES_FILENAME,
std::process::id(),
Utc::now().timestamp_nanos_opt().unwrap_or_default()
);
let tmp_path = self.path.with_file_name(tmp_name);
fs::write(&tmp_path, &json).with_context(|| {
format!(
"Failed to write temporary auth profile file at {}",
tmp_path.display()
)
})?;
fs::rename(&tmp_path, &self.path).with_context(|| {
format!(
"Failed to replace auth profile store at {}",
self.path.display()
)
})?;
Ok(())
}
fn encrypt_optional(&self, value: Option<&str>) -> Result<Option<String>> {
match value {
Some(value) if !value.is_empty() => self.secret_store.encrypt(value).map(Some),
Some(_) | None => Ok(None),
}
}
fn decrypt_optional(&self, value: Option<&str>) -> Result<(Option<String>, Option<String>)> {
match value {
Some(value) if !value.is_empty() => {
let (plaintext, migrated) = self.secret_store.decrypt_and_migrate(value)?;
Ok((Some(plaintext), migrated))
}
Some(_) | None => Ok((None, None)),
}
}
fn acquire_lock(&self) -> Result<AuthProfileLockGuard> {
if let Some(parent) = self.lock_path.parent() {
fs::create_dir_all(parent).with_context(|| {
format!("Failed to create lock directory at {}", parent.display())
})?;
}
let mut waited = 0_u64;
loop {
match OpenOptions::new()
.create_new(true)
.write(true)
.open(&self.lock_path)
{
Ok(mut file) => {
let _ = writeln!(file, "pid={}", std::process::id());
return Ok(AuthProfileLockGuard {
lock_path: self.lock_path.clone(),
});
}
Err(e) if e.kind() == std::io::ErrorKind::AlreadyExists => {
if waited >= LOCK_TIMEOUT_MS {
anyhow::bail!(
"Timed out waiting for auth profile lock at {}",
self.lock_path.display()
);
}
thread::sleep(Duration::from_millis(LOCK_WAIT_MS));
waited = waited.saturating_add(LOCK_WAIT_MS);
}
Err(e) => {
return Err(e).with_context(|| {
format!(
"Failed to create auth profile lock at {}",
self.lock_path.display()
)
});
}
}
}
}
}
struct AuthProfileLockGuard {
lock_path: PathBuf,
}
impl Drop for AuthProfileLockGuard {
fn drop(&mut self) {
let _ = fs::remove_file(&self.lock_path);
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct PersistedAuthProfiles {
#[serde(default = "default_schema_version")]
schema_version: u32,
#[serde(default = "default_now_rfc3339")]
updated_at: String,
#[serde(default)]
active_profiles: BTreeMap<String, String>,
#[serde(default)]
profiles: BTreeMap<String, PersistedAuthProfile>,
}
impl Default for PersistedAuthProfiles {
fn default() -> Self {
Self {
schema_version: CURRENT_SCHEMA_VERSION,
updated_at: default_now_rfc3339(),
active_profiles: BTreeMap::new(),
profiles: BTreeMap::new(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
struct PersistedAuthProfile {
provider: String,
profile_name: String,
kind: String,
#[serde(default)]
account_id: Option<String>,
#[serde(default)]
workspace_id: Option<String>,
#[serde(default)]
access_token: Option<String>,
#[serde(default)]
refresh_token: Option<String>,
#[serde(default)]
id_token: Option<String>,
#[serde(default)]
token: Option<String>,
#[serde(default)]
expires_at: Option<String>,
#[serde(default)]
token_type: Option<String>,
#[serde(default)]
scope: Option<String>,
#[serde(default = "default_now_rfc3339")]
created_at: String,
#[serde(default = "default_now_rfc3339")]
updated_at: String,
#[serde(default)]
metadata: BTreeMap<String, String>,
}
fn default_schema_version() -> u32 {
CURRENT_SCHEMA_VERSION
}
fn default_now_rfc3339() -> String {
Utc::now().to_rfc3339()
}
fn parse_profile_kind(value: &str) -> Result<AuthProfileKind> {
match value {
"oauth" => Ok(AuthProfileKind::OAuth),
"token" => Ok(AuthProfileKind::Token),
other => anyhow::bail!("Unsupported auth profile kind: {other}"),
}
}
fn profile_kind_to_string(kind: AuthProfileKind) -> &'static str {
match kind {
AuthProfileKind::OAuth => "oauth",
AuthProfileKind::Token => "token",
}
}
fn parse_optional_datetime(value: Option<&str>) -> Result<Option<DateTime<Utc>>> {
value
.map(parse_datetime)
.transpose()
}
fn parse_datetime(value: &str) -> Result<DateTime<Utc>> {
DateTime::parse_from_rfc3339(value)
.map(|dt| dt.with_timezone(&Utc))
.with_context(|| format!("Invalid RFC3339 timestamp: {value}"))
}
fn parse_datetime_with_fallback(value: &str) -> DateTime<Utc> {
parse_datetime(value).unwrap_or_else(|_| Utc::now())
}
pub fn profile_id(provider: &str, profile_name: &str) -> String {
format!("{}:{}", provider.trim(), profile_name.trim())
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
#[test]
fn profile_id_format() {
assert_eq!(profile_id("openai-codex", "default"), "openai-codex:default");
}
#[test]
fn token_expiry_math() {
let token_set = TokenSet {
access_token: "token".into(),
refresh_token: Some("refresh".into()),
id_token: None,
expires_at: Some(Utc::now() + chrono::Duration::seconds(10)),
token_type: Some("Bearer".into()),
scope: None,
};
assert!(token_set.is_expiring_within(Duration::from_secs(15)));
assert!(!token_set.is_expiring_within(Duration::from_secs(1)));
}
#[test]
fn store_roundtrip_with_encryption() {
let tmp = TempDir::new().unwrap();
let store = AuthProfilesStore::new(tmp.path(), true);
let mut profile = AuthProfile::new_oauth(
"openai-codex",
"default",
TokenSet {
access_token: "access-123".into(),
refresh_token: Some("refresh-123".into()),
id_token: None,
expires_at: Some(Utc::now() + chrono::Duration::hours(1)),
token_type: Some("Bearer".into()),
scope: Some("openid offline_access".into()),
},
);
profile.account_id = Some("acct_123".into());
store.upsert_profile(profile.clone(), true).unwrap();
let data = store.load().unwrap();
let loaded = data.profiles.get(&profile.id).unwrap();
assert_eq!(loaded.provider, "openai-codex");
assert_eq!(loaded.profile_name, "default");
assert_eq!(loaded.account_id.as_deref(), Some("acct_123"));
assert_eq!(
loaded
.token_set
.as_ref()
.and_then(|t| t.refresh_token.as_deref()),
Some("refresh-123")
);
let raw = fs::read_to_string(store.path()).unwrap();
assert!(raw.contains("enc2:"));
assert!(!raw.contains("refresh-123"));
assert!(!raw.contains("access-123"));
}
#[test]
fn atomic_write_replaces_file() {
let tmp = TempDir::new().unwrap();
let store = AuthProfilesStore::new(tmp.path(), false);
let profile = AuthProfile::new_token("anthropic", "default", "token-abc".into());
store.upsert_profile(profile, true).unwrap();
let path = store.path().to_path_buf();
assert!(path.exists());
let contents = fs::read_to_string(path).unwrap();
assert!(contents.contains("\"schema_version\": 1"));
}
}

View file

@ -960,6 +960,11 @@ pub async fn start_channels(config: Config) -> Result<()> {
config.api_key.as_deref(),
config.api_url.as_deref(),
&config.reliability,
&providers::ProviderRuntimeOptions {
auth_profile_override: None,
zeroclaw_dir: config.config_path.parent().map(std::path::PathBuf::from),
secrets_encrypt: config.secrets.encrypt,
},
)?);
// Warm up the provider connection pool (TLS handshake, DNS, HTTP/2 setup)

View file

@ -297,11 +297,16 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
let actual_port = listener.local_addr()?.port();
let display_addr = format!("{host}:{actual_port}");
let provider: Arc<dyn Provider> = Arc::from(providers::create_resilient_provider(
let provider: Arc<dyn Provider> = Arc::from(providers::create_resilient_provider_with_options(
config.default_provider.as_deref().unwrap_or("openrouter"),
config.api_key.as_deref(),
config.api_url.as_deref(),
&config.reliability,
&providers::ProviderRuntimeOptions {
auth_profile_override: None,
zeroclaw_dir: config.config_path.parent().map(std::path::PathBuf::from),
secrets_encrypt: config.secrets.encrypt,
},
)?);
let model = config
.default_model

View file

@ -129,7 +129,7 @@ enum Commands {
#[arg(short, long)]
message: Option<String>,
/// Provider to use (openrouter, anthropic, openai)
/// Provider to use (openrouter, anthropic, openai, openai-codex)
#[arg(short, long)]
provider: Option<String>,

View file

@ -1385,6 +1385,10 @@ fn setup_provider(workspace_dir: &Path) -> Result<(String, String, String, Optio
("venice", "Venice AI — privacy-first (Llama, Opus)"),
("anthropic", "Anthropic — Claude Sonnet & Opus (direct)"),
("openai", "OpenAI — GPT-4o, o1, GPT-5 (direct)"),
(
"openai-codex",
"OpenAI Codex (ChatGPT subscription OAuth, no API key)",
),
("deepseek", "DeepSeek — V3 & R1 (affordable)"),
("mistral", "Mistral — Large & Codestral"),
("xai", "xAI — Grok 3 & 4"),
@ -1719,6 +1723,10 @@ fn setup_provider(workspace_dir: &Path) -> Result<(String, String, String, Optio
("gpt-4o-mini", "GPT-4o Mini (fast, cheap)"),
("o1-mini", "o1-mini (reasoning)"),
],
"openai-codex" => vec![
("gpt-5-codex", "GPT-5 Codex (recommended)"),
("o4-mini", "o4-mini (fallback)"),
],
"venice" => vec![
("llama-3.3-70b", "Llama 3.3 70B (default, fast)"),
("claude-opus-45", "Claude Opus 4.5 via Venice (strongest)"),
@ -4054,15 +4062,41 @@ fn print_summary(config: &Config) {
let mut step = 1u8;
if config.api_key.is_none() {
let env_var = provider_env_var(config.default_provider.as_deref().unwrap_or("openrouter"));
println!(
" {} Set your API key:",
style(format!("{step}.")).cyan().bold()
);
println!(
" {}",
style(format!("export {env_var}=\"sk-...\"")).yellow()
);
let provider = config.default_provider.as_deref().unwrap_or("openrouter");
if provider == "openai-codex" {
println!(
" {} Authenticate OpenAI Codex:",
style(format!("{step}.")).cyan().bold()
);
println!(
" {}",
style("zeroclaw auth login --provider openai-codex --device-code").yellow()
);
} else if provider == "anthropic" {
println!(
" {} Configure Anthropic auth:",
style(format!("{step}.")).cyan().bold()
);
println!(
" {}",
style("export ANTHROPIC_API_KEY=\"sk-ant-...\"").yellow()
);
println!(
" {}",
style("or: zeroclaw auth paste-token --provider anthropic --auth-kind authorization")
.yellow()
);
} else {
let env_var = provider_env_var(provider);
println!(
" {} Set your API key:",
style(format!("{step}.")).cyan().bold()
);
println!(
" {}",
style(format!("export {env_var}=\"sk-...\"")).yellow()
);
}
println!();
step += 1;
}

View file

@ -6,6 +6,7 @@ use crate::tools::ToolSpec;
use async_trait::async_trait;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::path::PathBuf;
pub struct AnthropicProvider {
credential: Option<String>,
@ -614,4 +615,10 @@ mod tests {
assert!(json.contains(&format!("{temp}")));
}
}
#[test]
fn detects_auth_from_jwt_shape() {
let kind = detect_auth_kind("a.b.c", None);
assert_eq!(kind, AnthropicAuthKind::Authorization);
}
}

View file

@ -4,6 +4,7 @@ pub mod copilot;
pub mod gemini;
pub mod ollama;
pub mod openai;
pub mod openai_codex;
pub mod openrouter;
pub mod reliable;
pub mod router;
@ -17,6 +18,7 @@ pub use traits::{
use compatible::{AuthStyle, OpenAiCompatibleProvider};
use reliable::ReliableProvider;
use std::path::PathBuf;
const MAX_API_ERROR_CHARS: usize = 200;
const MINIMAX_INTL_BASE_URL: &str = "https://api.minimax.io/v1";
@ -178,6 +180,23 @@ fn zai_base_url(name: &str) -> Option<&'static str> {
}
}
#[derive(Debug, Clone)]
pub struct ProviderRuntimeOptions {
pub auth_profile_override: Option<String>,
pub zeroclaw_dir: Option<PathBuf>,
pub secrets_encrypt: bool,
}
impl Default for ProviderRuntimeOptions {
fn default() -> Self {
Self {
auth_profile_override: None,
zeroclaw_dir: None,
secrets_encrypt: true,
}
}
}
fn is_secret_char(c: char) -> bool {
c.is_ascii_alphanumeric() || matches!(c, '-' | '_' | '.' | ':')
}
@ -538,6 +557,21 @@ pub fn create_resilient_provider(
api_key: Option<&str>,
api_url: Option<&str>,
reliability: &crate::config::ReliabilityConfig,
) -> anyhow::Result<Box<dyn Provider>> {
create_resilient_provider_with_options(
primary_name,
api_key,
reliability,
&ProviderRuntimeOptions::default(),
)
}
/// Create provider chain with retry/fallback behavior and auth runtime options.
pub fn create_resilient_provider_with_options(
primary_name: &str,
api_key: Option<&str>,
reliability: &crate::config::ReliabilityConfig,
options: &ProviderRuntimeOptions,
) -> anyhow::Result<Box<dyn Provider>> {
let mut providers: Vec<(String, Box<dyn Provider>)> = Vec::new();
@ -943,6 +977,12 @@ mod tests {
assert!(create_provider("openai", Some("provider-test-credential")).is_ok());
}
#[test]
fn factory_openai_codex() {
let options = ProviderRuntimeOptions::default();
assert!(create_provider_with_options("openai-codex", None, &options).is_ok());
}
#[test]
fn factory_ollama() {
assert!(create_provider("ollama", None).is_ok());

View file

@ -0,0 +1,198 @@
use crate::auth::AuthService;
use crate::providers::traits::Provider;
use crate::providers::ProviderRuntimeOptions;
use async_trait::async_trait;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::path::PathBuf;
const CODEX_RESPONSES_URL: &str = "https://chatgpt.com/backend-api/codex/responses";
pub struct OpenAiCodexProvider {
auth: AuthService,
auth_profile_override: Option<String>,
client: Client,
}
#[derive(Debug, Serialize)]
struct ResponsesRequest {
model: String,
input: Vec<ResponsesInput>,
#[serde(skip_serializing_if = "Option::is_none")]
instructions: Option<String>,
stream: bool,
}
#[derive(Debug, Serialize)]
struct ResponsesInput {
role: String,
content: String,
}
#[derive(Debug, Deserialize)]
struct ResponsesResponse {
#[serde(default)]
output: Vec<ResponsesOutput>,
#[serde(default)]
output_text: Option<String>,
}
#[derive(Debug, Deserialize)]
struct ResponsesOutput {
#[serde(default)]
content: Vec<ResponsesContent>,
}
#[derive(Debug, Deserialize)]
struct ResponsesContent {
#[serde(rename = "type")]
kind: Option<String>,
text: Option<String>,
}
impl OpenAiCodexProvider {
pub fn new(options: &ProviderRuntimeOptions) -> Self {
let state_dir = options
.zeroclaw_dir
.clone()
.unwrap_or_else(default_zeroclaw_dir);
let auth = AuthService::new(&state_dir, options.secrets_encrypt);
Self {
auth,
auth_profile_override: options.auth_profile_override.clone(),
client: Client::builder()
.timeout(std::time::Duration::from_secs(120))
.connect_timeout(std::time::Duration::from_secs(10))
.build()
.unwrap_or_else(|_| Client::new()),
}
}
}
fn default_zeroclaw_dir() -> PathBuf {
directories::UserDirs::new().map_or_else(
|| PathBuf::from(".zeroclaw"),
|dirs| dirs.home_dir().join(".zeroclaw"),
)
}
fn first_nonempty(text: Option<&str>) -> Option<String> {
text.and_then(|value| {
let trimmed = value.trim();
if trimmed.is_empty() {
None
} else {
Some(trimmed.to_string())
}
})
}
fn extract_responses_text(response: &ResponsesResponse) -> Option<String> {
if let Some(text) = first_nonempty(response.output_text.as_deref()) {
return Some(text);
}
for item in &response.output {
for content in &item.content {
if content.kind.as_deref() == Some("output_text") {
if let Some(text) = first_nonempty(content.text.as_deref()) {
return Some(text);
}
}
}
}
for item in &response.output {
for content in &item.content {
if let Some(text) = first_nonempty(content.text.as_deref()) {
return Some(text);
}
}
}
None
}
#[async_trait]
impl Provider for OpenAiCodexProvider {
async fn chat_with_system(
&self,
system_prompt: Option<&str>,
message: &str,
model: &str,
_temperature: f64,
) -> anyhow::Result<String> {
let access_token = self
.auth
.get_valid_openai_access_token(self.auth_profile_override.as_deref())
.await?
.ok_or_else(|| {
anyhow::anyhow!(
"OpenAI Codex auth profile not found. Run `zeroclaw auth login --provider openai-codex`."
)
})?;
let request = ResponsesRequest {
model: model.to_string(),
input: vec![ResponsesInput {
role: "user".to_string(),
content: message.to_string(),
}],
instructions: system_prompt.map(str::to_string),
stream: false,
};
let response = self
.client
.post(CODEX_RESPONSES_URL)
.header("Authorization", format!("Bearer {access_token}"))
.header("Content-Type", "application/json")
.json(&request)
.send()
.await?;
if !response.status().is_success() {
return Err(super::api_error("OpenAI Codex", response).await);
}
let parsed: ResponsesResponse = response.json().await?;
extract_responses_text(&parsed)
.ok_or_else(|| anyhow::anyhow!("No response from OpenAI Codex"))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn extracts_output_text_first() {
let response = ResponsesResponse {
output: vec![],
output_text: Some("hello".into()),
};
assert_eq!(extract_responses_text(&response).as_deref(), Some("hello"));
}
#[test]
fn extracts_nested_output_text() {
let response = ResponsesResponse {
output: vec![ResponsesOutput {
content: vec![ResponsesContent {
kind: Some("output_text".into()),
text: Some("nested".into()),
}],
}],
output_text: None,
};
assert_eq!(extract_responses_text(&response).as_deref(), Some("nested"));
}
#[test]
fn default_state_dir_is_non_empty() {
let path = default_zeroclaw_dir();
assert!(!path.as_os_str().is_empty());
}
}