chore: Remove more blocking io calls

This commit is contained in:
Jayson Reis 2026-02-19 06:30:43 +00:00 committed by Chummy
parent 1aec9ad9c0
commit f1ca73d3d2
14 changed files with 427 additions and 357 deletions

13
Cargo.lock generated
View file

@ -385,6 +385,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8b52af3cb4058c895d37317bb27508dccc8e5f2d39454016b297bf4a400597b8" checksum = "8b52af3cb4058c895d37317bb27508dccc8e5f2d39454016b297bf4a400597b8"
dependencies = [ dependencies = [
"axum-core", "axum-core",
"axum-macros",
"base64", "base64",
"bytes", "bytes",
"form_urlencoded", "form_urlencoded",
@ -431,6 +432,17 @@ dependencies = [
"tower-service", "tower-service",
] ]
[[package]]
name = "axum-macros"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "604fde5e028fea851ce1d8570bbdc034bec850d157f7569d10f347d06808c05c"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.116",
]
[[package]] [[package]]
name = "backon" name = "backon"
version = "1.6.0" version = "1.6.0"
@ -7734,6 +7746,7 @@ dependencies = [
"tokio", "tokio",
"tokio-rustls", "tokio-rustls",
"tokio-serial", "tokio-serial",
"tokio-stream",
"tokio-tungstenite 0.24.0", "tokio-tungstenite 0.24.0",
"tokio-util", "tokio-util",
"toml 1.0.1+spec-1.1.0", "toml 1.0.1+spec-1.1.0",

View file

@ -120,7 +120,7 @@ mail-parser = "0.11.2"
async-imap = { version = "0.11",features = ["runtime-tokio"], default-features = false } async-imap = { version = "0.11",features = ["runtime-tokio"], default-features = false }
# HTTP server (gateway) — replaces raw TCP for proper HTTP/1.1 compliance # HTTP server (gateway) — replaces raw TCP for proper HTTP/1.1 compliance
axum = { version = "0.8", default-features = false, features = ["http1", "json", "tokio", "query", "ws"] } axum = { version = "0.8", default-features = false, features = ["http1", "json", "tokio", "query", "ws", "macros"] }
tower = { version = "0.5", default-features = false } tower = { version = "0.5", default-features = false }
tower-http = { version = "0.6", default-features = false, features = ["limit", "timeout"] } tower-http = { version = "0.6", default-features = false, features = ["limit", "timeout"] }
http-body-util = "0.1" http-body-util = "0.1"
@ -141,6 +141,7 @@ probe-rs = { version = "0.30", optional = true }
# PDF extraction for datasheet RAG (optional, enable with --features rag-pdf) # PDF extraction for datasheet RAG (optional, enable with --features rag-pdf)
pdf-extract = { version = "0.10", optional = true } pdf-extract = { version = "0.10", optional = true }
tokio-stream = { version = "0.1.18", features = ["full"] }
# WhatsApp Web client (wa-rs) — optional, enable with --features whatsapp-web # WhatsApp Web client (wa-rs) — optional, enable with --features whatsapp-web
# Uses wa-rs for Bot and Client, wa-rs-core for storage traits, custom rusqlite backend avoids Diesel conflict. # Uses wa-rs for Bot and Client, wa-rs-core for storage traits, custom rusqlite backend avoids Diesel conflict.

View file

@ -26,7 +26,7 @@ pub fn run_wizard() -> Result<Config> {
security: SecurityConfig::autodetect(), // Silent! security: SecurityConfig::autodetect(), // Silent!
}; };
config.save()?; config.save().await?;
Ok(config) Ok(config)
} }
``` ```

View file

@ -726,7 +726,7 @@ mod tests {
"!r:m".to_string(), "!r:m".to_string(),
vec![], vec![],
Some(" ".to_string()), Some(" ".to_string()),
Some("".to_string()), Some(String::new()),
); );
assert!(ch.session_owner_hint.is_none()); assert!(ch.session_owner_hint.is_none());

View file

@ -1117,7 +1117,7 @@ fn normalize_telegram_identity(value: &str) -> String {
value.trim().trim_start_matches('@').to_string() value.trim().trim_start_matches('@').to_string()
} }
fn bind_telegram_identity(config: &Config, identity: &str) -> Result<()> { async fn bind_telegram_identity(config: &Config, identity: &str) -> Result<()> {
let normalized = normalize_telegram_identity(identity); let normalized = normalize_telegram_identity(identity);
if normalized.is_empty() { if normalized.is_empty() {
anyhow::bail!("Telegram identity cannot be empty"); anyhow::bail!("Telegram identity cannot be empty");
@ -1147,7 +1147,7 @@ fn bind_telegram_identity(config: &Config, identity: &str) -> Result<()> {
} }
telegram.allowed_users.push(normalized.clone()); telegram.allowed_users.push(normalized.clone());
updated.save()?; updated.save().await?;
println!("✅ Bound Telegram identity: {normalized}"); println!("✅ Bound Telegram identity: {normalized}");
println!(" Saved to {}", updated.config_path.display()); println!(" Saved to {}", updated.config_path.display());
match maybe_restart_managed_daemon_service() { match maybe_restart_managed_daemon_service() {
@ -1243,7 +1243,7 @@ fn maybe_restart_managed_daemon_service() -> Result<bool> {
Ok(false) Ok(false)
} }
pub fn handle_command(command: crate::ChannelCommands, config: &Config) -> Result<()> { pub async fn handle_command(command: crate::ChannelCommands, config: &Config) -> Result<()> {
match command { match command {
crate::ChannelCommands::Start => { crate::ChannelCommands::Start => {
anyhow::bail!("Start must be handled in main.rs (requires async runtime)") anyhow::bail!("Start must be handled in main.rs (requires async runtime)")
@ -1290,7 +1290,7 @@ pub fn handle_command(command: crate::ChannelCommands, config: &Config) -> Resul
anyhow::bail!("Remove channel '{name}' — edit ~/.zeroclaw/config.toml directly"); anyhow::bail!("Remove channel '{name}' — edit ~/.zeroclaw/config.toml directly");
} }
crate::ChannelCommands::BindTelegram { identity } => { crate::ChannelCommands::BindTelegram { identity } => {
bind_telegram_identity(config, &identity) bind_telegram_identity(config, &identity).await
} }
} }
} }

View file

@ -6,10 +6,10 @@ use async_trait::async_trait;
use directories::UserDirs; use directories::UserDirs;
use parking_lot::Mutex; use parking_lot::Mutex;
use reqwest::multipart::{Form, Part}; use reqwest::multipart::{Form, Part};
use std::fs;
use std::path::Path; use std::path::Path;
use std::sync::{Arc, RwLock}; use std::sync::{Arc, RwLock};
use std::time::Duration; use std::time::Duration;
use tokio::fs;
/// Telegram's maximum message length for text messages /// Telegram's maximum message length for text messages
const TELEGRAM_MAX_MESSAGE_LENGTH: usize = 4096; const TELEGRAM_MAX_MESSAGE_LENGTH: usize = 4096;
@ -373,7 +373,7 @@ impl TelegramChannel {
.collect() .collect()
} }
fn load_config_without_env() -> anyhow::Result<Config> { async fn load_config_without_env() -> anyhow::Result<Config> {
let home = UserDirs::new() let home = UserDirs::new()
.map(|u| u.home_dir().to_path_buf()) .map(|u| u.home_dir().to_path_buf())
.context("Could not find home directory")?; .context("Could not find home directory")?;
@ -381,6 +381,7 @@ impl TelegramChannel {
let config_path = zeroclaw_dir.join("config.toml"); let config_path = zeroclaw_dir.join("config.toml");
let contents = fs::read_to_string(&config_path) let contents = fs::read_to_string(&config_path)
.await
.with_context(|| format!("Failed to read config file: {}", config_path.display()))?; .with_context(|| format!("Failed to read config file: {}", config_path.display()))?;
let mut config: Config = toml::from_str(&contents) let mut config: Config = toml::from_str(&contents)
.context("Failed to parse config file for Telegram binding")?; .context("Failed to parse config file for Telegram binding")?;
@ -389,8 +390,8 @@ impl TelegramChannel {
Ok(config) Ok(config)
} }
fn persist_allowed_identity_blocking(identity: &str) -> anyhow::Result<()> { async fn persist_allowed_identity(&self, identity: &str) -> anyhow::Result<()> {
let mut config = Self::load_config_without_env()?; let mut config = Self::load_config_without_env().await?;
let Some(telegram) = config.channels_config.telegram.as_mut() else { let Some(telegram) = config.channels_config.telegram.as_mut() else {
anyhow::bail!("Telegram channel config is missing in config.toml"); anyhow::bail!("Telegram channel config is missing in config.toml");
}; };
@ -404,20 +405,13 @@ impl TelegramChannel {
telegram.allowed_users.push(normalized); telegram.allowed_users.push(normalized);
config config
.save() .save()
.await
.context("Failed to persist Telegram allowlist to config.toml")?; .context("Failed to persist Telegram allowlist to config.toml")?;
} }
Ok(()) Ok(())
} }
async fn persist_allowed_identity(&self, identity: &str) -> anyhow::Result<()> {
let identity = identity.to_string();
tokio::task::spawn_blocking(move || Self::persist_allowed_identity_blocking(&identity))
.await
.map_err(|e| anyhow::anyhow!("Failed to join Telegram bind save task: {e}"))??;
Ok(())
}
fn add_allowed_identity_runtime(&self, identity: &str) { fn add_allowed_identity_runtime(&self, identity: &str) {
let normalized = Self::normalize_identity(identity); let normalized = Self::normalize_identity(identity);
if normalized.is_empty() { if normalized.is_empty() {
@ -629,7 +623,7 @@ impl TelegramChannel {
if let Some(code) = Self::extract_bind_code(text) { if let Some(code) = Self::extract_bind_code(text) {
if let Some(pairing) = self.pairing.as_ref() { if let Some(pairing) = self.pairing.as_ref() {
match pairing.try_pair(code) { match pairing.try_pair(code).await {
Ok(Some(_token)) => { Ok(Some(_token)) => {
let bind_identity = normalized_sender_id.clone().or_else(|| { let bind_identity = normalized_sender_id.clone().or_else(|| {
if normalized_username.is_empty() || normalized_username == "unknown" { if normalized_username.is_empty() || normalized_username == "unknown" {

File diff suppressed because it is too large Load diff

View file

@ -587,6 +587,7 @@ async fn handle_metrics(State(state): State<AppState>) -> impl IntoResponse {
} }
/// POST /pair — exchange one-time code for bearer token /// POST /pair — exchange one-time code for bearer token
#[axum::debug_handler]
async fn handle_pair( async fn handle_pair(
State(state): State<AppState>, State(state): State<AppState>,
ConnectInfo(peer_addr): ConnectInfo<SocketAddr>, ConnectInfo(peer_addr): ConnectInfo<SocketAddr>,
@ -608,10 +609,10 @@ async fn handle_pair(
.and_then(|v| v.to_str().ok()) .and_then(|v| v.to_str().ok())
.unwrap_or(""); .unwrap_or("");
match state.pairing.try_pair(code) { match state.pairing.try_pair(code).await {
Ok(Some(token)) => { Ok(Some(token)) => {
tracing::info!("🔐 New client paired successfully"); tracing::info!("🔐 New client paired successfully");
if let Err(err) = persist_pairing_tokens(&state.config, &state.pairing) { if let Err(err) = persist_pairing_tokens(state.config.clone(), &state.pairing).await {
tracing::error!("🔐 Pairing succeeded but token persistence failed: {err:#}"); tracing::error!("🔐 Pairing succeeded but token persistence failed: {err:#}");
let body = serde_json::json!({ let body = serde_json::json!({
"paired": true, "paired": true,
@ -648,11 +649,14 @@ async fn handle_pair(
} }
} }
fn persist_pairing_tokens(config: &Arc<Mutex<Config>>, pairing: &PairingGuard) -> Result<()> { async fn persist_pairing_tokens(config: Arc<Mutex<Config>>, pairing: &PairingGuard) -> Result<()> {
let paired_tokens = pairing.tokens(); let paired_tokens = pairing.tokens();
let mut cfg = config.lock(); // This is needed because parking_lot's guard is not Send so we clone the inner
// this should be removed once async mutexes are used everywhere
let mut cfg = { config.lock().clone() };
cfg.gateway.paired_tokens = paired_tokens; cfg.gateway.paired_tokens = paired_tokens;
cfg.save() cfg.save()
.await
.context("Failed to persist paired tokens to config.toml") .context("Failed to persist paired tokens to config.toml")
} }
@ -1398,15 +1402,15 @@ mod tests {
let mut config = Config::default(); let mut config = Config::default();
config.config_path = config_path.clone(); config.config_path = config_path.clone();
config.workspace_dir = workspace_path; config.workspace_dir = workspace_path;
config.save().unwrap(); config.save().await.unwrap();
let guard = PairingGuard::new(true, &[]); let guard = PairingGuard::new(true, &[]);
let code = guard.pairing_code().unwrap(); let code = guard.pairing_code().unwrap();
let token = guard.try_pair(&code).unwrap().unwrap(); let token = guard.try_pair(&code).await.unwrap().unwrap();
assert!(guard.is_authenticated(&token)); assert!(guard.is_authenticated(&token));
let shared_config = Arc::new(Mutex::new(config)); let shared_config = Arc::new(Mutex::new(config));
persist_pairing_tokens(&shared_config, &guard).unwrap(); persist_pairing_tokens(shared_config, &guard).await.unwrap();
let saved = tokio::fs::read_to_string(config_path).await.unwrap(); let saved = tokio::fs::read_to_string(config_path).await.unwrap();
let parsed: Config = toml::from_str(&saved).unwrap(); let parsed: Config = toml::from_str(&saved).unwrap();

View file

@ -553,21 +553,19 @@ async fn main() -> Result<()> {
{ {
bail!("--channels-only does not accept --api-key, --provider, --model, or --memory"); bail!("--channels-only does not accept --api-key, --provider, --model, or --memory");
} }
let config = tokio::task::spawn_blocking(move || { let config = if channels_only {
if channels_only { onboard::run_channels_repair_wizard().await
onboard::run_channels_repair_wizard() } else if interactive {
} else if interactive { onboard::run_wizard().await
onboard::run_wizard() } else {
} else { onboard::run_quick_setup(
onboard::run_quick_setup( api_key.as_deref(),
api_key.as_deref(), provider.as_deref(),
provider.as_deref(), model.as_deref(),
model.as_deref(), memory.as_deref(),
memory.as_deref(), )
) .await
} }?;
})
.await??;
// Auto-start channels if user said yes during wizard // Auto-start channels if user said yes during wizard
if std::env::var("ZEROCLAW_AUTOSTART_CHANNELS").as_deref() == Ok("1") { if std::env::var("ZEROCLAW_AUTOSTART_CHANNELS").as_deref() == Ok("1") {
channels::start_channels(config).await?; channels::start_channels(config).await?;
@ -576,7 +574,7 @@ async fn main() -> Result<()> {
} }
// All other commands need config loaded first // All other commands need config loaded first
let mut config = Config::load_or_init()?; let mut config = Config::load_or_init().await?;
config.apply_env_overrides(); config.apply_env_overrides();
match cli.command { match cli.command {
@ -764,7 +762,7 @@ async fn main() -> Result<()> {
Commands::Channel { channel_command } => match channel_command { Commands::Channel { channel_command } => match channel_command {
ChannelCommands::Start => channels::start_channels(config).await, ChannelCommands::Start => channels::start_channels(config).await,
ChannelCommands::Doctor => channels::doctor_channels(config).await, ChannelCommands::Doctor => channels::doctor_channels(config).await,
other => channels::handle_command(other, &config), other => channels::handle_command(other, &config).await,
}, },
Commands::Integrations { Commands::Integrations {
@ -786,7 +784,7 @@ async fn main() -> Result<()> {
} }
Commands::Peripheral { peripheral_command } => { Commands::Peripheral { peripheral_command } => {
peripherals::handle_command(peripheral_command.clone(), &config) peripherals::handle_command(peripheral_command.clone(), &config).await
} }
Commands::Config { config_command } => match config_command { Commands::Config { config_command } => match config_command {

View file

@ -95,7 +95,7 @@ fn has_launchable_channels(channels: &ChannelsConfig) -> bool {
// ── Main wizard entry point ────────────────────────────────────── // ── Main wizard entry point ──────────────────────────────────────
pub fn run_wizard() -> Result<Config> { pub async fn run_wizard() -> Result<Config> {
println!("{}", style(BANNER).cyan().bold()); println!("{}", style(BANNER).cyan().bold());
println!( println!(
@ -191,8 +191,8 @@ pub fn run_wizard() -> Result<Config> {
if config.memory.auto_save { "on" } else { "off" } if config.memory.auto_save { "on" } else { "off" }
); );
config.save()?; config.save().await?;
persist_workspace_selection(&config.config_path)?; persist_workspace_selection(&config.config_path).await?;
// ── Final summary ──────────────────────────────────────────── // ── Final summary ────────────────────────────────────────────
print_summary(&config); print_summary(&config);
@ -226,7 +226,7 @@ pub fn run_wizard() -> Result<Config> {
} }
/// Interactive repair flow: rerun channel setup only without redoing full onboarding. /// Interactive repair flow: rerun channel setup only without redoing full onboarding.
pub fn run_channels_repair_wizard() -> Result<Config> { pub async fn run_channels_repair_wizard() -> Result<Config> {
println!("{}", style(BANNER).cyan().bold()); println!("{}", style(BANNER).cyan().bold());
println!( println!(
" {}", " {}",
@ -236,12 +236,12 @@ pub fn run_channels_repair_wizard() -> Result<Config> {
); );
println!(); println!();
let mut config = Config::load_or_init()?; let mut config = Config::load_or_init().await?;
print_step(1, 1, "Channels (How You Talk to ZeroClaw)"); print_step(1, 1, "Channels (How You Talk to ZeroClaw)");
config.channels_config = setup_channels()?; config.channels_config = setup_channels()?;
config.save()?; config.save().await?;
persist_workspace_selection(&config.config_path)?; persist_workspace_selection(&config.config_path).await?;
println!(); println!();
println!( println!(
@ -321,7 +321,7 @@ fn memory_config_defaults_for_backend(backend: &str) -> MemoryConfig {
} }
#[allow(clippy::too_many_lines)] #[allow(clippy::too_many_lines)]
pub fn run_quick_setup( pub async fn run_quick_setup(
credential_override: Option<&str>, credential_override: Option<&str>,
provider: Option<&str>, provider: Option<&str>,
model_override: Option<&str>, model_override: Option<&str>,
@ -396,8 +396,8 @@ pub fn run_quick_setup(
query_classification: crate::config::QueryClassificationConfig::default(), query_classification: crate::config::QueryClassificationConfig::default(),
}; };
config.save()?; config.save().await?;
persist_workspace_selection(&config.config_path)?; persist_workspace_selection(&config.config_path).await?;
// Scaffold minimal workspace files // Scaffold minimal workspace files
let default_ctx = ProjectContext { let default_ctx = ProjectContext {
@ -1459,16 +1459,18 @@ fn print_bullet(text: &str) {
println!(" {} {}", style("").cyan(), text); println!(" {} {}", style("").cyan(), text);
} }
fn persist_workspace_selection(config_path: &Path) -> Result<()> { async fn persist_workspace_selection(config_path: &Path) -> Result<()> {
let config_dir = config_path let config_dir = config_path
.parent() .parent()
.context("Config path must have a parent directory")?; .context("Config path must have a parent directory")?;
crate::config::schema::persist_active_workspace_config_dir(config_dir).with_context(|| { crate::config::schema::persist_active_workspace_config_dir(config_dir)
format!( .await
"Failed to persist active workspace selection for {}", .with_context(|| {
config_dir.display() format!(
) "Failed to persist active workspace selection for {}",
}) config_dir.display()
)
})
} }
// ── Step 1: Workspace ──────────────────────────────────────────── // ── Step 1: Workspace ────────────────────────────────────────────

View file

@ -42,7 +42,7 @@ pub fn list_configured_boards(config: &PeripheralsConfig) -> Vec<&PeripheralBoar
/// Handle `zeroclaw peripheral` subcommands. /// Handle `zeroclaw peripheral` subcommands.
#[allow(clippy::module_name_repetitions)] #[allow(clippy::module_name_repetitions)]
pub fn handle_command(cmd: crate::PeripheralCommands, config: &Config) -> Result<()> { pub async fn handle_command(cmd: crate::PeripheralCommands, config: &Config) -> Result<()> {
match cmd { match cmd {
crate::PeripheralCommands::List => { crate::PeripheralCommands::List => {
let boards = list_configured_boards(&config.peripherals); let boards = list_configured_boards(&config.peripherals);
@ -76,7 +76,7 @@ pub fn handle_command(cmd: crate::PeripheralCommands, config: &Config) -> Result
Some(path.clone()) Some(path.clone())
}; };
let mut cfg = crate::config::Config::load_or_init()?; let mut cfg = crate::config::Config::load_or_init().await?;
cfg.peripherals.enabled = true; cfg.peripherals.enabled = true;
if cfg if cfg
@ -95,7 +95,7 @@ pub fn handle_command(cmd: crate::PeripheralCommands, config: &Config) -> Result
path: path_opt, path: path_opt,
baud: 115_200, baud: 115_200,
}); });
cfg.save()?; cfg.save().await?;
println!("Added {} at {}. Restart daemon to apply.", board, path); println!("Added {} at {}. Restart daemon to apply.", board, path);
} }
#[cfg(feature = "hardware")] #[cfg(feature = "hardware")]

View file

@ -11,6 +11,7 @@
use parking_lot::Mutex; use parking_lot::Mutex;
use sha2::{Digest, Sha256}; use sha2::{Digest, Sha256};
use std::collections::HashSet; use std::collections::HashSet;
use std::sync::Arc;
use std::time::Instant; use std::time::Instant;
/// Maximum failed pairing attempts before lockout. /// Maximum failed pairing attempts before lockout.
@ -23,16 +24,17 @@ const PAIR_LOCKOUT_SECS: u64 = 300; // 5 minutes
/// Bearer tokens are stored as SHA-256 hashes to prevent plaintext exposure /// Bearer tokens are stored as SHA-256 hashes to prevent plaintext exposure
/// in config files. When a new token is generated, the plaintext is returned /// in config files. When a new token is generated, the plaintext is returned
/// to the client once, and only the hash is retained. /// to the client once, and only the hash is retained.
#[derive(Debug)] // TODO: I've just made this work with parking_lot but it should use either flume or tokio's async mutexes
#[derive(Debug, Clone)]
pub struct PairingGuard { pub struct PairingGuard {
/// Whether pairing is required at all. /// Whether pairing is required at all.
require_pairing: bool, require_pairing: bool,
/// One-time pairing code (generated on startup, consumed on first pair). /// One-time pairing code (generated on startup, consumed on first pair).
pairing_code: Mutex<Option<String>>, pairing_code: Arc<Mutex<Option<String>>>,
/// Set of SHA-256 hashed bearer tokens (persisted across restarts). /// Set of SHA-256 hashed bearer tokens (persisted across restarts).
paired_tokens: Mutex<HashSet<String>>, paired_tokens: Arc<Mutex<HashSet<String>>>,
/// Brute-force protection: failed attempt counter + lockout time. /// Brute-force protection: failed attempt counter + lockout time.
failed_attempts: Mutex<(u32, Option<Instant>)>, failed_attempts: Arc<Mutex<(u32, Option<Instant>)>>,
} }
impl PairingGuard { impl PairingGuard {
@ -62,9 +64,9 @@ impl PairingGuard {
}; };
Self { Self {
require_pairing, require_pairing,
pairing_code: Mutex::new(code), pairing_code: Arc::new(Mutex::new(code)),
paired_tokens: Mutex::new(tokens), paired_tokens: Arc::new(Mutex::new(tokens)),
failed_attempts: Mutex::new((0, None)), failed_attempts: Arc::new(Mutex::new((0, None))),
} }
} }
@ -78,9 +80,7 @@ impl PairingGuard {
self.require_pairing self.require_pairing
} }
/// Attempt to pair with the given code. Returns a bearer token on success. fn try_pair_blocking(&self, code: &str) -> Result<Option<String>, u64> {
/// Returns `Err(lockout_seconds)` if locked out due to brute force.
pub fn try_pair(&self, code: &str) -> Result<Option<String>, u64> {
// Check brute force lockout // Check brute force lockout
{ {
let attempts = self.failed_attempts.lock(); let attempts = self.failed_attempts.lock();
@ -127,6 +127,19 @@ impl PairingGuard {
Ok(None) Ok(None)
} }
/// Attempt to pair with the given code. Returns a bearer token on success.
/// Returns `Err(lockout_seconds)` if locked out due to brute force.
pub async fn try_pair(&self, code: &str) -> Result<Option<String>, u64> {
let this = self.clone();
let code = code.to_string();
// TODO: make this function the main one without spawning a task
let handle = tokio::task::spawn_blocking(move || this.try_pair_blocking(&code));
handle
.await
.expect("failed to spawn blocking task this should not happen")
}
/// Check if a bearer token is valid (compares against stored hashes). /// Check if a bearer token is valid (compares against stored hashes).
pub fn is_authenticated(&self, token: &str) -> bool { pub fn is_authenticated(&self, token: &str) -> bool {
if !self.require_pairing { if !self.require_pairing {
@ -232,63 +245,64 @@ pub fn is_public_bind(host: &str) -> bool {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use tokio::test;
// ── PairingGuard ───────────────────────────────────────── // ── PairingGuard ─────────────────────────────────────────
#[test] #[test]
fn new_guard_generates_code_when_no_tokens() { async fn new_guard_generates_code_when_no_tokens() {
let guard = PairingGuard::new(true, &[]); let guard = PairingGuard::new(true, &[]);
assert!(guard.pairing_code().is_some()); assert!(guard.pairing_code().is_some());
assert!(!guard.is_paired()); assert!(!guard.is_paired());
} }
#[test] #[test]
fn new_guard_no_code_when_tokens_exist() { async fn new_guard_no_code_when_tokens_exist() {
let guard = PairingGuard::new(true, &["zc_existing".into()]); let guard = PairingGuard::new(true, &["zc_existing".into()]);
assert!(guard.pairing_code().is_none()); assert!(guard.pairing_code().is_none());
assert!(guard.is_paired()); assert!(guard.is_paired());
} }
#[test] #[test]
fn new_guard_no_code_when_pairing_disabled() { async fn new_guard_no_code_when_pairing_disabled() {
let guard = PairingGuard::new(false, &[]); let guard = PairingGuard::new(false, &[]);
assert!(guard.pairing_code().is_none()); assert!(guard.pairing_code().is_none());
} }
#[test] #[test]
fn try_pair_correct_code() { async fn try_pair_correct_code() {
let guard = PairingGuard::new(true, &[]); let guard = PairingGuard::new(true, &[]);
let code = guard.pairing_code().unwrap().to_string(); let code = guard.pairing_code().unwrap().to_string();
let token = guard.try_pair(&code).unwrap(); let token = guard.try_pair(&code).await.unwrap();
assert!(token.is_some()); assert!(token.is_some());
assert!(token.unwrap().starts_with("zc_")); assert!(token.unwrap().starts_with("zc_"));
assert!(guard.is_paired()); assert!(guard.is_paired());
} }
#[test] #[test]
fn try_pair_wrong_code() { async fn try_pair_wrong_code() {
let guard = PairingGuard::new(true, &[]); let guard = PairingGuard::new(true, &[]);
let result = guard.try_pair("000000").unwrap(); let result = guard.try_pair("000000").await.unwrap();
// Might succeed if code happens to be 000000, but extremely unlikely // Might succeed if code happens to be 000000, but extremely unlikely
// Just check it returns Ok(None) normally // Just check it returns Ok(None) normally
let _ = result; let _ = result;
} }
#[test] #[test]
fn try_pair_empty_code() { async fn try_pair_empty_code() {
let guard = PairingGuard::new(true, &[]); let guard = PairingGuard::new(true, &[]);
assert!(guard.try_pair("").unwrap().is_none()); assert!(guard.try_pair("").await.unwrap().is_none());
} }
#[test] #[test]
fn is_authenticated_with_valid_token() { async fn is_authenticated_with_valid_token() {
// Pass plaintext token — PairingGuard hashes it on load // Pass plaintext token — PairingGuard hashes it on load
let guard = PairingGuard::new(true, &["zc_valid".into()]); let guard = PairingGuard::new(true, &["zc_valid".into()]);
assert!(guard.is_authenticated("zc_valid")); assert!(guard.is_authenticated("zc_valid"));
} }
#[test] #[test]
fn is_authenticated_with_prehashed_token() { async fn is_authenticated_with_prehashed_token() {
// Pass an already-hashed token (64 hex chars) // Pass an already-hashed token (64 hex chars)
let hashed = hash_token("zc_valid"); let hashed = hash_token("zc_valid");
let guard = PairingGuard::new(true, &[hashed]); let guard = PairingGuard::new(true, &[hashed]);
@ -296,20 +310,20 @@ mod tests {
} }
#[test] #[test]
fn is_authenticated_with_invalid_token() { async fn is_authenticated_with_invalid_token() {
let guard = PairingGuard::new(true, &["zc_valid".into()]); let guard = PairingGuard::new(true, &["zc_valid".into()]);
assert!(!guard.is_authenticated("zc_invalid")); assert!(!guard.is_authenticated("zc_invalid"));
} }
#[test] #[test]
fn is_authenticated_when_pairing_disabled() { async fn is_authenticated_when_pairing_disabled() {
let guard = PairingGuard::new(false, &[]); let guard = PairingGuard::new(false, &[]);
assert!(guard.is_authenticated("anything")); assert!(guard.is_authenticated("anything"));
assert!(guard.is_authenticated("")); assert!(guard.is_authenticated(""));
} }
#[test] #[test]
fn tokens_returns_hashes() { async fn tokens_returns_hashes() {
let guard = PairingGuard::new(true, &["zc_a".into(), "zc_b".into()]); let guard = PairingGuard::new(true, &["zc_a".into(), "zc_b".into()]);
let tokens = guard.tokens(); let tokens = guard.tokens();
assert_eq!(tokens.len(), 2); assert_eq!(tokens.len(), 2);
@ -322,10 +336,10 @@ mod tests {
} }
#[test] #[test]
fn pair_then_authenticate() { async fn pair_then_authenticate() {
let guard = PairingGuard::new(true, &[]); let guard = PairingGuard::new(true, &[]);
let code = guard.pairing_code().unwrap().to_string(); let code = guard.pairing_code().unwrap().to_string();
let token = guard.try_pair(&code).unwrap().unwrap(); let token = guard.try_pair(&code).await.unwrap().unwrap();
assert!(guard.is_authenticated(&token)); assert!(guard.is_authenticated(&token));
assert!(!guard.is_authenticated("wrong")); assert!(!guard.is_authenticated("wrong"));
} }
@ -333,24 +347,24 @@ mod tests {
// ── Token hashing ──────────────────────────────────────── // ── Token hashing ────────────────────────────────────────
#[test] #[test]
fn hash_token_produces_64_hex_chars() { async fn hash_token_produces_64_hex_chars() {
let hash = hash_token("zc_test_token"); let hash = hash_token("zc_test_token");
assert_eq!(hash.len(), 64); assert_eq!(hash.len(), 64);
assert!(hash.chars().all(|c| c.is_ascii_hexdigit())); assert!(hash.chars().all(|c| c.is_ascii_hexdigit()));
} }
#[test] #[test]
fn hash_token_is_deterministic() { async fn hash_token_is_deterministic() {
assert_eq!(hash_token("zc_abc"), hash_token("zc_abc")); assert_eq!(hash_token("zc_abc"), hash_token("zc_abc"));
} }
#[test] #[test]
fn hash_token_differs_for_different_inputs() { async fn hash_token_differs_for_different_inputs() {
assert_ne!(hash_token("zc_a"), hash_token("zc_b")); assert_ne!(hash_token("zc_a"), hash_token("zc_b"));
} }
#[test] #[test]
fn is_token_hash_detects_hash_vs_plaintext() { async fn is_token_hash_detects_hash_vs_plaintext() {
assert!(is_token_hash(&hash_token("zc_test"))); assert!(is_token_hash(&hash_token("zc_test")));
assert!(!is_token_hash("zc_test_token")); assert!(!is_token_hash("zc_test_token"));
assert!(!is_token_hash("too_short")); assert!(!is_token_hash("too_short"));
@ -360,7 +374,7 @@ mod tests {
// ── is_public_bind ─────────────────────────────────────── // ── is_public_bind ───────────────────────────────────────
#[test] #[test]
fn localhost_variants_not_public() { async fn localhost_variants_not_public() {
assert!(!is_public_bind("127.0.0.1")); assert!(!is_public_bind("127.0.0.1"));
assert!(!is_public_bind("localhost")); assert!(!is_public_bind("localhost"));
assert!(!is_public_bind("::1")); assert!(!is_public_bind("::1"));
@ -368,12 +382,12 @@ mod tests {
} }
#[test] #[test]
fn zero_zero_is_public() { async fn zero_zero_is_public() {
assert!(is_public_bind("0.0.0.0")); assert!(is_public_bind("0.0.0.0"));
} }
#[test] #[test]
fn real_ip_is_public() { async fn real_ip_is_public() {
assert!(is_public_bind("192.168.1.100")); assert!(is_public_bind("192.168.1.100"));
assert!(is_public_bind("10.0.0.1")); assert!(is_public_bind("10.0.0.1"));
} }
@ -381,13 +395,13 @@ mod tests {
// ── constant_time_eq ───────────────────────────────────── // ── constant_time_eq ─────────────────────────────────────
#[test] #[test]
fn constant_time_eq_same() { async fn constant_time_eq_same() {
assert!(constant_time_eq("abc", "abc")); assert!(constant_time_eq("abc", "abc"));
assert!(constant_time_eq("", "")); assert!(constant_time_eq("", ""));
} }
#[test] #[test]
fn constant_time_eq_different() { async fn constant_time_eq_different() {
assert!(!constant_time_eq("abc", "abd")); assert!(!constant_time_eq("abc", "abd"));
assert!(!constant_time_eq("abc", "ab")); assert!(!constant_time_eq("abc", "ab"));
assert!(!constant_time_eq("a", "")); assert!(!constant_time_eq("a", ""));
@ -396,14 +410,14 @@ mod tests {
// ── generate helpers ───────────────────────────────────── // ── generate helpers ─────────────────────────────────────
#[test] #[test]
fn generate_code_is_6_digits() { async fn generate_code_is_6_digits() {
let code = generate_code(); let code = generate_code();
assert_eq!(code.len(), 6); assert_eq!(code.len(), 6);
assert!(code.chars().all(|c| c.is_ascii_digit())); assert!(code.chars().all(|c| c.is_ascii_digit()));
} }
#[test] #[test]
fn generate_code_is_not_deterministic() { async fn generate_code_is_not_deterministic() {
// Two codes should differ with overwhelming probability. We try // Two codes should differ with overwhelming probability. We try
// multiple pairs so a single 1-in-10^6 collision doesn't cause // multiple pairs so a single 1-in-10^6 collision doesn't cause
// a flaky CI failure. All 10 pairs colliding is ~1-in-10^60. // a flaky CI failure. All 10 pairs colliding is ~1-in-10^60.
@ -416,7 +430,7 @@ mod tests {
} }
#[test] #[test]
fn generate_token_has_prefix_and_hex_payload() { async fn generate_token_has_prefix_and_hex_payload() {
let token = generate_token(); let token = generate_token();
let payload = token let payload = token
.strip_prefix("zc_") .strip_prefix("zc_")
@ -434,15 +448,15 @@ mod tests {
// ── Brute force protection ─────────────────────────────── // ── Brute force protection ───────────────────────────────
#[test] #[test]
fn brute_force_lockout_after_max_attempts() { async fn brute_force_lockout_after_max_attempts() {
let guard = PairingGuard::new(true, &[]); let guard = PairingGuard::new(true, &[]);
// Exhaust all attempts with wrong codes // Exhaust all attempts with wrong codes
for i in 0..MAX_PAIR_ATTEMPTS { for i in 0..MAX_PAIR_ATTEMPTS {
let result = guard.try_pair(&format!("wrong_{i}")); let result = guard.try_pair(&format!("wrong_{i}")).await;
assert!(result.is_ok(), "Attempt {i} should not be locked out yet"); assert!(result.is_ok(), "Attempt {i} should not be locked out yet");
} }
// Next attempt should be locked out // Next attempt should be locked out
let result = guard.try_pair("another_wrong"); let result = guard.try_pair("another_wrong").await;
assert!( assert!(
result.is_err(), result.is_err(),
"Should be locked out after {MAX_PAIR_ATTEMPTS} attempts" "Should be locked out after {MAX_PAIR_ATTEMPTS} attempts"
@ -456,25 +470,25 @@ mod tests {
} }
#[test] #[test]
fn correct_code_resets_failed_attempts() { async fn correct_code_resets_failed_attempts() {
let guard = PairingGuard::new(true, &[]); let guard = PairingGuard::new(true, &[]);
let code = guard.pairing_code().unwrap().to_string(); let code = guard.pairing_code().unwrap().to_string();
// Fail a few times // Fail a few times
for _ in 0..3 { for _ in 0..3 {
let _ = guard.try_pair("wrong"); let _ = guard.try_pair("wrong").await;
} }
// Correct code should still work (under MAX_PAIR_ATTEMPTS) // Correct code should still work (under MAX_PAIR_ATTEMPTS)
let result = guard.try_pair(&code).unwrap(); let result = guard.try_pair(&code).await.unwrap();
assert!(result.is_some(), "Correct code should work before lockout"); assert!(result.is_some(), "Correct code should work before lockout");
} }
#[test] #[test]
fn lockout_returns_remaining_seconds() { async fn lockout_returns_remaining_seconds() {
let guard = PairingGuard::new(true, &[]); let guard = PairingGuard::new(true, &[]);
for _ in 0..MAX_PAIR_ATTEMPTS { for _ in 0..MAX_PAIR_ATTEMPTS {
let _ = guard.try_pair("wrong"); let _ = guard.try_pair("wrong").await;
} }
let err = guard.try_pair("wrong").unwrap_err(); let err = guard.try_pair("wrong").await.unwrap_err();
// Should be close to PAIR_LOCKOUT_SECS (within a second) // Should be close to PAIR_LOCKOUT_SECS (within a second)
assert!( assert!(
err >= PAIR_LOCKOUT_SECS - 1, err >= PAIR_LOCKOUT_SECS - 1,

View file

@ -3,6 +3,7 @@ use crate::config::{
runtime_proxy_config, set_runtime_proxy_config, Config, ProxyConfig, ProxyScope, runtime_proxy_config, set_runtime_proxy_config, Config, ProxyConfig, ProxyScope,
}; };
use crate::security::SecurityPolicy; use crate::security::SecurityPolicy;
use crate::util::MaybeSet;
use async_trait::async_trait; use async_trait::async_trait;
use serde_json::{json, Value}; use serde_json::{json, Value};
use std::fs; use std::fs;
@ -93,16 +94,13 @@ impl ProxyConfigTool {
anyhow::bail!("'{field}' must be a string or string[]") anyhow::bail!("'{field}' must be a string or string[]")
} }
fn parse_optional_string_update( fn parse_optional_string_update(args: &Value, field: &str) -> anyhow::Result<MaybeSet<String>> {
args: &Value,
field: &str,
) -> anyhow::Result<Option<Option<String>>> {
let Some(raw) = args.get(field) else { let Some(raw) = args.get(field) else {
return Ok(None); return Ok(MaybeSet::Unset);
}; };
if raw.is_null() { if raw.is_null() {
return Ok(Some(None)); return Ok(MaybeSet::Null);
} }
let value = raw let value = raw
@ -110,7 +108,13 @@ impl ProxyConfigTool {
.ok_or_else(|| anyhow::anyhow!("'{field}' must be a string or null"))? .ok_or_else(|| anyhow::anyhow!("'{field}' must be a string or null"))?
.trim() .trim()
.to_string(); .to_string();
Ok(Some((!value.is_empty()).then_some(value)))
let output = if value.is_empty() {
MaybeSet::Null
} else {
MaybeSet::Set(value)
};
Ok(output)
} }
fn env_snapshot() -> Value { fn env_snapshot() -> Value {
@ -164,7 +168,7 @@ impl ProxyConfigTool {
}) })
} }
fn handle_set(&self, args: &Value) -> anyhow::Result<ToolResult> { async fn handle_set(&self, args: &Value) -> anyhow::Result<ToolResult> {
let mut cfg = self.load_config_without_env()?; let mut cfg = self.load_config_without_env()?;
let previous_scope = cfg.proxy.scope; let previous_scope = cfg.proxy.scope;
let mut proxy = cfg.proxy.clone(); let mut proxy = cfg.proxy.clone();
@ -185,23 +189,24 @@ impl ProxyConfigTool {
})?; })?;
} }
if let Some(update) = Self::parse_optional_string_update(args, "http_proxy")? { if let MaybeSet::Set(update) = Self::parse_optional_string_update(args, "http_proxy")? {
proxy.http_proxy = update; proxy.http_proxy = Some(update);
touched_proxy_url = true; touched_proxy_url = true;
} }
if let Some(update) = Self::parse_optional_string_update(args, "https_proxy")? { if let MaybeSet::Set(update) = Self::parse_optional_string_update(args, "https_proxy")? {
proxy.https_proxy = update; proxy.https_proxy = Some(update);
touched_proxy_url = true; touched_proxy_url = true;
} }
if let Some(update) = Self::parse_optional_string_update(args, "all_proxy")? { if let MaybeSet::Set(update) = Self::parse_optional_string_update(args, "all_proxy")? {
proxy.all_proxy = update; proxy.all_proxy = Some(update);
touched_proxy_url = true; touched_proxy_url = true;
} }
if let Some(no_proxy_raw) = args.get("no_proxy") { if let Some(no_proxy_raw) = args.get("no_proxy") {
proxy.no_proxy = Self::parse_string_list(no_proxy_raw, "no_proxy")?; proxy.no_proxy = Self::parse_string_list(no_proxy_raw, "no_proxy")?;
touched_proxy_url = true;
} }
if let Some(services_raw) = args.get("services") { if let Some(services_raw) = args.get("services") {
@ -217,7 +222,7 @@ impl ProxyConfigTool {
proxy.validate()?; proxy.validate()?;
cfg.proxy = proxy.clone(); cfg.proxy = proxy.clone();
cfg.save()?; cfg.save().await?;
set_runtime_proxy_config(proxy.clone()); set_runtime_proxy_config(proxy.clone());
if proxy.enabled && proxy.scope == ProxyScope::Environment { if proxy.enabled && proxy.scope == ProxyScope::Environment {
@ -237,11 +242,11 @@ impl ProxyConfigTool {
}) })
} }
fn handle_disable(&self, args: &Value) -> anyhow::Result<ToolResult> { async fn handle_disable(&self, args: &Value) -> anyhow::Result<ToolResult> {
let mut cfg = self.load_config_without_env()?; let mut cfg = self.load_config_without_env()?;
let clear_env_default = cfg.proxy.scope == ProxyScope::Environment; let clear_env_default = cfg.proxy.scope == ProxyScope::Environment;
cfg.proxy.enabled = false; cfg.proxy.enabled = false;
cfg.save()?; cfg.save().await?;
set_runtime_proxy_config(cfg.proxy.clone()); set_runtime_proxy_config(cfg.proxy.clone());
@ -384,8 +389,8 @@ impl Tool for ProxyConfigTool {
} }
match action.as_str() { match action.as_str() {
"set" => self.handle_set(&args), "set" => self.handle_set(&args).await,
"disable" => self.handle_disable(&args), "disable" => self.handle_disable(&args).await,
"apply_env" => self.handle_apply_env(), "apply_env" => self.handle_apply_env(),
"clear_env" => self.handle_clear_env(), "clear_env" => self.handle_clear_env(),
_ => unreachable!("handled above"), _ => unreachable!("handled above"),
@ -421,20 +426,20 @@ mod tests {
}) })
} }
fn test_config(tmp: &TempDir) -> Arc<Config> { async fn test_config(tmp: &TempDir) -> Arc<Config> {
let config = Config { let config = Config {
workspace_dir: tmp.path().join("workspace"), workspace_dir: tmp.path().join("workspace"),
config_path: tmp.path().join("config.toml"), config_path: tmp.path().join("config.toml"),
..Config::default() ..Config::default()
}; };
config.save().unwrap(); config.save().await.unwrap();
Arc::new(config) Arc::new(config)
} }
#[tokio::test] #[tokio::test]
async fn list_services_action_returns_known_keys() { async fn list_services_action_returns_known_keys() {
let tmp = TempDir::new().unwrap(); let tmp = TempDir::new().unwrap();
let tool = ProxyConfigTool::new(test_config(&tmp), test_security()); let tool = ProxyConfigTool::new(test_config(&tmp).await, test_security());
let result = tool let result = tool
.execute(json!({"action": "list_services"})) .execute(json!({"action": "list_services"}))
@ -448,7 +453,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn set_scope_services_requires_services_entries() { async fn set_scope_services_requires_services_entries() {
let tmp = TempDir::new().unwrap(); let tmp = TempDir::new().unwrap();
let tool = ProxyConfigTool::new(test_config(&tmp), test_security()); let tool = ProxyConfigTool::new(test_config(&tmp).await, test_security());
let result = tool let result = tool
.execute(json!({ .execute(json!({
@ -471,7 +476,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn set_and_get_round_trip_proxy_scope() { async fn set_and_get_round_trip_proxy_scope() {
let tmp = TempDir::new().unwrap(); let tmp = TempDir::new().unwrap();
let tool = ProxyConfigTool::new(test_config(&tmp), test_security()); let tool = ProxyConfigTool::new(test_config(&tmp).await, test_security());
let set_result = tool let set_result = tool
.execute(json!({ .execute(json!({

View file

@ -43,6 +43,13 @@ pub fn truncate_with_ellipsis(s: &str, max_chars: usize) -> String {
} }
} }
/// Utility enum for handling optional values.
pub enum MaybeSet<T> {
Set(T),
Unset,
Null,
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;