diff --git a/src/memory/sqlite.rs b/src/memory/sqlite.rs index fa992a0..9f2a25c 100644 --- a/src/memory/sqlite.rs +++ b/src/memory/sqlite.rs @@ -326,16 +326,36 @@ impl SqliteMemory { Ok(results) } - /// Vector similarity search: scan embeddings and compute cosine similarity + /// Vector similarity search: scan embeddings and compute cosine similarity. + /// + /// Optional `category` and `session_id` filters reduce full-table scans + /// when the caller already knows the scope of relevant memories. fn vector_search( conn: &Connection, query_embedding: &[f32], limit: usize, + category: Option<&str>, + session_id: Option<&str>, ) -> anyhow::Result> { - let mut stmt = - conn.prepare("SELECT id, embedding FROM memories WHERE embedding IS NOT NULL")?; + let mut sql = + "SELECT id, embedding FROM memories WHERE embedding IS NOT NULL".to_string(); + let mut param_values: Vec> = Vec::new(); + let mut idx = 1; - let rows = stmt.query_map([], |row| { + if let Some(cat) = category { + sql.push_str(&format!(" AND category = ?{idx}")); + param_values.push(Box::new(cat.to_string())); + idx += 1; + } + if let Some(sid) = session_id { + sql.push_str(&format!(" AND session_id = ?{idx}")); + param_values.push(Box::new(sid.to_string())); + } + + let mut stmt = conn.prepare(&sql)?; + let params_ref: Vec<&dyn rusqlite::types::ToSql> = + param_values.iter().map(AsRef::as_ref).collect(); + let rows = stmt.query_map(params_ref.as_slice(), |row| { let id: String = row.get(0)?; let blob: Vec = row.get(1)?; Ok((id, blob)) @@ -485,7 +505,8 @@ impl Memory for SqliteMemory { // Vector similarity search (if embeddings available) let vector_results = if let Some(ref qe) = query_embedding { - Self::vector_search(&conn, qe, limit * 2).unwrap_or_default() + Self::vector_search(&conn, qe, limit * 2, None, session_ref) + .unwrap_or_default() } else { Vec::new() }; @@ -511,36 +532,73 @@ impl Memory for SqliteMemory { ) }; - // Fetch full entries for merged results + // Fetch full entries for merged results in a single query + // instead of N round-trips (N+1 pattern). let mut results = Vec::new(); - for scored in &merged { - let mut stmt = conn.prepare( - "SELECT id, key, content, category, created_at, session_id FROM memories WHERE id = ?1", - )?; - if let Ok(entry) = stmt.query_row(params![scored.id], |row| { - Ok(MemoryEntry { - id: row.get(0)?, - key: row.get(1)?, - content: row.get(2)?, - category: Self::str_to_category(&row.get::<_, String>(3)?), - timestamp: row.get(4)?, - session_id: row.get(5)?, - score: Some(f64::from(scored.final_score)), - }) - }) { - if let Some(sid) = session_ref { - if entry.session_id.as_deref() != Some(sid) { - continue; + if !merged.is_empty() { + let placeholders: String = (1..=merged.len()) + .map(|i| format!("?{i}")) + .collect::>() + .join(", "); + let sql = format!( + "SELECT id, key, content, category, created_at, session_id \ + FROM memories WHERE id IN ({placeholders})" + ); + let mut stmt = conn.prepare(&sql)?; + let id_params: Vec> = merged + .iter() + .map(|s| Box::new(s.id.clone()) as Box) + .collect(); + let params_ref: Vec<&dyn rusqlite::types::ToSql> = + id_params.iter().map(AsRef::as_ref).collect(); + let rows = stmt.query_map(params_ref.as_slice(), |row| { + Ok(( + row.get::<_, String>(0)?, + row.get::<_, String>(1)?, + row.get::<_, String>(2)?, + row.get::<_, String>(3)?, + row.get::<_, String>(4)?, + row.get::<_, Option>(5)?, + )) + })?; + + let mut entry_map = std::collections::HashMap::new(); + for row in rows { + let (id, key, content, cat, ts, sid) = row?; + entry_map.insert(id, (key, content, cat, ts, sid)); + } + + for scored in &merged { + if let Some((key, content, cat, ts, sid)) = entry_map.remove(&scored.id) { + let entry = MemoryEntry { + id: scored.id.clone(), + key, + content, + category: Self::str_to_category(&cat), + timestamp: ts, + session_id: sid, + score: Some(f64::from(scored.final_score)), + }; + if let Some(filter_sid) = session_ref { + if entry.session_id.as_deref() != Some(filter_sid) { + continue; + } } + results.push(entry); } - results.push(entry); } } - // If hybrid returned nothing, fall back to LIKE search + // If hybrid returned nothing, fall back to LIKE search. + // Cap keyword count so we don't create too many SQL shapes, + // which helps prepared-statement cache efficiency. if results.is_empty() { - let keywords: Vec = - query.split_whitespace().map(|w| format!("%{w}%")).collect(); + const MAX_LIKE_KEYWORDS: usize = 8; + let keywords: Vec = query + .split_whitespace() + .take(MAX_LIKE_KEYWORDS) + .map(|w| format!("%{w}%")) + .collect(); if !keywords.is_empty() { let conditions: Vec = keywords .iter() @@ -635,6 +693,8 @@ impl Memory for SqliteMemory { category: Option<&MemoryCategory>, session_id: Option<&str>, ) -> anyhow::Result> { + const DEFAULT_LIST_LIMIT: i64 = 1000; + let conn = self.conn.clone(); let category = category.cloned(); let session_id = session_id.map(String::from); @@ -660,9 +720,9 @@ impl Memory for SqliteMemory { let cat_str = Self::category_to_str(cat); let mut stmt = conn.prepare( "SELECT id, key, content, category, created_at, session_id FROM memories - WHERE category = ?1 ORDER BY updated_at DESC", + WHERE category = ?1 ORDER BY updated_at DESC LIMIT ?2", )?; - let rows = stmt.query_map(params![cat_str], row_mapper)?; + let rows = stmt.query_map(params![cat_str, DEFAULT_LIST_LIMIT], row_mapper)?; for row in rows { let entry = row?; if let Some(sid) = session_ref { @@ -675,9 +735,9 @@ impl Memory for SqliteMemory { } else { let mut stmt = conn.prepare( "SELECT id, key, content, category, created_at, session_id FROM memories - ORDER BY updated_at DESC", + ORDER BY updated_at DESC LIMIT ?1", )?; - let rows = stmt.query_map([], row_mapper)?; + let rows = stmt.query_map(params![DEFAULT_LIST_LIMIT], row_mapper)?; for row in rows { let entry = row?; if let Some(sid) = session_ref {