diff --git a/src/agent/loop_.rs b/src/agent/loop_.rs index 57f983c..54b88f4 100644 --- a/src/agent/loop_.rs +++ b/src/agent/loop_.rs @@ -40,7 +40,8 @@ pub async fn run( // ── Wire up agnostic subsystems ────────────────────────────── let observer: Arc = Arc::from(observability::create_observer(&config.observability)); - let _runtime = runtime::create_runtime(&config.runtime)?; + let runtime: Arc = + Arc::from(runtime::create_runtime(&config.runtime)?); let security = Arc::new(SecurityPolicy::from_config( &config.autonomy, &config.workspace_dir, @@ -60,7 +61,13 @@ pub async fn run( } else { None }; - let _tools = tools::all_tools(&security, mem.clone(), composio_key, &config.browser); + let _tools = tools::all_tools_with_runtime( + &security, + runtime, + mem.clone(), + composio_key, + &config.browser, + ); // ── Resolve provider ───────────────────────────────────────── let provider_name = provider_override diff --git a/src/config/mod.rs b/src/config/mod.rs index e5a6521..b442538 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -2,7 +2,7 @@ pub mod schema; pub use schema::{ AutonomyConfig, BrowserConfig, ChannelsConfig, ComposioConfig, Config, DiscordConfig, - GatewayConfig, HeartbeatConfig, IMessageConfig, IdentityConfig, MatrixConfig, MemoryConfig, - ModelRouteConfig, ObservabilityConfig, ReliabilityConfig, RuntimeConfig, SecretsConfig, - SlackConfig, TelegramConfig, TunnelConfig, WebhookConfig, + DockerRuntimeConfig, GatewayConfig, HeartbeatConfig, IMessageConfig, IdentityConfig, + MatrixConfig, MemoryConfig, ModelRouteConfig, ObservabilityConfig, ReliabilityConfig, + RuntimeConfig, SecretsConfig, SlackConfig, TelegramConfig, TunnelConfig, WebhookConfig, }; diff --git a/src/config/schema.rs b/src/config/schema.rs index 764ba69..a866880 100644 --- a/src/config/schema.rs +++ b/src/config/schema.rs @@ -2,8 +2,9 @@ use crate::security::AutonomyLevel; use anyhow::{Context, Result}; use directories::UserDirs; use serde::{Deserialize, Serialize}; -use std::fs; -use std::path::PathBuf; +use std::fs::{self, File, OpenOptions}; +use std::io::Write; +use std::path::{Path, PathBuf}; // ── Top-level config ────────────────────────────────────────────── @@ -112,6 +113,18 @@ pub struct GatewayConfig { /// Paired bearer tokens (managed automatically, not user-edited) #[serde(default)] pub paired_tokens: Vec, + + /// Max `/pair` requests per minute per client key. + #[serde(default = "default_pair_rate_limit")] + pub pair_rate_limit_per_minute: u32, + + /// Max `/webhook` requests per minute per client key. + #[serde(default = "default_webhook_rate_limit")] + pub webhook_rate_limit_per_minute: u32, + + /// TTL for webhook idempotency keys. + #[serde(default = "default_idempotency_ttl_secs")] + pub idempotency_ttl_secs: u64, } fn default_gateway_port() -> u16 { @@ -122,6 +135,18 @@ fn default_gateway_host() -> String { "127.0.0.1".into() } +fn default_pair_rate_limit() -> u32 { + 10 +} + +fn default_webhook_rate_limit() -> u32 { + 60 +} + +fn default_idempotency_ttl_secs() -> u64 { + 300 +} + fn default_true() -> bool { true } @@ -134,6 +159,9 @@ impl Default for GatewayConfig { require_pairing: true, allow_public_bind: false, paired_tokens: Vec::new(), + pair_rate_limit_per_minute: default_pair_rate_limit(), + webhook_rate_limit_per_minute: default_webhook_rate_limit(), + idempotency_ttl_secs: default_idempotency_ttl_secs(), } } } @@ -320,6 +348,14 @@ pub struct AutonomyConfig { pub forbidden_paths: Vec, pub max_actions_per_hour: u32, pub max_cost_per_day_cents: u32, + + /// Require explicit approval for medium-risk shell commands. + #[serde(default = "default_true")] + pub require_approval_for_medium_risk: bool, + + /// Block high-risk shell commands even if allowlisted. + #[serde(default = "default_true")] + pub block_high_risk_commands: bool, } impl Default for AutonomyConfig { @@ -363,6 +399,8 @@ impl Default for AutonomyConfig { ], max_actions_per_hour: 20, max_cost_per_day_cents: 500, + require_approval_for_medium_risk: true, + block_high_risk_commands: true, } } } @@ -371,16 +409,85 @@ impl Default for AutonomyConfig { #[derive(Debug, Clone, Serialize, Deserialize)] pub struct RuntimeConfig { - /// Runtime kind (currently supported: "native"). - /// - /// Reserved values (not implemented yet): "docker", "cloudflare". + /// Runtime kind (`native` | `docker`). + #[serde(default = "default_runtime_kind")] pub kind: String, + + /// Docker runtime settings (used when `kind = "docker"`). + #[serde(default)] + pub docker: DockerRuntimeConfig, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct DockerRuntimeConfig { + /// Runtime image used to execute shell commands. + #[serde(default = "default_docker_image")] + pub image: String, + + /// Docker network mode (`none`, `bridge`, etc.). + #[serde(default = "default_docker_network")] + pub network: String, + + /// Optional memory limit in MB (`None` = no explicit limit). + #[serde(default = "default_docker_memory_limit_mb")] + pub memory_limit_mb: Option, + + /// Optional CPU limit (`None` = no explicit limit). + #[serde(default = "default_docker_cpu_limit")] + pub cpu_limit: Option, + + /// Mount root filesystem as read-only. + #[serde(default = "default_true")] + pub read_only_rootfs: bool, + + /// Mount configured workspace into `/workspace`. + #[serde(default = "default_true")] + pub mount_workspace: bool, + + /// Optional workspace root allowlist for Docker mount validation. + #[serde(default)] + pub allowed_workspace_roots: Vec, +} + +fn default_runtime_kind() -> String { + "native".into() +} + +fn default_docker_image() -> String { + "alpine:3.20".into() +} + +fn default_docker_network() -> String { + "none".into() +} + +fn default_docker_memory_limit_mb() -> Option { + Some(512) +} + +fn default_docker_cpu_limit() -> Option { + Some(1.0) +} + +impl Default for DockerRuntimeConfig { + fn default() -> Self { + Self { + image: default_docker_image(), + network: default_docker_network(), + memory_limit_mb: default_docker_memory_limit_mb(), + cpu_limit: default_docker_cpu_limit(), + read_only_rootfs: true, + mount_workspace: true, + allowed_workspace_roots: Vec::new(), + } + } } impl Default for RuntimeConfig { fn default() -> Self { Self { - kind: "native".into(), + kind: default_runtime_kind(), + docker: DockerRuntimeConfig::default(), } } } @@ -811,11 +918,86 @@ impl Config { pub fn save(&self) -> Result<()> { let toml_str = toml::to_string_pretty(self).context("Failed to serialize config")?; - fs::write(&self.config_path, toml_str).context("Failed to write config file")?; + + let parent_dir = self + .config_path + .parent() + .context("Config path must have a parent directory")?; + fs::create_dir_all(parent_dir).with_context(|| { + format!( + "Failed to create config directory: {}", + parent_dir.display() + ) + })?; + + let file_name = self + .config_path + .file_name() + .and_then(|v| v.to_str()) + .unwrap_or("config.toml"); + let temp_path = parent_dir.join(format!(".{file_name}.tmp-{}", uuid::Uuid::new_v4())); + let backup_path = parent_dir.join(format!("{file_name}.bak")); + + let mut temp_file = OpenOptions::new() + .create_new(true) + .write(true) + .open(&temp_path) + .with_context(|| { + format!( + "Failed to create temporary config file: {}", + temp_path.display() + ) + })?; + temp_file + .write_all(toml_str.as_bytes()) + .context("Failed to write temporary config contents")?; + temp_file + .sync_all() + .context("Failed to fsync temporary config file")?; + drop(temp_file); + + let had_existing_config = self.config_path.exists(); + if had_existing_config { + fs::copy(&self.config_path, &backup_path).with_context(|| { + format!( + "Failed to create config backup before atomic replace: {}", + backup_path.display() + ) + })?; + } + + if let Err(e) = fs::rename(&temp_path, &self.config_path) { + let _ = fs::remove_file(&temp_path); + if had_existing_config && backup_path.exists() { + let _ = fs::copy(&backup_path, &self.config_path); + } + anyhow::bail!("Failed to atomically replace config file: {e}"); + } + + sync_directory(parent_dir)?; + + if had_existing_config { + let _ = fs::remove_file(&backup_path); + } + Ok(()) } } +#[cfg(unix)] +fn sync_directory(path: &Path) -> Result<()> { + let dir = File::open(path) + .with_context(|| format!("Failed to open directory for fsync: {}", path.display()))?; + dir.sync_all() + .with_context(|| format!("Failed to fsync directory metadata: {}", path.display()))?; + Ok(()) +} + +#[cfg(not(unix))] +fn sync_directory(_path: &Path) -> Result<()> { + Ok(()) +} + #[cfg(test)] mod tests { use super::*; @@ -850,12 +1032,20 @@ mod tests { assert!(a.forbidden_paths.contains(&"/etc".to_string())); assert_eq!(a.max_actions_per_hour, 20); assert_eq!(a.max_cost_per_day_cents, 500); + assert!(a.require_approval_for_medium_risk); + assert!(a.block_high_risk_commands); } #[test] fn runtime_config_default() { let r = RuntimeConfig::default(); assert_eq!(r.kind, "native"); + assert_eq!(r.docker.image, "alpine:3.20"); + assert_eq!(r.docker.network, "none"); + assert_eq!(r.docker.memory_limit_mb, Some(512)); + assert_eq!(r.docker.cpu_limit, Some(1.0)); + assert!(r.docker.read_only_rootfs); + assert!(r.docker.mount_workspace); } #[test] @@ -905,9 +1095,12 @@ mod tests { forbidden_paths: vec!["/secret".into()], max_actions_per_hour: 50, max_cost_per_day_cents: 1000, + require_approval_for_medium_risk: false, + block_high_risk_commands: true, }, runtime: RuntimeConfig { kind: "docker".into(), + ..RuntimeConfig::default() }, reliability: ReliabilityConfig::default(), model_routes: Vec::new(), @@ -1022,6 +1215,38 @@ default_temperature = 0.7 let _ = fs::remove_dir_all(&dir); } + + #[test] + fn config_save_atomic_cleanup() { + let dir = + std::env::temp_dir().join(format!("zeroclaw_test_config_{}", uuid::Uuid::new_v4())); + fs::create_dir_all(&dir).unwrap(); + + let config_path = dir.join("config.toml"); + let mut config = Config::default(); + config.workspace_dir = dir.join("workspace"); + config.config_path = config_path.clone(); + config.default_model = Some("model-a".into()); + + config.save().unwrap(); + assert!(config_path.exists()); + + config.default_model = Some("model-b".into()); + config.save().unwrap(); + + let contents = fs::read_to_string(&config_path).unwrap(); + assert!(contents.contains("model-b")); + + let names: Vec = fs::read_dir(&dir) + .unwrap() + .map(|entry| entry.unwrap().file_name().to_string_lossy().to_string()) + .collect(); + assert!(!names.iter().any(|name| name.contains(".tmp-"))); + assert!(!names.iter().any(|name| name.ends_with(".bak"))); + + let _ = fs::remove_dir_all(&dir); + } + // ── Telegram / Discord config ──────────────────────────── #[test] @@ -1343,6 +1568,9 @@ channel_id = "C123" g.paired_tokens.is_empty(), "No pre-paired tokens by default" ); + assert_eq!(g.pair_rate_limit_per_minute, 10); + assert_eq!(g.webhook_rate_limit_per_minute, 60); + assert_eq!(g.idempotency_ttl_secs, 300); } #[test] @@ -1368,12 +1596,18 @@ channel_id = "C123" require_pairing: true, allow_public_bind: false, paired_tokens: vec!["zc_test_token".into()], + pair_rate_limit_per_minute: 12, + webhook_rate_limit_per_minute: 80, + idempotency_ttl_secs: 600, }; let toml_str = toml::to_string(&g).unwrap(); let parsed: GatewayConfig = toml::from_str(&toml_str).unwrap(); assert!(parsed.require_pairing); assert!(!parsed.allow_public_bind); assert_eq!(parsed.paired_tokens, vec!["zc_test_token"]); + assert_eq!(parsed.pair_rate_limit_per_minute, 12); + assert_eq!(parsed.webhook_rate_limit_per_minute, 80); + assert_eq!(parsed.idempotency_ttl_secs, 600); } #[test] diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index bede685..4f85437 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -22,9 +22,10 @@ use axum::{ routing::{get, post}, Router, }; +use std::collections::HashMap; use std::net::SocketAddr; -use std::sync::Arc; -use std::time::Duration; +use std::sync::{Arc, Mutex}; +use std::time::{Duration, Instant}; use tower_http::limit::RequestBodyLimitLayer; use tower_http::timeout::TimeoutLayer; @@ -32,6 +33,118 @@ use tower_http::timeout::TimeoutLayer; pub const MAX_BODY_SIZE: usize = 65_536; /// Request timeout (30s) — prevents slow-loris attacks pub const REQUEST_TIMEOUT_SECS: u64 = 30; +/// Sliding window used by gateway rate limiting. +pub const RATE_LIMIT_WINDOW_SECS: u64 = 60; + +#[derive(Debug)] +struct SlidingWindowRateLimiter { + limit_per_window: u32, + window: Duration, + requests: Mutex>>, +} + +impl SlidingWindowRateLimiter { + fn new(limit_per_window: u32, window: Duration) -> Self { + Self { + limit_per_window, + window, + requests: Mutex::new(HashMap::new()), + } + } + + fn allow(&self, key: &str) -> bool { + if self.limit_per_window == 0 { + return true; + } + + let now = Instant::now(); + let cutoff = now.checked_sub(self.window).unwrap_or_else(Instant::now); + + let mut requests = self + .requests + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + + let entry = requests.entry(key.to_owned()).or_default(); + entry.retain(|instant| *instant > cutoff); + + if entry.len() >= self.limit_per_window as usize { + return false; + } + + entry.push(now); + true + } +} + +#[derive(Debug)] +pub struct GatewayRateLimiter { + pair: SlidingWindowRateLimiter, + webhook: SlidingWindowRateLimiter, +} + +impl GatewayRateLimiter { + fn new(pair_per_minute: u32, webhook_per_minute: u32) -> Self { + let window = Duration::from_secs(RATE_LIMIT_WINDOW_SECS); + Self { + pair: SlidingWindowRateLimiter::new(pair_per_minute, window), + webhook: SlidingWindowRateLimiter::new(webhook_per_minute, window), + } + } + + fn allow_pair(&self, key: &str) -> bool { + self.pair.allow(key) + } + + fn allow_webhook(&self, key: &str) -> bool { + self.webhook.allow(key) + } +} + +#[derive(Debug)] +pub struct IdempotencyStore { + ttl: Duration, + keys: Mutex>, +} + +impl IdempotencyStore { + fn new(ttl: Duration) -> Self { + Self { + ttl, + keys: Mutex::new(HashMap::new()), + } + } + + /// Returns true if this key is new and is now recorded. + fn record_if_new(&self, key: &str) -> bool { + let now = Instant::now(); + let mut keys = self + .keys + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner); + + keys.retain(|_, seen_at| now.duration_since(*seen_at) < self.ttl); + + if keys.contains_key(key) { + return false; + } + + keys.insert(key.to_owned(), now); + true + } +} + +fn client_key_from_headers(headers: &HeaderMap) -> String { + for header_name in ["X-Forwarded-For", "X-Real-IP"] { + if let Some(value) = headers.get(header_name).and_then(|v| v.to_str().ok()) { + let first = value.split(',').next().unwrap_or("").trim(); + if !first.is_empty() { + return first.to_owned(); + } + } + } + "unknown".into() +} /// Shared state for all axum handlers #[derive(Clone)] @@ -43,6 +156,8 @@ pub struct AppState { pub auto_save: bool, pub webhook_secret: Option>, pub pairing: Arc, + pub rate_limiter: Arc, + pub idempotency_store: Arc, pub whatsapp: Option>, /// `WhatsApp` app secret for webhook signature verification (`X-Hub-Signature-256`) pub whatsapp_app_secret: Option>, @@ -66,17 +181,15 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> { let actual_port = listener.local_addr()?.port(); let display_addr = format!("{host}:{actual_port}"); + let provider: Arc = Arc::from(providers::create_resilient_provider( + config.default_provider.as_deref().unwrap_or("openrouter"), + config.api_key.as_deref(), + &config.reliability, + )?); let model = config .default_model .clone() .unwrap_or_else(|| "anthropic/claude-sonnet-4-20250514".into()); - let provider: Arc = Arc::from(providers::create_routed_provider( - config.default_provider.as_deref().unwrap_or("openrouter"), - config.api_key.as_deref(), - &config.reliability, - &config.model_routes, - &model, - )?); let temperature = config.default_temperature; let mem: Arc = Arc::from(memory::create_memory( &config.memory, @@ -127,6 +240,13 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> { config.gateway.require_pairing, &config.gateway.paired_tokens, )); + let rate_limiter = Arc::new(GatewayRateLimiter::new( + config.gateway.pair_rate_limit_per_minute, + config.gateway.webhook_rate_limit_per_minute, + )); + let idempotency_store = Arc::new(IdempotencyStore::new(Duration::from_secs( + config.gateway.idempotency_ttl_secs.max(1), + ))); // ── Tunnel ──────────────────────────────────────────────── let tunnel = crate::tunnel::create_tunnel(&config.tunnel)?; @@ -185,6 +305,8 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> { auto_save: config.memory.auto_save, webhook_secret, pairing, + rate_limiter, + idempotency_store, whatsapp: whatsapp_channel, whatsapp_app_secret, }; @@ -225,6 +347,16 @@ async fn handle_health(State(state): State) -> impl IntoResponse { /// POST /pair — exchange one-time code for bearer token async fn handle_pair(State(state): State, headers: HeaderMap) -> impl IntoResponse { + let client_key = client_key_from_headers(&headers); + if !state.rate_limiter.allow_pair(&client_key) { + tracing::warn!("/pair rate limit exceeded for key: {client_key}"); + let err = serde_json::json!({ + "error": "Too many pairing requests. Please retry later.", + "retry_after": RATE_LIMIT_WINDOW_SECS, + }); + return (StatusCode::TOO_MANY_REQUESTS, Json(err)); + } + let code = headers .get("X-Pairing-Code") .and_then(|v| v.to_str().ok()) @@ -270,6 +402,16 @@ async fn handle_webhook( headers: HeaderMap, body: Result, axum::extract::rejection::JsonRejection>, ) -> impl IntoResponse { + let client_key = client_key_from_headers(&headers); + if !state.rate_limiter.allow_webhook(&client_key) { + tracing::warn!("/webhook rate limit exceeded for key: {client_key}"); + let err = serde_json::json!({ + "error": "Too many webhook requests. Please retry later.", + "retry_after": RATE_LIMIT_WINDOW_SECS, + }); + return (StatusCode::TOO_MANY_REQUESTS, Json(err)); + } + // ── Bearer token auth (pairing) ── if state.pairing.require_pairing() { let auth = headers @@ -312,6 +454,24 @@ async fn handle_webhook( } }; + // ── Idempotency (optional) ── + if let Some(idempotency_key) = headers + .get("X-Idempotency-Key") + .and_then(|v| v.to_str().ok()) + .map(str::trim) + .filter(|value| !value.is_empty()) + { + if !state.idempotency_store.record_if_new(idempotency_key) { + tracing::info!("Webhook duplicate ignored (idempotency key: {idempotency_key})"); + let body = serde_json::json!({ + "status": "duplicate", + "idempotent": true, + "message": "Request already processed for this idempotency key" + }); + return (StatusCode::OK, Json(body)); + } + } + let message = &webhook_body.message; if state.auto_save { @@ -508,6 +668,13 @@ async fn handle_whatsapp_message( #[cfg(test)] mod tests { use super::*; + use crate::memory::{Memory, MemoryCategory, MemoryEntry}; + use crate::providers::Provider; + use async_trait::async_trait; + use axum::http::HeaderValue; + use axum::response::IntoResponse; + use http_body_util::BodyExt; + use std::sync::atomic::{AtomicUsize, Ordering}; #[test] fn security_body_limit_is_64kb() { @@ -547,6 +714,133 @@ mod tests { assert_clone::(); } + #[test] + fn gateway_rate_limiter_blocks_after_limit() { + let limiter = GatewayRateLimiter::new(2, 2); + assert!(limiter.allow_pair("127.0.0.1")); + assert!(limiter.allow_pair("127.0.0.1")); + assert!(!limiter.allow_pair("127.0.0.1")); + } + + #[test] + fn idempotency_store_rejects_duplicate_key() { + let store = IdempotencyStore::new(Duration::from_secs(30)); + assert!(store.record_if_new("req-1")); + assert!(!store.record_if_new("req-1")); + assert!(store.record_if_new("req-2")); + } + + #[derive(Default)] + struct MockMemory; + + #[async_trait] + impl Memory for MockMemory { + fn name(&self) -> &str { + "mock" + } + + async fn store( + &self, + _key: &str, + _content: &str, + _category: MemoryCategory, + ) -> anyhow::Result<()> { + Ok(()) + } + + async fn recall(&self, _query: &str, _limit: usize) -> anyhow::Result> { + Ok(Vec::new()) + } + + async fn get(&self, _key: &str) -> anyhow::Result> { + Ok(None) + } + + async fn list( + &self, + _category: Option<&MemoryCategory>, + ) -> anyhow::Result> { + Ok(Vec::new()) + } + + async fn forget(&self, _key: &str) -> anyhow::Result { + Ok(false) + } + + async fn count(&self) -> anyhow::Result { + Ok(0) + } + + async fn health_check(&self) -> bool { + true + } + } + + #[derive(Default)] + struct MockProvider { + calls: AtomicUsize, + } + + #[async_trait] + impl Provider for MockProvider { + async fn chat_with_system( + &self, + _system_prompt: Option<&str>, + _message: &str, + _model: &str, + _temperature: f64, + ) -> anyhow::Result { + self.calls.fetch_add(1, Ordering::SeqCst); + Ok("ok".into()) + } + } + + #[tokio::test] + async fn webhook_idempotency_skips_duplicate_provider_calls() { + let provider_impl = Arc::new(MockProvider::default()); + let provider: Arc = provider_impl.clone(); + let memory: Arc = Arc::new(MockMemory); + + let state = AppState { + provider, + model: "test-model".into(), + temperature: 0.0, + mem: memory, + auto_save: false, + webhook_secret: None, + pairing: Arc::new(PairingGuard::new(false, &[])), + rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100)), + idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300))), + whatsapp: None, + whatsapp_app_secret: None, + }; + + let mut headers = HeaderMap::new(); + headers.insert("X-Idempotency-Key", HeaderValue::from_static("abc-123")); + + let body = Ok(Json(WebhookBody { + message: "hello".into(), + })); + let first = handle_webhook(State(state.clone()), headers.clone(), body) + .await + .into_response(); + assert_eq!(first.status(), StatusCode::OK); + + let body = Ok(Json(WebhookBody { + message: "hello".into(), + })); + let second = handle_webhook(State(state), headers, body) + .await + .into_response(); + assert_eq!(second.status(), StatusCode::OK); + + let payload = second.into_body().collect().await.unwrap().to_bytes(); + let parsed: serde_json::Value = serde_json::from_slice(&payload).unwrap(); + assert_eq!(parsed["status"], "duplicate"); + assert_eq!(parsed["idempotent"], true); + assert_eq!(provider_impl.calls.load(Ordering::SeqCst), 1); + } + // ══════════════════════════════════════════════════════════ // WhatsApp Signature Verification Tests (CWE-345 Prevention) // ══════════════════════════════════════════════════════════ @@ -572,7 +866,11 @@ mod tests { let signature_header = compute_whatsapp_signature_header(app_secret, body); - assert!(verify_whatsapp_signature(app_secret, body, &signature_header)); + assert!(verify_whatsapp_signature( + app_secret, + body, + &signature_header + )); } #[test] @@ -583,7 +881,11 @@ mod tests { let signature_header = compute_whatsapp_signature_header(wrong_secret, body); - assert!(!verify_whatsapp_signature(app_secret, body, &signature_header)); + assert!(!verify_whatsapp_signature( + app_secret, + body, + &signature_header + )); } #[test] @@ -610,7 +912,11 @@ mod tests { // Signature without "sha256=" prefix let signature_header = "abc123def456"; - assert!(!verify_whatsapp_signature(app_secret, body, signature_header)); + assert!(!verify_whatsapp_signature( + app_secret, + body, + signature_header + )); } #[test] @@ -643,7 +949,11 @@ mod tests { let signature_header = compute_whatsapp_signature_header(app_secret, body); - assert!(verify_whatsapp_signature(app_secret, body, &signature_header)); + assert!(verify_whatsapp_signature( + app_secret, + body, + &signature_header + )); } #[test] @@ -653,7 +963,11 @@ mod tests { let signature_header = compute_whatsapp_signature_header(app_secret, body); - assert!(verify_whatsapp_signature(app_secret, body, &signature_header)); + assert!(verify_whatsapp_signature( + app_secret, + body, + &signature_header + )); } #[test] @@ -663,7 +977,11 @@ mod tests { let signature_header = compute_whatsapp_signature_header(app_secret, body); - assert!(verify_whatsapp_signature(app_secret, body, &signature_header)); + assert!(verify_whatsapp_signature( + app_secret, + body, + &signature_header + )); } #[test] diff --git a/src/runtime/docker.rs b/src/runtime/docker.rs new file mode 100644 index 0000000..eaa3d09 --- /dev/null +++ b/src/runtime/docker.rs @@ -0,0 +1,199 @@ +use super::traits::RuntimeAdapter; +use crate::config::DockerRuntimeConfig; +use anyhow::{Context, Result}; +use std::path::{Path, PathBuf}; + +/// Docker runtime with lightweight container isolation. +#[derive(Debug, Clone)] +pub struct DockerRuntime { + config: DockerRuntimeConfig, +} + +impl DockerRuntime { + pub fn new(config: DockerRuntimeConfig) -> Self { + Self { config } + } + + fn workspace_mount_path(&self, workspace_dir: &Path) -> Result { + let resolved = workspace_dir + .canonicalize() + .unwrap_or_else(|_| workspace_dir.to_path_buf()); + + if !resolved.is_absolute() { + anyhow::bail!( + "Docker runtime requires an absolute workspace path, got: {}", + resolved.display() + ); + } + + if resolved == Path::new("/") { + anyhow::bail!("Refusing to mount filesystem root (/) into docker runtime"); + } + + if self.config.allowed_workspace_roots.is_empty() { + return Ok(resolved); + } + + let allowed = self.config.allowed_workspace_roots.iter().any(|root| { + let root_path = Path::new(root) + .canonicalize() + .unwrap_or_else(|_| PathBuf::from(root)); + resolved.starts_with(root_path) + }); + + if !allowed { + anyhow::bail!( + "Workspace path {} is not in runtime.docker.allowed_workspace_roots", + resolved.display() + ); + } + + Ok(resolved) + } +} + +impl RuntimeAdapter for DockerRuntime { + fn name(&self) -> &str { + "docker" + } + + fn has_shell_access(&self) -> bool { + true + } + + fn has_filesystem_access(&self) -> bool { + self.config.mount_workspace + } + + fn storage_path(&self) -> PathBuf { + if self.config.mount_workspace { + PathBuf::from("/workspace/.zeroclaw") + } else { + PathBuf::from("/tmp/.zeroclaw") + } + } + + fn supports_long_running(&self) -> bool { + false + } + + fn memory_budget(&self) -> u64 { + self.config + .memory_limit_mb + .map_or(0, |mb| mb.saturating_mul(1024 * 1024)) + } + + fn build_shell_command( + &self, + command: &str, + workspace_dir: &Path, + ) -> anyhow::Result { + let mut process = tokio::process::Command::new("docker"); + process + .arg("run") + .arg("--rm") + .arg("--init") + .arg("--interactive"); + + let network = self.config.network.trim(); + if !network.is_empty() { + process.arg("--network").arg(network); + } + + if let Some(memory_limit_mb) = self.config.memory_limit_mb.filter(|mb| *mb > 0) { + process.arg("--memory").arg(format!("{memory_limit_mb}m")); + } + + if let Some(cpu_limit) = self.config.cpu_limit.filter(|cpus| *cpus > 0.0) { + process.arg("--cpus").arg(cpu_limit.to_string()); + } + + if self.config.read_only_rootfs { + process.arg("--read-only"); + } + + if self.config.mount_workspace { + let host_workspace = self.workspace_mount_path(workspace_dir).with_context(|| { + format!( + "Failed to validate workspace mount path {}", + workspace_dir.display() + ) + })?; + + process + .arg("--volume") + .arg(format!("{}:/workspace:rw", host_workspace.display())) + .arg("--workdir") + .arg("/workspace"); + } + + process + .arg(self.config.image.trim()) + .arg("sh") + .arg("-c") + .arg(command); + + Ok(process) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn docker_runtime_name() { + let runtime = DockerRuntime::new(DockerRuntimeConfig::default()); + assert_eq!(runtime.name(), "docker"); + } + + #[test] + fn docker_runtime_memory_budget() { + let mut cfg = DockerRuntimeConfig::default(); + cfg.memory_limit_mb = Some(256); + let runtime = DockerRuntime::new(cfg); + assert_eq!(runtime.memory_budget(), 256 * 1024 * 1024); + } + + #[test] + fn docker_build_shell_command_includes_runtime_flags() { + let cfg = DockerRuntimeConfig { + image: "alpine:3.20".into(), + network: "none".into(), + memory_limit_mb: Some(128), + cpu_limit: Some(1.5), + read_only_rootfs: true, + mount_workspace: true, + allowed_workspace_roots: Vec::new(), + }; + let runtime = DockerRuntime::new(cfg); + + let workspace = std::env::temp_dir(); + let command = runtime + .build_shell_command("echo hello", &workspace) + .unwrap(); + let debug = format!("{command:?}"); + + assert!(debug.contains("docker")); + assert!(debug.contains("--memory")); + assert!(debug.contains("128m")); + assert!(debug.contains("--cpus")); + assert!(debug.contains("1.5")); + assert!(debug.contains("--workdir")); + assert!(debug.contains("echo hello")); + } + + #[test] + fn docker_workspace_allowlist_blocks_outside_paths() { + let cfg = DockerRuntimeConfig { + allowed_workspace_roots: vec!["/tmp/allowed".into()], + ..DockerRuntimeConfig::default() + }; + let runtime = DockerRuntime::new(cfg); + + let outside = PathBuf::from("/tmp/blocked_workspace"); + let result = runtime.build_shell_command("echo test", &outside); + + assert!(result.is_err()); + } +} diff --git a/src/runtime/mod.rs b/src/runtime/mod.rs index 9ed0ee0..cea7aa3 100644 --- a/src/runtime/mod.rs +++ b/src/runtime/mod.rs @@ -1,6 +1,8 @@ +pub mod docker; pub mod native; pub mod traits; +pub use docker::DockerRuntime; pub use native::NativeRuntime; pub use traits::RuntimeAdapter; @@ -10,18 +12,14 @@ use crate::config::RuntimeConfig; pub fn create_runtime(config: &RuntimeConfig) -> anyhow::Result> { match config.kind.as_str() { "native" => Ok(Box::new(NativeRuntime::new())), - "docker" => anyhow::bail!( - "runtime.kind='docker' is not implemented yet. Use runtime.kind='native' until container runtime support lands." - ), + "docker" => Ok(Box::new(DockerRuntime::new(config.docker.clone()))), "cloudflare" => anyhow::bail!( "runtime.kind='cloudflare' is not implemented yet. Use runtime.kind='native' for now." ), - other if other.trim().is_empty() => anyhow::bail!( - "runtime.kind cannot be empty. Supported values: native" - ), - other => anyhow::bail!( - "Unknown runtime kind '{other}'. Supported values: native" - ), + other if other.trim().is_empty() => { + anyhow::bail!("runtime.kind cannot be empty. Supported values: native, docker") + } + other => anyhow::bail!("Unknown runtime kind '{other}'. Supported values: native, docker"), } } @@ -33,6 +31,7 @@ mod tests { fn factory_native() { let cfg = RuntimeConfig { kind: "native".into(), + ..RuntimeConfig::default() }; let rt = create_runtime(&cfg).unwrap(); assert_eq!(rt.name(), "native"); @@ -40,20 +39,21 @@ mod tests { } #[test] - fn factory_docker_errors() { + fn factory_docker() { let cfg = RuntimeConfig { kind: "docker".into(), + ..RuntimeConfig::default() }; - match create_runtime(&cfg) { - Err(err) => assert!(err.to_string().contains("not implemented")), - Ok(_) => panic!("docker runtime should error"), - } + let rt = create_runtime(&cfg).unwrap(); + assert_eq!(rt.name(), "docker"); + assert!(rt.has_shell_access()); } #[test] fn factory_cloudflare_errors() { let cfg = RuntimeConfig { kind: "cloudflare".into(), + ..RuntimeConfig::default() }; match create_runtime(&cfg) { Err(err) => assert!(err.to_string().contains("not implemented")), @@ -65,6 +65,7 @@ mod tests { fn factory_unknown_errors() { let cfg = RuntimeConfig { kind: "wasm-edge-unknown".into(), + ..RuntimeConfig::default() }; match create_runtime(&cfg) { Err(err) => assert!(err.to_string().contains("Unknown runtime kind")), @@ -76,6 +77,7 @@ mod tests { fn factory_empty_errors() { let cfg = RuntimeConfig { kind: String::new(), + ..RuntimeConfig::default() }; match create_runtime(&cfg) { Err(err) => assert!(err.to_string().contains("cannot be empty")), diff --git a/src/runtime/native.rs b/src/runtime/native.rs index 4b0ef3c..927c895 100644 --- a/src/runtime/native.rs +++ b/src/runtime/native.rs @@ -1,5 +1,5 @@ use super::traits::RuntimeAdapter; -use std::path::PathBuf; +use std::path::{Path, PathBuf}; /// Native runtime — full access, runs on Mac/Linux/Docker/Raspberry Pi pub struct NativeRuntime; @@ -33,6 +33,16 @@ impl RuntimeAdapter for NativeRuntime { fn supports_long_running(&self) -> bool { true } + + fn build_shell_command( + &self, + command: &str, + workspace_dir: &Path, + ) -> anyhow::Result { + let mut process = tokio::process::Command::new("sh"); + process.arg("-c").arg(command).current_dir(workspace_dir); + Ok(process) + } } #[cfg(test)] @@ -69,4 +79,14 @@ mod tests { let path = NativeRuntime::new().storage_path(); assert!(path.to_string_lossy().contains("zeroclaw")); } + + #[test] + fn native_builds_shell_command() { + let cwd = std::env::temp_dir(); + let command = NativeRuntime::new() + .build_shell_command("echo hello", &cwd) + .unwrap(); + let debug = format!("{command:?}"); + assert!(debug.contains("echo hello")); + } } diff --git a/src/runtime/traits.rs b/src/runtime/traits.rs index cbff5b1..743ee5e 100644 --- a/src/runtime/traits.rs +++ b/src/runtime/traits.rs @@ -1,4 +1,4 @@ -use std::path::PathBuf; +use std::path::{Path, PathBuf}; /// Runtime adapter — abstracts platform differences so the same agent /// code runs on native, Docker, Cloudflare Workers, Raspberry Pi, etc. @@ -22,4 +22,11 @@ pub trait RuntimeAdapter: Send + Sync { fn memory_budget(&self) -> u64 { 0 } + + /// Build a shell command process for this runtime. + fn build_shell_command( + &self, + command: &str, + workspace_dir: &Path, + ) -> anyhow::Result; } diff --git a/src/security/policy.rs b/src/security/policy.rs index 1dd6963..57e8526 100644 --- a/src/security/policy.rs +++ b/src/security/policy.rs @@ -16,6 +16,14 @@ pub enum AutonomyLevel { Full, } +/// Risk score for shell command execution. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum CommandRiskLevel { + Low, + Medium, + High, +} + /// Sliding-window action tracker for rate limiting. #[derive(Debug)] pub struct ActionTracker { @@ -80,6 +88,8 @@ pub struct SecurityPolicy { pub forbidden_paths: Vec, pub max_actions_per_hour: u32, pub max_cost_per_day_cents: u32, + pub require_approval_for_medium_risk: bool, + pub block_high_risk_commands: bool, pub tracker: ActionTracker, } @@ -127,6 +137,8 @@ impl Default for SecurityPolicy { ], max_actions_per_hour: 20, max_cost_per_day_cents: 500, + require_approval_for_medium_risk: true, + block_high_risk_commands: true, tracker: ActionTracker::new(), } } @@ -156,6 +168,163 @@ fn skip_env_assignments(s: &str) -> &str { } impl SecurityPolicy { + /// Classify command risk. Any high-risk segment marks the whole command high. + pub fn command_risk_level(&self, command: &str) -> CommandRiskLevel { + let mut normalized = command.to_string(); + for sep in ["&&", "||"] { + normalized = normalized.replace(sep, "\x00"); + } + for sep in ['\n', ';', '|'] { + normalized = normalized.replace(sep, "\x00"); + } + + let mut saw_medium = false; + + for segment in normalized.split('\x00') { + let segment = segment.trim(); + if segment.is_empty() { + continue; + } + + let cmd_part = skip_env_assignments(segment); + let mut words = cmd_part.split_whitespace(); + let Some(base_raw) = words.next() else { + continue; + }; + + let base = base_raw + .rsplit('/') + .next() + .unwrap_or("") + .to_ascii_lowercase(); + + let args: Vec = words.map(|w| w.to_ascii_lowercase()).collect(); + let joined_segment = cmd_part.to_ascii_lowercase(); + + // High-risk commands + if matches!( + base.as_str(), + "rm" | "mkfs" + | "dd" + | "shutdown" + | "reboot" + | "halt" + | "poweroff" + | "sudo" + | "su" + | "chown" + | "chmod" + | "useradd" + | "userdel" + | "usermod" + | "passwd" + | "mount" + | "umount" + | "iptables" + | "ufw" + | "firewall-cmd" + | "curl" + | "wget" + | "nc" + | "ncat" + | "netcat" + | "scp" + | "ssh" + | "ftp" + | "telnet" + ) { + return CommandRiskLevel::High; + } + + if joined_segment.contains("rm -rf /") + || joined_segment.contains("rm -fr /") + || joined_segment.contains(":(){:|:&};:") + { + return CommandRiskLevel::High; + } + + // Medium-risk commands (state-changing, but not inherently destructive) + let medium = match base.as_str() { + "git" => args.first().is_some_and(|verb| { + matches!( + verb.as_str(), + "commit" + | "push" + | "reset" + | "clean" + | "rebase" + | "merge" + | "cherry-pick" + | "revert" + | "branch" + | "checkout" + | "switch" + | "tag" + ) + }), + "npm" | "pnpm" | "yarn" => args.first().is_some_and(|verb| { + matches!( + verb.as_str(), + "install" | "add" | "remove" | "uninstall" | "update" | "publish" + ) + }), + "cargo" => args.first().is_some_and(|verb| { + matches!( + verb.as_str(), + "add" | "remove" | "install" | "clean" | "publish" + ) + }), + "touch" | "mkdir" | "mv" | "cp" | "ln" => true, + _ => false, + }; + + saw_medium |= medium; + } + + if saw_medium { + CommandRiskLevel::Medium + } else { + CommandRiskLevel::Low + } + } + + /// Validate full command execution policy (allowlist + risk gate). + pub fn validate_command_execution( + &self, + command: &str, + approved: bool, + ) -> Result { + if !self.is_command_allowed(command) { + return Err(format!("Command not allowed by security policy: {command}")); + } + + let risk = self.command_risk_level(command); + + if risk == CommandRiskLevel::High { + if self.block_high_risk_commands { + return Err("Command blocked: high-risk command is disallowed by policy".into()); + } + if self.autonomy == AutonomyLevel::Supervised && !approved { + return Err( + "Command requires explicit approval (approved=true): high-risk operation" + .into(), + ); + } + } + + if risk == CommandRiskLevel::Medium + && self.autonomy == AutonomyLevel::Supervised + && self.require_approval_for_medium_risk + && !approved + { + return Err( + "Command requires explicit approval (approved=true): medium-risk operation".into(), + ); + } + + Ok(risk) + } + /// Check if a shell command is allowed. /// /// Validates the **entire** command string, not just the first word: @@ -329,6 +498,8 @@ impl SecurityPolicy { forbidden_paths: autonomy_config.forbidden_paths.clone(), max_actions_per_hour: autonomy_config.max_actions_per_hour, max_cost_per_day_cents: autonomy_config.max_cost_per_day_cents, + require_approval_for_medium_risk: autonomy_config.require_approval_for_medium_risk, + block_high_risk_commands: autonomy_config.block_high_risk_commands, tracker: ActionTracker::new(), } } @@ -473,6 +644,71 @@ mod tests { assert!(!p.is_command_allowed("echo hello")); } + #[test] + fn command_risk_low_for_read_commands() { + let p = default_policy(); + assert_eq!(p.command_risk_level("git status"), CommandRiskLevel::Low); + assert_eq!(p.command_risk_level("ls -la"), CommandRiskLevel::Low); + } + + #[test] + fn command_risk_medium_for_mutating_commands() { + let p = SecurityPolicy { + allowed_commands: vec!["git".into(), "touch".into()], + ..SecurityPolicy::default() + }; + assert_eq!( + p.command_risk_level("git reset --hard HEAD~1"), + CommandRiskLevel::Medium + ); + assert_eq!( + p.command_risk_level("touch file.txt"), + CommandRiskLevel::Medium + ); + } + + #[test] + fn command_risk_high_for_dangerous_commands() { + let p = SecurityPolicy { + allowed_commands: vec!["rm".into()], + ..SecurityPolicy::default() + }; + assert_eq!( + p.command_risk_level("rm -rf /tmp/test"), + CommandRiskLevel::High + ); + } + + #[test] + fn validate_command_requires_approval_for_medium_risk() { + let p = SecurityPolicy { + autonomy: AutonomyLevel::Supervised, + require_approval_for_medium_risk: true, + allowed_commands: vec!["touch".into()], + ..SecurityPolicy::default() + }; + + let denied = p.validate_command_execution("touch test.txt", false); + assert!(denied.is_err()); + assert!(denied.unwrap_err().contains("requires explicit approval"),); + + let allowed = p.validate_command_execution("touch test.txt", true); + assert_eq!(allowed.unwrap(), CommandRiskLevel::Medium); + } + + #[test] + fn validate_command_blocks_high_risk_by_default() { + let p = SecurityPolicy { + autonomy: AutonomyLevel::Supervised, + allowed_commands: vec!["rm".into()], + ..SecurityPolicy::default() + }; + + let result = p.validate_command_execution("rm -rf /tmp/test", true); + assert!(result.is_err()); + assert!(result.unwrap_err().contains("high-risk")); + } + // ── is_path_allowed ───────────────────────────────────── #[test] @@ -546,6 +782,8 @@ mod tests { forbidden_paths: vec!["/secret".into()], max_actions_per_hour: 100, max_cost_per_day_cents: 1000, + require_approval_for_medium_risk: false, + block_high_risk_commands: false, }; let workspace = PathBuf::from("/tmp/test-workspace"); let policy = SecurityPolicy::from_config(&autonomy_config, &workspace); @@ -556,6 +794,8 @@ mod tests { assert_eq!(policy.forbidden_paths, vec!["/secret"]); assert_eq!(policy.max_actions_per_hour, 100); assert_eq!(policy.max_cost_per_day_cents, 1000); + assert!(!policy.require_approval_for_medium_risk); + assert!(!policy.block_high_risk_commands); assert_eq!(policy.workspace_dir, PathBuf::from("/tmp/test-workspace")); } @@ -570,6 +810,8 @@ mod tests { assert!(!p.forbidden_paths.is_empty()); assert!(p.max_actions_per_hour > 0); assert!(p.max_cost_per_day_cents > 0); + assert!(p.require_approval_for_medium_risk); + assert!(p.block_high_risk_commands); } // ── ActionTracker / rate limiting ─────────────────────── @@ -853,6 +1095,8 @@ mod tests { forbidden_paths: vec![], max_actions_per_hour: 10, max_cost_per_day_cents: 100, + require_approval_for_medium_risk: true, + block_high_risk_commands: true, }; let workspace = PathBuf::from("/tmp/test"); let policy = SecurityPolicy::from_config(&autonomy_config, &workspace); diff --git a/src/tools/mod.rs b/src/tools/mod.rs index e02154d..6f9891f 100644 --- a/src/tools/mod.rs +++ b/src/tools/mod.rs @@ -23,13 +23,22 @@ pub use traits::Tool; pub use traits::{ToolResult, ToolSpec}; use crate::memory::Memory; +use crate::runtime::{NativeRuntime, RuntimeAdapter}; use crate::security::SecurityPolicy; use std::sync::Arc; /// Create the default tool registry pub fn default_tools(security: Arc) -> Vec> { + default_tools_with_runtime(security, Arc::new(NativeRuntime::new())) +} + +/// Create the default tool registry with explicit runtime adapter. +pub fn default_tools_with_runtime( + security: Arc, + runtime: Arc, +) -> Vec> { vec![ - Box::new(ShellTool::new(security.clone())), + Box::new(ShellTool::new(security.clone(), runtime)), Box::new(FileReadTool::new(security.clone())), Box::new(FileWriteTool::new(security)), ] @@ -41,9 +50,26 @@ pub fn all_tools( memory: Arc, composio_key: Option<&str>, browser_config: &crate::config::BrowserConfig, +) -> Vec> { + all_tools_with_runtime( + security, + Arc::new(NativeRuntime::new()), + memory, + composio_key, + browser_config, + ) +} + +/// Create full tool registry including memory tools and optional Composio. +pub fn all_tools_with_runtime( + security: &Arc, + runtime: Arc, + memory: Arc, + composio_key: Option<&str>, + browser_config: &crate::config::BrowserConfig, ) -> Vec> { let mut tools: Vec> = vec![ - Box::new(ShellTool::new(security.clone())), + Box::new(ShellTool::new(security.clone(), runtime)), Box::new(FileReadTool::new(security.clone())), Box::new(FileWriteTool::new(security.clone())), Box::new(MemoryStoreTool::new(memory.clone())), diff --git a/src/tools/shell.rs b/src/tools/shell.rs index a06558b..662d7ab 100644 --- a/src/tools/shell.rs +++ b/src/tools/shell.rs @@ -1,4 +1,5 @@ use super::traits::{Tool, ToolResult}; +use crate::runtime::RuntimeAdapter; use crate::security::SecurityPolicy; use async_trait::async_trait; use serde_json::json; @@ -18,11 +19,12 @@ const SAFE_ENV_VARS: &[&str] = &[ /// Shell command execution tool with sandboxing pub struct ShellTool { security: Arc, + runtime: Arc, } impl ShellTool { - pub fn new(security: Arc) -> Self { - Self { security } + pub fn new(security: Arc, runtime: Arc) -> Self { + Self { security, runtime } } } @@ -43,6 +45,11 @@ impl Tool for ShellTool { "command": { "type": "string", "description": "The shell command to execute" + }, + "approved": { + "type": "boolean", + "description": "Set true to explicitly approve medium/high-risk commands in supervised mode", + "default": false } }, "required": ["command"] @@ -54,24 +61,55 @@ impl Tool for ShellTool { .get("command") .and_then(|v| v.as_str()) .ok_or_else(|| anyhow::anyhow!("Missing 'command' parameter"))?; + let approved = args + .get("approved") + .and_then(|v| v.as_bool()) + .unwrap_or(false); - // Security check: validate command against allowlist - if !self.security.is_command_allowed(command) { + if self.security.is_rate_limited() { return Ok(ToolResult { success: false, output: String::new(), - error: Some(format!("Command not allowed by security policy: {command}")), + error: Some("Rate limit exceeded: too many actions in the last hour".into()), + }); + } + + match self.security.validate_command_execution(command, approved) { + Ok(_) => {} + Err(reason) => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(reason), + }); + } + } + + if !self.security.record_action() { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("Rate limit exceeded: action budget exhausted".into()), }); } // Execute with timeout to prevent hanging commands. // Clear the environment to prevent leaking API keys and other secrets // (CWE-200), then re-add only safe, functional variables. - let mut cmd = tokio::process::Command::new("sh"); - cmd.arg("-c") - .arg(command) - .current_dir(&self.security.workspace_dir) - .env_clear(); + let mut cmd = match self + .runtime + .build_shell_command(command, &self.security.workspace_dir) + { + Ok(cmd) => cmd, + Err(e) => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!("Failed to build runtime command: {e}")), + }); + } + }; + cmd.env_clear(); for var in SAFE_ENV_VARS { if let Ok(val) = std::env::var(var) { @@ -126,6 +164,7 @@ impl Tool for ShellTool { #[cfg(test)] mod tests { use super::*; + use crate::runtime::{NativeRuntime, RuntimeAdapter}; use crate::security::{AutonomyLevel, SecurityPolicy}; fn test_security(autonomy: AutonomyLevel) -> Arc { @@ -136,32 +175,37 @@ mod tests { }) } + fn test_runtime() -> Arc { + Arc::new(NativeRuntime::new()) + } + #[test] fn shell_tool_name() { - let tool = ShellTool::new(test_security(AutonomyLevel::Supervised)); + let tool = ShellTool::new(test_security(AutonomyLevel::Supervised), test_runtime()); assert_eq!(tool.name(), "shell"); } #[test] fn shell_tool_description() { - let tool = ShellTool::new(test_security(AutonomyLevel::Supervised)); + let tool = ShellTool::new(test_security(AutonomyLevel::Supervised), test_runtime()); assert!(!tool.description().is_empty()); } #[test] fn shell_tool_schema_has_command() { - let tool = ShellTool::new(test_security(AutonomyLevel::Supervised)); + let tool = ShellTool::new(test_security(AutonomyLevel::Supervised), test_runtime()); let schema = tool.parameters_schema(); assert!(schema["properties"]["command"].is_object()); assert!(schema["required"] .as_array() .unwrap() .contains(&json!("command"))); + assert!(schema["properties"]["approved"].is_object()); } #[tokio::test] async fn shell_executes_allowed_command() { - let tool = ShellTool::new(test_security(AutonomyLevel::Supervised)); + let tool = ShellTool::new(test_security(AutonomyLevel::Supervised), test_runtime()); let result = tool .execute(json!({"command": "echo hello"})) .await @@ -173,15 +217,16 @@ mod tests { #[tokio::test] async fn shell_blocks_disallowed_command() { - let tool = ShellTool::new(test_security(AutonomyLevel::Supervised)); + let tool = ShellTool::new(test_security(AutonomyLevel::Supervised), test_runtime()); let result = tool.execute(json!({"command": "rm -rf /"})).await.unwrap(); assert!(!result.success); - assert!(result.error.as_ref().unwrap().contains("not allowed")); + let error = result.error.as_deref().unwrap_or(""); + assert!(error.contains("not allowed") || error.contains("high-risk")); } #[tokio::test] async fn shell_blocks_readonly() { - let tool = ShellTool::new(test_security(AutonomyLevel::ReadOnly)); + let tool = ShellTool::new(test_security(AutonomyLevel::ReadOnly), test_runtime()); let result = tool.execute(json!({"command": "ls"})).await.unwrap(); assert!(!result.success); assert!(result.error.as_ref().unwrap().contains("not allowed")); @@ -189,7 +234,7 @@ mod tests { #[tokio::test] async fn shell_missing_command_param() { - let tool = ShellTool::new(test_security(AutonomyLevel::Supervised)); + let tool = ShellTool::new(test_security(AutonomyLevel::Supervised), test_runtime()); let result = tool.execute(json!({})).await; assert!(result.is_err()); assert!(result.unwrap_err().to_string().contains("command")); @@ -197,14 +242,14 @@ mod tests { #[tokio::test] async fn shell_wrong_type_param() { - let tool = ShellTool::new(test_security(AutonomyLevel::Supervised)); + let tool = ShellTool::new(test_security(AutonomyLevel::Supervised), test_runtime()); let result = tool.execute(json!({"command": 123})).await; assert!(result.is_err()); } #[tokio::test] async fn shell_captures_exit_code() { - let tool = ShellTool::new(test_security(AutonomyLevel::Supervised)); + let tool = ShellTool::new(test_security(AutonomyLevel::Supervised), test_runtime()); let result = tool .execute(json!({"command": "ls /nonexistent_dir_xyz"})) .await @@ -250,7 +295,7 @@ mod tests { let _g1 = EnvGuard::set("API_KEY", "sk-test-secret-12345"); let _g2 = EnvGuard::set("ZEROCLAW_API_KEY", "sk-test-secret-67890"); - let tool = ShellTool::new(test_security_with_env_cmd()); + let tool = ShellTool::new(test_security_with_env_cmd(), test_runtime()); let result = tool.execute(json!({"command": "env"})).await.unwrap(); assert!(result.success); assert!( @@ -265,7 +310,7 @@ mod tests { #[tokio::test] async fn shell_preserves_path_and_home() { - let tool = ShellTool::new(test_security_with_env_cmd()); + let tool = ShellTool::new(test_security_with_env_cmd(), test_runtime()); let result = tool .execute(json!({"command": "echo $HOME"})) @@ -287,4 +332,37 @@ mod tests { "PATH should be available in shell" ); } + + #[tokio::test] + async fn shell_requires_approval_for_medium_risk_command() { + let security = Arc::new(SecurityPolicy { + autonomy: AutonomyLevel::Supervised, + allowed_commands: vec!["touch".into()], + workspace_dir: std::env::temp_dir(), + ..SecurityPolicy::default() + }); + + let tool = ShellTool::new(security.clone(), test_runtime()); + let denied = tool + .execute(json!({"command": "touch zeroclaw_shell_approval_test"})) + .await + .unwrap(); + assert!(!denied.success); + assert!(denied + .error + .as_deref() + .unwrap_or("") + .contains("explicit approval")); + + let allowed = tool + .execute(json!({ + "command": "touch zeroclaw_shell_approval_test", + "approved": true + })) + .await + .unwrap(); + assert!(allowed.success); + + let _ = std::fs::remove_file(std::env::temp_dir().join("zeroclaw_shell_approval_test")); + } }