perf(memory): wrap blocking SQLite calls in tokio::task::spawn_blocking
Problem: Every async fn in SqliteMemory acquired self.conn.lock() and ran synchronous rusqlite queries directly on the Tokio runtime thread. This blocks the async executor, preventing other tasks from making progress — especially harmful under concurrent recall/store load. Fix: - Change conn from Mutex<Connection> to Arc<Mutex<Connection>> so the connection handle can be cloned into spawn_blocking closures. - Wrap all synchronous database operations (store, recall, get, list, forget, count, health_check) in tokio::task::spawn_blocking. - Split get_or_compute_embedding into three phases: cache check (blocking), embedding computation (async I/O), cache store (blocking) — ensuring no lock is held across await points. - Apply the same pattern to the reindex method. The async I/O (embedding computation) remains on the Tokio runtime while all SQLite access runs on the blocking thread pool, preventing executor starvation. Ref: zeroclaw-labs/zeroclaw#710 (Item 4) Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
parent
83b098d7ac
commit
4e528dde7d
1 changed files with 279 additions and 215 deletions
|
|
@ -25,7 +25,7 @@ const SQLITE_OPEN_TIMEOUT_CAP_SECS: u64 = 300;
|
||||||
/// - **Embedding Cache**: LRU-evicted cache to avoid redundant API calls
|
/// - **Embedding Cache**: LRU-evicted cache to avoid redundant API calls
|
||||||
/// - **Safe Reindex**: temp DB → seed → sync → atomic swap → rollback
|
/// - **Safe Reindex**: temp DB → seed → sync → atomic swap → rollback
|
||||||
pub struct SqliteMemory {
|
pub struct SqliteMemory {
|
||||||
conn: Mutex<Connection>,
|
conn: Arc<Mutex<Connection>>,
|
||||||
db_path: PathBuf,
|
db_path: PathBuf,
|
||||||
embedder: Arc<dyn EmbeddingProvider>,
|
embedder: Arc<dyn EmbeddingProvider>,
|
||||||
vector_weight: f32,
|
vector_weight: f32,
|
||||||
|
|
@ -83,7 +83,7 @@ impl SqliteMemory {
|
||||||
Self::init_schema(&conn)?;
|
Self::init_schema(&conn)?;
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
conn: Mutex::new(conn),
|
conn: Arc::new(Mutex::new(conn)),
|
||||||
db_path,
|
db_path,
|
||||||
embedder,
|
embedder,
|
||||||
vector_weight,
|
vector_weight,
|
||||||
|
|
@ -229,50 +229,56 @@ impl SqliteMemory {
|
||||||
let hash = Self::content_hash(text);
|
let hash = Self::content_hash(text);
|
||||||
let now = Local::now().to_rfc3339();
|
let now = Local::now().to_rfc3339();
|
||||||
|
|
||||||
// Check cache
|
// Check cache (offloaded to blocking thread)
|
||||||
{
|
let conn = self.conn.clone();
|
||||||
let conn = self.conn.lock();
|
let hash_c = hash.clone();
|
||||||
|
let now_c = now.clone();
|
||||||
|
let cached = tokio::task::spawn_blocking(move || -> anyhow::Result<Option<Vec<f32>>> {
|
||||||
|
let conn = conn.lock();
|
||||||
let mut stmt =
|
let mut stmt =
|
||||||
conn.prepare("SELECT embedding FROM embedding_cache WHERE content_hash = ?1")?;
|
conn.prepare("SELECT embedding FROM embedding_cache WHERE content_hash = ?1")?;
|
||||||
let cached: Option<Vec<u8>> = stmt.query_row(params![hash], |row| row.get(0)).ok();
|
let blob: Option<Vec<u8>> = stmt.query_row(params![hash_c], |row| row.get(0)).ok();
|
||||||
|
if let Some(bytes) = blob {
|
||||||
if let Some(bytes) = cached {
|
|
||||||
// Update accessed_at for LRU
|
|
||||||
conn.execute(
|
conn.execute(
|
||||||
"UPDATE embedding_cache SET accessed_at = ?1 WHERE content_hash = ?2",
|
"UPDATE embedding_cache SET accessed_at = ?1 WHERE content_hash = ?2",
|
||||||
params![now, hash],
|
params![now_c, hash_c],
|
||||||
)?;
|
)?;
|
||||||
return Ok(Some(vector::bytes_to_vec(&bytes)));
|
return Ok(Some(vector::bytes_to_vec(&bytes)));
|
||||||
}
|
}
|
||||||
|
Ok(None)
|
||||||
|
})
|
||||||
|
.await??;
|
||||||
|
|
||||||
|
if cached.is_some() {
|
||||||
|
return Ok(cached);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Compute embedding
|
// Compute embedding (async I/O)
|
||||||
let embedding = self.embedder.embed_one(text).await?;
|
let embedding = self.embedder.embed_one(text).await?;
|
||||||
let bytes = vector::vec_to_bytes(&embedding);
|
let bytes = vector::vec_to_bytes(&embedding);
|
||||||
|
|
||||||
// Store in cache + LRU eviction
|
// Store in cache + LRU eviction (offloaded to blocking thread)
|
||||||
{
|
let conn = self.conn.clone();
|
||||||
let conn = self.conn.lock();
|
#[allow(clippy::cast_possible_wrap)]
|
||||||
|
let cache_max = self.cache_max as i64;
|
||||||
|
tokio::task::spawn_blocking(move || -> anyhow::Result<()> {
|
||||||
|
let conn = conn.lock();
|
||||||
conn.execute(
|
conn.execute(
|
||||||
"INSERT OR REPLACE INTO embedding_cache (content_hash, embedding, created_at, accessed_at)
|
"INSERT OR REPLACE INTO embedding_cache (content_hash, embedding, created_at, accessed_at)
|
||||||
VALUES (?1, ?2, ?3, ?4)",
|
VALUES (?1, ?2, ?3, ?4)",
|
||||||
params![hash, bytes, now, now],
|
params![hash, bytes, now, now],
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
// LRU eviction: keep only cache_max entries
|
|
||||||
#[allow(clippy::cast_possible_wrap)]
|
|
||||||
let max = self.cache_max as i64;
|
|
||||||
conn.execute(
|
conn.execute(
|
||||||
"DELETE FROM embedding_cache WHERE content_hash IN (
|
"DELETE FROM embedding_cache WHERE content_hash IN (
|
||||||
SELECT content_hash FROM embedding_cache
|
SELECT content_hash FROM embedding_cache
|
||||||
ORDER BY accessed_at ASC
|
ORDER BY accessed_at ASC
|
||||||
LIMIT MAX(0, (SELECT COUNT(*) FROM embedding_cache) - ?1)
|
LIMIT MAX(0, (SELECT COUNT(*) FROM embedding_cache) - ?1)
|
||||||
)",
|
)",
|
||||||
params![max],
|
params![cache_max],
|
||||||
)?;
|
)?;
|
||||||
}
|
Ok(())
|
||||||
|
})
|
||||||
|
.await??;
|
||||||
|
|
||||||
Ok(Some(embedding))
|
Ok(Some(embedding))
|
||||||
}
|
}
|
||||||
|
|
@ -355,9 +361,13 @@ impl SqliteMemory {
|
||||||
pub async fn reindex(&self) -> anyhow::Result<usize> {
|
pub async fn reindex(&self) -> anyhow::Result<usize> {
|
||||||
// Step 1: Rebuild FTS5
|
// Step 1: Rebuild FTS5
|
||||||
{
|
{
|
||||||
let conn = self.conn.lock();
|
let conn = self.conn.clone();
|
||||||
|
tokio::task::spawn_blocking(move || -> anyhow::Result<()> {
|
||||||
conn.execute_batch("INSERT INTO memories_fts(memories_fts) VALUES('rebuild');")?;
|
let conn = conn.lock();
|
||||||
|
conn.execute_batch("INSERT INTO memories_fts(memories_fts) VALUES('rebuild');")?;
|
||||||
|
Ok(())
|
||||||
|
})
|
||||||
|
.await??;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Step 2: Re-embed all memories that lack embeddings
|
// Step 2: Re-embed all memories that lack embeddings
|
||||||
|
|
@ -365,26 +375,33 @@ impl SqliteMemory {
|
||||||
return Ok(0);
|
return Ok(0);
|
||||||
}
|
}
|
||||||
|
|
||||||
let entries: Vec<(String, String)> = {
|
let conn = self.conn.clone();
|
||||||
let conn = self.conn.lock();
|
let entries: Vec<(String, String)> = tokio::task::spawn_blocking(move || {
|
||||||
|
let conn = conn.lock();
|
||||||
let mut stmt =
|
let mut stmt =
|
||||||
conn.prepare("SELECT id, content FROM memories WHERE embedding IS NULL")?;
|
conn.prepare("SELECT id, content FROM memories WHERE embedding IS NULL")?;
|
||||||
let rows = stmt.query_map([], |row| {
|
let rows = stmt.query_map([], |row| {
|
||||||
Ok((row.get::<_, String>(0)?, row.get::<_, String>(1)?))
|
Ok((row.get::<_, String>(0)?, row.get::<_, String>(1)?))
|
||||||
})?;
|
})?;
|
||||||
rows.filter_map(std::result::Result::ok).collect()
|
Ok::<_, anyhow::Error>(rows.filter_map(std::result::Result::ok).collect())
|
||||||
};
|
})
|
||||||
|
.await??;
|
||||||
|
|
||||||
let mut count = 0;
|
let mut count = 0;
|
||||||
for (id, content) in &entries {
|
for (id, content) in &entries {
|
||||||
if let Ok(Some(emb)) = self.get_or_compute_embedding(content).await {
|
if let Ok(Some(emb)) = self.get_or_compute_embedding(content).await {
|
||||||
let bytes = vector::vec_to_bytes(&emb);
|
let bytes = vector::vec_to_bytes(&emb);
|
||||||
let conn = self.conn.lock();
|
let conn = self.conn.clone();
|
||||||
conn.execute(
|
let id = id.clone();
|
||||||
"UPDATE memories SET embedding = ?1 WHERE id = ?2",
|
tokio::task::spawn_blocking(move || -> anyhow::Result<()> {
|
||||||
params![bytes, id],
|
let conn = conn.lock();
|
||||||
)?;
|
conn.execute(
|
||||||
|
"UPDATE memories SET embedding = ?1 WHERE id = ?2",
|
||||||
|
params![bytes, id],
|
||||||
|
)?;
|
||||||
|
Ok(())
|
||||||
|
})
|
||||||
|
.await??;
|
||||||
count += 1;
|
count += 1;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -406,30 +423,37 @@ impl Memory for SqliteMemory {
|
||||||
category: MemoryCategory,
|
category: MemoryCategory,
|
||||||
session_id: Option<&str>,
|
session_id: Option<&str>,
|
||||||
) -> anyhow::Result<()> {
|
) -> anyhow::Result<()> {
|
||||||
// Compute embedding (async, before lock)
|
// Compute embedding (async, before blocking work)
|
||||||
let embedding_bytes = self
|
let embedding_bytes = self
|
||||||
.get_or_compute_embedding(content)
|
.get_or_compute_embedding(content)
|
||||||
.await?
|
.await?
|
||||||
.map(|emb| vector::vec_to_bytes(&emb));
|
.map(|emb| vector::vec_to_bytes(&emb));
|
||||||
|
|
||||||
let conn = self.conn.lock();
|
let conn = self.conn.clone();
|
||||||
let now = Local::now().to_rfc3339();
|
let key = key.to_string();
|
||||||
let cat = Self::category_to_str(&category);
|
let content = content.to_string();
|
||||||
let id = Uuid::new_v4().to_string();
|
let session_id = session_id.map(String::from);
|
||||||
|
|
||||||
conn.execute(
|
tokio::task::spawn_blocking(move || -> anyhow::Result<()> {
|
||||||
"INSERT INTO memories (id, key, content, category, embedding, created_at, updated_at, session_id)
|
let conn = conn.lock();
|
||||||
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)
|
let now = Local::now().to_rfc3339();
|
||||||
ON CONFLICT(key) DO UPDATE SET
|
let cat = Self::category_to_str(&category);
|
||||||
content = excluded.content,
|
let id = Uuid::new_v4().to_string();
|
||||||
category = excluded.category,
|
|
||||||
embedding = excluded.embedding,
|
|
||||||
updated_at = excluded.updated_at,
|
|
||||||
session_id = excluded.session_id",
|
|
||||||
params![id, key, content, cat, embedding_bytes, now, now, session_id],
|
|
||||||
)?;
|
|
||||||
|
|
||||||
Ok(())
|
conn.execute(
|
||||||
|
"INSERT INTO memories (id, key, content, category, embedding, created_at, updated_at, session_id)
|
||||||
|
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)
|
||||||
|
ON CONFLICT(key) DO UPDATE SET
|
||||||
|
content = excluded.content,
|
||||||
|
category = excluded.category,
|
||||||
|
embedding = excluded.embedding,
|
||||||
|
updated_at = excluded.updated_at,
|
||||||
|
session_id = excluded.session_id",
|
||||||
|
params![id, key, content, cat, embedding_bytes, now, now, session_id],
|
||||||
|
)?;
|
||||||
|
Ok(())
|
||||||
|
})
|
||||||
|
.await?
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn recall(
|
async fn recall(
|
||||||
|
|
@ -442,101 +466,58 @@ impl Memory for SqliteMemory {
|
||||||
return Ok(Vec::new());
|
return Ok(Vec::new());
|
||||||
}
|
}
|
||||||
|
|
||||||
// Compute query embedding (async, before lock)
|
// Compute query embedding (async, before blocking work)
|
||||||
let query_embedding = self.get_or_compute_embedding(query).await?;
|
let query_embedding = self.get_or_compute_embedding(query).await?;
|
||||||
|
|
||||||
let conn = self.conn.lock();
|
let conn = self.conn.clone();
|
||||||
|
let query = query.to_string();
|
||||||
|
let session_id = session_id.map(String::from);
|
||||||
|
let vector_weight = self.vector_weight;
|
||||||
|
let keyword_weight = self.keyword_weight;
|
||||||
|
|
||||||
// FTS5 BM25 keyword search
|
tokio::task::spawn_blocking(move || -> anyhow::Result<Vec<MemoryEntry>> {
|
||||||
let keyword_results = Self::fts5_search(&conn, query, limit * 2).unwrap_or_default();
|
let conn = conn.lock();
|
||||||
|
let session_ref = session_id.as_deref();
|
||||||
|
|
||||||
// Vector similarity search (if embeddings available)
|
// FTS5 BM25 keyword search
|
||||||
let vector_results = if let Some(ref qe) = query_embedding {
|
let keyword_results =
|
||||||
Self::vector_search(&conn, qe, limit * 2).unwrap_or_default()
|
Self::fts5_search(&conn, &query, limit * 2).unwrap_or_default();
|
||||||
} else {
|
|
||||||
Vec::new()
|
|
||||||
};
|
|
||||||
|
|
||||||
// Hybrid merge
|
// Vector similarity search (if embeddings available)
|
||||||
let merged = if vector_results.is_empty() {
|
let vector_results = if let Some(ref qe) = query_embedding {
|
||||||
// No embeddings — use keyword results only
|
Self::vector_search(&conn, qe, limit * 2).unwrap_or_default()
|
||||||
keyword_results
|
} else {
|
||||||
.iter()
|
Vec::new()
|
||||||
.map(|(id, score)| vector::ScoredResult {
|
};
|
||||||
id: id.clone(),
|
|
||||||
vector_score: None,
|
|
||||||
keyword_score: Some(*score),
|
|
||||||
final_score: *score,
|
|
||||||
})
|
|
||||||
.collect::<Vec<_>>()
|
|
||||||
} else {
|
|
||||||
vector::hybrid_merge(
|
|
||||||
&vector_results,
|
|
||||||
&keyword_results,
|
|
||||||
self.vector_weight,
|
|
||||||
self.keyword_weight,
|
|
||||||
limit,
|
|
||||||
)
|
|
||||||
};
|
|
||||||
|
|
||||||
// Fetch full entries for merged results
|
// Hybrid merge
|
||||||
let mut results = Vec::new();
|
let merged = if vector_results.is_empty() {
|
||||||
for scored in &merged {
|
keyword_results
|
||||||
let mut stmt = conn.prepare(
|
|
||||||
"SELECT id, key, content, category, created_at, session_id FROM memories WHERE id = ?1",
|
|
||||||
)?;
|
|
||||||
if let Ok(entry) = stmt.query_row(params![scored.id], |row| {
|
|
||||||
Ok(MemoryEntry {
|
|
||||||
id: row.get(0)?,
|
|
||||||
key: row.get(1)?,
|
|
||||||
content: row.get(2)?,
|
|
||||||
category: Self::str_to_category(&row.get::<_, String>(3)?),
|
|
||||||
timestamp: row.get(4)?,
|
|
||||||
session_id: row.get(5)?,
|
|
||||||
score: Some(f64::from(scored.final_score)),
|
|
||||||
})
|
|
||||||
}) {
|
|
||||||
// Filter by session_id if requested
|
|
||||||
if let Some(sid) = session_id {
|
|
||||||
if entry.session_id.as_deref() != Some(sid) {
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
results.push(entry);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// If hybrid returned nothing, fall back to LIKE search
|
|
||||||
if results.is_empty() {
|
|
||||||
let keywords: Vec<String> =
|
|
||||||
query.split_whitespace().map(|w| format!("%{w}%")).collect();
|
|
||||||
if !keywords.is_empty() {
|
|
||||||
let conditions: Vec<String> = keywords
|
|
||||||
.iter()
|
.iter()
|
||||||
.enumerate()
|
.map(|(id, score)| vector::ScoredResult {
|
||||||
.map(|(i, _)| {
|
id: id.clone(),
|
||||||
format!("(content LIKE ?{} OR key LIKE ?{})", i * 2 + 1, i * 2 + 2)
|
vector_score: None,
|
||||||
|
keyword_score: Some(*score),
|
||||||
|
final_score: *score,
|
||||||
})
|
})
|
||||||
.collect();
|
.collect::<Vec<_>>()
|
||||||
let where_clause = conditions.join(" OR ");
|
} else {
|
||||||
let sql = format!(
|
vector::hybrid_merge(
|
||||||
"SELECT id, key, content, category, created_at, session_id FROM memories
|
&vector_results,
|
||||||
WHERE {where_clause}
|
&keyword_results,
|
||||||
ORDER BY updated_at DESC
|
vector_weight,
|
||||||
LIMIT ?{}",
|
keyword_weight,
|
||||||
keywords.len() * 2 + 1
|
limit,
|
||||||
);
|
)
|
||||||
let mut stmt = conn.prepare(&sql)?;
|
};
|
||||||
let mut param_values: Vec<Box<dyn rusqlite::types::ToSql>> = Vec::new();
|
|
||||||
for kw in &keywords {
|
// Fetch full entries for merged results
|
||||||
param_values.push(Box::new(kw.clone()));
|
let mut results = Vec::new();
|
||||||
param_values.push(Box::new(kw.clone()));
|
for scored in &merged {
|
||||||
}
|
let mut stmt = conn.prepare(
|
||||||
#[allow(clippy::cast_possible_wrap)]
|
"SELECT id, key, content, category, created_at, session_id FROM memories WHERE id = ?1",
|
||||||
param_values.push(Box::new(limit as i64));
|
)?;
|
||||||
let params_ref: Vec<&dyn rusqlite::types::ToSql> =
|
if let Ok(entry) = stmt.query_row(params![scored.id], |row| {
|
||||||
param_values.iter().map(AsRef::as_ref).collect();
|
|
||||||
let rows = stmt.query_map(params_ref.as_slice(), |row| {
|
|
||||||
Ok(MemoryEntry {
|
Ok(MemoryEntry {
|
||||||
id: row.get(0)?,
|
id: row.get(0)?,
|
||||||
key: row.get(1)?,
|
key: row.get(1)?,
|
||||||
|
|
@ -544,12 +525,10 @@ impl Memory for SqliteMemory {
|
||||||
category: Self::str_to_category(&row.get::<_, String>(3)?),
|
category: Self::str_to_category(&row.get::<_, String>(3)?),
|
||||||
timestamp: row.get(4)?,
|
timestamp: row.get(4)?,
|
||||||
session_id: row.get(5)?,
|
session_id: row.get(5)?,
|
||||||
score: Some(1.0),
|
score: Some(f64::from(scored.final_score)),
|
||||||
})
|
})
|
||||||
})?;
|
}) {
|
||||||
for row in rows {
|
if let Some(sid) = session_ref {
|
||||||
let entry = row?;
|
|
||||||
if let Some(sid) = session_id {
|
|
||||||
if entry.session_id.as_deref() != Some(sid) {
|
if entry.session_id.as_deref() != Some(sid) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
@ -557,35 +536,98 @@ impl Memory for SqliteMemory {
|
||||||
results.push(entry);
|
results.push(entry);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
results.truncate(limit);
|
// If hybrid returned nothing, fall back to LIKE search
|
||||||
Ok(results)
|
if results.is_empty() {
|
||||||
|
let keywords: Vec<String> =
|
||||||
|
query.split_whitespace().map(|w| format!("%{w}%")).collect();
|
||||||
|
if !keywords.is_empty() {
|
||||||
|
let conditions: Vec<String> = keywords
|
||||||
|
.iter()
|
||||||
|
.enumerate()
|
||||||
|
.map(|(i, _)| {
|
||||||
|
format!(
|
||||||
|
"(content LIKE ?{} OR key LIKE ?{})",
|
||||||
|
i * 2 + 1,
|
||||||
|
i * 2 + 2
|
||||||
|
)
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
let where_clause = conditions.join(" OR ");
|
||||||
|
let sql = format!(
|
||||||
|
"SELECT id, key, content, category, created_at, session_id FROM memories
|
||||||
|
WHERE {where_clause}
|
||||||
|
ORDER BY updated_at DESC
|
||||||
|
LIMIT ?{}",
|
||||||
|
keywords.len() * 2 + 1
|
||||||
|
);
|
||||||
|
let mut stmt = conn.prepare(&sql)?;
|
||||||
|
let mut param_values: Vec<Box<dyn rusqlite::types::ToSql>> = Vec::new();
|
||||||
|
for kw in &keywords {
|
||||||
|
param_values.push(Box::new(kw.clone()));
|
||||||
|
param_values.push(Box::new(kw.clone()));
|
||||||
|
}
|
||||||
|
#[allow(clippy::cast_possible_wrap)]
|
||||||
|
param_values.push(Box::new(limit as i64));
|
||||||
|
let params_ref: Vec<&dyn rusqlite::types::ToSql> =
|
||||||
|
param_values.iter().map(AsRef::as_ref).collect();
|
||||||
|
let rows = stmt.query_map(params_ref.as_slice(), |row| {
|
||||||
|
Ok(MemoryEntry {
|
||||||
|
id: row.get(0)?,
|
||||||
|
key: row.get(1)?,
|
||||||
|
content: row.get(2)?,
|
||||||
|
category: Self::str_to_category(&row.get::<_, String>(3)?),
|
||||||
|
timestamp: row.get(4)?,
|
||||||
|
session_id: row.get(5)?,
|
||||||
|
score: Some(1.0),
|
||||||
|
})
|
||||||
|
})?;
|
||||||
|
for row in rows {
|
||||||
|
let entry = row?;
|
||||||
|
if let Some(sid) = session_ref {
|
||||||
|
if entry.session_id.as_deref() != Some(sid) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
results.push(entry);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
results.truncate(limit);
|
||||||
|
Ok(results)
|
||||||
|
})
|
||||||
|
.await?
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn get(&self, key: &str) -> anyhow::Result<Option<MemoryEntry>> {
|
async fn get(&self, key: &str) -> anyhow::Result<Option<MemoryEntry>> {
|
||||||
let conn = self.conn.lock();
|
let conn = self.conn.clone();
|
||||||
|
let key = key.to_string();
|
||||||
|
|
||||||
let mut stmt = conn.prepare(
|
tokio::task::spawn_blocking(move || -> anyhow::Result<Option<MemoryEntry>> {
|
||||||
"SELECT id, key, content, category, created_at, session_id FROM memories WHERE key = ?1",
|
let conn = conn.lock();
|
||||||
)?;
|
let mut stmt = conn.prepare(
|
||||||
|
"SELECT id, key, content, category, created_at, session_id FROM memories WHERE key = ?1",
|
||||||
|
)?;
|
||||||
|
|
||||||
let mut rows = stmt.query_map(params![key], |row| {
|
let mut rows = stmt.query_map(params![key], |row| {
|
||||||
Ok(MemoryEntry {
|
Ok(MemoryEntry {
|
||||||
id: row.get(0)?,
|
id: row.get(0)?,
|
||||||
key: row.get(1)?,
|
key: row.get(1)?,
|
||||||
content: row.get(2)?,
|
content: row.get(2)?,
|
||||||
category: Self::str_to_category(&row.get::<_, String>(3)?),
|
category: Self::str_to_category(&row.get::<_, String>(3)?),
|
||||||
timestamp: row.get(4)?,
|
timestamp: row.get(4)?,
|
||||||
session_id: row.get(5)?,
|
session_id: row.get(5)?,
|
||||||
score: None,
|
score: None,
|
||||||
})
|
})
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
match rows.next() {
|
match rows.next() {
|
||||||
Some(Ok(entry)) => Ok(Some(entry)),
|
Some(Ok(entry)) => Ok(Some(entry)),
|
||||||
_ => Ok(None),
|
_ => Ok(None),
|
||||||
}
|
}
|
||||||
|
})
|
||||||
|
.await?
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn list(
|
async fn list(
|
||||||
|
|
@ -593,73 +635,95 @@ impl Memory for SqliteMemory {
|
||||||
category: Option<&MemoryCategory>,
|
category: Option<&MemoryCategory>,
|
||||||
session_id: Option<&str>,
|
session_id: Option<&str>,
|
||||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||||
let conn = self.conn.lock();
|
let conn = self.conn.clone();
|
||||||
|
let category = category.cloned();
|
||||||
|
let session_id = session_id.map(String::from);
|
||||||
|
|
||||||
let mut results = Vec::new();
|
tokio::task::spawn_blocking(move || -> anyhow::Result<Vec<MemoryEntry>> {
|
||||||
|
let conn = conn.lock();
|
||||||
|
let session_ref = session_id.as_deref();
|
||||||
|
let mut results = Vec::new();
|
||||||
|
|
||||||
let row_mapper = |row: &rusqlite::Row| -> rusqlite::Result<MemoryEntry> {
|
let row_mapper = |row: &rusqlite::Row| -> rusqlite::Result<MemoryEntry> {
|
||||||
Ok(MemoryEntry {
|
Ok(MemoryEntry {
|
||||||
id: row.get(0)?,
|
id: row.get(0)?,
|
||||||
key: row.get(1)?,
|
key: row.get(1)?,
|
||||||
content: row.get(2)?,
|
content: row.get(2)?,
|
||||||
category: Self::str_to_category(&row.get::<_, String>(3)?),
|
category: Self::str_to_category(&row.get::<_, String>(3)?),
|
||||||
timestamp: row.get(4)?,
|
timestamp: row.get(4)?,
|
||||||
session_id: row.get(5)?,
|
session_id: row.get(5)?,
|
||||||
score: None,
|
score: None,
|
||||||
})
|
})
|
||||||
};
|
};
|
||||||
|
|
||||||
if let Some(cat) = category {
|
if let Some(ref cat) = category {
|
||||||
let cat_str = Self::category_to_str(cat);
|
let cat_str = Self::category_to_str(cat);
|
||||||
let mut stmt = conn.prepare(
|
let mut stmt = conn.prepare(
|
||||||
"SELECT id, key, content, category, created_at, session_id FROM memories
|
"SELECT id, key, content, category, created_at, session_id FROM memories
|
||||||
WHERE category = ?1 ORDER BY updated_at DESC",
|
WHERE category = ?1 ORDER BY updated_at DESC",
|
||||||
)?;
|
)?;
|
||||||
let rows = stmt.query_map(params![cat_str], row_mapper)?;
|
let rows = stmt.query_map(params![cat_str], row_mapper)?;
|
||||||
for row in rows {
|
for row in rows {
|
||||||
let entry = row?;
|
let entry = row?;
|
||||||
if let Some(sid) = session_id {
|
if let Some(sid) = session_ref {
|
||||||
if entry.session_id.as_deref() != Some(sid) {
|
if entry.session_id.as_deref() != Some(sid) {
|
||||||
continue;
|
continue;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
results.push(entry);
|
||||||
}
|
}
|
||||||
results.push(entry);
|
} else {
|
||||||
}
|
let mut stmt = conn.prepare(
|
||||||
} else {
|
"SELECT id, key, content, category, created_at, session_id FROM memories
|
||||||
let mut stmt = conn.prepare(
|
ORDER BY updated_at DESC",
|
||||||
"SELECT id, key, content, category, created_at, session_id FROM memories
|
)?;
|
||||||
ORDER BY updated_at DESC",
|
let rows = stmt.query_map([], row_mapper)?;
|
||||||
)?;
|
for row in rows {
|
||||||
let rows = stmt.query_map([], row_mapper)?;
|
let entry = row?;
|
||||||
for row in rows {
|
if let Some(sid) = session_ref {
|
||||||
let entry = row?;
|
if entry.session_id.as_deref() != Some(sid) {
|
||||||
if let Some(sid) = session_id {
|
continue;
|
||||||
if entry.session_id.as_deref() != Some(sid) {
|
}
|
||||||
continue;
|
|
||||||
}
|
}
|
||||||
|
results.push(entry);
|
||||||
}
|
}
|
||||||
results.push(entry);
|
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
Ok(results)
|
Ok(results)
|
||||||
|
})
|
||||||
|
.await?
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn forget(&self, key: &str) -> anyhow::Result<bool> {
|
async fn forget(&self, key: &str) -> anyhow::Result<bool> {
|
||||||
let conn = self.conn.lock();
|
let conn = self.conn.clone();
|
||||||
let affected = conn.execute("DELETE FROM memories WHERE key = ?1", params![key])?;
|
let key = key.to_string();
|
||||||
Ok(affected > 0)
|
|
||||||
|
tokio::task::spawn_blocking(move || -> anyhow::Result<bool> {
|
||||||
|
let conn = conn.lock();
|
||||||
|
let affected = conn.execute("DELETE FROM memories WHERE key = ?1", params![key])?;
|
||||||
|
Ok(affected > 0)
|
||||||
|
})
|
||||||
|
.await?
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn count(&self) -> anyhow::Result<usize> {
|
async fn count(&self) -> anyhow::Result<usize> {
|
||||||
let conn = self.conn.lock();
|
let conn = self.conn.clone();
|
||||||
let count: i64 = conn.query_row("SELECT COUNT(*) FROM memories", [], |row| row.get(0))?;
|
|
||||||
#[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)]
|
tokio::task::spawn_blocking(move || -> anyhow::Result<usize> {
|
||||||
Ok(count as usize)
|
let conn = conn.lock();
|
||||||
|
let count: i64 =
|
||||||
|
conn.query_row("SELECT COUNT(*) FROM memories", [], |row| row.get(0))?;
|
||||||
|
#[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)]
|
||||||
|
Ok(count as usize)
|
||||||
|
})
|
||||||
|
.await?
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn health_check(&self) -> bool {
|
async fn health_check(&self) -> bool {
|
||||||
self.conn.lock().execute_batch("SELECT 1").is_ok()
|
let conn = self.conn.clone();
|
||||||
|
tokio::task::spawn_blocking(move || conn.lock().execute_batch("SELECT 1").is_ok())
|
||||||
|
.await
|
||||||
|
.unwrap_or(false)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue