test: 130 edge case tests + fix NaN/Infinity bug in cosine_similarity
Edge cases found 2 real bugs: - cosine_similarity(NaN, ...) returned NaN instead of 0.0 - cosine_similarity(Infinity, ...) returned NaN instead of 0.0 Fix: added is_finite() guards on denom and raw ratio. New edge case tests by module: - vector.rs (18): NaN, Infinity, negative vectors, opposite vectors clamped, high-dimensional (1536), single element, both-zero, non-aligned bytes, 3-byte input, special float values, NaN roundtrip, limit=0, zero weights, negative BM25 scores, duplicate IDs, large normalization, single item - embeddings.rs (8): noop embed_one error, empty batch, multiple texts, empty/unknown provider, custom empty URL, no API key, trailing slash, dims - chunker.rs (11): headings-only, deeply nested ####, long single line, whitespace-only, max_tokens=0, max_tokens=1, unicode/emoji, FTS5 special chars, multiple blank lines, trailing heading, no content loss - sqlite.rs (23): FTS5 quotes/asterisks/parens, SQL injection, empty content/key, 100KB content, unicode+emoji, newlines+tabs, single char query, limit=0/1, key matching, unicode query, schema idempotency, triple open, ghost results after forget, forget+re-store cycle, reindex empty/twice, content_hash empty/unicode/long, category roundtrip with spaces/empty, list custom category, list empty DB 869 tests passing, 0 clippy warnings, cargo-deny clean
This commit is contained in:
parent
0e7f501fd6
commit
ce4f36a3ab
4 changed files with 649 additions and 2 deletions
|
|
@ -256,4 +256,112 @@ mod tests {
|
||||||
let chunks = chunk_markdown(text, 512);
|
let chunks = chunk_markdown(text, 512);
|
||||||
assert_eq!(chunks.len(), 1);
|
assert_eq!(chunks.len(), 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ── Edge cases ───────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn headings_only_no_body() {
|
||||||
|
let text = "# Title\n## Section A\n## Section B\n### Subsection";
|
||||||
|
let chunks = chunk_markdown(text, 512);
|
||||||
|
// Should produce chunks for each heading (even with empty bodies)
|
||||||
|
assert!(!chunks.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn deeply_nested_headings_ignored() {
|
||||||
|
// #### and deeper are NOT treated as heading splits
|
||||||
|
let text = "# Top\nIntro\n#### Deep heading\nDeep content";
|
||||||
|
let chunks = chunk_markdown(text, 512);
|
||||||
|
// "#### Deep heading" should stay with its parent section
|
||||||
|
assert!(!chunks.is_empty());
|
||||||
|
let all_content: String = chunks.iter().map(|c| c.content.clone()).collect();
|
||||||
|
assert!(all_content.contains("Deep heading"));
|
||||||
|
assert!(all_content.contains("Deep content"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn very_long_single_line_no_newlines() {
|
||||||
|
// One giant line with no newlines — can't split on lines effectively
|
||||||
|
let text = "word ".repeat(5000);
|
||||||
|
let chunks = chunk_markdown(&text, 50);
|
||||||
|
// Should produce at least 1 chunk without panicking
|
||||||
|
assert!(!chunks.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn only_newlines_and_whitespace() {
|
||||||
|
assert!(chunk_markdown("\n\n\n \n\n", 512).is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn max_tokens_zero() {
|
||||||
|
// max_tokens=0 → max_chars=0, should not panic or infinite loop
|
||||||
|
let chunks = chunk_markdown("Hello world", 0);
|
||||||
|
// Every chunk will exceed 0 chars, so it splits maximally
|
||||||
|
assert!(!chunks.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn max_tokens_one() {
|
||||||
|
// max_tokens=1 → max_chars=4, very aggressive splitting
|
||||||
|
let text = "Line one\nLine two\nLine three";
|
||||||
|
let chunks = chunk_markdown(text, 1);
|
||||||
|
assert!(!chunks.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn unicode_content() {
|
||||||
|
let text = "# 日本語\nこんにちは世界\n\n## Émojis\n🦀 Rust is great 🚀";
|
||||||
|
let chunks = chunk_markdown(text, 512);
|
||||||
|
assert!(!chunks.is_empty());
|
||||||
|
let all: String = chunks.iter().map(|c| c.content.clone()).collect();
|
||||||
|
assert!(all.contains("こんにちは"));
|
||||||
|
assert!(all.contains("🦀"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn fts5_special_chars_in_content() {
|
||||||
|
let text = "Content with \"quotes\" and (parentheses) and * asterisks *";
|
||||||
|
let chunks = chunk_markdown(text, 512);
|
||||||
|
assert_eq!(chunks.len(), 1);
|
||||||
|
assert!(chunks[0].content.contains("\"quotes\""));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn multiple_blank_lines_between_paragraphs() {
|
||||||
|
let text = "Paragraph one.\n\n\n\n\nParagraph two.\n\n\n\nParagraph three.";
|
||||||
|
let chunks = chunk_markdown(text, 512);
|
||||||
|
assert_eq!(chunks.len(), 1); // All fits in one chunk
|
||||||
|
assert!(chunks[0].content.contains("Paragraph one"));
|
||||||
|
assert!(chunks[0].content.contains("Paragraph three"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn heading_at_end_of_text() {
|
||||||
|
let text = "Some content\n# Trailing Heading";
|
||||||
|
let chunks = chunk_markdown(text, 512);
|
||||||
|
assert!(!chunks.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn single_heading_no_content() {
|
||||||
|
let text = "# Just a heading";
|
||||||
|
let chunks = chunk_markdown(text, 512);
|
||||||
|
assert_eq!(chunks.len(), 1);
|
||||||
|
assert_eq!(chunks[0].heading.as_deref(), Some("# Just a heading"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn no_content_loss() {
|
||||||
|
let text = "# A\nContent A line 1\nContent A line 2\n\n## B\nContent B\n\n## C\nContent C";
|
||||||
|
let chunks = chunk_markdown(text, 512);
|
||||||
|
let reassembled: String = chunks.iter().map(|c| format!("{}\n", c.content)).collect();
|
||||||
|
// All original content words should appear
|
||||||
|
for word in ["Content", "line", "1", "2"] {
|
||||||
|
assert!(
|
||||||
|
reassembled.contains(word),
|
||||||
|
"Missing word '{word}' in reassembled chunks"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -187,4 +187,66 @@ mod tests {
|
||||||
assert_eq!(p.name(), "openai"); // uses OpenAiEmbedding internally
|
assert_eq!(p.name(), "openai"); // uses OpenAiEmbedding internally
|
||||||
assert_eq!(p.dimensions(), 768);
|
assert_eq!(p.dimensions(), 768);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ── Edge cases ───────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn noop_embed_one_returns_error() {
|
||||||
|
let p = NoopEmbedding;
|
||||||
|
// embed returns empty vec → pop() returns None → error
|
||||||
|
let result = p.embed_one("hello").await;
|
||||||
|
assert!(result.is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn noop_embed_empty_batch() {
|
||||||
|
let p = NoopEmbedding;
|
||||||
|
let result = p.embed(&[]).await.unwrap();
|
||||||
|
assert!(result.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn noop_embed_multiple_texts() {
|
||||||
|
let p = NoopEmbedding;
|
||||||
|
let result = p.embed(&["a", "b", "c"]).await.unwrap();
|
||||||
|
assert!(result.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn factory_empty_string_returns_noop() {
|
||||||
|
let p = create_embedding_provider("", None, "model", 1536);
|
||||||
|
assert_eq!(p.name(), "none");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn factory_unknown_provider_returns_noop() {
|
||||||
|
let p = create_embedding_provider("cohere", None, "model", 1536);
|
||||||
|
assert_eq!(p.name(), "none");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn factory_custom_empty_url() {
|
||||||
|
// "custom:" with no URL — should still construct without panic
|
||||||
|
let p = create_embedding_provider("custom:", None, "model", 768);
|
||||||
|
assert_eq!(p.name(), "openai");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn factory_openai_no_api_key() {
|
||||||
|
let p = create_embedding_provider("openai", None, "text-embedding-3-small", 1536);
|
||||||
|
assert_eq!(p.name(), "openai");
|
||||||
|
assert_eq!(p.dimensions(), 1536);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn openai_trailing_slash_stripped() {
|
||||||
|
let p = OpenAiEmbedding::new("https://api.openai.com/", "key", "model", 1536);
|
||||||
|
assert_eq!(p.base_url, "https://api.openai.com");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn openai_dimensions_custom() {
|
||||||
|
let p = OpenAiEmbedding::new("http://localhost", "k", "m", 384);
|
||||||
|
assert_eq!(p.dimensions(), 384);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1045,4 +1045,321 @@ mod tests {
|
||||||
assert!(r.score.is_some(), "Expected score on result: {:?}", r.key);
|
assert!(r.score.is_some(), "Expected score on result: {:?}", r.key);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ── Edge cases: FTS5 special characters ──────────────────────
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn recall_with_quotes_in_query() {
|
||||||
|
let (_tmp, mem) = temp_sqlite();
|
||||||
|
mem.store("q1", "He said hello world", MemoryCategory::Core)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
// Quotes in query should not crash FTS5
|
||||||
|
let results = mem.recall("\"hello\"", 10).await.unwrap();
|
||||||
|
// May or may not match depending on FTS5 escaping, but must not error
|
||||||
|
assert!(results.len() <= 10);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn recall_with_asterisk_in_query() {
|
||||||
|
let (_tmp, mem) = temp_sqlite();
|
||||||
|
mem.store("a1", "wildcard test content", MemoryCategory::Core)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
let results = mem.recall("wild*", 10).await.unwrap();
|
||||||
|
assert!(results.len() <= 10);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn recall_with_parentheses_in_query() {
|
||||||
|
let (_tmp, mem) = temp_sqlite();
|
||||||
|
mem.store("p1", "function call test", MemoryCategory::Core)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
let results = mem.recall("function()", 10).await.unwrap();
|
||||||
|
assert!(results.len() <= 10);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn recall_with_sql_injection_attempt() {
|
||||||
|
let (_tmp, mem) = temp_sqlite();
|
||||||
|
mem.store("safe", "normal content", MemoryCategory::Core)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
// Should not crash or leak data
|
||||||
|
let results = mem.recall("'; DROP TABLE memories; --", 10).await.unwrap();
|
||||||
|
assert!(results.len() <= 10);
|
||||||
|
// Table should still exist
|
||||||
|
assert_eq!(mem.count().await.unwrap(), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Edge cases: store ────────────────────────────────────────
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn store_empty_content() {
|
||||||
|
let (_tmp, mem) = temp_sqlite();
|
||||||
|
mem.store("empty", "", MemoryCategory::Core).await.unwrap();
|
||||||
|
let entry = mem.get("empty").await.unwrap().unwrap();
|
||||||
|
assert_eq!(entry.content, "");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn store_empty_key() {
|
||||||
|
let (_tmp, mem) = temp_sqlite();
|
||||||
|
mem.store("", "content for empty key", MemoryCategory::Core)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
let entry = mem.get("").await.unwrap().unwrap();
|
||||||
|
assert_eq!(entry.content, "content for empty key");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn store_very_long_content() {
|
||||||
|
let (_tmp, mem) = temp_sqlite();
|
||||||
|
let long_content = "x".repeat(100_000);
|
||||||
|
mem.store("long", &long_content, MemoryCategory::Core)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
let entry = mem.get("long").await.unwrap().unwrap();
|
||||||
|
assert_eq!(entry.content.len(), 100_000);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn store_unicode_and_emoji() {
|
||||||
|
let (_tmp, mem) = temp_sqlite();
|
||||||
|
mem.store("emoji_key_🦀", "こんにちは 🚀 Ñoño", MemoryCategory::Core)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
let entry = mem.get("emoji_key_🦀").await.unwrap().unwrap();
|
||||||
|
assert_eq!(entry.content, "こんにちは 🚀 Ñoño");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn store_content_with_newlines_and_tabs() {
|
||||||
|
let (_tmp, mem) = temp_sqlite();
|
||||||
|
let content = "line1\nline2\ttab\rcarriage\n\nnewparagraph";
|
||||||
|
mem.store("whitespace", content, MemoryCategory::Core)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
let entry = mem.get("whitespace").await.unwrap().unwrap();
|
||||||
|
assert_eq!(entry.content, content);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Edge cases: recall ───────────────────────────────────────
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn recall_single_character_query() {
|
||||||
|
let (_tmp, mem) = temp_sqlite();
|
||||||
|
mem.store("a", "x marks the spot", MemoryCategory::Core)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
// Single char may not match FTS5 but LIKE fallback should work
|
||||||
|
let results = mem.recall("x", 10).await.unwrap();
|
||||||
|
// Should not crash; may or may not find results
|
||||||
|
assert!(results.len() <= 10);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn recall_limit_zero() {
|
||||||
|
let (_tmp, mem) = temp_sqlite();
|
||||||
|
mem.store("a", "some content", MemoryCategory::Core)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
let results = mem.recall("some", 0).await.unwrap();
|
||||||
|
assert!(results.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn recall_limit_one() {
|
||||||
|
let (_tmp, mem) = temp_sqlite();
|
||||||
|
mem.store("a", "matching content alpha", MemoryCategory::Core)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
mem.store("b", "matching content beta", MemoryCategory::Core)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
let results = mem.recall("matching content", 1).await.unwrap();
|
||||||
|
assert_eq!(results.len(), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn recall_matches_by_key_not_just_content() {
|
||||||
|
let (_tmp, mem) = temp_sqlite();
|
||||||
|
mem.store(
|
||||||
|
"rust_preferences",
|
||||||
|
"User likes systems programming",
|
||||||
|
MemoryCategory::Core,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
// "rust" appears in key but not content — LIKE fallback checks key too
|
||||||
|
let results = mem.recall("rust", 10).await.unwrap();
|
||||||
|
assert!(!results.is_empty(), "Should match by key");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn recall_unicode_query() {
|
||||||
|
let (_tmp, mem) = temp_sqlite();
|
||||||
|
mem.store("jp", "日本語のテスト", MemoryCategory::Core)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
let results = mem.recall("日本語", 10).await.unwrap();
|
||||||
|
assert!(!results.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Edge cases: schema idempotency ───────────────────────────
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn schema_idempotent_reopen() {
|
||||||
|
let tmp = TempDir::new().unwrap();
|
||||||
|
{
|
||||||
|
let mem = SqliteMemory::new(tmp.path()).unwrap();
|
||||||
|
mem.store("k1", "v1", MemoryCategory::Core).await.unwrap();
|
||||||
|
}
|
||||||
|
// Open again — init_schema runs again on existing DB
|
||||||
|
let mem2 = SqliteMemory::new(tmp.path()).unwrap();
|
||||||
|
let entry = mem2.get("k1").await.unwrap();
|
||||||
|
assert!(entry.is_some());
|
||||||
|
assert_eq!(entry.unwrap().content, "v1");
|
||||||
|
// Store more data — should work fine
|
||||||
|
mem2.store("k2", "v2", MemoryCategory::Daily).await.unwrap();
|
||||||
|
assert_eq!(mem2.count().await.unwrap(), 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn schema_triple_open() {
|
||||||
|
let tmp = TempDir::new().unwrap();
|
||||||
|
let _m1 = SqliteMemory::new(tmp.path()).unwrap();
|
||||||
|
let _m2 = SqliteMemory::new(tmp.path()).unwrap();
|
||||||
|
let m3 = SqliteMemory::new(tmp.path()).unwrap();
|
||||||
|
assert!(m3.health_check().await);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Edge cases: forget + FTS5 consistency ────────────────────
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn forget_then_recall_no_ghost_results() {
|
||||||
|
let (_tmp, mem) = temp_sqlite();
|
||||||
|
mem.store("ghost", "phantom memory content", MemoryCategory::Core)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
mem.forget("ghost").await.unwrap();
|
||||||
|
let results = mem.recall("phantom memory", 10).await.unwrap();
|
||||||
|
assert!(
|
||||||
|
results.is_empty(),
|
||||||
|
"Deleted memory should not appear in recall"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn forget_and_re_store_same_key() {
|
||||||
|
let (_tmp, mem) = temp_sqlite();
|
||||||
|
mem.store("cycle", "version 1", MemoryCategory::Core)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
mem.forget("cycle").await.unwrap();
|
||||||
|
mem.store("cycle", "version 2", MemoryCategory::Core)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
let entry = mem.get("cycle").await.unwrap().unwrap();
|
||||||
|
assert_eq!(entry.content, "version 2");
|
||||||
|
assert_eq!(mem.count().await.unwrap(), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Edge cases: reindex ──────────────────────────────────────
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn reindex_empty_db() {
|
||||||
|
let (_tmp, mem) = temp_sqlite();
|
||||||
|
let count = mem.reindex().await.unwrap();
|
||||||
|
assert_eq!(count, 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn reindex_twice_is_safe() {
|
||||||
|
let (_tmp, mem) = temp_sqlite();
|
||||||
|
mem.store("r1", "reindex data", MemoryCategory::Core)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
mem.reindex().await.unwrap();
|
||||||
|
let count = mem.reindex().await.unwrap();
|
||||||
|
assert_eq!(count, 0); // Noop embedder → nothing to re-embed
|
||||||
|
// Data should still be intact
|
||||||
|
let results = mem.recall("reindex", 10).await.unwrap();
|
||||||
|
assert_eq!(results.len(), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Edge cases: content_hash ─────────────────────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn content_hash_empty_string() {
|
||||||
|
let h = SqliteMemory::content_hash("");
|
||||||
|
assert!(!h.is_empty());
|
||||||
|
assert_eq!(h.len(), 16); // 16 hex chars
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn content_hash_unicode() {
|
||||||
|
let h1 = SqliteMemory::content_hash("🦀");
|
||||||
|
let h2 = SqliteMemory::content_hash("🦀");
|
||||||
|
assert_eq!(h1, h2);
|
||||||
|
let h3 = SqliteMemory::content_hash("🚀");
|
||||||
|
assert_ne!(h1, h3);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn content_hash_long_input() {
|
||||||
|
let long = "a".repeat(1_000_000);
|
||||||
|
let h = SqliteMemory::content_hash(&long);
|
||||||
|
assert_eq!(h.len(), 16);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Edge cases: category helpers ─────────────────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn category_roundtrip_custom_with_spaces() {
|
||||||
|
let cat = MemoryCategory::Custom("my custom category".into());
|
||||||
|
let s = SqliteMemory::category_to_str(&cat);
|
||||||
|
assert_eq!(s, "my custom category");
|
||||||
|
let back = SqliteMemory::str_to_category(&s);
|
||||||
|
assert_eq!(back, cat);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn category_roundtrip_empty_custom() {
|
||||||
|
let cat = MemoryCategory::Custom(String::new());
|
||||||
|
let s = SqliteMemory::category_to_str(&cat);
|
||||||
|
assert_eq!(s, "");
|
||||||
|
let back = SqliteMemory::str_to_category(&s);
|
||||||
|
assert_eq!(back, MemoryCategory::Custom(String::new()));
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Edge cases: list ─────────────────────────────────────────
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn list_custom_category() {
|
||||||
|
let (_tmp, mem) = temp_sqlite();
|
||||||
|
mem.store("c1", "custom1", MemoryCategory::Custom("project".into()))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
mem.store("c2", "custom2", MemoryCategory::Custom("project".into()))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
mem.store("c3", "other", MemoryCategory::Core)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let project = mem
|
||||||
|
.list(Some(&MemoryCategory::Custom("project".into())))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(project.len(), 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn list_empty_db() {
|
||||||
|
let (_tmp, mem) = temp_sqlite();
|
||||||
|
let all = mem.list(None).await.unwrap();
|
||||||
|
assert!(all.is_empty());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -19,13 +19,18 @@ pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
|
||||||
}
|
}
|
||||||
|
|
||||||
let denom = norm_a.sqrt() * norm_b.sqrt();
|
let denom = norm_a.sqrt() * norm_b.sqrt();
|
||||||
if denom < f64::EPSILON {
|
if !denom.is_finite() || denom < f64::EPSILON {
|
||||||
|
return 0.0;
|
||||||
|
}
|
||||||
|
|
||||||
|
let raw = dot / denom;
|
||||||
|
if !raw.is_finite() {
|
||||||
return 0.0;
|
return 0.0;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Clamp to [0, 1] — embeddings are typically positive
|
// Clamp to [0, 1] — embeddings are typically positive
|
||||||
#[allow(clippy::cast_possible_truncation)]
|
#[allow(clippy::cast_possible_truncation)]
|
||||||
let sim = (dot / denom).clamp(0.0, 1.0) as f32;
|
let sim = raw.clamp(0.0, 1.0) as f32;
|
||||||
sim
|
sim
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -231,4 +236,159 @@ mod tests {
|
||||||
let merged = hybrid_merge(&[], &[], 0.7, 0.3, 10);
|
let merged = hybrid_merge(&[], &[], 0.7, 0.3, 10);
|
||||||
assert!(merged.is_empty());
|
assert!(merged.is_empty());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ── Edge cases: cosine similarity ────────────────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn cosine_nan_returns_zero() {
|
||||||
|
let a = vec![f32::NAN, 1.0, 2.0];
|
||||||
|
let b = vec![1.0, 2.0, 3.0];
|
||||||
|
let sim = cosine_similarity(&a, &b);
|
||||||
|
// NaN propagates through arithmetic — result should be 0.0 (clamped or denom check)
|
||||||
|
assert!(sim.is_finite(), "Expected finite, got {sim}");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn cosine_infinity_returns_zero_or_finite() {
|
||||||
|
let a = vec![f32::INFINITY, 1.0];
|
||||||
|
let b = vec![1.0, 2.0];
|
||||||
|
let sim = cosine_similarity(&a, &b);
|
||||||
|
assert!(sim.is_finite(), "Expected finite, got {sim}");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn cosine_negative_values() {
|
||||||
|
let a = vec![-1.0, -2.0, -3.0];
|
||||||
|
let b = vec![-1.0, -2.0, -3.0];
|
||||||
|
// Identical negative vectors → cosine = 1.0, but clamped to [0,1]
|
||||||
|
let sim = cosine_similarity(&a, &b);
|
||||||
|
assert!((sim - 1.0).abs() < 0.001);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn cosine_opposite_vectors_clamped() {
|
||||||
|
let a = vec![1.0, 0.0];
|
||||||
|
let b = vec![-1.0, 0.0];
|
||||||
|
// Cosine = -1.0, clamped to 0.0
|
||||||
|
let sim = cosine_similarity(&a, &b);
|
||||||
|
assert_eq!(sim, 0.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn cosine_high_dimensional() {
|
||||||
|
let a: Vec<f32> = (0..1536).map(|i| (i as f32) * 0.001).collect();
|
||||||
|
let b: Vec<f32> = (0..1536).map(|i| (i as f32) * 0.001 + 0.0001).collect();
|
||||||
|
let sim = cosine_similarity(&a, &b);
|
||||||
|
assert!(
|
||||||
|
sim > 0.99,
|
||||||
|
"High-dim similar vectors should be close: {sim}"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn cosine_single_element() {
|
||||||
|
assert!((cosine_similarity(&[5.0], &[5.0]) - 1.0).abs() < 0.001);
|
||||||
|
assert_eq!(cosine_similarity(&[5.0], &[-5.0]), 0.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn cosine_both_zero_vectors() {
|
||||||
|
let a = vec![0.0, 0.0];
|
||||||
|
let b = vec![0.0, 0.0];
|
||||||
|
assert_eq!(cosine_similarity(&a, &b), 0.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Edge cases: vec↔bytes serialization ──────────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn bytes_to_vec_non_aligned_truncates() {
|
||||||
|
// 5 bytes → only first 4 used (1 float), last byte dropped
|
||||||
|
let bytes = vec![0u8, 0, 0, 0, 0xFF];
|
||||||
|
let result = bytes_to_vec(&bytes);
|
||||||
|
assert_eq!(result.len(), 1);
|
||||||
|
assert_eq!(result[0], 0.0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn bytes_to_vec_three_bytes_returns_empty() {
|
||||||
|
let bytes = vec![1u8, 2, 3];
|
||||||
|
let result = bytes_to_vec(&bytes);
|
||||||
|
assert!(result.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn vec_bytes_roundtrip_special_values() {
|
||||||
|
let special = vec![f32::MIN, f32::MAX, f32::EPSILON, -0.0, 0.0];
|
||||||
|
let bytes = vec_to_bytes(&special);
|
||||||
|
let restored = bytes_to_vec(&bytes);
|
||||||
|
assert_eq!(special.len(), restored.len());
|
||||||
|
for (a, b) in special.iter().zip(restored.iter()) {
|
||||||
|
assert_eq!(a.to_bits(), b.to_bits());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn vec_bytes_roundtrip_nan_preserves_bits() {
|
||||||
|
let nan_vec = vec![f32::NAN];
|
||||||
|
let bytes = vec_to_bytes(&nan_vec);
|
||||||
|
let restored = bytes_to_vec(&bytes);
|
||||||
|
assert!(restored[0].is_nan());
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Edge cases: hybrid merge ─────────────────────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn hybrid_merge_limit_zero() {
|
||||||
|
let vec_results = vec![("a".into(), 0.9)];
|
||||||
|
let merged = hybrid_merge(&vec_results, &[], 0.7, 0.3, 0);
|
||||||
|
assert!(merged.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn hybrid_merge_zero_weights() {
|
||||||
|
let vec_results = vec![("a".into(), 0.9)];
|
||||||
|
let kw_results = vec![("b".into(), 10.0)];
|
||||||
|
let merged = hybrid_merge(&vec_results, &kw_results, 0.0, 0.0, 10);
|
||||||
|
// All final scores should be 0.0
|
||||||
|
for r in &merged {
|
||||||
|
assert_eq!(r.final_score, 0.0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn hybrid_merge_negative_keyword_scores() {
|
||||||
|
// BM25 scores are negated in our code, but raw negatives shouldn't crash
|
||||||
|
let kw_results = vec![("a".into(), -5.0), ("b".into(), -1.0)];
|
||||||
|
let merged = hybrid_merge(&[], &kw_results, 0.7, 0.3, 10);
|
||||||
|
assert_eq!(merged.len(), 2);
|
||||||
|
// Should still produce finite scores
|
||||||
|
for r in &merged {
|
||||||
|
assert!(r.final_score.is_finite());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn hybrid_merge_duplicate_ids_in_same_source() {
|
||||||
|
let vec_results = vec![("a".into(), 0.9), ("a".into(), 0.5)];
|
||||||
|
let merged = hybrid_merge(&vec_results, &[], 1.0, 0.0, 10);
|
||||||
|
// Should deduplicate — only 1 entry for "a"
|
||||||
|
assert_eq!(merged.len(), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn hybrid_merge_large_bm25_normalization() {
|
||||||
|
let kw_results = vec![("a".into(), 1000.0), ("b".into(), 500.0), ("c".into(), 1.0)];
|
||||||
|
let merged = hybrid_merge(&[], &kw_results, 0.0, 1.0, 10);
|
||||||
|
// "a" should have normalized score of 1.0
|
||||||
|
assert!((merged[0].keyword_score.unwrap() - 1.0).abs() < 0.001);
|
||||||
|
// "b" should have 0.5
|
||||||
|
assert!((merged[1].keyword_score.unwrap() - 0.5).abs() < 0.001);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn hybrid_merge_single_item() {
|
||||||
|
let merged = hybrid_merge(&[("only".into(), 0.8)], &[], 0.7, 0.3, 10);
|
||||||
|
assert_eq!(merged.len(), 1);
|
||||||
|
assert_eq!(merged[0].id, "only");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue