chore: Remove more blocking io calls
This commit is contained in:
parent
1aec9ad9c0
commit
f1ca73d3d2
14 changed files with 427 additions and 357 deletions
13
Cargo.lock
generated
13
Cargo.lock
generated
|
|
@ -385,6 +385,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "8b52af3cb4058c895d37317bb27508dccc8e5f2d39454016b297bf4a400597b8"
|
checksum = "8b52af3cb4058c895d37317bb27508dccc8e5f2d39454016b297bf4a400597b8"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"axum-core",
|
"axum-core",
|
||||||
|
"axum-macros",
|
||||||
"base64",
|
"base64",
|
||||||
"bytes",
|
"bytes",
|
||||||
"form_urlencoded",
|
"form_urlencoded",
|
||||||
|
|
@ -431,6 +432,17 @@ dependencies = [
|
||||||
"tower-service",
|
"tower-service",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "axum-macros"
|
||||||
|
version = "0.5.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "604fde5e028fea851ce1d8570bbdc034bec850d157f7569d10f347d06808c05c"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"syn 2.0.116",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "backon"
|
name = "backon"
|
||||||
version = "1.6.0"
|
version = "1.6.0"
|
||||||
|
|
@ -7734,6 +7746,7 @@ dependencies = [
|
||||||
"tokio",
|
"tokio",
|
||||||
"tokio-rustls",
|
"tokio-rustls",
|
||||||
"tokio-serial",
|
"tokio-serial",
|
||||||
|
"tokio-stream",
|
||||||
"tokio-tungstenite 0.24.0",
|
"tokio-tungstenite 0.24.0",
|
||||||
"tokio-util",
|
"tokio-util",
|
||||||
"toml 1.0.1+spec-1.1.0",
|
"toml 1.0.1+spec-1.1.0",
|
||||||
|
|
|
||||||
|
|
@ -120,7 +120,7 @@ mail-parser = "0.11.2"
|
||||||
async-imap = { version = "0.11",features = ["runtime-tokio"], default-features = false }
|
async-imap = { version = "0.11",features = ["runtime-tokio"], default-features = false }
|
||||||
|
|
||||||
# HTTP server (gateway) — replaces raw TCP for proper HTTP/1.1 compliance
|
# HTTP server (gateway) — replaces raw TCP for proper HTTP/1.1 compliance
|
||||||
axum = { version = "0.8", default-features = false, features = ["http1", "json", "tokio", "query", "ws"] }
|
axum = { version = "0.8", default-features = false, features = ["http1", "json", "tokio", "query", "ws", "macros"] }
|
||||||
tower = { version = "0.5", default-features = false }
|
tower = { version = "0.5", default-features = false }
|
||||||
tower-http = { version = "0.6", default-features = false, features = ["limit", "timeout"] }
|
tower-http = { version = "0.6", default-features = false, features = ["limit", "timeout"] }
|
||||||
http-body-util = "0.1"
|
http-body-util = "0.1"
|
||||||
|
|
@ -141,6 +141,7 @@ probe-rs = { version = "0.30", optional = true }
|
||||||
|
|
||||||
# PDF extraction for datasheet RAG (optional, enable with --features rag-pdf)
|
# PDF extraction for datasheet RAG (optional, enable with --features rag-pdf)
|
||||||
pdf-extract = { version = "0.10", optional = true }
|
pdf-extract = { version = "0.10", optional = true }
|
||||||
|
tokio-stream = { version = "0.1.18", features = ["full"] }
|
||||||
|
|
||||||
# WhatsApp Web client (wa-rs) — optional, enable with --features whatsapp-web
|
# WhatsApp Web client (wa-rs) — optional, enable with --features whatsapp-web
|
||||||
# Uses wa-rs for Bot and Client, wa-rs-core for storage traits, custom rusqlite backend avoids Diesel conflict.
|
# Uses wa-rs for Bot and Client, wa-rs-core for storage traits, custom rusqlite backend avoids Diesel conflict.
|
||||||
|
|
|
||||||
|
|
@ -26,7 +26,7 @@ pub fn run_wizard() -> Result<Config> {
|
||||||
security: SecurityConfig::autodetect(), // Silent!
|
security: SecurityConfig::autodetect(), // Silent!
|
||||||
};
|
};
|
||||||
|
|
||||||
config.save()?;
|
config.save().await?;
|
||||||
Ok(config)
|
Ok(config)
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
|
||||||
|
|
@ -726,7 +726,7 @@ mod tests {
|
||||||
"!r:m".to_string(),
|
"!r:m".to_string(),
|
||||||
vec![],
|
vec![],
|
||||||
Some(" ".to_string()),
|
Some(" ".to_string()),
|
||||||
Some("".to_string()),
|
Some(String::new()),
|
||||||
);
|
);
|
||||||
|
|
||||||
assert!(ch.session_owner_hint.is_none());
|
assert!(ch.session_owner_hint.is_none());
|
||||||
|
|
|
||||||
|
|
@ -1117,7 +1117,7 @@ fn normalize_telegram_identity(value: &str) -> String {
|
||||||
value.trim().trim_start_matches('@').to_string()
|
value.trim().trim_start_matches('@').to_string()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn bind_telegram_identity(config: &Config, identity: &str) -> Result<()> {
|
async fn bind_telegram_identity(config: &Config, identity: &str) -> Result<()> {
|
||||||
let normalized = normalize_telegram_identity(identity);
|
let normalized = normalize_telegram_identity(identity);
|
||||||
if normalized.is_empty() {
|
if normalized.is_empty() {
|
||||||
anyhow::bail!("Telegram identity cannot be empty");
|
anyhow::bail!("Telegram identity cannot be empty");
|
||||||
|
|
@ -1147,7 +1147,7 @@ fn bind_telegram_identity(config: &Config, identity: &str) -> Result<()> {
|
||||||
}
|
}
|
||||||
|
|
||||||
telegram.allowed_users.push(normalized.clone());
|
telegram.allowed_users.push(normalized.clone());
|
||||||
updated.save()?;
|
updated.save().await?;
|
||||||
println!("✅ Bound Telegram identity: {normalized}");
|
println!("✅ Bound Telegram identity: {normalized}");
|
||||||
println!(" Saved to {}", updated.config_path.display());
|
println!(" Saved to {}", updated.config_path.display());
|
||||||
match maybe_restart_managed_daemon_service() {
|
match maybe_restart_managed_daemon_service() {
|
||||||
|
|
@ -1243,7 +1243,7 @@ fn maybe_restart_managed_daemon_service() -> Result<bool> {
|
||||||
Ok(false)
|
Ok(false)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn handle_command(command: crate::ChannelCommands, config: &Config) -> Result<()> {
|
pub async fn handle_command(command: crate::ChannelCommands, config: &Config) -> Result<()> {
|
||||||
match command {
|
match command {
|
||||||
crate::ChannelCommands::Start => {
|
crate::ChannelCommands::Start => {
|
||||||
anyhow::bail!("Start must be handled in main.rs (requires async runtime)")
|
anyhow::bail!("Start must be handled in main.rs (requires async runtime)")
|
||||||
|
|
@ -1290,7 +1290,7 @@ pub fn handle_command(command: crate::ChannelCommands, config: &Config) -> Resul
|
||||||
anyhow::bail!("Remove channel '{name}' — edit ~/.zeroclaw/config.toml directly");
|
anyhow::bail!("Remove channel '{name}' — edit ~/.zeroclaw/config.toml directly");
|
||||||
}
|
}
|
||||||
crate::ChannelCommands::BindTelegram { identity } => {
|
crate::ChannelCommands::BindTelegram { identity } => {
|
||||||
bind_telegram_identity(config, &identity)
|
bind_telegram_identity(config, &identity).await
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -6,10 +6,10 @@ use async_trait::async_trait;
|
||||||
use directories::UserDirs;
|
use directories::UserDirs;
|
||||||
use parking_lot::Mutex;
|
use parking_lot::Mutex;
|
||||||
use reqwest::multipart::{Form, Part};
|
use reqwest::multipart::{Form, Part};
|
||||||
use std::fs;
|
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
use std::sync::{Arc, RwLock};
|
use std::sync::{Arc, RwLock};
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
|
use tokio::fs;
|
||||||
|
|
||||||
/// Telegram's maximum message length for text messages
|
/// Telegram's maximum message length for text messages
|
||||||
const TELEGRAM_MAX_MESSAGE_LENGTH: usize = 4096;
|
const TELEGRAM_MAX_MESSAGE_LENGTH: usize = 4096;
|
||||||
|
|
@ -373,7 +373,7 @@ impl TelegramChannel {
|
||||||
.collect()
|
.collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn load_config_without_env() -> anyhow::Result<Config> {
|
async fn load_config_without_env() -> anyhow::Result<Config> {
|
||||||
let home = UserDirs::new()
|
let home = UserDirs::new()
|
||||||
.map(|u| u.home_dir().to_path_buf())
|
.map(|u| u.home_dir().to_path_buf())
|
||||||
.context("Could not find home directory")?;
|
.context("Could not find home directory")?;
|
||||||
|
|
@ -381,6 +381,7 @@ impl TelegramChannel {
|
||||||
let config_path = zeroclaw_dir.join("config.toml");
|
let config_path = zeroclaw_dir.join("config.toml");
|
||||||
|
|
||||||
let contents = fs::read_to_string(&config_path)
|
let contents = fs::read_to_string(&config_path)
|
||||||
|
.await
|
||||||
.with_context(|| format!("Failed to read config file: {}", config_path.display()))?;
|
.with_context(|| format!("Failed to read config file: {}", config_path.display()))?;
|
||||||
let mut config: Config = toml::from_str(&contents)
|
let mut config: Config = toml::from_str(&contents)
|
||||||
.context("Failed to parse config file for Telegram binding")?;
|
.context("Failed to parse config file for Telegram binding")?;
|
||||||
|
|
@ -389,8 +390,8 @@ impl TelegramChannel {
|
||||||
Ok(config)
|
Ok(config)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn persist_allowed_identity_blocking(identity: &str) -> anyhow::Result<()> {
|
async fn persist_allowed_identity(&self, identity: &str) -> anyhow::Result<()> {
|
||||||
let mut config = Self::load_config_without_env()?;
|
let mut config = Self::load_config_without_env().await?;
|
||||||
let Some(telegram) = config.channels_config.telegram.as_mut() else {
|
let Some(telegram) = config.channels_config.telegram.as_mut() else {
|
||||||
anyhow::bail!("Telegram channel config is missing in config.toml");
|
anyhow::bail!("Telegram channel config is missing in config.toml");
|
||||||
};
|
};
|
||||||
|
|
@ -404,20 +405,13 @@ impl TelegramChannel {
|
||||||
telegram.allowed_users.push(normalized);
|
telegram.allowed_users.push(normalized);
|
||||||
config
|
config
|
||||||
.save()
|
.save()
|
||||||
|
.await
|
||||||
.context("Failed to persist Telegram allowlist to config.toml")?;
|
.context("Failed to persist Telegram allowlist to config.toml")?;
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn persist_allowed_identity(&self, identity: &str) -> anyhow::Result<()> {
|
|
||||||
let identity = identity.to_string();
|
|
||||||
tokio::task::spawn_blocking(move || Self::persist_allowed_identity_blocking(&identity))
|
|
||||||
.await
|
|
||||||
.map_err(|e| anyhow::anyhow!("Failed to join Telegram bind save task: {e}"))??;
|
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn add_allowed_identity_runtime(&self, identity: &str) {
|
fn add_allowed_identity_runtime(&self, identity: &str) {
|
||||||
let normalized = Self::normalize_identity(identity);
|
let normalized = Self::normalize_identity(identity);
|
||||||
if normalized.is_empty() {
|
if normalized.is_empty() {
|
||||||
|
|
@ -629,7 +623,7 @@ impl TelegramChannel {
|
||||||
|
|
||||||
if let Some(code) = Self::extract_bind_code(text) {
|
if let Some(code) = Self::extract_bind_code(text) {
|
||||||
if let Some(pairing) = self.pairing.as_ref() {
|
if let Some(pairing) = self.pairing.as_ref() {
|
||||||
match pairing.try_pair(code) {
|
match pairing.try_pair(code).await {
|
||||||
Ok(Some(_token)) => {
|
Ok(Some(_token)) => {
|
||||||
let bind_identity = normalized_sender_id.clone().or_else(|| {
|
let bind_identity = normalized_sender_id.clone().or_else(|| {
|
||||||
if normalized_username.is_empty() || normalized_username == "unknown" {
|
if normalized_username.is_empty() || normalized_username == "unknown" {
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load diff
|
|
@ -587,6 +587,7 @@ async fn handle_metrics(State(state): State<AppState>) -> impl IntoResponse {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// POST /pair — exchange one-time code for bearer token
|
/// POST /pair — exchange one-time code for bearer token
|
||||||
|
#[axum::debug_handler]
|
||||||
async fn handle_pair(
|
async fn handle_pair(
|
||||||
State(state): State<AppState>,
|
State(state): State<AppState>,
|
||||||
ConnectInfo(peer_addr): ConnectInfo<SocketAddr>,
|
ConnectInfo(peer_addr): ConnectInfo<SocketAddr>,
|
||||||
|
|
@ -608,10 +609,10 @@ async fn handle_pair(
|
||||||
.and_then(|v| v.to_str().ok())
|
.and_then(|v| v.to_str().ok())
|
||||||
.unwrap_or("");
|
.unwrap_or("");
|
||||||
|
|
||||||
match state.pairing.try_pair(code) {
|
match state.pairing.try_pair(code).await {
|
||||||
Ok(Some(token)) => {
|
Ok(Some(token)) => {
|
||||||
tracing::info!("🔐 New client paired successfully");
|
tracing::info!("🔐 New client paired successfully");
|
||||||
if let Err(err) = persist_pairing_tokens(&state.config, &state.pairing) {
|
if let Err(err) = persist_pairing_tokens(state.config.clone(), &state.pairing).await {
|
||||||
tracing::error!("🔐 Pairing succeeded but token persistence failed: {err:#}");
|
tracing::error!("🔐 Pairing succeeded but token persistence failed: {err:#}");
|
||||||
let body = serde_json::json!({
|
let body = serde_json::json!({
|
||||||
"paired": true,
|
"paired": true,
|
||||||
|
|
@ -648,11 +649,14 @@ async fn handle_pair(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn persist_pairing_tokens(config: &Arc<Mutex<Config>>, pairing: &PairingGuard) -> Result<()> {
|
async fn persist_pairing_tokens(config: Arc<Mutex<Config>>, pairing: &PairingGuard) -> Result<()> {
|
||||||
let paired_tokens = pairing.tokens();
|
let paired_tokens = pairing.tokens();
|
||||||
let mut cfg = config.lock();
|
// This is needed because parking_lot's guard is not Send so we clone the inner
|
||||||
|
// this should be removed once async mutexes are used everywhere
|
||||||
|
let mut cfg = { config.lock().clone() };
|
||||||
cfg.gateway.paired_tokens = paired_tokens;
|
cfg.gateway.paired_tokens = paired_tokens;
|
||||||
cfg.save()
|
cfg.save()
|
||||||
|
.await
|
||||||
.context("Failed to persist paired tokens to config.toml")
|
.context("Failed to persist paired tokens to config.toml")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1398,15 +1402,15 @@ mod tests {
|
||||||
let mut config = Config::default();
|
let mut config = Config::default();
|
||||||
config.config_path = config_path.clone();
|
config.config_path = config_path.clone();
|
||||||
config.workspace_dir = workspace_path;
|
config.workspace_dir = workspace_path;
|
||||||
config.save().unwrap();
|
config.save().await.unwrap();
|
||||||
|
|
||||||
let guard = PairingGuard::new(true, &[]);
|
let guard = PairingGuard::new(true, &[]);
|
||||||
let code = guard.pairing_code().unwrap();
|
let code = guard.pairing_code().unwrap();
|
||||||
let token = guard.try_pair(&code).unwrap().unwrap();
|
let token = guard.try_pair(&code).await.unwrap().unwrap();
|
||||||
assert!(guard.is_authenticated(&token));
|
assert!(guard.is_authenticated(&token));
|
||||||
|
|
||||||
let shared_config = Arc::new(Mutex::new(config));
|
let shared_config = Arc::new(Mutex::new(config));
|
||||||
persist_pairing_tokens(&shared_config, &guard).unwrap();
|
persist_pairing_tokens(shared_config, &guard).await.unwrap();
|
||||||
|
|
||||||
let saved = tokio::fs::read_to_string(config_path).await.unwrap();
|
let saved = tokio::fs::read_to_string(config_path).await.unwrap();
|
||||||
let parsed: Config = toml::from_str(&saved).unwrap();
|
let parsed: Config = toml::from_str(&saved).unwrap();
|
||||||
|
|
|
||||||
18
src/main.rs
18
src/main.rs
|
|
@ -553,11 +553,10 @@ async fn main() -> Result<()> {
|
||||||
{
|
{
|
||||||
bail!("--channels-only does not accept --api-key, --provider, --model, or --memory");
|
bail!("--channels-only does not accept --api-key, --provider, --model, or --memory");
|
||||||
}
|
}
|
||||||
let config = tokio::task::spawn_blocking(move || {
|
let config = if channels_only {
|
||||||
if channels_only {
|
onboard::run_channels_repair_wizard().await
|
||||||
onboard::run_channels_repair_wizard()
|
|
||||||
} else if interactive {
|
} else if interactive {
|
||||||
onboard::run_wizard()
|
onboard::run_wizard().await
|
||||||
} else {
|
} else {
|
||||||
onboard::run_quick_setup(
|
onboard::run_quick_setup(
|
||||||
api_key.as_deref(),
|
api_key.as_deref(),
|
||||||
|
|
@ -565,9 +564,8 @@ async fn main() -> Result<()> {
|
||||||
model.as_deref(),
|
model.as_deref(),
|
||||||
memory.as_deref(),
|
memory.as_deref(),
|
||||||
)
|
)
|
||||||
}
|
.await
|
||||||
})
|
}?;
|
||||||
.await??;
|
|
||||||
// Auto-start channels if user said yes during wizard
|
// Auto-start channels if user said yes during wizard
|
||||||
if std::env::var("ZEROCLAW_AUTOSTART_CHANNELS").as_deref() == Ok("1") {
|
if std::env::var("ZEROCLAW_AUTOSTART_CHANNELS").as_deref() == Ok("1") {
|
||||||
channels::start_channels(config).await?;
|
channels::start_channels(config).await?;
|
||||||
|
|
@ -576,7 +574,7 @@ async fn main() -> Result<()> {
|
||||||
}
|
}
|
||||||
|
|
||||||
// All other commands need config loaded first
|
// All other commands need config loaded first
|
||||||
let mut config = Config::load_or_init()?;
|
let mut config = Config::load_or_init().await?;
|
||||||
config.apply_env_overrides();
|
config.apply_env_overrides();
|
||||||
|
|
||||||
match cli.command {
|
match cli.command {
|
||||||
|
|
@ -764,7 +762,7 @@ async fn main() -> Result<()> {
|
||||||
Commands::Channel { channel_command } => match channel_command {
|
Commands::Channel { channel_command } => match channel_command {
|
||||||
ChannelCommands::Start => channels::start_channels(config).await,
|
ChannelCommands::Start => channels::start_channels(config).await,
|
||||||
ChannelCommands::Doctor => channels::doctor_channels(config).await,
|
ChannelCommands::Doctor => channels::doctor_channels(config).await,
|
||||||
other => channels::handle_command(other, &config),
|
other => channels::handle_command(other, &config).await,
|
||||||
},
|
},
|
||||||
|
|
||||||
Commands::Integrations {
|
Commands::Integrations {
|
||||||
|
|
@ -786,7 +784,7 @@ async fn main() -> Result<()> {
|
||||||
}
|
}
|
||||||
|
|
||||||
Commands::Peripheral { peripheral_command } => {
|
Commands::Peripheral { peripheral_command } => {
|
||||||
peripherals::handle_command(peripheral_command.clone(), &config)
|
peripherals::handle_command(peripheral_command.clone(), &config).await
|
||||||
}
|
}
|
||||||
|
|
||||||
Commands::Config { config_command } => match config_command {
|
Commands::Config { config_command } => match config_command {
|
||||||
|
|
|
||||||
|
|
@ -95,7 +95,7 @@ fn has_launchable_channels(channels: &ChannelsConfig) -> bool {
|
||||||
|
|
||||||
// ── Main wizard entry point ──────────────────────────────────────
|
// ── Main wizard entry point ──────────────────────────────────────
|
||||||
|
|
||||||
pub fn run_wizard() -> Result<Config> {
|
pub async fn run_wizard() -> Result<Config> {
|
||||||
println!("{}", style(BANNER).cyan().bold());
|
println!("{}", style(BANNER).cyan().bold());
|
||||||
|
|
||||||
println!(
|
println!(
|
||||||
|
|
@ -191,8 +191,8 @@ pub fn run_wizard() -> Result<Config> {
|
||||||
if config.memory.auto_save { "on" } else { "off" }
|
if config.memory.auto_save { "on" } else { "off" }
|
||||||
);
|
);
|
||||||
|
|
||||||
config.save()?;
|
config.save().await?;
|
||||||
persist_workspace_selection(&config.config_path)?;
|
persist_workspace_selection(&config.config_path).await?;
|
||||||
|
|
||||||
// ── Final summary ────────────────────────────────────────────
|
// ── Final summary ────────────────────────────────────────────
|
||||||
print_summary(&config);
|
print_summary(&config);
|
||||||
|
|
@ -226,7 +226,7 @@ pub fn run_wizard() -> Result<Config> {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Interactive repair flow: rerun channel setup only without redoing full onboarding.
|
/// Interactive repair flow: rerun channel setup only without redoing full onboarding.
|
||||||
pub fn run_channels_repair_wizard() -> Result<Config> {
|
pub async fn run_channels_repair_wizard() -> Result<Config> {
|
||||||
println!("{}", style(BANNER).cyan().bold());
|
println!("{}", style(BANNER).cyan().bold());
|
||||||
println!(
|
println!(
|
||||||
" {}",
|
" {}",
|
||||||
|
|
@ -236,12 +236,12 @@ pub fn run_channels_repair_wizard() -> Result<Config> {
|
||||||
);
|
);
|
||||||
println!();
|
println!();
|
||||||
|
|
||||||
let mut config = Config::load_or_init()?;
|
let mut config = Config::load_or_init().await?;
|
||||||
|
|
||||||
print_step(1, 1, "Channels (How You Talk to ZeroClaw)");
|
print_step(1, 1, "Channels (How You Talk to ZeroClaw)");
|
||||||
config.channels_config = setup_channels()?;
|
config.channels_config = setup_channels()?;
|
||||||
config.save()?;
|
config.save().await?;
|
||||||
persist_workspace_selection(&config.config_path)?;
|
persist_workspace_selection(&config.config_path).await?;
|
||||||
|
|
||||||
println!();
|
println!();
|
||||||
println!(
|
println!(
|
||||||
|
|
@ -321,7 +321,7 @@ fn memory_config_defaults_for_backend(backend: &str) -> MemoryConfig {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(clippy::too_many_lines)]
|
#[allow(clippy::too_many_lines)]
|
||||||
pub fn run_quick_setup(
|
pub async fn run_quick_setup(
|
||||||
credential_override: Option<&str>,
|
credential_override: Option<&str>,
|
||||||
provider: Option<&str>,
|
provider: Option<&str>,
|
||||||
model_override: Option<&str>,
|
model_override: Option<&str>,
|
||||||
|
|
@ -396,8 +396,8 @@ pub fn run_quick_setup(
|
||||||
query_classification: crate::config::QueryClassificationConfig::default(),
|
query_classification: crate::config::QueryClassificationConfig::default(),
|
||||||
};
|
};
|
||||||
|
|
||||||
config.save()?;
|
config.save().await?;
|
||||||
persist_workspace_selection(&config.config_path)?;
|
persist_workspace_selection(&config.config_path).await?;
|
||||||
|
|
||||||
// Scaffold minimal workspace files
|
// Scaffold minimal workspace files
|
||||||
let default_ctx = ProjectContext {
|
let default_ctx = ProjectContext {
|
||||||
|
|
@ -1459,11 +1459,13 @@ fn print_bullet(text: &str) {
|
||||||
println!(" {} {}", style("›").cyan(), text);
|
println!(" {} {}", style("›").cyan(), text);
|
||||||
}
|
}
|
||||||
|
|
||||||
fn persist_workspace_selection(config_path: &Path) -> Result<()> {
|
async fn persist_workspace_selection(config_path: &Path) -> Result<()> {
|
||||||
let config_dir = config_path
|
let config_dir = config_path
|
||||||
.parent()
|
.parent()
|
||||||
.context("Config path must have a parent directory")?;
|
.context("Config path must have a parent directory")?;
|
||||||
crate::config::schema::persist_active_workspace_config_dir(config_dir).with_context(|| {
|
crate::config::schema::persist_active_workspace_config_dir(config_dir)
|
||||||
|
.await
|
||||||
|
.with_context(|| {
|
||||||
format!(
|
format!(
|
||||||
"Failed to persist active workspace selection for {}",
|
"Failed to persist active workspace selection for {}",
|
||||||
config_dir.display()
|
config_dir.display()
|
||||||
|
|
|
||||||
|
|
@ -42,7 +42,7 @@ pub fn list_configured_boards(config: &PeripheralsConfig) -> Vec<&PeripheralBoar
|
||||||
|
|
||||||
/// Handle `zeroclaw peripheral` subcommands.
|
/// Handle `zeroclaw peripheral` subcommands.
|
||||||
#[allow(clippy::module_name_repetitions)]
|
#[allow(clippy::module_name_repetitions)]
|
||||||
pub fn handle_command(cmd: crate::PeripheralCommands, config: &Config) -> Result<()> {
|
pub async fn handle_command(cmd: crate::PeripheralCommands, config: &Config) -> Result<()> {
|
||||||
match cmd {
|
match cmd {
|
||||||
crate::PeripheralCommands::List => {
|
crate::PeripheralCommands::List => {
|
||||||
let boards = list_configured_boards(&config.peripherals);
|
let boards = list_configured_boards(&config.peripherals);
|
||||||
|
|
@ -76,7 +76,7 @@ pub fn handle_command(cmd: crate::PeripheralCommands, config: &Config) -> Result
|
||||||
Some(path.clone())
|
Some(path.clone())
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut cfg = crate::config::Config::load_or_init()?;
|
let mut cfg = crate::config::Config::load_or_init().await?;
|
||||||
cfg.peripherals.enabled = true;
|
cfg.peripherals.enabled = true;
|
||||||
|
|
||||||
if cfg
|
if cfg
|
||||||
|
|
@ -95,7 +95,7 @@ pub fn handle_command(cmd: crate::PeripheralCommands, config: &Config) -> Result
|
||||||
path: path_opt,
|
path: path_opt,
|
||||||
baud: 115_200,
|
baud: 115_200,
|
||||||
});
|
});
|
||||||
cfg.save()?;
|
cfg.save().await?;
|
||||||
println!("Added {} at {}. Restart daemon to apply.", board, path);
|
println!("Added {} at {}. Restart daemon to apply.", board, path);
|
||||||
}
|
}
|
||||||
#[cfg(feature = "hardware")]
|
#[cfg(feature = "hardware")]
|
||||||
|
|
|
||||||
|
|
@ -11,6 +11,7 @@
|
||||||
use parking_lot::Mutex;
|
use parking_lot::Mutex;
|
||||||
use sha2::{Digest, Sha256};
|
use sha2::{Digest, Sha256};
|
||||||
use std::collections::HashSet;
|
use std::collections::HashSet;
|
||||||
|
use std::sync::Arc;
|
||||||
use std::time::Instant;
|
use std::time::Instant;
|
||||||
|
|
||||||
/// Maximum failed pairing attempts before lockout.
|
/// Maximum failed pairing attempts before lockout.
|
||||||
|
|
@ -23,16 +24,17 @@ const PAIR_LOCKOUT_SECS: u64 = 300; // 5 minutes
|
||||||
/// Bearer tokens are stored as SHA-256 hashes to prevent plaintext exposure
|
/// Bearer tokens are stored as SHA-256 hashes to prevent plaintext exposure
|
||||||
/// in config files. When a new token is generated, the plaintext is returned
|
/// in config files. When a new token is generated, the plaintext is returned
|
||||||
/// to the client once, and only the hash is retained.
|
/// to the client once, and only the hash is retained.
|
||||||
#[derive(Debug)]
|
// TODO: I've just made this work with parking_lot but it should use either flume or tokio's async mutexes
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
pub struct PairingGuard {
|
pub struct PairingGuard {
|
||||||
/// Whether pairing is required at all.
|
/// Whether pairing is required at all.
|
||||||
require_pairing: bool,
|
require_pairing: bool,
|
||||||
/// One-time pairing code (generated on startup, consumed on first pair).
|
/// One-time pairing code (generated on startup, consumed on first pair).
|
||||||
pairing_code: Mutex<Option<String>>,
|
pairing_code: Arc<Mutex<Option<String>>>,
|
||||||
/// Set of SHA-256 hashed bearer tokens (persisted across restarts).
|
/// Set of SHA-256 hashed bearer tokens (persisted across restarts).
|
||||||
paired_tokens: Mutex<HashSet<String>>,
|
paired_tokens: Arc<Mutex<HashSet<String>>>,
|
||||||
/// Brute-force protection: failed attempt counter + lockout time.
|
/// Brute-force protection: failed attempt counter + lockout time.
|
||||||
failed_attempts: Mutex<(u32, Option<Instant>)>,
|
failed_attempts: Arc<Mutex<(u32, Option<Instant>)>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl PairingGuard {
|
impl PairingGuard {
|
||||||
|
|
@ -62,9 +64,9 @@ impl PairingGuard {
|
||||||
};
|
};
|
||||||
Self {
|
Self {
|
||||||
require_pairing,
|
require_pairing,
|
||||||
pairing_code: Mutex::new(code),
|
pairing_code: Arc::new(Mutex::new(code)),
|
||||||
paired_tokens: Mutex::new(tokens),
|
paired_tokens: Arc::new(Mutex::new(tokens)),
|
||||||
failed_attempts: Mutex::new((0, None)),
|
failed_attempts: Arc::new(Mutex::new((0, None))),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -78,9 +80,7 @@ impl PairingGuard {
|
||||||
self.require_pairing
|
self.require_pairing
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Attempt to pair with the given code. Returns a bearer token on success.
|
fn try_pair_blocking(&self, code: &str) -> Result<Option<String>, u64> {
|
||||||
/// Returns `Err(lockout_seconds)` if locked out due to brute force.
|
|
||||||
pub fn try_pair(&self, code: &str) -> Result<Option<String>, u64> {
|
|
||||||
// Check brute force lockout
|
// Check brute force lockout
|
||||||
{
|
{
|
||||||
let attempts = self.failed_attempts.lock();
|
let attempts = self.failed_attempts.lock();
|
||||||
|
|
@ -127,6 +127,19 @@ impl PairingGuard {
|
||||||
Ok(None)
|
Ok(None)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Attempt to pair with the given code. Returns a bearer token on success.
|
||||||
|
/// Returns `Err(lockout_seconds)` if locked out due to brute force.
|
||||||
|
pub async fn try_pair(&self, code: &str) -> Result<Option<String>, u64> {
|
||||||
|
let this = self.clone();
|
||||||
|
let code = code.to_string();
|
||||||
|
// TODO: make this function the main one without spawning a task
|
||||||
|
let handle = tokio::task::spawn_blocking(move || this.try_pair_blocking(&code));
|
||||||
|
|
||||||
|
handle
|
||||||
|
.await
|
||||||
|
.expect("failed to spawn blocking task this should not happen")
|
||||||
|
}
|
||||||
|
|
||||||
/// Check if a bearer token is valid (compares against stored hashes).
|
/// Check if a bearer token is valid (compares against stored hashes).
|
||||||
pub fn is_authenticated(&self, token: &str) -> bool {
|
pub fn is_authenticated(&self, token: &str) -> bool {
|
||||||
if !self.require_pairing {
|
if !self.require_pairing {
|
||||||
|
|
@ -232,63 +245,64 @@ pub fn is_public_bind(host: &str) -> bool {
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
use tokio::test;
|
||||||
|
|
||||||
// ── PairingGuard ─────────────────────────────────────────
|
// ── PairingGuard ─────────────────────────────────────────
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn new_guard_generates_code_when_no_tokens() {
|
async fn new_guard_generates_code_when_no_tokens() {
|
||||||
let guard = PairingGuard::new(true, &[]);
|
let guard = PairingGuard::new(true, &[]);
|
||||||
assert!(guard.pairing_code().is_some());
|
assert!(guard.pairing_code().is_some());
|
||||||
assert!(!guard.is_paired());
|
assert!(!guard.is_paired());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn new_guard_no_code_when_tokens_exist() {
|
async fn new_guard_no_code_when_tokens_exist() {
|
||||||
let guard = PairingGuard::new(true, &["zc_existing".into()]);
|
let guard = PairingGuard::new(true, &["zc_existing".into()]);
|
||||||
assert!(guard.pairing_code().is_none());
|
assert!(guard.pairing_code().is_none());
|
||||||
assert!(guard.is_paired());
|
assert!(guard.is_paired());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn new_guard_no_code_when_pairing_disabled() {
|
async fn new_guard_no_code_when_pairing_disabled() {
|
||||||
let guard = PairingGuard::new(false, &[]);
|
let guard = PairingGuard::new(false, &[]);
|
||||||
assert!(guard.pairing_code().is_none());
|
assert!(guard.pairing_code().is_none());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn try_pair_correct_code() {
|
async fn try_pair_correct_code() {
|
||||||
let guard = PairingGuard::new(true, &[]);
|
let guard = PairingGuard::new(true, &[]);
|
||||||
let code = guard.pairing_code().unwrap().to_string();
|
let code = guard.pairing_code().unwrap().to_string();
|
||||||
let token = guard.try_pair(&code).unwrap();
|
let token = guard.try_pair(&code).await.unwrap();
|
||||||
assert!(token.is_some());
|
assert!(token.is_some());
|
||||||
assert!(token.unwrap().starts_with("zc_"));
|
assert!(token.unwrap().starts_with("zc_"));
|
||||||
assert!(guard.is_paired());
|
assert!(guard.is_paired());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn try_pair_wrong_code() {
|
async fn try_pair_wrong_code() {
|
||||||
let guard = PairingGuard::new(true, &[]);
|
let guard = PairingGuard::new(true, &[]);
|
||||||
let result = guard.try_pair("000000").unwrap();
|
let result = guard.try_pair("000000").await.unwrap();
|
||||||
// Might succeed if code happens to be 000000, but extremely unlikely
|
// Might succeed if code happens to be 000000, but extremely unlikely
|
||||||
// Just check it returns Ok(None) normally
|
// Just check it returns Ok(None) normally
|
||||||
let _ = result;
|
let _ = result;
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn try_pair_empty_code() {
|
async fn try_pair_empty_code() {
|
||||||
let guard = PairingGuard::new(true, &[]);
|
let guard = PairingGuard::new(true, &[]);
|
||||||
assert!(guard.try_pair("").unwrap().is_none());
|
assert!(guard.try_pair("").await.unwrap().is_none());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn is_authenticated_with_valid_token() {
|
async fn is_authenticated_with_valid_token() {
|
||||||
// Pass plaintext token — PairingGuard hashes it on load
|
// Pass plaintext token — PairingGuard hashes it on load
|
||||||
let guard = PairingGuard::new(true, &["zc_valid".into()]);
|
let guard = PairingGuard::new(true, &["zc_valid".into()]);
|
||||||
assert!(guard.is_authenticated("zc_valid"));
|
assert!(guard.is_authenticated("zc_valid"));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn is_authenticated_with_prehashed_token() {
|
async fn is_authenticated_with_prehashed_token() {
|
||||||
// Pass an already-hashed token (64 hex chars)
|
// Pass an already-hashed token (64 hex chars)
|
||||||
let hashed = hash_token("zc_valid");
|
let hashed = hash_token("zc_valid");
|
||||||
let guard = PairingGuard::new(true, &[hashed]);
|
let guard = PairingGuard::new(true, &[hashed]);
|
||||||
|
|
@ -296,20 +310,20 @@ mod tests {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn is_authenticated_with_invalid_token() {
|
async fn is_authenticated_with_invalid_token() {
|
||||||
let guard = PairingGuard::new(true, &["zc_valid".into()]);
|
let guard = PairingGuard::new(true, &["zc_valid".into()]);
|
||||||
assert!(!guard.is_authenticated("zc_invalid"));
|
assert!(!guard.is_authenticated("zc_invalid"));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn is_authenticated_when_pairing_disabled() {
|
async fn is_authenticated_when_pairing_disabled() {
|
||||||
let guard = PairingGuard::new(false, &[]);
|
let guard = PairingGuard::new(false, &[]);
|
||||||
assert!(guard.is_authenticated("anything"));
|
assert!(guard.is_authenticated("anything"));
|
||||||
assert!(guard.is_authenticated(""));
|
assert!(guard.is_authenticated(""));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn tokens_returns_hashes() {
|
async fn tokens_returns_hashes() {
|
||||||
let guard = PairingGuard::new(true, &["zc_a".into(), "zc_b".into()]);
|
let guard = PairingGuard::new(true, &["zc_a".into(), "zc_b".into()]);
|
||||||
let tokens = guard.tokens();
|
let tokens = guard.tokens();
|
||||||
assert_eq!(tokens.len(), 2);
|
assert_eq!(tokens.len(), 2);
|
||||||
|
|
@ -322,10 +336,10 @@ mod tests {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn pair_then_authenticate() {
|
async fn pair_then_authenticate() {
|
||||||
let guard = PairingGuard::new(true, &[]);
|
let guard = PairingGuard::new(true, &[]);
|
||||||
let code = guard.pairing_code().unwrap().to_string();
|
let code = guard.pairing_code().unwrap().to_string();
|
||||||
let token = guard.try_pair(&code).unwrap().unwrap();
|
let token = guard.try_pair(&code).await.unwrap().unwrap();
|
||||||
assert!(guard.is_authenticated(&token));
|
assert!(guard.is_authenticated(&token));
|
||||||
assert!(!guard.is_authenticated("wrong"));
|
assert!(!guard.is_authenticated("wrong"));
|
||||||
}
|
}
|
||||||
|
|
@ -333,24 +347,24 @@ mod tests {
|
||||||
// ── Token hashing ────────────────────────────────────────
|
// ── Token hashing ────────────────────────────────────────
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn hash_token_produces_64_hex_chars() {
|
async fn hash_token_produces_64_hex_chars() {
|
||||||
let hash = hash_token("zc_test_token");
|
let hash = hash_token("zc_test_token");
|
||||||
assert_eq!(hash.len(), 64);
|
assert_eq!(hash.len(), 64);
|
||||||
assert!(hash.chars().all(|c| c.is_ascii_hexdigit()));
|
assert!(hash.chars().all(|c| c.is_ascii_hexdigit()));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn hash_token_is_deterministic() {
|
async fn hash_token_is_deterministic() {
|
||||||
assert_eq!(hash_token("zc_abc"), hash_token("zc_abc"));
|
assert_eq!(hash_token("zc_abc"), hash_token("zc_abc"));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn hash_token_differs_for_different_inputs() {
|
async fn hash_token_differs_for_different_inputs() {
|
||||||
assert_ne!(hash_token("zc_a"), hash_token("zc_b"));
|
assert_ne!(hash_token("zc_a"), hash_token("zc_b"));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn is_token_hash_detects_hash_vs_plaintext() {
|
async fn is_token_hash_detects_hash_vs_plaintext() {
|
||||||
assert!(is_token_hash(&hash_token("zc_test")));
|
assert!(is_token_hash(&hash_token("zc_test")));
|
||||||
assert!(!is_token_hash("zc_test_token"));
|
assert!(!is_token_hash("zc_test_token"));
|
||||||
assert!(!is_token_hash("too_short"));
|
assert!(!is_token_hash("too_short"));
|
||||||
|
|
@ -360,7 +374,7 @@ mod tests {
|
||||||
// ── is_public_bind ───────────────────────────────────────
|
// ── is_public_bind ───────────────────────────────────────
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn localhost_variants_not_public() {
|
async fn localhost_variants_not_public() {
|
||||||
assert!(!is_public_bind("127.0.0.1"));
|
assert!(!is_public_bind("127.0.0.1"));
|
||||||
assert!(!is_public_bind("localhost"));
|
assert!(!is_public_bind("localhost"));
|
||||||
assert!(!is_public_bind("::1"));
|
assert!(!is_public_bind("::1"));
|
||||||
|
|
@ -368,12 +382,12 @@ mod tests {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn zero_zero_is_public() {
|
async fn zero_zero_is_public() {
|
||||||
assert!(is_public_bind("0.0.0.0"));
|
assert!(is_public_bind("0.0.0.0"));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn real_ip_is_public() {
|
async fn real_ip_is_public() {
|
||||||
assert!(is_public_bind("192.168.1.100"));
|
assert!(is_public_bind("192.168.1.100"));
|
||||||
assert!(is_public_bind("10.0.0.1"));
|
assert!(is_public_bind("10.0.0.1"));
|
||||||
}
|
}
|
||||||
|
|
@ -381,13 +395,13 @@ mod tests {
|
||||||
// ── constant_time_eq ─────────────────────────────────────
|
// ── constant_time_eq ─────────────────────────────────────
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn constant_time_eq_same() {
|
async fn constant_time_eq_same() {
|
||||||
assert!(constant_time_eq("abc", "abc"));
|
assert!(constant_time_eq("abc", "abc"));
|
||||||
assert!(constant_time_eq("", ""));
|
assert!(constant_time_eq("", ""));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn constant_time_eq_different() {
|
async fn constant_time_eq_different() {
|
||||||
assert!(!constant_time_eq("abc", "abd"));
|
assert!(!constant_time_eq("abc", "abd"));
|
||||||
assert!(!constant_time_eq("abc", "ab"));
|
assert!(!constant_time_eq("abc", "ab"));
|
||||||
assert!(!constant_time_eq("a", ""));
|
assert!(!constant_time_eq("a", ""));
|
||||||
|
|
@ -396,14 +410,14 @@ mod tests {
|
||||||
// ── generate helpers ─────────────────────────────────────
|
// ── generate helpers ─────────────────────────────────────
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn generate_code_is_6_digits() {
|
async fn generate_code_is_6_digits() {
|
||||||
let code = generate_code();
|
let code = generate_code();
|
||||||
assert_eq!(code.len(), 6);
|
assert_eq!(code.len(), 6);
|
||||||
assert!(code.chars().all(|c| c.is_ascii_digit()));
|
assert!(code.chars().all(|c| c.is_ascii_digit()));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn generate_code_is_not_deterministic() {
|
async fn generate_code_is_not_deterministic() {
|
||||||
// Two codes should differ with overwhelming probability. We try
|
// Two codes should differ with overwhelming probability. We try
|
||||||
// multiple pairs so a single 1-in-10^6 collision doesn't cause
|
// multiple pairs so a single 1-in-10^6 collision doesn't cause
|
||||||
// a flaky CI failure. All 10 pairs colliding is ~1-in-10^60.
|
// a flaky CI failure. All 10 pairs colliding is ~1-in-10^60.
|
||||||
|
|
@ -416,7 +430,7 @@ mod tests {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn generate_token_has_prefix_and_hex_payload() {
|
async fn generate_token_has_prefix_and_hex_payload() {
|
||||||
let token = generate_token();
|
let token = generate_token();
|
||||||
let payload = token
|
let payload = token
|
||||||
.strip_prefix("zc_")
|
.strip_prefix("zc_")
|
||||||
|
|
@ -434,15 +448,15 @@ mod tests {
|
||||||
// ── Brute force protection ───────────────────────────────
|
// ── Brute force protection ───────────────────────────────
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn brute_force_lockout_after_max_attempts() {
|
async fn brute_force_lockout_after_max_attempts() {
|
||||||
let guard = PairingGuard::new(true, &[]);
|
let guard = PairingGuard::new(true, &[]);
|
||||||
// Exhaust all attempts with wrong codes
|
// Exhaust all attempts with wrong codes
|
||||||
for i in 0..MAX_PAIR_ATTEMPTS {
|
for i in 0..MAX_PAIR_ATTEMPTS {
|
||||||
let result = guard.try_pair(&format!("wrong_{i}"));
|
let result = guard.try_pair(&format!("wrong_{i}")).await;
|
||||||
assert!(result.is_ok(), "Attempt {i} should not be locked out yet");
|
assert!(result.is_ok(), "Attempt {i} should not be locked out yet");
|
||||||
}
|
}
|
||||||
// Next attempt should be locked out
|
// Next attempt should be locked out
|
||||||
let result = guard.try_pair("another_wrong");
|
let result = guard.try_pair("another_wrong").await;
|
||||||
assert!(
|
assert!(
|
||||||
result.is_err(),
|
result.is_err(),
|
||||||
"Should be locked out after {MAX_PAIR_ATTEMPTS} attempts"
|
"Should be locked out after {MAX_PAIR_ATTEMPTS} attempts"
|
||||||
|
|
@ -456,25 +470,25 @@ mod tests {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn correct_code_resets_failed_attempts() {
|
async fn correct_code_resets_failed_attempts() {
|
||||||
let guard = PairingGuard::new(true, &[]);
|
let guard = PairingGuard::new(true, &[]);
|
||||||
let code = guard.pairing_code().unwrap().to_string();
|
let code = guard.pairing_code().unwrap().to_string();
|
||||||
// Fail a few times
|
// Fail a few times
|
||||||
for _ in 0..3 {
|
for _ in 0..3 {
|
||||||
let _ = guard.try_pair("wrong");
|
let _ = guard.try_pair("wrong").await;
|
||||||
}
|
}
|
||||||
// Correct code should still work (under MAX_PAIR_ATTEMPTS)
|
// Correct code should still work (under MAX_PAIR_ATTEMPTS)
|
||||||
let result = guard.try_pair(&code).unwrap();
|
let result = guard.try_pair(&code).await.unwrap();
|
||||||
assert!(result.is_some(), "Correct code should work before lockout");
|
assert!(result.is_some(), "Correct code should work before lockout");
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn lockout_returns_remaining_seconds() {
|
async fn lockout_returns_remaining_seconds() {
|
||||||
let guard = PairingGuard::new(true, &[]);
|
let guard = PairingGuard::new(true, &[]);
|
||||||
for _ in 0..MAX_PAIR_ATTEMPTS {
|
for _ in 0..MAX_PAIR_ATTEMPTS {
|
||||||
let _ = guard.try_pair("wrong");
|
let _ = guard.try_pair("wrong").await;
|
||||||
}
|
}
|
||||||
let err = guard.try_pair("wrong").unwrap_err();
|
let err = guard.try_pair("wrong").await.unwrap_err();
|
||||||
// Should be close to PAIR_LOCKOUT_SECS (within a second)
|
// Should be close to PAIR_LOCKOUT_SECS (within a second)
|
||||||
assert!(
|
assert!(
|
||||||
err >= PAIR_LOCKOUT_SECS - 1,
|
err >= PAIR_LOCKOUT_SECS - 1,
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,7 @@ use crate::config::{
|
||||||
runtime_proxy_config, set_runtime_proxy_config, Config, ProxyConfig, ProxyScope,
|
runtime_proxy_config, set_runtime_proxy_config, Config, ProxyConfig, ProxyScope,
|
||||||
};
|
};
|
||||||
use crate::security::SecurityPolicy;
|
use crate::security::SecurityPolicy;
|
||||||
|
use crate::util::MaybeSet;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use serde_json::{json, Value};
|
use serde_json::{json, Value};
|
||||||
use std::fs;
|
use std::fs;
|
||||||
|
|
@ -93,16 +94,13 @@ impl ProxyConfigTool {
|
||||||
anyhow::bail!("'{field}' must be a string or string[]")
|
anyhow::bail!("'{field}' must be a string or string[]")
|
||||||
}
|
}
|
||||||
|
|
||||||
fn parse_optional_string_update(
|
fn parse_optional_string_update(args: &Value, field: &str) -> anyhow::Result<MaybeSet<String>> {
|
||||||
args: &Value,
|
|
||||||
field: &str,
|
|
||||||
) -> anyhow::Result<Option<Option<String>>> {
|
|
||||||
let Some(raw) = args.get(field) else {
|
let Some(raw) = args.get(field) else {
|
||||||
return Ok(None);
|
return Ok(MaybeSet::Unset);
|
||||||
};
|
};
|
||||||
|
|
||||||
if raw.is_null() {
|
if raw.is_null() {
|
||||||
return Ok(Some(None));
|
return Ok(MaybeSet::Null);
|
||||||
}
|
}
|
||||||
|
|
||||||
let value = raw
|
let value = raw
|
||||||
|
|
@ -110,7 +108,13 @@ impl ProxyConfigTool {
|
||||||
.ok_or_else(|| anyhow::anyhow!("'{field}' must be a string or null"))?
|
.ok_or_else(|| anyhow::anyhow!("'{field}' must be a string or null"))?
|
||||||
.trim()
|
.trim()
|
||||||
.to_string();
|
.to_string();
|
||||||
Ok(Some((!value.is_empty()).then_some(value)))
|
|
||||||
|
let output = if value.is_empty() {
|
||||||
|
MaybeSet::Null
|
||||||
|
} else {
|
||||||
|
MaybeSet::Set(value)
|
||||||
|
};
|
||||||
|
Ok(output)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn env_snapshot() -> Value {
|
fn env_snapshot() -> Value {
|
||||||
|
|
@ -164,7 +168,7 @@ impl ProxyConfigTool {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn handle_set(&self, args: &Value) -> anyhow::Result<ToolResult> {
|
async fn handle_set(&self, args: &Value) -> anyhow::Result<ToolResult> {
|
||||||
let mut cfg = self.load_config_without_env()?;
|
let mut cfg = self.load_config_without_env()?;
|
||||||
let previous_scope = cfg.proxy.scope;
|
let previous_scope = cfg.proxy.scope;
|
||||||
let mut proxy = cfg.proxy.clone();
|
let mut proxy = cfg.proxy.clone();
|
||||||
|
|
@ -185,23 +189,24 @@ impl ProxyConfigTool {
|
||||||
})?;
|
})?;
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(update) = Self::parse_optional_string_update(args, "http_proxy")? {
|
if let MaybeSet::Set(update) = Self::parse_optional_string_update(args, "http_proxy")? {
|
||||||
proxy.http_proxy = update;
|
proxy.http_proxy = Some(update);
|
||||||
touched_proxy_url = true;
|
touched_proxy_url = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(update) = Self::parse_optional_string_update(args, "https_proxy")? {
|
if let MaybeSet::Set(update) = Self::parse_optional_string_update(args, "https_proxy")? {
|
||||||
proxy.https_proxy = update;
|
proxy.https_proxy = Some(update);
|
||||||
touched_proxy_url = true;
|
touched_proxy_url = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(update) = Self::parse_optional_string_update(args, "all_proxy")? {
|
if let MaybeSet::Set(update) = Self::parse_optional_string_update(args, "all_proxy")? {
|
||||||
proxy.all_proxy = update;
|
proxy.all_proxy = Some(update);
|
||||||
touched_proxy_url = true;
|
touched_proxy_url = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(no_proxy_raw) = args.get("no_proxy") {
|
if let Some(no_proxy_raw) = args.get("no_proxy") {
|
||||||
proxy.no_proxy = Self::parse_string_list(no_proxy_raw, "no_proxy")?;
|
proxy.no_proxy = Self::parse_string_list(no_proxy_raw, "no_proxy")?;
|
||||||
|
touched_proxy_url = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(services_raw) = args.get("services") {
|
if let Some(services_raw) = args.get("services") {
|
||||||
|
|
@ -217,7 +222,7 @@ impl ProxyConfigTool {
|
||||||
proxy.validate()?;
|
proxy.validate()?;
|
||||||
|
|
||||||
cfg.proxy = proxy.clone();
|
cfg.proxy = proxy.clone();
|
||||||
cfg.save()?;
|
cfg.save().await?;
|
||||||
set_runtime_proxy_config(proxy.clone());
|
set_runtime_proxy_config(proxy.clone());
|
||||||
|
|
||||||
if proxy.enabled && proxy.scope == ProxyScope::Environment {
|
if proxy.enabled && proxy.scope == ProxyScope::Environment {
|
||||||
|
|
@ -237,11 +242,11 @@ impl ProxyConfigTool {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn handle_disable(&self, args: &Value) -> anyhow::Result<ToolResult> {
|
async fn handle_disable(&self, args: &Value) -> anyhow::Result<ToolResult> {
|
||||||
let mut cfg = self.load_config_without_env()?;
|
let mut cfg = self.load_config_without_env()?;
|
||||||
let clear_env_default = cfg.proxy.scope == ProxyScope::Environment;
|
let clear_env_default = cfg.proxy.scope == ProxyScope::Environment;
|
||||||
cfg.proxy.enabled = false;
|
cfg.proxy.enabled = false;
|
||||||
cfg.save()?;
|
cfg.save().await?;
|
||||||
|
|
||||||
set_runtime_proxy_config(cfg.proxy.clone());
|
set_runtime_proxy_config(cfg.proxy.clone());
|
||||||
|
|
||||||
|
|
@ -384,8 +389,8 @@ impl Tool for ProxyConfigTool {
|
||||||
}
|
}
|
||||||
|
|
||||||
match action.as_str() {
|
match action.as_str() {
|
||||||
"set" => self.handle_set(&args),
|
"set" => self.handle_set(&args).await,
|
||||||
"disable" => self.handle_disable(&args),
|
"disable" => self.handle_disable(&args).await,
|
||||||
"apply_env" => self.handle_apply_env(),
|
"apply_env" => self.handle_apply_env(),
|
||||||
"clear_env" => self.handle_clear_env(),
|
"clear_env" => self.handle_clear_env(),
|
||||||
_ => unreachable!("handled above"),
|
_ => unreachable!("handled above"),
|
||||||
|
|
@ -421,20 +426,20 @@ mod tests {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn test_config(tmp: &TempDir) -> Arc<Config> {
|
async fn test_config(tmp: &TempDir) -> Arc<Config> {
|
||||||
let config = Config {
|
let config = Config {
|
||||||
workspace_dir: tmp.path().join("workspace"),
|
workspace_dir: tmp.path().join("workspace"),
|
||||||
config_path: tmp.path().join("config.toml"),
|
config_path: tmp.path().join("config.toml"),
|
||||||
..Config::default()
|
..Config::default()
|
||||||
};
|
};
|
||||||
config.save().unwrap();
|
config.save().await.unwrap();
|
||||||
Arc::new(config)
|
Arc::new(config)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn list_services_action_returns_known_keys() {
|
async fn list_services_action_returns_known_keys() {
|
||||||
let tmp = TempDir::new().unwrap();
|
let tmp = TempDir::new().unwrap();
|
||||||
let tool = ProxyConfigTool::new(test_config(&tmp), test_security());
|
let tool = ProxyConfigTool::new(test_config(&tmp).await, test_security());
|
||||||
|
|
||||||
let result = tool
|
let result = tool
|
||||||
.execute(json!({"action": "list_services"}))
|
.execute(json!({"action": "list_services"}))
|
||||||
|
|
@ -448,7 +453,7 @@ mod tests {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn set_scope_services_requires_services_entries() {
|
async fn set_scope_services_requires_services_entries() {
|
||||||
let tmp = TempDir::new().unwrap();
|
let tmp = TempDir::new().unwrap();
|
||||||
let tool = ProxyConfigTool::new(test_config(&tmp), test_security());
|
let tool = ProxyConfigTool::new(test_config(&tmp).await, test_security());
|
||||||
|
|
||||||
let result = tool
|
let result = tool
|
||||||
.execute(json!({
|
.execute(json!({
|
||||||
|
|
@ -471,7 +476,7 @@ mod tests {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn set_and_get_round_trip_proxy_scope() {
|
async fn set_and_get_round_trip_proxy_scope() {
|
||||||
let tmp = TempDir::new().unwrap();
|
let tmp = TempDir::new().unwrap();
|
||||||
let tool = ProxyConfigTool::new(test_config(&tmp), test_security());
|
let tool = ProxyConfigTool::new(test_config(&tmp).await, test_security());
|
||||||
|
|
||||||
let set_result = tool
|
let set_result = tool
|
||||||
.execute(json!({
|
.execute(json!({
|
||||||
|
|
|
||||||
|
|
@ -43,6 +43,13 @@ pub fn truncate_with_ellipsis(s: &str, max_chars: usize) -> String {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Utility enum for handling optional values.
|
||||||
|
pub enum MaybeSet<T> {
|
||||||
|
Set(T),
|
||||||
|
Unset,
|
||||||
|
Null,
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue