zeroclaw/src/channels/discord.rs
argenis de la rosa a5887ad2dc docs+tests: architecture diagram, security docs, 75 new edge-case tests
README:
- Add ASCII architecture flow diagram showing all layers
- Add Security Architecture section (Layer 1: Channel Auth,
  Layer 2: Rate Limiting, Layer 3: Tool Sandbox)
- Update test count to 629

New edge-case tests (75 new):
- SecurityPolicy: command injection (semicolon, backtick, dollar-paren,
  env prefix, newline), path traversal (encoded dots, double-dot in
  filename, null byte, symlink, tilde-ssh, /var/run), rate limiter
  boundaries (exactly-at, zero, high), autonomy+command combos,
  from_config fresh tracker
- Discord: exact match not substring, empty user ID, wildcard+specific,
  case sensitivity, base64 edge cases
- Slack: exact match, empty user ID, case sensitivity, wildcard combo
- Telegram: exact match, empty string, case sensitivity, wildcard combo
- Gateway: first-match-wins, empty value, colon in value, different
  headers, empty request, newline-only request
- Config schema: backward compat (Discord/Slack without allowed_users),
  TOML roundtrip, webhook secret presence/absence

629 tests passing, 0 clippy warnings
2026-02-13 16:00:15 -05:00

358 lines
12 KiB
Rust

