From 9639446fb9fd4afe2f6cd1fd0c07aedd46c91b6a Mon Sep 17 00:00:00 2001 From: Chummy <183474434+chumyin@users.noreply.github.com> Date: Mon, 16 Feb 2026 10:58:06 +0800 Subject: [PATCH] fix(memory): prevent autosave overwrite collisions Generate unique autosave memory keys across channels, agent loop, and gateway webhook/WhatsApp flows to avoid ON CONFLICT(key) overwrites in SQLite memory. Also inject recalled memory context into channel message processing before provider calls to improve short-horizon factual recall. Refs #221 --- src/agent/loop_.rs | 50 +++++++++++++-- src/channels/mod.rs | 127 ++++++++++++++++++++++++++++++++++++- src/gateway/mod.rs | 150 +++++++++++++++++++++++++++++++++++++++++++- 3 files changed, 319 insertions(+), 8 deletions(-) diff --git a/src/agent/loop_.rs b/src/agent/loop_.rs index 361396f..4783896 100644 --- a/src/agent/loop_.rs +++ b/src/agent/loop_.rs @@ -11,6 +11,7 @@ use std::fmt::Write; use std::io::Write as IoWrite; use std::sync::Arc; use std::time::Instant; +use uuid::Uuid; /// Maximum agentic tool-use iterations per user message to prevent runaway loops. const MAX_TOOL_ITERATIONS: usize = 10; @@ -19,6 +20,10 @@ const MAX_TOOL_ITERATIONS: usize = 10; /// When exceeded, the oldest messages are dropped (system prompt is always preserved). const MAX_HISTORY_MESSAGES: usize = 50; +fn autosave_memory_key(prefix: &str) -> String { + format!("{prefix}_{}", Uuid::new_v4()) +} + /// Trim conversation history to prevent unbounded growth. /// Preserves the system prompt (first message if role=system) and the most recent messages. fn trim_history(history: &mut Vec) { @@ -397,8 +402,9 @@ pub async fn run( if let Some(msg) = message { // Auto-save user message to memory if config.memory.auto_save { + let user_key = autosave_memory_key("user_msg"); let _ = mem - .store("user_msg", &msg, MemoryCategory::Conversation) + .store(&user_key, &msg, MemoryCategory::Conversation) .await; } @@ -429,8 +435,9 @@ pub async fn run( // Auto-save assistant response to daily log if config.memory.auto_save { let summary = truncate_with_ellipsis(&response, 100); + let response_key = autosave_memory_key("assistant_resp"); let _ = mem - .store("assistant_resp", &summary, MemoryCategory::Daily) + .store(&response_key, &summary, MemoryCategory::Daily) .await; } } else { @@ -451,8 +458,9 @@ pub async fn run( while let Some(msg) = rx.recv().await { // Auto-save conversation turns if config.memory.auto_save { + let user_key = autosave_memory_key("user_msg"); let _ = mem - .store("user_msg", &msg.content, MemoryCategory::Conversation) + .store(&user_key, &msg.content, MemoryCategory::Conversation) .await; } @@ -489,8 +497,9 @@ pub async fn run( if config.memory.auto_save { let summary = truncate_with_ellipsis(&response, 100); + let response_key = autosave_memory_key("assistant_resp"); let _ = mem - .store("assistant_resp", &summary, MemoryCategory::Daily) + .store(&response_key, &summary, MemoryCategory::Daily) .await; } } @@ -510,6 +519,8 @@ pub async fn run( #[cfg(test)] mod tests { use super::*; + use crate::memory::{Memory, MemoryCategory, SqliteMemory}; + use tempfile::TempDir; #[test] fn parse_tool_calls_extracts_single_call() { @@ -664,4 +675,35 @@ After text."#; trim_history(&mut history); assert_eq!(history.len(), 3); } + + #[test] + fn autosave_memory_key_has_prefix_and_uniqueness() { + let key1 = autosave_memory_key("user_msg"); + let key2 = autosave_memory_key("user_msg"); + + assert!(key1.starts_with("user_msg_")); + assert!(key2.starts_with("user_msg_")); + assert_ne!(key1, key2); + } + + #[tokio::test] + async fn autosave_memory_keys_preserve_multiple_turns() { + let tmp = TempDir::new().unwrap(); + let mem = SqliteMemory::new(tmp.path()).unwrap(); + + let key1 = autosave_memory_key("user_msg"); + let key2 = autosave_memory_key("user_msg"); + + mem.store(&key1, "I'm Paul", MemoryCategory::Conversation) + .await + .unwrap(); + mem.store(&key2, "I'm 45", MemoryCategory::Conversation) + .await + .unwrap(); + + assert_eq!(mem.count().await.unwrap(), 2); + + let recalled = mem.recall("45", 5).await.unwrap(); + assert!(recalled.iter().any(|entry| entry.content.contains("45"))); + } } diff --git a/src/channels/mod.rs b/src/channels/mod.rs index 8e67179..8a9e3dc 100644 --- a/src/channels/mod.rs +++ b/src/channels/mod.rs @@ -26,6 +26,7 @@ use crate::memory::{self, Memory}; use crate::providers::{self, Provider}; use crate::util::truncate_with_ellipsis; use anyhow::Result; +use std::fmt::Write; use std::sync::Arc; use std::time::{Duration, Instant}; @@ -36,6 +37,26 @@ const DEFAULT_CHANNEL_INITIAL_BACKOFF_SECS: u64 = 2; const DEFAULT_CHANNEL_MAX_BACKOFF_SECS: u64 = 60; const CHANNEL_MESSAGE_TIMEOUT_SECS: u64 = 90; +fn conversation_memory_key(msg: &traits::ChannelMessage) -> String { + format!("{}_{}_{}", msg.channel, msg.sender, msg.id) +} + +async fn build_memory_context(mem: &dyn Memory, user_msg: &str) -> String { + let mut context = String::new(); + + if let Ok(entries) = mem.recall(user_msg, 5).await { + if !entries.is_empty() { + context.push_str("[Memory context]\n"); + for entry in &entries { + let _ = writeln!(context, "- {}: {}", entry.key, entry.content); + } + context.push('\n'); + } + } + + context +} + fn spawn_supervised_listener( ch: Arc, tx: tokio::sync::mpsc::Sender, @@ -681,17 +702,26 @@ pub async fn start_channels(config: Config) -> Result<()> { truncate_with_ellipsis(&msg.content, 80) ); + let memory_context = build_memory_context(mem.as_ref(), &msg.content).await; + // Auto-save to memory if config.memory.auto_save { + let autosave_key = conversation_memory_key(&msg); let _ = mem .store( - &format!("{}_{}", msg.channel, msg.sender), + &autosave_key, &msg.content, crate::memory::MemoryCategory::Conversation, ) .await; } + let enriched_message = if memory_context.is_empty() { + msg.content.clone() + } else { + format!("{memory_context}{}", msg.content) + }; + let target_channel = channels.iter().find(|ch| ch.name() == msg.channel); // Show typing indicator while processing @@ -707,7 +737,12 @@ pub async fn start_channels(config: Config) -> Result<()> { let llm_result = tokio::time::timeout( Duration::from_secs(CHANNEL_MESSAGE_TIMEOUT_SECS), - provider.chat_with_system(Some(&system_prompt), &msg.content, &model, temperature), + provider.chat_with_system( + Some(&system_prompt), + &enriched_message, + &model, + temperature, + ), ) .await; @@ -773,6 +808,7 @@ pub async fn start_channels(config: Config) -> Result<()> { #[cfg(test)] mod tests { use super::*; + use crate::memory::{Memory, MemoryCategory, SqliteMemory}; use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; use tempfile::TempDir; @@ -998,6 +1034,93 @@ mod tests { assert!(prompt.contains(&format!("Working directory: `{}`", ws.path().display()))); } + #[test] + fn conversation_memory_key_uses_message_id() { + let msg = traits::ChannelMessage { + id: "msg_abc123".into(), + sender: "U123".into(), + content: "hello".into(), + channel: "slack".into(), + timestamp: 1, + }; + + assert_eq!(conversation_memory_key(&msg), "slack_U123_msg_abc123"); + } + + #[test] + fn conversation_memory_key_is_unique_per_message() { + let msg1 = traits::ChannelMessage { + id: "msg_1".into(), + sender: "U123".into(), + content: "first".into(), + channel: "slack".into(), + timestamp: 1, + }; + let msg2 = traits::ChannelMessage { + id: "msg_2".into(), + sender: "U123".into(), + content: "second".into(), + channel: "slack".into(), + timestamp: 2, + }; + + assert_ne!(conversation_memory_key(&msg1), conversation_memory_key(&msg2)); + } + + #[tokio::test] + async fn autosave_keys_preserve_multiple_conversation_facts() { + let tmp = TempDir::new().unwrap(); + let mem = SqliteMemory::new(tmp.path()).unwrap(); + + let msg1 = traits::ChannelMessage { + id: "msg_1".into(), + sender: "U123".into(), + content: "I'm Paul".into(), + channel: "slack".into(), + timestamp: 1, + }; + let msg2 = traits::ChannelMessage { + id: "msg_2".into(), + sender: "U123".into(), + content: "I'm 45".into(), + channel: "slack".into(), + timestamp: 2, + }; + + mem.store( + &conversation_memory_key(&msg1), + &msg1.content, + MemoryCategory::Conversation, + ) + .await + .unwrap(); + mem.store( + &conversation_memory_key(&msg2), + &msg2.content, + MemoryCategory::Conversation, + ) + .await + .unwrap(); + + assert_eq!(mem.count().await.unwrap(), 2); + + let recalled = mem.recall("45", 5).await.unwrap(); + assert!(recalled.iter().any(|entry| entry.content.contains("45"))); + } + + #[tokio::test] + async fn build_memory_context_includes_recalled_entries() { + let tmp = TempDir::new().unwrap(); + let mem = SqliteMemory::new(tmp.path()).unwrap(); + mem.store("age_fact", "Age is 45", MemoryCategory::Conversation) + .await + .unwrap(); + + let context = build_memory_context(&mem, "age").await; + assert!(context.contains("[Memory context]")); + assert!(context.contains("Age is 45")); + } + // ── AIEOS Identity Tests (Issue #168) ───────────────────────── #[test] diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index 4f85437..79f9adb 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -28,6 +28,7 @@ use std::sync::{Arc, Mutex}; use std::time::{Duration, Instant}; use tower_http::limit::RequestBodyLimitLayer; use tower_http::timeout::TimeoutLayer; +use uuid::Uuid; /// Maximum request body size (64KB) — prevents memory exhaustion pub const MAX_BODY_SIZE: usize = 65_536; @@ -36,6 +37,14 @@ pub const REQUEST_TIMEOUT_SECS: u64 = 30; /// Sliding window used by gateway rate limiting. pub const RATE_LIMIT_WINDOW_SECS: u64 = 60; +fn webhook_memory_key() -> String { + format!("webhook_msg_{}", Uuid::new_v4()) +} + +fn whatsapp_memory_key(msg: &crate::channels::traits::ChannelMessage) -> String { + format!("whatsapp_{}_{}", msg.sender, msg.id) +} + #[derive(Debug)] struct SlidingWindowRateLimiter { limit_per_window: u32, @@ -475,9 +484,10 @@ async fn handle_webhook( let message = &webhook_body.message; if state.auto_save { + let key = webhook_memory_key(); let _ = state .mem - .store("webhook_msg", message, MemoryCategory::Conversation) + .store(&key, message, MemoryCategory::Conversation) .await; } @@ -627,10 +637,11 @@ async fn handle_whatsapp_message( // Auto-save to memory if state.auto_save { + let key = whatsapp_memory_key(msg); let _ = state .mem .store( - &format!("whatsapp_{}", msg.sender), + &key, &msg.content, MemoryCategory::Conversation, ) @@ -668,12 +679,14 @@ async fn handle_whatsapp_message( #[cfg(test)] mod tests { use super::*; + use crate::channels::traits::ChannelMessage; use crate::memory::{Memory, MemoryCategory, MemoryEntry}; use crate::providers::Provider; use async_trait::async_trait; use axum::http::HeaderValue; use axum::response::IntoResponse; use http_body_util::BodyExt; + use std::sync::Mutex; use std::sync::atomic::{AtomicUsize, Ordering}; #[test] @@ -730,6 +743,30 @@ mod tests { assert!(store.record_if_new("req-2")); } + #[test] + fn webhook_memory_key_is_unique() { + let key1 = webhook_memory_key(); + let key2 = webhook_memory_key(); + + assert!(key1.starts_with("webhook_msg_")); + assert!(key2.starts_with("webhook_msg_")); + assert_ne!(key1, key2); + } + + #[test] + fn whatsapp_memory_key_includes_sender_and_message_id() { + let msg = ChannelMessage { + id: "wamid-123".into(), + sender: "+1234567890".into(), + content: "hello".into(), + channel: "whatsapp".into(), + timestamp: 1, + }; + + let key = whatsapp_memory_key(&msg); + assert_eq!(key, "whatsapp_+1234567890_wamid-123"); + } + #[derive(Default)] struct MockMemory; @@ -795,6 +832,63 @@ mod tests { } } + #[derive(Default)] + struct TrackingMemory { + keys: Mutex>, + } + + #[async_trait] + impl Memory for TrackingMemory { + fn name(&self) -> &str { + "tracking" + } + + async fn store( + &self, + key: &str, + _content: &str, + _category: MemoryCategory, + ) -> anyhow::Result<()> { + self.keys + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) + .push(key.to_string()); + Ok(()) + } + + async fn recall(&self, _query: &str, _limit: usize) -> anyhow::Result> { + Ok(Vec::new()) + } + + async fn get(&self, _key: &str) -> anyhow::Result> { + Ok(None) + } + + async fn list( + &self, + _category: Option<&MemoryCategory>, + ) -> anyhow::Result> { + Ok(Vec::new()) + } + + async fn forget(&self, _key: &str) -> anyhow::Result { + Ok(false) + } + + async fn count(&self) -> anyhow::Result { + let size = self + .keys + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) + .len(); + Ok(size) + } + + async fn health_check(&self) -> bool { + true + } + } + #[tokio::test] async fn webhook_idempotency_skips_duplicate_provider_calls() { let provider_impl = Arc::new(MockProvider::default()); @@ -841,6 +935,58 @@ mod tests { assert_eq!(provider_impl.calls.load(Ordering::SeqCst), 1); } + #[tokio::test] + async fn webhook_autosave_stores_distinct_keys_per_request() { + let provider_impl = Arc::new(MockProvider::default()); + let provider: Arc = provider_impl.clone(); + + let tracking_impl = Arc::new(TrackingMemory::default()); + let memory: Arc = tracking_impl.clone(); + + let state = AppState { + provider, + model: "test-model".into(), + temperature: 0.0, + mem: memory, + auto_save: true, + webhook_secret: None, + pairing: Arc::new(PairingGuard::new(false, &[])), + rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100)), + idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300))), + whatsapp: None, + whatsapp_app_secret: None, + }; + + let headers = HeaderMap::new(); + + let body1 = Ok(Json(WebhookBody { + message: "hello one".into(), + })); + let first = handle_webhook(State(state.clone()), headers.clone(), body1) + .await + .into_response(); + assert_eq!(first.status(), StatusCode::OK); + + let body2 = Ok(Json(WebhookBody { + message: "hello two".into(), + })); + let second = handle_webhook(State(state), headers, body2) + .await + .into_response(); + assert_eq!(second.status(), StatusCode::OK); + + let keys = tracking_impl + .keys + .lock() + .unwrap_or_else(std::sync::PoisonError::into_inner) + .clone(); + assert_eq!(keys.len(), 2); + assert_ne!(keys[0], keys[1]); + assert!(keys[0].starts_with("webhook_msg_")); + assert!(keys[1].starts_with("webhook_msg_")); + assert_eq!(provider_impl.calls.load(Ordering::SeqCst), 2); + } + // ══════════════════════════════════════════════════════════ // WhatsApp Signature Verification Tests (CWE-345 Prevention) // ══════════════════════════════════════════════════════════