zeroclaw/src/tools/memory_recall.rs
argenis de la rosa bc31e4389b style: cargo fmt — fix all formatting for CI
Ran cargo fmt across entire codebase to pass CI's cargo fmt --check.
No logic changes, only whitespace/formatting.
2026-02-13 16:03:50 -05:00

166 lines
4.9 KiB
Rust

use super::traits::{Tool, ToolResult};
use crate::memory::Memory;
use async_trait::async_trait;
use serde_json::json;
use std::fmt::Write;
use std::sync::Arc;
/// Let the agent search its own memory
pub struct MemoryRecallTool {
memory: Arc<dyn Memory>,
}
impl MemoryRecallTool {
pub fn new(memory: Arc<dyn Memory>) -> Self {
Self { memory }
}
}
#[async_trait]
impl Tool for MemoryRecallTool {
fn name(&self) -> &str {
"memory_recall"
}
fn description(&self) -> &str {
"Search long-term memory for relevant facts, preferences, or context. Returns scored results ranked by relevance."
}
fn parameters_schema(&self) -> serde_json::Value {
json!({
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "Keywords or phrase to search for in memory"
},
"limit": {
"type": "integer",
"description": "Max results to return (default: 5)"
}
},
"required": ["query"]
})
}
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
let query = args
.get("query")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing 'query' parameter"))?;
#[allow(clippy::cast_possible_truncation)]
let limit = args
.get("limit")
.and_then(serde_json::Value::as_u64)
.map_or(5, |v| v as usize);
match self.memory.recall(query, limit).await {
Ok(entries) if entries.is_empty() => Ok(ToolResult {
success: true,
output: "No memories found matching that query.".into(),
error: None,
}),
Ok(entries) => {
let mut output = format!("Found {} memories:\n", entries.len());
for entry in &entries {
let score = entry
.score
.map_or_else(String::new, |s| format!(" [{s:.0}%]"));
let _ = writeln!(
output,
"- [{}] {}: {}{score}",
entry.category, entry.key, entry.content
);
}
Ok(ToolResult {
success: true,
output,
error: None,
})
}
Err(e) => Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!("Memory recall failed: {e}")),
}),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::memory::{MemoryCategory, SqliteMemory};
use tempfile::TempDir;
fn seeded_mem() -> (TempDir, Arc<dyn Memory>) {
let tmp = TempDir::new().unwrap();
let mem = SqliteMemory::new(tmp.path()).unwrap();
(tmp, Arc::new(mem))
}
#[tokio::test]
async fn recall_empty() {
let (_tmp, mem) = seeded_mem();
let tool = MemoryRecallTool::new(mem);
let result = tool.execute(json!({"query": "anything"})).await.unwrap();
assert!(result.success);
assert!(result.output.contains("No memories found"));
}
#[tokio::test]
async fn recall_finds_match() {
let (_tmp, mem) = seeded_mem();
mem.store("lang", "User prefers Rust", MemoryCategory::Core)
.await
.unwrap();
mem.store("tz", "Timezone is EST", MemoryCategory::Core)
.await
.unwrap();
let tool = MemoryRecallTool::new(mem);
let result = tool.execute(json!({"query": "Rust"})).await.unwrap();
assert!(result.success);
assert!(result.output.contains("Rust"));
assert!(result.output.contains("Found 1"));
}
#[tokio::test]
async fn recall_respects_limit() {
let (_tmp, mem) = seeded_mem();
for i in 0..10 {
mem.store(
&format!("k{i}"),
&format!("Rust fact {i}"),
MemoryCategory::Core,
)
.await
.unwrap();
}
let tool = MemoryRecallTool::new(mem);
let result = tool
.execute(json!({"query": "Rust", "limit": 3}))
.await
.unwrap();
assert!(result.success);
assert!(result.output.contains("Found 3"));
}
#[tokio::test]
async fn recall_missing_query() {
let (_tmp, mem) = seeded_mem();
let tool = MemoryRecallTool::new(mem);
let result = tool.execute(json!({})).await;
assert!(result.is_err());
}
#[test]
fn name_and_schema() {
let (_tmp, mem) = seeded_mem();
let tool = MemoryRecallTool::new(mem);
assert_eq!(tool.name(), "memory_recall");
assert!(tool.parameters_schema()["properties"]["query"].is_object());
}
}