feat(memory): add session_id isolation to Memory trait (#530)
* feat(memory): add session_id isolation to Memory trait Add optional session_id parameter to store(), recall(), and list() methods across the Memory trait and all four backends (sqlite, markdown, lucid, none). This enables per-session memory isolation so different agent sessions cannot cross-read each other's stored memories. Changes: - traits.rs: Add session_id: Option<&str> to store/recall/list - sqlite.rs: Schema migration (ALTER TABLE ADD COLUMN session_id), index, persist/filter by session_id in all query paths - markdown.rs, lucid.rs, none.rs: Updated signatures - All callers pass None for backward compatibility - 5 new tests: session-filtered recall, cross-session isolation, session-filtered list, no-filter returns all, migration idempotency Closes #518 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix(channels): fix discord _channel_id typo and lark missing reply_to Pre-existing compilation errors on main after reply_to was added to ChannelMessage: discord.rs used _channel_id (underscore prefix) but referenced channel_id, and lark.rs was missing the reply_to field. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
f30f87662e
commit
ebb78afda4
16 changed files with 556 additions and 221 deletions
|
|
@ -389,7 +389,7 @@ impl Agent {
|
||||||
if self.auto_save {
|
if self.auto_save {
|
||||||
let _ = self
|
let _ = self
|
||||||
.memory
|
.memory
|
||||||
.store("user_msg", user_message, MemoryCategory::Conversation)
|
.store("user_msg", user_message, MemoryCategory::Conversation, None)
|
||||||
.await;
|
.await;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -448,7 +448,7 @@ impl Agent {
|
||||||
let summary = truncate_with_ellipsis(&final_text, 100);
|
let summary = truncate_with_ellipsis(&final_text, 100);
|
||||||
let _ = self
|
let _ = self
|
||||||
.memory
|
.memory
|
||||||
.store("assistant_resp", &summary, MemoryCategory::Daily)
|
.store("assistant_resp", &summary, MemoryCategory::Daily, None)
|
||||||
.await;
|
.await;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -145,7 +145,7 @@ async fn build_context(mem: &dyn Memory, user_msg: &str) -> String {
|
||||||
let mut context = String::new();
|
let mut context = String::new();
|
||||||
|
|
||||||
// Pull relevant memories for this message
|
// Pull relevant memories for this message
|
||||||
if let Ok(entries) = mem.recall(user_msg, 5).await {
|
if let Ok(entries) = mem.recall(user_msg, 5, None).await {
|
||||||
if !entries.is_empty() {
|
if !entries.is_empty() {
|
||||||
context.push_str("[Memory context]\n");
|
context.push_str("[Memory context]\n");
|
||||||
for entry in &entries {
|
for entry in &entries {
|
||||||
|
|
@ -913,7 +913,7 @@ pub async fn run(
|
||||||
if config.memory.auto_save {
|
if config.memory.auto_save {
|
||||||
let user_key = autosave_memory_key("user_msg");
|
let user_key = autosave_memory_key("user_msg");
|
||||||
let _ = mem
|
let _ = mem
|
||||||
.store(&user_key, &msg, MemoryCategory::Conversation)
|
.store(&user_key, &msg, MemoryCategory::Conversation, None)
|
||||||
.await;
|
.await;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -956,7 +956,7 @@ pub async fn run(
|
||||||
let summary = truncate_with_ellipsis(&response, 100);
|
let summary = truncate_with_ellipsis(&response, 100);
|
||||||
let response_key = autosave_memory_key("assistant_resp");
|
let response_key = autosave_memory_key("assistant_resp");
|
||||||
let _ = mem
|
let _ = mem
|
||||||
.store(&response_key, &summary, MemoryCategory::Daily)
|
.store(&response_key, &summary, MemoryCategory::Daily, None)
|
||||||
.await;
|
.await;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -979,7 +979,7 @@ pub async fn run(
|
||||||
if config.memory.auto_save {
|
if config.memory.auto_save {
|
||||||
let user_key = autosave_memory_key("user_msg");
|
let user_key = autosave_memory_key("user_msg");
|
||||||
let _ = mem
|
let _ = mem
|
||||||
.store(&user_key, &msg.content, MemoryCategory::Conversation)
|
.store(&user_key, &msg.content, MemoryCategory::Conversation, None)
|
||||||
.await;
|
.await;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1037,7 +1037,7 @@ pub async fn run(
|
||||||
let summary = truncate_with_ellipsis(&response, 100);
|
let summary = truncate_with_ellipsis(&response, 100);
|
||||||
let response_key = autosave_memory_key("assistant_resp");
|
let response_key = autosave_memory_key("assistant_resp");
|
||||||
let _ = mem
|
let _ = mem
|
||||||
.store(&response_key, &summary, MemoryCategory::Daily)
|
.store(&response_key, &summary, MemoryCategory::Daily, None)
|
||||||
.await;
|
.await;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -1499,16 +1499,16 @@ I will now call the tool with this payload:
|
||||||
let key1 = autosave_memory_key("user_msg");
|
let key1 = autosave_memory_key("user_msg");
|
||||||
let key2 = autosave_memory_key("user_msg");
|
let key2 = autosave_memory_key("user_msg");
|
||||||
|
|
||||||
mem.store(&key1, "I'm Paul", MemoryCategory::Conversation)
|
mem.store(&key1, "I'm Paul", MemoryCategory::Conversation, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
mem.store(&key2, "I'm 45", MemoryCategory::Conversation)
|
mem.store(&key2, "I'm 45", MemoryCategory::Conversation, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
assert_eq!(mem.count().await.unwrap(), 2);
|
assert_eq!(mem.count().await.unwrap(), 2);
|
||||||
|
|
||||||
let recalled = mem.recall("45", 5).await.unwrap();
|
let recalled = mem.recall("45", 5, None).await.unwrap();
|
||||||
assert!(recalled.iter().any(|entry| entry.content.contains("45")));
|
assert!(recalled.iter().any(|entry| entry.content.contains("45")));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -33,7 +33,7 @@ impl MemoryLoader for DefaultMemoryLoader {
|
||||||
memory: &dyn Memory,
|
memory: &dyn Memory,
|
||||||
user_message: &str,
|
user_message: &str,
|
||||||
) -> anyhow::Result<String> {
|
) -> anyhow::Result<String> {
|
||||||
let entries = memory.recall(user_message, self.limit).await?;
|
let entries = memory.recall(user_message, self.limit, None).await?;
|
||||||
if entries.is_empty() {
|
if entries.is_empty() {
|
||||||
return Ok(String::new());
|
return Ok(String::new());
|
||||||
}
|
}
|
||||||
|
|
@ -61,11 +61,17 @@ mod tests {
|
||||||
_key: &str,
|
_key: &str,
|
||||||
_content: &str,
|
_content: &str,
|
||||||
_category: MemoryCategory,
|
_category: MemoryCategory,
|
||||||
|
_session_id: Option<&str>,
|
||||||
) -> anyhow::Result<()> {
|
) -> anyhow::Result<()> {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn recall(&self, _query: &str, limit: usize) -> anyhow::Result<Vec<MemoryEntry>> {
|
async fn recall(
|
||||||
|
&self,
|
||||||
|
_query: &str,
|
||||||
|
limit: usize,
|
||||||
|
_session_id: Option<&str>,
|
||||||
|
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||||
if limit == 0 {
|
if limit == 0 {
|
||||||
return Ok(vec![]);
|
return Ok(vec![]);
|
||||||
}
|
}
|
||||||
|
|
@ -87,6 +93,7 @@ mod tests {
|
||||||
async fn list(
|
async fn list(
|
||||||
&self,
|
&self,
|
||||||
_category: Option<&MemoryCategory>,
|
_category: Option<&MemoryCategory>,
|
||||||
|
_session_id: Option<&str>,
|
||||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||||
Ok(vec![])
|
Ok(vec![])
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -72,7 +72,7 @@ fn conversation_memory_key(msg: &traits::ChannelMessage) -> String {
|
||||||
async fn build_memory_context(mem: &dyn Memory, user_msg: &str) -> String {
|
async fn build_memory_context(mem: &dyn Memory, user_msg: &str) -> String {
|
||||||
let mut context = String::new();
|
let mut context = String::new();
|
||||||
|
|
||||||
if let Ok(entries) = mem.recall(user_msg, 5).await {
|
if let Ok(entries) = mem.recall(user_msg, 5, None).await {
|
||||||
if !entries.is_empty() {
|
if !entries.is_empty() {
|
||||||
context.push_str("[Memory context]\n");
|
context.push_str("[Memory context]\n");
|
||||||
for entry in &entries {
|
for entry in &entries {
|
||||||
|
|
@ -158,6 +158,7 @@ async fn process_channel_message(ctx: Arc<ChannelRuntimeContext>, msg: traits::C
|
||||||
&autosave_key,
|
&autosave_key,
|
||||||
&msg.content,
|
&msg.content,
|
||||||
crate::memory::MemoryCategory::Conversation,
|
crate::memory::MemoryCategory::Conversation,
|
||||||
|
None,
|
||||||
)
|
)
|
||||||
.await;
|
.await;
|
||||||
}
|
}
|
||||||
|
|
@ -1260,6 +1261,7 @@ mod tests {
|
||||||
_key: &str,
|
_key: &str,
|
||||||
_content: &str,
|
_content: &str,
|
||||||
_category: crate::memory::MemoryCategory,
|
_category: crate::memory::MemoryCategory,
|
||||||
|
_session_id: Option<&str>,
|
||||||
) -> anyhow::Result<()> {
|
) -> anyhow::Result<()> {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
@ -1268,6 +1270,7 @@ mod tests {
|
||||||
&self,
|
&self,
|
||||||
_query: &str,
|
_query: &str,
|
||||||
_limit: usize,
|
_limit: usize,
|
||||||
|
_session_id: Option<&str>,
|
||||||
) -> anyhow::Result<Vec<crate::memory::MemoryEntry>> {
|
) -> anyhow::Result<Vec<crate::memory::MemoryEntry>> {
|
||||||
Ok(Vec::new())
|
Ok(Vec::new())
|
||||||
}
|
}
|
||||||
|
|
@ -1279,6 +1282,7 @@ mod tests {
|
||||||
async fn list(
|
async fn list(
|
||||||
&self,
|
&self,
|
||||||
_category: Option<&crate::memory::MemoryCategory>,
|
_category: Option<&crate::memory::MemoryCategory>,
|
||||||
|
_session_id: Option<&str>,
|
||||||
) -> anyhow::Result<Vec<crate::memory::MemoryEntry>> {
|
) -> anyhow::Result<Vec<crate::memory::MemoryEntry>> {
|
||||||
Ok(Vec::new())
|
Ok(Vec::new())
|
||||||
}
|
}
|
||||||
|
|
@ -1636,6 +1640,7 @@ mod tests {
|
||||||
&conversation_memory_key(&msg1),
|
&conversation_memory_key(&msg1),
|
||||||
&msg1.content,
|
&msg1.content,
|
||||||
MemoryCategory::Conversation,
|
MemoryCategory::Conversation,
|
||||||
|
None,
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
@ -1643,13 +1648,14 @@ mod tests {
|
||||||
&conversation_memory_key(&msg2),
|
&conversation_memory_key(&msg2),
|
||||||
&msg2.content,
|
&msg2.content,
|
||||||
MemoryCategory::Conversation,
|
MemoryCategory::Conversation,
|
||||||
|
None,
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
assert_eq!(mem.count().await.unwrap(), 2);
|
assert_eq!(mem.count().await.unwrap(), 2);
|
||||||
|
|
||||||
let recalled = mem.recall("45", 5).await.unwrap();
|
let recalled = mem.recall("45", 5, None).await.unwrap();
|
||||||
assert!(recalled.iter().any(|entry| entry.content.contains("45")));
|
assert!(recalled.iter().any(|entry| entry.content.contains("45")));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1657,7 +1663,7 @@ mod tests {
|
||||||
async fn build_memory_context_includes_recalled_entries() {
|
async fn build_memory_context_includes_recalled_entries() {
|
||||||
let tmp = TempDir::new().unwrap();
|
let tmp = TempDir::new().unwrap();
|
||||||
let mem = SqliteMemory::new(tmp.path()).unwrap();
|
let mem = SqliteMemory::new(tmp.path()).unwrap();
|
||||||
mem.store("age_fact", "Age is 45", MemoryCategory::Conversation)
|
mem.store("age_fact", "Age is 45", MemoryCategory::Conversation, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -544,7 +544,7 @@ async fn handle_webhook(
|
||||||
let key = webhook_memory_key();
|
let key = webhook_memory_key();
|
||||||
let _ = state
|
let _ = state
|
||||||
.mem
|
.mem
|
||||||
.store(&key, message, MemoryCategory::Conversation)
|
.store(&key, message, MemoryCategory::Conversation, None)
|
||||||
.await;
|
.await;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -697,7 +697,7 @@ async fn handle_whatsapp_message(
|
||||||
let key = whatsapp_memory_key(msg);
|
let key = whatsapp_memory_key(msg);
|
||||||
let _ = state
|
let _ = state
|
||||||
.mem
|
.mem
|
||||||
.store(&key, &msg.content, MemoryCategory::Conversation)
|
.store(&key, &msg.content, MemoryCategory::Conversation, None)
|
||||||
.await;
|
.await;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -886,11 +886,17 @@ mod tests {
|
||||||
_key: &str,
|
_key: &str,
|
||||||
_content: &str,
|
_content: &str,
|
||||||
_category: MemoryCategory,
|
_category: MemoryCategory,
|
||||||
|
_session_id: Option<&str>,
|
||||||
) -> anyhow::Result<()> {
|
) -> anyhow::Result<()> {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn recall(&self, _query: &str, _limit: usize) -> anyhow::Result<Vec<MemoryEntry>> {
|
async fn recall(
|
||||||
|
&self,
|
||||||
|
_query: &str,
|
||||||
|
_limit: usize,
|
||||||
|
_session_id: Option<&str>,
|
||||||
|
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||||
Ok(Vec::new())
|
Ok(Vec::new())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -901,6 +907,7 @@ mod tests {
|
||||||
async fn list(
|
async fn list(
|
||||||
&self,
|
&self,
|
||||||
_category: Option<&MemoryCategory>,
|
_category: Option<&MemoryCategory>,
|
||||||
|
_session_id: Option<&str>,
|
||||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||||
Ok(Vec::new())
|
Ok(Vec::new())
|
||||||
}
|
}
|
||||||
|
|
@ -953,6 +960,7 @@ mod tests {
|
||||||
key: &str,
|
key: &str,
|
||||||
_content: &str,
|
_content: &str,
|
||||||
_category: MemoryCategory,
|
_category: MemoryCategory,
|
||||||
|
_session_id: Option<&str>,
|
||||||
) -> anyhow::Result<()> {
|
) -> anyhow::Result<()> {
|
||||||
self.keys
|
self.keys
|
||||||
.lock()
|
.lock()
|
||||||
|
|
@ -961,7 +969,12 @@ mod tests {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn recall(&self, _query: &str, _limit: usize) -> anyhow::Result<Vec<MemoryEntry>> {
|
async fn recall(
|
||||||
|
&self,
|
||||||
|
_query: &str,
|
||||||
|
_limit: usize,
|
||||||
|
_session_id: Option<&str>,
|
||||||
|
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||||
Ok(Vec::new())
|
Ok(Vec::new())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -972,6 +985,7 @@ mod tests {
|
||||||
async fn list(
|
async fn list(
|
||||||
&self,
|
&self,
|
||||||
_category: Option<&MemoryCategory>,
|
_category: Option<&MemoryCategory>,
|
||||||
|
_session_id: Option<&str>,
|
||||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||||
Ok(Vec::new())
|
Ok(Vec::new())
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -502,10 +502,10 @@ mod tests {
|
||||||
let workspace = tmp.path();
|
let workspace = tmp.path();
|
||||||
|
|
||||||
let mem = SqliteMemory::new(workspace).unwrap();
|
let mem = SqliteMemory::new(workspace).unwrap();
|
||||||
mem.store("conv_old", "outdated", MemoryCategory::Conversation)
|
mem.store("conv_old", "outdated", MemoryCategory::Conversation, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
mem.store("core_keep", "durable", MemoryCategory::Core)
|
mem.store("core_keep", "durable", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
drop(mem);
|
drop(mem);
|
||||||
|
|
|
||||||
|
|
@ -314,14 +314,22 @@ impl Memory for LucidMemory {
|
||||||
key: &str,
|
key: &str,
|
||||||
content: &str,
|
content: &str,
|
||||||
category: MemoryCategory,
|
category: MemoryCategory,
|
||||||
|
session_id: Option<&str>,
|
||||||
) -> anyhow::Result<()> {
|
) -> anyhow::Result<()> {
|
||||||
self.local.store(key, content, category.clone()).await?;
|
self.local
|
||||||
|
.store(key, content, category.clone(), session_id)
|
||||||
|
.await?;
|
||||||
self.sync_to_lucid_async(key, content, &category).await;
|
self.sync_to_lucid_async(key, content, &category).await;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn recall(&self, query: &str, limit: usize) -> anyhow::Result<Vec<MemoryEntry>> {
|
async fn recall(
|
||||||
let local_results = self.local.recall(query, limit).await?;
|
&self,
|
||||||
|
query: &str,
|
||||||
|
limit: usize,
|
||||||
|
session_id: Option<&str>,
|
||||||
|
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||||
|
let local_results = self.local.recall(query, limit, session_id).await?;
|
||||||
if limit == 0
|
if limit == 0
|
||||||
|| local_results.len() >= limit
|
|| local_results.len() >= limit
|
||||||
|| local_results.len() >= self.local_hit_threshold
|
|| local_results.len() >= self.local_hit_threshold
|
||||||
|
|
@ -358,8 +366,12 @@ impl Memory for LucidMemory {
|
||||||
self.local.get(key).await
|
self.local.get(key).await
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn list(&self, category: Option<&MemoryCategory>) -> anyhow::Result<Vec<MemoryEntry>> {
|
async fn list(
|
||||||
self.local.list(category).await
|
&self,
|
||||||
|
category: Option<&MemoryCategory>,
|
||||||
|
session_id: Option<&str>,
|
||||||
|
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||||
|
self.local.list(category, session_id).await
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn forget(&self, key: &str) -> anyhow::Result<bool> {
|
async fn forget(&self, key: &str) -> anyhow::Result<bool> {
|
||||||
|
|
@ -475,7 +487,7 @@ exit 1
|
||||||
let memory = test_memory(tmp.path(), "nonexistent-lucid-binary".to_string());
|
let memory = test_memory(tmp.path(), "nonexistent-lucid-binary".to_string());
|
||||||
|
|
||||||
memory
|
memory
|
||||||
.store("lang", "User prefers Rust", MemoryCategory::Core)
|
.store("lang", "User prefers Rust", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
|
|
@ -495,11 +507,12 @@ exit 1
|
||||||
"local_note",
|
"local_note",
|
||||||
"Local sqlite auth fallback note",
|
"Local sqlite auth fallback note",
|
||||||
MemoryCategory::Core,
|
MemoryCategory::Core,
|
||||||
|
None,
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let entries = memory.recall("auth", 5).await.unwrap();
|
let entries = memory.recall("auth", 5, None).await.unwrap();
|
||||||
|
|
||||||
assert!(entries
|
assert!(entries
|
||||||
.iter()
|
.iter()
|
||||||
|
|
@ -526,11 +539,16 @@ exit 1
|
||||||
);
|
);
|
||||||
|
|
||||||
memory
|
memory
|
||||||
.store("pref", "Rust should stay local-first", MemoryCategory::Core)
|
.store(
|
||||||
|
"pref",
|
||||||
|
"Rust should stay local-first",
|
||||||
|
MemoryCategory::Core,
|
||||||
|
None,
|
||||||
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let entries = memory.recall("rust", 5).await.unwrap();
|
let entries = memory.recall("rust", 5, None).await.unwrap();
|
||||||
assert!(entries
|
assert!(entries
|
||||||
.iter()
|
.iter()
|
||||||
.any(|e| e.content.contains("Rust should stay local-first")));
|
.any(|e| e.content.contains("Rust should stay local-first")));
|
||||||
|
|
@ -590,8 +608,8 @@ exit 1
|
||||||
Duration::from_secs(5),
|
Duration::from_secs(5),
|
||||||
);
|
);
|
||||||
|
|
||||||
let first = memory.recall("auth", 5).await.unwrap();
|
let first = memory.recall("auth", 5, None).await.unwrap();
|
||||||
let second = memory.recall("auth", 5).await.unwrap();
|
let second = memory.recall("auth", 5, None).await.unwrap();
|
||||||
|
|
||||||
assert!(first.is_empty());
|
assert!(first.is_empty());
|
||||||
assert!(second.is_empty());
|
assert!(second.is_empty());
|
||||||
|
|
|
||||||
|
|
@ -143,6 +143,7 @@ impl Memory for MarkdownMemory {
|
||||||
key: &str,
|
key: &str,
|
||||||
content: &str,
|
content: &str,
|
||||||
category: MemoryCategory,
|
category: MemoryCategory,
|
||||||
|
_session_id: Option<&str>,
|
||||||
) -> anyhow::Result<()> {
|
) -> anyhow::Result<()> {
|
||||||
let entry = format!("- **{key}**: {content}");
|
let entry = format!("- **{key}**: {content}");
|
||||||
let path = match category {
|
let path = match category {
|
||||||
|
|
@ -152,7 +153,12 @@ impl Memory for MarkdownMemory {
|
||||||
self.append_to_file(&path, &entry).await
|
self.append_to_file(&path, &entry).await
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn recall(&self, query: &str, limit: usize) -> anyhow::Result<Vec<MemoryEntry>> {
|
async fn recall(
|
||||||
|
&self,
|
||||||
|
query: &str,
|
||||||
|
limit: usize,
|
||||||
|
_session_id: Option<&str>,
|
||||||
|
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||||
let all = self.read_all_entries().await?;
|
let all = self.read_all_entries().await?;
|
||||||
let query_lower = query.to_lowercase();
|
let query_lower = query.to_lowercase();
|
||||||
let keywords: Vec<&str> = query_lower.split_whitespace().collect();
|
let keywords: Vec<&str> = query_lower.split_whitespace().collect();
|
||||||
|
|
@ -192,7 +198,11 @@ impl Memory for MarkdownMemory {
|
||||||
.find(|e| e.key == key || e.content.contains(key)))
|
.find(|e| e.key == key || e.content.contains(key)))
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn list(&self, category: Option<&MemoryCategory>) -> anyhow::Result<Vec<MemoryEntry>> {
|
async fn list(
|
||||||
|
&self,
|
||||||
|
category: Option<&MemoryCategory>,
|
||||||
|
_session_id: Option<&str>,
|
||||||
|
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||||
let all = self.read_all_entries().await?;
|
let all = self.read_all_entries().await?;
|
||||||
match category {
|
match category {
|
||||||
Some(cat) => Ok(all.into_iter().filter(|e| &e.category == cat).collect()),
|
Some(cat) => Ok(all.into_iter().filter(|e| &e.category == cat).collect()),
|
||||||
|
|
@ -243,7 +253,7 @@ mod tests {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn markdown_store_core() {
|
async fn markdown_store_core() {
|
||||||
let (_tmp, mem) = temp_workspace();
|
let (_tmp, mem) = temp_workspace();
|
||||||
mem.store("pref", "User likes Rust", MemoryCategory::Core)
|
mem.store("pref", "User likes Rust", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let content = sync_fs::read_to_string(mem.core_path()).unwrap();
|
let content = sync_fs::read_to_string(mem.core_path()).unwrap();
|
||||||
|
|
@ -253,7 +263,7 @@ mod tests {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn markdown_store_daily() {
|
async fn markdown_store_daily() {
|
||||||
let (_tmp, mem) = temp_workspace();
|
let (_tmp, mem) = temp_workspace();
|
||||||
mem.store("note", "Finished tests", MemoryCategory::Daily)
|
mem.store("note", "Finished tests", MemoryCategory::Daily, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let path = mem.daily_path();
|
let path = mem.daily_path();
|
||||||
|
|
@ -264,17 +274,17 @@ mod tests {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn markdown_recall_keyword() {
|
async fn markdown_recall_keyword() {
|
||||||
let (_tmp, mem) = temp_workspace();
|
let (_tmp, mem) = temp_workspace();
|
||||||
mem.store("a", "Rust is fast", MemoryCategory::Core)
|
mem.store("a", "Rust is fast", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
mem.store("b", "Python is slow", MemoryCategory::Core)
|
mem.store("b", "Python is slow", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
mem.store("c", "Rust and safety", MemoryCategory::Core)
|
mem.store("c", "Rust and safety", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let results = mem.recall("Rust", 10).await.unwrap();
|
let results = mem.recall("Rust", 10, None).await.unwrap();
|
||||||
assert!(results.len() >= 2);
|
assert!(results.len() >= 2);
|
||||||
assert!(results
|
assert!(results
|
||||||
.iter()
|
.iter()
|
||||||
|
|
@ -284,18 +294,20 @@ mod tests {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn markdown_recall_no_match() {
|
async fn markdown_recall_no_match() {
|
||||||
let (_tmp, mem) = temp_workspace();
|
let (_tmp, mem) = temp_workspace();
|
||||||
mem.store("a", "Rust is great", MemoryCategory::Core)
|
mem.store("a", "Rust is great", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let results = mem.recall("javascript", 10).await.unwrap();
|
let results = mem.recall("javascript", 10, None).await.unwrap();
|
||||||
assert!(results.is_empty());
|
assert!(results.is_empty());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn markdown_count() {
|
async fn markdown_count() {
|
||||||
let (_tmp, mem) = temp_workspace();
|
let (_tmp, mem) = temp_workspace();
|
||||||
mem.store("a", "first", MemoryCategory::Core).await.unwrap();
|
mem.store("a", "first", MemoryCategory::Core, None)
|
||||||
mem.store("b", "second", MemoryCategory::Core)
|
.await
|
||||||
|
.unwrap();
|
||||||
|
mem.store("b", "second", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let count = mem.count().await.unwrap();
|
let count = mem.count().await.unwrap();
|
||||||
|
|
@ -305,24 +317,24 @@ mod tests {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn markdown_list_by_category() {
|
async fn markdown_list_by_category() {
|
||||||
let (_tmp, mem) = temp_workspace();
|
let (_tmp, mem) = temp_workspace();
|
||||||
mem.store("a", "core fact", MemoryCategory::Core)
|
mem.store("a", "core fact", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
mem.store("b", "daily note", MemoryCategory::Daily)
|
mem.store("b", "daily note", MemoryCategory::Daily, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let core = mem.list(Some(&MemoryCategory::Core)).await.unwrap();
|
let core = mem.list(Some(&MemoryCategory::Core), None).await.unwrap();
|
||||||
assert!(core.iter().all(|e| e.category == MemoryCategory::Core));
|
assert!(core.iter().all(|e| e.category == MemoryCategory::Core));
|
||||||
|
|
||||||
let daily = mem.list(Some(&MemoryCategory::Daily)).await.unwrap();
|
let daily = mem.list(Some(&MemoryCategory::Daily), None).await.unwrap();
|
||||||
assert!(daily.iter().all(|e| e.category == MemoryCategory::Daily));
|
assert!(daily.iter().all(|e| e.category == MemoryCategory::Daily));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn markdown_forget_is_noop() {
|
async fn markdown_forget_is_noop() {
|
||||||
let (_tmp, mem) = temp_workspace();
|
let (_tmp, mem) = temp_workspace();
|
||||||
mem.store("a", "permanent", MemoryCategory::Core)
|
mem.store("a", "permanent", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let removed = mem.forget("a").await.unwrap();
|
let removed = mem.forget("a").await.unwrap();
|
||||||
|
|
@ -332,7 +344,7 @@ mod tests {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn markdown_empty_recall() {
|
async fn markdown_empty_recall() {
|
||||||
let (_tmp, mem) = temp_workspace();
|
let (_tmp, mem) = temp_workspace();
|
||||||
let results = mem.recall("anything", 10).await.unwrap();
|
let results = mem.recall("anything", 10, None).await.unwrap();
|
||||||
assert!(results.is_empty());
|
assert!(results.is_empty());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -25,11 +25,17 @@ impl Memory for NoneMemory {
|
||||||
_key: &str,
|
_key: &str,
|
||||||
_content: &str,
|
_content: &str,
|
||||||
_category: MemoryCategory,
|
_category: MemoryCategory,
|
||||||
|
_session_id: Option<&str>,
|
||||||
) -> anyhow::Result<()> {
|
) -> anyhow::Result<()> {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn recall(&self, _query: &str, _limit: usize) -> anyhow::Result<Vec<MemoryEntry>> {
|
async fn recall(
|
||||||
|
&self,
|
||||||
|
_query: &str,
|
||||||
|
_limit: usize,
|
||||||
|
_session_id: Option<&str>,
|
||||||
|
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||||
Ok(Vec::new())
|
Ok(Vec::new())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -37,7 +43,11 @@ impl Memory for NoneMemory {
|
||||||
Ok(None)
|
Ok(None)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn list(&self, _category: Option<&MemoryCategory>) -> anyhow::Result<Vec<MemoryEntry>> {
|
async fn list(
|
||||||
|
&self,
|
||||||
|
_category: Option<&MemoryCategory>,
|
||||||
|
_session_id: Option<&str>,
|
||||||
|
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||||
Ok(Vec::new())
|
Ok(Vec::new())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -62,11 +72,14 @@ mod tests {
|
||||||
async fn none_memory_is_noop() {
|
async fn none_memory_is_noop() {
|
||||||
let memory = NoneMemory::new();
|
let memory = NoneMemory::new();
|
||||||
|
|
||||||
memory.store("k", "v", MemoryCategory::Core).await.unwrap();
|
memory
|
||||||
|
.store("k", "v", MemoryCategory::Core, None)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
assert!(memory.get("k").await.unwrap().is_none());
|
assert!(memory.get("k").await.unwrap().is_none());
|
||||||
assert!(memory.recall("k", 10).await.unwrap().is_empty());
|
assert!(memory.recall("k", 10, None).await.unwrap().is_empty());
|
||||||
assert!(memory.list(None).await.unwrap().is_empty());
|
assert!(memory.list(None, None).await.unwrap().is_empty());
|
||||||
assert!(!memory.forget("k").await.unwrap());
|
assert!(!memory.forget("k").await.unwrap());
|
||||||
assert_eq!(memory.count().await.unwrap(), 0);
|
assert_eq!(memory.count().await.unwrap(), 0);
|
||||||
assert!(memory.health_check().await);
|
assert!(memory.health_check().await);
|
||||||
|
|
|
||||||
|
|
@ -123,6 +123,19 @@ impl SqliteMemory {
|
||||||
);
|
);
|
||||||
CREATE INDEX IF NOT EXISTS idx_cache_accessed ON embedding_cache(accessed_at);",
|
CREATE INDEX IF NOT EXISTS idx_cache_accessed ON embedding_cache(accessed_at);",
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
|
// Migration: add session_id column if not present (safe to run repeatedly)
|
||||||
|
let has_session_id: bool = conn
|
||||||
|
.prepare("SELECT sql FROM sqlite_master WHERE type='table' AND name='memories'")?
|
||||||
|
.query_row([], |row| row.get::<_, String>(0))?
|
||||||
|
.contains("session_id");
|
||||||
|
if !has_session_id {
|
||||||
|
conn.execute_batch(
|
||||||
|
"ALTER TABLE memories ADD COLUMN session_id TEXT;
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_memories_session ON memories(session_id);",
|
||||||
|
)?;
|
||||||
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -360,6 +373,7 @@ impl Memory for SqliteMemory {
|
||||||
key: &str,
|
key: &str,
|
||||||
content: &str,
|
content: &str,
|
||||||
category: MemoryCategory,
|
category: MemoryCategory,
|
||||||
|
session_id: Option<&str>,
|
||||||
) -> anyhow::Result<()> {
|
) -> anyhow::Result<()> {
|
||||||
// Compute embedding (async, before lock)
|
// Compute embedding (async, before lock)
|
||||||
let embedding_bytes = self
|
let embedding_bytes = self
|
||||||
|
|
@ -376,20 +390,26 @@ impl Memory for SqliteMemory {
|
||||||
let id = Uuid::new_v4().to_string();
|
let id = Uuid::new_v4().to_string();
|
||||||
|
|
||||||
conn.execute(
|
conn.execute(
|
||||||
"INSERT INTO memories (id, key, content, category, embedding, created_at, updated_at)
|
"INSERT INTO memories (id, key, content, category, embedding, created_at, updated_at, session_id)
|
||||||
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)
|
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)
|
||||||
ON CONFLICT(key) DO UPDATE SET
|
ON CONFLICT(key) DO UPDATE SET
|
||||||
content = excluded.content,
|
content = excluded.content,
|
||||||
category = excluded.category,
|
category = excluded.category,
|
||||||
embedding = excluded.embedding,
|
embedding = excluded.embedding,
|
||||||
updated_at = excluded.updated_at",
|
updated_at = excluded.updated_at,
|
||||||
params![id, key, content, cat, embedding_bytes, now, now],
|
session_id = excluded.session_id",
|
||||||
|
params![id, key, content, cat, embedding_bytes, now, now, session_id],
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn recall(&self, query: &str, limit: usize) -> anyhow::Result<Vec<MemoryEntry>> {
|
async fn recall(
|
||||||
|
&self,
|
||||||
|
query: &str,
|
||||||
|
limit: usize,
|
||||||
|
session_id: Option<&str>,
|
||||||
|
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||||
if query.trim().is_empty() {
|
if query.trim().is_empty() {
|
||||||
return Ok(Vec::new());
|
return Ok(Vec::new());
|
||||||
}
|
}
|
||||||
|
|
@ -438,7 +458,7 @@ impl Memory for SqliteMemory {
|
||||||
let mut results = Vec::new();
|
let mut results = Vec::new();
|
||||||
for scored in &merged {
|
for scored in &merged {
|
||||||
let mut stmt = conn.prepare(
|
let mut stmt = conn.prepare(
|
||||||
"SELECT id, key, content, category, created_at FROM memories WHERE id = ?1",
|
"SELECT id, key, content, category, created_at, session_id FROM memories WHERE id = ?1",
|
||||||
)?;
|
)?;
|
||||||
if let Ok(entry) = stmt.query_row(params![scored.id], |row| {
|
if let Ok(entry) = stmt.query_row(params![scored.id], |row| {
|
||||||
Ok(MemoryEntry {
|
Ok(MemoryEntry {
|
||||||
|
|
@ -447,10 +467,16 @@ impl Memory for SqliteMemory {
|
||||||
content: row.get(2)?,
|
content: row.get(2)?,
|
||||||
category: Self::str_to_category(&row.get::<_, String>(3)?),
|
category: Self::str_to_category(&row.get::<_, String>(3)?),
|
||||||
timestamp: row.get(4)?,
|
timestamp: row.get(4)?,
|
||||||
session_id: None,
|
session_id: row.get(5)?,
|
||||||
score: Some(f64::from(scored.final_score)),
|
score: Some(f64::from(scored.final_score)),
|
||||||
})
|
})
|
||||||
}) {
|
}) {
|
||||||
|
// Filter by session_id if requested
|
||||||
|
if let Some(sid) = session_id {
|
||||||
|
if entry.session_id.as_deref() != Some(sid) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
results.push(entry);
|
results.push(entry);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -469,7 +495,7 @@ impl Memory for SqliteMemory {
|
||||||
.collect();
|
.collect();
|
||||||
let where_clause = conditions.join(" OR ");
|
let where_clause = conditions.join(" OR ");
|
||||||
let sql = format!(
|
let sql = format!(
|
||||||
"SELECT id, key, content, category, created_at FROM memories
|
"SELECT id, key, content, category, created_at, session_id FROM memories
|
||||||
WHERE {where_clause}
|
WHERE {where_clause}
|
||||||
ORDER BY updated_at DESC
|
ORDER BY updated_at DESC
|
||||||
LIMIT ?{}",
|
LIMIT ?{}",
|
||||||
|
|
@ -492,12 +518,18 @@ impl Memory for SqliteMemory {
|
||||||
content: row.get(2)?,
|
content: row.get(2)?,
|
||||||
category: Self::str_to_category(&row.get::<_, String>(3)?),
|
category: Self::str_to_category(&row.get::<_, String>(3)?),
|
||||||
timestamp: row.get(4)?,
|
timestamp: row.get(4)?,
|
||||||
session_id: None,
|
session_id: row.get(5)?,
|
||||||
score: Some(1.0),
|
score: Some(1.0),
|
||||||
})
|
})
|
||||||
})?;
|
})?;
|
||||||
for row in rows {
|
for row in rows {
|
||||||
results.push(row?);
|
let entry = row?;
|
||||||
|
if let Some(sid) = session_id {
|
||||||
|
if entry.session_id.as_deref() != Some(sid) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
results.push(entry);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -513,7 +545,7 @@ impl Memory for SqliteMemory {
|
||||||
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
|
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
|
||||||
|
|
||||||
let mut stmt = conn.prepare(
|
let mut stmt = conn.prepare(
|
||||||
"SELECT id, key, content, category, created_at FROM memories WHERE key = ?1",
|
"SELECT id, key, content, category, created_at, session_id FROM memories WHERE key = ?1",
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
let mut rows = stmt.query_map(params![key], |row| {
|
let mut rows = stmt.query_map(params![key], |row| {
|
||||||
|
|
@ -523,7 +555,7 @@ impl Memory for SqliteMemory {
|
||||||
content: row.get(2)?,
|
content: row.get(2)?,
|
||||||
category: Self::str_to_category(&row.get::<_, String>(3)?),
|
category: Self::str_to_category(&row.get::<_, String>(3)?),
|
||||||
timestamp: row.get(4)?,
|
timestamp: row.get(4)?,
|
||||||
session_id: None,
|
session_id: row.get(5)?,
|
||||||
score: None,
|
score: None,
|
||||||
})
|
})
|
||||||
})?;
|
})?;
|
||||||
|
|
@ -534,7 +566,11 @@ impl Memory for SqliteMemory {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn list(&self, category: Option<&MemoryCategory>) -> anyhow::Result<Vec<MemoryEntry>> {
|
async fn list(
|
||||||
|
&self,
|
||||||
|
category: Option<&MemoryCategory>,
|
||||||
|
session_id: Option<&str>,
|
||||||
|
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||||
let conn = self
|
let conn = self
|
||||||
.conn
|
.conn
|
||||||
.lock()
|
.lock()
|
||||||
|
|
@ -549,7 +585,7 @@ impl Memory for SqliteMemory {
|
||||||
content: row.get(2)?,
|
content: row.get(2)?,
|
||||||
category: Self::str_to_category(&row.get::<_, String>(3)?),
|
category: Self::str_to_category(&row.get::<_, String>(3)?),
|
||||||
timestamp: row.get(4)?,
|
timestamp: row.get(4)?,
|
||||||
session_id: None,
|
session_id: row.get(5)?,
|
||||||
score: None,
|
score: None,
|
||||||
})
|
})
|
||||||
};
|
};
|
||||||
|
|
@ -557,21 +593,33 @@ impl Memory for SqliteMemory {
|
||||||
if let Some(cat) = category {
|
if let Some(cat) = category {
|
||||||
let cat_str = Self::category_to_str(cat);
|
let cat_str = Self::category_to_str(cat);
|
||||||
let mut stmt = conn.prepare(
|
let mut stmt = conn.prepare(
|
||||||
"SELECT id, key, content, category, created_at FROM memories
|
"SELECT id, key, content, category, created_at, session_id FROM memories
|
||||||
WHERE category = ?1 ORDER BY updated_at DESC",
|
WHERE category = ?1 ORDER BY updated_at DESC",
|
||||||
)?;
|
)?;
|
||||||
let rows = stmt.query_map(params![cat_str], row_mapper)?;
|
let rows = stmt.query_map(params![cat_str], row_mapper)?;
|
||||||
for row in rows {
|
for row in rows {
|
||||||
results.push(row?);
|
let entry = row?;
|
||||||
|
if let Some(sid) = session_id {
|
||||||
|
if entry.session_id.as_deref() != Some(sid) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
results.push(entry);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
let mut stmt = conn.prepare(
|
let mut stmt = conn.prepare(
|
||||||
"SELECT id, key, content, category, created_at FROM memories
|
"SELECT id, key, content, category, created_at, session_id FROM memories
|
||||||
ORDER BY updated_at DESC",
|
ORDER BY updated_at DESC",
|
||||||
)?;
|
)?;
|
||||||
let rows = stmt.query_map([], row_mapper)?;
|
let rows = stmt.query_map([], row_mapper)?;
|
||||||
for row in rows {
|
for row in rows {
|
||||||
results.push(row?);
|
let entry = row?;
|
||||||
|
if let Some(sid) = session_id {
|
||||||
|
if entry.session_id.as_deref() != Some(sid) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
results.push(entry);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -631,7 +679,7 @@ mod tests {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn sqlite_store_and_get() {
|
async fn sqlite_store_and_get() {
|
||||||
let (_tmp, mem) = temp_sqlite();
|
let (_tmp, mem) = temp_sqlite();
|
||||||
mem.store("user_lang", "Prefers Rust", MemoryCategory::Core)
|
mem.store("user_lang", "Prefers Rust", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
|
|
@ -646,10 +694,10 @@ mod tests {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn sqlite_store_upsert() {
|
async fn sqlite_store_upsert() {
|
||||||
let (_tmp, mem) = temp_sqlite();
|
let (_tmp, mem) = temp_sqlite();
|
||||||
mem.store("pref", "likes Rust", MemoryCategory::Core)
|
mem.store("pref", "likes Rust", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
mem.store("pref", "loves Rust", MemoryCategory::Core)
|
mem.store("pref", "loves Rust", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
|
|
@ -661,17 +709,22 @@ mod tests {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn sqlite_recall_keyword() {
|
async fn sqlite_recall_keyword() {
|
||||||
let (_tmp, mem) = temp_sqlite();
|
let (_tmp, mem) = temp_sqlite();
|
||||||
mem.store("a", "Rust is fast and safe", MemoryCategory::Core)
|
mem.store("a", "Rust is fast and safe", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
mem.store("b", "Python is interpreted", MemoryCategory::Core)
|
mem.store("b", "Python is interpreted", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
mem.store("c", "Rust has zero-cost abstractions", MemoryCategory::Core)
|
mem.store(
|
||||||
|
"c",
|
||||||
|
"Rust has zero-cost abstractions",
|
||||||
|
MemoryCategory::Core,
|
||||||
|
None,
|
||||||
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let results = mem.recall("Rust", 10).await.unwrap();
|
let results = mem.recall("Rust", 10, None).await.unwrap();
|
||||||
assert_eq!(results.len(), 2);
|
assert_eq!(results.len(), 2);
|
||||||
assert!(results
|
assert!(results
|
||||||
.iter()
|
.iter()
|
||||||
|
|
@ -681,14 +734,14 @@ mod tests {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn sqlite_recall_multi_keyword() {
|
async fn sqlite_recall_multi_keyword() {
|
||||||
let (_tmp, mem) = temp_sqlite();
|
let (_tmp, mem) = temp_sqlite();
|
||||||
mem.store("a", "Rust is fast", MemoryCategory::Core)
|
mem.store("a", "Rust is fast", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
mem.store("b", "Rust is safe and fast", MemoryCategory::Core)
|
mem.store("b", "Rust is safe and fast", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let results = mem.recall("fast safe", 10).await.unwrap();
|
let results = mem.recall("fast safe", 10, None).await.unwrap();
|
||||||
assert!(!results.is_empty());
|
assert!(!results.is_empty());
|
||||||
// Entry with both keywords should score higher
|
// Entry with both keywords should score higher
|
||||||
assert!(results[0].content.contains("safe") && results[0].content.contains("fast"));
|
assert!(results[0].content.contains("safe") && results[0].content.contains("fast"));
|
||||||
|
|
@ -697,17 +750,17 @@ mod tests {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn sqlite_recall_no_match() {
|
async fn sqlite_recall_no_match() {
|
||||||
let (_tmp, mem) = temp_sqlite();
|
let (_tmp, mem) = temp_sqlite();
|
||||||
mem.store("a", "Rust rocks", MemoryCategory::Core)
|
mem.store("a", "Rust rocks", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let results = mem.recall("javascript", 10).await.unwrap();
|
let results = mem.recall("javascript", 10, None).await.unwrap();
|
||||||
assert!(results.is_empty());
|
assert!(results.is_empty());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn sqlite_forget() {
|
async fn sqlite_forget() {
|
||||||
let (_tmp, mem) = temp_sqlite();
|
let (_tmp, mem) = temp_sqlite();
|
||||||
mem.store("temp", "temporary data", MemoryCategory::Conversation)
|
mem.store("temp", "temporary data", MemoryCategory::Conversation, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
assert_eq!(mem.count().await.unwrap(), 1);
|
assert_eq!(mem.count().await.unwrap(), 1);
|
||||||
|
|
@ -727,29 +780,37 @@ mod tests {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn sqlite_list_all() {
|
async fn sqlite_list_all() {
|
||||||
let (_tmp, mem) = temp_sqlite();
|
let (_tmp, mem) = temp_sqlite();
|
||||||
mem.store("a", "one", MemoryCategory::Core).await.unwrap();
|
mem.store("a", "one", MemoryCategory::Core, None)
|
||||||
mem.store("b", "two", MemoryCategory::Daily).await.unwrap();
|
.await
|
||||||
mem.store("c", "three", MemoryCategory::Conversation)
|
.unwrap();
|
||||||
|
mem.store("b", "two", MemoryCategory::Daily, None)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
mem.store("c", "three", MemoryCategory::Conversation, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let all = mem.list(None).await.unwrap();
|
let all = mem.list(None, None).await.unwrap();
|
||||||
assert_eq!(all.len(), 3);
|
assert_eq!(all.len(), 3);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn sqlite_list_by_category() {
|
async fn sqlite_list_by_category() {
|
||||||
let (_tmp, mem) = temp_sqlite();
|
let (_tmp, mem) = temp_sqlite();
|
||||||
mem.store("a", "core1", MemoryCategory::Core).await.unwrap();
|
mem.store("a", "core1", MemoryCategory::Core, None)
|
||||||
mem.store("b", "core2", MemoryCategory::Core).await.unwrap();
|
.await
|
||||||
mem.store("c", "daily1", MemoryCategory::Daily)
|
.unwrap();
|
||||||
|
mem.store("b", "core2", MemoryCategory::Core, None)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
mem.store("c", "daily1", MemoryCategory::Daily, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let core = mem.list(Some(&MemoryCategory::Core)).await.unwrap();
|
let core = mem.list(Some(&MemoryCategory::Core), None).await.unwrap();
|
||||||
assert_eq!(core.len(), 2);
|
assert_eq!(core.len(), 2);
|
||||||
|
|
||||||
let daily = mem.list(Some(&MemoryCategory::Daily)).await.unwrap();
|
let daily = mem.list(Some(&MemoryCategory::Daily), None).await.unwrap();
|
||||||
assert_eq!(daily.len(), 1);
|
assert_eq!(daily.len(), 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -771,7 +832,7 @@ mod tests {
|
||||||
|
|
||||||
{
|
{
|
||||||
let mem = SqliteMemory::new(tmp.path()).unwrap();
|
let mem = SqliteMemory::new(tmp.path()).unwrap();
|
||||||
mem.store("persist", "I survive restarts", MemoryCategory::Core)
|
mem.store("persist", "I survive restarts", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
}
|
}
|
||||||
|
|
@ -794,7 +855,7 @@ mod tests {
|
||||||
];
|
];
|
||||||
|
|
||||||
for (i, cat) in categories.iter().enumerate() {
|
for (i, cat) in categories.iter().enumerate() {
|
||||||
mem.store(&format!("k{i}"), &format!("v{i}"), cat.clone())
|
mem.store(&format!("k{i}"), &format!("v{i}"), cat.clone(), None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
}
|
}
|
||||||
|
|
@ -814,21 +875,28 @@ mod tests {
|
||||||
"a",
|
"a",
|
||||||
"Rust is a systems programming language",
|
"Rust is a systems programming language",
|
||||||
MemoryCategory::Core,
|
MemoryCategory::Core,
|
||||||
|
None,
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
mem.store("b", "Python is great for scripting", MemoryCategory::Core)
|
mem.store(
|
||||||
|
"b",
|
||||||
|
"Python is great for scripting",
|
||||||
|
MemoryCategory::Core,
|
||||||
|
None,
|
||||||
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
mem.store(
|
mem.store(
|
||||||
"c",
|
"c",
|
||||||
"Rust and Rust and Rust everywhere",
|
"Rust and Rust and Rust everywhere",
|
||||||
MemoryCategory::Core,
|
MemoryCategory::Core,
|
||||||
|
None,
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let results = mem.recall("Rust", 10).await.unwrap();
|
let results = mem.recall("Rust", 10, None).await.unwrap();
|
||||||
assert!(results.len() >= 2);
|
assert!(results.len() >= 2);
|
||||||
// All results should contain "Rust"
|
// All results should contain "Rust"
|
||||||
for r in &results {
|
for r in &results {
|
||||||
|
|
@ -843,17 +911,17 @@ mod tests {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn fts5_multi_word_query() {
|
async fn fts5_multi_word_query() {
|
||||||
let (_tmp, mem) = temp_sqlite();
|
let (_tmp, mem) = temp_sqlite();
|
||||||
mem.store("a", "The quick brown fox jumps", MemoryCategory::Core)
|
mem.store("a", "The quick brown fox jumps", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
mem.store("b", "A lazy dog sleeps", MemoryCategory::Core)
|
mem.store("b", "A lazy dog sleeps", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
mem.store("c", "The quick dog runs fast", MemoryCategory::Core)
|
mem.store("c", "The quick dog runs fast", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let results = mem.recall("quick dog", 10).await.unwrap();
|
let results = mem.recall("quick dog", 10, None).await.unwrap();
|
||||||
assert!(!results.is_empty());
|
assert!(!results.is_empty());
|
||||||
// "The quick dog runs fast" matches both terms
|
// "The quick dog runs fast" matches both terms
|
||||||
assert!(results[0].content.contains("quick"));
|
assert!(results[0].content.contains("quick"));
|
||||||
|
|
@ -862,16 +930,20 @@ mod tests {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn recall_empty_query_returns_empty() {
|
async fn recall_empty_query_returns_empty() {
|
||||||
let (_tmp, mem) = temp_sqlite();
|
let (_tmp, mem) = temp_sqlite();
|
||||||
mem.store("a", "data", MemoryCategory::Core).await.unwrap();
|
mem.store("a", "data", MemoryCategory::Core, None)
|
||||||
let results = mem.recall("", 10).await.unwrap();
|
.await
|
||||||
|
.unwrap();
|
||||||
|
let results = mem.recall("", 10, None).await.unwrap();
|
||||||
assert!(results.is_empty());
|
assert!(results.is_empty());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn recall_whitespace_query_returns_empty() {
|
async fn recall_whitespace_query_returns_empty() {
|
||||||
let (_tmp, mem) = temp_sqlite();
|
let (_tmp, mem) = temp_sqlite();
|
||||||
mem.store("a", "data", MemoryCategory::Core).await.unwrap();
|
mem.store("a", "data", MemoryCategory::Core, None)
|
||||||
let results = mem.recall(" ", 10).await.unwrap();
|
.await
|
||||||
|
.unwrap();
|
||||||
|
let results = mem.recall(" ", 10, None).await.unwrap();
|
||||||
assert!(results.is_empty());
|
assert!(results.is_empty());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -936,7 +1008,12 @@ mod tests {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn fts5_syncs_on_insert() {
|
async fn fts5_syncs_on_insert() {
|
||||||
let (_tmp, mem) = temp_sqlite();
|
let (_tmp, mem) = temp_sqlite();
|
||||||
mem.store("test_key", "unique_searchterm_xyz", MemoryCategory::Core)
|
mem.store(
|
||||||
|
"test_key",
|
||||||
|
"unique_searchterm_xyz",
|
||||||
|
MemoryCategory::Core,
|
||||||
|
None,
|
||||||
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
|
|
@ -954,7 +1031,12 @@ mod tests {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn fts5_syncs_on_delete() {
|
async fn fts5_syncs_on_delete() {
|
||||||
let (_tmp, mem) = temp_sqlite();
|
let (_tmp, mem) = temp_sqlite();
|
||||||
mem.store("del_key", "deletable_content_abc", MemoryCategory::Core)
|
mem.store(
|
||||||
|
"del_key",
|
||||||
|
"deletable_content_abc",
|
||||||
|
MemoryCategory::Core,
|
||||||
|
None,
|
||||||
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
mem.forget("del_key").await.unwrap();
|
mem.forget("del_key").await.unwrap();
|
||||||
|
|
@ -973,10 +1055,15 @@ mod tests {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn fts5_syncs_on_update() {
|
async fn fts5_syncs_on_update() {
|
||||||
let (_tmp, mem) = temp_sqlite();
|
let (_tmp, mem) = temp_sqlite();
|
||||||
mem.store("upd_key", "original_content_111", MemoryCategory::Core)
|
mem.store(
|
||||||
|
"upd_key",
|
||||||
|
"original_content_111",
|
||||||
|
MemoryCategory::Core,
|
||||||
|
None,
|
||||||
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
mem.store("upd_key", "updated_content_222", MemoryCategory::Core)
|
mem.store("upd_key", "updated_content_222", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
|
|
@ -1018,10 +1105,10 @@ mod tests {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn reindex_rebuilds_fts() {
|
async fn reindex_rebuilds_fts() {
|
||||||
let (_tmp, mem) = temp_sqlite();
|
let (_tmp, mem) = temp_sqlite();
|
||||||
mem.store("r1", "reindex test alpha", MemoryCategory::Core)
|
mem.store("r1", "reindex test alpha", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
mem.store("r2", "reindex test beta", MemoryCategory::Core)
|
mem.store("r2", "reindex test beta", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
|
|
@ -1030,7 +1117,7 @@ mod tests {
|
||||||
assert_eq!(count, 0);
|
assert_eq!(count, 0);
|
||||||
|
|
||||||
// FTS should still work after rebuild
|
// FTS should still work after rebuild
|
||||||
let results = mem.recall("reindex", 10).await.unwrap();
|
let results = mem.recall("reindex", 10, None).await.unwrap();
|
||||||
assert_eq!(results.len(), 2);
|
assert_eq!(results.len(), 2);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1044,12 +1131,13 @@ mod tests {
|
||||||
&format!("k{i}"),
|
&format!("k{i}"),
|
||||||
&format!("common keyword item {i}"),
|
&format!("common keyword item {i}"),
|
||||||
MemoryCategory::Core,
|
MemoryCategory::Core,
|
||||||
|
None,
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
let results = mem.recall("common keyword", 5).await.unwrap();
|
let results = mem.recall("common keyword", 5, None).await.unwrap();
|
||||||
assert!(results.len() <= 5);
|
assert!(results.len() <= 5);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1058,11 +1146,11 @@ mod tests {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn recall_results_have_scores() {
|
async fn recall_results_have_scores() {
|
||||||
let (_tmp, mem) = temp_sqlite();
|
let (_tmp, mem) = temp_sqlite();
|
||||||
mem.store("s1", "scored result test", MemoryCategory::Core)
|
mem.store("s1", "scored result test", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let results = mem.recall("scored", 10).await.unwrap();
|
let results = mem.recall("scored", 10, None).await.unwrap();
|
||||||
assert!(!results.is_empty());
|
assert!(!results.is_empty());
|
||||||
for r in &results {
|
for r in &results {
|
||||||
assert!(r.score.is_some(), "Expected score on result: {:?}", r.key);
|
assert!(r.score.is_some(), "Expected score on result: {:?}", r.key);
|
||||||
|
|
@ -1074,11 +1162,11 @@ mod tests {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn recall_with_quotes_in_query() {
|
async fn recall_with_quotes_in_query() {
|
||||||
let (_tmp, mem) = temp_sqlite();
|
let (_tmp, mem) = temp_sqlite();
|
||||||
mem.store("q1", "He said hello world", MemoryCategory::Core)
|
mem.store("q1", "He said hello world", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
// Quotes in query should not crash FTS5
|
// Quotes in query should not crash FTS5
|
||||||
let results = mem.recall("\"hello\"", 10).await.unwrap();
|
let results = mem.recall("\"hello\"", 10, None).await.unwrap();
|
||||||
// May or may not match depending on FTS5 escaping, but must not error
|
// May or may not match depending on FTS5 escaping, but must not error
|
||||||
assert!(results.len() <= 10);
|
assert!(results.len() <= 10);
|
||||||
}
|
}
|
||||||
|
|
@ -1086,31 +1174,34 @@ mod tests {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn recall_with_asterisk_in_query() {
|
async fn recall_with_asterisk_in_query() {
|
||||||
let (_tmp, mem) = temp_sqlite();
|
let (_tmp, mem) = temp_sqlite();
|
||||||
mem.store("a1", "wildcard test content", MemoryCategory::Core)
|
mem.store("a1", "wildcard test content", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let results = mem.recall("wild*", 10).await.unwrap();
|
let results = mem.recall("wild*", 10, None).await.unwrap();
|
||||||
assert!(results.len() <= 10);
|
assert!(results.len() <= 10);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn recall_with_parentheses_in_query() {
|
async fn recall_with_parentheses_in_query() {
|
||||||
let (_tmp, mem) = temp_sqlite();
|
let (_tmp, mem) = temp_sqlite();
|
||||||
mem.store("p1", "function call test", MemoryCategory::Core)
|
mem.store("p1", "function call test", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let results = mem.recall("function()", 10).await.unwrap();
|
let results = mem.recall("function()", 10, None).await.unwrap();
|
||||||
assert!(results.len() <= 10);
|
assert!(results.len() <= 10);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn recall_with_sql_injection_attempt() {
|
async fn recall_with_sql_injection_attempt() {
|
||||||
let (_tmp, mem) = temp_sqlite();
|
let (_tmp, mem) = temp_sqlite();
|
||||||
mem.store("safe", "normal content", MemoryCategory::Core)
|
mem.store("safe", "normal content", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
// Should not crash or leak data
|
// Should not crash or leak data
|
||||||
let results = mem.recall("'; DROP TABLE memories; --", 10).await.unwrap();
|
let results = mem
|
||||||
|
.recall("'; DROP TABLE memories; --", 10, None)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
assert!(results.len() <= 10);
|
assert!(results.len() <= 10);
|
||||||
// Table should still exist
|
// Table should still exist
|
||||||
assert_eq!(mem.count().await.unwrap(), 1);
|
assert_eq!(mem.count().await.unwrap(), 1);
|
||||||
|
|
@ -1121,7 +1212,9 @@ mod tests {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn store_empty_content() {
|
async fn store_empty_content() {
|
||||||
let (_tmp, mem) = temp_sqlite();
|
let (_tmp, mem) = temp_sqlite();
|
||||||
mem.store("empty", "", MemoryCategory::Core).await.unwrap();
|
mem.store("empty", "", MemoryCategory::Core, None)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
let entry = mem.get("empty").await.unwrap().unwrap();
|
let entry = mem.get("empty").await.unwrap().unwrap();
|
||||||
assert_eq!(entry.content, "");
|
assert_eq!(entry.content, "");
|
||||||
}
|
}
|
||||||
|
|
@ -1129,7 +1222,7 @@ mod tests {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn store_empty_key() {
|
async fn store_empty_key() {
|
||||||
let (_tmp, mem) = temp_sqlite();
|
let (_tmp, mem) = temp_sqlite();
|
||||||
mem.store("", "content for empty key", MemoryCategory::Core)
|
mem.store("", "content for empty key", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let entry = mem.get("").await.unwrap().unwrap();
|
let entry = mem.get("").await.unwrap().unwrap();
|
||||||
|
|
@ -1140,7 +1233,7 @@ mod tests {
|
||||||
async fn store_very_long_content() {
|
async fn store_very_long_content() {
|
||||||
let (_tmp, mem) = temp_sqlite();
|
let (_tmp, mem) = temp_sqlite();
|
||||||
let long_content = "x".repeat(100_000);
|
let long_content = "x".repeat(100_000);
|
||||||
mem.store("long", &long_content, MemoryCategory::Core)
|
mem.store("long", &long_content, MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let entry = mem.get("long").await.unwrap().unwrap();
|
let entry = mem.get("long").await.unwrap().unwrap();
|
||||||
|
|
@ -1150,7 +1243,12 @@ mod tests {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn store_unicode_and_emoji() {
|
async fn store_unicode_and_emoji() {
|
||||||
let (_tmp, mem) = temp_sqlite();
|
let (_tmp, mem) = temp_sqlite();
|
||||||
mem.store("emoji_key_🦀", "こんにちは 🚀 Ñoño", MemoryCategory::Core)
|
mem.store(
|
||||||
|
"emoji_key_🦀",
|
||||||
|
"こんにちは 🚀 Ñoño",
|
||||||
|
MemoryCategory::Core,
|
||||||
|
None,
|
||||||
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let entry = mem.get("emoji_key_🦀").await.unwrap().unwrap();
|
let entry = mem.get("emoji_key_🦀").await.unwrap().unwrap();
|
||||||
|
|
@ -1161,7 +1259,7 @@ mod tests {
|
||||||
async fn store_content_with_newlines_and_tabs() {
|
async fn store_content_with_newlines_and_tabs() {
|
||||||
let (_tmp, mem) = temp_sqlite();
|
let (_tmp, mem) = temp_sqlite();
|
||||||
let content = "line1\nline2\ttab\rcarriage\n\nnewparagraph";
|
let content = "line1\nline2\ttab\rcarriage\n\nnewparagraph";
|
||||||
mem.store("whitespace", content, MemoryCategory::Core)
|
mem.store("whitespace", content, MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let entry = mem.get("whitespace").await.unwrap().unwrap();
|
let entry = mem.get("whitespace").await.unwrap().unwrap();
|
||||||
|
|
@ -1173,11 +1271,11 @@ mod tests {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn recall_single_character_query() {
|
async fn recall_single_character_query() {
|
||||||
let (_tmp, mem) = temp_sqlite();
|
let (_tmp, mem) = temp_sqlite();
|
||||||
mem.store("a", "x marks the spot", MemoryCategory::Core)
|
mem.store("a", "x marks the spot", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
// Single char may not match FTS5 but LIKE fallback should work
|
// Single char may not match FTS5 but LIKE fallback should work
|
||||||
let results = mem.recall("x", 10).await.unwrap();
|
let results = mem.recall("x", 10, None).await.unwrap();
|
||||||
// Should not crash; may or may not find results
|
// Should not crash; may or may not find results
|
||||||
assert!(results.len() <= 10);
|
assert!(results.len() <= 10);
|
||||||
}
|
}
|
||||||
|
|
@ -1185,23 +1283,23 @@ mod tests {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn recall_limit_zero() {
|
async fn recall_limit_zero() {
|
||||||
let (_tmp, mem) = temp_sqlite();
|
let (_tmp, mem) = temp_sqlite();
|
||||||
mem.store("a", "some content", MemoryCategory::Core)
|
mem.store("a", "some content", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let results = mem.recall("some", 0).await.unwrap();
|
let results = mem.recall("some", 0, None).await.unwrap();
|
||||||
assert!(results.is_empty());
|
assert!(results.is_empty());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn recall_limit_one() {
|
async fn recall_limit_one() {
|
||||||
let (_tmp, mem) = temp_sqlite();
|
let (_tmp, mem) = temp_sqlite();
|
||||||
mem.store("a", "matching content alpha", MemoryCategory::Core)
|
mem.store("a", "matching content alpha", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
mem.store("b", "matching content beta", MemoryCategory::Core)
|
mem.store("b", "matching content beta", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let results = mem.recall("matching content", 1).await.unwrap();
|
let results = mem.recall("matching content", 1, None).await.unwrap();
|
||||||
assert_eq!(results.len(), 1);
|
assert_eq!(results.len(), 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1212,21 +1310,22 @@ mod tests {
|
||||||
"rust_preferences",
|
"rust_preferences",
|
||||||
"User likes systems programming",
|
"User likes systems programming",
|
||||||
MemoryCategory::Core,
|
MemoryCategory::Core,
|
||||||
|
None,
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
// "rust" appears in key but not content — LIKE fallback checks key too
|
// "rust" appears in key but not content — LIKE fallback checks key too
|
||||||
let results = mem.recall("rust", 10).await.unwrap();
|
let results = mem.recall("rust", 10, None).await.unwrap();
|
||||||
assert!(!results.is_empty(), "Should match by key");
|
assert!(!results.is_empty(), "Should match by key");
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn recall_unicode_query() {
|
async fn recall_unicode_query() {
|
||||||
let (_tmp, mem) = temp_sqlite();
|
let (_tmp, mem) = temp_sqlite();
|
||||||
mem.store("jp", "日本語のテスト", MemoryCategory::Core)
|
mem.store("jp", "日本語のテスト", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let results = mem.recall("日本語", 10).await.unwrap();
|
let results = mem.recall("日本語", 10, None).await.unwrap();
|
||||||
assert!(!results.is_empty());
|
assert!(!results.is_empty());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1237,7 +1336,9 @@ mod tests {
|
||||||
let tmp = TempDir::new().unwrap();
|
let tmp = TempDir::new().unwrap();
|
||||||
{
|
{
|
||||||
let mem = SqliteMemory::new(tmp.path()).unwrap();
|
let mem = SqliteMemory::new(tmp.path()).unwrap();
|
||||||
mem.store("k1", "v1", MemoryCategory::Core).await.unwrap();
|
mem.store("k1", "v1", MemoryCategory::Core, None)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
}
|
}
|
||||||
// Open again — init_schema runs again on existing DB
|
// Open again — init_schema runs again on existing DB
|
||||||
let mem2 = SqliteMemory::new(tmp.path()).unwrap();
|
let mem2 = SqliteMemory::new(tmp.path()).unwrap();
|
||||||
|
|
@ -1245,7 +1346,9 @@ mod tests {
|
||||||
assert!(entry.is_some());
|
assert!(entry.is_some());
|
||||||
assert_eq!(entry.unwrap().content, "v1");
|
assert_eq!(entry.unwrap().content, "v1");
|
||||||
// Store more data — should work fine
|
// Store more data — should work fine
|
||||||
mem2.store("k2", "v2", MemoryCategory::Daily).await.unwrap();
|
mem2.store("k2", "v2", MemoryCategory::Daily, None)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
assert_eq!(mem2.count().await.unwrap(), 2);
|
assert_eq!(mem2.count().await.unwrap(), 2);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1263,11 +1366,16 @@ mod tests {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn forget_then_recall_no_ghost_results() {
|
async fn forget_then_recall_no_ghost_results() {
|
||||||
let (_tmp, mem) = temp_sqlite();
|
let (_tmp, mem) = temp_sqlite();
|
||||||
mem.store("ghost", "phantom memory content", MemoryCategory::Core)
|
mem.store(
|
||||||
|
"ghost",
|
||||||
|
"phantom memory content",
|
||||||
|
MemoryCategory::Core,
|
||||||
|
None,
|
||||||
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
mem.forget("ghost").await.unwrap();
|
mem.forget("ghost").await.unwrap();
|
||||||
let results = mem.recall("phantom memory", 10).await.unwrap();
|
let results = mem.recall("phantom memory", 10, None).await.unwrap();
|
||||||
assert!(
|
assert!(
|
||||||
results.is_empty(),
|
results.is_empty(),
|
||||||
"Deleted memory should not appear in recall"
|
"Deleted memory should not appear in recall"
|
||||||
|
|
@ -1277,11 +1385,11 @@ mod tests {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn forget_and_re_store_same_key() {
|
async fn forget_and_re_store_same_key() {
|
||||||
let (_tmp, mem) = temp_sqlite();
|
let (_tmp, mem) = temp_sqlite();
|
||||||
mem.store("cycle", "version 1", MemoryCategory::Core)
|
mem.store("cycle", "version 1", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
mem.forget("cycle").await.unwrap();
|
mem.forget("cycle").await.unwrap();
|
||||||
mem.store("cycle", "version 2", MemoryCategory::Core)
|
mem.store("cycle", "version 2", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let entry = mem.get("cycle").await.unwrap().unwrap();
|
let entry = mem.get("cycle").await.unwrap().unwrap();
|
||||||
|
|
@ -1301,14 +1409,14 @@ mod tests {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn reindex_twice_is_safe() {
|
async fn reindex_twice_is_safe() {
|
||||||
let (_tmp, mem) = temp_sqlite();
|
let (_tmp, mem) = temp_sqlite();
|
||||||
mem.store("r1", "reindex data", MemoryCategory::Core)
|
mem.store("r1", "reindex data", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
mem.reindex().await.unwrap();
|
mem.reindex().await.unwrap();
|
||||||
let count = mem.reindex().await.unwrap();
|
let count = mem.reindex().await.unwrap();
|
||||||
assert_eq!(count, 0); // Noop embedder → nothing to re-embed
|
assert_eq!(count, 0); // Noop embedder → nothing to re-embed
|
||||||
// Data should still be intact
|
// Data should still be intact
|
||||||
let results = mem.recall("reindex", 10).await.unwrap();
|
let results = mem.recall("reindex", 10, None).await.unwrap();
|
||||||
assert_eq!(results.len(), 1);
|
assert_eq!(results.len(), 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1362,18 +1470,28 @@ mod tests {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn list_custom_category() {
|
async fn list_custom_category() {
|
||||||
let (_tmp, mem) = temp_sqlite();
|
let (_tmp, mem) = temp_sqlite();
|
||||||
mem.store("c1", "custom1", MemoryCategory::Custom("project".into()))
|
mem.store(
|
||||||
|
"c1",
|
||||||
|
"custom1",
|
||||||
|
MemoryCategory::Custom("project".into()),
|
||||||
|
None,
|
||||||
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
mem.store("c2", "custom2", MemoryCategory::Custom("project".into()))
|
mem.store(
|
||||||
|
"c2",
|
||||||
|
"custom2",
|
||||||
|
MemoryCategory::Custom("project".into()),
|
||||||
|
None,
|
||||||
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
mem.store("c3", "other", MemoryCategory::Core)
|
mem.store("c3", "other", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let project = mem
|
let project = mem
|
||||||
.list(Some(&MemoryCategory::Custom("project".into())))
|
.list(Some(&MemoryCategory::Custom("project".into())), None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
assert_eq!(project.len(), 2);
|
assert_eq!(project.len(), 2);
|
||||||
|
|
@ -1382,7 +1500,122 @@ mod tests {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn list_empty_db() {
|
async fn list_empty_db() {
|
||||||
let (_tmp, mem) = temp_sqlite();
|
let (_tmp, mem) = temp_sqlite();
|
||||||
let all = mem.list(None).await.unwrap();
|
let all = mem.list(None, None).await.unwrap();
|
||||||
assert!(all.is_empty());
|
assert!(all.is_empty());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ── Session isolation ─────────────────────────────────────────
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn store_and_recall_with_session_id() {
|
||||||
|
let (_tmp, mem) = temp_sqlite();
|
||||||
|
mem.store("k1", "session A fact", MemoryCategory::Core, Some("sess-a"))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
mem.store("k2", "session B fact", MemoryCategory::Core, Some("sess-b"))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
mem.store("k3", "no session fact", MemoryCategory::Core, None)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// Recall with session-a filter returns only session-a entry
|
||||||
|
let results = mem.recall("fact", 10, Some("sess-a")).await.unwrap();
|
||||||
|
assert_eq!(results.len(), 1);
|
||||||
|
assert_eq!(results[0].key, "k1");
|
||||||
|
assert_eq!(results[0].session_id.as_deref(), Some("sess-a"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn recall_no_session_filter_returns_all() {
|
||||||
|
let (_tmp, mem) = temp_sqlite();
|
||||||
|
mem.store("k1", "alpha fact", MemoryCategory::Core, Some("sess-a"))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
mem.store("k2", "beta fact", MemoryCategory::Core, Some("sess-b"))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
mem.store("k3", "gamma fact", MemoryCategory::Core, None)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// Recall without session filter returns all matching entries
|
||||||
|
let results = mem.recall("fact", 10, None).await.unwrap();
|
||||||
|
assert_eq!(results.len(), 3);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn cross_session_recall_isolation() {
|
||||||
|
let (_tmp, mem) = temp_sqlite();
|
||||||
|
mem.store(
|
||||||
|
"secret",
|
||||||
|
"session A secret data",
|
||||||
|
MemoryCategory::Core,
|
||||||
|
Some("sess-a"),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// Session B cannot see session A data
|
||||||
|
let results = mem.recall("secret", 10, Some("sess-b")).await.unwrap();
|
||||||
|
assert!(results.is_empty());
|
||||||
|
|
||||||
|
// Session A can see its own data
|
||||||
|
let results = mem.recall("secret", 10, Some("sess-a")).await.unwrap();
|
||||||
|
assert_eq!(results.len(), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn list_with_session_filter() {
|
||||||
|
let (_tmp, mem) = temp_sqlite();
|
||||||
|
mem.store("k1", "a1", MemoryCategory::Core, Some("sess-a"))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
mem.store("k2", "a2", MemoryCategory::Conversation, Some("sess-a"))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
mem.store("k3", "b1", MemoryCategory::Core, Some("sess-b"))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
mem.store("k4", "none1", MemoryCategory::Core, None)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// List with session-a filter
|
||||||
|
let results = mem.list(None, Some("sess-a")).await.unwrap();
|
||||||
|
assert_eq!(results.len(), 2);
|
||||||
|
assert!(results
|
||||||
|
.iter()
|
||||||
|
.all(|e| e.session_id.as_deref() == Some("sess-a")));
|
||||||
|
|
||||||
|
// List with session-a + category filter
|
||||||
|
let results = mem
|
||||||
|
.list(Some(&MemoryCategory::Core), Some("sess-a"))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(results.len(), 1);
|
||||||
|
assert_eq!(results[0].key, "k1");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn schema_migration_idempotent_on_reopen() {
|
||||||
|
let tmp = TempDir::new().unwrap();
|
||||||
|
|
||||||
|
// First open: creates schema + migration
|
||||||
|
{
|
||||||
|
let mem = SqliteMemory::new(tmp.path()).unwrap();
|
||||||
|
mem.store("k1", "before reopen", MemoryCategory::Core, Some("sess-x"))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Second open: migration runs again but is idempotent
|
||||||
|
{
|
||||||
|
let mem = SqliteMemory::new(tmp.path()).unwrap();
|
||||||
|
let results = mem.recall("reopen", 10, Some("sess-x")).await.unwrap();
|
||||||
|
assert_eq!(results.len(), 1);
|
||||||
|
assert_eq!(results[0].key, "k1");
|
||||||
|
assert_eq!(results[0].session_id.as_deref(), Some("sess-x"));
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -44,18 +44,32 @@ pub trait Memory: Send + Sync {
|
||||||
/// Backend name
|
/// Backend name
|
||||||
fn name(&self) -> &str;
|
fn name(&self) -> &str;
|
||||||
|
|
||||||
/// Store a memory entry
|
/// Store a memory entry, optionally scoped to a session
|
||||||
async fn store(&self, key: &str, content: &str, category: MemoryCategory)
|
async fn store(
|
||||||
-> anyhow::Result<()>;
|
&self,
|
||||||
|
key: &str,
|
||||||
|
content: &str,
|
||||||
|
category: MemoryCategory,
|
||||||
|
session_id: Option<&str>,
|
||||||
|
) -> anyhow::Result<()>;
|
||||||
|
|
||||||
/// Recall memories matching a query (keyword search)
|
/// Recall memories matching a query (keyword search), optionally scoped to a session
|
||||||
async fn recall(&self, query: &str, limit: usize) -> anyhow::Result<Vec<MemoryEntry>>;
|
async fn recall(
|
||||||
|
&self,
|
||||||
|
query: &str,
|
||||||
|
limit: usize,
|
||||||
|
session_id: Option<&str>,
|
||||||
|
) -> anyhow::Result<Vec<MemoryEntry>>;
|
||||||
|
|
||||||
/// Get a specific memory by key
|
/// Get a specific memory by key
|
||||||
async fn get(&self, key: &str) -> anyhow::Result<Option<MemoryEntry>>;
|
async fn get(&self, key: &str) -> anyhow::Result<Option<MemoryEntry>>;
|
||||||
|
|
||||||
/// List all memory keys, optionally filtered by category
|
/// List all memory keys, optionally filtered by category and/or session
|
||||||
async fn list(&self, category: Option<&MemoryCategory>) -> anyhow::Result<Vec<MemoryEntry>>;
|
async fn list(
|
||||||
|
&self,
|
||||||
|
category: Option<&MemoryCategory>,
|
||||||
|
session_id: Option<&str>,
|
||||||
|
) -> anyhow::Result<Vec<MemoryEntry>>;
|
||||||
|
|
||||||
/// Remove a memory by key
|
/// Remove a memory by key
|
||||||
async fn forget(&self, key: &str) -> anyhow::Result<bool>;
|
async fn forget(&self, key: &str) -> anyhow::Result<bool>;
|
||||||
|
|
|
||||||
|
|
@ -95,7 +95,9 @@ async fn migrate_openclaw_memory(
|
||||||
stats.renamed_conflicts += 1;
|
stats.renamed_conflicts += 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
memory.store(&key, &entry.content, entry.category).await?;
|
memory
|
||||||
|
.store(&key, &entry.content, entry.category, None)
|
||||||
|
.await?;
|
||||||
stats.imported += 1;
|
stats.imported += 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -488,7 +490,7 @@ mod tests {
|
||||||
// Existing target memory
|
// Existing target memory
|
||||||
let target_mem = SqliteMemory::new(target.path()).unwrap();
|
let target_mem = SqliteMemory::new(target.path()).unwrap();
|
||||||
target_mem
|
target_mem
|
||||||
.store("k", "new value", MemoryCategory::Core)
|
.store("k", "new value", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
|
|
@ -510,7 +512,7 @@ mod tests {
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let all = target_mem.list(None).await.unwrap();
|
let all = target_mem.list(None, None).await.unwrap();
|
||||||
assert!(all.iter().any(|e| e.key == "k" && e.content == "new value"));
|
assert!(all.iter().any(|e| e.key == "k" && e.content == "new value"));
|
||||||
assert!(all
|
assert!(all
|
||||||
.iter()
|
.iter()
|
||||||
|
|
|
||||||
|
|
@ -87,7 +87,7 @@ mod tests {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn forget_existing() {
|
async fn forget_existing() {
|
||||||
let (_tmp, mem) = test_mem();
|
let (_tmp, mem) = test_mem();
|
||||||
mem.store("temp", "temporary", MemoryCategory::Conversation)
|
mem.store("temp", "temporary", MemoryCategory::Conversation, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -55,7 +55,7 @@ impl Tool for MemoryRecallTool {
|
||||||
.and_then(serde_json::Value::as_u64)
|
.and_then(serde_json::Value::as_u64)
|
||||||
.map_or(5, |v| v as usize);
|
.map_or(5, |v| v as usize);
|
||||||
|
|
||||||
match self.memory.recall(query, limit).await {
|
match self.memory.recall(query, limit, None).await {
|
||||||
Ok(entries) if entries.is_empty() => Ok(ToolResult {
|
Ok(entries) if entries.is_empty() => Ok(ToolResult {
|
||||||
success: true,
|
success: true,
|
||||||
output: "No memories found matching that query.".into(),
|
output: "No memories found matching that query.".into(),
|
||||||
|
|
@ -112,10 +112,10 @@ mod tests {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn recall_finds_match() {
|
async fn recall_finds_match() {
|
||||||
let (_tmp, mem) = seeded_mem();
|
let (_tmp, mem) = seeded_mem();
|
||||||
mem.store("lang", "User prefers Rust", MemoryCategory::Core)
|
mem.store("lang", "User prefers Rust", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
mem.store("tz", "Timezone is EST", MemoryCategory::Core)
|
mem.store("tz", "Timezone is EST", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
|
|
@ -134,6 +134,7 @@ mod tests {
|
||||||
&format!("k{i}"),
|
&format!("k{i}"),
|
||||||
&format!("Rust fact {i}"),
|
&format!("Rust fact {i}"),
|
||||||
MemoryCategory::Core,
|
MemoryCategory::Core,
|
||||||
|
None,
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
|
||||||
|
|
@ -64,7 +64,7 @@ impl Tool for MemoryStoreTool {
|
||||||
_ => MemoryCategory::Core,
|
_ => MemoryCategory::Core,
|
||||||
};
|
};
|
||||||
|
|
||||||
match self.memory.store(key, content, category).await {
|
match self.memory.store(key, content, category, None).await {
|
||||||
Ok(()) => Ok(ToolResult {
|
Ok(()) => Ok(ToolResult {
|
||||||
success: true,
|
success: true,
|
||||||
output: format!("Stored memory: {key}"),
|
output: format!("Stored memory: {key}"),
|
||||||
|
|
|
||||||
|
|
@ -36,6 +36,7 @@ async fn compare_store_speed() {
|
||||||
&format!("key_{i}"),
|
&format!("key_{i}"),
|
||||||
&format!("Memory entry number {i} about Rust programming"),
|
&format!("Memory entry number {i} about Rust programming"),
|
||||||
MemoryCategory::Core,
|
MemoryCategory::Core,
|
||||||
|
None,
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
@ -49,6 +50,7 @@ async fn compare_store_speed() {
|
||||||
&format!("key_{i}"),
|
&format!("key_{i}"),
|
||||||
&format!("Memory entry number {i} about Rust programming"),
|
&format!("Memory entry number {i} about Rust programming"),
|
||||||
MemoryCategory::Core,
|
MemoryCategory::Core,
|
||||||
|
None,
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
@ -127,8 +129,8 @@ async fn compare_recall_quality() {
|
||||||
];
|
];
|
||||||
|
|
||||||
for (key, content, cat) in &entries {
|
for (key, content, cat) in &entries {
|
||||||
sq.store(key, content, cat.clone()).await.unwrap();
|
sq.store(key, content, cat.clone(), None).await.unwrap();
|
||||||
md.store(key, content, cat.clone()).await.unwrap();
|
md.store(key, content, cat.clone(), None).await.unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test queries and compare results
|
// Test queries and compare results
|
||||||
|
|
@ -145,8 +147,8 @@ async fn compare_recall_quality() {
|
||||||
println!("RECALL QUALITY (10 entries seeded):\n");
|
println!("RECALL QUALITY (10 entries seeded):\n");
|
||||||
|
|
||||||
for (query, desc) in &queries {
|
for (query, desc) in &queries {
|
||||||
let sq_results = sq.recall(query, 10).await.unwrap();
|
let sq_results = sq.recall(query, 10, None).await.unwrap();
|
||||||
let md_results = md.recall(query, 10).await.unwrap();
|
let md_results = md.recall(query, 10, None).await.unwrap();
|
||||||
|
|
||||||
println!(" Query: \"{query}\" — {desc}");
|
println!(" Query: \"{query}\" — {desc}");
|
||||||
println!(" SQLite: {} results", sq_results.len());
|
println!(" SQLite: {} results", sq_results.len());
|
||||||
|
|
@ -190,21 +192,21 @@ async fn compare_recall_speed() {
|
||||||
} else {
|
} else {
|
||||||
format!("TypeScript powers modern web apps, entry {i}")
|
format!("TypeScript powers modern web apps, entry {i}")
|
||||||
};
|
};
|
||||||
sq.store(&format!("e{i}"), &content, MemoryCategory::Core)
|
sq.store(&format!("e{i}"), &content, MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
md.store(&format!("e{i}"), &content, MemoryCategory::Daily)
|
md.store(&format!("e{i}"), &content, MemoryCategory::Daily, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Benchmark recall
|
// Benchmark recall
|
||||||
let start = Instant::now();
|
let start = Instant::now();
|
||||||
let sq_results = sq.recall("Rust systems", 10).await.unwrap();
|
let sq_results = sq.recall("Rust systems", 10, None).await.unwrap();
|
||||||
let sq_dur = start.elapsed();
|
let sq_dur = start.elapsed();
|
||||||
|
|
||||||
let start = Instant::now();
|
let start = Instant::now();
|
||||||
let md_results = md.recall("Rust systems", 10).await.unwrap();
|
let md_results = md.recall("Rust systems", 10, None).await.unwrap();
|
||||||
let md_dur = start.elapsed();
|
let md_dur = start.elapsed();
|
||||||
|
|
||||||
println!("\n============================================================");
|
println!("\n============================================================");
|
||||||
|
|
@ -227,13 +229,23 @@ async fn compare_persistence() {
|
||||||
// Store in both, then drop and re-open
|
// Store in both, then drop and re-open
|
||||||
{
|
{
|
||||||
let sq = sqlite_backend(tmp_sq.path());
|
let sq = sqlite_backend(tmp_sq.path());
|
||||||
sq.store("persist_test", "I should survive", MemoryCategory::Core)
|
sq.store(
|
||||||
|
"persist_test",
|
||||||
|
"I should survive",
|
||||||
|
MemoryCategory::Core,
|
||||||
|
None,
|
||||||
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
let md = markdown_backend(tmp_md.path());
|
let md = markdown_backend(tmp_md.path());
|
||||||
md.store("persist_test", "I should survive", MemoryCategory::Core)
|
md.store(
|
||||||
|
"persist_test",
|
||||||
|
"I should survive",
|
||||||
|
MemoryCategory::Core,
|
||||||
|
None,
|
||||||
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
}
|
}
|
||||||
|
|
@ -282,17 +294,17 @@ async fn compare_upsert() {
|
||||||
let md = markdown_backend(tmp_md.path());
|
let md = markdown_backend(tmp_md.path());
|
||||||
|
|
||||||
// Store twice with same key, different content
|
// Store twice with same key, different content
|
||||||
sq.store("pref", "likes Rust", MemoryCategory::Core)
|
sq.store("pref", "likes Rust", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
sq.store("pref", "loves Rust", MemoryCategory::Core)
|
sq.store("pref", "loves Rust", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
md.store("pref", "likes Rust", MemoryCategory::Core)
|
md.store("pref", "likes Rust", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
md.store("pref", "loves Rust", MemoryCategory::Core)
|
md.store("pref", "loves Rust", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
|
|
@ -300,7 +312,7 @@ async fn compare_upsert() {
|
||||||
let md_count = md.count().await.unwrap();
|
let md_count = md.count().await.unwrap();
|
||||||
|
|
||||||
let sq_entry = sq.get("pref").await.unwrap();
|
let sq_entry = sq.get("pref").await.unwrap();
|
||||||
let md_results = md.recall("loves Rust", 5).await.unwrap();
|
let md_results = md.recall("loves Rust", 5, None).await.unwrap();
|
||||||
|
|
||||||
println!("\n============================================================");
|
println!("\n============================================================");
|
||||||
println!("UPSERT (store same key twice):");
|
println!("UPSERT (store same key twice):");
|
||||||
|
|
@ -328,10 +340,10 @@ async fn compare_forget() {
|
||||||
let sq = sqlite_backend(tmp_sq.path());
|
let sq = sqlite_backend(tmp_sq.path());
|
||||||
let md = markdown_backend(tmp_md.path());
|
let md = markdown_backend(tmp_md.path());
|
||||||
|
|
||||||
sq.store("secret", "API key: sk-1234", MemoryCategory::Core)
|
sq.store("secret", "API key: sk-1234", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
md.store("secret", "API key: sk-1234", MemoryCategory::Core)
|
md.store("secret", "API key: sk-1234", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
|
|
@ -372,37 +384,40 @@ async fn compare_category_filter() {
|
||||||
let md = markdown_backend(tmp_md.path());
|
let md = markdown_backend(tmp_md.path());
|
||||||
|
|
||||||
// Mix of categories
|
// Mix of categories
|
||||||
sq.store("a", "core fact 1", MemoryCategory::Core)
|
sq.store("a", "core fact 1", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
sq.store("b", "core fact 2", MemoryCategory::Core)
|
sq.store("b", "core fact 2", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
sq.store("c", "daily note", MemoryCategory::Daily)
|
sq.store("c", "daily note", MemoryCategory::Daily, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
sq.store("d", "convo msg", MemoryCategory::Conversation)
|
sq.store("d", "convo msg", MemoryCategory::Conversation, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
md.store("a", "core fact 1", MemoryCategory::Core)
|
md.store("a", "core fact 1", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
md.store("b", "core fact 2", MemoryCategory::Core)
|
md.store("b", "core fact 2", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
md.store("c", "daily note", MemoryCategory::Daily)
|
md.store("c", "daily note", MemoryCategory::Daily, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let sq_core = sq.list(Some(&MemoryCategory::Core)).await.unwrap();
|
let sq_core = sq.list(Some(&MemoryCategory::Core), None).await.unwrap();
|
||||||
let sq_daily = sq.list(Some(&MemoryCategory::Daily)).await.unwrap();
|
let sq_daily = sq.list(Some(&MemoryCategory::Daily), None).await.unwrap();
|
||||||
let sq_conv = sq.list(Some(&MemoryCategory::Conversation)).await.unwrap();
|
let sq_conv = sq
|
||||||
let sq_all = sq.list(None).await.unwrap();
|
.list(Some(&MemoryCategory::Conversation), None)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
let sq_all = sq.list(None, None).await.unwrap();
|
||||||
|
|
||||||
let md_core = md.list(Some(&MemoryCategory::Core)).await.unwrap();
|
let md_core = md.list(Some(&MemoryCategory::Core), None).await.unwrap();
|
||||||
let md_daily = md.list(Some(&MemoryCategory::Daily)).await.unwrap();
|
let md_daily = md.list(Some(&MemoryCategory::Daily), None).await.unwrap();
|
||||||
let md_all = md.list(None).await.unwrap();
|
let md_all = md.list(None, None).await.unwrap();
|
||||||
|
|
||||||
println!("\n============================================================");
|
println!("\n============================================================");
|
||||||
println!("CATEGORY FILTERING:");
|
println!("CATEGORY FILTERING:");
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue