zeroclaw/src/memory/embeddings.rs

321 lines
9.1 KiB
Rust

use async_trait::async_trait;
/// Trait for embedding providers — convert text to vectors
#[async_trait]
pub trait EmbeddingProvider: Send + Sync {
/// Provider name
fn name(&self) -> &str;
/// Embedding dimensions
fn dimensions(&self) -> usize;
/// Embed a batch of texts into vectors
async fn embed(&self, texts: &[&str]) -> anyhow::Result<Vec<Vec<f32>>>;
/// Embed a single text
async fn embed_one(&self, text: &str) -> anyhow::Result<Vec<f32>> {
let mut results = self.embed(&[text]).await?;
results
.pop()
.ok_or_else(|| anyhow::anyhow!("Empty embedding result"))
}
}
// ── Noop provider (keyword-only fallback) ────────────────────
pub struct NoopEmbedding;
#[async_trait]
impl EmbeddingProvider for NoopEmbedding {
fn name(&self) -> &str {
"none"
}
fn dimensions(&self) -> usize {
0
}
async fn embed(&self, _texts: &[&str]) -> anyhow::Result<Vec<Vec<f32>>> {
Ok(Vec::new())
}
}
// ── OpenAI-compatible embedding provider ─────────────────────
pub struct OpenAiEmbedding {
client: reqwest::Client,
base_url: String,
api_key: String,
model: String,
dims: usize,
}
impl OpenAiEmbedding {
pub fn new(base_url: &str, api_key: &str, model: &str, dims: usize) -> Self {
Self {
client: reqwest::Client::new(),
base_url: base_url.trim_end_matches('/').to_string(),
api_key: api_key.to_string(),
model: model.to_string(),
dims,
}
}
fn has_explicit_api_path(&self) -> bool {
let Ok(url) = reqwest::Url::parse(&self.base_url) else {
return false;
};
let path = url.path().trim_end_matches('/');
!path.is_empty() && path != "/"
}
fn has_embeddings_endpoint(&self) -> bool {
let Ok(url) = reqwest::Url::parse(&self.base_url) else {
return false;
};
url.path().trim_end_matches('/').ends_with("/embeddings")
}
fn embeddings_url(&self) -> String {
if self.has_embeddings_endpoint() {
return self.base_url.clone();
}
if self.has_explicit_api_path() {
format!("{}/embeddings", self.base_url)
} else {
format!("{}/v1/embeddings", self.base_url)
}
}
}
#[async_trait]
impl EmbeddingProvider for OpenAiEmbedding {
fn name(&self) -> &str {
"openai"
}
fn dimensions(&self) -> usize {
self.dims
}
async fn embed(&self, texts: &[&str]) -> anyhow::Result<Vec<Vec<f32>>> {
if texts.is_empty() {
return Ok(Vec::new());
}
let body = serde_json::json!({
"model": self.model,
"input": texts,
});
let resp = self
.client
.post(self.embeddings_url())
.header("Authorization", format!("Bearer {}", self.api_key))
.header("Content-Type", "application/json")
.json(&body)
.send()
.await?;
if !resp.status().is_success() {
let status = resp.status();
let text = resp.text().await.unwrap_or_default();
anyhow::bail!("Embedding API error {status}: {text}");
}
let json: serde_json::Value = resp.json().await?;
let data = json
.get("data")
.and_then(|d| d.as_array())
.ok_or_else(|| anyhow::anyhow!("Invalid embedding response: missing 'data'"))?;
let mut embeddings = Vec::with_capacity(data.len());
for item in data {
let embedding = item
.get("embedding")
.and_then(|e| e.as_array())
.ok_or_else(|| anyhow::anyhow!("Invalid embedding item"))?;
#[allow(clippy::cast_possible_truncation)]
let vec: Vec<f32> = embedding
.iter()
.filter_map(|v| v.as_f64().map(|f| f as f32))
.collect();
embeddings.push(vec);
}
Ok(embeddings)
}
}
// ── Factory ──────────────────────────────────────────────────
pub fn create_embedding_provider(
provider: &str,
api_key: Option<&str>,
model: &str,
dims: usize,
) -> Box<dyn EmbeddingProvider> {
match provider {
"openai" => {
let key = api_key.unwrap_or("");
Box::new(OpenAiEmbedding::new(
"https://api.openai.com",
key,
model,
dims,
))
}
name if name.starts_with("custom:") => {
let base_url = name.strip_prefix("custom:").unwrap_or("");
let key = api_key.unwrap_or("");
Box::new(OpenAiEmbedding::new(base_url, key, model, dims))
}
_ => Box::new(NoopEmbedding),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn noop_name() {
let p = NoopEmbedding;
assert_eq!(p.name(), "none");
assert_eq!(p.dimensions(), 0);
}
#[tokio::test]
async fn noop_embed_returns_empty() {
let p = NoopEmbedding;
let result = p.embed(&["hello"]).await.unwrap();
assert!(result.is_empty());
}
#[test]
fn factory_none() {
let p = create_embedding_provider("none", None, "model", 1536);
assert_eq!(p.name(), "none");
}
#[test]
fn factory_openai() {
let p = create_embedding_provider("openai", Some("key"), "text-embedding-3-small", 1536);
assert_eq!(p.name(), "openai");
assert_eq!(p.dimensions(), 1536);
}
#[test]
fn factory_custom_url() {
let p = create_embedding_provider("custom:http://localhost:1234", None, "model", 768);
assert_eq!(p.name(), "openai"); // uses OpenAiEmbedding internally
assert_eq!(p.dimensions(), 768);
}
// ── Edge cases ───────────────────────────────────────────────
#[tokio::test]
async fn noop_embed_one_returns_error() {
let p = NoopEmbedding;
// embed returns empty vec → pop() returns None → error
let result = p.embed_one("hello").await;
assert!(result.is_err());
}
#[tokio::test]
async fn noop_embed_empty_batch() {
let p = NoopEmbedding;
let result = p.embed(&[]).await.unwrap();
assert!(result.is_empty());
}
#[tokio::test]
async fn noop_embed_multiple_texts() {
let p = NoopEmbedding;
let result = p.embed(&["a", "b", "c"]).await.unwrap();
assert!(result.is_empty());
}
#[test]
fn factory_empty_string_returns_noop() {
let p = create_embedding_provider("", None, "model", 1536);
assert_eq!(p.name(), "none");
}
#[test]
fn factory_unknown_provider_returns_noop() {
let p = create_embedding_provider("cohere", None, "model", 1536);
assert_eq!(p.name(), "none");
}
#[test]
fn factory_custom_empty_url() {
// "custom:" with no URL — should still construct without panic
let p = create_embedding_provider("custom:", None, "model", 768);
assert_eq!(p.name(), "openai");
}
#[test]
fn factory_openai_no_api_key() {
let p = create_embedding_provider("openai", None, "text-embedding-3-small", 1536);
assert_eq!(p.name(), "openai");
assert_eq!(p.dimensions(), 1536);
}
#[test]
fn openai_trailing_slash_stripped() {
let p = OpenAiEmbedding::new("https://api.openai.com/", "key", "model", 1536);
assert_eq!(p.base_url, "https://api.openai.com");
}
#[test]
fn openai_dimensions_custom() {
let p = OpenAiEmbedding::new("http://localhost", "k", "m", 384);
assert_eq!(p.dimensions(), 384);
}
#[test]
fn embeddings_url_standard_openai() {
let p = OpenAiEmbedding::new("https://api.openai.com", "key", "model", 1536);
assert_eq!(p.embeddings_url(), "https://api.openai.com/v1/embeddings");
}
#[test]
fn embeddings_url_base_with_v1_no_duplicate() {
let p = OpenAiEmbedding::new("https://api.example.com/v1", "key", "model", 1536);
assert_eq!(p.embeddings_url(), "https://api.example.com/v1/embeddings");
}
#[test]
fn embeddings_url_non_v1_api_path_uses_raw_suffix() {
let p = OpenAiEmbedding::new(
"https://api.example.com/api/coding/v3",
"key",
"model",
1536,
);
assert_eq!(
p.embeddings_url(),
"https://api.example.com/api/coding/v3/embeddings"
);
}
#[test]
fn embeddings_url_custom_full_endpoint() {
let p = OpenAiEmbedding::new(
"https://my-api.example.com/api/v2/embeddings",
"key",
"model",
1536,
);
assert_eq!(
p.embeddings_url(),
"https://my-api.example.com/api/v2/embeddings"
);
}
}