use super::traits::{Tool, ToolResult}; use crate::memory::Memory; use crate::security::policy::ToolOperation; use crate::security::SecurityPolicy; use async_trait::async_trait; use serde_json::json; use std::sync::Arc; /// Let the agent forget/delete a memory entry pub struct MemoryForgetTool { memory: Arc, security: Arc, } impl MemoryForgetTool { pub fn new(memory: Arc, security: Arc) -> Self { Self { memory, security } } } #[async_trait] impl Tool for MemoryForgetTool { fn name(&self) -> &str { "memory_forget" } fn description(&self) -> &str { "Remove a memory by key. Use to delete outdated facts or sensitive data. Returns whether the memory was found and removed." } fn parameters_schema(&self) -> serde_json::Value { json!({ "type": "object", "properties": { "key": { "type": "string", "description": "The key of the memory to forget" } }, "required": ["key"] }) } async fn execute(&self, args: serde_json::Value) -> anyhow::Result { let key = args .get("key") .and_then(|v| v.as_str()) .ok_or_else(|| anyhow::anyhow!("Missing 'key' parameter"))?; if let Err(error) = self .security .enforce_tool_operation(ToolOperation::Act, "memory_forget") { return Ok(ToolResult { success: false, output: String::new(), error: Some(error), }); } match self.memory.forget(key).await { Ok(true) => Ok(ToolResult { success: true, output: format!("Forgot memory: {key}"), error: None, }), Ok(false) => Ok(ToolResult { success: true, output: format!("No memory found with key: {key}"), error: None, }), Err(e) => Ok(ToolResult { success: false, output: String::new(), error: Some(format!("Failed to forget memory: {e}")), }), } } } #[cfg(test)] mod tests { use super::*; use crate::memory::{MemoryCategory, SqliteMemory}; use crate::security::{AutonomyLevel, SecurityPolicy}; use tempfile::TempDir; fn test_security() -> Arc { Arc::new(SecurityPolicy::default()) } fn test_mem() -> (TempDir, Arc) { let tmp = TempDir::new().unwrap(); let mem = SqliteMemory::new(tmp.path()).unwrap(); (tmp, Arc::new(mem)) } #[test] fn name_and_schema() { let (_tmp, mem) = test_mem(); let tool = MemoryForgetTool::new(mem, test_security()); assert_eq!(tool.name(), "memory_forget"); assert!(tool.parameters_schema()["properties"]["key"].is_object()); } #[tokio::test] async fn forget_existing() { let (_tmp, mem) = test_mem(); mem.store("temp", "temporary", MemoryCategory::Conversation, None) .await .unwrap(); let tool = MemoryForgetTool::new(mem.clone(), test_security()); let result = tool.execute(json!({"key": "temp"})).await.unwrap(); assert!(result.success); assert!(result.output.contains("Forgot")); assert!(mem.get("temp").await.unwrap().is_none()); } #[tokio::test] async fn forget_nonexistent() { let (_tmp, mem) = test_mem(); let tool = MemoryForgetTool::new(mem, test_security()); let result = tool.execute(json!({"key": "nope"})).await.unwrap(); assert!(result.success); assert!(result.output.contains("No memory found")); } #[tokio::test] async fn forget_missing_key() { let (_tmp, mem) = test_mem(); let tool = MemoryForgetTool::new(mem, test_security()); let result = tool.execute(json!({})).await; assert!(result.is_err()); } #[tokio::test] async fn forget_blocked_in_readonly_mode() { let (_tmp, mem) = test_mem(); mem.store("temp", "temporary", MemoryCategory::Conversation, None) .await .unwrap(); let readonly = Arc::new(SecurityPolicy { autonomy: AutonomyLevel::ReadOnly, ..SecurityPolicy::default() }); let tool = MemoryForgetTool::new(mem.clone(), readonly); let result = tool.execute(json!({"key": "temp"})).await.unwrap(); assert!(!result.success); assert!(result .error .as_deref() .unwrap_or("") .contains("read-only mode")); assert!(mem.get("temp").await.unwrap().is_some()); } #[tokio::test] async fn forget_blocked_when_rate_limited() { let (_tmp, mem) = test_mem(); mem.store("temp", "temporary", MemoryCategory::Conversation, None) .await .unwrap(); let limited = Arc::new(SecurityPolicy { max_actions_per_hour: 0, ..SecurityPolicy::default() }); let tool = MemoryForgetTool::new(mem.clone(), limited); let result = tool.execute(json!({"key": "temp"})).await.unwrap(); assert!(!result.success); assert!(result .error .as_deref() .unwrap_or("") .contains("Rate limit exceeded")); assert!(mem.get("temp").await.unwrap().is_some()); } }