perf(memory): fold recall/vector/list optimizations into spawn_blocking refactor

This commit is contained in:
Chummy 2026-02-18 14:44:17 +08:00
parent 4e528dde7d
commit dd454178ed

View file

@ -326,16 +326,36 @@ impl SqliteMemory {
Ok(results)
}
/// Vector similarity search: scan embeddings and compute cosine similarity
/// Vector similarity search: scan embeddings and compute cosine similarity.
///
/// Optional `category` and `session_id` filters reduce full-table scans
/// when the caller already knows the scope of relevant memories.
fn vector_search(
conn: &Connection,
query_embedding: &[f32],
limit: usize,
category: Option<&str>,
session_id: Option<&str>,
) -> anyhow::Result<Vec<(String, f32)>> {
let mut stmt =
conn.prepare("SELECT id, embedding FROM memories WHERE embedding IS NOT NULL")?;
let mut sql =
"SELECT id, embedding FROM memories WHERE embedding IS NOT NULL".to_string();
let mut param_values: Vec<Box<dyn rusqlite::types::ToSql>> = Vec::new();
let mut idx = 1;
let rows = stmt.query_map([], |row| {
if let Some(cat) = category {
sql.push_str(&format!(" AND category = ?{idx}"));
param_values.push(Box::new(cat.to_string()));
idx += 1;
}
if let Some(sid) = session_id {
sql.push_str(&format!(" AND session_id = ?{idx}"));
param_values.push(Box::new(sid.to_string()));
}
let mut stmt = conn.prepare(&sql)?;
let params_ref: Vec<&dyn rusqlite::types::ToSql> =
param_values.iter().map(AsRef::as_ref).collect();
let rows = stmt.query_map(params_ref.as_slice(), |row| {
let id: String = row.get(0)?;
let blob: Vec<u8> = row.get(1)?;
Ok((id, blob))
@ -485,7 +505,8 @@ impl Memory for SqliteMemory {
// Vector similarity search (if embeddings available)
let vector_results = if let Some(ref qe) = query_embedding {
Self::vector_search(&conn, qe, limit * 2).unwrap_or_default()
Self::vector_search(&conn, qe, limit * 2, None, session_ref)
.unwrap_or_default()
} else {
Vec::new()
};
@ -511,36 +532,73 @@ impl Memory for SqliteMemory {
)
};
// Fetch full entries for merged results
// Fetch full entries for merged results in a single query
// instead of N round-trips (N+1 pattern).
let mut results = Vec::new();
if !merged.is_empty() {
let placeholders: String = (1..=merged.len())
.map(|i| format!("?{i}"))
.collect::<Vec<_>>()
.join(", ");
let sql = format!(
"SELECT id, key, content, category, created_at, session_id \
FROM memories WHERE id IN ({placeholders})"
);
let mut stmt = conn.prepare(&sql)?;
let id_params: Vec<Box<dyn rusqlite::types::ToSql>> = merged
.iter()
.map(|s| Box::new(s.id.clone()) as Box<dyn rusqlite::types::ToSql>)
.collect();
let params_ref: Vec<&dyn rusqlite::types::ToSql> =
id_params.iter().map(AsRef::as_ref).collect();
let rows = stmt.query_map(params_ref.as_slice(), |row| {
Ok((
row.get::<_, String>(0)?,
row.get::<_, String>(1)?,
row.get::<_, String>(2)?,
row.get::<_, String>(3)?,
row.get::<_, String>(4)?,
row.get::<_, Option<String>>(5)?,
))
})?;
let mut entry_map = std::collections::HashMap::new();
for row in rows {
let (id, key, content, cat, ts, sid) = row?;
entry_map.insert(id, (key, content, cat, ts, sid));
}
for scored in &merged {
let mut stmt = conn.prepare(
"SELECT id, key, content, category, created_at, session_id FROM memories WHERE id = ?1",
)?;
if let Ok(entry) = stmt.query_row(params![scored.id], |row| {
Ok(MemoryEntry {
id: row.get(0)?,
key: row.get(1)?,
content: row.get(2)?,
category: Self::str_to_category(&row.get::<_, String>(3)?),
timestamp: row.get(4)?,
session_id: row.get(5)?,
if let Some((key, content, cat, ts, sid)) = entry_map.remove(&scored.id) {
let entry = MemoryEntry {
id: scored.id.clone(),
key,
content,
category: Self::str_to_category(&cat),
timestamp: ts,
session_id: sid,
score: Some(f64::from(scored.final_score)),
})
}) {
if let Some(sid) = session_ref {
if entry.session_id.as_deref() != Some(sid) {
};
if let Some(filter_sid) = session_ref {
if entry.session_id.as_deref() != Some(filter_sid) {
continue;
}
}
results.push(entry);
}
}
}
// If hybrid returned nothing, fall back to LIKE search
// If hybrid returned nothing, fall back to LIKE search.
// Cap keyword count so we don't create too many SQL shapes,
// which helps prepared-statement cache efficiency.
if results.is_empty() {
let keywords: Vec<String> =
query.split_whitespace().map(|w| format!("%{w}%")).collect();
const MAX_LIKE_KEYWORDS: usize = 8;
let keywords: Vec<String> = query
.split_whitespace()
.take(MAX_LIKE_KEYWORDS)
.map(|w| format!("%{w}%"))
.collect();
if !keywords.is_empty() {
let conditions: Vec<String> = keywords
.iter()
@ -635,6 +693,8 @@ impl Memory for SqliteMemory {
category: Option<&MemoryCategory>,
session_id: Option<&str>,
) -> anyhow::Result<Vec<MemoryEntry>> {
const DEFAULT_LIST_LIMIT: i64 = 1000;
let conn = self.conn.clone();
let category = category.cloned();
let session_id = session_id.map(String::from);
@ -660,9 +720,9 @@ impl Memory for SqliteMemory {
let cat_str = Self::category_to_str(cat);
let mut stmt = conn.prepare(
"SELECT id, key, content, category, created_at, session_id FROM memories
WHERE category = ?1 ORDER BY updated_at DESC",
WHERE category = ?1 ORDER BY updated_at DESC LIMIT ?2",
)?;
let rows = stmt.query_map(params![cat_str], row_mapper)?;
let rows = stmt.query_map(params![cat_str, DEFAULT_LIST_LIMIT], row_mapper)?;
for row in rows {
let entry = row?;
if let Some(sid) = session_ref {
@ -675,9 +735,9 @@ impl Memory for SqliteMemory {
} else {
let mut stmt = conn.prepare(
"SELECT id, key, content, category, created_at, session_id FROM memories
ORDER BY updated_at DESC",
ORDER BY updated_at DESC LIMIT ?1",
)?;
let rows = stmt.query_map([], row_mapper)?;
let rows = stmt.query_map(params![DEFAULT_LIST_LIMIT], row_mapper)?;
for row in rows {
let entry = row?;
if let Some(sid) = session_ref {