use super::traits::{Channel, ChannelMessage}; use async_trait::async_trait; use futures_util::{SinkExt, StreamExt}; use serde_json::json; use tokio_tungstenite::tungstenite::Message; use uuid::Uuid; /// Discord channel — connects via Gateway WebSocket for real-time messages pub struct DiscordChannel { bot_token: String, guild_id: Option, allowed_users: Vec, listen_to_bots: bool, client: reqwest::Client, typing_handle: std::sync::Mutex>>, } impl DiscordChannel { pub fn new(bot_token: String, guild_id: Option, allowed_users: Vec, listen_to_bots: bool) -> Self { Self { bot_token, guild_id, allowed_users, listen_to_bots, client: reqwest::Client::new(), typing_handle: std::sync::Mutex::new(None), } } /// Check if a Discord user ID is in the allowlist. /// Empty list means deny everyone until explicitly configured. /// `"*"` means allow everyone. fn is_user_allowed(&self, user_id: &str) -> bool { self.allowed_users.iter().any(|u| u == "*" || u == user_id) } fn bot_user_id_from_token(token: &str) -> Option { // Discord bot tokens are base64(bot_user_id).timestamp.hmac let part = token.split('.').next()?; base64_decode(part) } } const BASE64_ALPHABET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; /// Discord's maximum message length for regular messages. /// /// Discord rejects longer payloads with `50035 Invalid Form Body`. const DISCORD_MAX_MESSAGE_LENGTH: usize = 2000; /// Split a message into chunks that respect Discord's 2000-character limit. /// Tries to split at word boundaries when possible. fn split_message_for_discord(message: &str) -> Vec { if message.chars().count() <= DISCORD_MAX_MESSAGE_LENGTH { return vec![message.to_string()]; } let mut chunks = Vec::new(); let mut remaining = message; while !remaining.is_empty() { // Find the byte offset for the 2000th character boundary. // If there are fewer than 2000 chars left, we can emit the tail directly. let hard_split = remaining .char_indices() .nth(DISCORD_MAX_MESSAGE_LENGTH) .map_or(remaining.len(), |(idx, _)| idx); let chunk_end = if hard_split == remaining.len() { hard_split } else { // Try to find a good break point (newline, then space) let search_area = &remaining[..hard_split]; // Prefer splitting at newline if let Some(pos) = search_area.rfind('\n') { // Don't split if the newline is too close to the end if search_area[..pos].chars().count() >= DISCORD_MAX_MESSAGE_LENGTH / 2 { pos + 1 } else { // Try space as fallback search_area.rfind(' ').map_or(hard_split, |space| space + 1) } } else if let Some(pos) = search_area.rfind(' ') { pos + 1 } else { // Hard split at the limit hard_split } }; chunks.push(remaining[..chunk_end].to_string()); remaining = &remaining[chunk_end..]; } chunks } /// Minimal base64 decode (no extra dep) — only needs to decode the user ID portion #[allow(clippy::cast_possible_truncation)] fn base64_decode(input: &str) -> Option { let padded = match input.len() % 4 { 2 => format!("{input}=="), 3 => format!("{input}="), _ => input.to_string(), }; let mut bytes = Vec::new(); let chars: Vec = padded.bytes().collect(); for chunk in chars.chunks(4) { if chunk.len() < 4 { break; } let mut v = [0usize; 4]; for (i, &b) in chunk.iter().enumerate() { if b == b'=' { v[i] = 0; } else { v[i] = BASE64_ALPHABET.iter().position(|&a| a == b)?; } } bytes.push(((v[0] << 2) | (v[1] >> 4)) as u8); if chunk[2] != b'=' { bytes.push((((v[1] & 0xF) << 4) | (v[2] >> 2)) as u8); } if chunk[3] != b'=' { bytes.push((((v[2] & 0x3) << 6) | v[3]) as u8); } } String::from_utf8(bytes).ok() } #[async_trait] impl Channel for DiscordChannel { fn name(&self) -> &str { "discord" } async fn send(&self, message: &str, channel_id: &str) -> anyhow::Result<()> { let chunks = split_message_for_discord(message); for (i, chunk) in chunks.iter().enumerate() { let url = format!("https://discord.com/api/v10/channels/{channel_id}/messages"); let body = json!({ "content": chunk }); let resp = self .client .post(&url) .header("Authorization", format!("Bot {}", self.bot_token)) .json(&body) .send() .await?; if !resp.status().is_success() { let status = resp.status(); let err = resp .text() .await .unwrap_or_else(|e| format!("")); anyhow::bail!("Discord send message failed ({status}): {err}"); } // Add a small delay between chunks to avoid rate limiting if i < chunks.len() - 1 { tokio::time::sleep(std::time::Duration::from_millis(500)).await; } } Ok(()) } #[allow(clippy::too_many_lines)] async fn listen(&self, tx: tokio::sync::mpsc::Sender) -> anyhow::Result<()> { let bot_user_id = Self::bot_user_id_from_token(&self.bot_token).unwrap_or_default(); // Get Gateway URL let gw_resp: serde_json::Value = self .client .get("https://discord.com/api/v10/gateway/bot") .header("Authorization", format!("Bot {}", self.bot_token)) .send() .await? .json() .await?; let gw_url = gw_resp .get("url") .and_then(|u| u.as_str()) .unwrap_or("wss://gateway.discord.gg"); let ws_url = format!("{gw_url}/?v=10&encoding=json"); tracing::info!("Discord: connecting to gateway..."); let (ws_stream, _) = tokio_tungstenite::connect_async(&ws_url).await?; let (mut write, mut read) = ws_stream.split(); // Read Hello (opcode 10) let hello = read.next().await.ok_or(anyhow::anyhow!("No hello"))??; let hello_data: serde_json::Value = serde_json::from_str(&hello.to_string())?; let heartbeat_interval = hello_data .get("d") .and_then(|d| d.get("heartbeat_interval")) .and_then(serde_json::Value::as_u64) .unwrap_or(41250); // Send Identify (opcode 2) let identify = json!({ "op": 2, "d": { "token": self.bot_token, "intents": 37377, // GUILDS | GUILD_MESSAGES | MESSAGE_CONTENT | DIRECT_MESSAGES "properties": { "os": "linux", "browser": "zeroclaw", "device": "zeroclaw" } } }); write.send(Message::Text(identify.to_string())).await?; tracing::info!("Discord: connected and identified"); // Track the last sequence number for heartbeats and resume. // Only accessed in the select! loop below, so a plain i64 suffices. let mut sequence: i64 = -1; // Spawn heartbeat timer — sends a tick signal, actual heartbeat // is assembled in the select! loop where `sequence` lives. let (hb_tx, mut hb_rx) = tokio::sync::mpsc::channel::<()>(1); let hb_interval = heartbeat_interval; tokio::spawn(async move { let mut interval = tokio::time::interval(std::time::Duration::from_millis(hb_interval)); loop { interval.tick().await; if hb_tx.send(()).await.is_err() { break; } } }); let guild_filter = self.guild_id.clone(); loop { tokio::select! { _ = hb_rx.recv() => { let d = if sequence >= 0 { json!(sequence) } else { json!(null) }; let hb = json!({"op": 1, "d": d}); if write.send(Message::Text(hb.to_string())).await.is_err() { break; } } msg = read.next() => { let msg = match msg { Some(Ok(Message::Text(t))) => t, Some(Ok(Message::Close(_))) | None => break, _ => continue, }; let event: serde_json::Value = match serde_json::from_str(&msg) { Ok(e) => e, Err(_) => continue, }; // Track sequence number from all dispatch events if let Some(s) = event.get("s").and_then(serde_json::Value::as_i64) { sequence = s; } let op = event.get("op").and_then(serde_json::Value::as_u64).unwrap_or(0); match op { // Op 1: Server requests an immediate heartbeat 1 => { let d = if sequence >= 0 { json!(sequence) } else { json!(null) }; let hb = json!({"op": 1, "d": d}); if write.send(Message::Text(hb.to_string())).await.is_err() { break; } continue; } // Op 7: Reconnect 7 => { tracing::warn!("Discord: received Reconnect (op 7), closing for restart"); break; } // Op 9: Invalid Session 9 => { tracing::warn!("Discord: received Invalid Session (op 9), closing for restart"); break; } _ => {} } // Only handle MESSAGE_CREATE (opcode 0, type "MESSAGE_CREATE") let event_type = event.get("t").and_then(|t| t.as_str()).unwrap_or(""); if event_type != "MESSAGE_CREATE" { continue; } let Some(d) = event.get("d") else { continue; }; // Skip messages from the bot itself let author_id = d.get("author").and_then(|a| a.get("id")).and_then(|i| i.as_str()).unwrap_or(""); if author_id == bot_user_id { continue; } // Skip bot messages (unless listen_to_bots is enabled) if !self.listen_to_bots && d.get("author").and_then(|a| a.get("bot")).and_then(serde_json::Value::as_bool).unwrap_or(false) { continue; } // Sender validation if !self.is_user_allowed(author_id) { tracing::warn!("Discord: ignoring message from unauthorized user: {author_id}"); continue; } // Guild filter if let Some(ref gid) = guild_filter { let msg_guild = d.get("guild_id").and_then(serde_json::Value::as_str); // DMs have no guild_id — let them through; for guild messages, enforce the filter if let Some(g) = msg_guild { if g != gid { continue; } } } let content = d.get("content").and_then(|c| c.as_str()).unwrap_or(""); if content.is_empty() { continue; } let channel_id = d.get("channel_id").and_then(|c| c.as_str()).unwrap_or("").to_string(); let channel_msg = ChannelMessage { id: Uuid::new_v4().to_string(), sender: channel_id, content: content.to_string(), channel: "discord".to_string(), timestamp: std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) .unwrap_or_default() .as_secs(), }; if tx.send(channel_msg).await.is_err() { break; } } } } Ok(()) } async fn health_check(&self) -> bool { self.client .get("https://discord.com/api/v10/users/@me") .header("Authorization", format!("Bot {}", self.bot_token)) .send() .await .map(|r| r.status().is_success()) .unwrap_or(false) } async fn start_typing(&self, recipient: &str) -> anyhow::Result<()> { self.stop_typing(recipient).await?; let client = self.client.clone(); let token = self.bot_token.clone(); let channel_id = recipient.to_string(); let handle = tokio::spawn(async move { let url = format!("https://discord.com/api/v10/channels/{channel_id}/typing"); loop { let _ = client .post(&url) .header("Authorization", format!("Bot {token}")) .send() .await; tokio::time::sleep(std::time::Duration::from_secs(8)).await; } }); if let Ok(mut guard) = self.typing_handle.lock() { *guard = Some(handle); } Ok(()) } async fn stop_typing(&self, _recipient: &str) -> anyhow::Result<()> { if let Ok(mut guard) = self.typing_handle.lock() { if let Some(handle) = guard.take() { handle.abort(); } } Ok(()) } } #[cfg(test)] mod tests { use super::*; #[test] fn discord_channel_name() { let ch = DiscordChannel::new("fake".into(), None, vec![], false); assert_eq!(ch.name(), "discord"); } #[test] fn base64_decode_bot_id() { // "MTIzNDU2" decodes to "123456" let decoded = base64_decode("MTIzNDU2"); assert_eq!(decoded, Some("123456".to_string())); } #[test] fn bot_user_id_extraction() { // Token format: base64(user_id).timestamp.hmac let token = "MTIzNDU2.fake.hmac"; let id = DiscordChannel::bot_user_id_from_token(token); assert_eq!(id, Some("123456".to_string())); } #[test] fn empty_allowlist_denies_everyone() { let ch = DiscordChannel::new("fake".into(), None, vec![], false); assert!(!ch.is_user_allowed("12345")); assert!(!ch.is_user_allowed("anyone")); } #[test] fn wildcard_allows_everyone() { let ch = DiscordChannel::new("fake".into(), None, vec!["*".into()], false); assert!(ch.is_user_allowed("12345")); assert!(ch.is_user_allowed("anyone")); } #[test] fn specific_allowlist_filters() { let ch = DiscordChannel::new("fake".into(), None, vec!["111".into(), "222".into()], false); assert!(ch.is_user_allowed("111")); assert!(ch.is_user_allowed("222")); assert!(!ch.is_user_allowed("333")); assert!(!ch.is_user_allowed("unknown")); } #[test] fn allowlist_is_exact_match_not_substring() { let ch = DiscordChannel::new("fake".into(), None, vec!["111".into()], false); assert!(!ch.is_user_allowed("1111")); assert!(!ch.is_user_allowed("11")); assert!(!ch.is_user_allowed("0111")); } #[test] fn allowlist_empty_string_user_id() { let ch = DiscordChannel::new("fake".into(), None, vec!["111".into()], false); assert!(!ch.is_user_allowed("")); } #[test] fn allowlist_with_wildcard_and_specific() { let ch = DiscordChannel::new("fake".into(), None, vec!["111".into(), "*".into()], false); assert!(ch.is_user_allowed("111")); assert!(ch.is_user_allowed("anyone_else")); } #[test] fn allowlist_case_sensitive() { let ch = DiscordChannel::new("fake".into(), None, vec!["ABC".into()], false); assert!(ch.is_user_allowed("ABC")); assert!(!ch.is_user_allowed("abc")); assert!(!ch.is_user_allowed("Abc")); } #[test] fn base64_decode_empty_string() { let decoded = base64_decode(""); assert_eq!(decoded, Some(String::new())); } #[test] fn base64_decode_invalid_chars() { let decoded = base64_decode("!!!!"); assert!(decoded.is_none()); } #[test] fn bot_user_id_from_empty_token() { let id = DiscordChannel::bot_user_id_from_token(""); assert_eq!(id, Some(String::new())); } // Message splitting tests #[test] fn split_empty_message() { let chunks = split_message_for_discord(""); assert_eq!(chunks, vec![""]); } #[test] fn split_short_message_under_limit() { let msg = "Hello, world!"; let chunks = split_message_for_discord(msg); assert_eq!(chunks, vec![msg]); } #[test] fn split_message_exactly_2000_chars() { let msg = "a".repeat(DISCORD_MAX_MESSAGE_LENGTH); let chunks = split_message_for_discord(&msg); assert_eq!(chunks.len(), 1); assert_eq!(chunks[0].chars().count(), DISCORD_MAX_MESSAGE_LENGTH); } #[test] fn split_message_just_over_limit() { let msg = "a".repeat(DISCORD_MAX_MESSAGE_LENGTH + 1); let chunks = split_message_for_discord(&msg); assert_eq!(chunks.len(), 2); assert_eq!(chunks[0].chars().count(), DISCORD_MAX_MESSAGE_LENGTH); assert_eq!(chunks[1].chars().count(), 1); } #[test] fn split_very_long_message() { let msg = "word ".repeat(2000); // 10000 characters (5 chars per "word ") let chunks = split_message_for_discord(&msg); // Should split into 5 chunks of <= 2000 chars assert_eq!(chunks.len(), 5); assert!(chunks .iter() .all(|chunk| chunk.chars().count() <= DISCORD_MAX_MESSAGE_LENGTH)); // Verify total content is preserved let reconstructed = chunks.concat(); assert_eq!(reconstructed, msg); } #[test] fn split_prefer_newline_break() { let msg = format!("{}\n{}", "a".repeat(1500), "b".repeat(500)); let chunks = split_message_for_discord(&msg); // Should split at the newline assert_eq!(chunks.len(), 2); assert!(chunks[0].ends_with('\n')); assert!(chunks[1].starts_with('b')); } #[test] fn split_prefer_space_break() { let msg = format!("{} {}", "a".repeat(1500), "b".repeat(600)); let chunks = split_message_for_discord(&msg); assert_eq!(chunks.len(), 2); } #[test] fn split_without_good_break_points_hard_split() { // No spaces or newlines - should hard split at 2000 let msg = "a".repeat(5000); let chunks = split_message_for_discord(&msg); assert_eq!(chunks.len(), 3); assert_eq!(chunks[0].chars().count(), DISCORD_MAX_MESSAGE_LENGTH); assert_eq!(chunks[1].chars().count(), DISCORD_MAX_MESSAGE_LENGTH); assert_eq!(chunks[2].chars().count(), 1000); } #[test] fn split_multiple_breaks() { // Create a message with multiple newlines let part1 = "a".repeat(900); let part2 = "b".repeat(900); let part3 = "c".repeat(900); let msg = format!("{part1}\n{part2}\n{part3}"); let chunks = split_message_for_discord(&msg); // Should split into 2 chunks (first two parts + third part) assert_eq!(chunks.len(), 2); assert!(chunks[0].chars().count() <= DISCORD_MAX_MESSAGE_LENGTH); assert!(chunks[1].chars().count() <= DISCORD_MAX_MESSAGE_LENGTH); } #[test] fn split_preserves_content() { let original = "Hello world! This is a test message with some content. ".repeat(200); let chunks = split_message_for_discord(&original); let reconstructed = chunks.concat(); assert_eq!(reconstructed, original); } #[test] fn split_unicode_content() { // Test with emoji and multi-byte characters let msg = "🦀 Rust is awesome! ".repeat(500); let chunks = split_message_for_discord(&msg); // All chunks should be valid UTF-8 for chunk in &chunks { assert!(std::str::from_utf8(chunk.as_bytes()).is_ok()); assert!(chunk.chars().count() <= DISCORD_MAX_MESSAGE_LENGTH); } // Reconstruct and verify let reconstructed = chunks.concat(); assert_eq!(reconstructed, msg); } #[test] fn split_newline_too_close_to_end() { // If newline is in the first half, don't use it - use space instead or hard split let msg = format!("{}\n{}", "a".repeat(1900), "b".repeat(500)); let chunks = split_message_for_discord(&msg); // Should split at newline since it's in the second half of the window assert_eq!(chunks.len(), 2); } #[test] fn split_multibyte_only_content_without_panics() { let msg = "你".repeat(2500); let chunks = split_message_for_discord(&msg); assert_eq!(chunks.len(), 2); assert_eq!(chunks[0].chars().count(), DISCORD_MAX_MESSAGE_LENGTH); assert_eq!(chunks[1].chars().count(), 500); let reconstructed = chunks.concat(); assert_eq!(reconstructed, msg); } #[test] fn split_chunks_always_within_discord_limit() { let msg = "x".repeat(12_345); let chunks = split_message_for_discord(&msg); assert!(chunks .iter() .all(|chunk| chunk.chars().count() <= DISCORD_MAX_MESSAGE_LENGTH)); } #[test] fn split_message_with_multiple_newlines() { let msg = "Line 1\nLine 2\nLine 3\n".repeat(1000); let chunks = split_message_for_discord(&msg); assert!(chunks.len() > 1); let reconstructed = chunks.concat(); assert_eq!(reconstructed, msg); } #[test] fn typing_handle_starts_as_none() { let ch = DiscordChannel::new("fake".into(), None, vec![], false); let guard = ch.typing_handle.lock().unwrap(); assert!(guard.is_none()); } #[tokio::test] async fn start_typing_sets_handle() { let ch = DiscordChannel::new("fake".into(), None, vec![], false); let _ = ch.start_typing("123456").await; let guard = ch.typing_handle.lock().unwrap(); assert!(guard.is_some()); } #[tokio::test] async fn stop_typing_clears_handle() { let ch = DiscordChannel::new("fake".into(), None, vec![], false); let _ = ch.start_typing("123456").await; let _ = ch.stop_typing("123456").await; let guard = ch.typing_handle.lock().unwrap(); assert!(guard.is_none()); } #[tokio::test] async fn stop_typing_is_idempotent() { let ch = DiscordChannel::new("fake".into(), None, vec![], false); assert!(ch.stop_typing("123456").await.is_ok()); assert!(ch.stop_typing("123456").await.is_ok()); } #[tokio::test] async fn start_typing_replaces_existing_task() { let ch = DiscordChannel::new("fake".into(), None, vec![], false); let _ = ch.start_typing("111").await; let _ = ch.start_typing("222").await; let guard = ch.typing_handle.lock().unwrap(); assert!(guard.is_some()); } }