feat(config): make config writes atomic with rollback-safe replacement (#190)

* feat(runtime): add Docker runtime MVP and runtime-aware command builder

* feat(security): add shell risk classification, approval gates, and action throttling

* feat(gateway): add per-endpoint rate limiting and webhook idempotency

* feat(config): make config writes atomic with rollback-safe replacement

---------

Co-authored-by: chumyin <chumyin@users.noreply.github.com>
This commit is contained in:
Chummy 2026-02-16 01:18:45 +08:00 committed by GitHub
parent f1e3b1166d
commit b0e1e32819
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
11 changed files with 1202 additions and 67 deletions

View file

@ -40,7 +40,8 @@ pub async fn run(
// ── Wire up agnostic subsystems ────────────────────────────── // ── Wire up agnostic subsystems ──────────────────────────────
let observer: Arc<dyn Observer> = let observer: Arc<dyn Observer> =
Arc::from(observability::create_observer(&config.observability)); Arc::from(observability::create_observer(&config.observability));
let _runtime = runtime::create_runtime(&config.runtime)?; let runtime: Arc<dyn runtime::RuntimeAdapter> =
Arc::from(runtime::create_runtime(&config.runtime)?);
let security = Arc::new(SecurityPolicy::from_config( let security = Arc::new(SecurityPolicy::from_config(
&config.autonomy, &config.autonomy,
&config.workspace_dir, &config.workspace_dir,
@ -60,7 +61,13 @@ pub async fn run(
} else { } else {
None None
}; };
let _tools = tools::all_tools(&security, mem.clone(), composio_key, &config.browser); let _tools = tools::all_tools_with_runtime(
&security,
runtime,
mem.clone(),
composio_key,
&config.browser,
);
// ── Resolve provider ───────────────────────────────────────── // ── Resolve provider ─────────────────────────────────────────
let provider_name = provider_override let provider_name = provider_override

View file

@ -2,7 +2,7 @@ pub mod schema;
pub use schema::{ pub use schema::{
AutonomyConfig, BrowserConfig, ChannelsConfig, ComposioConfig, Config, DiscordConfig, AutonomyConfig, BrowserConfig, ChannelsConfig, ComposioConfig, Config, DiscordConfig,
GatewayConfig, HeartbeatConfig, IMessageConfig, IdentityConfig, MatrixConfig, MemoryConfig, DockerRuntimeConfig, GatewayConfig, HeartbeatConfig, IMessageConfig, IdentityConfig,
ModelRouteConfig, ObservabilityConfig, ReliabilityConfig, RuntimeConfig, SecretsConfig, MatrixConfig, MemoryConfig, ModelRouteConfig, ObservabilityConfig, ReliabilityConfig,
SlackConfig, TelegramConfig, TunnelConfig, WebhookConfig, RuntimeConfig, SecretsConfig, SlackConfig, TelegramConfig, TunnelConfig, WebhookConfig,
}; };

View file

@ -2,8 +2,9 @@ use crate::security::AutonomyLevel;
use anyhow::{Context, Result}; use anyhow::{Context, Result};
use directories::UserDirs; use directories::UserDirs;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use std::fs; use std::fs::{self, File, OpenOptions};
use std::path::PathBuf; use std::io::Write;
use std::path::{Path, PathBuf};
// ── Top-level config ────────────────────────────────────────────── // ── Top-level config ──────────────────────────────────────────────
@ -112,6 +113,18 @@ pub struct GatewayConfig {
/// Paired bearer tokens (managed automatically, not user-edited) /// Paired bearer tokens (managed automatically, not user-edited)
#[serde(default)] #[serde(default)]
pub paired_tokens: Vec<String>, pub paired_tokens: Vec<String>,
/// Max `/pair` requests per minute per client key.
#[serde(default = "default_pair_rate_limit")]
pub pair_rate_limit_per_minute: u32,
/// Max `/webhook` requests per minute per client key.
#[serde(default = "default_webhook_rate_limit")]
pub webhook_rate_limit_per_minute: u32,
/// TTL for webhook idempotency keys.
#[serde(default = "default_idempotency_ttl_secs")]
pub idempotency_ttl_secs: u64,
} }
fn default_gateway_port() -> u16 { fn default_gateway_port() -> u16 {
@ -122,6 +135,18 @@ fn default_gateway_host() -> String {
"127.0.0.1".into() "127.0.0.1".into()
} }
fn default_pair_rate_limit() -> u32 {
10
}
fn default_webhook_rate_limit() -> u32 {
60
}
fn default_idempotency_ttl_secs() -> u64 {
300
}
fn default_true() -> bool { fn default_true() -> bool {
true true
} }
@ -134,6 +159,9 @@ impl Default for GatewayConfig {
require_pairing: true, require_pairing: true,
allow_public_bind: false, allow_public_bind: false,
paired_tokens: Vec::new(), paired_tokens: Vec::new(),
pair_rate_limit_per_minute: default_pair_rate_limit(),
webhook_rate_limit_per_minute: default_webhook_rate_limit(),
idempotency_ttl_secs: default_idempotency_ttl_secs(),
} }
} }
} }
@ -320,6 +348,14 @@ pub struct AutonomyConfig {
pub forbidden_paths: Vec<String>, pub forbidden_paths: Vec<String>,
pub max_actions_per_hour: u32, pub max_actions_per_hour: u32,
pub max_cost_per_day_cents: u32, pub max_cost_per_day_cents: u32,
/// Require explicit approval for medium-risk shell commands.
#[serde(default = "default_true")]
pub require_approval_for_medium_risk: bool,
/// Block high-risk shell commands even if allowlisted.
#[serde(default = "default_true")]
pub block_high_risk_commands: bool,
} }
impl Default for AutonomyConfig { impl Default for AutonomyConfig {
@ -363,6 +399,8 @@ impl Default for AutonomyConfig {
], ],
max_actions_per_hour: 20, max_actions_per_hour: 20,
max_cost_per_day_cents: 500, max_cost_per_day_cents: 500,
require_approval_for_medium_risk: true,
block_high_risk_commands: true,
} }
} }
} }
@ -371,16 +409,85 @@ impl Default for AutonomyConfig {
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RuntimeConfig { pub struct RuntimeConfig {
/// Runtime kind (currently supported: "native"). /// Runtime kind (`native` | `docker`).
/// #[serde(default = "default_runtime_kind")]
/// Reserved values (not implemented yet): "docker", "cloudflare".
pub kind: String, pub kind: String,
/// Docker runtime settings (used when `kind = "docker"`).
#[serde(default)]
pub docker: DockerRuntimeConfig,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DockerRuntimeConfig {
/// Runtime image used to execute shell commands.
#[serde(default = "default_docker_image")]
pub image: String,
/// Docker network mode (`none`, `bridge`, etc.).
#[serde(default = "default_docker_network")]
pub network: String,
/// Optional memory limit in MB (`None` = no explicit limit).
#[serde(default = "default_docker_memory_limit_mb")]
pub memory_limit_mb: Option<u64>,
/// Optional CPU limit (`None` = no explicit limit).
#[serde(default = "default_docker_cpu_limit")]
pub cpu_limit: Option<f64>,
/// Mount root filesystem as read-only.
#[serde(default = "default_true")]
pub read_only_rootfs: bool,
/// Mount configured workspace into `/workspace`.
#[serde(default = "default_true")]
pub mount_workspace: bool,
/// Optional workspace root allowlist for Docker mount validation.
#[serde(default)]
pub allowed_workspace_roots: Vec<String>,
}
fn default_runtime_kind() -> String {
"native".into()
}
fn default_docker_image() -> String {
"alpine:3.20".into()
}
fn default_docker_network() -> String {
"none".into()
}
fn default_docker_memory_limit_mb() -> Option<u64> {
Some(512)
}
fn default_docker_cpu_limit() -> Option<f64> {
Some(1.0)
}
impl Default for DockerRuntimeConfig {
fn default() -> Self {
Self {
image: default_docker_image(),
network: default_docker_network(),
memory_limit_mb: default_docker_memory_limit_mb(),
cpu_limit: default_docker_cpu_limit(),
read_only_rootfs: true,
mount_workspace: true,
allowed_workspace_roots: Vec::new(),
}
}
} }
impl Default for RuntimeConfig { impl Default for RuntimeConfig {
fn default() -> Self { fn default() -> Self {
Self { Self {
kind: "native".into(), kind: default_runtime_kind(),
docker: DockerRuntimeConfig::default(),
} }
} }
} }
@ -811,11 +918,86 @@ impl Config {
pub fn save(&self) -> Result<()> { pub fn save(&self) -> Result<()> {
let toml_str = toml::to_string_pretty(self).context("Failed to serialize config")?; let toml_str = toml::to_string_pretty(self).context("Failed to serialize config")?;
fs::write(&self.config_path, toml_str).context("Failed to write config file")?;
let parent_dir = self
.config_path
.parent()
.context("Config path must have a parent directory")?;
fs::create_dir_all(parent_dir).with_context(|| {
format!(
"Failed to create config directory: {}",
parent_dir.display()
)
})?;
let file_name = self
.config_path
.file_name()
.and_then(|v| v.to_str())
.unwrap_or("config.toml");
let temp_path = parent_dir.join(format!(".{file_name}.tmp-{}", uuid::Uuid::new_v4()));
let backup_path = parent_dir.join(format!("{file_name}.bak"));
let mut temp_file = OpenOptions::new()
.create_new(true)
.write(true)
.open(&temp_path)
.with_context(|| {
format!(
"Failed to create temporary config file: {}",
temp_path.display()
)
})?;
temp_file
.write_all(toml_str.as_bytes())
.context("Failed to write temporary config contents")?;
temp_file
.sync_all()
.context("Failed to fsync temporary config file")?;
drop(temp_file);
let had_existing_config = self.config_path.exists();
if had_existing_config {
fs::copy(&self.config_path, &backup_path).with_context(|| {
format!(
"Failed to create config backup before atomic replace: {}",
backup_path.display()
)
})?;
}
if let Err(e) = fs::rename(&temp_path, &self.config_path) {
let _ = fs::remove_file(&temp_path);
if had_existing_config && backup_path.exists() {
let _ = fs::copy(&backup_path, &self.config_path);
}
anyhow::bail!("Failed to atomically replace config file: {e}");
}
sync_directory(parent_dir)?;
if had_existing_config {
let _ = fs::remove_file(&backup_path);
}
Ok(()) Ok(())
} }
} }
#[cfg(unix)]
fn sync_directory(path: &Path) -> Result<()> {
let dir = File::open(path)
.with_context(|| format!("Failed to open directory for fsync: {}", path.display()))?;
dir.sync_all()
.with_context(|| format!("Failed to fsync directory metadata: {}", path.display()))?;
Ok(())
}
#[cfg(not(unix))]
fn sync_directory(_path: &Path) -> Result<()> {
Ok(())
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
@ -850,12 +1032,20 @@ mod tests {
assert!(a.forbidden_paths.contains(&"/etc".to_string())); assert!(a.forbidden_paths.contains(&"/etc".to_string()));
assert_eq!(a.max_actions_per_hour, 20); assert_eq!(a.max_actions_per_hour, 20);
assert_eq!(a.max_cost_per_day_cents, 500); assert_eq!(a.max_cost_per_day_cents, 500);
assert!(a.require_approval_for_medium_risk);
assert!(a.block_high_risk_commands);
} }
#[test] #[test]
fn runtime_config_default() { fn runtime_config_default() {
let r = RuntimeConfig::default(); let r = RuntimeConfig::default();
assert_eq!(r.kind, "native"); assert_eq!(r.kind, "native");
assert_eq!(r.docker.image, "alpine:3.20");
assert_eq!(r.docker.network, "none");
assert_eq!(r.docker.memory_limit_mb, Some(512));
assert_eq!(r.docker.cpu_limit, Some(1.0));
assert!(r.docker.read_only_rootfs);
assert!(r.docker.mount_workspace);
} }
#[test] #[test]
@ -905,9 +1095,12 @@ mod tests {
forbidden_paths: vec!["/secret".into()], forbidden_paths: vec!["/secret".into()],
max_actions_per_hour: 50, max_actions_per_hour: 50,
max_cost_per_day_cents: 1000, max_cost_per_day_cents: 1000,
require_approval_for_medium_risk: false,
block_high_risk_commands: true,
}, },
runtime: RuntimeConfig { runtime: RuntimeConfig {
kind: "docker".into(), kind: "docker".into(),
..RuntimeConfig::default()
}, },
reliability: ReliabilityConfig::default(), reliability: ReliabilityConfig::default(),
model_routes: Vec::new(), model_routes: Vec::new(),
@ -1022,6 +1215,38 @@ default_temperature = 0.7
let _ = fs::remove_dir_all(&dir); let _ = fs::remove_dir_all(&dir);
} }
#[test]
fn config_save_atomic_cleanup() {
let dir =
std::env::temp_dir().join(format!("zeroclaw_test_config_{}", uuid::Uuid::new_v4()));
fs::create_dir_all(&dir).unwrap();
let config_path = dir.join("config.toml");
let mut config = Config::default();
config.workspace_dir = dir.join("workspace");
config.config_path = config_path.clone();
config.default_model = Some("model-a".into());
config.save().unwrap();
assert!(config_path.exists());
config.default_model = Some("model-b".into());
config.save().unwrap();
let contents = fs::read_to_string(&config_path).unwrap();
assert!(contents.contains("model-b"));
let names: Vec<String> = fs::read_dir(&dir)
.unwrap()
.map(|entry| entry.unwrap().file_name().to_string_lossy().to_string())
.collect();
assert!(!names.iter().any(|name| name.contains(".tmp-")));
assert!(!names.iter().any(|name| name.ends_with(".bak")));
let _ = fs::remove_dir_all(&dir);
}
// ── Telegram / Discord config ──────────────────────────── // ── Telegram / Discord config ────────────────────────────
#[test] #[test]
@ -1343,6 +1568,9 @@ channel_id = "C123"
g.paired_tokens.is_empty(), g.paired_tokens.is_empty(),
"No pre-paired tokens by default" "No pre-paired tokens by default"
); );
assert_eq!(g.pair_rate_limit_per_minute, 10);
assert_eq!(g.webhook_rate_limit_per_minute, 60);
assert_eq!(g.idempotency_ttl_secs, 300);
} }
#[test] #[test]
@ -1368,12 +1596,18 @@ channel_id = "C123"
require_pairing: true, require_pairing: true,
allow_public_bind: false, allow_public_bind: false,
paired_tokens: vec!["zc_test_token".into()], paired_tokens: vec!["zc_test_token".into()],
pair_rate_limit_per_minute: 12,
webhook_rate_limit_per_minute: 80,
idempotency_ttl_secs: 600,
}; };
let toml_str = toml::to_string(&g).unwrap(); let toml_str = toml::to_string(&g).unwrap();
let parsed: GatewayConfig = toml::from_str(&toml_str).unwrap(); let parsed: GatewayConfig = toml::from_str(&toml_str).unwrap();
assert!(parsed.require_pairing); assert!(parsed.require_pairing);
assert!(!parsed.allow_public_bind); assert!(!parsed.allow_public_bind);
assert_eq!(parsed.paired_tokens, vec!["zc_test_token"]); assert_eq!(parsed.paired_tokens, vec!["zc_test_token"]);
assert_eq!(parsed.pair_rate_limit_per_minute, 12);
assert_eq!(parsed.webhook_rate_limit_per_minute, 80);
assert_eq!(parsed.idempotency_ttl_secs, 600);
} }
#[test] #[test]

