fix(policy): standardize side-effect tool autonomy gates
This commit is contained in:
parent
89d0fb9a1e
commit
4f9c87ff74
6 changed files with 369 additions and 38 deletions
|
|
@ -24,6 +24,13 @@ pub enum CommandRiskLevel {
|
|||
High,
|
||||
}
|
||||
|
||||
/// Classifies whether a tool operation is read-only or side-effecting.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum ToolOperation {
|
||||
Read,
|
||||
Act,
|
||||
}
|
||||
|
||||
/// Sliding-window action tracker for rate limiting.
|
||||
#[derive(Debug)]
|
||||
pub struct ActionTracker {
|
||||
|
|
@ -530,6 +537,33 @@ impl SecurityPolicy {
|
|||
self.autonomy != AutonomyLevel::ReadOnly
|
||||
}
|
||||
|
||||
/// Enforce policy for a tool operation.
|
||||
///
|
||||
/// Read operations are always allowed by autonomy/rate gates.
|
||||
/// Act operations require non-readonly autonomy and available action budget.
|
||||
pub fn enforce_tool_operation(
|
||||
&self,
|
||||
operation: ToolOperation,
|
||||
operation_name: &str,
|
||||
) -> Result<(), String> {
|
||||
match operation {
|
||||
ToolOperation::Read => Ok(()),
|
||||
ToolOperation::Act => {
|
||||
if !self.can_act() {
|
||||
return Err(format!(
|
||||
"Security policy: read-only mode, cannot perform '{operation_name}'"
|
||||
));
|
||||
}
|
||||
|
||||
if !self.record_action() {
|
||||
return Err("Rate limit exceeded: action budget exhausted".to_string());
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Record an action and check if the rate limit has been exceeded.
|
||||
/// Returns `true` if the action is allowed, `false` if rate-limited.
|
||||
pub fn record_action(&self) -> bool {
|
||||
|
|
@ -616,6 +650,35 @@ mod tests {
|
|||
assert!(full_policy().can_act());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn enforce_tool_operation_read_allowed_in_readonly_mode() {
|
||||
let p = readonly_policy();
|
||||
assert!(p
|
||||
.enforce_tool_operation(ToolOperation::Read, "memory_recall")
|
||||
.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn enforce_tool_operation_act_blocked_in_readonly_mode() {
|
||||
let p = readonly_policy();
|
||||
let err = p
|
||||
.enforce_tool_operation(ToolOperation::Act, "memory_store")
|
||||
.unwrap_err();
|
||||
assert!(err.contains("read-only mode"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn enforce_tool_operation_act_uses_rate_budget() {
|
||||
let p = SecurityPolicy {
|
||||
max_actions_per_hour: 0,
|
||||
..default_policy()
|
||||
};
|
||||
let err = p
|
||||
.enforce_tool_operation(ToolOperation::Act, "memory_store")
|
||||
.unwrap_err();
|
||||
assert!(err.contains("Rate limit exceeded"));
|
||||
}
|
||||
|
||||
// ── is_command_allowed ───────────────────────────────────
|
||||
|
||||
#[test]
|
||||
|
|
|
|||
|
|
@ -7,11 +7,14 @@
|
|||
// The Composio API key is stored in the encrypted secret store.
|
||||
|
||||
use super::traits::{Tool, ToolResult};
|
||||
use crate::security::policy::ToolOperation;
|
||||
use crate::security::SecurityPolicy;
|
||||
use anyhow::Context;
|
||||
use async_trait::async_trait;
|
||||
use reqwest::Client;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
use std::sync::Arc;
|
||||
|
||||
const COMPOSIO_API_BASE_V2: &str = "https://backend.composio.dev/api/v2";
|
||||
const COMPOSIO_API_BASE_V3: &str = "https://backend.composio.dev/api/v3";
|
||||
|
|
@ -20,14 +23,20 @@ const COMPOSIO_API_BASE_V3: &str = "https://backend.composio.dev/api/v3";
|
|||
pub struct ComposioTool {
|
||||
api_key: String,
|
||||
default_entity_id: String,
|
||||
security: Arc<SecurityPolicy>,
|
||||
client: Client,
|
||||
}
|
||||
|
||||
impl ComposioTool {
|
||||
pub fn new(api_key: &str, default_entity_id: Option<&str>) -> Self {
|
||||
pub fn new(
|
||||
api_key: &str,
|
||||
default_entity_id: Option<&str>,
|
||||
security: Arc<SecurityPolicy>,
|
||||
) -> Self {
|
||||
Self {
|
||||
api_key: api_key.to_string(),
|
||||
default_entity_id: normalize_entity_id(default_entity_id.unwrap_or("default")),
|
||||
security,
|
||||
client: Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(60))
|
||||
.connect_timeout(std::time::Duration::from_secs(10))
|
||||
|
|
@ -481,6 +490,17 @@ impl Tool for ComposioTool {
|
|||
}
|
||||
|
||||
"execute" => {
|
||||
if let Err(error) = self
|
||||
.security
|
||||
.enforce_tool_operation(ToolOperation::Act, "composio.execute")
|
||||
{
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(error),
|
||||
});
|
||||
}
|
||||
|
||||
let action_name = args
|
||||
.get("tool_slug")
|
||||
.or_else(|| args.get("action_name"))
|
||||
|
|
@ -515,6 +535,17 @@ impl Tool for ComposioTool {
|
|||
}
|
||||
|
||||
"connect" => {
|
||||
if let Err(error) = self
|
||||
.security
|
||||
.enforce_tool_operation(ToolOperation::Act, "composio.connect")
|
||||
{
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(error),
|
||||
});
|
||||
}
|
||||
|
||||
let app = args.get("app").and_then(|v| v.as_str());
|
||||
let auth_config_id = args.get("auth_config_id").and_then(|v| v.as_str());
|
||||
|
||||
|
|
@ -734,25 +765,30 @@ pub struct ComposioAction {
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::security::{AutonomyLevel, SecurityPolicy};
|
||||
|
||||
fn test_security() -> Arc<SecurityPolicy> {
|
||||
Arc::new(SecurityPolicy::default())
|
||||
}
|
||||
|
||||
// ── Constructor ───────────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn composio_tool_has_correct_name() {
|
||||
let tool = ComposioTool::new("test-key", None);
|
||||
let tool = ComposioTool::new("test-key", None, test_security());
|
||||
assert_eq!(tool.name(), "composio");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn composio_tool_has_description() {
|
||||
let tool = ComposioTool::new("test-key", None);
|
||||
let tool = ComposioTool::new("test-key", None, test_security());
|
||||
assert!(!tool.description().is_empty());
|
||||
assert!(tool.description().contains("1000+"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn composio_tool_schema_has_required_fields() {
|
||||
let tool = ComposioTool::new("test-key", None);
|
||||
let tool = ComposioTool::new("test-key", None, test_security());
|
||||
let schema = tool.parameters_schema();
|
||||
assert!(schema["properties"]["action"].is_object());
|
||||
assert!(schema["properties"]["action_name"].is_object());
|
||||
|
|
@ -767,7 +803,7 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn composio_tool_spec_roundtrip() {
|
||||
let tool = ComposioTool::new("test-key", None);
|
||||
let tool = ComposioTool::new("test-key", None, test_security());
|
||||
let spec = tool.spec();
|
||||
assert_eq!(spec.name, "composio");
|
||||
assert!(spec.parameters.is_object());
|
||||
|
|
@ -777,14 +813,14 @@ mod tests {
|
|||
|
||||
#[tokio::test]
|
||||
async fn execute_missing_action_returns_error() {
|
||||
let tool = ComposioTool::new("test-key", None);
|
||||
let tool = ComposioTool::new("test-key", None, test_security());
|
||||
let result = tool.execute(json!({})).await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn execute_unknown_action_returns_error() {
|
||||
let tool = ComposioTool::new("test-key", None);
|
||||
let tool = ComposioTool::new("test-key", None, test_security());
|
||||
let result = tool.execute(json!({"action": "unknown"})).await.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.as_ref().unwrap().contains("Unknown action"));
|
||||
|
|
@ -792,18 +828,62 @@ mod tests {
|
|||
|
||||
#[tokio::test]
|
||||
async fn execute_without_action_name_returns_error() {
|
||||
let tool = ComposioTool::new("test-key", None);
|
||||
let tool = ComposioTool::new("test-key", None, test_security());
|
||||
let result = tool.execute(json!({"action": "execute"})).await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn connect_without_target_returns_error() {
|
||||
let tool = ComposioTool::new("test-key", None);
|
||||
let tool = ComposioTool::new("test-key", None, test_security());
|
||||
let result = tool.execute(json!({"action": "connect"})).await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn execute_blocked_in_readonly_mode() {
|
||||
let readonly = Arc::new(SecurityPolicy {
|
||||
autonomy: AutonomyLevel::ReadOnly,
|
||||
..SecurityPolicy::default()
|
||||
});
|
||||
let tool = ComposioTool::new("test-key", None, readonly);
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"action": "execute",
|
||||
"action_name": "GITHUB_LIST_REPOS"
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result
|
||||
.error
|
||||
.as_deref()
|
||||
.unwrap_or("")
|
||||
.contains("read-only mode"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn execute_blocked_when_rate_limited() {
|
||||
let limited = Arc::new(SecurityPolicy {
|
||||
max_actions_per_hour: 0,
|
||||
..SecurityPolicy::default()
|
||||
});
|
||||
let tool = ComposioTool::new("test-key", None, limited);
|
||||
let result = tool
|
||||
.execute(json!({
|
||||
"action": "execute",
|
||||
"action_name": "GITHUB_LIST_REPOS"
|
||||
}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result
|
||||
.error
|
||||
.as_deref()
|
||||
.unwrap_or("")
|
||||
.contains("Rate limit exceeded"));
|
||||
}
|
||||
|
||||
// ── API response parsing ──────────────────────────────────
|
||||
|
||||
#[test]
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
use super::traits::{Tool, ToolResult};
|
||||
use crate::config::DelegateAgentConfig;
|
||||
use crate::providers::{self, Provider};
|
||||
use crate::security::policy::ToolOperation;
|
||||
use crate::security::SecurityPolicy;
|
||||
use async_trait::async_trait;
|
||||
use serde_json::json;
|
||||
use std::collections::HashMap;
|
||||
|
|
@ -16,6 +18,7 @@ const DELEGATE_TIMEOUT_SECS: u64 = 120;
|
|||
/// summarization) to purpose-built sub-agents.
|
||||
pub struct DelegateTool {
|
||||
agents: Arc<HashMap<String, DelegateAgentConfig>>,
|
||||
security: Arc<SecurityPolicy>,
|
||||
/// Global credential fallback (from config.api_key)
|
||||
fallback_credential: Option<String>,
|
||||
/// Depth at which this tool instance lives in the delegation chain.
|
||||
|
|
@ -26,9 +29,11 @@ impl DelegateTool {
|
|||
pub fn new(
|
||||
agents: HashMap<String, DelegateAgentConfig>,
|
||||
fallback_credential: Option<String>,
|
||||
security: Arc<SecurityPolicy>,
|
||||
) -> Self {
|
||||
Self {
|
||||
agents: Arc::new(agents),
|
||||
security,
|
||||
fallback_credential,
|
||||
depth: 0,
|
||||
}
|
||||
|
|
@ -40,10 +45,12 @@ impl DelegateTool {
|
|||
pub fn with_depth(
|
||||
agents: HashMap<String, DelegateAgentConfig>,
|
||||
fallback_credential: Option<String>,
|
||||
security: Arc<SecurityPolicy>,
|
||||
depth: u32,
|
||||
) -> Self {
|
||||
Self {
|
||||
agents: Arc::new(agents),
|
||||
security,
|
||||
fallback_credential,
|
||||
depth,
|
||||
}
|
||||
|
|
@ -164,6 +171,17 @@ impl Tool for DelegateTool {
|
|||
});
|
||||
}
|
||||
|
||||
if let Err(error) = self
|
||||
.security
|
||||
.enforce_tool_operation(ToolOperation::Act, "delegate")
|
||||
{
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(error),
|
||||
});
|
||||
}
|
||||
|
||||
// Create provider for this agent
|
||||
let provider_credential_owned = agent_config
|
||||
.api_key
|
||||
|
|
@ -250,6 +268,11 @@ impl Tool for DelegateTool {
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::security::{AutonomyLevel, SecurityPolicy};
|
||||
|
||||
fn test_security() -> Arc<SecurityPolicy> {
|
||||
Arc::new(SecurityPolicy::default())
|
||||
}
|
||||
|
||||
fn sample_agents() -> HashMap<String, DelegateAgentConfig> {
|
||||
let mut agents = HashMap::new();
|
||||
|
|
@ -280,7 +303,7 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn name_and_schema() {
|
||||
let tool = DelegateTool::new(sample_agents(), None);
|
||||
let tool = DelegateTool::new(sample_agents(), None, test_security());
|
||||
assert_eq!(tool.name(), "delegate");
|
||||
let schema = tool.parameters_schema();
|
||||
assert!(schema["properties"]["agent"].is_object());
|
||||
|
|
@ -296,13 +319,13 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn description_not_empty() {
|
||||
let tool = DelegateTool::new(sample_agents(), None);
|
||||
let tool = DelegateTool::new(sample_agents(), None, test_security());
|
||||
assert!(!tool.description().is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn schema_lists_agent_names() {
|
||||
let tool = DelegateTool::new(sample_agents(), None);
|
||||
let tool = DelegateTool::new(sample_agents(), None, test_security());
|
||||
let schema = tool.parameters_schema();
|
||||
let desc = schema["properties"]["agent"]["description"]
|
||||
.as_str()
|
||||
|
|
@ -312,21 +335,21 @@ mod tests {
|
|||
|
||||
#[tokio::test]
|
||||
async fn missing_agent_param() {
|
||||
let tool = DelegateTool::new(sample_agents(), None);
|
||||
let tool = DelegateTool::new(sample_agents(), None, test_security());
|
||||
let result = tool.execute(json!({"prompt": "test"})).await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn missing_prompt_param() {
|
||||
let tool = DelegateTool::new(sample_agents(), None);
|
||||
let tool = DelegateTool::new(sample_agents(), None, test_security());
|
||||
let result = tool.execute(json!({"agent": "researcher"})).await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn unknown_agent_returns_error() {
|
||||
let tool = DelegateTool::new(sample_agents(), None);
|
||||
let tool = DelegateTool::new(sample_agents(), None, test_security());
|
||||
let result = tool
|
||||
.execute(json!({"agent": "nonexistent", "prompt": "test"}))
|
||||
.await
|
||||
|
|
@ -337,7 +360,7 @@ mod tests {
|
|||
|
||||
#[tokio::test]
|
||||
async fn depth_limit_enforced() {
|
||||
let tool = DelegateTool::with_depth(sample_agents(), None, 3);
|
||||
let tool = DelegateTool::with_depth(sample_agents(), None, test_security(), 3);
|
||||
let result = tool
|
||||
.execute(json!({"agent": "researcher", "prompt": "test"}))
|
||||
.await
|
||||
|
|
@ -349,7 +372,7 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn depth_limit_per_agent() {
|
||||
// coder has max_depth=2, so depth=2 should be blocked
|
||||
let tool = DelegateTool::with_depth(sample_agents(), None, 2);
|
||||
let tool = DelegateTool::with_depth(sample_agents(), None, test_security(), 2);
|
||||
let result = tool
|
||||
.execute(json!({"agent": "coder", "prompt": "test"}))
|
||||
.await
|
||||
|
|
@ -360,7 +383,7 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn empty_agents_schema() {
|
||||
let tool = DelegateTool::new(HashMap::new(), None);
|
||||
let tool = DelegateTool::new(HashMap::new(), None, test_security());
|
||||
let schema = tool.parameters_schema();
|
||||
let desc = schema["properties"]["agent"]["description"]
|
||||
.as_str()
|
||||
|
|
@ -382,7 +405,7 @@ mod tests {
|
|||
max_depth: 3,
|
||||
},
|
||||
);
|
||||
let tool = DelegateTool::new(agents, None);
|
||||
let tool = DelegateTool::new(agents, None, test_security());
|
||||
let result = tool
|
||||
.execute(json!({"agent": "broken", "prompt": "test"}))
|
||||
.await
|
||||
|
|
@ -393,7 +416,7 @@ mod tests {
|
|||
|
||||
#[tokio::test]
|
||||
async fn blank_agent_rejected() {
|
||||
let tool = DelegateTool::new(sample_agents(), None);
|
||||
let tool = DelegateTool::new(sample_agents(), None, test_security());
|
||||
let result = tool
|
||||
.execute(json!({"agent": " ", "prompt": "test"}))
|
||||
.await
|
||||
|
|
@ -404,7 +427,7 @@ mod tests {
|
|||
|
||||
#[tokio::test]
|
||||
async fn blank_prompt_rejected() {
|
||||
let tool = DelegateTool::new(sample_agents(), None);
|
||||
let tool = DelegateTool::new(sample_agents(), None, test_security());
|
||||
let result = tool
|
||||
.execute(json!({"agent": "researcher", "prompt": " \t "}))
|
||||
.await
|
||||
|
|
@ -415,7 +438,7 @@ mod tests {
|
|||
|
||||
#[tokio::test]
|
||||
async fn whitespace_agent_name_trimmed_and_found() {
|
||||
let tool = DelegateTool::new(sample_agents(), None);
|
||||
let tool = DelegateTool::new(sample_agents(), None, test_security());
|
||||
// " researcher " with surrounding whitespace — after trim becomes "researcher"
|
||||
let result = tool
|
||||
.execute(json!({"agent": " researcher ", "prompt": "test"}))
|
||||
|
|
@ -432,4 +455,42 @@ mod tests {
|
|||
.contains("Unknown agent")
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn delegation_blocked_in_readonly_mode() {
|
||||
let readonly = Arc::new(SecurityPolicy {
|
||||
autonomy: AutonomyLevel::ReadOnly,
|
||||
..SecurityPolicy::default()
|
||||
});
|
||||
let tool = DelegateTool::new(sample_agents(), None, readonly);
|
||||
let result = tool
|
||||
.execute(json!({"agent": "researcher", "prompt": "test"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result
|
||||
.error
|
||||
.as_deref()
|
||||
.unwrap_or("")
|
||||
.contains("read-only mode"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn delegation_blocked_when_rate_limited() {
|
||||
let limited = Arc::new(SecurityPolicy {
|
||||
max_actions_per_hour: 0,
|
||||
..SecurityPolicy::default()
|
||||
});
|
||||
let tool = DelegateTool::new(sample_agents(), None, limited);
|
||||
let result = tool
|
||||
.execute(json!({"agent": "researcher", "prompt": "test"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result
|
||||
.error
|
||||
.as_deref()
|
||||
.unwrap_or("")
|
||||
.contains("Rate limit exceeded"));
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
use super::traits::{Tool, ToolResult};
|
||||
use crate::memory::Memory;
|
||||
use crate::security::policy::ToolOperation;
|
||||
use crate::security::SecurityPolicy;
|
||||
use async_trait::async_trait;
|
||||
use serde_json::json;
|
||||
use std::sync::Arc;
|
||||
|
|
@ -7,11 +9,12 @@ use std::sync::Arc;
|
|||
/// Let the agent forget/delete a memory entry
|
||||
pub struct MemoryForgetTool {
|
||||
memory: Arc<dyn Memory>,
|
||||
security: Arc<SecurityPolicy>,
|
||||
}
|
||||
|
||||
impl MemoryForgetTool {
|
||||
pub fn new(memory: Arc<dyn Memory>) -> Self {
|
||||
Self { memory }
|
||||
pub fn new(memory: Arc<dyn Memory>, security: Arc<SecurityPolicy>) -> Self {
|
||||
Self { memory, security }
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -44,6 +47,17 @@ impl Tool for MemoryForgetTool {
|
|||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'key' parameter"))?;
|
||||
|
||||
if let Err(error) = self
|
||||
.security
|
||||
.enforce_tool_operation(ToolOperation::Act, "memory_forget")
|
||||
{
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(error),
|
||||
});
|
||||
}
|
||||
|
||||
match self.memory.forget(key).await {
|
||||
Ok(true) => Ok(ToolResult {
|
||||
success: true,
|
||||
|
|
@ -68,8 +82,13 @@ impl Tool for MemoryForgetTool {
|
|||
mod tests {
|
||||
use super::*;
|
||||
use crate::memory::{MemoryCategory, SqliteMemory};
|
||||
use crate::security::{AutonomyLevel, SecurityPolicy};
|
||||
use tempfile::TempDir;
|
||||
|
||||
fn test_security() -> Arc<SecurityPolicy> {
|
||||
Arc::new(SecurityPolicy::default())
|
||||
}
|
||||
|
||||
fn test_mem() -> (TempDir, Arc<dyn Memory>) {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let mem = SqliteMemory::new(tmp.path()).unwrap();
|
||||
|
|
@ -79,7 +98,7 @@ mod tests {
|
|||
#[test]
|
||||
fn name_and_schema() {
|
||||
let (_tmp, mem) = test_mem();
|
||||
let tool = MemoryForgetTool::new(mem);
|
||||
let tool = MemoryForgetTool::new(mem, test_security());
|
||||
assert_eq!(tool.name(), "memory_forget");
|
||||
assert!(tool.parameters_schema()["properties"]["key"].is_object());
|
||||
}
|
||||
|
|
@ -91,7 +110,7 @@ mod tests {
|
|||
.await
|
||||
.unwrap();
|
||||
|
||||
let tool = MemoryForgetTool::new(mem.clone());
|
||||
let tool = MemoryForgetTool::new(mem.clone(), test_security());
|
||||
let result = tool.execute(json!({"key": "temp"})).await.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("Forgot"));
|
||||
|
|
@ -102,7 +121,7 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn forget_nonexistent() {
|
||||
let (_tmp, mem) = test_mem();
|
||||
let tool = MemoryForgetTool::new(mem);
|
||||
let tool = MemoryForgetTool::new(mem, test_security());
|
||||
let result = tool.execute(json!({"key": "nope"})).await.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("No memory found"));
|
||||
|
|
@ -111,8 +130,50 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn forget_missing_key() {
|
||||
let (_tmp, mem) = test_mem();
|
||||
let tool = MemoryForgetTool::new(mem);
|
||||
let tool = MemoryForgetTool::new(mem, test_security());
|
||||
let result = tool.execute(json!({})).await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn forget_blocked_in_readonly_mode() {
|
||||
let (_tmp, mem) = test_mem();
|
||||
mem.store("temp", "temporary", MemoryCategory::Conversation, None)
|
||||
.await
|
||||
.unwrap();
|
||||
let readonly = Arc::new(SecurityPolicy {
|
||||
autonomy: AutonomyLevel::ReadOnly,
|
||||
..SecurityPolicy::default()
|
||||
});
|
||||
let tool = MemoryForgetTool::new(mem.clone(), readonly);
|
||||
let result = tool.execute(json!({"key": "temp"})).await.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result
|
||||
.error
|
||||
.as_deref()
|
||||
.unwrap_or("")
|
||||
.contains("read-only mode"));
|
||||
assert!(mem.get("temp").await.unwrap().is_some());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn forget_blocked_when_rate_limited() {
|
||||
let (_tmp, mem) = test_mem();
|
||||
mem.store("temp", "temporary", MemoryCategory::Conversation, None)
|
||||
.await
|
||||
.unwrap();
|
||||
let limited = Arc::new(SecurityPolicy {
|
||||
max_actions_per_hour: 0,
|
||||
..SecurityPolicy::default()
|
||||
});
|
||||
let tool = MemoryForgetTool::new(mem.clone(), limited);
|
||||
let result = tool.execute(json!({"key": "temp"})).await.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result
|
||||
.error
|
||||
.as_deref()
|
||||
.unwrap_or("")
|
||||
.contains("Rate limit exceeded"));
|
||||
assert!(mem.get("temp").await.unwrap().is_some());
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
use super::traits::{Tool, ToolResult};
|
||||
use crate::memory::{Memory, MemoryCategory};
|
||||
use crate::security::policy::ToolOperation;
|
||||
use crate::security::SecurityPolicy;
|
||||
use async_trait::async_trait;
|
||||
use serde_json::json;
|
||||
use std::sync::Arc;
|
||||
|
|
@ -7,11 +9,12 @@ use std::sync::Arc;
|
|||
/// Let the agent store memories — its own brain writes
|
||||
pub struct MemoryStoreTool {
|
||||
memory: Arc<dyn Memory>,
|
||||
security: Arc<SecurityPolicy>,
|
||||
}
|
||||
|
||||
impl MemoryStoreTool {
|
||||
pub fn new(memory: Arc<dyn Memory>) -> Self {
|
||||
Self { memory }
|
||||
pub fn new(memory: Arc<dyn Memory>, security: Arc<SecurityPolicy>) -> Self {
|
||||
Self { memory, security }
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -64,6 +67,17 @@ impl Tool for MemoryStoreTool {
|
|||
Some(other) => MemoryCategory::Custom(other.to_string()),
|
||||
};
|
||||
|
||||
if let Err(error) = self
|
||||
.security
|
||||
.enforce_tool_operation(ToolOperation::Act, "memory_store")
|
||||
{
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(error),
|
||||
});
|
||||
}
|
||||
|
||||
match self.memory.store(key, content, category, None).await {
|
||||
Ok(()) => Ok(ToolResult {
|
||||
success: true,
|
||||
|
|
@ -83,8 +97,13 @@ impl Tool for MemoryStoreTool {
|
|||
mod tests {
|
||||
use super::*;
|
||||
use crate::memory::SqliteMemory;
|
||||
use crate::security::{AutonomyLevel, SecurityPolicy};
|
||||
use tempfile::TempDir;
|
||||
|
||||
fn test_security() -> Arc<SecurityPolicy> {
|
||||
Arc::new(SecurityPolicy::default())
|
||||
}
|
||||
|
||||
fn test_mem() -> (TempDir, Arc<dyn Memory>) {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let mem = SqliteMemory::new(tmp.path()).unwrap();
|
||||
|
|
@ -94,7 +113,7 @@ mod tests {
|
|||
#[test]
|
||||
fn name_and_schema() {
|
||||
let (_tmp, mem) = test_mem();
|
||||
let tool = MemoryStoreTool::new(mem);
|
||||
let tool = MemoryStoreTool::new(mem, test_security());
|
||||
assert_eq!(tool.name(), "memory_store");
|
||||
let schema = tool.parameters_schema();
|
||||
assert!(schema["properties"]["key"].is_object());
|
||||
|
|
@ -104,7 +123,7 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn store_core() {
|
||||
let (_tmp, mem) = test_mem();
|
||||
let tool = MemoryStoreTool::new(mem.clone());
|
||||
let tool = MemoryStoreTool::new(mem.clone(), test_security());
|
||||
let result = tool
|
||||
.execute(json!({"key": "lang", "content": "Prefers Rust"}))
|
||||
.await
|
||||
|
|
@ -120,7 +139,7 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn store_with_category() {
|
||||
let (_tmp, mem) = test_mem();
|
||||
let tool = MemoryStoreTool::new(mem.clone());
|
||||
let tool = MemoryStoreTool::new(mem.clone(), test_security());
|
||||
let result = tool
|
||||
.execute(json!({"key": "note", "content": "Fixed bug", "category": "daily"}))
|
||||
.await
|
||||
|
|
@ -148,7 +167,7 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn store_missing_key() {
|
||||
let (_tmp, mem) = test_mem();
|
||||
let tool = MemoryStoreTool::new(mem);
|
||||
let tool = MemoryStoreTool::new(mem, test_security());
|
||||
let result = tool.execute(json!({"content": "no key"})).await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
|
@ -156,8 +175,50 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn store_missing_content() {
|
||||
let (_tmp, mem) = test_mem();
|
||||
let tool = MemoryStoreTool::new(mem);
|
||||
let tool = MemoryStoreTool::new(mem, test_security());
|
||||
let result = tool.execute(json!({"key": "no_content"})).await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn store_blocked_in_readonly_mode() {
|
||||
let (_tmp, mem) = test_mem();
|
||||
let readonly = Arc::new(SecurityPolicy {
|
||||
autonomy: AutonomyLevel::ReadOnly,
|
||||
..SecurityPolicy::default()
|
||||
});
|
||||
let tool = MemoryStoreTool::new(mem.clone(), readonly);
|
||||
let result = tool
|
||||
.execute(json!({"key": "lang", "content": "Prefers Rust"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result
|
||||
.error
|
||||
.as_deref()
|
||||
.unwrap_or("")
|
||||
.contains("read-only mode"));
|
||||
assert!(mem.get("lang").await.unwrap().is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn store_blocked_when_rate_limited() {
|
||||
let (_tmp, mem) = test_mem();
|
||||
let limited = Arc::new(SecurityPolicy {
|
||||
max_actions_per_hour: 0,
|
||||
..SecurityPolicy::default()
|
||||
});
|
||||
let tool = MemoryStoreTool::new(mem.clone(), limited);
|
||||
let result = tool
|
||||
.execute(json!({"key": "lang", "content": "Prefers Rust"}))
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result
|
||||
.error
|
||||
.as_deref()
|
||||
.unwrap_or("")
|
||||
.contains("Rate limit exceeded"));
|
||||
assert!(mem.get("lang").await.unwrap().is_none());
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -138,9 +138,9 @@ pub fn all_tools_with_runtime(
|
|||
Box::new(CronUpdateTool::new(config.clone(), security.clone())),
|
||||
Box::new(CronRunTool::new(config.clone())),
|
||||
Box::new(CronRunsTool::new(config.clone())),
|
||||
Box::new(MemoryStoreTool::new(memory.clone())),
|
||||
Box::new(MemoryStoreTool::new(memory.clone(), security.clone())),
|
||||
Box::new(MemoryRecallTool::new(memory.clone())),
|
||||
Box::new(MemoryForgetTool::new(memory)),
|
||||
Box::new(MemoryForgetTool::new(memory, security.clone())),
|
||||
Box::new(ScheduleTool::new(security.clone(), root_config.clone())),
|
||||
Box::new(GitOperationsTool::new(
|
||||
security.clone(),
|
||||
|
|
@ -194,7 +194,11 @@ pub fn all_tools_with_runtime(
|
|||
|
||||
if let Some(key) = composio_key {
|
||||
if !key.is_empty() {
|
||||
tools.push(Box::new(ComposioTool::new(key, composio_entity_id)));
|
||||
tools.push(Box::new(ComposioTool::new(
|
||||
key,
|
||||
composio_entity_id,
|
||||
security.clone(),
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -211,6 +215,7 @@ pub fn all_tools_with_runtime(
|
|||
tools.push(Box::new(DelegateTool::new(
|
||||
delegate_agents,
|
||||
delegate_fallback_credential,
|
||||
security.clone(),
|
||||
)));
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue