feat(memory): add embedding hint routes and upgrade guidance

This commit is contained in:
Chummy 2026-02-19 17:51:35 +08:00
parent 2b8547b386
commit 572aa77c2a
8 changed files with 449 additions and 15 deletions

View file

@ -27,7 +27,7 @@ pub use traits::Memory;
#[allow(unused_imports)]
pub use traits::{MemoryCategory, MemoryEntry};
use crate::config::{MemoryConfig, StorageProviderConfig};
use crate::config::{EmbeddingRouteConfig, MemoryConfig, StorageProviderConfig};
use anyhow::Context;
use std::path::Path;
use std::sync::Arc;
@ -75,13 +75,83 @@ pub fn effective_memory_backend_name(
memory_backend.trim().to_ascii_lowercase()
}
#[derive(Debug, Clone, PartialEq, Eq)]
struct ResolvedEmbeddingConfig {
provider: String,
model: String,
dimensions: usize,
api_key: Option<String>,
}
fn resolve_embedding_config(
config: &MemoryConfig,
embedding_routes: &[EmbeddingRouteConfig],
api_key: Option<&str>,
) -> ResolvedEmbeddingConfig {
let fallback_api_key = api_key
.map(str::trim)
.filter(|value| !value.is_empty())
.map(str::to_string);
let fallback = ResolvedEmbeddingConfig {
provider: config.embedding_provider.trim().to_string(),
model: config.embedding_model.trim().to_string(),
dimensions: config.embedding_dimensions,
api_key: fallback_api_key.clone(),
};
let Some(hint) = config
.embedding_model
.strip_prefix("hint:")
.map(str::trim)
.filter(|value| !value.is_empty())
else {
return fallback;
};
let Some(route) = embedding_routes
.iter()
.find(|route| route.hint.trim() == hint)
else {
tracing::warn!(
hint,
"Unknown embedding route hint; falling back to [memory] embedding settings"
);
return fallback;
};
let provider = route.provider.trim();
let model = route.model.trim();
let dimensions = route.dimensions.unwrap_or(config.embedding_dimensions);
if provider.is_empty() || model.is_empty() || dimensions == 0 {
tracing::warn!(
hint,
"Invalid embedding route configuration; falling back to [memory] embedding settings"
);
return fallback;
}
let routed_api_key = route
.api_key
.as_deref()
.map(str::trim)
.filter(|value: &&str| !value.is_empty())
.map(|value| value.to_string());
ResolvedEmbeddingConfig {
provider: provider.to_string(),
model: model.to_string(),
dimensions,
api_key: routed_api_key.or(fallback_api_key),
}
}
/// Factory: create the right memory backend from config
pub fn create_memory(
config: &MemoryConfig,
workspace_dir: &Path,
api_key: Option<&str>,
) -> anyhow::Result<Box<dyn Memory>> {
create_memory_with_storage(config, None, workspace_dir, api_key)
create_memory_with_storage_and_routes(config, &[], None, workspace_dir, api_key)
}
/// Factory: create memory with optional storage-provider override.
@ -90,9 +160,21 @@ pub fn create_memory_with_storage(
storage_provider: Option<&StorageProviderConfig>,
workspace_dir: &Path,
api_key: Option<&str>,
) -> anyhow::Result<Box<dyn Memory>> {
create_memory_with_storage_and_routes(config, &[], storage_provider, workspace_dir, api_key)
}
/// Factory: create memory with optional storage-provider override and embedding routes.
pub fn create_memory_with_storage_and_routes(
config: &MemoryConfig,
embedding_routes: &[EmbeddingRouteConfig],
storage_provider: Option<&StorageProviderConfig>,
workspace_dir: &Path,
api_key: Option<&str>,
) -> anyhow::Result<Box<dyn Memory>> {
let backend_name = effective_memory_backend_name(&config.backend, storage_provider);
let backend_kind = classify_memory_backend(&backend_name);
let resolved_embedding = resolve_embedding_config(config, embedding_routes, api_key);
// Best-effort memory hygiene/retention pass (throttled by state file).
if let Err(e) = hygiene::run_if_due(config, workspace_dir) {
@ -137,14 +219,14 @@ pub fn create_memory_with_storage(
fn build_sqlite_memory(
config: &MemoryConfig,
workspace_dir: &Path,
api_key: Option<&str>,
resolved_embedding: &ResolvedEmbeddingConfig,
) -> anyhow::Result<SqliteMemory> {
let embedder: Arc<dyn embeddings::EmbeddingProvider> =
Arc::from(embeddings::create_embedding_provider(
&config.embedding_provider,
api_key,
&config.embedding_model,
config.embedding_dimensions,
&resolved_embedding.provider,
resolved_embedding.api_key.as_deref(),
&resolved_embedding.model,
resolved_embedding.dimensions,
));
#[allow(clippy::cast_possible_truncation)]
@ -184,7 +266,7 @@ pub fn create_memory_with_storage(
create_memory_with_builders(
&backend_name,
workspace_dir,
|| build_sqlite_memory(config, workspace_dir, api_key),
|| build_sqlite_memory(config, workspace_dir, &resolved_embedding),
|| build_postgres_memory(storage_provider),
"",
)
@ -247,7 +329,7 @@ pub fn create_response_cache(config: &MemoryConfig, workspace_dir: &Path) -> Opt
#[cfg(test)]
mod tests {
use super::*;
use crate::config::StorageProviderConfig;
use crate::config::{EmbeddingRouteConfig, StorageProviderConfig};
use tempfile::TempDir;
#[test]
@ -353,4 +435,102 @@ mod tests {
.expect("postgres without db_url should be rejected");
assert!(error.to_string().contains("db_url"));
}
#[test]
fn resolve_embedding_config_uses_base_config_when_model_is_not_hint() {
let cfg = MemoryConfig {
embedding_provider: "openai".into(),
embedding_model: "text-embedding-3-small".into(),
embedding_dimensions: 1536,
..MemoryConfig::default()
};
let resolved = resolve_embedding_config(&cfg, &[], Some("base-key"));
assert_eq!(
resolved,
ResolvedEmbeddingConfig {
provider: "openai".into(),
model: "text-embedding-3-small".into(),
dimensions: 1536,
api_key: Some("base-key".into()),
}
);
}
#[test]
fn resolve_embedding_config_uses_matching_route_with_api_key_override() {
let cfg = MemoryConfig {
embedding_provider: "none".into(),
embedding_model: "hint:semantic".into(),
embedding_dimensions: 1536,
..MemoryConfig::default()
};
let routes = vec![EmbeddingRouteConfig {
hint: "semantic".into(),
provider: "custom:https://api.example.com/v1".into(),
model: "custom-embed-v2".into(),
dimensions: Some(1024),
api_key: Some("route-key".into()),
}];
let resolved = resolve_embedding_config(&cfg, &routes, Some("base-key"));
assert_eq!(
resolved,
ResolvedEmbeddingConfig {
provider: "custom:https://api.example.com/v1".into(),
model: "custom-embed-v2".into(),
dimensions: 1024,
api_key: Some("route-key".into()),
}
);
}
#[test]
fn resolve_embedding_config_falls_back_when_hint_is_missing() {
let cfg = MemoryConfig {
embedding_provider: "openai".into(),
embedding_model: "hint:semantic".into(),
embedding_dimensions: 1536,
..MemoryConfig::default()
};
let resolved = resolve_embedding_config(&cfg, &[], Some("base-key"));
assert_eq!(
resolved,
ResolvedEmbeddingConfig {
provider: "openai".into(),
model: "hint:semantic".into(),
dimensions: 1536,
api_key: Some("base-key".into()),
}
);
}
#[test]
fn resolve_embedding_config_falls_back_when_route_is_invalid() {
let cfg = MemoryConfig {
embedding_provider: "openai".into(),
embedding_model: "hint:semantic".into(),
embedding_dimensions: 1536,
..MemoryConfig::default()
};
let routes = vec![EmbeddingRouteConfig {
hint: "semantic".into(),
provider: String::new(),
model: "text-embedding-3-small".into(),
dimensions: Some(0),
api_key: None,
}];
let resolved = resolve_embedding_config(&cfg, &routes, Some("base-key"));
assert_eq!(
resolved,
ResolvedEmbeddingConfig {
provider: "openai".into(),
model: "hint:semantic".into(),
dimensions: 1536,
api_key: Some("base-key".into()),
}
);
}
}