From b2c5d611be5f571c7a9d1c95cc79524ed465ca2f Mon Sep 17 00:00:00 2001 From: Chummy Date: Fri, 20 Feb 2026 10:41:42 +0800 Subject: [PATCH] fix(channel): preserve memory enrichment for current call while storing raw user turn --- src/channels/mod.rs | 148 ++++++++++++++++++++++++++++++++++++++++++-- 1 file changed, 142 insertions(+), 6 deletions(-) diff --git a/src/channels/mod.rs b/src/channels/mod.rs index 9a40bb8..fd848d7 100644 --- a/src/channels/mod.rs +++ b/src/channels/mod.rs @@ -835,11 +835,7 @@ async fn process_channel_message( let started_at = Instant::now(); // Preserve user turn before the LLM call so interrupted requests keep context. - append_sender_turn( - ctx.as_ref(), - &history_key, - ChatMessage::user(&msg.content), - ); + append_sender_turn(ctx.as_ref(), &history_key, ChatMessage::user(&msg.content)); // Build history from per-sender conversation cache. let prior_turns_raw = ctx @@ -849,7 +845,14 @@ async fn process_channel_message( .get(&history_key) .cloned() .unwrap_or_default(); - let prior_turns = normalize_cached_channel_turns(prior_turns_raw); + let mut prior_turns = normalize_cached_channel_turns(prior_turns_raw); + // Keep persisted history clean (raw user text), but inject memory context + // for the current provider call by enriching the newest user turn only. + if let Some(last_turn) = prior_turns.last_mut() { + if last_turn.role == "user" { + last_turn.content = enriched_message.clone(); + } + } let system_prompt = build_channel_system_prompt(ctx.system_prompt.as_str(), &msg.channel); let mut history = vec![ChatMessage::system(system_prompt)]; @@ -3218,6 +3221,66 @@ mod tests { } } + struct RecallMemory; + + #[async_trait::async_trait] + impl Memory for RecallMemory { + fn name(&self) -> &str { + "recall-memory" + } + + async fn store( + &self, + _key: &str, + _content: &str, + _category: crate::memory::MemoryCategory, + _session_id: Option<&str>, + ) -> anyhow::Result<()> { + Ok(()) + } + + async fn recall( + &self, + _query: &str, + _limit: usize, + _session_id: Option<&str>, + ) -> anyhow::Result> { + Ok(vec![crate::memory::MemoryEntry { + id: "entry-1".to_string(), + key: "memory_key_1".to_string(), + content: "Age is 45".to_string(), + category: crate::memory::MemoryCategory::Conversation, + timestamp: "2026-02-20T00:00:00Z".to_string(), + session_id: None, + score: Some(0.9), + }]) + } + + async fn get(&self, _key: &str) -> anyhow::Result> { + Ok(None) + } + + async fn list( + &self, + _category: Option<&crate::memory::MemoryCategory>, + _session_id: Option<&str>, + ) -> anyhow::Result> { + Ok(Vec::new()) + } + + async fn forget(&self, _key: &str) -> anyhow::Result { + Ok(false) + } + + async fn count(&self) -> anyhow::Result { + Ok(1) + } + + async fn health_check(&self) -> bool { + true + } + } + #[tokio::test] async fn message_dispatch_processes_messages_in_parallel() { let channel_impl = Arc::new(RecordingChannel::default()); @@ -4003,6 +4066,79 @@ mod tests { assert!(calls[1][3].1.contains("follow up")); } + #[tokio::test] + async fn process_channel_message_enriches_current_turn_without_persisting_context() { + let channel_impl = Arc::new(RecordingChannel::default()); + let channel: Arc = channel_impl.clone(); + + let mut channels_by_name = HashMap::new(); + channels_by_name.insert(channel.name().to_string(), channel); + + let provider_impl = Arc::new(HistoryCaptureProvider::default()); + let runtime_ctx = Arc::new(ChannelRuntimeContext { + channels_by_name: Arc::new(channels_by_name), + provider: provider_impl.clone(), + default_provider: Arc::new("test-provider".to_string()), + memory: Arc::new(RecallMemory), + tools_registry: Arc::new(vec![]), + observer: Arc::new(NoopObserver), + system_prompt: Arc::new("test-system-prompt".to_string()), + model: Arc::new("test-model".to_string()), + temperature: 0.0, + auto_save_memory: false, + max_tool_iterations: 5, + min_relevance_score: 0.0, + conversation_histories: Arc::new(Mutex::new(HashMap::new())), + provider_cache: Arc::new(Mutex::new(HashMap::new())), + route_overrides: Arc::new(Mutex::new(HashMap::new())), + api_key: None, + api_url: None, + reliability: Arc::new(crate::config::ReliabilityConfig::default()), + provider_runtime_options: providers::ProviderRuntimeOptions::default(), + workspace_dir: Arc::new(std::env::temp_dir()), + message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS, + interrupt_on_new_message: false, + multimodal: crate::config::MultimodalConfig::default(), + }); + + process_channel_message( + runtime_ctx.clone(), + traits::ChannelMessage { + id: "msg-ctx-1".to_string(), + sender: "alice".to_string(), + reply_target: "chat-ctx".to_string(), + content: "hello".to_string(), + channel: "test-channel".to_string(), + timestamp: 1, + thread_ts: None, + }, + CancellationToken::new(), + ) + .await; + + let calls = provider_impl + .calls + .lock() + .unwrap_or_else(|e| e.into_inner()); + assert_eq!(calls.len(), 1); + assert_eq!(calls[0].len(), 2); + assert_eq!(calls[0][1].0, "user"); + assert!(calls[0][1].1.contains("[Memory context]")); + assert!(calls[0][1].1.contains("Age is 45")); + assert!(calls[0][1].1.contains("hello")); + + let histories = runtime_ctx + .conversation_histories + .lock() + .unwrap_or_else(|e| e.into_inner()); + let turns = histories + .get("test-channel_alice") + .expect("history should be stored for sender"); + assert_eq!(turns[0].role, "user"); + assert_eq!(turns[0].content, "hello"); + assert!(!turns[0].content.contains("[Memory context]")); + } + #[tokio::test] async fn process_channel_message_telegram_keeps_system_instruction_at_top_only() { let channel_impl = Arc::new(TelegramRecordingChannel::default());