zeroclaw/src/auth/profiles.rs
2026-02-19 14:52:29 +08:00

684 lines
22 KiB
Rust

use crate::security::SecretStore;
use anyhow::{Context, Result};
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::collections::BTreeMap;
use std::fs::{self, OpenOptions};
use std::io::Write;
use std::path::{Path, PathBuf};
use std::thread;
use std::time::Duration;
const CURRENT_SCHEMA_VERSION: u32 = 1;
const PROFILES_FILENAME: &str = "auth-profiles.json";
const LOCK_FILENAME: &str = "auth-profiles.lock";
const LOCK_WAIT_MS: u64 = 50;
const LOCK_TIMEOUT_MS: u64 = 10_000;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "kebab-case")]
pub enum AuthProfileKind {
OAuth,
Token,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TokenSet {
pub access_token: String,
#[serde(default)]
pub refresh_token: Option<String>,
#[serde(default)]
pub id_token: Option<String>,
#[serde(default)]
pub expires_at: Option<DateTime<Utc>>,
#[serde(default)]
pub token_type: Option<String>,
#[serde(default)]
pub scope: Option<String>,
}
impl TokenSet {
pub fn is_expiring_within(&self, skew: Duration) -> bool {
match self.expires_at {
Some(expires_at) => {
let now_plus_skew =
Utc::now() + chrono::Duration::from_std(skew).unwrap_or_default();
expires_at <= now_plus_skew
}
None => false,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AuthProfile {
pub id: String,
pub provider: String,
pub profile_name: String,
pub kind: AuthProfileKind,
#[serde(default)]
pub account_id: Option<String>,
#[serde(default)]
pub workspace_id: Option<String>,
#[serde(default)]
pub token_set: Option<TokenSet>,
#[serde(default)]
pub token: Option<String>,
#[serde(default)]
pub metadata: BTreeMap<String, String>,
pub created_at: DateTime<Utc>,
pub updated_at: DateTime<Utc>,
}
impl AuthProfile {
pub fn new_oauth(provider: &str, profile_name: &str, token_set: TokenSet) -> Self {
let now = Utc::now();
let id = profile_id(provider, profile_name);
Self {
id,
provider: provider.to_string(),
profile_name: profile_name.to_string(),
kind: AuthProfileKind::OAuth,
account_id: None,
workspace_id: None,
token_set: Some(token_set),
token: None,
metadata: BTreeMap::new(),
created_at: now,
updated_at: now,
}
}
pub fn new_token(provider: &str, profile_name: &str, token: String) -> Self {
let now = Utc::now();
let id = profile_id(provider, profile_name);
Self {
id,
provider: provider.to_string(),
profile_name: profile_name.to_string(),
kind: AuthProfileKind::Token,
account_id: None,
workspace_id: None,
token_set: None,
token: Some(token),
metadata: BTreeMap::new(),
created_at: now,
updated_at: now,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AuthProfilesData {
pub schema_version: u32,
pub updated_at: DateTime<Utc>,
pub active_profiles: BTreeMap<String, String>,
pub profiles: BTreeMap<String, AuthProfile>,
}
impl Default for AuthProfilesData {
fn default() -> Self {
Self {
schema_version: CURRENT_SCHEMA_VERSION,
updated_at: Utc::now(),
active_profiles: BTreeMap::new(),
profiles: BTreeMap::new(),
}
}
}
#[derive(Debug, Clone)]
pub struct AuthProfilesStore {
path: PathBuf,
lock_path: PathBuf,
secret_store: SecretStore,
}
impl AuthProfilesStore {
pub fn new(state_dir: &Path, encrypt_secrets: bool) -> Self {
Self {
path: state_dir.join(PROFILES_FILENAME),
lock_path: state_dir.join(LOCK_FILENAME),
secret_store: SecretStore::new(state_dir, encrypt_secrets),
}
}
pub fn path(&self) -> &Path {
&self.path
}
pub fn load(&self) -> Result<AuthProfilesData> {
let _lock = self.acquire_lock()?;
self.load_locked()
}
pub fn upsert_profile(&self, mut profile: AuthProfile, set_active: bool) -> Result<()> {
let _lock = self.acquire_lock()?;
let mut data = self.load_locked()?;
profile.updated_at = Utc::now();
if let Some(existing) = data.profiles.get(&profile.id) {
profile.created_at = existing.created_at;
}
if set_active {
data.active_profiles
.insert(profile.provider.clone(), profile.id.clone());
}
data.profiles.insert(profile.id.clone(), profile);
data.updated_at = Utc::now();
self.save_locked(&data)
}
pub fn remove_profile(&self, profile_id: &str) -> Result<bool> {
let _lock = self.acquire_lock()?;
let mut data = self.load_locked()?;
let removed = data.profiles.remove(profile_id).is_some();
if !removed {
return Ok(false);
}
data.active_profiles
.retain(|_, active| active != profile_id);
data.updated_at = Utc::now();
self.save_locked(&data)?;
Ok(true)
}
pub fn set_active_profile(&self, provider: &str, profile_id: &str) -> Result<()> {
let _lock = self.acquire_lock()?;
let mut data = self.load_locked()?;
if !data.profiles.contains_key(profile_id) {
anyhow::bail!("Auth profile not found: {profile_id}");
}
data.active_profiles
.insert(provider.to_string(), profile_id.to_string());
data.updated_at = Utc::now();
self.save_locked(&data)
}
pub fn clear_active_profile(&self, provider: &str) -> Result<()> {
let _lock = self.acquire_lock()?;
let mut data = self.load_locked()?;
data.active_profiles.remove(provider);
data.updated_at = Utc::now();
self.save_locked(&data)
}
pub fn update_profile<F>(&self, profile_id: &str, mut updater: F) -> Result<AuthProfile>
where
F: FnMut(&mut AuthProfile) -> Result<()>,
{
let _lock = self.acquire_lock()?;
let mut data = self.load_locked()?;
let profile = data
.profiles
.get_mut(profile_id)
.ok_or_else(|| anyhow::anyhow!("Auth profile not found: {profile_id}"))?;
updater(profile)?;
profile.updated_at = Utc::now();
let updated_profile = profile.clone();
data.updated_at = Utc::now();
self.save_locked(&data)?;
Ok(updated_profile)
}
fn load_locked(&self) -> Result<AuthProfilesData> {
let mut persisted = self.read_persisted_locked()?;
let mut migrated = false;
let mut profiles = BTreeMap::new();
for (id, p) in &mut persisted.profiles {
let (access_token, access_migrated) =
self.decrypt_optional(p.access_token.as_deref())?;
let (refresh_token, refresh_migrated) =
self.decrypt_optional(p.refresh_token.as_deref())?;
let (id_token, id_migrated) = self.decrypt_optional(p.id_token.as_deref())?;
let (token, token_migrated) = self.decrypt_optional(p.token.as_deref())?;
if let Some(value) = access_migrated {
p.access_token = Some(value);
migrated = true;
}
if let Some(value) = refresh_migrated {
p.refresh_token = Some(value);
migrated = true;
}
if let Some(value) = id_migrated {
p.id_token = Some(value);
migrated = true;
}
if let Some(value) = token_migrated {
p.token = Some(value);
migrated = true;
}
let kind = parse_profile_kind(&p.kind)?;
let token_set = match kind {
AuthProfileKind::OAuth => {
let access = access_token.ok_or_else(|| {
anyhow::anyhow!("OAuth profile missing access_token: {id}")
})?;
Some(TokenSet {
access_token: access,
refresh_token,
id_token,
expires_at: parse_optional_datetime(p.expires_at.as_deref())?,
token_type: p.token_type.clone(),
scope: p.scope.clone(),
})
}
AuthProfileKind::Token => None,
};
profiles.insert(
id.clone(),
AuthProfile {
id: id.clone(),
provider: p.provider.clone(),
profile_name: p.profile_name.clone(),
kind,
account_id: p.account_id.clone(),
workspace_id: p.workspace_id.clone(),
token_set,
token,
metadata: p.metadata.clone(),
created_at: parse_datetime_with_fallback(&p.created_at),
updated_at: parse_datetime_with_fallback(&p.updated_at),
},
);
}
if migrated {
self.write_persisted_locked(&persisted)?;
}
Ok(AuthProfilesData {
schema_version: persisted.schema_version,
updated_at: parse_datetime_with_fallback(&persisted.updated_at),
active_profiles: persisted.active_profiles,
profiles,
})
}
fn save_locked(&self, data: &AuthProfilesData) -> Result<()> {
let mut persisted = PersistedAuthProfiles {
schema_version: CURRENT_SCHEMA_VERSION,
updated_at: data.updated_at.to_rfc3339(),
active_profiles: data.active_profiles.clone(),
profiles: BTreeMap::new(),
};
for (id, profile) in &data.profiles {
let (access_token, refresh_token, id_token, expires_at, token_type, scope) =
match (&profile.kind, &profile.token_set) {
(AuthProfileKind::OAuth, Some(token_set)) => (
self.encrypt_optional(Some(&token_set.access_token))?,
self.encrypt_optional(token_set.refresh_token.as_deref())?,
self.encrypt_optional(token_set.id_token.as_deref())?,
token_set.expires_at.as_ref().map(DateTime::to_rfc3339),
token_set.token_type.clone(),
token_set.scope.clone(),
),
_ => (None, None, None, None, None, None),
};
let token = self.encrypt_optional(profile.token.as_deref())?;
persisted.profiles.insert(
id.clone(),
PersistedAuthProfile {
provider: profile.provider.clone(),
profile_name: profile.profile_name.clone(),
kind: profile_kind_to_string(profile.kind).to_string(),
account_id: profile.account_id.clone(),
workspace_id: profile.workspace_id.clone(),
access_token,
refresh_token,
id_token,
token,
expires_at,
token_type,
scope,
metadata: profile.metadata.clone(),
created_at: profile.created_at.to_rfc3339(),
updated_at: profile.updated_at.to_rfc3339(),
},
);
}
self.write_persisted_locked(&persisted)
}
fn read_persisted_locked(&self) -> Result<PersistedAuthProfiles> {
if !self.path.exists() {
return Ok(PersistedAuthProfiles::default());
}
let bytes = fs::read(&self.path).with_context(|| {
format!(
"Failed to read auth profile store at {}",
self.path.display()
)
})?;
if bytes.is_empty() {
return Ok(PersistedAuthProfiles::default());
}
let mut persisted: PersistedAuthProfiles =
serde_json::from_slice(&bytes).with_context(|| {
format!(
"Failed to parse auth profile store at {}",
self.path.display()
)
})?;
if persisted.schema_version == 0 {
persisted.schema_version = CURRENT_SCHEMA_VERSION;
}
if persisted.schema_version > CURRENT_SCHEMA_VERSION {
anyhow::bail!(
"Unsupported auth profile schema version {} (max supported: {})",
persisted.schema_version,
CURRENT_SCHEMA_VERSION
);
}
Ok(persisted)
}
fn write_persisted_locked(&self, persisted: &PersistedAuthProfiles) -> Result<()> {
if let Some(parent) = self.path.parent() {
fs::create_dir_all(parent).with_context(|| {
format!(
"Failed to create auth profile directory at {}",
parent.display()
)
})?;
}
let json =
serde_json::to_vec_pretty(persisted).context("Failed to serialize auth profiles")?;
let tmp_name = format!(
"{}.tmp.{}.{}",
PROFILES_FILENAME,
std::process::id(),
Utc::now().timestamp_nanos_opt().unwrap_or_default()
);
let tmp_path = self.path.with_file_name(tmp_name);
fs::write(&tmp_path, &json).with_context(|| {
format!(
"Failed to write temporary auth profile file at {}",
tmp_path.display()
)
})?;
fs::rename(&tmp_path, &self.path).with_context(|| {
format!(
"Failed to replace auth profile store at {}",
self.path.display()
)
})?;
Ok(())
}
fn encrypt_optional(&self, value: Option<&str>) -> Result<Option<String>> {
match value {
Some(value) if !value.is_empty() => self.secret_store.encrypt(value).map(Some),
Some(_) | None => Ok(None),
}
}
fn decrypt_optional(&self, value: Option<&str>) -> Result<(Option<String>, Option<String>)> {
match value {
Some(value) if !value.is_empty() => {
let (plaintext, migrated) = self.secret_store.decrypt_and_migrate(value)?;
Ok((Some(plaintext), migrated))
}
Some(_) | None => Ok((None, None)),
}
}
fn acquire_lock(&self) -> Result<AuthProfileLockGuard> {
if let Some(parent) = self.lock_path.parent() {
fs::create_dir_all(parent).with_context(|| {
format!("Failed to create lock directory at {}", parent.display())
})?;
}
let mut waited = 0_u64;
loop {
match OpenOptions::new()
.create_new(true)
.write(true)
.open(&self.lock_path)
{
Ok(mut file) => {
let _ = writeln!(file, "pid={}", std::process::id());
return Ok(AuthProfileLockGuard {
lock_path: self.lock_path.clone(),
});
}
Err(e) if e.kind() == std::io::ErrorKind::AlreadyExists => {
if waited >= LOCK_TIMEOUT_MS {
anyhow::bail!(
"Timed out waiting for auth profile lock at {}",
self.lock_path.display()
);
}
thread::sleep(Duration::from_millis(LOCK_WAIT_MS));
waited = waited.saturating_add(LOCK_WAIT_MS);
}
Err(e) => {
return Err(e).with_context(|| {
format!(
"Failed to create auth profile lock at {}",
self.lock_path.display()
)
});
}
}
}
}
}
struct AuthProfileLockGuard {
lock_path: PathBuf,
}
impl Drop for AuthProfileLockGuard {
fn drop(&mut self) {
let _ = fs::remove_file(&self.lock_path);
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct PersistedAuthProfiles {
#[serde(default = "default_schema_version")]
schema_version: u32,
#[serde(default = "default_now_rfc3339")]
updated_at: String,
#[serde(default)]
active_profiles: BTreeMap<String, String>,
#[serde(default)]
profiles: BTreeMap<String, PersistedAuthProfile>,
}
impl Default for PersistedAuthProfiles {
fn default() -> Self {
Self {
schema_version: CURRENT_SCHEMA_VERSION,
updated_at: default_now_rfc3339(),
active_profiles: BTreeMap::new(),
profiles: BTreeMap::new(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
struct PersistedAuthProfile {
provider: String,
profile_name: String,
kind: String,
#[serde(default)]
account_id: Option<String>,
#[serde(default)]
workspace_id: Option<String>,
#[serde(default)]
access_token: Option<String>,
#[serde(default)]
refresh_token: Option<String>,
#[serde(default)]
id_token: Option<String>,
#[serde(default)]
token: Option<String>,
#[serde(default)]
expires_at: Option<String>,
#[serde(default)]
token_type: Option<String>,
#[serde(default)]
scope: Option<String>,
#[serde(default = "default_now_rfc3339")]
created_at: String,
#[serde(default = "default_now_rfc3339")]
updated_at: String,
#[serde(default)]
metadata: BTreeMap<String, String>,
}
fn default_schema_version() -> u32 {
CURRENT_SCHEMA_VERSION
}
fn default_now_rfc3339() -> String {
Utc::now().to_rfc3339()
}
fn parse_profile_kind(value: &str) -> Result<AuthProfileKind> {
match value {
"oauth" => Ok(AuthProfileKind::OAuth),
"token" => Ok(AuthProfileKind::Token),
other => anyhow::bail!("Unsupported auth profile kind: {other}"),
}
}
fn profile_kind_to_string(kind: AuthProfileKind) -> &'static str {
match kind {
AuthProfileKind::OAuth => "oauth",
AuthProfileKind::Token => "token",
}
}
fn parse_optional_datetime(value: Option<&str>) -> Result<Option<DateTime<Utc>>> {
value.map(parse_datetime).transpose()
}
fn parse_datetime(value: &str) -> Result<DateTime<Utc>> {
DateTime::parse_from_rfc3339(value)
.map(|dt| dt.with_timezone(&Utc))
.with_context(|| format!("Invalid RFC3339 timestamp: {value}"))
}
fn parse_datetime_with_fallback(value: &str) -> DateTime<Utc> {
parse_datetime(value).unwrap_or_else(|_| Utc::now())
}
pub fn profile_id(provider: &str, profile_name: &str) -> String {
format!("{}:{}", provider.trim(), profile_name.trim())
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
#[test]
fn profile_id_format() {
assert_eq!(
profile_id("openai-codex", "default"),
"openai-codex:default"
);
}
#[test]
fn token_expiry_math() {
let token_set = TokenSet {
access_token: "token".into(),
refresh_token: Some("refresh".into()),
id_token: None,
expires_at: Some(Utc::now() + chrono::Duration::seconds(10)),
token_type: Some("Bearer".into()),
scope: None,
};
assert!(token_set.is_expiring_within(Duration::from_secs(15)));
assert!(!token_set.is_expiring_within(Duration::from_secs(1)));
}
#[tokio::test]
async fn store_roundtrip_with_encryption() {
let tmp = TempDir::new().unwrap();
let store = AuthProfilesStore::new(tmp.path(), true);
let mut profile = AuthProfile::new_oauth(
"openai-codex",
"default",
TokenSet {
access_token: "access-123".into(),
refresh_token: Some("refresh-123".into()),
id_token: None,
expires_at: Some(Utc::now() + chrono::Duration::hours(1)),
token_type: Some("Bearer".into()),
scope: Some("openid offline_access".into()),
},
);
profile.account_id = Some("acct_123".into());
store.upsert_profile(profile.clone(), true).unwrap();
let data = store.load().unwrap();
let loaded = data.profiles.get(&profile.id).unwrap();
assert_eq!(loaded.provider, "openai-codex");
assert_eq!(loaded.profile_name, "default");
assert_eq!(loaded.account_id.as_deref(), Some("acct_123"));
assert_eq!(
loaded
.token_set
.as_ref()
.and_then(|t| t.refresh_token.as_deref()),
Some("refresh-123")
);
let raw = tokio::fs::read_to_string(store.path()).await.unwrap();
assert!(raw.contains("enc2:"));
assert!(!raw.contains("refresh-123"));
assert!(!raw.contains("access-123"));
}
#[tokio::test]
async fn atomic_write_replaces_file() {
let tmp = TempDir::new().unwrap();
let store = AuthProfilesStore::new(tmp.path(), false);
let profile = AuthProfile::new_token("anthropic", "default", "token-abc".into());
store.upsert_profile(profile, true).unwrap();
let path = store.path().to_path_buf();
assert!(path.exists());
let contents = tokio::fs::read_to_string(path).await.unwrap();
assert!(contents.contains("\"schema_version\": 1"));
}
}