use super::traits::{Channel, ChannelMessage};
use async_trait::async_trait;
use futures_util::{SinkExt, StreamExt};
use serde_json::json;
use tokio_tungstenite::tungstenite::Message;
use uuid::Uuid;
/// Discord channel — connects via Gateway WebSocket for real-time messages
pub struct DiscordChannel {
bot_token: String,
guild_id: Option<String>,
allowed_users: Vec<String>,
client: reqwest::Client,
}
impl DiscordChannel {
pub fn new(bot_token: String, guild_id: Option<String>, allowed_users: Vec<String>) -> Self {
Self {
bot_token,
guild_id,
allowed_users,
client: reqwest::Client::new(),
}
}
/// Check if a Discord user ID is in the allowlist.
/// Empty list or `["*"]` means allow everyone.
fn is_user_allowed(&self, user_id: &str) -> bool {
if self.allowed_users.is_empty() {
return true;
}
self.allowed_users.iter().any(|u| u == "*" || u == user_id)
}
fn bot_user_id_from_token(token: &str) -> Option<String> {
// Discord bot tokens are base64(bot_user_id).timestamp.hmac
let part = token.split('.').next()?;
base64_decode(part)
}
}
const BASE64_ALPHABET: &[u8] =
b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
/// Minimal base64 decode (no extra dep) — only needs to decode the user ID portion
#[allow(clippy::cast_possible_truncation)]
fn base64_decode(input: &str) -> Option<String> {
let padded = match input.len() % 4 {
2 => format!("{input}=="),
3 => format!("{input}="),
_ => input.to_string(),
};
let mut bytes = Vec::new();
let chars: Vec<u8> = padded.bytes().collect();
for chunk in chars.chunks(4) {
if chunk.len() < 4 {
break;
}
let mut v = [0usize; 4];
for (i, &b) in chunk.iter().enumerate() {
if b == b'=' {
v[i] = 0;
} else {
v[i] = BASE64_ALPHABET.iter().position(|&a| a == b)?;
}
}
bytes.push(((v[0] << 2) | (v[1] >> 4)) as u8);
if chunk[2] != b'=' {
bytes.push((((v[1] & 0xF) << 4) | (v[2] >> 2)) as u8);
}
if chunk[3] != b'=' {
bytes.push((((v[2] & 0x3) << 6) | v[3]) as u8);
}
}
String::from_utf8(bytes).ok()
}
#[async_trait]
impl Channel for DiscordChannel {
fn name(&self) -> &str {
"discord"
}
async fn send(&self, message: &str, channel_id: &str) -> anyhow::Result<()> {
let url = format!("https://discord.com/api/v10/channels/{channel_id}/messages");
let body = json!({ "content": message });
self.client
.post(&url)
.header("Authorization", format!("Bot {}", self.bot_token))
.json(&body)
.send()
.await?;
Ok(())
}
#[allow(clippy::too_many_lines)]
async fn listen(&self, tx: tokio::sync::mpsc::Sender<ChannelMessage>) -> anyhow::Result<()> {
let bot_user_id = Self::bot_user_id_from_token(&self.bot_token).unwrap_or_default();
// Get Gateway URL
let gw_resp: serde_json::Value = self
.client
.get("https://discord.com/api/v10/gateway/bot")
.header("Authorization", format!("Bot {}", self.bot_token))
.send()
.await?
.json()
.await?;
let gw_url = gw_resp
.get("url")
.and_then(|u| u.as_str())
.unwrap_or("wss://gateway.discord.gg");
let ws_url = format!("{gw_url}/?v=10&encoding=json");
tracing::info!("Discord: connecting to gateway...");
let (ws_stream, _) = tokio_tungstenite::connect_async(&ws_url).await?;
let (mut write, mut read) = ws_stream.split();
// Read Hello (opcode 10)
let hello = read.next().await.ok_or(anyhow::anyhow!("No hello"))??;
let hello_data: serde_json::Value = serde_json::from_str(&hello.to_string())?;
let heartbeat_interval = hello_data
.get("d")
.and_then(|d| d.get("heartbeat_interval"))
.and_then(serde_json::Value::as_u64)
.unwrap_or(41250);
// Send Identify (opcode 2)
let identify = json!({
"op": 2,
"d": {
"token": self.bot_token,
"intents": 33281, // GUILDS | GUILD_MESSAGES | MESSAGE_CONTENT | DIRECT_MESSAGES
"properties": {
"os": "linux",
"browser": "zeroclaw",
"device": "zeroclaw"
}
}
});
write.send(Message::Text(identify.to_string())).await?;
tracing::info!("Discord: connected and identified");
// Spawn heartbeat task
let (hb_tx, mut hb_rx) = tokio::sync::mpsc::channel::<()>(1);
let hb_interval = heartbeat_interval;
tokio::spawn(async move {
let mut interval =
tokio::time::interval(std::time::Duration::from_millis(hb_interval));
loop {
interval.tick().await;
if hb_tx.send(()).await.is_err() {
break;
}
}
});
let guild_filter = self.guild_id.clone();
loop {
tokio::select! {
_ = hb_rx.recv() => {
let hb = json!({"op": 1, "d": null});
if write.send(Message::Text(hb.to_string())).await.is_err() {
break;
}
}
msg = read.next() => {
let msg = match msg {
Some(Ok(Message::Text(t))) => t,
Some(Ok(Message::Close(_))) | None => break,
_ => continue,
};
let event: serde_json::Value = match serde_json::from_str(&msg) {
Ok(e) => e,
Err(_) => continue,
};
// Only handle MESSAGE_CREATE (opcode 0, type "MESSAGE_CREATE")
let event_type = event.get("t").and_then(|t| t.as_str()).unwrap_or("");
if event_type != "MESSAGE_CREATE" {
continue;
}
let Some(d) = event.get("d") else {
continue;
};
// Skip messages from the bot itself
let author_id = d.get("author").and_then(|a| a.get("id")).and_then(|i| i.as_str()).unwrap_or("");
if author_id == bot_user_id {
continue;
}
// Skip bot messages
if d.get("author").and_then(|a| a.get("bot")).and_then(serde_json::Value::as_bool).unwrap_or(false) {
continue;
}
// Sender validation
if !self.is_user_allowed(author_id) {
tracing::warn!("Discord: ignoring message from unauthorized user: {author_id}");
continue;
}
// Guild filter
if let Some(ref gid) = guild_filter {
let msg_guild = d.get("guild_id").and_then(serde_json::Value::as_str).unwrap_or("");
if msg_guild != gid {
continue;
}
}
let content = d.get("content").and_then(|c| c.as_str()).unwrap_or("");
if content.is_empty() {
continue;
}
let channel_id = d.get("channel_id").and_then(|c| c.as_str()).unwrap_or("").to_string();
let channel_msg = ChannelMessage {
id: Uuid::new_v4().to_string(),
sender: channel_id,
content: content.to_string(),
channel: "discord".to_string(),
timestamp: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
};
if tx.send(channel_msg).await.is_err() {
break;
}
}
}
}
Ok(())
}
async fn health_check(&self) -> bool {
self.client
.get("https://discord.com/api/v10/users/@me")
.header("Authorization", format!("Bot {}", self.bot_token))
.send()
.await
.map(|r| r.status().is_success())
.unwrap_or(false)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn discord_channel_name() {
let ch = DiscordChannel::new("fake".into(), None, vec![]);
assert_eq!(ch.name(), "discord");
}
#[test]
fn base64_decode_bot_id() {
// "MTIzNDU2" decodes to "123456"
let decoded = base64_decode("MTIzNDU2");
assert_eq!(decoded, Some("123456".to_string()));
}
#[test]
fn bot_user_id_extraction() {
// Token format: base64(user_id).timestamp.hmac
let token = "MTIzNDU2.fake.hmac";
let id = DiscordChannel::bot_user_id_from_token(token);
assert_eq!(id, Some("123456".to_string()));
}
#[test]
fn empty_allowlist_allows_everyone() {
let ch = DiscordChannel::new("fake".into(), None, vec![]);
assert!(ch.is_user_allowed("12345"));
assert!(ch.is_user_allowed("anyone"));
}
#[test]
fn wildcard_allows_everyone() {
let ch = DiscordChannel::new("fake".into(), None, vec!["*".into()]);
assert!(ch.is_user_allowed("12345"));
assert!(ch.is_user_allowed("anyone"));
}
#[test]
fn specific_allowlist_filters() {
let ch = DiscordChannel::new("fake".into(), None, vec!["111".into(), "222".into()]);
assert!(ch.is_user_allowed("111"));
assert!(ch.is_user_allowed("222"));
assert!(!ch.is_user_allowed("333"));
assert!(!ch.is_user_allowed("unknown"));
}
#[test]
fn allowlist_is_exact_match_not_substring() {
let ch = DiscordChannel::new("fake".into(), None, vec!["111".into()]);
assert!(!ch.is_user_allowed("1111"));
assert!(!ch.is_user_allowed("11"));
assert!(!ch.is_user_allowed("0111"));
}
#[test]
fn allowlist_empty_string_user_id() {
let ch = DiscordChannel::new("fake".into(), None, vec!["111".into()]);
assert!(!ch.is_user_allowed(""));
}
#[test]
fn allowlist_with_wildcard_and_specific() {
let ch = DiscordChannel::new("fake".into(), None, vec!["111".into(), "*".into()]);
assert!(ch.is_user_allowed("111"));
assert!(ch.is_user_allowed("anyone_else"));
}
#[test]
fn allowlist_case_sensitive() {
let ch = DiscordChannel::new("fake".into(), None, vec!["ABC".into()]);
assert!(ch.is_user_allowed("ABC"));
assert!(!ch.is_user_allowed("abc"));
assert!(!ch.is_user_allowed("Abc"));
}
#[test]
fn base64_decode_empty_string() {
let decoded = base64_decode("");
assert_eq!(decoded, Some(String::new()));
}
#[test]
fn base64_decode_invalid_chars() {
let decoded = base64_decode("!!!!");
assert!(decoded.is_none());
}
#[test]
fn bot_user_id_from_empty_token() {
let id = DiscordChannel::bot_user_id_from_token("");
assert_eq!(id, Some(String::new()));
}
}