diff --git a/src/agent/loop_.rs b/src/agent/loop_.rs index 1c33c49..a995a72 100644 --- a/src/agent/loop_.rs +++ b/src/agent/loop_.rs @@ -1263,7 +1263,11 @@ I will now call the tool with this payload: let (text, calls) = parse_tool_calls(response); assert!(text.contains("Sure, creating the file now.")); - assert_eq!(calls.len(), 0, "Raw JSON without wrappers should not be parsed"); + assert_eq!( + calls.len(), + 0, + "Raw JSON without wrappers should not be parsed" + ); } #[test] diff --git a/src/channels/discord.rs b/src/channels/discord.rs index c685e96..6b3bae3 100644 --- a/src/channels/discord.rs +++ b/src/channels/discord.rs @@ -343,11 +343,16 @@ impl Channel for DiscordChannel { continue; } + let message_id = d.get("id").and_then(|i| i.as_str()).unwrap_or(""); let channel_id = d.get("channel_id").and_then(|c| c.as_str()).unwrap_or("").to_string(); let channel_msg = ChannelMessage { - id: Uuid::new_v4().to_string(), - sender: channel_id, + id: if message_id.is_empty() { + Uuid::new_v4().to_string() + } else { + format!("discord_{message_id}") + }, + sender: author_id.to_string(), content: content.to_string(), channel: "discord".to_string(), timestamp: std::time::SystemTime::now() @@ -695,4 +700,55 @@ mod tests { let guard = ch.typing_handle.lock().unwrap(); assert!(guard.is_some()); } + + // ── Message ID edge cases ───────────────────────────────────── + + #[test] + fn discord_message_id_format_includes_discord_prefix() { + // Verify that message IDs follow the format: discord_{message_id} + let message_id = "123456789012345678"; + let expected_id = format!("discord_{message_id}"); + assert_eq!(expected_id, "discord_123456789012345678"); + } + + #[test] + fn discord_message_id_is_deterministic() { + // Same message_id = same ID (prevents duplicates after restart) + let message_id = "123456789012345678"; + let id1 = format!("discord_{message_id}"); + let id2 = format!("discord_{message_id}"); + assert_eq!(id1, id2); + } + + #[test] + fn discord_message_id_different_message_different_id() { + // Different message IDs produce different IDs + let id1 = format!("discord_123456789012345678"); + let id2 = format!("discord_987654321098765432"); + assert_ne!(id1, id2); + } + + #[test] + fn discord_message_id_uses_snowflake_id() { + // Discord snowflake IDs are numeric strings + let message_id = "123456789012345678"; // Typical snowflake format + let id = format!("discord_{message_id}"); + assert!(id.starts_with("discord_")); + // Snowflake IDs are numeric + assert!(message_id.chars().all(|c| c.is_ascii_digit())); + } + + #[test] + fn discord_message_id_fallback_to_uuid_on_empty() { + // Edge case: empty message_id falls back to UUID + let message_id = ""; + let id = if message_id.is_empty() { + format!("discord_{}", uuid::Uuid::new_v4()) + } else { + format!("discord_{message_id}") + }; + assert!(id.starts_with("discord_")); + // Should have UUID dashes + assert!(id.contains('-')); + } } diff --git a/src/channels/slack.rs b/src/channels/slack.rs index 5a18cc3..4485af6 100644 --- a/src/channels/slack.rs +++ b/src/channels/slack.rs @@ -160,8 +160,8 @@ impl Channel for SlackChannel { last_ts = ts.to_string(); let channel_msg = ChannelMessage { - id: Uuid::new_v4().to_string(), - sender: channel_id.clone(), + id: format!("slack_{channel_id}_{ts}"), + sender: user.to_string(), content: text.to_string(), channel: "slack".to_string(), timestamp: std::time::SystemTime::now() @@ -252,4 +252,53 @@ mod tests { assert!(ch.is_user_allowed("U111")); assert!(ch.is_user_allowed("anyone")); } + + // ── Message ID edge cases ───────────────────────────────────── + + #[test] + fn slack_message_id_format_includes_channel_and_ts() { + // Verify that message IDs follow the format: slack_{channel_id}_{ts} + let ts = "1234567890.123456"; + let channel_id = "C12345"; + let expected_id = format!("slack_{channel_id}_{ts}"); + assert_eq!(expected_id, "slack_C12345_1234567890.123456"); + } + + #[test] + fn slack_message_id_is_deterministic() { + // Same channel_id + same ts = same ID (prevents duplicates after restart) + let ts = "1234567890.123456"; + let channel_id = "C12345"; + let id1 = format!("slack_{channel_id}_{ts}"); + let id2 = format!("slack_{channel_id}_{ts}"); + assert_eq!(id1, id2); + } + + #[test] + fn slack_message_id_different_ts_different_id() { + // Different timestamps produce different IDs + let channel_id = "C12345"; + let id1 = format!("slack_{channel_id}_1234567890.123456"); + let id2 = format!("slack_{channel_id}_1234567890.123457"); + assert_ne!(id1, id2); + } + + #[test] + fn slack_message_id_different_channel_different_id() { + // Different channels produce different IDs even with same ts + let ts = "1234567890.123456"; + let id1 = format!("slack_C12345_{ts}"); + let id2 = format!("slack_C67890_{ts}"); + assert_ne!(id1, id2); + } + + #[test] + fn slack_message_id_no_uuid_randomness() { + // Verify format doesn't contain random UUID components + let ts = "1234567890.123456"; + let channel_id = "C12345"; + let id = format!("slack_{channel_id}_{ts}"); + assert!(!id.contains('-')); // No UUID dashes + assert!(id.starts_with("slack_")); + } } diff --git a/src/channels/telegram.rs b/src/channels/telegram.rs index 94ff767..117f42e 100644 --- a/src/channels/telegram.rs +++ b/src/channels/telegram.rs @@ -579,6 +579,11 @@ Allowlist Telegram @username or numeric user ID, then run `zeroclaw onboard --ch continue; }; + let message_id = message + .get("message_id") + .and_then(|v| v.as_i64()) + .unwrap_or(0); + // Send "typing" indicator immediately when we receive a message let typing_body = serde_json::json!({ "chat_id": &chat_id, @@ -592,8 +597,8 @@ Allowlist Telegram @username or numeric user ID, then run `zeroclaw onboard --ch .await; // Ignore errors for typing indicator let msg = ChannelMessage { - id: Uuid::new_v4().to_string(), - sender: chat_id, + id: format!("telegram_{chat_id}_{message_id}"), + sender: username.to_string(), content: text.to_string(), channel: "telegram".to_string(), timestamp: std::time::SystemTime::now() @@ -1033,4 +1038,62 @@ mod tests { // Should not panic assert!(result.is_err()); } + + // ── Message ID edge cases ───────────────────────────────────── + + #[test] + fn telegram_message_id_format_includes_chat_and_message_id() { + // Verify that message IDs follow the format: telegram_{chat_id}_{message_id} + let chat_id = "123456"; + let message_id = 789; + let expected_id = format!("telegram_{chat_id}_{message_id}"); + assert_eq!(expected_id, "telegram_123456_789"); + } + + #[test] + fn telegram_message_id_is_deterministic() { + // Same chat_id + same message_id = same ID (prevents duplicates after restart) + let chat_id = "123456"; + let message_id = 789; + let id1 = format!("telegram_{chat_id}_{message_id}"); + let id2 = format!("telegram_{chat_id}_{message_id}"); + assert_eq!(id1, id2); + } + + #[test] + fn telegram_message_id_different_message_different_id() { + // Different message IDs produce different IDs + let chat_id = "123456"; + let id1 = format!("telegram_{chat_id}_789"); + let id2 = format!("telegram_{chat_id}_790"); + assert_ne!(id1, id2); + } + + #[test] + fn telegram_message_id_different_chat_different_id() { + // Different chats produce different IDs even with same message_id + let message_id = 789; + let id1 = format!("telegram_123456_{message_id}"); + let id2 = format!("telegram_789012_{message_id}"); + assert_ne!(id1, id2); + } + + #[test] + fn telegram_message_id_no_uuid_randomness() { + // Verify format doesn't contain random UUID components + let chat_id = "123456"; + let message_id = 789; + let id = format!("telegram_{chat_id}_{message_id}"); + assert!(!id.contains('-')); // No UUID dashes + assert!(id.starts_with("telegram_")); + } + + #[test] + fn telegram_message_id_handles_zero_message_id() { + // Edge case: message_id can be 0 (fallback/missing case) + let chat_id = "123456"; + let message_id = 0; + let id = format!("telegram_{chat_id}_{message_id}"); + assert_eq!(id, "telegram_123456_0"); + } } diff --git a/src/memory/mod.rs b/src/memory/mod.rs index f012c27..45b7451 100644 --- a/src/memory/mod.rs +++ b/src/memory/mod.rs @@ -76,7 +76,10 @@ pub fn create_memory( // Auto-hydration: if brain.db is missing but MEMORY_SNAPSHOT.md exists, // restore the "soul" from the snapshot before creating the backend. if config.auto_hydrate - && matches!(classify_memory_backend(&config.backend), MemoryBackendKind::Sqlite | MemoryBackendKind::Lucid) + && matches!( + classify_memory_backend(&config.backend), + MemoryBackendKind::Sqlite | MemoryBackendKind::Lucid + ) && snapshot::should_hydrate(workspace_dir) { tracing::info!("🧬 Cold boot detected — hydrating from MEMORY_SNAPSHOT.md"); @@ -143,10 +146,7 @@ pub fn create_memory_for_migration( } /// Factory: create an optional response cache from config. -pub fn create_response_cache( - config: &MemoryConfig, - workspace_dir: &Path, -) -> Option { +pub fn create_response_cache(config: &MemoryConfig, workspace_dir: &Path) -> Option { if !config.response_cache_enabled { return None; } diff --git a/src/memory/response_cache.rs b/src/memory/response_cache.rs index 843b971..3135b2b 100644 --- a/src/memory/response_cache.rs +++ b/src/memory/response_cache.rs @@ -90,9 +90,7 @@ impl ResponseCache { WHERE prompt_hash = ?1 AND created_at > ?2", )?; - let result: Option = stmt - .query_row(params![key, cutoff], |row| row.get(0)) - .ok(); + let result: Option = stmt.query_row(params![key, cutoff], |row| row.get(0)).ok(); if result.is_some() { // Bump hit count and accessed_at @@ -109,13 +107,7 @@ impl ResponseCache { } /// Store a response in the cache. - pub fn put( - &self, - key: &str, - model: &str, - response: &str, - token_count: u32, - ) -> Result<()> { + pub fn put(&self, key: &str, model: &str, response: &str, token_count: u32) -> Result<()> { let conn = self .conn .lock() @@ -162,19 +154,17 @@ impl ResponseCache { let count: i64 = conn.query_row("SELECT COUNT(*) FROM response_cache", [], |row| row.get(0))?; - let hits: i64 = conn - .query_row( - "SELECT COALESCE(SUM(hit_count), 0) FROM response_cache", - [], - |row| row.get(0), - )?; + let hits: i64 = conn.query_row( + "SELECT COALESCE(SUM(hit_count), 0) FROM response_cache", + [], + |row| row.get(0), + )?; - let tokens_saved: i64 = conn - .query_row( - "SELECT COALESCE(SUM(token_count * hit_count), 0) FROM response_cache", - [], - |row| row.get(0), - )?; + let tokens_saved: i64 = conn.query_row( + "SELECT COALESCE(SUM(token_count * hit_count), 0) FROM response_cache", + [], + |row| row.get(0), + )?; #[allow(clippy::cast_sign_loss)] Ok((count as usize, hits as u64, tokens_saved as u64)) @@ -363,7 +353,9 @@ mod tests { let (_tmp, cache) = temp_cache(60); let key = ResponseCache::cache_key("gpt-4", None, "日本語のテスト 🦀"); - cache.put(&key, "gpt-4", "はい、Rustは素晴らしい", 30).unwrap(); + cache + .put(&key, "gpt-4", "はい、Rustは素晴らしい", 30) + .unwrap(); let result = cache.get(&key).unwrap(); assert_eq!(result.as_deref(), Some("はい、Rustは素晴らしい")); diff --git a/src/memory/snapshot.rs b/src/memory/snapshot.rs index edd0748..dcfbe1a 100644 --- a/src/memory/snapshot.rs +++ b/src/memory/snapshot.rs @@ -64,7 +64,10 @@ pub fn export_snapshot(workspace_dir: &Path) -> Result { let now = Local::now().format("%Y-%m-%d %H:%M:%S").to_string(); output.push_str(&format!("**Last exported:** {now}\n\n")); - output.push_str(&format!("**Total core memories:** {}\n\n---\n\n", rows.len())); + output.push_str(&format!( + "**Total core memories:** {}\n\n---\n\n", + rows.len() + )); for (key, content, _category, created_at, updated_at) in &rows { output.push_str(&format!("### 🔑 `{key}`\n\n")); diff --git a/src/security/pairing.rs b/src/security/pairing.rs index 0c0ff6e..806431b 100644 --- a/src/security/pairing.rs +++ b/src/security/pairing.rs @@ -8,8 +8,8 @@ // Already-paired tokens are persisted in config so restarts don't require // re-pairing. -use sha2::{Digest, Sha256}; use parking_lot::Mutex; +use sha2::{Digest, Sha256}; use std::collections::HashSet; use std::time::Instant; @@ -70,9 +70,7 @@ impl PairingGuard { /// The one-time pairing code (only set when no tokens exist yet). pub fn pairing_code(&self) -> Option { - self.pairing_code - .lock() - .clone() + self.pairing_code.lock().clone() } /// Whether pairing is required at all. @@ -85,10 +83,7 @@ impl PairingGuard { pub fn try_pair(&self, code: &str) -> Result, u64> { // Check brute force lockout { - let attempts = self - .failed_attempts - .lock() - ; + let attempts = self.failed_attempts.lock(); if let (count, Some(locked_at)) = &*attempts { if *count >= MAX_PAIR_ATTEMPTS { let elapsed = locked_at.elapsed().as_secs(); @@ -100,25 +95,16 @@ impl PairingGuard { } { - let mut pairing_code = self - .pairing_code - .lock() - ; + let mut pairing_code = self.pairing_code.lock(); if let Some(ref expected) = *pairing_code { if constant_time_eq(code.trim(), expected.trim()) { // Reset failed attempts on success { - let mut attempts = self - .failed_attempts - .lock() - ; + let mut attempts = self.failed_attempts.lock(); *attempts = (0, None); } let token = generate_token(); - let mut tokens = self - .paired_tokens - .lock() - ; + let mut tokens = self.paired_tokens.lock(); tokens.insert(hash_token(&token)); // Consume the pairing code so it cannot be reused @@ -131,10 +117,7 @@ impl PairingGuard { // Increment failed attempts { - let mut attempts = self - .failed_attempts - .lock() - ; + let mut attempts = self.failed_attempts.lock(); attempts.0 += 1; if attempts.0 >= MAX_PAIR_ATTEMPTS { attempts.1 = Some(Instant::now()); @@ -150,28 +133,19 @@ impl PairingGuard { return true; } let hashed = hash_token(token); - let tokens = self - .paired_tokens - .lock() - ; + let tokens = self.paired_tokens.lock(); tokens.contains(&hashed) } /// Returns true if the gateway is already paired (has at least one token). pub fn is_paired(&self) -> bool { - let tokens = self - .paired_tokens - .lock() - ; + let tokens = self.paired_tokens.lock(); !tokens.is_empty() } /// Get all paired token hashes (for persisting to config). pub fn tokens(&self) -> Vec { - let tokens = self - .paired_tokens - .lock() - ; + let tokens = self.paired_tokens.lock(); tokens.iter().cloned().collect() } } diff --git a/src/security/policy.rs b/src/security/policy.rs index 6a6bf8b..66591c2 100644 --- a/src/security/policy.rs +++ b/src/security/policy.rs @@ -1,5 +1,5 @@ -use serde::{Deserialize, Serialize}; use parking_lot::Mutex; +use serde::{Deserialize, Serialize}; use std::path::{Path, PathBuf}; use std::time::Instant; @@ -40,9 +40,7 @@ impl ActionTracker { /// Record an action and return the current count within the window. pub fn record(&self) -> usize { - let mut actions = self - .actions - .lock(); + let mut actions = self.actions.lock(); let cutoff = Instant::now() .checked_sub(std::time::Duration::from_secs(3600)) .unwrap_or_else(Instant::now); @@ -53,9 +51,7 @@ impl ActionTracker { /// Count of actions in the current window without recording. pub fn count(&self) -> usize { - let mut actions = self - .actions - .lock(); + let mut actions = self.actions.lock(); let cutoff = Instant::now() .checked_sub(std::time::Duration::from_secs(3600)) .unwrap_or_else(Instant::now); @@ -66,9 +62,7 @@ impl ActionTracker { impl Clone for ActionTracker { fn clone(&self) -> Self { - let actions = self - .actions - .lock(); + let actions = self.actions.lock(); Self { actions: Mutex::new(actions.clone()), }