fix(channel): preserve memory enrichment for current call while storing raw user turn

This commit is contained in:
Chummy 2026-02-20 10:41:42 +08:00
parent 6cbdef8c16
commit b2c5d611be

View file

@ -835,11 +835,7 @@ async fn process_channel_message(
let started_at = Instant::now(); let started_at = Instant::now();
// Preserve user turn before the LLM call so interrupted requests keep context. // Preserve user turn before the LLM call so interrupted requests keep context.
append_sender_turn( append_sender_turn(ctx.as_ref(), &history_key, ChatMessage::user(&msg.content));
ctx.as_ref(),
&history_key,
ChatMessage::user(&msg.content),
);
// Build history from per-sender conversation cache. // Build history from per-sender conversation cache.
let prior_turns_raw = ctx let prior_turns_raw = ctx
@ -849,7 +845,14 @@ async fn process_channel_message(
.get(&history_key) .get(&history_key)
.cloned() .cloned()
.unwrap_or_default(); .unwrap_or_default();
let prior_turns = normalize_cached_channel_turns(prior_turns_raw); let mut prior_turns = normalize_cached_channel_turns(prior_turns_raw);
// Keep persisted history clean (raw user text), but inject memory context
// for the current provider call by enriching the newest user turn only.
if let Some(last_turn) = prior_turns.last_mut() {
if last_turn.role == "user" {
last_turn.content = enriched_message.clone();
}
}
let system_prompt = build_channel_system_prompt(ctx.system_prompt.as_str(), &msg.channel); let system_prompt = build_channel_system_prompt(ctx.system_prompt.as_str(), &msg.channel);
let mut history = vec![ChatMessage::system(system_prompt)]; let mut history = vec![ChatMessage::system(system_prompt)];
@ -3218,6 +3221,66 @@ mod tests {
} }
} }
struct RecallMemory;
#[async_trait::async_trait]
impl Memory for RecallMemory {
fn name(&self) -> &str {
"recall-memory"
}
async fn store(
&self,
_key: &str,
_content: &str,
_category: crate::memory::MemoryCategory,
_session_id: Option<&str>,
) -> anyhow::Result<()> {
Ok(())
}
async fn recall(
&self,
_query: &str,
_limit: usize,
_session_id: Option<&str>,
) -> anyhow::Result<Vec<crate::memory::MemoryEntry>> {
Ok(vec![crate::memory::MemoryEntry {
id: "entry-1".to_string(),
key: "memory_key_1".to_string(),
content: "Age is 45".to_string(),
category: crate::memory::MemoryCategory::Conversation,
timestamp: "2026-02-20T00:00:00Z".to_string(),
session_id: None,
score: Some(0.9),
}])
}
async fn get(&self, _key: &str) -> anyhow::Result<Option<crate::memory::MemoryEntry>> {
Ok(None)
}
async fn list(
&self,
_category: Option<&crate::memory::MemoryCategory>,
_session_id: Option<&str>,
) -> anyhow::Result<Vec<crate::memory::MemoryEntry>> {
Ok(Vec::new())
}
async fn forget(&self, _key: &str) -> anyhow::Result<bool> {
Ok(false)
}
async fn count(&self) -> anyhow::Result<usize> {
Ok(1)
}
async fn health_check(&self) -> bool {
true
}
}
#[tokio::test] #[tokio::test]
async fn message_dispatch_processes_messages_in_parallel() { async fn message_dispatch_processes_messages_in_parallel() {
let channel_impl = Arc::new(RecordingChannel::default()); let channel_impl = Arc::new(RecordingChannel::default());
@ -4003,6 +4066,79 @@ mod tests {
assert!(calls[1][3].1.contains("follow up")); assert!(calls[1][3].1.contains("follow up"));
} }
#[tokio::test]
async fn process_channel_message_enriches_current_turn_without_persisting_context() {
let channel_impl = Arc::new(RecordingChannel::default());
let channel: Arc<dyn Channel> = channel_impl.clone();
let mut channels_by_name = HashMap::new();
channels_by_name.insert(channel.name().to_string(), channel);
let provider_impl = Arc::new(HistoryCaptureProvider::default());
let runtime_ctx = Arc::new(ChannelRuntimeContext {
channels_by_name: Arc::new(channels_by_name),
provider: provider_impl.clone(),
default_provider: Arc::new("test-provider".to_string()),
memory: Arc::new(RecallMemory),
tools_registry: Arc::new(vec![]),
observer: Arc::new(NoopObserver),
system_prompt: Arc::new("test-system-prompt".to_string()),
model: Arc::new("test-model".to_string()),
temperature: 0.0,
auto_save_memory: false,
max_tool_iterations: 5,
min_relevance_score: 0.0,
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
provider_cache: Arc::new(Mutex::new(HashMap::new())),
route_overrides: Arc::new(Mutex::new(HashMap::new())),
api_key: None,
api_url: None,
reliability: Arc::new(crate::config::ReliabilityConfig::default()),
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
workspace_dir: Arc::new(std::env::temp_dir()),
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
interrupt_on_new_message: false,
multimodal: crate::config::MultimodalConfig::default(),
});
process_channel_message(
runtime_ctx.clone(),
traits::ChannelMessage {
id: "msg-ctx-1".to_string(),
sender: "alice".to_string(),
reply_target: "chat-ctx".to_string(),
content: "hello".to_string(),
channel: "test-channel".to_string(),
timestamp: 1,
thread_ts: None,
},
CancellationToken::new(),
)
.await;
let calls = provider_impl
.calls
.lock()
.unwrap_or_else(|e| e.into_inner());
assert_eq!(calls.len(), 1);
assert_eq!(calls[0].len(), 2);
assert_eq!(calls[0][1].0, "user");
assert!(calls[0][1].1.contains("[Memory context]"));
assert!(calls[0][1].1.contains("Age is 45"));
assert!(calls[0][1].1.contains("hello"));
let histories = runtime_ctx
.conversation_histories
.lock()
.unwrap_or_else(|e| e.into_inner());
let turns = histories
.get("test-channel_alice")
.expect("history should be stored for sender");
assert_eq!(turns[0].role, "user");
assert_eq!(turns[0].content, "hello");
assert!(!turns[0].content.contains("[Memory context]"));
}
#[tokio::test] #[tokio::test]
async fn process_channel_message_telegram_keeps_system_instruction_at_top_only() { async fn process_channel_message_telegram_keeps_system_instruction_at_top_only() {
let channel_impl = Arc::new(TelegramRecordingChannel::default()); let channel_impl = Arc::new(TelegramRecordingChannel::default());