diff --git a/src/memory/sqlite.rs b/src/memory/sqlite.rs index a7b0a44..fa992a0 100644 --- a/src/memory/sqlite.rs +++ b/src/memory/sqlite.rs @@ -25,7 +25,7 @@ const SQLITE_OPEN_TIMEOUT_CAP_SECS: u64 = 300; /// - **Embedding Cache**: LRU-evicted cache to avoid redundant API calls /// - **Safe Reindex**: temp DB → seed → sync → atomic swap → rollback pub struct SqliteMemory { - conn: Mutex, + conn: Arc>, db_path: PathBuf, embedder: Arc, vector_weight: f32, @@ -83,7 +83,7 @@ impl SqliteMemory { Self::init_schema(&conn)?; Ok(Self { - conn: Mutex::new(conn), + conn: Arc::new(Mutex::new(conn)), db_path, embedder, vector_weight, @@ -229,50 +229,56 @@ impl SqliteMemory { let hash = Self::content_hash(text); let now = Local::now().to_rfc3339(); - // Check cache - { - let conn = self.conn.lock(); - + // Check cache (offloaded to blocking thread) + let conn = self.conn.clone(); + let hash_c = hash.clone(); + let now_c = now.clone(); + let cached = tokio::task::spawn_blocking(move || -> anyhow::Result>> { + let conn = conn.lock(); let mut stmt = conn.prepare("SELECT embedding FROM embedding_cache WHERE content_hash = ?1")?; - let cached: Option> = stmt.query_row(params![hash], |row| row.get(0)).ok(); - - if let Some(bytes) = cached { - // Update accessed_at for LRU + let blob: Option> = stmt.query_row(params![hash_c], |row| row.get(0)).ok(); + if let Some(bytes) = blob { conn.execute( "UPDATE embedding_cache SET accessed_at = ?1 WHERE content_hash = ?2", - params![now, hash], + params![now_c, hash_c], )?; return Ok(Some(vector::bytes_to_vec(&bytes))); } + Ok(None) + }) + .await??; + + if cached.is_some() { + return Ok(cached); } - // Compute embedding + // Compute embedding (async I/O) let embedding = self.embedder.embed_one(text).await?; let bytes = vector::vec_to_bytes(&embedding); - // Store in cache + LRU eviction - { - let conn = self.conn.lock(); - + // Store in cache + LRU eviction (offloaded to blocking thread) + let conn = self.conn.clone(); + #[allow(clippy::cast_possible_wrap)] + let cache_max = self.cache_max as i64; + tokio::task::spawn_blocking(move || -> anyhow::Result<()> { + let conn = conn.lock(); conn.execute( "INSERT OR REPLACE INTO embedding_cache (content_hash, embedding, created_at, accessed_at) VALUES (?1, ?2, ?3, ?4)", params![hash, bytes, now, now], )?; - - // LRU eviction: keep only cache_max entries - #[allow(clippy::cast_possible_wrap)] - let max = self.cache_max as i64; conn.execute( "DELETE FROM embedding_cache WHERE content_hash IN ( SELECT content_hash FROM embedding_cache ORDER BY accessed_at ASC LIMIT MAX(0, (SELECT COUNT(*) FROM embedding_cache) - ?1) )", - params![max], + params![cache_max], )?; - } + Ok(()) + }) + .await??; Ok(Some(embedding)) } @@ -355,9 +361,13 @@ impl SqliteMemory { pub async fn reindex(&self) -> anyhow::Result { // Step 1: Rebuild FTS5 { - let conn = self.conn.lock(); - - conn.execute_batch("INSERT INTO memories_fts(memories_fts) VALUES('rebuild');")?; + let conn = self.conn.clone(); + tokio::task::spawn_blocking(move || -> anyhow::Result<()> { + let conn = conn.lock(); + conn.execute_batch("INSERT INTO memories_fts(memories_fts) VALUES('rebuild');")?; + Ok(()) + }) + .await??; } // Step 2: Re-embed all memories that lack embeddings @@ -365,26 +375,33 @@ impl SqliteMemory { return Ok(0); } - let entries: Vec<(String, String)> = { - let conn = self.conn.lock(); - + let conn = self.conn.clone(); + let entries: Vec<(String, String)> = tokio::task::spawn_blocking(move || { + let conn = conn.lock(); let mut stmt = conn.prepare("SELECT id, content FROM memories WHERE embedding IS NULL")?; let rows = stmt.query_map([], |row| { Ok((row.get::<_, String>(0)?, row.get::<_, String>(1)?)) })?; - rows.filter_map(std::result::Result::ok).collect() - }; + Ok::<_, anyhow::Error>(rows.filter_map(std::result::Result::ok).collect()) + }) + .await??; let mut count = 0; for (id, content) in &entries { if let Ok(Some(emb)) = self.get_or_compute_embedding(content).await { let bytes = vector::vec_to_bytes(&emb); - let conn = self.conn.lock(); - conn.execute( - "UPDATE memories SET embedding = ?1 WHERE id = ?2", - params![bytes, id], - )?; + let conn = self.conn.clone(); + let id = id.clone(); + tokio::task::spawn_blocking(move || -> anyhow::Result<()> { + let conn = conn.lock(); + conn.execute( + "UPDATE memories SET embedding = ?1 WHERE id = ?2", + params![bytes, id], + )?; + Ok(()) + }) + .await??; count += 1; } } @@ -406,30 +423,37 @@ impl Memory for SqliteMemory { category: MemoryCategory, session_id: Option<&str>, ) -> anyhow::Result<()> { - // Compute embedding (async, before lock) + // Compute embedding (async, before blocking work) let embedding_bytes = self .get_or_compute_embedding(content) .await? .map(|emb| vector::vec_to_bytes(&emb)); - let conn = self.conn.lock(); - let now = Local::now().to_rfc3339(); - let cat = Self::category_to_str(&category); - let id = Uuid::new_v4().to_string(); + let conn = self.conn.clone(); + let key = key.to_string(); + let content = content.to_string(); + let session_id = session_id.map(String::from); - conn.execute( - "INSERT INTO memories (id, key, content, category, embedding, created_at, updated_at, session_id) - VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8) - ON CONFLICT(key) DO UPDATE SET - content = excluded.content, - category = excluded.category, - embedding = excluded.embedding, - updated_at = excluded.updated_at, - session_id = excluded.session_id", - params![id, key, content, cat, embedding_bytes, now, now, session_id], - )?; + tokio::task::spawn_blocking(move || -> anyhow::Result<()> { + let conn = conn.lock(); + let now = Local::now().to_rfc3339(); + let cat = Self::category_to_str(&category); + let id = Uuid::new_v4().to_string(); - Ok(()) + conn.execute( + "INSERT INTO memories (id, key, content, category, embedding, created_at, updated_at, session_id) + VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8) + ON CONFLICT(key) DO UPDATE SET + content = excluded.content, + category = excluded.category, + embedding = excluded.embedding, + updated_at = excluded.updated_at, + session_id = excluded.session_id", + params![id, key, content, cat, embedding_bytes, now, now, session_id], + )?; + Ok(()) + }) + .await? } async fn recall( @@ -442,101 +466,58 @@ impl Memory for SqliteMemory { return Ok(Vec::new()); } - // Compute query embedding (async, before lock) + // Compute query embedding (async, before blocking work) let query_embedding = self.get_or_compute_embedding(query).await?; - let conn = self.conn.lock(); + let conn = self.conn.clone(); + let query = query.to_string(); + let session_id = session_id.map(String::from); + let vector_weight = self.vector_weight; + let keyword_weight = self.keyword_weight; - // FTS5 BM25 keyword search - let keyword_results = Self::fts5_search(&conn, query, limit * 2).unwrap_or_default(); + tokio::task::spawn_blocking(move || -> anyhow::Result> { + let conn = conn.lock(); + let session_ref = session_id.as_deref(); - // Vector similarity search (if embeddings available) - let vector_results = if let Some(ref qe) = query_embedding { - Self::vector_search(&conn, qe, limit * 2).unwrap_or_default() - } else { - Vec::new() - }; + // FTS5 BM25 keyword search + let keyword_results = + Self::fts5_search(&conn, &query, limit * 2).unwrap_or_default(); - // Hybrid merge - let merged = if vector_results.is_empty() { - // No embeddings — use keyword results only - keyword_results - .iter() - .map(|(id, score)| vector::ScoredResult { - id: id.clone(), - vector_score: None, - keyword_score: Some(*score), - final_score: *score, - }) - .collect::>() - } else { - vector::hybrid_merge( - &vector_results, - &keyword_results, - self.vector_weight, - self.keyword_weight, - limit, - ) - }; + // Vector similarity search (if embeddings available) + let vector_results = if let Some(ref qe) = query_embedding { + Self::vector_search(&conn, qe, limit * 2).unwrap_or_default() + } else { + Vec::new() + }; - // Fetch full entries for merged results - let mut results = Vec::new(); - for scored in &merged { - let mut stmt = conn.prepare( - "SELECT id, key, content, category, created_at, session_id FROM memories WHERE id = ?1", - )?; - if let Ok(entry) = stmt.query_row(params![scored.id], |row| { - Ok(MemoryEntry { - id: row.get(0)?, - key: row.get(1)?, - content: row.get(2)?, - category: Self::str_to_category(&row.get::<_, String>(3)?), - timestamp: row.get(4)?, - session_id: row.get(5)?, - score: Some(f64::from(scored.final_score)), - }) - }) { - // Filter by session_id if requested - if let Some(sid) = session_id { - if entry.session_id.as_deref() != Some(sid) { - continue; - } - } - results.push(entry); - } - } - - // If hybrid returned nothing, fall back to LIKE search - if results.is_empty() { - let keywords: Vec = - query.split_whitespace().map(|w| format!("%{w}%")).collect(); - if !keywords.is_empty() { - let conditions: Vec = keywords + // Hybrid merge + let merged = if vector_results.is_empty() { + keyword_results .iter() - .enumerate() - .map(|(i, _)| { - format!("(content LIKE ?{} OR key LIKE ?{})", i * 2 + 1, i * 2 + 2) + .map(|(id, score)| vector::ScoredResult { + id: id.clone(), + vector_score: None, + keyword_score: Some(*score), + final_score: *score, }) - .collect(); - let where_clause = conditions.join(" OR "); - let sql = format!( - "SELECT id, key, content, category, created_at, session_id FROM memories - WHERE {where_clause} - ORDER BY updated_at DESC - LIMIT ?{}", - keywords.len() * 2 + 1 - ); - let mut stmt = conn.prepare(&sql)?; - let mut param_values: Vec> = Vec::new(); - for kw in &keywords { - param_values.push(Box::new(kw.clone())); - param_values.push(Box::new(kw.clone())); - } - #[allow(clippy::cast_possible_wrap)] - param_values.push(Box::new(limit as i64)); - let params_ref: Vec<&dyn rusqlite::types::ToSql> = - param_values.iter().map(AsRef::as_ref).collect(); - let rows = stmt.query_map(params_ref.as_slice(), |row| { + .collect::>() + } else { + vector::hybrid_merge( + &vector_results, + &keyword_results, + vector_weight, + keyword_weight, + limit, + ) + }; + + // Fetch full entries for merged results + let mut results = Vec::new(); + for scored in &merged { + let mut stmt = conn.prepare( + "SELECT id, key, content, category, created_at, session_id FROM memories WHERE id = ?1", + )?; + if let Ok(entry) = stmt.query_row(params![scored.id], |row| { Ok(MemoryEntry { id: row.get(0)?, key: row.get(1)?, @@ -544,12 +525,10 @@ impl Memory for SqliteMemory { category: Self::str_to_category(&row.get::<_, String>(3)?), timestamp: row.get(4)?, session_id: row.get(5)?, - score: Some(1.0), + score: Some(f64::from(scored.final_score)), }) - })?; - for row in rows { - let entry = row?; - if let Some(sid) = session_id { + }) { + if let Some(sid) = session_ref { if entry.session_id.as_deref() != Some(sid) { continue; } @@ -557,35 +536,98 @@ impl Memory for SqliteMemory { results.push(entry); } } - } - results.truncate(limit); - Ok(results) + // If hybrid returned nothing, fall back to LIKE search + if results.is_empty() { + let keywords: Vec = + query.split_whitespace().map(|w| format!("%{w}%")).collect(); + if !keywords.is_empty() { + let conditions: Vec = keywords + .iter() + .enumerate() + .map(|(i, _)| { + format!( + "(content LIKE ?{} OR key LIKE ?{})", + i * 2 + 1, + i * 2 + 2 + ) + }) + .collect(); + let where_clause = conditions.join(" OR "); + let sql = format!( + "SELECT id, key, content, category, created_at, session_id FROM memories + WHERE {where_clause} + ORDER BY updated_at DESC + LIMIT ?{}", + keywords.len() * 2 + 1 + ); + let mut stmt = conn.prepare(&sql)?; + let mut param_values: Vec> = Vec::new(); + for kw in &keywords { + param_values.push(Box::new(kw.clone())); + param_values.push(Box::new(kw.clone())); + } + #[allow(clippy::cast_possible_wrap)] + param_values.push(Box::new(limit as i64)); + let params_ref: Vec<&dyn rusqlite::types::ToSql> = + param_values.iter().map(AsRef::as_ref).collect(); + let rows = stmt.query_map(params_ref.as_slice(), |row| { + Ok(MemoryEntry { + id: row.get(0)?, + key: row.get(1)?, + content: row.get(2)?, + category: Self::str_to_category(&row.get::<_, String>(3)?), + timestamp: row.get(4)?, + session_id: row.get(5)?, + score: Some(1.0), + }) + })?; + for row in rows { + let entry = row?; + if let Some(sid) = session_ref { + if entry.session_id.as_deref() != Some(sid) { + continue; + } + } + results.push(entry); + } + } + } + + results.truncate(limit); + Ok(results) + }) + .await? } async fn get(&self, key: &str) -> anyhow::Result> { - let conn = self.conn.lock(); + let conn = self.conn.clone(); + let key = key.to_string(); - let mut stmt = conn.prepare( - "SELECT id, key, content, category, created_at, session_id FROM memories WHERE key = ?1", - )?; + tokio::task::spawn_blocking(move || -> anyhow::Result> { + let conn = conn.lock(); + let mut stmt = conn.prepare( + "SELECT id, key, content, category, created_at, session_id FROM memories WHERE key = ?1", + )?; - let mut rows = stmt.query_map(params![key], |row| { - Ok(MemoryEntry { - id: row.get(0)?, - key: row.get(1)?, - content: row.get(2)?, - category: Self::str_to_category(&row.get::<_, String>(3)?), - timestamp: row.get(4)?, - session_id: row.get(5)?, - score: None, - }) - })?; + let mut rows = stmt.query_map(params![key], |row| { + Ok(MemoryEntry { + id: row.get(0)?, + key: row.get(1)?, + content: row.get(2)?, + category: Self::str_to_category(&row.get::<_, String>(3)?), + timestamp: row.get(4)?, + session_id: row.get(5)?, + score: None, + }) + })?; - match rows.next() { - Some(Ok(entry)) => Ok(Some(entry)), - _ => Ok(None), - } + match rows.next() { + Some(Ok(entry)) => Ok(Some(entry)), + _ => Ok(None), + } + }) + .await? } async fn list( @@ -593,73 +635,95 @@ impl Memory for SqliteMemory { category: Option<&MemoryCategory>, session_id: Option<&str>, ) -> anyhow::Result> { - let conn = self.conn.lock(); + let conn = self.conn.clone(); + let category = category.cloned(); + let session_id = session_id.map(String::from); - let mut results = Vec::new(); + tokio::task::spawn_blocking(move || -> anyhow::Result> { + let conn = conn.lock(); + let session_ref = session_id.as_deref(); + let mut results = Vec::new(); - let row_mapper = |row: &rusqlite::Row| -> rusqlite::Result { - Ok(MemoryEntry { - id: row.get(0)?, - key: row.get(1)?, - content: row.get(2)?, - category: Self::str_to_category(&row.get::<_, String>(3)?), - timestamp: row.get(4)?, - session_id: row.get(5)?, - score: None, - }) - }; + let row_mapper = |row: &rusqlite::Row| -> rusqlite::Result { + Ok(MemoryEntry { + id: row.get(0)?, + key: row.get(1)?, + content: row.get(2)?, + category: Self::str_to_category(&row.get::<_, String>(3)?), + timestamp: row.get(4)?, + session_id: row.get(5)?, + score: None, + }) + }; - if let Some(cat) = category { - let cat_str = Self::category_to_str(cat); - let mut stmt = conn.prepare( - "SELECT id, key, content, category, created_at, session_id FROM memories - WHERE category = ?1 ORDER BY updated_at DESC", - )?; - let rows = stmt.query_map(params![cat_str], row_mapper)?; - for row in rows { - let entry = row?; - if let Some(sid) = session_id { - if entry.session_id.as_deref() != Some(sid) { - continue; + if let Some(ref cat) = category { + let cat_str = Self::category_to_str(cat); + let mut stmt = conn.prepare( + "SELECT id, key, content, category, created_at, session_id FROM memories + WHERE category = ?1 ORDER BY updated_at DESC", + )?; + let rows = stmt.query_map(params![cat_str], row_mapper)?; + for row in rows { + let entry = row?; + if let Some(sid) = session_ref { + if entry.session_id.as_deref() != Some(sid) { + continue; + } } + results.push(entry); } - results.push(entry); - } - } else { - let mut stmt = conn.prepare( - "SELECT id, key, content, category, created_at, session_id FROM memories - ORDER BY updated_at DESC", - )?; - let rows = stmt.query_map([], row_mapper)?; - for row in rows { - let entry = row?; - if let Some(sid) = session_id { - if entry.session_id.as_deref() != Some(sid) { - continue; + } else { + let mut stmt = conn.prepare( + "SELECT id, key, content, category, created_at, session_id FROM memories + ORDER BY updated_at DESC", + )?; + let rows = stmt.query_map([], row_mapper)?; + for row in rows { + let entry = row?; + if let Some(sid) = session_ref { + if entry.session_id.as_deref() != Some(sid) { + continue; + } } + results.push(entry); } - results.push(entry); } - } - Ok(results) + Ok(results) + }) + .await? } async fn forget(&self, key: &str) -> anyhow::Result { - let conn = self.conn.lock(); - let affected = conn.execute("DELETE FROM memories WHERE key = ?1", params![key])?; - Ok(affected > 0) + let conn = self.conn.clone(); + let key = key.to_string(); + + tokio::task::spawn_blocking(move || -> anyhow::Result { + let conn = conn.lock(); + let affected = conn.execute("DELETE FROM memories WHERE key = ?1", params![key])?; + Ok(affected > 0) + }) + .await? } async fn count(&self) -> anyhow::Result { - let conn = self.conn.lock(); - let count: i64 = conn.query_row("SELECT COUNT(*) FROM memories", [], |row| row.get(0))?; - #[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)] - Ok(count as usize) + let conn = self.conn.clone(); + + tokio::task::spawn_blocking(move || -> anyhow::Result { + let conn = conn.lock(); + let count: i64 = + conn.query_row("SELECT COUNT(*) FROM memories", [], |row| row.get(0))?; + #[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)] + Ok(count as usize) + }) + .await? } async fn health_check(&self) -> bool { - self.conn.lock().execute_batch("SELECT 1").is_ok() + let conn = self.conn.clone(); + tokio::task::spawn_blocking(move || conn.lock().execute_batch("SELECT 1").is_ok()) + .await + .unwrap_or(false) } }