use super::traits::{Tool, ToolResult}; use crate::memory::Memory; use async_trait::async_trait; use serde_json::json; use std::fmt::Write; use std::sync::Arc; /// Let the agent search its own memory pub struct MemoryRecallTool { memory: Arc, } impl MemoryRecallTool { pub fn new(memory: Arc) -> Self { Self { memory } } } #[async_trait] impl Tool for MemoryRecallTool { fn name(&self) -> &str { "memory_recall" } fn description(&self) -> &str { "Search long-term memory for relevant facts, preferences, or context. Returns scored results ranked by relevance." } fn parameters_schema(&self) -> serde_json::Value { json!({ "type": "object", "properties": { "query": { "type": "string", "description": "Keywords or phrase to search for in memory" }, "limit": { "type": "integer", "description": "Max results to return (default: 5)" } }, "required": ["query"] }) } async fn execute(&self, args: serde_json::Value) -> anyhow::Result { let query = args .get("query") .and_then(|v| v.as_str()) .ok_or_else(|| anyhow::anyhow!("Missing 'query' parameter"))?; #[allow(clippy::cast_possible_truncation)] let limit = args .get("limit") .and_then(serde_json::Value::as_u64) .map_or(5, |v| v as usize); match self.memory.recall(query, limit, None).await { Ok(entries) if entries.is_empty() => Ok(ToolResult { success: true, output: "No memories found matching that query.".into(), error: None, }), Ok(entries) => { let mut output = format!("Found {} memories:\n", entries.len()); for entry in &entries { let score = entry .score .map_or_else(String::new, |s| format!(" [{s:.0}%]")); let _ = writeln!( output, "- [{}] {}: {}{score}", entry.category, entry.key, entry.content ); } Ok(ToolResult { success: true, output, error: None, }) } Err(e) => Ok(ToolResult { success: false, output: String::new(), error: Some(format!("Memory recall failed: {e}")), }), } } } #[cfg(test)] mod tests { use super::*; use crate::memory::{MemoryCategory, SqliteMemory}; use tempfile::TempDir; fn seeded_mem() -> (TempDir, Arc) { let tmp = TempDir::new().unwrap(); let mem = SqliteMemory::new(tmp.path()).unwrap(); (tmp, Arc::new(mem)) } #[tokio::test] async fn recall_empty() { let (_tmp, mem) = seeded_mem(); let tool = MemoryRecallTool::new(mem); let result = tool.execute(json!({"query": "anything"})).await.unwrap(); assert!(result.success); assert!(result.output.contains("No memories found")); } #[tokio::test] async fn recall_finds_match() { let (_tmp, mem) = seeded_mem(); mem.store("lang", "User prefers Rust", MemoryCategory::Core, None) .await .unwrap(); mem.store("tz", "Timezone is EST", MemoryCategory::Core, None) .await .unwrap(); let tool = MemoryRecallTool::new(mem); let result = tool.execute(json!({"query": "Rust"})).await.unwrap(); assert!(result.success); assert!(result.output.contains("Rust")); assert!(result.output.contains("Found 1")); } #[tokio::test] async fn recall_respects_limit() { let (_tmp, mem) = seeded_mem(); for i in 0..10 { mem.store( &format!("k{i}"), &format!("Rust fact {i}"), MemoryCategory::Core, None, ) .await .unwrap(); } let tool = MemoryRecallTool::new(mem); let result = tool .execute(json!({"query": "Rust", "limit": 3})) .await .unwrap(); assert!(result.success); assert!(result.output.contains("Found 3")); } #[tokio::test] async fn recall_missing_query() { let (_tmp, mem) = seeded_mem(); let tool = MemoryRecallTool::new(mem); let result = tool.execute(json!({})).await; assert!(result.is_err()); } #[test] fn name_and_schema() { let (_tmp, mem) = seeded_mem(); let tool = MemoryRecallTool::new(mem); assert_eq!(tool.name(), "memory_recall"); assert!(tool.parameters_schema()["properties"]["query"].is_object()); } }