diff --git a/src/memory/chunker.rs b/src/memory/chunker.rs index 23cf2f1..d45eb4b 100644 --- a/src/memory/chunker.rs +++ b/src/memory/chunker.rs @@ -256,4 +256,112 @@ mod tests { let chunks = chunk_markdown(text, 512); assert_eq!(chunks.len(), 1); } + + // ── Edge cases ─────────────────────────────────────────────── + + #[test] + fn headings_only_no_body() { + let text = "# Title\n## Section A\n## Section B\n### Subsection"; + let chunks = chunk_markdown(text, 512); + // Should produce chunks for each heading (even with empty bodies) + assert!(!chunks.is_empty()); + } + + #[test] + fn deeply_nested_headings_ignored() { + // #### and deeper are NOT treated as heading splits + let text = "# Top\nIntro\n#### Deep heading\nDeep content"; + let chunks = chunk_markdown(text, 512); + // "#### Deep heading" should stay with its parent section + assert!(!chunks.is_empty()); + let all_content: String = chunks.iter().map(|c| c.content.clone()).collect(); + assert!(all_content.contains("Deep heading")); + assert!(all_content.contains("Deep content")); + } + + #[test] + fn very_long_single_line_no_newlines() { + // One giant line with no newlines — can't split on lines effectively + let text = "word ".repeat(5000); + let chunks = chunk_markdown(&text, 50); + // Should produce at least 1 chunk without panicking + assert!(!chunks.is_empty()); + } + + #[test] + fn only_newlines_and_whitespace() { + assert!(chunk_markdown("\n\n\n \n\n", 512).is_empty()); + } + + #[test] + fn max_tokens_zero() { + // max_tokens=0 → max_chars=0, should not panic or infinite loop + let chunks = chunk_markdown("Hello world", 0); + // Every chunk will exceed 0 chars, so it splits maximally + assert!(!chunks.is_empty()); + } + + #[test] + fn max_tokens_one() { + // max_tokens=1 → max_chars=4, very aggressive splitting + let text = "Line one\nLine two\nLine three"; + let chunks = chunk_markdown(text, 1); + assert!(!chunks.is_empty()); + } + + #[test] + fn unicode_content() { + let text = "# 日本語\nこんにちは世界\n\n## Émojis\n🦀 Rust is great 🚀"; + let chunks = chunk_markdown(text, 512); + assert!(!chunks.is_empty()); + let all: String = chunks.iter().map(|c| c.content.clone()).collect(); + assert!(all.contains("こんにちは")); + assert!(all.contains("🦀")); + } + + #[test] + fn fts5_special_chars_in_content() { + let text = "Content with \"quotes\" and (parentheses) and * asterisks *"; + let chunks = chunk_markdown(text, 512); + assert_eq!(chunks.len(), 1); + assert!(chunks[0].content.contains("\"quotes\"")); + } + + #[test] + fn multiple_blank_lines_between_paragraphs() { + let text = "Paragraph one.\n\n\n\n\nParagraph two.\n\n\n\nParagraph three."; + let chunks = chunk_markdown(text, 512); + assert_eq!(chunks.len(), 1); // All fits in one chunk + assert!(chunks[0].content.contains("Paragraph one")); + assert!(chunks[0].content.contains("Paragraph three")); + } + + #[test] + fn heading_at_end_of_text() { + let text = "Some content\n# Trailing Heading"; + let chunks = chunk_markdown(text, 512); + assert!(!chunks.is_empty()); + } + + #[test] + fn single_heading_no_content() { + let text = "# Just a heading"; + let chunks = chunk_markdown(text, 512); + assert_eq!(chunks.len(), 1); + assert_eq!(chunks[0].heading.as_deref(), Some("# Just a heading")); + } + + #[test] + fn no_content_loss() { + let text = "# A\nContent A line 1\nContent A line 2\n\n## B\nContent B\n\n## C\nContent C"; + let chunks = chunk_markdown(text, 512); + let reassembled: String = chunks.iter().map(|c| format!("{}\n", c.content)).collect(); + // All original content words should appear + for word in ["Content", "line", "1", "2"] { + assert!( + reassembled.contains(word), + "Missing word '{word}' in reassembled chunks" + ); + } + } } diff --git a/src/memory/embeddings.rs b/src/memory/embeddings.rs index 882082b..270ebfe 100644 --- a/src/memory/embeddings.rs +++ b/src/memory/embeddings.rs @@ -187,4 +187,66 @@ mod tests { assert_eq!(p.name(), "openai"); // uses OpenAiEmbedding internally assert_eq!(p.dimensions(), 768); } + + // ── Edge cases ─────────────────────────────────────────────── + + #[tokio::test] + async fn noop_embed_one_returns_error() { + let p = NoopEmbedding; + // embed returns empty vec → pop() returns None → error + let result = p.embed_one("hello").await; + assert!(result.is_err()); + } + + #[tokio::test] + async fn noop_embed_empty_batch() { + let p = NoopEmbedding; + let result = p.embed(&[]).await.unwrap(); + assert!(result.is_empty()); + } + + #[tokio::test] + async fn noop_embed_multiple_texts() { + let p = NoopEmbedding; + let result = p.embed(&["a", "b", "c"]).await.unwrap(); + assert!(result.is_empty()); + } + + #[test] + fn factory_empty_string_returns_noop() { + let p = create_embedding_provider("", None, "model", 1536); + assert_eq!(p.name(), "none"); + } + + #[test] + fn factory_unknown_provider_returns_noop() { + let p = create_embedding_provider("cohere", None, "model", 1536); + assert_eq!(p.name(), "none"); + } + + #[test] + fn factory_custom_empty_url() { + // "custom:" with no URL — should still construct without panic + let p = create_embedding_provider("custom:", None, "model", 768); + assert_eq!(p.name(), "openai"); + } + + #[test] + fn factory_openai_no_api_key() { + let p = create_embedding_provider("openai", None, "text-embedding-3-small", 1536); + assert_eq!(p.name(), "openai"); + assert_eq!(p.dimensions(), 1536); + } + + #[test] + fn openai_trailing_slash_stripped() { + let p = OpenAiEmbedding::new("https://api.openai.com/", "key", "model", 1536); + assert_eq!(p.base_url, "https://api.openai.com"); + } + + #[test] + fn openai_dimensions_custom() { + let p = OpenAiEmbedding::new("http://localhost", "k", "m", 384); + assert_eq!(p.dimensions(), 384); + } } diff --git a/src/memory/sqlite.rs b/src/memory/sqlite.rs index ed7eec2..8b17ed5 100644 --- a/src/memory/sqlite.rs +++ b/src/memory/sqlite.rs @@ -1045,4 +1045,321 @@ mod tests { assert!(r.score.is_some(), "Expected score on result: {:?}", r.key); } } + + // ── Edge cases: FTS5 special characters ────────────────────── + + #[tokio::test] + async fn recall_with_quotes_in_query() { + let (_tmp, mem) = temp_sqlite(); + mem.store("q1", "He said hello world", MemoryCategory::Core) + .await + .unwrap(); + // Quotes in query should not crash FTS5 + let results = mem.recall("\"hello\"", 10).await.unwrap(); + // May or may not match depending on FTS5 escaping, but must not error + assert!(results.len() <= 10); + } + + #[tokio::test] + async fn recall_with_asterisk_in_query() { + let (_tmp, mem) = temp_sqlite(); + mem.store("a1", "wildcard test content", MemoryCategory::Core) + .await + .unwrap(); + let results = mem.recall("wild*", 10).await.unwrap(); + assert!(results.len() <= 10); + } + + #[tokio::test] + async fn recall_with_parentheses_in_query() { + let (_tmp, mem) = temp_sqlite(); + mem.store("p1", "function call test", MemoryCategory::Core) + .await + .unwrap(); + let results = mem.recall("function()", 10).await.unwrap(); + assert!(results.len() <= 10); + } + + #[tokio::test] + async fn recall_with_sql_injection_attempt() { + let (_tmp, mem) = temp_sqlite(); + mem.store("safe", "normal content", MemoryCategory::Core) + .await + .unwrap(); + // Should not crash or leak data + let results = mem.recall("'; DROP TABLE memories; --", 10).await.unwrap(); + assert!(results.len() <= 10); + // Table should still exist + assert_eq!(mem.count().await.unwrap(), 1); + } + + // ── Edge cases: store ──────────────────────────────────────── + + #[tokio::test] + async fn store_empty_content() { + let (_tmp, mem) = temp_sqlite(); + mem.store("empty", "", MemoryCategory::Core).await.unwrap(); + let entry = mem.get("empty").await.unwrap().unwrap(); + assert_eq!(entry.content, ""); + } + + #[tokio::test] + async fn store_empty_key() { + let (_tmp, mem) = temp_sqlite(); + mem.store("", "content for empty key", MemoryCategory::Core) + .await + .unwrap(); + let entry = mem.get("").await.unwrap().unwrap(); + assert_eq!(entry.content, "content for empty key"); + } + + #[tokio::test] + async fn store_very_long_content() { + let (_tmp, mem) = temp_sqlite(); + let long_content = "x".repeat(100_000); + mem.store("long", &long_content, MemoryCategory::Core) + .await + .unwrap(); + let entry = mem.get("long").await.unwrap().unwrap(); + assert_eq!(entry.content.len(), 100_000); + } + + #[tokio::test] + async fn store_unicode_and_emoji() { + let (_tmp, mem) = temp_sqlite(); + mem.store("emoji_key_🦀", "こんにちは 🚀 Ñoño", MemoryCategory::Core) + .await + .unwrap(); + let entry = mem.get("emoji_key_🦀").await.unwrap().unwrap(); + assert_eq!(entry.content, "こんにちは 🚀 Ñoño"); + } + + #[tokio::test] + async fn store_content_with_newlines_and_tabs() { + let (_tmp, mem) = temp_sqlite(); + let content = "line1\nline2\ttab\rcarriage\n\nnewparagraph"; + mem.store("whitespace", content, MemoryCategory::Core) + .await + .unwrap(); + let entry = mem.get("whitespace").await.unwrap().unwrap(); + assert_eq!(entry.content, content); + } + + // ── Edge cases: recall ─────────────────────────────────────── + + #[tokio::test] + async fn recall_single_character_query() { + let (_tmp, mem) = temp_sqlite(); + mem.store("a", "x marks the spot", MemoryCategory::Core) + .await + .unwrap(); + // Single char may not match FTS5 but LIKE fallback should work + let results = mem.recall("x", 10).await.unwrap(); + // Should not crash; may or may not find results + assert!(results.len() <= 10); + } + + #[tokio::test] + async fn recall_limit_zero() { + let (_tmp, mem) = temp_sqlite(); + mem.store("a", "some content", MemoryCategory::Core) + .await + .unwrap(); + let results = mem.recall("some", 0).await.unwrap(); + assert!(results.is_empty()); + } + + #[tokio::test] + async fn recall_limit_one() { + let (_tmp, mem) = temp_sqlite(); + mem.store("a", "matching content alpha", MemoryCategory::Core) + .await + .unwrap(); + mem.store("b", "matching content beta", MemoryCategory::Core) + .await + .unwrap(); + let results = mem.recall("matching content", 1).await.unwrap(); + assert_eq!(results.len(), 1); + } + + #[tokio::test] + async fn recall_matches_by_key_not_just_content() { + let (_tmp, mem) = temp_sqlite(); + mem.store( + "rust_preferences", + "User likes systems programming", + MemoryCategory::Core, + ) + .await + .unwrap(); + // "rust" appears in key but not content — LIKE fallback checks key too + let results = mem.recall("rust", 10).await.unwrap(); + assert!(!results.is_empty(), "Should match by key"); + } + + #[tokio::test] + async fn recall_unicode_query() { + let (_tmp, mem) = temp_sqlite(); + mem.store("jp", "日本語のテスト", MemoryCategory::Core) + .await + .unwrap(); + let results = mem.recall("日本語", 10).await.unwrap(); + assert!(!results.is_empty()); + } + + // ── Edge cases: schema idempotency ─────────────────────────── + + #[tokio::test] + async fn schema_idempotent_reopen() { + let tmp = TempDir::new().unwrap(); + { + let mem = SqliteMemory::new(tmp.path()).unwrap(); + mem.store("k1", "v1", MemoryCategory::Core).await.unwrap(); + } + // Open again — init_schema runs again on existing DB + let mem2 = SqliteMemory::new(tmp.path()).unwrap(); + let entry = mem2.get("k1").await.unwrap(); + assert!(entry.is_some()); + assert_eq!(entry.unwrap().content, "v1"); + // Store more data — should work fine + mem2.store("k2", "v2", MemoryCategory::Daily).await.unwrap(); + assert_eq!(mem2.count().await.unwrap(), 2); + } + + #[tokio::test] + async fn schema_triple_open() { + let tmp = TempDir::new().unwrap(); + let _m1 = SqliteMemory::new(tmp.path()).unwrap(); + let _m2 = SqliteMemory::new(tmp.path()).unwrap(); + let m3 = SqliteMemory::new(tmp.path()).unwrap(); + assert!(m3.health_check().await); + } + + // ── Edge cases: forget + FTS5 consistency ──────────────────── + + #[tokio::test] + async fn forget_then_recall_no_ghost_results() { + let (_tmp, mem) = temp_sqlite(); + mem.store("ghost", "phantom memory content", MemoryCategory::Core) + .await + .unwrap(); + mem.forget("ghost").await.unwrap(); + let results = mem.recall("phantom memory", 10).await.unwrap(); + assert!( + results.is_empty(), + "Deleted memory should not appear in recall" + ); + } + + #[tokio::test] + async fn forget_and_re_store_same_key() { + let (_tmp, mem) = temp_sqlite(); + mem.store("cycle", "version 1", MemoryCategory::Core) + .await + .unwrap(); + mem.forget("cycle").await.unwrap(); + mem.store("cycle", "version 2", MemoryCategory::Core) + .await + .unwrap(); + let entry = mem.get("cycle").await.unwrap().unwrap(); + assert_eq!(entry.content, "version 2"); + assert_eq!(mem.count().await.unwrap(), 1); + } + + // ── Edge cases: reindex ────────────────────────────────────── + + #[tokio::test] + async fn reindex_empty_db() { + let (_tmp, mem) = temp_sqlite(); + let count = mem.reindex().await.unwrap(); + assert_eq!(count, 0); + } + + #[tokio::test] + async fn reindex_twice_is_safe() { + let (_tmp, mem) = temp_sqlite(); + mem.store("r1", "reindex data", MemoryCategory::Core) + .await + .unwrap(); + mem.reindex().await.unwrap(); + let count = mem.reindex().await.unwrap(); + assert_eq!(count, 0); // Noop embedder → nothing to re-embed + // Data should still be intact + let results = mem.recall("reindex", 10).await.unwrap(); + assert_eq!(results.len(), 1); + } + + // ── Edge cases: content_hash ───────────────────────────────── + + #[test] + fn content_hash_empty_string() { + let h = SqliteMemory::content_hash(""); + assert!(!h.is_empty()); + assert_eq!(h.len(), 16); // 16 hex chars + } + + #[test] + fn content_hash_unicode() { + let h1 = SqliteMemory::content_hash("🦀"); + let h2 = SqliteMemory::content_hash("🦀"); + assert_eq!(h1, h2); + let h3 = SqliteMemory::content_hash("🚀"); + assert_ne!(h1, h3); + } + + #[test] + fn content_hash_long_input() { + let long = "a".repeat(1_000_000); + let h = SqliteMemory::content_hash(&long); + assert_eq!(h.len(), 16); + } + + // ── Edge cases: category helpers ───────────────────────────── + + #[test] + fn category_roundtrip_custom_with_spaces() { + let cat = MemoryCategory::Custom("my custom category".into()); + let s = SqliteMemory::category_to_str(&cat); + assert_eq!(s, "my custom category"); + let back = SqliteMemory::str_to_category(&s); + assert_eq!(back, cat); + } + + #[test] + fn category_roundtrip_empty_custom() { + let cat = MemoryCategory::Custom(String::new()); + let s = SqliteMemory::category_to_str(&cat); + assert_eq!(s, ""); + let back = SqliteMemory::str_to_category(&s); + assert_eq!(back, MemoryCategory::Custom(String::new())); + } + + // ── Edge cases: list ───────────────────────────────────────── + + #[tokio::test] + async fn list_custom_category() { + let (_tmp, mem) = temp_sqlite(); + mem.store("c1", "custom1", MemoryCategory::Custom("project".into())) + .await + .unwrap(); + mem.store("c2", "custom2", MemoryCategory::Custom("project".into())) + .await + .unwrap(); + mem.store("c3", "other", MemoryCategory::Core) + .await + .unwrap(); + + let project = mem + .list(Some(&MemoryCategory::Custom("project".into()))) + .await + .unwrap(); + assert_eq!(project.len(), 2); + } + + #[tokio::test] + async fn list_empty_db() { + let (_tmp, mem) = temp_sqlite(); + let all = mem.list(None).await.unwrap(); + assert!(all.is_empty()); + } } diff --git a/src/memory/vector.rs b/src/memory/vector.rs index 1ca82f4..b7a7d79 100644 --- a/src/memory/vector.rs +++ b/src/memory/vector.rs @@ -19,13 +19,18 @@ pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 { } let denom = norm_a.sqrt() * norm_b.sqrt(); - if denom < f64::EPSILON { + if !denom.is_finite() || denom < f64::EPSILON { + return 0.0; + } + + let raw = dot / denom; + if !raw.is_finite() { return 0.0; } // Clamp to [0, 1] — embeddings are typically positive #[allow(clippy::cast_possible_truncation)] - let sim = (dot / denom).clamp(0.0, 1.0) as f32; + let sim = raw.clamp(0.0, 1.0) as f32; sim } @@ -231,4 +236,159 @@ mod tests { let merged = hybrid_merge(&[], &[], 0.7, 0.3, 10); assert!(merged.is_empty()); } + + // ── Edge cases: cosine similarity ──────────────────────────── + + #[test] + fn cosine_nan_returns_zero() { + let a = vec![f32::NAN, 1.0, 2.0]; + let b = vec![1.0, 2.0, 3.0]; + let sim = cosine_similarity(&a, &b); + // NaN propagates through arithmetic — result should be 0.0 (clamped or denom check) + assert!(sim.is_finite(), "Expected finite, got {sim}"); + } + + #[test] + fn cosine_infinity_returns_zero_or_finite() { + let a = vec![f32::INFINITY, 1.0]; + let b = vec![1.0, 2.0]; + let sim = cosine_similarity(&a, &b); + assert!(sim.is_finite(), "Expected finite, got {sim}"); + } + + #[test] + fn cosine_negative_values() { + let a = vec![-1.0, -2.0, -3.0]; + let b = vec![-1.0, -2.0, -3.0]; + // Identical negative vectors → cosine = 1.0, but clamped to [0,1] + let sim = cosine_similarity(&a, &b); + assert!((sim - 1.0).abs() < 0.001); + } + + #[test] + fn cosine_opposite_vectors_clamped() { + let a = vec![1.0, 0.0]; + let b = vec![-1.0, 0.0]; + // Cosine = -1.0, clamped to 0.0 + let sim = cosine_similarity(&a, &b); + assert_eq!(sim, 0.0); + } + + #[test] + fn cosine_high_dimensional() { + let a: Vec = (0..1536).map(|i| (i as f32) * 0.001).collect(); + let b: Vec = (0..1536).map(|i| (i as f32) * 0.001 + 0.0001).collect(); + let sim = cosine_similarity(&a, &b); + assert!( + sim > 0.99, + "High-dim similar vectors should be close: {sim}" + ); + } + + #[test] + fn cosine_single_element() { + assert!((cosine_similarity(&[5.0], &[5.0]) - 1.0).abs() < 0.001); + assert_eq!(cosine_similarity(&[5.0], &[-5.0]), 0.0); + } + + #[test] + fn cosine_both_zero_vectors() { + let a = vec![0.0, 0.0]; + let b = vec![0.0, 0.0]; + assert_eq!(cosine_similarity(&a, &b), 0.0); + } + + // ── Edge cases: vec↔bytes serialization ────────────────────── + + #[test] + fn bytes_to_vec_non_aligned_truncates() { + // 5 bytes → only first 4 used (1 float), last byte dropped + let bytes = vec![0u8, 0, 0, 0, 0xFF]; + let result = bytes_to_vec(&bytes); + assert_eq!(result.len(), 1); + assert_eq!(result[0], 0.0); + } + + #[test] + fn bytes_to_vec_three_bytes_returns_empty() { + let bytes = vec![1u8, 2, 3]; + let result = bytes_to_vec(&bytes); + assert!(result.is_empty()); + } + + #[test] + fn vec_bytes_roundtrip_special_values() { + let special = vec![f32::MIN, f32::MAX, f32::EPSILON, -0.0, 0.0]; + let bytes = vec_to_bytes(&special); + let restored = bytes_to_vec(&bytes); + assert_eq!(special.len(), restored.len()); + for (a, b) in special.iter().zip(restored.iter()) { + assert_eq!(a.to_bits(), b.to_bits()); + } + } + + #[test] + fn vec_bytes_roundtrip_nan_preserves_bits() { + let nan_vec = vec![f32::NAN]; + let bytes = vec_to_bytes(&nan_vec); + let restored = bytes_to_vec(&bytes); + assert!(restored[0].is_nan()); + } + + // ── Edge cases: hybrid merge ───────────────────────────────── + + #[test] + fn hybrid_merge_limit_zero() { + let vec_results = vec![("a".into(), 0.9)]; + let merged = hybrid_merge(&vec_results, &[], 0.7, 0.3, 0); + assert!(merged.is_empty()); + } + + #[test] + fn hybrid_merge_zero_weights() { + let vec_results = vec![("a".into(), 0.9)]; + let kw_results = vec![("b".into(), 10.0)]; + let merged = hybrid_merge(&vec_results, &kw_results, 0.0, 0.0, 10); + // All final scores should be 0.0 + for r in &merged { + assert_eq!(r.final_score, 0.0); + } + } + + #[test] + fn hybrid_merge_negative_keyword_scores() { + // BM25 scores are negated in our code, but raw negatives shouldn't crash + let kw_results = vec![("a".into(), -5.0), ("b".into(), -1.0)]; + let merged = hybrid_merge(&[], &kw_results, 0.7, 0.3, 10); + assert_eq!(merged.len(), 2); + // Should still produce finite scores + for r in &merged { + assert!(r.final_score.is_finite()); + } + } + + #[test] + fn hybrid_merge_duplicate_ids_in_same_source() { + let vec_results = vec![("a".into(), 0.9), ("a".into(), 0.5)]; + let merged = hybrid_merge(&vec_results, &[], 1.0, 0.0, 10); + // Should deduplicate — only 1 entry for "a" + assert_eq!(merged.len(), 1); + } + + #[test] + fn hybrid_merge_large_bm25_normalization() { + let kw_results = vec![("a".into(), 1000.0), ("b".into(), 500.0), ("c".into(), 1.0)]; + let merged = hybrid_merge(&[], &kw_results, 0.0, 1.0, 10); + // "a" should have normalized score of 1.0 + assert!((merged[0].keyword_score.unwrap() - 1.0).abs() < 0.001); + // "b" should have 0.5 + assert!((merged[1].keyword_score.unwrap() - 0.5).abs() < 0.001); + } + + #[test] + fn hybrid_merge_single_item() { + let merged = hybrid_merge(&[("only".into(), 0.8)], &[], 0.7, 0.3, 10); + assert_eq!(merged.len(), 1); + assert_eq!(merged[0].id, "only"); + } }