diff --git a/Cargo.lock b/Cargo.lock index 6a4bb3f..f0a6be7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -209,6 +209,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b52af3cb4058c895d37317bb27508dccc8e5f2d39454016b297bf4a400597b8" dependencies = [ "axum-core", + "base64", "bytes", "form_urlencoded", "futures-util", @@ -227,8 +228,10 @@ dependencies = [ "serde_json", "serde_path_to_error", "serde_urlencoded", + "sha1", "sync_wrapper", "tokio", + "tokio-tungstenite 0.28.0", "tower", "tower-layer", "tower-service", @@ -3756,10 +3759,22 @@ dependencies = [ "rustls-pki-types", "tokio", "tokio-rustls", - "tungstenite", + "tungstenite 0.24.0", "webpki-roots 0.26.11", ] +[[package]] +name = "tokio-tungstenite" +version = "0.28.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d25a406cddcc431a75d3d9afc6a7c0f7428d4891dd973e4d54c56b46127bf857" +dependencies = [ + "futures-util", + "log", + "tokio", + "tungstenite 0.28.0", +] + [[package]] name = "tokio-util" version = "0.7.18" @@ -3991,6 +4006,23 @@ dependencies = [ "utf-8", ] +[[package]] +name = "tungstenite" +version = "0.28.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8628dcc84e5a09eb3d8423d6cb682965dea9133204e8fb3efee74c2a0c259442" +dependencies = [ + "bytes", + "data-encoding", + "http 1.4.0", + "httparse", + "log", + "rand 0.9.2", + "sha1", + "thiserror 2.0.18", + "utf-8", +] + [[package]] name = "twox-hash" version = "2.1.2" @@ -4893,6 +4925,7 @@ dependencies = [ "pdf-extract", "probe-rs", "prometheus", + "prost", "rand 0.8.5", "reqwest", "rppal", @@ -4909,7 +4942,7 @@ dependencies = [ "tokio-rustls", "tokio-serial", "tokio-test", - "tokio-tungstenite", + "tokio-tungstenite 0.24.0", "toml", "tower", "tower-http", diff --git a/Cargo.toml b/Cargo.toml index 10c054d..b91c56a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -69,6 +69,9 @@ landlock = { version = "0.4", optional = true } # Async traits async-trait = "0.1" +# Protobuf encode/decode (Feishu WS long-connection frame codec) +prost = { version = "0.14", default-features = false } + # Memory / persistence rusqlite = { version = "0.38", features = ["bundled"] } chrono = { version = "0.4", default-features = false, features = ["clock", "std", "serde"] } @@ -95,7 +98,7 @@ tokio-rustls = "0.26.4" webpki-roots = "1.0.6" # HTTP server (gateway) — replaces raw TCP for proper HTTP/1.1 compliance -axum = { version = "0.8", default-features = false, features = ["http1", "json", "tokio", "query"] } +axum = { version = "0.8", default-features = false, features = ["http1", "json", "tokio", "query", "ws"] } tower = { version = "0.5", default-features = false } tower-http = { version = "0.6", default-features = false, features = ["limit", "timeout"] } http-body-util = "0.1" diff --git a/src/channels/lark.rs b/src/channels/lark.rs index 4e9e679..3e482f5 100644 --- a/src/channels/lark.rs +++ b/src/channels/lark.rs @@ -1,21 +1,152 @@ use super::traits::{Channel, ChannelMessage}; use async_trait::async_trait; +use futures_util::{SinkExt, StreamExt}; +use prost::Message as ProstMessage; +use std::collections::HashMap; use std::sync::Arc; +use std::time::{Duration, Instant}; use tokio::sync::RwLock; +use tokio_tungstenite::tungstenite::Message as WsMsg; use uuid::Uuid; const FEISHU_BASE_URL: &str = "https://open.feishu.cn/open-apis"; +const FEISHU_WS_BASE_URL: &str = "https://open.feishu.cn"; +const LARK_BASE_URL: &str = "https://open.larksuite.com/open-apis"; +const LARK_WS_BASE_URL: &str = "https://open.larksuite.com"; -/// Lark/Feishu channel — receives events via HTTP callback, sends via Open API +// ───────────────────────────────────────────────────────────────────────────── +// Feishu WebSocket long-connection: pbbp2.proto frame codec +// ───────────────────────────────────────────────────────────────────────────── + +#[derive(Clone, PartialEq, prost::Message)] +struct PbHeader { + #[prost(string, tag = "1")] + pub key: String, + #[prost(string, tag = "2")] + pub value: String, +} + +/// Feishu WS frame (pbbp2.proto). +/// method=0 → CONTROL (ping/pong) method=1 → DATA (events) +#[derive(Clone, PartialEq, prost::Message)] +struct PbFrame { + #[prost(uint64, tag = "1")] + pub seq_id: u64, + #[prost(uint64, tag = "2")] + pub log_id: u64, + #[prost(int32, tag = "3")] + pub service: i32, + #[prost(int32, tag = "4")] + pub method: i32, + #[prost(message, repeated, tag = "5")] + pub headers: Vec, + #[prost(bytes = "vec", optional, tag = "8")] + pub payload: Option>, +} + +impl PbFrame { + fn header_value<'a>(&'a self, key: &str) -> &'a str { + self.headers + .iter() + .find(|h| h.key == key) + .map(|h| h.value.as_str()) + .unwrap_or("") + } +} + +/// Server-sent client config (parsed from pong payload) +#[derive(Debug, serde::Deserialize, Default, Clone)] +struct WsClientConfig { + #[serde(rename = "PingInterval")] + ping_interval: Option, +} + +/// POST /callback/ws/endpoint response +#[derive(Debug, serde::Deserialize)] +struct WsEndpointResp { + code: i32, + #[serde(default)] + msg: Option, + #[serde(default)] + data: Option, +} + +#[derive(Debug, serde::Deserialize)] +struct WsEndpoint { + #[serde(rename = "URL")] + url: String, + #[serde(rename = "ClientConfig")] + client_config: Option, +} + +/// LarkEvent envelope (method=1 / type=event payload) +#[derive(Debug, serde::Deserialize)] +struct LarkEvent { + header: LarkEventHeader, + event: serde_json::Value, +} + +#[derive(Debug, serde::Deserialize)] +struct LarkEventHeader { + event_type: String, + #[allow(dead_code)] + event_id: String, +} + +#[derive(Debug, serde::Deserialize)] +struct MsgReceivePayload { + sender: LarkSender, + message: LarkMessage, +} + +#[derive(Debug, serde::Deserialize)] +struct LarkSender { + sender_id: LarkSenderId, + #[serde(default)] + sender_type: String, +} + +#[derive(Debug, serde::Deserialize, Default)] +struct LarkSenderId { + open_id: Option, +} + +#[derive(Debug, serde::Deserialize)] +struct LarkMessage { + message_id: String, + chat_id: String, + chat_type: String, + message_type: String, + #[serde(default)] + content: String, + #[serde(default)] + mentions: Vec, +} + +/// Heartbeat timeout for WS connection — must be larger than ping_interval (default 120 s). +/// If no binary frame (pong or event) is received within this window, reconnect. +const WS_HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(300); + +/// Lark/Feishu channel. +/// +/// Supports two receive modes (configured via `receive_mode` in config): +/// - **`websocket`** (default): persistent WSS long-connection; no public URL needed. +/// - **`webhook`**: HTTP callback server; requires a public HTTPS endpoint. pub struct LarkChannel { app_id: String, app_secret: String, verification_token: String, - port: u16, + port: Option, allowed_users: Vec, + /// When true, use Feishu (CN) endpoints; when false, use Lark (international). + use_feishu: bool, + /// How to receive events: WebSocket long-connection or HTTP webhook. + receive_mode: crate::config::schema::LarkReceiveMode, client: reqwest::Client, /// Cached tenant access token tenant_token: Arc>>, + /// Dedup set: WS message_ids seen in last ~30 min to prevent double-dispatch + ws_seen_ids: Arc>>, } impl LarkChannel { @@ -23,7 +154,7 @@ impl LarkChannel { app_id: String, app_secret: String, verification_token: String, - port: u16, + port: Option, allowed_users: Vec, ) -> Self { Self { @@ -32,11 +163,295 @@ impl LarkChannel { verification_token, port, allowed_users, + use_feishu: true, + receive_mode: crate::config::schema::LarkReceiveMode::default(), client: reqwest::Client::new(), tenant_token: Arc::new(RwLock::new(None)), + ws_seen_ids: Arc::new(RwLock::new(HashMap::new())), } } + /// Build from `LarkConfig` (preserves `use_feishu` and `receive_mode`). + pub fn from_config(config: &crate::config::schema::LarkConfig) -> Self { + let mut ch = Self::new( + config.app_id.clone(), + config.app_secret.clone(), + config.verification_token.clone().unwrap_or_default(), + config.port, + config.allowed_users.clone(), + ); + ch.use_feishu = config.use_feishu; + ch.receive_mode = config.receive_mode.clone(); + ch + } + + fn api_base(&self) -> &'static str { + if self.use_feishu { + FEISHU_BASE_URL + } else { + LARK_BASE_URL + } + } + + fn ws_base(&self) -> &'static str { + if self.use_feishu { + FEISHU_WS_BASE_URL + } else { + LARK_WS_BASE_URL + } + } + + /// POST /callback/ws/endpoint → (wss_url, client_config) + async fn get_ws_endpoint(&self) -> anyhow::Result<(String, WsClientConfig)> { + let resp = self + .client + .post(format!("{}/callback/ws/endpoint", self.ws_base())) + .header("locale", if self.use_feishu { "zh" } else { "en" }) + .json(&serde_json::json!({ + "AppID": self.app_id, + "AppSecret": self.app_secret, + })) + .send() + .await? + .json::() + .await?; + if resp.code != 0 { + anyhow::bail!( + "Lark WS endpoint failed: code={} msg={}", + resp.code, + resp.msg.as_deref().unwrap_or("(none)") + ); + } + let ep = resp + .data + .ok_or_else(|| anyhow::anyhow!("Lark WS endpoint: empty data"))?; + Ok((ep.url, ep.client_config.unwrap_or_default())) + } + + /// WS long-connection event loop. Returns Ok(()) when the connection closes + /// (the caller reconnects). + #[allow(clippy::too_many_lines)] + async fn listen_ws(&self, tx: tokio::sync::mpsc::Sender) -> anyhow::Result<()> { + let (wss_url, client_config) = self.get_ws_endpoint().await?; + let service_id = wss_url + .split('?') + .nth(1) + .and_then(|qs| { + qs.split('&') + .find(|kv| kv.starts_with("service_id=")) + .and_then(|kv| kv.split('=').nth(1)) + .and_then(|v| v.parse::().ok()) + }) + .unwrap_or(0); + tracing::info!("Lark: connecting to {wss_url}"); + + let (ws_stream, _) = tokio_tungstenite::connect_async(&wss_url).await?; + let (mut write, mut read) = ws_stream.split(); + tracing::info!("Lark: WS connected (service_id={service_id})"); + + let mut ping_secs = client_config.ping_interval.unwrap_or(120).max(10); + let mut hb_interval = tokio::time::interval(Duration::from_secs(ping_secs)); + let mut timeout_check = tokio::time::interval(Duration::from_secs(10)); + hb_interval.tick().await; // consume immediate tick + + let mut seq: u64 = 0; + let mut last_recv = Instant::now(); + + // Send initial ping immediately (like the official SDK) so the server + // starts responding with pongs and we can calibrate the ping_interval. + seq = seq.wrapping_add(1); + let initial_ping = PbFrame { + seq_id: seq, + log_id: 0, + service: service_id, + method: 0, + headers: vec![PbHeader { + key: "type".into(), + value: "ping".into(), + }], + payload: None, + }; + if write + .send(WsMsg::Binary(initial_ping.encode_to_vec())) + .await + .is_err() + { + anyhow::bail!("Lark: initial ping failed"); + } + // message_id → (fragment_slots, created_at) for multi-part reassembly + type FragEntry = (Vec>>, Instant); + let mut frag_cache: HashMap = HashMap::new(); + + loop { + tokio::select! { + biased; + + _ = hb_interval.tick() => { + seq = seq.wrapping_add(1); + let ping = PbFrame { + seq_id: seq, log_id: 0, service: service_id, method: 0, + headers: vec![PbHeader { key: "type".into(), value: "ping".into() }], + payload: None, + }; + if write.send(WsMsg::Binary(ping.encode_to_vec())).await.is_err() { + tracing::warn!("Lark: ping failed, reconnecting"); + break; + } + // GC stale fragments > 5 min + let cutoff = Instant::now().checked_sub(Duration::from_secs(300)).unwrap_or(Instant::now()); + frag_cache.retain(|_, (_, ts)| *ts > cutoff); + } + + _ = timeout_check.tick() => { + if last_recv.elapsed() > WS_HEARTBEAT_TIMEOUT { + tracing::warn!("Lark: heartbeat timeout, reconnecting"); + break; + } + } + + msg = read.next() => { + let raw = match msg { + Some(Ok(WsMsg::Binary(b))) => { last_recv = Instant::now(); b } + Some(Ok(WsMsg::Ping(d))) => { let _ = write.send(WsMsg::Pong(d)).await; continue; } + Some(Ok(WsMsg::Close(_))) | None => { tracing::info!("Lark: WS closed — reconnecting"); break; } + Some(Err(e)) => { tracing::error!("Lark: WS read error: {e}"); break; } + _ => continue, + }; + + let frame = match PbFrame::decode(&raw[..]) { + Ok(f) => f, + Err(e) => { tracing::error!("Lark: proto decode: {e}"); continue; } + }; + + // CONTROL frame + if frame.method == 0 { + if frame.header_value("type") == "pong" { + if let Some(p) = &frame.payload { + if let Ok(cfg) = serde_json::from_slice::(p) { + if let Some(secs) = cfg.ping_interval { + let secs = secs.max(10); + if secs != ping_secs { + ping_secs = secs; + hb_interval = tokio::time::interval(Duration::from_secs(ping_secs)); + tracing::info!("Lark: ping_interval → {ping_secs}s"); + } + } + } + } + } + continue; + } + + // DATA frame + let msg_type = frame.header_value("type").to_string(); + let msg_id = frame.header_value("message_id").to_string(); + let sum = frame.header_value("sum").parse::().unwrap_or(1); + let seq_num = frame.header_value("seq").parse::().unwrap_or(0); + + // ACK immediately (Feishu requires within 3 s) + { + let mut ack = frame.clone(); + ack.payload = Some(br#"{"code":200,"headers":{},"data":[]}"#.to_vec()); + ack.headers.push(PbHeader { key: "biz_rt".into(), value: "0".into() }); + let _ = write.send(WsMsg::Binary(ack.encode_to_vec())).await; + } + + // Fragment reassembly + let sum = if sum == 0 { 1 } else { sum }; + let payload: Vec = if sum == 1 || msg_id.is_empty() || seq_num >= sum { + frame.payload.clone().unwrap_or_default() + } else { + let entry = frag_cache.entry(msg_id.clone()) + .or_insert_with(|| (vec![None; sum], Instant::now())); + if entry.0.len() != sum { *entry = (vec![None; sum], Instant::now()); } + entry.0[seq_num] = frame.payload.clone(); + if entry.0.iter().all(|s| s.is_some()) { + let full: Vec = entry.0.iter() + .flat_map(|s| s.as_deref().unwrap_or(&[])) + .copied().collect(); + frag_cache.remove(&msg_id); + full + } else { continue; } + }; + + if msg_type != "event" { continue; } + + let event: LarkEvent = match serde_json::from_slice(&payload) { + Ok(e) => e, + Err(e) => { tracing::error!("Lark: event JSON: {e}"); continue; } + }; + if event.header.event_type != "im.message.receive_v1" { continue; } + + let recv: MsgReceivePayload = match serde_json::from_value(event.event) { + Ok(r) => r, + Err(e) => { tracing::error!("Lark: payload parse: {e}"); continue; } + }; + + if recv.sender.sender_type == "app" || recv.sender.sender_type == "bot" { continue; } + + let sender_open_id = recv.sender.sender_id.open_id.as_deref().unwrap_or(""); + if !self.is_user_allowed(sender_open_id) { + tracing::warn!("Lark WS: ignoring {sender_open_id} (not in allowed_users)"); + continue; + } + + let lark_msg = &recv.message; + + // Dedup + { + let now = Instant::now(); + let mut seen = self.ws_seen_ids.write().await; + // GC + seen.retain(|_, t| now.duration_since(*t) < Duration::from_secs(30 * 60)); + if seen.contains_key(&lark_msg.message_id) { + tracing::debug!("Lark WS: dup {}", lark_msg.message_id); + continue; + } + seen.insert(lark_msg.message_id.clone(), now); + } + + // Decode content by type (mirrors clawdbot-feishu parsing) + let text = match lark_msg.message_type.as_str() { + "text" => { + let v: serde_json::Value = match serde_json::from_str(&lark_msg.content) { + Ok(v) => v, + Err(_) => continue, + }; + v.get("text").and_then(|t| t.as_str()).unwrap_or("").to_string() + } + "post" => parse_post_content(&lark_msg.content), + _ => { tracing::debug!("Lark WS: skipping unsupported type '{}'", lark_msg.message_type); continue; } + }; + + // Strip @_user_N placeholders + let text = strip_at_placeholders(&text); + let text = text.trim().to_string(); + if text.is_empty() { continue; } + + // Group-chat: only respond when explicitly @-mentioned + if lark_msg.chat_type == "group" && !should_respond_in_group(&lark_msg.mentions) { + continue; + } + + let channel_msg = ChannelMessage { + id: Uuid::new_v4().to_string(), + sender: lark_msg.chat_id.clone(), + content: text, + channel: "lark".to_string(), + timestamp: std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_secs(), + }; + + tracing::debug!("Lark WS: message in {}", lark_msg.chat_id); + if tx.send(channel_msg).await.is_err() { break; } + } + } + } + Ok(()) + } + /// Check if a user open_id is allowed fn is_user_allowed(&self, open_id: &str) -> bool { self.allowed_users.iter().any(|u| u == "*" || u == open_id) @@ -238,6 +653,25 @@ impl Channel for LarkChannel { } async fn listen(&self, tx: tokio::sync::mpsc::Sender) -> anyhow::Result<()> { + use crate::config::schema::LarkReceiveMode; + match self.receive_mode { + LarkReceiveMode::Websocket => self.listen_ws(tx).await, + LarkReceiveMode::Webhook => self.listen_http(tx).await, + } + } + + async fn health_check(&self) -> bool { + self.get_tenant_access_token().await.is_ok() + } +} + +impl LarkChannel { + /// HTTP callback server (legacy — requires a public endpoint). + /// Use `listen()` (WS long-connection) for new deployments. + pub async fn listen_http( + &self, + tx: tokio::sync::mpsc::Sender, + ) -> anyhow::Result<()> { use axum::{extract::State, routing::post, Json, Router}; #[derive(Clone)] @@ -282,13 +716,17 @@ impl Channel for LarkChannel { (StatusCode::OK, "ok").into_response() } + let port = self.port.ok_or_else(|| { + anyhow::anyhow!("Lark webhook mode requires `port` to be set in [channels_config.lark]") + })?; + let state = AppState { verification_token: self.verification_token.clone(), channel: Arc::new(LarkChannel::new( self.app_id.clone(), self.app_secret.clone(), self.verification_token.clone(), - self.port, + None, self.allowed_users.clone(), )), tx, @@ -298,7 +736,7 @@ impl Channel for LarkChannel { .route("/lark", post(handle_event)) .with_state(state); - let addr = std::net::SocketAddr::from(([0, 0, 0, 0], self.port)); + let addr = std::net::SocketAddr::from(([0, 0, 0, 0], port)); tracing::info!("Lark event callback server listening on {addr}"); let listener = tokio::net::TcpListener::bind(addr).await?; @@ -306,10 +744,102 @@ impl Channel for LarkChannel { Ok(()) } +} - async fn health_check(&self) -> bool { - self.get_tenant_access_token().await.is_ok() +// ───────────────────────────────────────────────────────────────────────────── +// WS helper functions +// ───────────────────────────────────────────────────────────────────────────── + +/// Flatten a Feishu `post` rich-text message to plain text. +fn parse_post_content(content: &str) -> String { + let Ok(parsed) = serde_json::from_str::(content) else { + return "[富文本消息]".to_string(); + }; + let locale = parsed + .get("zh_cn") + .or_else(|| parsed.get("en_us")) + .or_else(|| { + parsed + .as_object() + .and_then(|m| m.values().find(|v| v.is_object())) + }); + let Some(locale) = locale else { + return "[富文本消息]".to_string(); + }; + let mut text = String::new(); + if let Some(paragraphs) = locale.get("content").and_then(|c| c.as_array()) { + for para in paragraphs { + if let Some(elements) = para.as_array() { + for el in elements { + match el.get("tag").and_then(|t| t.as_str()).unwrap_or("") { + "text" => { + if let Some(t) = el.get("text").and_then(|t| t.as_str()) { + text.push_str(t); + } + } + "a" => { + text.push_str( + el.get("text") + .and_then(|t| t.as_str()) + .filter(|s| !s.is_empty()) + .or_else(|| el.get("href").and_then(|h| h.as_str())) + .unwrap_or(""), + ); + } + "at" => { + let n = el + .get("user_name") + .and_then(|n| n.as_str()) + .or_else(|| el.get("user_id").and_then(|i| i.as_str())) + .unwrap_or("user"); + text.push('@'); + text.push_str(n); + } + "img" => { + text.push_str("[图片]"); + } + _ => {} + } + } + text.push('\n'); + } + } } + let result = text.trim().to_string(); + if result.is_empty() { + "[富文本消息]".to_string() + } else { + result + } +} + +/// Remove `@_user_N` placeholder tokens injected by Feishu in group chats. +fn strip_at_placeholders(text: &str) -> String { + let mut result = String::with_capacity(text.len()); + let mut chars = text.char_indices().peekable(); + while let Some((_, ch)) = chars.next() { + if ch == '@' { + let rest: String = chars.clone().map(|(_, c)| c).collect(); + if let Some(after) = rest.strip_prefix("_user_") { + let skip = + "_user_".len() + after.chars().take_while(|c| c.is_ascii_digit()).count(); + for _ in 0..=skip { + chars.next(); + } + if chars.peek().map(|(_, c)| *c == ' ').unwrap_or(false) { + chars.next(); + } + continue; + } + } + result.push(ch); + } + result +} + +/// In group chats, only respond when the bot is explicitly @-mentioned. +fn should_respond_in_group(mentions: &[serde_json::Value]) -> bool { + !mentions.is_empty() } #[cfg(test)] @@ -321,7 +851,7 @@ mod tests { "cli_test_app_id".into(), "test_app_secret".into(), "test_verification_token".into(), - 9898, + None, vec!["ou_testuser123".into()], ) } @@ -345,7 +875,7 @@ mod tests { "id".into(), "secret".into(), "token".into(), - 9898, + None, vec!["*".into()], ); assert!(ch.is_user_allowed("ou_anyone")); @@ -353,7 +883,7 @@ mod tests { #[test] fn lark_user_denied_empty() { - let ch = LarkChannel::new("id".into(), "secret".into(), "token".into(), 9898, vec![]); + let ch = LarkChannel::new("id".into(), "secret".into(), "token".into(), None, vec![]); assert!(!ch.is_user_allowed("ou_anyone")); } @@ -426,7 +956,7 @@ mod tests { "id".into(), "secret".into(), "token".into(), - 9898, + None, vec!["*".into()], ); let payload = serde_json::json!({ @@ -451,7 +981,7 @@ mod tests { "id".into(), "secret".into(), "token".into(), - 9898, + None, vec!["*".into()], ); let payload = serde_json::json!({ @@ -488,7 +1018,7 @@ mod tests { "id".into(), "secret".into(), "token".into(), - 9898, + None, vec!["*".into()], ); let payload = serde_json::json!({ @@ -512,7 +1042,7 @@ mod tests { "id".into(), "secret".into(), "token".into(), - 9898, + None, vec!["*".into()], ); let payload = serde_json::json!({ @@ -550,7 +1080,7 @@ mod tests { "id".into(), "secret".into(), "token".into(), - 9898, + None, vec!["*".into()], ); let payload = serde_json::json!({ @@ -571,7 +1101,7 @@ mod tests { #[test] fn lark_config_serde() { - use crate::config::schema::LarkConfig; + use crate::config::schema::{LarkConfig, LarkReceiveMode}; let lc = LarkConfig { app_id: "cli_app123".into(), app_secret: "secret456".into(), @@ -579,6 +1109,8 @@ mod tests { verification_token: Some("vtoken789".into()), allowed_users: vec!["ou_user1".into(), "ou_user2".into()], use_feishu: false, + receive_mode: LarkReceiveMode::default(), + port: None, }; let json = serde_json::to_string(&lc).unwrap(); let parsed: LarkConfig = serde_json::from_str(&json).unwrap(); @@ -590,7 +1122,7 @@ mod tests { #[test] fn lark_config_toml_roundtrip() { - use crate::config::schema::LarkConfig; + use crate::config::schema::{LarkConfig, LarkReceiveMode}; let lc = LarkConfig { app_id: "app".into(), app_secret: "secret".into(), @@ -598,6 +1130,8 @@ mod tests { verification_token: Some("tok".into()), allowed_users: vec!["*".into()], use_feishu: false, + receive_mode: LarkReceiveMode::Webhook, + port: Some(9898), }; let toml_str = toml::to_string(&lc).unwrap(); let parsed: LarkConfig = toml::from_str(&toml_str).unwrap(); @@ -622,7 +1156,7 @@ mod tests { "id".into(), "secret".into(), "token".into(), - 9898, + None, vec!["*".into()], ); let payload = serde_json::json!({ diff --git a/src/channels/mod.rs b/src/channels/mod.rs index d46a998..813a2ba 100644 --- a/src/channels/mod.rs +++ b/src/channels/mod.rs @@ -694,7 +694,7 @@ pub async fn doctor_channels(config: Config) -> Result<()> { lk.app_id.clone(), lk.app_secret.clone(), lk.verification_token.clone().unwrap_or_default(), - 9898, + lk.port, lk.allowed_users.clone(), )), )); @@ -963,13 +963,7 @@ pub async fn start_channels(config: Config) -> Result<()> { } if let Some(ref lk) = config.channels_config.lark { - channels.push(Arc::new(LarkChannel::new( - lk.app_id.clone(), - lk.app_secret.clone(), - lk.verification_token.clone().unwrap_or_default(), - 9898, - lk.allowed_users.clone(), - ))); + channels.push(Arc::new(LarkChannel::from_config(lk))); } if let Some(ref dt) = config.channels_config.dingtalk { diff --git a/src/config/mod.rs b/src/config/mod.rs index 4fec9ae..07b5c0b 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -39,7 +39,19 @@ mod tests { listen_to_bots: false, }; + let lark = LarkConfig { + app_id: "app-id".into(), + app_secret: "app-secret".into(), + encrypt_key: None, + verification_token: None, + allowed_users: vec![], + use_feishu: false, + receive_mode: crate::config::schema::LarkReceiveMode::Websocket, + port: None, + }; + assert_eq!(telegram.allowed_users.len(), 1); assert_eq!(discord.guild_id.as_deref(), Some("123")); + assert_eq!(lark.app_id, "app-id"); } } diff --git a/src/config/schema.rs b/src/config/schema.rs index d78e53f..40b4bcb 100644 --- a/src/config/schema.rs +++ b/src/config/schema.rs @@ -1397,8 +1397,20 @@ fn default_irc_port() -> u16 { 6697 } -/// Lark/Feishu configuration for messaging integration -/// Lark is the international version, Feishu is the Chinese version +/// How ZeroClaw receives events from Feishu / Lark. +/// +/// - `websocket` (default) — persistent WSS long-connection; no public URL required. +/// - `webhook` — HTTP callback server; requires a public HTTPS endpoint. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)] +#[serde(rename_all = "lowercase")] +pub enum LarkReceiveMode { + #[default] + Websocket, + Webhook, +} + +/// Lark/Feishu configuration for messaging integration. +/// Lark is the international version; Feishu is the Chinese version. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct LarkConfig { /// App ID from Lark/Feishu developer console @@ -1417,6 +1429,13 @@ pub struct LarkConfig { /// Whether to use the Feishu (Chinese) endpoint instead of Lark (International) #[serde(default)] pub use_feishu: bool, + /// Event receive mode: "websocket" (default) or "webhook" + #[serde(default)] + pub receive_mode: LarkReceiveMode, + /// HTTP port for webhook mode only. Must be set when receive_mode = "webhook". + /// Not required (and ignored) for websocket mode. + #[serde(default)] + pub port: Option, } // ── Security Config ───────────────────────────────────────────────── @@ -3105,4 +3124,239 @@ default_model = "legacy-model" assert_eq!(parsed.boards[0].board, "nucleo-f401re"); assert_eq!(parsed.boards[0].path.as_deref(), Some("/dev/ttyACM0")); } + + #[test] + fn lark_config_serde() { + let lc = LarkConfig { + app_id: "cli_123456".into(), + app_secret: "secret_abc".into(), + encrypt_key: Some("encrypt_key".into()), + verification_token: Some("verify_token".into()), + allowed_users: vec!["user_123".into(), "user_456".into()], + use_feishu: true, + receive_mode: LarkReceiveMode::Websocket, + port: None, + }; + let json = serde_json::to_string(&lc).unwrap(); + let parsed: LarkConfig = serde_json::from_str(&json).unwrap(); + assert_eq!(parsed.app_id, "cli_123456"); + assert_eq!(parsed.app_secret, "secret_abc"); + assert_eq!(parsed.encrypt_key.as_deref(), Some("encrypt_key")); + assert_eq!(parsed.verification_token.as_deref(), Some("verify_token")); + assert_eq!(parsed.allowed_users.len(), 2); + assert!(parsed.use_feishu); + } + + #[test] + fn lark_config_toml_roundtrip() { + let lc = LarkConfig { + app_id: "cli_123456".into(), + app_secret: "secret_abc".into(), + encrypt_key: Some("encrypt_key".into()), + verification_token: Some("verify_token".into()), + allowed_users: vec!["*".into()], + use_feishu: false, + receive_mode: LarkReceiveMode::Webhook, + port: Some(9898), + }; + let toml_str = toml::to_string(&lc).unwrap(); + let parsed: LarkConfig = toml::from_str(&toml_str).unwrap(); + assert_eq!(parsed.app_id, "cli_123456"); + assert_eq!(parsed.app_secret, "secret_abc"); + assert!(!parsed.use_feishu); + } + + #[test] + fn lark_config_deserializes_without_optional_fields() { + let json = r#"{"app_id":"cli_123","app_secret":"secret"}"#; + let parsed: LarkConfig = serde_json::from_str(json).unwrap(); + assert!(parsed.encrypt_key.is_none()); + assert!(parsed.verification_token.is_none()); + assert!(parsed.allowed_users.is_empty()); + assert!(!parsed.use_feishu); + } + + #[test] + fn lark_config_defaults_to_lark_endpoint() { + let json = r#"{"app_id":"cli_123","app_secret":"secret"}"#; + let parsed: LarkConfig = serde_json::from_str(json).unwrap(); + assert!( + !parsed.use_feishu, + "use_feishu should default to false (Lark)" + ); + } + + #[test] + fn lark_config_with_wildcard_allowed_users() { + let json = r#"{"app_id":"cli_123","app_secret":"secret","allowed_users":["*"]}"#; + let parsed: LarkConfig = serde_json::from_str(json).unwrap(); + assert_eq!(parsed.allowed_users, vec!["*"]); + } + + // ══════════════════════════════════════════════════════════ + // AGENT DELEGATION CONFIG TESTS + // ══════════════════════════════════════════════════════════ + + #[test] + fn agents_config_default_empty() { + let c = Config::default(); + assert!(c.agents.is_empty()); + } + + #[test] + fn agents_config_backward_compat_missing_section() { + let minimal = r#" +workspace_dir = "/tmp/ws" +config_path = "/tmp/config.toml" +default_temperature = 0.7 +"#; + let parsed: Config = toml::from_str(minimal).unwrap(); + assert!(parsed.agents.is_empty()); + } + + #[test] + fn agents_config_toml_roundtrip() { + let toml_str = r#" +default_temperature = 0.7 + +[agents.researcher] +provider = "gemini" +model = "gemini-2.0-flash" +system_prompt = "You are a research assistant." +max_depth = 2 + +[agents.coder] +provider = "openrouter" +model = "anthropic/claude-sonnet-4-20250514" +"#; + let parsed: Config = toml::from_str(toml_str).unwrap(); + assert_eq!(parsed.agents.len(), 2); + + let researcher = &parsed.agents["researcher"]; + assert_eq!(researcher.provider, "gemini"); + assert_eq!(researcher.model, "gemini-2.0-flash"); + assert_eq!( + researcher.system_prompt.as_deref(), + Some("You are a research assistant.") + ); + assert_eq!(researcher.max_depth, 2); + assert!(researcher.api_key.is_none()); + assert!(researcher.temperature.is_none()); + + let coder = &parsed.agents["coder"]; + assert_eq!(coder.provider, "openrouter"); + assert_eq!(coder.model, "anthropic/claude-sonnet-4-20250514"); + assert!(coder.system_prompt.is_none()); + assert_eq!(coder.max_depth, 3); // default + } + + #[test] + fn agents_config_with_api_key_and_temperature() { + let toml_str = r#" +[agents.fast] +provider = "groq" +model = "llama-3.3-70b-versatile" +api_key = "gsk-test-key" +temperature = 0.3 +"#; + let parsed: HashMap = toml::from_str::(toml_str) + .unwrap()["agents"] + .clone() + .try_into() + .unwrap(); + let fast = &parsed["fast"]; + assert_eq!(fast.api_key.as_deref(), Some("gsk-test-key")); + assert!((fast.temperature.unwrap() - 0.3).abs() < f64::EPSILON); + } + + #[test] + fn agent_api_key_encrypted_on_save_and_decrypted_on_load() { + let tmp = TempDir::new().unwrap(); + let zeroclaw_dir = tmp.path(); + let config_path = zeroclaw_dir.join("config.toml"); + + // Create a config with a plaintext agent API key + let mut agents = HashMap::new(); + agents.insert( + "test_agent".to_string(), + DelegateAgentConfig { + provider: "openrouter".to_string(), + model: "test-model".to_string(), + system_prompt: None, + api_key: Some("sk-super-secret".to_string()), + temperature: None, + max_depth: 3, + }, + ); + let config = Config { + config_path: config_path.clone(), + workspace_dir: zeroclaw_dir.join("workspace"), + secrets: SecretsConfig { encrypt: true }, + agents, + ..Config::default() + }; + std::fs::create_dir_all(&config.workspace_dir).unwrap(); + config.save().unwrap(); + + // Read the raw TOML and verify the key is encrypted (not plaintext) + let raw = std::fs::read_to_string(&config_path).unwrap(); + assert!( + !raw.contains("sk-super-secret"), + "Plaintext API key should not appear in saved config" + ); + assert!( + raw.contains("enc2:"), + "Encrypted key should use enc2: prefix" + ); + + // Parse and decrypt — simulate load_or_init by reading + decrypting + let store = crate::security::SecretStore::new(zeroclaw_dir, true); + let mut loaded: Config = toml::from_str(&raw).unwrap(); + for agent in loaded.agents.values_mut() { + if let Some(ref encrypted_key) = agent.api_key { + agent.api_key = Some(store.decrypt(encrypted_key).unwrap()); + } + } + assert_eq!( + loaded.agents["test_agent"].api_key.as_deref(), + Some("sk-super-secret"), + "Decrypted key should match original" + ); + } + + #[test] + fn agent_api_key_not_encrypted_when_disabled() { + let tmp = TempDir::new().unwrap(); + let zeroclaw_dir = tmp.path(); + let config_path = zeroclaw_dir.join("config.toml"); + + let mut agents = HashMap::new(); + agents.insert( + "test_agent".to_string(), + DelegateAgentConfig { + provider: "openrouter".to_string(), + model: "test-model".to_string(), + system_prompt: None, + api_key: Some("sk-plaintext-ok".to_string()), + temperature: None, + max_depth: 3, + }, + ); + let config = Config { + config_path: config_path.clone(), + workspace_dir: zeroclaw_dir.join("workspace"), + secrets: SecretsConfig { encrypt: false }, + agents, + ..Config::default() + }; + std::fs::create_dir_all(&config.workspace_dir).unwrap(); + config.save().unwrap(); + + let raw = std::fs::read_to_string(&config_path).unwrap(); + assert!( + raw.contains("sk-plaintext-ok"), + "With encryption disabled, key should remain plaintext" + ); + assert!(!raw.contains("enc2:"), "No encryption prefix when disabled"); + } } diff --git a/src/daemon/mod.rs b/src/daemon/mod.rs index c2f4487..a223597 100644 --- a/src/daemon/mod.rs +++ b/src/daemon/mod.rs @@ -216,6 +216,7 @@ fn has_supervised_channels(config: &Config) -> bool { || config.channels_config.matrix.is_some() || config.channels_config.whatsapp.is_some() || config.channels_config.email.is_some() + || config.channels_config.lark.is_some() } #[cfg(test)]