fix(agent): use config max_tool_iterations, add memory relevance filtering, rebalance search weights
Three fixes for conversation quality issues: 1. loop_.rs and channels now read max_tool_iterations from AgentConfig instead of using a hardcoded constant of 10, making it configurable. 2. Memory recall now filters entries below a configurable min_relevance_score threshold (default 0.4), preventing unrelated memories from bleeding into conversation context. 3. Default hybrid search weights rebalanced from 70/30 vector/keyword to 40/60, reducing cross-topic semantic bleed. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
21c5f58363
commit
8a1e7cc7ef
6 changed files with 90 additions and 24 deletions
|
|
@ -271,7 +271,10 @@ impl Agent {
|
||||||
.memory(memory)
|
.memory(memory)
|
||||||
.observer(observer)
|
.observer(observer)
|
||||||
.tool_dispatcher(tool_dispatcher)
|
.tool_dispatcher(tool_dispatcher)
|
||||||
.memory_loader(Box::new(DefaultMemoryLoader::default()))
|
.memory_loader(Box::new(DefaultMemoryLoader::new(
|
||||||
|
5,
|
||||||
|
config.memory.min_relevance_score,
|
||||||
|
)))
|
||||||
.prompt_builder(SystemPromptBuilder::with_defaults())
|
.prompt_builder(SystemPromptBuilder::with_defaults())
|
||||||
.config(config.agent.clone())
|
.config(config.agent.clone())
|
||||||
.model_name(model_name)
|
.model_name(model_name)
|
||||||
|
|
|
||||||
|
|
@ -16,8 +16,7 @@ use std::time::Instant;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
/// Default maximum agentic tool-use iterations per user message to prevent runaway loops.
|
/// Default maximum agentic tool-use iterations per user message to prevent runaway loops.
|
||||||
/// Prefer passing the config-driven value via `run_tool_call_loop`; this constant is only
|
/// Used as a safe fallback when `max_tool_iterations` is unset or configured as zero.
|
||||||
/// used when callers omit the parameter.
|
|
||||||
const DEFAULT_MAX_TOOL_ITERATIONS: usize = 10;
|
const DEFAULT_MAX_TOOL_ITERATIONS: usize = 10;
|
||||||
|
|
||||||
static SENSITIVE_KEY_PATTERNS: LazyLock<RegexSet> = LazyLock::new(|| {
|
static SENSITIVE_KEY_PATTERNS: LazyLock<RegexSet> = LazyLock::new(|| {
|
||||||
|
|
@ -202,15 +201,25 @@ async fn auto_compact_history(
|
||||||
Ok(true)
|
Ok(true)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Build context preamble by searching memory for relevant entries
|
/// Build context preamble by searching memory for relevant entries.
|
||||||
async fn build_context(mem: &dyn Memory, user_msg: &str) -> String {
|
/// Entries with a hybrid score below `min_relevance_score` are dropped to
|
||||||
|
/// prevent unrelated memories from bleeding into the conversation.
|
||||||
|
async fn build_context(mem: &dyn Memory, user_msg: &str, min_relevance_score: f64) -> String {
|
||||||
let mut context = String::new();
|
let mut context = String::new();
|
||||||
|
|
||||||
// Pull relevant memories for this message
|
// Pull relevant memories for this message
|
||||||
if let Ok(entries) = mem.recall(user_msg, 5, None).await {
|
if let Ok(entries) = mem.recall(user_msg, 5, None).await {
|
||||||
if !entries.is_empty() {
|
let relevant: Vec<_> = entries
|
||||||
|
.iter()
|
||||||
|
.filter(|e| match e.score {
|
||||||
|
Some(score) => score >= min_relevance_score,
|
||||||
|
None => true,
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
if !relevant.is_empty() {
|
||||||
context.push_str("[Memory context]\n");
|
context.push_str("[Memory context]\n");
|
||||||
for entry in &entries {
|
for entry in &relevant {
|
||||||
let _ = writeln!(context, "- {}: {}", entry.key, entry.content);
|
let _ = writeln!(context, "- {}: {}", entry.key, entry.content);
|
||||||
}
|
}
|
||||||
context.push('\n');
|
context.push('\n');
|
||||||
|
|
@ -598,6 +607,7 @@ pub(crate) async fn agent_turn(
|
||||||
model: &str,
|
model: &str,
|
||||||
temperature: f64,
|
temperature: f64,
|
||||||
silent: bool,
|
silent: bool,
|
||||||
|
max_tool_iterations: usize,
|
||||||
) -> Result<String> {
|
) -> Result<String> {
|
||||||
run_tool_call_loop(
|
run_tool_call_loop(
|
||||||
provider,
|
provider,
|
||||||
|
|
@ -610,7 +620,7 @@ pub(crate) async fn agent_turn(
|
||||||
silent,
|
silent,
|
||||||
None,
|
None,
|
||||||
"channel",
|
"channel",
|
||||||
DEFAULT_MAX_TOOL_ITERATIONS,
|
max_tool_iterations,
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
|
|
@ -631,6 +641,12 @@ pub(crate) async fn run_tool_call_loop(
|
||||||
channel_name: &str,
|
channel_name: &str,
|
||||||
max_tool_iterations: usize,
|
max_tool_iterations: usize,
|
||||||
) -> Result<String> {
|
) -> Result<String> {
|
||||||
|
let max_iterations = if max_tool_iterations == 0 {
|
||||||
|
DEFAULT_MAX_TOOL_ITERATIONS
|
||||||
|
} else {
|
||||||
|
max_tool_iterations
|
||||||
|
};
|
||||||
|
|
||||||
// Build native tool definitions once if the provider supports them.
|
// Build native tool definitions once if the provider supports them.
|
||||||
let use_native_tools = provider.supports_native_tools() && !tools_registry.is_empty();
|
let use_native_tools = provider.supports_native_tools() && !tools_registry.is_empty();
|
||||||
let tool_definitions = if use_native_tools {
|
let tool_definitions = if use_native_tools {
|
||||||
|
|
@ -639,7 +655,7 @@ pub(crate) async fn run_tool_call_loop(
|
||||||
Vec::new()
|
Vec::new()
|
||||||
};
|
};
|
||||||
|
|
||||||
for _iteration in 0..max_tool_iterations {
|
for _iteration in 0..max_iterations {
|
||||||
observer.record_event(&ObserverEvent::LlmRequest {
|
observer.record_event(&ObserverEvent::LlmRequest {
|
||||||
provider: provider_name.to_string(),
|
provider: provider_name.to_string(),
|
||||||
model: model.to_string(),
|
model: model.to_string(),
|
||||||
|
|
@ -857,7 +873,7 @@ pub(crate) async fn run_tool_call_loop(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
anyhow::bail!("Agent exceeded maximum tool iterations ({max_tool_iterations})")
|
anyhow::bail!("Agent exceeded maximum tool iterations ({max_iterations})")
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Build the tool instruction block for the system prompt so the LLM knows
|
/// Build the tool instruction block for the system prompt so the LLM knows
|
||||||
|
|
@ -1142,7 +1158,8 @@ pub async fn run(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Inject memory + hardware RAG context into user message
|
// Inject memory + hardware RAG context into user message
|
||||||
let mem_context = build_context(mem.as_ref(), &msg).await;
|
let mem_context =
|
||||||
|
build_context(mem.as_ref(), &msg, config.memory.min_relevance_score).await;
|
||||||
let rag_limit = if config.agent.compact_context { 2 } else { 5 };
|
let rag_limit = if config.agent.compact_context { 2 } else { 5 };
|
||||||
let hw_context = hardware_rag
|
let hw_context = hardware_rag
|
||||||
.as_ref()
|
.as_ref()
|
||||||
|
|
@ -1270,7 +1287,8 @@ pub async fn run(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Inject memory + hardware RAG context into user message
|
// Inject memory + hardware RAG context into user message
|
||||||
let mem_context = build_context(mem.as_ref(), &user_input).await;
|
let mem_context =
|
||||||
|
build_context(mem.as_ref(), &user_input, config.memory.min_relevance_score).await;
|
||||||
let rag_limit = if config.agent.compact_context { 2 } else { 5 };
|
let rag_limit = if config.agent.compact_context { 2 } else { 5 };
|
||||||
let hw_context = hardware_rag
|
let hw_context = hardware_rag
|
||||||
.as_ref()
|
.as_ref()
|
||||||
|
|
@ -1487,7 +1505,7 @@ pub async fn process_message(config: Config, message: &str) -> Result<String> {
|
||||||
);
|
);
|
||||||
system_prompt.push_str(&build_tool_instructions(&tools_registry));
|
system_prompt.push_str(&build_tool_instructions(&tools_registry));
|
||||||
|
|
||||||
let mem_context = build_context(mem.as_ref(), message).await;
|
let mem_context = build_context(mem.as_ref(), message, config.memory.min_relevance_score).await;
|
||||||
let rag_limit = if config.agent.compact_context { 2 } else { 5 };
|
let rag_limit = if config.agent.compact_context { 2 } else { 5 };
|
||||||
let hw_context = hardware_rag
|
let hw_context = hardware_rag
|
||||||
.as_ref()
|
.as_ref()
|
||||||
|
|
@ -1514,6 +1532,7 @@ pub async fn process_message(config: Config, message: &str) -> Result<String> {
|
||||||
&model_name,
|
&model_name,
|
||||||
config.default_temperature,
|
config.default_temperature,
|
||||||
true,
|
true,
|
||||||
|
config.agent.max_tool_iterations,
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -10,18 +10,23 @@ pub trait MemoryLoader: Send + Sync {
|
||||||
|
|
||||||
pub struct DefaultMemoryLoader {
|
pub struct DefaultMemoryLoader {
|
||||||
limit: usize,
|
limit: usize,
|
||||||
|
min_relevance_score: f64,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for DefaultMemoryLoader {
|
impl Default for DefaultMemoryLoader {
|
||||||
fn default() -> Self {
|
fn default() -> Self {
|
||||||
Self { limit: 5 }
|
Self {
|
||||||
|
limit: 5,
|
||||||
|
min_relevance_score: 0.4,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl DefaultMemoryLoader {
|
impl DefaultMemoryLoader {
|
||||||
pub fn new(limit: usize) -> Self {
|
pub fn new(limit: usize, min_relevance_score: f64) -> Self {
|
||||||
Self {
|
Self {
|
||||||
limit: limit.max(1),
|
limit: limit.max(1),
|
||||||
|
min_relevance_score,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -40,8 +45,19 @@ impl MemoryLoader for DefaultMemoryLoader {
|
||||||
|
|
||||||
let mut context = String::from("[Memory context]\n");
|
let mut context = String::from("[Memory context]\n");
|
||||||
for entry in entries {
|
for entry in entries {
|
||||||
|
if let Some(score) = entry.score {
|
||||||
|
if score < self.min_relevance_score {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
let _ = writeln!(context, "- {}: {}", entry.key, entry.content);
|
let _ = writeln!(context, "- {}: {}", entry.key, entry.content);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// If all entries were below threshold, return empty
|
||||||
|
if context == "[Memory context]\n" {
|
||||||
|
return Ok(String::new());
|
||||||
|
}
|
||||||
|
|
||||||
context.push('\n');
|
context.push('\n');
|
||||||
Ok(context)
|
Ok(context)
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -72,6 +72,7 @@ struct ChannelRuntimeContext {
|
||||||
temperature: f64,
|
temperature: f64,
|
||||||
auto_save_memory: bool,
|
auto_save_memory: bool,
|
||||||
max_tool_iterations: usize,
|
max_tool_iterations: usize,
|
||||||
|
min_relevance_score: f64,
|
||||||
}
|
}
|
||||||
|
|
||||||
fn conversation_memory_key(msg: &traits::ChannelMessage) -> String {
|
fn conversation_memory_key(msg: &traits::ChannelMessage) -> String {
|
||||||
|
|
@ -87,13 +88,25 @@ fn channel_delivery_instructions(channel_name: &str) -> Option<&'static str> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn build_memory_context(mem: &dyn Memory, user_msg: &str) -> String {
|
async fn build_memory_context(
|
||||||
|
mem: &dyn Memory,
|
||||||
|
user_msg: &str,
|
||||||
|
min_relevance_score: f64,
|
||||||
|
) -> String {
|
||||||
let mut context = String::new();
|
let mut context = String::new();
|
||||||
|
|
||||||
if let Ok(entries) = mem.recall(user_msg, 5, None).await {
|
if let Ok(entries) = mem.recall(user_msg, 5, None).await {
|
||||||
if !entries.is_empty() {
|
let relevant: Vec<_> = entries
|
||||||
|
.iter()
|
||||||
|
.filter(|e| match e.score {
|
||||||
|
Some(score) => score >= min_relevance_score,
|
||||||
|
None => true, // keep entries without a score (e.g. non-vector backends)
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
if !relevant.is_empty() {
|
||||||
context.push_str("[Memory context]\n");
|
context.push_str("[Memory context]\n");
|
||||||
for entry in &entries {
|
for entry in &relevant {
|
||||||
let _ = writeln!(context, "- {}: {}", entry.key, entry.content);
|
let _ = writeln!(context, "- {}: {}", entry.key, entry.content);
|
||||||
}
|
}
|
||||||
context.push('\n');
|
context.push('\n');
|
||||||
|
|
@ -166,7 +179,8 @@ async fn process_channel_message(ctx: Arc<ChannelRuntimeContext>, msg: traits::C
|
||||||
truncate_with_ellipsis(&msg.content, 80)
|
truncate_with_ellipsis(&msg.content, 80)
|
||||||
);
|
);
|
||||||
|
|
||||||
let memory_context = build_memory_context(ctx.memory.as_ref(), &msg.content).await;
|
let memory_context =
|
||||||
|
build_memory_context(ctx.memory.as_ref(), &msg.content, ctx.min_relevance_score).await;
|
||||||
|
|
||||||
if ctx.auto_save_memory {
|
if ctx.auto_save_memory {
|
||||||
let autosave_key = conversation_memory_key(&msg);
|
let autosave_key = conversation_memory_key(&msg);
|
||||||
|
|
@ -1279,6 +1293,7 @@ pub async fn start_channels(config: Config) -> Result<()> {
|
||||||
temperature,
|
temperature,
|
||||||
auto_save_memory: config.memory.auto_save,
|
auto_save_memory: config.memory.auto_save,
|
||||||
max_tool_iterations: config.agent.max_tool_iterations,
|
max_tool_iterations: config.agent.max_tool_iterations,
|
||||||
|
min_relevance_score: config.memory.min_relevance_score,
|
||||||
});
|
});
|
||||||
|
|
||||||
run_message_dispatch_loop(rx, runtime_ctx, max_in_flight_messages).await;
|
run_message_dispatch_loop(rx, runtime_ctx, max_in_flight_messages).await;
|
||||||
|
|
@ -1504,6 +1519,7 @@ mod tests {
|
||||||
temperature: 0.0,
|
temperature: 0.0,
|
||||||
auto_save_memory: false,
|
auto_save_memory: false,
|
||||||
max_tool_iterations: 10,
|
max_tool_iterations: 10,
|
||||||
|
min_relevance_score: 0.0,
|
||||||
});
|
});
|
||||||
|
|
||||||
process_channel_message(
|
process_channel_message(
|
||||||
|
|
@ -1546,6 +1562,7 @@ mod tests {
|
||||||
temperature: 0.0,
|
temperature: 0.0,
|
||||||
auto_save_memory: false,
|
auto_save_memory: false,
|
||||||
max_tool_iterations: 10,
|
max_tool_iterations: 10,
|
||||||
|
min_relevance_score: 0.0,
|
||||||
});
|
});
|
||||||
|
|
||||||
process_channel_message(
|
process_channel_message(
|
||||||
|
|
@ -1642,6 +1659,7 @@ mod tests {
|
||||||
temperature: 0.0,
|
temperature: 0.0,
|
||||||
auto_save_memory: false,
|
auto_save_memory: false,
|
||||||
max_tool_iterations: 10,
|
max_tool_iterations: 10,
|
||||||
|
min_relevance_score: 0.0,
|
||||||
});
|
});
|
||||||
|
|
||||||
let (tx, rx) = tokio::sync::mpsc::channel::<traits::ChannelMessage>(4);
|
let (tx, rx) = tokio::sync::mpsc::channel::<traits::ChannelMessage>(4);
|
||||||
|
|
@ -2008,7 +2026,7 @@ mod tests {
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let context = build_memory_context(&mem, "age").await;
|
let context = build_memory_context(&mem, "age", 0.0).await;
|
||||||
assert!(context.contains("[Memory context]"));
|
assert!(context.contains("[Memory context]"));
|
||||||
assert!(context.contains("Age is 45"));
|
assert!(context.contains("Age is 45"));
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -753,6 +753,11 @@ pub struct MemoryConfig {
|
||||||
/// Weight for keyword BM25 in hybrid search (0.0–1.0)
|
/// Weight for keyword BM25 in hybrid search (0.0–1.0)
|
||||||
#[serde(default = "default_keyword_weight")]
|
#[serde(default = "default_keyword_weight")]
|
||||||
pub keyword_weight: f64,
|
pub keyword_weight: f64,
|
||||||
|
/// Minimum hybrid score (0.0–1.0) for a memory to be included in context.
|
||||||
|
/// Memories scoring below this threshold are dropped to prevent irrelevant
|
||||||
|
/// context from bleeding into conversations. Default: 0.4
|
||||||
|
#[serde(default = "default_min_relevance_score")]
|
||||||
|
pub min_relevance_score: f64,
|
||||||
/// Max embedding cache entries before LRU eviction
|
/// Max embedding cache entries before LRU eviction
|
||||||
#[serde(default = "default_cache_size")]
|
#[serde(default = "default_cache_size")]
|
||||||
pub embedding_cache_size: usize,
|
pub embedding_cache_size: usize,
|
||||||
|
|
@ -811,10 +816,13 @@ fn default_embedding_dims() -> usize {
|
||||||
1536
|
1536
|
||||||
}
|
}
|
||||||
fn default_vector_weight() -> f64 {
|
fn default_vector_weight() -> f64 {
|
||||||
0.7
|
0.4
|
||||||
}
|
}
|
||||||
fn default_keyword_weight() -> f64 {
|
fn default_keyword_weight() -> f64 {
|
||||||
0.3
|
0.6
|
||||||
|
}
|
||||||
|
fn default_min_relevance_score() -> f64 {
|
||||||
|
0.4
|
||||||
}
|
}
|
||||||
fn default_cache_size() -> usize {
|
fn default_cache_size() -> usize {
|
||||||
10_000
|
10_000
|
||||||
|
|
@ -843,6 +851,7 @@ impl Default for MemoryConfig {
|
||||||
embedding_dimensions: default_embedding_dims(),
|
embedding_dimensions: default_embedding_dims(),
|
||||||
vector_weight: default_vector_weight(),
|
vector_weight: default_vector_weight(),
|
||||||
keyword_weight: default_keyword_weight(),
|
keyword_weight: default_keyword_weight(),
|
||||||
|
min_relevance_score: default_min_relevance_score(),
|
||||||
embedding_cache_size: default_cache_size(),
|
embedding_cache_size: default_cache_size(),
|
||||||
chunk_max_tokens: default_chunk_size(),
|
chunk_max_tokens: default_chunk_size(),
|
||||||
response_cache_enabled: false,
|
response_cache_enabled: false,
|
||||||
|
|
|
||||||
|
|
@ -274,8 +274,9 @@ fn memory_config_defaults_for_backend(backend: &str) -> MemoryConfig {
|
||||||
embedding_provider: "none".to_string(),
|
embedding_provider: "none".to_string(),
|
||||||
embedding_model: "text-embedding-3-small".to_string(),
|
embedding_model: "text-embedding-3-small".to_string(),
|
||||||
embedding_dimensions: 1536,
|
embedding_dimensions: 1536,
|
||||||
vector_weight: 0.7,
|
vector_weight: 0.4,
|
||||||
keyword_weight: 0.3,
|
keyword_weight: 0.6,
|
||||||
|
min_relevance_score: 0.4,
|
||||||
embedding_cache_size: if profile.uses_sqlite_hygiene {
|
embedding_cache_size: if profile.uses_sqlite_hygiene {
|
||||||
10000
|
10000
|
||||||
} else {
|
} else {
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue