perf(memory): fold recall/vector/list optimizations into spawn_blocking refactor
This commit is contained in:
parent
4e528dde7d
commit
dd454178ed
1 changed files with 92 additions and 32 deletions
|
|
@ -326,16 +326,36 @@ impl SqliteMemory {
|
||||||
Ok(results)
|
Ok(results)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Vector similarity search: scan embeddings and compute cosine similarity
|
/// Vector similarity search: scan embeddings and compute cosine similarity.
|
||||||
|
///
|
||||||
|
/// Optional `category` and `session_id` filters reduce full-table scans
|
||||||
|
/// when the caller already knows the scope of relevant memories.
|
||||||
fn vector_search(
|
fn vector_search(
|
||||||
conn: &Connection,
|
conn: &Connection,
|
||||||
query_embedding: &[f32],
|
query_embedding: &[f32],
|
||||||
limit: usize,
|
limit: usize,
|
||||||
|
category: Option<&str>,
|
||||||
|
session_id: Option<&str>,
|
||||||
) -> anyhow::Result<Vec<(String, f32)>> {
|
) -> anyhow::Result<Vec<(String, f32)>> {
|
||||||
let mut stmt =
|
let mut sql =
|
||||||
conn.prepare("SELECT id, embedding FROM memories WHERE embedding IS NOT NULL")?;
|
"SELECT id, embedding FROM memories WHERE embedding IS NOT NULL".to_string();
|
||||||
|
let mut param_values: Vec<Box<dyn rusqlite::types::ToSql>> = Vec::new();
|
||||||
|
let mut idx = 1;
|
||||||
|
|
||||||
let rows = stmt.query_map([], |row| {
|
if let Some(cat) = category {
|
||||||
|
sql.push_str(&format!(" AND category = ?{idx}"));
|
||||||
|
param_values.push(Box::new(cat.to_string()));
|
||||||
|
idx += 1;
|
||||||
|
}
|
||||||
|
if let Some(sid) = session_id {
|
||||||
|
sql.push_str(&format!(" AND session_id = ?{idx}"));
|
||||||
|
param_values.push(Box::new(sid.to_string()));
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut stmt = conn.prepare(&sql)?;
|
||||||
|
let params_ref: Vec<&dyn rusqlite::types::ToSql> =
|
||||||
|
param_values.iter().map(AsRef::as_ref).collect();
|
||||||
|
let rows = stmt.query_map(params_ref.as_slice(), |row| {
|
||||||
let id: String = row.get(0)?;
|
let id: String = row.get(0)?;
|
||||||
let blob: Vec<u8> = row.get(1)?;
|
let blob: Vec<u8> = row.get(1)?;
|
||||||
Ok((id, blob))
|
Ok((id, blob))
|
||||||
|
|
@ -485,7 +505,8 @@ impl Memory for SqliteMemory {
|
||||||
|
|
||||||
// Vector similarity search (if embeddings available)
|
// Vector similarity search (if embeddings available)
|
||||||
let vector_results = if let Some(ref qe) = query_embedding {
|
let vector_results = if let Some(ref qe) = query_embedding {
|
||||||
Self::vector_search(&conn, qe, limit * 2).unwrap_or_default()
|
Self::vector_search(&conn, qe, limit * 2, None, session_ref)
|
||||||
|
.unwrap_or_default()
|
||||||
} else {
|
} else {
|
||||||
Vec::new()
|
Vec::new()
|
||||||
};
|
};
|
||||||
|
|
@ -511,36 +532,73 @@ impl Memory for SqliteMemory {
|
||||||
)
|
)
|
||||||
};
|
};
|
||||||
|
|
||||||
// Fetch full entries for merged results
|
// Fetch full entries for merged results in a single query
|
||||||
|
// instead of N round-trips (N+1 pattern).
|
||||||
let mut results = Vec::new();
|
let mut results = Vec::new();
|
||||||
for scored in &merged {
|
if !merged.is_empty() {
|
||||||
let mut stmt = conn.prepare(
|
let placeholders: String = (1..=merged.len())
|
||||||
"SELECT id, key, content, category, created_at, session_id FROM memories WHERE id = ?1",
|
.map(|i| format!("?{i}"))
|
||||||
)?;
|
.collect::<Vec<_>>()
|
||||||
if let Ok(entry) = stmt.query_row(params![scored.id], |row| {
|
.join(", ");
|
||||||
Ok(MemoryEntry {
|
let sql = format!(
|
||||||
id: row.get(0)?,
|
"SELECT id, key, content, category, created_at, session_id \
|
||||||
key: row.get(1)?,
|
FROM memories WHERE id IN ({placeholders})"
|
||||||
content: row.get(2)?,
|
);
|
||||||
category: Self::str_to_category(&row.get::<_, String>(3)?),
|
let mut stmt = conn.prepare(&sql)?;
|
||||||
timestamp: row.get(4)?,
|
let id_params: Vec<Box<dyn rusqlite::types::ToSql>> = merged
|
||||||
session_id: row.get(5)?,
|
.iter()
|
||||||
score: Some(f64::from(scored.final_score)),
|
.map(|s| Box::new(s.id.clone()) as Box<dyn rusqlite::types::ToSql>)
|
||||||
})
|
.collect();
|
||||||
}) {
|
let params_ref: Vec<&dyn rusqlite::types::ToSql> =
|
||||||
if let Some(sid) = session_ref {
|
id_params.iter().map(AsRef::as_ref).collect();
|
||||||
if entry.session_id.as_deref() != Some(sid) {
|
let rows = stmt.query_map(params_ref.as_slice(), |row| {
|
||||||
continue;
|
Ok((
|
||||||
|
row.get::<_, String>(0)?,
|
||||||
|
row.get::<_, String>(1)?,
|
||||||
|
row.get::<_, String>(2)?,
|
||||||
|
row.get::<_, String>(3)?,
|
||||||
|
row.get::<_, String>(4)?,
|
||||||
|
row.get::<_, Option<String>>(5)?,
|
||||||
|
))
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let mut entry_map = std::collections::HashMap::new();
|
||||||
|
for row in rows {
|
||||||
|
let (id, key, content, cat, ts, sid) = row?;
|
||||||
|
entry_map.insert(id, (key, content, cat, ts, sid));
|
||||||
|
}
|
||||||
|
|
||||||
|
for scored in &merged {
|
||||||
|
if let Some((key, content, cat, ts, sid)) = entry_map.remove(&scored.id) {
|
||||||
|
let entry = MemoryEntry {
|
||||||
|
id: scored.id.clone(),
|
||||||
|
key,
|
||||||
|
content,
|
||||||
|
category: Self::str_to_category(&cat),
|
||||||
|
timestamp: ts,
|
||||||
|
session_id: sid,
|
||||||
|
score: Some(f64::from(scored.final_score)),
|
||||||
|
};
|
||||||
|
if let Some(filter_sid) = session_ref {
|
||||||
|
if entry.session_id.as_deref() != Some(filter_sid) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
results.push(entry);
|
||||||
}
|
}
|
||||||
results.push(entry);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// If hybrid returned nothing, fall back to LIKE search
|
// If hybrid returned nothing, fall back to LIKE search.
|
||||||
|
// Cap keyword count so we don't create too many SQL shapes,
|
||||||
|
// which helps prepared-statement cache efficiency.
|
||||||
if results.is_empty() {
|
if results.is_empty() {
|
||||||
let keywords: Vec<String> =
|
const MAX_LIKE_KEYWORDS: usize = 8;
|
||||||
query.split_whitespace().map(|w| format!("%{w}%")).collect();
|
let keywords: Vec<String> = query
|
||||||
|
.split_whitespace()
|
||||||
|
.take(MAX_LIKE_KEYWORDS)
|
||||||
|
.map(|w| format!("%{w}%"))
|
||||||
|
.collect();
|
||||||
if !keywords.is_empty() {
|
if !keywords.is_empty() {
|
||||||
let conditions: Vec<String> = keywords
|
let conditions: Vec<String> = keywords
|
||||||
.iter()
|
.iter()
|
||||||
|
|
@ -635,6 +693,8 @@ impl Memory for SqliteMemory {
|
||||||
category: Option<&MemoryCategory>,
|
category: Option<&MemoryCategory>,
|
||||||
session_id: Option<&str>,
|
session_id: Option<&str>,
|
||||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||||
|
const DEFAULT_LIST_LIMIT: i64 = 1000;
|
||||||
|
|
||||||
let conn = self.conn.clone();
|
let conn = self.conn.clone();
|
||||||
let category = category.cloned();
|
let category = category.cloned();
|
||||||
let session_id = session_id.map(String::from);
|
let session_id = session_id.map(String::from);
|
||||||
|
|
@ -660,9 +720,9 @@ impl Memory for SqliteMemory {
|
||||||
let cat_str = Self::category_to_str(cat);
|
let cat_str = Self::category_to_str(cat);
|
||||||
let mut stmt = conn.prepare(
|
let mut stmt = conn.prepare(
|
||||||
"SELECT id, key, content, category, created_at, session_id FROM memories
|
"SELECT id, key, content, category, created_at, session_id FROM memories
|
||||||
WHERE category = ?1 ORDER BY updated_at DESC",
|
WHERE category = ?1 ORDER BY updated_at DESC LIMIT ?2",
|
||||||
)?;
|
)?;
|
||||||
let rows = stmt.query_map(params![cat_str], row_mapper)?;
|
let rows = stmt.query_map(params![cat_str, DEFAULT_LIST_LIMIT], row_mapper)?;
|
||||||
for row in rows {
|
for row in rows {
|
||||||
let entry = row?;
|
let entry = row?;
|
||||||
if let Some(sid) = session_ref {
|
if let Some(sid) = session_ref {
|
||||||
|
|
@ -675,9 +735,9 @@ impl Memory for SqliteMemory {
|
||||||
} else {
|
} else {
|
||||||
let mut stmt = conn.prepare(
|
let mut stmt = conn.prepare(
|
||||||
"SELECT id, key, content, category, created_at, session_id FROM memories
|
"SELECT id, key, content, category, created_at, session_id FROM memories
|
||||||
ORDER BY updated_at DESC",
|
ORDER BY updated_at DESC LIMIT ?1",
|
||||||
)?;
|
)?;
|
||||||
let rows = stmt.query_map([], row_mapper)?;
|
let rows = stmt.query_map(params![DEFAULT_LIST_LIMIT], row_mapper)?;
|
||||||
for row in rows {
|
for row in rows {
|
||||||
let entry = row?;
|
let entry = row?;
|
||||||
if let Some(sid) = session_ref {
|
if let Some(sid) = session_ref {
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue