perf(memory): fold recall/vector/list optimizations into spawn_blocking refactor
This commit is contained in:
parent
4e528dde7d
commit
dd454178ed
1 changed files with 92 additions and 32 deletions
|
|
@ -326,16 +326,36 @@ impl SqliteMemory {
|
|||
Ok(results)
|
||||
}
|
||||
|
||||
/// Vector similarity search: scan embeddings and compute cosine similarity
|
||||
/// Vector similarity search: scan embeddings and compute cosine similarity.
|
||||
///
|
||||
/// Optional `category` and `session_id` filters reduce full-table scans
|
||||
/// when the caller already knows the scope of relevant memories.
|
||||
fn vector_search(
|
||||
conn: &Connection,
|
||||
query_embedding: &[f32],
|
||||
limit: usize,
|
||||
category: Option<&str>,
|
||||
session_id: Option<&str>,
|
||||
) -> anyhow::Result<Vec<(String, f32)>> {
|
||||
let mut stmt =
|
||||
conn.prepare("SELECT id, embedding FROM memories WHERE embedding IS NOT NULL")?;
|
||||
let mut sql =
|
||||
"SELECT id, embedding FROM memories WHERE embedding IS NOT NULL".to_string();
|
||||
let mut param_values: Vec<Box<dyn rusqlite::types::ToSql>> = Vec::new();
|
||||
let mut idx = 1;
|
||||
|
||||
let rows = stmt.query_map([], |row| {
|
||||
if let Some(cat) = category {
|
||||
sql.push_str(&format!(" AND category = ?{idx}"));
|
||||
param_values.push(Box::new(cat.to_string()));
|
||||
idx += 1;
|
||||
}
|
||||
if let Some(sid) = session_id {
|
||||
sql.push_str(&format!(" AND session_id = ?{idx}"));
|
||||
param_values.push(Box::new(sid.to_string()));
|
||||
}
|
||||
|
||||
let mut stmt = conn.prepare(&sql)?;
|
||||
let params_ref: Vec<&dyn rusqlite::types::ToSql> =
|
||||
param_values.iter().map(AsRef::as_ref).collect();
|
||||
let rows = stmt.query_map(params_ref.as_slice(), |row| {
|
||||
let id: String = row.get(0)?;
|
||||
let blob: Vec<u8> = row.get(1)?;
|
||||
Ok((id, blob))
|
||||
|
|
@ -485,7 +505,8 @@ impl Memory for SqliteMemory {
|
|||
|
||||
// Vector similarity search (if embeddings available)
|
||||
let vector_results = if let Some(ref qe) = query_embedding {
|
||||
Self::vector_search(&conn, qe, limit * 2).unwrap_or_default()
|
||||
Self::vector_search(&conn, qe, limit * 2, None, session_ref)
|
||||
.unwrap_or_default()
|
||||
} else {
|
||||
Vec::new()
|
||||
};
|
||||
|
|
@ -511,36 +532,73 @@ impl Memory for SqliteMemory {
|
|||
)
|
||||
};
|
||||
|
||||
// Fetch full entries for merged results
|
||||
// Fetch full entries for merged results in a single query
|
||||
// instead of N round-trips (N+1 pattern).
|
||||
let mut results = Vec::new();
|
||||
if !merged.is_empty() {
|
||||
let placeholders: String = (1..=merged.len())
|
||||
.map(|i| format!("?{i}"))
|
||||
.collect::<Vec<_>>()
|
||||
.join(", ");
|
||||
let sql = format!(
|
||||
"SELECT id, key, content, category, created_at, session_id \
|
||||
FROM memories WHERE id IN ({placeholders})"
|
||||
);
|
||||
let mut stmt = conn.prepare(&sql)?;
|
||||
let id_params: Vec<Box<dyn rusqlite::types::ToSql>> = merged
|
||||
.iter()
|
||||
.map(|s| Box::new(s.id.clone()) as Box<dyn rusqlite::types::ToSql>)
|
||||
.collect();
|
||||
let params_ref: Vec<&dyn rusqlite::types::ToSql> =
|
||||
id_params.iter().map(AsRef::as_ref).collect();
|
||||
let rows = stmt.query_map(params_ref.as_slice(), |row| {
|
||||
Ok((
|
||||
row.get::<_, String>(0)?,
|
||||
row.get::<_, String>(1)?,
|
||||
row.get::<_, String>(2)?,
|
||||
row.get::<_, String>(3)?,
|
||||
row.get::<_, String>(4)?,
|
||||
row.get::<_, Option<String>>(5)?,
|
||||
))
|
||||
})?;
|
||||
|
||||
let mut entry_map = std::collections::HashMap::new();
|
||||
for row in rows {
|
||||
let (id, key, content, cat, ts, sid) = row?;
|
||||
entry_map.insert(id, (key, content, cat, ts, sid));
|
||||
}
|
||||
|
||||
for scored in &merged {
|
||||
let mut stmt = conn.prepare(
|
||||
"SELECT id, key, content, category, created_at, session_id FROM memories WHERE id = ?1",
|
||||
)?;
|
||||
if let Ok(entry) = stmt.query_row(params![scored.id], |row| {
|
||||
Ok(MemoryEntry {
|
||||
id: row.get(0)?,
|
||||
key: row.get(1)?,
|
||||
content: row.get(2)?,
|
||||
category: Self::str_to_category(&row.get::<_, String>(3)?),
|
||||
timestamp: row.get(4)?,
|
||||
session_id: row.get(5)?,
|
||||
if let Some((key, content, cat, ts, sid)) = entry_map.remove(&scored.id) {
|
||||
let entry = MemoryEntry {
|
||||
id: scored.id.clone(),
|
||||
key,
|
||||
content,
|
||||
category: Self::str_to_category(&cat),
|
||||
timestamp: ts,
|
||||
session_id: sid,
|
||||
score: Some(f64::from(scored.final_score)),
|
||||
})
|
||||
}) {
|
||||
if let Some(sid) = session_ref {
|
||||
if entry.session_id.as_deref() != Some(sid) {
|
||||
};
|
||||
if let Some(filter_sid) = session_ref {
|
||||
if entry.session_id.as_deref() != Some(filter_sid) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
results.push(entry);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If hybrid returned nothing, fall back to LIKE search
|
||||
// If hybrid returned nothing, fall back to LIKE search.
|
||||
// Cap keyword count so we don't create too many SQL shapes,
|
||||
// which helps prepared-statement cache efficiency.
|
||||
if results.is_empty() {
|
||||
let keywords: Vec<String> =
|
||||
query.split_whitespace().map(|w| format!("%{w}%")).collect();
|
||||
const MAX_LIKE_KEYWORDS: usize = 8;
|
||||
let keywords: Vec<String> = query
|
||||
.split_whitespace()
|
||||
.take(MAX_LIKE_KEYWORDS)
|
||||
.map(|w| format!("%{w}%"))
|
||||
.collect();
|
||||
if !keywords.is_empty() {
|
||||
let conditions: Vec<String> = keywords
|
||||
.iter()
|
||||
|
|
@ -635,6 +693,8 @@ impl Memory for SqliteMemory {
|
|||
category: Option<&MemoryCategory>,
|
||||
session_id: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
const DEFAULT_LIST_LIMIT: i64 = 1000;
|
||||
|
||||
let conn = self.conn.clone();
|
||||
let category = category.cloned();
|
||||
let session_id = session_id.map(String::from);
|
||||
|
|
@ -660,9 +720,9 @@ impl Memory for SqliteMemory {
|
|||
let cat_str = Self::category_to_str(cat);
|
||||
let mut stmt = conn.prepare(
|
||||
"SELECT id, key, content, category, created_at, session_id FROM memories
|
||||
WHERE category = ?1 ORDER BY updated_at DESC",
|
||||
WHERE category = ?1 ORDER BY updated_at DESC LIMIT ?2",
|
||||
)?;
|
||||
let rows = stmt.query_map(params![cat_str], row_mapper)?;
|
||||
let rows = stmt.query_map(params![cat_str, DEFAULT_LIST_LIMIT], row_mapper)?;
|
||||
for row in rows {
|
||||
let entry = row?;
|
||||
if let Some(sid) = session_ref {
|
||||
|
|
@ -675,9 +735,9 @@ impl Memory for SqliteMemory {
|
|||
} else {
|
||||
let mut stmt = conn.prepare(
|
||||
"SELECT id, key, content, category, created_at, session_id FROM memories
|
||||
ORDER BY updated_at DESC",
|
||||
ORDER BY updated_at DESC LIMIT ?1",
|
||||
)?;
|
||||
let rows = stmt.query_map([], row_mapper)?;
|
||||
let rows = stmt.query_map(params![DEFAULT_LIST_LIMIT], row_mapper)?;
|
||||
for row in rows {
|
||||
let entry = row?;
|
||||
if let Some(sid) = session_ref {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue