zeroclaw/src/memory/embeddings.rs
argenis de la rosa ce4f36a3ab test: 130 edge case tests + fix NaN/Infinity bug in cosine_similarity
Edge cases found 2 real bugs:
- cosine_similarity(NaN, ...) returned NaN instead of 0.0
- cosine_similarity(Infinity, ...) returned NaN instead of 0.0
Fix: added is_finite() guards on denom and raw ratio.

New edge case tests by module:
- vector.rs (18): NaN, Infinity, negative vectors, opposite vectors clamped,
  high-dimensional (1536), single element, both-zero, non-aligned bytes,
  3-byte input, special float values, NaN roundtrip, limit=0, zero weights,
  negative BM25 scores, duplicate IDs, large normalization, single item
- embeddings.rs (8): noop embed_one error, empty batch, multiple texts,
  empty/unknown provider, custom empty URL, no API key, trailing slash, dims
- chunker.rs (11): headings-only, deeply nested ####, long single line,
  whitespace-only, max_tokens=0, max_tokens=1, unicode/emoji, FTS5 special
  chars, multiple blank lines, trailing heading, no content loss
- sqlite.rs (23): FTS5 quotes/asterisks/parens, SQL injection, empty
  content/key, 100KB content, unicode+emoji, newlines+tabs, single char
  query, limit=0/1, key matching, unicode query, schema idempotency,
  triple open, ghost results after forget, forget+re-store cycle,
  reindex empty/twice, content_hash empty/unicode/long, category
  roundtrip with spaces/empty, list custom category, list empty DB

869 tests passing, 0 clippy warnings, cargo-deny clean
2026-02-14 00:28:55 -05:00

252 lines
7.2 KiB
Rust

use async_trait::async_trait;
/// Trait for embedding providers — convert text to vectors
#[async_trait]
pub trait EmbeddingProvider: Send + Sync {
/// Provider name
fn name(&self) -> &str;
/// Embedding dimensions
fn dimensions(&self) -> usize;
/// Embed a batch of texts into vectors
async fn embed(&self, texts: &[&str]) -> anyhow::Result<Vec<Vec<f32>>>;
/// Embed a single text
async fn embed_one(&self, text: &str) -> anyhow::Result<Vec<f32>> {
let mut results = self.embed(&[text]).await?;
results
.pop()
.ok_or_else(|| anyhow::anyhow!("Empty embedding result"))
}
}
// ── Noop provider (keyword-only fallback) ────────────────────
pub struct NoopEmbedding;
#[async_trait]
impl EmbeddingProvider for NoopEmbedding {
fn name(&self) -> &str {
"none"
}
fn dimensions(&self) -> usize {
0
}
async fn embed(&self, _texts: &[&str]) -> anyhow::Result<Vec<Vec<f32>>> {
Ok(Vec::new())
}
}
// ── OpenAI-compatible embedding provider ─────────────────────
pub struct OpenAiEmbedding {
client: reqwest::Client,
base_url: String,
api_key: String,
model: String,
dims: usize,
}
impl OpenAiEmbedding {
pub fn new(base_url: &str, api_key: &str, model: &str, dims: usize) -> Self {
Self {
client: reqwest::Client::new(),
base_url: base_url.trim_end_matches('/').to_string(),
api_key: api_key.to_string(),
model: model.to_string(),
dims,
}
}
}
#[async_trait]
impl EmbeddingProvider for OpenAiEmbedding {
fn name(&self) -> &str {
"openai"
}
fn dimensions(&self) -> usize {
self.dims
}
async fn embed(&self, texts: &[&str]) -> anyhow::Result<Vec<Vec<f32>>> {
if texts.is_empty() {
return Ok(Vec::new());
}
let body = serde_json::json!({
"model": self.model,
"input": texts,
});
let resp = self
.client
.post(format!("{}/v1/embeddings", self.base_url))
.header("Authorization", format!("Bearer {}", self.api_key))
.header("Content-Type", "application/json")
.json(&body)
.send()
.await?;
if !resp.status().is_success() {
let status = resp.status();
let text = resp.text().await.unwrap_or_default();
anyhow::bail!("Embedding API error {status}: {text}");
}
let json: serde_json::Value = resp.json().await?;
let data = json
.get("data")
.and_then(|d| d.as_array())
.ok_or_else(|| anyhow::anyhow!("Invalid embedding response: missing 'data'"))?;
let mut embeddings = Vec::with_capacity(data.len());
for item in data {
let embedding = item
.get("embedding")
.and_then(|e| e.as_array())
.ok_or_else(|| anyhow::anyhow!("Invalid embedding item"))?;
#[allow(clippy::cast_possible_truncation)]
let vec: Vec<f32> = embedding
.iter()
.filter_map(|v| v.as_f64().map(|f| f as f32))
.collect();
embeddings.push(vec);
}
Ok(embeddings)
}
}
// ── Factory ──────────────────────────────────────────────────
pub fn create_embedding_provider(
provider: &str,
api_key: Option<&str>,
model: &str,
dims: usize,
) -> Box<dyn EmbeddingProvider> {
match provider {
"openai" => {
let key = api_key.unwrap_or("");
Box::new(OpenAiEmbedding::new(
"https://api.openai.com",
key,
model,
dims,
))
}
name if name.starts_with("custom:") => {
let base_url = name.strip_prefix("custom:").unwrap_or("");
let key = api_key.unwrap_or("");
Box::new(OpenAiEmbedding::new(base_url, key, model, dims))
}
_ => Box::new(NoopEmbedding),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn noop_name() {
let p = NoopEmbedding;
assert_eq!(p.name(), "none");
assert_eq!(p.dimensions(), 0);
}
#[tokio::test]
async fn noop_embed_returns_empty() {
let p = NoopEmbedding;
let result = p.embed(&["hello"]).await.unwrap();
assert!(result.is_empty());
}
#[test]
fn factory_none() {
let p = create_embedding_provider("none", None, "model", 1536);
assert_eq!(p.name(), "none");
}
#[test]
fn factory_openai() {
let p = create_embedding_provider("openai", Some("key"), "text-embedding-3-small", 1536);
assert_eq!(p.name(), "openai");
assert_eq!(p.dimensions(), 1536);
}
#[test]
fn factory_custom_url() {
let p = create_embedding_provider("custom:http://localhost:1234", None, "model", 768);
assert_eq!(p.name(), "openai"); // uses OpenAiEmbedding internally
assert_eq!(p.dimensions(), 768);
}
// ── Edge cases ───────────────────────────────────────────────
#[tokio::test]
async fn noop_embed_one_returns_error() {
let p = NoopEmbedding;
// embed returns empty vec → pop() returns None → error
let result = p.embed_one("hello").await;
assert!(result.is_err());
}
#[tokio::test]
async fn noop_embed_empty_batch() {
let p = NoopEmbedding;
let result = p.embed(&[]).await.unwrap();
assert!(result.is_empty());
}
#[tokio::test]
async fn noop_embed_multiple_texts() {
let p = NoopEmbedding;
let result = p.embed(&["a", "b", "c"]).await.unwrap();
assert!(result.is_empty());
}
#[test]
fn factory_empty_string_returns_noop() {
let p = create_embedding_provider("", None, "model", 1536);
assert_eq!(p.name(), "none");
}
#[test]
fn factory_unknown_provider_returns_noop() {
let p = create_embedding_provider("cohere", None, "model", 1536);
assert_eq!(p.name(), "none");
}
#[test]
fn factory_custom_empty_url() {
// "custom:" with no URL — should still construct without panic
let p = create_embedding_provider("custom:", None, "model", 768);
assert_eq!(p.name(), "openai");
}
#[test]
fn factory_openai_no_api_key() {
let p = create_embedding_provider("openai", None, "text-embedding-3-small", 1536);
assert_eq!(p.name(), "openai");
assert_eq!(p.dimensions(), 1536);
}
#[test]
fn openai_trailing_slash_stripped() {
let p = OpenAiEmbedding::new("https://api.openai.com/", "key", "model", 1536);
assert_eq!(p.base_url, "https://api.openai.com");
}
#[test]
fn openai_dimensions_custom() {
let p = OpenAiEmbedding::new("http://localhost", "k", "m", 384);
assert_eq!(p.dimensions(), 384);
}
}