security: harden architecture against Moltbot security model
- Discord: add allowed_users field + sender validation in listen() - Slack: add allowed_users field + sender validation in listen() - Webhook: add X-Webhook-Secret header auth (401 on mismatch) - SecurityPolicy: add ActionTracker with sliding-window rate limiting - record_action() enforces max_actions_per_hour - is_rate_limited() checks without recording - Gateway: print auth status on startup (ENABLED/DISABLED) - 22 new tests (Discord/Slack allowlists, gateway header extraction, rate limiter: starts at zero, records, allows within limit, blocks over limit, clone independence) - 554 tests passing, 0 clippy warnings
This commit is contained in:
parent
cf0ca71fdc
commit
542bb80743
7 changed files with 287 additions and 6 deletions
|
|
@ -9,18 +9,29 @@ use uuid::Uuid;
|
|||
pub struct DiscordChannel {
|
||||
bot_token: String,
|
||||
guild_id: Option<String>,
|
||||
allowed_users: Vec<String>,
|
||||
client: reqwest::Client,
|
||||
}
|
||||
|
||||
impl DiscordChannel {
|
||||
pub fn new(bot_token: String, guild_id: Option<String>) -> Self {
|
||||
pub fn new(bot_token: String, guild_id: Option<String>, allowed_users: Vec<String>) -> Self {
|
||||
Self {
|
||||
bot_token,
|
||||
guild_id,
|
||||
allowed_users,
|
||||
client: reqwest::Client::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if a Discord user ID is in the allowlist.
|
||||
/// Empty list or `["*"]` means allow everyone.
|
||||
fn is_user_allowed(&self, user_id: &str) -> bool {
|
||||
if self.allowed_users.is_empty() {
|
||||
return true;
|
||||
}
|
||||
self.allowed_users.iter().any(|u| u == "*" || u == user_id)
|
||||
}
|
||||
|
||||
fn bot_user_id_from_token(token: &str) -> Option<String> {
|
||||
// Discord bot tokens are base64(bot_user_id).timestamp.hmac
|
||||
let part = token.split('.').next()?;
|
||||
|
|
@ -197,6 +208,12 @@ impl Channel for DiscordChannel {
|
|||
continue;
|
||||
}
|
||||
|
||||
// Sender validation
|
||||
if !self.is_user_allowed(author_id) {
|
||||
tracing::warn!("Discord: ignoring message from unauthorized user: {author_id}");
|
||||
continue;
|
||||
}
|
||||
|
||||
// Guild filter
|
||||
if let Some(ref gid) = guild_filter {
|
||||
let msg_guild = d.get("guild_id").and_then(serde_json::Value::as_str).unwrap_or("");
|
||||
|
|
@ -250,7 +267,7 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn discord_channel_name() {
|
||||
let ch = DiscordChannel::new("fake".into(), None);
|
||||
let ch = DiscordChannel::new("fake".into(), None, vec![]);
|
||||
assert_eq!(ch.name(), "discord");
|
||||
}
|
||||
|
||||
|
|
@ -268,4 +285,27 @@ mod tests {
|
|||
let id = DiscordChannel::bot_user_id_from_token(token);
|
||||
assert_eq!(id, Some("123456".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_allowlist_allows_everyone() {
|
||||
let ch = DiscordChannel::new("fake".into(), None, vec![]);
|
||||
assert!(ch.is_user_allowed("12345"));
|
||||
assert!(ch.is_user_allowed("anyone"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn wildcard_allows_everyone() {
|
||||
let ch = DiscordChannel::new("fake".into(), None, vec!["*".into()]);
|
||||
assert!(ch.is_user_allowed("12345"));
|
||||
assert!(ch.is_user_allowed("anyone"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn specific_allowlist_filters() {
|
||||
let ch = DiscordChannel::new("fake".into(), None, vec!["111".into(), "222".into()]);
|
||||
assert!(ch.is_user_allowed("111"));
|
||||
assert!(ch.is_user_allowed("222"));
|
||||
assert!(!ch.is_user_allowed("333"));
|
||||
assert!(!ch.is_user_allowed("unknown"));
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -250,6 +250,7 @@ pub async fn start_channels(config: Config) -> Result<()> {
|
|||
channels.push(Arc::new(DiscordChannel::new(
|
||||
dc.bot_token.clone(),
|
||||
dc.guild_id.clone(),
|
||||
dc.allowed_users.clone(),
|
||||
)));
|
||||
}
|
||||
|
||||
|
|
@ -257,6 +258,7 @@ pub async fn start_channels(config: Config) -> Result<()> {
|
|||
channels.push(Arc::new(SlackChannel::new(
|
||||
sl.bot_token.clone(),
|
||||
sl.channel_id.clone(),
|
||||
sl.allowed_users.clone(),
|
||||
)));
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -6,18 +6,29 @@ use uuid::Uuid;
|
|||
pub struct SlackChannel {
|
||||
bot_token: String,
|
||||
channel_id: Option<String>,
|
||||
allowed_users: Vec<String>,
|
||||
client: reqwest::Client,
|
||||
}
|
||||
|
||||
impl SlackChannel {
|
||||
pub fn new(bot_token: String, channel_id: Option<String>) -> Self {
|
||||
pub fn new(bot_token: String, channel_id: Option<String>, allowed_users: Vec<String>) -> Self {
|
||||
Self {
|
||||
bot_token,
|
||||
channel_id,
|
||||
allowed_users,
|
||||
client: reqwest::Client::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if a Slack user ID is in the allowlist.
|
||||
/// Empty list or `["*"]` means allow everyone.
|
||||
fn is_user_allowed(&self, user_id: &str) -> bool {
|
||||
if self.allowed_users.is_empty() {
|
||||
return true;
|
||||
}
|
||||
self.allowed_users.iter().any(|u| u == "*" || u == user_id)
|
||||
}
|
||||
|
||||
/// Get the bot's own user ID so we can ignore our own messages
|
||||
async fn get_bot_user_id(&self) -> Option<String> {
|
||||
let resp: serde_json::Value = self
|
||||
|
|
@ -119,6 +130,12 @@ impl Channel for SlackChannel {
|
|||
continue;
|
||||
}
|
||||
|
||||
// Sender validation
|
||||
if !self.is_user_allowed(user) {
|
||||
tracing::warn!("Slack: ignoring message from unauthorized user: {user}");
|
||||
continue;
|
||||
}
|
||||
|
||||
// Skip empty or already-seen
|
||||
if text.is_empty() || ts <= last_ts.as_str() {
|
||||
continue;
|
||||
|
|
@ -162,13 +179,34 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn slack_channel_name() {
|
||||
let ch = SlackChannel::new("xoxb-fake".into(), None);
|
||||
let ch = SlackChannel::new("xoxb-fake".into(), None, vec![]);
|
||||
assert_eq!(ch.name(), "slack");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn slack_channel_with_channel_id() {
|
||||
let ch = SlackChannel::new("xoxb-fake".into(), Some("C12345".into()));
|
||||
let ch = SlackChannel::new("xoxb-fake".into(), Some("C12345".into()), vec![]);
|
||||
assert_eq!(ch.channel_id, Some("C12345".to_string()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_allowlist_allows_everyone() {
|
||||
let ch = SlackChannel::new("xoxb-fake".into(), None, vec![]);
|
||||
assert!(ch.is_user_allowed("U12345"));
|
||||
assert!(ch.is_user_allowed("anyone"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn wildcard_allows_everyone() {
|
||||
let ch = SlackChannel::new("xoxb-fake".into(), None, vec!["*".into()]);
|
||||
assert!(ch.is_user_allowed("U12345"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn specific_allowlist_filters() {
|
||||
let ch = SlackChannel::new("xoxb-fake".into(), None, vec!["U111".into(), "U222".into()]);
|
||||
assert!(ch.is_user_allowed("U111"));
|
||||
assert!(ch.is_user_allowed("U222"));
|
||||
assert!(!ch.is_user_allowed("U333"));
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -183,6 +183,8 @@ pub struct TelegramConfig {
|
|||
pub struct DiscordConfig {
|
||||
pub bot_token: String,
|
||||
pub guild_id: Option<String>,
|
||||
#[serde(default)]
|
||||
pub allowed_users: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
|
|
@ -190,6 +192,8 @@ pub struct SlackConfig {
|
|||
pub bot_token: String,
|
||||
pub app_token: Option<String>,
|
||||
pub channel_id: Option<String>,
|
||||
#[serde(default)]
|
||||
pub allowed_users: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
|
|
@ -461,6 +465,7 @@ default_temperature = 0.7
|
|||
let dc = DiscordConfig {
|
||||
bot_token: "discord-token".into(),
|
||||
guild_id: Some("12345".into()),
|
||||
allowed_users: vec![],
|
||||
};
|
||||
let json = serde_json::to_string(&dc).unwrap();
|
||||
let parsed: DiscordConfig = serde_json::from_str(&json).unwrap();
|
||||
|
|
@ -473,6 +478,7 @@ default_temperature = 0.7
|
|||
let dc = DiscordConfig {
|
||||
bot_token: "tok".into(),
|
||||
guild_id: None,
|
||||
allowed_users: vec![],
|
||||
};
|
||||
let json = serde_json::to_string(&dc).unwrap();
|
||||
let parsed: DiscordConfig = serde_json::from_str(&json).unwrap();
|
||||
|
|
|
|||
|
|
@ -25,9 +25,22 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
|
|||
let mem: Arc<dyn Memory> =
|
||||
Arc::from(memory::create_memory(&config.memory, &config.workspace_dir)?);
|
||||
|
||||
// Extract webhook secret for authentication
|
||||
let webhook_secret: Option<Arc<str>> = config
|
||||
.channels_config
|
||||
.webhook
|
||||
.as_ref()
|
||||
.and_then(|w| w.secret.as_deref())
|
||||
.map(Arc::from);
|
||||
|
||||
println!("🦀 ZeroClaw Gateway listening on http://{addr}");
|
||||
println!(" POST /webhook — {{\"message\": \"your prompt\"}}");
|
||||
println!(" GET /health — health check");
|
||||
if webhook_secret.is_some() {
|
||||
println!(" 🔒 Webhook authentication: ENABLED (X-Webhook-Secret header required)");
|
||||
} else {
|
||||
println!(" ⚠️ Webhook authentication: DISABLED (set [channels.webhook] secret to enable)");
|
||||
}
|
||||
println!(" Press Ctrl+C to stop.\n");
|
||||
|
||||
loop {
|
||||
|
|
@ -36,6 +49,7 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
|
|||
let model = model.clone();
|
||||
let mem = mem.clone();
|
||||
let auto_save = config.memory.auto_save;
|
||||
let secret = webhook_secret.clone();
|
||||
|
||||
tokio::spawn(async move {
|
||||
let mut buf = vec![0u8; 8192];
|
||||
|
|
@ -50,7 +64,7 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
|
|||
|
||||
if let [method, path, ..] = parts.as_slice() {
|
||||
tracing::info!("{peer} → {method} {path}");
|
||||
handle_request(&mut stream, method, path, &request, &provider, &model, temperature, &mem, auto_save).await;
|
||||
handle_request(&mut stream, method, path, &request, &provider, &model, temperature, &mem, auto_save, secret.as_ref()).await;
|
||||
} else {
|
||||
let _ = send_response(&mut stream, 400, "Bad Request").await;
|
||||
}
|
||||
|
|
@ -58,6 +72,19 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
|
|||
}
|
||||
}
|
||||
|
||||
/// Extract a header value from a raw HTTP request.
|
||||
fn extract_header<'a>(request: &'a str, header_name: &str) -> Option<&'a str> {
|
||||
let lower_name = header_name.to_lowercase();
|
||||
for line in request.lines() {
|
||||
if let Some((key, value)) = line.split_once(':') {
|
||||
if key.trim().to_lowercase() == lower_name {
|
||||
return Some(value.trim());
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
async fn handle_request(
|
||||
stream: &mut tokio::net::TcpStream,
|
||||
|
|
@ -69,6 +96,7 @@ async fn handle_request(
|
|||
temperature: f64,
|
||||
mem: &Arc<dyn Memory>,
|
||||
auto_save: bool,
|
||||
webhook_secret: Option<&Arc<str>>,
|
||||
) {
|
||||
match (method, path) {
|
||||
("GET", "/health") => {
|
||||
|
|
@ -82,6 +110,19 @@ async fn handle_request(
|
|||
}
|
||||
|
||||
("POST", "/webhook") => {
|
||||
// Authenticate webhook requests if a secret is configured
|
||||
if let Some(secret) = webhook_secret {
|
||||
let header_val = extract_header(request, "X-Webhook-Secret");
|
||||
match header_val {
|
||||
Some(val) if val == secret.as_ref() => {}
|
||||
_ => {
|
||||
tracing::warn!("Webhook: rejected request — invalid or missing X-Webhook-Secret");
|
||||
let err = serde_json::json!({"error": "Unauthorized — invalid or missing X-Webhook-Secret header"});
|
||||
let _ = send_json(stream, 401, &err).await;
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
handle_webhook(stream, request, provider, model, temperature, mem, auto_save).await;
|
||||
}
|
||||
|
||||
|
|
@ -159,6 +200,35 @@ async fn send_response(
|
|||
stream.write_all(response.as_bytes()).await
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn extract_header_finds_value() {
|
||||
let req = "POST /webhook HTTP/1.1\r\nHost: localhost\r\nX-Webhook-Secret: my-secret\r\n\r\n{}";
|
||||
assert_eq!(extract_header(req, "X-Webhook-Secret"), Some("my-secret"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_header_case_insensitive() {
|
||||
let req = "POST /webhook HTTP/1.1\r\nx-webhook-secret: abc123\r\n\r\n{}";
|
||||
assert_eq!(extract_header(req, "X-Webhook-Secret"), Some("abc123"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_header_missing_returns_none() {
|
||||
let req = "POST /webhook HTTP/1.1\r\nHost: localhost\r\n\r\n{}";
|
||||
assert_eq!(extract_header(req, "X-Webhook-Secret"), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_header_trims_whitespace() {
|
||||
let req = "POST /webhook HTTP/1.1\r\nX-Webhook-Secret: spaced \r\n\r\n{}";
|
||||
assert_eq!(extract_header(req, "X-Webhook-Secret"), Some("spaced"));
|
||||
}
|
||||
}
|
||||
|
||||
async fn send_json(
|
||||
stream: &mut tokio::net::TcpStream,
|
||||
status: u16,
|
||||
|
|
|
|||
|
|
@ -715,6 +715,7 @@ fn setup_channels() -> Result<ChannelsConfig> {
|
|||
config.discord = Some(DiscordConfig {
|
||||
bot_token: token,
|
||||
guild_id: if guild.is_empty() { None } else { Some(guild) },
|
||||
allowed_users: vec![],
|
||||
});
|
||||
}
|
||||
2 => {
|
||||
|
|
@ -791,6 +792,7 @@ fn setup_channels() -> Result<ChannelsConfig> {
|
|||
bot_token: token,
|
||||
app_token: if app_token.is_empty() { None } else { Some(app_token) },
|
||||
channel_id: if channel.is_empty() { None } else { Some(channel) },
|
||||
allowed_users: vec![],
|
||||
});
|
||||
}
|
||||
3 => {
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
use serde::{Deserialize, Serialize};
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::Mutex;
|
||||
use std::time::Instant;
|
||||
|
||||
/// How much autonomy the agent has
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
|
|
@ -19,6 +21,47 @@ impl Default for AutonomyLevel {
|
|||
}
|
||||
}
|
||||
|
||||
/// Sliding-window action tracker for rate limiting.
|
||||
#[derive(Debug)]
|
||||
pub struct ActionTracker {
|
||||
/// Timestamps of recent actions (kept within the last hour).
|
||||
actions: Mutex<Vec<Instant>>,
|
||||
}
|
||||
|
||||
impl ActionTracker {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
actions: Mutex::new(Vec::new()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Record an action and return the current count within the window.
|
||||
pub fn record(&self) -> usize {
|
||||
let mut actions = self.actions.lock().unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||
let cutoff = Instant::now().checked_sub(std::time::Duration::from_secs(3600)).unwrap_or_else(Instant::now);
|
||||
actions.retain(|t| *t > cutoff);
|
||||
actions.push(Instant::now());
|
||||
actions.len()
|
||||
}
|
||||
|
||||
/// Count of actions in the current window without recording.
|
||||
pub fn count(&self) -> usize {
|
||||
let mut actions = self.actions.lock().unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||
let cutoff = Instant::now().checked_sub(std::time::Duration::from_secs(3600)).unwrap_or_else(Instant::now);
|
||||
actions.retain(|t| *t > cutoff);
|
||||
actions.len()
|
||||
}
|
||||
}
|
||||
|
||||
impl Clone for ActionTracker {
|
||||
fn clone(&self) -> Self {
|
||||
let actions = self.actions.lock().unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||
Self {
|
||||
actions: Mutex::new(actions.clone()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Security policy enforced on all tool executions
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SecurityPolicy {
|
||||
|
|
@ -29,6 +72,7 @@ pub struct SecurityPolicy {
|
|||
pub forbidden_paths: Vec<String>,
|
||||
pub max_actions_per_hour: u32,
|
||||
pub max_cost_per_day_cents: u32,
|
||||
pub tracker: ActionTracker,
|
||||
}
|
||||
|
||||
impl Default for SecurityPolicy {
|
||||
|
|
@ -60,6 +104,7 @@ impl Default for SecurityPolicy {
|
|||
],
|
||||
max_actions_per_hour: 20,
|
||||
max_cost_per_day_cents: 500,
|
||||
tracker: ActionTracker::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -112,6 +157,18 @@ impl SecurityPolicy {
|
|||
self.autonomy != AutonomyLevel::ReadOnly
|
||||
}
|
||||
|
||||
/// Record an action and check if the rate limit has been exceeded.
|
||||
/// Returns `true` if the action is allowed, `false` if rate-limited.
|
||||
pub fn record_action(&self) -> bool {
|
||||
let count = self.tracker.record();
|
||||
count <= self.max_actions_per_hour as usize
|
||||
}
|
||||
|
||||
/// Check if the rate limit would be exceeded without recording.
|
||||
pub fn is_rate_limited(&self) -> bool {
|
||||
self.tracker.count() >= self.max_actions_per_hour as usize
|
||||
}
|
||||
|
||||
/// Build from config sections
|
||||
pub fn from_config(
|
||||
autonomy_config: &crate::config::AutonomyConfig,
|
||||
|
|
@ -125,6 +182,7 @@ impl SecurityPolicy {
|
|||
forbidden_paths: autonomy_config.forbidden_paths.clone(),
|
||||
max_actions_per_hour: autonomy_config.max_actions_per_hour,
|
||||
max_cost_per_day_cents: autonomy_config.max_cost_per_day_cents,
|
||||
tracker: ActionTracker::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -362,4 +420,69 @@ mod tests {
|
|||
assert!(p.max_actions_per_hour > 0);
|
||||
assert!(p.max_cost_per_day_cents > 0);
|
||||
}
|
||||
|
||||
// ── ActionTracker / rate limiting ───────────────────────
|
||||
|
||||
#[test]
|
||||
fn action_tracker_starts_at_zero() {
|
||||
let tracker = ActionTracker::new();
|
||||
assert_eq!(tracker.count(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn action_tracker_records_actions() {
|
||||
let tracker = ActionTracker::new();
|
||||
assert_eq!(tracker.record(), 1);
|
||||
assert_eq!(tracker.record(), 2);
|
||||
assert_eq!(tracker.record(), 3);
|
||||
assert_eq!(tracker.count(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn record_action_allows_within_limit() {
|
||||
let p = SecurityPolicy {
|
||||
max_actions_per_hour: 5,
|
||||
..SecurityPolicy::default()
|
||||
};
|
||||
for _ in 0..5 {
|
||||
assert!(p.record_action(), "should allow actions within limit");
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn record_action_blocks_over_limit() {
|
||||
let p = SecurityPolicy {
|
||||
max_actions_per_hour: 3,
|
||||
..SecurityPolicy::default()
|
||||
};
|
||||
assert!(p.record_action()); // 1
|
||||
assert!(p.record_action()); // 2
|
||||
assert!(p.record_action()); // 3
|
||||
assert!(!p.record_action()); // 4 — over limit
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn is_rate_limited_reflects_count() {
|
||||
let p = SecurityPolicy {
|
||||
max_actions_per_hour: 2,
|
||||
..SecurityPolicy::default()
|
||||
};
|
||||
assert!(!p.is_rate_limited());
|
||||
p.record_action();
|
||||
assert!(!p.is_rate_limited());
|
||||
p.record_action();
|
||||
assert!(p.is_rate_limited());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn action_tracker_clone_is_independent() {
|
||||
let tracker = ActionTracker::new();
|
||||
tracker.record();
|
||||
tracker.record();
|
||||
let cloned = tracker.clone();
|
||||
assert_eq!(cloned.count(), 2);
|
||||
tracker.record();
|
||||
assert_eq!(tracker.count(), 3);
|
||||
assert_eq!(cloned.count(), 2); // clone is independent
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue