zeroclaw/src/memory/sqlite.rs
Argenis 1530a8707d
feat: add Git operations tool for structured repository management
Implements #214 - Add git_operations tool that provides safe, parsed
git operations with JSON output and security policy integration.

Features:
- Operations: status, diff, log, branch, commit, add, checkout, stash
- Structured JSON output (parsed status, diff hunks, commit history)
- SecurityPolicy integration with autonomy-aware controls
- Command injection protection

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-16 05:53:29 -05:00

1388 lines
47 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

use super::embeddings::EmbeddingProvider;
use super::traits::{Memory, MemoryCategory, MemoryEntry};
use super::vector;
use async_trait::async_trait;
use chrono::Local;
use rusqlite::{params, Connection};
use std::path::{Path, PathBuf};
use std::sync::{Arc, Mutex};
use uuid::Uuid;
/// SQLite-backed persistent memory — the brain
///
/// Full-stack search engine:
/// - **Vector DB**: embeddings stored as BLOB, cosine similarity search
/// - **Keyword Search**: FTS5 virtual table with BM25 scoring
/// - **Hybrid Merge**: weighted fusion of vector + keyword results
/// - **Embedding Cache**: LRU-evicted cache to avoid redundant API calls
/// - **Safe Reindex**: temp DB → seed → sync → atomic swap → rollback
pub struct SqliteMemory {
conn: Mutex<Connection>,
db_path: PathBuf,
embedder: Arc<dyn EmbeddingProvider>,
vector_weight: f32,
keyword_weight: f32,
cache_max: usize,
}
impl SqliteMemory {
pub fn new(workspace_dir: &Path) -> anyhow::Result<Self> {
Self::with_embedder(
workspace_dir,
Arc::new(super::embeddings::NoopEmbedding),
0.7,
0.3,
10_000,
)
}
pub fn with_embedder(
workspace_dir: &Path,
embedder: Arc<dyn EmbeddingProvider>,
vector_weight: f32,
keyword_weight: f32,
cache_max: usize,
) -> anyhow::Result<Self> {
let db_path = workspace_dir.join("memory").join("brain.db");
if let Some(parent) = db_path.parent() {
std::fs::create_dir_all(parent)?;
}
let conn = Connection::open(&db_path)?;
// ── Production-grade PRAGMA tuning ──────────────────────
// WAL mode: concurrent reads during writes, crash-safe
// normal sync: 2× write speed, still durable on WAL
// mmap 8 MB: let the OS page-cache serve hot reads
// cache 2 MB: keep ~500 hot pages in-process
// temp_store memory: temp tables never hit disk
conn.execute_batch(
"PRAGMA journal_mode = WAL;
PRAGMA synchronous = NORMAL;
PRAGMA mmap_size = 8388608;
PRAGMA cache_size = -2000;
PRAGMA temp_store = MEMORY;",
)?;
Self::init_schema(&conn)?;
Ok(Self {
conn: Mutex::new(conn),
db_path,
embedder,
vector_weight,
keyword_weight,
cache_max,
})
}
/// Initialize all tables: memories, FTS5, `embedding_cache`
fn init_schema(conn: &Connection) -> anyhow::Result<()> {
conn.execute_batch(
"-- Core memories table
CREATE TABLE IF NOT EXISTS memories (
id TEXT PRIMARY KEY,
key TEXT NOT NULL UNIQUE,
content TEXT NOT NULL,
category TEXT NOT NULL DEFAULT 'core',
embedding BLOB,
created_at TEXT NOT NULL,
updated_at TEXT NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_memories_category ON memories(category);
CREATE INDEX IF NOT EXISTS idx_memories_key ON memories(key);
-- FTS5 full-text search (BM25 scoring)
CREATE VIRTUAL TABLE IF NOT EXISTS memories_fts USING fts5(
key, content, content=memories, content_rowid=rowid
);
-- FTS5 triggers: keep in sync with memories table
CREATE TRIGGER IF NOT EXISTS memories_ai AFTER INSERT ON memories BEGIN
INSERT INTO memories_fts(rowid, key, content)
VALUES (new.rowid, new.key, new.content);
END;
CREATE TRIGGER IF NOT EXISTS memories_ad AFTER DELETE ON memories BEGIN
INSERT INTO memories_fts(memories_fts, rowid, key, content)
VALUES ('delete', old.rowid, old.key, old.content);
END;
CREATE TRIGGER IF NOT EXISTS memories_au AFTER UPDATE ON memories BEGIN
INSERT INTO memories_fts(memories_fts, rowid, key, content)
VALUES ('delete', old.rowid, old.key, old.content);
INSERT INTO memories_fts(rowid, key, content)
VALUES (new.rowid, new.key, new.content);
END;
-- Embedding cache with LRU eviction
CREATE TABLE IF NOT EXISTS embedding_cache (
content_hash TEXT PRIMARY KEY,
embedding BLOB NOT NULL,
created_at TEXT NOT NULL,
accessed_at TEXT NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_cache_accessed ON embedding_cache(accessed_at);",
)?;
Ok(())
}
fn category_to_str(cat: &MemoryCategory) -> String {
match cat {
MemoryCategory::Core => "core".into(),
MemoryCategory::Daily => "daily".into(),
MemoryCategory::Conversation => "conversation".into(),
MemoryCategory::Custom(name) => name.clone(),
}
}
fn str_to_category(s: &str) -> MemoryCategory {
match s {
"core" => MemoryCategory::Core,
"daily" => MemoryCategory::Daily,
"conversation" => MemoryCategory::Conversation,
other => MemoryCategory::Custom(other.to_string()),
}
}
/// Deterministic content hash for embedding cache.
/// Uses SHA-256 (truncated) instead of DefaultHasher, which is
/// explicitly documented as unstable across Rust versions.
fn content_hash(text: &str) -> String {
use sha2::{Digest, Sha256};
let hash = Sha256::digest(text.as_bytes());
// First 8 bytes → 16 hex chars, matching previous format length
format!(
"{:016x}",
u64::from_be_bytes(
hash[..8]
.try_into()
.expect("SHA-256 always produces >= 8 bytes")
)
)
}
/// Get embedding from cache, or compute + cache it
async fn get_or_compute_embedding(&self, text: &str) -> anyhow::Result<Option<Vec<f32>>> {
if self.embedder.dimensions() == 0 {
return Ok(None); // Noop embedder
}
let hash = Self::content_hash(text);
let now = Local::now().to_rfc3339();
// Check cache
{
let conn = self
.conn
.lock()
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
let mut stmt =
conn.prepare("SELECT embedding FROM embedding_cache WHERE content_hash = ?1")?;
let cached: Option<Vec<u8>> = stmt.query_row(params![hash], |row| row.get(0)).ok();
if let Some(bytes) = cached {
// Update accessed_at for LRU
conn.execute(
"UPDATE embedding_cache SET accessed_at = ?1 WHERE content_hash = ?2",
params![now, hash],
)?;
return Ok(Some(vector::bytes_to_vec(&bytes)));
}
}
// Compute embedding
let embedding = self.embedder.embed_one(text).await?;
let bytes = vector::vec_to_bytes(&embedding);
// Store in cache + LRU eviction
{
let conn = self
.conn
.lock()
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
conn.execute(
"INSERT OR REPLACE INTO embedding_cache (content_hash, embedding, created_at, accessed_at)
VALUES (?1, ?2, ?3, ?4)",
params![hash, bytes, now, now],
)?;
// LRU eviction: keep only cache_max entries
#[allow(clippy::cast_possible_wrap)]
let max = self.cache_max as i64;
conn.execute(
"DELETE FROM embedding_cache WHERE content_hash IN (
SELECT content_hash FROM embedding_cache
ORDER BY accessed_at ASC
LIMIT MAX(0, (SELECT COUNT(*) FROM embedding_cache) - ?1)
)",
params![max],
)?;
}
Ok(Some(embedding))
}
/// FTS5 BM25 keyword search
fn fts5_search(
conn: &Connection,
query: &str,
limit: usize,
) -> anyhow::Result<Vec<(String, f32)>> {
// Escape FTS5 special chars and build query
let fts_query: String = query
.split_whitespace()
.map(|w| format!("\"{w}\""))
.collect::<Vec<_>>()
.join(" OR ");
if fts_query.is_empty() {
return Ok(Vec::new());
}
let sql = "SELECT m.id, bm25(memories_fts) as score
FROM memories_fts f
JOIN memories m ON m.rowid = f.rowid
WHERE memories_fts MATCH ?1
ORDER BY score
LIMIT ?2";
let mut stmt = conn.prepare(sql)?;
#[allow(clippy::cast_possible_wrap)]
let limit_i64 = limit as i64;
let rows = stmt.query_map(params![fts_query, limit_i64], |row| {
let id: String = row.get(0)?;
let score: f64 = row.get(1)?;
// BM25 returns negative scores (lower = better), negate for ranking
#[allow(clippy::cast_possible_truncation)]
Ok((id, (-score) as f32))
})?;
let mut results = Vec::new();
for row in rows {
results.push(row?);
}
Ok(results)
}
/// Vector similarity search: scan embeddings and compute cosine similarity
fn vector_search(
conn: &Connection,
query_embedding: &[f32],
limit: usize,
) -> anyhow::Result<Vec<(String, f32)>> {
let mut stmt =
conn.prepare("SELECT id, embedding FROM memories WHERE embedding IS NOT NULL")?;
let rows = stmt.query_map([], |row| {
let id: String = row.get(0)?;
let blob: Vec<u8> = row.get(1)?;
Ok((id, blob))
})?;
let mut scored: Vec<(String, f32)> = Vec::new();
for row in rows {
let (id, blob) = row?;
let emb = vector::bytes_to_vec(&blob);
let sim = vector::cosine_similarity(query_embedding, &emb);
if sim > 0.0 {
scored.push((id, sim));
}
}
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
scored.truncate(limit);
Ok(scored)
}
/// Safe reindex: rebuild FTS5 + embeddings with rollback on failure
#[allow(dead_code)]
pub async fn reindex(&self) -> anyhow::Result<usize> {
// Step 1: Rebuild FTS5
{
let conn = self
.conn
.lock()
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
conn.execute_batch("INSERT INTO memories_fts(memories_fts) VALUES('rebuild');")?;
}
// Step 2: Re-embed all memories that lack embeddings
if self.embedder.dimensions() == 0 {
return Ok(0);
}
let entries: Vec<(String, String)> = {
let conn = self
.conn
.lock()
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
let mut stmt =
conn.prepare("SELECT id, content FROM memories WHERE embedding IS NULL")?;
let rows = stmt.query_map([], |row| {
Ok((row.get::<_, String>(0)?, row.get::<_, String>(1)?))
})?;
rows.filter_map(std::result::Result::ok).collect()
};
let mut count = 0;
for (id, content) in &entries {
if let Ok(Some(emb)) = self.get_or_compute_embedding(content).await {
let bytes = vector::vec_to_bytes(&emb);
let conn = self
.conn
.lock()
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
conn.execute(
"UPDATE memories SET embedding = ?1 WHERE id = ?2",
params![bytes, id],
)?;
count += 1;
}
}
Ok(count)
}
}
#[async_trait]
impl Memory for SqliteMemory {
fn name(&self) -> &str {
"sqlite"
}
async fn store(
&self,
key: &str,
content: &str,
category: MemoryCategory,
) -> anyhow::Result<()> {
// Compute embedding (async, before lock)
let embedding_bytes = self
.get_or_compute_embedding(content)
.await?
.map(|emb| vector::vec_to_bytes(&emb));
let conn = self
.conn
.lock()
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
let now = Local::now().to_rfc3339();
let cat = Self::category_to_str(&category);
let id = Uuid::new_v4().to_string();
conn.execute(
"INSERT INTO memories (id, key, content, category, embedding, created_at, updated_at)
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)
ON CONFLICT(key) DO UPDATE SET
content = excluded.content,
category = excluded.category,
embedding = excluded.embedding,
updated_at = excluded.updated_at",
params![id, key, content, cat, embedding_bytes, now, now],
)?;
Ok(())
}
async fn recall(&self, query: &str, limit: usize) -> anyhow::Result<Vec<MemoryEntry>> {
if query.trim().is_empty() {
return Ok(Vec::new());
}
// Compute query embedding (async, before lock)
let query_embedding = self.get_or_compute_embedding(query).await?;
let conn = self
.conn
.lock()
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
// FTS5 BM25 keyword search
let keyword_results = Self::fts5_search(&conn, query, limit * 2).unwrap_or_default();
// Vector similarity search (if embeddings available)
let vector_results = if let Some(ref qe) = query_embedding {
Self::vector_search(&conn, qe, limit * 2).unwrap_or_default()
} else {
Vec::new()
};
// Hybrid merge
let merged = if vector_results.is_empty() {
// No embeddings — use keyword results only
keyword_results
.iter()
.map(|(id, score)| vector::ScoredResult {
id: id.clone(),
vector_score: None,
keyword_score: Some(*score),
final_score: *score,
})
.collect::<Vec<_>>()
} else {
vector::hybrid_merge(
&vector_results,
&keyword_results,
self.vector_weight,
self.keyword_weight,
limit,
)
};
// Fetch full entries for merged results
let mut results = Vec::new();
for scored in &merged {
let mut stmt = conn.prepare(
"SELECT id, key, content, category, created_at FROM memories WHERE id = ?1",
)?;
if let Ok(entry) = stmt.query_row(params![scored.id], |row| {
Ok(MemoryEntry {
id: row.get(0)?,
key: row.get(1)?,
content: row.get(2)?,
category: Self::str_to_category(&row.get::<_, String>(3)?),
timestamp: row.get(4)?,
session_id: None,
score: Some(f64::from(scored.final_score)),
})
}) {
results.push(entry);
}
}
// If hybrid returned nothing, fall back to LIKE search
if results.is_empty() {
let keywords: Vec<String> =
query.split_whitespace().map(|w| format!("%{w}%")).collect();
if !keywords.is_empty() {
let conditions: Vec<String> = keywords
.iter()
.enumerate()
.map(|(i, _)| {
format!("(content LIKE ?{} OR key LIKE ?{})", i * 2 + 1, i * 2 + 2)
})
.collect();
let where_clause = conditions.join(" OR ");
let sql = format!(
"SELECT id, key, content, category, created_at FROM memories
WHERE {where_clause}
ORDER BY updated_at DESC
LIMIT ?{}",
keywords.len() * 2 + 1
);
let mut stmt = conn.prepare(&sql)?;
let mut param_values: Vec<Box<dyn rusqlite::types::ToSql>> = Vec::new();
for kw in &keywords {
param_values.push(Box::new(kw.clone()));
param_values.push(Box::new(kw.clone()));
}
#[allow(clippy::cast_possible_wrap)]
param_values.push(Box::new(limit as i64));
let params_ref: Vec<&dyn rusqlite::types::ToSql> =
param_values.iter().map(AsRef::as_ref).collect();
let rows = stmt.query_map(params_ref.as_slice(), |row| {
Ok(MemoryEntry {
id: row.get(0)?,
key: row.get(1)?,
content: row.get(2)?,
category: Self::str_to_category(&row.get::<_, String>(3)?),
timestamp: row.get(4)?,
session_id: None,
score: Some(1.0),
})
})?;
for row in rows {
results.push(row?);
}
}
}
results.truncate(limit);
Ok(results)
}
async fn get(&self, key: &str) -> anyhow::Result<Option<MemoryEntry>> {
let conn = self
.conn
.lock()
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
let mut stmt = conn.prepare(
"SELECT id, key, content, category, created_at FROM memories WHERE key = ?1",
)?;
let mut rows = stmt.query_map(params![key], |row| {
Ok(MemoryEntry {
id: row.get(0)?,
key: row.get(1)?,
content: row.get(2)?,
category: Self::str_to_category(&row.get::<_, String>(3)?),
timestamp: row.get(4)?,
session_id: None,
score: None,
})
})?;
match rows.next() {
Some(Ok(entry)) => Ok(Some(entry)),
_ => Ok(None),
}
}
async fn list(&self, category: Option<&MemoryCategory>) -> anyhow::Result<Vec<MemoryEntry>> {
let conn = self
.conn
.lock()
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
let mut results = Vec::new();
let row_mapper = |row: &rusqlite::Row| -> rusqlite::Result<MemoryEntry> {
Ok(MemoryEntry {
id: row.get(0)?,
key: row.get(1)?,
content: row.get(2)?,
category: Self::str_to_category(&row.get::<_, String>(3)?),
timestamp: row.get(4)?,
session_id: None,
score: None,
})
};
if let Some(cat) = category {
let cat_str = Self::category_to_str(cat);
let mut stmt = conn.prepare(
"SELECT id, key, content, category, created_at FROM memories
WHERE category = ?1 ORDER BY updated_at DESC",
)?;
let rows = stmt.query_map(params![cat_str], row_mapper)?;
for row in rows {
results.push(row?);
}
} else {
let mut stmt = conn.prepare(
"SELECT id, key, content, category, created_at FROM memories
ORDER BY updated_at DESC",
)?;
let rows = stmt.query_map([], row_mapper)?;
for row in rows {
results.push(row?);
}
}
Ok(results)
}
async fn forget(&self, key: &str) -> anyhow::Result<bool> {
let conn = self
.conn
.lock()
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
let affected = conn.execute("DELETE FROM memories WHERE key = ?1", params![key])?;
Ok(affected > 0)
}
async fn count(&self) -> anyhow::Result<usize> {
let conn = self
.conn
.lock()
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
let count: i64 = conn.query_row("SELECT COUNT(*) FROM memories", [], |row| row.get(0))?;
#[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)]
Ok(count as usize)
}
async fn health_check(&self) -> bool {
self.conn
.lock()
.map(|c| c.execute_batch("SELECT 1").is_ok())
.unwrap_or(false)
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
fn temp_sqlite() -> (TempDir, SqliteMemory) {
let tmp = TempDir::new().unwrap();
let mem = SqliteMemory::new(tmp.path()).unwrap();
(tmp, mem)
}
#[tokio::test]
async fn sqlite_name() {
let (_tmp, mem) = temp_sqlite();
assert_eq!(mem.name(), "sqlite");
}
#[tokio::test]
async fn sqlite_health() {
let (_tmp, mem) = temp_sqlite();
assert!(mem.health_check().await);
}
#[tokio::test]
async fn sqlite_store_and_get() {
let (_tmp, mem) = temp_sqlite();
mem.store("user_lang", "Prefers Rust", MemoryCategory::Core)
.await
.unwrap();
let entry = mem.get("user_lang").await.unwrap();
assert!(entry.is_some());
let entry = entry.unwrap();
assert_eq!(entry.key, "user_lang");
assert_eq!(entry.content, "Prefers Rust");
assert_eq!(entry.category, MemoryCategory::Core);
}
#[tokio::test]
async fn sqlite_store_upsert() {
let (_tmp, mem) = temp_sqlite();
mem.store("pref", "likes Rust", MemoryCategory::Core)
.await
.unwrap();
mem.store("pref", "loves Rust", MemoryCategory::Core)
.await
.unwrap();
let entry = mem.get("pref").await.unwrap().unwrap();
assert_eq!(entry.content, "loves Rust");
assert_eq!(mem.count().await.unwrap(), 1);
}
#[tokio::test]
async fn sqlite_recall_keyword() {
let (_tmp, mem) = temp_sqlite();
mem.store("a", "Rust is fast and safe", MemoryCategory::Core)
.await
.unwrap();
mem.store("b", "Python is interpreted", MemoryCategory::Core)
.await
.unwrap();
mem.store("c", "Rust has zero-cost abstractions", MemoryCategory::Core)
.await
.unwrap();
let results = mem.recall("Rust", 10).await.unwrap();
assert_eq!(results.len(), 2);
assert!(results
.iter()
.all(|r| r.content.to_lowercase().contains("rust")));
}
#[tokio::test]
async fn sqlite_recall_multi_keyword() {
let (_tmp, mem) = temp_sqlite();
mem.store("a", "Rust is fast", MemoryCategory::Core)
.await
.unwrap();
mem.store("b", "Rust is safe and fast", MemoryCategory::Core)
.await
.unwrap();
let results = mem.recall("fast safe", 10).await.unwrap();
assert!(!results.is_empty());
// Entry with both keywords should score higher
assert!(results[0].content.contains("safe") && results[0].content.contains("fast"));
}
#[tokio::test]
async fn sqlite_recall_no_match() {
let (_tmp, mem) = temp_sqlite();
mem.store("a", "Rust rocks", MemoryCategory::Core)
.await
.unwrap();
let results = mem.recall("javascript", 10).await.unwrap();
assert!(results.is_empty());
}
#[tokio::test]
async fn sqlite_forget() {
let (_tmp, mem) = temp_sqlite();
mem.store("temp", "temporary data", MemoryCategory::Conversation)
.await
.unwrap();
assert_eq!(mem.count().await.unwrap(), 1);
let removed = mem.forget("temp").await.unwrap();
assert!(removed);
assert_eq!(mem.count().await.unwrap(), 0);
}
#[tokio::test]
async fn sqlite_forget_nonexistent() {
let (_tmp, mem) = temp_sqlite();
let removed = mem.forget("nope").await.unwrap();
assert!(!removed);
}
#[tokio::test]
async fn sqlite_list_all() {
let (_tmp, mem) = temp_sqlite();
mem.store("a", "one", MemoryCategory::Core).await.unwrap();
mem.store("b", "two", MemoryCategory::Daily).await.unwrap();
mem.store("c", "three", MemoryCategory::Conversation)
.await
.unwrap();
let all = mem.list(None).await.unwrap();
assert_eq!(all.len(), 3);
}
#[tokio::test]
async fn sqlite_list_by_category() {
let (_tmp, mem) = temp_sqlite();
mem.store("a", "core1", MemoryCategory::Core).await.unwrap();
mem.store("b", "core2", MemoryCategory::Core).await.unwrap();
mem.store("c", "daily1", MemoryCategory::Daily)
.await
.unwrap();
let core = mem.list(Some(&MemoryCategory::Core)).await.unwrap();
assert_eq!(core.len(), 2);
let daily = mem.list(Some(&MemoryCategory::Daily)).await.unwrap();
assert_eq!(daily.len(), 1);
}
#[tokio::test]
async fn sqlite_count_empty() {
let (_tmp, mem) = temp_sqlite();
assert_eq!(mem.count().await.unwrap(), 0);
}
#[tokio::test]
async fn sqlite_get_nonexistent() {
let (_tmp, mem) = temp_sqlite();
assert!(mem.get("nope").await.unwrap().is_none());
}
#[tokio::test]
async fn sqlite_db_persists() {
let tmp = TempDir::new().unwrap();
{
let mem = SqliteMemory::new(tmp.path()).unwrap();
mem.store("persist", "I survive restarts", MemoryCategory::Core)
.await
.unwrap();
}
// Reopen
let mem2 = SqliteMemory::new(tmp.path()).unwrap();
let entry = mem2.get("persist").await.unwrap();
assert!(entry.is_some());
assert_eq!(entry.unwrap().content, "I survive restarts");
}
#[tokio::test]
async fn sqlite_category_roundtrip() {
let (_tmp, mem) = temp_sqlite();
let categories = [
MemoryCategory::Core,
MemoryCategory::Daily,
MemoryCategory::Conversation,
MemoryCategory::Custom("project".into()),
];
for (i, cat) in categories.iter().enumerate() {
mem.store(&format!("k{i}"), &format!("v{i}"), cat.clone())
.await
.unwrap();
}
for (i, cat) in categories.iter().enumerate() {
let entry = mem.get(&format!("k{i}")).await.unwrap().unwrap();
assert_eq!(&entry.category, cat);
}
}
// ── FTS5 search tests ────────────────────────────────────────
#[tokio::test]
async fn fts5_bm25_ranking() {
let (_tmp, mem) = temp_sqlite();
mem.store(
"a",
"Rust is a systems programming language",
MemoryCategory::Core,
)
.await
.unwrap();
mem.store("b", "Python is great for scripting", MemoryCategory::Core)
.await
.unwrap();
mem.store(
"c",
"Rust and Rust and Rust everywhere",
MemoryCategory::Core,
)
.await
.unwrap();
let results = mem.recall("Rust", 10).await.unwrap();
assert!(results.len() >= 2);
// All results should contain "Rust"
for r in &results {
assert!(
r.content.to_lowercase().contains("rust"),
"Expected 'rust' in: {}",
r.content
);
}
}
#[tokio::test]
async fn fts5_multi_word_query() {
let (_tmp, mem) = temp_sqlite();
mem.store("a", "The quick brown fox jumps", MemoryCategory::Core)
.await
.unwrap();
mem.store("b", "A lazy dog sleeps", MemoryCategory::Core)
.await
.unwrap();
mem.store("c", "The quick dog runs fast", MemoryCategory::Core)
.await
.unwrap();
let results = mem.recall("quick dog", 10).await.unwrap();
assert!(!results.is_empty());
// "The quick dog runs fast" matches both terms
assert!(results[0].content.contains("quick"));
}
#[tokio::test]
async fn recall_empty_query_returns_empty() {
let (_tmp, mem) = temp_sqlite();
mem.store("a", "data", MemoryCategory::Core).await.unwrap();
let results = mem.recall("", 10).await.unwrap();
assert!(results.is_empty());
}
#[tokio::test]
async fn recall_whitespace_query_returns_empty() {
let (_tmp, mem) = temp_sqlite();
mem.store("a", "data", MemoryCategory::Core).await.unwrap();
let results = mem.recall(" ", 10).await.unwrap();
assert!(results.is_empty());
}
// ── Embedding cache tests ────────────────────────────────────
#[test]
fn content_hash_deterministic() {
let h1 = SqliteMemory::content_hash("hello world");
let h2 = SqliteMemory::content_hash("hello world");
assert_eq!(h1, h2);
}
#[test]
fn content_hash_different_inputs() {
let h1 = SqliteMemory::content_hash("hello");
let h2 = SqliteMemory::content_hash("world");
assert_ne!(h1, h2);
}
// ── Schema tests ─────────────────────────────────────────────
#[tokio::test]
async fn schema_has_fts5_table() {
let (_tmp, mem) = temp_sqlite();
let conn = mem.conn.lock().unwrap();
// FTS5 table should exist
let count: i64 = conn
.query_row(
"SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='memories_fts'",
[],
|row| row.get(0),
)
.unwrap();
assert_eq!(count, 1);
}
#[tokio::test]
async fn schema_has_embedding_cache() {
let (_tmp, mem) = temp_sqlite();
let conn = mem.conn.lock().unwrap();
let count: i64 = conn
.query_row(
"SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='embedding_cache'",
[],
|row| row.get(0),
)
.unwrap();
assert_eq!(count, 1);
}
#[tokio::test]
async fn schema_memories_has_embedding_column() {
let (_tmp, mem) = temp_sqlite();
let conn = mem.conn.lock().unwrap();
// Check that embedding column exists by querying it
let result = conn.execute_batch("SELECT embedding FROM memories LIMIT 0");
assert!(result.is_ok());
}
// ── FTS5 sync trigger tests ──────────────────────────────────
#[tokio::test]
async fn fts5_syncs_on_insert() {
let (_tmp, mem) = temp_sqlite();
mem.store("test_key", "unique_searchterm_xyz", MemoryCategory::Core)
.await
.unwrap();
let conn = mem.conn.lock().unwrap();
let count: i64 = conn
.query_row(
"SELECT COUNT(*) FROM memories_fts WHERE memories_fts MATCH '\"unique_searchterm_xyz\"'",
[],
|row| row.get(0),
)
.unwrap();
assert_eq!(count, 1);
}
#[tokio::test]
async fn fts5_syncs_on_delete() {
let (_tmp, mem) = temp_sqlite();
mem.store("del_key", "deletable_content_abc", MemoryCategory::Core)
.await
.unwrap();
mem.forget("del_key").await.unwrap();
let conn = mem.conn.lock().unwrap();
let count: i64 = conn
.query_row(
"SELECT COUNT(*) FROM memories_fts WHERE memories_fts MATCH '\"deletable_content_abc\"'",
[],
|row| row.get(0),
)
.unwrap();
assert_eq!(count, 0);
}
#[tokio::test]
async fn fts5_syncs_on_update() {
let (_tmp, mem) = temp_sqlite();
mem.store("upd_key", "original_content_111", MemoryCategory::Core)
.await
.unwrap();
mem.store("upd_key", "updated_content_222", MemoryCategory::Core)
.await
.unwrap();
let conn = mem.conn.lock().unwrap();
// Old content should not be findable
let old: i64 = conn
.query_row(
"SELECT COUNT(*) FROM memories_fts WHERE memories_fts MATCH '\"original_content_111\"'",
[],
|row| row.get(0),
)
.unwrap();
assert_eq!(old, 0);
// New content should be findable
let new: i64 = conn
.query_row(
"SELECT COUNT(*) FROM memories_fts WHERE memories_fts MATCH '\"updated_content_222\"'",
[],
|row| row.get(0),
)
.unwrap();
assert_eq!(new, 1);
}
// ── With-embedder constructor test ───────────────────────────
#[test]
fn with_embedder_noop() {
let tmp = TempDir::new().unwrap();
let embedder = Arc::new(super::super::embeddings::NoopEmbedding);
let mem = SqliteMemory::with_embedder(tmp.path(), embedder, 0.7, 0.3, 1000);
assert!(mem.is_ok());
assert_eq!(mem.unwrap().name(), "sqlite");
}
// ── Reindex test ─────────────────────────────────────────────
#[tokio::test]
async fn reindex_rebuilds_fts() {
let (_tmp, mem) = temp_sqlite();
mem.store("r1", "reindex test alpha", MemoryCategory::Core)
.await
.unwrap();
mem.store("r2", "reindex test beta", MemoryCategory::Core)
.await
.unwrap();
// Reindex should succeed (noop embedder → 0 re-embedded)
let count = mem.reindex().await.unwrap();
assert_eq!(count, 0);
// FTS should still work after rebuild
let results = mem.recall("reindex", 10).await.unwrap();
assert_eq!(results.len(), 2);
}
// ── Recall limit test ────────────────────────────────────────
#[tokio::test]
async fn recall_respects_limit() {
let (_tmp, mem) = temp_sqlite();
for i in 0..20 {
mem.store(
&format!("k{i}"),
&format!("common keyword item {i}"),
MemoryCategory::Core,
)
.await
.unwrap();
}
let results = mem.recall("common keyword", 5).await.unwrap();
assert!(results.len() <= 5);
}
// ── Score presence test ──────────────────────────────────────
#[tokio::test]
async fn recall_results_have_scores() {
let (_tmp, mem) = temp_sqlite();
mem.store("s1", "scored result test", MemoryCategory::Core)
.await
.unwrap();
let results = mem.recall("scored", 10).await.unwrap();
assert!(!results.is_empty());
for r in &results {
assert!(r.score.is_some(), "Expected score on result: {:?}", r.key);
}
}
// ── Edge cases: FTS5 special characters ──────────────────────
#[tokio::test]
async fn recall_with_quotes_in_query() {
let (_tmp, mem) = temp_sqlite();
mem.store("q1", "He said hello world", MemoryCategory::Core)
.await
.unwrap();
// Quotes in query should not crash FTS5
let results = mem.recall("\"hello\"", 10).await.unwrap();
// May or may not match depending on FTS5 escaping, but must not error
assert!(results.len() <= 10);
}
#[tokio::test]
async fn recall_with_asterisk_in_query() {
let (_tmp, mem) = temp_sqlite();
mem.store("a1", "wildcard test content", MemoryCategory::Core)
.await
.unwrap();
let results = mem.recall("wild*", 10).await.unwrap();
assert!(results.len() <= 10);
}
#[tokio::test]
async fn recall_with_parentheses_in_query() {
let (_tmp, mem) = temp_sqlite();
mem.store("p1", "function call test", MemoryCategory::Core)
.await
.unwrap();
let results = mem.recall("function()", 10).await.unwrap();
assert!(results.len() <= 10);
}
#[tokio::test]
async fn recall_with_sql_injection_attempt() {
let (_tmp, mem) = temp_sqlite();
mem.store("safe", "normal content", MemoryCategory::Core)
.await
.unwrap();
// Should not crash or leak data
let results = mem.recall("'; DROP TABLE memories; --", 10).await.unwrap();
assert!(results.len() <= 10);
// Table should still exist
assert_eq!(mem.count().await.unwrap(), 1);
}
// ── Edge cases: store ────────────────────────────────────────
#[tokio::test]
async fn store_empty_content() {
let (_tmp, mem) = temp_sqlite();
mem.store("empty", "", MemoryCategory::Core).await.unwrap();
let entry = mem.get("empty").await.unwrap().unwrap();
assert_eq!(entry.content, "");
}
#[tokio::test]
async fn store_empty_key() {
let (_tmp, mem) = temp_sqlite();
mem.store("", "content for empty key", MemoryCategory::Core)
.await
.unwrap();
let entry = mem.get("").await.unwrap().unwrap();
assert_eq!(entry.content, "content for empty key");
}
#[tokio::test]
async fn store_very_long_content() {
let (_tmp, mem) = temp_sqlite();
let long_content = "x".repeat(100_000);
mem.store("long", &long_content, MemoryCategory::Core)
.await
.unwrap();
let entry = mem.get("long").await.unwrap().unwrap();
assert_eq!(entry.content.len(), 100_000);
}
#[tokio::test]
async fn store_unicode_and_emoji() {
let (_tmp, mem) = temp_sqlite();
mem.store("emoji_key_🦀", "こんにちは 🚀 Ñoño", MemoryCategory::Core)
.await
.unwrap();
let entry = mem.get("emoji_key_🦀").await.unwrap().unwrap();
assert_eq!(entry.content, "こんにちは 🚀 Ñoño");
}
#[tokio::test]
async fn store_content_with_newlines_and_tabs() {
let (_tmp, mem) = temp_sqlite();
let content = "line1\nline2\ttab\rcarriage\n\nnewparagraph";
mem.store("whitespace", content, MemoryCategory::Core)
.await
.unwrap();
let entry = mem.get("whitespace").await.unwrap().unwrap();
assert_eq!(entry.content, content);
}
// ── Edge cases: recall ───────────────────────────────────────
#[tokio::test]
async fn recall_single_character_query() {
let (_tmp, mem) = temp_sqlite();
mem.store("a", "x marks the spot", MemoryCategory::Core)
.await
.unwrap();
// Single char may not match FTS5 but LIKE fallback should work
let results = mem.recall("x", 10).await.unwrap();
// Should not crash; may or may not find results
assert!(results.len() <= 10);
}
#[tokio::test]
async fn recall_limit_zero() {
let (_tmp, mem) = temp_sqlite();
mem.store("a", "some content", MemoryCategory::Core)
.await
.unwrap();
let results = mem.recall("some", 0).await.unwrap();
assert!(results.is_empty());
}
#[tokio::test]
async fn recall_limit_one() {
let (_tmp, mem) = temp_sqlite();
mem.store("a", "matching content alpha", MemoryCategory::Core)
.await
.unwrap();
mem.store("b", "matching content beta", MemoryCategory::Core)
.await
.unwrap();
let results = mem.recall("matching content", 1).await.unwrap();
assert_eq!(results.len(), 1);
}
#[tokio::test]
async fn recall_matches_by_key_not_just_content() {
let (_tmp, mem) = temp_sqlite();
mem.store(
"rust_preferences",
"User likes systems programming",
MemoryCategory::Core,
)
.await
.unwrap();
// "rust" appears in key but not content — LIKE fallback checks key too
let results = mem.recall("rust", 10).await.unwrap();
assert!(!results.is_empty(), "Should match by key");
}
#[tokio::test]
async fn recall_unicode_query() {
let (_tmp, mem) = temp_sqlite();
mem.store("jp", "日本語のテスト", MemoryCategory::Core)
.await
.unwrap();
let results = mem.recall("日本語", 10).await.unwrap();
assert!(!results.is_empty());
}
// ── Edge cases: schema idempotency ───────────────────────────
#[tokio::test]
async fn schema_idempotent_reopen() {
let tmp = TempDir::new().unwrap();
{
let mem = SqliteMemory::new(tmp.path()).unwrap();
mem.store("k1", "v1", MemoryCategory::Core).await.unwrap();
}
// Open again — init_schema runs again on existing DB
let mem2 = SqliteMemory::new(tmp.path()).unwrap();
let entry = mem2.get("k1").await.unwrap();
assert!(entry.is_some());
assert_eq!(entry.unwrap().content, "v1");
// Store more data — should work fine
mem2.store("k2", "v2", MemoryCategory::Daily).await.unwrap();
assert_eq!(mem2.count().await.unwrap(), 2);
}
#[tokio::test]
async fn schema_triple_open() {
let tmp = TempDir::new().unwrap();
let _m1 = SqliteMemory::new(tmp.path()).unwrap();
let _m2 = SqliteMemory::new(tmp.path()).unwrap();
let m3 = SqliteMemory::new(tmp.path()).unwrap();
assert!(m3.health_check().await);
}
// ── Edge cases: forget + FTS5 consistency ────────────────────
#[tokio::test]
async fn forget_then_recall_no_ghost_results() {
let (_tmp, mem) = temp_sqlite();
mem.store("ghost", "phantom memory content", MemoryCategory::Core)
.await
.unwrap();
mem.forget("ghost").await.unwrap();
let results = mem.recall("phantom memory", 10).await.unwrap();
assert!(
results.is_empty(),
"Deleted memory should not appear in recall"
);
}
#[tokio::test]
async fn forget_and_re_store_same_key() {
let (_tmp, mem) = temp_sqlite();
mem.store("cycle", "version 1", MemoryCategory::Core)
.await
.unwrap();
mem.forget("cycle").await.unwrap();
mem.store("cycle", "version 2", MemoryCategory::Core)
.await
.unwrap();
let entry = mem.get("cycle").await.unwrap().unwrap();
assert_eq!(entry.content, "version 2");
assert_eq!(mem.count().await.unwrap(), 1);
}
// ── Edge cases: reindex ──────────────────────────────────────
#[tokio::test]
async fn reindex_empty_db() {
let (_tmp, mem) = temp_sqlite();
let count = mem.reindex().await.unwrap();
assert_eq!(count, 0);
}
#[tokio::test]
async fn reindex_twice_is_safe() {
let (_tmp, mem) = temp_sqlite();
mem.store("r1", "reindex data", MemoryCategory::Core)
.await
.unwrap();
mem.reindex().await.unwrap();
let count = mem.reindex().await.unwrap();
assert_eq!(count, 0); // Noop embedder → nothing to re-embed
// Data should still be intact
let results = mem.recall("reindex", 10).await.unwrap();
assert_eq!(results.len(), 1);
}
// ── Edge cases: content_hash ─────────────────────────────────
#[test]
fn content_hash_empty_string() {
let h = SqliteMemory::content_hash("");
assert!(!h.is_empty());
assert_eq!(h.len(), 16); // 16 hex chars
}
#[test]
fn content_hash_unicode() {
let h1 = SqliteMemory::content_hash("🦀");
let h2 = SqliteMemory::content_hash("🦀");
assert_eq!(h1, h2);
let h3 = SqliteMemory::content_hash("🚀");
assert_ne!(h1, h3);
}
#[test]
fn content_hash_long_input() {
let long = "a".repeat(1_000_000);
let h = SqliteMemory::content_hash(&long);
assert_eq!(h.len(), 16);
}
// ── Edge cases: category helpers ─────────────────────────────
#[test]
fn category_roundtrip_custom_with_spaces() {
let cat = MemoryCategory::Custom("my custom category".into());
let s = SqliteMemory::category_to_str(&cat);
assert_eq!(s, "my custom category");
let back = SqliteMemory::str_to_category(&s);
assert_eq!(back, cat);
}
#[test]
fn category_roundtrip_empty_custom() {
let cat = MemoryCategory::Custom(String::new());
let s = SqliteMemory::category_to_str(&cat);
assert_eq!(s, "");
let back = SqliteMemory::str_to_category(&s);
assert_eq!(back, MemoryCategory::Custom(String::new()));
}
// ── Edge cases: list ─────────────────────────────────────────
#[tokio::test]
async fn list_custom_category() {
let (_tmp, mem) = temp_sqlite();
mem.store("c1", "custom1", MemoryCategory::Custom("project".into()))
.await
.unwrap();
mem.store("c2", "custom2", MemoryCategory::Custom("project".into()))
.await
.unwrap();
mem.store("c3", "other", MemoryCategory::Core)
.await
.unwrap();
let project = mem
.list(Some(&MemoryCategory::Custom("project".into())))
.await
.unwrap();
assert_eq!(project.len(), 2);
}
#[tokio::test]
async fn list_empty_db() {
let (_tmp, mem) = temp_sqlite();
let all = mem.list(None).await.unwrap();
assert!(all.is_empty());
}
}