From 4f9c87ff7429d84378003be4d04808a729d04409 Mon Sep 17 00:00:00 2001 From: fettpl <38704082+fettpl@users.noreply.github.com> Date: Tue, 17 Feb 2026 22:56:07 +0100 Subject: [PATCH] fix(policy): standardize side-effect tool autonomy gates --- src/security/policy.rs | 63 ++++++++++++++++++++++++ src/tools/composio.rs | 98 ++++++++++++++++++++++++++++++++++---- src/tools/delegate.rs | 87 ++++++++++++++++++++++++++++----- src/tools/memory_forget.rs | 73 +++++++++++++++++++++++++--- src/tools/memory_store.rs | 75 ++++++++++++++++++++++++++--- src/tools/mod.rs | 11 +++-- 6 files changed, 369 insertions(+), 38 deletions(-) diff --git a/src/security/policy.rs b/src/security/policy.rs index 7db3ef8..3e726dd 100644 --- a/src/security/policy.rs +++ b/src/security/policy.rs @@ -24,6 +24,13 @@ pub enum CommandRiskLevel { High, } +/// Classifies whether a tool operation is read-only or side-effecting. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ToolOperation { + Read, + Act, +} + /// Sliding-window action tracker for rate limiting. #[derive(Debug)] pub struct ActionTracker { @@ -530,6 +537,33 @@ impl SecurityPolicy { self.autonomy != AutonomyLevel::ReadOnly } + /// Enforce policy for a tool operation. + /// + /// Read operations are always allowed by autonomy/rate gates. + /// Act operations require non-readonly autonomy and available action budget. + pub fn enforce_tool_operation( + &self, + operation: ToolOperation, + operation_name: &str, + ) -> Result<(), String> { + match operation { + ToolOperation::Read => Ok(()), + ToolOperation::Act => { + if !self.can_act() { + return Err(format!( + "Security policy: read-only mode, cannot perform '{operation_name}'" + )); + } + + if !self.record_action() { + return Err("Rate limit exceeded: action budget exhausted".to_string()); + } + + Ok(()) + } + } + } + /// Record an action and check if the rate limit has been exceeded. /// Returns `true` if the action is allowed, `false` if rate-limited. pub fn record_action(&self) -> bool { @@ -616,6 +650,35 @@ mod tests { assert!(full_policy().can_act()); } + #[test] + fn enforce_tool_operation_read_allowed_in_readonly_mode() { + let p = readonly_policy(); + assert!(p + .enforce_tool_operation(ToolOperation::Read, "memory_recall") + .is_ok()); + } + + #[test] + fn enforce_tool_operation_act_blocked_in_readonly_mode() { + let p = readonly_policy(); + let err = p + .enforce_tool_operation(ToolOperation::Act, "memory_store") + .unwrap_err(); + assert!(err.contains("read-only mode")); + } + + #[test] + fn enforce_tool_operation_act_uses_rate_budget() { + let p = SecurityPolicy { + max_actions_per_hour: 0, + ..default_policy() + }; + let err = p + .enforce_tool_operation(ToolOperation::Act, "memory_store") + .unwrap_err(); + assert!(err.contains("Rate limit exceeded")); + } + // ── is_command_allowed ─────────────────────────────────── #[test] diff --git a/src/tools/composio.rs b/src/tools/composio.rs index 65f128e..916e571 100644 --- a/src/tools/composio.rs +++ b/src/tools/composio.rs @@ -7,11 +7,14 @@ // The Composio API key is stored in the encrypted secret store. use super::traits::{Tool, ToolResult}; +use crate::security::policy::ToolOperation; +use crate::security::SecurityPolicy; use anyhow::Context; use async_trait::async_trait; use reqwest::Client; use serde::{Deserialize, Serialize}; use serde_json::json; +use std::sync::Arc; const COMPOSIO_API_BASE_V2: &str = "https://backend.composio.dev/api/v2"; const COMPOSIO_API_BASE_V3: &str = "https://backend.composio.dev/api/v3"; @@ -20,14 +23,20 @@ const COMPOSIO_API_BASE_V3: &str = "https://backend.composio.dev/api/v3"; pub struct ComposioTool { api_key: String, default_entity_id: String, + security: Arc, client: Client, } impl ComposioTool { - pub fn new(api_key: &str, default_entity_id: Option<&str>) -> Self { + pub fn new( + api_key: &str, + default_entity_id: Option<&str>, + security: Arc, + ) -> Self { Self { api_key: api_key.to_string(), default_entity_id: normalize_entity_id(default_entity_id.unwrap_or("default")), + security, client: Client::builder() .timeout(std::time::Duration::from_secs(60)) .connect_timeout(std::time::Duration::from_secs(10)) @@ -481,6 +490,17 @@ impl Tool for ComposioTool { } "execute" => { + if let Err(error) = self + .security + .enforce_tool_operation(ToolOperation::Act, "composio.execute") + { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(error), + }); + } + let action_name = args .get("tool_slug") .or_else(|| args.get("action_name")) @@ -515,6 +535,17 @@ impl Tool for ComposioTool { } "connect" => { + if let Err(error) = self + .security + .enforce_tool_operation(ToolOperation::Act, "composio.connect") + { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(error), + }); + } + let app = args.get("app").and_then(|v| v.as_str()); let auth_config_id = args.get("auth_config_id").and_then(|v| v.as_str()); @@ -734,25 +765,30 @@ pub struct ComposioAction { #[cfg(test)] mod tests { use super::*; + use crate::security::{AutonomyLevel, SecurityPolicy}; + + fn test_security() -> Arc { + Arc::new(SecurityPolicy::default()) + } // ── Constructor ─────────────────────────────────────────── #[test] fn composio_tool_has_correct_name() { - let tool = ComposioTool::new("test-key", None); + let tool = ComposioTool::new("test-key", None, test_security()); assert_eq!(tool.name(), "composio"); } #[test] fn composio_tool_has_description() { - let tool = ComposioTool::new("test-key", None); + let tool = ComposioTool::new("test-key", None, test_security()); assert!(!tool.description().is_empty()); assert!(tool.description().contains("1000+")); } #[test] fn composio_tool_schema_has_required_fields() { - let tool = ComposioTool::new("test-key", None); + let tool = ComposioTool::new("test-key", None, test_security()); let schema = tool.parameters_schema(); assert!(schema["properties"]["action"].is_object()); assert!(schema["properties"]["action_name"].is_object()); @@ -767,7 +803,7 @@ mod tests { #[test] fn composio_tool_spec_roundtrip() { - let tool = ComposioTool::new("test-key", None); + let tool = ComposioTool::new("test-key", None, test_security()); let spec = tool.spec(); assert_eq!(spec.name, "composio"); assert!(spec.parameters.is_object()); @@ -777,14 +813,14 @@ mod tests { #[tokio::test] async fn execute_missing_action_returns_error() { - let tool = ComposioTool::new("test-key", None); + let tool = ComposioTool::new("test-key", None, test_security()); let result = tool.execute(json!({})).await; assert!(result.is_err()); } #[tokio::test] async fn execute_unknown_action_returns_error() { - let tool = ComposioTool::new("test-key", None); + let tool = ComposioTool::new("test-key", None, test_security()); let result = tool.execute(json!({"action": "unknown"})).await.unwrap(); assert!(!result.success); assert!(result.error.as_ref().unwrap().contains("Unknown action")); @@ -792,18 +828,62 @@ mod tests { #[tokio::test] async fn execute_without_action_name_returns_error() { - let tool = ComposioTool::new("test-key", None); + let tool = ComposioTool::new("test-key", None, test_security()); let result = tool.execute(json!({"action": "execute"})).await; assert!(result.is_err()); } #[tokio::test] async fn connect_without_target_returns_error() { - let tool = ComposioTool::new("test-key", None); + let tool = ComposioTool::new("test-key", None, test_security()); let result = tool.execute(json!({"action": "connect"})).await; assert!(result.is_err()); } + #[tokio::test] + async fn execute_blocked_in_readonly_mode() { + let readonly = Arc::new(SecurityPolicy { + autonomy: AutonomyLevel::ReadOnly, + ..SecurityPolicy::default() + }); + let tool = ComposioTool::new("test-key", None, readonly); + let result = tool + .execute(json!({ + "action": "execute", + "action_name": "GITHUB_LIST_REPOS" + })) + .await + .unwrap(); + assert!(!result.success); + assert!(result + .error + .as_deref() + .unwrap_or("") + .contains("read-only mode")); + } + + #[tokio::test] + async fn execute_blocked_when_rate_limited() { + let limited = Arc::new(SecurityPolicy { + max_actions_per_hour: 0, + ..SecurityPolicy::default() + }); + let tool = ComposioTool::new("test-key", None, limited); + let result = tool + .execute(json!({ + "action": "execute", + "action_name": "GITHUB_LIST_REPOS" + })) + .await + .unwrap(); + assert!(!result.success); + assert!(result + .error + .as_deref() + .unwrap_or("") + .contains("Rate limit exceeded")); + } + // ── API response parsing ────────────────────────────────── #[test] diff --git a/src/tools/delegate.rs b/src/tools/delegate.rs index 3de7872..2f3cd71 100644 --- a/src/tools/delegate.rs +++ b/src/tools/delegate.rs @@ -1,6 +1,8 @@ use super::traits::{Tool, ToolResult}; use crate::config::DelegateAgentConfig; use crate::providers::{self, Provider}; +use crate::security::policy::ToolOperation; +use crate::security::SecurityPolicy; use async_trait::async_trait; use serde_json::json; use std::collections::HashMap; @@ -16,6 +18,7 @@ const DELEGATE_TIMEOUT_SECS: u64 = 120; /// summarization) to purpose-built sub-agents. pub struct DelegateTool { agents: Arc>, + security: Arc, /// Global credential fallback (from config.api_key) fallback_credential: Option, /// Depth at which this tool instance lives in the delegation chain. @@ -26,9 +29,11 @@ impl DelegateTool { pub fn new( agents: HashMap, fallback_credential: Option, + security: Arc, ) -> Self { Self { agents: Arc::new(agents), + security, fallback_credential, depth: 0, } @@ -40,10 +45,12 @@ impl DelegateTool { pub fn with_depth( agents: HashMap, fallback_credential: Option, + security: Arc, depth: u32, ) -> Self { Self { agents: Arc::new(agents), + security, fallback_credential, depth, } @@ -164,6 +171,17 @@ impl Tool for DelegateTool { }); } + if let Err(error) = self + .security + .enforce_tool_operation(ToolOperation::Act, "delegate") + { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(error), + }); + } + // Create provider for this agent let provider_credential_owned = agent_config .api_key @@ -250,6 +268,11 @@ impl Tool for DelegateTool { #[cfg(test)] mod tests { use super::*; + use crate::security::{AutonomyLevel, SecurityPolicy}; + + fn test_security() -> Arc { + Arc::new(SecurityPolicy::default()) + } fn sample_agents() -> HashMap { let mut agents = HashMap::new(); @@ -280,7 +303,7 @@ mod tests { #[test] fn name_and_schema() { - let tool = DelegateTool::new(sample_agents(), None); + let tool = DelegateTool::new(sample_agents(), None, test_security()); assert_eq!(tool.name(), "delegate"); let schema = tool.parameters_schema(); assert!(schema["properties"]["agent"].is_object()); @@ -296,13 +319,13 @@ mod tests { #[test] fn description_not_empty() { - let tool = DelegateTool::new(sample_agents(), None); + let tool = DelegateTool::new(sample_agents(), None, test_security()); assert!(!tool.description().is_empty()); } #[test] fn schema_lists_agent_names() { - let tool = DelegateTool::new(sample_agents(), None); + let tool = DelegateTool::new(sample_agents(), None, test_security()); let schema = tool.parameters_schema(); let desc = schema["properties"]["agent"]["description"] .as_str() @@ -312,21 +335,21 @@ mod tests { #[tokio::test] async fn missing_agent_param() { - let tool = DelegateTool::new(sample_agents(), None); + let tool = DelegateTool::new(sample_agents(), None, test_security()); let result = tool.execute(json!({"prompt": "test"})).await; assert!(result.is_err()); } #[tokio::test] async fn missing_prompt_param() { - let tool = DelegateTool::new(sample_agents(), None); + let tool = DelegateTool::new(sample_agents(), None, test_security()); let result = tool.execute(json!({"agent": "researcher"})).await; assert!(result.is_err()); } #[tokio::test] async fn unknown_agent_returns_error() { - let tool = DelegateTool::new(sample_agents(), None); + let tool = DelegateTool::new(sample_agents(), None, test_security()); let result = tool .execute(json!({"agent": "nonexistent", "prompt": "test"})) .await @@ -337,7 +360,7 @@ mod tests { #[tokio::test] async fn depth_limit_enforced() { - let tool = DelegateTool::with_depth(sample_agents(), None, 3); + let tool = DelegateTool::with_depth(sample_agents(), None, test_security(), 3); let result = tool .execute(json!({"agent": "researcher", "prompt": "test"})) .await @@ -349,7 +372,7 @@ mod tests { #[tokio::test] async fn depth_limit_per_agent() { // coder has max_depth=2, so depth=2 should be blocked - let tool = DelegateTool::with_depth(sample_agents(), None, 2); + let tool = DelegateTool::with_depth(sample_agents(), None, test_security(), 2); let result = tool .execute(json!({"agent": "coder", "prompt": "test"})) .await @@ -360,7 +383,7 @@ mod tests { #[test] fn empty_agents_schema() { - let tool = DelegateTool::new(HashMap::new(), None); + let tool = DelegateTool::new(HashMap::new(), None, test_security()); let schema = tool.parameters_schema(); let desc = schema["properties"]["agent"]["description"] .as_str() @@ -382,7 +405,7 @@ mod tests { max_depth: 3, }, ); - let tool = DelegateTool::new(agents, None); + let tool = DelegateTool::new(agents, None, test_security()); let result = tool .execute(json!({"agent": "broken", "prompt": "test"})) .await @@ -393,7 +416,7 @@ mod tests { #[tokio::test] async fn blank_agent_rejected() { - let tool = DelegateTool::new(sample_agents(), None); + let tool = DelegateTool::new(sample_agents(), None, test_security()); let result = tool .execute(json!({"agent": " ", "prompt": "test"})) .await @@ -404,7 +427,7 @@ mod tests { #[tokio::test] async fn blank_prompt_rejected() { - let tool = DelegateTool::new(sample_agents(), None); + let tool = DelegateTool::new(sample_agents(), None, test_security()); let result = tool .execute(json!({"agent": "researcher", "prompt": " \t "})) .await @@ -415,7 +438,7 @@ mod tests { #[tokio::test] async fn whitespace_agent_name_trimmed_and_found() { - let tool = DelegateTool::new(sample_agents(), None); + let tool = DelegateTool::new(sample_agents(), None, test_security()); // " researcher " with surrounding whitespace — after trim becomes "researcher" let result = tool .execute(json!({"agent": " researcher ", "prompt": "test"})) @@ -432,4 +455,42 @@ mod tests { .contains("Unknown agent") ); } + + #[tokio::test] + async fn delegation_blocked_in_readonly_mode() { + let readonly = Arc::new(SecurityPolicy { + autonomy: AutonomyLevel::ReadOnly, + ..SecurityPolicy::default() + }); + let tool = DelegateTool::new(sample_agents(), None, readonly); + let result = tool + .execute(json!({"agent": "researcher", "prompt": "test"})) + .await + .unwrap(); + assert!(!result.success); + assert!(result + .error + .as_deref() + .unwrap_or("") + .contains("read-only mode")); + } + + #[tokio::test] + async fn delegation_blocked_when_rate_limited() { + let limited = Arc::new(SecurityPolicy { + max_actions_per_hour: 0, + ..SecurityPolicy::default() + }); + let tool = DelegateTool::new(sample_agents(), None, limited); + let result = tool + .execute(json!({"agent": "researcher", "prompt": "test"})) + .await + .unwrap(); + assert!(!result.success); + assert!(result + .error + .as_deref() + .unwrap_or("") + .contains("Rate limit exceeded")); + } } diff --git a/src/tools/memory_forget.rs b/src/tools/memory_forget.rs index a53885e..67e8ce6 100644 --- a/src/tools/memory_forget.rs +++ b/src/tools/memory_forget.rs @@ -1,5 +1,7 @@ use super::traits::{Tool, ToolResult}; use crate::memory::Memory; +use crate::security::policy::ToolOperation; +use crate::security::SecurityPolicy; use async_trait::async_trait; use serde_json::json; use std::sync::Arc; @@ -7,11 +9,12 @@ use std::sync::Arc; /// Let the agent forget/delete a memory entry pub struct MemoryForgetTool { memory: Arc, + security: Arc, } impl MemoryForgetTool { - pub fn new(memory: Arc) -> Self { - Self { memory } + pub fn new(memory: Arc, security: Arc) -> Self { + Self { memory, security } } } @@ -44,6 +47,17 @@ impl Tool for MemoryForgetTool { .and_then(|v| v.as_str()) .ok_or_else(|| anyhow::anyhow!("Missing 'key' parameter"))?; + if let Err(error) = self + .security + .enforce_tool_operation(ToolOperation::Act, "memory_forget") + { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(error), + }); + } + match self.memory.forget(key).await { Ok(true) => Ok(ToolResult { success: true, @@ -68,8 +82,13 @@ impl Tool for MemoryForgetTool { mod tests { use super::*; use crate::memory::{MemoryCategory, SqliteMemory}; + use crate::security::{AutonomyLevel, SecurityPolicy}; use tempfile::TempDir; + fn test_security() -> Arc { + Arc::new(SecurityPolicy::default()) + } + fn test_mem() -> (TempDir, Arc) { let tmp = TempDir::new().unwrap(); let mem = SqliteMemory::new(tmp.path()).unwrap(); @@ -79,7 +98,7 @@ mod tests { #[test] fn name_and_schema() { let (_tmp, mem) = test_mem(); - let tool = MemoryForgetTool::new(mem); + let tool = MemoryForgetTool::new(mem, test_security()); assert_eq!(tool.name(), "memory_forget"); assert!(tool.parameters_schema()["properties"]["key"].is_object()); } @@ -91,7 +110,7 @@ mod tests { .await .unwrap(); - let tool = MemoryForgetTool::new(mem.clone()); + let tool = MemoryForgetTool::new(mem.clone(), test_security()); let result = tool.execute(json!({"key": "temp"})).await.unwrap(); assert!(result.success); assert!(result.output.contains("Forgot")); @@ -102,7 +121,7 @@ mod tests { #[tokio::test] async fn forget_nonexistent() { let (_tmp, mem) = test_mem(); - let tool = MemoryForgetTool::new(mem); + let tool = MemoryForgetTool::new(mem, test_security()); let result = tool.execute(json!({"key": "nope"})).await.unwrap(); assert!(result.success); assert!(result.output.contains("No memory found")); @@ -111,8 +130,50 @@ mod tests { #[tokio::test] async fn forget_missing_key() { let (_tmp, mem) = test_mem(); - let tool = MemoryForgetTool::new(mem); + let tool = MemoryForgetTool::new(mem, test_security()); let result = tool.execute(json!({})).await; assert!(result.is_err()); } + + #[tokio::test] + async fn forget_blocked_in_readonly_mode() { + let (_tmp, mem) = test_mem(); + mem.store("temp", "temporary", MemoryCategory::Conversation, None) + .await + .unwrap(); + let readonly = Arc::new(SecurityPolicy { + autonomy: AutonomyLevel::ReadOnly, + ..SecurityPolicy::default() + }); + let tool = MemoryForgetTool::new(mem.clone(), readonly); + let result = tool.execute(json!({"key": "temp"})).await.unwrap(); + assert!(!result.success); + assert!(result + .error + .as_deref() + .unwrap_or("") + .contains("read-only mode")); + assert!(mem.get("temp").await.unwrap().is_some()); + } + + #[tokio::test] + async fn forget_blocked_when_rate_limited() { + let (_tmp, mem) = test_mem(); + mem.store("temp", "temporary", MemoryCategory::Conversation, None) + .await + .unwrap(); + let limited = Arc::new(SecurityPolicy { + max_actions_per_hour: 0, + ..SecurityPolicy::default() + }); + let tool = MemoryForgetTool::new(mem.clone(), limited); + let result = tool.execute(json!({"key": "temp"})).await.unwrap(); + assert!(!result.success); + assert!(result + .error + .as_deref() + .unwrap_or("") + .contains("Rate limit exceeded")); + assert!(mem.get("temp").await.unwrap().is_some()); + } } diff --git a/src/tools/memory_store.rs b/src/tools/memory_store.rs index 8d28714..1095f04 100644 --- a/src/tools/memory_store.rs +++ b/src/tools/memory_store.rs @@ -1,5 +1,7 @@ use super::traits::{Tool, ToolResult}; use crate::memory::{Memory, MemoryCategory}; +use crate::security::policy::ToolOperation; +use crate::security::SecurityPolicy; use async_trait::async_trait; use serde_json::json; use std::sync::Arc; @@ -7,11 +9,12 @@ use std::sync::Arc; /// Let the agent store memories — its own brain writes pub struct MemoryStoreTool { memory: Arc, + security: Arc, } impl MemoryStoreTool { - pub fn new(memory: Arc) -> Self { - Self { memory } + pub fn new(memory: Arc, security: Arc) -> Self { + Self { memory, security } } } @@ -64,6 +67,17 @@ impl Tool for MemoryStoreTool { Some(other) => MemoryCategory::Custom(other.to_string()), }; + if let Err(error) = self + .security + .enforce_tool_operation(ToolOperation::Act, "memory_store") + { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(error), + }); + } + match self.memory.store(key, content, category, None).await { Ok(()) => Ok(ToolResult { success: true, @@ -83,8 +97,13 @@ impl Tool for MemoryStoreTool { mod tests { use super::*; use crate::memory::SqliteMemory; + use crate::security::{AutonomyLevel, SecurityPolicy}; use tempfile::TempDir; + fn test_security() -> Arc { + Arc::new(SecurityPolicy::default()) + } + fn test_mem() -> (TempDir, Arc) { let tmp = TempDir::new().unwrap(); let mem = SqliteMemory::new(tmp.path()).unwrap(); @@ -94,7 +113,7 @@ mod tests { #[test] fn name_and_schema() { let (_tmp, mem) = test_mem(); - let tool = MemoryStoreTool::new(mem); + let tool = MemoryStoreTool::new(mem, test_security()); assert_eq!(tool.name(), "memory_store"); let schema = tool.parameters_schema(); assert!(schema["properties"]["key"].is_object()); @@ -104,7 +123,7 @@ mod tests { #[tokio::test] async fn store_core() { let (_tmp, mem) = test_mem(); - let tool = MemoryStoreTool::new(mem.clone()); + let tool = MemoryStoreTool::new(mem.clone(), test_security()); let result = tool .execute(json!({"key": "lang", "content": "Prefers Rust"})) .await @@ -120,7 +139,7 @@ mod tests { #[tokio::test] async fn store_with_category() { let (_tmp, mem) = test_mem(); - let tool = MemoryStoreTool::new(mem.clone()); + let tool = MemoryStoreTool::new(mem.clone(), test_security()); let result = tool .execute(json!({"key": "note", "content": "Fixed bug", "category": "daily"})) .await @@ -148,7 +167,7 @@ mod tests { #[tokio::test] async fn store_missing_key() { let (_tmp, mem) = test_mem(); - let tool = MemoryStoreTool::new(mem); + let tool = MemoryStoreTool::new(mem, test_security()); let result = tool.execute(json!({"content": "no key"})).await; assert!(result.is_err()); } @@ -156,8 +175,50 @@ mod tests { #[tokio::test] async fn store_missing_content() { let (_tmp, mem) = test_mem(); - let tool = MemoryStoreTool::new(mem); + let tool = MemoryStoreTool::new(mem, test_security()); let result = tool.execute(json!({"key": "no_content"})).await; assert!(result.is_err()); } + + #[tokio::test] + async fn store_blocked_in_readonly_mode() { + let (_tmp, mem) = test_mem(); + let readonly = Arc::new(SecurityPolicy { + autonomy: AutonomyLevel::ReadOnly, + ..SecurityPolicy::default() + }); + let tool = MemoryStoreTool::new(mem.clone(), readonly); + let result = tool + .execute(json!({"key": "lang", "content": "Prefers Rust"})) + .await + .unwrap(); + assert!(!result.success); + assert!(result + .error + .as_deref() + .unwrap_or("") + .contains("read-only mode")); + assert!(mem.get("lang").await.unwrap().is_none()); + } + + #[tokio::test] + async fn store_blocked_when_rate_limited() { + let (_tmp, mem) = test_mem(); + let limited = Arc::new(SecurityPolicy { + max_actions_per_hour: 0, + ..SecurityPolicy::default() + }); + let tool = MemoryStoreTool::new(mem.clone(), limited); + let result = tool + .execute(json!({"key": "lang", "content": "Prefers Rust"})) + .await + .unwrap(); + assert!(!result.success); + assert!(result + .error + .as_deref() + .unwrap_or("") + .contains("Rate limit exceeded")); + assert!(mem.get("lang").await.unwrap().is_none()); + } } diff --git a/src/tools/mod.rs b/src/tools/mod.rs index 3c6309f..03fc067 100644 --- a/src/tools/mod.rs +++ b/src/tools/mod.rs @@ -138,9 +138,9 @@ pub fn all_tools_with_runtime( Box::new(CronUpdateTool::new(config.clone(), security.clone())), Box::new(CronRunTool::new(config.clone())), Box::new(CronRunsTool::new(config.clone())), - Box::new(MemoryStoreTool::new(memory.clone())), + Box::new(MemoryStoreTool::new(memory.clone(), security.clone())), Box::new(MemoryRecallTool::new(memory.clone())), - Box::new(MemoryForgetTool::new(memory)), + Box::new(MemoryForgetTool::new(memory, security.clone())), Box::new(ScheduleTool::new(security.clone(), root_config.clone())), Box::new(GitOperationsTool::new( security.clone(), @@ -194,7 +194,11 @@ pub fn all_tools_with_runtime( if let Some(key) = composio_key { if !key.is_empty() { - tools.push(Box::new(ComposioTool::new(key, composio_entity_id))); + tools.push(Box::new(ComposioTool::new( + key, + composio_entity_id, + security.clone(), + ))); } } @@ -211,6 +215,7 @@ pub fn all_tools_with_runtime( tools.push(Box::new(DelegateTool::new( delegate_agents, delegate_fallback_credential, + security.clone(), ))); }