security: harden architecture against Moltbot security model

- Discord: add allowed_users field + sender validation in listen()
- Slack: add allowed_users field + sender validation in listen()
- Webhook: add X-Webhook-Secret header auth (401 on mismatch)
- SecurityPolicy: add ActionTracker with sliding-window rate limiting
  - record_action() enforces max_actions_per_hour
  - is_rate_limited() checks without recording
- Gateway: print auth status on startup (ENABLED/DISABLED)
- 22 new tests (Discord/Slack allowlists, gateway header extraction,
  rate limiter: starts at zero, records, allows within limit,
  blocks over limit, clone independence)
- 554 tests passing, 0 clippy warnings
This commit is contained in:
argenis de la rosa 2026-02-13 15:31:21 -05:00
parent cf0ca71fdc
commit 542bb80743
7 changed files with 287 additions and 6 deletions

View file

@ -9,18 +9,29 @@ use uuid::Uuid;
pub struct DiscordChannel {
bot_token: String,
guild_id: Option<String>,
allowed_users: Vec<String>,
client: reqwest::Client,
}
impl DiscordChannel {
pub fn new(bot_token: String, guild_id: Option<String>) -> Self {
pub fn new(bot_token: String, guild_id: Option<String>, allowed_users: Vec<String>) -> Self {
Self {
bot_token,
guild_id,
allowed_users,
client: reqwest::Client::new(),
}
}
/// Check if a Discord user ID is in the allowlist.
/// Empty list or `["*"]` means allow everyone.
fn is_user_allowed(&self, user_id: &str) -> bool {
if self.allowed_users.is_empty() {
return true;
}
self.allowed_users.iter().any(|u| u == "*" || u == user_id)
}
fn bot_user_id_from_token(token: &str) -> Option<String> {
// Discord bot tokens are base64(bot_user_id).timestamp.hmac
let part = token.split('.').next()?;
@ -197,6 +208,12 @@ impl Channel for DiscordChannel {
continue;
}
// Sender validation
if !self.is_user_allowed(author_id) {
tracing::warn!("Discord: ignoring message from unauthorized user: {author_id}");
continue;
}
// Guild filter
if let Some(ref gid) = guild_filter {
let msg_guild = d.get("guild_id").and_then(serde_json::Value::as_str).unwrap_or("");
@ -250,7 +267,7 @@ mod tests {
#[test]
fn discord_channel_name() {
let ch = DiscordChannel::new("fake".into(), None);
let ch = DiscordChannel::new("fake".into(), None, vec![]);
assert_eq!(ch.name(), "discord");
}
@ -268,4 +285,27 @@ mod tests {
let id = DiscordChannel::bot_user_id_from_token(token);
assert_eq!(id, Some("123456".to_string()));
}
#[test]
fn empty_allowlist_allows_everyone() {
let ch = DiscordChannel::new("fake".into(), None, vec![]);
assert!(ch.is_user_allowed("12345"));
assert!(ch.is_user_allowed("anyone"));
}
#[test]
fn wildcard_allows_everyone() {
let ch = DiscordChannel::new("fake".into(), None, vec!["*".into()]);
assert!(ch.is_user_allowed("12345"));
assert!(ch.is_user_allowed("anyone"));
}
#[test]
fn specific_allowlist_filters() {
let ch = DiscordChannel::new("fake".into(), None, vec!["111".into(), "222".into()]);
assert!(ch.is_user_allowed("111"));
assert!(ch.is_user_allowed("222"));
assert!(!ch.is_user_allowed("333"));
assert!(!ch.is_user_allowed("unknown"));
}
}

View file

@ -250,6 +250,7 @@ pub async fn start_channels(config: Config) -> Result<()> {
channels.push(Arc::new(DiscordChannel::new(
dc.bot_token.clone(),
dc.guild_id.clone(),
dc.allowed_users.clone(),
)));
}
@ -257,6 +258,7 @@ pub async fn start_channels(config: Config) -> Result<()> {
channels.push(Arc::new(SlackChannel::new(
sl.bot_token.clone(),
sl.channel_id.clone(),
sl.allowed_users.clone(),
)));
}

View file

@ -6,18 +6,29 @@ use uuid::Uuid;
pub struct SlackChannel {
bot_token: String,
channel_id: Option<String>,
allowed_users: Vec<String>,
client: reqwest::Client,
}
impl SlackChannel {
pub fn new(bot_token: String, channel_id: Option<String>) -> Self {
pub fn new(bot_token: String, channel_id: Option<String>, allowed_users: Vec<String>) -> Self {
Self {
bot_token,
channel_id,
allowed_users,
client: reqwest::Client::new(),
}
}
/// Check if a Slack user ID is in the allowlist.
/// Empty list or `["*"]` means allow everyone.
fn is_user_allowed(&self, user_id: &str) -> bool {
if self.allowed_users.is_empty() {
return true;
}
self.allowed_users.iter().any(|u| u == "*" || u == user_id)
}
/// Get the bot's own user ID so we can ignore our own messages
async fn get_bot_user_id(&self) -> Option<String> {
let resp: serde_json::Value = self
@ -119,6 +130,12 @@ impl Channel for SlackChannel {
continue;
}
// Sender validation
if !self.is_user_allowed(user) {
tracing::warn!("Slack: ignoring message from unauthorized user: {user}");
continue;
}
// Skip empty or already-seen
if text.is_empty() || ts <= last_ts.as_str() {
continue;
@ -162,13 +179,34 @@ mod tests {
#[test]
fn slack_channel_name() {
let ch = SlackChannel::new("xoxb-fake".into(), None);
let ch = SlackChannel::new("xoxb-fake".into(), None, vec![]);
assert_eq!(ch.name(), "slack");
}
#[test]
fn slack_channel_with_channel_id() {
let ch = SlackChannel::new("xoxb-fake".into(), Some("C12345".into()));
let ch = SlackChannel::new("xoxb-fake".into(), Some("C12345".into()), vec![]);
assert_eq!(ch.channel_id, Some("C12345".to_string()));
}
#[test]
fn empty_allowlist_allows_everyone() {
let ch = SlackChannel::new("xoxb-fake".into(), None, vec![]);
assert!(ch.is_user_allowed("U12345"));
assert!(ch.is_user_allowed("anyone"));
}
#[test]
fn wildcard_allows_everyone() {
let ch = SlackChannel::new("xoxb-fake".into(), None, vec!["*".into()]);
assert!(ch.is_user_allowed("U12345"));
}
#[test]
fn specific_allowlist_filters() {
let ch = SlackChannel::new("xoxb-fake".into(), None, vec!["U111".into(), "U222".into()]);
assert!(ch.is_user_allowed("U111"));
assert!(ch.is_user_allowed("U222"));
assert!(!ch.is_user_allowed("U333"));
}
}

View file

@ -183,6 +183,8 @@ pub struct TelegramConfig {
pub struct DiscordConfig {
pub bot_token: String,
pub guild_id: Option<String>,
#[serde(default)]
pub allowed_users: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
@ -190,6 +192,8 @@ pub struct SlackConfig {
pub bot_token: String,
pub app_token: Option<String>,
pub channel_id: Option<String>,
#[serde(default)]
pub allowed_users: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
@ -461,6 +465,7 @@ default_temperature = 0.7
let dc = DiscordConfig {
bot_token: "discord-token".into(),
guild_id: Some("12345".into()),
allowed_users: vec![],
};
let json = serde_json::to_string(&dc).unwrap();
let parsed: DiscordConfig = serde_json::from_str(&json).unwrap();
@ -473,6 +478,7 @@ default_temperature = 0.7
let dc = DiscordConfig {
bot_token: "tok".into(),
guild_id: None,
allowed_users: vec![],
};
let json = serde_json::to_string(&dc).unwrap();
let parsed: DiscordConfig = serde_json::from_str(&json).unwrap();

View file

@ -25,9 +25,22 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
let mem: Arc<dyn Memory> =
Arc::from(memory::create_memory(&config.memory, &config.workspace_dir)?);
// Extract webhook secret for authentication
let webhook_secret: Option<Arc<str>> = config
.channels_config
.webhook
.as_ref()
.and_then(|w| w.secret.as_deref())
.map(Arc::from);
println!("🦀 ZeroClaw Gateway listening on http://{addr}");
println!(" POST /webhook — {{\"message\": \"your prompt\"}}");
println!(" GET /health — health check");
if webhook_secret.is_some() {
println!(" 🔒 Webhook authentication: ENABLED (X-Webhook-Secret header required)");
} else {
println!(" ⚠️ Webhook authentication: DISABLED (set [channels.webhook] secret to enable)");
}
println!(" Press Ctrl+C to stop.\n");
loop {
@ -36,6 +49,7 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
let model = model.clone();
let mem = mem.clone();
let auto_save = config.memory.auto_save;
let secret = webhook_secret.clone();
tokio::spawn(async move {
let mut buf = vec![0u8; 8192];
@ -50,7 +64,7 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
if let [method, path, ..] = parts.as_slice() {
tracing::info!("{peer} → {method} {path}");
handle_request(&mut stream, method, path, &request, &provider, &model, temperature, &mem, auto_save).await;
handle_request(&mut stream, method, path, &request, &provider, &model, temperature, &mem, auto_save, secret.as_ref()).await;
} else {
let _ = send_response(&mut stream, 400, "Bad Request").await;
}
@ -58,6 +72,19 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
}
}
/// Extract a header value from a raw HTTP request.
fn extract_header<'a>(request: &'a str, header_name: &str) -> Option<&'a str> {
let lower_name = header_name.to_lowercase();
for line in request.lines() {
if let Some((key, value)) = line.split_once(':') {
if key.trim().to_lowercase() == lower_name {
return Some(value.trim());
}
}
}
None
}
#[allow(clippy::too_many_arguments)]
async fn handle_request(
stream: &mut tokio::net::TcpStream,
@ -69,6 +96,7 @@ async fn handle_request(
temperature: f64,
mem: &Arc<dyn Memory>,
auto_save: bool,
webhook_secret: Option<&Arc<str>>,
) {
match (method, path) {
("GET", "/health") => {
@ -82,6 +110,19 @@ async fn handle_request(
}
("POST", "/webhook") => {
// Authenticate webhook requests if a secret is configured
if let Some(secret) = webhook_secret {
let header_val = extract_header(request, "X-Webhook-Secret");
match header_val {
Some(val) if val == secret.as_ref() => {}
_ => {
tracing::warn!("Webhook: rejected request — invalid or missing X-Webhook-Secret");
let err = serde_json::json!({"error": "Unauthorized — invalid or missing X-Webhook-Secret header"});
let _ = send_json(stream, 401, &err).await;
return;
}
}
}
handle_webhook(stream, request, provider, model, temperature, mem, auto_save).await;
}
@ -159,6 +200,35 @@ async fn send_response(
stream.write_all(response.as_bytes()).await
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn extract_header_finds_value() {
let req = "POST /webhook HTTP/1.1\r\nHost: localhost\r\nX-Webhook-Secret: my-secret\r\n\r\n{}";
assert_eq!(extract_header(req, "X-Webhook-Secret"), Some("my-secret"));
}
#[test]
fn extract_header_case_insensitive() {
let req = "POST /webhook HTTP/1.1\r\nx-webhook-secret: abc123\r\n\r\n{}";
assert_eq!(extract_header(req, "X-Webhook-Secret"), Some("abc123"));
}
#[test]
fn extract_header_missing_returns_none() {
let req = "POST /webhook HTTP/1.1\r\nHost: localhost\r\n\r\n{}";
assert_eq!(extract_header(req, "X-Webhook-Secret"), None);
}
#[test]
fn extract_header_trims_whitespace() {
let req = "POST /webhook HTTP/1.1\r\nX-Webhook-Secret: spaced \r\n\r\n{}";
assert_eq!(extract_header(req, "X-Webhook-Secret"), Some("spaced"));
}
}
async fn send_json(
stream: &mut tokio::net::TcpStream,
status: u16,

View file

@ -715,6 +715,7 @@ fn setup_channels() -> Result<ChannelsConfig> {
config.discord = Some(DiscordConfig {
bot_token: token,
guild_id: if guild.is_empty() { None } else { Some(guild) },
allowed_users: vec![],
});
}
2 => {
@ -791,6 +792,7 @@ fn setup_channels() -> Result<ChannelsConfig> {
bot_token: token,
app_token: if app_token.is_empty() { None } else { Some(app_token) },
channel_id: if channel.is_empty() { None } else { Some(channel) },
allowed_users: vec![],
});
}
3 => {

View file

@ -1,5 +1,7 @@
use serde::{Deserialize, Serialize};
use std::path::{Path, PathBuf};
use std::sync::Mutex;
use std::time::Instant;
/// How much autonomy the agent has
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
@ -19,6 +21,47 @@ impl Default for AutonomyLevel {
}
}
/// Sliding-window action tracker for rate limiting.
#[derive(Debug)]
pub struct ActionTracker {
/// Timestamps of recent actions (kept within the last hour).
actions: Mutex<Vec<Instant>>,
}
impl ActionTracker {
pub fn new() -> Self {
Self {
actions: Mutex::new(Vec::new()),
}
}
/// Record an action and return the current count within the window.
pub fn record(&self) -> usize {
let mut actions = self.actions.lock().unwrap_or_else(std::sync::PoisonError::into_inner);
let cutoff = Instant::now().checked_sub(std::time::Duration::from_secs(3600)).unwrap_or_else(Instant::now);
actions.retain(|t| *t > cutoff);
actions.push(Instant::now());
actions.len()
}
/// Count of actions in the current window without recording.
pub fn count(&self) -> usize {
let mut actions = self.actions.lock().unwrap_or_else(std::sync::PoisonError::into_inner);
let cutoff = Instant::now().checked_sub(std::time::Duration::from_secs(3600)).unwrap_or_else(Instant::now);
actions.retain(|t| *t > cutoff);
actions.len()
}
}
impl Clone for ActionTracker {
fn clone(&self) -> Self {
let actions = self.actions.lock().unwrap_or_else(std::sync::PoisonError::into_inner);
Self {
actions: Mutex::new(actions.clone()),
}
}
}
/// Security policy enforced on all tool executions
#[derive(Debug, Clone)]
pub struct SecurityPolicy {
@ -29,6 +72,7 @@ pub struct SecurityPolicy {
pub forbidden_paths: Vec<String>,
pub max_actions_per_hour: u32,
pub max_cost_per_day_cents: u32,
pub tracker: ActionTracker,
}
impl Default for SecurityPolicy {
@ -60,6 +104,7 @@ impl Default for SecurityPolicy {
],
max_actions_per_hour: 20,
max_cost_per_day_cents: 500,
tracker: ActionTracker::new(),
}
}
}
@ -112,6 +157,18 @@ impl SecurityPolicy {
self.autonomy != AutonomyLevel::ReadOnly
}
/// Record an action and check if the rate limit has been exceeded.
/// Returns `true` if the action is allowed, `false` if rate-limited.
pub fn record_action(&self) -> bool {
let count = self.tracker.record();
count <= self.max_actions_per_hour as usize
}
/// Check if the rate limit would be exceeded without recording.
pub fn is_rate_limited(&self) -> bool {
self.tracker.count() >= self.max_actions_per_hour as usize
}
/// Build from config sections
pub fn from_config(
autonomy_config: &crate::config::AutonomyConfig,
@ -125,6 +182,7 @@ impl SecurityPolicy {
forbidden_paths: autonomy_config.forbidden_paths.clone(),
max_actions_per_hour: autonomy_config.max_actions_per_hour,
max_cost_per_day_cents: autonomy_config.max_cost_per_day_cents,
tracker: ActionTracker::new(),
}
}
}
@ -362,4 +420,69 @@ mod tests {
assert!(p.max_actions_per_hour > 0);
assert!(p.max_cost_per_day_cents > 0);
}
// ── ActionTracker / rate limiting ───────────────────────
#[test]
fn action_tracker_starts_at_zero() {
let tracker = ActionTracker::new();
assert_eq!(tracker.count(), 0);
}
#[test]
fn action_tracker_records_actions() {
let tracker = ActionTracker::new();
assert_eq!(tracker.record(), 1);
assert_eq!(tracker.record(), 2);
assert_eq!(tracker.record(), 3);
assert_eq!(tracker.count(), 3);
}
#[test]
fn record_action_allows_within_limit() {
let p = SecurityPolicy {
max_actions_per_hour: 5,
..SecurityPolicy::default()
};
for _ in 0..5 {
assert!(p.record_action(), "should allow actions within limit");
}
}
#[test]
fn record_action_blocks_over_limit() {
let p = SecurityPolicy {
max_actions_per_hour: 3,
..SecurityPolicy::default()
};
assert!(p.record_action()); // 1
assert!(p.record_action()); // 2
assert!(p.record_action()); // 3
assert!(!p.record_action()); // 4 — over limit
}
#[test]
fn is_rate_limited_reflects_count() {
let p = SecurityPolicy {
max_actions_per_hour: 2,
..SecurityPolicy::default()
};
assert!(!p.is_rate_limited());
p.record_action();
assert!(!p.is_rate_limited());
p.record_action();
assert!(p.is_rate_limited());
}
#[test]
fn action_tracker_clone_is_independent() {
let tracker = ActionTracker::new();
tracker.record();
tracker.record();
let cloned = tracker.clone();
assert_eq!(cloned.count(), 2);
tracker.record();
assert_eq!(tracker.count(), 3);
assert_eq!(cloned.count(), 2); // clone is independent
}
}