fix(channel): preserve memory enrichment for current call while storing raw user turn
This commit is contained in:
parent
6cbdef8c16
commit
b2c5d611be
1 changed files with 142 additions and 6 deletions
|
|
@ -835,11 +835,7 @@ async fn process_channel_message(
|
||||||
let started_at = Instant::now();
|
let started_at = Instant::now();
|
||||||
|
|
||||||
// Preserve user turn before the LLM call so interrupted requests keep context.
|
// Preserve user turn before the LLM call so interrupted requests keep context.
|
||||||
append_sender_turn(
|
append_sender_turn(ctx.as_ref(), &history_key, ChatMessage::user(&msg.content));
|
||||||
ctx.as_ref(),
|
|
||||||
&history_key,
|
|
||||||
ChatMessage::user(&msg.content),
|
|
||||||
);
|
|
||||||
|
|
||||||
// Build history from per-sender conversation cache.
|
// Build history from per-sender conversation cache.
|
||||||
let prior_turns_raw = ctx
|
let prior_turns_raw = ctx
|
||||||
|
|
@ -849,7 +845,14 @@ async fn process_channel_message(
|
||||||
.get(&history_key)
|
.get(&history_key)
|
||||||
.cloned()
|
.cloned()
|
||||||
.unwrap_or_default();
|
.unwrap_or_default();
|
||||||
let prior_turns = normalize_cached_channel_turns(prior_turns_raw);
|
let mut prior_turns = normalize_cached_channel_turns(prior_turns_raw);
|
||||||
|
// Keep persisted history clean (raw user text), but inject memory context
|
||||||
|
// for the current provider call by enriching the newest user turn only.
|
||||||
|
if let Some(last_turn) = prior_turns.last_mut() {
|
||||||
|
if last_turn.role == "user" {
|
||||||
|
last_turn.content = enriched_message.clone();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
let system_prompt = build_channel_system_prompt(ctx.system_prompt.as_str(), &msg.channel);
|
let system_prompt = build_channel_system_prompt(ctx.system_prompt.as_str(), &msg.channel);
|
||||||
let mut history = vec![ChatMessage::system(system_prompt)];
|
let mut history = vec![ChatMessage::system(system_prompt)];
|
||||||
|
|
@ -3218,6 +3221,66 @@ mod tests {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct RecallMemory;
|
||||||
|
|
||||||
|
#[async_trait::async_trait]
|
||||||
|
impl Memory for RecallMemory {
|
||||||
|
fn name(&self) -> &str {
|
||||||
|
"recall-memory"
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn store(
|
||||||
|
&self,
|
||||||
|
_key: &str,
|
||||||
|
_content: &str,
|
||||||
|
_category: crate::memory::MemoryCategory,
|
||||||
|
_session_id: Option<&str>,
|
||||||
|
) -> anyhow::Result<()> {
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn recall(
|
||||||
|
&self,
|
||||||
|
_query: &str,
|
||||||
|
_limit: usize,
|
||||||
|
_session_id: Option<&str>,
|
||||||
|
) -> anyhow::Result<Vec<crate::memory::MemoryEntry>> {
|
||||||
|
Ok(vec![crate::memory::MemoryEntry {
|
||||||
|
id: "entry-1".to_string(),
|
||||||
|
key: "memory_key_1".to_string(),
|
||||||
|
content: "Age is 45".to_string(),
|
||||||
|
category: crate::memory::MemoryCategory::Conversation,
|
||||||
|
timestamp: "2026-02-20T00:00:00Z".to_string(),
|
||||||
|
session_id: None,
|
||||||
|
score: Some(0.9),
|
||||||
|
}])
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn get(&self, _key: &str) -> anyhow::Result<Option<crate::memory::MemoryEntry>> {
|
||||||
|
Ok(None)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn list(
|
||||||
|
&self,
|
||||||
|
_category: Option<&crate::memory::MemoryCategory>,
|
||||||
|
_session_id: Option<&str>,
|
||||||
|
) -> anyhow::Result<Vec<crate::memory::MemoryEntry>> {
|
||||||
|
Ok(Vec::new())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn forget(&self, _key: &str) -> anyhow::Result<bool> {
|
||||||
|
Ok(false)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn count(&self) -> anyhow::Result<usize> {
|
||||||
|
Ok(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn health_check(&self) -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn message_dispatch_processes_messages_in_parallel() {
|
async fn message_dispatch_processes_messages_in_parallel() {
|
||||||
let channel_impl = Arc::new(RecordingChannel::default());
|
let channel_impl = Arc::new(RecordingChannel::default());
|
||||||
|
|
@ -4003,6 +4066,79 @@ mod tests {
|
||||||
assert!(calls[1][3].1.contains("follow up"));
|
assert!(calls[1][3].1.contains("follow up"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn process_channel_message_enriches_current_turn_without_persisting_context() {
|
||||||
|
let channel_impl = Arc::new(RecordingChannel::default());
|
||||||
|
let channel: Arc<dyn Channel> = channel_impl.clone();
|
||||||
|
|
||||||
|
let mut channels_by_name = HashMap::new();
|
||||||
|
channels_by_name.insert(channel.name().to_string(), channel);
|
||||||
|
|
||||||
|
let provider_impl = Arc::new(HistoryCaptureProvider::default());
|
||||||
|
let runtime_ctx = Arc::new(ChannelRuntimeContext {
|
||||||
|
channels_by_name: Arc::new(channels_by_name),
|
||||||
|
provider: provider_impl.clone(),
|
||||||
|
default_provider: Arc::new("test-provider".to_string()),
|
||||||
|
memory: Arc::new(RecallMemory),
|
||||||
|
tools_registry: Arc::new(vec![]),
|
||||||
|
observer: Arc::new(NoopObserver),
|
||||||
|
system_prompt: Arc::new("test-system-prompt".to_string()),
|
||||||
|
model: Arc::new("test-model".to_string()),
|
||||||
|
temperature: 0.0,
|
||||||
|
auto_save_memory: false,
|
||||||
|
max_tool_iterations: 5,
|
||||||
|
min_relevance_score: 0.0,
|
||||||
|
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
|
||||||
|
provider_cache: Arc::new(Mutex::new(HashMap::new())),
|
||||||
|
route_overrides: Arc::new(Mutex::new(HashMap::new())),
|
||||||
|
api_key: None,
|
||||||
|
api_url: None,
|
||||||
|
reliability: Arc::new(crate::config::ReliabilityConfig::default()),
|
||||||
|
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
|
||||||
|
workspace_dir: Arc::new(std::env::temp_dir()),
|
||||||
|
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
|
||||||
|
interrupt_on_new_message: false,
|
||||||
|
multimodal: crate::config::MultimodalConfig::default(),
|
||||||
|
});
|
||||||
|
|
||||||
|
process_channel_message(
|
||||||
|
runtime_ctx.clone(),
|
||||||
|
traits::ChannelMessage {
|
||||||
|
id: "msg-ctx-1".to_string(),
|
||||||
|
sender: "alice".to_string(),
|
||||||
|
reply_target: "chat-ctx".to_string(),
|
||||||
|
content: "hello".to_string(),
|
||||||
|
channel: "test-channel".to_string(),
|
||||||
|
timestamp: 1,
|
||||||
|
thread_ts: None,
|
||||||
|
},
|
||||||
|
CancellationToken::new(),
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
let calls = provider_impl
|
||||||
|
.calls
|
||||||
|
.lock()
|
||||||
|
.unwrap_or_else(|e| e.into_inner());
|
||||||
|
assert_eq!(calls.len(), 1);
|
||||||
|
assert_eq!(calls[0].len(), 2);
|
||||||
|
assert_eq!(calls[0][1].0, "user");
|
||||||
|
assert!(calls[0][1].1.contains("[Memory context]"));
|
||||||
|
assert!(calls[0][1].1.contains("Age is 45"));
|
||||||
|
assert!(calls[0][1].1.contains("hello"));
|
||||||
|
|
||||||
|
let histories = runtime_ctx
|
||||||
|
.conversation_histories
|
||||||
|
.lock()
|
||||||
|
.unwrap_or_else(|e| e.into_inner());
|
||||||
|
let turns = histories
|
||||||
|
.get("test-channel_alice")
|
||||||
|
.expect("history should be stored for sender");
|
||||||
|
assert_eq!(turns[0].role, "user");
|
||||||
|
assert_eq!(turns[0].content, "hello");
|
||||||
|
assert!(!turns[0].content.contains("[Memory context]"));
|
||||||
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn process_channel_message_telegram_keeps_system_instruction_at_top_only() {
|
async fn process_channel_message_telegram_keeps_system_instruction_at_top_only() {
|
||||||
let channel_impl = Arc::new(TelegramRecordingChannel::default());
|
let channel_impl = Arc::new(TelegramRecordingChannel::default());
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue