chore: Remove more blocking io calls
This commit is contained in:
parent
1aec9ad9c0
commit
f1ca73d3d2
14 changed files with 427 additions and 357 deletions
13
Cargo.lock
generated
13
Cargo.lock
generated
|
|
@ -385,6 +385,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||
checksum = "8b52af3cb4058c895d37317bb27508dccc8e5f2d39454016b297bf4a400597b8"
|
||||
dependencies = [
|
||||
"axum-core",
|
||||
"axum-macros",
|
||||
"base64",
|
||||
"bytes",
|
||||
"form_urlencoded",
|
||||
|
|
@ -431,6 +432,17 @@ dependencies = [
|
|||
"tower-service",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "axum-macros"
|
||||
version = "0.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "604fde5e028fea851ce1d8570bbdc034bec850d157f7569d10f347d06808c05c"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.116",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "backon"
|
||||
version = "1.6.0"
|
||||
|
|
@ -7734,6 +7746,7 @@ dependencies = [
|
|||
"tokio",
|
||||
"tokio-rustls",
|
||||
"tokio-serial",
|
||||
"tokio-stream",
|
||||
"tokio-tungstenite 0.24.0",
|
||||
"tokio-util",
|
||||
"toml 1.0.1+spec-1.1.0",
|
||||
|
|
|
|||
|
|
@ -120,7 +120,7 @@ mail-parser = "0.11.2"
|
|||
async-imap = { version = "0.11",features = ["runtime-tokio"], default-features = false }
|
||||
|
||||
# HTTP server (gateway) — replaces raw TCP for proper HTTP/1.1 compliance
|
||||
axum = { version = "0.8", default-features = false, features = ["http1", "json", "tokio", "query", "ws"] }
|
||||
axum = { version = "0.8", default-features = false, features = ["http1", "json", "tokio", "query", "ws", "macros"] }
|
||||
tower = { version = "0.5", default-features = false }
|
||||
tower-http = { version = "0.6", default-features = false, features = ["limit", "timeout"] }
|
||||
http-body-util = "0.1"
|
||||
|
|
@ -141,6 +141,7 @@ probe-rs = { version = "0.30", optional = true }
|
|||
|
||||
# PDF extraction for datasheet RAG (optional, enable with --features rag-pdf)
|
||||
pdf-extract = { version = "0.10", optional = true }
|
||||
tokio-stream = { version = "0.1.18", features = ["full"] }
|
||||
|
||||
# WhatsApp Web client (wa-rs) — optional, enable with --features whatsapp-web
|
||||
# Uses wa-rs for Bot and Client, wa-rs-core for storage traits, custom rusqlite backend avoids Diesel conflict.
|
||||
|
|
|
|||
|
|
@ -26,7 +26,7 @@ pub fn run_wizard() -> Result<Config> {
|
|||
security: SecurityConfig::autodetect(), // Silent!
|
||||
};
|
||||
|
||||
config.save()?;
|
||||
config.save().await?;
|
||||
Ok(config)
|
||||
}
|
||||
```
|
||||
|
|
|
|||
|
|
@ -726,7 +726,7 @@ mod tests {
|
|||
"!r:m".to_string(),
|
||||
vec![],
|
||||
Some(" ".to_string()),
|
||||
Some("".to_string()),
|
||||
Some(String::new()),
|
||||
);
|
||||
|
||||
assert!(ch.session_owner_hint.is_none());
|
||||
|
|
|
|||
|
|
@ -1117,7 +1117,7 @@ fn normalize_telegram_identity(value: &str) -> String {
|
|||
value.trim().trim_start_matches('@').to_string()
|
||||
}
|
||||
|
||||
fn bind_telegram_identity(config: &Config, identity: &str) -> Result<()> {
|
||||
async fn bind_telegram_identity(config: &Config, identity: &str) -> Result<()> {
|
||||
let normalized = normalize_telegram_identity(identity);
|
||||
if normalized.is_empty() {
|
||||
anyhow::bail!("Telegram identity cannot be empty");
|
||||
|
|
@ -1147,7 +1147,7 @@ fn bind_telegram_identity(config: &Config, identity: &str) -> Result<()> {
|
|||
}
|
||||
|
||||
telegram.allowed_users.push(normalized.clone());
|
||||
updated.save()?;
|
||||
updated.save().await?;
|
||||
println!("✅ Bound Telegram identity: {normalized}");
|
||||
println!(" Saved to {}", updated.config_path.display());
|
||||
match maybe_restart_managed_daemon_service() {
|
||||
|
|
@ -1243,7 +1243,7 @@ fn maybe_restart_managed_daemon_service() -> Result<bool> {
|
|||
Ok(false)
|
||||
}
|
||||
|
||||
pub fn handle_command(command: crate::ChannelCommands, config: &Config) -> Result<()> {
|
||||
pub async fn handle_command(command: crate::ChannelCommands, config: &Config) -> Result<()> {
|
||||
match command {
|
||||
crate::ChannelCommands::Start => {
|
||||
anyhow::bail!("Start must be handled in main.rs (requires async runtime)")
|
||||
|
|
@ -1290,7 +1290,7 @@ pub fn handle_command(command: crate::ChannelCommands, config: &Config) -> Resul
|
|||
anyhow::bail!("Remove channel '{name}' — edit ~/.zeroclaw/config.toml directly");
|
||||
}
|
||||
crate::ChannelCommands::BindTelegram { identity } => {
|
||||
bind_telegram_identity(config, &identity)
|
||||
bind_telegram_identity(config, &identity).await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,10 +6,10 @@ use async_trait::async_trait;
|
|||
use directories::UserDirs;
|
||||
use parking_lot::Mutex;
|
||||
use reqwest::multipart::{Form, Part};
|
||||
use std::fs;
|
||||
use std::path::Path;
|
||||
use std::sync::{Arc, RwLock};
|
||||
use std::time::Duration;
|
||||
use tokio::fs;
|
||||
|
||||
/// Telegram's maximum message length for text messages
|
||||
const TELEGRAM_MAX_MESSAGE_LENGTH: usize = 4096;
|
||||
|
|
@ -373,7 +373,7 @@ impl TelegramChannel {
|
|||
.collect()
|
||||
}
|
||||
|
||||
fn load_config_without_env() -> anyhow::Result<Config> {
|
||||
async fn load_config_without_env() -> anyhow::Result<Config> {
|
||||
let home = UserDirs::new()
|
||||
.map(|u| u.home_dir().to_path_buf())
|
||||
.context("Could not find home directory")?;
|
||||
|
|
@ -381,6 +381,7 @@ impl TelegramChannel {
|
|||
let config_path = zeroclaw_dir.join("config.toml");
|
||||
|
||||
let contents = fs::read_to_string(&config_path)
|
||||
.await
|
||||
.with_context(|| format!("Failed to read config file: {}", config_path.display()))?;
|
||||
let mut config: Config = toml::from_str(&contents)
|
||||
.context("Failed to parse config file for Telegram binding")?;
|
||||
|
|
@ -389,8 +390,8 @@ impl TelegramChannel {
|
|||
Ok(config)
|
||||
}
|
||||
|
||||
fn persist_allowed_identity_blocking(identity: &str) -> anyhow::Result<()> {
|
||||
let mut config = Self::load_config_without_env()?;
|
||||
async fn persist_allowed_identity(&self, identity: &str) -> anyhow::Result<()> {
|
||||
let mut config = Self::load_config_without_env().await?;
|
||||
let Some(telegram) = config.channels_config.telegram.as_mut() else {
|
||||
anyhow::bail!("Telegram channel config is missing in config.toml");
|
||||
};
|
||||
|
|
@ -404,20 +405,13 @@ impl TelegramChannel {
|
|||
telegram.allowed_users.push(normalized);
|
||||
config
|
||||
.save()
|
||||
.await
|
||||
.context("Failed to persist Telegram allowlist to config.toml")?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn persist_allowed_identity(&self, identity: &str) -> anyhow::Result<()> {
|
||||
let identity = identity.to_string();
|
||||
tokio::task::spawn_blocking(move || Self::persist_allowed_identity_blocking(&identity))
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("Failed to join Telegram bind save task: {e}"))??;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn add_allowed_identity_runtime(&self, identity: &str) {
|
||||
let normalized = Self::normalize_identity(identity);
|
||||
if normalized.is_empty() {
|
||||
|
|
@ -629,7 +623,7 @@ impl TelegramChannel {
|
|||
|
||||
if let Some(code) = Self::extract_bind_code(text) {
|
||||
if let Some(pairing) = self.pairing.as_ref() {
|
||||
match pairing.try_pair(code) {
|
||||
match pairing.try_pair(code).await {
|
||||
Ok(Some(_token)) => {
|
||||
let bind_identity = normalized_sender_id.clone().or_else(|| {
|
||||
if normalized_username.is_empty() || normalized_username == "unknown" {
|
||||
|
|
|
|||
File diff suppressed because it is too large
Load diff
|
|
@ -587,6 +587,7 @@ async fn handle_metrics(State(state): State<AppState>) -> impl IntoResponse {
|
|||
}
|
||||
|
||||
/// POST /pair — exchange one-time code for bearer token
|
||||
#[axum::debug_handler]
|
||||
async fn handle_pair(
|
||||
State(state): State<AppState>,
|
||||
ConnectInfo(peer_addr): ConnectInfo<SocketAddr>,
|
||||
|
|
@ -608,10 +609,10 @@ async fn handle_pair(
|
|||
.and_then(|v| v.to_str().ok())
|
||||
.unwrap_or("");
|
||||
|
||||
match state.pairing.try_pair(code) {
|
||||
match state.pairing.try_pair(code).await {
|
||||
Ok(Some(token)) => {
|
||||
tracing::info!("🔐 New client paired successfully");
|
||||
if let Err(err) = persist_pairing_tokens(&state.config, &state.pairing) {
|
||||
if let Err(err) = persist_pairing_tokens(state.config.clone(), &state.pairing).await {
|
||||
tracing::error!("🔐 Pairing succeeded but token persistence failed: {err:#}");
|
||||
let body = serde_json::json!({
|
||||
"paired": true,
|
||||
|
|
@ -648,11 +649,14 @@ async fn handle_pair(
|
|||
}
|
||||
}
|
||||
|
||||
fn persist_pairing_tokens(config: &Arc<Mutex<Config>>, pairing: &PairingGuard) -> Result<()> {
|
||||
async fn persist_pairing_tokens(config: Arc<Mutex<Config>>, pairing: &PairingGuard) -> Result<()> {
|
||||
let paired_tokens = pairing.tokens();
|
||||
let mut cfg = config.lock();
|
||||
// This is needed because parking_lot's guard is not Send so we clone the inner
|
||||
// this should be removed once async mutexes are used everywhere
|
||||
let mut cfg = { config.lock().clone() };
|
||||
cfg.gateway.paired_tokens = paired_tokens;
|
||||
cfg.save()
|
||||
.await
|
||||
.context("Failed to persist paired tokens to config.toml")
|
||||
}
|
||||
|
||||
|
|
@ -1398,15 +1402,15 @@ mod tests {
|
|||
let mut config = Config::default();
|
||||
config.config_path = config_path.clone();
|
||||
config.workspace_dir = workspace_path;
|
||||
config.save().unwrap();
|
||||
config.save().await.unwrap();
|
||||
|
||||
let guard = PairingGuard::new(true, &[]);
|
||||
let code = guard.pairing_code().unwrap();
|
||||
let token = guard.try_pair(&code).unwrap().unwrap();
|
||||
let token = guard.try_pair(&code).await.unwrap().unwrap();
|
||||
assert!(guard.is_authenticated(&token));
|
||||
|
||||
let shared_config = Arc::new(Mutex::new(config));
|
||||
persist_pairing_tokens(&shared_config, &guard).unwrap();
|
||||
persist_pairing_tokens(shared_config, &guard).await.unwrap();
|
||||
|
||||
let saved = tokio::fs::read_to_string(config_path).await.unwrap();
|
||||
let parsed: Config = toml::from_str(&saved).unwrap();
|
||||
|
|
|
|||
34
src/main.rs
34
src/main.rs
|
|
@ -553,21 +553,19 @@ async fn main() -> Result<()> {
|
|||
{
|
||||
bail!("--channels-only does not accept --api-key, --provider, --model, or --memory");
|
||||
}
|
||||
let config = tokio::task::spawn_blocking(move || {
|
||||
if channels_only {
|
||||
onboard::run_channels_repair_wizard()
|
||||
} else if interactive {
|
||||
onboard::run_wizard()
|
||||
} else {
|
||||
onboard::run_quick_setup(
|
||||
api_key.as_deref(),
|
||||
provider.as_deref(),
|
||||
model.as_deref(),
|
||||
memory.as_deref(),
|
||||
)
|
||||
}
|
||||
})
|
||||
.await??;
|
||||
let config = if channels_only {
|
||||
onboard::run_channels_repair_wizard().await
|
||||
} else if interactive {
|
||||
onboard::run_wizard().await
|
||||
} else {
|
||||
onboard::run_quick_setup(
|
||||
api_key.as_deref(),
|
||||
provider.as_deref(),
|
||||
model.as_deref(),
|
||||
memory.as_deref(),
|
||||
)
|
||||
.await
|
||||
}?;
|
||||
// Auto-start channels if user said yes during wizard
|
||||
if std::env::var("ZEROCLAW_AUTOSTART_CHANNELS").as_deref() == Ok("1") {
|
||||
channels::start_channels(config).await?;
|
||||
|
|
@ -576,7 +574,7 @@ async fn main() -> Result<()> {
|
|||
}
|
||||
|
||||
// All other commands need config loaded first
|
||||
let mut config = Config::load_or_init()?;
|
||||
let mut config = Config::load_or_init().await?;
|
||||
config.apply_env_overrides();
|
||||
|
||||
match cli.command {
|
||||
|
|
@ -764,7 +762,7 @@ async fn main() -> Result<()> {
|
|||
Commands::Channel { channel_command } => match channel_command {
|
||||
ChannelCommands::Start => channels::start_channels(config).await,
|
||||
ChannelCommands::Doctor => channels::doctor_channels(config).await,
|
||||
other => channels::handle_command(other, &config),
|
||||
other => channels::handle_command(other, &config).await,
|
||||
},
|
||||
|
||||
Commands::Integrations {
|
||||
|
|
@ -786,7 +784,7 @@ async fn main() -> Result<()> {
|
|||
}
|
||||
|
||||
Commands::Peripheral { peripheral_command } => {
|
||||
peripherals::handle_command(peripheral_command.clone(), &config)
|
||||
peripherals::handle_command(peripheral_command.clone(), &config).await
|
||||
}
|
||||
|
||||
Commands::Config { config_command } => match config_command {
|
||||
|
|
|
|||
|
|
@ -95,7 +95,7 @@ fn has_launchable_channels(channels: &ChannelsConfig) -> bool {
|
|||
|
||||
// ── Main wizard entry point ──────────────────────────────────────
|
||||
|
||||
pub fn run_wizard() -> Result<Config> {
|
||||
pub async fn run_wizard() -> Result<Config> {
|
||||
println!("{}", style(BANNER).cyan().bold());
|
||||
|
||||
println!(
|
||||
|
|
@ -191,8 +191,8 @@ pub fn run_wizard() -> Result<Config> {
|
|||
if config.memory.auto_save { "on" } else { "off" }
|
||||
);
|
||||
|
||||
config.save()?;
|
||||
persist_workspace_selection(&config.config_path)?;
|
||||
config.save().await?;
|
||||
persist_workspace_selection(&config.config_path).await?;
|
||||
|
||||
// ── Final summary ────────────────────────────────────────────
|
||||
print_summary(&config);
|
||||
|
|
@ -226,7 +226,7 @@ pub fn run_wizard() -> Result<Config> {
|
|||
}
|
||||
|
||||
/// Interactive repair flow: rerun channel setup only without redoing full onboarding.
|
||||
pub fn run_channels_repair_wizard() -> Result<Config> {
|
||||
pub async fn run_channels_repair_wizard() -> Result<Config> {
|
||||
println!("{}", style(BANNER).cyan().bold());
|
||||
println!(
|
||||
" {}",
|
||||
|
|
@ -236,12 +236,12 @@ pub fn run_channels_repair_wizard() -> Result<Config> {
|
|||
);
|
||||
println!();
|
||||
|
||||
let mut config = Config::load_or_init()?;
|
||||
let mut config = Config::load_or_init().await?;
|
||||
|
||||
print_step(1, 1, "Channels (How You Talk to ZeroClaw)");
|
||||
config.channels_config = setup_channels()?;
|
||||
config.save()?;
|
||||
persist_workspace_selection(&config.config_path)?;
|
||||
config.save().await?;
|
||||
persist_workspace_selection(&config.config_path).await?;
|
||||
|
||||
println!();
|
||||
println!(
|
||||
|
|
@ -321,7 +321,7 @@ fn memory_config_defaults_for_backend(backend: &str) -> MemoryConfig {
|
|||
}
|
||||
|
||||
#[allow(clippy::too_many_lines)]
|
||||
pub fn run_quick_setup(
|
||||
pub async fn run_quick_setup(
|
||||
credential_override: Option<&str>,
|
||||
provider: Option<&str>,
|
||||
model_override: Option<&str>,
|
||||
|
|
@ -396,8 +396,8 @@ pub fn run_quick_setup(
|
|||
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||
};
|
||||
|
||||
config.save()?;
|
||||
persist_workspace_selection(&config.config_path)?;
|
||||
config.save().await?;
|
||||
persist_workspace_selection(&config.config_path).await?;
|
||||
|
||||
// Scaffold minimal workspace files
|
||||
let default_ctx = ProjectContext {
|
||||
|
|
@ -1459,16 +1459,18 @@ fn print_bullet(text: &str) {
|
|||
println!(" {} {}", style("›").cyan(), text);
|
||||
}
|
||||
|
||||
fn persist_workspace_selection(config_path: &Path) -> Result<()> {
|
||||
async fn persist_workspace_selection(config_path: &Path) -> Result<()> {
|
||||
let config_dir = config_path
|
||||
.parent()
|
||||
.context("Config path must have a parent directory")?;
|
||||
crate::config::schema::persist_active_workspace_config_dir(config_dir).with_context(|| {
|
||||
format!(
|
||||
"Failed to persist active workspace selection for {}",
|
||||
config_dir.display()
|
||||
)
|
||||
})
|
||||
crate::config::schema::persist_active_workspace_config_dir(config_dir)
|
||||
.await
|
||||
.with_context(|| {
|
||||
format!(
|
||||
"Failed to persist active workspace selection for {}",
|
||||
config_dir.display()
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
// ── Step 1: Workspace ────────────────────────────────────────────
|
||||
|
|
|
|||
|
|
@ -42,7 +42,7 @@ pub fn list_configured_boards(config: &PeripheralsConfig) -> Vec<&PeripheralBoar
|
|||
|
||||
/// Handle `zeroclaw peripheral` subcommands.
|
||||
#[allow(clippy::module_name_repetitions)]
|
||||
pub fn handle_command(cmd: crate::PeripheralCommands, config: &Config) -> Result<()> {
|
||||
pub async fn handle_command(cmd: crate::PeripheralCommands, config: &Config) -> Result<()> {
|
||||
match cmd {
|
||||
crate::PeripheralCommands::List => {
|
||||
let boards = list_configured_boards(&config.peripherals);
|
||||
|
|
@ -76,7 +76,7 @@ pub fn handle_command(cmd: crate::PeripheralCommands, config: &Config) -> Result
|
|||
Some(path.clone())
|
||||
};
|
||||
|
||||
let mut cfg = crate::config::Config::load_or_init()?;
|
||||
let mut cfg = crate::config::Config::load_or_init().await?;
|
||||
cfg.peripherals.enabled = true;
|
||||
|
||||
if cfg
|
||||
|
|
@ -95,7 +95,7 @@ pub fn handle_command(cmd: crate::PeripheralCommands, config: &Config) -> Result
|
|||
path: path_opt,
|
||||
baud: 115_200,
|
||||
});
|
||||
cfg.save()?;
|
||||
cfg.save().await?;
|
||||
println!("Added {} at {}. Restart daemon to apply.", board, path);
|
||||
}
|
||||
#[cfg(feature = "hardware")]
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@
|
|||
use parking_lot::Mutex;
|
||||
use sha2::{Digest, Sha256};
|
||||
use std::collections::HashSet;
|
||||
use std::sync::Arc;
|
||||
use std::time::Instant;
|
||||
|
||||
/// Maximum failed pairing attempts before lockout.
|
||||
|
|
@ -23,16 +24,17 @@ const PAIR_LOCKOUT_SECS: u64 = 300; // 5 minutes
|
|||
/// Bearer tokens are stored as SHA-256 hashes to prevent plaintext exposure
|
||||
/// in config files. When a new token is generated, the plaintext is returned
|
||||
/// to the client once, and only the hash is retained.
|
||||
#[derive(Debug)]
|
||||
// TODO: I've just made this work with parking_lot but it should use either flume or tokio's async mutexes
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PairingGuard {
|
||||
/// Whether pairing is required at all.
|
||||
require_pairing: bool,
|
||||
/// One-time pairing code (generated on startup, consumed on first pair).
|
||||
pairing_code: Mutex<Option<String>>,
|
||||
pairing_code: Arc<Mutex<Option<String>>>,
|
||||
/// Set of SHA-256 hashed bearer tokens (persisted across restarts).
|
||||
paired_tokens: Mutex<HashSet<String>>,
|
||||
paired_tokens: Arc<Mutex<HashSet<String>>>,
|
||||
/// Brute-force protection: failed attempt counter + lockout time.
|
||||
failed_attempts: Mutex<(u32, Option<Instant>)>,
|
||||
failed_attempts: Arc<Mutex<(u32, Option<Instant>)>>,
|
||||
}
|
||||
|
||||
impl PairingGuard {
|
||||
|
|
@ -62,9 +64,9 @@ impl PairingGuard {
|
|||
};
|
||||
Self {
|
||||
require_pairing,
|
||||
pairing_code: Mutex::new(code),
|
||||
paired_tokens: Mutex::new(tokens),
|
||||
failed_attempts: Mutex::new((0, None)),
|
||||
pairing_code: Arc::new(Mutex::new(code)),
|
||||
paired_tokens: Arc::new(Mutex::new(tokens)),
|
||||
failed_attempts: Arc::new(Mutex::new((0, None))),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -78,9 +80,7 @@ impl PairingGuard {
|
|||
self.require_pairing
|
||||
}
|
||||
|
||||
/// Attempt to pair with the given code. Returns a bearer token on success.
|
||||
/// Returns `Err(lockout_seconds)` if locked out due to brute force.
|
||||
pub fn try_pair(&self, code: &str) -> Result<Option<String>, u64> {
|
||||
fn try_pair_blocking(&self, code: &str) -> Result<Option<String>, u64> {
|
||||
// Check brute force lockout
|
||||
{
|
||||
let attempts = self.failed_attempts.lock();
|
||||
|
|
@ -127,6 +127,19 @@ impl PairingGuard {
|
|||
Ok(None)
|
||||
}
|
||||
|
||||
/// Attempt to pair with the given code. Returns a bearer token on success.
|
||||
/// Returns `Err(lockout_seconds)` if locked out due to brute force.
|
||||
pub async fn try_pair(&self, code: &str) -> Result<Option<String>, u64> {
|
||||
let this = self.clone();
|
||||
let code = code.to_string();
|
||||
// TODO: make this function the main one without spawning a task
|
||||
let handle = tokio::task::spawn_blocking(move || this.try_pair_blocking(&code));
|
||||
|
||||
handle
|
||||
.await
|
||||
.expect("failed to spawn blocking task this should not happen")
|
||||
}
|
||||
|
||||
/// Check if a bearer token is valid (compares against stored hashes).
|
||||
pub fn is_authenticated(&self, token: &str) -> bool {
|
||||
if !self.require_pairing {
|
||||
|
|
@ -232,63 +245,64 @@ pub fn is_public_bind(host: &str) -> bool {
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use tokio::test;
|
||||
|
||||
// ── PairingGuard ─────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn new_guard_generates_code_when_no_tokens() {
|
||||
async fn new_guard_generates_code_when_no_tokens() {
|
||||
let guard = PairingGuard::new(true, &[]);
|
||||
assert!(guard.pairing_code().is_some());
|
||||
assert!(!guard.is_paired());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn new_guard_no_code_when_tokens_exist() {
|
||||
async fn new_guard_no_code_when_tokens_exist() {
|
||||
let guard = PairingGuard::new(true, &["zc_existing".into()]);
|
||||
assert!(guard.pairing_code().is_none());
|
||||
assert!(guard.is_paired());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn new_guard_no_code_when_pairing_disabled() {
|
||||
async fn new_guard_no_code_when_pairing_disabled() {
|
||||
let guard = PairingGuard::new(false, &[]);
|
||||
assert!(guard.pairing_code().is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn try_pair_correct_code() {
|
||||
async fn try_pair_correct_code() {
|
||||
let guard = PairingGuard::new(true, &[]);
|
||||
let code = guard.pairing_code().unwrap().to_string();
|
||||
let token = guard.try_pair(&code).unwrap();
|
||||
let token = guard.try_pair(&code).await.unwrap();
|
||||
assert!(token.is_some());
|
||||
assert!(token.unwrap().starts_with("zc_"));
|
||||
assert!(guard.is_paired());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn try_pair_wrong_code() {
|
||||
async fn try_pair_wrong_code() {
|
||||
let guard = PairingGuard::new(true, &[]);
|
||||
let result = guard.try_pair("000000").unwrap();
|
||||
let result = guard.try_pair("000000").await.unwrap();
|
||||
// Might succeed if code happens to be 000000, but extremely unlikely
|
||||
// Just check it returns Ok(None) normally
|
||||
let _ = result;
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn try_pair_empty_code() {
|
||||
async fn try_pair_empty_code() {
|
||||
let guard = PairingGuard::new(true, &[]);
|
||||
assert!(guard.try_pair("").unwrap().is_none());
|
||||
assert!(guard.try_pair("").await.unwrap().is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_authenticated_with_valid_token() {
|
||||
async fn is_authenticated_with_valid_token() {
|
||||
// Pass plaintext token — PairingGuard hashes it on load
|
||||
let guard = PairingGuard::new(true, &["zc_valid".into()]);
|
||||
assert!(guard.is_authenticated("zc_valid"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_authenticated_with_prehashed_token() {
|
||||
async fn is_authenticated_with_prehashed_token() {
|
||||
// Pass an already-hashed token (64 hex chars)
|
||||
let hashed = hash_token("zc_valid");
|
||||
let guard = PairingGuard::new(true, &[hashed]);
|
||||
|
|
@ -296,20 +310,20 @@ mod tests {
|
|||
}
|
||||
|
||||
#[test]
|
||||
fn is_authenticated_with_invalid_token() {
|
||||
async fn is_authenticated_with_invalid_token() {
|
||||
let guard = PairingGuard::new(true, &["zc_valid".into()]);
|
||||
assert!(!guard.is_authenticated("zc_invalid"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_authenticated_when_pairing_disabled() {
|
||||
async fn is_authenticated_when_pairing_disabled() {
|
||||
let guard = PairingGuard::new(false, &[]);
|
||||
assert!(guard.is_authenticated("anything"));
|
||||
assert!(guard.is_authenticated(""));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn tokens_returns_hashes() {
|
||||
async fn tokens_returns_hashes() {
|
||||
let guard = PairingGuard::new(true, &["zc_a".into(), "zc_b".into()]);
|
||||
let tokens = guard.tokens();
|
||||
assert_eq!(tokens.len(), 2);
|
||||
|
|
@ -322,10 +336,10 @@ mod tests {
|
|||
}
|
||||
|
||||
#[test]
|
||||
fn pair_then_authenticate() {
|
||||
async fn pair_then_authenticate() {
|
||||
let guard = PairingGuard::new(true, &[]);
|
||||
let code = guard.pairing_code().unwrap().to_string();
|
||||
let token = guard.try_pair(&code).unwrap().unwrap();
|
||||
let token = guard.try_pair(&code).await.unwrap().unwrap();
|
||||
assert!(guard.is_authenticated(&token));
|
||||
assert!(!guard.is_authenticated("wrong"));
|
||||
}
|
||||
|
|
@ -333,24 +347,24 @@ mod tests {
|
|||
// ── Token hashing ────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn hash_token_produces_64_hex_chars() {
|
||||
async fn hash_token_produces_64_hex_chars() {
|
||||
let hash = hash_token("zc_test_token");
|
||||
assert_eq!(hash.len(), 64);
|
||||
assert!(hash.chars().all(|c| c.is_ascii_hexdigit()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn hash_token_is_deterministic() {
|
||||
async fn hash_token_is_deterministic() {
|
||||
assert_eq!(hash_token("zc_abc"), hash_token("zc_abc"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn hash_token_differs_for_different_inputs() {
|
||||
async fn hash_token_differs_for_different_inputs() {
|
||||
assert_ne!(hash_token("zc_a"), hash_token("zc_b"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_token_hash_detects_hash_vs_plaintext() {
|
||||
async fn is_token_hash_detects_hash_vs_plaintext() {
|
||||
assert!(is_token_hash(&hash_token("zc_test")));
|
||||
assert!(!is_token_hash("zc_test_token"));
|
||||
assert!(!is_token_hash("too_short"));
|
||||
|
|
@ -360,7 +374,7 @@ mod tests {
|
|||
// ── is_public_bind ───────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn localhost_variants_not_public() {
|
||||
async fn localhost_variants_not_public() {
|
||||
assert!(!is_public_bind("127.0.0.1"));
|
||||
assert!(!is_public_bind("localhost"));
|
||||
assert!(!is_public_bind("::1"));
|
||||
|
|
@ -368,12 +382,12 @@ mod tests {
|
|||
}
|
||||
|
||||
#[test]
|
||||
fn zero_zero_is_public() {
|
||||
async fn zero_zero_is_public() {
|
||||
assert!(is_public_bind("0.0.0.0"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn real_ip_is_public() {
|
||||
async fn real_ip_is_public() {
|
||||
assert!(is_public_bind("192.168.1.100"));
|
||||
assert!(is_public_bind("10.0.0.1"));
|
||||
}
|
||||
|
|
@ -381,13 +395,13 @@ mod tests {
|
|||
// ── constant_time_eq ─────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn constant_time_eq_same() {
|
||||
async fn constant_time_eq_same() {
|
||||
assert!(constant_time_eq("abc", "abc"));
|
||||
assert!(constant_time_eq("", ""));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn constant_time_eq_different() {
|
||||
async fn constant_time_eq_different() {
|
||||
assert!(!constant_time_eq("abc", "abd"));
|
||||
assert!(!constant_time_eq("abc", "ab"));
|
||||
assert!(!constant_time_eq("a", ""));
|
||||
|
|
@ -396,14 +410,14 @@ mod tests {
|
|||
// ── generate helpers ─────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn generate_code_is_6_digits() {
|
||||
async fn generate_code_is_6_digits() {
|
||||
let code = generate_code();
|
||||
assert_eq!(code.len(), 6);
|
||||
assert!(code.chars().all(|c| c.is_ascii_digit()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn generate_code_is_not_deterministic() {
|
||||
async fn generate_code_is_not_deterministic() {
|
||||
// Two codes should differ with overwhelming probability. We try
|
||||
// multiple pairs so a single 1-in-10^6 collision doesn't cause
|
||||
// a flaky CI failure. All 10 pairs colliding is ~1-in-10^60.
|
||||
|
|
@ -416,7 +430,7 @@ mod tests {
|
|||
}
|
||||
|
||||
#[test]
|
||||
fn generate_token_has_prefix_and_hex_payload() {
|
||||
async fn generate_token_has_prefix_and_hex_payload() {
|
||||
let token = generate_token();
|
||||
let payload = token
|
||||
.strip_prefix("zc_")
|
||||
|
|
@ -434,15 +448,15 @@ mod tests {
|
|||
// ── Brute force protection ───────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn brute_force_lockout_after_max_attempts() {
|
||||
async fn brute_force_lockout_after_max_attempts() {
|
||||
let guard = PairingGuard::new(true, &[]);
|
||||
// Exhaust all attempts with wrong codes
|
||||
for i in 0..MAX_PAIR_ATTEMPTS {
|
||||
let result = guard.try_pair(&format!("wrong_{i}"));
|
||||
let result = guard.try_pair(&format!("wrong_{i}")).await;
|
||||
assert!(result.is_ok(), "Attempt {i} should not be locked out yet");
|
||||
}
|
||||
// Next attempt should be locked out
|
||||
let result = guard.try_pair("another_wrong");
|
||||
let result = guard.try_pair("another_wrong").await;
|
||||
assert!(
|
||||
result.is_err(),
|
||||
"Should be locked out after {MAX_PAIR_ATTEMPTS} attempts"
|
||||
|
|
@ -456,25 +470,25 @@ mod tests {
|
|||
}
|
||||
|
||||
#[test]
|
||||
fn correct_code_resets_failed_attempts() {
|
||||
async fn correct_code_resets_failed_attempts() {
|
||||
let guard = PairingGuard::new(true, &[]);
|
||||
let code = guard.pairing_code().unwrap().to_string();
|
||||
// Fail a few times
|
||||
for _ in 0..3 {
|
||||
let _ = guard.try_pair("wrong");
|
||||
let _ = guard.try_pair("wrong").await;
|
||||
}
|
||||
// Correct code should still work (under MAX_PAIR_ATTEMPTS)
|
||||
let result = guard.try_pair(&code).unwrap();
|
||||
let result = guard.try_pair(&code).await.unwrap();
|
||||
assert!(result.is_some(), "Correct code should work before lockout");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lockout_returns_remaining_seconds() {
|
||||
async fn lockout_returns_remaining_seconds() {
|
||||
let guard = PairingGuard::new(true, &[]);
|
||||
for _ in 0..MAX_PAIR_ATTEMPTS {
|
||||
let _ = guard.try_pair("wrong");
|
||||
let _ = guard.try_pair("wrong").await;
|
||||
}
|
||||
let err = guard.try_pair("wrong").unwrap_err();
|
||||
let err = guard.try_pair("wrong").await.unwrap_err();
|
||||
// Should be close to PAIR_LOCKOUT_SECS (within a second)
|
||||
assert!(
|
||||
err >= PAIR_LOCKOUT_SECS - 1,
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ use crate::config::{
|
|||
runtime_proxy_config, set_runtime_proxy_config, Config, ProxyConfig, ProxyScope,
|
||||
};
|
||||
use crate::security::SecurityPolicy;
|
||||
use crate::util::MaybeSet;
|
||||
use async_trait::async_trait;
|
||||
use serde_json::{json, Value};
|
||||
use std::fs;
|
||||
|
|
@ -93,16 +94,13 @@ impl ProxyConfigTool {
|
|||
anyhow::bail!("'{field}' must be a string or string[]")
|
||||
}
|
||||
|
||||
fn parse_optional_string_update(
|
||||
args: &Value,
|
||||
field: &str,
|
||||
) -> anyhow::Result<Option<Option<String>>> {
|
||||
fn parse_optional_string_update(args: &Value, field: &str) -> anyhow::Result<MaybeSet<String>> {
|
||||
let Some(raw) = args.get(field) else {
|
||||
return Ok(None);
|
||||
return Ok(MaybeSet::Unset);
|
||||
};
|
||||
|
||||
if raw.is_null() {
|
||||
return Ok(Some(None));
|
||||
return Ok(MaybeSet::Null);
|
||||
}
|
||||
|
||||
let value = raw
|
||||
|
|
@ -110,7 +108,13 @@ impl ProxyConfigTool {
|
|||
.ok_or_else(|| anyhow::anyhow!("'{field}' must be a string or null"))?
|
||||
.trim()
|
||||
.to_string();
|
||||
Ok(Some((!value.is_empty()).then_some(value)))
|
||||
|
||||
let output = if value.is_empty() {
|
||||
MaybeSet::Null
|
||||
} else {
|
||||
MaybeSet::Set(value)
|
||||
};
|
||||
Ok(output)
|
||||
}
|
||||
|
||||
fn env_snapshot() -> Value {
|
||||
|
|
@ -164,7 +168,7 @@ impl ProxyConfigTool {
|
|||
})
|
||||
}
|
||||
|
||||
fn handle_set(&self, args: &Value) -> anyhow::Result<ToolResult> {
|
||||
async fn handle_set(&self, args: &Value) -> anyhow::Result<ToolResult> {
|
||||
let mut cfg = self.load_config_without_env()?;
|
||||
let previous_scope = cfg.proxy.scope;
|
||||
let mut proxy = cfg.proxy.clone();
|
||||
|
|
@ -185,23 +189,24 @@ impl ProxyConfigTool {
|
|||
})?;
|
||||
}
|
||||
|
||||
if let Some(update) = Self::parse_optional_string_update(args, "http_proxy")? {
|
||||
proxy.http_proxy = update;
|
||||
if let MaybeSet::Set(update) = Self::parse_optional_string_update(args, "http_proxy")? {
|
||||
proxy.http_proxy = Some(update);
|
||||
touched_proxy_url = true;
|
||||
}
|
||||
|
||||
if let Some(update) = Self::parse_optional_string_update(args, "https_proxy")? {
|
||||
proxy.https_proxy = update;
|
||||
if let MaybeSet::Set(update) = Self::parse_optional_string_update(args, "https_proxy")? {
|
||||
proxy.https_proxy = Some(update);
|
||||
touched_proxy_url = true;
|
||||
}
|
||||
|
||||
if let Some(update) = Self::parse_optional_string_update(args, "all_proxy")? {
|
||||
proxy.all_proxy = update;
|
||||
if let MaybeSet::Set(update) = Self::parse_optional_string_update(args, "all_proxy")? {
|
||||
proxy.all_proxy = Some(update);
|
||||
touched_proxy_url = true;
|
||||
}
|
||||
|
||||
if let Some(no_proxy_raw) = args.get("no_proxy") {
|
||||
proxy.no_proxy = Self::parse_string_list(no_proxy_raw, "no_proxy")?;
|
||||
touched_proxy_url = true;
|
||||
}
|
||||
|
||||
if let Some(services_raw) = args.get("services") {
|
||||
|
|
@ -217,7 +222,7 @@ impl ProxyConfigTool {
|
|||
proxy.validate()?;
|
||||
|
||||
cfg.proxy = proxy.clone();
|
||||
cfg.save()?;
|
||||
cfg.save().await?;
|
||||
set_runtime_proxy_config(proxy.clone());
|
||||
|
||||
if proxy.enabled && proxy.scope == ProxyScope::Environment {
|
||||
|
|
@ -237,11 +242,11 @@ impl ProxyConfigTool {
|
|||
})
|
||||
}
|
||||
|
||||
fn handle_disable(&self, args: &Value) -> anyhow::Result<ToolResult> {
|
||||
async fn handle_disable(&self, args: &Value) -> anyhow::Result<ToolResult> {
|
||||
let mut cfg = self.load_config_without_env()?;
|
||||
let clear_env_default = cfg.proxy.scope == ProxyScope::Environment;
|
||||
cfg.proxy.enabled = false;
|
||||
cfg.save()?;
|
||||
cfg.save().await?;
|
||||
|
||||
set_runtime_proxy_config(cfg.proxy.clone());
|
||||
|
||||
|
|
@ -384,8 +389,8 @@ impl Tool for ProxyConfigTool {
|
|||
}
|
||||
|
||||
match action.as_str() {
|
||||
"set" => self.handle_set(&args),
|
||||
"disable" => self.handle_disable(&args),
|
||||
"set" => self.handle_set(&args).await,
|
||||
"disable" => self.handle_disable(&args).await,
|
||||
"apply_env" => self.handle_apply_env(),
|
||||
"clear_env" => self.handle_clear_env(),
|
||||
_ => unreachable!("handled above"),
|
||||
|
|
@ -421,20 +426,20 @@ mod tests {
|
|||
})
|
||||
}
|
||||
|
||||
fn test_config(tmp: &TempDir) -> Arc<Config> {
|
||||
async fn test_config(tmp: &TempDir) -> Arc<Config> {
|
||||
let config = Config {
|
||||
workspace_dir: tmp.path().join("workspace"),
|
||||
config_path: tmp.path().join("config.toml"),
|
||||
..Config::default()
|
||||
};
|
||||
config.save().unwrap();
|
||||
config.save().await.unwrap();
|
||||
Arc::new(config)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn list_services_action_returns_known_keys() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let tool = ProxyConfigTool::new(test_config(&tmp), test_security());
|
||||
let tool = ProxyConfigTool::new(test_config(&tmp).await, test_security());
|
||||
|
||||
let result = tool
|
||||
.execute(json!({"action": "list_services"}))
|
||||
|
|
@ -448,7 +453,7 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn set_scope_services_requires_services_entries() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let tool = ProxyConfigTool::new(test_config(&tmp), test_security());
|
||||
let tool = ProxyConfigTool::new(test_config(&tmp).await, test_security());
|
||||
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
|
|
@ -471,7 +476,7 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn set_and_get_round_trip_proxy_scope() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let tool = ProxyConfigTool::new(test_config(&tmp), test_security());
|
||||
let tool = ProxyConfigTool::new(test_config(&tmp).await, test_security());
|
||||
|
||||
let set_result = tool
|
||||
.execute(json!({
|
||||
|
|
|
|||
|
|
@ -43,6 +43,13 @@ pub fn truncate_with_ellipsis(s: &str, max_chars: usize) -> String {
|
|||
}
|
||||
}
|
||||
|
||||
/// Utility enum for handling optional values.
|
||||
pub enum MaybeSet<T> {
|
||||
Set(T),
|
||||
Unset,
|
||||
Null,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue