fix(channel): preserve memory enrichment for current call while storing raw user turn
This commit is contained in:
parent
6cbdef8c16
commit
b2c5d611be
1 changed files with 142 additions and 6 deletions
|
|
@ -835,11 +835,7 @@ async fn process_channel_message(
|
|||
let started_at = Instant::now();
|
||||
|
||||
// Preserve user turn before the LLM call so interrupted requests keep context.
|
||||
append_sender_turn(
|
||||
ctx.as_ref(),
|
||||
&history_key,
|
||||
ChatMessage::user(&msg.content),
|
||||
);
|
||||
append_sender_turn(ctx.as_ref(), &history_key, ChatMessage::user(&msg.content));
|
||||
|
||||
// Build history from per-sender conversation cache.
|
||||
let prior_turns_raw = ctx
|
||||
|
|
@ -849,7 +845,14 @@ async fn process_channel_message(
|
|||
.get(&history_key)
|
||||
.cloned()
|
||||
.unwrap_or_default();
|
||||
let prior_turns = normalize_cached_channel_turns(prior_turns_raw);
|
||||
let mut prior_turns = normalize_cached_channel_turns(prior_turns_raw);
|
||||
// Keep persisted history clean (raw user text), but inject memory context
|
||||
// for the current provider call by enriching the newest user turn only.
|
||||
if let Some(last_turn) = prior_turns.last_mut() {
|
||||
if last_turn.role == "user" {
|
||||
last_turn.content = enriched_message.clone();
|
||||
}
|
||||
}
|
||||
|
||||
let system_prompt = build_channel_system_prompt(ctx.system_prompt.as_str(), &msg.channel);
|
||||
let mut history = vec![ChatMessage::system(system_prompt)];
|
||||
|
|
@ -3218,6 +3221,66 @@ mod tests {
|
|||
}
|
||||
}
|
||||
|
||||
struct RecallMemory;
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl Memory for RecallMemory {
|
||||
fn name(&self) -> &str {
|
||||
"recall-memory"
|
||||
}
|
||||
|
||||
async fn store(
|
||||
&self,
|
||||
_key: &str,
|
||||
_content: &str,
|
||||
_category: crate::memory::MemoryCategory,
|
||||
_session_id: Option<&str>,
|
||||
) -> anyhow::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn recall(
|
||||
&self,
|
||||
_query: &str,
|
||||
_limit: usize,
|
||||
_session_id: Option<&str>,
|
||||
) -> anyhow::Result<Vec<crate::memory::MemoryEntry>> {
|
||||
Ok(vec![crate::memory::MemoryEntry {
|
||||
id: "entry-1".to_string(),
|
||||
key: "memory_key_1".to_string(),
|
||||
content: "Age is 45".to_string(),
|
||||
category: crate::memory::MemoryCategory::Conversation,
|
||||
timestamp: "2026-02-20T00:00:00Z".to_string(),
|
||||
session_id: None,
|
||||
score: Some(0.9),
|
||||
}])
|
||||
}
|
||||
|
||||
async fn get(&self, _key: &str) -> anyhow::Result<Option<crate::memory::MemoryEntry>> {
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
async fn list(
|
||||
&self,
|
||||
_category: Option<&crate::memory::MemoryCategory>,
|
||||
_session_id: Option<&str>,
|
||||
) -> anyhow::Result<Vec<crate::memory::MemoryEntry>> {
|
||||
Ok(Vec::new())
|
||||
}
|
||||
|
||||
async fn forget(&self, _key: &str) -> anyhow::Result<bool> {
|
||||
Ok(false)
|
||||
}
|
||||
|
||||
async fn count(&self) -> anyhow::Result<usize> {
|
||||
Ok(1)
|
||||
}
|
||||
|
||||
async fn health_check(&self) -> bool {
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn message_dispatch_processes_messages_in_parallel() {
|
||||
let channel_impl = Arc::new(RecordingChannel::default());
|
||||
|
|
@ -4003,6 +4066,79 @@ mod tests {
|
|||
assert!(calls[1][3].1.contains("follow up"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn process_channel_message_enriches_current_turn_without_persisting_context() {
|
||||
let channel_impl = Arc::new(RecordingChannel::default());
|
||||
let channel: Arc<dyn Channel> = channel_impl.clone();
|
||||
|
||||
let mut channels_by_name = HashMap::new();
|
||||
channels_by_name.insert(channel.name().to_string(), channel);
|
||||
|
||||
let provider_impl = Arc::new(HistoryCaptureProvider::default());
|
||||
let runtime_ctx = Arc::new(ChannelRuntimeContext {
|
||||
channels_by_name: Arc::new(channels_by_name),
|
||||
provider: provider_impl.clone(),
|
||||
default_provider: Arc::new("test-provider".to_string()),
|
||||
memory: Arc::new(RecallMemory),
|
||||
tools_registry: Arc::new(vec![]),
|
||||
observer: Arc::new(NoopObserver),
|
||||
system_prompt: Arc::new("test-system-prompt".to_string()),
|
||||
model: Arc::new("test-model".to_string()),
|
||||
temperature: 0.0,
|
||||
auto_save_memory: false,
|
||||
max_tool_iterations: 5,
|
||||
min_relevance_score: 0.0,
|
||||
conversation_histories: Arc::new(Mutex::new(HashMap::new())),
|
||||
provider_cache: Arc::new(Mutex::new(HashMap::new())),
|
||||
route_overrides: Arc::new(Mutex::new(HashMap::new())),
|
||||
api_key: None,
|
||||
api_url: None,
|
||||
reliability: Arc::new(crate::config::ReliabilityConfig::default()),
|
||||
provider_runtime_options: providers::ProviderRuntimeOptions::default(),
|
||||
workspace_dir: Arc::new(std::env::temp_dir()),
|
||||
message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS,
|
||||
interrupt_on_new_message: false,
|
||||
multimodal: crate::config::MultimodalConfig::default(),
|
||||
});
|
||||
|
||||
process_channel_message(
|
||||
runtime_ctx.clone(),
|
||||
traits::ChannelMessage {
|
||||
id: "msg-ctx-1".to_string(),
|
||||
sender: "alice".to_string(),
|
||||
reply_target: "chat-ctx".to_string(),
|
||||
content: "hello".to_string(),
|
||||
channel: "test-channel".to_string(),
|
||||
timestamp: 1,
|
||||
thread_ts: None,
|
||||
},
|
||||
CancellationToken::new(),
|
||||
)
|
||||
.await;
|
||||
|
||||
let calls = provider_impl
|
||||
.calls
|
||||
.lock()
|
||||
.unwrap_or_else(|e| e.into_inner());
|
||||
assert_eq!(calls.len(), 1);
|
||||
assert_eq!(calls[0].len(), 2);
|
||||
assert_eq!(calls[0][1].0, "user");
|
||||
assert!(calls[0][1].1.contains("[Memory context]"));
|
||||
assert!(calls[0][1].1.contains("Age is 45"));
|
||||
assert!(calls[0][1].1.contains("hello"));
|
||||
|
||||
let histories = runtime_ctx
|
||||
.conversation_histories
|
||||
.lock()
|
||||
.unwrap_or_else(|e| e.into_inner());
|
||||
let turns = histories
|
||||
.get("test-channel_alice")
|
||||
.expect("history should be stored for sender");
|
||||
assert_eq!(turns[0].role, "user");
|
||||
assert_eq!(turns[0].content, "hello");
|
||||
assert!(!turns[0].content.contains("[Memory context]"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn process_channel_message_telegram_keeps_system_instruction_at_top_only() {
|
||||
let channel_impl = Arc::new(TelegramRecordingChannel::default());
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue