475 lines
17 KiB
Rust
475 lines
17 KiB
Rust
use super::traits::{Channel, ChannelMessage};
|
|
use async_trait::async_trait;
|
|
use futures_util::{SinkExt, StreamExt};
|
|
use serde_json::json;
|
|
use std::collections::HashSet;
|
|
use std::sync::Arc;
|
|
use tokio::sync::RwLock;
|
|
use tokio_tungstenite::tungstenite::Message;
|
|
use uuid::Uuid;
|
|
|
|
const QQ_API_BASE: &str = "https://api.sgroup.qq.com";
|
|
const QQ_AUTH_URL: &str = "https://bots.qq.com/app/getAppAccessToken";
|
|
|
|
/// Deduplication set capacity — evict half of entries when full.
|
|
const DEDUP_CAPACITY: usize = 10_000;
|
|
|
|
/// QQ Official Bot channel — uses Tencent's official QQ Bot API with
|
|
/// OAuth2 authentication and a Discord-like WebSocket gateway protocol.
|
|
pub struct QQChannel {
|
|
app_id: String,
|
|
app_secret: String,
|
|
allowed_users: Vec<String>,
|
|
client: reqwest::Client,
|
|
/// Cached access token + expiry timestamp.
|
|
token_cache: Arc<RwLock<Option<(String, u64)>>>,
|
|
/// Message deduplication set.
|
|
dedup: Arc<RwLock<HashSet<String>>>,
|
|
}
|
|
|
|
impl QQChannel {
|
|
pub fn new(app_id: String, app_secret: String, allowed_users: Vec<String>) -> Self {
|
|
Self {
|
|
app_id,
|
|
app_secret,
|
|
allowed_users,
|
|
client: reqwest::Client::new(),
|
|
token_cache: Arc::new(RwLock::new(None)),
|
|
dedup: Arc::new(RwLock::new(HashSet::new())),
|
|
}
|
|
}
|
|
|
|
fn is_user_allowed(&self, user_id: &str) -> bool {
|
|
self.allowed_users.iter().any(|u| u == "*" || u == user_id)
|
|
}
|
|
|
|
/// Fetch an access token from QQ's OAuth2 endpoint.
|
|
async fn fetch_access_token(&self) -> anyhow::Result<(String, u64)> {
|
|
let body = json!({
|
|
"appId": self.app_id,
|
|
"clientSecret": self.app_secret,
|
|
});
|
|
|
|
let resp = self.client.post(QQ_AUTH_URL).json(&body).send().await?;
|
|
|
|
if !resp.status().is_success() {
|
|
let status = resp.status();
|
|
let err = resp.text().await.unwrap_or_default();
|
|
anyhow::bail!("QQ token request failed ({status}): {err}");
|
|
}
|
|
|
|
let data: serde_json::Value = resp.json().await?;
|
|
let token = data
|
|
.get("access_token")
|
|
.and_then(|t| t.as_str())
|
|
.ok_or_else(|| anyhow::anyhow!("Missing access_token in QQ response"))?
|
|
.to_string();
|
|
|
|
let expires_in = data
|
|
.get("expires_in")
|
|
.and_then(|e| e.as_str())
|
|
.and_then(|e| e.parse::<u64>().ok())
|
|
.unwrap_or(7200);
|
|
|
|
let now = std::time::SystemTime::now()
|
|
.duration_since(std::time::UNIX_EPOCH)
|
|
.unwrap_or_default()
|
|
.as_secs();
|
|
|
|
// Expire 60 seconds early to avoid edge cases
|
|
let expiry = now + expires_in.saturating_sub(60);
|
|
|
|
Ok((token, expiry))
|
|
}
|
|
|
|
/// Get a valid access token, refreshing if expired.
|
|
async fn get_token(&self) -> anyhow::Result<String> {
|
|
let now = std::time::SystemTime::now()
|
|
.duration_since(std::time::UNIX_EPOCH)
|
|
.unwrap_or_default()
|
|
.as_secs();
|
|
|
|
{
|
|
let cache = self.token_cache.read().await;
|
|
if let Some((ref token, expiry)) = *cache {
|
|
if now < expiry {
|
|
return Ok(token.clone());
|
|
}
|
|
}
|
|
}
|
|
|
|
let (token, expiry) = self.fetch_access_token().await?;
|
|
{
|
|
let mut cache = self.token_cache.write().await;
|
|
*cache = Some((token.clone(), expiry));
|
|
}
|
|
Ok(token)
|
|
}
|
|
|
|
/// Get the WebSocket gateway URL.
|
|
async fn get_gateway_url(&self, token: &str) -> anyhow::Result<String> {
|
|
let resp = self
|
|
.client
|
|
.get(format!("{QQ_API_BASE}/gateway"))
|
|
.header("Authorization", format!("QQBot {token}"))
|
|
.send()
|
|
.await?;
|
|
|
|
if !resp.status().is_success() {
|
|
let status = resp.status();
|
|
let err = resp.text().await.unwrap_or_default();
|
|
anyhow::bail!("QQ gateway request failed ({status}): {err}");
|
|
}
|
|
|
|
let data: serde_json::Value = resp.json().await?;
|
|
let url = data
|
|
.get("url")
|
|
.and_then(|u| u.as_str())
|
|
.ok_or_else(|| anyhow::anyhow!("Missing gateway URL in QQ response"))?
|
|
.to_string();
|
|
|
|
Ok(url)
|
|
}
|
|
|
|
/// Check and insert message ID for deduplication.
|
|
async fn is_duplicate(&self, msg_id: &str) -> bool {
|
|
if msg_id.is_empty() {
|
|
return false;
|
|
}
|
|
|
|
let mut dedup = self.dedup.write().await;
|
|
|
|
if dedup.contains(msg_id) {
|
|
return true;
|
|
}
|
|
|
|
// Evict oldest half when at capacity
|
|
if dedup.len() >= DEDUP_CAPACITY {
|
|
let to_remove: Vec<String> = dedup.iter().take(DEDUP_CAPACITY / 2).cloned().collect();
|
|
for key in to_remove {
|
|
dedup.remove(&key);
|
|
}
|
|
}
|
|
|
|
dedup.insert(msg_id.to_string());
|
|
false
|
|
}
|
|
}
|
|
|
|
#[async_trait]
|
|
impl Channel for QQChannel {
|
|
fn name(&self) -> &str {
|
|
"qq"
|
|
}
|
|
|
|
async fn send(&self, message: &str, recipient: &str) -> anyhow::Result<()> {
|
|
let token = self.get_token().await?;
|
|
|
|
// Determine if this is a group or private message based on recipient format
|
|
// Format: "user:{openid}" or "group:{group_openid}"
|
|
let (url, body) = if let Some(group_id) = recipient.strip_prefix("group:") {
|
|
(
|
|
format!("{QQ_API_BASE}/v2/groups/{group_id}/messages"),
|
|
json!({
|
|
"content": message,
|
|
"msg_type": 0,
|
|
}),
|
|
)
|
|
} else {
|
|
let user_id = recipient.strip_prefix("user:").unwrap_or(recipient);
|
|
(
|
|
format!("{QQ_API_BASE}/v2/users/{user_id}/messages"),
|
|
json!({
|
|
"content": message,
|
|
"msg_type": 0,
|
|
}),
|
|
)
|
|
};
|
|
|
|
let resp = self
|
|
.client
|
|
.post(&url)
|
|
.header("Authorization", format!("QQBot {token}"))
|
|
.json(&body)
|
|
.send()
|
|
.await?;
|
|
|
|
if !resp.status().is_success() {
|
|
let status = resp.status();
|
|
let err = resp.text().await.unwrap_or_default();
|
|
anyhow::bail!("QQ send message failed ({status}): {err}");
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
#[allow(clippy::too_many_lines)]
|
|
async fn listen(&self, tx: tokio::sync::mpsc::Sender<ChannelMessage>) -> anyhow::Result<()> {
|
|
tracing::info!("QQ: authenticating...");
|
|
let token = self.get_token().await?;
|
|
|
|
tracing::info!("QQ: fetching gateway URL...");
|
|
let gw_url = self.get_gateway_url(&token).await?;
|
|
|
|
tracing::info!("QQ: connecting to gateway WebSocket...");
|
|
let (ws_stream, _) = tokio_tungstenite::connect_async(&gw_url).await?;
|
|
let (mut write, mut read) = ws_stream.split();
|
|
|
|
// Read Hello (opcode 10)
|
|
let hello = read
|
|
.next()
|
|
.await
|
|
.ok_or(anyhow::anyhow!("QQ: no hello frame"))??;
|
|
let hello_data: serde_json::Value = serde_json::from_str(&hello.to_string())?;
|
|
let heartbeat_interval = hello_data
|
|
.get("d")
|
|
.and_then(|d| d.get("heartbeat_interval"))
|
|
.and_then(serde_json::Value::as_u64)
|
|
.unwrap_or(41250);
|
|
|
|
// Send Identify (opcode 2)
|
|
// Intents: PUBLIC_GUILD_MESSAGES (1<<30) | C2C_MESSAGE_CREATE & GROUP_AT_MESSAGE_CREATE (1<<25)
|
|
let intents: u64 = (1 << 25) | (1 << 30);
|
|
let identify = json!({
|
|
"op": 2,
|
|
"d": {
|
|
"token": format!("QQBot {token}"),
|
|
"intents": intents,
|
|
"properties": {
|
|
"os": "linux",
|
|
"browser": "zeroclaw",
|
|
"device": "zeroclaw",
|
|
}
|
|
}
|
|
});
|
|
write.send(Message::Text(identify.to_string())).await?;
|
|
|
|
tracing::info!("QQ: connected and identified");
|
|
|
|
let mut sequence: i64 = -1;
|
|
|
|
// Spawn heartbeat timer
|
|
let (hb_tx, mut hb_rx) = tokio::sync::mpsc::channel::<()>(1);
|
|
let hb_interval = heartbeat_interval;
|
|
tokio::spawn(async move {
|
|
let mut interval = tokio::time::interval(std::time::Duration::from_millis(hb_interval));
|
|
loop {
|
|
interval.tick().await;
|
|
if hb_tx.send(()).await.is_err() {
|
|
break;
|
|
}
|
|
}
|
|
});
|
|
|
|
loop {
|
|
tokio::select! {
|
|
_ = hb_rx.recv() => {
|
|
let d = if sequence >= 0 { json!(sequence) } else { json!(null) };
|
|
let hb = json!({"op": 1, "d": d});
|
|
if write.send(Message::Text(hb.to_string())).await.is_err() {
|
|
break;
|
|
}
|
|
}
|
|
msg = read.next() => {
|
|
let msg = match msg {
|
|
Some(Ok(Message::Text(t))) => t,
|
|
Some(Ok(Message::Close(_))) | None => break,
|
|
_ => continue,
|
|
};
|
|
|
|
let event: serde_json::Value = match serde_json::from_str(&msg) {
|
|
Ok(e) => e,
|
|
Err(_) => continue,
|
|
};
|
|
|
|
// Track sequence number
|
|
if let Some(s) = event.get("s").and_then(serde_json::Value::as_i64) {
|
|
sequence = s;
|
|
}
|
|
|
|
let op = event.get("op").and_then(serde_json::Value::as_u64).unwrap_or(0);
|
|
|
|
match op {
|
|
// Server requests immediate heartbeat
|
|
1 => {
|
|
let d = if sequence >= 0 { json!(sequence) } else { json!(null) };
|
|
let hb = json!({"op": 1, "d": d});
|
|
if write.send(Message::Text(hb.to_string())).await.is_err() {
|
|
break;
|
|
}
|
|
continue;
|
|
}
|
|
// Reconnect
|
|
7 => {
|
|
tracing::warn!("QQ: received Reconnect (op 7)");
|
|
break;
|
|
}
|
|
// Invalid Session
|
|
9 => {
|
|
tracing::warn!("QQ: received Invalid Session (op 9)");
|
|
break;
|
|
}
|
|
_ => {}
|
|
}
|
|
|
|
// Only process dispatch events (op 0)
|
|
if op != 0 {
|
|
continue;
|
|
}
|
|
|
|
let event_type = event.get("t").and_then(|t| t.as_str()).unwrap_or("");
|
|
let d = match event.get("d") {
|
|
Some(d) => d,
|
|
None => continue,
|
|
};
|
|
|
|
match event_type {
|
|
"C2C_MESSAGE_CREATE" => {
|
|
let msg_id = d.get("id").and_then(|i| i.as_str()).unwrap_or("");
|
|
if self.is_duplicate(msg_id).await {
|
|
continue;
|
|
}
|
|
|
|
let content = d.get("content").and_then(|c| c.as_str()).unwrap_or("").trim();
|
|
if content.is_empty() {
|
|
continue;
|
|
}
|
|
|
|
let author_id = d.get("author").and_then(|a| a.get("id")).and_then(|i| i.as_str()).unwrap_or("unknown");
|
|
// For QQ, user_openid is the identifier
|
|
let user_openid = d.get("author").and_then(|a| a.get("user_openid")).and_then(|u| u.as_str()).unwrap_or(author_id);
|
|
|
|
if !self.is_user_allowed(user_openid) {
|
|
tracing::warn!("QQ: ignoring C2C message from unauthorized user: {user_openid}");
|
|
continue;
|
|
}
|
|
|
|
let chat_id = format!("user:{user_openid}");
|
|
|
|
let channel_msg = ChannelMessage {
|
|
id: Uuid::new_v4().to_string(),
|
|
sender: user_openid.to_string(),
|
|
reply_target: chat_id,
|
|
content: content.to_string(),
|
|
channel: "qq".to_string(),
|
|
timestamp: std::time::SystemTime::now()
|
|
.duration_since(std::time::UNIX_EPOCH)
|
|
.unwrap_or_default()
|
|
.as_secs(),
|
|
};
|
|
|
|
if tx.send(channel_msg).await.is_err() {
|
|
tracing::warn!("QQ: message channel closed");
|
|
break;
|
|
}
|
|
}
|
|
"GROUP_AT_MESSAGE_CREATE" => {
|
|
let msg_id = d.get("id").and_then(|i| i.as_str()).unwrap_or("");
|
|
if self.is_duplicate(msg_id).await {
|
|
continue;
|
|
}
|
|
|
|
let content = d.get("content").and_then(|c| c.as_str()).unwrap_or("").trim();
|
|
if content.is_empty() {
|
|
continue;
|
|
}
|
|
|
|
let author_id = d.get("author").and_then(|a| a.get("member_openid")).and_then(|m| m.as_str()).unwrap_or("unknown");
|
|
|
|
if !self.is_user_allowed(author_id) {
|
|
tracing::warn!("QQ: ignoring group message from unauthorized user: {author_id}");
|
|
continue;
|
|
}
|
|
|
|
let group_openid = d.get("group_openid").and_then(|g| g.as_str()).unwrap_or("unknown");
|
|
let chat_id = format!("group:{group_openid}");
|
|
|
|
let channel_msg = ChannelMessage {
|
|
id: Uuid::new_v4().to_string(),
|
|
sender: author_id.to_string(),
|
|
reply_target: chat_id,
|
|
content: content.to_string(),
|
|
channel: "qq".to_string(),
|
|
timestamp: std::time::SystemTime::now()
|
|
.duration_since(std::time::UNIX_EPOCH)
|
|
.unwrap_or_default()
|
|
.as_secs(),
|
|
};
|
|
|
|
if tx.send(channel_msg).await.is_err() {
|
|
tracing::warn!("QQ: message channel closed");
|
|
break;
|
|
}
|
|
}
|
|
_ => {}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
anyhow::bail!("QQ WebSocket connection closed")
|
|
}
|
|
|
|
async fn health_check(&self) -> bool {
|
|
self.fetch_access_token().await.is_ok()
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
#[test]
|
|
fn test_name() {
|
|
let ch = QQChannel::new("id".into(), "secret".into(), vec![]);
|
|
assert_eq!(ch.name(), "qq");
|
|
}
|
|
|
|
#[test]
|
|
fn test_user_allowed_wildcard() {
|
|
let ch = QQChannel::new("id".into(), "secret".into(), vec!["*".into()]);
|
|
assert!(ch.is_user_allowed("anyone"));
|
|
}
|
|
|
|
#[test]
|
|
fn test_user_allowed_specific() {
|
|
let ch = QQChannel::new("id".into(), "secret".into(), vec!["user123".into()]);
|
|
assert!(ch.is_user_allowed("user123"));
|
|
assert!(!ch.is_user_allowed("other"));
|
|
}
|
|
|
|
#[test]
|
|
fn test_user_denied_empty() {
|
|
let ch = QQChannel::new("id".into(), "secret".into(), vec![]);
|
|
assert!(!ch.is_user_allowed("anyone"));
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_dedup() {
|
|
let ch = QQChannel::new("id".into(), "secret".into(), vec![]);
|
|
assert!(!ch.is_duplicate("msg1").await);
|
|
assert!(ch.is_duplicate("msg1").await);
|
|
assert!(!ch.is_duplicate("msg2").await);
|
|
}
|
|
|
|
#[tokio::test]
|
|
async fn test_dedup_empty_id() {
|
|
let ch = QQChannel::new("id".into(), "secret".into(), vec![]);
|
|
// Empty IDs should never be considered duplicates
|
|
assert!(!ch.is_duplicate("").await);
|
|
assert!(!ch.is_duplicate("").await);
|
|
}
|
|
|
|
#[test]
|
|
fn test_config_serde() {
|
|
let toml_str = r#"
|
|
app_id = "12345"
|
|
app_secret = "secret_abc"
|
|
allowed_users = ["user1"]
|
|
"#;
|
|
let config: crate::config::schema::QQConfig = toml::from_str(toml_str).unwrap();
|
|
assert_eq!(config.app_id, "12345");
|
|
assert_eq!(config.app_secret, "secret_abc");
|
|
assert_eq!(config.allowed_users, vec!["user1"]);
|
|
}
|
|
}
|