fix(policy): standardize side-effect tool autonomy gates
This commit is contained in:
parent
89d0fb9a1e
commit
4f9c87ff74
6 changed files with 369 additions and 38 deletions
|
|
@ -1,5 +1,7 @@
|
|||
use super::traits::{Tool, ToolResult};
|
||||
use crate::memory::Memory;
|
||||
use crate::security::policy::ToolOperation;
|
||||
use crate::security::SecurityPolicy;
|
||||
use async_trait::async_trait;
|
||||
use serde_json::json;
|
||||
use std::sync::Arc;
|
||||
|
|
@ -7,11 +9,12 @@ use std::sync::Arc;
|
|||
/// Let the agent forget/delete a memory entry
|
||||
pub struct MemoryForgetTool {
|
||||
memory: Arc<dyn Memory>,
|
||||
security: Arc<SecurityPolicy>,
|
||||
}
|
||||
|
||||
impl MemoryForgetTool {
|
||||
pub fn new(memory: Arc<dyn Memory>) -> Self {
|
||||
Self { memory }
|
||||
pub fn new(memory: Arc<dyn Memory>, security: Arc<SecurityPolicy>) -> Self {
|
||||
Self { memory, security }
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -44,6 +47,17 @@ impl Tool for MemoryForgetTool {
|
|||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'key' parameter"))?;
|
||||
|
||||
if let Err(error) = self
|
||||
.security
|
||||
.enforce_tool_operation(ToolOperation::Act, "memory_forget")
|
||||
{
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(error),
|
||||
});
|
||||
}
|
||||
|
||||
match self.memory.forget(key).await {
|
||||
Ok(true) => Ok(ToolResult {
|
||||
success: true,
|
||||
|
|
@ -68,8 +82,13 @@ impl Tool for MemoryForgetTool {
|
|||
mod tests {
|
||||
use super::*;
|
||||
use crate::memory::{MemoryCategory, SqliteMemory};
|
||||
use crate::security::{AutonomyLevel, SecurityPolicy};
|
||||
use tempfile::TempDir;
|
||||
|
||||
fn test_security() -> Arc<SecurityPolicy> {
|
||||
Arc::new(SecurityPolicy::default())
|
||||
}
|
||||
|
||||
fn test_mem() -> (TempDir, Arc<dyn Memory>) {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let mem = SqliteMemory::new(tmp.path()).unwrap();
|
||||
|
|
@ -79,7 +98,7 @@ mod tests {
|
|||
#[test]
|
||||
fn name_and_schema() {
|
||||
let (_tmp, mem) = test_mem();
|
||||
let tool = MemoryForgetTool::new(mem);
|
||||
let tool = MemoryForgetTool::new(mem, test_security());
|
||||
assert_eq!(tool.name(), "memory_forget");
|
||||
assert!(tool.parameters_schema()["properties"]["key"].is_object());
|
||||
}
|
||||
|
|
@ -91,7 +110,7 @@ mod tests {
|
|||
.await
|
||||
.unwrap();
|
||||
|
||||
let tool = MemoryForgetTool::new(mem.clone());
|
||||
let tool = MemoryForgetTool::new(mem.clone(), test_security());
|
||||
let result = tool.execute(json!({"key": "temp"})).await.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("Forgot"));
|
||||
|
|
@ -102,7 +121,7 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn forget_nonexistent() {
|
||||
let (_tmp, mem) = test_mem();
|
||||
let tool = MemoryForgetTool::new(mem);
|
||||
let tool = MemoryForgetTool::new(mem, test_security());
|
||||
let result = tool.execute(json!({"key": "nope"})).await.unwrap();
|
||||
assert!(result.success);
|
||||
assert!(result.output.contains("No memory found"));
|
||||
|
|
@ -111,8 +130,50 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn forget_missing_key() {
|
||||
let (_tmp, mem) = test_mem();
|
||||
let tool = MemoryForgetTool::new(mem);
|
||||
let tool = MemoryForgetTool::new(mem, test_security());
|
||||
let result = tool.execute(json!({})).await;
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn forget_blocked_in_readonly_mode() {
|
||||
let (_tmp, mem) = test_mem();
|
||||
mem.store("temp", "temporary", MemoryCategory::Conversation, None)
|
||||
.await
|
||||
.unwrap();
|
||||
let readonly = Arc::new(SecurityPolicy {
|
||||
autonomy: AutonomyLevel::ReadOnly,
|
||||
..SecurityPolicy::default()
|
||||
});
|
||||
let tool = MemoryForgetTool::new(mem.clone(), readonly);
|
||||
let result = tool.execute(json!({"key": "temp"})).await.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result
|
||||
.error
|
||||
.as_deref()
|
||||
.unwrap_or("")
|
||||
.contains("read-only mode"));
|
||||
assert!(mem.get("temp").await.unwrap().is_some());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn forget_blocked_when_rate_limited() {
|
||||
let (_tmp, mem) = test_mem();
|
||||
mem.store("temp", "temporary", MemoryCategory::Conversation, None)
|
||||
.await
|
||||
.unwrap();
|
||||
let limited = Arc::new(SecurityPolicy {
|
||||
max_actions_per_hour: 0,
|
||||
..SecurityPolicy::default()
|
||||
});
|
||||
let tool = MemoryForgetTool::new(mem.clone(), limited);
|
||||
let result = tool.execute(json!({"key": "temp"})).await.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result
|
||||
.error
|
||||
.as_deref()
|
||||
.unwrap_or("")
|
||||
.contains("Rate limit exceeded"));
|
||||
assert!(mem.get("temp").await.unwrap().is_some());
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue