diff --git a/Cargo.lock b/Cargo.lock index 456c55a..4b43472 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -385,6 +385,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b52af3cb4058c895d37317bb27508dccc8e5f2d39454016b297bf4a400597b8" dependencies = [ "axum-core", + "axum-macros", "base64", "bytes", "form_urlencoded", @@ -431,6 +432,17 @@ dependencies = [ "tower-service", ] +[[package]] +name = "axum-macros" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "604fde5e028fea851ce1d8570bbdc034bec850d157f7569d10f347d06808c05c" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.116", +] + [[package]] name = "backon" version = "1.6.0" @@ -7734,6 +7746,7 @@ dependencies = [ "tokio", "tokio-rustls", "tokio-serial", + "tokio-stream", "tokio-tungstenite 0.24.0", "tokio-util", "toml 1.0.1+spec-1.1.0", diff --git a/Cargo.toml b/Cargo.toml index 2f23c40..4f42ec0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -120,7 +120,7 @@ mail-parser = "0.11.2" async-imap = { version = "0.11",features = ["runtime-tokio"], default-features = false } # HTTP server (gateway) — replaces raw TCP for proper HTTP/1.1 compliance -axum = { version = "0.8", default-features = false, features = ["http1", "json", "tokio", "query", "ws"] } +axum = { version = "0.8", default-features = false, features = ["http1", "json", "tokio", "query", "ws", "macros"] } tower = { version = "0.5", default-features = false } tower-http = { version = "0.6", default-features = false, features = ["limit", "timeout"] } http-body-util = "0.1" @@ -141,6 +141,7 @@ probe-rs = { version = "0.30", optional = true } # PDF extraction for datasheet RAG (optional, enable with --features rag-pdf) pdf-extract = { version = "0.10", optional = true } +tokio-stream = { version = "0.1.18", features = ["full"] } # WhatsApp Web client (wa-rs) — optional, enable with --features whatsapp-web # Uses wa-rs for Bot and Client, wa-rs-core for storage traits, custom rusqlite backend avoids Diesel conflict. diff --git a/docs/frictionless-security.md b/docs/frictionless-security.md index 2f5fde6..f62046d 100644 --- a/docs/frictionless-security.md +++ b/docs/frictionless-security.md @@ -26,7 +26,7 @@ pub fn run_wizard() -> Result { security: SecurityConfig::autodetect(), // Silent! }; - config.save()?; + config.save().await?; Ok(config) } ``` diff --git a/src/channels/matrix.rs b/src/channels/matrix.rs index 6dc74ad..5d0c207 100644 --- a/src/channels/matrix.rs +++ b/src/channels/matrix.rs @@ -726,7 +726,7 @@ mod tests { "!r:m".to_string(), vec![], Some(" ".to_string()), - Some("".to_string()), + Some(String::new()), ); assert!(ch.session_owner_hint.is_none()); diff --git a/src/channels/mod.rs b/src/channels/mod.rs index 5d76861..ada5803 100644 --- a/src/channels/mod.rs +++ b/src/channels/mod.rs @@ -1117,7 +1117,7 @@ fn normalize_telegram_identity(value: &str) -> String { value.trim().trim_start_matches('@').to_string() } -fn bind_telegram_identity(config: &Config, identity: &str) -> Result<()> { +async fn bind_telegram_identity(config: &Config, identity: &str) -> Result<()> { let normalized = normalize_telegram_identity(identity); if normalized.is_empty() { anyhow::bail!("Telegram identity cannot be empty"); @@ -1147,7 +1147,7 @@ fn bind_telegram_identity(config: &Config, identity: &str) -> Result<()> { } telegram.allowed_users.push(normalized.clone()); - updated.save()?; + updated.save().await?; println!("✅ Bound Telegram identity: {normalized}"); println!(" Saved to {}", updated.config_path.display()); match maybe_restart_managed_daemon_service() { @@ -1243,7 +1243,7 @@ fn maybe_restart_managed_daemon_service() -> Result { Ok(false) } -pub fn handle_command(command: crate::ChannelCommands, config: &Config) -> Result<()> { +pub async fn handle_command(command: crate::ChannelCommands, config: &Config) -> Result<()> { match command { crate::ChannelCommands::Start => { anyhow::bail!("Start must be handled in main.rs (requires async runtime)") @@ -1290,7 +1290,7 @@ pub fn handle_command(command: crate::ChannelCommands, config: &Config) -> Resul anyhow::bail!("Remove channel '{name}' — edit ~/.zeroclaw/config.toml directly"); } crate::ChannelCommands::BindTelegram { identity } => { - bind_telegram_identity(config, &identity) + bind_telegram_identity(config, &identity).await } } } diff --git a/src/channels/telegram.rs b/src/channels/telegram.rs index b05dc56..dcd5b7d 100644 --- a/src/channels/telegram.rs +++ b/src/channels/telegram.rs @@ -6,10 +6,10 @@ use async_trait::async_trait; use directories::UserDirs; use parking_lot::Mutex; use reqwest::multipart::{Form, Part}; -use std::fs; use std::path::Path; use std::sync::{Arc, RwLock}; use std::time::Duration; +use tokio::fs; /// Telegram's maximum message length for text messages const TELEGRAM_MAX_MESSAGE_LENGTH: usize = 4096; @@ -373,7 +373,7 @@ impl TelegramChannel { .collect() } - fn load_config_without_env() -> anyhow::Result { + async fn load_config_without_env() -> anyhow::Result { let home = UserDirs::new() .map(|u| u.home_dir().to_path_buf()) .context("Could not find home directory")?; @@ -381,6 +381,7 @@ impl TelegramChannel { let config_path = zeroclaw_dir.join("config.toml"); let contents = fs::read_to_string(&config_path) + .await .with_context(|| format!("Failed to read config file: {}", config_path.display()))?; let mut config: Config = toml::from_str(&contents) .context("Failed to parse config file for Telegram binding")?; @@ -389,8 +390,8 @@ impl TelegramChannel { Ok(config) } - fn persist_allowed_identity_blocking(identity: &str) -> anyhow::Result<()> { - let mut config = Self::load_config_without_env()?; + async fn persist_allowed_identity(&self, identity: &str) -> anyhow::Result<()> { + let mut config = Self::load_config_without_env().await?; let Some(telegram) = config.channels_config.telegram.as_mut() else { anyhow::bail!("Telegram channel config is missing in config.toml"); }; @@ -404,20 +405,13 @@ impl TelegramChannel { telegram.allowed_users.push(normalized); config .save() + .await .context("Failed to persist Telegram allowlist to config.toml")?; } Ok(()) } - async fn persist_allowed_identity(&self, identity: &str) -> anyhow::Result<()> { - let identity = identity.to_string(); - tokio::task::spawn_blocking(move || Self::persist_allowed_identity_blocking(&identity)) - .await - .map_err(|e| anyhow::anyhow!("Failed to join Telegram bind save task: {e}"))??; - Ok(()) - } - fn add_allowed_identity_runtime(&self, identity: &str) { let normalized = Self::normalize_identity(identity); if normalized.is_empty() { @@ -629,7 +623,7 @@ impl TelegramChannel { if let Some(code) = Self::extract_bind_code(text) { if let Some(pairing) = self.pairing.as_ref() { - match pairing.try_pair(code) { + match pairing.try_pair(code).await { Ok(Some(_token)) => { let bind_identity = normalized_sender_id.clone().or_else(|| { if normalized_username.is_empty() || normalized_username == "unknown" { diff --git a/src/config/schema.rs b/src/config/schema.rs index f9946f3..c3d3d03 100644 --- a/src/config/schema.rs +++ b/src/config/schema.rs @@ -5,10 +5,10 @@ use directories::UserDirs; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use std::collections::HashMap; -use std::fs::{self, File, OpenOptions}; -use std::io::Write; use std::path::{Path, PathBuf}; use std::sync::{OnceLock, RwLock}; +use tokio::fs::{self, File, OpenOptions}; +use tokio::io::AsyncWriteExt; const SUPPORTED_PROXY_SERVICE_KEYS: &[&str] = &[ "provider.anthropic", @@ -2526,13 +2526,15 @@ fn active_workspace_state_path(default_dir: &Path) -> PathBuf { default_dir.join(ACTIVE_WORKSPACE_STATE_FILE) } -fn load_persisted_workspace_dirs(default_config_dir: &Path) -> Result> { +async fn load_persisted_workspace_dirs( + default_config_dir: &Path, +) -> Result> { let state_path = active_workspace_state_path(default_config_dir); if !state_path.exists() { return Ok(None); } - let contents = match fs::read_to_string(&state_path) { + let contents = match fs::read_to_string(&state_path).await { Ok(contents) => contents, Err(error) => { tracing::warn!( @@ -2572,13 +2574,13 @@ fn load_persisted_workspace_dirs(default_config_dir: &Path) -> Result Result<()> { +pub(crate) async fn persist_active_workspace_config_dir(config_dir: &Path) -> Result<()> { let default_config_dir = default_config_dir()?; let state_path = active_workspace_state_path(&default_config_dir); if config_dir == default_config_dir { if state_path.exists() { - fs::remove_file(&state_path).with_context(|| { + fs::remove_file(&state_path).await.with_context(|| { format!( "Failed to clear active workspace marker: {}", state_path.display() @@ -2588,12 +2590,14 @@ pub(crate) fn persist_active_workspace_config_dir(config_dir: &Path) -> Result<( return Ok(()); } - fs::create_dir_all(&default_config_dir).with_context(|| { - format!( - "Failed to create default config directory: {}", - default_config_dir.display() - ) - })?; + fs::create_dir_all(&default_config_dir) + .await + .with_context(|| { + format!( + "Failed to create default config directory: {}", + default_config_dir.display() + ) + })?; let state = ActiveWorkspaceState { config_dir: config_dir.to_string_lossy().into_owned(), @@ -2605,22 +2609,22 @@ pub(crate) fn persist_active_workspace_config_dir(config_dir: &Path) -> Result<( ".{ACTIVE_WORKSPACE_STATE_FILE}.tmp-{}", uuid::Uuid::new_v4() )); - fs::write(&temp_path, serialized).with_context(|| { + fs::write(&temp_path, serialized).await.with_context(|| { format!( "Failed to write temporary active workspace marker: {}", temp_path.display() ) })?; - if let Err(error) = fs::rename(&temp_path, &state_path) { - let _ = fs::remove_file(&temp_path); + if let Err(error) = fs::rename(&temp_path, &state_path).await { + let _ = fs::remove_file(&temp_path).await; anyhow::bail!( "Failed to atomically persist active workspace marker {}: {error}", state_path.display() ); } - sync_directory(&default_config_dir)?; + sync_directory(&default_config_dir).await?; Ok(()) } @@ -2690,7 +2694,7 @@ fn encrypt_optional_secret( } impl Config { - pub fn load_or_init() -> Result { + pub async fn load_or_init() -> Result { let (default_zeroclaw_dir, default_workspace_dir) = default_config_and_workspace_dirs()?; // Resolution priority: @@ -2701,21 +2705,26 @@ impl Config { Ok(custom_workspace) if !custom_workspace.is_empty() => { resolve_config_dir_for_workspace(&PathBuf::from(custom_workspace)) } - _ => load_persisted_workspace_dirs(&default_zeroclaw_dir)? + _ => load_persisted_workspace_dirs(&default_zeroclaw_dir) + .await? .unwrap_or((default_zeroclaw_dir, default_workspace_dir)), }; let config_path = zeroclaw_dir.join("config.toml"); - fs::create_dir_all(&zeroclaw_dir).context("Failed to create config directory")?; - fs::create_dir_all(&workspace_dir).context("Failed to create workspace directory")?; + fs::create_dir_all(&zeroclaw_dir) + .await + .context("Failed to create config directory")?; + fs::create_dir_all(&workspace_dir) + .await + .context("Failed to create workspace directory")?; if config_path.exists() { // Warn if config file is world-readable (may contain API keys) #[cfg(unix)] { use std::os::unix::fs::PermissionsExt; - if let Ok(meta) = fs::metadata(&config_path) { + if let Ok(meta) = fs::metadata(&config_path).await { if meta.permissions().mode() & 0o004 != 0 { tracing::warn!( "Config file {:?} is world-readable (mode {:o}). \ @@ -2728,8 +2737,9 @@ impl Config { } } - let contents = - fs::read_to_string(&config_path).context("Failed to read config file")?; + let contents = fs::read_to_string(&config_path) + .await + .context("Failed to read config file")?; let mut config: Config = toml::from_str(&contents).context("Failed to parse config file")?; // Set computed paths that are skipped during serialization @@ -2770,13 +2780,13 @@ impl Config { let mut config = Config::default(); config.config_path = config_path.clone(); config.workspace_dir = workspace_dir; - config.save()?; + config.save().await?; // Restrict permissions on newly created config file (may contain API keys) #[cfg(unix)] { - use std::os::unix::fs::PermissionsExt; - let _ = fs::set_permissions(&config_path, fs::Permissions::from_mode(0o600)); + use std::{fs::Permissions, os::unix::fs::PermissionsExt}; + let _ = fs::set_permissions(&config_path, Permissions::from_mode(0o600)).await; } config.apply_env_overrides(); @@ -3019,7 +3029,7 @@ impl Config { set_runtime_proxy_config(self.proxy.clone()); } - pub fn save(&self) -> Result<()> { + pub async fn save(&self) -> Result<()> { // Encrypt secrets before serialization let mut config_to_save = self.clone(); let zeroclaw_dir = self @@ -3064,7 +3074,8 @@ impl Config { .config_path .parent() .context("Config path must have a parent directory")?; - fs::create_dir_all(parent_dir).with_context(|| { + + fs::create_dir_all(parent_dir).await.with_context(|| { format!( "Failed to create config directory: {}", parent_dir.display() @@ -3083,6 +3094,7 @@ impl Config { .create_new(true) .write(true) .open(&temp_path) + .await .with_context(|| { format!( "Failed to create temporary config file: {}", @@ -3091,34 +3103,40 @@ impl Config { })?; temp_file .write_all(toml_str.as_bytes()) + .await .context("Failed to write temporary config contents")?; temp_file .sync_all() + .await .context("Failed to fsync temporary config file")?; drop(temp_file); let had_existing_config = self.config_path.exists(); if had_existing_config { - fs::copy(&self.config_path, &backup_path).with_context(|| { - format!( - "Failed to create config backup before atomic replace: {}", - backup_path.display() - ) - })?; + fs::copy(&self.config_path, &backup_path) + .await + .with_context(|| { + format!( + "Failed to create config backup before atomic replace: {}", + backup_path.display() + ) + })?; } - if let Err(e) = fs::rename(&temp_path, &self.config_path) { - let _ = fs::remove_file(&temp_path); + if let Err(e) = fs::rename(&temp_path, &self.config_path).await { + let _ = fs::remove_file(&temp_path).await; if had_existing_config && backup_path.exists() { - let _ = fs::copy(&backup_path, &self.config_path); + fs::copy(&backup_path, &self.config_path) + .await + .context("Failed to restore config backup")?; } anyhow::bail!("Failed to atomically replace config file: {e}"); } - sync_directory(parent_dir)?; + sync_directory(parent_dir).await?; if had_existing_config { - let _ = fs::remove_file(&backup_path); + let _ = fs::remove_file(&backup_path).await; } Ok(()) @@ -3126,10 +3144,12 @@ impl Config { } #[cfg(unix)] -fn sync_directory(path: &Path) -> Result<()> { +async fn sync_directory(path: &Path) -> Result<()> { let dir = File::open(path) + .await .with_context(|| format!("Failed to open directory for fsync: {}", path.display()))?; dir.sync_all() + .await .with_context(|| format!("Failed to fsync directory metadata: {}", path.display()))?; Ok(()) } @@ -3142,12 +3162,16 @@ fn sync_directory(_path: &Path) -> Result<()> { #[cfg(test)] mod tests { use super::*; - use std::path::PathBuf; + use std::{fs::Permissions, os::unix::fs::PermissionsExt, path::PathBuf}; + use tokio::sync::{Mutex, MutexGuard}; + use tokio::test; + use tokio_stream::wrappers::ReadDirStream; + use tokio_stream::StreamExt; // ── Defaults ───────────────────────────────────────────── #[test] - fn config_default_has_sane_values() { + async fn config_default_has_sane_values() { let c = Config::default(); assert_eq!(c.default_provider.as_deref(), Some("openrouter")); assert!(c.default_model.as_deref().unwrap().contains("claude")); @@ -3190,13 +3214,13 @@ mod tests { } #[test] - fn observability_config_default() { + async fn observability_config_default() { let o = ObservabilityConfig::default(); assert_eq!(o.backend, "none"); } #[test] - fn autonomy_config_default() { + async fn autonomy_config_default() { let a = AutonomyConfig::default(); assert_eq!(a.level, AutonomyLevel::Supervised); assert!(a.workspace_only); @@ -3210,7 +3234,7 @@ mod tests { } #[test] - fn runtime_config_default() { + async fn runtime_config_default() { let r = RuntimeConfig::default(); assert_eq!(r.kind, "native"); assert_eq!(r.docker.image, "alpine:3.20"); @@ -3222,21 +3246,21 @@ mod tests { } #[test] - fn heartbeat_config_default() { + async fn heartbeat_config_default() { let h = HeartbeatConfig::default(); assert!(!h.enabled); assert_eq!(h.interval_minutes, 30); } #[test] - fn cron_config_default() { + async fn cron_config_default() { let c = CronConfig::default(); assert!(c.enabled); assert_eq!(c.max_run_history, 50); } #[test] - fn cron_config_serde_roundtrip() { + async fn cron_config_serde_roundtrip() { let c = CronConfig { enabled: false, max_run_history: 100, @@ -3248,7 +3272,7 @@ mod tests { } #[test] - fn config_defaults_cron_when_section_missing() { + async fn config_defaults_cron_when_section_missing() { let toml_str = r#" workspace_dir = "/tmp/workspace" config_path = "/tmp/config.toml" @@ -3261,7 +3285,7 @@ default_temperature = 0.7 } #[test] - fn memory_config_default_hygiene_settings() { + async fn memory_config_default_hygiene_settings() { let m = MemoryConfig::default(); assert_eq!(m.backend, "sqlite"); assert!(m.auto_save); @@ -3273,7 +3297,7 @@ default_temperature = 0.7 } #[test] - fn storage_provider_config_defaults() { + async fn storage_provider_config_defaults() { let storage = StorageConfig::default(); assert!(storage.provider.config.provider.is_empty()); assert!(storage.provider.config.db_url.is_none()); @@ -3283,7 +3307,7 @@ default_temperature = 0.7 } #[test] - fn channels_config_default() { + async fn channels_config_default() { let c = ChannelsConfig::default(); assert!(c.cli); assert!(c.telegram.is_none()); @@ -3293,7 +3317,7 @@ default_temperature = 0.7 // ── Serde round-trip ───────────────────────────────────── #[test] - fn config_toml_roundtrip() { + async fn config_toml_roundtrip() { let config = Config { workspace_dir: PathBuf::from("/tmp/test/workspace"), config_path: PathBuf::from("/tmp/test/config.toml"), @@ -3395,7 +3419,7 @@ default_temperature = 0.7 } #[test] - fn config_minimal_toml_uses_defaults() { + async fn config_minimal_toml_uses_defaults() { let minimal = r#" workspace_dir = "/tmp/ws" config_path = "/tmp/config.toml" @@ -3416,7 +3440,7 @@ default_temperature = 0.7 } #[test] - fn storage_provider_dburl_alias_deserializes() { + async fn storage_provider_dburl_alias_deserializes() { let raw = r#" default_temperature = 0.7 @@ -3443,7 +3467,7 @@ connect_timeout_secs = 12 } #[test] - fn agent_config_defaults() { + async fn agent_config_defaults() { let cfg = AgentConfig::default(); assert!(!cfg.compact_context); assert_eq!(cfg.max_tool_iterations, 10); @@ -3453,7 +3477,7 @@ connect_timeout_secs = 12 } #[test] - fn agent_config_deserializes() { + async fn agent_config_deserializes() { let raw = r#" default_temperature = 0.7 [agent] @@ -3474,8 +3498,8 @@ tool_dispatcher = "xml" #[tokio::test] async fn config_save_and_load_tmpdir() { let dir = std::env::temp_dir().join("zeroclaw_test_config"); - let _ = fs::remove_dir_all(&dir); - fs::create_dir_all(&dir).unwrap(); + let _ = fs::remove_dir_all(&dir).await; + fs::create_dir_all(&dir).await.unwrap(); let config_path = dir.join("config.toml"); let config = Config { @@ -3514,7 +3538,7 @@ tool_dispatcher = "xml" hardware: HardwareConfig::default(), }; - config.save().unwrap(); + config.save().await.unwrap(); assert!(config_path.exists()); let contents = tokio::fs::read_to_string(&config_path).await.unwrap(); @@ -3529,7 +3553,7 @@ tool_dispatcher = "xml" assert_eq!(loaded.default_model.as_deref(), Some("test-model")); assert!((loaded.default_temperature - 0.9).abs() < f64::EPSILON); - let _ = fs::remove_dir_all(&dir); + let _ = fs::remove_dir_all(&dir).await; } #[tokio::test] @@ -3538,7 +3562,7 @@ tool_dispatcher = "xml" "zeroclaw_test_nested_credentials_{}", uuid::Uuid::new_v4() )); - fs::create_dir_all(&dir).unwrap(); + fs::create_dir_all(&dir).await.unwrap(); let mut config = Config::default(); config.workspace_dir = dir.join("workspace"); @@ -3561,7 +3585,7 @@ tool_dispatcher = "xml" }, ); - config.save().unwrap(); + config.save().await.unwrap(); let contents = tokio::fs::read_to_string(config.config_path.clone()) .await @@ -3612,44 +3636,43 @@ tool_dispatcher = "xml" "postgres://user:pw@host/db" ); - let _ = fs::remove_dir_all(&dir); + let _ = fs::remove_dir_all(&dir).await; } #[tokio::test] async fn config_save_atomic_cleanup() { let dir = std::env::temp_dir().join(format!("zeroclaw_test_config_{}", uuid::Uuid::new_v4())); - fs::create_dir_all(&dir).unwrap(); + fs::create_dir_all(&dir).await.unwrap(); let config_path = dir.join("config.toml"); let mut config = Config::default(); config.workspace_dir = dir.join("workspace"); config.config_path = config_path.clone(); config.default_model = Some("model-a".into()); - - config.save().unwrap(); + config.save().await.unwrap(); assert!(config_path.exists()); config.default_model = Some("model-b".into()); - config.save().unwrap(); + config.save().await.unwrap(); let contents = tokio::fs::read_to_string(&config_path).await.unwrap(); assert!(contents.contains("model-b")); - let names: Vec = fs::read_dir(&dir) - .unwrap() + let names: Vec = ReadDirStream::new(fs::read_dir(&dir).await.unwrap()) .map(|entry| entry.unwrap().file_name().to_string_lossy().to_string()) - .collect(); + .collect() + .await; assert!(!names.iter().any(|name| name.contains(".tmp-"))); assert!(!names.iter().any(|name| name.ends_with(".bak"))); - let _ = fs::remove_dir_all(&dir); + let _ = fs::remove_dir_all(&dir).await; } // ── Telegram / Discord config ──────────────────────────── #[test] - fn telegram_config_serde() { + async fn telegram_config_serde() { let tc = TelegramConfig { bot_token: "123:XYZ".into(), allowed_users: vec!["alice".into(), "bob".into()], @@ -3666,7 +3689,7 @@ tool_dispatcher = "xml" } #[test] - fn telegram_config_defaults_stream_off() { + async fn telegram_config_defaults_stream_off() { let json = r#"{"bot_token":"tok","allowed_users":[]}"#; let parsed: TelegramConfig = serde_json::from_str(json).unwrap(); assert_eq!(parsed.stream_mode, StreamMode::Off); @@ -3674,7 +3697,7 @@ tool_dispatcher = "xml" } #[test] - fn discord_config_serde() { + async fn discord_config_serde() { let dc = DiscordConfig { bot_token: "discord-token".into(), guild_id: Some("12345".into()), @@ -3689,7 +3712,7 @@ tool_dispatcher = "xml" } #[test] - fn discord_config_optional_guild() { + async fn discord_config_optional_guild() { let dc = DiscordConfig { bot_token: "tok".into(), guild_id: None, @@ -3705,7 +3728,7 @@ tool_dispatcher = "xml" // ── iMessage / Matrix config ──────────────────────────── #[test] - fn imessage_config_serde() { + async fn imessage_config_serde() { let ic = IMessageConfig { allowed_contacts: vec!["+1234567890".into(), "user@icloud.com".into()], }; @@ -3716,7 +3739,7 @@ tool_dispatcher = "xml" } #[test] - fn imessage_config_empty_contacts() { + async fn imessage_config_empty_contacts() { let ic = IMessageConfig { allowed_contacts: vec![], }; @@ -3726,7 +3749,7 @@ tool_dispatcher = "xml" } #[test] - fn imessage_config_wildcard() { + async fn imessage_config_wildcard() { let ic = IMessageConfig { allowed_contacts: vec!["*".into()], }; @@ -3736,7 +3759,7 @@ tool_dispatcher = "xml" } #[test] - fn matrix_config_serde() { + async fn matrix_config_serde() { let mc = MatrixConfig { homeserver: "https://matrix.org".into(), access_token: "syt_token_abc".into(), @@ -3756,7 +3779,7 @@ tool_dispatcher = "xml" } #[test] - fn matrix_config_toml_roundtrip() { + async fn matrix_config_toml_roundtrip() { let mc = MatrixConfig { homeserver: "https://synapse.local:8448".into(), access_token: "tok".into(), @@ -3772,7 +3795,7 @@ tool_dispatcher = "xml" } #[test] - fn matrix_config_backward_compatible_without_session_hints() { + async fn matrix_config_backward_compatible_without_session_hints() { let toml = r#" homeserver = "https://matrix.org" access_token = "tok" @@ -3787,7 +3810,7 @@ allowed_users = ["@ops:matrix.org"] } #[test] - fn signal_config_serde() { + async fn signal_config_serde() { let sc = SignalConfig { http_url: "http://127.0.0.1:8686".into(), account: "+1234567890".into(), @@ -3807,7 +3830,7 @@ allowed_users = ["@ops:matrix.org"] } #[test] - fn signal_config_toml_roundtrip() { + async fn signal_config_toml_roundtrip() { let sc = SignalConfig { http_url: "http://localhost:8080".into(), account: "+9876543210".into(), @@ -3825,7 +3848,7 @@ allowed_users = ["@ops:matrix.org"] } #[test] - fn signal_config_defaults() { + async fn signal_config_defaults() { let json = r#"{"http_url":"http://127.0.0.1:8686","account":"+1234567890"}"#; let parsed: SignalConfig = serde_json::from_str(json).unwrap(); assert!(parsed.group_id.is_none()); @@ -3835,7 +3858,7 @@ allowed_users = ["@ops:matrix.org"] } #[test] - fn channels_config_with_imessage_and_matrix() { + async fn channels_config_with_imessage_and_matrix() { let c = ChannelsConfig { cli: true, telegram: None, @@ -3873,7 +3896,7 @@ allowed_users = ["@ops:matrix.org"] } #[test] - fn channels_config_default_has_no_imessage_matrix() { + async fn channels_config_default_has_no_imessage_matrix() { let c = ChannelsConfig::default(); assert!(c.imessage.is_none()); assert!(c.matrix.is_none()); @@ -3882,7 +3905,7 @@ allowed_users = ["@ops:matrix.org"] // ── Edge cases: serde(default) for allowed_users ───────── #[test] - fn discord_config_deserializes_without_allowed_users() { + async fn discord_config_deserializes_without_allowed_users() { // Old configs won't have allowed_users — serde(default) should fill vec![] let json = r#"{"bot_token":"tok","guild_id":"123"}"#; let parsed: DiscordConfig = serde_json::from_str(json).unwrap(); @@ -3890,28 +3913,28 @@ allowed_users = ["@ops:matrix.org"] } #[test] - fn discord_config_deserializes_with_allowed_users() { + async fn discord_config_deserializes_with_allowed_users() { let json = r#"{"bot_token":"tok","guild_id":"123","allowed_users":["111","222"]}"#; let parsed: DiscordConfig = serde_json::from_str(json).unwrap(); assert_eq!(parsed.allowed_users, vec!["111", "222"]); } #[test] - fn slack_config_deserializes_without_allowed_users() { + async fn slack_config_deserializes_without_allowed_users() { let json = r#"{"bot_token":"xoxb-tok"}"#; let parsed: SlackConfig = serde_json::from_str(json).unwrap(); assert!(parsed.allowed_users.is_empty()); } #[test] - fn slack_config_deserializes_with_allowed_users() { + async fn slack_config_deserializes_with_allowed_users() { let json = r#"{"bot_token":"xoxb-tok","allowed_users":["U111"]}"#; let parsed: SlackConfig = serde_json::from_str(json).unwrap(); assert_eq!(parsed.allowed_users, vec!["U111"]); } #[test] - fn discord_config_toml_backward_compat() { + async fn discord_config_toml_backward_compat() { let toml_str = r#" bot_token = "tok" guild_id = "123" @@ -3922,7 +3945,7 @@ guild_id = "123" } #[test] - fn slack_config_toml_backward_compat() { + async fn slack_config_toml_backward_compat() { let toml_str = r#" bot_token = "xoxb-tok" channel_id = "C123" @@ -3933,14 +3956,14 @@ channel_id = "C123" } #[test] - fn webhook_config_with_secret() { + async fn webhook_config_with_secret() { let json = r#"{"port":8080,"secret":"my-secret-key"}"#; let parsed: WebhookConfig = serde_json::from_str(json).unwrap(); assert_eq!(parsed.secret.as_deref(), Some("my-secret-key")); } #[test] - fn webhook_config_without_secret() { + async fn webhook_config_without_secret() { let json = r#"{"port":8080}"#; let parsed: WebhookConfig = serde_json::from_str(json).unwrap(); assert!(parsed.secret.is_none()); @@ -3950,7 +3973,7 @@ channel_id = "C123" // ── WhatsApp config ────────────────────────────────────── #[test] - fn whatsapp_config_serde() { + async fn whatsapp_config_serde() { let wc = WhatsAppConfig { access_token: Some("EAABx...".into()), phone_number_id: Some("123456789".into()), @@ -3970,7 +3993,7 @@ channel_id = "C123" } #[test] - fn whatsapp_config_toml_roundtrip() { + async fn whatsapp_config_toml_roundtrip() { let wc = WhatsAppConfig { access_token: Some("tok".into()), phone_number_id: Some("12345".into()), @@ -3988,14 +4011,14 @@ channel_id = "C123" } #[test] - fn whatsapp_config_deserializes_without_allowed_numbers() { + async fn whatsapp_config_deserializes_without_allowed_numbers() { let json = r#"{"access_token":"tok","phone_number_id":"123","verify_token":"ver"}"#; let parsed: WhatsAppConfig = serde_json::from_str(json).unwrap(); assert!(parsed.allowed_numbers.is_empty()); } #[test] - fn whatsapp_config_wildcard_allowed() { + async fn whatsapp_config_wildcard_allowed() { let wc = WhatsAppConfig { access_token: Some("tok".into()), phone_number_id: Some("123".into()), @@ -4012,7 +4035,7 @@ channel_id = "C123" } #[test] - fn channels_config_with_whatsapp() { + async fn channels_config_with_whatsapp() { let c = ChannelsConfig { cli: true, telegram: None, @@ -4050,7 +4073,7 @@ channel_id = "C123" } #[test] - fn channels_config_default_has_no_whatsapp() { + async fn channels_config_default_has_no_whatsapp() { let c = ChannelsConfig::default(); assert!(c.whatsapp.is_none()); } @@ -4060,13 +4083,13 @@ channel_id = "C123" // ══════════════════════════════════════════════════════════ #[test] - fn checklist_gateway_default_requires_pairing() { + async fn checklist_gateway_default_requires_pairing() { let g = GatewayConfig::default(); assert!(g.require_pairing, "Pairing must be required by default"); } #[test] - fn checklist_gateway_default_blocks_public_bind() { + async fn checklist_gateway_default_blocks_public_bind() { let g = GatewayConfig::default(); assert!( !g.allow_public_bind, @@ -4075,7 +4098,7 @@ channel_id = "C123" } #[test] - fn checklist_gateway_default_no_tokens() { + async fn checklist_gateway_default_no_tokens() { let g = GatewayConfig::default(); assert!( g.paired_tokens.is_empty(), @@ -4090,7 +4113,7 @@ channel_id = "C123" } #[test] - fn checklist_gateway_cli_default_host_is_localhost() { + async fn checklist_gateway_cli_default_host_is_localhost() { // The CLI default for --host is 127.0.0.1 (checked in main.rs) // Here we verify the config default matches let c = Config::default(); @@ -4105,7 +4128,7 @@ channel_id = "C123" } #[test] - fn checklist_gateway_serde_roundtrip() { + async fn checklist_gateway_serde_roundtrip() { let g = GatewayConfig { port: 3000, host: "127.0.0.1".into(), @@ -4133,7 +4156,7 @@ channel_id = "C123" } #[test] - fn checklist_gateway_backward_compat_no_gateway_section() { + async fn checklist_gateway_backward_compat_no_gateway_section() { // Old configs without [gateway] should get secure defaults let minimal = r#" workspace_dir = "/tmp/ws" @@ -4152,7 +4175,7 @@ default_temperature = 0.7 } #[test] - fn checklist_autonomy_default_is_workspace_scoped() { + async fn checklist_autonomy_default_is_workspace_scoped() { let a = AutonomyConfig::default(); assert!(a.workspace_only, "Default autonomy must be workspace_only"); assert!( @@ -4174,7 +4197,7 @@ default_temperature = 0.7 // ══════════════════════════════════════════════════════════ #[test] - fn composio_config_default_disabled() { + async fn composio_config_default_disabled() { let c = ComposioConfig::default(); assert!(!c.enabled, "Composio must be disabled by default"); assert!(c.api_key.is_none(), "No API key by default"); @@ -4182,7 +4205,7 @@ default_temperature = 0.7 } #[test] - fn composio_config_serde_roundtrip() { + async fn composio_config_serde_roundtrip() { let c = ComposioConfig { enabled: true, api_key: Some("comp-key-123".into()), @@ -4196,7 +4219,7 @@ default_temperature = 0.7 } #[test] - fn composio_config_backward_compat_missing_section() { + async fn composio_config_backward_compat_missing_section() { let minimal = r#" workspace_dir = "/tmp/ws" config_path = "/tmp/config.toml" @@ -4211,7 +4234,7 @@ default_temperature = 0.7 } #[test] - fn composio_config_partial_toml() { + async fn composio_config_partial_toml() { let toml_str = r" enabled = true "; @@ -4226,13 +4249,13 @@ enabled = true // ══════════════════════════════════════════════════════════ #[test] - fn secrets_config_default_encrypts() { + async fn secrets_config_default_encrypts() { let s = SecretsConfig::default(); assert!(s.encrypt, "Encryption must be enabled by default"); } #[test] - fn secrets_config_serde_roundtrip() { + async fn secrets_config_serde_roundtrip() { let s = SecretsConfig { encrypt: false }; let toml_str = toml::to_string(&s).unwrap(); let parsed: SecretsConfig = toml::from_str(&toml_str).unwrap(); @@ -4240,7 +4263,7 @@ enabled = true } #[test] - fn secrets_config_backward_compat_missing_section() { + async fn secrets_config_backward_compat_missing_section() { let minimal = r#" workspace_dir = "/tmp/ws" config_path = "/tmp/config.toml" @@ -4254,7 +4277,7 @@ default_temperature = 0.7 } #[test] - fn config_default_has_composio_and_secrets() { + async fn config_default_has_composio_and_secrets() { let c = Config::default(); assert!(!c.composio.enabled); assert!(c.composio.api_key.is_none()); @@ -4264,7 +4287,7 @@ default_temperature = 0.7 } #[test] - fn browser_config_default_disabled() { + async fn browser_config_default_disabled() { let b = BrowserConfig::default(); assert!(!b.enabled); assert!(b.allowed_domains.is_empty()); @@ -4281,7 +4304,7 @@ default_temperature = 0.7 } #[test] - fn browser_config_serde_roundtrip() { + async fn browser_config_serde_roundtrip() { let b = BrowserConfig { enabled: true, allowed_domains: vec!["example.com".into(), "docs.example.com".into()], @@ -4325,7 +4348,7 @@ default_temperature = 0.7 } #[test] - fn browser_config_backward_compat_missing_section() { + async fn browser_config_backward_compat_missing_section() { let minimal = r#" workspace_dir = "/tmp/ws" config_path = "/tmp/config.toml" @@ -4338,11 +4361,9 @@ default_temperature = 0.7 // ── Environment variable overrides (Docker support) ───────── - fn env_override_test_guard() -> std::sync::MutexGuard<'static, ()> { - static ENV_OVERRIDE_TEST_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(()); - ENV_OVERRIDE_TEST_LOCK - .lock() - .expect("env override test lock poisoned") + async fn env_override_lock() -> MutexGuard<'static, ()> { + static ENV_OVERRIDE_TEST_LOCK: Mutex<()> = Mutex::const_new(()); + ENV_OVERRIDE_TEST_LOCK.lock().await } fn clear_proxy_env_test_vars() { @@ -4368,8 +4389,8 @@ default_temperature = 0.7 } #[test] - fn env_override_api_key() { - let _env_guard = env_override_test_guard(); + async fn env_override_api_key() { + let _env_guard = env_override_lock().await; let mut config = Config::default(); assert!(config.api_key.is_none()); @@ -4381,8 +4402,8 @@ default_temperature = 0.7 } #[test] - fn env_override_api_key_fallback() { - let _env_guard = env_override_test_guard(); + async fn env_override_api_key_fallback() { + let _env_guard = env_override_lock().await; let mut config = Config::default(); std::env::remove_var("ZEROCLAW_API_KEY"); @@ -4394,8 +4415,8 @@ default_temperature = 0.7 } #[test] - fn env_override_provider() { - let _env_guard = env_override_test_guard(); + async fn env_override_provider() { + let _env_guard = env_override_lock().await; let mut config = Config::default(); std::env::set_var("ZEROCLAW_PROVIDER", "anthropic"); @@ -4406,8 +4427,8 @@ default_temperature = 0.7 } #[test] - fn env_override_provider_fallback() { - let _env_guard = env_override_test_guard(); + async fn env_override_provider_fallback() { + let _env_guard = env_override_lock().await; let mut config = Config::default(); std::env::remove_var("ZEROCLAW_PROVIDER"); @@ -4419,8 +4440,8 @@ default_temperature = 0.7 } #[test] - fn env_override_provider_fallback_does_not_replace_non_default_provider() { - let _env_guard = env_override_test_guard(); + async fn env_override_provider_fallback_does_not_replace_non_default_provider() { + let _env_guard = env_override_lock().await; let mut config = Config { default_provider: Some("custom:https://proxy.example.com/v1".to_string()), ..Config::default() @@ -4438,8 +4459,8 @@ default_temperature = 0.7 } #[test] - fn env_override_zero_claw_provider_overrides_non_default_provider() { - let _env_guard = env_override_test_guard(); + async fn env_override_zero_claw_provider_overrides_non_default_provider() { + let _env_guard = env_override_lock().await; let mut config = Config { default_provider: Some("custom:https://proxy.example.com/v1".to_string()), ..Config::default() @@ -4455,8 +4476,8 @@ default_temperature = 0.7 } #[test] - fn env_override_glm_api_key_for_regional_aliases() { - let _env_guard = env_override_test_guard(); + async fn env_override_glm_api_key_for_regional_aliases() { + let _env_guard = env_override_lock().await; let mut config = Config { default_provider: Some("glm-cn".to_string()), ..Config::default() @@ -4470,8 +4491,8 @@ default_temperature = 0.7 } #[test] - fn env_override_zai_api_key_for_regional_aliases() { - let _env_guard = env_override_test_guard(); + async fn env_override_zai_api_key_for_regional_aliases() { + let _env_guard = env_override_lock().await; let mut config = Config { default_provider: Some("zai-cn".to_string()), ..Config::default() @@ -4485,8 +4506,8 @@ default_temperature = 0.7 } #[test] - fn env_override_model() { - let _env_guard = env_override_test_guard(); + async fn env_override_model() { + let _env_guard = env_override_lock().await; let mut config = Config::default(); std::env::set_var("ZEROCLAW_MODEL", "gpt-4o"); @@ -4497,8 +4518,8 @@ default_temperature = 0.7 } #[test] - fn env_override_model_fallback() { - let _env_guard = env_override_test_guard(); + async fn env_override_model_fallback() { + let _env_guard = env_override_lock().await; let mut config = Config::default(); std::env::remove_var("ZEROCLAW_MODEL"); @@ -4513,8 +4534,8 @@ default_temperature = 0.7 } #[test] - fn env_override_workspace() { - let _env_guard = env_override_test_guard(); + async fn env_override_workspace() { + let _env_guard = env_override_lock().await; let mut config = Config::default(); std::env::set_var("ZEROCLAW_WORKSPACE", "/custom/workspace"); @@ -4525,8 +4546,8 @@ default_temperature = 0.7 } #[test] - fn load_or_init_workspace_override_uses_workspace_root_for_config() { - let _env_guard = env_override_test_guard(); + async fn load_or_init_workspace_override_uses_workspace_root_for_config() { + let _env_guard = env_override_lock().await; let temp_home = std::env::temp_dir().join(format!("zeroclaw_test_home_{}", uuid::Uuid::new_v4())); let workspace_dir = temp_home.join("profile-a"); @@ -4535,7 +4556,7 @@ default_temperature = 0.7 std::env::set_var("HOME", &temp_home); std::env::set_var("ZEROCLAW_WORKSPACE", &workspace_dir); - let config = Config::load_or_init().unwrap(); + let config = Config::load_or_init().await.unwrap(); assert_eq!(config.workspace_dir, workspace_dir.join("workspace")); assert_eq!(config.config_path, workspace_dir.join("config.toml")); @@ -4547,12 +4568,12 @@ default_temperature = 0.7 } else { std::env::remove_var("HOME"); } - let _ = fs::remove_dir_all(temp_home); + let _ = fs::remove_dir_all(temp_home).await; } #[test] - fn load_or_init_workspace_suffix_uses_legacy_config_layout() { - let _env_guard = env_override_test_guard(); + async fn load_or_init_workspace_suffix_uses_legacy_config_layout() { + let _env_guard = env_override_lock().await; let temp_home = std::env::temp_dir().join(format!("zeroclaw_test_home_{}", uuid::Uuid::new_v4())); let workspace_dir = temp_home.join("workspace"); @@ -4562,7 +4583,7 @@ default_temperature = 0.7 std::env::set_var("HOME", &temp_home); std::env::set_var("ZEROCLAW_WORKSPACE", &workspace_dir); - let config = Config::load_or_init().unwrap(); + let config = Config::load_or_init().await.unwrap(); assert_eq!(config.workspace_dir, workspace_dir); assert_eq!(config.config_path, legacy_config_path); @@ -4574,32 +4595,33 @@ default_temperature = 0.7 } else { std::env::remove_var("HOME"); } - let _ = fs::remove_dir_all(temp_home); + let _ = fs::remove_dir_all(temp_home).await; } #[test] - fn load_or_init_workspace_override_keeps_existing_legacy_config() { - let _env_guard = env_override_test_guard(); + async fn load_or_init_workspace_override_keeps_existing_legacy_config() { + let _env_guard = env_override_lock().await; let temp_home = std::env::temp_dir().join(format!("zeroclaw_test_home_{}", uuid::Uuid::new_v4())); let workspace_dir = temp_home.join("custom-workspace"); let legacy_config_dir = temp_home.join(".zeroclaw"); let legacy_config_path = legacy_config_dir.join("config.toml"); - fs::create_dir_all(&legacy_config_dir).unwrap(); + fs::create_dir_all(&legacy_config_dir).await.unwrap(); fs::write( &legacy_config_path, r#"default_temperature = 0.7 default_model = "legacy-model" "#, ) + .await .unwrap(); let original_home = std::env::var("HOME").ok(); std::env::set_var("HOME", &temp_home); std::env::set_var("ZEROCLAW_WORKSPACE", &workspace_dir); - let config = Config::load_or_init().unwrap(); + let config = Config::load_or_init().await.unwrap(); assert_eq!(config.workspace_dir, workspace_dir); assert_eq!(config.config_path, legacy_config_path); @@ -4611,30 +4633,33 @@ default_model = "legacy-model" } else { std::env::remove_var("HOME"); } - let _ = fs::remove_dir_all(temp_home); + let _ = fs::remove_dir_all(temp_home).await; } #[test] - fn load_or_init_uses_persisted_active_workspace_marker() { - let _env_guard = env_override_test_guard(); + async fn load_or_init_uses_persisted_active_workspace_marker() { + let _env_guard = env_override_lock().await; let temp_home = std::env::temp_dir().join(format!("zeroclaw_test_home_{}", uuid::Uuid::new_v4())); let custom_config_dir = temp_home.join("profiles").join("agent-alpha"); - fs::create_dir_all(&custom_config_dir).unwrap(); + fs::create_dir_all(&custom_config_dir).await.unwrap(); fs::write( custom_config_dir.join("config.toml"), "default_temperature = 0.7\ndefault_model = \"persisted-profile\"\n", ) + .await .unwrap(); let original_home = std::env::var("HOME").ok(); std::env::set_var("HOME", &temp_home); std::env::remove_var("ZEROCLAW_WORKSPACE"); - persist_active_workspace_config_dir(&custom_config_dir).unwrap(); + persist_active_workspace_config_dir(&custom_config_dir) + .await + .unwrap(); - let config = Config::load_or_init().unwrap(); + let config = Config::load_or_init().await.unwrap(); assert_eq!(config.config_path, custom_config_dir.join("config.toml")); assert_eq!(config.workspace_dir, custom_config_dir.join("workspace")); @@ -4645,30 +4670,33 @@ default_model = "legacy-model" } else { std::env::remove_var("HOME"); } - let _ = fs::remove_dir_all(temp_home); + let _ = fs::remove_dir_all(temp_home).await; } #[test] - fn load_or_init_env_workspace_override_takes_priority_over_marker() { - let _env_guard = env_override_test_guard(); + async fn load_or_init_env_workspace_override_takes_priority_over_marker() { + let _env_guard = env_override_lock().await; let temp_home = std::env::temp_dir().join(format!("zeroclaw_test_home_{}", uuid::Uuid::new_v4())); let marker_config_dir = temp_home.join("profiles").join("persisted-profile"); let env_workspace_dir = temp_home.join("env-workspace"); - fs::create_dir_all(&marker_config_dir).unwrap(); + fs::create_dir_all(&marker_config_dir).await.unwrap(); fs::write( marker_config_dir.join("config.toml"), "default_temperature = 0.7\ndefault_model = \"marker-model\"\n", ) + .await .unwrap(); let original_home = std::env::var("HOME").ok(); std::env::set_var("HOME", &temp_home); - persist_active_workspace_config_dir(&marker_config_dir).unwrap(); + persist_active_workspace_config_dir(&marker_config_dir) + .await + .unwrap(); std::env::set_var("ZEROCLAW_WORKSPACE", &env_workspace_dir); - let config = Config::load_or_init().unwrap(); + let config = Config::load_or_init().await.unwrap(); assert_eq!(config.workspace_dir, env_workspace_dir.join("workspace")); assert_eq!(config.config_path, env_workspace_dir.join("config.toml")); @@ -4679,12 +4707,12 @@ default_model = "legacy-model" } else { std::env::remove_var("HOME"); } - let _ = fs::remove_dir_all(temp_home); + let _ = fs::remove_dir_all(temp_home).await; } #[test] - fn persist_active_workspace_marker_is_cleared_for_default_config_dir() { - let _env_guard = env_override_test_guard(); + async fn persist_active_workspace_marker_is_cleared_for_default_config_dir() { + let _env_guard = env_override_lock().await; let temp_home = std::env::temp_dir().join(format!("zeroclaw_test_home_{}", uuid::Uuid::new_v4())); let default_config_dir = temp_home.join(".zeroclaw"); @@ -4694,10 +4722,14 @@ default_model = "legacy-model" let original_home = std::env::var("HOME").ok(); std::env::set_var("HOME", &temp_home); - persist_active_workspace_config_dir(&custom_config_dir).unwrap(); + persist_active_workspace_config_dir(&custom_config_dir) + .await + .unwrap(); assert!(marker_path.exists()); - persist_active_workspace_config_dir(&default_config_dir).unwrap(); + persist_active_workspace_config_dir(&default_config_dir) + .await + .unwrap(); assert!(!marker_path.exists()); if let Some(home) = original_home { @@ -4705,12 +4737,12 @@ default_model = "legacy-model" } else { std::env::remove_var("HOME"); } - let _ = fs::remove_dir_all(temp_home); + let _ = fs::remove_dir_all(temp_home).await; } #[test] - fn env_override_empty_values_ignored() { - let _env_guard = env_override_test_guard(); + async fn env_override_empty_values_ignored() { + let _env_guard = env_override_lock().await; let mut config = Config::default(); let original_provider = config.default_provider.clone(); @@ -4722,8 +4754,8 @@ default_model = "legacy-model" } #[test] - fn env_override_gateway_port() { - let _env_guard = env_override_test_guard(); + async fn env_override_gateway_port() { + let _env_guard = env_override_lock().await; let mut config = Config::default(); assert_eq!(config.gateway.port, 3000); @@ -4735,8 +4767,8 @@ default_model = "legacy-model" } #[test] - fn env_override_port_fallback() { - let _env_guard = env_override_test_guard(); + async fn env_override_port_fallback() { + let _env_guard = env_override_lock().await; let mut config = Config::default(); std::env::remove_var("ZEROCLAW_GATEWAY_PORT"); @@ -4748,8 +4780,8 @@ default_model = "legacy-model" } #[test] - fn env_override_gateway_host() { - let _env_guard = env_override_test_guard(); + async fn env_override_gateway_host() { + let _env_guard = env_override_lock().await; let mut config = Config::default(); assert_eq!(config.gateway.host, "127.0.0.1"); @@ -4761,8 +4793,8 @@ default_model = "legacy-model" } #[test] - fn env_override_host_fallback() { - let _env_guard = env_override_test_guard(); + async fn env_override_host_fallback() { + let _env_guard = env_override_lock().await; let mut config = Config::default(); std::env::remove_var("ZEROCLAW_GATEWAY_HOST"); @@ -4774,8 +4806,8 @@ default_model = "legacy-model" } #[test] - fn env_override_temperature() { - let _env_guard = env_override_test_guard(); + async fn env_override_temperature() { + let _env_guard = env_override_lock().await; let mut config = Config::default(); std::env::set_var("ZEROCLAW_TEMPERATURE", "0.5"); @@ -4786,8 +4818,8 @@ default_model = "legacy-model" } #[test] - fn env_override_temperature_out_of_range_ignored() { - let _env_guard = env_override_test_guard(); + async fn env_override_temperature_out_of_range_ignored() { + let _env_guard = env_override_lock().await; // Clean up any leftover env vars from other tests std::env::remove_var("ZEROCLAW_TEMPERATURE"); @@ -4806,8 +4838,8 @@ default_model = "legacy-model" } #[test] - fn env_override_invalid_port_ignored() { - let _env_guard = env_override_test_guard(); + async fn env_override_invalid_port_ignored() { + let _env_guard = env_override_lock().await; let mut config = Config::default(); let original_port = config.gateway.port; @@ -4819,8 +4851,8 @@ default_model = "legacy-model" } #[test] - fn env_override_web_search_config() { - let _env_guard = env_override_test_guard(); + async fn env_override_web_search_config() { + let _env_guard = env_override_lock().await; let mut config = Config::default(); std::env::set_var("WEB_SEARCH_ENABLED", "false"); @@ -4848,8 +4880,8 @@ default_model = "legacy-model" } #[test] - fn env_override_web_search_invalid_values_ignored() { - let _env_guard = env_override_test_guard(); + async fn env_override_web_search_invalid_values_ignored() { + let _env_guard = env_override_lock().await; let mut config = Config::default(); let original_max_results = config.web_search.max_results; let original_timeout = config.web_search.timeout_secs; @@ -4867,8 +4899,8 @@ default_model = "legacy-model" } #[test] - fn env_override_storage_provider_config() { - let _env_guard = env_override_test_guard(); + async fn env_override_storage_provider_config() { + let _env_guard = env_override_lock().await; let mut config = Config::default(); std::env::set_var("ZEROCLAW_STORAGE_PROVIDER", "postgres"); @@ -4893,7 +4925,7 @@ default_model = "legacy-model" } #[test] - fn proxy_config_scope_services_requires_entries_when_enabled() { + async fn proxy_config_scope_services_requires_entries_when_enabled() { let proxy = ProxyConfig { enabled: true, http_proxy: Some("http://127.0.0.1:7890".into()), @@ -4909,8 +4941,8 @@ default_model = "legacy-model" } #[test] - fn env_override_proxy_scope_services() { - let _env_guard = env_override_test_guard(); + async fn env_override_proxy_scope_services() { + let _env_guard = env_override_lock().await; clear_proxy_env_test_vars(); let mut config = Config::default(); @@ -4938,8 +4970,8 @@ default_model = "legacy-model" } #[test] - fn env_override_proxy_scope_environment_applies_process_env() { - let _env_guard = env_override_test_guard(); + async fn env_override_proxy_scope_environment_applies_process_env() { + let _env_guard = env_override_lock().await; clear_proxy_env_test_vars(); let mut config = Config::default(); @@ -4975,7 +5007,7 @@ default_model = "legacy-model" } #[test] - fn runtime_proxy_client_cache_reuses_default_profile_key() { + async fn runtime_proxy_client_cache_reuses_default_profile_key() { let service_key = format!( "provider.cache_test.{}", std::time::SystemTime::now() @@ -4996,7 +5028,7 @@ default_model = "legacy-model" } #[test] - fn set_runtime_proxy_config_clears_runtime_proxy_client_cache() { + async fn set_runtime_proxy_config_clears_runtime_proxy_client_cache() { let service_key = format!( "provider.cache_timeout_test.{}", std::time::SystemTime::now() @@ -5015,7 +5047,7 @@ default_model = "legacy-model" } #[test] - fn gateway_config_default_values() { + async fn gateway_config_default_values() { let g = GatewayConfig::default(); assert_eq!(g.port, 3000); assert_eq!(g.host, "127.0.0.1"); @@ -5030,14 +5062,14 @@ default_model = "legacy-model" // ── Peripherals config ─────────────────────────────────────── #[test] - fn peripherals_config_default_disabled() { + async fn peripherals_config_default_disabled() { let p = PeripheralsConfig::default(); assert!(!p.enabled); assert!(p.boards.is_empty()); } #[test] - fn peripheral_board_config_defaults() { + async fn peripheral_board_config_defaults() { let b = PeripheralBoardConfig::default(); assert!(b.board.is_empty()); assert_eq!(b.transport, "serial"); @@ -5046,7 +5078,7 @@ default_model = "legacy-model" } #[test] - fn peripherals_config_toml_roundtrip() { + async fn peripherals_config_toml_roundtrip() { let p = PeripheralsConfig { enabled: true, boards: vec![PeripheralBoardConfig { @@ -5066,7 +5098,7 @@ default_model = "legacy-model" } #[test] - fn lark_config_serde() { + async fn lark_config_serde() { let lc = LarkConfig { app_id: "cli_123456".into(), app_secret: "secret_abc".into(), @@ -5088,7 +5120,7 @@ default_model = "legacy-model" } #[test] - fn lark_config_toml_roundtrip() { + async fn lark_config_toml_roundtrip() { let lc = LarkConfig { app_id: "cli_123456".into(), app_secret: "secret_abc".into(), @@ -5107,7 +5139,7 @@ default_model = "legacy-model" } #[test] - fn lark_config_deserializes_without_optional_fields() { + async fn lark_config_deserializes_without_optional_fields() { let json = r#"{"app_id":"cli_123","app_secret":"secret"}"#; let parsed: LarkConfig = serde_json::from_str(json).unwrap(); assert!(parsed.encrypt_key.is_none()); @@ -5117,7 +5149,7 @@ default_model = "legacy-model" } #[test] - fn lark_config_defaults_to_lark_endpoint() { + async fn lark_config_defaults_to_lark_endpoint() { let json = r#"{"app_id":"cli_123","app_secret":"secret"}"#; let parsed: LarkConfig = serde_json::from_str(json).unwrap(); assert!( @@ -5127,7 +5159,7 @@ default_model = "legacy-model" } #[test] - fn lark_config_with_wildcard_allowed_users() { + async fn lark_config_with_wildcard_allowed_users() { let json = r#"{"app_id":"cli_123","app_secret":"secret","allowed_users":["*"]}"#; let parsed: LarkConfig = serde_json::from_str(json).unwrap(); assert_eq!(parsed.allowed_users, vec!["*"]); @@ -5137,21 +5169,21 @@ default_model = "legacy-model" #[cfg(unix)] #[test] - fn new_config_file_has_restricted_permissions() { - use std::os::unix::fs::PermissionsExt; - + async fn new_config_file_has_restricted_permissions() { let tmp = tempfile::TempDir::new().unwrap(); let config_path = tmp.path().join("config.toml"); // Create a config and save it let mut config = Config::default(); config.config_path = config_path.clone(); - config.save().unwrap(); + config.save().await.unwrap(); // Apply the same permission logic as load_or_init - let _ = std::fs::set_permissions(&config_path, std::fs::Permissions::from_mode(0o600)); + fs::set_permissions(&config_path, Permissions::from_mode(0o600)) + .await + .expect("Failed to set permissions"); - let meta = std::fs::metadata(&config_path).unwrap(); + let meta = fs::metadata(&config_path).await.unwrap(); let mode = meta.permissions().mode() & 0o777; assert_eq!( mode, 0o600, @@ -5161,7 +5193,7 @@ default_model = "legacy-model" #[cfg(unix)] #[test] - fn world_readable_config_is_detectable() { + async fn world_readable_config_is_detectable() { use std::os::unix::fs::PermissionsExt; let tmp = tempfile::TempDir::new().unwrap(); diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index c3c4e31..9847806 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -587,6 +587,7 @@ async fn handle_metrics(State(state): State) -> impl IntoResponse { } /// POST /pair — exchange one-time code for bearer token +#[axum::debug_handler] async fn handle_pair( State(state): State, ConnectInfo(peer_addr): ConnectInfo, @@ -608,10 +609,10 @@ async fn handle_pair( .and_then(|v| v.to_str().ok()) .unwrap_or(""); - match state.pairing.try_pair(code) { + match state.pairing.try_pair(code).await { Ok(Some(token)) => { tracing::info!("🔐 New client paired successfully"); - if let Err(err) = persist_pairing_tokens(&state.config, &state.pairing) { + if let Err(err) = persist_pairing_tokens(state.config.clone(), &state.pairing).await { tracing::error!("🔐 Pairing succeeded but token persistence failed: {err:#}"); let body = serde_json::json!({ "paired": true, @@ -648,11 +649,14 @@ async fn handle_pair( } } -fn persist_pairing_tokens(config: &Arc>, pairing: &PairingGuard) -> Result<()> { +async fn persist_pairing_tokens(config: Arc>, pairing: &PairingGuard) -> Result<()> { let paired_tokens = pairing.tokens(); - let mut cfg = config.lock(); + // This is needed because parking_lot's guard is not Send so we clone the inner + // this should be removed once async mutexes are used everywhere + let mut cfg = { config.lock().clone() }; cfg.gateway.paired_tokens = paired_tokens; cfg.save() + .await .context("Failed to persist paired tokens to config.toml") } @@ -1398,15 +1402,15 @@ mod tests { let mut config = Config::default(); config.config_path = config_path.clone(); config.workspace_dir = workspace_path; - config.save().unwrap(); + config.save().await.unwrap(); let guard = PairingGuard::new(true, &[]); let code = guard.pairing_code().unwrap(); - let token = guard.try_pair(&code).unwrap().unwrap(); + let token = guard.try_pair(&code).await.unwrap().unwrap(); assert!(guard.is_authenticated(&token)); let shared_config = Arc::new(Mutex::new(config)); - persist_pairing_tokens(&shared_config, &guard).unwrap(); + persist_pairing_tokens(shared_config, &guard).await.unwrap(); let saved = tokio::fs::read_to_string(config_path).await.unwrap(); let parsed: Config = toml::from_str(&saved).unwrap(); diff --git a/src/main.rs b/src/main.rs index d93b099..8e21ae8 100644 --- a/src/main.rs +++ b/src/main.rs @@ -553,21 +553,19 @@ async fn main() -> Result<()> { { bail!("--channels-only does not accept --api-key, --provider, --model, or --memory"); } - let config = tokio::task::spawn_blocking(move || { - if channels_only { - onboard::run_channels_repair_wizard() - } else if interactive { - onboard::run_wizard() - } else { - onboard::run_quick_setup( - api_key.as_deref(), - provider.as_deref(), - model.as_deref(), - memory.as_deref(), - ) - } - }) - .await??; + let config = if channels_only { + onboard::run_channels_repair_wizard().await + } else if interactive { + onboard::run_wizard().await + } else { + onboard::run_quick_setup( + api_key.as_deref(), + provider.as_deref(), + model.as_deref(), + memory.as_deref(), + ) + .await + }?; // Auto-start channels if user said yes during wizard if std::env::var("ZEROCLAW_AUTOSTART_CHANNELS").as_deref() == Ok("1") { channels::start_channels(config).await?; @@ -576,7 +574,7 @@ async fn main() -> Result<()> { } // All other commands need config loaded first - let mut config = Config::load_or_init()?; + let mut config = Config::load_or_init().await?; config.apply_env_overrides(); match cli.command { @@ -764,7 +762,7 @@ async fn main() -> Result<()> { Commands::Channel { channel_command } => match channel_command { ChannelCommands::Start => channels::start_channels(config).await, ChannelCommands::Doctor => channels::doctor_channels(config).await, - other => channels::handle_command(other, &config), + other => channels::handle_command(other, &config).await, }, Commands::Integrations { @@ -786,7 +784,7 @@ async fn main() -> Result<()> { } Commands::Peripheral { peripheral_command } => { - peripherals::handle_command(peripheral_command.clone(), &config) + peripherals::handle_command(peripheral_command.clone(), &config).await } Commands::Config { config_command } => match config_command { diff --git a/src/onboard/wizard.rs b/src/onboard/wizard.rs index 40dd7f7..8522540 100644 --- a/src/onboard/wizard.rs +++ b/src/onboard/wizard.rs @@ -95,7 +95,7 @@ fn has_launchable_channels(channels: &ChannelsConfig) -> bool { // ── Main wizard entry point ────────────────────────────────────── -pub fn run_wizard() -> Result { +pub async fn run_wizard() -> Result { println!("{}", style(BANNER).cyan().bold()); println!( @@ -191,8 +191,8 @@ pub fn run_wizard() -> Result { if config.memory.auto_save { "on" } else { "off" } ); - config.save()?; - persist_workspace_selection(&config.config_path)?; + config.save().await?; + persist_workspace_selection(&config.config_path).await?; // ── Final summary ──────────────────────────────────────────── print_summary(&config); @@ -226,7 +226,7 @@ pub fn run_wizard() -> Result { } /// Interactive repair flow: rerun channel setup only without redoing full onboarding. -pub fn run_channels_repair_wizard() -> Result { +pub async fn run_channels_repair_wizard() -> Result { println!("{}", style(BANNER).cyan().bold()); println!( " {}", @@ -236,12 +236,12 @@ pub fn run_channels_repair_wizard() -> Result { ); println!(); - let mut config = Config::load_or_init()?; + let mut config = Config::load_or_init().await?; print_step(1, 1, "Channels (How You Talk to ZeroClaw)"); config.channels_config = setup_channels()?; - config.save()?; - persist_workspace_selection(&config.config_path)?; + config.save().await?; + persist_workspace_selection(&config.config_path).await?; println!(); println!( @@ -321,7 +321,7 @@ fn memory_config_defaults_for_backend(backend: &str) -> MemoryConfig { } #[allow(clippy::too_many_lines)] -pub fn run_quick_setup( +pub async fn run_quick_setup( credential_override: Option<&str>, provider: Option<&str>, model_override: Option<&str>, @@ -396,8 +396,8 @@ pub fn run_quick_setup( query_classification: crate::config::QueryClassificationConfig::default(), }; - config.save()?; - persist_workspace_selection(&config.config_path)?; + config.save().await?; + persist_workspace_selection(&config.config_path).await?; // Scaffold minimal workspace files let default_ctx = ProjectContext { @@ -1459,16 +1459,18 @@ fn print_bullet(text: &str) { println!(" {} {}", style("›").cyan(), text); } -fn persist_workspace_selection(config_path: &Path) -> Result<()> { +async fn persist_workspace_selection(config_path: &Path) -> Result<()> { let config_dir = config_path .parent() .context("Config path must have a parent directory")?; - crate::config::schema::persist_active_workspace_config_dir(config_dir).with_context(|| { - format!( - "Failed to persist active workspace selection for {}", - config_dir.display() - ) - }) + crate::config::schema::persist_active_workspace_config_dir(config_dir) + .await + .with_context(|| { + format!( + "Failed to persist active workspace selection for {}", + config_dir.display() + ) + }) } // ── Step 1: Workspace ──────────────────────────────────────────── diff --git a/src/peripherals/mod.rs b/src/peripherals/mod.rs index f3f8a8a..cfcb785 100644 --- a/src/peripherals/mod.rs +++ b/src/peripherals/mod.rs @@ -42,7 +42,7 @@ pub fn list_configured_boards(config: &PeripheralsConfig) -> Vec<&PeripheralBoar /// Handle `zeroclaw peripheral` subcommands. #[allow(clippy::module_name_repetitions)] -pub fn handle_command(cmd: crate::PeripheralCommands, config: &Config) -> Result<()> { +pub async fn handle_command(cmd: crate::PeripheralCommands, config: &Config) -> Result<()> { match cmd { crate::PeripheralCommands::List => { let boards = list_configured_boards(&config.peripherals); @@ -76,7 +76,7 @@ pub fn handle_command(cmd: crate::PeripheralCommands, config: &Config) -> Result Some(path.clone()) }; - let mut cfg = crate::config::Config::load_or_init()?; + let mut cfg = crate::config::Config::load_or_init().await?; cfg.peripherals.enabled = true; if cfg @@ -95,7 +95,7 @@ pub fn handle_command(cmd: crate::PeripheralCommands, config: &Config) -> Result path: path_opt, baud: 115_200, }); - cfg.save()?; + cfg.save().await?; println!("Added {} at {}. Restart daemon to apply.", board, path); } #[cfg(feature = "hardware")] diff --git a/src/security/pairing.rs b/src/security/pairing.rs index e4030d5..c6a14e5 100644 --- a/src/security/pairing.rs +++ b/src/security/pairing.rs @@ -11,6 +11,7 @@ use parking_lot::Mutex; use sha2::{Digest, Sha256}; use std::collections::HashSet; +use std::sync::Arc; use std::time::Instant; /// Maximum failed pairing attempts before lockout. @@ -23,16 +24,17 @@ const PAIR_LOCKOUT_SECS: u64 = 300; // 5 minutes /// Bearer tokens are stored as SHA-256 hashes to prevent plaintext exposure /// in config files. When a new token is generated, the plaintext is returned /// to the client once, and only the hash is retained. -#[derive(Debug)] +// TODO: I've just made this work with parking_lot but it should use either flume or tokio's async mutexes +#[derive(Debug, Clone)] pub struct PairingGuard { /// Whether pairing is required at all. require_pairing: bool, /// One-time pairing code (generated on startup, consumed on first pair). - pairing_code: Mutex>, + pairing_code: Arc>>, /// Set of SHA-256 hashed bearer tokens (persisted across restarts). - paired_tokens: Mutex>, + paired_tokens: Arc>>, /// Brute-force protection: failed attempt counter + lockout time. - failed_attempts: Mutex<(u32, Option)>, + failed_attempts: Arc)>>, } impl PairingGuard { @@ -62,9 +64,9 @@ impl PairingGuard { }; Self { require_pairing, - pairing_code: Mutex::new(code), - paired_tokens: Mutex::new(tokens), - failed_attempts: Mutex::new((0, None)), + pairing_code: Arc::new(Mutex::new(code)), + paired_tokens: Arc::new(Mutex::new(tokens)), + failed_attempts: Arc::new(Mutex::new((0, None))), } } @@ -78,9 +80,7 @@ impl PairingGuard { self.require_pairing } - /// Attempt to pair with the given code. Returns a bearer token on success. - /// Returns `Err(lockout_seconds)` if locked out due to brute force. - pub fn try_pair(&self, code: &str) -> Result, u64> { + fn try_pair_blocking(&self, code: &str) -> Result, u64> { // Check brute force lockout { let attempts = self.failed_attempts.lock(); @@ -127,6 +127,19 @@ impl PairingGuard { Ok(None) } + /// Attempt to pair with the given code. Returns a bearer token on success. + /// Returns `Err(lockout_seconds)` if locked out due to brute force. + pub async fn try_pair(&self, code: &str) -> Result, u64> { + let this = self.clone(); + let code = code.to_string(); + // TODO: make this function the main one without spawning a task + let handle = tokio::task::spawn_blocking(move || this.try_pair_blocking(&code)); + + handle + .await + .expect("failed to spawn blocking task this should not happen") + } + /// Check if a bearer token is valid (compares against stored hashes). pub fn is_authenticated(&self, token: &str) -> bool { if !self.require_pairing { @@ -232,63 +245,64 @@ pub fn is_public_bind(host: &str) -> bool { #[cfg(test)] mod tests { use super::*; + use tokio::test; // ── PairingGuard ───────────────────────────────────────── #[test] - fn new_guard_generates_code_when_no_tokens() { + async fn new_guard_generates_code_when_no_tokens() { let guard = PairingGuard::new(true, &[]); assert!(guard.pairing_code().is_some()); assert!(!guard.is_paired()); } #[test] - fn new_guard_no_code_when_tokens_exist() { + async fn new_guard_no_code_when_tokens_exist() { let guard = PairingGuard::new(true, &["zc_existing".into()]); assert!(guard.pairing_code().is_none()); assert!(guard.is_paired()); } #[test] - fn new_guard_no_code_when_pairing_disabled() { + async fn new_guard_no_code_when_pairing_disabled() { let guard = PairingGuard::new(false, &[]); assert!(guard.pairing_code().is_none()); } #[test] - fn try_pair_correct_code() { + async fn try_pair_correct_code() { let guard = PairingGuard::new(true, &[]); let code = guard.pairing_code().unwrap().to_string(); - let token = guard.try_pair(&code).unwrap(); + let token = guard.try_pair(&code).await.unwrap(); assert!(token.is_some()); assert!(token.unwrap().starts_with("zc_")); assert!(guard.is_paired()); } #[test] - fn try_pair_wrong_code() { + async fn try_pair_wrong_code() { let guard = PairingGuard::new(true, &[]); - let result = guard.try_pair("000000").unwrap(); + let result = guard.try_pair("000000").await.unwrap(); // Might succeed if code happens to be 000000, but extremely unlikely // Just check it returns Ok(None) normally let _ = result; } #[test] - fn try_pair_empty_code() { + async fn try_pair_empty_code() { let guard = PairingGuard::new(true, &[]); - assert!(guard.try_pair("").unwrap().is_none()); + assert!(guard.try_pair("").await.unwrap().is_none()); } #[test] - fn is_authenticated_with_valid_token() { + async fn is_authenticated_with_valid_token() { // Pass plaintext token — PairingGuard hashes it on load let guard = PairingGuard::new(true, &["zc_valid".into()]); assert!(guard.is_authenticated("zc_valid")); } #[test] - fn is_authenticated_with_prehashed_token() { + async fn is_authenticated_with_prehashed_token() { // Pass an already-hashed token (64 hex chars) let hashed = hash_token("zc_valid"); let guard = PairingGuard::new(true, &[hashed]); @@ -296,20 +310,20 @@ mod tests { } #[test] - fn is_authenticated_with_invalid_token() { + async fn is_authenticated_with_invalid_token() { let guard = PairingGuard::new(true, &["zc_valid".into()]); assert!(!guard.is_authenticated("zc_invalid")); } #[test] - fn is_authenticated_when_pairing_disabled() { + async fn is_authenticated_when_pairing_disabled() { let guard = PairingGuard::new(false, &[]); assert!(guard.is_authenticated("anything")); assert!(guard.is_authenticated("")); } #[test] - fn tokens_returns_hashes() { + async fn tokens_returns_hashes() { let guard = PairingGuard::new(true, &["zc_a".into(), "zc_b".into()]); let tokens = guard.tokens(); assert_eq!(tokens.len(), 2); @@ -322,10 +336,10 @@ mod tests { } #[test] - fn pair_then_authenticate() { + async fn pair_then_authenticate() { let guard = PairingGuard::new(true, &[]); let code = guard.pairing_code().unwrap().to_string(); - let token = guard.try_pair(&code).unwrap().unwrap(); + let token = guard.try_pair(&code).await.unwrap().unwrap(); assert!(guard.is_authenticated(&token)); assert!(!guard.is_authenticated("wrong")); } @@ -333,24 +347,24 @@ mod tests { // ── Token hashing ──────────────────────────────────────── #[test] - fn hash_token_produces_64_hex_chars() { + async fn hash_token_produces_64_hex_chars() { let hash = hash_token("zc_test_token"); assert_eq!(hash.len(), 64); assert!(hash.chars().all(|c| c.is_ascii_hexdigit())); } #[test] - fn hash_token_is_deterministic() { + async fn hash_token_is_deterministic() { assert_eq!(hash_token("zc_abc"), hash_token("zc_abc")); } #[test] - fn hash_token_differs_for_different_inputs() { + async fn hash_token_differs_for_different_inputs() { assert_ne!(hash_token("zc_a"), hash_token("zc_b")); } #[test] - fn is_token_hash_detects_hash_vs_plaintext() { + async fn is_token_hash_detects_hash_vs_plaintext() { assert!(is_token_hash(&hash_token("zc_test"))); assert!(!is_token_hash("zc_test_token")); assert!(!is_token_hash("too_short")); @@ -360,7 +374,7 @@ mod tests { // ── is_public_bind ─────────────────────────────────────── #[test] - fn localhost_variants_not_public() { + async fn localhost_variants_not_public() { assert!(!is_public_bind("127.0.0.1")); assert!(!is_public_bind("localhost")); assert!(!is_public_bind("::1")); @@ -368,12 +382,12 @@ mod tests { } #[test] - fn zero_zero_is_public() { + async fn zero_zero_is_public() { assert!(is_public_bind("0.0.0.0")); } #[test] - fn real_ip_is_public() { + async fn real_ip_is_public() { assert!(is_public_bind("192.168.1.100")); assert!(is_public_bind("10.0.0.1")); } @@ -381,13 +395,13 @@ mod tests { // ── constant_time_eq ───────────────────────────────────── #[test] - fn constant_time_eq_same() { + async fn constant_time_eq_same() { assert!(constant_time_eq("abc", "abc")); assert!(constant_time_eq("", "")); } #[test] - fn constant_time_eq_different() { + async fn constant_time_eq_different() { assert!(!constant_time_eq("abc", "abd")); assert!(!constant_time_eq("abc", "ab")); assert!(!constant_time_eq("a", "")); @@ -396,14 +410,14 @@ mod tests { // ── generate helpers ───────────────────────────────────── #[test] - fn generate_code_is_6_digits() { + async fn generate_code_is_6_digits() { let code = generate_code(); assert_eq!(code.len(), 6); assert!(code.chars().all(|c| c.is_ascii_digit())); } #[test] - fn generate_code_is_not_deterministic() { + async fn generate_code_is_not_deterministic() { // Two codes should differ with overwhelming probability. We try // multiple pairs so a single 1-in-10^6 collision doesn't cause // a flaky CI failure. All 10 pairs colliding is ~1-in-10^60. @@ -416,7 +430,7 @@ mod tests { } #[test] - fn generate_token_has_prefix_and_hex_payload() { + async fn generate_token_has_prefix_and_hex_payload() { let token = generate_token(); let payload = token .strip_prefix("zc_") @@ -434,15 +448,15 @@ mod tests { // ── Brute force protection ─────────────────────────────── #[test] - fn brute_force_lockout_after_max_attempts() { + async fn brute_force_lockout_after_max_attempts() { let guard = PairingGuard::new(true, &[]); // Exhaust all attempts with wrong codes for i in 0..MAX_PAIR_ATTEMPTS { - let result = guard.try_pair(&format!("wrong_{i}")); + let result = guard.try_pair(&format!("wrong_{i}")).await; assert!(result.is_ok(), "Attempt {i} should not be locked out yet"); } // Next attempt should be locked out - let result = guard.try_pair("another_wrong"); + let result = guard.try_pair("another_wrong").await; assert!( result.is_err(), "Should be locked out after {MAX_PAIR_ATTEMPTS} attempts" @@ -456,25 +470,25 @@ mod tests { } #[test] - fn correct_code_resets_failed_attempts() { + async fn correct_code_resets_failed_attempts() { let guard = PairingGuard::new(true, &[]); let code = guard.pairing_code().unwrap().to_string(); // Fail a few times for _ in 0..3 { - let _ = guard.try_pair("wrong"); + let _ = guard.try_pair("wrong").await; } // Correct code should still work (under MAX_PAIR_ATTEMPTS) - let result = guard.try_pair(&code).unwrap(); + let result = guard.try_pair(&code).await.unwrap(); assert!(result.is_some(), "Correct code should work before lockout"); } #[test] - fn lockout_returns_remaining_seconds() { + async fn lockout_returns_remaining_seconds() { let guard = PairingGuard::new(true, &[]); for _ in 0..MAX_PAIR_ATTEMPTS { - let _ = guard.try_pair("wrong"); + let _ = guard.try_pair("wrong").await; } - let err = guard.try_pair("wrong").unwrap_err(); + let err = guard.try_pair("wrong").await.unwrap_err(); // Should be close to PAIR_LOCKOUT_SECS (within a second) assert!( err >= PAIR_LOCKOUT_SECS - 1, diff --git a/src/tools/proxy_config.rs b/src/tools/proxy_config.rs index 3ddde9e..5f4183d 100644 --- a/src/tools/proxy_config.rs +++ b/src/tools/proxy_config.rs @@ -3,6 +3,7 @@ use crate::config::{ runtime_proxy_config, set_runtime_proxy_config, Config, ProxyConfig, ProxyScope, }; use crate::security::SecurityPolicy; +use crate::util::MaybeSet; use async_trait::async_trait; use serde_json::{json, Value}; use std::fs; @@ -93,16 +94,13 @@ impl ProxyConfigTool { anyhow::bail!("'{field}' must be a string or string[]") } - fn parse_optional_string_update( - args: &Value, - field: &str, - ) -> anyhow::Result>> { + fn parse_optional_string_update(args: &Value, field: &str) -> anyhow::Result> { let Some(raw) = args.get(field) else { - return Ok(None); + return Ok(MaybeSet::Unset); }; if raw.is_null() { - return Ok(Some(None)); + return Ok(MaybeSet::Null); } let value = raw @@ -110,7 +108,13 @@ impl ProxyConfigTool { .ok_or_else(|| anyhow::anyhow!("'{field}' must be a string or null"))? .trim() .to_string(); - Ok(Some((!value.is_empty()).then_some(value))) + + let output = if value.is_empty() { + MaybeSet::Null + } else { + MaybeSet::Set(value) + }; + Ok(output) } fn env_snapshot() -> Value { @@ -164,7 +168,7 @@ impl ProxyConfigTool { }) } - fn handle_set(&self, args: &Value) -> anyhow::Result { + async fn handle_set(&self, args: &Value) -> anyhow::Result { let mut cfg = self.load_config_without_env()?; let previous_scope = cfg.proxy.scope; let mut proxy = cfg.proxy.clone(); @@ -185,23 +189,24 @@ impl ProxyConfigTool { })?; } - if let Some(update) = Self::parse_optional_string_update(args, "http_proxy")? { - proxy.http_proxy = update; + if let MaybeSet::Set(update) = Self::parse_optional_string_update(args, "http_proxy")? { + proxy.http_proxy = Some(update); touched_proxy_url = true; } - if let Some(update) = Self::parse_optional_string_update(args, "https_proxy")? { - proxy.https_proxy = update; + if let MaybeSet::Set(update) = Self::parse_optional_string_update(args, "https_proxy")? { + proxy.https_proxy = Some(update); touched_proxy_url = true; } - if let Some(update) = Self::parse_optional_string_update(args, "all_proxy")? { - proxy.all_proxy = update; + if let MaybeSet::Set(update) = Self::parse_optional_string_update(args, "all_proxy")? { + proxy.all_proxy = Some(update); touched_proxy_url = true; } if let Some(no_proxy_raw) = args.get("no_proxy") { proxy.no_proxy = Self::parse_string_list(no_proxy_raw, "no_proxy")?; + touched_proxy_url = true; } if let Some(services_raw) = args.get("services") { @@ -217,7 +222,7 @@ impl ProxyConfigTool { proxy.validate()?; cfg.proxy = proxy.clone(); - cfg.save()?; + cfg.save().await?; set_runtime_proxy_config(proxy.clone()); if proxy.enabled && proxy.scope == ProxyScope::Environment { @@ -237,11 +242,11 @@ impl ProxyConfigTool { }) } - fn handle_disable(&self, args: &Value) -> anyhow::Result { + async fn handle_disable(&self, args: &Value) -> anyhow::Result { let mut cfg = self.load_config_without_env()?; let clear_env_default = cfg.proxy.scope == ProxyScope::Environment; cfg.proxy.enabled = false; - cfg.save()?; + cfg.save().await?; set_runtime_proxy_config(cfg.proxy.clone()); @@ -384,8 +389,8 @@ impl Tool for ProxyConfigTool { } match action.as_str() { - "set" => self.handle_set(&args), - "disable" => self.handle_disable(&args), + "set" => self.handle_set(&args).await, + "disable" => self.handle_disable(&args).await, "apply_env" => self.handle_apply_env(), "clear_env" => self.handle_clear_env(), _ => unreachable!("handled above"), @@ -421,20 +426,20 @@ mod tests { }) } - fn test_config(tmp: &TempDir) -> Arc { + async fn test_config(tmp: &TempDir) -> Arc { let config = Config { workspace_dir: tmp.path().join("workspace"), config_path: tmp.path().join("config.toml"), ..Config::default() }; - config.save().unwrap(); + config.save().await.unwrap(); Arc::new(config) } #[tokio::test] async fn list_services_action_returns_known_keys() { let tmp = TempDir::new().unwrap(); - let tool = ProxyConfigTool::new(test_config(&tmp), test_security()); + let tool = ProxyConfigTool::new(test_config(&tmp).await, test_security()); let result = tool .execute(json!({"action": "list_services"})) @@ -448,7 +453,7 @@ mod tests { #[tokio::test] async fn set_scope_services_requires_services_entries() { let tmp = TempDir::new().unwrap(); - let tool = ProxyConfigTool::new(test_config(&tmp), test_security()); + let tool = ProxyConfigTool::new(test_config(&tmp).await, test_security()); let result = tool .execute(json!({ @@ -471,7 +476,7 @@ mod tests { #[tokio::test] async fn set_and_get_round_trip_proxy_scope() { let tmp = TempDir::new().unwrap(); - let tool = ProxyConfigTool::new(test_config(&tmp), test_security()); + let tool = ProxyConfigTool::new(test_config(&tmp).await, test_security()); let set_result = tool .execute(json!({ diff --git a/src/util.rs b/src/util.rs index 9a218e7..85c7856 100644 --- a/src/util.rs +++ b/src/util.rs @@ -43,6 +43,13 @@ pub fn truncate_with_ellipsis(s: &str, max_chars: usize) -> String { } } +/// Utility enum for handling optional values. +pub enum MaybeSet { + Set(T), + Unset, + Null, +} + #[cfg(test)] mod tests { use super::*;