diff --git a/src/channels/discord.rs b/src/channels/discord.rs index 5473288..1babfd1 100644 --- a/src/channels/discord.rs +++ b/src/channels/discord.rs @@ -11,6 +11,7 @@ pub struct DiscordChannel { guild_id: Option, allowed_users: Vec, client: reqwest::Client, + typing_handle: std::sync::Mutex>>, } impl DiscordChannel { @@ -20,6 +21,7 @@ impl DiscordChannel { guild_id, allowed_users, client: reqwest::Client::new(), + typing_handle: std::sync::Mutex::new(None), } } @@ -357,6 +359,41 @@ impl Channel for DiscordChannel { .map(|r| r.status().is_success()) .unwrap_or(false) } + + async fn start_typing(&self, recipient: &str) -> anyhow::Result<()> { + self.stop_typing(recipient).await?; + + let client = self.client.clone(); + let token = self.bot_token.clone(); + let channel_id = recipient.to_string(); + + let handle = tokio::spawn(async move { + let url = format!("https://discord.com/api/v10/channels/{channel_id}/typing"); + loop { + let _ = client + .post(&url) + .header("Authorization", format!("Bot {token}")) + .send() + .await; + tokio::time::sleep(std::time::Duration::from_secs(8)).await; + } + }); + + if let Ok(mut guard) = self.typing_handle.lock() { + *guard = Some(handle); + } + + Ok(()) + } + + async fn stop_typing(&self, _recipient: &str) -> anyhow::Result<()> { + if let Ok(mut guard) = self.typing_handle.lock() { + if let Some(handle) = guard.take() { + handle.abort(); + } + } + Ok(()) + } } #[cfg(test)] @@ -581,4 +618,44 @@ mod tests { let reconstructed = chunks.concat(); assert_eq!(reconstructed, msg); } + + #[test] + fn typing_handle_starts_as_none() { + let ch = DiscordChannel::new("fake".into(), None, vec![]); + let guard = ch.typing_handle.lock().unwrap(); + assert!(guard.is_none()); + } + + #[tokio::test] + async fn start_typing_sets_handle() { + let ch = DiscordChannel::new("fake".into(), None, vec![]); + let _ = ch.start_typing("123456").await; + let guard = ch.typing_handle.lock().unwrap(); + assert!(guard.is_some()); + } + + #[tokio::test] + async fn stop_typing_clears_handle() { + let ch = DiscordChannel::new("fake".into(), None, vec![]); + let _ = ch.start_typing("123456").await; + let _ = ch.stop_typing("123456").await; + let guard = ch.typing_handle.lock().unwrap(); + assert!(guard.is_none()); + } + + #[tokio::test] + async fn stop_typing_is_idempotent() { + let ch = DiscordChannel::new("fake".into(), None, vec![]); + assert!(ch.stop_typing("123456").await.is_ok()); + assert!(ch.stop_typing("123456").await.is_ok()); + } + + #[tokio::test] + async fn start_typing_replaces_existing_task() { + let ch = DiscordChannel::new("fake".into(), None, vec![]); + let _ = ch.start_typing("111").await; + let _ = ch.start_typing("222").await; + let guard = ch.typing_handle.lock().unwrap(); + assert!(guard.is_some()); + } } diff --git a/src/channels/mod.rs b/src/channels/mod.rs index 313398e..8e67179 100644 --- a/src/channels/mod.rs +++ b/src/channels/mod.rs @@ -692,6 +692,15 @@ pub async fn start_channels(config: Config) -> Result<()> { .await; } + let target_channel = channels.iter().find(|ch| ch.name() == msg.channel); + + // Show typing indicator while processing + if let Some(ch) = target_channel { + if let Err(e) = ch.start_typing(&msg.sender).await { + tracing::debug!("Failed to start typing on {}: {e}", ch.name()); + } + } + // Call the LLM with system prompt (identity + soul + tools) println!(" ⏳ Processing message..."); let started_at = Instant::now(); @@ -702,6 +711,13 @@ pub async fn start_channels(config: Config) -> Result<()> { ) .await; + // Stop typing before sending the response + if let Some(ch) = target_channel { + if let Err(e) = ch.stop_typing(&msg.sender).await { + tracing::debug!("Failed to stop typing on {}: {e}", ch.name()); + } + } + match llm_result { Ok(Ok(response)) => { println!( @@ -709,13 +725,9 @@ pub async fn start_channels(config: Config) -> Result<()> { started_at.elapsed().as_millis(), truncate_with_ellipsis(&response, 80) ); - // Find the channel that sent this message and reply - for ch in &channels { - if ch.name() == msg.channel { - if let Err(e) = ch.send(&response, &msg.sender).await { - eprintln!(" ❌ Failed to reply on {}: {e}", ch.name()); - } - break; + if let Some(ch) = target_channel { + if let Err(e) = ch.send(&response, &msg.sender).await { + eprintln!(" ❌ Failed to reply on {}: {e}", ch.name()); } } } @@ -724,11 +736,8 @@ pub async fn start_channels(config: Config) -> Result<()> { " ❌ LLM error after {}ms: {e}", started_at.elapsed().as_millis() ); - for ch in &channels { - if ch.name() == msg.channel { - let _ = ch.send(&format!("⚠️ Error: {e}"), &msg.sender).await; - break; - } + if let Some(ch) = target_channel { + let _ = ch.send(&format!("⚠️ Error: {e}"), &msg.sender).await; } } Err(_) => { @@ -741,16 +750,13 @@ pub async fn start_channels(config: Config) -> Result<()> { timeout_msg, started_at.elapsed().as_millis() ); - for ch in &channels { - if ch.name() == msg.channel { - let _ = ch - .send( - "⚠️ Request timed out while waiting for the model. Please try again.", - &msg.sender, - ) - .await; - break; - } + if let Some(ch) = target_channel { + let _ = ch + .send( + "⚠️ Request timed out while waiting for the model. Please try again.", + &msg.sender, + ) + .await; } } } diff --git a/src/channels/traits.rs b/src/channels/traits.rs index 4709a1b..ae6239b 100644 --- a/src/channels/traits.rs +++ b/src/channels/traits.rs @@ -26,4 +26,15 @@ pub trait Channel: Send + Sync { async fn health_check(&self) -> bool { true } + + /// Signal that the bot is processing a response (e.g. "typing" indicator). + /// Implementations should repeat the indicator as needed for their platform. + async fn start_typing(&self, _recipient: &str) -> anyhow::Result<()> { + Ok(()) + } + + /// Stop any active typing indicator. + async fn stop_typing(&self, _recipient: &str) -> anyhow::Result<()> { + Ok(()) + } }