feat(memory): add embedding hint routes and upgrade guidance
This commit is contained in:
parent
2b8547b386
commit
572aa77c2a
8 changed files with 449 additions and 15 deletions
|
|
@ -27,7 +27,7 @@ pub use traits::Memory;
|
|||
#[allow(unused_imports)]
|
||||
pub use traits::{MemoryCategory, MemoryEntry};
|
||||
|
||||
use crate::config::{MemoryConfig, StorageProviderConfig};
|
||||
use crate::config::{EmbeddingRouteConfig, MemoryConfig, StorageProviderConfig};
|
||||
use anyhow::Context;
|
||||
use std::path::Path;
|
||||
use std::sync::Arc;
|
||||
|
|
@ -75,13 +75,83 @@ pub fn effective_memory_backend_name(
|
|||
memory_backend.trim().to_ascii_lowercase()
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
struct ResolvedEmbeddingConfig {
|
||||
provider: String,
|
||||
model: String,
|
||||
dimensions: usize,
|
||||
api_key: Option<String>,
|
||||
}
|
||||
|
||||
fn resolve_embedding_config(
|
||||
config: &MemoryConfig,
|
||||
embedding_routes: &[EmbeddingRouteConfig],
|
||||
api_key: Option<&str>,
|
||||
) -> ResolvedEmbeddingConfig {
|
||||
let fallback_api_key = api_key
|
||||
.map(str::trim)
|
||||
.filter(|value| !value.is_empty())
|
||||
.map(str::to_string);
|
||||
let fallback = ResolvedEmbeddingConfig {
|
||||
provider: config.embedding_provider.trim().to_string(),
|
||||
model: config.embedding_model.trim().to_string(),
|
||||
dimensions: config.embedding_dimensions,
|
||||
api_key: fallback_api_key.clone(),
|
||||
};
|
||||
|
||||
let Some(hint) = config
|
||||
.embedding_model
|
||||
.strip_prefix("hint:")
|
||||
.map(str::trim)
|
||||
.filter(|value| !value.is_empty())
|
||||
else {
|
||||
return fallback;
|
||||
};
|
||||
|
||||
let Some(route) = embedding_routes
|
||||
.iter()
|
||||
.find(|route| route.hint.trim() == hint)
|
||||
else {
|
||||
tracing::warn!(
|
||||
hint,
|
||||
"Unknown embedding route hint; falling back to [memory] embedding settings"
|
||||
);
|
||||
return fallback;
|
||||
};
|
||||
|
||||
let provider = route.provider.trim();
|
||||
let model = route.model.trim();
|
||||
let dimensions = route.dimensions.unwrap_or(config.embedding_dimensions);
|
||||
if provider.is_empty() || model.is_empty() || dimensions == 0 {
|
||||
tracing::warn!(
|
||||
hint,
|
||||
"Invalid embedding route configuration; falling back to [memory] embedding settings"
|
||||
);
|
||||
return fallback;
|
||||
}
|
||||
|
||||
let routed_api_key = route
|
||||
.api_key
|
||||
.as_deref()
|
||||
.map(str::trim)
|
||||
.filter(|value: &&str| !value.is_empty())
|
||||
.map(|value| value.to_string());
|
||||
|
||||
ResolvedEmbeddingConfig {
|
||||
provider: provider.to_string(),
|
||||
model: model.to_string(),
|
||||
dimensions,
|
||||
api_key: routed_api_key.or(fallback_api_key),
|
||||
}
|
||||
}
|
||||
|
||||
/// Factory: create the right memory backend from config
|
||||
pub fn create_memory(
|
||||
config: &MemoryConfig,
|
||||
workspace_dir: &Path,
|
||||
api_key: Option<&str>,
|
||||
) -> anyhow::Result<Box<dyn Memory>> {
|
||||
create_memory_with_storage(config, None, workspace_dir, api_key)
|
||||
create_memory_with_storage_and_routes(config, &[], None, workspace_dir, api_key)
|
||||
}
|
||||
|
||||
/// Factory: create memory with optional storage-provider override.
|
||||
|
|
@ -90,9 +160,21 @@ pub fn create_memory_with_storage(
|
|||
storage_provider: Option<&StorageProviderConfig>,
|
||||
workspace_dir: &Path,
|
||||
api_key: Option<&str>,
|
||||
) -> anyhow::Result<Box<dyn Memory>> {
|
||||
create_memory_with_storage_and_routes(config, &[], storage_provider, workspace_dir, api_key)
|
||||
}
|
||||
|
||||
/// Factory: create memory with optional storage-provider override and embedding routes.
|
||||
pub fn create_memory_with_storage_and_routes(
|
||||
config: &MemoryConfig,
|
||||
embedding_routes: &[EmbeddingRouteConfig],
|
||||
storage_provider: Option<&StorageProviderConfig>,
|
||||
workspace_dir: &Path,
|
||||
api_key: Option<&str>,
|
||||
) -> anyhow::Result<Box<dyn Memory>> {
|
||||
let backend_name = effective_memory_backend_name(&config.backend, storage_provider);
|
||||
let backend_kind = classify_memory_backend(&backend_name);
|
||||
let resolved_embedding = resolve_embedding_config(config, embedding_routes, api_key);
|
||||
|
||||
// Best-effort memory hygiene/retention pass (throttled by state file).
|
||||
if let Err(e) = hygiene::run_if_due(config, workspace_dir) {
|
||||
|
|
@ -137,14 +219,14 @@ pub fn create_memory_with_storage(
|
|||
fn build_sqlite_memory(
|
||||
config: &MemoryConfig,
|
||||
workspace_dir: &Path,
|
||||
api_key: Option<&str>,
|
||||
resolved_embedding: &ResolvedEmbeddingConfig,
|
||||
) -> anyhow::Result<SqliteMemory> {
|
||||
let embedder: Arc<dyn embeddings::EmbeddingProvider> =
|
||||
Arc::from(embeddings::create_embedding_provider(
|
||||
&config.embedding_provider,
|
||||
api_key,
|
||||
&config.embedding_model,
|
||||
config.embedding_dimensions,
|
||||
&resolved_embedding.provider,
|
||||
resolved_embedding.api_key.as_deref(),
|
||||
&resolved_embedding.model,
|
||||
resolved_embedding.dimensions,
|
||||
));
|
||||
|
||||
#[allow(clippy::cast_possible_truncation)]
|
||||
|
|
@ -184,7 +266,7 @@ pub fn create_memory_with_storage(
|
|||
create_memory_with_builders(
|
||||
&backend_name,
|
||||
workspace_dir,
|
||||
|| build_sqlite_memory(config, workspace_dir, api_key),
|
||||
|| build_sqlite_memory(config, workspace_dir, &resolved_embedding),
|
||||
|| build_postgres_memory(storage_provider),
|
||||
"",
|
||||
)
|
||||
|
|
@ -247,7 +329,7 @@ pub fn create_response_cache(config: &MemoryConfig, workspace_dir: &Path) -> Opt
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::config::StorageProviderConfig;
|
||||
use crate::config::{EmbeddingRouteConfig, StorageProviderConfig};
|
||||
use tempfile::TempDir;
|
||||
|
||||
#[test]
|
||||
|
|
@ -353,4 +435,102 @@ mod tests {
|
|||
.expect("postgres without db_url should be rejected");
|
||||
assert!(error.to_string().contains("db_url"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_embedding_config_uses_base_config_when_model_is_not_hint() {
|
||||
let cfg = MemoryConfig {
|
||||
embedding_provider: "openai".into(),
|
||||
embedding_model: "text-embedding-3-small".into(),
|
||||
embedding_dimensions: 1536,
|
||||
..MemoryConfig::default()
|
||||
};
|
||||
|
||||
let resolved = resolve_embedding_config(&cfg, &[], Some("base-key"));
|
||||
assert_eq!(
|
||||
resolved,
|
||||
ResolvedEmbeddingConfig {
|
||||
provider: "openai".into(),
|
||||
model: "text-embedding-3-small".into(),
|
||||
dimensions: 1536,
|
||||
api_key: Some("base-key".into()),
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_embedding_config_uses_matching_route_with_api_key_override() {
|
||||
let cfg = MemoryConfig {
|
||||
embedding_provider: "none".into(),
|
||||
embedding_model: "hint:semantic".into(),
|
||||
embedding_dimensions: 1536,
|
||||
..MemoryConfig::default()
|
||||
};
|
||||
let routes = vec![EmbeddingRouteConfig {
|
||||
hint: "semantic".into(),
|
||||
provider: "custom:https://api.example.com/v1".into(),
|
||||
model: "custom-embed-v2".into(),
|
||||
dimensions: Some(1024),
|
||||
api_key: Some("route-key".into()),
|
||||
}];
|
||||
|
||||
let resolved = resolve_embedding_config(&cfg, &routes, Some("base-key"));
|
||||
assert_eq!(
|
||||
resolved,
|
||||
ResolvedEmbeddingConfig {
|
||||
provider: "custom:https://api.example.com/v1".into(),
|
||||
model: "custom-embed-v2".into(),
|
||||
dimensions: 1024,
|
||||
api_key: Some("route-key".into()),
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_embedding_config_falls_back_when_hint_is_missing() {
|
||||
let cfg = MemoryConfig {
|
||||
embedding_provider: "openai".into(),
|
||||
embedding_model: "hint:semantic".into(),
|
||||
embedding_dimensions: 1536,
|
||||
..MemoryConfig::default()
|
||||
};
|
||||
|
||||
let resolved = resolve_embedding_config(&cfg, &[], Some("base-key"));
|
||||
assert_eq!(
|
||||
resolved,
|
||||
ResolvedEmbeddingConfig {
|
||||
provider: "openai".into(),
|
||||
model: "hint:semantic".into(),
|
||||
dimensions: 1536,
|
||||
api_key: Some("base-key".into()),
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_embedding_config_falls_back_when_route_is_invalid() {
|
||||
let cfg = MemoryConfig {
|
||||
embedding_provider: "openai".into(),
|
||||
embedding_model: "hint:semantic".into(),
|
||||
embedding_dimensions: 1536,
|
||||
..MemoryConfig::default()
|
||||
};
|
||||
let routes = vec![EmbeddingRouteConfig {
|
||||
hint: "semantic".into(),
|
||||
provider: String::new(),
|
||||
model: "text-embedding-3-small".into(),
|
||||
dimensions: Some(0),
|
||||
api_key: None,
|
||||
}];
|
||||
|
||||
let resolved = resolve_embedding_config(&cfg, &routes, Some("base-key"));
|
||||
assert_eq!(
|
||||
resolved,
|
||||
ResolvedEmbeddingConfig {
|
||||
provider: "openai".into(),
|
||||
model: "hint:semantic".into(),
|
||||
dimensions: 1536,
|
||||
api_key: Some("base-key".into()),
|
||||
}
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue