Merge branch 'main' into pr-484-clean
This commit is contained in:
commit
ee05d62ce4
90 changed files with 6937 additions and 1403 deletions
|
|
@ -7,6 +7,7 @@ pub enum MemoryBackendKind {
|
|||
Unknown,
|
||||
}
|
||||
|
||||
#[allow(clippy::struct_excessive_bools)]
|
||||
#[derive(Debug, Clone, Copy, Eq, PartialEq)]
|
||||
pub struct MemoryBackendProfile {
|
||||
pub key: &'static str,
|
||||
|
|
|
|||
|
|
@ -502,10 +502,10 @@ mod tests {
|
|||
let workspace = tmp.path();
|
||||
|
||||
let mem = SqliteMemory::new(workspace).unwrap();
|
||||
mem.store("conv_old", "outdated", MemoryCategory::Conversation)
|
||||
mem.store("conv_old", "outdated", MemoryCategory::Conversation, None)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("core_keep", "durable", MemoryCategory::Core)
|
||||
mem.store("core_keep", "durable", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
drop(mem);
|
||||
|
|
|
|||
|
|
@ -24,7 +24,9 @@ pub struct LucidMemory {
|
|||
impl LucidMemory {
|
||||
const DEFAULT_LUCID_CMD: &'static str = "lucid";
|
||||
const DEFAULT_TOKEN_BUDGET: usize = 200;
|
||||
const DEFAULT_RECALL_TIMEOUT_MS: u64 = 120;
|
||||
// Lucid CLI cold start can exceed 120ms on slower machines, which causes
|
||||
// avoidable fallback to local-only memory and premature cooldown.
|
||||
const DEFAULT_RECALL_TIMEOUT_MS: u64 = 500;
|
||||
const DEFAULT_STORE_TIMEOUT_MS: u64 = 800;
|
||||
const DEFAULT_LOCAL_HIT_THRESHOLD: usize = 3;
|
||||
const DEFAULT_FAILURE_COOLDOWN_MS: u64 = 15_000;
|
||||
|
|
@ -74,6 +76,7 @@ impl LucidMemory {
|
|||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn with_options(
|
||||
workspace_dir: &Path,
|
||||
local: SqliteMemory,
|
||||
|
|
@ -307,14 +310,22 @@ impl Memory for LucidMemory {
|
|||
key: &str,
|
||||
content: &str,
|
||||
category: MemoryCategory,
|
||||
session_id: Option<&str>,
|
||||
) -> anyhow::Result<()> {
|
||||
self.local.store(key, content, category.clone()).await?;
|
||||
self.local
|
||||
.store(key, content, category.clone(), session_id)
|
||||
.await?;
|
||||
self.sync_to_lucid_async(key, content, &category).await;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn recall(&self, query: &str, limit: usize) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
let local_results = self.local.recall(query, limit).await?;
|
||||
async fn recall(
|
||||
&self,
|
||||
query: &str,
|
||||
limit: usize,
|
||||
session_id: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
let local_results = self.local.recall(query, limit, session_id).await?;
|
||||
if limit == 0
|
||||
|| local_results.len() >= limit
|
||||
|| local_results.len() >= self.local_hit_threshold
|
||||
|
|
@ -351,8 +362,12 @@ impl Memory for LucidMemory {
|
|||
self.local.get(key).await
|
||||
}
|
||||
|
||||
async fn list(&self, category: Option<&MemoryCategory>) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
self.local.list(category).await
|
||||
async fn list(
|
||||
&self,
|
||||
category: Option<&MemoryCategory>,
|
||||
session_id: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
self.local.list(category, session_id).await
|
||||
}
|
||||
|
||||
async fn forget(&self, key: &str) -> anyhow::Result<bool> {
|
||||
|
|
@ -396,6 +411,38 @@ EOF
|
|||
exit 0
|
||||
fi
|
||||
|
||||
echo "unsupported command" >&2
|
||||
exit 1
|
||||
"#;
|
||||
|
||||
fs::write(&script_path, script).unwrap();
|
||||
let mut perms = fs::metadata(&script_path).unwrap().permissions();
|
||||
perms.set_mode(0o755);
|
||||
fs::set_permissions(&script_path, perms).unwrap();
|
||||
script_path.display().to_string()
|
||||
}
|
||||
|
||||
fn write_delayed_lucid_script(dir: &Path) -> String {
|
||||
let script_path = dir.join("delayed-lucid.sh");
|
||||
let script = r#"#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
if [[ "${1:-}" == "store" ]]; then
|
||||
echo '{"success":true,"id":"mem_1"}'
|
||||
exit 0
|
||||
fi
|
||||
|
||||
if [[ "${1:-}" == "context" ]]; then
|
||||
# Simulate a cold start that is slower than 120ms but below the 500ms timeout.
|
||||
sleep 0.2
|
||||
cat <<'EOF'
|
||||
<lucid-context>
|
||||
- [decision] Delayed token refresh guidance
|
||||
</lucid-context>
|
||||
EOF
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "unsupported command" >&2
|
||||
exit 1
|
||||
"#;
|
||||
|
|
@ -449,7 +496,7 @@ exit 1
|
|||
cmd,
|
||||
200,
|
||||
3,
|
||||
Duration::from_millis(120),
|
||||
Duration::from_millis(500),
|
||||
Duration::from_millis(400),
|
||||
Duration::from_secs(2),
|
||||
)
|
||||
|
|
@ -468,7 +515,7 @@ exit 1
|
|||
let memory = test_memory(tmp.path(), "nonexistent-lucid-binary".to_string());
|
||||
|
||||
memory
|
||||
.store("lang", "User prefers Rust", MemoryCategory::Core)
|
||||
.store("lang", "User prefers Rust", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
|
|
@ -483,6 +530,30 @@ exit 1
|
|||
let fake_cmd = write_fake_lucid_script(tmp.path());
|
||||
let memory = test_memory(tmp.path(), fake_cmd);
|
||||
|
||||
memory
|
||||
.store(
|
||||
"local_note",
|
||||
"Local sqlite auth fallback note",
|
||||
MemoryCategory::Core,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let entries = memory.recall("auth", 5, None).await.unwrap();
|
||||
|
||||
assert!(entries
|
||||
.iter()
|
||||
.any(|e| e.content.contains("Local sqlite auth fallback note")));
|
||||
assert!(entries.iter().any(|e| e.content.contains("token refresh")));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn recall_handles_lucid_cold_start_delay_within_timeout() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let delayed_cmd = write_delayed_lucid_script(tmp.path());
|
||||
let memory = test_memory(tmp.path(), delayed_cmd);
|
||||
|
||||
memory
|
||||
.store(
|
||||
"local_note",
|
||||
|
|
@ -497,7 +568,9 @@ exit 1
|
|||
assert!(entries
|
||||
.iter()
|
||||
.any(|e| e.content.contains("Local sqlite auth fallback note")));
|
||||
assert!(entries.iter().any(|e| e.content.contains("token refresh")));
|
||||
assert!(entries
|
||||
.iter()
|
||||
.any(|e| e.content.contains("Delayed token refresh guidance")));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
|
@ -513,17 +586,22 @@ exit 1
|
|||
probe_cmd,
|
||||
200,
|
||||
1,
|
||||
Duration::from_millis(120),
|
||||
Duration::from_millis(500),
|
||||
Duration::from_millis(400),
|
||||
Duration::from_secs(2),
|
||||
);
|
||||
|
||||
memory
|
||||
.store("pref", "Rust should stay local-first", MemoryCategory::Core)
|
||||
.store(
|
||||
"pref",
|
||||
"Rust should stay local-first",
|
||||
MemoryCategory::Core,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let entries = memory.recall("rust", 5).await.unwrap();
|
||||
let entries = memory.recall("rust", 5, None).await.unwrap();
|
||||
assert!(entries
|
||||
.iter()
|
||||
.any(|e| e.content.contains("Rust should stay local-first")));
|
||||
|
|
@ -578,13 +656,13 @@ exit 1
|
|||
failing_cmd,
|
||||
200,
|
||||
99,
|
||||
Duration::from_millis(120),
|
||||
Duration::from_millis(500),
|
||||
Duration::from_millis(400),
|
||||
Duration::from_secs(5),
|
||||
);
|
||||
|
||||
let first = memory.recall("auth", 5).await.unwrap();
|
||||
let second = memory.recall("auth", 5).await.unwrap();
|
||||
let first = memory.recall("auth", 5, None).await.unwrap();
|
||||
let second = memory.recall("auth", 5, None).await.unwrap();
|
||||
|
||||
assert!(first.is_empty());
|
||||
assert!(second.is_empty());
|
||||
|
|
|
|||
|
|
@ -143,6 +143,7 @@ impl Memory for MarkdownMemory {
|
|||
key: &str,
|
||||
content: &str,
|
||||
category: MemoryCategory,
|
||||
_session_id: Option<&str>,
|
||||
) -> anyhow::Result<()> {
|
||||
let entry = format!("- **{key}**: {content}");
|
||||
let path = match category {
|
||||
|
|
@ -152,7 +153,12 @@ impl Memory for MarkdownMemory {
|
|||
self.append_to_file(&path, &entry).await
|
||||
}
|
||||
|
||||
async fn recall(&self, query: &str, limit: usize) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
async fn recall(
|
||||
&self,
|
||||
query: &str,
|
||||
limit: usize,
|
||||
_session_id: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
let all = self.read_all_entries().await?;
|
||||
let query_lower = query.to_lowercase();
|
||||
let keywords: Vec<&str> = query_lower.split_whitespace().collect();
|
||||
|
|
@ -192,7 +198,11 @@ impl Memory for MarkdownMemory {
|
|||
.find(|e| e.key == key || e.content.contains(key)))
|
||||
}
|
||||
|
||||
async fn list(&self, category: Option<&MemoryCategory>) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
async fn list(
|
||||
&self,
|
||||
category: Option<&MemoryCategory>,
|
||||
_session_id: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
let all = self.read_all_entries().await?;
|
||||
match category {
|
||||
Some(cat) => Ok(all.into_iter().filter(|e| &e.category == cat).collect()),
|
||||
|
|
@ -243,7 +253,7 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn markdown_store_core() {
|
||||
let (_tmp, mem) = temp_workspace();
|
||||
mem.store("pref", "User likes Rust", MemoryCategory::Core)
|
||||
mem.store("pref", "User likes Rust", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
let content = sync_fs::read_to_string(mem.core_path()).unwrap();
|
||||
|
|
@ -253,7 +263,7 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn markdown_store_daily() {
|
||||
let (_tmp, mem) = temp_workspace();
|
||||
mem.store("note", "Finished tests", MemoryCategory::Daily)
|
||||
mem.store("note", "Finished tests", MemoryCategory::Daily, None)
|
||||
.await
|
||||
.unwrap();
|
||||
let path = mem.daily_path();
|
||||
|
|
@ -264,17 +274,17 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn markdown_recall_keyword() {
|
||||
let (_tmp, mem) = temp_workspace();
|
||||
mem.store("a", "Rust is fast", MemoryCategory::Core)
|
||||
mem.store("a", "Rust is fast", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("b", "Python is slow", MemoryCategory::Core)
|
||||
mem.store("b", "Python is slow", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("c", "Rust and safety", MemoryCategory::Core)
|
||||
mem.store("c", "Rust and safety", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let results = mem.recall("Rust", 10).await.unwrap();
|
||||
let results = mem.recall("Rust", 10, None).await.unwrap();
|
||||
assert!(results.len() >= 2);
|
||||
assert!(results
|
||||
.iter()
|
||||
|
|
@ -284,18 +294,20 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn markdown_recall_no_match() {
|
||||
let (_tmp, mem) = temp_workspace();
|
||||
mem.store("a", "Rust is great", MemoryCategory::Core)
|
||||
mem.store("a", "Rust is great", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
let results = mem.recall("javascript", 10).await.unwrap();
|
||||
let results = mem.recall("javascript", 10, None).await.unwrap();
|
||||
assert!(results.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn markdown_count() {
|
||||
let (_tmp, mem) = temp_workspace();
|
||||
mem.store("a", "first", MemoryCategory::Core).await.unwrap();
|
||||
mem.store("b", "second", MemoryCategory::Core)
|
||||
mem.store("a", "first", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("b", "second", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
let count = mem.count().await.unwrap();
|
||||
|
|
@ -305,24 +317,24 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn markdown_list_by_category() {
|
||||
let (_tmp, mem) = temp_workspace();
|
||||
mem.store("a", "core fact", MemoryCategory::Core)
|
||||
mem.store("a", "core fact", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("b", "daily note", MemoryCategory::Daily)
|
||||
mem.store("b", "daily note", MemoryCategory::Daily, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let core = mem.list(Some(&MemoryCategory::Core)).await.unwrap();
|
||||
let core = mem.list(Some(&MemoryCategory::Core), None).await.unwrap();
|
||||
assert!(core.iter().all(|e| e.category == MemoryCategory::Core));
|
||||
|
||||
let daily = mem.list(Some(&MemoryCategory::Daily)).await.unwrap();
|
||||
let daily = mem.list(Some(&MemoryCategory::Daily), None).await.unwrap();
|
||||
assert!(daily.iter().all(|e| e.category == MemoryCategory::Daily));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn markdown_forget_is_noop() {
|
||||
let (_tmp, mem) = temp_workspace();
|
||||
mem.store("a", "permanent", MemoryCategory::Core)
|
||||
mem.store("a", "permanent", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
let removed = mem.forget("a").await.unwrap();
|
||||
|
|
@ -332,7 +344,7 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn markdown_empty_recall() {
|
||||
let (_tmp, mem) = temp_workspace();
|
||||
let results = mem.recall("anything", 10).await.unwrap();
|
||||
let results = mem.recall("anything", 10, None).await.unwrap();
|
||||
assert!(results.is_empty());
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -25,11 +25,17 @@ impl Memory for NoneMemory {
|
|||
_key: &str,
|
||||
_content: &str,
|
||||
_category: MemoryCategory,
|
||||
_session_id: Option<&str>,
|
||||
) -> anyhow::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn recall(&self, _query: &str, _limit: usize) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
async fn recall(
|
||||
&self,
|
||||
_query: &str,
|
||||
_limit: usize,
|
||||
_session_id: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
Ok(Vec::new())
|
||||
}
|
||||
|
||||
|
|
@ -37,7 +43,11 @@ impl Memory for NoneMemory {
|
|||
Ok(None)
|
||||
}
|
||||
|
||||
async fn list(&self, _category: Option<&MemoryCategory>) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
async fn list(
|
||||
&self,
|
||||
_category: Option<&MemoryCategory>,
|
||||
_session_id: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
Ok(Vec::new())
|
||||
}
|
||||
|
||||
|
|
@ -62,11 +72,14 @@ mod tests {
|
|||
async fn none_memory_is_noop() {
|
||||
let memory = NoneMemory::new();
|
||||
|
||||
memory.store("k", "v", MemoryCategory::Core).await.unwrap();
|
||||
memory
|
||||
.store("k", "v", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(memory.get("k").await.unwrap().is_none());
|
||||
assert!(memory.recall("k", 10).await.unwrap().is_empty());
|
||||
assert!(memory.list(None).await.unwrap().is_empty());
|
||||
assert!(memory.recall("k", 10, None).await.unwrap().is_empty());
|
||||
assert!(memory.list(None, None).await.unwrap().is_empty());
|
||||
assert!(!memory.forget("k").await.unwrap());
|
||||
assert_eq!(memory.count().await.unwrap(), 0);
|
||||
assert!(memory.health_check().await);
|
||||
|
|
|
|||
|
|
@ -157,7 +157,7 @@ impl ResponseCache {
|
|||
|row| row.get(0),
|
||||
)?;
|
||||
|
||||
#[allow(clippy::cast_sign_loss)]
|
||||
#[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)]
|
||||
Ok((count as usize, hits as u64, tokens_saved as u64))
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -124,6 +124,19 @@ impl SqliteMemory {
|
|||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_cache_accessed ON embedding_cache(accessed_at);",
|
||||
)?;
|
||||
|
||||
// Migration: add session_id column if not present (safe to run repeatedly)
|
||||
let has_session_id: bool = conn
|
||||
.prepare("SELECT sql FROM sqlite_master WHERE type='table' AND name='memories'")?
|
||||
.query_row([], |row| row.get::<_, String>(0))?
|
||||
.contains("session_id");
|
||||
if !has_session_id {
|
||||
conn.execute_batch(
|
||||
"ALTER TABLE memories ADD COLUMN session_id TEXT;
|
||||
CREATE INDEX IF NOT EXISTS idx_memories_session ON memories(session_id);",
|
||||
)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
|
@ -361,6 +374,7 @@ impl Memory for SqliteMemory {
|
|||
key: &str,
|
||||
content: &str,
|
||||
category: MemoryCategory,
|
||||
session_id: Option<&str>,
|
||||
) -> anyhow::Result<()> {
|
||||
// Compute embedding (async, before lock)
|
||||
let embedding_bytes = self
|
||||
|
|
@ -377,20 +391,26 @@ impl Memory for SqliteMemory {
|
|||
let id = Uuid::new_v4().to_string();
|
||||
|
||||
conn.execute(
|
||||
"INSERT INTO memories (id, key, content, category, embedding, created_at, updated_at)
|
||||
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)
|
||||
"INSERT INTO memories (id, key, content, category, embedding, created_at, updated_at, session_id)
|
||||
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)
|
||||
ON CONFLICT(key) DO UPDATE SET
|
||||
content = excluded.content,
|
||||
category = excluded.category,
|
||||
embedding = excluded.embedding,
|
||||
updated_at = excluded.updated_at",
|
||||
params![id, key, content, cat, embedding_bytes, now, now],
|
||||
updated_at = excluded.updated_at,
|
||||
session_id = excluded.session_id",
|
||||
params![id, key, content, cat, embedding_bytes, now, now, session_id],
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn recall(&self, query: &str, limit: usize) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
async fn recall(
|
||||
&self,
|
||||
query: &str,
|
||||
limit: usize,
|
||||
session_id: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
if query.trim().is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
|
@ -439,7 +459,7 @@ impl Memory for SqliteMemory {
|
|||
let mut results = Vec::new();
|
||||
for scored in &merged {
|
||||
let mut stmt = conn.prepare(
|
||||
"SELECT id, key, content, category, created_at FROM memories WHERE id = ?1",
|
||||
"SELECT id, key, content, category, created_at, session_id FROM memories WHERE id = ?1",
|
||||
)?;
|
||||
if let Ok(entry) = stmt.query_row(params![scored.id], |row| {
|
||||
Ok(MemoryEntry {
|
||||
|
|
@ -448,10 +468,16 @@ impl Memory for SqliteMemory {
|
|||
content: row.get(2)?,
|
||||
category: Self::str_to_category(&row.get::<_, String>(3)?),
|
||||
timestamp: row.get(4)?,
|
||||
session_id: None,
|
||||
session_id: row.get(5)?,
|
||||
score: Some(f64::from(scored.final_score)),
|
||||
})
|
||||
}) {
|
||||
// Filter by session_id if requested
|
||||
if let Some(sid) = session_id {
|
||||
if entry.session_id.as_deref() != Some(sid) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
results.push(entry);
|
||||
}
|
||||
}
|
||||
|
|
@ -470,7 +496,7 @@ impl Memory for SqliteMemory {
|
|||
.collect();
|
||||
let where_clause = conditions.join(" OR ");
|
||||
let sql = format!(
|
||||
"SELECT id, key, content, category, created_at FROM memories
|
||||
"SELECT id, key, content, category, created_at, session_id FROM memories
|
||||
WHERE {where_clause}
|
||||
ORDER BY updated_at DESC
|
||||
LIMIT ?{}",
|
||||
|
|
@ -493,12 +519,18 @@ impl Memory for SqliteMemory {
|
|||
content: row.get(2)?,
|
||||
category: Self::str_to_category(&row.get::<_, String>(3)?),
|
||||
timestamp: row.get(4)?,
|
||||
session_id: None,
|
||||
session_id: row.get(5)?,
|
||||
score: Some(1.0),
|
||||
})
|
||||
})?;
|
||||
for row in rows {
|
||||
results.push(row?);
|
||||
let entry = row?;
|
||||
if let Some(sid) = session_id {
|
||||
if entry.session_id.as_deref() != Some(sid) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
results.push(entry);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -514,7 +546,7 @@ impl Memory for SqliteMemory {
|
|||
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
|
||||
|
||||
let mut stmt = conn.prepare(
|
||||
"SELECT id, key, content, category, created_at FROM memories WHERE key = ?1",
|
||||
"SELECT id, key, content, category, created_at, session_id FROM memories WHERE key = ?1",
|
||||
)?;
|
||||
|
||||
let mut rows = stmt.query_map(params![key], |row| {
|
||||
|
|
@ -524,7 +556,7 @@ impl Memory for SqliteMemory {
|
|||
content: row.get(2)?,
|
||||
category: Self::str_to_category(&row.get::<_, String>(3)?),
|
||||
timestamp: row.get(4)?,
|
||||
session_id: None,
|
||||
session_id: row.get(5)?,
|
||||
score: None,
|
||||
})
|
||||
})?;
|
||||
|
|
@ -535,7 +567,11 @@ impl Memory for SqliteMemory {
|
|||
}
|
||||
}
|
||||
|
||||
async fn list(&self, category: Option<&MemoryCategory>) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
async fn list(
|
||||
&self,
|
||||
category: Option<&MemoryCategory>,
|
||||
session_id: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
let conn = self
|
||||
.conn
|
||||
.lock()
|
||||
|
|
@ -550,7 +586,7 @@ impl Memory for SqliteMemory {
|
|||
content: row.get(2)?,
|
||||
category: Self::str_to_category(&row.get::<_, String>(3)?),
|
||||
timestamp: row.get(4)?,
|
||||
session_id: None,
|
||||
session_id: row.get(5)?,
|
||||
score: None,
|
||||
})
|
||||
};
|
||||
|
|
@ -558,21 +594,33 @@ impl Memory for SqliteMemory {
|
|||
if let Some(cat) = category {
|
||||
let cat_str = Self::category_to_str(cat);
|
||||
let mut stmt = conn.prepare(
|
||||
"SELECT id, key, content, category, created_at FROM memories
|
||||
"SELECT id, key, content, category, created_at, session_id FROM memories
|
||||
WHERE category = ?1 ORDER BY updated_at DESC",
|
||||
)?;
|
||||
let rows = stmt.query_map(params![cat_str], row_mapper)?;
|
||||
for row in rows {
|
||||
results.push(row?);
|
||||
let entry = row?;
|
||||
if let Some(sid) = session_id {
|
||||
if entry.session_id.as_deref() != Some(sid) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
results.push(entry);
|
||||
}
|
||||
} else {
|
||||
let mut stmt = conn.prepare(
|
||||
"SELECT id, key, content, category, created_at FROM memories
|
||||
"SELECT id, key, content, category, created_at, session_id FROM memories
|
||||
ORDER BY updated_at DESC",
|
||||
)?;
|
||||
let rows = stmt.query_map([], row_mapper)?;
|
||||
for row in rows {
|
||||
results.push(row?);
|
||||
let entry = row?;
|
||||
if let Some(sid) = session_id {
|
||||
if entry.session_id.as_deref() != Some(sid) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
results.push(entry);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -632,7 +680,7 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn sqlite_store_and_get() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("user_lang", "Prefers Rust", MemoryCategory::Core)
|
||||
mem.store("user_lang", "Prefers Rust", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
|
|
@ -647,10 +695,10 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn sqlite_store_upsert() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("pref", "likes Rust", MemoryCategory::Core)
|
||||
mem.store("pref", "likes Rust", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("pref", "loves Rust", MemoryCategory::Core)
|
||||
mem.store("pref", "loves Rust", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
|
|
@ -662,17 +710,22 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn sqlite_recall_keyword() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("a", "Rust is fast and safe", MemoryCategory::Core)
|
||||
mem.store("a", "Rust is fast and safe", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("b", "Python is interpreted", MemoryCategory::Core)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("c", "Rust has zero-cost abstractions", MemoryCategory::Core)
|
||||
mem.store("b", "Python is interpreted", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store(
|
||||
"c",
|
||||
"Rust has zero-cost abstractions",
|
||||
MemoryCategory::Core,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let results = mem.recall("Rust", 10).await.unwrap();
|
||||
let results = mem.recall("Rust", 10, None).await.unwrap();
|
||||
assert_eq!(results.len(), 2);
|
||||
assert!(results
|
||||
.iter()
|
||||
|
|
@ -682,14 +735,14 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn sqlite_recall_multi_keyword() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("a", "Rust is fast", MemoryCategory::Core)
|
||||
mem.store("a", "Rust is fast", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("b", "Rust is safe and fast", MemoryCategory::Core)
|
||||
mem.store("b", "Rust is safe and fast", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let results = mem.recall("fast safe", 10).await.unwrap();
|
||||
let results = mem.recall("fast safe", 10, None).await.unwrap();
|
||||
assert!(!results.is_empty());
|
||||
// Entry with both keywords should score higher
|
||||
assert!(results[0].content.contains("safe") && results[0].content.contains("fast"));
|
||||
|
|
@ -698,17 +751,17 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn sqlite_recall_no_match() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("a", "Rust rocks", MemoryCategory::Core)
|
||||
mem.store("a", "Rust rocks", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
let results = mem.recall("javascript", 10).await.unwrap();
|
||||
let results = mem.recall("javascript", 10, None).await.unwrap();
|
||||
assert!(results.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn sqlite_forget() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("temp", "temporary data", MemoryCategory::Conversation)
|
||||
mem.store("temp", "temporary data", MemoryCategory::Conversation, None)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(mem.count().await.unwrap(), 1);
|
||||
|
|
@ -728,29 +781,37 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn sqlite_list_all() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("a", "one", MemoryCategory::Core).await.unwrap();
|
||||
mem.store("b", "two", MemoryCategory::Daily).await.unwrap();
|
||||
mem.store("c", "three", MemoryCategory::Conversation)
|
||||
mem.store("a", "one", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("b", "two", MemoryCategory::Daily, None)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("c", "three", MemoryCategory::Conversation, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let all = mem.list(None).await.unwrap();
|
||||
let all = mem.list(None, None).await.unwrap();
|
||||
assert_eq!(all.len(), 3);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn sqlite_list_by_category() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("a", "core1", MemoryCategory::Core).await.unwrap();
|
||||
mem.store("b", "core2", MemoryCategory::Core).await.unwrap();
|
||||
mem.store("c", "daily1", MemoryCategory::Daily)
|
||||
mem.store("a", "core1", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("b", "core2", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("c", "daily1", MemoryCategory::Daily, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let core = mem.list(Some(&MemoryCategory::Core)).await.unwrap();
|
||||
let core = mem.list(Some(&MemoryCategory::Core), None).await.unwrap();
|
||||
assert_eq!(core.len(), 2);
|
||||
|
||||
let daily = mem.list(Some(&MemoryCategory::Daily)).await.unwrap();
|
||||
let daily = mem.list(Some(&MemoryCategory::Daily), None).await.unwrap();
|
||||
assert_eq!(daily.len(), 1);
|
||||
}
|
||||
|
||||
|
|
@ -772,7 +833,7 @@ mod tests {
|
|||
|
||||
{
|
||||
let mem = SqliteMemory::new(tmp.path()).unwrap();
|
||||
mem.store("persist", "I survive restarts", MemoryCategory::Core)
|
||||
mem.store("persist", "I survive restarts", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
|
@ -795,7 +856,7 @@ mod tests {
|
|||
];
|
||||
|
||||
for (i, cat) in categories.iter().enumerate() {
|
||||
mem.store(&format!("k{i}"), &format!("v{i}"), cat.clone())
|
||||
mem.store(&format!("k{i}"), &format!("v{i}"), cat.clone(), None)
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
|
@ -815,21 +876,28 @@ mod tests {
|
|||
"a",
|
||||
"Rust is a systems programming language",
|
||||
MemoryCategory::Core,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store(
|
||||
"b",
|
||||
"Python is great for scripting",
|
||||
MemoryCategory::Core,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("b", "Python is great for scripting", MemoryCategory::Core)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store(
|
||||
"c",
|
||||
"Rust and Rust and Rust everywhere",
|
||||
MemoryCategory::Core,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let results = mem.recall("Rust", 10).await.unwrap();
|
||||
let results = mem.recall("Rust", 10, None).await.unwrap();
|
||||
assert!(results.len() >= 2);
|
||||
// All results should contain "Rust"
|
||||
for r in &results {
|
||||
|
|
@ -844,17 +912,17 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn fts5_multi_word_query() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("a", "The quick brown fox jumps", MemoryCategory::Core)
|
||||
mem.store("a", "The quick brown fox jumps", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("b", "A lazy dog sleeps", MemoryCategory::Core)
|
||||
mem.store("b", "A lazy dog sleeps", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("c", "The quick dog runs fast", MemoryCategory::Core)
|
||||
mem.store("c", "The quick dog runs fast", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let results = mem.recall("quick dog", 10).await.unwrap();
|
||||
let results = mem.recall("quick dog", 10, None).await.unwrap();
|
||||
assert!(!results.is_empty());
|
||||
// "The quick dog runs fast" matches both terms
|
||||
assert!(results[0].content.contains("quick"));
|
||||
|
|
@ -863,16 +931,20 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn recall_empty_query_returns_empty() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("a", "data", MemoryCategory::Core).await.unwrap();
|
||||
let results = mem.recall("", 10).await.unwrap();
|
||||
mem.store("a", "data", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
let results = mem.recall("", 10, None).await.unwrap();
|
||||
assert!(results.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn recall_whitespace_query_returns_empty() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("a", "data", MemoryCategory::Core).await.unwrap();
|
||||
let results = mem.recall(" ", 10).await.unwrap();
|
||||
mem.store("a", "data", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
let results = mem.recall(" ", 10, None).await.unwrap();
|
||||
assert!(results.is_empty());
|
||||
}
|
||||
|
||||
|
|
@ -937,9 +1009,14 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn fts5_syncs_on_insert() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("test_key", "unique_searchterm_xyz", MemoryCategory::Core)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store(
|
||||
"test_key",
|
||||
"unique_searchterm_xyz",
|
||||
MemoryCategory::Core,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let conn = mem.conn.lock();
|
||||
let count: i64 = conn
|
||||
|
|
@ -955,9 +1032,14 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn fts5_syncs_on_delete() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("del_key", "deletable_content_abc", MemoryCategory::Core)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store(
|
||||
"del_key",
|
||||
"deletable_content_abc",
|
||||
MemoryCategory::Core,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.forget("del_key").await.unwrap();
|
||||
|
||||
let conn = mem.conn.lock();
|
||||
|
|
@ -974,10 +1056,15 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn fts5_syncs_on_update() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("upd_key", "original_content_111", MemoryCategory::Core)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("upd_key", "updated_content_222", MemoryCategory::Core)
|
||||
mem.store(
|
||||
"upd_key",
|
||||
"original_content_111",
|
||||
MemoryCategory::Core,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("upd_key", "updated_content_222", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
|
|
@ -1019,10 +1106,10 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn reindex_rebuilds_fts() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("r1", "reindex test alpha", MemoryCategory::Core)
|
||||
mem.store("r1", "reindex test alpha", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("r2", "reindex test beta", MemoryCategory::Core)
|
||||
mem.store("r2", "reindex test beta", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
|
|
@ -1031,7 +1118,7 @@ mod tests {
|
|||
assert_eq!(count, 0);
|
||||
|
||||
// FTS should still work after rebuild
|
||||
let results = mem.recall("reindex", 10).await.unwrap();
|
||||
let results = mem.recall("reindex", 10, None).await.unwrap();
|
||||
assert_eq!(results.len(), 2);
|
||||
}
|
||||
|
||||
|
|
@ -1045,12 +1132,13 @@ mod tests {
|
|||
&format!("k{i}"),
|
||||
&format!("common keyword item {i}"),
|
||||
MemoryCategory::Core,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
let results = mem.recall("common keyword", 5).await.unwrap();
|
||||
let results = mem.recall("common keyword", 5, None).await.unwrap();
|
||||
assert!(results.len() <= 5);
|
||||
}
|
||||
|
||||
|
|
@ -1059,11 +1147,11 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn recall_results_have_scores() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("s1", "scored result test", MemoryCategory::Core)
|
||||
mem.store("s1", "scored result test", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let results = mem.recall("scored", 10).await.unwrap();
|
||||
let results = mem.recall("scored", 10, None).await.unwrap();
|
||||
assert!(!results.is_empty());
|
||||
for r in &results {
|
||||
assert!(r.score.is_some(), "Expected score on result: {:?}", r.key);
|
||||
|
|
@ -1075,11 +1163,11 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn recall_with_quotes_in_query() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("q1", "He said hello world", MemoryCategory::Core)
|
||||
mem.store("q1", "He said hello world", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
// Quotes in query should not crash FTS5
|
||||
let results = mem.recall("\"hello\"", 10).await.unwrap();
|
||||
let results = mem.recall("\"hello\"", 10, None).await.unwrap();
|
||||
// May or may not match depending on FTS5 escaping, but must not error
|
||||
assert!(results.len() <= 10);
|
||||
}
|
||||
|
|
@ -1087,31 +1175,34 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn recall_with_asterisk_in_query() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("a1", "wildcard test content", MemoryCategory::Core)
|
||||
mem.store("a1", "wildcard test content", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
let results = mem.recall("wild*", 10).await.unwrap();
|
||||
let results = mem.recall("wild*", 10, None).await.unwrap();
|
||||
assert!(results.len() <= 10);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn recall_with_parentheses_in_query() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("p1", "function call test", MemoryCategory::Core)
|
||||
mem.store("p1", "function call test", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
let results = mem.recall("function()", 10).await.unwrap();
|
||||
let results = mem.recall("function()", 10, None).await.unwrap();
|
||||
assert!(results.len() <= 10);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn recall_with_sql_injection_attempt() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("safe", "normal content", MemoryCategory::Core)
|
||||
mem.store("safe", "normal content", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
// Should not crash or leak data
|
||||
let results = mem.recall("'; DROP TABLE memories; --", 10).await.unwrap();
|
||||
let results = mem
|
||||
.recall("'; DROP TABLE memories; --", 10, None)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(results.len() <= 10);
|
||||
// Table should still exist
|
||||
assert_eq!(mem.count().await.unwrap(), 1);
|
||||
|
|
@ -1122,7 +1213,9 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn store_empty_content() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("empty", "", MemoryCategory::Core).await.unwrap();
|
||||
mem.store("empty", "", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
let entry = mem.get("empty").await.unwrap().unwrap();
|
||||
assert_eq!(entry.content, "");
|
||||
}
|
||||
|
|
@ -1130,7 +1223,7 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn store_empty_key() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("", "content for empty key", MemoryCategory::Core)
|
||||
mem.store("", "content for empty key", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
let entry = mem.get("").await.unwrap().unwrap();
|
||||
|
|
@ -1141,7 +1234,7 @@ mod tests {
|
|||
async fn store_very_long_content() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
let long_content = "x".repeat(100_000);
|
||||
mem.store("long", &long_content, MemoryCategory::Core)
|
||||
mem.store("long", &long_content, MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
let entry = mem.get("long").await.unwrap().unwrap();
|
||||
|
|
@ -1151,9 +1244,14 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn store_unicode_and_emoji() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("emoji_key_🦀", "こんにちは 🚀 Ñoño", MemoryCategory::Core)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store(
|
||||
"emoji_key_🦀",
|
||||
"こんにちは 🚀 Ñoño",
|
||||
MemoryCategory::Core,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
let entry = mem.get("emoji_key_🦀").await.unwrap().unwrap();
|
||||
assert_eq!(entry.content, "こんにちは 🚀 Ñoño");
|
||||
}
|
||||
|
|
@ -1162,7 +1260,7 @@ mod tests {
|
|||
async fn store_content_with_newlines_and_tabs() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
let content = "line1\nline2\ttab\rcarriage\n\nnewparagraph";
|
||||
mem.store("whitespace", content, MemoryCategory::Core)
|
||||
mem.store("whitespace", content, MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
let entry = mem.get("whitespace").await.unwrap().unwrap();
|
||||
|
|
@ -1174,11 +1272,11 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn recall_single_character_query() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("a", "x marks the spot", MemoryCategory::Core)
|
||||
mem.store("a", "x marks the spot", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
// Single char may not match FTS5 but LIKE fallback should work
|
||||
let results = mem.recall("x", 10).await.unwrap();
|
||||
let results = mem.recall("x", 10, None).await.unwrap();
|
||||
// Should not crash; may or may not find results
|
||||
assert!(results.len() <= 10);
|
||||
}
|
||||
|
|
@ -1186,23 +1284,23 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn recall_limit_zero() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("a", "some content", MemoryCategory::Core)
|
||||
mem.store("a", "some content", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
let results = mem.recall("some", 0).await.unwrap();
|
||||
let results = mem.recall("some", 0, None).await.unwrap();
|
||||
assert!(results.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn recall_limit_one() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("a", "matching content alpha", MemoryCategory::Core)
|
||||
mem.store("a", "matching content alpha", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("b", "matching content beta", MemoryCategory::Core)
|
||||
mem.store("b", "matching content beta", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
let results = mem.recall("matching content", 1).await.unwrap();
|
||||
let results = mem.recall("matching content", 1, None).await.unwrap();
|
||||
assert_eq!(results.len(), 1);
|
||||
}
|
||||
|
||||
|
|
@ -1213,21 +1311,22 @@ mod tests {
|
|||
"rust_preferences",
|
||||
"User likes systems programming",
|
||||
MemoryCategory::Core,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
// "rust" appears in key but not content — LIKE fallback checks key too
|
||||
let results = mem.recall("rust", 10).await.unwrap();
|
||||
let results = mem.recall("rust", 10, None).await.unwrap();
|
||||
assert!(!results.is_empty(), "Should match by key");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn recall_unicode_query() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("jp", "日本語のテスト", MemoryCategory::Core)
|
||||
mem.store("jp", "日本語のテスト", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
let results = mem.recall("日本語", 10).await.unwrap();
|
||||
let results = mem.recall("日本語", 10, None).await.unwrap();
|
||||
assert!(!results.is_empty());
|
||||
}
|
||||
|
||||
|
|
@ -1238,7 +1337,9 @@ mod tests {
|
|||
let tmp = TempDir::new().unwrap();
|
||||
{
|
||||
let mem = SqliteMemory::new(tmp.path()).unwrap();
|
||||
mem.store("k1", "v1", MemoryCategory::Core).await.unwrap();
|
||||
mem.store("k1", "v1", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
// Open again — init_schema runs again on existing DB
|
||||
let mem2 = SqliteMemory::new(tmp.path()).unwrap();
|
||||
|
|
@ -1246,7 +1347,9 @@ mod tests {
|
|||
assert!(entry.is_some());
|
||||
assert_eq!(entry.unwrap().content, "v1");
|
||||
// Store more data — should work fine
|
||||
mem2.store("k2", "v2", MemoryCategory::Daily).await.unwrap();
|
||||
mem2.store("k2", "v2", MemoryCategory::Daily, None)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(mem2.count().await.unwrap(), 2);
|
||||
}
|
||||
|
||||
|
|
@ -1264,11 +1367,16 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn forget_then_recall_no_ghost_results() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("ghost", "phantom memory content", MemoryCategory::Core)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store(
|
||||
"ghost",
|
||||
"phantom memory content",
|
||||
MemoryCategory::Core,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.forget("ghost").await.unwrap();
|
||||
let results = mem.recall("phantom memory", 10).await.unwrap();
|
||||
let results = mem.recall("phantom memory", 10, None).await.unwrap();
|
||||
assert!(
|
||||
results.is_empty(),
|
||||
"Deleted memory should not appear in recall"
|
||||
|
|
@ -1278,11 +1386,11 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn forget_and_re_store_same_key() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("cycle", "version 1", MemoryCategory::Core)
|
||||
mem.store("cycle", "version 1", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.forget("cycle").await.unwrap();
|
||||
mem.store("cycle", "version 2", MemoryCategory::Core)
|
||||
mem.store("cycle", "version 2", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
let entry = mem.get("cycle").await.unwrap().unwrap();
|
||||
|
|
@ -1302,14 +1410,14 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn reindex_twice_is_safe() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("r1", "reindex data", MemoryCategory::Core)
|
||||
mem.store("r1", "reindex data", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.reindex().await.unwrap();
|
||||
let count = mem.reindex().await.unwrap();
|
||||
assert_eq!(count, 0); // Noop embedder → nothing to re-embed
|
||||
// Data should still be intact
|
||||
let results = mem.recall("reindex", 10).await.unwrap();
|
||||
let results = mem.recall("reindex", 10, None).await.unwrap();
|
||||
assert_eq!(results.len(), 1);
|
||||
}
|
||||
|
||||
|
|
@ -1363,18 +1471,28 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn list_custom_category() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("c1", "custom1", MemoryCategory::Custom("project".into()))
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("c2", "custom2", MemoryCategory::Custom("project".into()))
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("c3", "other", MemoryCategory::Core)
|
||||
mem.store(
|
||||
"c1",
|
||||
"custom1",
|
||||
MemoryCategory::Custom("project".into()),
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store(
|
||||
"c2",
|
||||
"custom2",
|
||||
MemoryCategory::Custom("project".into()),
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("c3", "other", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let project = mem
|
||||
.list(Some(&MemoryCategory::Custom("project".into())))
|
||||
.list(Some(&MemoryCategory::Custom("project".into())), None)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(project.len(), 2);
|
||||
|
|
@ -1383,7 +1501,122 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn list_empty_db() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
let all = mem.list(None).await.unwrap();
|
||||
let all = mem.list(None, None).await.unwrap();
|
||||
assert!(all.is_empty());
|
||||
}
|
||||
|
||||
// ── Session isolation ─────────────────────────────────────────
|
||||
|
||||
#[tokio::test]
|
||||
async fn store_and_recall_with_session_id() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("k1", "session A fact", MemoryCategory::Core, Some("sess-a"))
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("k2", "session B fact", MemoryCategory::Core, Some("sess-b"))
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("k3", "no session fact", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Recall with session-a filter returns only session-a entry
|
||||
let results = mem.recall("fact", 10, Some("sess-a")).await.unwrap();
|
||||
assert_eq!(results.len(), 1);
|
||||
assert_eq!(results[0].key, "k1");
|
||||
assert_eq!(results[0].session_id.as_deref(), Some("sess-a"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn recall_no_session_filter_returns_all() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("k1", "alpha fact", MemoryCategory::Core, Some("sess-a"))
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("k2", "beta fact", MemoryCategory::Core, Some("sess-b"))
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("k3", "gamma fact", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Recall without session filter returns all matching entries
|
||||
let results = mem.recall("fact", 10, None).await.unwrap();
|
||||
assert_eq!(results.len(), 3);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn cross_session_recall_isolation() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store(
|
||||
"secret",
|
||||
"session A secret data",
|
||||
MemoryCategory::Core,
|
||||
Some("sess-a"),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Session B cannot see session A data
|
||||
let results = mem.recall("secret", 10, Some("sess-b")).await.unwrap();
|
||||
assert!(results.is_empty());
|
||||
|
||||
// Session A can see its own data
|
||||
let results = mem.recall("secret", 10, Some("sess-a")).await.unwrap();
|
||||
assert_eq!(results.len(), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn list_with_session_filter() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("k1", "a1", MemoryCategory::Core, Some("sess-a"))
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("k2", "a2", MemoryCategory::Conversation, Some("sess-a"))
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("k3", "b1", MemoryCategory::Core, Some("sess-b"))
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("k4", "none1", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// List with session-a filter
|
||||
let results = mem.list(None, Some("sess-a")).await.unwrap();
|
||||
assert_eq!(results.len(), 2);
|
||||
assert!(results
|
||||
.iter()
|
||||
.all(|e| e.session_id.as_deref() == Some("sess-a")));
|
||||
|
||||
// List with session-a + category filter
|
||||
let results = mem
|
||||
.list(Some(&MemoryCategory::Core), Some("sess-a"))
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(results.len(), 1);
|
||||
assert_eq!(results[0].key, "k1");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn schema_migration_idempotent_on_reopen() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
|
||||
// First open: creates schema + migration
|
||||
{
|
||||
let mem = SqliteMemory::new(tmp.path()).unwrap();
|
||||
mem.store("k1", "before reopen", MemoryCategory::Core, Some("sess-x"))
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
// Second open: migration runs again but is idempotent
|
||||
{
|
||||
let mem = SqliteMemory::new(tmp.path()).unwrap();
|
||||
let results = mem.recall("reopen", 10, Some("sess-x")).await.unwrap();
|
||||
assert_eq!(results.len(), 1);
|
||||
assert_eq!(results[0].key, "k1");
|
||||
assert_eq!(results[0].session_id.as_deref(), Some("sess-x"));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -44,18 +44,32 @@ pub trait Memory: Send + Sync {
|
|||
/// Backend name
|
||||
fn name(&self) -> &str;
|
||||
|
||||
/// Store a memory entry
|
||||
async fn store(&self, key: &str, content: &str, category: MemoryCategory)
|
||||
-> anyhow::Result<()>;
|
||||
/// Store a memory entry, optionally scoped to a session
|
||||
async fn store(
|
||||
&self,
|
||||
key: &str,
|
||||
content: &str,
|
||||
category: MemoryCategory,
|
||||
session_id: Option<&str>,
|
||||
) -> anyhow::Result<()>;
|
||||
|
||||
/// Recall memories matching a query (keyword search)
|
||||
async fn recall(&self, query: &str, limit: usize) -> anyhow::Result<Vec<MemoryEntry>>;
|
||||
/// Recall memories matching a query (keyword search), optionally scoped to a session
|
||||
async fn recall(
|
||||
&self,
|
||||
query: &str,
|
||||
limit: usize,
|
||||
session_id: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>>;
|
||||
|
||||
/// Get a specific memory by key
|
||||
async fn get(&self, key: &str) -> anyhow::Result<Option<MemoryEntry>>;
|
||||
|
||||
/// List all memory keys, optionally filtered by category
|
||||
async fn list(&self, category: Option<&MemoryCategory>) -> anyhow::Result<Vec<MemoryEntry>>;
|
||||
/// List all memory keys, optionally filtered by category and/or session
|
||||
async fn list(
|
||||
&self,
|
||||
category: Option<&MemoryCategory>,
|
||||
session_id: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>>;
|
||||
|
||||
/// Remove a memory by key
|
||||
async fn forget(&self, key: &str) -> anyhow::Result<bool>;
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue