perf(memory): fold recall/vector/list optimizations into spawn_blocking refactor

This commit is contained in:
Chummy 2026-02-18 14:44:17 +08:00
parent 4e528dde7d
commit dd454178ed

View file

@ -326,16 +326,36 @@ impl SqliteMemory {
Ok(results) Ok(results)
} }
/// Vector similarity search: scan embeddings and compute cosine similarity /// Vector similarity search: scan embeddings and compute cosine similarity.
///
/// Optional `category` and `session_id` filters reduce full-table scans
/// when the caller already knows the scope of relevant memories.
fn vector_search( fn vector_search(
conn: &Connection, conn: &Connection,
query_embedding: &[f32], query_embedding: &[f32],
limit: usize, limit: usize,
category: Option<&str>,
session_id: Option<&str>,
) -> anyhow::Result<Vec<(String, f32)>> { ) -> anyhow::Result<Vec<(String, f32)>> {
let mut stmt = let mut sql =
conn.prepare("SELECT id, embedding FROM memories WHERE embedding IS NOT NULL")?; "SELECT id, embedding FROM memories WHERE embedding IS NOT NULL".to_string();
let mut param_values: Vec<Box<dyn rusqlite::types::ToSql>> = Vec::new();
let mut idx = 1;
let rows = stmt.query_map([], |row| { if let Some(cat) = category {
sql.push_str(&format!(" AND category = ?{idx}"));
param_values.push(Box::new(cat.to_string()));
idx += 1;
}
if let Some(sid) = session_id {
sql.push_str(&format!(" AND session_id = ?{idx}"));
param_values.push(Box::new(sid.to_string()));
}
let mut stmt = conn.prepare(&sql)?;
let params_ref: Vec<&dyn rusqlite::types::ToSql> =
param_values.iter().map(AsRef::as_ref).collect();
let rows = stmt.query_map(params_ref.as_slice(), |row| {
let id: String = row.get(0)?; let id: String = row.get(0)?;
let blob: Vec<u8> = row.get(1)?; let blob: Vec<u8> = row.get(1)?;
Ok((id, blob)) Ok((id, blob))
@ -485,7 +505,8 @@ impl Memory for SqliteMemory {
// Vector similarity search (if embeddings available) // Vector similarity search (if embeddings available)
let vector_results = if let Some(ref qe) = query_embedding { let vector_results = if let Some(ref qe) = query_embedding {
Self::vector_search(&conn, qe, limit * 2).unwrap_or_default() Self::vector_search(&conn, qe, limit * 2, None, session_ref)
.unwrap_or_default()
} else { } else {
Vec::new() Vec::new()
}; };
@ -511,36 +532,73 @@ impl Memory for SqliteMemory {
) )
}; };
// Fetch full entries for merged results // Fetch full entries for merged results in a single query
// instead of N round-trips (N+1 pattern).
let mut results = Vec::new(); let mut results = Vec::new();
for scored in &merged { if !merged.is_empty() {
let mut stmt = conn.prepare( let placeholders: String = (1..=merged.len())
"SELECT id, key, content, category, created_at, session_id FROM memories WHERE id = ?1", .map(|i| format!("?{i}"))
)?; .collect::<Vec<_>>()
if let Ok(entry) = stmt.query_row(params![scored.id], |row| { .join(", ");
Ok(MemoryEntry { let sql = format!(
id: row.get(0)?, "SELECT id, key, content, category, created_at, session_id \
key: row.get(1)?, FROM memories WHERE id IN ({placeholders})"
content: row.get(2)?, );
category: Self::str_to_category(&row.get::<_, String>(3)?), let mut stmt = conn.prepare(&sql)?;
timestamp: row.get(4)?, let id_params: Vec<Box<dyn rusqlite::types::ToSql>> = merged
session_id: row.get(5)?, .iter()
score: Some(f64::from(scored.final_score)), .map(|s| Box::new(s.id.clone()) as Box<dyn rusqlite::types::ToSql>)
}) .collect();
}) { let params_ref: Vec<&dyn rusqlite::types::ToSql> =
if let Some(sid) = session_ref { id_params.iter().map(AsRef::as_ref).collect();
if entry.session_id.as_deref() != Some(sid) { let rows = stmt.query_map(params_ref.as_slice(), |row| {
continue; Ok((
row.get::<_, String>(0)?,
row.get::<_, String>(1)?,
row.get::<_, String>(2)?,
row.get::<_, String>(3)?,
row.get::<_, String>(4)?,
row.get::<_, Option<String>>(5)?,
))
})?;
let mut entry_map = std::collections::HashMap::new();
for row in rows {
let (id, key, content, cat, ts, sid) = row?;
entry_map.insert(id, (key, content, cat, ts, sid));
}
for scored in &merged {
if let Some((key, content, cat, ts, sid)) = entry_map.remove(&scored.id) {
let entry = MemoryEntry {
id: scored.id.clone(),
key,
content,
category: Self::str_to_category(&cat),
timestamp: ts,
session_id: sid,
score: Some(f64::from(scored.final_score)),
};
if let Some(filter_sid) = session_ref {
if entry.session_id.as_deref() != Some(filter_sid) {
continue;
}
} }
results.push(entry);
} }
results.push(entry);
} }
} }
// If hybrid returned nothing, fall back to LIKE search // If hybrid returned nothing, fall back to LIKE search.
// Cap keyword count so we don't create too many SQL shapes,
// which helps prepared-statement cache efficiency.
if results.is_empty() { if results.is_empty() {
let keywords: Vec<String> = const MAX_LIKE_KEYWORDS: usize = 8;
query.split_whitespace().map(|w| format!("%{w}%")).collect(); let keywords: Vec<String> = query
.split_whitespace()
.take(MAX_LIKE_KEYWORDS)
.map(|w| format!("%{w}%"))
.collect();
if !keywords.is_empty() { if !keywords.is_empty() {
let conditions: Vec<String> = keywords let conditions: Vec<String> = keywords
.iter() .iter()
@ -635,6 +693,8 @@ impl Memory for SqliteMemory {
category: Option<&MemoryCategory>, category: Option<&MemoryCategory>,
session_id: Option<&str>, session_id: Option<&str>,
) -> anyhow::Result<Vec<MemoryEntry>> { ) -> anyhow::Result<Vec<MemoryEntry>> {
const DEFAULT_LIST_LIMIT: i64 = 1000;
let conn = self.conn.clone(); let conn = self.conn.clone();
let category = category.cloned(); let category = category.cloned();
let session_id = session_id.map(String::from); let session_id = session_id.map(String::from);
@ -660,9 +720,9 @@ impl Memory for SqliteMemory {
let cat_str = Self::category_to_str(cat); let cat_str = Self::category_to_str(cat);
let mut stmt = conn.prepare( let mut stmt = conn.prepare(
"SELECT id, key, content, category, created_at, session_id FROM memories "SELECT id, key, content, category, created_at, session_id FROM memories
WHERE category = ?1 ORDER BY updated_at DESC", WHERE category = ?1 ORDER BY updated_at DESC LIMIT ?2",
)?; )?;
let rows = stmt.query_map(params![cat_str], row_mapper)?; let rows = stmt.query_map(params![cat_str, DEFAULT_LIST_LIMIT], row_mapper)?;
for row in rows { for row in rows {
let entry = row?; let entry = row?;
if let Some(sid) = session_ref { if let Some(sid) = session_ref {
@ -675,9 +735,9 @@ impl Memory for SqliteMemory {
} else { } else {
let mut stmt = conn.prepare( let mut stmt = conn.prepare(
"SELECT id, key, content, category, created_at, session_id FROM memories "SELECT id, key, content, category, created_at, session_id FROM memories
ORDER BY updated_at DESC", ORDER BY updated_at DESC LIMIT ?1",
)?; )?;
let rows = stmt.query_map([], row_mapper)?; let rows = stmt.query_map(params![DEFAULT_LIST_LIMIT], row_mapper)?;
for row in rows { for row in rows {
let entry = row?; let entry = row?;
if let Some(sid) = session_ref { if let Some(sid) = session_ref {