chore: Remove more blocking io calls

This commit is contained in:
Jayson Reis 2026-02-19 06:30:43 +00:00 committed by Chummy
parent 1aec9ad9c0
commit f1ca73d3d2
14 changed files with 427 additions and 357 deletions

View file

@ -726,7 +726,7 @@ mod tests {
"!r:m".to_string(),
vec![],
Some(" ".to_string()),
Some("".to_string()),
Some(String::new()),
);
assert!(ch.session_owner_hint.is_none());

View file

@ -1117,7 +1117,7 @@ fn normalize_telegram_identity(value: &str) -> String {
value.trim().trim_start_matches('@').to_string()
}
fn bind_telegram_identity(config: &Config, identity: &str) -> Result<()> {
async fn bind_telegram_identity(config: &Config, identity: &str) -> Result<()> {
let normalized = normalize_telegram_identity(identity);
if normalized.is_empty() {
anyhow::bail!("Telegram identity cannot be empty");
@ -1147,7 +1147,7 @@ fn bind_telegram_identity(config: &Config, identity: &str) -> Result<()> {
}
telegram.allowed_users.push(normalized.clone());
updated.save()?;
updated.save().await?;
println!("✅ Bound Telegram identity: {normalized}");
println!(" Saved to {}", updated.config_path.display());
match maybe_restart_managed_daemon_service() {
@ -1243,7 +1243,7 @@ fn maybe_restart_managed_daemon_service() -> Result<bool> {
Ok(false)
}
pub fn handle_command(command: crate::ChannelCommands, config: &Config) -> Result<()> {
pub async fn handle_command(command: crate::ChannelCommands, config: &Config) -> Result<()> {
match command {
crate::ChannelCommands::Start => {
anyhow::bail!("Start must be handled in main.rs (requires async runtime)")
@ -1290,7 +1290,7 @@ pub fn handle_command(command: crate::ChannelCommands, config: &Config) -> Resul
anyhow::bail!("Remove channel '{name}' — edit ~/.zeroclaw/config.toml directly");
}
crate::ChannelCommands::BindTelegram { identity } => {
bind_telegram_identity(config, &identity)
bind_telegram_identity(config, &identity).await
}
}
}

View file

@ -6,10 +6,10 @@ use async_trait::async_trait;
use directories::UserDirs;
use parking_lot::Mutex;
use reqwest::multipart::{Form, Part};
use std::fs;
use std::path::Path;
use std::sync::{Arc, RwLock};
use std::time::Duration;
use tokio::fs;
/// Telegram's maximum message length for text messages
const TELEGRAM_MAX_MESSAGE_LENGTH: usize = 4096;
@ -373,7 +373,7 @@ impl TelegramChannel {
.collect()
}
fn load_config_without_env() -> anyhow::Result<Config> {
async fn load_config_without_env() -> anyhow::Result<Config> {
let home = UserDirs::new()
.map(|u| u.home_dir().to_path_buf())
.context("Could not find home directory")?;
@ -381,6 +381,7 @@ impl TelegramChannel {
let config_path = zeroclaw_dir.join("config.toml");
let contents = fs::read_to_string(&config_path)
.await
.with_context(|| format!("Failed to read config file: {}", config_path.display()))?;
let mut config: Config = toml::from_str(&contents)
.context("Failed to parse config file for Telegram binding")?;
@ -389,8 +390,8 @@ impl TelegramChannel {
Ok(config)
}
fn persist_allowed_identity_blocking(identity: &str) -> anyhow::Result<()> {
let mut config = Self::load_config_without_env()?;
async fn persist_allowed_identity(&self, identity: &str) -> anyhow::Result<()> {
let mut config = Self::load_config_without_env().await?;
let Some(telegram) = config.channels_config.telegram.as_mut() else {
anyhow::bail!("Telegram channel config is missing in config.toml");
};
@ -404,20 +405,13 @@ impl TelegramChannel {
telegram.allowed_users.push(normalized);
config
.save()
.await
.context("Failed to persist Telegram allowlist to config.toml")?;
}
Ok(())
}
async fn persist_allowed_identity(&self, identity: &str) -> anyhow::Result<()> {
let identity = identity.to_string();
tokio::task::spawn_blocking(move || Self::persist_allowed_identity_blocking(&identity))
.await
.map_err(|e| anyhow::anyhow!("Failed to join Telegram bind save task: {e}"))??;
Ok(())
}
fn add_allowed_identity_runtime(&self, identity: &str) {
let normalized = Self::normalize_identity(identity);
if normalized.is_empty() {
@ -629,7 +623,7 @@ impl TelegramChannel {
if let Some(code) = Self::extract_bind_code(text) {
if let Some(pairing) = self.pairing.as_ref() {
match pairing.try_pair(code) {
match pairing.try_pair(code).await {
Ok(Some(_token)) => {
let bind_identity = normalized_sender_id.clone().or_else(|| {
if normalized_username.is_empty() || normalized_username == "unknown" {

File diff suppressed because it is too large Load diff

View file

@ -587,6 +587,7 @@ async fn handle_metrics(State(state): State<AppState>) -> impl IntoResponse {
}
/// POST /pair — exchange one-time code for bearer token
#[axum::debug_handler]
async fn handle_pair(
State(state): State<AppState>,
ConnectInfo(peer_addr): ConnectInfo<SocketAddr>,
@ -608,10 +609,10 @@ async fn handle_pair(
.and_then(|v| v.to_str().ok())
.unwrap_or("");
match state.pairing.try_pair(code) {
match state.pairing.try_pair(code).await {
Ok(Some(token)) => {
tracing::info!("🔐 New client paired successfully");
if let Err(err) = persist_pairing_tokens(&state.config, &state.pairing) {
if let Err(err) = persist_pairing_tokens(state.config.clone(), &state.pairing).await {
tracing::error!("🔐 Pairing succeeded but token persistence failed: {err:#}");
let body = serde_json::json!({
"paired": true,
@ -648,11 +649,14 @@ async fn handle_pair(
}
}
fn persist_pairing_tokens(config: &Arc<Mutex<Config>>, pairing: &PairingGuard) -> Result<()> {
async fn persist_pairing_tokens(config: Arc<Mutex<Config>>, pairing: &PairingGuard) -> Result<()> {
let paired_tokens = pairing.tokens();
let mut cfg = config.lock();
// This is needed because parking_lot's guard is not Send so we clone the inner
// this should be removed once async mutexes are used everywhere
let mut cfg = { config.lock().clone() };
cfg.gateway.paired_tokens = paired_tokens;
cfg.save()
.await
.context("Failed to persist paired tokens to config.toml")
}
@ -1398,15 +1402,15 @@ mod tests {
let mut config = Config::default();
config.config_path = config_path.clone();
config.workspace_dir = workspace_path;
config.save().unwrap();
config.save().await.unwrap();
let guard = PairingGuard::new(true, &[]);
let code = guard.pairing_code().unwrap();
let token = guard.try_pair(&code).unwrap().unwrap();
let token = guard.try_pair(&code).await.unwrap().unwrap();
assert!(guard.is_authenticated(&token));
let shared_config = Arc::new(Mutex::new(config));
persist_pairing_tokens(&shared_config, &guard).unwrap();
persist_pairing_tokens(shared_config, &guard).await.unwrap();
let saved = tokio::fs::read_to_string(config_path).await.unwrap();
let parsed: Config = toml::from_str(&saved).unwrap();

View file

@ -553,21 +553,19 @@ async fn main() -> Result<()> {
{
bail!("--channels-only does not accept --api-key, --provider, --model, or --memory");
}
let config = tokio::task::spawn_blocking(move || {
if channels_only {
onboard::run_channels_repair_wizard()
} else if interactive {
onboard::run_wizard()
} else {
onboard::run_quick_setup(
api_key.as_deref(),
provider.as_deref(),
model.as_deref(),
memory.as_deref(),
)
}
})
.await??;
let config = if channels_only {
onboard::run_channels_repair_wizard().await
} else if interactive {
onboard::run_wizard().await
} else {
onboard::run_quick_setup(
api_key.as_deref(),
provider.as_deref(),
model.as_deref(),
memory.as_deref(),
)
.await
}?;
// Auto-start channels if user said yes during wizard
if std::env::var("ZEROCLAW_AUTOSTART_CHANNELS").as_deref() == Ok("1") {
channels::start_channels(config).await?;
@ -576,7 +574,7 @@ async fn main() -> Result<()> {
}
// All other commands need config loaded first
let mut config = Config::load_or_init()?;
let mut config = Config::load_or_init().await?;
config.apply_env_overrides();
match cli.command {
@ -764,7 +762,7 @@ async fn main() -> Result<()> {
Commands::Channel { channel_command } => match channel_command {
ChannelCommands::Start => channels::start_channels(config).await,
ChannelCommands::Doctor => channels::doctor_channels(config).await,
other => channels::handle_command(other, &config),
other => channels::handle_command(other, &config).await,
},
Commands::Integrations {
@ -786,7 +784,7 @@ async fn main() -> Result<()> {
}
Commands::Peripheral { peripheral_command } => {
peripherals::handle_command(peripheral_command.clone(), &config)
peripherals::handle_command(peripheral_command.clone(), &config).await
}
Commands::Config { config_command } => match config_command {

View file

@ -95,7 +95,7 @@ fn has_launchable_channels(channels: &ChannelsConfig) -> bool {
// ── Main wizard entry point ──────────────────────────────────────
pub fn run_wizard() -> Result<Config> {
pub async fn run_wizard() -> Result<Config> {
println!("{}", style(BANNER).cyan().bold());
println!(
@ -191,8 +191,8 @@ pub fn run_wizard() -> Result<Config> {
if config.memory.auto_save { "on" } else { "off" }
);
config.save()?;
persist_workspace_selection(&config.config_path)?;
config.save().await?;
persist_workspace_selection(&config.config_path).await?;
// ── Final summary ────────────────────────────────────────────
print_summary(&config);
@ -226,7 +226,7 @@ pub fn run_wizard() -> Result<Config> {
}
/// Interactive repair flow: rerun channel setup only without redoing full onboarding.
pub fn run_channels_repair_wizard() -> Result<Config> {
pub async fn run_channels_repair_wizard() -> Result<Config> {
println!("{}", style(BANNER).cyan().bold());
println!(
" {}",
@ -236,12 +236,12 @@ pub fn run_channels_repair_wizard() -> Result<Config> {
);
println!();
let mut config = Config::load_or_init()?;
let mut config = Config::load_or_init().await?;
print_step(1, 1, "Channels (How You Talk to ZeroClaw)");
config.channels_config = setup_channels()?;
config.save()?;
persist_workspace_selection(&config.config_path)?;
config.save().await?;
persist_workspace_selection(&config.config_path).await?;
println!();
println!(
@ -321,7 +321,7 @@ fn memory_config_defaults_for_backend(backend: &str) -> MemoryConfig {
}
#[allow(clippy::too_many_lines)]
pub fn run_quick_setup(
pub async fn run_quick_setup(
credential_override: Option<&str>,
provider: Option<&str>,
model_override: Option<&str>,
@ -396,8 +396,8 @@ pub fn run_quick_setup(
query_classification: crate::config::QueryClassificationConfig::default(),
};
config.save()?;
persist_workspace_selection(&config.config_path)?;
config.save().await?;
persist_workspace_selection(&config.config_path).await?;
// Scaffold minimal workspace files
let default_ctx = ProjectContext {
@ -1459,16 +1459,18 @@ fn print_bullet(text: &str) {
println!(" {} {}", style("").cyan(), text);
}
fn persist_workspace_selection(config_path: &Path) -> Result<()> {
async fn persist_workspace_selection(config_path: &Path) -> Result<()> {
let config_dir = config_path
.parent()
.context("Config path must have a parent directory")?;
crate::config::schema::persist_active_workspace_config_dir(config_dir).with_context(|| {
format!(
"Failed to persist active workspace selection for {}",
config_dir.display()
)
})
crate::config::schema::persist_active_workspace_config_dir(config_dir)
.await
.with_context(|| {
format!(
"Failed to persist active workspace selection for {}",
config_dir.display()
)
})
}
// ── Step 1: Workspace ────────────────────────────────────────────

View file

@ -42,7 +42,7 @@ pub fn list_configured_boards(config: &PeripheralsConfig) -> Vec<&PeripheralBoar
/// Handle `zeroclaw peripheral` subcommands.
#[allow(clippy::module_name_repetitions)]
pub fn handle_command(cmd: crate::PeripheralCommands, config: &Config) -> Result<()> {
pub async fn handle_command(cmd: crate::PeripheralCommands, config: &Config) -> Result<()> {
match cmd {
crate::PeripheralCommands::List => {
let boards = list_configured_boards(&config.peripherals);
@ -76,7 +76,7 @@ pub fn handle_command(cmd: crate::PeripheralCommands, config: &Config) -> Result
Some(path.clone())
};
let mut cfg = crate::config::Config::load_or_init()?;
let mut cfg = crate::config::Config::load_or_init().await?;
cfg.peripherals.enabled = true;
if cfg
@ -95,7 +95,7 @@ pub fn handle_command(cmd: crate::PeripheralCommands, config: &Config) -> Result
path: path_opt,
baud: 115_200,
});
cfg.save()?;
cfg.save().await?;
println!("Added {} at {}. Restart daemon to apply.", board, path);
}
#[cfg(feature = "hardware")]

View file

@ -11,6 +11,7 @@
use parking_lot::Mutex;
use sha2::{Digest, Sha256};
use std::collections::HashSet;
use std::sync::Arc;
use std::time::Instant;
/// Maximum failed pairing attempts before lockout.
@ -23,16 +24,17 @@ const PAIR_LOCKOUT_SECS: u64 = 300; // 5 minutes
/// Bearer tokens are stored as SHA-256 hashes to prevent plaintext exposure
/// in config files. When a new token is generated, the plaintext is returned
/// to the client once, and only the hash is retained.
#[derive(Debug)]
// TODO: I've just made this work with parking_lot but it should use either flume or tokio's async mutexes
#[derive(Debug, Clone)]
pub struct PairingGuard {
/// Whether pairing is required at all.
require_pairing: bool,
/// One-time pairing code (generated on startup, consumed on first pair).
pairing_code: Mutex<Option<String>>,
pairing_code: Arc<Mutex<Option<String>>>,
/// Set of SHA-256 hashed bearer tokens (persisted across restarts).
paired_tokens: Mutex<HashSet<String>>,
paired_tokens: Arc<Mutex<HashSet<String>>>,
/// Brute-force protection: failed attempt counter + lockout time.
failed_attempts: Mutex<(u32, Option<Instant>)>,
failed_attempts: Arc<Mutex<(u32, Option<Instant>)>>,
}
impl PairingGuard {
@ -62,9 +64,9 @@ impl PairingGuard {
};
Self {
require_pairing,
pairing_code: Mutex::new(code),
paired_tokens: Mutex::new(tokens),
failed_attempts: Mutex::new((0, None)),
pairing_code: Arc::new(Mutex::new(code)),
paired_tokens: Arc::new(Mutex::new(tokens)),
failed_attempts: Arc::new(Mutex::new((0, None))),
}
}
@ -78,9 +80,7 @@ impl PairingGuard {
self.require_pairing
}
/// Attempt to pair with the given code. Returns a bearer token on success.
/// Returns `Err(lockout_seconds)` if locked out due to brute force.
pub fn try_pair(&self, code: &str) -> Result<Option<String>, u64> {
fn try_pair_blocking(&self, code: &str) -> Result<Option<String>, u64> {
// Check brute force lockout
{
let attempts = self.failed_attempts.lock();
@ -127,6 +127,19 @@ impl PairingGuard {
Ok(None)
}
/// Attempt to pair with the given code. Returns a bearer token on success.
/// Returns `Err(lockout_seconds)` if locked out due to brute force.
pub async fn try_pair(&self, code: &str) -> Result<Option<String>, u64> {
let this = self.clone();
let code = code.to_string();
// TODO: make this function the main one without spawning a task
let handle = tokio::task::spawn_blocking(move || this.try_pair_blocking(&code));
handle
.await
.expect("failed to spawn blocking task this should not happen")
}
/// Check if a bearer token is valid (compares against stored hashes).
pub fn is_authenticated(&self, token: &str) -> bool {
if !self.require_pairing {
@ -232,63 +245,64 @@ pub fn is_public_bind(host: &str) -> bool {
#[cfg(test)]
mod tests {
use super::*;
use tokio::test;
// ── PairingGuard ─────────────────────────────────────────
#[test]
fn new_guard_generates_code_when_no_tokens() {
async fn new_guard_generates_code_when_no_tokens() {
let guard = PairingGuard::new(true, &[]);
assert!(guard.pairing_code().is_some());
assert!(!guard.is_paired());
}
#[test]
fn new_guard_no_code_when_tokens_exist() {
async fn new_guard_no_code_when_tokens_exist() {
let guard = PairingGuard::new(true, &["zc_existing".into()]);
assert!(guard.pairing_code().is_none());
assert!(guard.is_paired());
}
#[test]
fn new_guard_no_code_when_pairing_disabled() {
async fn new_guard_no_code_when_pairing_disabled() {
let guard = PairingGuard::new(false, &[]);
assert!(guard.pairing_code().is_none());
}
#[test]
fn try_pair_correct_code() {
async fn try_pair_correct_code() {
let guard = PairingGuard::new(true, &[]);
let code = guard.pairing_code().unwrap().to_string();
let token = guard.try_pair(&code).unwrap();
let token = guard.try_pair(&code).await.unwrap();
assert!(token.is_some());
assert!(token.unwrap().starts_with("zc_"));
assert!(guard.is_paired());
}
#[test]
fn try_pair_wrong_code() {
async fn try_pair_wrong_code() {
let guard = PairingGuard::new(true, &[]);
let result = guard.try_pair("000000").unwrap();
let result = guard.try_pair("000000").await.unwrap();
// Might succeed if code happens to be 000000, but extremely unlikely
// Just check it returns Ok(None) normally
let _ = result;
}
#[test]
fn try_pair_empty_code() {
async fn try_pair_empty_code() {
let guard = PairingGuard::new(true, &[]);
assert!(guard.try_pair("").unwrap().is_none());
assert!(guard.try_pair("").await.unwrap().is_none());
}
#[test]
fn is_authenticated_with_valid_token() {
async fn is_authenticated_with_valid_token() {
// Pass plaintext token — PairingGuard hashes it on load
let guard = PairingGuard::new(true, &["zc_valid".into()]);
assert!(guard.is_authenticated("zc_valid"));
}
#[test]
fn is_authenticated_with_prehashed_token() {
async fn is_authenticated_with_prehashed_token() {
// Pass an already-hashed token (64 hex chars)
let hashed = hash_token("zc_valid");
let guard = PairingGuard::new(true, &[hashed]);
@ -296,20 +310,20 @@ mod tests {
}
#[test]
fn is_authenticated_with_invalid_token() {
async fn is_authenticated_with_invalid_token() {
let guard = PairingGuard::new(true, &["zc_valid".into()]);
assert!(!guard.is_authenticated("zc_invalid"));
}
#[test]
fn is_authenticated_when_pairing_disabled() {
async fn is_authenticated_when_pairing_disabled() {
let guard = PairingGuard::new(false, &[]);
assert!(guard.is_authenticated("anything"));
assert!(guard.is_authenticated(""));
}
#[test]
fn tokens_returns_hashes() {
async fn tokens_returns_hashes() {
let guard = PairingGuard::new(true, &["zc_a".into(), "zc_b".into()]);
let tokens = guard.tokens();
assert_eq!(tokens.len(), 2);
@ -322,10 +336,10 @@ mod tests {
}
#[test]
fn pair_then_authenticate() {
async fn pair_then_authenticate() {
let guard = PairingGuard::new(true, &[]);
let code = guard.pairing_code().unwrap().to_string();
let token = guard.try_pair(&code).unwrap().unwrap();
let token = guard.try_pair(&code).await.unwrap().unwrap();
assert!(guard.is_authenticated(&token));
assert!(!guard.is_authenticated("wrong"));
}
@ -333,24 +347,24 @@ mod tests {
// ── Token hashing ────────────────────────────────────────
#[test]
fn hash_token_produces_64_hex_chars() {
async fn hash_token_produces_64_hex_chars() {
let hash = hash_token("zc_test_token");
assert_eq!(hash.len(), 64);
assert!(hash.chars().all(|c| c.is_ascii_hexdigit()));
}
#[test]
fn hash_token_is_deterministic() {
async fn hash_token_is_deterministic() {
assert_eq!(hash_token("zc_abc"), hash_token("zc_abc"));
}
#[test]
fn hash_token_differs_for_different_inputs() {
async fn hash_token_differs_for_different_inputs() {
assert_ne!(hash_token("zc_a"), hash_token("zc_b"));
}
#[test]
fn is_token_hash_detects_hash_vs_plaintext() {
async fn is_token_hash_detects_hash_vs_plaintext() {
assert!(is_token_hash(&hash_token("zc_test")));
assert!(!is_token_hash("zc_test_token"));
assert!(!is_token_hash("too_short"));
@ -360,7 +374,7 @@ mod tests {
// ── is_public_bind ───────────────────────────────────────
#[test]
fn localhost_variants_not_public() {
async fn localhost_variants_not_public() {
assert!(!is_public_bind("127.0.0.1"));
assert!(!is_public_bind("localhost"));
assert!(!is_public_bind("::1"));
@ -368,12 +382,12 @@ mod tests {
}
#[test]
fn zero_zero_is_public() {
async fn zero_zero_is_public() {
assert!(is_public_bind("0.0.0.0"));
}
#[test]
fn real_ip_is_public() {
async fn real_ip_is_public() {
assert!(is_public_bind("192.168.1.100"));
assert!(is_public_bind("10.0.0.1"));
}
@ -381,13 +395,13 @@ mod tests {
// ── constant_time_eq ─────────────────────────────────────
#[test]
fn constant_time_eq_same() {
async fn constant_time_eq_same() {
assert!(constant_time_eq("abc", "abc"));
assert!(constant_time_eq("", ""));
}
#[test]
fn constant_time_eq_different() {
async fn constant_time_eq_different() {
assert!(!constant_time_eq("abc", "abd"));
assert!(!constant_time_eq("abc", "ab"));
assert!(!constant_time_eq("a", ""));
@ -396,14 +410,14 @@ mod tests {
// ── generate helpers ─────────────────────────────────────
#[test]
fn generate_code_is_6_digits() {
async fn generate_code_is_6_digits() {
let code = generate_code();
assert_eq!(code.len(), 6);
assert!(code.chars().all(|c| c.is_ascii_digit()));
}
#[test]
fn generate_code_is_not_deterministic() {
async fn generate_code_is_not_deterministic() {
// Two codes should differ with overwhelming probability. We try
// multiple pairs so a single 1-in-10^6 collision doesn't cause
// a flaky CI failure. All 10 pairs colliding is ~1-in-10^60.
@ -416,7 +430,7 @@ mod tests {
}
#[test]
fn generate_token_has_prefix_and_hex_payload() {
async fn generate_token_has_prefix_and_hex_payload() {
let token = generate_token();
let payload = token
.strip_prefix("zc_")
@ -434,15 +448,15 @@ mod tests {
// ── Brute force protection ───────────────────────────────
#[test]
fn brute_force_lockout_after_max_attempts() {
async fn brute_force_lockout_after_max_attempts() {
let guard = PairingGuard::new(true, &[]);
// Exhaust all attempts with wrong codes
for i in 0..MAX_PAIR_ATTEMPTS {
let result = guard.try_pair(&format!("wrong_{i}"));
let result = guard.try_pair(&format!("wrong_{i}")).await;
assert!(result.is_ok(), "Attempt {i} should not be locked out yet");
}
// Next attempt should be locked out
let result = guard.try_pair("another_wrong");
let result = guard.try_pair("another_wrong").await;
assert!(
result.is_err(),
"Should be locked out after {MAX_PAIR_ATTEMPTS} attempts"
@ -456,25 +470,25 @@ mod tests {
}
#[test]
fn correct_code_resets_failed_attempts() {
async fn correct_code_resets_failed_attempts() {
let guard = PairingGuard::new(true, &[]);
let code = guard.pairing_code().unwrap().to_string();
// Fail a few times
for _ in 0..3 {
let _ = guard.try_pair("wrong");
let _ = guard.try_pair("wrong").await;
}
// Correct code should still work (under MAX_PAIR_ATTEMPTS)
let result = guard.try_pair(&code).unwrap();
let result = guard.try_pair(&code).await.unwrap();
assert!(result.is_some(), "Correct code should work before lockout");
}
#[test]
fn lockout_returns_remaining_seconds() {
async fn lockout_returns_remaining_seconds() {
let guard = PairingGuard::new(true, &[]);
for _ in 0..MAX_PAIR_ATTEMPTS {
let _ = guard.try_pair("wrong");
let _ = guard.try_pair("wrong").await;
}
let err = guard.try_pair("wrong").unwrap_err();
let err = guard.try_pair("wrong").await.unwrap_err();
// Should be close to PAIR_LOCKOUT_SECS (within a second)
assert!(
err >= PAIR_LOCKOUT_SECS - 1,

View file

@ -3,6 +3,7 @@ use crate::config::{
runtime_proxy_config, set_runtime_proxy_config, Config, ProxyConfig, ProxyScope,
};
use crate::security::SecurityPolicy;
use crate::util::MaybeSet;
use async_trait::async_trait;
use serde_json::{json, Value};
use std::fs;
@ -93,16 +94,13 @@ impl ProxyConfigTool {
anyhow::bail!("'{field}' must be a string or string[]")
}
fn parse_optional_string_update(
args: &Value,
field: &str,
) -> anyhow::Result<Option<Option<String>>> {
fn parse_optional_string_update(args: &Value, field: &str) -> anyhow::Result<MaybeSet<String>> {
let Some(raw) = args.get(field) else {
return Ok(None);
return Ok(MaybeSet::Unset);
};
if raw.is_null() {
return Ok(Some(None));
return Ok(MaybeSet::Null);
}
let value = raw
@ -110,7 +108,13 @@ impl ProxyConfigTool {
.ok_or_else(|| anyhow::anyhow!("'{field}' must be a string or null"))?
.trim()
.to_string();
Ok(Some((!value.is_empty()).then_some(value)))
let output = if value.is_empty() {
MaybeSet::Null
} else {
MaybeSet::Set(value)
};
Ok(output)
}
fn env_snapshot() -> Value {
@ -164,7 +168,7 @@ impl ProxyConfigTool {
})
}
fn handle_set(&self, args: &Value) -> anyhow::Result<ToolResult> {
async fn handle_set(&self, args: &Value) -> anyhow::Result<ToolResult> {
let mut cfg = self.load_config_without_env()?;
let previous_scope = cfg.proxy.scope;
let mut proxy = cfg.proxy.clone();
@ -185,23 +189,24 @@ impl ProxyConfigTool {
})?;
}
if let Some(update) = Self::parse_optional_string_update(args, "http_proxy")? {
proxy.http_proxy = update;
if let MaybeSet::Set(update) = Self::parse_optional_string_update(args, "http_proxy")? {
proxy.http_proxy = Some(update);
touched_proxy_url = true;
}
if let Some(update) = Self::parse_optional_string_update(args, "https_proxy")? {
proxy.https_proxy = update;
if let MaybeSet::Set(update) = Self::parse_optional_string_update(args, "https_proxy")? {
proxy.https_proxy = Some(update);
touched_proxy_url = true;
}
if let Some(update) = Self::parse_optional_string_update(args, "all_proxy")? {
proxy.all_proxy = update;
if let MaybeSet::Set(update) = Self::parse_optional_string_update(args, "all_proxy")? {
proxy.all_proxy = Some(update);
touched_proxy_url = true;
}
if let Some(no_proxy_raw) = args.get("no_proxy") {
proxy.no_proxy = Self::parse_string_list(no_proxy_raw, "no_proxy")?;
touched_proxy_url = true;
}
if let Some(services_raw) = args.get("services") {
@ -217,7 +222,7 @@ impl ProxyConfigTool {
proxy.validate()?;
cfg.proxy = proxy.clone();
cfg.save()?;
cfg.save().await?;
set_runtime_proxy_config(proxy.clone());
if proxy.enabled && proxy.scope == ProxyScope::Environment {
@ -237,11 +242,11 @@ impl ProxyConfigTool {
})
}
fn handle_disable(&self, args: &Value) -> anyhow::Result<ToolResult> {
async fn handle_disable(&self, args: &Value) -> anyhow::Result<ToolResult> {
let mut cfg = self.load_config_without_env()?;
let clear_env_default = cfg.proxy.scope == ProxyScope::Environment;
cfg.proxy.enabled = false;
cfg.save()?;
cfg.save().await?;
set_runtime_proxy_config(cfg.proxy.clone());
@ -384,8 +389,8 @@ impl Tool for ProxyConfigTool {
}
match action.as_str() {
"set" => self.handle_set(&args),
"disable" => self.handle_disable(&args),
"set" => self.handle_set(&args).await,
"disable" => self.handle_disable(&args).await,
"apply_env" => self.handle_apply_env(),
"clear_env" => self.handle_clear_env(),
_ => unreachable!("handled above"),
@ -421,20 +426,20 @@ mod tests {
})
}
fn test_config(tmp: &TempDir) -> Arc<Config> {
async fn test_config(tmp: &TempDir) -> Arc<Config> {
let config = Config {
workspace_dir: tmp.path().join("workspace"),
config_path: tmp.path().join("config.toml"),
..Config::default()
};
config.save().unwrap();
config.save().await.unwrap();
Arc::new(config)
}
#[tokio::test]
async fn list_services_action_returns_known_keys() {
let tmp = TempDir::new().unwrap();
let tool = ProxyConfigTool::new(test_config(&tmp), test_security());
let tool = ProxyConfigTool::new(test_config(&tmp).await, test_security());
let result = tool
.execute(json!({"action": "list_services"}))
@ -448,7 +453,7 @@ mod tests {
#[tokio::test]
async fn set_scope_services_requires_services_entries() {
let tmp = TempDir::new().unwrap();
let tool = ProxyConfigTool::new(test_config(&tmp), test_security());
let tool = ProxyConfigTool::new(test_config(&tmp).await, test_security());
let result = tool
.execute(json!({
@ -471,7 +476,7 @@ mod tests {
#[tokio::test]
async fn set_and_get_round_trip_proxy_scope() {
let tmp = TempDir::new().unwrap();
let tool = ProxyConfigTool::new(test_config(&tmp), test_security());
let tool = ProxyConfigTool::new(test_config(&tmp).await, test_security());
let set_result = tool
.execute(json!({

View file

@ -43,6 +43,13 @@ pub fn truncate_with_ellipsis(s: &str, max_chars: usize) -> String {
}
}
/// Utility enum for handling optional values.
pub enum MaybeSet<T> {
Set(T),
Unset,
Null,
}
#[cfg(test)]
mod tests {
use super::*;