diff --git a/src/channels/discord.rs b/src/channels/discord.rs index 3ae69ba..bcb447d 100644 --- a/src/channels/discord.rs +++ b/src/channels/discord.rs @@ -3,6 +3,7 @@ use async_trait::async_trait; use futures_util::{SinkExt, StreamExt}; use parking_lot::Mutex; use serde_json::json; +use std::collections::HashMap; use tokio_tungstenite::tungstenite::Message; use uuid::Uuid; @@ -13,7 +14,7 @@ pub struct DiscordChannel { allowed_users: Vec, listen_to_bots: bool, mention_only: bool, - typing_handle: Mutex>>, + typing_handles: Mutex>>, } impl DiscordChannel { @@ -30,7 +31,7 @@ impl DiscordChannel { allowed_users, listen_to_bots, mention_only, - typing_handle: Mutex::new(None), + typing_handles: Mutex::new(HashMap::new()), } } @@ -457,15 +458,15 @@ impl Channel for DiscordChannel { } }); - let mut guard = self.typing_handle.lock(); - *guard = Some(handle); + let mut guard = self.typing_handles.lock(); + guard.insert(recipient.to_string(), handle); Ok(()) } - async fn stop_typing(&self, _recipient: &str) -> anyhow::Result<()> { - let mut guard = self.typing_handle.lock(); - if let Some(handle) = guard.take() { + async fn stop_typing(&self, recipient: &str) -> anyhow::Result<()> { + let mut guard = self.typing_handles.lock(); + if let Some(handle) = guard.remove(recipient) { handle.abort(); } Ok(()) @@ -754,18 +755,18 @@ mod tests { } #[test] - fn typing_handle_starts_as_none() { + fn typing_handles_start_empty() { let ch = DiscordChannel::new("fake".into(), None, vec![], false, false); - let guard = ch.typing_handle.lock(); - assert!(guard.is_none()); + let guard = ch.typing_handles.lock(); + assert!(guard.is_empty()); } #[tokio::test] async fn start_typing_sets_handle() { let ch = DiscordChannel::new("fake".into(), None, vec![], false, false); let _ = ch.start_typing("123456").await; - let guard = ch.typing_handle.lock(); - assert!(guard.is_some()); + let guard = ch.typing_handles.lock(); + assert!(guard.contains_key("123456")); } #[tokio::test] @@ -773,8 +774,8 @@ mod tests { let ch = DiscordChannel::new("fake".into(), None, vec![], false, false); let _ = ch.start_typing("123456").await; let _ = ch.stop_typing("123456").await; - let guard = ch.typing_handle.lock(); - assert!(guard.is_none()); + let guard = ch.typing_handles.lock(); + assert!(!guard.contains_key("123456")); } #[tokio::test] @@ -785,12 +786,21 @@ mod tests { } #[tokio::test] - async fn start_typing_replaces_existing_task() { + async fn concurrent_typing_handles_are_independent() { let ch = DiscordChannel::new("fake".into(), None, vec![], false, false); let _ = ch.start_typing("111").await; let _ = ch.start_typing("222").await; - let guard = ch.typing_handle.lock(); - assert!(guard.is_some()); + { + let guard = ch.typing_handles.lock(); + assert_eq!(guard.len(), 2); + assert!(guard.contains_key("111")); + assert!(guard.contains_key("222")); + } + // Stopping one does not affect the other + let _ = ch.stop_typing("111").await; + let guard = ch.typing_handles.lock(); + assert_eq!(guard.len(), 1); + assert!(guard.contains_key("222")); } // ── Message ID edge cases ─────────────────────────────────────