View file

@ -22,9 +22,10 @@ use axum::{
routing::{get, post}, routing::{get, post},
Router, Router,
}; };
use std::collections::HashMap;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::Arc; use std::sync::{Arc, Mutex};
use std::time::Duration; use std::time::{Duration, Instant};
use tower_http::limit::RequestBodyLimitLayer; use tower_http::limit::RequestBodyLimitLayer;
use tower_http::timeout::TimeoutLayer; use tower_http::timeout::TimeoutLayer;
@ -32,6 +33,118 @@ use tower_http::timeout::TimeoutLayer;
pub const MAX_BODY_SIZE: usize = 65_536; pub const MAX_BODY_SIZE: usize = 65_536;
/// Request timeout (30s) — prevents slow-loris attacks /// Request timeout (30s) — prevents slow-loris attacks
pub const REQUEST_TIMEOUT_SECS: u64 = 30; pub const REQUEST_TIMEOUT_SECS: u64 = 30;
/// Sliding window used by gateway rate limiting.
pub const RATE_LIMIT_WINDOW_SECS: u64 = 60;
#[derive(Debug)]
struct SlidingWindowRateLimiter {
limit_per_window: u32,
window: Duration,
requests: Mutex<HashMap<String, Vec<Instant>>>,
}
impl SlidingWindowRateLimiter {
fn new(limit_per_window: u32, window: Duration) -> Self {
Self {
limit_per_window,
window,
requests: Mutex::new(HashMap::new()),
}
}
fn allow(&self, key: &str) -> bool {
if self.limit_per_window == 0 {
return true;
}
let now = Instant::now();
let cutoff = now.checked_sub(self.window).unwrap_or_else(Instant::now);
let mut requests = self
.requests
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
let entry = requests.entry(key.to_owned()).or_default();
entry.retain(|instant| *instant > cutoff);
if entry.len() >= self.limit_per_window as usize {
return false;
}
entry.push(now);
true
}
}
#[derive(Debug)]
pub struct GatewayRateLimiter {
pair: SlidingWindowRateLimiter,
webhook: SlidingWindowRateLimiter,
}
impl GatewayRateLimiter {
fn new(pair_per_minute: u32, webhook_per_minute: u32) -> Self {
let window = Duration::from_secs(RATE_LIMIT_WINDOW_SECS);
Self {
pair: SlidingWindowRateLimiter::new(pair_per_minute, window),
webhook: SlidingWindowRateLimiter::new(webhook_per_minute, window),
}
}
fn allow_pair(&self, key: &str) -> bool {
self.pair.allow(key)
}
fn allow_webhook(&self, key: &str) -> bool {
self.webhook.allow(key)
}
}
#[derive(Debug)]
pub struct IdempotencyStore {
ttl: Duration,
keys: Mutex<HashMap<String, Instant>>,
}
impl IdempotencyStore {
fn new(ttl: Duration) -> Self {
Self {
ttl,
keys: Mutex::new(HashMap::new()),
}
}
/// Returns true if this key is new and is now recorded.
fn record_if_new(&self, key: &str) -> bool {
let now = Instant::now();
let mut keys = self
.keys
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
keys.retain(|_, seen_at| now.duration_since(*seen_at) < self.ttl);
if keys.contains_key(key) {
return false;
}
keys.insert(key.to_owned(), now);
true
}
}
fn client_key_from_headers(headers: &HeaderMap) -> String {
for header_name in ["X-Forwarded-For", "X-Real-IP"] {
if let Some(value) = headers.get(header_name).and_then(|v| v.to_str().ok()) {
let first = value.split(',').next().unwrap_or("").trim();
if !first.is_empty() {
return first.to_owned();
}
}
}
"unknown".into()
}
/// Shared state for all axum handlers /// Shared state for all axum handlers
#[derive(Clone)] #[derive(Clone)]
@ -43,6 +156,8 @@ pub struct AppState {
pub auto_save: bool, pub auto_save: bool,
pub webhook_secret: Option<Arc<str>>, pub webhook_secret: Option<Arc<str>>,
pub pairing: Arc<PairingGuard>, pub pairing: Arc<PairingGuard>,
pub rate_limiter: Arc<GatewayRateLimiter>,
pub idempotency_store: Arc<IdempotencyStore>,
pub whatsapp: Option<Arc<WhatsAppChannel>>, pub whatsapp: Option<Arc<WhatsAppChannel>>,
/// `WhatsApp` app secret for webhook signature verification (`X-Hub-Signature-256`) /// `WhatsApp` app secret for webhook signature verification (`X-Hub-Signature-256`)
pub whatsapp_app_secret: Option<Arc<str>>, pub whatsapp_app_secret: Option<Arc<str>>,
@ -66,17 +181,15 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
let actual_port = listener.local_addr()?.port(); let actual_port = listener.local_addr()?.port();
let display_addr = format!("{host}:{actual_port}"); let display_addr = format!("{host}:{actual_port}");
let provider: Arc<dyn Provider> = Arc::from(providers::create_resilient_provider(
config.default_provider.as_deref().unwrap_or("openrouter"),
config.api_key.as_deref(),
&config.reliability,
)?);
let model = config let model = config
.default_model .default_model
.clone() .clone()
.unwrap_or_else(|| "anthropic/claude-sonnet-4-20250514".into()); .unwrap_or_else(|| "anthropic/claude-sonnet-4-20250514".into());
let provider: Arc<dyn Provider> = Arc::from(providers::create_routed_provider(
config.default_provider.as_deref().unwrap_or("openrouter"),
config.api_key.as_deref(),
&config.reliability,
&config.model_routes,
&model,
)?);
let temperature = config.default_temperature; let temperature = config.default_temperature;
let mem: Arc<dyn Memory> = Arc::from(memory::create_memory( let mem: Arc<dyn Memory> = Arc::from(memory::create_memory(
&config.memory, &config.memory,
@ -127,6 +240,13 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
config.gateway.require_pairing, config.gateway.require_pairing,
&config.gateway.paired_tokens, &config.gateway.paired_tokens,
)); ));
let rate_limiter = Arc::new(GatewayRateLimiter::new(
config.gateway.pair_rate_limit_per_minute,
config.gateway.webhook_rate_limit_per_minute,
));
let idempotency_store = Arc::new(IdempotencyStore::new(Duration::from_secs(
config.gateway.idempotency_ttl_secs.max(1),
)));
// ── Tunnel ──────────────────────────────────────────────── // ── Tunnel ────────────────────────────────────────────────
let tunnel = crate::tunnel::create_tunnel(&config.tunnel)?; let tunnel = crate::tunnel::create_tunnel(&config.tunnel)?;
@ -185,6 +305,8 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
auto_save: config.memory.auto_save, auto_save: config.memory.auto_save,
webhook_secret, webhook_secret,
pairing, pairing,
rate_limiter,
idempotency_store,
whatsapp: whatsapp_channel, whatsapp: whatsapp_channel,
whatsapp_app_secret, whatsapp_app_secret,
}; };
@ -225,6 +347,16 @@ async fn handle_health(State(state): State<AppState>) -> impl IntoResponse {
/// POST /pair — exchange one-time code for bearer token /// POST /pair — exchange one-time code for bearer token
async fn handle_pair(State(state): State<AppState>, headers: HeaderMap) -> impl IntoResponse { async fn handle_pair(State(state): State<AppState>, headers: HeaderMap) -> impl IntoResponse {
let client_key = client_key_from_headers(&headers);
if !state.rate_limiter.allow_pair(&client_key) {
tracing::warn!("/pair rate limit exceeded for key: {client_key}");
let err = serde_json::json!({
"error": "Too many pairing requests. Please retry later.",
"retry_after": RATE_LIMIT_WINDOW_SECS,
});
return (StatusCode::TOO_MANY_REQUESTS, Json(err));
}
let code = headers let code = headers
.get("X-Pairing-Code") .get("X-Pairing-Code")
.and_then(|v| v.to_str().ok()) .and_then(|v| v.to_str().ok())
@ -270,6 +402,16 @@ async fn handle_webhook(
headers: HeaderMap, headers: HeaderMap,
body: Result<Json<WebhookBody>, axum::extract::rejection::JsonRejection>, body: Result<Json<WebhookBody>, axum::extract::rejection::JsonRejection>,
) -> impl IntoResponse { ) -> impl IntoResponse {
let client_key = client_key_from_headers(&headers);
if !state.rate_limiter.allow_webhook(&client_key) {
tracing::warn!("/webhook rate limit exceeded for key: {client_key}");
let err = serde_json::json!({
"error": "Too many webhook requests. Please retry later.",
"retry_after": RATE_LIMIT_WINDOW_SECS,
});
return (StatusCode::TOO_MANY_REQUESTS, Json(err));
}
// ── Bearer token auth (pairing) ── // ── Bearer token auth (pairing) ──
if state.pairing.require_pairing() { if state.pairing.require_pairing() {
let auth = headers let auth = headers
@ -312,6 +454,24 @@ async fn handle_webhook(
} }
}; };
// ── Idempotency (optional) ──
if let Some(idempotency_key) = headers
.get("X-Idempotency-Key")
.and_then(|v| v.to_str().ok())
.map(str::trim)
.filter(|value| !value.is_empty())
{
if !state.idempotency_store.record_if_new(idempotency_key) {
tracing::info!("Webhook duplicate ignored (idempotency key: {idempotency_key})");
let body = serde_json::json!({
"status": "duplicate",
"idempotent": true,
"message": "Request already processed for this idempotency key"
});
return (StatusCode::OK, Json(body));
}
}
let message = &webhook_body.message; let message = &webhook_body.message;
if state.auto_save { if state.auto_save {
@ -508,6 +668,13 @@ async fn handle_whatsapp_message(
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::memory::{Memory, MemoryCategory, MemoryEntry};
use crate::providers::Provider;
use async_trait::async_trait;
use axum::http::HeaderValue;
use axum::response::IntoResponse;
use http_body_util::BodyExt;
use std::sync::atomic::{AtomicUsize, Ordering};
#[test] #[test]
fn security_body_limit_is_64kb() { fn security_body_limit_is_64kb() {
@ -547,6 +714,133 @@ mod tests {
assert_clone::<AppState>(); assert_clone::<AppState>();
} }
#[test]
fn gateway_rate_limiter_blocks_after_limit() {
let limiter = GatewayRateLimiter::new(2, 2);
assert!(limiter.allow_pair("127.0.0.1"));
assert!(limiter.allow_pair("127.0.0.1"));
assert!(!limiter.allow_pair("127.0.0.1"));
}
#[test]
fn idempotency_store_rejects_duplicate_key() {
let store = IdempotencyStore::new(Duration::from_secs(30));
assert!(store.record_if_new("req-1"));
assert!(!store.record_if_new("req-1"));
assert!(store.record_if_new("req-2"));
}
#[derive(Default)]
struct MockMemory;
#[async_trait]
impl Memory for MockMemory {
fn name(&self) -> &str {
"mock"
}
async fn store(
&self,
_key: &str,
_content: &str,
_category: MemoryCategory,
) -> anyhow::Result<()> {
Ok(())
}
async fn recall(&self, _query: &str, _limit: usize) -> anyhow::Result<Vec<MemoryEntry>> {
Ok(Vec::new())
}
async fn get(&self, _key: &str) -> anyhow::Result<Option<MemoryEntry>> {
Ok(None)
}
async fn list(
&self,
_category: Option<&MemoryCategory>,
) -> anyhow::Result<Vec<MemoryEntry>> {
Ok(Vec::new())
}
async fn forget(&self, _key: &str) -> anyhow::Result<bool> {
Ok(false)
}
async fn count(&self) -> anyhow::Result<usize> {
Ok(0)
}
async fn health_check(&self) -> bool {
true
}
}
#[derive(Default)]
struct MockProvider {
calls: AtomicUsize,
}
#[async_trait]
impl Provider for MockProvider {
async fn chat_with_system(
&self,
_system_prompt: Option<&str>,
_message: &str,
_model: &str,
_temperature: f64,
) -> anyhow::Result<String> {
self.calls.fetch_add(1, Ordering::SeqCst);
Ok("ok".into())
}
}
#[tokio::test]
async fn webhook_idempotency_skips_duplicate_provider_calls() {
let provider_impl = Arc::new(MockProvider::default());
let provider: Arc<dyn Provider> = provider_impl.clone();
let memory: Arc<dyn Memory> = Arc::new(MockMemory);
let state = AppState {
provider,
model: "test-model".into(),
temperature: 0.0,
mem: memory,
auto_save: false,
webhook_secret: None,
pairing: Arc::new(PairingGuard::new(false, &[])),
rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100)),
idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300))),
whatsapp: None,
whatsapp_app_secret: None,
};
let mut headers = HeaderMap::new();
headers.insert("X-Idempotency-Key", HeaderValue::from_static("abc-123"));
let body = Ok(Json(WebhookBody {
message: "hello".into(),
}));
let first = handle_webhook(State(state.clone()), headers.clone(), body)
.await
.into_response();
assert_eq!(first.status(), StatusCode::OK);
let body = Ok(Json(WebhookBody {
message: "hello".into(),
}));
let second = handle_webhook(State(state), headers, body)
.await
.into_response();
assert_eq!(second.status(), StatusCode::OK);
let payload = second.into_body().collect().await.unwrap().to_bytes();
let parsed: serde_json::Value = serde_json::from_slice(&payload).unwrap();
assert_eq!(parsed["status"], "duplicate");
assert_eq!(parsed["idempotent"], true);
assert_eq!(provider_impl.calls.load(Ordering::SeqCst), 1);
}
// ══════════════════════════════════════════════════════════ // ══════════════════════════════════════════════════════════
// WhatsApp Signature Verification Tests (CWE-345 Prevention) // WhatsApp Signature Verification Tests (CWE-345 Prevention)
// ══════════════════════════════════════════════════════════ // ══════════════════════════════════════════════════════════
@ -572,7 +866,11 @@ mod tests {
let signature_header = compute_whatsapp_signature_header(app_secret, body); let signature_header = compute_whatsapp_signature_header(app_secret, body);
assert!(verify_whatsapp_signature(app_secret, body, &signature_header)); assert!(verify_whatsapp_signature(
app_secret,
body,
&signature_header
));
} }
#[test] #[test]
@ -583,7 +881,11 @@ mod tests {
let signature_header = compute_whatsapp_signature_header(wrong_secret, body); let signature_header = compute_whatsapp_signature_header(wrong_secret, body);
assert!(!verify_whatsapp_signature(app_secret, body, &signature_header)); assert!(!verify_whatsapp_signature(
app_secret,
body,
&signature_header
));
} }
#[test] #[test]
@ -610,7 +912,11 @@ mod tests {
// Signature without "sha256=" prefix // Signature without "sha256=" prefix
let signature_header = "abc123def456"; let signature_header = "abc123def456";
assert!(!verify_whatsapp_signature(app_secret, body, signature_header)); assert!(!verify_whatsapp_signature(
app_secret,
body,
signature_header
));
} }
#[test] #[test]
@ -643,7 +949,11 @@ mod tests {
let signature_header = compute_whatsapp_signature_header(app_secret, body); let signature_header = compute_whatsapp_signature_header(app_secret, body);
assert!(verify_whatsapp_signature(app_secret, body, &signature_header)); assert!(verify_whatsapp_signature(
app_secret,
body,
&signature_header
));
} }
#[test] #[test]
@ -653,7 +963,11 @@ mod tests {
let signature_header = compute_whatsapp_signature_header(app_secret, body); let signature_header = compute_whatsapp_signature_header(app_secret, body);
assert!(verify_whatsapp_signature(app_secret, body, &signature_header)); assert!(verify_whatsapp_signature(
app_secret,
body,
&signature_header
));
} }
#[test] #[test]
@ -663,7 +977,11 @@ mod tests {
let signature_header = compute_whatsapp_signature_header(app_secret, body); let signature_header = compute_whatsapp_signature_header(app_secret, body);
assert!(verify_whatsapp_signature(app_secret, body, &signature_header)); assert!(verify_whatsapp_signature(
app_secret,
body,
&signature_header
));
} }
#[test] #[test]

199
src/runtime/docker.rs Normal file
View file

@ -0,0 +1,199 @@
use super::traits::RuntimeAdapter;
use crate::config::DockerRuntimeConfig;
use anyhow::{Context, Result};
use std::path::{Path, PathBuf};
/// Docker runtime with lightweight container isolation.
#[derive(Debug, Clone)]
pub struct DockerRuntime {
config: DockerRuntimeConfig,
}
impl DockerRuntime {
pub fn new(config: DockerRuntimeConfig) -> Self {
Self { config }
}
fn workspace_mount_path(&self, workspace_dir: &Path) -> Result<PathBuf> {
let resolved = workspace_dir
.canonicalize()
.unwrap_or_else(|_| workspace_dir.to_path_buf());
if !resolved.is_absolute() {
anyhow::bail!(
"Docker runtime requires an absolute workspace path, got: {}",
resolved.display()
);
}
if resolved == Path::new("/") {
anyhow::bail!("Refusing to mount filesystem root (/) into docker runtime");
}
if self.config.allowed_workspace_roots.is_empty() {
return Ok(resolved);
}
let allowed = self.config.allowed_workspace_roots.iter().any(|root| {
let root_path = Path::new(root)
.canonicalize()
.unwrap_or_else(|_| PathBuf::from(root));
resolved.starts_with(root_path)
});
if !allowed {
anyhow::bail!(
"Workspace path {} is not in runtime.docker.allowed_workspace_roots",
resolved.display()
);
}
Ok(resolved)
}
}
impl RuntimeAdapter for DockerRuntime {
fn name(&self) -> &str {
"docker"
}
fn has_shell_access(&self) -> bool {
true
}
fn has_filesystem_access(&self) -> bool {
self.config.mount_workspace
}
fn storage_path(&self) -> PathBuf {
if self.config.mount_workspace {
PathBuf::from("/workspace/.zeroclaw")
} else {
PathBuf::from("/tmp/.zeroclaw")
}
}
fn supports_long_running(&self) -> bool {
false
}
fn memory_budget(&self) -> u64 {
self.config
.memory_limit_mb
.map_or(0, |mb| mb.saturating_mul(1024 * 1024))
}
fn build_shell_command(
&self,
command: &str,
workspace_dir: &Path,
) -> anyhow::Result<tokio::process::Command> {
let mut process = tokio::process::Command::new("docker");
process
.arg("run")
.arg("--rm")
.arg("--init")
.arg("--interactive");
let network = self.config.network.trim();
if !network.is_empty() {
process.arg("--network").arg(network);
}
if let Some(memory_limit_mb) = self.config.memory_limit_mb.filter(|mb| *mb > 0) {
process.arg("--memory").arg(format!("{memory_limit_mb}m"));
}
if let Some(cpu_limit) = self.config.cpu_limit.filter(|cpus| *cpus > 0.0) {
process.arg("--cpus").arg(cpu_limit.to_string());
}
if self.config.read_only_rootfs {
process.arg("--read-only");
}
if self.config.mount_workspace {
let host_workspace = self.workspace_mount_path(workspace_dir).with_context(|| {
format!(
"Failed to validate workspace mount path {}",
workspace_dir.display()
)
})?;
process
.arg("--volume")
.arg(format!("{}:/workspace:rw", host_workspace.display()))
.arg("--workdir")
.arg("/workspace");
}
process
.arg(self.config.image.trim())
.arg("sh")
.arg("-c")
.arg(command);
Ok(process)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn docker_runtime_name() {
let runtime = DockerRuntime::new(DockerRuntimeConfig::default());
assert_eq!(runtime.name(), "docker");
}
#[test]
fn docker_runtime_memory_budget() {
let mut cfg = DockerRuntimeConfig::default();
cfg.memory_limit_mb = Some(256);
let runtime = DockerRuntime::new(cfg);
assert_eq!(runtime.memory_budget(), 256 * 1024 * 1024);
}
#[test]
fn docker_build_shell_command_includes_runtime_flags() {
let cfg = DockerRuntimeConfig {
image: "alpine:3.20".into(),
network: "none".into(),
memory_limit_mb: Some(128),
cpu_limit: Some(1.5),
read_only_rootfs: true,
mount_workspace: true,
allowed_workspace_roots: Vec::new(),
};
let runtime = DockerRuntime::new(cfg);
let workspace = std::env::temp_dir();
let command = runtime
.build_shell_command("echo hello", &workspace)
.unwrap();
let debug = format!("{command:?}");
assert!(debug.contains("docker"));
assert!(debug.contains("--memory"));
assert!(debug.contains("128m"));
assert!(debug.contains("--cpus"));
assert!(debug.contains("1.5"));
assert!(debug.contains("--workdir"));
assert!(debug.contains("echo hello"));
}
#[test]
fn docker_workspace_allowlist_blocks_outside_paths() {
let cfg = DockerRuntimeConfig {
allowed_workspace_roots: vec!["/tmp/allowed".into()],
..DockerRuntimeConfig::default()
};
let runtime = DockerRuntime::new(cfg);
let outside = PathBuf::from("/tmp/blocked_workspace");
let result = runtime.build_shell_command("echo test", &outside);
assert!(result.is_err());
}
}

View file

@ -1,6 +1,8 @@
pub mod docker;
pub mod native; pub mod native;
pub mod traits; pub mod traits;
pub use docker::DockerRuntime;
pub use native::NativeRuntime; pub use native::NativeRuntime;
pub use traits::RuntimeAdapter; pub use traits::RuntimeAdapter;
@ -10,18 +12,14 @@ use crate::config::RuntimeConfig;
pub fn create_runtime(config: &RuntimeConfig) -> anyhow::Result<Box<dyn RuntimeAdapter>> { pub fn create_runtime(config: &RuntimeConfig) -> anyhow::Result<Box<dyn RuntimeAdapter>> {
match config.kind.as_str() { match config.kind.as_str() {
"native" => Ok(Box::new(NativeRuntime::new())), "native" => Ok(Box::new(NativeRuntime::new())),
"docker" => anyhow::bail!( "docker" => Ok(Box::new(DockerRuntime::new(config.docker.clone()))),
"runtime.kind='docker' is not implemented yet. Use runtime.kind='native' until container runtime support lands."
),
"cloudflare" => anyhow::bail!( "cloudflare" => anyhow::bail!(
"runtime.kind='cloudflare' is not implemented yet. Use runtime.kind='native' for now." "runtime.kind='cloudflare' is not implemented yet. Use runtime.kind='native' for now."
), ),
other if other.trim().is_empty() => anyhow::bail!( other if other.trim().is_empty() => {
"runtime.kind cannot be empty. Supported values: native" anyhow::bail!("runtime.kind cannot be empty. Supported values: native, docker")
), }
other => anyhow::bail!( other => anyhow::bail!("Unknown runtime kind '{other}'. Supported values: native, docker"),
"Unknown runtime kind '{other}'. Supported values: native"
),
} }
} }
@ -33,6 +31,7 @@ mod tests {
fn factory_native() { fn factory_native() {
let cfg = RuntimeConfig { let cfg = RuntimeConfig {
kind: "native".into(), kind: "native".into(),
..RuntimeConfig::default()
}; };
let rt = create_runtime(&cfg).unwrap(); let rt = create_runtime(&cfg).unwrap();
assert_eq!(rt.name(), "native"); assert_eq!(rt.name(), "native");
@ -40,20 +39,21 @@ mod tests {
} }
#[test] #[test]
fn factory_docker_errors() { fn factory_docker() {
let cfg = RuntimeConfig { let cfg = RuntimeConfig {
kind: "docker".into(), kind: "docker".into(),
..RuntimeConfig::default()
}; };
match create_runtime(&cfg) { let rt = create_runtime(&cfg).unwrap();
Err(err) => assert!(err.to_string().contains("not implemented")), assert_eq!(rt.name(), "docker");
Ok(_) => panic!("docker runtime should error"), assert!(rt.has_shell_access());
}
} }
#[test] #[test]
fn factory_cloudflare_errors() { fn factory_cloudflare_errors() {
let cfg = RuntimeConfig { let cfg = RuntimeConfig {
kind: "cloudflare".into(), kind: "cloudflare".into(),
..RuntimeConfig::default()
}; };
match create_runtime(&cfg) { match create_runtime(&cfg) {
Err(err) => assert!(err.to_string().contains("not implemented")), Err(err) => assert!(err.to_string().contains("not implemented")),
@ -65,6 +65,7 @@ mod tests {
fn factory_unknown_errors() { fn factory_unknown_errors() {
let cfg = RuntimeConfig { let cfg = RuntimeConfig {
kind: "wasm-edge-unknown".into(), kind: "wasm-edge-unknown".into(),
..RuntimeConfig::default()
}; };
match create_runtime(&cfg) { match create_runtime(&cfg) {
Err(err) => assert!(err.to_string().contains("Unknown runtime kind")), Err(err) => assert!(err.to_string().contains("Unknown runtime kind")),
@ -76,6 +77,7 @@ mod tests {
fn factory_empty_errors() { fn factory_empty_errors() {
let cfg = RuntimeConfig { let cfg = RuntimeConfig {
kind: String::new(), kind: String::new(),
..RuntimeConfig::default()
}; };
match create_runtime(&cfg) { match create_runtime(&cfg) {
Err(err) => assert!(err.to_string().contains("cannot be empty")), Err(err) => assert!(err.to_string().contains("cannot be empty")),

View file

@ -1,5 +1,5 @@
use super::traits::RuntimeAdapter; use super::traits::RuntimeAdapter;
use std::path::PathBuf; use std::path::{Path, PathBuf};
/// Native runtime — full access, runs on Mac/Linux/Docker/Raspberry Pi /// Native runtime — full access, runs on Mac/Linux/Docker/Raspberry Pi
pub struct NativeRuntime; pub struct NativeRuntime;
@ -33,6 +33,16 @@ impl RuntimeAdapter for NativeRuntime {
fn supports_long_running(&self) -> bool { fn supports_long_running(&self) -> bool {
true true
} }
fn build_shell_command(
&self,
command: &str,
workspace_dir: &Path,
) -> anyhow::Result<tokio::process::Command> {
let mut process = tokio::process::Command::new("sh");
process.arg("-c").arg(command).current_dir(workspace_dir);
Ok(process)
}
} }
#[cfg(test)] #[cfg(test)]
@ -69,4 +79,14 @@ mod tests {
let path = NativeRuntime::new().storage_path(); let path = NativeRuntime::new().storage_path();
assert!(path.to_string_lossy().contains("zeroclaw")); assert!(path.to_string_lossy().contains("zeroclaw"));
} }
#[test]
fn native_builds_shell_command() {
let cwd = std::env::temp_dir();
let command = NativeRuntime::new()
.build_shell_command("echo hello", &cwd)
.unwrap();
let debug = format!("{command:?}");
assert!(debug.contains("echo hello"));
}
} }

View file

@ -1,4 +1,4 @@
use std::path::PathBuf; use std::path::{Path, PathBuf};
/// Runtime adapter — abstracts platform differences so the same agent /// Runtime adapter — abstracts platform differences so the same agent
/// code runs on native, Docker, Cloudflare Workers, Raspberry Pi, etc. /// code runs on native, Docker, Cloudflare Workers, Raspberry Pi, etc.
@ -22,4 +22,11 @@ pub trait RuntimeAdapter: Send + Sync {
fn memory_budget(&self) -> u64 { fn memory_budget(&self) -> u64 {
0 0
} }
/// Build a shell command process for this runtime.
fn build_shell_command(
&self,
command: &str,
workspace_dir: &Path,
) -> anyhow::Result<tokio::process::Command>;
} }

View file

@ -16,6 +16,14 @@ pub enum AutonomyLevel {
Full, Full,
} }
/// Risk score for shell command execution.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CommandRiskLevel {
Low,
Medium,
High,
}
/// Sliding-window action tracker for rate limiting. /// Sliding-window action tracker for rate limiting.
#[derive(Debug)] #[derive(Debug)]
pub struct ActionTracker { pub struct ActionTracker {
@ -80,6 +88,8 @@ pub struct SecurityPolicy {
pub forbidden_paths: Vec<String>, pub forbidden_paths: Vec<String>,
pub max_actions_per_hour: u32, pub max_actions_per_hour: u32,
pub max_cost_per_day_cents: u32, pub max_cost_per_day_cents: u32,
pub require_approval_for_medium_risk: bool,
pub block_high_risk_commands: bool,
pub tracker: ActionTracker, pub tracker: ActionTracker,
} }
@ -127,6 +137,8 @@ impl Default for SecurityPolicy {
], ],
max_actions_per_hour: 20, max_actions_per_hour: 20,
max_cost_per_day_cents: 500, max_cost_per_day_cents: 500,
require_approval_for_medium_risk: true,
block_high_risk_commands: true,
tracker: ActionTracker::new(), tracker: ActionTracker::new(),
} }
} }
@ -156,6 +168,163 @@ fn skip_env_assignments(s: &str) -> &str {
} }
impl SecurityPolicy { impl SecurityPolicy {
/// Classify command risk. Any high-risk segment marks the whole command high.
pub fn command_risk_level(&self, command: &str) -> CommandRiskLevel {
let mut normalized = command.to_string();
for sep in ["&&", "||"] {
normalized = normalized.replace(sep, "\x00");
}
for sep in ['\n', ';', '|'] {
normalized = normalized.replace(sep, "\x00");
}
let mut saw_medium = false;
for segment in normalized.split('\x00') {
let segment = segment.trim();
if segment.is_empty() {
continue;
}
let cmd_part = skip_env_assignments(segment);
let mut words = cmd_part.split_whitespace();
let Some(base_raw) = words.next() else {
continue;
};
let base = base_raw
.rsplit('/')
.next()
.unwrap_or("")
.to_ascii_lowercase();
let args: Vec<String> = words.map(|w| w.to_ascii_lowercase()).collect();
let joined_segment = cmd_part.to_ascii_lowercase();
// High-risk commands
if matches!(
base.as_str(),
"rm" | "mkfs"
| "dd"
| "shutdown"
| "reboot"
| "halt"
| "poweroff"
| "sudo"
| "su"
| "chown"
| "chmod"
| "useradd"
| "userdel"
| "usermod"
| "passwd"
| "mount"
| "umount"
| "iptables"
| "ufw"
| "firewall-cmd"
| "curl"
| "wget"
| "nc"
| "ncat"
| "netcat"
| "scp"
| "ssh"
| "ftp"
| "telnet"
) {
return CommandRiskLevel::High;
}
if joined_segment.contains("rm -rf /")
|| joined_segment.contains("rm -fr /")
|| joined_segment.contains(":(){:|:&};:")
{
return CommandRiskLevel::High;
}
// Medium-risk commands (state-changing, but not inherently destructive)
let medium = match base.as_str() {
"git" => args.first().is_some_and(|verb| {
matches!(
verb.as_str(),
"commit"
| "push"
| "reset"
| "clean"
| "rebase"
| "merge"
| "cherry-pick"
| "revert"
| "branch"
| "checkout"
| "switch"
| "tag"
)
}),
"npm" | "pnpm" | "yarn" => args.first().is_some_and(|verb| {
matches!(
verb.as_str(),
"install" | "add" | "remove" | "uninstall" | "update" | "publish"
)
}),
"cargo" => args.first().is_some_and(|verb| {
matches!(
verb.as_str(),
"add" | "remove" | "install" | "clean" | "publish"
)
}),
"touch" | "mkdir" | "mv" | "cp" | "ln" => true,
_ => false,
};
saw_medium |= medium;
}
if saw_medium {
CommandRiskLevel::Medium
} else {
CommandRiskLevel::Low
}
}
/// Validate full command execution policy (allowlist + risk gate).
pub fn validate_command_execution(
&self,
command: &str,
approved: bool,
) -> Result<CommandRiskLevel, String> {
if !self.is_command_allowed(command) {
return Err(format!("Command not allowed by security policy: {command}"));
}
let risk = self.command_risk_level(command);
if risk == CommandRiskLevel::High {
if self.block_high_risk_commands {
return Err("Command blocked: high-risk command is disallowed by policy".into());
}
if self.autonomy == AutonomyLevel::Supervised && !approved {
return Err(
"Command requires explicit approval (approved=true): high-risk operation"
.into(),
);
}
}
if risk == CommandRiskLevel::Medium
&& self.autonomy == AutonomyLevel::Supervised
&& self.require_approval_for_medium_risk
&& !approved
{
return Err(
"Command requires explicit approval (approved=true): medium-risk operation".into(),
);
}
Ok(risk)
}
/// Check if a shell command is allowed. /// Check if a shell command is allowed.
/// ///
/// Validates the **entire** command string, not just the first word: /// Validates the **entire** command string, not just the first word:
@ -329,6 +498,8 @@ impl SecurityPolicy {
forbidden_paths: autonomy_config.forbidden_paths.clone(), forbidden_paths: autonomy_config.forbidden_paths.clone(),
max_actions_per_hour: autonomy_config.max_actions_per_hour, max_actions_per_hour: autonomy_config.max_actions_per_hour,
max_cost_per_day_cents: autonomy_config.max_cost_per_day_cents, max_cost_per_day_cents: autonomy_config.max_cost_per_day_cents,
require_approval_for_medium_risk: autonomy_config.require_approval_for_medium_risk,
block_high_risk_commands: autonomy_config.block_high_risk_commands,
tracker: ActionTracker::new(), tracker: ActionTracker::new(),
} }
} }
@ -473,6 +644,71 @@ mod tests {
assert!(!p.is_command_allowed("echo hello")); assert!(!p.is_command_allowed("echo hello"));
} }
#[test]
fn command_risk_low_for_read_commands() {
let p = default_policy();
assert_eq!(p.command_risk_level("git status"), CommandRiskLevel::Low);
assert_eq!(p.command_risk_level("ls -la"), CommandRiskLevel::Low);
}
#[test]
fn command_risk_medium_for_mutating_commands() {
let p = SecurityPolicy {
allowed_commands: vec!["git".into(), "touch".into()],
..SecurityPolicy::default()
};
assert_eq!(
p.command_risk_level("git reset --hard HEAD~1"),
CommandRiskLevel::Medium
);
assert_eq!(
p.command_risk_level("touch file.txt"),
CommandRiskLevel::Medium
);
}
#[test]
fn command_risk_high_for_dangerous_commands() {
let p = SecurityPolicy {
allowed_commands: vec!["rm".into()],
..SecurityPolicy::default()
};
assert_eq!(
p.command_risk_level("rm -rf /tmp/test"),
CommandRiskLevel::High
);
}
#[test]
fn validate_command_requires_approval_for_medium_risk() {
let p = SecurityPolicy {
autonomy: AutonomyLevel::Supervised,
require_approval_for_medium_risk: true,
allowed_commands: vec!["touch".into()],
..SecurityPolicy::default()
};
let denied = p.validate_command_execution("touch test.txt", false);
assert!(denied.is_err());
assert!(denied.unwrap_err().contains("requires explicit approval"),);
let allowed = p.validate_command_execution("touch test.txt", true);
assert_eq!(allowed.unwrap(), CommandRiskLevel::Medium);
}
#[test]
fn validate_command_blocks_high_risk_by_default() {
let p = SecurityPolicy {
autonomy: AutonomyLevel::Supervised,
allowed_commands: vec!["rm".into()],
..SecurityPolicy::default()
};
let result = p.validate_command_execution("rm -rf /tmp/test", true);
assert!(result.is_err());
assert!(result.unwrap_err().contains("high-risk"));
}
// ── is_path_allowed ───────────────────────────────────── // ── is_path_allowed ─────────────────────────────────────
#[test] #[test]
@ -546,6 +782,8 @@ mod tests {
forbidden_paths: vec!["/secret".into()], forbidden_paths: vec!["/secret".into()],
max_actions_per_hour: 100, max_actions_per_hour: 100,
max_cost_per_day_cents: 1000, max_cost_per_day_cents: 1000,
require_approval_for_medium_risk: false,
block_high_risk_commands: false,
}; };
let workspace = PathBuf::from("/tmp/test-workspace"); let workspace = PathBuf::from("/tmp/test-workspace");
let policy = SecurityPolicy::from_config(&autonomy_config, &workspace); let policy = SecurityPolicy::from_config(&autonomy_config, &workspace);
@ -556,6 +794,8 @@ mod tests {
assert_eq!(policy.forbidden_paths, vec!["/secret"]); assert_eq!(policy.forbidden_paths, vec!["/secret"]);
assert_eq!(policy.max_actions_per_hour, 100); assert_eq!(policy.max_actions_per_hour, 100);
assert_eq!(policy.max_cost_per_day_cents, 1000); assert_eq!(policy.max_cost_per_day_cents, 1000);
assert!(!policy.require_approval_for_medium_risk);
assert!(!policy.block_high_risk_commands);
assert_eq!(policy.workspace_dir, PathBuf::from("/tmp/test-workspace")); assert_eq!(policy.workspace_dir, PathBuf::from("/tmp/test-workspace"));
} }
@ -570,6 +810,8 @@ mod tests {
assert!(!p.forbidden_paths.is_empty()); assert!(!p.forbidden_paths.is_empty());
assert!(p.max_actions_per_hour > 0); assert!(p.max_actions_per_hour > 0);
assert!(p.max_cost_per_day_cents > 0); assert!(p.max_cost_per_day_cents > 0);
assert!(p.require_approval_for_medium_risk);
assert!(p.block_high_risk_commands);
} }
// ── ActionTracker / rate limiting ─────────────────────── // ── ActionTracker / rate limiting ───────────────────────
@ -853,6 +1095,8 @@ mod tests {
forbidden_paths: vec![], forbidden_paths: vec![],
max_actions_per_hour: 10, max_actions_per_hour: 10,
max_cost_per_day_cents: 100, max_cost_per_day_cents: 100,
require_approval_for_medium_risk: true,
block_high_risk_commands: true,
}; };
let workspace = PathBuf::from("/tmp/test"); let workspace = PathBuf::from("/tmp/test");
let policy = SecurityPolicy::from_config(&autonomy_config, &workspace); let policy = SecurityPolicy::from_config(&autonomy_config, &workspace);

View file

@ -23,13 +23,22 @@ pub use traits::Tool;
pub use traits::{ToolResult, ToolSpec}; pub use traits::{ToolResult, ToolSpec};
use crate::memory::Memory; use crate::memory::Memory;
use crate::runtime::{NativeRuntime, RuntimeAdapter};
use crate::security::SecurityPolicy; use crate::security::SecurityPolicy;
use std::sync::Arc; use std::sync::Arc;
/// Create the default tool registry /// Create the default tool registry
pub fn default_tools(security: Arc<SecurityPolicy>) -> Vec<Box<dyn Tool>> { pub fn default_tools(security: Arc<SecurityPolicy>) -> Vec<Box<dyn Tool>> {
default_tools_with_runtime(security, Arc::new(NativeRuntime::new()))
}
/// Create the default tool registry with explicit runtime adapter.
pub fn default_tools_with_runtime(
security: Arc<SecurityPolicy>,
runtime: Arc<dyn RuntimeAdapter>,
) -> Vec<Box<dyn Tool>> {
vec![ vec![
Box::new(ShellTool::new(security.clone())), Box::new(ShellTool::new(security.clone(), runtime)),
Box::new(FileReadTool::new(security.clone())), Box::new(FileReadTool::new(security.clone())),
Box::new(FileWriteTool::new(security)), Box::new(FileWriteTool::new(security)),
] ]
@ -41,9 +50,26 @@ pub fn all_tools(
memory: Arc<dyn Memory>, memory: Arc<dyn Memory>,
composio_key: Option<&str>, composio_key: Option<&str>,
browser_config: &crate::config::BrowserConfig, browser_config: &crate::config::BrowserConfig,
) -> Vec<Box<dyn Tool>> {
all_tools_with_runtime(
security,
Arc::new(NativeRuntime::new()),
memory,
composio_key,
browser_config,
)
}
/// Create full tool registry including memory tools and optional Composio.
pub fn all_tools_with_runtime(
security: &Arc<SecurityPolicy>,
runtime: Arc<dyn RuntimeAdapter>,
memory: Arc<dyn Memory>,
composio_key: Option<&str>,
browser_config: &crate::config::BrowserConfig,
) -> Vec<Box<dyn Tool>> { ) -> Vec<Box<dyn Tool>> {
let mut tools: Vec<Box<dyn Tool>> = vec![ let mut tools: Vec<Box<dyn Tool>> = vec![
Box::new(ShellTool::new(security.clone())), Box::new(ShellTool::new(security.clone(), runtime)),
Box::new(FileReadTool::new(security.clone())), Box::new(FileReadTool::new(security.clone())),
Box::new(FileWriteTool::new(security.clone())), Box::new(FileWriteTool::new(security.clone())),
Box::new(MemoryStoreTool::new(memory.clone())), Box::new(MemoryStoreTool::new(memory.clone())),

View file

@ -1,4 +1,5 @@
use super::traits::{Tool, ToolResult}; use super::traits::{Tool, ToolResult};
use crate::runtime::RuntimeAdapter;
use crate::security::SecurityPolicy; use crate::security::SecurityPolicy;
use async_trait::async_trait; use async_trait::async_trait;
use serde_json::json; use serde_json::json;
@ -18,11 +19,12 @@ const SAFE_ENV_VARS: &[&str] = &[
/// Shell command execution tool with sandboxing /// Shell command execution tool with sandboxing
pub struct ShellTool { pub struct ShellTool {
security: Arc<SecurityPolicy>, security: Arc<SecurityPolicy>,
runtime: Arc<dyn RuntimeAdapter>,
} }
impl ShellTool { impl ShellTool {
pub fn new(security: Arc<SecurityPolicy>) -> Self { pub fn new(security: Arc<SecurityPolicy>, runtime: Arc<dyn RuntimeAdapter>) -> Self {
Self { security } Self { security, runtime }
} }
} }
@ -43,6 +45,11 @@ impl Tool for ShellTool {
"command": { "command": {
"type": "string", "type": "string",
"description": "The shell command to execute" "description": "The shell command to execute"
},
"approved": {
"type": "boolean",
"description": "Set true to explicitly approve medium/high-risk commands in supervised mode",
"default": false
} }
}, },
"required": ["command"] "required": ["command"]
@ -54,24 +61,55 @@ impl Tool for ShellTool {
.get("command") .get("command")
.and_then(|v| v.as_str()) .and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing 'command' parameter"))?; .ok_or_else(|| anyhow::anyhow!("Missing 'command' parameter"))?;
let approved = args
.get("approved")
.and_then(|v| v.as_bool())
.unwrap_or(false);
// Security check: validate command against allowlist if self.security.is_rate_limited() {
if !self.security.is_command_allowed(command) {
return Ok(ToolResult { return Ok(ToolResult {
success: false, success: false,
output: String::new(), output: String::new(),
error: Some(format!("Command not allowed by security policy: {command}")), error: Some("Rate limit exceeded: too many actions in the last hour".into()),
});
}
match self.security.validate_command_execution(command, approved) {
Ok(_) => {}
Err(reason) => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(reason),
});
}
}
if !self.security.record_action() {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some("Rate limit exceeded: action budget exhausted".into()),
}); });
} }
// Execute with timeout to prevent hanging commands. // Execute with timeout to prevent hanging commands.
// Clear the environment to prevent leaking API keys and other secrets // Clear the environment to prevent leaking API keys and other secrets
// (CWE-200), then re-add only safe, functional variables. // (CWE-200), then re-add only safe, functional variables.
let mut cmd = tokio::process::Command::new("sh"); let mut cmd = match self
cmd.arg("-c") .runtime
.arg(command) .build_shell_command(command, &self.security.workspace_dir)
.current_dir(&self.security.workspace_dir) {
.env_clear(); Ok(cmd) => cmd,
Err(e) => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!("Failed to build runtime command: {e}")),
});
}
};
cmd.env_clear();
for var in SAFE_ENV_VARS { for var in SAFE_ENV_VARS {
if let Ok(val) = std::env::var(var) { if let Ok(val) = std::env::var(var) {
@ -126,6 +164,7 @@ impl Tool for ShellTool {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use crate::runtime::{NativeRuntime, RuntimeAdapter};
use crate::security::{AutonomyLevel, SecurityPolicy}; use crate::security::{AutonomyLevel, SecurityPolicy};
fn test_security(autonomy: AutonomyLevel) -> Arc<SecurityPolicy> { fn test_security(autonomy: AutonomyLevel) -> Arc<SecurityPolicy> {
@ -136,32 +175,37 @@ mod tests {
}) })
} }
fn test_runtime() -> Arc<dyn RuntimeAdapter> {
Arc::new(NativeRuntime::new())
}
#[test] #[test]
fn shell_tool_name() { fn shell_tool_name() {
let tool = ShellTool::new(test_security(AutonomyLevel::Supervised)); let tool = ShellTool::new(test_security(AutonomyLevel::Supervised), test_runtime());
assert_eq!(tool.name(), "shell"); assert_eq!(tool.name(), "shell");
} }
#[test] #[test]
fn shell_tool_description() { fn shell_tool_description() {
let tool = ShellTool::new(test_security(AutonomyLevel::Supervised)); let tool = ShellTool::new(test_security(AutonomyLevel::Supervised), test_runtime());
assert!(!tool.description().is_empty()); assert!(!tool.description().is_empty());
} }
#[test] #[test]
fn shell_tool_schema_has_command() { fn shell_tool_schema_has_command() {
let tool = ShellTool::new(test_security(AutonomyLevel::Supervised)); let tool = ShellTool::new(test_security(AutonomyLevel::Supervised), test_runtime());
let schema = tool.parameters_schema(); let schema = tool.parameters_schema();
assert!(schema["properties"]["command"].is_object()); assert!(schema["properties"]["command"].is_object());
assert!(schema["required"] assert!(schema["required"]
.as_array() .as_array()
.unwrap() .unwrap()
.contains(&json!("command"))); .contains(&json!("command")));
assert!(schema["properties"]["approved"].is_object());
} }
#[tokio::test] #[tokio::test]
async fn shell_executes_allowed_command() { async fn shell_executes_allowed_command() {
let tool = ShellTool::new(test_security(AutonomyLevel::Supervised)); let tool = ShellTool::new(test_security(AutonomyLevel::Supervised), test_runtime());
let result = tool let result = tool
.execute(json!({"command": "echo hello"})) .execute(json!({"command": "echo hello"}))
.await .await
@ -173,15 +217,16 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn shell_blocks_disallowed_command() { async fn shell_blocks_disallowed_command() {
let tool = ShellTool::new(test_security(AutonomyLevel::Supervised)); let tool = ShellTool::new(test_security(AutonomyLevel::Supervised), test_runtime());
let result = tool.execute(json!({"command": "rm -rf /"})).await.unwrap(); let result = tool.execute(json!({"command": "rm -rf /"})).await.unwrap();
assert!(!result.success); assert!(!result.success);
assert!(result.error.as_ref().unwrap().contains("not allowed")); let error = result.error.as_deref().unwrap_or("");
assert!(error.contains("not allowed") || error.contains("high-risk"));
} }
#[tokio::test] #[tokio::test]
async fn shell_blocks_readonly() { async fn shell_blocks_readonly() {
let tool = ShellTool::new(test_security(AutonomyLevel::ReadOnly)); let tool = ShellTool::new(test_security(AutonomyLevel::ReadOnly), test_runtime());
let result = tool.execute(json!({"command": "ls"})).await.unwrap(); let result = tool.execute(json!({"command": "ls"})).await.unwrap();
assert!(!result.success); assert!(!result.success);
assert!(result.error.as_ref().unwrap().contains("not allowed")); assert!(result.error.as_ref().unwrap().contains("not allowed"));
@ -189,7 +234,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn shell_missing_command_param() { async fn shell_missing_command_param() {
let tool = ShellTool::new(test_security(AutonomyLevel::Supervised)); let tool = ShellTool::new(test_security(AutonomyLevel::Supervised), test_runtime());
let result = tool.execute(json!({})).await; let result = tool.execute(json!({})).await;
assert!(result.is_err()); assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("command")); assert!(result.unwrap_err().to_string().contains("command"));
@ -197,14 +242,14 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn shell_wrong_type_param() { async fn shell_wrong_type_param() {
let tool = ShellTool::new(test_security(AutonomyLevel::Supervised)); let tool = ShellTool::new(test_security(AutonomyLevel::Supervised), test_runtime());
let result = tool.execute(json!({"command": 123})).await; let result = tool.execute(json!({"command": 123})).await;
assert!(result.is_err()); assert!(result.is_err());
} }
#[tokio::test] #[tokio::test]
async fn shell_captures_exit_code() { async fn shell_captures_exit_code() {
let tool = ShellTool::new(test_security(AutonomyLevel::Supervised)); let tool = ShellTool::new(test_security(AutonomyLevel::Supervised), test_runtime());
let result = tool let result = tool
.execute(json!({"command": "ls /nonexistent_dir_xyz"})) .execute(json!({"command": "ls /nonexistent_dir_xyz"}))
.await .await
@ -250,7 +295,7 @@ mod tests {
let _g1 = EnvGuard::set("API_KEY", "sk-test-secret-12345"); let _g1 = EnvGuard::set("API_KEY", "sk-test-secret-12345");
let _g2 = EnvGuard::set("ZEROCLAW_API_KEY", "sk-test-secret-67890"); let _g2 = EnvGuard::set("ZEROCLAW_API_KEY", "sk-test-secret-67890");
let tool = ShellTool::new(test_security_with_env_cmd()); let tool = ShellTool::new(test_security_with_env_cmd(), test_runtime());
let result = tool.execute(json!({"command": "env"})).await.unwrap(); let result = tool.execute(json!({"command": "env"})).await.unwrap();
assert!(result.success); assert!(result.success);
assert!( assert!(
@ -265,7 +310,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn shell_preserves_path_and_home() { async fn shell_preserves_path_and_home() {
let tool = ShellTool::new(test_security_with_env_cmd()); let tool = ShellTool::new(test_security_with_env_cmd(), test_runtime());
let result = tool let result = tool
.execute(json!({"command": "echo $HOME"})) .execute(json!({"command": "echo $HOME"}))
@ -287,4 +332,37 @@ mod tests {
"PATH should be available in shell" "PATH should be available in shell"
); );
} }
#[tokio::test]
async fn shell_requires_approval_for_medium_risk_command() {
let security = Arc::new(SecurityPolicy {
autonomy: AutonomyLevel::Supervised,
allowed_commands: vec!["touch".into()],
workspace_dir: std::env::temp_dir(),
..SecurityPolicy::default()
});
let tool = ShellTool::new(security.clone(), test_runtime());
let denied = tool
.execute(json!({"command": "touch zeroclaw_shell_approval_test"}))
.await
.unwrap();
assert!(!denied.success);
assert!(denied
.error
.as_deref()
.unwrap_or("")
.contains("explicit approval"));
let allowed = tool
.execute(json!({
"command": "touch zeroclaw_shell_approval_test",
"approved": true
}))
.await
.unwrap();
assert!(allowed.success);
let _ = std::fs::remove_file(std::env::temp_dir().join("zeroclaw_shell_approval_test"));
}
} }