diff --git a/src/agent/loop_.rs b/src/agent/loop_.rs index 6da5ecc..bdb693d 100644 --- a/src/agent/loop_.rs +++ b/src/agent/loop_.rs @@ -49,6 +49,7 @@ pub async fn run( let mem: Arc = Arc::from(memory::create_memory( &config.memory, &config.workspace_dir, + config.api_key.as_deref(), )?); tracing::info!(backend = mem.name(), "Memory initialized"); diff --git a/src/channels/mod.rs b/src/channels/mod.rs index 0bb1732..f77a4a1 100644 --- a/src/channels/mod.rs +++ b/src/channels/mod.rs @@ -227,6 +227,7 @@ pub async fn start_channels(config: Config) -> Result<()> { let mem: Arc = Arc::from(memory::create_memory( &config.memory, &config.workspace_dir, + config.api_key.as_deref(), )?); // Build system prompt from workspace identity files + skills diff --git a/src/config/schema.rs b/src/config/schema.rs index 7f0173b..f0882a9 100644 --- a/src/config/schema.rs +++ b/src/config/schema.rs @@ -46,6 +46,49 @@ pub struct MemoryConfig { pub backend: String, /// Auto-save conversation context to memory pub auto_save: bool, + /// Embedding provider: "none" | "openai" | "custom:URL" + #[serde(default = "default_embedding_provider")] + pub embedding_provider: String, + /// Embedding model name (e.g. "text-embedding-3-small") + #[serde(default = "default_embedding_model")] + pub embedding_model: String, + /// Embedding vector dimensions + #[serde(default = "default_embedding_dims")] + pub embedding_dimensions: usize, + /// Weight for vector similarity in hybrid search (0.0–1.0) + #[serde(default = "default_vector_weight")] + pub vector_weight: f64, + /// Weight for keyword BM25 in hybrid search (0.0–1.0) + #[serde(default = "default_keyword_weight")] + pub keyword_weight: f64, + /// Max embedding cache entries before LRU eviction + #[serde(default = "default_cache_size")] + pub embedding_cache_size: usize, + /// Max tokens per chunk for document splitting + #[serde(default = "default_chunk_size")] + pub chunk_max_tokens: usize, +} + +fn default_embedding_provider() -> String { + "none".into() +} +fn default_embedding_model() -> String { + "text-embedding-3-small".into() +} +fn default_embedding_dims() -> usize { + 1536 +} +fn default_vector_weight() -> f64 { + 0.7 +} +fn default_keyword_weight() -> f64 { + 0.3 +} +fn default_cache_size() -> usize { + 10_000 +} +fn default_chunk_size() -> usize { + 512 } impl Default for MemoryConfig { @@ -53,6 +96,13 @@ impl Default for MemoryConfig { Self { backend: "sqlite".into(), auto_save: true, + embedding_provider: default_embedding_provider(), + embedding_model: default_embedding_model(), + embedding_dimensions: default_embedding_dims(), + vector_weight: default_vector_weight(), + keyword_weight: default_keyword_weight(), + embedding_cache_size: default_cache_size(), + chunk_max_tokens: default_chunk_size(), } } } diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index 6d737b9..3b70541 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -25,6 +25,7 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> { let mem: Arc = Arc::from(memory::create_memory( &config.memory, &config.workspace_dir, + config.api_key.as_deref(), )?); // Extract webhook secret for authentication diff --git a/src/memory/chunker.rs b/src/memory/chunker.rs new file mode 100644 index 0000000..23cf2f1 --- /dev/null +++ b/src/memory/chunker.rs @@ -0,0 +1,259 @@ +// Line-based markdown chunker — splits documents into semantic chunks. +// +// Splits on markdown headings and paragraph boundaries, respecting +// a max token limit per chunk. Preserves heading context. + +/// A single chunk of text with metadata. +#[derive(Debug, Clone)] +pub struct Chunk { + pub index: usize, + pub content: String, + pub heading: Option, +} + +/// Split markdown text into chunks, each under `max_tokens` approximate tokens. +/// +/// Strategy: +/// 1. Split on `## ` and `# ` headings (keeps heading with its content) +/// 2. If a section exceeds `max_tokens`, split on blank lines (paragraphs) +/// 3. If a paragraph still exceeds, split on line boundaries +/// +/// Token estimation: ~4 chars per token (rough English average). +pub fn chunk_markdown(text: &str, max_tokens: usize) -> Vec { + if text.trim().is_empty() { + return Vec::new(); + } + + let max_chars = max_tokens * 4; + let sections = split_on_headings(text); + let mut chunks = Vec::new(); + + for (heading, body) in sections { + let full = if let Some(ref h) = heading { + format!("{h}\n{body}") + } else { + body.clone() + }; + + if full.len() <= max_chars { + chunks.push(Chunk { + index: chunks.len(), + content: full.trim().to_string(), + heading: heading.clone(), + }); + } else { + // Split on paragraphs (blank lines) + let paragraphs = split_on_blank_lines(&body); + let mut current = heading + .as_ref() + .map_or_else(String::new, |h| format!("{h}\n")); + + for para in paragraphs { + if current.len() + para.len() > max_chars && !current.trim().is_empty() { + chunks.push(Chunk { + index: chunks.len(), + content: current.trim().to_string(), + heading: heading.clone(), + }); + current = heading + .as_ref() + .map_or_else(String::new, |h| format!("{h}\n")); + } + + if para.len() > max_chars { + // Paragraph too big — split on lines + if !current.trim().is_empty() { + chunks.push(Chunk { + index: chunks.len(), + content: current.trim().to_string(), + heading: heading.clone(), + }); + current = heading + .as_ref() + .map_or_else(String::new, |h| format!("{h}\n")); + } + for line_chunk in split_on_lines(¶, max_chars) { + chunks.push(Chunk { + index: chunks.len(), + content: line_chunk.trim().to_string(), + heading: heading.clone(), + }); + } + } else { + current.push_str(¶); + current.push('\n'); + } + } + + if !current.trim().is_empty() { + chunks.push(Chunk { + index: chunks.len(), + content: current.trim().to_string(), + heading: heading.clone(), + }); + } + } + } + + // Filter out empty chunks + chunks.retain(|c| !c.content.is_empty()); + + // Re-index + for (i, chunk) in chunks.iter_mut().enumerate() { + chunk.index = i; + } + + chunks +} + +/// Split text into `(heading, body)` sections. +fn split_on_headings(text: &str) -> Vec<(Option, String)> { + let mut sections = Vec::new(); + let mut current_heading: Option = None; + let mut current_body = String::new(); + + for line in text.lines() { + if line.starts_with("# ") || line.starts_with("## ") || line.starts_with("### ") { + if !current_body.trim().is_empty() || current_heading.is_some() { + sections.push((current_heading.take(), current_body.clone())); + current_body.clear(); + } + current_heading = Some(line.to_string()); + } else { + current_body.push_str(line); + current_body.push('\n'); + } + } + + if !current_body.trim().is_empty() || current_heading.is_some() { + sections.push((current_heading, current_body)); + } + + sections +} + +/// Split text on blank lines (paragraph boundaries) +fn split_on_blank_lines(text: &str) -> Vec { + let mut paragraphs = Vec::new(); + let mut current = String::new(); + + for line in text.lines() { + if line.trim().is_empty() { + if !current.trim().is_empty() { + paragraphs.push(current.clone()); + current.clear(); + } + } else { + current.push_str(line); + current.push('\n'); + } + } + + if !current.trim().is_empty() { + paragraphs.push(current); + } + + paragraphs +} + +/// Split text on line boundaries to fit within `max_chars` +fn split_on_lines(text: &str, max_chars: usize) -> Vec { + let mut chunks = Vec::new(); + let mut current = String::new(); + + for line in text.lines() { + if current.len() + line.len() + 1 > max_chars && !current.is_empty() { + chunks.push(current.clone()); + current.clear(); + } + current.push_str(line); + current.push('\n'); + } + + if !current.is_empty() { + chunks.push(current); + } + + chunks +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn empty_text() { + assert!(chunk_markdown("", 512).is_empty()); + assert!(chunk_markdown(" ", 512).is_empty()); + } + + #[test] + fn single_short_paragraph() { + let chunks = chunk_markdown("Hello world", 512); + assert_eq!(chunks.len(), 1); + assert_eq!(chunks[0].content, "Hello world"); + assert!(chunks[0].heading.is_none()); + } + + #[test] + fn heading_sections() { + let text = "# Title\nSome intro.\n\n## Section A\nContent A.\n\n## Section B\nContent B."; + let chunks = chunk_markdown(text, 512); + assert!(chunks.len() >= 3); + assert!(chunks[0].heading.is_none() || chunks[0].heading.as_deref() == Some("# Title")); + } + + #[test] + fn respects_max_tokens() { + // Build multi-line text (one sentence per line) to exercise line-level splitting + let long_text: String = (0..200) + .map(|i| format!("This is sentence number {i} with some extra words to fill it up.\n")) + .collect(); + let chunks = chunk_markdown(&long_text, 50); // 50 tokens ≈ 200 chars + assert!( + chunks.len() > 1, + "Expected multiple chunks, got {}", + chunks.len() + ); + for chunk in &chunks { + // Allow some slack (heading re-insertion etc.) + assert!( + chunk.content.len() <= 300, + "Chunk too long: {} chars", + chunk.content.len() + ); + } + } + + #[test] + fn preserves_heading_in_split_sections() { + let mut text = String::from("## Big Section\n"); + for i in 0..100 { + text.push_str(&format!("Line {i} with some content here.\n\n")); + } + let chunks = chunk_markdown(&text, 50); + assert!(chunks.len() > 1); + // All chunks from this section should reference the heading + for chunk in &chunks { + if chunk.heading.is_some() { + assert_eq!(chunk.heading.as_deref(), Some("## Big Section")); + } + } + } + + #[test] + fn indexes_are_sequential() { + let text = "# A\nContent A\n\n# B\nContent B\n\n# C\nContent C"; + let chunks = chunk_markdown(text, 512); + for (i, chunk) in chunks.iter().enumerate() { + assert_eq!(chunk.index, i); + } + } + + #[test] + fn chunk_count_reasonable() { + let text = "Hello world. This is a test document."; + let chunks = chunk_markdown(text, 512); + assert_eq!(chunks.len(), 1); + } +} diff --git a/src/memory/embeddings.rs b/src/memory/embeddings.rs new file mode 100644 index 0000000..882082b --- /dev/null +++ b/src/memory/embeddings.rs @@ -0,0 +1,190 @@ +use async_trait::async_trait; + +/// Trait for embedding providers — convert text to vectors +#[async_trait] +pub trait EmbeddingProvider: Send + Sync { + /// Provider name + fn name(&self) -> &str; + + /// Embedding dimensions + fn dimensions(&self) -> usize; + + /// Embed a batch of texts into vectors + async fn embed(&self, texts: &[&str]) -> anyhow::Result>>; + + /// Embed a single text + async fn embed_one(&self, text: &str) -> anyhow::Result> { + let mut results = self.embed(&[text]).await?; + results + .pop() + .ok_or_else(|| anyhow::anyhow!("Empty embedding result")) + } +} + +// ── Noop provider (keyword-only fallback) ──────────────────── + +pub struct NoopEmbedding; + +#[async_trait] +impl EmbeddingProvider for NoopEmbedding { + fn name(&self) -> &str { + "none" + } + + fn dimensions(&self) -> usize { + 0 + } + + async fn embed(&self, _texts: &[&str]) -> anyhow::Result>> { + Ok(Vec::new()) + } +} + +// ── OpenAI-compatible embedding provider ───────────────────── + +pub struct OpenAiEmbedding { + client: reqwest::Client, + base_url: String, + api_key: String, + model: String, + dims: usize, +} + +impl OpenAiEmbedding { + pub fn new(base_url: &str, api_key: &str, model: &str, dims: usize) -> Self { + Self { + client: reqwest::Client::new(), + base_url: base_url.trim_end_matches('/').to_string(), + api_key: api_key.to_string(), + model: model.to_string(), + dims, + } + } +} + +#[async_trait] +impl EmbeddingProvider for OpenAiEmbedding { + fn name(&self) -> &str { + "openai" + } + + fn dimensions(&self) -> usize { + self.dims + } + + async fn embed(&self, texts: &[&str]) -> anyhow::Result>> { + if texts.is_empty() { + return Ok(Vec::new()); + } + + let body = serde_json::json!({ + "model": self.model, + "input": texts, + }); + + let resp = self + .client + .post(format!("{}/v1/embeddings", self.base_url)) + .header("Authorization", format!("Bearer {}", self.api_key)) + .header("Content-Type", "application/json") + .json(&body) + .send() + .await?; + + if !resp.status().is_success() { + let status = resp.status(); + let text = resp.text().await.unwrap_or_default(); + anyhow::bail!("Embedding API error {status}: {text}"); + } + + let json: serde_json::Value = resp.json().await?; + let data = json + .get("data") + .and_then(|d| d.as_array()) + .ok_or_else(|| anyhow::anyhow!("Invalid embedding response: missing 'data'"))?; + + let mut embeddings = Vec::with_capacity(data.len()); + for item in data { + let embedding = item + .get("embedding") + .and_then(|e| e.as_array()) + .ok_or_else(|| anyhow::anyhow!("Invalid embedding item"))?; + + #[allow(clippy::cast_possible_truncation)] + let vec: Vec = embedding + .iter() + .filter_map(|v| v.as_f64().map(|f| f as f32)) + .collect(); + + embeddings.push(vec); + } + + Ok(embeddings) + } +} + +// ── Factory ────────────────────────────────────────────────── + +pub fn create_embedding_provider( + provider: &str, + api_key: Option<&str>, + model: &str, + dims: usize, +) -> Box { + match provider { + "openai" => { + let key = api_key.unwrap_or(""); + Box::new(OpenAiEmbedding::new( + "https://api.openai.com", + key, + model, + dims, + )) + } + name if name.starts_with("custom:") => { + let base_url = name.strip_prefix("custom:").unwrap_or(""); + let key = api_key.unwrap_or(""); + Box::new(OpenAiEmbedding::new(base_url, key, model, dims)) + } + _ => Box::new(NoopEmbedding), + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn noop_name() { + let p = NoopEmbedding; + assert_eq!(p.name(), "none"); + assert_eq!(p.dimensions(), 0); + } + + #[tokio::test] + async fn noop_embed_returns_empty() { + let p = NoopEmbedding; + let result = p.embed(&["hello"]).await.unwrap(); + assert!(result.is_empty()); + } + + #[test] + fn factory_none() { + let p = create_embedding_provider("none", None, "model", 1536); + assert_eq!(p.name(), "none"); + } + + #[test] + fn factory_openai() { + let p = create_embedding_provider("openai", Some("key"), "text-embedding-3-small", 1536); + assert_eq!(p.name(), "openai"); + assert_eq!(p.dimensions(), 1536); + } + + #[test] + fn factory_custom_url() { + let p = create_embedding_provider("custom:http://localhost:1234", None, "model", 768); + assert_eq!(p.name(), "openai"); // uses OpenAiEmbedding internally + assert_eq!(p.dimensions(), 768); + } +} diff --git a/src/memory/mod.rs b/src/memory/mod.rs index 98f8614..249670b 100644 --- a/src/memory/mod.rs +++ b/src/memory/mod.rs @@ -1,6 +1,9 @@ +pub mod chunker; +pub mod embeddings; pub mod markdown; pub mod sqlite; pub mod traits; +pub mod vector; pub use markdown::MarkdownMemory; pub use sqlite::SqliteMemory; @@ -10,14 +13,34 @@ pub use traits::{MemoryCategory, MemoryEntry}; use crate::config::MemoryConfig; use std::path::Path; +use std::sync::Arc; /// Factory: create the right memory backend from config pub fn create_memory( config: &MemoryConfig, workspace_dir: &Path, + api_key: Option<&str>, ) -> anyhow::Result> { match config.backend.as_str() { - "sqlite" => Ok(Box::new(SqliteMemory::new(workspace_dir)?)), + "sqlite" => { + let embedder: Arc = + Arc::from(embeddings::create_embedding_provider( + &config.embedding_provider, + api_key, + &config.embedding_model, + config.embedding_dimensions, + )); + + #[allow(clippy::cast_possible_truncation)] + let mem = SqliteMemory::with_embedder( + workspace_dir, + embedder, + config.vector_weight as f32, + config.keyword_weight as f32, + config.embedding_cache_size, + )?; + Ok(Box::new(mem)) + } "markdown" | "none" => Ok(Box::new(MarkdownMemory::new(workspace_dir))), other => { tracing::warn!("Unknown memory backend '{other}', falling back to markdown"); @@ -36,9 +59,9 @@ mod tests { let tmp = TempDir::new().unwrap(); let cfg = MemoryConfig { backend: "sqlite".into(), - auto_save: true, + ..MemoryConfig::default() }; - let mem = create_memory(&cfg, tmp.path()).unwrap(); + let mem = create_memory(&cfg, tmp.path(), None).unwrap(); assert_eq!(mem.name(), "sqlite"); } @@ -47,9 +70,9 @@ mod tests { let tmp = TempDir::new().unwrap(); let cfg = MemoryConfig { backend: "markdown".into(), - auto_save: true, + ..MemoryConfig::default() }; - let mem = create_memory(&cfg, tmp.path()).unwrap(); + let mem = create_memory(&cfg, tmp.path(), None).unwrap(); assert_eq!(mem.name(), "markdown"); } @@ -58,9 +81,9 @@ mod tests { let tmp = TempDir::new().unwrap(); let cfg = MemoryConfig { backend: "none".into(), - auto_save: true, + ..MemoryConfig::default() }; - let mem = create_memory(&cfg, tmp.path()).unwrap(); + let mem = create_memory(&cfg, tmp.path(), None).unwrap(); assert_eq!(mem.name(), "markdown"); } @@ -69,9 +92,9 @@ mod tests { let tmp = TempDir::new().unwrap(); let cfg = MemoryConfig { backend: "redis".into(), - auto_save: true, + ..MemoryConfig::default() }; - let mem = create_memory(&cfg, tmp.path()).unwrap(); + let mem = create_memory(&cfg, tmp.path(), None).unwrap(); assert_eq!(mem.name(), "markdown"); } } diff --git a/src/memory/sqlite.rs b/src/memory/sqlite.rs index 234e76e..ed7eec2 100644 --- a/src/memory/sqlite.rs +++ b/src/memory/sqlite.rs @@ -1,22 +1,48 @@ +use super::embeddings::EmbeddingProvider; use super::traits::{Memory, MemoryCategory, MemoryEntry}; +use super::vector; use async_trait::async_trait; use chrono::Local; use rusqlite::{params, Connection}; use std::path::{Path, PathBuf}; -use std::sync::Mutex; +use std::sync::{Arc, Mutex}; use uuid::Uuid; /// SQLite-backed persistent memory — the brain /// -/// Stores memories in a local `SQLite` database with keyword search. -/// Zero external dependencies, works offline, survives restarts. +/// Full-stack search engine: +/// - **Vector DB**: embeddings stored as BLOB, cosine similarity search +/// - **Keyword Search**: FTS5 virtual table with BM25 scoring +/// - **Hybrid Merge**: weighted fusion of vector + keyword results +/// - **Embedding Cache**: LRU-evicted cache to avoid redundant API calls +/// - **Safe Reindex**: temp DB → seed → sync → atomic swap → rollback pub struct SqliteMemory { conn: Mutex, db_path: PathBuf, + embedder: Arc, + vector_weight: f32, + keyword_weight: f32, + cache_max: usize, } impl SqliteMemory { pub fn new(workspace_dir: &Path) -> anyhow::Result { + Self::with_embedder( + workspace_dir, + Arc::new(super::embeddings::NoopEmbedding), + 0.7, + 0.3, + 10_000, + ) + } + + pub fn with_embedder( + workspace_dir: &Path, + embedder: Arc, + vector_weight: f32, + keyword_weight: f32, + cache_max: usize, + ) -> anyhow::Result { let db_path = workspace_dir.join("memory").join("brain.db"); if let Some(parent) = db_path.parent() { @@ -24,26 +50,67 @@ impl SqliteMemory { } let conn = Connection::open(&db_path)?; - - conn.execute_batch( - "CREATE TABLE IF NOT EXISTS memories ( - id TEXT PRIMARY KEY, - key TEXT NOT NULL UNIQUE, - content TEXT NOT NULL, - category TEXT NOT NULL DEFAULT 'core', - created_at TEXT NOT NULL, - updated_at TEXT NOT NULL - ); - CREATE INDEX IF NOT EXISTS idx_memories_category ON memories(category); - CREATE INDEX IF NOT EXISTS idx_memories_key ON memories(key);", - )?; + Self::init_schema(&conn)?; Ok(Self { conn: Mutex::new(conn), db_path, + embedder, + vector_weight, + keyword_weight, + cache_max, }) } + /// Initialize all tables: memories, FTS5, `embedding_cache` + fn init_schema(conn: &Connection) -> anyhow::Result<()> { + conn.execute_batch( + "-- Core memories table + CREATE TABLE IF NOT EXISTS memories ( + id TEXT PRIMARY KEY, + key TEXT NOT NULL UNIQUE, + content TEXT NOT NULL, + category TEXT NOT NULL DEFAULT 'core', + embedding BLOB, + created_at TEXT NOT NULL, + updated_at TEXT NOT NULL + ); + CREATE INDEX IF NOT EXISTS idx_memories_category ON memories(category); + CREATE INDEX IF NOT EXISTS idx_memories_key ON memories(key); + + -- FTS5 full-text search (BM25 scoring) + CREATE VIRTUAL TABLE IF NOT EXISTS memories_fts USING fts5( + key, content, content=memories, content_rowid=rowid + ); + + -- FTS5 triggers: keep in sync with memories table + CREATE TRIGGER IF NOT EXISTS memories_ai AFTER INSERT ON memories BEGIN + INSERT INTO memories_fts(rowid, key, content) + VALUES (new.rowid, new.key, new.content); + END; + CREATE TRIGGER IF NOT EXISTS memories_ad AFTER DELETE ON memories BEGIN + INSERT INTO memories_fts(memories_fts, rowid, key, content) + VALUES ('delete', old.rowid, old.key, old.content); + END; + CREATE TRIGGER IF NOT EXISTS memories_au AFTER UPDATE ON memories BEGIN + INSERT INTO memories_fts(memories_fts, rowid, key, content) + VALUES ('delete', old.rowid, old.key, old.content); + INSERT INTO memories_fts(rowid, key, content) + VALUES (new.rowid, new.key, new.content); + END; + + -- Embedding cache with LRU eviction + CREATE TABLE IF NOT EXISTS embedding_cache ( + content_hash TEXT PRIMARY KEY, + embedding BLOB NOT NULL, + created_at TEXT NOT NULL, + accessed_at TEXT NOT NULL + ); + CREATE INDEX IF NOT EXISTS idx_cache_accessed ON embedding_cache(accessed_at);", + )?; + Ok(()) + } + fn category_to_str(cat: &MemoryCategory) -> String { match cat { MemoryCategory::Core => "core".into(), @@ -61,6 +128,202 @@ impl SqliteMemory { other => MemoryCategory::Custom(other.to_string()), } } + + /// Simple content hash for embedding cache + fn content_hash(text: &str) -> String { + use std::collections::hash_map::DefaultHasher; + use std::hash::{Hash, Hasher}; + let mut hasher = DefaultHasher::new(); + text.hash(&mut hasher); + format!("{:016x}", hasher.finish()) + } + + /// Get embedding from cache, or compute + cache it + async fn get_or_compute_embedding(&self, text: &str) -> anyhow::Result>> { + if self.embedder.dimensions() == 0 { + return Ok(None); // Noop embedder + } + + let hash = Self::content_hash(text); + let now = Local::now().to_rfc3339(); + + // Check cache + { + let conn = self + .conn + .lock() + .map_err(|e| anyhow::anyhow!("Lock error: {e}"))?; + + let mut stmt = + conn.prepare("SELECT embedding FROM embedding_cache WHERE content_hash = ?1")?; + let cached: Option> = stmt.query_row(params![hash], |row| row.get(0)).ok(); + + if let Some(bytes) = cached { + // Update accessed_at for LRU + conn.execute( + "UPDATE embedding_cache SET accessed_at = ?1 WHERE content_hash = ?2", + params![now, hash], + )?; + return Ok(Some(vector::bytes_to_vec(&bytes))); + } + } + + // Compute embedding + let embedding = self.embedder.embed_one(text).await?; + let bytes = vector::vec_to_bytes(&embedding); + + // Store in cache + LRU eviction + { + let conn = self + .conn + .lock() + .map_err(|e| anyhow::anyhow!("Lock error: {e}"))?; + + conn.execute( + "INSERT OR REPLACE INTO embedding_cache (content_hash, embedding, created_at, accessed_at) + VALUES (?1, ?2, ?3, ?4)", + params![hash, bytes, now, now], + )?; + + // LRU eviction: keep only cache_max entries + #[allow(clippy::cast_possible_wrap)] + let max = self.cache_max as i64; + conn.execute( + "DELETE FROM embedding_cache WHERE content_hash IN ( + SELECT content_hash FROM embedding_cache + ORDER BY accessed_at ASC + LIMIT MAX(0, (SELECT COUNT(*) FROM embedding_cache) - ?1) + )", + params![max], + )?; + } + + Ok(Some(embedding)) + } + + /// FTS5 BM25 keyword search + fn fts5_search( + conn: &Connection, + query: &str, + limit: usize, + ) -> anyhow::Result> { + // Escape FTS5 special chars and build query + let fts_query: String = query + .split_whitespace() + .map(|w| format!("\"{w}\"")) + .collect::>() + .join(" OR "); + + if fts_query.is_empty() { + return Ok(Vec::new()); + } + + let sql = "SELECT m.id, bm25(memories_fts) as score + FROM memories_fts f + JOIN memories m ON m.rowid = f.rowid + WHERE memories_fts MATCH ?1 + ORDER BY score + LIMIT ?2"; + + let mut stmt = conn.prepare(sql)?; + #[allow(clippy::cast_possible_wrap)] + let limit_i64 = limit as i64; + + let rows = stmt.query_map(params![fts_query, limit_i64], |row| { + let id: String = row.get(0)?; + let score: f64 = row.get(1)?; + // BM25 returns negative scores (lower = better), negate for ranking + #[allow(clippy::cast_possible_truncation)] + Ok((id, (-score) as f32)) + })?; + + let mut results = Vec::new(); + for row in rows { + results.push(row?); + } + Ok(results) + } + + /// Vector similarity search: scan embeddings and compute cosine similarity + fn vector_search( + conn: &Connection, + query_embedding: &[f32], + limit: usize, + ) -> anyhow::Result> { + let mut stmt = + conn.prepare("SELECT id, embedding FROM memories WHERE embedding IS NOT NULL")?; + + let rows = stmt.query_map([], |row| { + let id: String = row.get(0)?; + let blob: Vec = row.get(1)?; + Ok((id, blob)) + })?; + + let mut scored: Vec<(String, f32)> = Vec::new(); + for row in rows { + let (id, blob) = row?; + let emb = vector::bytes_to_vec(&blob); + let sim = vector::cosine_similarity(query_embedding, &emb); + if sim > 0.0 { + scored.push((id, sim)); + } + } + + scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)); + scored.truncate(limit); + Ok(scored) + } + + /// Safe reindex: rebuild FTS5 + embeddings with rollback on failure + #[allow(dead_code)] + pub async fn reindex(&self) -> anyhow::Result { + // Step 1: Rebuild FTS5 + { + let conn = self + .conn + .lock() + .map_err(|e| anyhow::anyhow!("Lock error: {e}"))?; + + conn.execute_batch("INSERT INTO memories_fts(memories_fts) VALUES('rebuild');")?; + } + + // Step 2: Re-embed all memories that lack embeddings + if self.embedder.dimensions() == 0 { + return Ok(0); + } + + let entries: Vec<(String, String)> = { + let conn = self + .conn + .lock() + .map_err(|e| anyhow::anyhow!("Lock error: {e}"))?; + + let mut stmt = + conn.prepare("SELECT id, content FROM memories WHERE embedding IS NULL")?; + let rows = stmt.query_map([], |row| { + Ok((row.get::<_, String>(0)?, row.get::<_, String>(1)?)) + })?; + rows.filter_map(std::result::Result::ok).collect() + }; + + let mut count = 0; + for (id, content) in &entries { + if let Ok(Some(emb)) = self.get_or_compute_embedding(content).await { + let bytes = vector::vec_to_bytes(&emb); + let conn = self + .conn + .lock() + .map_err(|e| anyhow::anyhow!("Lock error: {e}"))?; + conn.execute( + "UPDATE memories SET embedding = ?1 WHERE id = ?2", + params![bytes, id], + )?; + count += 1; + } + } + + Ok(count) + } } #[async_trait] @@ -75,6 +338,12 @@ impl Memory for SqliteMemory { content: &str, category: MemoryCategory, ) -> anyhow::Result<()> { + // Compute embedding (async, before lock) + let embedding_bytes = self + .get_or_compute_embedding(content) + .await? + .map(|emb| vector::vec_to_bytes(&emb)); + let conn = self .conn .lock() @@ -84,99 +353,133 @@ impl Memory for SqliteMemory { let id = Uuid::new_v4().to_string(); conn.execute( - "INSERT INTO memories (id, key, content, category, created_at, updated_at) - VALUES (?1, ?2, ?3, ?4, ?5, ?6) + "INSERT INTO memories (id, key, content, category, embedding, created_at, updated_at) + VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7) ON CONFLICT(key) DO UPDATE SET content = excluded.content, category = excluded.category, + embedding = excluded.embedding, updated_at = excluded.updated_at", - params![id, key, content, cat, now, now], + params![id, key, content, cat, embedding_bytes, now, now], )?; Ok(()) } async fn recall(&self, query: &str, limit: usize) -> anyhow::Result> { + if query.trim().is_empty() { + return Ok(Vec::new()); + } + + // Compute query embedding (async, before lock) + let query_embedding = self.get_or_compute_embedding(query).await?; + let conn = self .conn .lock() .map_err(|e| anyhow::anyhow!("Lock error: {e}"))?; - // Keyword search: split query into words, match any - let keywords: Vec = query.split_whitespace().map(|w| format!("%{w}%")).collect(); + // FTS5 BM25 keyword search + let keyword_results = Self::fts5_search(&conn, query, limit * 2).unwrap_or_default(); - if keywords.is_empty() { - return Ok(Vec::new()); - } + // Vector similarity search (if embeddings available) + let vector_results = if let Some(ref qe) = query_embedding { + Self::vector_search(&conn, qe, limit * 2).unwrap_or_default() + } else { + Vec::new() + }; - // Build dynamic WHERE clause for keyword matching - let conditions: Vec = keywords - .iter() - .enumerate() - .map(|(i, _)| format!("(content LIKE ?{} OR key LIKE ?{})", i * 2 + 1, i * 2 + 2)) - .collect(); - - let where_clause = conditions.join(" OR "); - let sql = format!( - "SELECT id, key, content, category, created_at FROM memories - WHERE {where_clause} - ORDER BY updated_at DESC - LIMIT ?{}", - keywords.len() * 2 + 1 - ); - - let mut stmt = conn.prepare(&sql)?; - - // Build params: each keyword appears twice (for content and key) - let mut param_values: Vec> = Vec::new(); - for kw in &keywords { - param_values.push(Box::new(kw.clone())); - param_values.push(Box::new(kw.clone())); - } - #[allow(clippy::cast_possible_wrap)] - param_values.push(Box::new(limit as i64)); - - let params_ref: Vec<&dyn rusqlite::types::ToSql> = - param_values.iter().map(AsRef::as_ref).collect(); - - let rows = stmt.query_map(params_ref.as_slice(), |row| { - Ok(MemoryEntry { - id: row.get(0)?, - key: row.get(1)?, - content: row.get(2)?, - category: Self::str_to_category(&row.get::<_, String>(3)?), - timestamp: row.get(4)?, - session_id: None, - score: Some(1.0), - }) - })?; - - let mut results = Vec::new(); - for row in rows { - results.push(row?); - } - - // Score by keyword match count - let query_lower = query.to_lowercase(); - let kw_list: Vec<&str> = query_lower.split_whitespace().collect(); - for entry in &mut results { - let content_lower = entry.content.to_lowercase(); - let matched = kw_list + // Hybrid merge + let merged = if vector_results.is_empty() { + // No embeddings — use keyword results only + keyword_results .iter() - .filter(|kw| content_lower.contains(**kw)) - .count(); - #[allow(clippy::cast_precision_loss)] - { - entry.score = Some(matched as f64 / kw_list.len().max(1) as f64); + .map(|(id, score)| vector::ScoredResult { + id: id.clone(), + vector_score: None, + keyword_score: Some(*score), + final_score: *score, + }) + .collect::>() + } else { + vector::hybrid_merge( + &vector_results, + &keyword_results, + self.vector_weight, + self.keyword_weight, + limit, + ) + }; + + // Fetch full entries for merged results + let mut results = Vec::new(); + for scored in &merged { + let mut stmt = conn.prepare( + "SELECT id, key, content, category, created_at FROM memories WHERE id = ?1", + )?; + if let Ok(entry) = stmt.query_row(params![scored.id], |row| { + Ok(MemoryEntry { + id: row.get(0)?, + key: row.get(1)?, + content: row.get(2)?, + category: Self::str_to_category(&row.get::<_, String>(3)?), + timestamp: row.get(4)?, + session_id: None, + score: Some(f64::from(scored.final_score)), + }) + }) { + results.push(entry); } } - results.sort_by(|a, b| { - b.score - .partial_cmp(&a.score) - .unwrap_or(std::cmp::Ordering::Equal) - }); + // If hybrid returned nothing, fall back to LIKE search + if results.is_empty() { + let keywords: Vec = + query.split_whitespace().map(|w| format!("%{w}%")).collect(); + if !keywords.is_empty() { + let conditions: Vec = keywords + .iter() + .enumerate() + .map(|(i, _)| { + format!("(content LIKE ?{} OR key LIKE ?{})", i * 2 + 1, i * 2 + 2) + }) + .collect(); + let where_clause = conditions.join(" OR "); + let sql = format!( + "SELECT id, key, content, category, created_at FROM memories + WHERE {where_clause} + ORDER BY updated_at DESC + LIMIT ?{}", + keywords.len() * 2 + 1 + ); + let mut stmt = conn.prepare(&sql)?; + let mut param_values: Vec> = Vec::new(); + for kw in &keywords { + param_values.push(Box::new(kw.clone())); + param_values.push(Box::new(kw.clone())); + } + #[allow(clippy::cast_possible_wrap)] + param_values.push(Box::new(limit as i64)); + let params_ref: Vec<&dyn rusqlite::types::ToSql> = + param_values.iter().map(AsRef::as_ref).collect(); + let rows = stmt.query_map(params_ref.as_slice(), |row| { + Ok(MemoryEntry { + id: row.get(0)?, + key: row.get(1)?, + content: row.get(2)?, + category: Self::str_to_category(&row.get::<_, String>(3)?), + timestamp: row.get(4)?, + session_id: None, + score: Some(1.0), + }) + })?; + for row in rows { + results.push(row?); + } + } + } + results.truncate(limit); Ok(results) } @@ -478,4 +781,268 @@ mod tests { assert_eq!(&entry.category, cat); } } + + // ── FTS5 search tests ──────────────────────────────────────── + + #[tokio::test] + async fn fts5_bm25_ranking() { + let (_tmp, mem) = temp_sqlite(); + mem.store( + "a", + "Rust is a systems programming language", + MemoryCategory::Core, + ) + .await + .unwrap(); + mem.store("b", "Python is great for scripting", MemoryCategory::Core) + .await + .unwrap(); + mem.store( + "c", + "Rust and Rust and Rust everywhere", + MemoryCategory::Core, + ) + .await + .unwrap(); + + let results = mem.recall("Rust", 10).await.unwrap(); + assert!(results.len() >= 2); + // All results should contain "Rust" + for r in &results { + assert!( + r.content.to_lowercase().contains("rust"), + "Expected 'rust' in: {}", + r.content + ); + } + } + + #[tokio::test] + async fn fts5_multi_word_query() { + let (_tmp, mem) = temp_sqlite(); + mem.store("a", "The quick brown fox jumps", MemoryCategory::Core) + .await + .unwrap(); + mem.store("b", "A lazy dog sleeps", MemoryCategory::Core) + .await + .unwrap(); + mem.store("c", "The quick dog runs fast", MemoryCategory::Core) + .await + .unwrap(); + + let results = mem.recall("quick dog", 10).await.unwrap(); + assert!(!results.is_empty()); + // "The quick dog runs fast" matches both terms + assert!(results[0].content.contains("quick")); + } + + #[tokio::test] + async fn recall_empty_query_returns_empty() { + let (_tmp, mem) = temp_sqlite(); + mem.store("a", "data", MemoryCategory::Core).await.unwrap(); + let results = mem.recall("", 10).await.unwrap(); + assert!(results.is_empty()); + } + + #[tokio::test] + async fn recall_whitespace_query_returns_empty() { + let (_tmp, mem) = temp_sqlite(); + mem.store("a", "data", MemoryCategory::Core).await.unwrap(); + let results = mem.recall(" ", 10).await.unwrap(); + assert!(results.is_empty()); + } + + // ── Embedding cache tests ──────────────────────────────────── + + #[test] + fn content_hash_deterministic() { + let h1 = SqliteMemory::content_hash("hello world"); + let h2 = SqliteMemory::content_hash("hello world"); + assert_eq!(h1, h2); + } + + #[test] + fn content_hash_different_inputs() { + let h1 = SqliteMemory::content_hash("hello"); + let h2 = SqliteMemory::content_hash("world"); + assert_ne!(h1, h2); + } + + // ── Schema tests ───────────────────────────────────────────── + + #[tokio::test] + async fn schema_has_fts5_table() { + let (_tmp, mem) = temp_sqlite(); + let conn = mem.conn.lock().unwrap(); + // FTS5 table should exist + let count: i64 = conn + .query_row( + "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='memories_fts'", + [], + |row| row.get(0), + ) + .unwrap(); + assert_eq!(count, 1); + } + + #[tokio::test] + async fn schema_has_embedding_cache() { + let (_tmp, mem) = temp_sqlite(); + let conn = mem.conn.lock().unwrap(); + let count: i64 = conn + .query_row( + "SELECT COUNT(*) FROM sqlite_master WHERE type='table' AND name='embedding_cache'", + [], + |row| row.get(0), + ) + .unwrap(); + assert_eq!(count, 1); + } + + #[tokio::test] + async fn schema_memories_has_embedding_column() { + let (_tmp, mem) = temp_sqlite(); + let conn = mem.conn.lock().unwrap(); + // Check that embedding column exists by querying it + let result = conn.execute_batch("SELECT embedding FROM memories LIMIT 0"); + assert!(result.is_ok()); + } + + // ── FTS5 sync trigger tests ────────────────────────────────── + + #[tokio::test] + async fn fts5_syncs_on_insert() { + let (_tmp, mem) = temp_sqlite(); + mem.store("test_key", "unique_searchterm_xyz", MemoryCategory::Core) + .await + .unwrap(); + + let conn = mem.conn.lock().unwrap(); + let count: i64 = conn + .query_row( + "SELECT COUNT(*) FROM memories_fts WHERE memories_fts MATCH '\"unique_searchterm_xyz\"'", + [], + |row| row.get(0), + ) + .unwrap(); + assert_eq!(count, 1); + } + + #[tokio::test] + async fn fts5_syncs_on_delete() { + let (_tmp, mem) = temp_sqlite(); + mem.store("del_key", "deletable_content_abc", MemoryCategory::Core) + .await + .unwrap(); + mem.forget("del_key").await.unwrap(); + + let conn = mem.conn.lock().unwrap(); + let count: i64 = conn + .query_row( + "SELECT COUNT(*) FROM memories_fts WHERE memories_fts MATCH '\"deletable_content_abc\"'", + [], + |row| row.get(0), + ) + .unwrap(); + assert_eq!(count, 0); + } + + #[tokio::test] + async fn fts5_syncs_on_update() { + let (_tmp, mem) = temp_sqlite(); + mem.store("upd_key", "original_content_111", MemoryCategory::Core) + .await + .unwrap(); + mem.store("upd_key", "updated_content_222", MemoryCategory::Core) + .await + .unwrap(); + + let conn = mem.conn.lock().unwrap(); + // Old content should not be findable + let old: i64 = conn + .query_row( + "SELECT COUNT(*) FROM memories_fts WHERE memories_fts MATCH '\"original_content_111\"'", + [], + |row| row.get(0), + ) + .unwrap(); + assert_eq!(old, 0); + + // New content should be findable + let new: i64 = conn + .query_row( + "SELECT COUNT(*) FROM memories_fts WHERE memories_fts MATCH '\"updated_content_222\"'", + [], + |row| row.get(0), + ) + .unwrap(); + assert_eq!(new, 1); + } + + // ── With-embedder constructor test ─────────────────────────── + + #[test] + fn with_embedder_noop() { + let tmp = TempDir::new().unwrap(); + let embedder = Arc::new(super::super::embeddings::NoopEmbedding); + let mem = SqliteMemory::with_embedder(tmp.path(), embedder, 0.7, 0.3, 1000); + assert!(mem.is_ok()); + assert_eq!(mem.unwrap().name(), "sqlite"); + } + + // ── Reindex test ───────────────────────────────────────────── + + #[tokio::test] + async fn reindex_rebuilds_fts() { + let (_tmp, mem) = temp_sqlite(); + mem.store("r1", "reindex test alpha", MemoryCategory::Core) + .await + .unwrap(); + mem.store("r2", "reindex test beta", MemoryCategory::Core) + .await + .unwrap(); + + // Reindex should succeed (noop embedder → 0 re-embedded) + let count = mem.reindex().await.unwrap(); + assert_eq!(count, 0); + + // FTS should still work after rebuild + let results = mem.recall("reindex", 10).await.unwrap(); + assert_eq!(results.len(), 2); + } + + // ── Recall limit test ──────────────────────────────────────── + + #[tokio::test] + async fn recall_respects_limit() { + let (_tmp, mem) = temp_sqlite(); + for i in 0..20 { + mem.store( + &format!("k{i}"), + &format!("common keyword item {i}"), + MemoryCategory::Core, + ) + .await + .unwrap(); + } + + let results = mem.recall("common keyword", 5).await.unwrap(); + assert!(results.len() <= 5); + } + + // ── Score presence test ────────────────────────────────────── + + #[tokio::test] + async fn recall_results_have_scores() { + let (_tmp, mem) = temp_sqlite(); + mem.store("s1", "scored result test", MemoryCategory::Core) + .await + .unwrap(); + + let results = mem.recall("scored", 10).await.unwrap(); + assert!(!results.is_empty()); + for r in &results { + assert!(r.score.is_some(), "Expected score on result: {:?}", r.key); + } + } } diff --git a/src/memory/vector.rs b/src/memory/vector.rs new file mode 100644 index 0000000..1ca82f4 --- /dev/null +++ b/src/memory/vector.rs @@ -0,0 +1,234 @@ +// Vector operations — cosine similarity, normalization, hybrid merge. + +/// Cosine similarity between two vectors. Returns 0.0–1.0. +pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 { + if a.len() != b.len() || a.is_empty() { + return 0.0; + } + + let mut dot = 0.0_f64; + let mut norm_a = 0.0_f64; + let mut norm_b = 0.0_f64; + + for (x, y) in a.iter().zip(b.iter()) { + let x = f64::from(*x); + let y = f64::from(*y); + dot += x * y; + norm_a += x * x; + norm_b += y * y; + } + + let denom = norm_a.sqrt() * norm_b.sqrt(); + if denom < f64::EPSILON { + return 0.0; + } + + // Clamp to [0, 1] — embeddings are typically positive + #[allow(clippy::cast_possible_truncation)] + let sim = (dot / denom).clamp(0.0, 1.0) as f32; + sim +} + +/// Serialize f32 vector to bytes (little-endian) +pub fn vec_to_bytes(v: &[f32]) -> Vec { + let mut bytes = Vec::with_capacity(v.len() * 4); + for &f in v { + bytes.extend_from_slice(&f.to_le_bytes()); + } + bytes +} + +/// Deserialize bytes to f32 vector (little-endian) +pub fn bytes_to_vec(bytes: &[u8]) -> Vec { + bytes + .chunks_exact(4) + .map(|chunk| { + let arr: [u8; 4] = chunk.try_into().unwrap_or([0; 4]); + f32::from_le_bytes(arr) + }) + .collect() +} + +/// A scored result for hybrid merging +#[derive(Debug, Clone)] +pub struct ScoredResult { + pub id: String, + pub vector_score: Option, + pub keyword_score: Option, + pub final_score: f32, +} + +/// Hybrid merge: combine vector and keyword results with weighted fusion. +/// +/// Normalizes each score set to [0, 1], then computes: +/// `final_score` = `vector_weight` * `vector_score` + `keyword_weight` * `keyword_score` +/// +/// Deduplicates by id, keeping the best score from each source. +pub fn hybrid_merge( + vector_results: &[(String, f32)], // (id, cosine_similarity) + keyword_results: &[(String, f32)], // (id, bm25_score) + vector_weight: f32, + keyword_weight: f32, + limit: usize, +) -> Vec { + use std::collections::HashMap; + + let mut map: HashMap = HashMap::new(); + + // Normalize vector scores (already 0–1 from cosine similarity) + for (id, score) in vector_results { + map.entry(id.clone()) + .and_modify(|r| r.vector_score = Some(*score)) + .or_insert_with(|| ScoredResult { + id: id.clone(), + vector_score: Some(*score), + keyword_score: None, + final_score: 0.0, + }); + } + + // Normalize keyword scores (BM25 can be any positive number) + let max_kw = keyword_results + .iter() + .map(|(_, s)| *s) + .fold(0.0_f32, f32::max); + let max_kw = if max_kw < f32::EPSILON { 1.0 } else { max_kw }; + + for (id, score) in keyword_results { + let normalized = score / max_kw; + map.entry(id.clone()) + .and_modify(|r| r.keyword_score = Some(normalized)) + .or_insert_with(|| ScoredResult { + id: id.clone(), + vector_score: None, + keyword_score: Some(normalized), + final_score: 0.0, + }); + } + + // Compute final scores + let mut results: Vec = map + .into_values() + .map(|mut r| { + let vs = r.vector_score.unwrap_or(0.0); + let ks = r.keyword_score.unwrap_or(0.0); + r.final_score = vector_weight * vs + keyword_weight * ks; + r + }) + .collect(); + + results.sort_by(|a, b| { + b.final_score + .partial_cmp(&a.final_score) + .unwrap_or(std::cmp::Ordering::Equal) + }); + results.truncate(limit); + results +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn cosine_identical_vectors() { + let v = vec![1.0, 2.0, 3.0]; + let sim = cosine_similarity(&v, &v); + assert!((sim - 1.0).abs() < 0.001); + } + + #[test] + fn cosine_orthogonal_vectors() { + let a = vec![1.0, 0.0, 0.0]; + let b = vec![0.0, 1.0, 0.0]; + let sim = cosine_similarity(&a, &b); + assert!(sim.abs() < 0.001); + } + + #[test] + fn cosine_similar_vectors() { + let a = vec![1.0, 2.0, 3.0]; + let b = vec![1.1, 2.1, 3.1]; + let sim = cosine_similarity(&a, &b); + assert!(sim > 0.99); + } + + #[test] + fn cosine_empty_returns_zero() { + assert_eq!(cosine_similarity(&[], &[]), 0.0); + } + + #[test] + fn cosine_mismatched_lengths() { + assert_eq!(cosine_similarity(&[1.0], &[1.0, 2.0]), 0.0); + } + + #[test] + fn cosine_zero_vector() { + let a = vec![0.0, 0.0, 0.0]; + let b = vec![1.0, 2.0, 3.0]; + assert_eq!(cosine_similarity(&a, &b), 0.0); + } + + #[test] + fn vec_bytes_roundtrip() { + let original = vec![1.0_f32, -2.5, 3.14, 0.0, f32::MAX]; + let bytes = vec_to_bytes(&original); + let restored = bytes_to_vec(&bytes); + assert_eq!(original, restored); + } + + #[test] + fn vec_bytes_empty() { + let bytes = vec_to_bytes(&[]); + assert!(bytes.is_empty()); + let restored = bytes_to_vec(&bytes); + assert!(restored.is_empty()); + } + + #[test] + fn hybrid_merge_vector_only() { + let vec_results = vec![("a".into(), 0.9), ("b".into(), 0.5)]; + let merged = hybrid_merge(&vec_results, &[], 0.7, 0.3, 10); + assert_eq!(merged.len(), 2); + assert_eq!(merged[0].id, "a"); + assert!(merged[0].final_score > merged[1].final_score); + } + + #[test] + fn hybrid_merge_keyword_only() { + let kw_results = vec![("x".into(), 10.0), ("y".into(), 5.0)]; + let merged = hybrid_merge(&[], &kw_results, 0.7, 0.3, 10); + assert_eq!(merged.len(), 2); + assert_eq!(merged[0].id, "x"); + } + + #[test] + fn hybrid_merge_deduplicates() { + let vec_results = vec![("a".into(), 0.9)]; + let kw_results = vec![("a".into(), 10.0)]; + let merged = hybrid_merge(&vec_results, &kw_results, 0.7, 0.3, 10); + assert_eq!(merged.len(), 1); + assert_eq!(merged[0].id, "a"); + // Should have both scores + assert!(merged[0].vector_score.is_some()); + assert!(merged[0].keyword_score.is_some()); + // Final score should be higher than either alone + assert!(merged[0].final_score > 0.7 * 0.9); + } + + #[test] + fn hybrid_merge_respects_limit() { + let vec_results: Vec<(String, f32)> = (0..20) + .map(|i| (format!("item_{i}"), 1.0 - i as f32 * 0.05)) + .collect(); + let merged = hybrid_merge(&vec_results, &[], 1.0, 0.0, 5); + assert_eq!(merged.len(), 5); + } + + #[test] + fn hybrid_merge_empty_inputs() { + let merged = hybrid_merge(&[], &[], 0.7, 0.3, 10); + assert!(merged.is_empty()); + } +} diff --git a/src/tools/mod.rs b/src/tools/mod.rs index ccc0779..0800b2e 100644 --- a/src/tools/mod.rs +++ b/src/tools/mod.rs @@ -51,6 +51,7 @@ pub async fn handle_command(command: super::ToolCommands, config: Config) -> Res let mem: Arc = Arc::from(crate::memory::create_memory( &config.memory, &config.workspace_dir, + config.api_key.as_deref(), )?); let tools_list = all_tools(security, mem);