diff --git a/src/agent/agent.rs b/src/agent/agent.rs index 3df62fc..fbb5ec6 100644 --- a/src/agent/agent.rs +++ b/src/agent/agent.rs @@ -271,7 +271,10 @@ impl Agent { .memory(memory) .observer(observer) .tool_dispatcher(tool_dispatcher) - .memory_loader(Box::new(DefaultMemoryLoader::default())) + .memory_loader(Box::new(DefaultMemoryLoader::new( + 5, + config.memory.min_relevance_score, + ))) .prompt_builder(SystemPromptBuilder::with_defaults()) .config(config.agent.clone()) .model_name(model_name) diff --git a/src/agent/loop_.rs b/src/agent/loop_.rs index 80014df..b9b4344 100644 --- a/src/agent/loop_.rs +++ b/src/agent/loop_.rs @@ -16,8 +16,7 @@ use std::time::Instant; use uuid::Uuid; /// Default maximum agentic tool-use iterations per user message to prevent runaway loops. -/// Prefer passing the config-driven value via `run_tool_call_loop`; this constant is only -/// used when callers omit the parameter. +/// Used as a safe fallback when `max_tool_iterations` is unset or configured as zero. const DEFAULT_MAX_TOOL_ITERATIONS: usize = 10; static SENSITIVE_KEY_PATTERNS: LazyLock = LazyLock::new(|| { @@ -202,15 +201,25 @@ async fn auto_compact_history( Ok(true) } -/// Build context preamble by searching memory for relevant entries -async fn build_context(mem: &dyn Memory, user_msg: &str) -> String { +/// Build context preamble by searching memory for relevant entries. +/// Entries with a hybrid score below `min_relevance_score` are dropped to +/// prevent unrelated memories from bleeding into the conversation. +async fn build_context(mem: &dyn Memory, user_msg: &str, min_relevance_score: f64) -> String { let mut context = String::new(); // Pull relevant memories for this message if let Ok(entries) = mem.recall(user_msg, 5, None).await { - if !entries.is_empty() { + let relevant: Vec<_> = entries + .iter() + .filter(|e| match e.score { + Some(score) => score >= min_relevance_score, + None => true, + }) + .collect(); + + if !relevant.is_empty() { context.push_str("[Memory context]\n"); - for entry in &entries { + for entry in &relevant { let _ = writeln!(context, "- {}: {}", entry.key, entry.content); } context.push('\n'); @@ -598,6 +607,7 @@ pub(crate) async fn agent_turn( model: &str, temperature: f64, silent: bool, + max_tool_iterations: usize, ) -> Result { run_tool_call_loop( provider, @@ -610,7 +620,7 @@ pub(crate) async fn agent_turn( silent, None, "channel", - DEFAULT_MAX_TOOL_ITERATIONS, + max_tool_iterations, ) .await } @@ -631,6 +641,12 @@ pub(crate) async fn run_tool_call_loop( channel_name: &str, max_tool_iterations: usize, ) -> Result { + let max_iterations = if max_tool_iterations == 0 { + DEFAULT_MAX_TOOL_ITERATIONS + } else { + max_tool_iterations + }; + // Build native tool definitions once if the provider supports them. let use_native_tools = provider.supports_native_tools() && !tools_registry.is_empty(); let tool_definitions = if use_native_tools { @@ -639,7 +655,7 @@ pub(crate) async fn run_tool_call_loop( Vec::new() }; - for _iteration in 0..max_tool_iterations { + for _iteration in 0..max_iterations { observer.record_event(&ObserverEvent::LlmRequest { provider: provider_name.to_string(), model: model.to_string(), @@ -857,7 +873,7 @@ pub(crate) async fn run_tool_call_loop( } } - anyhow::bail!("Agent exceeded maximum tool iterations ({max_tool_iterations})") + anyhow::bail!("Agent exceeded maximum tool iterations ({max_iterations})") } /// Build the tool instruction block for the system prompt so the LLM knows @@ -1142,7 +1158,8 @@ pub async fn run( } // Inject memory + hardware RAG context into user message - let mem_context = build_context(mem.as_ref(), &msg).await; + let mem_context = + build_context(mem.as_ref(), &msg, config.memory.min_relevance_score).await; let rag_limit = if config.agent.compact_context { 2 } else { 5 }; let hw_context = hardware_rag .as_ref() @@ -1270,7 +1287,8 @@ pub async fn run( } // Inject memory + hardware RAG context into user message - let mem_context = build_context(mem.as_ref(), &user_input).await; + let mem_context = + build_context(mem.as_ref(), &user_input, config.memory.min_relevance_score).await; let rag_limit = if config.agent.compact_context { 2 } else { 5 }; let hw_context = hardware_rag .as_ref() @@ -1487,7 +1505,7 @@ pub async fn process_message(config: Config, message: &str) -> Result { ); system_prompt.push_str(&build_tool_instructions(&tools_registry)); - let mem_context = build_context(mem.as_ref(), message).await; + let mem_context = build_context(mem.as_ref(), message, config.memory.min_relevance_score).await; let rag_limit = if config.agent.compact_context { 2 } else { 5 }; let hw_context = hardware_rag .as_ref() @@ -1514,6 +1532,7 @@ pub async fn process_message(config: Config, message: &str) -> Result { &model_name, config.default_temperature, true, + config.agent.max_tool_iterations, ) .await } diff --git a/src/agent/memory_loader.rs b/src/agent/memory_loader.rs index 0cc530f..b171eed 100644 --- a/src/agent/memory_loader.rs +++ b/src/agent/memory_loader.rs @@ -10,18 +10,23 @@ pub trait MemoryLoader: Send + Sync { pub struct DefaultMemoryLoader { limit: usize, + min_relevance_score: f64, } impl Default for DefaultMemoryLoader { fn default() -> Self { - Self { limit: 5 } + Self { + limit: 5, + min_relevance_score: 0.4, + } } } impl DefaultMemoryLoader { - pub fn new(limit: usize) -> Self { + pub fn new(limit: usize, min_relevance_score: f64) -> Self { Self { limit: limit.max(1), + min_relevance_score, } } } @@ -40,8 +45,19 @@ impl MemoryLoader for DefaultMemoryLoader { let mut context = String::from("[Memory context]\n"); for entry in entries { + if let Some(score) = entry.score { + if score < self.min_relevance_score { + continue; + } + } let _ = writeln!(context, "- {}: {}", entry.key, entry.content); } + + // If all entries were below threshold, return empty + if context == "[Memory context]\n" { + return Ok(String::new()); + } + context.push('\n'); Ok(context) } diff --git a/src/channels/mod.rs b/src/channels/mod.rs index 2cf7892..ec11c2b 100644 --- a/src/channels/mod.rs +++ b/src/channels/mod.rs @@ -72,6 +72,7 @@ struct ChannelRuntimeContext { temperature: f64, auto_save_memory: bool, max_tool_iterations: usize, + min_relevance_score: f64, } fn conversation_memory_key(msg: &traits::ChannelMessage) -> String { @@ -87,13 +88,25 @@ fn channel_delivery_instructions(channel_name: &str) -> Option<&'static str> { } } -async fn build_memory_context(mem: &dyn Memory, user_msg: &str) -> String { +async fn build_memory_context( + mem: &dyn Memory, + user_msg: &str, + min_relevance_score: f64, +) -> String { let mut context = String::new(); if let Ok(entries) = mem.recall(user_msg, 5, None).await { - if !entries.is_empty() { + let relevant: Vec<_> = entries + .iter() + .filter(|e| match e.score { + Some(score) => score >= min_relevance_score, + None => true, // keep entries without a score (e.g. non-vector backends) + }) + .collect(); + + if !relevant.is_empty() { context.push_str("[Memory context]\n"); - for entry in &entries { + for entry in &relevant { let _ = writeln!(context, "- {}: {}", entry.key, entry.content); } context.push('\n'); @@ -166,7 +179,8 @@ async fn process_channel_message(ctx: Arc, msg: traits::C truncate_with_ellipsis(&msg.content, 80) ); - let memory_context = build_memory_context(ctx.memory.as_ref(), &msg.content).await; + let memory_context = + build_memory_context(ctx.memory.as_ref(), &msg.content, ctx.min_relevance_score).await; if ctx.auto_save_memory { let autosave_key = conversation_memory_key(&msg); @@ -1279,6 +1293,7 @@ pub async fn start_channels(config: Config) -> Result<()> { temperature, auto_save_memory: config.memory.auto_save, max_tool_iterations: config.agent.max_tool_iterations, + min_relevance_score: config.memory.min_relevance_score, }); run_message_dispatch_loop(rx, runtime_ctx, max_in_flight_messages).await; @@ -1504,6 +1519,7 @@ mod tests { temperature: 0.0, auto_save_memory: false, max_tool_iterations: 10, + min_relevance_score: 0.0, }); process_channel_message( @@ -1546,6 +1562,7 @@ mod tests { temperature: 0.0, auto_save_memory: false, max_tool_iterations: 10, + min_relevance_score: 0.0, }); process_channel_message( @@ -1642,6 +1659,7 @@ mod tests { temperature: 0.0, auto_save_memory: false, max_tool_iterations: 10, + min_relevance_score: 0.0, }); let (tx, rx) = tokio::sync::mpsc::channel::(4); @@ -2008,7 +2026,7 @@ mod tests { .await .unwrap(); - let context = build_memory_context(&mem, "age").await; + let context = build_memory_context(&mem, "age", 0.0).await; assert!(context.contains("[Memory context]")); assert!(context.contains("Age is 45")); } diff --git a/src/config/schema.rs b/src/config/schema.rs index 84ba630..cfd77f3 100644 --- a/src/config/schema.rs +++ b/src/config/schema.rs @@ -753,6 +753,11 @@ pub struct MemoryConfig { /// Weight for keyword BM25 in hybrid search (0.0–1.0) #[serde(default = "default_keyword_weight")] pub keyword_weight: f64, + /// Minimum hybrid score (0.0–1.0) for a memory to be included in context. + /// Memories scoring below this threshold are dropped to prevent irrelevant + /// context from bleeding into conversations. Default: 0.4 + #[serde(default = "default_min_relevance_score")] + pub min_relevance_score: f64, /// Max embedding cache entries before LRU eviction #[serde(default = "default_cache_size")] pub embedding_cache_size: usize, @@ -811,10 +816,13 @@ fn default_embedding_dims() -> usize { 1536 } fn default_vector_weight() -> f64 { - 0.7 + 0.4 } fn default_keyword_weight() -> f64 { - 0.3 + 0.6 +} +fn default_min_relevance_score() -> f64 { + 0.4 } fn default_cache_size() -> usize { 10_000 @@ -843,6 +851,7 @@ impl Default for MemoryConfig { embedding_dimensions: default_embedding_dims(), vector_weight: default_vector_weight(), keyword_weight: default_keyword_weight(), + min_relevance_score: default_min_relevance_score(), embedding_cache_size: default_cache_size(), chunk_max_tokens: default_chunk_size(), response_cache_enabled: false, diff --git a/src/onboard/wizard.rs b/src/onboard/wizard.rs index 3c2e9b1..83c1bf1 100644 --- a/src/onboard/wizard.rs +++ b/src/onboard/wizard.rs @@ -274,8 +274,9 @@ fn memory_config_defaults_for_backend(backend: &str) -> MemoryConfig { embedding_provider: "none".to_string(), embedding_model: "text-embedding-3-small".to_string(), embedding_dimensions: 1536, - vector_weight: 0.7, - keyword_weight: 0.3, + vector_weight: 0.4, + keyword_weight: 0.6, + min_relevance_score: 0.4, embedding_cache_size: if profile.uses_sqlite_hygiene { 10000 } else {