diff --git a/src/agent/agent.rs b/src/agent/agent.rs index 44e40b6..4495736 100644 --- a/src/agent/agent.rs +++ b/src/agent/agent.rs @@ -389,7 +389,7 @@ impl Agent { if self.auto_save { let _ = self .memory - .store("user_msg", user_message, MemoryCategory::Conversation) + .store("user_msg", user_message, MemoryCategory::Conversation, None) .await; } @@ -448,7 +448,7 @@ impl Agent { let summary = truncate_with_ellipsis(&final_text, 100); let _ = self .memory - .store("assistant_resp", &summary, MemoryCategory::Daily) + .store("assistant_resp", &summary, MemoryCategory::Daily, None) .await; } diff --git a/src/agent/loop_.rs b/src/agent/loop_.rs index 4f4d84c..fd04b63 100644 --- a/src/agent/loop_.rs +++ b/src/agent/loop_.rs @@ -145,7 +145,7 @@ async fn build_context(mem: &dyn Memory, user_msg: &str) -> String { let mut context = String::new(); // Pull relevant memories for this message - if let Ok(entries) = mem.recall(user_msg, 5).await { + if let Ok(entries) = mem.recall(user_msg, 5, None).await { if !entries.is_empty() { context.push_str("[Memory context]\n"); for entry in &entries { @@ -913,7 +913,7 @@ pub async fn run( if config.memory.auto_save { let user_key = autosave_memory_key("user_msg"); let _ = mem - .store(&user_key, &msg, MemoryCategory::Conversation) + .store(&user_key, &msg, MemoryCategory::Conversation, None) .await; } @@ -956,7 +956,7 @@ pub async fn run( let summary = truncate_with_ellipsis(&response, 100); let response_key = autosave_memory_key("assistant_resp"); let _ = mem - .store(&response_key, &summary, MemoryCategory::Daily) + .store(&response_key, &summary, MemoryCategory::Daily, None) .await; } } else { @@ -979,7 +979,7 @@ pub async fn run( if config.memory.auto_save { let user_key = autosave_memory_key("user_msg"); let _ = mem - .store(&user_key, &msg.content, MemoryCategory::Conversation) + .store(&user_key, &msg.content, MemoryCategory::Conversation, None) .await; } @@ -1037,7 +1037,7 @@ pub async fn run( let summary = truncate_with_ellipsis(&response, 100); let response_key = autosave_memory_key("assistant_resp"); let _ = mem - .store(&response_key, &summary, MemoryCategory::Daily) + .store(&response_key, &summary, MemoryCategory::Daily, None) .await; } } @@ -1499,16 +1499,16 @@ I will now call the tool with this payload: let key1 = autosave_memory_key("user_msg"); let key2 = autosave_memory_key("user_msg"); - mem.store(&key1, "I'm Paul", MemoryCategory::Conversation) + mem.store(&key1, "I'm Paul", MemoryCategory::Conversation, None) .await .unwrap(); - mem.store(&key2, "I'm 45", MemoryCategory::Conversation) + mem.store(&key2, "I'm 45", MemoryCategory::Conversation, None) .await .unwrap(); assert_eq!(mem.count().await.unwrap(), 2); - let recalled = mem.recall("45", 5).await.unwrap(); + let recalled = mem.recall("45", 5, None).await.unwrap(); assert!(recalled.iter().any(|entry| entry.content.contains("45"))); } diff --git a/src/agent/memory_loader.rs b/src/agent/memory_loader.rs index f5733ec..0cc530f 100644 --- a/src/agent/memory_loader.rs +++ b/src/agent/memory_loader.rs @@ -33,7 +33,7 @@ impl MemoryLoader for DefaultMemoryLoader { memory: &dyn Memory, user_message: &str, ) -> anyhow::Result { - let entries = memory.recall(user_message, self.limit).await?; + let entries = memory.recall(user_message, self.limit, None).await?; if entries.is_empty() { return Ok(String::new()); } @@ -61,11 +61,17 @@ mod tests { _key: &str, _content: &str, _category: MemoryCategory, + _session_id: Option<&str>, ) -> anyhow::Result<()> { Ok(()) } - async fn recall(&self, _query: &str, limit: usize) -> anyhow::Result> { + async fn recall( + &self, + _query: &str, + limit: usize, + _session_id: Option<&str>, + ) -> anyhow::Result> { if limit == 0 { return Ok(vec![]); } @@ -87,6 +93,7 @@ mod tests { async fn list( &self, _category: Option<&MemoryCategory>, + _session_id: Option<&str>, ) -> anyhow::Result> { Ok(vec![]) } diff --git a/src/channels/mod.rs b/src/channels/mod.rs index 6c21fe8..783ce04 100644 --- a/src/channels/mod.rs +++ b/src/channels/mod.rs @@ -72,7 +72,7 @@ fn conversation_memory_key(msg: &traits::ChannelMessage) -> String { async fn build_memory_context(mem: &dyn Memory, user_msg: &str) -> String { let mut context = String::new(); - if let Ok(entries) = mem.recall(user_msg, 5).await { + if let Ok(entries) = mem.recall(user_msg, 5, None).await { if !entries.is_empty() { context.push_str("[Memory context]\n"); for entry in &entries { @@ -158,6 +158,7 @@ async fn process_channel_message(ctx: Arc, msg: traits::C &autosave_key, &msg.content, crate::memory::MemoryCategory::Conversation, + None, ) .await; } @@ -1260,6 +1261,7 @@ mod tests { _key: &str, _content: &str, _category: crate::memory::MemoryCategory, + _session_id: Option<&str>, ) -> anyhow::Result<()> { Ok(()) } @@ -1268,6 +1270,7 @@ mod tests { &self, _query: &str, _limit: usize, + _session_id: Option<&str>, ) -> anyhow::Result> { Ok(Vec::new()) } @@ -1279,6 +1282,7 @@ mod tests { async fn list( &self, _category: Option<&crate::memory::MemoryCategory>, + _session_id: Option<&str>, ) -> anyhow::Result> { Ok(Vec::new()) } @@ -1636,6 +1640,7 @@ mod tests { &conversation_memory_key(&msg1), &msg1.content, MemoryCategory::Conversation, + None, ) .await .unwrap(); @@ -1643,13 +1648,14 @@ mod tests { &conversation_memory_key(&msg2), &msg2.content, MemoryCategory::Conversation, + None, ) .await .unwrap(); assert_eq!(mem.count().await.unwrap(), 2); - let recalled = mem.recall("45", 5).await.unwrap(); + let recalled = mem.recall("45", 5, None).await.unwrap(); assert!(recalled.iter().any(|entry| entry.content.contains("45"))); } @@ -1657,7 +1663,7 @@ mod tests { async fn build_memory_context_includes_recalled_entries() { let tmp = TempDir::new().unwrap(); let mem = SqliteMemory::new(tmp.path()).unwrap(); - mem.store("age_fact", "Age is 45", MemoryCategory::Conversation) + mem.store("age_fact", "Age is 45", MemoryCategory::Conversation, None) .await .unwrap(); diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index df500a5..86111da 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -544,7 +544,7 @@ async fn handle_webhook( let key = webhook_memory_key(); let _ = state .mem - .store(&key, message, MemoryCategory::Conversation) + .store(&key, message, MemoryCategory::Conversation, None) .await; } @@ -697,7 +697,7 @@ async fn handle_whatsapp_message( let key = whatsapp_memory_key(msg); let _ = state .mem - .store(&key, &msg.content, MemoryCategory::Conversation) + .store(&key, &msg.content, MemoryCategory::Conversation, None) .await; } @@ -886,11 +886,17 @@ mod tests { _key: &str, _content: &str, _category: MemoryCategory, + _session_id: Option<&str>, ) -> anyhow::Result<()> { Ok(()) } - async fn recall(&self, _query: &str, _limit: usize) -> anyhow::Result> { + async fn recall( + &self, + _query: &str, + _limit: usize, + _session_id: Option<&str>, + ) -> anyhow::Result> { Ok(Vec::new()) } @@ -901,6 +907,7 @@ mod tests { async fn list( &self, _category: Option<&MemoryCategory>, + _session_id: Option<&str>, ) -> anyhow::Result> { Ok(Vec::new()) } @@ -953,6 +960,7 @@ mod tests { key: &str, _content: &str, _category: MemoryCategory, + _session_id: Option<&str>, ) -> anyhow::Result<()> { self.keys .lock() @@ -961,7 +969,12 @@ mod tests { Ok(()) } - async fn recall(&self, _query: &str, _limit: usize) -> anyhow::Result> { + async fn recall( + &self, + _query: &str, + _limit: usize, + _session_id: Option<&str>, + ) -> anyhow::Result> { Ok(Vec::new()) } @@ -972,6 +985,7 @@ mod tests { async fn list( &self, _category: Option<&MemoryCategory>, + _session_id: Option<&str>, ) -> anyhow::Result> { Ok(Vec::new()) } diff --git a/src/memory/hygiene.rs b/src/memory/hygiene.rs index cf58e21..01054ce 100644 --- a/src/memory/hygiene.rs +++ b/src/memory/hygiene.rs @@ -502,10 +502,10 @@ mod tests { let workspace = tmp.path(); let mem = SqliteMemory::new(workspace).unwrap(); - mem.store("conv_old", "outdated", MemoryCategory::Conversation) + mem.store("conv_old", "outdated", MemoryCategory::Conversation, None) .await .unwrap(); - mem.store("core_keep", "durable", MemoryCategory::Core) + mem.store("core_keep", "durable", MemoryCategory::Core, None) .await .unwrap(); drop(mem); diff --git a/src/memory/lucid.rs b/src/memory/lucid.rs index 9a0e84d..4747bbd 100644 --- a/src/memory/lucid.rs +++ b/src/memory/lucid.rs @@ -314,14 +314,22 @@ impl Memory for LucidMemory { key: &str, content: &str, category: MemoryCategory, + session_id: Option<&str>, ) -> anyhow::Result<()> { - self.local.store(key, content, category.clone()).await?; + self.local + .store(key, content, category.clone(), session_id) + .await?; self.sync_to_lucid_async(key, content, &category).await; Ok(()) } - async fn recall(&self, query: &str, limit: usize) -> anyhow::Result> { - let local_results = self.local.recall(query, limit).await?; + async fn recall( + &self, + query: &str, + limit: usize, + session_id: Option<&str>, + ) -> anyhow::Result> { + let local_results = self.local.recall(query, limit, session_id).await?; if limit == 0 || local_results.len() >= limit || local_results.len() >= self.local_hit_threshold @@ -358,8 +366,12 @@ impl Memory for LucidMemory { self.local.get(key).await } - async fn list(&self, category: Option<&MemoryCategory>) -> anyhow::Result> { - self.local.list(category).await + async fn list( + &self, + category: Option<&MemoryCategory>, + session_id: Option<&str>, + ) -> anyhow::Result> { + self.local.list(category, session_id).await } async fn forget(&self, key: &str) -> anyhow::Result { @@ -475,7 +487,7 @@ exit 1 let memory = test_memory(tmp.path(), "nonexistent-lucid-binary".to_string()); memory - .store("lang", "User prefers Rust", MemoryCategory::Core) + .store("lang", "User prefers Rust", MemoryCategory::Core, None) .await .unwrap(); @@ -495,11 +507,12 @@ exit 1 "local_note", "Local sqlite auth fallback note", MemoryCategory::Core, + None, ) .await .unwrap(); - let entries = memory.recall("auth", 5).await.unwrap(); + let entries = memory.recall("auth", 5, None).await.unwrap(); assert!(entries .iter() @@ -526,11 +539,16 @@ exit 1 ); memory - .store("pref", "Rust should stay local-first", MemoryCategory::Core) + .store( + "pref", + "Rust should stay local-first", + MemoryCategory::Core, + None, + ) .await .unwrap(); - let entries = memory.recall("rust", 5).await.unwrap(); + let entries = memory.recall("rust", 5, None).await.unwrap(); assert!(entries .iter() .any(|e| e.content.contains("Rust should stay local-first"))); @@ -590,8 +608,8 @@ exit 1 Duration::from_secs(5), ); - let first = memory.recall("auth", 5).await.unwrap(); - let second = memory.recall("auth", 5).await.unwrap(); + let first = memory.recall("auth", 5, None).await.unwrap(); + let second = memory.recall("auth", 5, None).await.unwrap(); assert!(first.is_empty()); assert!(second.is_empty()); diff --git a/src/memory/markdown.rs b/src/memory/markdown.rs index 8dcd667..9038683 100644 --- a/src/memory/markdown.rs +++ b/src/memory/markdown.rs @@ -143,6 +143,7 @@ impl Memory for MarkdownMemory { key: &str, content: &str, category: MemoryCategory, + _session_id: Option<&str>, ) -> anyhow::Result<()> { let entry = format!("- **{key}**: {content}"); let path = match category { @@ -152,7 +153,12 @@ impl Memory for MarkdownMemory { self.append_to_file(&path, &entry).await } - async fn recall(&self, query: &str, limit: usize) -> anyhow::Result> { + async fn recall( + &self, + query: &str, + limit: usize, + _session_id: Option<&str>, + ) -> anyhow::Result> { let all = self.read_all_entries().await?; let query_lower = query.to_lowercase(); let keywords: Vec<&str> = query_lower.split_whitespace().collect(); @@ -192,7 +198,11 @@ impl Memory for MarkdownMemory { .find(|e| e.key == key || e.content.contains(key))) } - async fn list(&self, category: Option<&MemoryCategory>) -> anyhow::Result> { + async fn list( + &self, + category: Option<&MemoryCategory>, + _session_id: Option<&str>, + ) -> anyhow::Result> { let all = self.read_all_entries().await?; match category { Some(cat) => Ok(all.into_iter().filter(|e| &e.category == cat).collect()), @@ -243,7 +253,7 @@ mod tests { #[tokio::test] async fn markdown_store_core() { let (_tmp, mem) = temp_workspace(); - mem.store("pref", "User likes Rust", MemoryCategory::Core) + mem.store("pref", "User likes Rust", MemoryCategory::Core, None) .await .unwrap(); let content = sync_fs::read_to_string(mem.core_path()).unwrap(); @@ -253,7 +263,7 @@ mod tests { #[tokio::test] async fn markdown_store_daily() { let (_tmp, mem) = temp_workspace(); - mem.store("note", "Finished tests", MemoryCategory::Daily) + mem.store("note", "Finished tests", MemoryCategory::Daily, None) .await .unwrap(); let path = mem.daily_path(); @@ -264,17 +274,17 @@ mod tests { #[tokio::test] async fn markdown_recall_keyword() { let (_tmp, mem) = temp_workspace(); - mem.store("a", "Rust is fast", MemoryCategory::Core) + mem.store("a", "Rust is fast", MemoryCategory::Core, None) .await .unwrap(); - mem.store("b", "Python is slow", MemoryCategory::Core) + mem.store("b", "Python is slow", MemoryCategory::Core, None) .await .unwrap(); - mem.store("c", "Rust and safety", MemoryCategory::Core) + mem.store("c", "Rust and safety", MemoryCategory::Core, None) .await .unwrap(); - let results = mem.recall("Rust", 10).await.unwrap(); + let results = mem.recall("Rust", 10, None).await.unwrap(); assert!(results.len() >= 2); assert!(results .iter() @@ -284,18 +294,20 @@ mod tests { #[tokio::test] async fn markdown_recall_no_match() { let (_tmp, mem) = temp_workspace(); - mem.store("a", "Rust is great", MemoryCategory::Core) + mem.store("a", "Rust is great", MemoryCategory::Core, None) .await .unwrap(); - let results = mem.recall("javascript", 10).await.unwrap(); + let results = mem.recall("javascript", 10, None).await.unwrap(); assert!(results.is_empty()); } #[tokio::test] async fn markdown_count() { let (_tmp, mem) = temp_workspace(); - mem.store("a", "first", MemoryCategory::Core).await.unwrap(); - mem.store("b", "second", MemoryCategory::Core) + mem.store("a", "first", MemoryCategory::Core, None) + .await + .unwrap(); + mem.store("b", "second", MemoryCategory::Core, None) .await .unwrap(); let count = mem.count().await.unwrap(); @@ -305,24 +317,24 @@ mod tests { #[tokio::test] async fn markdown_list_by_category() { let (_tmp, mem) = temp_workspace(); - mem.store("a", "core fact", MemoryCategory::Core) + mem.store("a", "core fact", MemoryCategory::Core, None) .await .unwrap(); - mem.store("b", "daily note", MemoryCategory::Daily) + mem.store("b", "daily note", MemoryCategory::Daily, None) .await .unwrap(); - let core = mem.list(Some(&MemoryCategory::Core)).await.unwrap(); + let core = mem.list(Some(&MemoryCategory::Core), None).await.unwrap(); assert!(core.iter().all(|e| e.category == MemoryCategory::Core)); - let daily = mem.list(Some(&MemoryCategory::Daily)).await.unwrap(); + let daily = mem.list(Some(&MemoryCategory::Daily), None).await.unwrap(); assert!(daily.iter().all(|e| e.category == MemoryCategory::Daily)); } #[tokio::test] async fn markdown_forget_is_noop() { let (_tmp, mem) = temp_workspace(); - mem.store("a", "permanent", MemoryCategory::Core) + mem.store("a", "permanent", MemoryCategory::Core, None) .await .unwrap(); let removed = mem.forget("a").await.unwrap(); @@ -332,7 +344,7 @@ mod tests { #[tokio::test] async fn markdown_empty_recall() { let (_tmp, mem) = temp_workspace(); - let results = mem.recall("anything", 10).await.unwrap(); + let results = mem.recall("anything", 10, None).await.unwrap(); assert!(results.is_empty()); } diff --git a/src/memory/none.rs b/src/memory/none.rs index 6057ad0..4ccd2f8 100644 --- a/src/memory/none.rs +++ b/src/memory/none.rs @@ -25,11 +25,17 @@ impl Memory for NoneMemory { _key: &str, _content: &str, _category: MemoryCategory, + _session_id: Option<&str>, ) -> anyhow::Result<()> { Ok(()) } - async fn recall(&self, _query: &str, _limit: usize) -> anyhow::Result> { + async fn recall( + &self, + _query: &str, + _limit: usize, + _session_id: Option<&str>, + ) -> anyhow::Result> { Ok(Vec::new()) } @@ -37,7 +43,11 @@ impl Memory for NoneMemory { Ok(None) } - async fn list(&self, _category: Option<&MemoryCategory>) -> anyhow::Result> { + async fn list( + &self, + _category: Option<&MemoryCategory>, + _session_id: Option<&str>, + ) -> anyhow::Result> { Ok(Vec::new()) } @@ -62,11 +72,14 @@ mod tests { async fn none_memory_is_noop() { let memory = NoneMemory::new(); - memory.store("k", "v", MemoryCategory::Core).await.unwrap(); + memory + .store("k", "v", MemoryCategory::Core, None) + .await + .unwrap(); assert!(memory.get("k").await.unwrap().is_none()); - assert!(memory.recall("k", 10).await.unwrap().is_empty()); - assert!(memory.list(None).await.unwrap().is_empty()); + assert!(memory.recall("k", 10, None).await.unwrap().is_empty()); + assert!(memory.list(None, None).await.unwrap().is_empty()); assert!(!memory.forget("k").await.unwrap()); assert_eq!(memory.count().await.unwrap(), 0); assert!(memory.health_check().await); diff --git a/src/memory/sqlite.rs b/src/memory/sqlite.rs index 6219989..f5df9a3 100644 --- a/src/memory/sqlite.rs +++ b/src/memory/sqlite.rs @@ -123,6 +123,19 @@ impl SqliteMemory { ); CREATE INDEX IF NOT EXISTS idx_cache_accessed ON embedding_cache(accessed_at);", )?; + + // Migration: add session_id column if not present (safe to run repeatedly) + let has_session_id: bool = conn + .prepare("SELECT sql FROM sqlite_master WHERE type='table' AND name='memories'")? + .query_row([], |row| row.get::<_, String>(0))? + .contains("session_id"); + if !has_session_id { + conn.execute_batch( + "ALTER TABLE memories ADD COLUMN session_id TEXT; + CREATE INDEX IF NOT EXISTS idx_memories_session ON memories(session_id);", + )?; + } + Ok(()) } @@ -360,6 +373,7 @@ impl Memory for SqliteMemory { key: &str, content: &str, category: MemoryCategory, + session_id: Option<&str>, ) -> anyhow::Result<()> { // Compute embedding (async, before lock) let embedding_bytes = self @@ -376,20 +390,26 @@ impl Memory for SqliteMemory { let id = Uuid::new_v4().to_string(); conn.execute( - "INSERT INTO memories (id, key, content, category, embedding, created_at, updated_at) - VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7) + "INSERT INTO memories (id, key, content, category, embedding, created_at, updated_at, session_id) + VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8) ON CONFLICT(key) DO UPDATE SET content = excluded.content, category = excluded.category, embedding = excluded.embedding, - updated_at = excluded.updated_at", - params![id, key, content, cat, embedding_bytes, now, now], + updated_at = excluded.updated_at, + session_id = excluded.session_id", + params![id, key, content, cat, embedding_bytes, now, now, session_id], )?; Ok(()) } - async fn recall(&self, query: &str, limit: usize) -> anyhow::Result> { + async fn recall( + &self, + query: &str, + limit: usize, + session_id: Option<&str>, + ) -> anyhow::Result> { if query.trim().is_empty() { return Ok(Vec::new()); } @@ -438,7 +458,7 @@ impl Memory for SqliteMemory { let mut results = Vec::new(); for scored in &merged { let mut stmt = conn.prepare( - "SELECT id, key, content, category, created_at FROM memories WHERE id = ?1", + "SELECT id, key, content, category, created_at, session_id FROM memories WHERE id = ?1", )?; if let Ok(entry) = stmt.query_row(params![scored.id], |row| { Ok(MemoryEntry { @@ -447,10 +467,16 @@ impl Memory for SqliteMemory { content: row.get(2)?, category: Self::str_to_category(&row.get::<_, String>(3)?), timestamp: row.get(4)?, - session_id: None, + session_id: row.get(5)?, score: Some(f64::from(scored.final_score)), }) }) { + // Filter by session_id if requested + if let Some(sid) = session_id { + if entry.session_id.as_deref() != Some(sid) { + continue; + } + } results.push(entry); } } @@ -469,7 +495,7 @@ impl Memory for SqliteMemory { .collect(); let where_clause = conditions.join(" OR "); let sql = format!( - "SELECT id, key, content, category, created_at FROM memories + "SELECT id, key, content, category, created_at, session_id FROM memories WHERE {where_clause} ORDER BY updated_at DESC LIMIT ?{}", @@ -492,12 +518,18 @@ impl Memory for SqliteMemory { content: row.get(2)?, category: Self::str_to_category(&row.get::<_, String>(3)?), timestamp: row.get(4)?, - session_id: None, + session_id: row.get(5)?, score: Some(1.0), }) })?; for row in rows { - results.push(row?); + let entry = row?; + if let Some(sid) = session_id { + if entry.session_id.as_deref() != Some(sid) { + continue; + } + } + results.push(entry); } } } @@ -513,7 +545,7 @@ impl Memory for SqliteMemory { .map_err(|e| anyhow::anyhow!("Lock error: {e}"))?; let mut stmt = conn.prepare( - "SELECT id, key, content, category, created_at FROM memories WHERE key = ?1", + "SELECT id, key, content, category, created_at, session_id FROM memories WHERE key = ?1", )?; let mut rows = stmt.query_map(params![key], |row| { @@ -523,7 +555,7 @@ impl Memory for SqliteMemory { content: row.get(2)?, category: Self::str_to_category(&row.get::<_, String>(3)?), timestamp: row.get(4)?, - session_id: None, + session_id: row.get(5)?, score: None, }) })?; @@ -534,7 +566,11 @@ impl Memory for SqliteMemory { } } - async fn list(&self, category: Option<&MemoryCategory>) -> anyhow::Result> { + async fn list( + &self, + category: Option<&MemoryCategory>, + session_id: Option<&str>, + ) -> anyhow::Result> { let conn = self .conn .lock() @@ -549,7 +585,7 @@ impl Memory for SqliteMemory { content: row.get(2)?, category: Self::str_to_category(&row.get::<_, String>(3)?), timestamp: row.get(4)?, - session_id: None, + session_id: row.get(5)?, score: None, }) }; @@ -557,21 +593,33 @@ impl Memory for SqliteMemory { if let Some(cat) = category { let cat_str = Self::category_to_str(cat); let mut stmt = conn.prepare( - "SELECT id, key, content, category, created_at FROM memories + "SELECT id, key, content, category, created_at, session_id FROM memories WHERE category = ?1 ORDER BY updated_at DESC", )?; let rows = stmt.query_map(params![cat_str], row_mapper)?; for row in rows { - results.push(row?); + let entry = row?; + if let Some(sid) = session_id { + if entry.session_id.as_deref() != Some(sid) { + continue; + } + } + results.push(entry); } } else { let mut stmt = conn.prepare( - "SELECT id, key, content, category, created_at FROM memories + "SELECT id, key, content, category, created_at, session_id FROM memories ORDER BY updated_at DESC", )?; let rows = stmt.query_map([], row_mapper)?; for row in rows { - results.push(row?); + let entry = row?; + if let Some(sid) = session_id { + if entry.session_id.as_deref() != Some(sid) { + continue; + } + } + results.push(entry); } } @@ -631,7 +679,7 @@ mod tests { #[tokio::test] async fn sqlite_store_and_get() { let (_tmp, mem) = temp_sqlite(); - mem.store("user_lang", "Prefers Rust", MemoryCategory::Core) + mem.store("user_lang", "Prefers Rust", MemoryCategory::Core, None) .await .unwrap(); @@ -646,10 +694,10 @@ mod tests { #[tokio::test] async fn sqlite_store_upsert() { let (_tmp, mem) = temp_sqlite(); - mem.store("pref", "likes Rust", MemoryCategory::Core) + mem.store("pref", "likes Rust", MemoryCategory::Core, None) .await .unwrap(); - mem.store("pref", "loves Rust", MemoryCategory::Core) + mem.store("pref", "loves Rust", MemoryCategory::Core, None) .await .unwrap(); @@ -661,17 +709,22 @@ mod tests { #[tokio::test] async fn sqlite_recall_keyword() { let (_tmp, mem) = temp_sqlite(); - mem.store("a", "Rust is fast and safe", MemoryCategory::Core) + mem.store("a", "Rust is fast and safe", MemoryCategory::Core, None) .await .unwrap(); - mem.store("b", "Python is interpreted", MemoryCategory::Core) - .await - .unwrap(); - mem.store("c", "Rust has zero-cost abstractions", MemoryCategory::Core) + mem.store("b", "Python is interpreted", MemoryCategory::Core, None) .await .unwrap(); + mem.store( + "c", + "Rust has zero-cost abstractions", + MemoryCategory::Core, + None, + ) + .await + .unwrap(); - let results = mem.recall("Rust", 10).await.unwrap(); + let results = mem.recall("Rust", 10, None).await.unwrap(); assert_eq!(results.len(), 2); assert!(results .iter() @@ -681,14 +734,14 @@ mod tests { #[tokio::test] async fn sqlite_recall_multi_keyword() { let (_tmp, mem) = temp_sqlite(); - mem.store("a", "Rust is fast", MemoryCategory::Core) + mem.store("a", "Rust is fast", MemoryCategory::Core, None) .await .unwrap(); - mem.store("b", "Rust is safe and fast", MemoryCategory::Core) + mem.store("b", "Rust is safe and fast", MemoryCategory::Core, None) .await .unwrap(); - let results = mem.recall("fast safe", 10).await.unwrap(); + let results = mem.recall("fast safe", 10, None).await.unwrap(); assert!(!results.is_empty()); // Entry with both keywords should score higher assert!(results[0].content.contains("safe") && results[0].content.contains("fast")); @@ -697,17 +750,17 @@ mod tests { #[tokio::test] async fn sqlite_recall_no_match() { let (_tmp, mem) = temp_sqlite(); - mem.store("a", "Rust rocks", MemoryCategory::Core) + mem.store("a", "Rust rocks", MemoryCategory::Core, None) .await .unwrap(); - let results = mem.recall("javascript", 10).await.unwrap(); + let results = mem.recall("javascript", 10, None).await.unwrap(); assert!(results.is_empty()); } #[tokio::test] async fn sqlite_forget() { let (_tmp, mem) = temp_sqlite(); - mem.store("temp", "temporary data", MemoryCategory::Conversation) + mem.store("temp", "temporary data", MemoryCategory::Conversation, None) .await .unwrap(); assert_eq!(mem.count().await.unwrap(), 1); @@ -727,29 +780,37 @@ mod tests { #[tokio::test] async fn sqlite_list_all() { let (_tmp, mem) = temp_sqlite(); - mem.store("a", "one", MemoryCategory::Core).await.unwrap(); - mem.store("b", "two", MemoryCategory::Daily).await.unwrap(); - mem.store("c", "three", MemoryCategory::Conversation) + mem.store("a", "one", MemoryCategory::Core, None) + .await + .unwrap(); + mem.store("b", "two", MemoryCategory::Daily, None) + .await + .unwrap(); + mem.store("c", "three", MemoryCategory::Conversation, None) .await .unwrap(); - let all = mem.list(None).await.unwrap(); + let all = mem.list(None, None).await.unwrap(); assert_eq!(all.len(), 3); } #[tokio::test] async fn sqlite_list_by_category() { let (_tmp, mem) = temp_sqlite(); - mem.store("a", "core1", MemoryCategory::Core).await.unwrap(); - mem.store("b", "core2", MemoryCategory::Core).await.unwrap(); - mem.store("c", "daily1", MemoryCategory::Daily) + mem.store("a", "core1", MemoryCategory::Core, None) + .await + .unwrap(); + mem.store("b", "core2", MemoryCategory::Core, None) + .await + .unwrap(); + mem.store("c", "daily1", MemoryCategory::Daily, None) .await .unwrap(); - let core = mem.list(Some(&MemoryCategory::Core)).await.unwrap(); + let core = mem.list(Some(&MemoryCategory::Core), None).await.unwrap(); assert_eq!(core.len(), 2); - let daily = mem.list(Some(&MemoryCategory::Daily)).await.unwrap(); + let daily = mem.list(Some(&MemoryCategory::Daily), None).await.unwrap(); assert_eq!(daily.len(), 1); } @@ -771,7 +832,7 @@ mod tests { { let mem = SqliteMemory::new(tmp.path()).unwrap(); - mem.store("persist", "I survive restarts", MemoryCategory::Core) + mem.store("persist", "I survive restarts", MemoryCategory::Core, None) .await .unwrap(); } @@ -794,7 +855,7 @@ mod tests { ]; for (i, cat) in categories.iter().enumerate() { - mem.store(&format!("k{i}"), &format!("v{i}"), cat.clone()) + mem.store(&format!("k{i}"), &format!("v{i}"), cat.clone(), None) .await .unwrap(); } @@ -814,21 +875,28 @@ mod tests { "a", "Rust is a systems programming language", MemoryCategory::Core, + None, + ) + .await + .unwrap(); + mem.store( + "b", + "Python is great for scripting", + MemoryCategory::Core, + None, ) .await .unwrap(); - mem.store("b", "Python is great for scripting", MemoryCategory::Core) - .await - .unwrap(); mem.store( "c", "Rust and Rust and Rust everywhere", MemoryCategory::Core, + None, ) .await .unwrap(); - let results = mem.recall("Rust", 10).await.unwrap(); + let results = mem.recall("Rust", 10, None).await.unwrap(); assert!(results.len() >= 2); // All results should contain "Rust" for r in &results { @@ -843,17 +911,17 @@ mod tests { #[tokio::test] async fn fts5_multi_word_query() { let (_tmp, mem) = temp_sqlite(); - mem.store("a", "The quick brown fox jumps", MemoryCategory::Core) + mem.store("a", "The quick brown fox jumps", MemoryCategory::Core, None) .await .unwrap(); - mem.store("b", "A lazy dog sleeps", MemoryCategory::Core) + mem.store("b", "A lazy dog sleeps", MemoryCategory::Core, None) .await .unwrap(); - mem.store("c", "The quick dog runs fast", MemoryCategory::Core) + mem.store("c", "The quick dog runs fast", MemoryCategory::Core, None) .await .unwrap(); - let results = mem.recall("quick dog", 10).await.unwrap(); + let results = mem.recall("quick dog", 10, None).await.unwrap(); assert!(!results.is_empty()); // "The quick dog runs fast" matches both terms assert!(results[0].content.contains("quick")); @@ -862,16 +930,20 @@ mod tests { #[tokio::test] async fn recall_empty_query_returns_empty() { let (_tmp, mem) = temp_sqlite(); - mem.store("a", "data", MemoryCategory::Core).await.unwrap(); - let results = mem.recall("", 10).await.unwrap(); + mem.store("a", "data", MemoryCategory::Core, None) + .await + .unwrap(); + let results = mem.recall("", 10, None).await.unwrap(); assert!(results.is_empty()); } #[tokio::test] async fn recall_whitespace_query_returns_empty() { let (_tmp, mem) = temp_sqlite(); - mem.store("a", "data", MemoryCategory::Core).await.unwrap(); - let results = mem.recall(" ", 10).await.unwrap(); + mem.store("a", "data", MemoryCategory::Core, None) + .await + .unwrap(); + let results = mem.recall(" ", 10, None).await.unwrap(); assert!(results.is_empty()); } @@ -936,9 +1008,14 @@ mod tests { #[tokio::test] async fn fts5_syncs_on_insert() { let (_tmp, mem) = temp_sqlite(); - mem.store("test_key", "unique_searchterm_xyz", MemoryCategory::Core) - .await - .unwrap(); + mem.store( + "test_key", + "unique_searchterm_xyz", + MemoryCategory::Core, + None, + ) + .await + .unwrap(); let conn = mem.conn.lock().unwrap(); let count: i64 = conn @@ -954,9 +1031,14 @@ mod tests { #[tokio::test] async fn fts5_syncs_on_delete() { let (_tmp, mem) = temp_sqlite(); - mem.store("del_key", "deletable_content_abc", MemoryCategory::Core) - .await - .unwrap(); + mem.store( + "del_key", + "deletable_content_abc", + MemoryCategory::Core, + None, + ) + .await + .unwrap(); mem.forget("del_key").await.unwrap(); let conn = mem.conn.lock().unwrap(); @@ -973,10 +1055,15 @@ mod tests { #[tokio::test] async fn fts5_syncs_on_update() { let (_tmp, mem) = temp_sqlite(); - mem.store("upd_key", "original_content_111", MemoryCategory::Core) - .await - .unwrap(); - mem.store("upd_key", "updated_content_222", MemoryCategory::Core) + mem.store( + "upd_key", + "original_content_111", + MemoryCategory::Core, + None, + ) + .await + .unwrap(); + mem.store("upd_key", "updated_content_222", MemoryCategory::Core, None) .await .unwrap(); @@ -1018,10 +1105,10 @@ mod tests { #[tokio::test] async fn reindex_rebuilds_fts() { let (_tmp, mem) = temp_sqlite(); - mem.store("r1", "reindex test alpha", MemoryCategory::Core) + mem.store("r1", "reindex test alpha", MemoryCategory::Core, None) .await .unwrap(); - mem.store("r2", "reindex test beta", MemoryCategory::Core) + mem.store("r2", "reindex test beta", MemoryCategory::Core, None) .await .unwrap(); @@ -1030,7 +1117,7 @@ mod tests { assert_eq!(count, 0); // FTS should still work after rebuild - let results = mem.recall("reindex", 10).await.unwrap(); + let results = mem.recall("reindex", 10, None).await.unwrap(); assert_eq!(results.len(), 2); } @@ -1044,12 +1131,13 @@ mod tests { &format!("k{i}"), &format!("common keyword item {i}"), MemoryCategory::Core, + None, ) .await .unwrap(); } - let results = mem.recall("common keyword", 5).await.unwrap(); + let results = mem.recall("common keyword", 5, None).await.unwrap(); assert!(results.len() <= 5); } @@ -1058,11 +1146,11 @@ mod tests { #[tokio::test] async fn recall_results_have_scores() { let (_tmp, mem) = temp_sqlite(); - mem.store("s1", "scored result test", MemoryCategory::Core) + mem.store("s1", "scored result test", MemoryCategory::Core, None) .await .unwrap(); - let results = mem.recall("scored", 10).await.unwrap(); + let results = mem.recall("scored", 10, None).await.unwrap(); assert!(!results.is_empty()); for r in &results { assert!(r.score.is_some(), "Expected score on result: {:?}", r.key); @@ -1074,11 +1162,11 @@ mod tests { #[tokio::test] async fn recall_with_quotes_in_query() { let (_tmp, mem) = temp_sqlite(); - mem.store("q1", "He said hello world", MemoryCategory::Core) + mem.store("q1", "He said hello world", MemoryCategory::Core, None) .await .unwrap(); // Quotes in query should not crash FTS5 - let results = mem.recall("\"hello\"", 10).await.unwrap(); + let results = mem.recall("\"hello\"", 10, None).await.unwrap(); // May or may not match depending on FTS5 escaping, but must not error assert!(results.len() <= 10); } @@ -1086,31 +1174,34 @@ mod tests { #[tokio::test] async fn recall_with_asterisk_in_query() { let (_tmp, mem) = temp_sqlite(); - mem.store("a1", "wildcard test content", MemoryCategory::Core) + mem.store("a1", "wildcard test content", MemoryCategory::Core, None) .await .unwrap(); - let results = mem.recall("wild*", 10).await.unwrap(); + let results = mem.recall("wild*", 10, None).await.unwrap(); assert!(results.len() <= 10); } #[tokio::test] async fn recall_with_parentheses_in_query() { let (_tmp, mem) = temp_sqlite(); - mem.store("p1", "function call test", MemoryCategory::Core) + mem.store("p1", "function call test", MemoryCategory::Core, None) .await .unwrap(); - let results = mem.recall("function()", 10).await.unwrap(); + let results = mem.recall("function()", 10, None).await.unwrap(); assert!(results.len() <= 10); } #[tokio::test] async fn recall_with_sql_injection_attempt() { let (_tmp, mem) = temp_sqlite(); - mem.store("safe", "normal content", MemoryCategory::Core) + mem.store("safe", "normal content", MemoryCategory::Core, None) .await .unwrap(); // Should not crash or leak data - let results = mem.recall("'; DROP TABLE memories; --", 10).await.unwrap(); + let results = mem + .recall("'; DROP TABLE memories; --", 10, None) + .await + .unwrap(); assert!(results.len() <= 10); // Table should still exist assert_eq!(mem.count().await.unwrap(), 1); @@ -1121,7 +1212,9 @@ mod tests { #[tokio::test] async fn store_empty_content() { let (_tmp, mem) = temp_sqlite(); - mem.store("empty", "", MemoryCategory::Core).await.unwrap(); + mem.store("empty", "", MemoryCategory::Core, None) + .await + .unwrap(); let entry = mem.get("empty").await.unwrap().unwrap(); assert_eq!(entry.content, ""); } @@ -1129,7 +1222,7 @@ mod tests { #[tokio::test] async fn store_empty_key() { let (_tmp, mem) = temp_sqlite(); - mem.store("", "content for empty key", MemoryCategory::Core) + mem.store("", "content for empty key", MemoryCategory::Core, None) .await .unwrap(); let entry = mem.get("").await.unwrap().unwrap(); @@ -1140,7 +1233,7 @@ mod tests { async fn store_very_long_content() { let (_tmp, mem) = temp_sqlite(); let long_content = "x".repeat(100_000); - mem.store("long", &long_content, MemoryCategory::Core) + mem.store("long", &long_content, MemoryCategory::Core, None) .await .unwrap(); let entry = mem.get("long").await.unwrap().unwrap(); @@ -1150,9 +1243,14 @@ mod tests { #[tokio::test] async fn store_unicode_and_emoji() { let (_tmp, mem) = temp_sqlite(); - mem.store("emoji_key_๐Ÿฆ€", "ใ“ใ‚“ใซใกใฏ ๐Ÿš€ ร‘oรฑo", MemoryCategory::Core) - .await - .unwrap(); + mem.store( + "emoji_key_๐Ÿฆ€", + "ใ“ใ‚“ใซใกใฏ ๐Ÿš€ ร‘oรฑo", + MemoryCategory::Core, + None, + ) + .await + .unwrap(); let entry = mem.get("emoji_key_๐Ÿฆ€").await.unwrap().unwrap(); assert_eq!(entry.content, "ใ“ใ‚“ใซใกใฏ ๐Ÿš€ ร‘oรฑo"); } @@ -1161,7 +1259,7 @@ mod tests { async fn store_content_with_newlines_and_tabs() { let (_tmp, mem) = temp_sqlite(); let content = "line1\nline2\ttab\rcarriage\n\nnewparagraph"; - mem.store("whitespace", content, MemoryCategory::Core) + mem.store("whitespace", content, MemoryCategory::Core, None) .await .unwrap(); let entry = mem.get("whitespace").await.unwrap().unwrap(); @@ -1173,11 +1271,11 @@ mod tests { #[tokio::test] async fn recall_single_character_query() { let (_tmp, mem) = temp_sqlite(); - mem.store("a", "x marks the spot", MemoryCategory::Core) + mem.store("a", "x marks the spot", MemoryCategory::Core, None) .await .unwrap(); // Single char may not match FTS5 but LIKE fallback should work - let results = mem.recall("x", 10).await.unwrap(); + let results = mem.recall("x", 10, None).await.unwrap(); // Should not crash; may or may not find results assert!(results.len() <= 10); } @@ -1185,23 +1283,23 @@ mod tests { #[tokio::test] async fn recall_limit_zero() { let (_tmp, mem) = temp_sqlite(); - mem.store("a", "some content", MemoryCategory::Core) + mem.store("a", "some content", MemoryCategory::Core, None) .await .unwrap(); - let results = mem.recall("some", 0).await.unwrap(); + let results = mem.recall("some", 0, None).await.unwrap(); assert!(results.is_empty()); } #[tokio::test] async fn recall_limit_one() { let (_tmp, mem) = temp_sqlite(); - mem.store("a", "matching content alpha", MemoryCategory::Core) + mem.store("a", "matching content alpha", MemoryCategory::Core, None) .await .unwrap(); - mem.store("b", "matching content beta", MemoryCategory::Core) + mem.store("b", "matching content beta", MemoryCategory::Core, None) .await .unwrap(); - let results = mem.recall("matching content", 1).await.unwrap(); + let results = mem.recall("matching content", 1, None).await.unwrap(); assert_eq!(results.len(), 1); } @@ -1212,21 +1310,22 @@ mod tests { "rust_preferences", "User likes systems programming", MemoryCategory::Core, + None, ) .await .unwrap(); // "rust" appears in key but not content โ€” LIKE fallback checks key too - let results = mem.recall("rust", 10).await.unwrap(); + let results = mem.recall("rust", 10, None).await.unwrap(); assert!(!results.is_empty(), "Should match by key"); } #[tokio::test] async fn recall_unicode_query() { let (_tmp, mem) = temp_sqlite(); - mem.store("jp", "ๆ—ฅๆœฌ่ชžใฎใƒ†ใ‚นใƒˆ", MemoryCategory::Core) + mem.store("jp", "ๆ—ฅๆœฌ่ชžใฎใƒ†ใ‚นใƒˆ", MemoryCategory::Core, None) .await .unwrap(); - let results = mem.recall("ๆ—ฅๆœฌ่ชž", 10).await.unwrap(); + let results = mem.recall("ๆ—ฅๆœฌ่ชž", 10, None).await.unwrap(); assert!(!results.is_empty()); } @@ -1237,7 +1336,9 @@ mod tests { let tmp = TempDir::new().unwrap(); { let mem = SqliteMemory::new(tmp.path()).unwrap(); - mem.store("k1", "v1", MemoryCategory::Core).await.unwrap(); + mem.store("k1", "v1", MemoryCategory::Core, None) + .await + .unwrap(); } // Open again โ€” init_schema runs again on existing DB let mem2 = SqliteMemory::new(tmp.path()).unwrap(); @@ -1245,7 +1346,9 @@ mod tests { assert!(entry.is_some()); assert_eq!(entry.unwrap().content, "v1"); // Store more data โ€” should work fine - mem2.store("k2", "v2", MemoryCategory::Daily).await.unwrap(); + mem2.store("k2", "v2", MemoryCategory::Daily, None) + .await + .unwrap(); assert_eq!(mem2.count().await.unwrap(), 2); } @@ -1263,11 +1366,16 @@ mod tests { #[tokio::test] async fn forget_then_recall_no_ghost_results() { let (_tmp, mem) = temp_sqlite(); - mem.store("ghost", "phantom memory content", MemoryCategory::Core) - .await - .unwrap(); + mem.store( + "ghost", + "phantom memory content", + MemoryCategory::Core, + None, + ) + .await + .unwrap(); mem.forget("ghost").await.unwrap(); - let results = mem.recall("phantom memory", 10).await.unwrap(); + let results = mem.recall("phantom memory", 10, None).await.unwrap(); assert!( results.is_empty(), "Deleted memory should not appear in recall" @@ -1277,11 +1385,11 @@ mod tests { #[tokio::test] async fn forget_and_re_store_same_key() { let (_tmp, mem) = temp_sqlite(); - mem.store("cycle", "version 1", MemoryCategory::Core) + mem.store("cycle", "version 1", MemoryCategory::Core, None) .await .unwrap(); mem.forget("cycle").await.unwrap(); - mem.store("cycle", "version 2", MemoryCategory::Core) + mem.store("cycle", "version 2", MemoryCategory::Core, None) .await .unwrap(); let entry = mem.get("cycle").await.unwrap().unwrap(); @@ -1301,14 +1409,14 @@ mod tests { #[tokio::test] async fn reindex_twice_is_safe() { let (_tmp, mem) = temp_sqlite(); - mem.store("r1", "reindex data", MemoryCategory::Core) + mem.store("r1", "reindex data", MemoryCategory::Core, None) .await .unwrap(); mem.reindex().await.unwrap(); let count = mem.reindex().await.unwrap(); assert_eq!(count, 0); // Noop embedder โ†’ nothing to re-embed // Data should still be intact - let results = mem.recall("reindex", 10).await.unwrap(); + let results = mem.recall("reindex", 10, None).await.unwrap(); assert_eq!(results.len(), 1); } @@ -1362,18 +1470,28 @@ mod tests { #[tokio::test] async fn list_custom_category() { let (_tmp, mem) = temp_sqlite(); - mem.store("c1", "custom1", MemoryCategory::Custom("project".into())) - .await - .unwrap(); - mem.store("c2", "custom2", MemoryCategory::Custom("project".into())) - .await - .unwrap(); - mem.store("c3", "other", MemoryCategory::Core) + mem.store( + "c1", + "custom1", + MemoryCategory::Custom("project".into()), + None, + ) + .await + .unwrap(); + mem.store( + "c2", + "custom2", + MemoryCategory::Custom("project".into()), + None, + ) + .await + .unwrap(); + mem.store("c3", "other", MemoryCategory::Core, None) .await .unwrap(); let project = mem - .list(Some(&MemoryCategory::Custom("project".into()))) + .list(Some(&MemoryCategory::Custom("project".into())), None) .await .unwrap(); assert_eq!(project.len(), 2); @@ -1382,7 +1500,122 @@ mod tests { #[tokio::test] async fn list_empty_db() { let (_tmp, mem) = temp_sqlite(); - let all = mem.list(None).await.unwrap(); + let all = mem.list(None, None).await.unwrap(); assert!(all.is_empty()); } + + // โ”€โ”€ Session isolation โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ + + #[tokio::test] + async fn store_and_recall_with_session_id() { + let (_tmp, mem) = temp_sqlite(); + mem.store("k1", "session A fact", MemoryCategory::Core, Some("sess-a")) + .await + .unwrap(); + mem.store("k2", "session B fact", MemoryCategory::Core, Some("sess-b")) + .await + .unwrap(); + mem.store("k3", "no session fact", MemoryCategory::Core, None) + .await + .unwrap(); + + // Recall with session-a filter returns only session-a entry + let results = mem.recall("fact", 10, Some("sess-a")).await.unwrap(); + assert_eq!(results.len(), 1); + assert_eq!(results[0].key, "k1"); + assert_eq!(results[0].session_id.as_deref(), Some("sess-a")); + } + + #[tokio::test] + async fn recall_no_session_filter_returns_all() { + let (_tmp, mem) = temp_sqlite(); + mem.store("k1", "alpha fact", MemoryCategory::Core, Some("sess-a")) + .await + .unwrap(); + mem.store("k2", "beta fact", MemoryCategory::Core, Some("sess-b")) + .await + .unwrap(); + mem.store("k3", "gamma fact", MemoryCategory::Core, None) + .await + .unwrap(); + + // Recall without session filter returns all matching entries + let results = mem.recall("fact", 10, None).await.unwrap(); + assert_eq!(results.len(), 3); + } + + #[tokio::test] + async fn cross_session_recall_isolation() { + let (_tmp, mem) = temp_sqlite(); + mem.store( + "secret", + "session A secret data", + MemoryCategory::Core, + Some("sess-a"), + ) + .await + .unwrap(); + + // Session B cannot see session A data + let results = mem.recall("secret", 10, Some("sess-b")).await.unwrap(); + assert!(results.is_empty()); + + // Session A can see its own data + let results = mem.recall("secret", 10, Some("sess-a")).await.unwrap(); + assert_eq!(results.len(), 1); + } + + #[tokio::test] + async fn list_with_session_filter() { + let (_tmp, mem) = temp_sqlite(); + mem.store("k1", "a1", MemoryCategory::Core, Some("sess-a")) + .await + .unwrap(); + mem.store("k2", "a2", MemoryCategory::Conversation, Some("sess-a")) + .await + .unwrap(); + mem.store("k3", "b1", MemoryCategory::Core, Some("sess-b")) + .await + .unwrap(); + mem.store("k4", "none1", MemoryCategory::Core, None) + .await + .unwrap(); + + // List with session-a filter + let results = mem.list(None, Some("sess-a")).await.unwrap(); + assert_eq!(results.len(), 2); + assert!(results + .iter() + .all(|e| e.session_id.as_deref() == Some("sess-a"))); + + // List with session-a + category filter + let results = mem + .list(Some(&MemoryCategory::Core), Some("sess-a")) + .await + .unwrap(); + assert_eq!(results.len(), 1); + assert_eq!(results[0].key, "k1"); + } + + #[tokio::test] + async fn schema_migration_idempotent_on_reopen() { + let tmp = TempDir::new().unwrap(); + + // First open: creates schema + migration + { + let mem = SqliteMemory::new(tmp.path()).unwrap(); + mem.store("k1", "before reopen", MemoryCategory::Core, Some("sess-x")) + .await + .unwrap(); + } + + // Second open: migration runs again but is idempotent + { + let mem = SqliteMemory::new(tmp.path()).unwrap(); + let results = mem.recall("reopen", 10, Some("sess-x")).await.unwrap(); + assert_eq!(results.len(), 1); + assert_eq!(results[0].key, "k1"); + assert_eq!(results[0].session_id.as_deref(), Some("sess-x")); + } + } } diff --git a/src/memory/traits.rs b/src/memory/traits.rs index 72e120e..bf8c021 100644 --- a/src/memory/traits.rs +++ b/src/memory/traits.rs @@ -44,18 +44,32 @@ pub trait Memory: Send + Sync { /// Backend name fn name(&self) -> &str; - /// Store a memory entry - async fn store(&self, key: &str, content: &str, category: MemoryCategory) - -> anyhow::Result<()>; + /// Store a memory entry, optionally scoped to a session + async fn store( + &self, + key: &str, + content: &str, + category: MemoryCategory, + session_id: Option<&str>, + ) -> anyhow::Result<()>; - /// Recall memories matching a query (keyword search) - async fn recall(&self, query: &str, limit: usize) -> anyhow::Result>; + /// Recall memories matching a query (keyword search), optionally scoped to a session + async fn recall( + &self, + query: &str, + limit: usize, + session_id: Option<&str>, + ) -> anyhow::Result>; /// Get a specific memory by key async fn get(&self, key: &str) -> anyhow::Result>; - /// List all memory keys, optionally filtered by category - async fn list(&self, category: Option<&MemoryCategory>) -> anyhow::Result>; + /// List all memory keys, optionally filtered by category and/or session + async fn list( + &self, + category: Option<&MemoryCategory>, + session_id: Option<&str>, + ) -> anyhow::Result>; /// Remove a memory by key async fn forget(&self, key: &str) -> anyhow::Result; diff --git a/src/migration.rs b/src/migration.rs index f217030..8a83262 100644 --- a/src/migration.rs +++ b/src/migration.rs @@ -95,7 +95,9 @@ async fn migrate_openclaw_memory( stats.renamed_conflicts += 1; } - memory.store(&key, &entry.content, entry.category).await?; + memory + .store(&key, &entry.content, entry.category, None) + .await?; stats.imported += 1; } @@ -488,7 +490,7 @@ mod tests { // Existing target memory let target_mem = SqliteMemory::new(target.path()).unwrap(); target_mem - .store("k", "new value", MemoryCategory::Core) + .store("k", "new value", MemoryCategory::Core, None) .await .unwrap(); @@ -510,7 +512,7 @@ mod tests { .await .unwrap(); - let all = target_mem.list(None).await.unwrap(); + let all = target_mem.list(None, None).await.unwrap(); assert!(all.iter().any(|e| e.key == "k" && e.content == "new value")); assert!(all .iter() diff --git a/src/tools/memory_forget.rs b/src/tools/memory_forget.rs index 16b2b8a..a53885e 100644 --- a/src/tools/memory_forget.rs +++ b/src/tools/memory_forget.rs @@ -87,7 +87,7 @@ mod tests { #[tokio::test] async fn forget_existing() { let (_tmp, mem) = test_mem(); - mem.store("temp", "temporary", MemoryCategory::Conversation) + mem.store("temp", "temporary", MemoryCategory::Conversation, None) .await .unwrap(); diff --git a/src/tools/memory_recall.rs b/src/tools/memory_recall.rs index ff1385a..fada306 100644 --- a/src/tools/memory_recall.rs +++ b/src/tools/memory_recall.rs @@ -55,7 +55,7 @@ impl Tool for MemoryRecallTool { .and_then(serde_json::Value::as_u64) .map_or(5, |v| v as usize); - match self.memory.recall(query, limit).await { + match self.memory.recall(query, limit, None).await { Ok(entries) if entries.is_empty() => Ok(ToolResult { success: true, output: "No memories found matching that query.".into(), @@ -112,10 +112,10 @@ mod tests { #[tokio::test] async fn recall_finds_match() { let (_tmp, mem) = seeded_mem(); - mem.store("lang", "User prefers Rust", MemoryCategory::Core) + mem.store("lang", "User prefers Rust", MemoryCategory::Core, None) .await .unwrap(); - mem.store("tz", "Timezone is EST", MemoryCategory::Core) + mem.store("tz", "Timezone is EST", MemoryCategory::Core, None) .await .unwrap(); @@ -134,6 +134,7 @@ mod tests { &format!("k{i}"), &format!("Rust fact {i}"), MemoryCategory::Core, + None, ) .await .unwrap(); diff --git a/src/tools/memory_store.rs b/src/tools/memory_store.rs index b90222c..d2aad40 100644 --- a/src/tools/memory_store.rs +++ b/src/tools/memory_store.rs @@ -64,7 +64,7 @@ impl Tool for MemoryStoreTool { _ => MemoryCategory::Core, }; - match self.memory.store(key, content, category).await { + match self.memory.store(key, content, category, None).await { Ok(()) => Ok(ToolResult { success: true, output: format!("Stored memory: {key}"), diff --git a/tests/memory_comparison.rs b/tests/memory_comparison.rs index 8e0f4d6..2523829 100644 --- a/tests/memory_comparison.rs +++ b/tests/memory_comparison.rs @@ -36,6 +36,7 @@ async fn compare_store_speed() { &format!("key_{i}"), &format!("Memory entry number {i} about Rust programming"), MemoryCategory::Core, + None, ) .await .unwrap(); @@ -49,6 +50,7 @@ async fn compare_store_speed() { &format!("key_{i}"), &format!("Memory entry number {i} about Rust programming"), MemoryCategory::Core, + None, ) .await .unwrap(); @@ -127,8 +129,8 @@ async fn compare_recall_quality() { ]; for (key, content, cat) in &entries { - sq.store(key, content, cat.clone()).await.unwrap(); - md.store(key, content, cat.clone()).await.unwrap(); + sq.store(key, content, cat.clone(), None).await.unwrap(); + md.store(key, content, cat.clone(), None).await.unwrap(); } // Test queries and compare results @@ -145,8 +147,8 @@ async fn compare_recall_quality() { println!("RECALL QUALITY (10 entries seeded):\n"); for (query, desc) in &queries { - let sq_results = sq.recall(query, 10).await.unwrap(); - let md_results = md.recall(query, 10).await.unwrap(); + let sq_results = sq.recall(query, 10, None).await.unwrap(); + let md_results = md.recall(query, 10, None).await.unwrap(); println!(" Query: \"{query}\" โ€” {desc}"); println!(" SQLite: {} results", sq_results.len()); @@ -190,21 +192,21 @@ async fn compare_recall_speed() { } else { format!("TypeScript powers modern web apps, entry {i}") }; - sq.store(&format!("e{i}"), &content, MemoryCategory::Core) + sq.store(&format!("e{i}"), &content, MemoryCategory::Core, None) .await .unwrap(); - md.store(&format!("e{i}"), &content, MemoryCategory::Daily) + md.store(&format!("e{i}"), &content, MemoryCategory::Daily, None) .await .unwrap(); } // Benchmark recall let start = Instant::now(); - let sq_results = sq.recall("Rust systems", 10).await.unwrap(); + let sq_results = sq.recall("Rust systems", 10, None).await.unwrap(); let sq_dur = start.elapsed(); let start = Instant::now(); - let md_results = md.recall("Rust systems", 10).await.unwrap(); + let md_results = md.recall("Rust systems", 10, None).await.unwrap(); let md_dur = start.elapsed(); println!("\n============================================================"); @@ -227,15 +229,25 @@ async fn compare_persistence() { // Store in both, then drop and re-open { let sq = sqlite_backend(tmp_sq.path()); - sq.store("persist_test", "I should survive", MemoryCategory::Core) - .await - .unwrap(); + sq.store( + "persist_test", + "I should survive", + MemoryCategory::Core, + None, + ) + .await + .unwrap(); } { let md = markdown_backend(tmp_md.path()); - md.store("persist_test", "I should survive", MemoryCategory::Core) - .await - .unwrap(); + md.store( + "persist_test", + "I should survive", + MemoryCategory::Core, + None, + ) + .await + .unwrap(); } // Re-open @@ -282,17 +294,17 @@ async fn compare_upsert() { let md = markdown_backend(tmp_md.path()); // Store twice with same key, different content - sq.store("pref", "likes Rust", MemoryCategory::Core) + sq.store("pref", "likes Rust", MemoryCategory::Core, None) .await .unwrap(); - sq.store("pref", "loves Rust", MemoryCategory::Core) + sq.store("pref", "loves Rust", MemoryCategory::Core, None) .await .unwrap(); - md.store("pref", "likes Rust", MemoryCategory::Core) + md.store("pref", "likes Rust", MemoryCategory::Core, None) .await .unwrap(); - md.store("pref", "loves Rust", MemoryCategory::Core) + md.store("pref", "loves Rust", MemoryCategory::Core, None) .await .unwrap(); @@ -300,7 +312,7 @@ async fn compare_upsert() { let md_count = md.count().await.unwrap(); let sq_entry = sq.get("pref").await.unwrap(); - let md_results = md.recall("loves Rust", 5).await.unwrap(); + let md_results = md.recall("loves Rust", 5, None).await.unwrap(); println!("\n============================================================"); println!("UPSERT (store same key twice):"); @@ -328,10 +340,10 @@ async fn compare_forget() { let sq = sqlite_backend(tmp_sq.path()); let md = markdown_backend(tmp_md.path()); - sq.store("secret", "API key: sk-1234", MemoryCategory::Core) + sq.store("secret", "API key: sk-1234", MemoryCategory::Core, None) .await .unwrap(); - md.store("secret", "API key: sk-1234", MemoryCategory::Core) + md.store("secret", "API key: sk-1234", MemoryCategory::Core, None) .await .unwrap(); @@ -372,37 +384,40 @@ async fn compare_category_filter() { let md = markdown_backend(tmp_md.path()); // Mix of categories - sq.store("a", "core fact 1", MemoryCategory::Core) + sq.store("a", "core fact 1", MemoryCategory::Core, None) .await .unwrap(); - sq.store("b", "core fact 2", MemoryCategory::Core) + sq.store("b", "core fact 2", MemoryCategory::Core, None) .await .unwrap(); - sq.store("c", "daily note", MemoryCategory::Daily) + sq.store("c", "daily note", MemoryCategory::Daily, None) .await .unwrap(); - sq.store("d", "convo msg", MemoryCategory::Conversation) + sq.store("d", "convo msg", MemoryCategory::Conversation, None) .await .unwrap(); - md.store("a", "core fact 1", MemoryCategory::Core) + md.store("a", "core fact 1", MemoryCategory::Core, None) .await .unwrap(); - md.store("b", "core fact 2", MemoryCategory::Core) + md.store("b", "core fact 2", MemoryCategory::Core, None) .await .unwrap(); - md.store("c", "daily note", MemoryCategory::Daily) + md.store("c", "daily note", MemoryCategory::Daily, None) .await .unwrap(); - let sq_core = sq.list(Some(&MemoryCategory::Core)).await.unwrap(); - let sq_daily = sq.list(Some(&MemoryCategory::Daily)).await.unwrap(); - let sq_conv = sq.list(Some(&MemoryCategory::Conversation)).await.unwrap(); - let sq_all = sq.list(None).await.unwrap(); + let sq_core = sq.list(Some(&MemoryCategory::Core), None).await.unwrap(); + let sq_daily = sq.list(Some(&MemoryCategory::Daily), None).await.unwrap(); + let sq_conv = sq + .list(Some(&MemoryCategory::Conversation), None) + .await + .unwrap(); + let sq_all = sq.list(None, None).await.unwrap(); - let md_core = md.list(Some(&MemoryCategory::Core)).await.unwrap(); - let md_daily = md.list(Some(&MemoryCategory::Daily)).await.unwrap(); - let md_all = md.list(None).await.unwrap(); + let md_core = md.list(Some(&MemoryCategory::Core), None).await.unwrap(); + let md_daily = md.list(Some(&MemoryCategory::Daily), None).await.unwrap(); + let md_all = md.list(None, None).await.unwrap(); println!("\n============================================================"); println!("CATEGORY FILTERING:");