zeroclaw/src/memory/sqlite.rs
xj 65a12dd611 fix: resolve all clippy warnings and fix test compilation errors
Address clippy pedantic/all lints: format_push_string in sqlite memory,
match_same_arms and match_wildcard_for_single_variants in anthropic
provider and prometheus observer, option_as_ref_cloned in main. Fix
pre-existing test compilation errors in gateway (missing max_keys arg
and trust_forwarded_headers field) and memory_store (missing security
arg). Add .worktrees/ to gitignore.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-18 21:06:54 +08:00

1900 lines
65 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

use super::embeddings::EmbeddingProvider;
use super::traits::{Memory, MemoryCategory, MemoryEntry};
use super::vector;
use anyhow::Context;
use async_trait::async_trait;
use chrono::Local;
use parking_lot::Mutex;
use rusqlite::{params, Connection};
use std::fmt::Write as _;
use std::path::{Path, PathBuf};
use std::sync::mpsc;
use std::sync::Arc;
use std::thread;
use std::time::Duration;
use uuid::Uuid;
/// Maximum allowed open timeout (seconds) to avoid unreasonable waits.
const SQLITE_OPEN_TIMEOUT_CAP_SECS: u64 = 300;
/// SQLite-backed persistent memory — the brain
///
/// Full-stack search engine:
/// - **Vector DB**: embeddings stored as BLOB, cosine similarity search
/// - **Keyword Search**: FTS5 virtual table with BM25 scoring
/// - **Hybrid Merge**: weighted fusion of vector + keyword results
/// - **Embedding Cache**: LRU-evicted cache to avoid redundant API calls
/// - **Safe Reindex**: temp DB → seed → sync → atomic swap → rollback
pub struct SqliteMemory {
conn: Arc<Mutex<Connection>>,
db_path: PathBuf,
embedder: Arc<dyn EmbeddingProvider>,
vector_weight: f32,
keyword_weight: f32,
cache_max: usize,
}
impl SqliteMemory {
pub fn new(workspace_dir: &Path) -> anyhow::Result<Self> {
Self::with_embedder(
workspace_dir,
Arc::new(super::embeddings::NoopEmbedding),
0.7,
0.3,
10_000,
None,
)
}
/// Build SQLite memory with optional open timeout.
///
/// If `open_timeout_secs` is `Some(n)`, opening the database is limited to `n` seconds
/// (capped at 300). Useful when the DB file may be locked or on slow storage.
/// `None` = wait indefinitely (default).
pub fn with_embedder(
workspace_dir: &Path,
embedder: Arc<dyn EmbeddingProvider>,
vector_weight: f32,
keyword_weight: f32,
cache_max: usize,
open_timeout_secs: Option<u64>,
) -> anyhow::Result<Self> {
let db_path = workspace_dir.join("memory").join("brain.db");
if let Some(parent) = db_path.parent() {
std::fs::create_dir_all(parent)?;
}
let conn = Self::open_connection(&db_path, open_timeout_secs)?;
// ── Production-grade PRAGMA tuning ──────────────────────
// WAL mode: concurrent reads during writes, crash-safe
// normal sync: 2× write speed, still durable on WAL
// mmap 8 MB: let the OS page-cache serve hot reads
// cache 2 MB: keep ~500 hot pages in-process
// temp_store memory: temp tables never hit disk
conn.execute_batch(
"PRAGMA journal_mode = WAL;
PRAGMA synchronous = NORMAL;
PRAGMA mmap_size = 8388608;
PRAGMA cache_size = -2000;
PRAGMA temp_store = MEMORY;",
)?;
Self::init_schema(&conn)?;
Ok(Self {
conn: Arc::new(Mutex::new(conn)),
db_path,
embedder,
vector_weight,
keyword_weight,
cache_max,
})
}
/// Open SQLite connection, optionally with a timeout (for locked/slow storage).
fn open_connection(
db_path: &Path,
open_timeout_secs: Option<u64>,
) -> anyhow::Result<Connection> {
let path_buf = db_path.to_path_buf();
let conn = if let Some(secs) = open_timeout_secs {
let capped = secs.min(SQLITE_OPEN_TIMEOUT_CAP_SECS);
let (tx, rx) = mpsc::channel();
thread::spawn(move || {
let result = Connection::open(&path_buf);
let _ = tx.send(result);
});
match rx.recv_timeout(Duration::from_secs(capped)) {
Ok(Ok(c)) => c,
Ok(Err(e)) => return Err(e).context("SQLite failed to open database"),
Err(mpsc::RecvTimeoutError::Timeout) => {
anyhow::bail!("SQLite connection open timed out after {} seconds", capped);
}
Err(mpsc::RecvTimeoutError::Disconnected) => {
anyhow::bail!("SQLite open thread exited unexpectedly");
}
}
} else {
Connection::open(&path_buf).context("SQLite failed to open database")?
};
Ok(conn)
}
/// Initialize all tables: memories, FTS5, `embedding_cache`
fn init_schema(conn: &Connection) -> anyhow::Result<()> {
conn.execute_batch(
"-- Core memories table
CREATE TABLE IF NOT EXISTS memories (
id TEXT PRIMARY KEY,
key TEXT NOT NULL UNIQUE,
content TEXT NOT NULL,
category TEXT NOT NULL DEFAULT 'core',
embedding BLOB,
created_at TEXT NOT NULL,
updated_at TEXT NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_memories_category ON memories(category);
CREATE INDEX IF NOT EXISTS idx_memories_key ON memories(key);
-- FTS5 full-text search (BM25 scoring)
CREATE VIRTUAL TABLE IF NOT EXISTS memories_fts USING fts5(
key, content, content=memories, content_rowid=rowid
);
-- FTS5 triggers: keep in sync with memories table
CREATE TRIGGER IF NOT EXISTS memories_ai AFTER INSERT ON memories BEGIN
INSERT INTO memories_fts(rowid, key, content)
VALUES (new.rowid, new.key, new.content);
END;
CREATE TRIGGER IF NOT EXISTS memories_ad AFTER DELETE ON memories BEGIN
INSERT INTO memories_fts(memories_fts, rowid, key, content)
VALUES ('delete', old.rowid, old.key, old.content);
END;
CREATE TRIGGER IF NOT EXISTS memories_au AFTER UPDATE ON memories BEGIN
INSERT INTO memories_fts(memories_fts, rowid, key, content)
VALUES ('delete', old.rowid, old.key, old.content);
INSERT INTO memories_fts(rowid, key, content)
VALUES (new.rowid, new.key, new.content);
END;
-- Embedding cache with LRU eviction
CREATE TABLE IF NOT EXISTS embedding_cache (
content_hash TEXT PRIMARY KEY,
embedding BLOB NOT NULL,
created_at TEXT NOT NULL,
accessed_at TEXT NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_cache_accessed ON embedding_cache(accessed_at);",
)?;
// Migration: add session_id column if not present (safe to run repeatedly)
let has_session_id: bool = conn
.prepare("SELECT sql FROM sqlite_master WHERE type='table' AND name='memories'")?
.query_row([], |row| row.get::<_, String>(0))?
.contains("session_id");
if !has_session_id {
conn.execute_batch(
"ALTER TABLE memories ADD COLUMN session_id TEXT;
CREATE INDEX IF NOT EXISTS idx_memories_session ON memories(session_id);",
)?;
}
Ok(())
}
fn category_to_str(cat: &MemoryCategory) -> String {
match cat {
MemoryCategory::Core => "core".into(),
MemoryCategory::Daily => "daily".into(),
MemoryCategory::Conversation => "conversation".into(),
MemoryCategory::Custom(name) => name.clone(),
}
}
fn str_to_category(s: &str) -> MemoryCategory {
match s {
"core" => MemoryCategory::Core,
"daily" => MemoryCategory::Daily,
"conversation" => MemoryCategory::Conversation,
other => MemoryCategory::Custom(other.to_string()),
}
}
/// Deterministic content hash for embedding cache.
/// Uses SHA-256 (truncated) instead of DefaultHasher, which is
/// explicitly documented as unstable across Rust versions.
fn content_hash(text: &str) -> String {
use sha2::{Digest, Sha256};
let hash = Sha256::digest(text.as_bytes());
// First 8 bytes → 16 hex chars, matching previous format length
format!(
"{:016x}",
u64::from_be_bytes(
hash[..8]
.try_into()
.expect("SHA-256 always produces >= 8 bytes")
)
)
}
/// Get embedding from cache, or compute + cache it
async fn get_or_compute_embedding(&self, text: &str) -> anyhow::Result<Option<Vec<f32>>> {
if self.embedder.dimensions() == 0 {
return Ok(None); // Noop embedder
}
let hash = Self::content_hash(text);
let now = Local::now().to_rfc3339();
// Check cache (offloaded to blocking thread)
let conn = self.conn.clone();
let hash_c = hash.clone();
let now_c = now.clone();
let cached = tokio::task::spawn_blocking(move || -> anyhow::Result<Option<Vec<f32>>> {
let conn = conn.lock();
let mut stmt =
conn.prepare("SELECT embedding FROM embedding_cache WHERE content_hash = ?1")?;
let blob: Option<Vec<u8>> = stmt.query_row(params![hash_c], |row| row.get(0)).ok();
if let Some(bytes) = blob {
conn.execute(
"UPDATE embedding_cache SET accessed_at = ?1 WHERE content_hash = ?2",
params![now_c, hash_c],
)?;
return Ok(Some(vector::bytes_to_vec(&bytes)));
}
Ok(None)
})
.await??;
if cached.is_some() {
return Ok(cached);
}
// Compute embedding (async I/O)
let embedding = self.embedder.embed_one(text).await?;
let bytes = vector::vec_to_bytes(&embedding);
// Store in cache + LRU eviction (offloaded to blocking thread)
let conn = self.conn.clone();
#[allow(clippy::cast_possible_wrap)]
let cache_max = self.cache_max as i64;
tokio::task::spawn_blocking(move || -> anyhow::Result<()> {
let conn = conn.lock();
conn.execute(
"INSERT OR REPLACE INTO embedding_cache (content_hash, embedding, created_at, accessed_at)
VALUES (?1, ?2, ?3, ?4)",
params![hash, bytes, now, now],
)?;
conn.execute(
"DELETE FROM embedding_cache WHERE content_hash IN (
SELECT content_hash FROM embedding_cache
ORDER BY accessed_at ASC
LIMIT MAX(0, (SELECT COUNT(*) FROM embedding_cache) - ?1)
)",
params![cache_max],
)?;
Ok(())
})
.await??;
Ok(Some(embedding))
}
/// FTS5 BM25 keyword search
fn fts5_search(
conn: &Connection,
query: &str,
limit: usize,
) -> anyhow::Result<Vec<(String, f32)>> {
// Escape FTS5 special chars and build query
let fts_query: String = query
.split_whitespace()
.map(|w| format!("\"{w}\""))
.collect::<Vec<_>>()
.join(" OR ");
if fts_query.is_empty() {
return Ok(Vec::new());
}
let sql = "SELECT m.id, bm25(memories_fts) as score
FROM memories_fts f
JOIN memories m ON m.rowid = f.rowid
WHERE memories_fts MATCH ?1
ORDER BY score
LIMIT ?2";
let mut stmt = conn.prepare(sql)?;
#[allow(clippy::cast_possible_wrap)]
let limit_i64 = limit as i64;
let rows = stmt.query_map(params![fts_query, limit_i64], |row| {
let id: String = row.get(0)?;
let score: f64 = row.get(1)?;
// BM25 returns negative scores (lower = better), negate for ranking
#[allow(clippy::cast_possible_truncation)]
Ok((id, (-score) as f32))
})?;
let mut results = Vec::new();
for row in rows {
results.push(row?);
}
Ok(results)
}
/// Vector similarity search: scan embeddings and compute cosine similarity.
///
/// Optional `category` and `session_id` filters reduce full-table scans
/// when the caller already knows the scope of relevant memories.
fn vector_search(
conn: &Connection,
query_embedding: &[f32],
limit: usize,
category: Option<&str>,
session_id: Option<&str>,
) -> anyhow::Result<Vec<(String, f32)>> {
let mut sql = "SELECT id, embedding FROM memories WHERE embedding IS NOT NULL".to_string();
let mut param_values: Vec<Box<dyn rusqlite::types::ToSql>> = Vec::new();
let mut idx = 1;
if let Some(cat) = category {
let _ = write!(sql, " AND category = ?{idx}");
param_values.push(Box::new(cat.to_string()));
idx += 1;
}
if let Some(sid) = session_id {
let _ = write!(sql, " AND session_id = ?{idx}");
param_values.push(Box::new(sid.to_string()));
}
let mut stmt = conn.prepare(&sql)?;
let params_ref: Vec<&dyn rusqlite::types::ToSql> =
param_values.iter().map(AsRef::as_ref).collect();
let rows = stmt.query_map(params_ref.as_slice(), |row| {
let id: String = row.get(0)?;
let blob: Vec<u8> = row.get(1)?;
Ok((id, blob))
})?;
let mut scored: Vec<(String, f32)> = Vec::new();
for row in rows {
let (id, blob) = row?;
let emb = vector::bytes_to_vec(&blob);
let sim = vector::cosine_similarity(query_embedding, &emb);
if sim > 0.0 {
scored.push((id, sim));
}
}
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scored.truncate(limit);
Ok(scored)
}
/// Safe reindex: rebuild FTS5 + embeddings with rollback on failure
#[allow(dead_code)]
pub async fn reindex(&self) -> anyhow::Result<usize> {
// Step 1: Rebuild FTS5
{
let conn = self.conn.clone();
tokio::task::spawn_blocking(move || -> anyhow::Result<()> {
let conn = conn.lock();
conn.execute_batch("INSERT INTO memories_fts(memories_fts) VALUES('rebuild');")?;
Ok(())
})
.await??;
}
// Step 2: Re-embed all memories that lack embeddings
if self.embedder.dimensions() == 0 {
return Ok(0);
}
let conn = self.conn.clone();
let entries: Vec<(String, String)> = tokio::task::spawn_blocking(move || {
let conn = conn.lock();
let mut stmt =
conn.prepare("SELECT id, content FROM memories WHERE embedding IS NULL")?;
let rows = stmt.query_map([], |row| {
Ok((row.get::<_, String>(0)?, row.get::<_, String>(1)?))
})?;
Ok::<_, anyhow::Error>(rows.filter_map(std::result::Result::ok).collect())
})
.await??;
let mut count = 0;
for (id, content) in &entries {
if let Ok(Some(emb)) = self.get_or_compute_embedding(content).await {
let bytes = vector::vec_to_bytes(&emb);
let conn = self.conn.clone();
let id = id.clone();
tokio::task::spawn_blocking(move || -> anyhow::Result<()> {
let conn = conn.lock();
conn.execute(
"UPDATE memories SET embedding = ?1 WHERE id = ?2",
params![bytes, id],
)?;
Ok(())
})
.await??;
count += 1;
}
}
Ok(count)
}
}
#[async_trait]
impl Memory for SqliteMemory {
fn name(&self) -> &str {
"sqlite"
}
async fn store(
&self,
key: &str,
content: &str,
category: MemoryCategory,
session_id: Option<&str>,
) -> anyhow::Result<()> {
// Compute embedding (async, before blocking work)
let embedding_bytes = self
.get_or_compute_embedding(content)
.await?
.map(|emb| vector::vec_to_bytes(&emb));
let conn = self.conn.clone();
let key = key.to_string();
let content = content.to_string();
let session_id = session_id.map(String::from);
tokio::task::spawn_blocking(move || -> anyhow::Result<()> {
let conn = conn.lock();
let now = Local::now().to_rfc3339();
let cat = Self::category_to_str(&category);
let id = Uuid::new_v4().to_string();
conn.execute(
"INSERT INTO memories (id, key, content, category, embedding, created_at, updated_at, session_id)
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)
ON CONFLICT(key) DO UPDATE SET
content = excluded.content,
category = excluded.category,
embedding = excluded.embedding,
updated_at = excluded.updated_at,
session_id = excluded.session_id",
params![id, key, content, cat, embedding_bytes, now, now, session_id],
)?;
Ok(())
})
.await?
}
async fn recall(
&self,
query: &str,
limit: usize,
session_id: Option<&str>,
) -> anyhow::Result<Vec<MemoryEntry>> {
if query.trim().is_empty() {
return Ok(Vec::new());
}
// Compute query embedding (async, before blocking work)
let query_embedding = self.get_or_compute_embedding(query).await?;
let conn = self.conn.clone();
let query = query.to_string();
let session_id = session_id.map(String::from);
let vector_weight = self.vector_weight;
let keyword_weight = self.keyword_weight;
tokio::task::spawn_blocking(move || -> anyhow::Result<Vec<MemoryEntry>> {
let conn = conn.lock();
let session_ref = session_id.as_deref();
// FTS5 BM25 keyword search
let keyword_results = Self::fts5_search(&conn, &query, limit * 2).unwrap_or_default();
// Vector similarity search (if embeddings available)
let vector_results = if let Some(ref qe) = query_embedding {
Self::vector_search(&conn, qe, limit * 2, None, session_ref).unwrap_or_default()
} else {
Vec::new()
};
// Hybrid merge
let merged = if vector_results.is_empty() {
keyword_results
.iter()
.map(|(id, score)| vector::ScoredResult {
id: id.clone(),
vector_score: None,
keyword_score: Some(*score),
final_score: *score,
})
.collect::<Vec<_>>()
} else {
vector::hybrid_merge(
&vector_results,
&keyword_results,
vector_weight,
keyword_weight,
limit,
)
};
// Fetch full entries for merged results in a single query
// instead of N round-trips (N+1 pattern).
let mut results = Vec::new();
if !merged.is_empty() {
let placeholders: String = (1..=merged.len())
.map(|i| format!("?{i}"))
.collect::<Vec<_>>()
.join(", ");
let sql = format!(
"SELECT id, key, content, category, created_at, session_id \
FROM memories WHERE id IN ({placeholders})"
);
let mut stmt = conn.prepare(&sql)?;
let id_params: Vec<Box<dyn rusqlite::types::ToSql>> = merged
.iter()
.map(|s| Box::new(s.id.clone()) as Box<dyn rusqlite::types::ToSql>)
.collect();
let params_ref: Vec<&dyn rusqlite::types::ToSql> =
id_params.iter().map(AsRef::as_ref).collect();
let rows = stmt.query_map(params_ref.as_slice(), |row| {
Ok((
row.get::<_, String>(0)?,
row.get::<_, String>(1)?,
row.get::<_, String>(2)?,
row.get::<_, String>(3)?,
row.get::<_, String>(4)?,
row.get::<_, Option<String>>(5)?,
))
})?;
let mut entry_map = std::collections::HashMap::new();
for row in rows {
let (id, key, content, cat, ts, sid) = row?;
entry_map.insert(id, (key, content, cat, ts, sid));
}
for scored in &merged {
if let Some((key, content, cat, ts, sid)) = entry_map.remove(&scored.id) {
let entry = MemoryEntry {
id: scored.id.clone(),
key,
content,
category: Self::str_to_category(&cat),
timestamp: ts,
session_id: sid,
score: Some(f64::from(scored.final_score)),
};
if let Some(filter_sid) = session_ref {
if entry.session_id.as_deref() != Some(filter_sid) {
continue;
}
}
results.push(entry);
}
}
}
// If hybrid returned nothing, fall back to LIKE search.
// Cap keyword count so we don't create too many SQL shapes,
// which helps prepared-statement cache efficiency.
if results.is_empty() {
const MAX_LIKE_KEYWORDS: usize = 8;
let keywords: Vec<String> = query
.split_whitespace()
.take(MAX_LIKE_KEYWORDS)
.map(|w| format!("%{w}%"))
.collect();
if !keywords.is_empty() {
let conditions: Vec<String> = keywords
.iter()
.enumerate()
.map(|(i, _)| {
format!("(content LIKE ?{} OR key LIKE ?{})", i * 2 + 1, i * 2 + 2)
})
.collect();
let where_clause = conditions.join(" OR ");
let sql = format!(
"SELECT id, key, content, category, created_at, session_id FROM memories
WHERE {where_clause}
ORDER BY updated_at DESC
LIMIT ?{}",
keywords.len() * 2 + 1
);
let mut stmt = conn.prepare(&sql)?;
let mut param_values: Vec<Box<dyn rusqlite::types::ToSql>> = Vec::new();
for kw in &keywords {
param_values.push(Box::new(kw.clone()));
param_values.push(Box::new(kw.clone()));
}
#[allow(clippy::cast_possible_wrap)]
param_values.push(Box::new(limit as i64));
let params_ref: Vec<&dyn rusqlite::types::ToSql> =
param_values.iter().map(AsRef::as_ref).collect();
let rows = stmt.query_map(params_ref.as_slice(), |row| {
Ok(MemoryEntry {
id: row.get(0)?,
key: row.get(1)?,
content: row.get(2)?,
category: Self::str_to_category(&row.get::<_, String>(3)?),
timestamp: row.get(4)?,
session_id: row.get(5)?,
score: Some(1.0),
})
})?;
for row in rows {
let entry = row?;
if let Some(sid) = session_ref {
if entry.session_id.as_deref() != Some(sid) {
continue;
}
}
results.push(entry);
}
}
}
results.truncate(limit);
Ok(results)
})
.await?
}
async fn get(&self, key: &str) -> anyhow::Result<Option<MemoryEntry>> {
let conn = self.conn.clone();
let key = key.to_string();
tokio::task::spawn_blocking(move || -> anyhow::Result<Option<MemoryEntry>> {
let conn = conn.lock();
let mut stmt = conn.prepare(
"SELECT id, key, content, category, created_at, session_id FROM memories WHERE key = ?1",
)?;
let mut rows = stmt.query_map(params![key], |row| {
Ok(MemoryEntry {
id: row.get(0)?,
key: row.get(1)?,
content: row.get(2)?,
category: Self::str_to_category(&row.get::<_, String>(3)?),
timestamp: row.get(4)?,
session_id: row.get(5)?,
score: None,
})
})?;
match rows.next() {
Some(Ok(entry)) => Ok(Some(entry)),
_ => Ok(None),
}
})
.await?
}
async fn list(
&self,
category: Option<&MemoryCategory>,
session_id: Option<&str>,
) -> anyhow::Result<Vec<MemoryEntry>> {
const DEFAULT_LIST_LIMIT: i64 = 1000;
let conn = self.conn.clone();
let category = category.cloned();
let session_id = session_id.map(String::from);
tokio::task::spawn_blocking(move || -> anyhow::Result<Vec<MemoryEntry>> {
let conn = conn.lock();
let session_ref = session_id.as_deref();
let mut results = Vec::new();
let row_mapper = |row: &rusqlite::Row| -> rusqlite::Result<MemoryEntry> {
Ok(MemoryEntry {
id: row.get(0)?,
key: row.get(1)?,
content: row.get(2)?,
category: Self::str_to_category(&row.get::<_, String>(3)?),
timestamp: row.get(4)?,
session_id: row.get(5)?,
score: None,
})
};
if let Some(ref cat) = category {
let cat_str = Self::category_to_str(cat);
let mut stmt = conn.prepare(
"SELECT id, key, content, category, created_at, session_id FROM memories
WHERE category = ?1 ORDER BY updated_at DESC LIMIT ?2",
)?;
let rows = stmt.query_map(params![cat_str, DEFAULT_LIST_LIMIT], row_mapper)?;
for row in rows {
let entry = row?;
if let Some(sid) = session_ref {
if entry.session_id.as_deref() != Some(sid) {
continue;
}
}
results.push(entry);
}
} else {
let mut stmt = conn.prepare(
"SELECT id, key, content, category, created_at, session_id FROM memories
ORDER BY updated_at DESC LIMIT ?1",
)?;
let rows = stmt.query_map(params![DEFAULT_LIST_LIMIT], row_mapper)?;
for row in rows {
let entry = row?;
if let Some(sid) = session_ref {
if entry.session_id.as_deref() != Some(sid) {
continue;
}
}
results.push(entry);
}
}
Ok(results)
})
.await?
}
async fn forget(&self, key: &str) -> anyhow::Result<bool> {
let conn = self.conn.clone();
let key = key.to_string();
tokio::task::spawn_blocking(move || -> anyhow::Result<bool> {
let conn = conn.lock();
let affected = conn.execute("DELETE FROM memories WHERE key = ?1", params![key])?;
Ok(affected > 0)
})
.await?
}
async fn count(&self) -> anyhow::Result<usize> {
let conn = self.conn.clone();
tokio::task::spawn_blocking(move || -> anyhow::Result<usize> {
let conn = conn.lock();
let count: i64 =
conn.query_row("SELECT COUNT(*) FROM memories", [], |row| row.get(0))?;
#[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)]
Ok(count as usize)
})
.await?
}
async fn health_check(&self) -> bool {
let conn = self.conn.clone();
tokio::task::spawn_blocking(move || conn.lock().execute_batch("SELECT 1").is_ok())
.await
.unwrap_or(false)
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
fn temp_sqlite() -> (TempDir, SqliteMemory) {
let tmp = TempDir::new().unwrap();
let mem = SqliteMemory::new(tmp.path()).unwrap();
(tmp, mem)
}
#[tokio::test]
async fn sqlite_name() {
let (_tmp, mem) = temp_sqlite();
assert_eq!(mem.name(), "sqlite");
}
#[tokio::test]
async fn sqlite_health() {
let (_tmp, mem) = temp_sqlite();
assert!(mem.health_check().await);
}
#[tokio::test]
async fn sqlite_store_and_get() {
let (_tmp, mem) = temp_sqlite();
mem.store("user_lang", "Prefers Rust", MemoryCategory::Core, None)
.await
.unwrap();
let entry = mem.get("user_lang").await.unwrap();
assert!(entry.is_some());
let entry = entry.unwrap();
assert_eq!(entry.key, "user_lang");
assert_eq!(entry.content, "Prefers Rust");
assert_eq!(entry.category, MemoryCategory::Core);
}
#[tokio::test]
async fn sqlite_store_upsert() {
let (_tmp, mem) = temp_sqlite();
mem.store("pref", "likes Rust", MemoryCategory::Core, None)
.await
.unwrap();
mem.store("pref", "loves Rust", MemoryCategory::Core, None)
.await
.unwrap();
let entry = mem.get("pref").await.unwrap().unwrap();
assert_eq!(entry.content, "loves Rust");
assert_eq!(mem.count().await.unwrap(), 1);
}
#[tokio::test]
async fn sqlite_recall_keyword() {
let (_tmp, mem) = temp_sqlite();
mem.store("a", "Rust is fast and safe", MemoryCategory::Core, None)
.await
.unwrap();
mem.store("b", "Python is interpreted", MemoryCategory::Core, None)
.await
.unwrap();
mem.store(
"c",
"Rust has zero-cost abstractions",
MemoryCategory::Core,
None,
)
.await
.unwrap();
let results = mem.recall("Rust", 10, None).await.unwrap();
assert_eq!(results.len(), 2);
assert!(results
.iter()
.all(|r| r.content.to_lowercase().contains("rust")));
}
#[tokio::test]
async fn sqlite_recall_multi_keyword() {
let (_tmp, mem) = temp_sqlite();
mem.store("a", "Rust is fast", MemoryCategory::Core, None)
.await
.unwrap();
mem.store("b", "Rust is safe and fast", MemoryCategory::Core, None)
.await
.unwrap();
let results = mem.recall("fast safe", 10, None).await.unwrap();
assert!(!results.is_empty());
// Entry with both keywords should score higher
assert!(results[0].content.contains("safe") && results[0].content.contains("fast"));
}
#[tokio::test]
async fn sqlite_recall_no_match() {
let (_tmp, mem) = temp_sqlite();
mem.store("a", "Rust rocks", MemoryCategory::Core, None)
.await
.unwrap();
let results = mem.recall("javascript", 10, None).await.unwrap();
assert!(results.is_empty());
}
#[tokio::test]
async fn sqlite_forget() {
let (_tmp, mem) = temp_sqlite();
mem.store("temp", "temporary data", MemoryCategory::Conversation, None)
.await
.unwrap();
assert_eq!(mem.count().await.unwrap(), 1);
let removed = mem.forget("temp").await.unwrap();
assert!(removed);
assert_eq!(mem.count().await.unwrap(), 0);
}
#[tokio::test]
async fn sqlite_forget_nonexistent() {
let (_tmp, mem) = temp_sqlite();
let removed = mem.forget("nope").await.unwrap();
assert!(!removed);
}
#[tokio::test]
async fn sqlite_list_all() {
let (_tmp, mem) = temp_sqlite();
mem.store("a", "one", MemoryCategory::Core, None)
.await
.unwrap();
mem.store("b", "two", MemoryCategory::Daily, None)
.await
.unwrap();
mem.store("c", "three", MemoryCategory::Conversation, None)
.await
.unwrap();
let all = mem.list(None, None).await.unwrap();
assert_eq!(all.len(), 3);
}
#[tokio::test]
async fn sqlite_list_by_category() {
let (_tmp, mem) = temp_sqlite();
mem.store("a", "core1", MemoryCategory::Core, None)
.await
.unwrap();
mem.store("b", "core2", MemoryCategory::Core, None)
.await
.unwrap();
mem.store("c", "daily1", MemoryCategory::Daily, None)
.await
.unwrap();
let core = mem.list(Some(&MemoryCategory::Core), None).await.unwrap();
assert_eq!(core.len(), 2);
let daily = mem.list(Some(&MemoryCategory::Daily), None).await.unwrap();
assert_eq!(daily.len(), 1);
}
#[tokio::test]
async fn sqlite_count_empty() {
let (_tmp, mem) = temp_sqlite();
assert_eq!(mem.count().await.unwrap(), 0);
}
#[tokio::test]
async fn sqlite_get_nonexistent() {
let (_tmp, mem) = temp_sqlite();
assert!(mem.get("nope").await.unwrap().is_none());
}
#[tokio::test]
async fn sqlite_db_persists() {
let tmp = TempDir::new().unwrap();
{
let mem = SqliteMemory::new(tmp.path()).unwrap();
mem.store("persist", "I survive restarts", MemoryCategory::Core, None)
.await
.unwrap();
}
// Reopen
let mem2 = SqliteMemory::new(tmp.path()).unwrap();
let entry = mem2.get("persist").await.unwrap();
assert!(entry.is_some());
assert_eq!(entry.unwrap().content, "I survive restarts");
}
#[tokio::test]
async fn sqlite_category_roundtrip() {
let (_tmp, mem) = temp_sqlite();
let categories = [
MemoryCategory::Core,
MemoryCategory::Daily,
MemoryCategory::Conversation,
MemoryCategory::Custom("project".into()),
];
for (i, cat) in categories.iter().enumerate() {
mem.store(&format!("k{i}"), &format!("v{i}"), cat.clone(), None)
.await
.unwrap();
}
for (i, cat) in categories.iter().enumerate() {
let entry = mem.get(&format!("k{i}")).await.unwrap().unwrap();
assert_eq!(&entry.category, cat);
}
}
// ── FTS5 search tests ────────────────────────────────────────
#[tokio::test]
async fn fts5_bm25_ranking() {
let (_tmp, mem) = temp_sqlite();
mem.store(
"a",
"Rust is a systems programming language",
MemoryCategory::Core,
None,
)
.await
.unwrap();
mem.store(
"b",
"Python is great for scripting",
MemoryCategory::Core,
None,
)
.await
.unwrap();
mem.store(
"c",
"Rust and Rust and Rust everywhere",
MemoryCategory::Core,
None,
)
.await
.unwrap();
let results = mem.recall("Rust", 10, None).await.unwrap();
assert!(results.len() >= 2);
// All results should contain "Rust"
for r in &results {
assert!(
r.content.to_lowercase().contains("rust"),
"Expected 'rust' in: {}",
r.content
);
}
}
#[tokio::test]
async fn fts5_multi_word_query() {
let (_tmp, mem) = temp_sqlite();
mem.store("a", "The quick brown fox jumps", MemoryCategory::Core, None)
.await
.unwrap();
mem.store("b", "A lazy dog sleeps", MemoryCategory::Core, None)
.await
.unwrap();
mem.store("c", "The quick dog runs fast", MemoryCategory::Core, None)
.await
.unwrap();
let results = mem.recall("quick dog", 10, None).await.unwrap();
assert!(!results.is_empty());
// "The quick dog runs fast" matches both terms
assert!(results[0].content.contains("quick"));
}
#[tokio::test]
async fn recall_empty_query_returns_empty() {
let (_tmp, mem) = temp_sqlite();
mem.store("a", "data", MemoryCategory::Core, None)
.await
.unwrap();
let results = mem.recall("", 10, None).await.unwrap();
assert!(results.is_empty());
}
#[tokio::test]
async fn recall_whitespace_query_returns_empty() {
let (_tmp, mem) = temp_sqlite();
mem.store("a", "data", MemoryCategory::Core, None)
.await
.unwrap();
let results = mem.recall(" ", 10, None).await.unwrap();
assert!(results.is_empty());
}
// ── Embedding cache tests ────────────────────────────────────
#[test]
fn content_hash_deterministic() {
let h1 = SqliteMemory::content_hash("hello world");
let h2 = SqliteMemory::content_hash("hello world");
assert_eq!(h1, h2);
}
#[test]
fn content_hash_different_inputs() {
let h1 = SqliteMemory::content_hash("hello");
let h2 = SqliteMemory::content_hash("world");
assert_ne!(h1, h2);
}
// ── Schema tests ─────────────────────────────────────────────
#[tokio::test]
async fn schema_has_fts5_table() {
let (_tmp, mem) = temp_sqlite();
let conn = mem.conn.lock();
// FTS5 table should exist
let count: i64 = conn
.query_row(
"SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='memories_fts'",
[],
|row| row.get(0),
)
.unwrap();
assert_eq!(count, 1);
}
#[tokio::test]
async fn schema_has_embedding_cache() {
let (_tmp, mem) = temp_sqlite();
let conn = mem.conn.lock();
let count: i64 = conn
.query_row(
"SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='embedding_cache'",
[],
|row| row.get(0),
)
.unwrap();
assert_eq!(count, 1);
}
#[tokio::test]
async fn schema_memories_has_embedding_column() {
let (_tmp, mem) = temp_sqlite();
let conn = mem.conn.lock();
// Check that embedding column exists by querying it
let result = conn.execute_batch("SELECT embedding FROM memories LIMIT 0");
assert!(result.is_ok());
}
// ── FTS5 sync trigger tests ──────────────────────────────────
#[tokio::test]
async fn fts5_syncs_on_insert() {
let (_tmp, mem) = temp_sqlite();
mem.store(
"test_key",
"unique_searchterm_xyz",
MemoryCategory::Core,
None,
)
.await
.unwrap();
let conn = mem.conn.lock();
let count: i64 = conn
.query_row(
"SELECT COUNT(*) FROM memories_fts WHERE memories_fts MATCH '\"unique_searchterm_xyz\"'",
[],
|row| row.get(0),
)
.unwrap();
assert_eq!(count, 1);
}
#[tokio::test]
async fn fts5_syncs_on_delete() {
let (_tmp, mem) = temp_sqlite();
mem.store(
"del_key",
"deletable_content_abc",
MemoryCategory::Core,
None,
)
.await
.unwrap();
mem.forget("del_key").await.unwrap();
let conn = mem.conn.lock();
let count: i64 = conn
.query_row(
"SELECT COUNT(*) FROM memories_fts WHERE memories_fts MATCH '\"deletable_content_abc\"'",
[],
|row| row.get(0),
)
.unwrap();
assert_eq!(count, 0);
}
#[tokio::test]
async fn fts5_syncs_on_update() {
let (_tmp, mem) = temp_sqlite();
mem.store(
"upd_key",
"original_content_111",
MemoryCategory::Core,
None,
)
.await
.unwrap();
mem.store("upd_key", "updated_content_222", MemoryCategory::Core, None)
.await
.unwrap();
let conn = mem.conn.lock();
// Old content should not be findable
let old: i64 = conn
.query_row(
"SELECT COUNT(*) FROM memories_fts WHERE memories_fts MATCH '\"original_content_111\"'",
[],
|row| row.get(0),
)
.unwrap();
assert_eq!(old, 0);
// New content should be findable
let new: i64 = conn
.query_row(
"SELECT COUNT(*) FROM memories_fts WHERE memories_fts MATCH '\"updated_content_222\"'",
[],
|row| row.get(0),
)
.unwrap();
assert_eq!(new, 1);
}
// ── Open timeout tests ────────────────────────────────────────
#[test]
fn open_with_timeout_succeeds_when_fast() {
let tmp = TempDir::new().unwrap();
let embedder = Arc::new(super::super::embeddings::NoopEmbedding);
let mem = SqliteMemory::with_embedder(tmp.path(), embedder, 0.7, 0.3, 1000, Some(5));
assert!(
mem.is_ok(),
"open with 5s timeout should succeed on fast path"
);
assert_eq!(mem.unwrap().name(), "sqlite");
}
#[tokio::test]
async fn open_with_timeout_store_recall_unchanged() {
let tmp = TempDir::new().unwrap();
let mem = SqliteMemory::with_embedder(
tmp.path(),
Arc::new(super::super::embeddings::NoopEmbedding),
0.7,
0.3,
1000,
Some(2),
)
.unwrap();
mem.store(
"timeout_key",
"value with timeout",
MemoryCategory::Core,
None,
)
.await
.unwrap();
let entry = mem.get("timeout_key").await.unwrap().unwrap();
assert_eq!(entry.content, "value with timeout");
}
// ── With-embedder constructor test ───────────────────────────
#[test]
fn with_embedder_noop() {
let tmp = TempDir::new().unwrap();
let embedder = Arc::new(super::super::embeddings::NoopEmbedding);
let mem = SqliteMemory::with_embedder(tmp.path(), embedder, 0.7, 0.3, 1000, None);
assert!(mem.is_ok());
assert_eq!(mem.unwrap().name(), "sqlite");
}
// ── Reindex test ─────────────────────────────────────────────
#[tokio::test]
async fn reindex_rebuilds_fts() {
let (_tmp, mem) = temp_sqlite();
mem.store("r1", "reindex test alpha", MemoryCategory::Core, None)
.await
.unwrap();
mem.store("r2", "reindex test beta", MemoryCategory::Core, None)
.await
.unwrap();
// Reindex should succeed (noop embedder → 0 re-embedded)
let count = mem.reindex().await.unwrap();
assert_eq!(count, 0);
// FTS should still work after rebuild
let results = mem.recall("reindex", 10, None).await.unwrap();
assert_eq!(results.len(), 2);
}
// ── Recall limit test ────────────────────────────────────────
#[tokio::test]
async fn recall_respects_limit() {
let (_tmp, mem) = temp_sqlite();
for i in 0..20 {
mem.store(
&format!("k{i}"),
&format!("common keyword item {i}"),
MemoryCategory::Core,
None,
)
.await
.unwrap();
}
let results = mem.recall("common keyword", 5, None).await.unwrap();
assert!(results.len() <= 5);
}
// ── Score presence test ──────────────────────────────────────
#[tokio::test]
async fn recall_results_have_scores() {
let (_tmp, mem) = temp_sqlite();
mem.store("s1", "scored result test", MemoryCategory::Core, None)
.await
.unwrap();
let results = mem.recall("scored", 10, None).await.unwrap();
assert!(!results.is_empty());
for r in &results {
assert!(r.score.is_some(), "Expected score on result: {:?}", r.key);
}
}
// ── Edge cases: FTS5 special characters ──────────────────────
#[tokio::test]
async fn recall_with_quotes_in_query() {
let (_tmp, mem) = temp_sqlite();
mem.store("q1", "He said hello world", MemoryCategory::Core, None)
.await
.unwrap();
// Quotes in query should not crash FTS5
let results = mem.recall("\"hello\"", 10, None).await.unwrap();
// May or may not match depending on FTS5 escaping, but must not error
assert!(results.len() <= 10);
}
#[tokio::test]
async fn recall_with_asterisk_in_query() {
let (_tmp, mem) = temp_sqlite();
mem.store("a1", "wildcard test content", MemoryCategory::Core, None)
.await
.unwrap();
let results = mem.recall("wild*", 10, None).await.unwrap();
assert!(results.len() <= 10);
}
#[tokio::test]
async fn recall_with_parentheses_in_query() {
let (_tmp, mem) = temp_sqlite();
mem.store("p1", "function call test", MemoryCategory::Core, None)
.await
.unwrap();
let results = mem.recall("function()", 10, None).await.unwrap();
assert!(results.len() <= 10);
}
#[tokio::test]
async fn recall_with_sql_injection_attempt() {
let (_tmp, mem) = temp_sqlite();
mem.store("safe", "normal content", MemoryCategory::Core, None)
.await
.unwrap();
// Should not crash or leak data
let results = mem
.recall("'; DROP TABLE memories; --", 10, None)
.await
.unwrap();
assert!(results.len() <= 10);
// Table should still exist
assert_eq!(mem.count().await.unwrap(), 1);
}
// ── Edge cases: store ────────────────────────────────────────
#[tokio::test]
async fn store_empty_content() {
let (_tmp, mem) = temp_sqlite();
mem.store("empty", "", MemoryCategory::Core, None)
.await
.unwrap();
let entry = mem.get("empty").await.unwrap().unwrap();
assert_eq!(entry.content, "");
}
#[tokio::test]
async fn store_empty_key() {
let (_tmp, mem) = temp_sqlite();
mem.store("", "content for empty key", MemoryCategory::Core, None)
.await
.unwrap();
let entry = mem.get("").await.unwrap().unwrap();
assert_eq!(entry.content, "content for empty key");
}
#[tokio::test]
async fn store_very_long_content() {
let (_tmp, mem) = temp_sqlite();
let long_content = "x".repeat(100_000);
mem.store("long", &long_content, MemoryCategory::Core, None)
.await
.unwrap();
let entry = mem.get("long").await.unwrap().unwrap();
assert_eq!(entry.content.len(), 100_000);
}
#[tokio::test]
async fn store_unicode_and_emoji() {
let (_tmp, mem) = temp_sqlite();
mem.store(
"emoji_key_🦀",
"こんにちは 🚀 Ñoño",
MemoryCategory::Core,
None,
)
.await
.unwrap();
let entry = mem.get("emoji_key_🦀").await.unwrap().unwrap();
assert_eq!(entry.content, "こんにちは 🚀 Ñoño");
}
#[tokio::test]
async fn store_content_with_newlines_and_tabs() {
let (_tmp, mem) = temp_sqlite();
let content = "line1\nline2\ttab\rcarriage\n\nnewparagraph";
mem.store("whitespace", content, MemoryCategory::Core, None)
.await
.unwrap();
let entry = mem.get("whitespace").await.unwrap().unwrap();
assert_eq!(entry.content, content);
}
// ── Edge cases: recall ───────────────────────────────────────
#[tokio::test]
async fn recall_single_character_query() {
let (_tmp, mem) = temp_sqlite();
mem.store("a", "x marks the spot", MemoryCategory::Core, None)
.await
.unwrap();
// Single char may not match FTS5 but LIKE fallback should work
let results = mem.recall("x", 10, None).await.unwrap();
// Should not crash; may or may not find results
assert!(results.len() <= 10);
}
#[tokio::test]
async fn recall_limit_zero() {
let (_tmp, mem) = temp_sqlite();
mem.store("a", "some content", MemoryCategory::Core, None)
.await
.unwrap();
let results = mem.recall("some", 0, None).await.unwrap();
assert!(results.is_empty());
}
#[tokio::test]
async fn recall_limit_one() {
let (_tmp, mem) = temp_sqlite();
mem.store("a", "matching content alpha", MemoryCategory::Core, None)
.await
.unwrap();
mem.store("b", "matching content beta", MemoryCategory::Core, None)
.await
.unwrap();
let results = mem.recall("matching content", 1, None).await.unwrap();
assert_eq!(results.len(), 1);
}
#[tokio::test]
async fn recall_matches_by_key_not_just_content() {
let (_tmp, mem) = temp_sqlite();
mem.store(
"rust_preferences",
"User likes systems programming",
MemoryCategory::Core,
None,
)
.await
.unwrap();
// "rust" appears in key but not content — LIKE fallback checks key too
let results = mem.recall("rust", 10, None).await.unwrap();
assert!(!results.is_empty(), "Should match by key");
}
#[tokio::test]
async fn recall_unicode_query() {
let (_tmp, mem) = temp_sqlite();
mem.store("jp", "日本語のテスト", MemoryCategory::Core, None)
.await
.unwrap();
let results = mem.recall("日本語", 10, None).await.unwrap();
assert!(!results.is_empty());
}
// ── Edge cases: schema idempotency ───────────────────────────
#[tokio::test]
async fn schema_idempotent_reopen() {
let tmp = TempDir::new().unwrap();
{
let mem = SqliteMemory::new(tmp.path()).unwrap();
mem.store("k1", "v1", MemoryCategory::Core, None)
.await
.unwrap();
}
// Open again — init_schema runs again on existing DB
let mem2 = SqliteMemory::new(tmp.path()).unwrap();
let entry = mem2.get("k1").await.unwrap();
assert!(entry.is_some());
assert_eq!(entry.unwrap().content, "v1");
// Store more data — should work fine
mem2.store("k2", "v2", MemoryCategory::Daily, None)
.await
.unwrap();
assert_eq!(mem2.count().await.unwrap(), 2);
}
#[tokio::test]
async fn schema_triple_open() {
let tmp = TempDir::new().unwrap();
let _m1 = SqliteMemory::new(tmp.path()).unwrap();
let _m2 = SqliteMemory::new(tmp.path()).unwrap();
let m3 = SqliteMemory::new(tmp.path()).unwrap();
assert!(m3.health_check().await);
}
// ── Edge cases: forget + FTS5 consistency ────────────────────
#[tokio::test]
async fn forget_then_recall_no_ghost_results() {
let (_tmp, mem) = temp_sqlite();
mem.store(
"ghost",
"phantom memory content",
MemoryCategory::Core,
None,
)
.await
.unwrap();
mem.forget("ghost").await.unwrap();
let results = mem.recall("phantom memory", 10, None).await.unwrap();
assert!(
results.is_empty(),
"Deleted memory should not appear in recall"
);
}
#[tokio::test]
async fn forget_and_re_store_same_key() {
let (_tmp, mem) = temp_sqlite();
mem.store("cycle", "version 1", MemoryCategory::Core, None)
.await
.unwrap();
mem.forget("cycle").await.unwrap();
mem.store("cycle", "version 2", MemoryCategory::Core, None)
.await
.unwrap();
let entry = mem.get("cycle").await.unwrap().unwrap();
assert_eq!(entry.content, "version 2");
assert_eq!(mem.count().await.unwrap(), 1);
}
// ── Edge cases: reindex ──────────────────────────────────────
#[tokio::test]
async fn reindex_empty_db() {
let (_tmp, mem) = temp_sqlite();
let count = mem.reindex().await.unwrap();
assert_eq!(count, 0);
}
#[tokio::test]
async fn reindex_twice_is_safe() {
let (_tmp, mem) = temp_sqlite();
mem.store("r1", "reindex data", MemoryCategory::Core, None)
.await
.unwrap();
mem.reindex().await.unwrap();
let count = mem.reindex().await.unwrap();
assert_eq!(count, 0); // Noop embedder → nothing to re-embed
// Data should still be intact
let results = mem.recall("reindex", 10, None).await.unwrap();
assert_eq!(results.len(), 1);
}
// ── Edge cases: content_hash ─────────────────────────────────
#[test]
fn content_hash_empty_string() {
let h = SqliteMemory::content_hash("");
assert!(!h.is_empty());
assert_eq!(h.len(), 16); // 16 hex chars
}
#[test]
fn content_hash_unicode() {
let h1 = SqliteMemory::content_hash("🦀");
let h2 = SqliteMemory::content_hash("🦀");
assert_eq!(h1, h2);
let h3 = SqliteMemory::content_hash("🚀");
assert_ne!(h1, h3);
}
#[test]
fn content_hash_long_input() {
let long = "a".repeat(1_000_000);
let h = SqliteMemory::content_hash(&long);
assert_eq!(h.len(), 16);
}
// ── Edge cases: category helpers ─────────────────────────────
#[test]
fn category_roundtrip_custom_with_spaces() {
let cat = MemoryCategory::Custom("my custom category".into());
let s = SqliteMemory::category_to_str(&cat);
assert_eq!(s, "my custom category");
let back = SqliteMemory::str_to_category(&s);
assert_eq!(back, cat);
}
#[test]
fn category_roundtrip_empty_custom() {
let cat = MemoryCategory::Custom(String::new());
let s = SqliteMemory::category_to_str(&cat);
assert_eq!(s, "");
let back = SqliteMemory::str_to_category(&s);
assert_eq!(back, MemoryCategory::Custom(String::new()));
}
// ── Edge cases: list ─────────────────────────────────────────
#[tokio::test]
async fn list_custom_category() {
let (_tmp, mem) = temp_sqlite();
mem.store(
"c1",
"custom1",
MemoryCategory::Custom("project".into()),
None,
)
.await
.unwrap();
mem.store(
"c2",
"custom2",
MemoryCategory::Custom("project".into()),
None,
)
.await
.unwrap();
mem.store("c3", "other", MemoryCategory::Core, None)
.await
.unwrap();
let project = mem
.list(Some(&MemoryCategory::Custom("project".into())), None)
.await
.unwrap();
assert_eq!(project.len(), 2);
}
#[tokio::test]
async fn list_empty_db() {
let (_tmp, mem) = temp_sqlite();
let all = mem.list(None, None).await.unwrap();
assert!(all.is_empty());
}
// ── Session isolation ─────────────────────────────────────────
#[tokio::test]
async fn store_and_recall_with_session_id() {
let (_tmp, mem) = temp_sqlite();
mem.store("k1", "session A fact", MemoryCategory::Core, Some("sess-a"))
.await
.unwrap();
mem.store("k2", "session B fact", MemoryCategory::Core, Some("sess-b"))
.await
.unwrap();
mem.store("k3", "no session fact", MemoryCategory::Core, None)
.await
.unwrap();
// Recall with session-a filter returns only session-a entry
let results = mem.recall("fact", 10, Some("sess-a")).await.unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].key, "k1");
assert_eq!(results[0].session_id.as_deref(), Some("sess-a"));
}
#[tokio::test]
async fn recall_no_session_filter_returns_all() {
let (_tmp, mem) = temp_sqlite();
mem.store("k1", "alpha fact", MemoryCategory::Core, Some("sess-a"))
.await
.unwrap();
mem.store("k2", "beta fact", MemoryCategory::Core, Some("sess-b"))
.await
.unwrap();
mem.store("k3", "gamma fact", MemoryCategory::Core, None)
.await
.unwrap();
// Recall without session filter returns all matching entries
let results = mem.recall("fact", 10, None).await.unwrap();
assert_eq!(results.len(), 3);
}
#[tokio::test]
async fn cross_session_recall_isolation() {
let (_tmp, mem) = temp_sqlite();
mem.store(
"secret",
"session A secret data",
MemoryCategory::Core,
Some("sess-a"),
)
.await
.unwrap();
// Session B cannot see session A data
let results = mem.recall("secret", 10, Some("sess-b")).await.unwrap();
assert!(results.is_empty());
// Session A can see its own data
let results = mem.recall("secret", 10, Some("sess-a")).await.unwrap();
assert_eq!(results.len(), 1);
}
#[tokio::test]
async fn list_with_session_filter() {
let (_tmp, mem) = temp_sqlite();
mem.store("k1", "a1", MemoryCategory::Core, Some("sess-a"))
.await
.unwrap();
mem.store("k2", "a2", MemoryCategory::Conversation, Some("sess-a"))
.await
.unwrap();
mem.store("k3", "b1", MemoryCategory::Core, Some("sess-b"))
.await
.unwrap();
mem.store("k4", "none1", MemoryCategory::Core, None)
.await
.unwrap();
// List with session-a filter
let results = mem.list(None, Some("sess-a")).await.unwrap();
assert_eq!(results.len(), 2);
assert!(results
.iter()
.all(|e| e.session_id.as_deref() == Some("sess-a")));
// List with session-a + category filter
let results = mem
.list(Some(&MemoryCategory::Core), Some("sess-a"))
.await
.unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].key, "k1");
}
#[tokio::test]
async fn schema_migration_idempotent_on_reopen() {
let tmp = TempDir::new().unwrap();
// First open: creates schema + migration
{
let mem = SqliteMemory::new(tmp.path()).unwrap();
mem.store("k1", "before reopen", MemoryCategory::Core, Some("sess-x"))
.await
.unwrap();
}
// Second open: migration runs again but is idempotent
{
let mem = SqliteMemory::new(tmp.path()).unwrap();
let results = mem.recall("reopen", 10, Some("sess-x")).await.unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].key, "k1");
assert_eq!(results[0].session_id.as_deref(), Some("sess-x"));
}
}
// ── §4.1 Concurrent write contention tests ──────────────
#[tokio::test]
async fn sqlite_concurrent_writes_no_data_loss() {
let (_tmp, mem) = temp_sqlite();
let mem = std::sync::Arc::new(mem);
let mut handles = Vec::new();
for i in 0..10 {
let mem = std::sync::Arc::clone(&mem);
handles.push(tokio::spawn(async move {
mem.store(
&format!("concurrent_key_{i}"),
&format!("value_{i}"),
MemoryCategory::Core,
None,
)
.await
.unwrap();
}));
}
for handle in handles {
handle.await.unwrap();
}
let count = mem.count().await.unwrap();
assert_eq!(
count, 10,
"all 10 concurrent writes must succeed without data loss"
);
}
#[tokio::test]
async fn sqlite_concurrent_read_write_no_panic() {
let (_tmp, mem) = temp_sqlite();
let mem = std::sync::Arc::new(mem);
// Pre-populate
mem.store("shared_key", "initial", MemoryCategory::Core, None)
.await
.unwrap();
let mut handles = Vec::new();
// Concurrent reads
for _ in 0..5 {
let mem = std::sync::Arc::clone(&mem);
handles.push(tokio::spawn(async move {
let _ = mem.get("shared_key").await.unwrap();
}));
}
// Concurrent writes
for i in 0..5 {
let mem = std::sync::Arc::clone(&mem);
handles.push(tokio::spawn(async move {
mem.store(
&format!("key_{i}"),
&format!("val_{i}"),
MemoryCategory::Core,
None,
)
.await
.unwrap();
}));
}
for handle in handles {
handle.await.unwrap();
}
// Should have 6 total entries (1 pre-existing + 5 new)
assert_eq!(mem.count().await.unwrap(), 6);
}
// ── §4.2 Reindex / corruption recovery tests ────────────
#[tokio::test]
async fn sqlite_reindex_preserves_data() {
let (_tmp, mem) = temp_sqlite();
mem.store("a", "Rust is fast", MemoryCategory::Core, None)
.await
.unwrap();
mem.store("b", "Python is interpreted", MemoryCategory::Core, None)
.await
.unwrap();
mem.reindex().await.unwrap();
let count = mem.count().await.unwrap();
assert_eq!(count, 2, "reindex must preserve all entries");
let entry = mem.get("a").await.unwrap();
assert!(entry.is_some());
assert_eq!(entry.unwrap().content, "Rust is fast");
}
#[tokio::test]
async fn sqlite_reindex_idempotent() {
let (_tmp, mem) = temp_sqlite();
mem.store("x", "test data", MemoryCategory::Core, None)
.await
.unwrap();
// Multiple reindex calls should be safe
mem.reindex().await.unwrap();
mem.reindex().await.unwrap();
mem.reindex().await.unwrap();
assert_eq!(mem.count().await.unwrap(), 1);
}
}