security: harden architecture against Moltbot security model
- Discord: add allowed_users field + sender validation in listen() - Slack: add allowed_users field + sender validation in listen() - Webhook: add X-Webhook-Secret header auth (401 on mismatch) - SecurityPolicy: add ActionTracker with sliding-window rate limiting - record_action() enforces max_actions_per_hour - is_rate_limited() checks without recording - Gateway: print auth status on startup (ENABLED/DISABLED) - 22 new tests (Discord/Slack allowlists, gateway header extraction, rate limiter: starts at zero, records, allows within limit, blocks over limit, clone independence) - 554 tests passing, 0 clippy warnings
This commit is contained in:
parent
cf0ca71fdc
commit
542bb80743
7 changed files with 287 additions and 6 deletions
|
|
@ -9,18 +9,29 @@ use uuid::Uuid;
|
||||||
pub struct DiscordChannel {
|
pub struct DiscordChannel {
|
||||||
bot_token: String,
|
bot_token: String,
|
||||||
guild_id: Option<String>,
|
guild_id: Option<String>,
|
||||||
|
allowed_users: Vec<String>,
|
||||||
client: reqwest::Client,
|
client: reqwest::Client,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl DiscordChannel {
|
impl DiscordChannel {
|
||||||
pub fn new(bot_token: String, guild_id: Option<String>) -> Self {
|
pub fn new(bot_token: String, guild_id: Option<String>, allowed_users: Vec<String>) -> Self {
|
||||||
Self {
|
Self {
|
||||||
bot_token,
|
bot_token,
|
||||||
guild_id,
|
guild_id,
|
||||||
|
allowed_users,
|
||||||
client: reqwest::Client::new(),
|
client: reqwest::Client::new(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Check if a Discord user ID is in the allowlist.
|
||||||
|
/// Empty list or `["*"]` means allow everyone.
|
||||||
|
fn is_user_allowed(&self, user_id: &str) -> bool {
|
||||||
|
if self.allowed_users.is_empty() {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
self.allowed_users.iter().any(|u| u == "*" || u == user_id)
|
||||||
|
}
|
||||||
|
|
||||||
fn bot_user_id_from_token(token: &str) -> Option<String> {
|
fn bot_user_id_from_token(token: &str) -> Option<String> {
|
||||||
// Discord bot tokens are base64(bot_user_id).timestamp.hmac
|
// Discord bot tokens are base64(bot_user_id).timestamp.hmac
|
||||||
let part = token.split('.').next()?;
|
let part = token.split('.').next()?;
|
||||||
|
|
@ -197,6 +208,12 @@ impl Channel for DiscordChannel {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Sender validation
|
||||||
|
if !self.is_user_allowed(author_id) {
|
||||||
|
tracing::warn!("Discord: ignoring message from unauthorized user: {author_id}");
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
// Guild filter
|
// Guild filter
|
||||||
if let Some(ref gid) = guild_filter {
|
if let Some(ref gid) = guild_filter {
|
||||||
let msg_guild = d.get("guild_id").and_then(serde_json::Value::as_str).unwrap_or("");
|
let msg_guild = d.get("guild_id").and_then(serde_json::Value::as_str).unwrap_or("");
|
||||||
|
|
@ -250,7 +267,7 @@ mod tests {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn discord_channel_name() {
|
fn discord_channel_name() {
|
||||||
let ch = DiscordChannel::new("fake".into(), None);
|
let ch = DiscordChannel::new("fake".into(), None, vec![]);
|
||||||
assert_eq!(ch.name(), "discord");
|
assert_eq!(ch.name(), "discord");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -268,4 +285,27 @@ mod tests {
|
||||||
let id = DiscordChannel::bot_user_id_from_token(token);
|
let id = DiscordChannel::bot_user_id_from_token(token);
|
||||||
assert_eq!(id, Some("123456".to_string()));
|
assert_eq!(id, Some("123456".to_string()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn empty_allowlist_allows_everyone() {
|
||||||
|
let ch = DiscordChannel::new("fake".into(), None, vec![]);
|
||||||
|
assert!(ch.is_user_allowed("12345"));
|
||||||
|
assert!(ch.is_user_allowed("anyone"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn wildcard_allows_everyone() {
|
||||||
|
let ch = DiscordChannel::new("fake".into(), None, vec!["*".into()]);
|
||||||
|
assert!(ch.is_user_allowed("12345"));
|
||||||
|
assert!(ch.is_user_allowed("anyone"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn specific_allowlist_filters() {
|
||||||
|
let ch = DiscordChannel::new("fake".into(), None, vec!["111".into(), "222".into()]);
|
||||||
|
assert!(ch.is_user_allowed("111"));
|
||||||
|
assert!(ch.is_user_allowed("222"));
|
||||||
|
assert!(!ch.is_user_allowed("333"));
|
||||||
|
assert!(!ch.is_user_allowed("unknown"));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -250,6 +250,7 @@ pub async fn start_channels(config: Config) -> Result<()> {
|
||||||
channels.push(Arc::new(DiscordChannel::new(
|
channels.push(Arc::new(DiscordChannel::new(
|
||||||
dc.bot_token.clone(),
|
dc.bot_token.clone(),
|
||||||
dc.guild_id.clone(),
|
dc.guild_id.clone(),
|
||||||
|
dc.allowed_users.clone(),
|
||||||
)));
|
)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -257,6 +258,7 @@ pub async fn start_channels(config: Config) -> Result<()> {
|
||||||
channels.push(Arc::new(SlackChannel::new(
|
channels.push(Arc::new(SlackChannel::new(
|
||||||
sl.bot_token.clone(),
|
sl.bot_token.clone(),
|
||||||
sl.channel_id.clone(),
|
sl.channel_id.clone(),
|
||||||
|
sl.allowed_users.clone(),
|
||||||
)));
|
)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -6,18 +6,29 @@ use uuid::Uuid;
|
||||||
pub struct SlackChannel {
|
pub struct SlackChannel {
|
||||||
bot_token: String,
|
bot_token: String,
|
||||||
channel_id: Option<String>,
|
channel_id: Option<String>,
|
||||||
|
allowed_users: Vec<String>,
|
||||||
client: reqwest::Client,
|
client: reqwest::Client,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl SlackChannel {
|
impl SlackChannel {
|
||||||
pub fn new(bot_token: String, channel_id: Option<String>) -> Self {
|
pub fn new(bot_token: String, channel_id: Option<String>, allowed_users: Vec<String>) -> Self {
|
||||||
Self {
|
Self {
|
||||||
bot_token,
|
bot_token,
|
||||||
channel_id,
|
channel_id,
|
||||||
|
allowed_users,
|
||||||
client: reqwest::Client::new(),
|
client: reqwest::Client::new(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Check if a Slack user ID is in the allowlist.
|
||||||
|
/// Empty list or `["*"]` means allow everyone.
|
||||||
|
fn is_user_allowed(&self, user_id: &str) -> bool {
|
||||||
|
if self.allowed_users.is_empty() {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
self.allowed_users.iter().any(|u| u == "*" || u == user_id)
|
||||||
|
}
|
||||||
|
|
||||||
/// Get the bot's own user ID so we can ignore our own messages
|
/// Get the bot's own user ID so we can ignore our own messages
|
||||||
async fn get_bot_user_id(&self) -> Option<String> {
|
async fn get_bot_user_id(&self) -> Option<String> {
|
||||||
let resp: serde_json::Value = self
|
let resp: serde_json::Value = self
|
||||||
|
|
@ -119,6 +130,12 @@ impl Channel for SlackChannel {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Sender validation
|
||||||
|
if !self.is_user_allowed(user) {
|
||||||
|
tracing::warn!("Slack: ignoring message from unauthorized user: {user}");
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
// Skip empty or already-seen
|
// Skip empty or already-seen
|
||||||
if text.is_empty() || ts <= last_ts.as_str() {
|
if text.is_empty() || ts <= last_ts.as_str() {
|
||||||
continue;
|
continue;
|
||||||
|
|
@ -162,13 +179,34 @@ mod tests {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn slack_channel_name() {
|
fn slack_channel_name() {
|
||||||
let ch = SlackChannel::new("xoxb-fake".into(), None);
|
let ch = SlackChannel::new("xoxb-fake".into(), None, vec![]);
|
||||||
assert_eq!(ch.name(), "slack");
|
assert_eq!(ch.name(), "slack");
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn slack_channel_with_channel_id() {
|
fn slack_channel_with_channel_id() {
|
||||||
let ch = SlackChannel::new("xoxb-fake".into(), Some("C12345".into()));
|
let ch = SlackChannel::new("xoxb-fake".into(), Some("C12345".into()), vec![]);
|
||||||
assert_eq!(ch.channel_id, Some("C12345".to_string()));
|
assert_eq!(ch.channel_id, Some("C12345".to_string()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn empty_allowlist_allows_everyone() {
|
||||||
|
let ch = SlackChannel::new("xoxb-fake".into(), None, vec![]);
|
||||||
|
assert!(ch.is_user_allowed("U12345"));
|
||||||
|
assert!(ch.is_user_allowed("anyone"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn wildcard_allows_everyone() {
|
||||||
|
let ch = SlackChannel::new("xoxb-fake".into(), None, vec!["*".into()]);
|
||||||
|
assert!(ch.is_user_allowed("U12345"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn specific_allowlist_filters() {
|
||||||
|
let ch = SlackChannel::new("xoxb-fake".into(), None, vec!["U111".into(), "U222".into()]);
|
||||||
|
assert!(ch.is_user_allowed("U111"));
|
||||||
|
assert!(ch.is_user_allowed("U222"));
|
||||||
|
assert!(!ch.is_user_allowed("U333"));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -183,6 +183,8 @@ pub struct TelegramConfig {
|
||||||
pub struct DiscordConfig {
|
pub struct DiscordConfig {
|
||||||
pub bot_token: String,
|
pub bot_token: String,
|
||||||
pub guild_id: Option<String>,
|
pub guild_id: Option<String>,
|
||||||
|
#[serde(default)]
|
||||||
|
pub allowed_users: Vec<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
|
@ -190,6 +192,8 @@ pub struct SlackConfig {
|
||||||
pub bot_token: String,
|
pub bot_token: String,
|
||||||
pub app_token: Option<String>,
|
pub app_token: Option<String>,
|
||||||
pub channel_id: Option<String>,
|
pub channel_id: Option<String>,
|
||||||
|
#[serde(default)]
|
||||||
|
pub allowed_users: Vec<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
|
@ -461,6 +465,7 @@ default_temperature = 0.7
|
||||||
let dc = DiscordConfig {
|
let dc = DiscordConfig {
|
||||||
bot_token: "discord-token".into(),
|
bot_token: "discord-token".into(),
|
||||||
guild_id: Some("12345".into()),
|
guild_id: Some("12345".into()),
|
||||||
|
allowed_users: vec![],
|
||||||
};
|
};
|
||||||
let json = serde_json::to_string(&dc).unwrap();
|
let json = serde_json::to_string(&dc).unwrap();
|
||||||
let parsed: DiscordConfig = serde_json::from_str(&json).unwrap();
|
let parsed: DiscordConfig = serde_json::from_str(&json).unwrap();
|
||||||
|
|
@ -473,6 +478,7 @@ default_temperature = 0.7
|
||||||
let dc = DiscordConfig {
|
let dc = DiscordConfig {
|
||||||
bot_token: "tok".into(),
|
bot_token: "tok".into(),
|
||||||
guild_id: None,
|
guild_id: None,
|
||||||
|
allowed_users: vec![],
|
||||||
};
|
};
|
||||||
let json = serde_json::to_string(&dc).unwrap();
|
let json = serde_json::to_string(&dc).unwrap();
|
||||||
let parsed: DiscordConfig = serde_json::from_str(&json).unwrap();
|
let parsed: DiscordConfig = serde_json::from_str(&json).unwrap();
|
||||||
|
|
|
||||||
|
|
@ -25,9 +25,22 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
|
||||||
let mem: Arc<dyn Memory> =
|
let mem: Arc<dyn Memory> =
|
||||||
Arc::from(memory::create_memory(&config.memory, &config.workspace_dir)?);
|
Arc::from(memory::create_memory(&config.memory, &config.workspace_dir)?);
|
||||||
|
|
||||||
|
// Extract webhook secret for authentication
|
||||||
|
let webhook_secret: Option<Arc<str>> = config
|
||||||
|
.channels_config
|
||||||
|
.webhook
|
||||||
|
.as_ref()
|
||||||
|
.and_then(|w| w.secret.as_deref())
|
||||||
|
.map(Arc::from);
|
||||||
|
|
||||||
println!("🦀 ZeroClaw Gateway listening on http://{addr}");
|
println!("🦀 ZeroClaw Gateway listening on http://{addr}");
|
||||||
println!(" POST /webhook — {{\"message\": \"your prompt\"}}");
|
println!(" POST /webhook — {{\"message\": \"your prompt\"}}");
|
||||||
println!(" GET /health — health check");
|
println!(" GET /health — health check");
|
||||||
|
if webhook_secret.is_some() {
|
||||||
|
println!(" 🔒 Webhook authentication: ENABLED (X-Webhook-Secret header required)");
|
||||||
|
} else {
|
||||||
|
println!(" ⚠️ Webhook authentication: DISABLED (set [channels.webhook] secret to enable)");
|
||||||
|
}
|
||||||
println!(" Press Ctrl+C to stop.\n");
|
println!(" Press Ctrl+C to stop.\n");
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
|
|
@ -36,6 +49,7 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
|
||||||
let model = model.clone();
|
let model = model.clone();
|
||||||
let mem = mem.clone();
|
let mem = mem.clone();
|
||||||
let auto_save = config.memory.auto_save;
|
let auto_save = config.memory.auto_save;
|
||||||
|
let secret = webhook_secret.clone();
|
||||||
|
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
let mut buf = vec![0u8; 8192];
|
let mut buf = vec![0u8; 8192];
|
||||||
|
|
@ -50,7 +64,7 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
|
||||||
|
|
||||||
if let [method, path, ..] = parts.as_slice() {
|
if let [method, path, ..] = parts.as_slice() {
|
||||||
tracing::info!("{peer} → {method} {path}");
|
tracing::info!("{peer} → {method} {path}");
|
||||||
handle_request(&mut stream, method, path, &request, &provider, &model, temperature, &mem, auto_save).await;
|
handle_request(&mut stream, method, path, &request, &provider, &model, temperature, &mem, auto_save, secret.as_ref()).await;
|
||||||
} else {
|
} else {
|
||||||
let _ = send_response(&mut stream, 400, "Bad Request").await;
|
let _ = send_response(&mut stream, 400, "Bad Request").await;
|
||||||
}
|
}
|
||||||
|
|
@ -58,6 +72,19 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Extract a header value from a raw HTTP request.
|
||||||
|
fn extract_header<'a>(request: &'a str, header_name: &str) -> Option<&'a str> {
|
||||||
|
let lower_name = header_name.to_lowercase();
|
||||||
|
for line in request.lines() {
|
||||||
|
if let Some((key, value)) = line.split_once(':') {
|
||||||
|
if key.trim().to_lowercase() == lower_name {
|
||||||
|
return Some(value.trim());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
async fn handle_request(
|
async fn handle_request(
|
||||||
stream: &mut tokio::net::TcpStream,
|
stream: &mut tokio::net::TcpStream,
|
||||||
|
|
@ -69,6 +96,7 @@ async fn handle_request(
|
||||||
temperature: f64,
|
temperature: f64,
|
||||||
mem: &Arc<dyn Memory>,
|
mem: &Arc<dyn Memory>,
|
||||||
auto_save: bool,
|
auto_save: bool,
|
||||||
|
webhook_secret: Option<&Arc<str>>,
|
||||||
) {
|
) {
|
||||||
match (method, path) {
|
match (method, path) {
|
||||||
("GET", "/health") => {
|
("GET", "/health") => {
|
||||||
|
|
@ -82,6 +110,19 @@ async fn handle_request(
|
||||||
}
|
}
|
||||||
|
|
||||||
("POST", "/webhook") => {
|
("POST", "/webhook") => {
|
||||||
|
// Authenticate webhook requests if a secret is configured
|
||||||
|
if let Some(secret) = webhook_secret {
|
||||||
|
let header_val = extract_header(request, "X-Webhook-Secret");
|
||||||
|
match header_val {
|
||||||
|
Some(val) if val == secret.as_ref() => {}
|
||||||
|
_ => {
|
||||||
|
tracing::warn!("Webhook: rejected request — invalid or missing X-Webhook-Secret");
|
||||||
|
let err = serde_json::json!({"error": "Unauthorized — invalid or missing X-Webhook-Secret header"});
|
||||||
|
let _ = send_json(stream, 401, &err).await;
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
handle_webhook(stream, request, provider, model, temperature, mem, auto_save).await;
|
handle_webhook(stream, request, provider, model, temperature, mem, auto_save).await;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -159,6 +200,35 @@ async fn send_response(
|
||||||
stream.write_all(response.as_bytes()).await
|
stream.write_all(response.as_bytes()).await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn extract_header_finds_value() {
|
||||||
|
let req = "POST /webhook HTTP/1.1\r\nHost: localhost\r\nX-Webhook-Secret: my-secret\r\n\r\n{}";
|
||||||
|
assert_eq!(extract_header(req, "X-Webhook-Secret"), Some("my-secret"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn extract_header_case_insensitive() {
|
||||||
|
let req = "POST /webhook HTTP/1.1\r\nx-webhook-secret: abc123\r\n\r\n{}";
|
||||||
|
assert_eq!(extract_header(req, "X-Webhook-Secret"), Some("abc123"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn extract_header_missing_returns_none() {
|
||||||
|
let req = "POST /webhook HTTP/1.1\r\nHost: localhost\r\n\r\n{}";
|
||||||
|
assert_eq!(extract_header(req, "X-Webhook-Secret"), None);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn extract_header_trims_whitespace() {
|
||||||
|
let req = "POST /webhook HTTP/1.1\r\nX-Webhook-Secret: spaced \r\n\r\n{}";
|
||||||
|
assert_eq!(extract_header(req, "X-Webhook-Secret"), Some("spaced"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
async fn send_json(
|
async fn send_json(
|
||||||
stream: &mut tokio::net::TcpStream,
|
stream: &mut tokio::net::TcpStream,
|
||||||
status: u16,
|
status: u16,
|
||||||
|
|
|
||||||
|
|
@ -715,6 +715,7 @@ fn setup_channels() -> Result<ChannelsConfig> {
|
||||||
config.discord = Some(DiscordConfig {
|
config.discord = Some(DiscordConfig {
|
||||||
bot_token: token,
|
bot_token: token,
|
||||||
guild_id: if guild.is_empty() { None } else { Some(guild) },
|
guild_id: if guild.is_empty() { None } else { Some(guild) },
|
||||||
|
allowed_users: vec![],
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
2 => {
|
2 => {
|
||||||
|
|
@ -791,6 +792,7 @@ fn setup_channels() -> Result<ChannelsConfig> {
|
||||||
bot_token: token,
|
bot_token: token,
|
||||||
app_token: if app_token.is_empty() { None } else { Some(app_token) },
|
app_token: if app_token.is_empty() { None } else { Some(app_token) },
|
||||||
channel_id: if channel.is_empty() { None } else { Some(channel) },
|
channel_id: if channel.is_empty() { None } else { Some(channel) },
|
||||||
|
allowed_users: vec![],
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
3 => {
|
3 => {
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,7 @@
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use std::path::{Path, PathBuf};
|
use std::path::{Path, PathBuf};
|
||||||
|
use std::sync::Mutex;
|
||||||
|
use std::time::Instant;
|
||||||
|
|
||||||
/// How much autonomy the agent has
|
/// How much autonomy the agent has
|
||||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||||
|
|
@ -19,6 +21,47 @@ impl Default for AutonomyLevel {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Sliding-window action tracker for rate limiting.
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct ActionTracker {
|
||||||
|
/// Timestamps of recent actions (kept within the last hour).
|
||||||
|
actions: Mutex<Vec<Instant>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ActionTracker {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
actions: Mutex::new(Vec::new()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Record an action and return the current count within the window.
|
||||||
|
pub fn record(&self) -> usize {
|
||||||
|
let mut actions = self.actions.lock().unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||||
|
let cutoff = Instant::now().checked_sub(std::time::Duration::from_secs(3600)).unwrap_or_else(Instant::now);
|
||||||
|
actions.retain(|t| *t > cutoff);
|
||||||
|
actions.push(Instant::now());
|
||||||
|
actions.len()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Count of actions in the current window without recording.
|
||||||
|
pub fn count(&self) -> usize {
|
||||||
|
let mut actions = self.actions.lock().unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||||
|
let cutoff = Instant::now().checked_sub(std::time::Duration::from_secs(3600)).unwrap_or_else(Instant::now);
|
||||||
|
actions.retain(|t| *t > cutoff);
|
||||||
|
actions.len()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Clone for ActionTracker {
|
||||||
|
fn clone(&self) -> Self {
|
||||||
|
let actions = self.actions.lock().unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||||
|
Self {
|
||||||
|
actions: Mutex::new(actions.clone()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Security policy enforced on all tool executions
|
/// Security policy enforced on all tool executions
|
||||||
#[derive(Debug, Clone)]
|
#[derive(Debug, Clone)]
|
||||||
pub struct SecurityPolicy {
|
pub struct SecurityPolicy {
|
||||||
|
|
@ -29,6 +72,7 @@ pub struct SecurityPolicy {
|
||||||
pub forbidden_paths: Vec<String>,
|
pub forbidden_paths: Vec<String>,
|
||||||
pub max_actions_per_hour: u32,
|
pub max_actions_per_hour: u32,
|
||||||
pub max_cost_per_day_cents: u32,
|
pub max_cost_per_day_cents: u32,
|
||||||
|
pub tracker: ActionTracker,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for SecurityPolicy {
|
impl Default for SecurityPolicy {
|
||||||
|
|
@ -60,6 +104,7 @@ impl Default for SecurityPolicy {
|
||||||
],
|
],
|
||||||
max_actions_per_hour: 20,
|
max_actions_per_hour: 20,
|
||||||
max_cost_per_day_cents: 500,
|
max_cost_per_day_cents: 500,
|
||||||
|
tracker: ActionTracker::new(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -112,6 +157,18 @@ impl SecurityPolicy {
|
||||||
self.autonomy != AutonomyLevel::ReadOnly
|
self.autonomy != AutonomyLevel::ReadOnly
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Record an action and check if the rate limit has been exceeded.
|
||||||
|
/// Returns `true` if the action is allowed, `false` if rate-limited.
|
||||||
|
pub fn record_action(&self) -> bool {
|
||||||
|
let count = self.tracker.record();
|
||||||
|
count <= self.max_actions_per_hour as usize
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if the rate limit would be exceeded without recording.
|
||||||
|
pub fn is_rate_limited(&self) -> bool {
|
||||||
|
self.tracker.count() >= self.max_actions_per_hour as usize
|
||||||
|
}
|
||||||
|
|
||||||
/// Build from config sections
|
/// Build from config sections
|
||||||
pub fn from_config(
|
pub fn from_config(
|
||||||
autonomy_config: &crate::config::AutonomyConfig,
|
autonomy_config: &crate::config::AutonomyConfig,
|
||||||
|
|
@ -125,6 +182,7 @@ impl SecurityPolicy {
|
||||||
forbidden_paths: autonomy_config.forbidden_paths.clone(),
|
forbidden_paths: autonomy_config.forbidden_paths.clone(),
|
||||||
max_actions_per_hour: autonomy_config.max_actions_per_hour,
|
max_actions_per_hour: autonomy_config.max_actions_per_hour,
|
||||||
max_cost_per_day_cents: autonomy_config.max_cost_per_day_cents,
|
max_cost_per_day_cents: autonomy_config.max_cost_per_day_cents,
|
||||||
|
tracker: ActionTracker::new(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -362,4 +420,69 @@ mod tests {
|
||||||
assert!(p.max_actions_per_hour > 0);
|
assert!(p.max_actions_per_hour > 0);
|
||||||
assert!(p.max_cost_per_day_cents > 0);
|
assert!(p.max_cost_per_day_cents > 0);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ── ActionTracker / rate limiting ───────────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn action_tracker_starts_at_zero() {
|
||||||
|
let tracker = ActionTracker::new();
|
||||||
|
assert_eq!(tracker.count(), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn action_tracker_records_actions() {
|
||||||
|
let tracker = ActionTracker::new();
|
||||||
|
assert_eq!(tracker.record(), 1);
|
||||||
|
assert_eq!(tracker.record(), 2);
|
||||||
|
assert_eq!(tracker.record(), 3);
|
||||||
|
assert_eq!(tracker.count(), 3);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn record_action_allows_within_limit() {
|
||||||
|
let p = SecurityPolicy {
|
||||||
|
max_actions_per_hour: 5,
|
||||||
|
..SecurityPolicy::default()
|
||||||
|
};
|
||||||
|
for _ in 0..5 {
|
||||||
|
assert!(p.record_action(), "should allow actions within limit");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn record_action_blocks_over_limit() {
|
||||||
|
let p = SecurityPolicy {
|
||||||
|
max_actions_per_hour: 3,
|
||||||
|
..SecurityPolicy::default()
|
||||||
|
};
|
||||||
|
assert!(p.record_action()); // 1
|
||||||
|
assert!(p.record_action()); // 2
|
||||||
|
assert!(p.record_action()); // 3
|
||||||
|
assert!(!p.record_action()); // 4 — over limit
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn is_rate_limited_reflects_count() {
|
||||||
|
let p = SecurityPolicy {
|
||||||
|
max_actions_per_hour: 2,
|
||||||
|
..SecurityPolicy::default()
|
||||||
|
};
|
||||||
|
assert!(!p.is_rate_limited());
|
||||||
|
p.record_action();
|
||||||
|
assert!(!p.is_rate_limited());
|
||||||
|
p.record_action();
|
||||||
|
assert!(p.is_rate_limited());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn action_tracker_clone_is_independent() {
|
||||||
|
let tracker = ActionTracker::new();
|
||||||
|
tracker.record();
|
||||||
|
tracker.record();
|
||||||
|
let cloned = tracker.clone();
|
||||||
|
assert_eq!(cloned.count(), 2);
|
||||||
|
tracker.record();
|
||||||
|
assert_eq!(tracker.count(), 3);
|
||||||
|
assert_eq!(cloned.count(), 2); // clone is independent
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue