use crate::security::SecretStore; use anyhow::{Context, Result}; use chrono::{DateTime, Utc}; use serde::{Deserialize, Serialize}; use std::collections::BTreeMap; use std::fs::{self, OpenOptions}; use std::io::Write; use std::path::{Path, PathBuf}; use std::thread; use std::time::Duration; const CURRENT_SCHEMA_VERSION: u32 = 1; const PROFILES_FILENAME: &str = "auth-profiles.json"; const LOCK_FILENAME: &str = "auth-profiles.lock"; const LOCK_WAIT_MS: u64 = 50; const LOCK_TIMEOUT_MS: u64 = 10_000; #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] #[serde(rename_all = "kebab-case")] pub enum AuthProfileKind { OAuth, Token, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct TokenSet { pub access_token: String, #[serde(default)] pub refresh_token: Option, #[serde(default)] pub id_token: Option, #[serde(default)] pub expires_at: Option>, #[serde(default)] pub token_type: Option, #[serde(default)] pub scope: Option, } impl TokenSet { pub fn is_expiring_within(&self, skew: Duration) -> bool { match self.expires_at { Some(expires_at) => { let now_plus_skew = Utc::now() + chrono::Duration::from_std(skew).unwrap_or_default(); expires_at <= now_plus_skew } None => false, } } } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct AuthProfile { pub id: String, pub provider: String, pub profile_name: String, pub kind: AuthProfileKind, #[serde(default)] pub account_id: Option, #[serde(default)] pub workspace_id: Option, #[serde(default)] pub token_set: Option, #[serde(default)] pub token: Option, #[serde(default)] pub metadata: BTreeMap, pub created_at: DateTime, pub updated_at: DateTime, } impl AuthProfile { pub fn new_oauth(provider: &str, profile_name: &str, token_set: TokenSet) -> Self { let now = Utc::now(); let id = profile_id(provider, profile_name); Self { id, provider: provider.to_string(), profile_name: profile_name.to_string(), kind: AuthProfileKind::OAuth, account_id: None, workspace_id: None, token_set: Some(token_set), token: None, metadata: BTreeMap::new(), created_at: now, updated_at: now, } } pub fn new_token(provider: &str, profile_name: &str, token: String) -> Self { let now = Utc::now(); let id = profile_id(provider, profile_name); Self { id, provider: provider.to_string(), profile_name: profile_name.to_string(), kind: AuthProfileKind::Token, account_id: None, workspace_id: None, token_set: None, token: Some(token), metadata: BTreeMap::new(), created_at: now, updated_at: now, } } } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct AuthProfilesData { pub schema_version: u32, pub updated_at: DateTime, pub active_profiles: BTreeMap, pub profiles: BTreeMap, } impl Default for AuthProfilesData { fn default() -> Self { Self { schema_version: CURRENT_SCHEMA_VERSION, updated_at: Utc::now(), active_profiles: BTreeMap::new(), profiles: BTreeMap::new(), } } } #[derive(Debug, Clone)] pub struct AuthProfilesStore { path: PathBuf, lock_path: PathBuf, secret_store: SecretStore, } impl AuthProfilesStore { pub fn new(state_dir: &Path, encrypt_secrets: bool) -> Self { Self { path: state_dir.join(PROFILES_FILENAME), lock_path: state_dir.join(LOCK_FILENAME), secret_store: SecretStore::new(state_dir, encrypt_secrets), } } pub fn path(&self) -> &Path { &self.path } pub fn load(&self) -> Result { let _lock = self.acquire_lock()?; self.load_locked() } pub fn upsert_profile(&self, mut profile: AuthProfile, set_active: bool) -> Result<()> { let _lock = self.acquire_lock()?; let mut data = self.load_locked()?; profile.updated_at = Utc::now(); if let Some(existing) = data.profiles.get(&profile.id) { profile.created_at = existing.created_at; } if set_active { data.active_profiles .insert(profile.provider.clone(), profile.id.clone()); } data.profiles.insert(profile.id.clone(), profile); data.updated_at = Utc::now(); self.save_locked(&data) } pub fn remove_profile(&self, profile_id: &str) -> Result { let _lock = self.acquire_lock()?; let mut data = self.load_locked()?; let removed = data.profiles.remove(profile_id).is_some(); if !removed { return Ok(false); } data.active_profiles .retain(|_, active| active != profile_id); data.updated_at = Utc::now(); self.save_locked(&data)?; Ok(true) } pub fn set_active_profile(&self, provider: &str, profile_id: &str) -> Result<()> { let _lock = self.acquire_lock()?; let mut data = self.load_locked()?; if !data.profiles.contains_key(profile_id) { anyhow::bail!("Auth profile not found: {profile_id}"); } data.active_profiles .insert(provider.to_string(), profile_id.to_string()); data.updated_at = Utc::now(); self.save_locked(&data) } pub fn clear_active_profile(&self, provider: &str) -> Result<()> { let _lock = self.acquire_lock()?; let mut data = self.load_locked()?; data.active_profiles.remove(provider); data.updated_at = Utc::now(); self.save_locked(&data) } pub fn update_profile(&self, profile_id: &str, mut updater: F) -> Result where F: FnMut(&mut AuthProfile) -> Result<()>, { let _lock = self.acquire_lock()?; let mut data = self.load_locked()?; let profile = data .profiles .get_mut(profile_id) .ok_or_else(|| anyhow::anyhow!("Auth profile not found: {profile_id}"))?; updater(profile)?; profile.updated_at = Utc::now(); let updated_profile = profile.clone(); data.updated_at = Utc::now(); self.save_locked(&data)?; Ok(updated_profile) } fn load_locked(&self) -> Result { let mut persisted = self.read_persisted_locked()?; let mut migrated = false; let mut profiles = BTreeMap::new(); for (id, p) in &mut persisted.profiles { let (access_token, access_migrated) = self.decrypt_optional(p.access_token.as_deref())?; let (refresh_token, refresh_migrated) = self.decrypt_optional(p.refresh_token.as_deref())?; let (id_token, id_migrated) = self.decrypt_optional(p.id_token.as_deref())?; let (token, token_migrated) = self.decrypt_optional(p.token.as_deref())?; if let Some(value) = access_migrated { p.access_token = Some(value); migrated = true; } if let Some(value) = refresh_migrated { p.refresh_token = Some(value); migrated = true; } if let Some(value) = id_migrated { p.id_token = Some(value); migrated = true; } if let Some(value) = token_migrated { p.token = Some(value); migrated = true; } let kind = parse_profile_kind(&p.kind)?; let token_set = match kind { AuthProfileKind::OAuth => { let access = access_token.ok_or_else(|| { anyhow::anyhow!("OAuth profile missing access_token: {id}") })?; Some(TokenSet { access_token: access, refresh_token, id_token, expires_at: parse_optional_datetime(p.expires_at.as_deref())?, token_type: p.token_type.clone(), scope: p.scope.clone(), }) } AuthProfileKind::Token => None, }; profiles.insert( id.clone(), AuthProfile { id: id.clone(), provider: p.provider.clone(), profile_name: p.profile_name.clone(), kind, account_id: p.account_id.clone(), workspace_id: p.workspace_id.clone(), token_set, token, metadata: p.metadata.clone(), created_at: parse_datetime_with_fallback(&p.created_at), updated_at: parse_datetime_with_fallback(&p.updated_at), }, ); } if migrated { self.write_persisted_locked(&persisted)?; } Ok(AuthProfilesData { schema_version: persisted.schema_version, updated_at: parse_datetime_with_fallback(&persisted.updated_at), active_profiles: persisted.active_profiles, profiles, }) } fn save_locked(&self, data: &AuthProfilesData) -> Result<()> { let mut persisted = PersistedAuthProfiles { schema_version: CURRENT_SCHEMA_VERSION, updated_at: data.updated_at.to_rfc3339(), active_profiles: data.active_profiles.clone(), profiles: BTreeMap::new(), }; for (id, profile) in &data.profiles { let (access_token, refresh_token, id_token, expires_at, token_type, scope) = match (&profile.kind, &profile.token_set) { (AuthProfileKind::OAuth, Some(token_set)) => ( self.encrypt_optional(Some(&token_set.access_token))?, self.encrypt_optional(token_set.refresh_token.as_deref())?, self.encrypt_optional(token_set.id_token.as_deref())?, token_set.expires_at.as_ref().map(DateTime::to_rfc3339), token_set.token_type.clone(), token_set.scope.clone(), ), _ => (None, None, None, None, None, None), }; let token = self.encrypt_optional(profile.token.as_deref())?; persisted.profiles.insert( id.clone(), PersistedAuthProfile { provider: profile.provider.clone(), profile_name: profile.profile_name.clone(), kind: profile_kind_to_string(profile.kind).to_string(), account_id: profile.account_id.clone(), workspace_id: profile.workspace_id.clone(), access_token, refresh_token, id_token, token, expires_at, token_type, scope, metadata: profile.metadata.clone(), created_at: profile.created_at.to_rfc3339(), updated_at: profile.updated_at.to_rfc3339(), }, ); } self.write_persisted_locked(&persisted) } fn read_persisted_locked(&self) -> Result { if !self.path.exists() { return Ok(PersistedAuthProfiles::default()); } let bytes = fs::read(&self.path).with_context(|| { format!( "Failed to read auth profile store at {}", self.path.display() ) })?; if bytes.is_empty() { return Ok(PersistedAuthProfiles::default()); } let mut persisted: PersistedAuthProfiles = serde_json::from_slice(&bytes).with_context(|| { format!( "Failed to parse auth profile store at {}", self.path.display() ) })?; if persisted.schema_version == 0 { persisted.schema_version = CURRENT_SCHEMA_VERSION; } if persisted.schema_version > CURRENT_SCHEMA_VERSION { anyhow::bail!( "Unsupported auth profile schema version {} (max supported: {})", persisted.schema_version, CURRENT_SCHEMA_VERSION ); } Ok(persisted) } fn write_persisted_locked(&self, persisted: &PersistedAuthProfiles) -> Result<()> { if let Some(parent) = self.path.parent() { fs::create_dir_all(parent).with_context(|| { format!( "Failed to create auth profile directory at {}", parent.display() ) })?; } let json = serde_json::to_vec_pretty(persisted).context("Failed to serialize auth profiles")?; let tmp_name = format!( "{}.tmp.{}.{}", PROFILES_FILENAME, std::process::id(), Utc::now().timestamp_nanos_opt().unwrap_or_default() ); let tmp_path = self.path.with_file_name(tmp_name); fs::write(&tmp_path, &json).with_context(|| { format!( "Failed to write temporary auth profile file at {}", tmp_path.display() ) })?; fs::rename(&tmp_path, &self.path).with_context(|| { format!( "Failed to replace auth profile store at {}", self.path.display() ) })?; Ok(()) } fn encrypt_optional(&self, value: Option<&str>) -> Result> { match value { Some(value) if !value.is_empty() => self.secret_store.encrypt(value).map(Some), Some(_) | None => Ok(None), } } fn decrypt_optional(&self, value: Option<&str>) -> Result<(Option, Option)> { match value { Some(value) if !value.is_empty() => { let (plaintext, migrated) = self.secret_store.decrypt_and_migrate(value)?; Ok((Some(plaintext), migrated)) } Some(_) | None => Ok((None, None)), } } fn acquire_lock(&self) -> Result { if let Some(parent) = self.lock_path.parent() { fs::create_dir_all(parent).with_context(|| { format!("Failed to create lock directory at {}", parent.display()) })?; } let mut waited = 0_u64; loop { match OpenOptions::new() .create_new(true) .write(true) .open(&self.lock_path) { Ok(mut file) => { let _ = writeln!(file, "pid={}", std::process::id()); return Ok(AuthProfileLockGuard { lock_path: self.lock_path.clone(), }); } Err(e) if e.kind() == std::io::ErrorKind::AlreadyExists => { if waited >= LOCK_TIMEOUT_MS { anyhow::bail!( "Timed out waiting for auth profile lock at {}", self.lock_path.display() ); } thread::sleep(Duration::from_millis(LOCK_WAIT_MS)); waited = waited.saturating_add(LOCK_WAIT_MS); } Err(e) => { return Err(e).with_context(|| { format!( "Failed to create auth profile lock at {}", self.lock_path.display() ) }); } } } } } struct AuthProfileLockGuard { lock_path: PathBuf, } impl Drop for AuthProfileLockGuard { fn drop(&mut self) { let _ = fs::remove_file(&self.lock_path); } } #[derive(Debug, Clone, Serialize, Deserialize)] struct PersistedAuthProfiles { #[serde(default = "default_schema_version")] schema_version: u32, #[serde(default = "default_now_rfc3339")] updated_at: String, #[serde(default)] active_profiles: BTreeMap, #[serde(default)] profiles: BTreeMap, } impl Default for PersistedAuthProfiles { fn default() -> Self { Self { schema_version: CURRENT_SCHEMA_VERSION, updated_at: default_now_rfc3339(), active_profiles: BTreeMap::new(), profiles: BTreeMap::new(), } } } #[derive(Debug, Clone, Serialize, Deserialize, Default)] struct PersistedAuthProfile { provider: String, profile_name: String, kind: String, #[serde(default)] account_id: Option, #[serde(default)] workspace_id: Option, #[serde(default)] access_token: Option, #[serde(default)] refresh_token: Option, #[serde(default)] id_token: Option, #[serde(default)] token: Option, #[serde(default)] expires_at: Option, #[serde(default)] token_type: Option, #[serde(default)] scope: Option, #[serde(default = "default_now_rfc3339")] created_at: String, #[serde(default = "default_now_rfc3339")] updated_at: String, #[serde(default)] metadata: BTreeMap, } fn default_schema_version() -> u32 { CURRENT_SCHEMA_VERSION } fn default_now_rfc3339() -> String { Utc::now().to_rfc3339() } fn parse_profile_kind(value: &str) -> Result { match value { "oauth" => Ok(AuthProfileKind::OAuth), "token" => Ok(AuthProfileKind::Token), other => anyhow::bail!("Unsupported auth profile kind: {other}"), } } fn profile_kind_to_string(kind: AuthProfileKind) -> &'static str { match kind { AuthProfileKind::OAuth => "oauth", AuthProfileKind::Token => "token", } } fn parse_optional_datetime(value: Option<&str>) -> Result>> { value.map(parse_datetime).transpose() } fn parse_datetime(value: &str) -> Result> { DateTime::parse_from_rfc3339(value) .map(|dt| dt.with_timezone(&Utc)) .with_context(|| format!("Invalid RFC3339 timestamp: {value}")) } fn parse_datetime_with_fallback(value: &str) -> DateTime { parse_datetime(value).unwrap_or_else(|_| Utc::now()) } pub fn profile_id(provider: &str, profile_name: &str) -> String { format!("{}:{}", provider.trim(), profile_name.trim()) } #[cfg(test)] mod tests { use super::*; use tempfile::TempDir; #[test] fn profile_id_format() { assert_eq!( profile_id("openai-codex", "default"), "openai-codex:default" ); } #[test] fn token_expiry_math() { let token_set = TokenSet { access_token: "token".into(), refresh_token: Some("refresh".into()), id_token: None, expires_at: Some(Utc::now() + chrono::Duration::seconds(10)), token_type: Some("Bearer".into()), scope: None, }; assert!(token_set.is_expiring_within(Duration::from_secs(15))); assert!(!token_set.is_expiring_within(Duration::from_secs(1))); } #[tokio::test] async fn store_roundtrip_with_encryption() { let tmp = TempDir::new().unwrap(); let store = AuthProfilesStore::new(tmp.path(), true); let mut profile = AuthProfile::new_oauth( "openai-codex", "default", TokenSet { access_token: "access-123".into(), refresh_token: Some("refresh-123".into()), id_token: None, expires_at: Some(Utc::now() + chrono::Duration::hours(1)), token_type: Some("Bearer".into()), scope: Some("openid offline_access".into()), }, ); profile.account_id = Some("acct_123".into()); store.upsert_profile(profile.clone(), true).unwrap(); let data = store.load().unwrap(); let loaded = data.profiles.get(&profile.id).unwrap(); assert_eq!(loaded.provider, "openai-codex"); assert_eq!(loaded.profile_name, "default"); assert_eq!(loaded.account_id.as_deref(), Some("acct_123")); assert_eq!( loaded .token_set .as_ref() .and_then(|t| t.refresh_token.as_deref()), Some("refresh-123") ); let raw = tokio::fs::read_to_string(store.path()).await.unwrap(); assert!(raw.contains("enc2:")); assert!(!raw.contains("refresh-123")); assert!(!raw.contains("access-123")); } #[tokio::test] async fn atomic_write_replaces_file() { let tmp = TempDir::new().unwrap(); let store = AuthProfilesStore::new(tmp.path(), false); let profile = AuthProfile::new_token("anthropic", "default", "token-abc".into()); store.upsert_profile(profile, true).unwrap(); let path = store.path().to_path_buf(); assert!(path.exists()); let contents = tokio::fs::read_to_string(path).await.unwrap(); assert!(contents.contains("\"schema_version\": 1")); } }