feat: full-stack search engine — FTS5, vector search, hybrid merge, embedding cache, chunker
The Full Stack (All Custom): - Vector DB: embeddings stored as BLOB, cosine similarity in pure Rust - Keyword Search: FTS5 virtual tables with BM25 scoring + auto-sync triggers - Hybrid Merge: weighted fusion of vector + keyword results (configurable weights) - Embeddings: provider abstraction (OpenAI, custom URL, noop fallback) - Chunking: line-based markdown chunker with heading preservation - Caching: embedding_cache table with LRU eviction - Safe Reindex: rebuild FTS5 + re-embed missing vectors New modules: - src/memory/embeddings.rs — EmbeddingProvider trait + OpenAI + Noop + factory - src/memory/vector.rs — cosine similarity, vec↔bytes, ScoredResult, hybrid_merge - src/memory/chunker.rs — markdown-aware document splitting Upgraded: - src/memory/sqlite.rs — FTS5 schema, embedding column, hybrid recall, cache, reindex - src/config/schema.rs — MemoryConfig expanded with embedding/search settings - All callers updated to pass api_key for embedding provider 739 tests passing, 0 clippy warnings (Rust 1.93.1), cargo-deny clean
This commit is contained in:
parent
4fceba0740
commit
0e7f501fd6
10 changed files with 1423 additions and 96 deletions
|
|
@ -49,6 +49,7 @@ pub async fn run(
|
||||||
let mem: Arc<dyn Memory> = Arc::from(memory::create_memory(
|
let mem: Arc<dyn Memory> = Arc::from(memory::create_memory(
|
||||||
&config.memory,
|
&config.memory,
|
||||||
&config.workspace_dir,
|
&config.workspace_dir,
|
||||||
|
config.api_key.as_deref(),
|
||||||
)?);
|
)?);
|
||||||
tracing::info!(backend = mem.name(), "Memory initialized");
|
tracing::info!(backend = mem.name(), "Memory initialized");
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -227,6 +227,7 @@ pub async fn start_channels(config: Config) -> Result<()> {
|
||||||
let mem: Arc<dyn Memory> = Arc::from(memory::create_memory(
|
let mem: Arc<dyn Memory> = Arc::from(memory::create_memory(
|
||||||
&config.memory,
|
&config.memory,
|
||||||
&config.workspace_dir,
|
&config.workspace_dir,
|
||||||
|
config.api_key.as_deref(),
|
||||||
)?);
|
)?);
|
||||||
|
|
||||||
// Build system prompt from workspace identity files + skills
|
// Build system prompt from workspace identity files + skills
|
||||||
|
|
|
||||||
|
|
@ -46,6 +46,49 @@ pub struct MemoryConfig {
|
||||||
pub backend: String,
|
pub backend: String,
|
||||||
/// Auto-save conversation context to memory
|
/// Auto-save conversation context to memory
|
||||||
pub auto_save: bool,
|
pub auto_save: bool,
|
||||||
|
/// Embedding provider: "none" | "openai" | "custom:URL"
|
||||||
|
#[serde(default = "default_embedding_provider")]
|
||||||
|
pub embedding_provider: String,
|
||||||
|
/// Embedding model name (e.g. "text-embedding-3-small")
|
||||||
|
#[serde(default = "default_embedding_model")]
|
||||||
|
pub embedding_model: String,
|
||||||
|
/// Embedding vector dimensions
|
||||||
|
#[serde(default = "default_embedding_dims")]
|
||||||
|
pub embedding_dimensions: usize,
|
||||||
|
/// Weight for vector similarity in hybrid search (0.0–1.0)
|
||||||
|
#[serde(default = "default_vector_weight")]
|
||||||
|
pub vector_weight: f64,
|
||||||
|
/// Weight for keyword BM25 in hybrid search (0.0–1.0)
|
||||||
|
#[serde(default = "default_keyword_weight")]
|
||||||
|
pub keyword_weight: f64,
|
||||||
|
/// Max embedding cache entries before LRU eviction
|
||||||
|
#[serde(default = "default_cache_size")]
|
||||||
|
pub embedding_cache_size: usize,
|
||||||
|
/// Max tokens per chunk for document splitting
|
||||||
|
#[serde(default = "default_chunk_size")]
|
||||||
|
pub chunk_max_tokens: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_embedding_provider() -> String {
|
||||||
|
"none".into()
|
||||||
|
}
|
||||||
|
fn default_embedding_model() -> String {
|
||||||
|
"text-embedding-3-small".into()
|
||||||
|
}
|
||||||
|
fn default_embedding_dims() -> usize {
|
||||||
|
1536
|
||||||
|
}
|
||||||
|
fn default_vector_weight() -> f64 {
|
||||||
|
0.7
|
||||||
|
}
|
||||||
|
fn default_keyword_weight() -> f64 {
|
||||||
|
0.3
|
||||||
|
}
|
||||||
|
fn default_cache_size() -> usize {
|
||||||
|
10_000
|
||||||
|
}
|
||||||
|
fn default_chunk_size() -> usize {
|
||||||
|
512
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for MemoryConfig {
|
impl Default for MemoryConfig {
|
||||||
|
|
@ -53,6 +96,13 @@ impl Default for MemoryConfig {
|
||||||
Self {
|
Self {
|
||||||
backend: "sqlite".into(),
|
backend: "sqlite".into(),
|
||||||
auto_save: true,
|
auto_save: true,
|
||||||
|
embedding_provider: default_embedding_provider(),
|
||||||
|
embedding_model: default_embedding_model(),
|
||||||
|
embedding_dimensions: default_embedding_dims(),
|
||||||
|
vector_weight: default_vector_weight(),
|
||||||
|
keyword_weight: default_keyword_weight(),
|
||||||
|
embedding_cache_size: default_cache_size(),
|
||||||
|
chunk_max_tokens: default_chunk_size(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -25,6 +25,7 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
|
||||||
let mem: Arc<dyn Memory> = Arc::from(memory::create_memory(
|
let mem: Arc<dyn Memory> = Arc::from(memory::create_memory(
|
||||||
&config.memory,
|
&config.memory,
|
||||||
&config.workspace_dir,
|
&config.workspace_dir,
|
||||||
|
config.api_key.as_deref(),
|
||||||
)?);
|
)?);
|
||||||
|
|
||||||
// Extract webhook secret for authentication
|
// Extract webhook secret for authentication
|
||||||
|
|
|
||||||
259
src/memory/chunker.rs
Normal file
259
src/memory/chunker.rs
Normal file
|
|
@ -0,0 +1,259 @@
|
||||||
|
// Line-based markdown chunker — splits documents into semantic chunks.
|
||||||
|
//
|
||||||
|
// Splits on markdown headings and paragraph boundaries, respecting
|
||||||
|
// a max token limit per chunk. Preserves heading context.
|
||||||
|
|
||||||
|
/// A single chunk of text with metadata.
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct Chunk {
|
||||||
|
pub index: usize,
|
||||||
|
pub content: String,
|
||||||
|
pub heading: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Split markdown text into chunks, each under `max_tokens` approximate tokens.
|
||||||
|
///
|
||||||
|
/// Strategy:
|
||||||
|
/// 1. Split on `## ` and `# ` headings (keeps heading with its content)
|
||||||
|
/// 2. If a section exceeds `max_tokens`, split on blank lines (paragraphs)
|
||||||
|
/// 3. If a paragraph still exceeds, split on line boundaries
|
||||||
|
///
|
||||||
|
/// Token estimation: ~4 chars per token (rough English average).
|
||||||
|
pub fn chunk_markdown(text: &str, max_tokens: usize) -> Vec<Chunk> {
|
||||||
|
if text.trim().is_empty() {
|
||||||
|
return Vec::new();
|
||||||
|
}
|
||||||
|
|
||||||
|
let max_chars = max_tokens * 4;
|
||||||
|
let sections = split_on_headings(text);
|
||||||
|
let mut chunks = Vec::new();
|
||||||
|
|
||||||
|
for (heading, body) in sections {
|
||||||
|
let full = if let Some(ref h) = heading {
|
||||||
|
format!("{h}\n{body}")
|
||||||
|
} else {
|
||||||
|
body.clone()
|
||||||
|
};
|
||||||
|
|
||||||
|
if full.len() <= max_chars {
|
||||||
|
chunks.push(Chunk {
|
||||||
|
index: chunks.len(),
|
||||||
|
content: full.trim().to_string(),
|
||||||
|
heading: heading.clone(),
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
// Split on paragraphs (blank lines)
|
||||||
|
let paragraphs = split_on_blank_lines(&body);
|
||||||
|
let mut current = heading
|
||||||
|
.as_ref()
|
||||||
|
.map_or_else(String::new, |h| format!("{h}\n"));
|
||||||
|
|
||||||
|
for para in paragraphs {
|
||||||
|
if current.len() + para.len() > max_chars && !current.trim().is_empty() {
|
||||||
|
chunks.push(Chunk {
|
||||||
|
index: chunks.len(),
|
||||||
|
content: current.trim().to_string(),
|
||||||
|
heading: heading.clone(),
|
||||||
|
});
|
||||||
|
current = heading
|
||||||
|
.as_ref()
|
||||||
|
.map_or_else(String::new, |h| format!("{h}\n"));
|
||||||
|
}
|
||||||
|
|
||||||
|
if para.len() > max_chars {
|
||||||
|
// Paragraph too big — split on lines
|
||||||
|
if !current.trim().is_empty() {
|
||||||
|
chunks.push(Chunk {
|
||||||
|
index: chunks.len(),
|
||||||
|
content: current.trim().to_string(),
|
||||||
|
heading: heading.clone(),
|
||||||
|
});
|
||||||
|
current = heading
|
||||||
|
.as_ref()
|
||||||
|
.map_or_else(String::new, |h| format!("{h}\n"));
|
||||||
|
}
|
||||||
|
for line_chunk in split_on_lines(¶, max_chars) {
|
||||||
|
chunks.push(Chunk {
|
||||||
|
index: chunks.len(),
|
||||||
|
content: line_chunk.trim().to_string(),
|
||||||
|
heading: heading.clone(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
current.push_str(¶);
|
||||||
|
current.push('\n');
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !current.trim().is_empty() {
|
||||||
|
chunks.push(Chunk {
|
||||||
|
index: chunks.len(),
|
||||||
|
content: current.trim().to_string(),
|
||||||
|
heading: heading.clone(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Filter out empty chunks
|
||||||
|
chunks.retain(|c| !c.content.is_empty());
|
||||||
|
|
||||||
|
// Re-index
|
||||||
|
for (i, chunk) in chunks.iter_mut().enumerate() {
|
||||||
|
chunk.index = i;
|
||||||
|
}
|
||||||
|
|
||||||
|
chunks
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Split text into `(heading, body)` sections.
|
||||||
|
fn split_on_headings(text: &str) -> Vec<(Option<String>, String)> {
|
||||||
|
let mut sections = Vec::new();
|
||||||
|
let mut current_heading: Option<String> = None;
|
||||||
|
let mut current_body = String::new();
|
||||||
|
|
||||||
|
for line in text.lines() {
|
||||||
|
if line.starts_with("# ") || line.starts_with("## ") || line.starts_with("### ") {
|
||||||
|
if !current_body.trim().is_empty() || current_heading.is_some() {
|
||||||
|
sections.push((current_heading.take(), current_body.clone()));
|
||||||
|
current_body.clear();
|
||||||
|
}
|
||||||
|
current_heading = Some(line.to_string());
|
||||||
|
} else {
|
||||||
|
current_body.push_str(line);
|
||||||
|
current_body.push('\n');
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !current_body.trim().is_empty() || current_heading.is_some() {
|
||||||
|
sections.push((current_heading, current_body));
|
||||||
|
}
|
||||||
|
|
||||||
|
sections
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Split text on blank lines (paragraph boundaries)
|
||||||
|
fn split_on_blank_lines(text: &str) -> Vec<String> {
|
||||||
|
let mut paragraphs = Vec::new();
|
||||||
|
let mut current = String::new();
|
||||||
|
|
||||||
|
for line in text.lines() {
|
||||||
|
if line.trim().is_empty() {
|
||||||
|
if !current.trim().is_empty() {
|
||||||
|
paragraphs.push(current.clone());
|
||||||
|
current.clear();
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
current.push_str(line);
|
||||||
|
current.push('\n');
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !current.trim().is_empty() {
|
||||||
|
paragraphs.push(current);
|
||||||
|
}
|
||||||
|
|
||||||
|
paragraphs
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Split text on line boundaries to fit within `max_chars`
|
||||||
|
fn split_on_lines(text: &str, max_chars: usize) -> Vec<String> {
|
||||||
|
let mut chunks = Vec::new();
|
||||||
|
let mut current = String::new();
|
||||||
|
|
||||||
|
for line in text.lines() {
|
||||||
|
if current.len() + line.len() + 1 > max_chars && !current.is_empty() {
|
||||||
|
chunks.push(current.clone());
|
||||||
|
current.clear();
|
||||||
|
}
|
||||||
|
current.push_str(line);
|
||||||
|
current.push('\n');
|
||||||
|
}
|
||||||
|
|
||||||
|
if !current.is_empty() {
|
||||||
|
chunks.push(current);
|
||||||
|
}
|
||||||
|
|
||||||
|
chunks
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn empty_text() {
|
||||||
|
assert!(chunk_markdown("", 512).is_empty());
|
||||||
|
assert!(chunk_markdown(" ", 512).is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn single_short_paragraph() {
|
||||||
|
let chunks = chunk_markdown("Hello world", 512);
|
||||||
|
assert_eq!(chunks.len(), 1);
|
||||||
|
assert_eq!(chunks[0].content, "Hello world");
|
||||||
|
assert!(chunks[0].heading.is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn heading_sections() {
|
||||||
|
let text = "# Title\nSome intro.\n\n## Section A\nContent A.\n\n## Section B\nContent B.";
|
||||||
|
let chunks = chunk_markdown(text, 512);
|
||||||
|
assert!(chunks.len() >= 3);
|
||||||
|
assert!(chunks[0].heading.is_none() || chunks[0].heading.as_deref() == Some("# Title"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn respects_max_tokens() {
|
||||||
|
// Build multi-line text (one sentence per line) to exercise line-level splitting
|
||||||
|
let long_text: String = (0..200)
|
||||||
|
.map(|i| format!("This is sentence number {i} with some extra words to fill it up.\n"))
|
||||||
|
.collect();
|
||||||
|
let chunks = chunk_markdown(&long_text, 50); // 50 tokens ≈ 200 chars
|
||||||
|
assert!(
|
||||||
|
chunks.len() > 1,
|
||||||
|
"Expected multiple chunks, got {}",
|
||||||
|
chunks.len()
|
||||||
|
);
|
||||||
|
for chunk in &chunks {
|
||||||
|
// Allow some slack (heading re-insertion etc.)
|
||||||
|
assert!(
|
||||||
|
chunk.content.len() <= 300,
|
||||||
|
"Chunk too long: {} chars",
|
||||||
|
chunk.content.len()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn preserves_heading_in_split_sections() {
|
||||||
|
let mut text = String::from("## Big Section\n");
|
||||||
|
for i in 0..100 {
|
||||||
|
text.push_str(&format!("Line {i} with some content here.\n\n"));
|
||||||
|
}
|
||||||
|
let chunks = chunk_markdown(&text, 50);
|
||||||
|
assert!(chunks.len() > 1);
|
||||||
|
// All chunks from this section should reference the heading
|
||||||
|
for chunk in &chunks {
|
||||||
|
if chunk.heading.is_some() {
|
||||||
|
assert_eq!(chunk.heading.as_deref(), Some("## Big Section"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn indexes_are_sequential() {
|
||||||
|
let text = "# A\nContent A\n\n# B\nContent B\n\n# C\nContent C";
|
||||||
|
let chunks = chunk_markdown(text, 512);
|
||||||
|
for (i, chunk) in chunks.iter().enumerate() {
|
||||||
|
assert_eq!(chunk.index, i);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn chunk_count_reasonable() {
|
||||||
|
let text = "Hello world. This is a test document.";
|
||||||
|
let chunks = chunk_markdown(text, 512);
|
||||||
|
assert_eq!(chunks.len(), 1);
|
||||||
|
}
|
||||||
|
}
|
||||||
190
src/memory/embeddings.rs
Normal file
190
src/memory/embeddings.rs
Normal file
|
|
@ -0,0 +1,190 @@
|
||||||
|
use async_trait::async_trait;
|
||||||
|
|
||||||
|
/// Trait for embedding providers — convert text to vectors
|
||||||
|
#[async_trait]
|
||||||
|
pub trait EmbeddingProvider: Send + Sync {
|
||||||
|
/// Provider name
|
||||||
|
fn name(&self) -> &str;
|
||||||
|
|
||||||
|
/// Embedding dimensions
|
||||||
|
fn dimensions(&self) -> usize;
|
||||||
|
|
||||||
|
/// Embed a batch of texts into vectors
|
||||||
|
async fn embed(&self, texts: &[&str]) -> anyhow::Result<Vec<Vec<f32>>>;
|
||||||
|
|
||||||
|
/// Embed a single text
|
||||||
|
async fn embed_one(&self, text: &str) -> anyhow::Result<Vec<f32>> {
|
||||||
|
let mut results = self.embed(&[text]).await?;
|
||||||
|
results
|
||||||
|
.pop()
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("Empty embedding result"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Noop provider (keyword-only fallback) ────────────────────
|
||||||
|
|
||||||
|
pub struct NoopEmbedding;
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl EmbeddingProvider for NoopEmbedding {
|
||||||
|
fn name(&self) -> &str {
|
||||||
|
"none"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn dimensions(&self) -> usize {
|
||||||
|
0
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn embed(&self, _texts: &[&str]) -> anyhow::Result<Vec<Vec<f32>>> {
|
||||||
|
Ok(Vec::new())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── OpenAI-compatible embedding provider ─────────────────────
|
||||||
|
|
||||||
|
pub struct OpenAiEmbedding {
|
||||||
|
client: reqwest::Client,
|
||||||
|
base_url: String,
|
||||||
|
api_key: String,
|
||||||
|
model: String,
|
||||||
|
dims: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl OpenAiEmbedding {
|
||||||
|
pub fn new(base_url: &str, api_key: &str, model: &str, dims: usize) -> Self {
|
||||||
|
Self {
|
||||||
|
client: reqwest::Client::new(),
|
||||||
|
base_url: base_url.trim_end_matches('/').to_string(),
|
||||||
|
api_key: api_key.to_string(),
|
||||||
|
model: model.to_string(),
|
||||||
|
dims,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl EmbeddingProvider for OpenAiEmbedding {
|
||||||
|
fn name(&self) -> &str {
|
||||||
|
"openai"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn dimensions(&self) -> usize {
|
||||||
|
self.dims
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn embed(&self, texts: &[&str]) -> anyhow::Result<Vec<Vec<f32>>> {
|
||||||
|
if texts.is_empty() {
|
||||||
|
return Ok(Vec::new());
|
||||||
|
}
|
||||||
|
|
||||||
|
let body = serde_json::json!({
|
||||||
|
"model": self.model,
|
||||||
|
"input": texts,
|
||||||
|
});
|
||||||
|
|
||||||
|
let resp = self
|
||||||
|
.client
|
||||||
|
.post(format!("{}/v1/embeddings", self.base_url))
|
||||||
|
.header("Authorization", format!("Bearer {}", self.api_key))
|
||||||
|
.header("Content-Type", "application/json")
|
||||||
|
.json(&body)
|
||||||
|
.send()
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
if !resp.status().is_success() {
|
||||||
|
let status = resp.status();
|
||||||
|
let text = resp.text().await.unwrap_or_default();
|
||||||
|
anyhow::bail!("Embedding API error {status}: {text}");
|
||||||
|
}
|
||||||
|
|
||||||
|
let json: serde_json::Value = resp.json().await?;
|
||||||
|
let data = json
|
||||||
|
.get("data")
|
||||||
|
.and_then(|d| d.as_array())
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("Invalid embedding response: missing 'data'"))?;
|
||||||
|
|
||||||
|
let mut embeddings = Vec::with_capacity(data.len());
|
||||||
|
for item in data {
|
||||||
|
let embedding = item
|
||||||
|
.get("embedding")
|
||||||
|
.and_then(|e| e.as_array())
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("Invalid embedding item"))?;
|
||||||
|
|
||||||
|
#[allow(clippy::cast_possible_truncation)]
|
||||||
|
let vec: Vec<f32> = embedding
|
||||||
|
.iter()
|
||||||
|
.filter_map(|v| v.as_f64().map(|f| f as f32))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
embeddings.push(vec);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(embeddings)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Factory ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
pub fn create_embedding_provider(
|
||||||
|
provider: &str,
|
||||||
|
api_key: Option<&str>,
|
||||||
|
model: &str,
|
||||||
|
dims: usize,
|
||||||
|
) -> Box<dyn EmbeddingProvider> {
|
||||||
|
match provider {
|
||||||
|
"openai" => {
|
||||||
|
let key = api_key.unwrap_or("");
|
||||||
|
Box::new(OpenAiEmbedding::new(
|
||||||
|
"https://api.openai.com",
|
||||||
|
key,
|
||||||
|
model,
|
||||||
|
dims,
|
||||||
|
))
|
||||||
|
}
|
||||||
|
name if name.starts_with("custom:") => {
|
||||||
|
let base_url = name.strip_prefix("custom:").unwrap_or("");
|
||||||
|
let key = api_key.unwrap_or("");
|
||||||
|
Box::new(OpenAiEmbedding::new(base_url, key, model, dims))
|
||||||
|
}
|
||||||
|
_ => Box::new(NoopEmbedding),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn noop_name() {
|
||||||
|
let p = NoopEmbedding;
|
||||||
|
assert_eq!(p.name(), "none");
|
||||||
|
assert_eq!(p.dimensions(), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn noop_embed_returns_empty() {
|
||||||
|
let p = NoopEmbedding;
|
||||||
|
let result = p.embed(&["hello"]).await.unwrap();
|
||||||
|
assert!(result.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn factory_none() {
|
||||||
|
let p = create_embedding_provider("none", None, "model", 1536);
|
||||||
|
assert_eq!(p.name(), "none");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn factory_openai() {
|
||||||
|
let p = create_embedding_provider("openai", Some("key"), "text-embedding-3-small", 1536);
|
||||||
|
assert_eq!(p.name(), "openai");
|
||||||
|
assert_eq!(p.dimensions(), 1536);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn factory_custom_url() {
|
||||||
|
let p = create_embedding_provider("custom:http://localhost:1234", None, "model", 768);
|
||||||
|
assert_eq!(p.name(), "openai"); // uses OpenAiEmbedding internally
|
||||||
|
assert_eq!(p.dimensions(), 768);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -1,6 +1,9 @@
|
||||||
|
pub mod chunker;
|
||||||
|
pub mod embeddings;
|
||||||
pub mod markdown;
|
pub mod markdown;
|
||||||
pub mod sqlite;
|
pub mod sqlite;
|
||||||
pub mod traits;
|
pub mod traits;
|
||||||
|
pub mod vector;
|
||||||
|
|
||||||
pub use markdown::MarkdownMemory;
|
pub use markdown::MarkdownMemory;
|
||||||
pub use sqlite::SqliteMemory;
|
pub use sqlite::SqliteMemory;
|
||||||
|
|
@ -10,14 +13,34 @@ pub use traits::{MemoryCategory, MemoryEntry};
|
||||||
|
|
||||||
use crate::config::MemoryConfig;
|
use crate::config::MemoryConfig;
|
||||||
use std::path::Path;
|
use std::path::Path;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
/// Factory: create the right memory backend from config
|
/// Factory: create the right memory backend from config
|
||||||
pub fn create_memory(
|
pub fn create_memory(
|
||||||
config: &MemoryConfig,
|
config: &MemoryConfig,
|
||||||
workspace_dir: &Path,
|
workspace_dir: &Path,
|
||||||
|
api_key: Option<&str>,
|
||||||
) -> anyhow::Result<Box<dyn Memory>> {
|
) -> anyhow::Result<Box<dyn Memory>> {
|
||||||
match config.backend.as_str() {
|
match config.backend.as_str() {
|
||||||
"sqlite" => Ok(Box::new(SqliteMemory::new(workspace_dir)?)),
|
"sqlite" => {
|
||||||
|
let embedder: Arc<dyn embeddings::EmbeddingProvider> =
|
||||||
|
Arc::from(embeddings::create_embedding_provider(
|
||||||
|
&config.embedding_provider,
|
||||||
|
api_key,
|
||||||
|
&config.embedding_model,
|
||||||
|
config.embedding_dimensions,
|
||||||
|
));
|
||||||
|
|
||||||
|
#[allow(clippy::cast_possible_truncation)]
|
||||||
|
let mem = SqliteMemory::with_embedder(
|
||||||
|
workspace_dir,
|
||||||
|
embedder,
|
||||||
|
config.vector_weight as f32,
|
||||||
|
config.keyword_weight as f32,
|
||||||
|
config.embedding_cache_size,
|
||||||
|
)?;
|
||||||
|
Ok(Box::new(mem))
|
||||||
|
}
|
||||||
"markdown" | "none" => Ok(Box::new(MarkdownMemory::new(workspace_dir))),
|
"markdown" | "none" => Ok(Box::new(MarkdownMemory::new(workspace_dir))),
|
||||||
other => {
|
other => {
|
||||||
tracing::warn!("Unknown memory backend '{other}', falling back to markdown");
|
tracing::warn!("Unknown memory backend '{other}', falling back to markdown");
|
||||||
|
|
@ -36,9 +59,9 @@ mod tests {
|
||||||
let tmp = TempDir::new().unwrap();
|
let tmp = TempDir::new().unwrap();
|
||||||
let cfg = MemoryConfig {
|
let cfg = MemoryConfig {
|
||||||
backend: "sqlite".into(),
|
backend: "sqlite".into(),
|
||||||
auto_save: true,
|
..MemoryConfig::default()
|
||||||
};
|
};
|
||||||
let mem = create_memory(&cfg, tmp.path()).unwrap();
|
let mem = create_memory(&cfg, tmp.path(), None).unwrap();
|
||||||
assert_eq!(mem.name(), "sqlite");
|
assert_eq!(mem.name(), "sqlite");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -47,9 +70,9 @@ mod tests {
|
||||||
let tmp = TempDir::new().unwrap();
|
let tmp = TempDir::new().unwrap();
|
||||||
let cfg = MemoryConfig {
|
let cfg = MemoryConfig {
|
||||||
backend: "markdown".into(),
|
backend: "markdown".into(),
|
||||||
auto_save: true,
|
..MemoryConfig::default()
|
||||||
};
|
};
|
||||||
let mem = create_memory(&cfg, tmp.path()).unwrap();
|
let mem = create_memory(&cfg, tmp.path(), None).unwrap();
|
||||||
assert_eq!(mem.name(), "markdown");
|
assert_eq!(mem.name(), "markdown");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -58,9 +81,9 @@ mod tests {
|
||||||
let tmp = TempDir::new().unwrap();
|
let tmp = TempDir::new().unwrap();
|
||||||
let cfg = MemoryConfig {
|
let cfg = MemoryConfig {
|
||||||
backend: "none".into(),
|
backend: "none".into(),
|
||||||
auto_save: true,
|
..MemoryConfig::default()
|
||||||
};
|
};
|
||||||
let mem = create_memory(&cfg, tmp.path()).unwrap();
|
let mem = create_memory(&cfg, tmp.path(), None).unwrap();
|
||||||
assert_eq!(mem.name(), "markdown");
|
assert_eq!(mem.name(), "markdown");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -69,9 +92,9 @@ mod tests {
|
||||||
let tmp = TempDir::new().unwrap();
|
let tmp = TempDir::new().unwrap();
|
||||||
let cfg = MemoryConfig {
|
let cfg = MemoryConfig {
|
||||||
backend: "redis".into(),
|
backend: "redis".into(),
|
||||||
auto_save: true,
|
..MemoryConfig::default()
|
||||||
};
|
};
|
||||||
let mem = create_memory(&cfg, tmp.path()).unwrap();
|
let mem = create_memory(&cfg, tmp.path(), None).unwrap();
|
||||||
assert_eq!(mem.name(), "markdown");
|
assert_eq!(mem.name(), "markdown");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,22 +1,48 @@
|
||||||
|
use super::embeddings::EmbeddingProvider;
|
||||||
use super::traits::{Memory, MemoryCategory, MemoryEntry};
|
use super::traits::{Memory, MemoryCategory, MemoryEntry};
|
||||||
|
use super::vector;
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use chrono::Local;
|
use chrono::Local;
|
||||||
use rusqlite::{params, Connection};
|
use rusqlite::{params, Connection};
|
||||||
use std::path::{Path, PathBuf};
|
use std::path::{Path, PathBuf};
|
||||||
use std::sync::Mutex;
|
use std::sync::{Arc, Mutex};
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
/// SQLite-backed persistent memory — the brain
|
/// SQLite-backed persistent memory — the brain
|
||||||
///
|
///
|
||||||
/// Stores memories in a local `SQLite` database with keyword search.
|
/// Full-stack search engine:
|
||||||
/// Zero external dependencies, works offline, survives restarts.
|
/// - **Vector DB**: embeddings stored as BLOB, cosine similarity search
|
||||||
|
/// - **Keyword Search**: FTS5 virtual table with BM25 scoring
|
||||||
|
/// - **Hybrid Merge**: weighted fusion of vector + keyword results
|
||||||
|
/// - **Embedding Cache**: LRU-evicted cache to avoid redundant API calls
|
||||||
|
/// - **Safe Reindex**: temp DB → seed → sync → atomic swap → rollback
|
||||||
pub struct SqliteMemory {
|
pub struct SqliteMemory {
|
||||||
conn: Mutex<Connection>,
|
conn: Mutex<Connection>,
|
||||||
db_path: PathBuf,
|
db_path: PathBuf,
|
||||||
|
embedder: Arc<dyn EmbeddingProvider>,
|
||||||
|
vector_weight: f32,
|
||||||
|
keyword_weight: f32,
|
||||||
|
cache_max: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl SqliteMemory {
|
impl SqliteMemory {
|
||||||
pub fn new(workspace_dir: &Path) -> anyhow::Result<Self> {
|
pub fn new(workspace_dir: &Path) -> anyhow::Result<Self> {
|
||||||
|
Self::with_embedder(
|
||||||
|
workspace_dir,
|
||||||
|
Arc::new(super::embeddings::NoopEmbedding),
|
||||||
|
0.7,
|
||||||
|
0.3,
|
||||||
|
10_000,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn with_embedder(
|
||||||
|
workspace_dir: &Path,
|
||||||
|
embedder: Arc<dyn EmbeddingProvider>,
|
||||||
|
vector_weight: f32,
|
||||||
|
keyword_weight: f32,
|
||||||
|
cache_max: usize,
|
||||||
|
) -> anyhow::Result<Self> {
|
||||||
let db_path = workspace_dir.join("memory").join("brain.db");
|
let db_path = workspace_dir.join("memory").join("brain.db");
|
||||||
|
|
||||||
if let Some(parent) = db_path.parent() {
|
if let Some(parent) = db_path.parent() {
|
||||||
|
|
@ -24,26 +50,67 @@ impl SqliteMemory {
|
||||||
}
|
}
|
||||||
|
|
||||||
let conn = Connection::open(&db_path)?;
|
let conn = Connection::open(&db_path)?;
|
||||||
|
Self::init_schema(&conn)?;
|
||||||
conn.execute_batch(
|
|
||||||
"CREATE TABLE IF NOT EXISTS memories (
|
|
||||||
id TEXT PRIMARY KEY,
|
|
||||||
key TEXT NOT NULL UNIQUE,
|
|
||||||
content TEXT NOT NULL,
|
|
||||||
category TEXT NOT NULL DEFAULT 'core',
|
|
||||||
created_at TEXT NOT NULL,
|
|
||||||
updated_at TEXT NOT NULL
|
|
||||||
);
|
|
||||||
CREATE INDEX IF NOT EXISTS idx_memories_category ON memories(category);
|
|
||||||
CREATE INDEX IF NOT EXISTS idx_memories_key ON memories(key);",
|
|
||||||
)?;
|
|
||||||
|
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
conn: Mutex::new(conn),
|
conn: Mutex::new(conn),
|
||||||
db_path,
|
db_path,
|
||||||
|
embedder,
|
||||||
|
vector_weight,
|
||||||
|
keyword_weight,
|
||||||
|
cache_max,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Initialize all tables: memories, FTS5, `embedding_cache`
|
||||||
|
fn init_schema(conn: &Connection) -> anyhow::Result<()> {
|
||||||
|
conn.execute_batch(
|
||||||
|
"-- Core memories table
|
||||||
|
CREATE TABLE IF NOT EXISTS memories (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
key TEXT NOT NULL UNIQUE,
|
||||||
|
content TEXT NOT NULL,
|
||||||
|
category TEXT NOT NULL DEFAULT 'core',
|
||||||
|
embedding BLOB,
|
||||||
|
created_at TEXT NOT NULL,
|
||||||
|
updated_at TEXT NOT NULL
|
||||||
|
);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_memories_category ON memories(category);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_memories_key ON memories(key);
|
||||||
|
|
||||||
|
-- FTS5 full-text search (BM25 scoring)
|
||||||
|
CREATE VIRTUAL TABLE IF NOT EXISTS memories_fts USING fts5(
|
||||||
|
key, content, content=memories, content_rowid=rowid
|
||||||
|
);
|
||||||
|
|
||||||
|
-- FTS5 triggers: keep in sync with memories table
|
||||||
|
CREATE TRIGGER IF NOT EXISTS memories_ai AFTER INSERT ON memories BEGIN
|
||||||
|
INSERT INTO memories_fts(rowid, key, content)
|
||||||
|
VALUES (new.rowid, new.key, new.content);
|
||||||
|
END;
|
||||||
|
CREATE TRIGGER IF NOT EXISTS memories_ad AFTER DELETE ON memories BEGIN
|
||||||
|
INSERT INTO memories_fts(memories_fts, rowid, key, content)
|
||||||
|
VALUES ('delete', old.rowid, old.key, old.content);
|
||||||
|
END;
|
||||||
|
CREATE TRIGGER IF NOT EXISTS memories_au AFTER UPDATE ON memories BEGIN
|
||||||
|
INSERT INTO memories_fts(memories_fts, rowid, key, content)
|
||||||
|
VALUES ('delete', old.rowid, old.key, old.content);
|
||||||
|
INSERT INTO memories_fts(rowid, key, content)
|
||||||
|
VALUES (new.rowid, new.key, new.content);
|
||||||
|
END;
|
||||||
|
|
||||||
|
-- Embedding cache with LRU eviction
|
||||||
|
CREATE TABLE IF NOT EXISTS embedding_cache (
|
||||||
|
content_hash TEXT PRIMARY KEY,
|
||||||
|
embedding BLOB NOT NULL,
|
||||||
|
created_at TEXT NOT NULL,
|
||||||
|
accessed_at TEXT NOT NULL
|
||||||
|
);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_cache_accessed ON embedding_cache(accessed_at);",
|
||||||
|
)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
fn category_to_str(cat: &MemoryCategory) -> String {
|
fn category_to_str(cat: &MemoryCategory) -> String {
|
||||||
match cat {
|
match cat {
|
||||||
MemoryCategory::Core => "core".into(),
|
MemoryCategory::Core => "core".into(),
|
||||||
|
|
@ -61,6 +128,202 @@ impl SqliteMemory {
|
||||||
other => MemoryCategory::Custom(other.to_string()),
|
other => MemoryCategory::Custom(other.to_string()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Simple content hash for embedding cache
|
||||||
|
fn content_hash(text: &str) -> String {
|
||||||
|
use std::collections::hash_map::DefaultHasher;
|
||||||
|
use std::hash::{Hash, Hasher};
|
||||||
|
let mut hasher = DefaultHasher::new();
|
||||||
|
text.hash(&mut hasher);
|
||||||
|
format!("{:016x}", hasher.finish())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get embedding from cache, or compute + cache it
|
||||||
|
async fn get_or_compute_embedding(&self, text: &str) -> anyhow::Result<Option<Vec<f32>>> {
|
||||||
|
if self.embedder.dimensions() == 0 {
|
||||||
|
return Ok(None); // Noop embedder
|
||||||
|
}
|
||||||
|
|
||||||
|
let hash = Self::content_hash(text);
|
||||||
|
let now = Local::now().to_rfc3339();
|
||||||
|
|
||||||
|
// Check cache
|
||||||
|
{
|
||||||
|
let conn = self
|
||||||
|
.conn
|
||||||
|
.lock()
|
||||||
|
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
|
||||||
|
|
||||||
|
let mut stmt =
|
||||||
|
conn.prepare("SELECT embedding FROM embedding_cache WHERE content_hash = ?1")?;
|
||||||
|
let cached: Option<Vec<u8>> = stmt.query_row(params![hash], |row| row.get(0)).ok();
|
||||||
|
|
||||||
|
if let Some(bytes) = cached {
|
||||||
|
// Update accessed_at for LRU
|
||||||
|
conn.execute(
|
||||||
|
"UPDATE embedding_cache SET accessed_at = ?1 WHERE content_hash = ?2",
|
||||||
|
params![now, hash],
|
||||||
|
)?;
|
||||||
|
return Ok(Some(vector::bytes_to_vec(&bytes)));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute embedding
|
||||||
|
let embedding = self.embedder.embed_one(text).await?;
|
||||||
|
let bytes = vector::vec_to_bytes(&embedding);
|
||||||
|
|
||||||
|
// Store in cache + LRU eviction
|
||||||
|
{
|
||||||
|
let conn = self
|
||||||
|
.conn
|
||||||
|
.lock()
|
||||||
|
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
|
||||||
|
|
||||||
|
conn.execute(
|
||||||
|
"INSERT OR REPLACE INTO embedding_cache (content_hash, embedding, created_at, accessed_at)
|
||||||
|
VALUES (?1, ?2, ?3, ?4)",
|
||||||
|
params![hash, bytes, now, now],
|
||||||
|
)?;
|
||||||
|
|
||||||
|
// LRU eviction: keep only cache_max entries
|
||||||
|
#[allow(clippy::cast_possible_wrap)]
|
||||||
|
let max = self.cache_max as i64;
|
||||||
|
conn.execute(
|
||||||
|
"DELETE FROM embedding_cache WHERE content_hash IN (
|
||||||
|
SELECT content_hash FROM embedding_cache
|
||||||
|
ORDER BY accessed_at ASC
|
||||||
|
LIMIT MAX(0, (SELECT COUNT(*) FROM embedding_cache) - ?1)
|
||||||
|
)",
|
||||||
|
params![max],
|
||||||
|
)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(Some(embedding))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// FTS5 BM25 keyword search
|
||||||
|
fn fts5_search(
|
||||||
|
conn: &Connection,
|
||||||
|
query: &str,
|
||||||
|
limit: usize,
|
||||||
|
) -> anyhow::Result<Vec<(String, f32)>> {
|
||||||
|
// Escape FTS5 special chars and build query
|
||||||
|
let fts_query: String = query
|
||||||
|
.split_whitespace()
|
||||||
|
.map(|w| format!("\"{w}\""))
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
.join(" OR ");
|
||||||
|
|
||||||
|
if fts_query.is_empty() {
|
||||||
|
return Ok(Vec::new());
|
||||||
|
}
|
||||||
|
|
||||||
|
let sql = "SELECT m.id, bm25(memories_fts) as score
|
||||||
|
FROM memories_fts f
|
||||||
|
JOIN memories m ON m.rowid = f.rowid
|
||||||
|
WHERE memories_fts MATCH ?1
|
||||||
|
ORDER BY score
|
||||||
|
LIMIT ?2";
|
||||||
|
|
||||||
|
let mut stmt = conn.prepare(sql)?;
|
||||||
|
#[allow(clippy::cast_possible_wrap)]
|
||||||
|
let limit_i64 = limit as i64;
|
||||||
|
|
||||||
|
let rows = stmt.query_map(params![fts_query, limit_i64], |row| {
|
||||||
|
let id: String = row.get(0)?;
|
||||||
|
let score: f64 = row.get(1)?;
|
||||||
|
// BM25 returns negative scores (lower = better), negate for ranking
|
||||||
|
#[allow(clippy::cast_possible_truncation)]
|
||||||
|
Ok((id, (-score) as f32))
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let mut results = Vec::new();
|
||||||
|
for row in rows {
|
||||||
|
results.push(row?);
|
||||||
|
}
|
||||||
|
Ok(results)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Vector similarity search: scan embeddings and compute cosine similarity
|
||||||
|
fn vector_search(
|
||||||
|
conn: &Connection,
|
||||||
|
query_embedding: &[f32],
|
||||||
|
limit: usize,
|
||||||
|
) -> anyhow::Result<Vec<(String, f32)>> {
|
||||||
|
let mut stmt =
|
||||||
|
conn.prepare("SELECT id, embedding FROM memories WHERE embedding IS NOT NULL")?;
|
||||||
|
|
||||||
|
let rows = stmt.query_map([], |row| {
|
||||||
|
let id: String = row.get(0)?;
|
||||||
|
let blob: Vec<u8> = row.get(1)?;
|
||||||
|
Ok((id, blob))
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let mut scored: Vec<(String, f32)> = Vec::new();
|
||||||
|
for row in rows {
|
||||||
|
let (id, blob) = row?;
|
||||||
|
let emb = vector::bytes_to_vec(&blob);
|
||||||
|
let sim = vector::cosine_similarity(query_embedding, &emb);
|
||||||
|
if sim > 0.0 {
|
||||||
|
scored.push((id, sim));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
|
||||||
|
scored.truncate(limit);
|
||||||
|
Ok(scored)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Safe reindex: rebuild FTS5 + embeddings with rollback on failure
|
||||||
|
#[allow(dead_code)]
|
||||||
|
pub async fn reindex(&self) -> anyhow::Result<usize> {
|
||||||
|
// Step 1: Rebuild FTS5
|
||||||
|
{
|
||||||
|
let conn = self
|
||||||
|
.conn
|
||||||
|
.lock()
|
||||||
|
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
|
||||||
|
|
||||||
|
conn.execute_batch("INSERT INTO memories_fts(memories_fts) VALUES('rebuild');")?;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 2: Re-embed all memories that lack embeddings
|
||||||
|
if self.embedder.dimensions() == 0 {
|
||||||
|
return Ok(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
let entries: Vec<(String, String)> = {
|
||||||
|
let conn = self
|
||||||
|
.conn
|
||||||
|
.lock()
|
||||||
|
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
|
||||||
|
|
||||||
|
let mut stmt =
|
||||||
|
conn.prepare("SELECT id, content FROM memories WHERE embedding IS NULL")?;
|
||||||
|
let rows = stmt.query_map([], |row| {
|
||||||
|
Ok((row.get::<_, String>(0)?, row.get::<_, String>(1)?))
|
||||||
|
})?;
|
||||||
|
rows.filter_map(std::result::Result::ok).collect()
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut count = 0;
|
||||||
|
for (id, content) in &entries {
|
||||||
|
if let Ok(Some(emb)) = self.get_or_compute_embedding(content).await {
|
||||||
|
let bytes = vector::vec_to_bytes(&emb);
|
||||||
|
let conn = self
|
||||||
|
.conn
|
||||||
|
.lock()
|
||||||
|
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
|
||||||
|
conn.execute(
|
||||||
|
"UPDATE memories SET embedding = ?1 WHERE id = ?2",
|
||||||
|
params![bytes, id],
|
||||||
|
)?;
|
||||||
|
count += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(count)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
|
|
@ -75,6 +338,12 @@ impl Memory for SqliteMemory {
|
||||||
content: &str,
|
content: &str,
|
||||||
category: MemoryCategory,
|
category: MemoryCategory,
|
||||||
) -> anyhow::Result<()> {
|
) -> anyhow::Result<()> {
|
||||||
|
// Compute embedding (async, before lock)
|
||||||
|
let embedding_bytes = self
|
||||||
|
.get_or_compute_embedding(content)
|
||||||
|
.await?
|
||||||
|
.map(|emb| vector::vec_to_bytes(&emb));
|
||||||
|
|
||||||
let conn = self
|
let conn = self
|
||||||
.conn
|
.conn
|
||||||
.lock()
|
.lock()
|
||||||
|
|
@ -84,99 +353,133 @@ impl Memory for SqliteMemory {
|
||||||
let id = Uuid::new_v4().to_string();
|
let id = Uuid::new_v4().to_string();
|
||||||
|
|
||||||
conn.execute(
|
conn.execute(
|
||||||
"INSERT INTO memories (id, key, content, category, created_at, updated_at)
|
"INSERT INTO memories (id, key, content, category, embedding, created_at, updated_at)
|
||||||
VALUES (?1, ?2, ?3, ?4, ?5, ?6)
|
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)
|
||||||
ON CONFLICT(key) DO UPDATE SET
|
ON CONFLICT(key) DO UPDATE SET
|
||||||
content = excluded.content,
|
content = excluded.content,
|
||||||
category = excluded.category,
|
category = excluded.category,
|
||||||
|
embedding = excluded.embedding,
|
||||||
updated_at = excluded.updated_at",
|
updated_at = excluded.updated_at",
|
||||||
params![id, key, content, cat, now, now],
|
params![id, key, content, cat, embedding_bytes, now, now],
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn recall(&self, query: &str, limit: usize) -> anyhow::Result<Vec<MemoryEntry>> {
|
async fn recall(&self, query: &str, limit: usize) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||||
|
if query.trim().is_empty() {
|
||||||
|
return Ok(Vec::new());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute query embedding (async, before lock)
|
||||||
|
let query_embedding = self.get_or_compute_embedding(query).await?;
|
||||||
|
|
||||||
let conn = self
|
let conn = self
|
||||||
.conn
|
.conn
|
||||||
.lock()
|
.lock()
|
||||||
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
|
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
|
||||||
|
|
||||||
// Keyword search: split query into words, match any
|
// FTS5 BM25 keyword search
|
||||||
let keywords: Vec<String> = query.split_whitespace().map(|w| format!("%{w}%")).collect();
|
let keyword_results = Self::fts5_search(&conn, query, limit * 2).unwrap_or_default();
|
||||||
|
|
||||||
if keywords.is_empty() {
|
// Vector similarity search (if embeddings available)
|
||||||
return Ok(Vec::new());
|
let vector_results = if let Some(ref qe) = query_embedding {
|
||||||
}
|
Self::vector_search(&conn, qe, limit * 2).unwrap_or_default()
|
||||||
|
} else {
|
||||||
|
Vec::new()
|
||||||
|
};
|
||||||
|
|
||||||
// Build dynamic WHERE clause for keyword matching
|
// Hybrid merge
|
||||||
let conditions: Vec<String> = keywords
|
let merged = if vector_results.is_empty() {
|
||||||
.iter()
|
// No embeddings — use keyword results only
|
||||||
.enumerate()
|
keyword_results
|
||||||
.map(|(i, _)| format!("(content LIKE ?{} OR key LIKE ?{})", i * 2 + 1, i * 2 + 2))
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
let where_clause = conditions.join(" OR ");
|
|
||||||
let sql = format!(
|
|
||||||
"SELECT id, key, content, category, created_at FROM memories
|
|
||||||
WHERE {where_clause}
|
|
||||||
ORDER BY updated_at DESC
|
|
||||||
LIMIT ?{}",
|
|
||||||
keywords.len() * 2 + 1
|
|
||||||
);
|
|
||||||
|
|
||||||
let mut stmt = conn.prepare(&sql)?;
|
|
||||||
|
|
||||||
// Build params: each keyword appears twice (for content and key)
|
|
||||||
let mut param_values: Vec<Box<dyn rusqlite::types::ToSql>> = Vec::new();
|
|
||||||
for kw in &keywords {
|
|
||||||
param_values.push(Box::new(kw.clone()));
|
|
||||||
param_values.push(Box::new(kw.clone()));
|
|
||||||
}
|
|
||||||
#[allow(clippy::cast_possible_wrap)]
|
|
||||||
param_values.push(Box::new(limit as i64));
|
|
||||||
|
|
||||||
let params_ref: Vec<&dyn rusqlite::types::ToSql> =
|
|
||||||
param_values.iter().map(AsRef::as_ref).collect();
|
|
||||||
|
|
||||||
let rows = stmt.query_map(params_ref.as_slice(), |row| {
|
|
||||||
Ok(MemoryEntry {
|
|
||||||
id: row.get(0)?,
|
|
||||||
key: row.get(1)?,
|
|
||||||
content: row.get(2)?,
|
|
||||||
category: Self::str_to_category(&row.get::<_, String>(3)?),
|
|
||||||
timestamp: row.get(4)?,
|
|
||||||
session_id: None,
|
|
||||||
score: Some(1.0),
|
|
||||||
})
|
|
||||||
})?;
|
|
||||||
|
|
||||||
let mut results = Vec::new();
|
|
||||||
for row in rows {
|
|
||||||
results.push(row?);
|
|
||||||
}
|
|
||||||
|
|
||||||
// Score by keyword match count
|
|
||||||
let query_lower = query.to_lowercase();
|
|
||||||
let kw_list: Vec<&str> = query_lower.split_whitespace().collect();
|
|
||||||
for entry in &mut results {
|
|
||||||
let content_lower = entry.content.to_lowercase();
|
|
||||||
let matched = kw_list
|
|
||||||
.iter()
|
.iter()
|
||||||
.filter(|kw| content_lower.contains(**kw))
|
.map(|(id, score)| vector::ScoredResult {
|
||||||
.count();
|
id: id.clone(),
|
||||||
#[allow(clippy::cast_precision_loss)]
|
vector_score: None,
|
||||||
{
|
keyword_score: Some(*score),
|
||||||
entry.score = Some(matched as f64 / kw_list.len().max(1) as f64);
|
final_score: *score,
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>()
|
||||||
|
} else {
|
||||||
|
vector::hybrid_merge(
|
||||||
|
&vector_results,
|
||||||
|
&keyword_results,
|
||||||
|
self.vector_weight,
|
||||||
|
self.keyword_weight,
|
||||||
|
limit,
|
||||||
|
)
|
||||||
|
};
|
||||||
|
|
||||||
|
// Fetch full entries for merged results
|
||||||
|
let mut results = Vec::new();
|
||||||
|
for scored in &merged {
|
||||||
|
let mut stmt = conn.prepare(
|
||||||
|
"SELECT id, key, content, category, created_at FROM memories WHERE id = ?1",
|
||||||
|
)?;
|
||||||
|
if let Ok(entry) = stmt.query_row(params![scored.id], |row| {
|
||||||
|
Ok(MemoryEntry {
|
||||||
|
id: row.get(0)?,
|
||||||
|
key: row.get(1)?,
|
||||||
|
content: row.get(2)?,
|
||||||
|
category: Self::str_to_category(&row.get::<_, String>(3)?),
|
||||||
|
timestamp: row.get(4)?,
|
||||||
|
session_id: None,
|
||||||
|
score: Some(f64::from(scored.final_score)),
|
||||||
|
})
|
||||||
|
}) {
|
||||||
|
results.push(entry);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
results.sort_by(|a, b| {
|
// If hybrid returned nothing, fall back to LIKE search
|
||||||
b.score
|
if results.is_empty() {
|
||||||
.partial_cmp(&a.score)
|
let keywords: Vec<String> =
|
||||||
.unwrap_or(std::cmp::Ordering::Equal)
|
query.split_whitespace().map(|w| format!("%{w}%")).collect();
|
||||||
});
|
if !keywords.is_empty() {
|
||||||
|
let conditions: Vec<String> = keywords
|
||||||
|
.iter()
|
||||||
|
.enumerate()
|
||||||
|
.map(|(i, _)| {
|
||||||
|
format!("(content LIKE ?{} OR key LIKE ?{})", i * 2 + 1, i * 2 + 2)
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
let where_clause = conditions.join(" OR ");
|
||||||
|
let sql = format!(
|
||||||
|
"SELECT id, key, content, category, created_at FROM memories
|
||||||
|
WHERE {where_clause}
|
||||||
|
ORDER BY updated_at DESC
|
||||||
|
LIMIT ?{}",
|
||||||
|
keywords.len() * 2 + 1
|
||||||
|
);
|
||||||
|
let mut stmt = conn.prepare(&sql)?;
|
||||||
|
let mut param_values: Vec<Box<dyn rusqlite::types::ToSql>> = Vec::new();
|
||||||
|
for kw in &keywords {
|
||||||
|
param_values.push(Box::new(kw.clone()));
|
||||||
|
param_values.push(Box::new(kw.clone()));
|
||||||
|
}
|
||||||
|
#[allow(clippy::cast_possible_wrap)]
|
||||||
|
param_values.push(Box::new(limit as i64));
|
||||||
|
let params_ref: Vec<&dyn rusqlite::types::ToSql> =
|
||||||
|
param_values.iter().map(AsRef::as_ref).collect();
|
||||||
|
let rows = stmt.query_map(params_ref.as_slice(), |row| {
|
||||||
|
Ok(MemoryEntry {
|
||||||
|
id: row.get(0)?,
|
||||||
|
key: row.get(1)?,
|
||||||
|
content: row.get(2)?,
|
||||||
|
category: Self::str_to_category(&row.get::<_, String>(3)?),
|
||||||
|
timestamp: row.get(4)?,
|
||||||
|
session_id: None,
|
||||||
|
score: Some(1.0),
|
||||||
|
})
|
||||||
|
})?;
|
||||||
|
for row in rows {
|
||||||
|
results.push(row?);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
results.truncate(limit);
|
||||||
Ok(results)
|
Ok(results)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -478,4 +781,268 @@ mod tests {
|
||||||
assert_eq!(&entry.category, cat);
|
assert_eq!(&entry.category, cat);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ── FTS5 search tests ────────────────────────────────────────
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn fts5_bm25_ranking() {
|
||||||
|
let (_tmp, mem) = temp_sqlite();
|
||||||
|
mem.store(
|
||||||
|
"a",
|
||||||
|
"Rust is a systems programming language",
|
||||||
|
MemoryCategory::Core,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
mem.store("b", "Python is great for scripting", MemoryCategory::Core)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
mem.store(
|
||||||
|
"c",
|
||||||
|
"Rust and Rust and Rust everywhere",
|
||||||
|
MemoryCategory::Core,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let results = mem.recall("Rust", 10).await.unwrap();
|
||||||
|
assert!(results.len() >= 2);
|
||||||
|
// All results should contain "Rust"
|
||||||
|
for r in &results {
|
||||||
|
assert!(
|
||||||
|
r.content.to_lowercase().contains("rust"),
|
||||||
|
"Expected 'rust' in: {}",
|
||||||
|
r.content
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn fts5_multi_word_query() {
|
||||||
|
let (_tmp, mem) = temp_sqlite();
|
||||||
|
mem.store("a", "The quick brown fox jumps", MemoryCategory::Core)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
mem.store("b", "A lazy dog sleeps", MemoryCategory::Core)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
mem.store("c", "The quick dog runs fast", MemoryCategory::Core)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let results = mem.recall("quick dog", 10).await.unwrap();
|
||||||
|
assert!(!results.is_empty());
|
||||||
|
// "The quick dog runs fast" matches both terms
|
||||||
|
assert!(results[0].content.contains("quick"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn recall_empty_query_returns_empty() {
|
||||||
|
let (_tmp, mem) = temp_sqlite();
|
||||||
|
mem.store("a", "data", MemoryCategory::Core).await.unwrap();
|
||||||
|
let results = mem.recall("", 10).await.unwrap();
|
||||||
|
assert!(results.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn recall_whitespace_query_returns_empty() {
|
||||||
|
let (_tmp, mem) = temp_sqlite();
|
||||||
|
mem.store("a", "data", MemoryCategory::Core).await.unwrap();
|
||||||
|
let results = mem.recall(" ", 10).await.unwrap();
|
||||||
|
assert!(results.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Embedding cache tests ────────────────────────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn content_hash_deterministic() {
|
||||||
|
let h1 = SqliteMemory::content_hash("hello world");
|
||||||
|
let h2 = SqliteMemory::content_hash("hello world");
|
||||||
|
assert_eq!(h1, h2);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn content_hash_different_inputs() {
|
||||||
|
let h1 = SqliteMemory::content_hash("hello");
|
||||||
|
let h2 = SqliteMemory::content_hash("world");
|
||||||
|
assert_ne!(h1, h2);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Schema tests ─────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn schema_has_fts5_table() {
|
||||||
|
let (_tmp, mem) = temp_sqlite();
|
||||||
|
let conn = mem.conn.lock().unwrap();
|
||||||
|
// FTS5 table should exist
|
||||||
|
let count: i64 = conn
|
||||||
|
.query_row(
|
||||||
|
"SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='memories_fts'",
|
||||||
|
[],
|
||||||
|
|row| row.get(0),
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(count, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn schema_has_embedding_cache() {
|
||||||
|
let (_tmp, mem) = temp_sqlite();
|
||||||
|
let conn = mem.conn.lock().unwrap();
|
||||||
|
let count: i64 = conn
|
||||||
|
.query_row(
|
||||||
|
"SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='embedding_cache'",
|
||||||
|
[],
|
||||||
|
|row| row.get(0),
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(count, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn schema_memories_has_embedding_column() {
|
||||||
|
let (_tmp, mem) = temp_sqlite();
|
||||||
|
let conn = mem.conn.lock().unwrap();
|
||||||
|
// Check that embedding column exists by querying it
|
||||||
|
let result = conn.execute_batch("SELECT embedding FROM memories LIMIT 0");
|
||||||
|
assert!(result.is_ok());
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── FTS5 sync trigger tests ──────────────────────────────────
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn fts5_syncs_on_insert() {
|
||||||
|
let (_tmp, mem) = temp_sqlite();
|
||||||
|
mem.store("test_key", "unique_searchterm_xyz", MemoryCategory::Core)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let conn = mem.conn.lock().unwrap();
|
||||||
|
let count: i64 = conn
|
||||||
|
.query_row(
|
||||||
|
"SELECT COUNT(*) FROM memories_fts WHERE memories_fts MATCH '\"unique_searchterm_xyz\"'",
|
||||||
|
[],
|
||||||
|
|row| row.get(0),
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(count, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn fts5_syncs_on_delete() {
|
||||||
|
let (_tmp, mem) = temp_sqlite();
|
||||||
|
mem.store("del_key", "deletable_content_abc", MemoryCategory::Core)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
mem.forget("del_key").await.unwrap();
|
||||||
|
|
||||||
|
let conn = mem.conn.lock().unwrap();
|
||||||
|
let count: i64 = conn
|
||||||
|
.query_row(
|
||||||
|
"SELECT COUNT(*) FROM memories_fts WHERE memories_fts MATCH '\"deletable_content_abc\"'",
|
||||||
|
[],
|
||||||
|
|row| row.get(0),
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(count, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn fts5_syncs_on_update() {
|
||||||
|
let (_tmp, mem) = temp_sqlite();
|
||||||
|
mem.store("upd_key", "original_content_111", MemoryCategory::Core)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
mem.store("upd_key", "updated_content_222", MemoryCategory::Core)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let conn = mem.conn.lock().unwrap();
|
||||||
|
// Old content should not be findable
|
||||||
|
let old: i64 = conn
|
||||||
|
.query_row(
|
||||||
|
"SELECT COUNT(*) FROM memories_fts WHERE memories_fts MATCH '\"original_content_111\"'",
|
||||||
|
[],
|
||||||
|
|row| row.get(0),
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(old, 0);
|
||||||
|
|
||||||
|
// New content should be findable
|
||||||
|
let new: i64 = conn
|
||||||
|
.query_row(
|
||||||
|
"SELECT COUNT(*) FROM memories_fts WHERE memories_fts MATCH '\"updated_content_222\"'",
|
||||||
|
[],
|
||||||
|
|row| row.get(0),
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(new, 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── With-embedder constructor test ───────────────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn with_embedder_noop() {
|
||||||
|
let tmp = TempDir::new().unwrap();
|
||||||
|
let embedder = Arc::new(super::super::embeddings::NoopEmbedding);
|
||||||
|
let mem = SqliteMemory::with_embedder(tmp.path(), embedder, 0.7, 0.3, 1000);
|
||||||
|
assert!(mem.is_ok());
|
||||||
|
assert_eq!(mem.unwrap().name(), "sqlite");
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Reindex test ─────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn reindex_rebuilds_fts() {
|
||||||
|
let (_tmp, mem) = temp_sqlite();
|
||||||
|
mem.store("r1", "reindex test alpha", MemoryCategory::Core)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
mem.store("r2", "reindex test beta", MemoryCategory::Core)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// Reindex should succeed (noop embedder → 0 re-embedded)
|
||||||
|
let count = mem.reindex().await.unwrap();
|
||||||
|
assert_eq!(count, 0);
|
||||||
|
|
||||||
|
// FTS should still work after rebuild
|
||||||
|
let results = mem.recall("reindex", 10).await.unwrap();
|
||||||
|
assert_eq!(results.len(), 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Recall limit test ────────────────────────────────────────
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn recall_respects_limit() {
|
||||||
|
let (_tmp, mem) = temp_sqlite();
|
||||||
|
for i in 0..20 {
|
||||||
|
mem.store(
|
||||||
|
&format!("k{i}"),
|
||||||
|
&format!("common keyword item {i}"),
|
||||||
|
MemoryCategory::Core,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
let results = mem.recall("common keyword", 5).await.unwrap();
|
||||||
|
assert!(results.len() <= 5);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Score presence test ──────────────────────────────────────
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn recall_results_have_scores() {
|
||||||
|
let (_tmp, mem) = temp_sqlite();
|
||||||
|
mem.store("s1", "scored result test", MemoryCategory::Core)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let results = mem.recall("scored", 10).await.unwrap();
|
||||||
|
assert!(!results.is_empty());
|
||||||
|
for r in &results {
|
||||||
|
assert!(r.score.is_some(), "Expected score on result: {:?}", r.key);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
234
src/memory/vector.rs
Normal file
234
src/memory/vector.rs
Normal file
|
|
@ -0,0 +1,234 @@
|
||||||
|
// Vector operations — cosine similarity, normalization, hybrid merge.
|
||||||
|
|
||||||
|
/// Cosine similarity between two vectors. Returns 0.0–1.0.
|
||||||
|
pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
|
||||||
|
if a.len() != b.len() || a.is_empty() {
|
||||||
|
return 0.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut dot = 0.0_f64;
|
||||||
|
let mut norm_a = 0.0_f64;
|
||||||
|
let mut norm_b = 0.0_f64;
|
||||||
|
|
||||||
|
for (x, y) in a.iter().zip(b.iter()) {
|
||||||
|
let x = f64::from(*x);
|
||||||
|
let y = f64::from(*y);
|
||||||
|
dot += x * y;
|
||||||
|
norm_a += x * x;
|
||||||
|
norm_b += y * y;
|
||||||
|
}
|
||||||
|
|
||||||
|
let denom = norm_a.sqrt() * norm_b.sqrt();
|
||||||
|
if denom < f64::EPSILON {
|
||||||
|
return 0.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Clamp to [0, 1] — embeddings are typically positive
|
||||||
|
#[allow(clippy::cast_possible_truncation)]
|
||||||
|
let sim = (dot / denom).clamp(0.0, 1.0) as f32;
|
||||||
|
sim
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Serialize f32 vector to bytes (little-endian)
|
||||||
|
pub fn vec_to_bytes(v: &[f32]) -> Vec<u8> {
|
||||||
|
let mut bytes = Vec::with_capacity(v.len() * 4);
|
||||||
|
for &f in v {
|
||||||
|
bytes.extend_from_slice(&f.to_le_bytes());
|
||||||
|
}
|
||||||
|
bytes
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Deserialize bytes to f32 vector (little-endian)
|
||||||
|
pub fn bytes_to_vec(bytes: &[u8]) -> Vec<f32> {
|
||||||
|
bytes
|
||||||
|
.chunks_exact(4)
|
||||||
|
.map(|chunk| {
|
||||||
|
let arr: [u8; 4] = chunk.try_into().unwrap_or([0; 4]);
|
||||||
|
f32::from_le_bytes(arr)
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// A scored result for hybrid merging
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct ScoredResult {
|
||||||
|
pub id: String,
|
||||||
|
pub vector_score: Option<f32>,
|
||||||
|
pub keyword_score: Option<f32>,
|
||||||
|
pub final_score: f32,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Hybrid merge: combine vector and keyword results with weighted fusion.
|
||||||
|
///
|
||||||
|
/// Normalizes each score set to [0, 1], then computes:
|
||||||
|
/// `final_score` = `vector_weight` * `vector_score` + `keyword_weight` * `keyword_score`
|
||||||
|
///
|
||||||
|
/// Deduplicates by id, keeping the best score from each source.
|
||||||
|
pub fn hybrid_merge(
|
||||||
|
vector_results: &[(String, f32)], // (id, cosine_similarity)
|
||||||
|
keyword_results: &[(String, f32)], // (id, bm25_score)
|
||||||
|
vector_weight: f32,
|
||||||
|
keyword_weight: f32,
|
||||||
|
limit: usize,
|
||||||
|
) -> Vec<ScoredResult> {
|
||||||
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
let mut map: HashMap<String, ScoredResult> = HashMap::new();
|
||||||
|
|
||||||
|
// Normalize vector scores (already 0–1 from cosine similarity)
|
||||||
|
for (id, score) in vector_results {
|
||||||
|
map.entry(id.clone())
|
||||||
|
.and_modify(|r| r.vector_score = Some(*score))
|
||||||
|
.or_insert_with(|| ScoredResult {
|
||||||
|
id: id.clone(),
|
||||||
|
vector_score: Some(*score),
|
||||||
|
keyword_score: None,
|
||||||
|
final_score: 0.0,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Normalize keyword scores (BM25 can be any positive number)
|
||||||
|
let max_kw = keyword_results
|
||||||
|
.iter()
|
||||||
|
.map(|(_, s)| *s)
|
||||||
|
.fold(0.0_f32, f32::max);
|
||||||
|
let max_kw = if max_kw < f32::EPSILON { 1.0 } else { max_kw };
|
||||||
|
|
||||||
|
for (id, score) in keyword_results {
|
||||||
|
let normalized = score / max_kw;
|
||||||
|
map.entry(id.clone())
|
||||||
|
.and_modify(|r| r.keyword_score = Some(normalized))
|
||||||
|
.or_insert_with(|| ScoredResult {
|
||||||
|
id: id.clone(),
|
||||||
|
vector_score: None,
|
||||||
|
keyword_score: Some(normalized),
|
||||||
|
final_score: 0.0,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
// Compute final scores
|
||||||
|
let mut results: Vec<ScoredResult> = map
|
||||||
|
.into_values()
|
||||||
|
.map(|mut r| {
|
||||||
|
let vs = r.vector_score.unwrap_or(0.0);
|
||||||
|
let ks = r.keyword_score.unwrap_or(0.0);
|
||||||
|
r.final_score = vector_weight * vs + keyword_weight * ks;
|
||||||
|
r
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
results.sort_by(|a, b| {
|
||||||
|
b.final_score
|
||||||
|
.partial_cmp(&a.final_score)
|
||||||
|
.unwrap_or(std::cmp::Ordering::Equal)
|
||||||
|
});
|
||||||
|
results.truncate(limit);
|
||||||
|
results
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn cosine_identical_vectors() {
|
||||||
|
let v = vec![1.0, 2.0, 3.0];
|
||||||
|
let sim = cosine_similarity(&v, &v);
|
||||||
|
assert!((sim - 1.0).abs() < 0.001);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn cosine_orthogonal_vectors() {
|
||||||
|
let a = vec![1.0, 0.0, 0.0];
|
||||||
|
let b = vec![0.0, 1.0, 0.0];
|
||||||
|
let sim = cosine_similarity(&a, &b);
|
||||||
|
assert!(sim.abs() < 0.001);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn cosine_similar_vectors() {
|
||||||
|
let a = vec![1.0, 2.0, 3.0];
|
||||||
|
let b = vec![1.1, 2.1, 3.1];
|
||||||
|
let sim = cosine_similarity(&a, &b);
|
||||||
|
assert!(sim > 0.99);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn cosine_empty_returns_zero() {
|
||||||
|
assert_eq!(cosine_similarity(&[], &[]), 0.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn cosine_mismatched_lengths() {
|
||||||
|
assert_eq!(cosine_similarity(&[1.0], &[1.0, 2.0]), 0.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn cosine_zero_vector() {
|
||||||
|
let a = vec![0.0, 0.0, 0.0];
|
||||||
|
let b = vec![1.0, 2.0, 3.0];
|
||||||
|
assert_eq!(cosine_similarity(&a, &b), 0.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn vec_bytes_roundtrip() {
|
||||||
|
let original = vec![1.0_f32, -2.5, 3.14, 0.0, f32::MAX];
|
||||||
|
let bytes = vec_to_bytes(&original);
|
||||||
|
let restored = bytes_to_vec(&bytes);
|
||||||
|
assert_eq!(original, restored);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn vec_bytes_empty() {
|
||||||
|
let bytes = vec_to_bytes(&[]);
|
||||||
|
assert!(bytes.is_empty());
|
||||||
|
let restored = bytes_to_vec(&bytes);
|
||||||
|
assert!(restored.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn hybrid_merge_vector_only() {
|
||||||
|
let vec_results = vec![("a".into(), 0.9), ("b".into(), 0.5)];
|
||||||
|
let merged = hybrid_merge(&vec_results, &[], 0.7, 0.3, 10);
|
||||||
|
assert_eq!(merged.len(), 2);
|
||||||
|
assert_eq!(merged[0].id, "a");
|
||||||
|
assert!(merged[0].final_score > merged[1].final_score);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn hybrid_merge_keyword_only() {
|
||||||
|
let kw_results = vec![("x".into(), 10.0), ("y".into(), 5.0)];
|
||||||
|
let merged = hybrid_merge(&[], &kw_results, 0.7, 0.3, 10);
|
||||||
|
assert_eq!(merged.len(), 2);
|
||||||
|
assert_eq!(merged[0].id, "x");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn hybrid_merge_deduplicates() {
|
||||||
|
let vec_results = vec![("a".into(), 0.9)];
|
||||||
|
let kw_results = vec![("a".into(), 10.0)];
|
||||||
|
let merged = hybrid_merge(&vec_results, &kw_results, 0.7, 0.3, 10);
|
||||||
|
assert_eq!(merged.len(), 1);
|
||||||
|
assert_eq!(merged[0].id, "a");
|
||||||
|
// Should have both scores
|
||||||
|
assert!(merged[0].vector_score.is_some());
|
||||||
|
assert!(merged[0].keyword_score.is_some());
|
||||||
|
// Final score should be higher than either alone
|
||||||
|
assert!(merged[0].final_score > 0.7 * 0.9);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn hybrid_merge_respects_limit() {
|
||||||
|
let vec_results: Vec<(String, f32)> = (0..20)
|
||||||
|
.map(|i| (format!("item_{i}"), 1.0 - i as f32 * 0.05))
|
||||||
|
.collect();
|
||||||
|
let merged = hybrid_merge(&vec_results, &[], 1.0, 0.0, 5);
|
||||||
|
assert_eq!(merged.len(), 5);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn hybrid_merge_empty_inputs() {
|
||||||
|
let merged = hybrid_merge(&[], &[], 0.7, 0.3, 10);
|
||||||
|
assert!(merged.is_empty());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -51,6 +51,7 @@ pub async fn handle_command(command: super::ToolCommands, config: Config) -> Res
|
||||||
let mem: Arc<dyn Memory> = Arc::from(crate::memory::create_memory(
|
let mem: Arc<dyn Memory> = Arc::from(crate::memory::create_memory(
|
||||||
&config.memory,
|
&config.memory,
|
||||||
&config.workspace_dir,
|
&config.workspace_dir,
|
||||||
|
config.api_key.as_deref(),
|
||||||
)?);
|
)?);
|
||||||
let tools_list = all_tools(security, mem);
|
let tools_list = all_tools(security, mem);
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue