feat(config): make config writes atomic with rollback-safe replacement (#190)
* feat(runtime): add Docker runtime MVP and runtime-aware command builder * feat(security): add shell risk classification, approval gates, and action throttling * feat(gateway): add per-endpoint rate limiting and webhook idempotency * feat(config): make config writes atomic with rollback-safe replacement --------- Co-authored-by: chumyin <chumyin@users.noreply.github.com>
This commit is contained in:
parent
f1e3b1166d
commit
b0e1e32819
11 changed files with 1202 additions and 67 deletions
|
|
@ -40,7 +40,8 @@ pub async fn run(
|
||||||
// ── Wire up agnostic subsystems ──────────────────────────────
|
// ── Wire up agnostic subsystems ──────────────────────────────
|
||||||
let observer: Arc<dyn Observer> =
|
let observer: Arc<dyn Observer> =
|
||||||
Arc::from(observability::create_observer(&config.observability));
|
Arc::from(observability::create_observer(&config.observability));
|
||||||
let _runtime = runtime::create_runtime(&config.runtime)?;
|
let runtime: Arc<dyn runtime::RuntimeAdapter> =
|
||||||
|
Arc::from(runtime::create_runtime(&config.runtime)?);
|
||||||
let security = Arc::new(SecurityPolicy::from_config(
|
let security = Arc::new(SecurityPolicy::from_config(
|
||||||
&config.autonomy,
|
&config.autonomy,
|
||||||
&config.workspace_dir,
|
&config.workspace_dir,
|
||||||
|
|
@ -60,7 +61,13 @@ pub async fn run(
|
||||||
} else {
|
} else {
|
||||||
None
|
None
|
||||||
};
|
};
|
||||||
let _tools = tools::all_tools(&security, mem.clone(), composio_key, &config.browser);
|
let _tools = tools::all_tools_with_runtime(
|
||||||
|
&security,
|
||||||
|
runtime,
|
||||||
|
mem.clone(),
|
||||||
|
composio_key,
|
||||||
|
&config.browser,
|
||||||
|
);
|
||||||
|
|
||||||
// ── Resolve provider ─────────────────────────────────────────
|
// ── Resolve provider ─────────────────────────────────────────
|
||||||
let provider_name = provider_override
|
let provider_name = provider_override
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@ pub mod schema;
|
||||||
|
|
||||||
pub use schema::{
|
pub use schema::{
|
||||||
AutonomyConfig, BrowserConfig, ChannelsConfig, ComposioConfig, Config, DiscordConfig,
|
AutonomyConfig, BrowserConfig, ChannelsConfig, ComposioConfig, Config, DiscordConfig,
|
||||||
GatewayConfig, HeartbeatConfig, IMessageConfig, IdentityConfig, MatrixConfig, MemoryConfig,
|
DockerRuntimeConfig, GatewayConfig, HeartbeatConfig, IMessageConfig, IdentityConfig,
|
||||||
ModelRouteConfig, ObservabilityConfig, ReliabilityConfig, RuntimeConfig, SecretsConfig,
|
MatrixConfig, MemoryConfig, ModelRouteConfig, ObservabilityConfig, ReliabilityConfig,
|
||||||
SlackConfig, TelegramConfig, TunnelConfig, WebhookConfig,
|
RuntimeConfig, SecretsConfig, SlackConfig, TelegramConfig, TunnelConfig, WebhookConfig,
|
||||||
};
|
};
|
||||||
|
|
|
||||||
|
|
@ -2,8 +2,9 @@ use crate::security::AutonomyLevel;
|
||||||
use anyhow::{Context, Result};
|
use anyhow::{Context, Result};
|
||||||
use directories::UserDirs;
|
use directories::UserDirs;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::fs;
|
use std::fs::{self, File, OpenOptions};
|
||||||
use std::path::PathBuf;
|
use std::io::Write;
|
||||||
|
use std::path::{Path, PathBuf};
|
||||||
|
|
||||||
// ── Top-level config ──────────────────────────────────────────────
|
// ── Top-level config ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
@ -112,6 +113,18 @@ pub struct GatewayConfig {
|
||||||
/// Paired bearer tokens (managed automatically, not user-edited)
|
/// Paired bearer tokens (managed automatically, not user-edited)
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub paired_tokens: Vec<String>,
|
pub paired_tokens: Vec<String>,
|
||||||
|
|
||||||
|
/// Max `/pair` requests per minute per client key.
|
||||||
|
#[serde(default = "default_pair_rate_limit")]
|
||||||
|
pub pair_rate_limit_per_minute: u32,
|
||||||
|
|
||||||
|
/// Max `/webhook` requests per minute per client key.
|
||||||
|
#[serde(default = "default_webhook_rate_limit")]
|
||||||
|
pub webhook_rate_limit_per_minute: u32,
|
||||||
|
|
||||||
|
/// TTL for webhook idempotency keys.
|
||||||
|
#[serde(default = "default_idempotency_ttl_secs")]
|
||||||
|
pub idempotency_ttl_secs: u64,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn default_gateway_port() -> u16 {
|
fn default_gateway_port() -> u16 {
|
||||||
|
|
@ -122,6 +135,18 @@ fn default_gateway_host() -> String {
|
||||||
"127.0.0.1".into()
|
"127.0.0.1".into()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn default_pair_rate_limit() -> u32 {
|
||||||
|
10
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_webhook_rate_limit() -> u32 {
|
||||||
|
60
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_idempotency_ttl_secs() -> u64 {
|
||||||
|
300
|
||||||
|
}
|
||||||
|
|
||||||
fn default_true() -> bool {
|
fn default_true() -> bool {
|
||||||
true
|
true
|
||||||
}
|
}
|
||||||
|
|
@ -134,6 +159,9 @@ impl Default for GatewayConfig {
|
||||||
require_pairing: true,
|
require_pairing: true,
|
||||||
allow_public_bind: false,
|
allow_public_bind: false,
|
||||||
paired_tokens: Vec::new(),
|
paired_tokens: Vec::new(),
|
||||||
|
pair_rate_limit_per_minute: default_pair_rate_limit(),
|
||||||
|
webhook_rate_limit_per_minute: default_webhook_rate_limit(),
|
||||||
|
idempotency_ttl_secs: default_idempotency_ttl_secs(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -320,6 +348,14 @@ pub struct AutonomyConfig {
|
||||||
pub forbidden_paths: Vec<String>,
|
pub forbidden_paths: Vec<String>,
|
||||||
pub max_actions_per_hour: u32,
|
pub max_actions_per_hour: u32,
|
||||||
pub max_cost_per_day_cents: u32,
|
pub max_cost_per_day_cents: u32,
|
||||||
|
|
||||||
|
/// Require explicit approval for medium-risk shell commands.
|
||||||
|
#[serde(default = "default_true")]
|
||||||
|
pub require_approval_for_medium_risk: bool,
|
||||||
|
|
||||||
|
/// Block high-risk shell commands even if allowlisted.
|
||||||
|
#[serde(default = "default_true")]
|
||||||
|
pub block_high_risk_commands: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for AutonomyConfig {
|
impl Default for AutonomyConfig {
|
||||||
|
|
@ -363,6 +399,8 @@ impl Default for AutonomyConfig {
|
||||||
],
|
],
|
||||||
max_actions_per_hour: 20,
|
max_actions_per_hour: 20,
|
||||||
max_cost_per_day_cents: 500,
|
max_cost_per_day_cents: 500,
|
||||||
|
require_approval_for_medium_risk: true,
|
||||||
|
block_high_risk_commands: true,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -371,16 +409,85 @@ impl Default for AutonomyConfig {
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
pub struct RuntimeConfig {
|
pub struct RuntimeConfig {
|
||||||
/// Runtime kind (currently supported: "native").
|
/// Runtime kind (`native` | `docker`).
|
||||||
///
|
#[serde(default = "default_runtime_kind")]
|
||||||
/// Reserved values (not implemented yet): "docker", "cloudflare".
|
|
||||||
pub kind: String,
|
pub kind: String,
|
||||||
|
|
||||||
|
/// Docker runtime settings (used when `kind = "docker"`).
|
||||||
|
#[serde(default)]
|
||||||
|
pub docker: DockerRuntimeConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct DockerRuntimeConfig {
|
||||||
|
/// Runtime image used to execute shell commands.
|
||||||
|
#[serde(default = "default_docker_image")]
|
||||||
|
pub image: String,
|
||||||
|
|
||||||
|
/// Docker network mode (`none`, `bridge`, etc.).
|
||||||
|
#[serde(default = "default_docker_network")]
|
||||||
|
pub network: String,
|
||||||
|
|
||||||
|
/// Optional memory limit in MB (`None` = no explicit limit).
|
||||||
|
#[serde(default = "default_docker_memory_limit_mb")]
|
||||||
|
pub memory_limit_mb: Option<u64>,
|
||||||
|
|
||||||
|
/// Optional CPU limit (`None` = no explicit limit).
|
||||||
|
#[serde(default = "default_docker_cpu_limit")]
|
||||||
|
pub cpu_limit: Option<f64>,
|
||||||
|
|
||||||
|
/// Mount root filesystem as read-only.
|
||||||
|
#[serde(default = "default_true")]
|
||||||
|
pub read_only_rootfs: bool,
|
||||||
|
|
||||||
|
/// Mount configured workspace into `/workspace`.
|
||||||
|
#[serde(default = "default_true")]
|
||||||
|
pub mount_workspace: bool,
|
||||||
|
|
||||||
|
/// Optional workspace root allowlist for Docker mount validation.
|
||||||
|
#[serde(default)]
|
||||||
|
pub allowed_workspace_roots: Vec<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_runtime_kind() -> String {
|
||||||
|
"native".into()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_docker_image() -> String {
|
||||||
|
"alpine:3.20".into()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_docker_network() -> String {
|
||||||
|
"none".into()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_docker_memory_limit_mb() -> Option<u64> {
|
||||||
|
Some(512)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_docker_cpu_limit() -> Option<f64> {
|
||||||
|
Some(1.0)
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for DockerRuntimeConfig {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
image: default_docker_image(),
|
||||||
|
network: default_docker_network(),
|
||||||
|
memory_limit_mb: default_docker_memory_limit_mb(),
|
||||||
|
cpu_limit: default_docker_cpu_limit(),
|
||||||
|
read_only_rootfs: true,
|
||||||
|
mount_workspace: true,
|
||||||
|
allowed_workspace_roots: Vec::new(),
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for RuntimeConfig {
|
impl Default for RuntimeConfig {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
Self {
|
Self {
|
||||||
kind: "native".into(),
|
kind: default_runtime_kind(),
|
||||||
|
docker: DockerRuntimeConfig::default(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -811,11 +918,86 @@ impl Config {
|
||||||
|
|
||||||
pub fn save(&self) -> Result<()> {
|
pub fn save(&self) -> Result<()> {
|
||||||
let toml_str = toml::to_string_pretty(self).context("Failed to serialize config")?;
|
let toml_str = toml::to_string_pretty(self).context("Failed to serialize config")?;
|
||||||
fs::write(&self.config_path, toml_str).context("Failed to write config file")?;
|
|
||||||
|
let parent_dir = self
|
||||||
|
.config_path
|
||||||
|
.parent()
|
||||||
|
.context("Config path must have a parent directory")?;
|
||||||
|
fs::create_dir_all(parent_dir).with_context(|| {
|
||||||
|
format!(
|
||||||
|
"Failed to create config directory: {}",
|
||||||
|
parent_dir.display()
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let file_name = self
|
||||||
|
.config_path
|
||||||
|
.file_name()
|
||||||
|
.and_then(|v| v.to_str())
|
||||||
|
.unwrap_or("config.toml");
|
||||||
|
let temp_path = parent_dir.join(format!(".{file_name}.tmp-{}", uuid::Uuid::new_v4()));
|
||||||
|
let backup_path = parent_dir.join(format!("{file_name}.bak"));
|
||||||
|
|
||||||
|
let mut temp_file = OpenOptions::new()
|
||||||
|
.create_new(true)
|
||||||
|
.write(true)
|
||||||
|
.open(&temp_path)
|
||||||
|
.with_context(|| {
|
||||||
|
format!(
|
||||||
|
"Failed to create temporary config file: {}",
|
||||||
|
temp_path.display()
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
temp_file
|
||||||
|
.write_all(toml_str.as_bytes())
|
||||||
|
.context("Failed to write temporary config contents")?;
|
||||||
|
temp_file
|
||||||
|
.sync_all()
|
||||||
|
.context("Failed to fsync temporary config file")?;
|
||||||
|
drop(temp_file);
|
||||||
|
|
||||||
|
let had_existing_config = self.config_path.exists();
|
||||||
|
if had_existing_config {
|
||||||
|
fs::copy(&self.config_path, &backup_path).with_context(|| {
|
||||||
|
format!(
|
||||||
|
"Failed to create config backup before atomic replace: {}",
|
||||||
|
backup_path.display()
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Err(e) = fs::rename(&temp_path, &self.config_path) {
|
||||||
|
let _ = fs::remove_file(&temp_path);
|
||||||
|
if had_existing_config && backup_path.exists() {
|
||||||
|
let _ = fs::copy(&backup_path, &self.config_path);
|
||||||
|
}
|
||||||
|
anyhow::bail!("Failed to atomically replace config file: {e}");
|
||||||
|
}
|
||||||
|
|
||||||
|
sync_directory(parent_dir)?;
|
||||||
|
|
||||||
|
if had_existing_config {
|
||||||
|
let _ = fs::remove_file(&backup_path);
|
||||||
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(unix)]
|
||||||
|
fn sync_directory(path: &Path) -> Result<()> {
|
||||||
|
let dir = File::open(path)
|
||||||
|
.with_context(|| format!("Failed to open directory for fsync: {}", path.display()))?;
|
||||||
|
dir.sync_all()
|
||||||
|
.with_context(|| format!("Failed to fsync directory metadata: {}", path.display()))?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(not(unix))]
|
||||||
|
fn sync_directory(_path: &Path) -> Result<()> {
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
@ -850,12 +1032,20 @@ mod tests {
|
||||||
assert!(a.forbidden_paths.contains(&"/etc".to_string()));
|
assert!(a.forbidden_paths.contains(&"/etc".to_string()));
|
||||||
assert_eq!(a.max_actions_per_hour, 20);
|
assert_eq!(a.max_actions_per_hour, 20);
|
||||||
assert_eq!(a.max_cost_per_day_cents, 500);
|
assert_eq!(a.max_cost_per_day_cents, 500);
|
||||||
|
assert!(a.require_approval_for_medium_risk);
|
||||||
|
assert!(a.block_high_risk_commands);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn runtime_config_default() {
|
fn runtime_config_default() {
|
||||||
let r = RuntimeConfig::default();
|
let r = RuntimeConfig::default();
|
||||||
assert_eq!(r.kind, "native");
|
assert_eq!(r.kind, "native");
|
||||||
|
assert_eq!(r.docker.image, "alpine:3.20");
|
||||||
|
assert_eq!(r.docker.network, "none");
|
||||||
|
assert_eq!(r.docker.memory_limit_mb, Some(512));
|
||||||
|
assert_eq!(r.docker.cpu_limit, Some(1.0));
|
||||||
|
assert!(r.docker.read_only_rootfs);
|
||||||
|
assert!(r.docker.mount_workspace);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
@ -905,9 +1095,12 @@ mod tests {
|
||||||
forbidden_paths: vec!["/secret".into()],
|
forbidden_paths: vec!["/secret".into()],
|
||||||
max_actions_per_hour: 50,
|
max_actions_per_hour: 50,
|
||||||
max_cost_per_day_cents: 1000,
|
max_cost_per_day_cents: 1000,
|
||||||
|
require_approval_for_medium_risk: false,
|
||||||
|
block_high_risk_commands: true,
|
||||||
},
|
},
|
||||||
runtime: RuntimeConfig {
|
runtime: RuntimeConfig {
|
||||||
kind: "docker".into(),
|
kind: "docker".into(),
|
||||||
|
..RuntimeConfig::default()
|
||||||
},
|
},
|
||||||
reliability: ReliabilityConfig::default(),
|
reliability: ReliabilityConfig::default(),
|
||||||
model_routes: Vec::new(),
|
model_routes: Vec::new(),
|
||||||
|
|
@ -1022,6 +1215,38 @@ default_temperature = 0.7
|
||||||
let _ = fs::remove_dir_all(&dir);
|
let _ = fs::remove_dir_all(&dir);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn config_save_atomic_cleanup() {
|
||||||
|
let dir =
|
||||||
|
std::env::temp_dir().join(format!("zeroclaw_test_config_{}", uuid::Uuid::new_v4()));
|
||||||
|
fs::create_dir_all(&dir).unwrap();
|
||||||
|
|
||||||
|
let config_path = dir.join("config.toml");
|
||||||
|
let mut config = Config::default();
|
||||||
|
config.workspace_dir = dir.join("workspace");
|
||||||
|
config.config_path = config_path.clone();
|
||||||
|
config.default_model = Some("model-a".into());
|
||||||
|
|
||||||
|
config.save().unwrap();
|
||||||
|
assert!(config_path.exists());
|
||||||
|
|
||||||
|
config.default_model = Some("model-b".into());
|
||||||
|
config.save().unwrap();
|
||||||
|
|
||||||
|
let contents = fs::read_to_string(&config_path).unwrap();
|
||||||
|
assert!(contents.contains("model-b"));
|
||||||
|
|
||||||
|
let names: Vec<String> = fs::read_dir(&dir)
|
||||||
|
.unwrap()
|
||||||
|
.map(|entry| entry.unwrap().file_name().to_string_lossy().to_string())
|
||||||
|
.collect();
|
||||||
|
assert!(!names.iter().any(|name| name.contains(".tmp-")));
|
||||||
|
assert!(!names.iter().any(|name| name.ends_with(".bak")));
|
||||||
|
|
||||||
|
let _ = fs::remove_dir_all(&dir);
|
||||||
|
}
|
||||||
|
|
||||||
// ── Telegram / Discord config ────────────────────────────
|
// ── Telegram / Discord config ────────────────────────────
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
@ -1343,6 +1568,9 @@ channel_id = "C123"
|
||||||
g.paired_tokens.is_empty(),
|
g.paired_tokens.is_empty(),
|
||||||
"No pre-paired tokens by default"
|
"No pre-paired tokens by default"
|
||||||
);
|
);
|
||||||
|
assert_eq!(g.pair_rate_limit_per_minute, 10);
|
||||||
|
assert_eq!(g.webhook_rate_limit_per_minute, 60);
|
||||||
|
assert_eq!(g.idempotency_ttl_secs, 300);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
@ -1368,12 +1596,18 @@ channel_id = "C123"
|
||||||
require_pairing: true,
|
require_pairing: true,
|
||||||
allow_public_bind: false,
|
allow_public_bind: false,
|
||||||
paired_tokens: vec!["zc_test_token".into()],
|
paired_tokens: vec!["zc_test_token".into()],
|
||||||
|
pair_rate_limit_per_minute: 12,
|
||||||
|
webhook_rate_limit_per_minute: 80,
|
||||||
|
idempotency_ttl_secs: 600,
|
||||||
};
|
};
|
||||||
let toml_str = toml::to_string(&g).unwrap();
|
let toml_str = toml::to_string(&g).unwrap();
|
||||||
let parsed: GatewayConfig = toml::from_str(&toml_str).unwrap();
|
let parsed: GatewayConfig = toml::from_str(&toml_str).unwrap();
|
||||||
assert!(parsed.require_pairing);
|
assert!(parsed.require_pairing);
|
||||||
assert!(!parsed.allow_public_bind);
|
assert!(!parsed.allow_public_bind);
|
||||||
assert_eq!(parsed.paired_tokens, vec!["zc_test_token"]);
|
assert_eq!(parsed.paired_tokens, vec!["zc_test_token"]);
|
||||||
|
assert_eq!(parsed.pair_rate_limit_per_minute, 12);
|
||||||
|
assert_eq!(parsed.webhook_rate_limit_per_minute, 80);
|
||||||
|
assert_eq!(parsed.idempotency_ttl_secs, 600);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
|
||||||
|
|
@ -22,9 +22,10 @@ use axum::{
|
||||||
routing::{get, post},
|
routing::{get, post},
|
||||||
Router,
|
Router,
|
||||||
};
|
};
|
||||||
|
use std::collections::HashMap;
|
||||||
use std::net::SocketAddr;
|
use std::net::SocketAddr;
|
||||||
use std::sync::Arc;
|
use std::sync::{Arc, Mutex};
|
||||||
use std::time::Duration;
|
use std::time::{Duration, Instant};
|
||||||
use tower_http::limit::RequestBodyLimitLayer;
|
use tower_http::limit::RequestBodyLimitLayer;
|
||||||
use tower_http::timeout::TimeoutLayer;
|
use tower_http::timeout::TimeoutLayer;
|
||||||
|
|
||||||
|
|
@ -32,6 +33,118 @@ use tower_http::timeout::TimeoutLayer;
|
||||||
pub const MAX_BODY_SIZE: usize = 65_536;
|
pub const MAX_BODY_SIZE: usize = 65_536;
|
||||||
/// Request timeout (30s) — prevents slow-loris attacks
|
/// Request timeout (30s) — prevents slow-loris attacks
|
||||||
pub const REQUEST_TIMEOUT_SECS: u64 = 30;
|
pub const REQUEST_TIMEOUT_SECS: u64 = 30;
|
||||||
|
/// Sliding window used by gateway rate limiting.
|
||||||
|
pub const RATE_LIMIT_WINDOW_SECS: u64 = 60;
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
struct SlidingWindowRateLimiter {
|
||||||
|
limit_per_window: u32,
|
||||||
|
window: Duration,
|
||||||
|
requests: Mutex<HashMap<String, Vec<Instant>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SlidingWindowRateLimiter {
|
||||||
|
fn new(limit_per_window: u32, window: Duration) -> Self {
|
||||||
|
Self {
|
||||||
|
limit_per_window,
|
||||||
|
window,
|
||||||
|
requests: Mutex::new(HashMap::new()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn allow(&self, key: &str) -> bool {
|
||||||
|
if self.limit_per_window == 0 {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
let now = Instant::now();
|
||||||
|
let cutoff = now.checked_sub(self.window).unwrap_or_else(Instant::now);
|
||||||
|
|
||||||
|
let mut requests = self
|
||||||
|
.requests
|
||||||
|
.lock()
|
||||||
|
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||||
|
|
||||||
|
let entry = requests.entry(key.to_owned()).or_default();
|
||||||
|
entry.retain(|instant| *instant > cutoff);
|
||||||
|
|
||||||
|
if entry.len() >= self.limit_per_window as usize {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
entry.push(now);
|
||||||
|
true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct GatewayRateLimiter {
|
||||||
|
pair: SlidingWindowRateLimiter,
|
||||||
|
webhook: SlidingWindowRateLimiter,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl GatewayRateLimiter {
|
||||||
|
fn new(pair_per_minute: u32, webhook_per_minute: u32) -> Self {
|
||||||
|
let window = Duration::from_secs(RATE_LIMIT_WINDOW_SECS);
|
||||||
|
Self {
|
||||||
|
pair: SlidingWindowRateLimiter::new(pair_per_minute, window),
|
||||||
|
webhook: SlidingWindowRateLimiter::new(webhook_per_minute, window),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn allow_pair(&self, key: &str) -> bool {
|
||||||
|
self.pair.allow(key)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn allow_webhook(&self, key: &str) -> bool {
|
||||||
|
self.webhook.allow(key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct IdempotencyStore {
|
||||||
|
ttl: Duration,
|
||||||
|
keys: Mutex<HashMap<String, Instant>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl IdempotencyStore {
|
||||||
|
fn new(ttl: Duration) -> Self {
|
||||||
|
Self {
|
||||||
|
ttl,
|
||||||
|
keys: Mutex::new(HashMap::new()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Returns true if this key is new and is now recorded.
|
||||||
|
fn record_if_new(&self, key: &str) -> bool {
|
||||||
|
let now = Instant::now();
|
||||||
|
let mut keys = self
|
||||||
|
.keys
|
||||||
|
.lock()
|
||||||
|
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||||
|
|
||||||
|
keys.retain(|_, seen_at| now.duration_since(*seen_at) < self.ttl);
|
||||||
|
|
||||||
|
if keys.contains_key(key) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
keys.insert(key.to_owned(), now);
|
||||||
|
true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn client_key_from_headers(headers: &HeaderMap) -> String {
|
||||||
|
for header_name in ["X-Forwarded-For", "X-Real-IP"] {
|
||||||
|
if let Some(value) = headers.get(header_name).and_then(|v| v.to_str().ok()) {
|
||||||
|
let first = value.split(',').next().unwrap_or("").trim();
|
||||||
|
if !first.is_empty() {
|
||||||
|
return first.to_owned();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"unknown".into()
|
||||||
|
}
|
||||||
|
|
||||||
/// Shared state for all axum handlers
|
/// Shared state for all axum handlers
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
|
|
@ -43,6 +156,8 @@ pub struct AppState {
|
||||||
pub auto_save: bool,
|
pub auto_save: bool,
|
||||||
pub webhook_secret: Option<Arc<str>>,
|
pub webhook_secret: Option<Arc<str>>,
|
||||||
pub pairing: Arc<PairingGuard>,
|
pub pairing: Arc<PairingGuard>,
|
||||||
|
pub rate_limiter: Arc<GatewayRateLimiter>,
|
||||||
|
pub idempotency_store: Arc<IdempotencyStore>,
|
||||||
pub whatsapp: Option<Arc<WhatsAppChannel>>,
|
pub whatsapp: Option<Arc<WhatsAppChannel>>,
|
||||||
/// `WhatsApp` app secret for webhook signature verification (`X-Hub-Signature-256`)
|
/// `WhatsApp` app secret for webhook signature verification (`X-Hub-Signature-256`)
|
||||||
pub whatsapp_app_secret: Option<Arc<str>>,
|
pub whatsapp_app_secret: Option<Arc<str>>,
|
||||||
|
|
@ -66,17 +181,15 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
|
||||||
let actual_port = listener.local_addr()?.port();
|
let actual_port = listener.local_addr()?.port();
|
||||||
let display_addr = format!("{host}:{actual_port}");
|
let display_addr = format!("{host}:{actual_port}");
|
||||||
|
|
||||||
|
let provider: Arc<dyn Provider> = Arc::from(providers::create_resilient_provider(
|
||||||
|
config.default_provider.as_deref().unwrap_or("openrouter"),
|
||||||
|
config.api_key.as_deref(),
|
||||||
|
&config.reliability,
|
||||||
|
)?);
|
||||||
let model = config
|
let model = config
|
||||||
.default_model
|
.default_model
|
||||||
.clone()
|
.clone()
|
||||||
.unwrap_or_else(|| "anthropic/claude-sonnet-4-20250514".into());
|
.unwrap_or_else(|| "anthropic/claude-sonnet-4-20250514".into());
|
||||||
let provider: Arc<dyn Provider> = Arc::from(providers::create_routed_provider(
|
|
||||||
config.default_provider.as_deref().unwrap_or("openrouter"),
|
|
||||||
config.api_key.as_deref(),
|
|
||||||
&config.reliability,
|
|
||||||
&config.model_routes,
|
|
||||||
&model,
|
|
||||||
)?);
|
|
||||||
let temperature = config.default_temperature;
|
let temperature = config.default_temperature;
|
||||||
let mem: Arc<dyn Memory> = Arc::from(memory::create_memory(
|
let mem: Arc<dyn Memory> = Arc::from(memory::create_memory(
|
||||||
&config.memory,
|
&config.memory,
|
||||||
|
|
@ -127,6 +240,13 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
|
||||||
config.gateway.require_pairing,
|
config.gateway.require_pairing,
|
||||||
&config.gateway.paired_tokens,
|
&config.gateway.paired_tokens,
|
||||||
));
|
));
|
||||||
|
let rate_limiter = Arc::new(GatewayRateLimiter::new(
|
||||||
|
config.gateway.pair_rate_limit_per_minute,
|
||||||
|
config.gateway.webhook_rate_limit_per_minute,
|
||||||
|
));
|
||||||
|
let idempotency_store = Arc::new(IdempotencyStore::new(Duration::from_secs(
|
||||||
|
config.gateway.idempotency_ttl_secs.max(1),
|
||||||
|
)));
|
||||||
|
|
||||||
// ── Tunnel ────────────────────────────────────────────────
|
// ── Tunnel ────────────────────────────────────────────────
|
||||||
let tunnel = crate::tunnel::create_tunnel(&config.tunnel)?;
|
let tunnel = crate::tunnel::create_tunnel(&config.tunnel)?;
|
||||||
|
|
@ -185,6 +305,8 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
|
||||||
auto_save: config.memory.auto_save,
|
auto_save: config.memory.auto_save,
|
||||||
webhook_secret,
|
webhook_secret,
|
||||||
pairing,
|
pairing,
|
||||||
|
rate_limiter,
|
||||||
|
idempotency_store,
|
||||||
whatsapp: whatsapp_channel,
|
whatsapp: whatsapp_channel,
|
||||||
whatsapp_app_secret,
|
whatsapp_app_secret,
|
||||||
};
|
};
|
||||||
|
|
@ -225,6 +347,16 @@ async fn handle_health(State(state): State<AppState>) -> impl IntoResponse {
|
||||||
|
|
||||||
/// POST /pair — exchange one-time code for bearer token
|
/// POST /pair — exchange one-time code for bearer token
|
||||||
async fn handle_pair(State(state): State<AppState>, headers: HeaderMap) -> impl IntoResponse {
|
async fn handle_pair(State(state): State<AppState>, headers: HeaderMap) -> impl IntoResponse {
|
||||||
|
let client_key = client_key_from_headers(&headers);
|
||||||
|
if !state.rate_limiter.allow_pair(&client_key) {
|
||||||
|
tracing::warn!("/pair rate limit exceeded for key: {client_key}");
|
||||||
|
let err = serde_json::json!({
|
||||||
|
"error": "Too many pairing requests. Please retry later.",
|
||||||
|
"retry_after": RATE_LIMIT_WINDOW_SECS,
|
||||||
|
});
|
||||||
|
return (StatusCode::TOO_MANY_REQUESTS, Json(err));
|
||||||
|
}
|
||||||
|
|
||||||
let code = headers
|
let code = headers
|
||||||
.get("X-Pairing-Code")
|
.get("X-Pairing-Code")
|
||||||
.and_then(|v| v.to_str().ok())
|
.and_then(|v| v.to_str().ok())
|
||||||
|
|
@ -270,6 +402,16 @@ async fn handle_webhook(
|
||||||
headers: HeaderMap,
|
headers: HeaderMap,
|
||||||
body: Result<Json<WebhookBody>, axum::extract::rejection::JsonRejection>,
|
body: Result<Json<WebhookBody>, axum::extract::rejection::JsonRejection>,
|
||||||
) -> impl IntoResponse {
|
) -> impl IntoResponse {
|
||||||
|
let client_key = client_key_from_headers(&headers);
|
||||||
|
if !state.rate_limiter.allow_webhook(&client_key) {
|
||||||
|
tracing::warn!("/webhook rate limit exceeded for key: {client_key}");
|
||||||
|
let err = serde_json::json!({
|
||||||
|
"error": "Too many webhook requests. Please retry later.",
|
||||||
|
"retry_after": RATE_LIMIT_WINDOW_SECS,
|
||||||
|
});
|
||||||
|
return (StatusCode::TOO_MANY_REQUESTS, Json(err));
|
||||||
|
}
|
||||||
|
|
||||||
// ── Bearer token auth (pairing) ──
|
// ── Bearer token auth (pairing) ──
|
||||||
if state.pairing.require_pairing() {
|
if state.pairing.require_pairing() {
|
||||||
let auth = headers
|
let auth = headers
|
||||||
|
|
@ -312,6 +454,24 @@ async fn handle_webhook(
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// ── Idempotency (optional) ──
|
||||||
|
if let Some(idempotency_key) = headers
|
||||||
|
.get("X-Idempotency-Key")
|
||||||
|
.and_then(|v| v.to_str().ok())
|
||||||
|
.map(str::trim)
|
||||||
|
.filter(|value| !value.is_empty())
|
||||||
|
{
|
||||||
|
if !state.idempotency_store.record_if_new(idempotency_key) {
|
||||||
|
tracing::info!("Webhook duplicate ignored (idempotency key: {idempotency_key})");
|
||||||
|
let body = serde_json::json!({
|
||||||
|
"status": "duplicate",
|
||||||
|
"idempotent": true,
|
||||||
|
"message": "Request already processed for this idempotency key"
|
||||||
|
});
|
||||||
|
return (StatusCode::OK, Json(body));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
let message = &webhook_body.message;
|
let message = &webhook_body.message;
|
||||||
|
|
||||||
if state.auto_save {
|
if state.auto_save {
|
||||||
|
|
@ -508,6 +668,13 @@ async fn handle_whatsapp_message(
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
use crate::memory::{Memory, MemoryCategory, MemoryEntry};
|
||||||
|
use crate::providers::Provider;
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use axum::http::HeaderValue;
|
||||||
|
use axum::response::IntoResponse;
|
||||||
|
use http_body_util::BodyExt;
|
||||||
|
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn security_body_limit_is_64kb() {
|
fn security_body_limit_is_64kb() {
|
||||||
|
|
@ -547,6 +714,133 @@ mod tests {
|
||||||
assert_clone::<AppState>();
|
assert_clone::<AppState>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn gateway_rate_limiter_blocks_after_limit() {
|
||||||
|
let limiter = GatewayRateLimiter::new(2, 2);
|
||||||
|
assert!(limiter.allow_pair("127.0.0.1"));
|
||||||
|
assert!(limiter.allow_pair("127.0.0.1"));
|
||||||
|
assert!(!limiter.allow_pair("127.0.0.1"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn idempotency_store_rejects_duplicate_key() {
|
||||||
|
let store = IdempotencyStore::new(Duration::from_secs(30));
|
||||||
|
assert!(store.record_if_new("req-1"));
|
||||||
|
assert!(!store.record_if_new("req-1"));
|
||||||
|
assert!(store.record_if_new("req-2"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Default)]
|
||||||
|
struct MockMemory;
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Memory for MockMemory {
|
||||||
|
fn name(&self) -> &str {
|
||||||
|
"mock"
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn store(
|
||||||
|
&self,
|
||||||
|
_key: &str,
|
||||||
|
_content: &str,
|
||||||
|
_category: MemoryCategory,
|
||||||
|
) -> anyhow::Result<()> {
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn recall(&self, _query: &str, _limit: usize) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||||
|
Ok(Vec::new())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn get(&self, _key: &str) -> anyhow::Result<Option<MemoryEntry>> {
|
||||||
|
Ok(None)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn list(
|
||||||
|
&self,
|
||||||
|
_category: Option<&MemoryCategory>,
|
||||||
|
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||||
|
Ok(Vec::new())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn forget(&self, _key: &str) -> anyhow::Result<bool> {
|
||||||
|
Ok(false)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn count(&self) -> anyhow::Result<usize> {
|
||||||
|
Ok(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn health_check(&self) -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Default)]
|
||||||
|
struct MockProvider {
|
||||||
|
calls: AtomicUsize,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Provider for MockProvider {
|
||||||
|
async fn chat_with_system(
|
||||||
|
&self,
|
||||||
|
_system_prompt: Option<&str>,
|
||||||
|
_message: &str,
|
||||||
|
_model: &str,
|
||||||
|
_temperature: f64,
|
||||||
|
) -> anyhow::Result<String> {
|
||||||
|
self.calls.fetch_add(1, Ordering::SeqCst);
|
||||||
|
Ok("ok".into())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn webhook_idempotency_skips_duplicate_provider_calls() {
|
||||||
|
let provider_impl = Arc::new(MockProvider::default());
|
||||||
|
let provider: Arc<dyn Provider> = provider_impl.clone();
|
||||||
|
let memory: Arc<dyn Memory> = Arc::new(MockMemory);
|
||||||
|
|
||||||
|
let state = AppState {
|
||||||
|
provider,
|
||||||
|
model: "test-model".into(),
|
||||||
|
temperature: 0.0,
|
||||||
|
mem: memory,
|
||||||
|
auto_save: false,
|
||||||
|
webhook_secret: None,
|
||||||
|
pairing: Arc::new(PairingGuard::new(false, &[])),
|
||||||
|
rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100)),
|
||||||
|
idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300))),
|
||||||
|
whatsapp: None,
|
||||||
|
whatsapp_app_secret: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut headers = HeaderMap::new();
|
||||||
|
headers.insert("X-Idempotency-Key", HeaderValue::from_static("abc-123"));
|
||||||
|
|
||||||
|
let body = Ok(Json(WebhookBody {
|
||||||
|
message: "hello".into(),
|
||||||
|
}));
|
||||||
|
let first = handle_webhook(State(state.clone()), headers.clone(), body)
|
||||||
|
.await
|
||||||
|
.into_response();
|
||||||
|
assert_eq!(first.status(), StatusCode::OK);
|
||||||
|
|
||||||
|
let body = Ok(Json(WebhookBody {
|
||||||
|
message: "hello".into(),
|
||||||
|
}));
|
||||||
|
let second = handle_webhook(State(state), headers, body)
|
||||||
|
.await
|
||||||
|
.into_response();
|
||||||
|
assert_eq!(second.status(), StatusCode::OK);
|
||||||
|
|
||||||
|
let payload = second.into_body().collect().await.unwrap().to_bytes();
|
||||||
|
let parsed: serde_json::Value = serde_json::from_slice(&payload).unwrap();
|
||||||
|
assert_eq!(parsed["status"], "duplicate");
|
||||||
|
assert_eq!(parsed["idempotent"], true);
|
||||||
|
assert_eq!(provider_impl.calls.load(Ordering::SeqCst), 1);
|
||||||
|
}
|
||||||
|
|
||||||
// ══════════════════════════════════════════════════════════
|
// ══════════════════════════════════════════════════════════
|
||||||
// WhatsApp Signature Verification Tests (CWE-345 Prevention)
|
// WhatsApp Signature Verification Tests (CWE-345 Prevention)
|
||||||
// ══════════════════════════════════════════════════════════
|
// ══════════════════════════════════════════════════════════
|
||||||
|
|
@ -572,7 +866,11 @@ mod tests {
|
||||||
|
|
||||||
let signature_header = compute_whatsapp_signature_header(app_secret, body);
|
let signature_header = compute_whatsapp_signature_header(app_secret, body);
|
||||||
|
|
||||||
assert!(verify_whatsapp_signature(app_secret, body, &signature_header));
|
assert!(verify_whatsapp_signature(
|
||||||
|
app_secret,
|
||||||
|
body,
|
||||||
|
&signature_header
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
@ -583,7 +881,11 @@ mod tests {
|
||||||
|
|
||||||
let signature_header = compute_whatsapp_signature_header(wrong_secret, body);
|
let signature_header = compute_whatsapp_signature_header(wrong_secret, body);
|
||||||
|
|
||||||
assert!(!verify_whatsapp_signature(app_secret, body, &signature_header));
|
assert!(!verify_whatsapp_signature(
|
||||||
|
app_secret,
|
||||||
|
body,
|
||||||
|
&signature_header
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
@ -610,7 +912,11 @@ mod tests {
|
||||||
// Signature without "sha256=" prefix
|
// Signature without "sha256=" prefix
|
||||||
let signature_header = "abc123def456";
|
let signature_header = "abc123def456";
|
||||||
|
|
||||||
assert!(!verify_whatsapp_signature(app_secret, body, signature_header));
|
assert!(!verify_whatsapp_signature(
|
||||||
|
app_secret,
|
||||||
|
body,
|
||||||
|
signature_header
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
@ -643,7 +949,11 @@ mod tests {
|
||||||
|
|
||||||
let signature_header = compute_whatsapp_signature_header(app_secret, body);
|
let signature_header = compute_whatsapp_signature_header(app_secret, body);
|
||||||
|
|
||||||
assert!(verify_whatsapp_signature(app_secret, body, &signature_header));
|
assert!(verify_whatsapp_signature(
|
||||||
|
app_secret,
|
||||||
|
body,
|
||||||
|
&signature_header
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
@ -653,7 +963,11 @@ mod tests {
|
||||||
|
|
||||||
let signature_header = compute_whatsapp_signature_header(app_secret, body);
|
let signature_header = compute_whatsapp_signature_header(app_secret, body);
|
||||||
|
|
||||||
assert!(verify_whatsapp_signature(app_secret, body, &signature_header));
|
assert!(verify_whatsapp_signature(
|
||||||
|
app_secret,
|
||||||
|
body,
|
||||||
|
&signature_header
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
@ -663,7 +977,11 @@ mod tests {
|
||||||
|
|
||||||
let signature_header = compute_whatsapp_signature_header(app_secret, body);
|
let signature_header = compute_whatsapp_signature_header(app_secret, body);
|
||||||
|
|
||||||
assert!(verify_whatsapp_signature(app_secret, body, &signature_header));
|
assert!(verify_whatsapp_signature(
|
||||||
|
app_secret,
|
||||||
|
body,
|
||||||
|
&signature_header
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
|
||||||
199
src/runtime/docker.rs
Normal file
199
src/runtime/docker.rs
Normal file
|
|
@ -0,0 +1,199 @@
|
||||||
|
use super::traits::RuntimeAdapter;
|
||||||
|
use crate::config::DockerRuntimeConfig;
|
||||||
|
use anyhow::{Context, Result};
|
||||||
|
use std::path::{Path, PathBuf};
|
||||||
|
|
||||||
|
/// Docker runtime with lightweight container isolation.
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct DockerRuntime {
|
||||||
|
config: DockerRuntimeConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DockerRuntime {
|
||||||
|
pub fn new(config: DockerRuntimeConfig) -> Self {
|
||||||
|
Self { config }
|
||||||
|
}
|
||||||
|
|
||||||
|
fn workspace_mount_path(&self, workspace_dir: &Path) -> Result<PathBuf> {
|
||||||
|
let resolved = workspace_dir
|
||||||
|
.canonicalize()
|
||||||
|
.unwrap_or_else(|_| workspace_dir.to_path_buf());
|
||||||
|
|
||||||
|
if !resolved.is_absolute() {
|
||||||
|
anyhow::bail!(
|
||||||
|
"Docker runtime requires an absolute workspace path, got: {}",
|
||||||
|
resolved.display()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if resolved == Path::new("/") {
|
||||||
|
anyhow::bail!("Refusing to mount filesystem root (/) into docker runtime");
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.config.allowed_workspace_roots.is_empty() {
|
||||||
|
return Ok(resolved);
|
||||||
|
}
|
||||||
|
|
||||||
|
let allowed = self.config.allowed_workspace_roots.iter().any(|root| {
|
||||||
|
let root_path = Path::new(root)
|
||||||
|
.canonicalize()
|
||||||
|
.unwrap_or_else(|_| PathBuf::from(root));
|
||||||
|
resolved.starts_with(root_path)
|
||||||
|
});
|
||||||
|
|
||||||
|
if !allowed {
|
||||||
|
anyhow::bail!(
|
||||||
|
"Workspace path {} is not in runtime.docker.allowed_workspace_roots",
|
||||||
|
resolved.display()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(resolved)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RuntimeAdapter for DockerRuntime {
|
||||||
|
fn name(&self) -> &str {
|
||||||
|
"docker"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn has_shell_access(&self) -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
|
||||||
|
fn has_filesystem_access(&self) -> bool {
|
||||||
|
self.config.mount_workspace
|
||||||
|
}
|
||||||
|
|
||||||
|
fn storage_path(&self) -> PathBuf {
|
||||||
|
if self.config.mount_workspace {
|
||||||
|
PathBuf::from("/workspace/.zeroclaw")
|
||||||
|
} else {
|
||||||
|
PathBuf::from("/tmp/.zeroclaw")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn supports_long_running(&self) -> bool {
|
||||||
|
false
|
||||||
|
}
|
||||||
|
|
||||||
|
fn memory_budget(&self) -> u64 {
|
||||||
|
self.config
|
||||||
|
.memory_limit_mb
|
||||||
|
.map_or(0, |mb| mb.saturating_mul(1024 * 1024))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn build_shell_command(
|
||||||
|
&self,
|
||||||
|
command: &str,
|
||||||
|
workspace_dir: &Path,
|
||||||
|
) -> anyhow::Result<tokio::process::Command> {
|
||||||
|
let mut process = tokio::process::Command::new("docker");
|
||||||
|
process
|
||||||
|
.arg("run")
|
||||||
|
.arg("--rm")
|
||||||
|
.arg("--init")
|
||||||
|
.arg("--interactive");
|
||||||
|
|
||||||
|
let network = self.config.network.trim();
|
||||||
|
if !network.is_empty() {
|
||||||
|
process.arg("--network").arg(network);
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(memory_limit_mb) = self.config.memory_limit_mb.filter(|mb| *mb > 0) {
|
||||||
|
process.arg("--memory").arg(format!("{memory_limit_mb}m"));
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(cpu_limit) = self.config.cpu_limit.filter(|cpus| *cpus > 0.0) {
|
||||||
|
process.arg("--cpus").arg(cpu_limit.to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.config.read_only_rootfs {
|
||||||
|
process.arg("--read-only");
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.config.mount_workspace {
|
||||||
|
let host_workspace = self.workspace_mount_path(workspace_dir).with_context(|| {
|
||||||
|
format!(
|
||||||
|
"Failed to validate workspace mount path {}",
|
||||||
|
workspace_dir.display()
|
||||||
|
)
|
||||||
|
})?;
|
||||||
|
|
||||||
|
process
|
||||||
|
.arg("--volume")
|
||||||
|
.arg(format!("{}:/workspace:rw", host_workspace.display()))
|
||||||
|
.arg("--workdir")
|
||||||
|
.arg("/workspace");
|
||||||
|
}
|
||||||
|
|
||||||
|
process
|
||||||
|
.arg(self.config.image.trim())
|
||||||
|
.arg("sh")
|
||||||
|
.arg("-c")
|
||||||
|
.arg(command);
|
||||||
|
|
||||||
|
Ok(process)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn docker_runtime_name() {
|
||||||
|
let runtime = DockerRuntime::new(DockerRuntimeConfig::default());
|
||||||
|
assert_eq!(runtime.name(), "docker");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn docker_runtime_memory_budget() {
|
||||||
|
let mut cfg = DockerRuntimeConfig::default();
|
||||||
|
cfg.memory_limit_mb = Some(256);
|
||||||
|
let runtime = DockerRuntime::new(cfg);
|
||||||
|
assert_eq!(runtime.memory_budget(), 256 * 1024 * 1024);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn docker_build_shell_command_includes_runtime_flags() {
|
||||||
|
let cfg = DockerRuntimeConfig {
|
||||||
|
image: "alpine:3.20".into(),
|
||||||
|
network: "none".into(),
|
||||||
|
memory_limit_mb: Some(128),
|
||||||
|
cpu_limit: Some(1.5),
|
||||||
|
read_only_rootfs: true,
|
||||||
|
mount_workspace: true,
|
||||||
|
allowed_workspace_roots: Vec::new(),
|
||||||
|
};
|
||||||
|
let runtime = DockerRuntime::new(cfg);
|
||||||
|
|
||||||
|
let workspace = std::env::temp_dir();
|
||||||
|
let command = runtime
|
||||||
|
.build_shell_command("echo hello", &workspace)
|
||||||
|
.unwrap();
|
||||||
|
let debug = format!("{command:?}");
|
||||||
|
|
||||||
|
assert!(debug.contains("docker"));
|
||||||
|
assert!(debug.contains("--memory"));
|
||||||
|
assert!(debug.contains("128m"));
|
||||||
|
assert!(debug.contains("--cpus"));
|
||||||
|
assert!(debug.contains("1.5"));
|
||||||
|
assert!(debug.contains("--workdir"));
|
||||||
|
assert!(debug.contains("echo hello"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn docker_workspace_allowlist_blocks_outside_paths() {
|
||||||
|
let cfg = DockerRuntimeConfig {
|
||||||
|
allowed_workspace_roots: vec!["/tmp/allowed".into()],
|
||||||
|
..DockerRuntimeConfig::default()
|
||||||
|
};
|
||||||
|
let runtime = DockerRuntime::new(cfg);
|
||||||
|
|
||||||
|
let outside = PathBuf::from("/tmp/blocked_workspace");
|
||||||
|
let result = runtime.build_shell_command("echo test", &outside);
|
||||||
|
|
||||||
|
assert!(result.is_err());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -1,6 +1,8 @@
|
||||||
|
pub mod docker;
|
||||||
pub mod native;
|
pub mod native;
|
||||||
pub mod traits;
|
pub mod traits;
|
||||||
|
|
||||||
|
pub use docker::DockerRuntime;
|
||||||
pub use native::NativeRuntime;
|
pub use native::NativeRuntime;
|
||||||
pub use traits::RuntimeAdapter;
|
pub use traits::RuntimeAdapter;
|
||||||
|
|
||||||
|
|
@ -10,18 +12,14 @@ use crate::config::RuntimeConfig;
|
||||||
pub fn create_runtime(config: &RuntimeConfig) -> anyhow::Result<Box<dyn RuntimeAdapter>> {
|
pub fn create_runtime(config: &RuntimeConfig) -> anyhow::Result<Box<dyn RuntimeAdapter>> {
|
||||||
match config.kind.as_str() {
|
match config.kind.as_str() {
|
||||||
"native" => Ok(Box::new(NativeRuntime::new())),
|
"native" => Ok(Box::new(NativeRuntime::new())),
|
||||||
"docker" => anyhow::bail!(
|
"docker" => Ok(Box::new(DockerRuntime::new(config.docker.clone()))),
|
||||||
"runtime.kind='docker' is not implemented yet. Use runtime.kind='native' until container runtime support lands."
|
|
||||||
),
|
|
||||||
"cloudflare" => anyhow::bail!(
|
"cloudflare" => anyhow::bail!(
|
||||||
"runtime.kind='cloudflare' is not implemented yet. Use runtime.kind='native' for now."
|
"runtime.kind='cloudflare' is not implemented yet. Use runtime.kind='native' for now."
|
||||||
),
|
),
|
||||||
other if other.trim().is_empty() => anyhow::bail!(
|
other if other.trim().is_empty() => {
|
||||||
"runtime.kind cannot be empty. Supported values: native"
|
anyhow::bail!("runtime.kind cannot be empty. Supported values: native, docker")
|
||||||
),
|
}
|
||||||
other => anyhow::bail!(
|
other => anyhow::bail!("Unknown runtime kind '{other}'. Supported values: native, docker"),
|
||||||
"Unknown runtime kind '{other}'. Supported values: native"
|
|
||||||
),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -33,6 +31,7 @@ mod tests {
|
||||||
fn factory_native() {
|
fn factory_native() {
|
||||||
let cfg = RuntimeConfig {
|
let cfg = RuntimeConfig {
|
||||||
kind: "native".into(),
|
kind: "native".into(),
|
||||||
|
..RuntimeConfig::default()
|
||||||
};
|
};
|
||||||
let rt = create_runtime(&cfg).unwrap();
|
let rt = create_runtime(&cfg).unwrap();
|
||||||
assert_eq!(rt.name(), "native");
|
assert_eq!(rt.name(), "native");
|
||||||
|
|
@ -40,20 +39,21 @@ mod tests {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn factory_docker_errors() {
|
fn factory_docker() {
|
||||||
let cfg = RuntimeConfig {
|
let cfg = RuntimeConfig {
|
||||||
kind: "docker".into(),
|
kind: "docker".into(),
|
||||||
|
..RuntimeConfig::default()
|
||||||
};
|
};
|
||||||
match create_runtime(&cfg) {
|
let rt = create_runtime(&cfg).unwrap();
|
||||||
Err(err) => assert!(err.to_string().contains("not implemented")),
|
assert_eq!(rt.name(), "docker");
|
||||||
Ok(_) => panic!("docker runtime should error"),
|
assert!(rt.has_shell_access());
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn factory_cloudflare_errors() {
|
fn factory_cloudflare_errors() {
|
||||||
let cfg = RuntimeConfig {
|
let cfg = RuntimeConfig {
|
||||||
kind: "cloudflare".into(),
|
kind: "cloudflare".into(),
|
||||||
|
..RuntimeConfig::default()
|
||||||
};
|
};
|
||||||
match create_runtime(&cfg) {
|
match create_runtime(&cfg) {
|
||||||
Err(err) => assert!(err.to_string().contains("not implemented")),
|
Err(err) => assert!(err.to_string().contains("not implemented")),
|
||||||
|
|
@ -65,6 +65,7 @@ mod tests {
|
||||||
fn factory_unknown_errors() {
|
fn factory_unknown_errors() {
|
||||||
let cfg = RuntimeConfig {
|
let cfg = RuntimeConfig {
|
||||||
kind: "wasm-edge-unknown".into(),
|
kind: "wasm-edge-unknown".into(),
|
||||||
|
..RuntimeConfig::default()
|
||||||
};
|
};
|
||||||
match create_runtime(&cfg) {
|
match create_runtime(&cfg) {
|
||||||
Err(err) => assert!(err.to_string().contains("Unknown runtime kind")),
|
Err(err) => assert!(err.to_string().contains("Unknown runtime kind")),
|
||||||
|
|
@ -76,6 +77,7 @@ mod tests {
|
||||||
fn factory_empty_errors() {
|
fn factory_empty_errors() {
|
||||||
let cfg = RuntimeConfig {
|
let cfg = RuntimeConfig {
|
||||||
kind: String::new(),
|
kind: String::new(),
|
||||||
|
..RuntimeConfig::default()
|
||||||
};
|
};
|
||||||
match create_runtime(&cfg) {
|
match create_runtime(&cfg) {
|
||||||
Err(err) => assert!(err.to_string().contains("cannot be empty")),
|
Err(err) => assert!(err.to_string().contains("cannot be empty")),
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,5 @@
|
||||||
use super::traits::RuntimeAdapter;
|
use super::traits::RuntimeAdapter;
|
||||||
use std::path::PathBuf;
|
use std::path::{Path, PathBuf};
|
||||||
|
|
||||||
/// Native runtime — full access, runs on Mac/Linux/Docker/Raspberry Pi
|
/// Native runtime — full access, runs on Mac/Linux/Docker/Raspberry Pi
|
||||||
pub struct NativeRuntime;
|
pub struct NativeRuntime;
|
||||||
|
|
@ -33,6 +33,16 @@ impl RuntimeAdapter for NativeRuntime {
|
||||||
fn supports_long_running(&self) -> bool {
|
fn supports_long_running(&self) -> bool {
|
||||||
true
|
true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn build_shell_command(
|
||||||
|
&self,
|
||||||
|
command: &str,
|
||||||
|
workspace_dir: &Path,
|
||||||
|
) -> anyhow::Result<tokio::process::Command> {
|
||||||
|
let mut process = tokio::process::Command::new("sh");
|
||||||
|
process.arg("-c").arg(command).current_dir(workspace_dir);
|
||||||
|
Ok(process)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
|
@ -69,4 +79,14 @@ mod tests {
|
||||||
let path = NativeRuntime::new().storage_path();
|
let path = NativeRuntime::new().storage_path();
|
||||||
assert!(path.to_string_lossy().contains("zeroclaw"));
|
assert!(path.to_string_lossy().contains("zeroclaw"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn native_builds_shell_command() {
|
||||||
|
let cwd = std::env::temp_dir();
|
||||||
|
let command = NativeRuntime::new()
|
||||||
|
.build_shell_command("echo hello", &cwd)
|
||||||
|
.unwrap();
|
||||||
|
let debug = format!("{command:?}");
|
||||||
|
assert!(debug.contains("echo hello"));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
use std::path::PathBuf;
|
use std::path::{Path, PathBuf};
|
||||||
|
|
||||||
/// Runtime adapter — abstracts platform differences so the same agent
|
/// Runtime adapter — abstracts platform differences so the same agent
|
||||||
/// code runs on native, Docker, Cloudflare Workers, Raspberry Pi, etc.
|
/// code runs on native, Docker, Cloudflare Workers, Raspberry Pi, etc.
|
||||||
|
|
@ -22,4 +22,11 @@ pub trait RuntimeAdapter: Send + Sync {
|
||||||
fn memory_budget(&self) -> u64 {
|
fn memory_budget(&self) -> u64 {
|
||||||
0
|
0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Build a shell command process for this runtime.
|
||||||
|
fn build_shell_command(
|
||||||
|
&self,
|
||||||
|
command: &str,
|
||||||
|
workspace_dir: &Path,
|
||||||
|
) -> anyhow::Result<tokio::process::Command>;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -16,6 +16,14 @@ pub enum AutonomyLevel {
|
||||||
Full,
|
Full,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Risk score for shell command execution.
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
|
pub enum CommandRiskLevel {
|
||||||
|
Low,
|
||||||
|
Medium,
|
||||||
|
High,
|
||||||
|
}
|
||||||
|
|
||||||
/// Sliding-window action tracker for rate limiting.
|
/// Sliding-window action tracker for rate limiting.
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct ActionTracker {
|
pub struct ActionTracker {
|
||||||
|
|
@ -80,6 +88,8 @@ pub struct SecurityPolicy {
|
||||||
pub forbidden_paths: Vec<String>,
|
pub forbidden_paths: Vec<String>,
|
||||||
pub max_actions_per_hour: u32,
|
pub max_actions_per_hour: u32,
|
||||||
pub max_cost_per_day_cents: u32,
|
pub max_cost_per_day_cents: u32,
|
||||||
|
pub require_approval_for_medium_risk: bool,
|
||||||
|
pub block_high_risk_commands: bool,
|
||||||
pub tracker: ActionTracker,
|
pub tracker: ActionTracker,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -127,6 +137,8 @@ impl Default for SecurityPolicy {
|
||||||
],
|
],
|
||||||
max_actions_per_hour: 20,
|
max_actions_per_hour: 20,
|
||||||
max_cost_per_day_cents: 500,
|
max_cost_per_day_cents: 500,
|
||||||
|
require_approval_for_medium_risk: true,
|
||||||
|
block_high_risk_commands: true,
|
||||||
tracker: ActionTracker::new(),
|
tracker: ActionTracker::new(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -156,6 +168,163 @@ fn skip_env_assignments(s: &str) -> &str {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl SecurityPolicy {
|
impl SecurityPolicy {
|
||||||
|
/// Classify command risk. Any high-risk segment marks the whole command high.
|
||||||
|
pub fn command_risk_level(&self, command: &str) -> CommandRiskLevel {
|
||||||
|
let mut normalized = command.to_string();
|
||||||
|
for sep in ["&&", "||"] {
|
||||||
|
normalized = normalized.replace(sep, "\x00");
|
||||||
|
}
|
||||||
|
for sep in ['\n', ';', '|'] {
|
||||||
|
normalized = normalized.replace(sep, "\x00");
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut saw_medium = false;
|
||||||
|
|
||||||
|
for segment in normalized.split('\x00') {
|
||||||
|
let segment = segment.trim();
|
||||||
|
if segment.is_empty() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let cmd_part = skip_env_assignments(segment);
|
||||||
|
let mut words = cmd_part.split_whitespace();
|
||||||
|
let Some(base_raw) = words.next() else {
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
|
||||||
|
let base = base_raw
|
||||||
|
.rsplit('/')
|
||||||
|
.next()
|
||||||
|
.unwrap_or("")
|
||||||
|
.to_ascii_lowercase();
|
||||||
|
|
||||||
|
let args: Vec<String> = words.map(|w| w.to_ascii_lowercase()).collect();
|
||||||
|
let joined_segment = cmd_part.to_ascii_lowercase();
|
||||||
|
|
||||||
|
// High-risk commands
|
||||||
|
if matches!(
|
||||||
|
base.as_str(),
|
||||||
|
"rm" | "mkfs"
|
||||||
|
| "dd"
|
||||||
|
| "shutdown"
|
||||||
|
| "reboot"
|
||||||
|
| "halt"
|
||||||
|
| "poweroff"
|
||||||
|
| "sudo"
|
||||||
|
| "su"
|
||||||
|
| "chown"
|
||||||
|
| "chmod"
|
||||||
|
| "useradd"
|
||||||
|
| "userdel"
|
||||||
|
| "usermod"
|
||||||
|
| "passwd"
|
||||||
|
| "mount"
|
||||||
|
| "umount"
|
||||||
|
| "iptables"
|
||||||
|
| "ufw"
|
||||||
|
| "firewall-cmd"
|
||||||
|
| "curl"
|
||||||
|
| "wget"
|
||||||
|
| "nc"
|
||||||
|
| "ncat"
|
||||||
|
| "netcat"
|
||||||
|
| "scp"
|
||||||
|
| "ssh"
|
||||||
|
| "ftp"
|
||||||
|
| "telnet"
|
||||||
|
) {
|
||||||
|
return CommandRiskLevel::High;
|
||||||
|
}
|
||||||
|
|
||||||
|
if joined_segment.contains("rm -rf /")
|
||||||
|
|| joined_segment.contains("rm -fr /")
|
||||||
|
|| joined_segment.contains(":(){:|:&};:")
|
||||||
|
{
|
||||||
|
return CommandRiskLevel::High;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Medium-risk commands (state-changing, but not inherently destructive)
|
||||||
|
let medium = match base.as_str() {
|
||||||
|
"git" => args.first().is_some_and(|verb| {
|
||||||
|
matches!(
|
||||||
|
verb.as_str(),
|
||||||
|
"commit"
|
||||||
|
| "push"
|
||||||
|
| "reset"
|
||||||
|
| "clean"
|
||||||
|
| "rebase"
|
||||||
|
| "merge"
|
||||||
|
| "cherry-pick"
|
||||||
|
| "revert"
|
||||||
|
| "branch"
|
||||||
|
| "checkout"
|
||||||
|
| "switch"
|
||||||
|
| "tag"
|
||||||
|
)
|
||||||
|
}),
|
||||||
|
"npm" | "pnpm" | "yarn" => args.first().is_some_and(|verb| {
|
||||||
|
matches!(
|
||||||
|
verb.as_str(),
|
||||||
|
"install" | "add" | "remove" | "uninstall" | "update" | "publish"
|
||||||
|
)
|
||||||
|
}),
|
||||||
|
"cargo" => args.first().is_some_and(|verb| {
|
||||||
|
matches!(
|
||||||
|
verb.as_str(),
|
||||||
|
"add" | "remove" | "install" | "clean" | "publish"
|
||||||
|
)
|
||||||
|
}),
|
||||||
|
"touch" | "mkdir" | "mv" | "cp" | "ln" => true,
|
||||||
|
_ => false,
|
||||||
|
};
|
||||||
|
|
||||||
|
saw_medium |= medium;
|
||||||
|
}
|
||||||
|
|
||||||
|
if saw_medium {
|
||||||
|
CommandRiskLevel::Medium
|
||||||
|
} else {
|
||||||
|
CommandRiskLevel::Low
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Validate full command execution policy (allowlist + risk gate).
|
||||||
|
pub fn validate_command_execution(
|
||||||
|
&self,
|
||||||
|
command: &str,
|
||||||
|
approved: bool,
|
||||||
|
) -> Result<CommandRiskLevel, String> {
|
||||||
|
if !self.is_command_allowed(command) {
|
||||||
|
return Err(format!("Command not allowed by security policy: {command}"));
|
||||||
|
}
|
||||||
|
|
||||||
|
let risk = self.command_risk_level(command);
|
||||||
|
|
||||||
|
if risk == CommandRiskLevel::High {
|
||||||
|
if self.block_high_risk_commands {
|
||||||
|
return Err("Command blocked: high-risk command is disallowed by policy".into());
|
||||||
|
}
|
||||||
|
if self.autonomy == AutonomyLevel::Supervised && !approved {
|
||||||
|
return Err(
|
||||||
|
"Command requires explicit approval (approved=true): high-risk operation"
|
||||||
|
.into(),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if risk == CommandRiskLevel::Medium
|
||||||
|
&& self.autonomy == AutonomyLevel::Supervised
|
||||||
|
&& self.require_approval_for_medium_risk
|
||||||
|
&& !approved
|
||||||
|
{
|
||||||
|
return Err(
|
||||||
|
"Command requires explicit approval (approved=true): medium-risk operation".into(),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(risk)
|
||||||
|
}
|
||||||
|
|
||||||
/// Check if a shell command is allowed.
|
/// Check if a shell command is allowed.
|
||||||
///
|
///
|
||||||
/// Validates the **entire** command string, not just the first word:
|
/// Validates the **entire** command string, not just the first word:
|
||||||
|
|
@ -329,6 +498,8 @@ impl SecurityPolicy {
|
||||||
forbidden_paths: autonomy_config.forbidden_paths.clone(),
|
forbidden_paths: autonomy_config.forbidden_paths.clone(),
|
||||||
max_actions_per_hour: autonomy_config.max_actions_per_hour,
|
max_actions_per_hour: autonomy_config.max_actions_per_hour,
|
||||||
max_cost_per_day_cents: autonomy_config.max_cost_per_day_cents,
|
max_cost_per_day_cents: autonomy_config.max_cost_per_day_cents,
|
||||||
|
require_approval_for_medium_risk: autonomy_config.require_approval_for_medium_risk,
|
||||||
|
block_high_risk_commands: autonomy_config.block_high_risk_commands,
|
||||||
tracker: ActionTracker::new(),
|
tracker: ActionTracker::new(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -473,6 +644,71 @@ mod tests {
|
||||||
assert!(!p.is_command_allowed("echo hello"));
|
assert!(!p.is_command_allowed("echo hello"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn command_risk_low_for_read_commands() {
|
||||||
|
let p = default_policy();
|
||||||
|
assert_eq!(p.command_risk_level("git status"), CommandRiskLevel::Low);
|
||||||
|
assert_eq!(p.command_risk_level("ls -la"), CommandRiskLevel::Low);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn command_risk_medium_for_mutating_commands() {
|
||||||
|
let p = SecurityPolicy {
|
||||||
|
allowed_commands: vec!["git".into(), "touch".into()],
|
||||||
|
..SecurityPolicy::default()
|
||||||
|
};
|
||||||
|
assert_eq!(
|
||||||
|
p.command_risk_level("git reset --hard HEAD~1"),
|
||||||
|
CommandRiskLevel::Medium
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
p.command_risk_level("touch file.txt"),
|
||||||
|
CommandRiskLevel::Medium
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn command_risk_high_for_dangerous_commands() {
|
||||||
|
let p = SecurityPolicy {
|
||||||
|
allowed_commands: vec!["rm".into()],
|
||||||
|
..SecurityPolicy::default()
|
||||||
|
};
|
||||||
|
assert_eq!(
|
||||||
|
p.command_risk_level("rm -rf /tmp/test"),
|
||||||
|
CommandRiskLevel::High
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn validate_command_requires_approval_for_medium_risk() {
|
||||||
|
let p = SecurityPolicy {
|
||||||
|
autonomy: AutonomyLevel::Supervised,
|
||||||
|
require_approval_for_medium_risk: true,
|
||||||
|
allowed_commands: vec!["touch".into()],
|
||||||
|
..SecurityPolicy::default()
|
||||||
|
};
|
||||||
|
|
||||||
|
let denied = p.validate_command_execution("touch test.txt", false);
|
||||||
|
assert!(denied.is_err());
|
||||||
|
assert!(denied.unwrap_err().contains("requires explicit approval"),);
|
||||||
|
|
||||||
|
let allowed = p.validate_command_execution("touch test.txt", true);
|
||||||
|
assert_eq!(allowed.unwrap(), CommandRiskLevel::Medium);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn validate_command_blocks_high_risk_by_default() {
|
||||||
|
let p = SecurityPolicy {
|
||||||
|
autonomy: AutonomyLevel::Supervised,
|
||||||
|
allowed_commands: vec!["rm".into()],
|
||||||
|
..SecurityPolicy::default()
|
||||||
|
};
|
||||||
|
|
||||||
|
let result = p.validate_command_execution("rm -rf /tmp/test", true);
|
||||||
|
assert!(result.is_err());
|
||||||
|
assert!(result.unwrap_err().contains("high-risk"));
|
||||||
|
}
|
||||||
|
|
||||||
// ── is_path_allowed ─────────────────────────────────────
|
// ── is_path_allowed ─────────────────────────────────────
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
@ -546,6 +782,8 @@ mod tests {
|
||||||
forbidden_paths: vec!["/secret".into()],
|
forbidden_paths: vec!["/secret".into()],
|
||||||
max_actions_per_hour: 100,
|
max_actions_per_hour: 100,
|
||||||
max_cost_per_day_cents: 1000,
|
max_cost_per_day_cents: 1000,
|
||||||
|
require_approval_for_medium_risk: false,
|
||||||
|
block_high_risk_commands: false,
|
||||||
};
|
};
|
||||||
let workspace = PathBuf::from("/tmp/test-workspace");
|
let workspace = PathBuf::from("/tmp/test-workspace");
|
||||||
let policy = SecurityPolicy::from_config(&autonomy_config, &workspace);
|
let policy = SecurityPolicy::from_config(&autonomy_config, &workspace);
|
||||||
|
|
@ -556,6 +794,8 @@ mod tests {
|
||||||
assert_eq!(policy.forbidden_paths, vec!["/secret"]);
|
assert_eq!(policy.forbidden_paths, vec!["/secret"]);
|
||||||
assert_eq!(policy.max_actions_per_hour, 100);
|
assert_eq!(policy.max_actions_per_hour, 100);
|
||||||
assert_eq!(policy.max_cost_per_day_cents, 1000);
|
assert_eq!(policy.max_cost_per_day_cents, 1000);
|
||||||
|
assert!(!policy.require_approval_for_medium_risk);
|
||||||
|
assert!(!policy.block_high_risk_commands);
|
||||||
assert_eq!(policy.workspace_dir, PathBuf::from("/tmp/test-workspace"));
|
assert_eq!(policy.workspace_dir, PathBuf::from("/tmp/test-workspace"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -570,6 +810,8 @@ mod tests {
|
||||||
assert!(!p.forbidden_paths.is_empty());
|
assert!(!p.forbidden_paths.is_empty());
|
||||||
assert!(p.max_actions_per_hour > 0);
|
assert!(p.max_actions_per_hour > 0);
|
||||||
assert!(p.max_cost_per_day_cents > 0);
|
assert!(p.max_cost_per_day_cents > 0);
|
||||||
|
assert!(p.require_approval_for_medium_risk);
|
||||||
|
assert!(p.block_high_risk_commands);
|
||||||
}
|
}
|
||||||
|
|
||||||
// ── ActionTracker / rate limiting ───────────────────────
|
// ── ActionTracker / rate limiting ───────────────────────
|
||||||
|
|
@ -853,6 +1095,8 @@ mod tests {
|
||||||
forbidden_paths: vec![],
|
forbidden_paths: vec![],
|
||||||
max_actions_per_hour: 10,
|
max_actions_per_hour: 10,
|
||||||
max_cost_per_day_cents: 100,
|
max_cost_per_day_cents: 100,
|
||||||
|
require_approval_for_medium_risk: true,
|
||||||
|
block_high_risk_commands: true,
|
||||||
};
|
};
|
||||||
let workspace = PathBuf::from("/tmp/test");
|
let workspace = PathBuf::from("/tmp/test");
|
||||||
let policy = SecurityPolicy::from_config(&autonomy_config, &workspace);
|
let policy = SecurityPolicy::from_config(&autonomy_config, &workspace);
|
||||||
|
|
|
||||||
|
|
@ -23,13 +23,22 @@ pub use traits::Tool;
|
||||||
pub use traits::{ToolResult, ToolSpec};
|
pub use traits::{ToolResult, ToolSpec};
|
||||||
|
|
||||||
use crate::memory::Memory;
|
use crate::memory::Memory;
|
||||||
|
use crate::runtime::{NativeRuntime, RuntimeAdapter};
|
||||||
use crate::security::SecurityPolicy;
|
use crate::security::SecurityPolicy;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
/// Create the default tool registry
|
/// Create the default tool registry
|
||||||
pub fn default_tools(security: Arc<SecurityPolicy>) -> Vec<Box<dyn Tool>> {
|
pub fn default_tools(security: Arc<SecurityPolicy>) -> Vec<Box<dyn Tool>> {
|
||||||
|
default_tools_with_runtime(security, Arc::new(NativeRuntime::new()))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create the default tool registry with explicit runtime adapter.
|
||||||
|
pub fn default_tools_with_runtime(
|
||||||
|
security: Arc<SecurityPolicy>,
|
||||||
|
runtime: Arc<dyn RuntimeAdapter>,
|
||||||
|
) -> Vec<Box<dyn Tool>> {
|
||||||
vec![
|
vec![
|
||||||
Box::new(ShellTool::new(security.clone())),
|
Box::new(ShellTool::new(security.clone(), runtime)),
|
||||||
Box::new(FileReadTool::new(security.clone())),
|
Box::new(FileReadTool::new(security.clone())),
|
||||||
Box::new(FileWriteTool::new(security)),
|
Box::new(FileWriteTool::new(security)),
|
||||||
]
|
]
|
||||||
|
|
@ -41,9 +50,26 @@ pub fn all_tools(
|
||||||
memory: Arc<dyn Memory>,
|
memory: Arc<dyn Memory>,
|
||||||
composio_key: Option<&str>,
|
composio_key: Option<&str>,
|
||||||
browser_config: &crate::config::BrowserConfig,
|
browser_config: &crate::config::BrowserConfig,
|
||||||
|
) -> Vec<Box<dyn Tool>> {
|
||||||
|
all_tools_with_runtime(
|
||||||
|
security,
|
||||||
|
Arc::new(NativeRuntime::new()),
|
||||||
|
memory,
|
||||||
|
composio_key,
|
||||||
|
browser_config,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Create full tool registry including memory tools and optional Composio.
|
||||||
|
pub fn all_tools_with_runtime(
|
||||||
|
security: &Arc<SecurityPolicy>,
|
||||||
|
runtime: Arc<dyn RuntimeAdapter>,
|
||||||
|
memory: Arc<dyn Memory>,
|
||||||
|
composio_key: Option<&str>,
|
||||||
|
browser_config: &crate::config::BrowserConfig,
|
||||||
) -> Vec<Box<dyn Tool>> {
|
) -> Vec<Box<dyn Tool>> {
|
||||||
let mut tools: Vec<Box<dyn Tool>> = vec![
|
let mut tools: Vec<Box<dyn Tool>> = vec![
|
||||||
Box::new(ShellTool::new(security.clone())),
|
Box::new(ShellTool::new(security.clone(), runtime)),
|
||||||
Box::new(FileReadTool::new(security.clone())),
|
Box::new(FileReadTool::new(security.clone())),
|
||||||
Box::new(FileWriteTool::new(security.clone())),
|
Box::new(FileWriteTool::new(security.clone())),
|
||||||
Box::new(MemoryStoreTool::new(memory.clone())),
|
Box::new(MemoryStoreTool::new(memory.clone())),
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
use super::traits::{Tool, ToolResult};
|
use super::traits::{Tool, ToolResult};
|
||||||
|
use crate::runtime::RuntimeAdapter;
|
||||||
use crate::security::SecurityPolicy;
|
use crate::security::SecurityPolicy;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
|
|
@ -18,11 +19,12 @@ const SAFE_ENV_VARS: &[&str] = &[
|
||||||
/// Shell command execution tool with sandboxing
|
/// Shell command execution tool with sandboxing
|
||||||
pub struct ShellTool {
|
pub struct ShellTool {
|
||||||
security: Arc<SecurityPolicy>,
|
security: Arc<SecurityPolicy>,
|
||||||
|
runtime: Arc<dyn RuntimeAdapter>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl ShellTool {
|
impl ShellTool {
|
||||||
pub fn new(security: Arc<SecurityPolicy>) -> Self {
|
pub fn new(security: Arc<SecurityPolicy>, runtime: Arc<dyn RuntimeAdapter>) -> Self {
|
||||||
Self { security }
|
Self { security, runtime }
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -43,6 +45,11 @@ impl Tool for ShellTool {
|
||||||
"command": {
|
"command": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
"description": "The shell command to execute"
|
"description": "The shell command to execute"
|
||||||
|
},
|
||||||
|
"approved": {
|
||||||
|
"type": "boolean",
|
||||||
|
"description": "Set true to explicitly approve medium/high-risk commands in supervised mode",
|
||||||
|
"default": false
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"required": ["command"]
|
"required": ["command"]
|
||||||
|
|
@ -54,24 +61,55 @@ impl Tool for ShellTool {
|
||||||
.get("command")
|
.get("command")
|
||||||
.and_then(|v| v.as_str())
|
.and_then(|v| v.as_str())
|
||||||
.ok_or_else(|| anyhow::anyhow!("Missing 'command' parameter"))?;
|
.ok_or_else(|| anyhow::anyhow!("Missing 'command' parameter"))?;
|
||||||
|
let approved = args
|
||||||
|
.get("approved")
|
||||||
|
.and_then(|v| v.as_bool())
|
||||||
|
.unwrap_or(false);
|
||||||
|
|
||||||
// Security check: validate command against allowlist
|
if self.security.is_rate_limited() {
|
||||||
if !self.security.is_command_allowed(command) {
|
|
||||||
return Ok(ToolResult {
|
return Ok(ToolResult {
|
||||||
success: false,
|
success: false,
|
||||||
output: String::new(),
|
output: String::new(),
|
||||||
error: Some(format!("Command not allowed by security policy: {command}")),
|
error: Some("Rate limit exceeded: too many actions in the last hour".into()),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
match self.security.validate_command_execution(command, approved) {
|
||||||
|
Ok(_) => {}
|
||||||
|
Err(reason) => {
|
||||||
|
return Ok(ToolResult {
|
||||||
|
success: false,
|
||||||
|
output: String::new(),
|
||||||
|
error: Some(reason),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !self.security.record_action() {
|
||||||
|
return Ok(ToolResult {
|
||||||
|
success: false,
|
||||||
|
output: String::new(),
|
||||||
|
error: Some("Rate limit exceeded: action budget exhausted".into()),
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
// Execute with timeout to prevent hanging commands.
|
// Execute with timeout to prevent hanging commands.
|
||||||
// Clear the environment to prevent leaking API keys and other secrets
|
// Clear the environment to prevent leaking API keys and other secrets
|
||||||
// (CWE-200), then re-add only safe, functional variables.
|
// (CWE-200), then re-add only safe, functional variables.
|
||||||
let mut cmd = tokio::process::Command::new("sh");
|
let mut cmd = match self
|
||||||
cmd.arg("-c")
|
.runtime
|
||||||
.arg(command)
|
.build_shell_command(command, &self.security.workspace_dir)
|
||||||
.current_dir(&self.security.workspace_dir)
|
{
|
||||||
.env_clear();
|
Ok(cmd) => cmd,
|
||||||
|
Err(e) => {
|
||||||
|
return Ok(ToolResult {
|
||||||
|
success: false,
|
||||||
|
output: String::new(),
|
||||||
|
error: Some(format!("Failed to build runtime command: {e}")),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
cmd.env_clear();
|
||||||
|
|
||||||
for var in SAFE_ENV_VARS {
|
for var in SAFE_ENV_VARS {
|
||||||
if let Ok(val) = std::env::var(var) {
|
if let Ok(val) = std::env::var(var) {
|
||||||
|
|
@ -126,6 +164,7 @@ impl Tool for ShellTool {
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
use crate::runtime::{NativeRuntime, RuntimeAdapter};
|
||||||
use crate::security::{AutonomyLevel, SecurityPolicy};
|
use crate::security::{AutonomyLevel, SecurityPolicy};
|
||||||
|
|
||||||
fn test_security(autonomy: AutonomyLevel) -> Arc<SecurityPolicy> {
|
fn test_security(autonomy: AutonomyLevel) -> Arc<SecurityPolicy> {
|
||||||
|
|
@ -136,32 +175,37 @@ mod tests {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn test_runtime() -> Arc<dyn RuntimeAdapter> {
|
||||||
|
Arc::new(NativeRuntime::new())
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn shell_tool_name() {
|
fn shell_tool_name() {
|
||||||
let tool = ShellTool::new(test_security(AutonomyLevel::Supervised));
|
let tool = ShellTool::new(test_security(AutonomyLevel::Supervised), test_runtime());
|
||||||
assert_eq!(tool.name(), "shell");
|
assert_eq!(tool.name(), "shell");
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn shell_tool_description() {
|
fn shell_tool_description() {
|
||||||
let tool = ShellTool::new(test_security(AutonomyLevel::Supervised));
|
let tool = ShellTool::new(test_security(AutonomyLevel::Supervised), test_runtime());
|
||||||
assert!(!tool.description().is_empty());
|
assert!(!tool.description().is_empty());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn shell_tool_schema_has_command() {
|
fn shell_tool_schema_has_command() {
|
||||||
let tool = ShellTool::new(test_security(AutonomyLevel::Supervised));
|
let tool = ShellTool::new(test_security(AutonomyLevel::Supervised), test_runtime());
|
||||||
let schema = tool.parameters_schema();
|
let schema = tool.parameters_schema();
|
||||||
assert!(schema["properties"]["command"].is_object());
|
assert!(schema["properties"]["command"].is_object());
|
||||||
assert!(schema["required"]
|
assert!(schema["required"]
|
||||||
.as_array()
|
.as_array()
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.contains(&json!("command")));
|
.contains(&json!("command")));
|
||||||
|
assert!(schema["properties"]["approved"].is_object());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn shell_executes_allowed_command() {
|
async fn shell_executes_allowed_command() {
|
||||||
let tool = ShellTool::new(test_security(AutonomyLevel::Supervised));
|
let tool = ShellTool::new(test_security(AutonomyLevel::Supervised), test_runtime());
|
||||||
let result = tool
|
let result = tool
|
||||||
.execute(json!({"command": "echo hello"}))
|
.execute(json!({"command": "echo hello"}))
|
||||||
.await
|
.await
|
||||||
|
|
@ -173,15 +217,16 @@ mod tests {
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn shell_blocks_disallowed_command() {
|
async fn shell_blocks_disallowed_command() {
|
||||||
let tool = ShellTool::new(test_security(AutonomyLevel::Supervised));
|
let tool = ShellTool::new(test_security(AutonomyLevel::Supervised), test_runtime());
|
||||||
let result = tool.execute(json!({"command": "rm -rf /"})).await.unwrap();
|
let result = tool.execute(json!({"command": "rm -rf /"})).await.unwrap();
|
||||||
assert!(!result.success);
|
assert!(!result.success);
|
||||||
assert!(result.error.as_ref().unwrap().contains("not allowed"));
|
let error = result.error.as_deref().unwrap_or("");
|
||||||
|
assert!(error.contains("not allowed") || error.contains("high-risk"));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn shell_blocks_readonly() {
|
async fn shell_blocks_readonly() {
|
||||||
let tool = ShellTool::new(test_security(AutonomyLevel::ReadOnly));
|
let tool = ShellTool::new(test_security(AutonomyLevel::ReadOnly), test_runtime());
|
||||||
let result = tool.execute(json!({"command": "ls"})).await.unwrap();
|
let result = tool.execute(json!({"command": "ls"})).await.unwrap();
|
||||||
assert!(!result.success);
|
assert!(!result.success);
|
||||||
assert!(result.error.as_ref().unwrap().contains("not allowed"));
|
assert!(result.error.as_ref().unwrap().contains("not allowed"));
|
||||||
|
|
@ -189,7 +234,7 @@ mod tests {
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn shell_missing_command_param() {
|
async fn shell_missing_command_param() {
|
||||||
let tool = ShellTool::new(test_security(AutonomyLevel::Supervised));
|
let tool = ShellTool::new(test_security(AutonomyLevel::Supervised), test_runtime());
|
||||||
let result = tool.execute(json!({})).await;
|
let result = tool.execute(json!({})).await;
|
||||||
assert!(result.is_err());
|
assert!(result.is_err());
|
||||||
assert!(result.unwrap_err().to_string().contains("command"));
|
assert!(result.unwrap_err().to_string().contains("command"));
|
||||||
|
|
@ -197,14 +242,14 @@ mod tests {
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn shell_wrong_type_param() {
|
async fn shell_wrong_type_param() {
|
||||||
let tool = ShellTool::new(test_security(AutonomyLevel::Supervised));
|
let tool = ShellTool::new(test_security(AutonomyLevel::Supervised), test_runtime());
|
||||||
let result = tool.execute(json!({"command": 123})).await;
|
let result = tool.execute(json!({"command": 123})).await;
|
||||||
assert!(result.is_err());
|
assert!(result.is_err());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn shell_captures_exit_code() {
|
async fn shell_captures_exit_code() {
|
||||||
let tool = ShellTool::new(test_security(AutonomyLevel::Supervised));
|
let tool = ShellTool::new(test_security(AutonomyLevel::Supervised), test_runtime());
|
||||||
let result = tool
|
let result = tool
|
||||||
.execute(json!({"command": "ls /nonexistent_dir_xyz"}))
|
.execute(json!({"command": "ls /nonexistent_dir_xyz"}))
|
||||||
.await
|
.await
|
||||||
|
|
@ -250,7 +295,7 @@ mod tests {
|
||||||
let _g1 = EnvGuard::set("API_KEY", "sk-test-secret-12345");
|
let _g1 = EnvGuard::set("API_KEY", "sk-test-secret-12345");
|
||||||
let _g2 = EnvGuard::set("ZEROCLAW_API_KEY", "sk-test-secret-67890");
|
let _g2 = EnvGuard::set("ZEROCLAW_API_KEY", "sk-test-secret-67890");
|
||||||
|
|
||||||
let tool = ShellTool::new(test_security_with_env_cmd());
|
let tool = ShellTool::new(test_security_with_env_cmd(), test_runtime());
|
||||||
let result = tool.execute(json!({"command": "env"})).await.unwrap();
|
let result = tool.execute(json!({"command": "env"})).await.unwrap();
|
||||||
assert!(result.success);
|
assert!(result.success);
|
||||||
assert!(
|
assert!(
|
||||||
|
|
@ -265,7 +310,7 @@ mod tests {
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn shell_preserves_path_and_home() {
|
async fn shell_preserves_path_and_home() {
|
||||||
let tool = ShellTool::new(test_security_with_env_cmd());
|
let tool = ShellTool::new(test_security_with_env_cmd(), test_runtime());
|
||||||
|
|
||||||
let result = tool
|
let result = tool
|
||||||
.execute(json!({"command": "echo $HOME"}))
|
.execute(json!({"command": "echo $HOME"}))
|
||||||
|
|
@ -287,4 +332,37 @@ mod tests {
|
||||||
"PATH should be available in shell"
|
"PATH should be available in shell"
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn shell_requires_approval_for_medium_risk_command() {
|
||||||
|
let security = Arc::new(SecurityPolicy {
|
||||||
|
autonomy: AutonomyLevel::Supervised,
|
||||||
|
allowed_commands: vec!["touch".into()],
|
||||||
|
workspace_dir: std::env::temp_dir(),
|
||||||
|
..SecurityPolicy::default()
|
||||||
|
});
|
||||||
|
|
||||||
|
let tool = ShellTool::new(security.clone(), test_runtime());
|
||||||
|
let denied = tool
|
||||||
|
.execute(json!({"command": "touch zeroclaw_shell_approval_test"}))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert!(!denied.success);
|
||||||
|
assert!(denied
|
||||||
|
.error
|
||||||
|
.as_deref()
|
||||||
|
.unwrap_or("")
|
||||||
|
.contains("explicit approval"));
|
||||||
|
|
||||||
|
let allowed = tool
|
||||||
|
.execute(json!({
|
||||||
|
"command": "touch zeroclaw_shell_approval_test",
|
||||||
|
"approved": true
|
||||||
|
}))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert!(allowed.success);
|
||||||
|
|
||||||
|
let _ = std::fs::remove_file(std::env::temp_dir().join("zeroclaw_shell_approval_test"));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue