zeroclaw/src/channels/lark.rs

644 lines
20 KiB
Rust

use super::traits::{Channel, ChannelMessage};
use async_trait::async_trait;
use std::sync::Arc;
use tokio::sync::RwLock;
use uuid::Uuid;
const FEISHU_BASE_URL: &str = "https://open.feishu.cn/open-apis";
/// Lark/Feishu channel — receives events via HTTP callback, sends via Open API
pub struct LarkChannel {
app_id: String,
app_secret: String,
verification_token: String,
port: u16,
allowed_users: Vec<String>,
client: reqwest::Client,
/// Cached tenant access token
tenant_token: Arc<RwLock<Option<String>>>,
}
impl LarkChannel {
pub fn new(
app_id: String,
app_secret: String,
verification_token: String,
port: u16,
allowed_users: Vec<String>,
) -> Self {
Self {
app_id,
app_secret,
verification_token,
port,
allowed_users,
client: reqwest::Client::new(),
tenant_token: Arc::new(RwLock::new(None)),
}
}
/// Check if a user open_id is allowed
fn is_user_allowed(&self, open_id: &str) -> bool {
self.allowed_users.iter().any(|u| u == "*" || u == open_id)
}
/// Get or refresh tenant access token
async fn get_tenant_access_token(&self) -> anyhow::Result<String> {
// Check cache first
{
let cached = self.tenant_token.read().await;
if let Some(ref token) = *cached {
return Ok(token.clone());
}
}
let url = format!("{FEISHU_BASE_URL}/auth/v3/tenant_access_token/internal");
let body = serde_json::json!({
"app_id": self.app_id,
"app_secret": self.app_secret,
});
let resp = self.client.post(&url).json(&body).send().await?;
let data: serde_json::Value = resp.json().await?;
let code = data.get("code").and_then(|c| c.as_i64()).unwrap_or(-1);
if code != 0 {
let msg = data
.get("msg")
.and_then(|m| m.as_str())
.unwrap_or("unknown error");
anyhow::bail!("Lark tenant_access_token failed: {msg}");
}
let token = data
.get("tenant_access_token")
.and_then(|t| t.as_str())
.ok_or_else(|| anyhow::anyhow!("missing tenant_access_token in response"))?
.to_string();
// Cache it
{
let mut cached = self.tenant_token.write().await;
*cached = Some(token.clone());
}
Ok(token)
}
/// Invalidate cached token (called on 401)
async fn invalidate_token(&self) {
let mut cached = self.tenant_token.write().await;
*cached = None;
}
/// Parse an event callback payload and extract text messages
pub fn parse_event_payload(&self, payload: &serde_json::Value) -> Vec<ChannelMessage> {
let mut messages = Vec::new();
// Lark event v2 structure:
// { "header": { "event_type": "im.message.receive_v1" }, "event": { "message": { ... }, "sender": { ... } } }
let event_type = payload
.pointer("/header/event_type")
.and_then(|e| e.as_str())
.unwrap_or("");
if event_type != "im.message.receive_v1" {
return messages;
}
let event = match payload.get("event") {
Some(e) => e,
None => return messages,
};
// Extract sender open_id
let open_id = event
.pointer("/sender/sender_id/open_id")
.and_then(|s| s.as_str())
.unwrap_or("");
if open_id.is_empty() {
return messages;
}
// Check allowlist
if !self.is_user_allowed(open_id) {
tracing::warn!("Lark: ignoring message from unauthorized user: {open_id}");
return messages;
}
// Extract message content (text only)
let msg_type = event
.pointer("/message/message_type")
.and_then(|t| t.as_str())
.unwrap_or("");
if msg_type != "text" {
tracing::debug!("Lark: skipping non-text message type: {msg_type}");
return messages;
}
let content_str = event
.pointer("/message/content")
.and_then(|c| c.as_str())
.unwrap_or("");
// content is a JSON string like "{\"text\":\"hello\"}"
let text = serde_json::from_str::<serde_json::Value>(content_str)
.ok()
.and_then(|v| v.get("text").and_then(|t| t.as_str()).map(String::from))
.unwrap_or_default();
if text.is_empty() {
return messages;
}
let timestamp = event
.pointer("/message/create_time")
.and_then(|t| t.as_str())
.and_then(|t| t.parse::<u64>().ok())
// Lark timestamps are in milliseconds
.map(|ms| ms / 1000)
.unwrap_or_else(|| {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
});
let chat_id = event
.pointer("/message/chat_id")
.and_then(|c| c.as_str())
.unwrap_or(open_id);
messages.push(ChannelMessage {
id: Uuid::new_v4().to_string(),
sender: chat_id.to_string(),
content: text,
channel: "lark".to_string(),
timestamp,
});
messages
}
}
#[async_trait]
impl Channel for LarkChannel {
fn name(&self) -> &str {
"lark"
}
async fn send(&self, message: &str, recipient: &str) -> anyhow::Result<()> {
let token = self.get_tenant_access_token().await?;
let url = format!("{FEISHU_BASE_URL}/im/v1/messages?receive_id_type=chat_id");
let content = serde_json::json!({ "text": message }).to_string();
let body = serde_json::json!({
"receive_id": recipient,
"msg_type": "text",
"content": content,
});
let resp = self
.client
.post(&url)
.header("Authorization", format!("Bearer {token}"))
.header("Content-Type", "application/json; charset=utf-8")
.json(&body)
.send()
.await?;
if resp.status().as_u16() == 401 {
// Token expired, invalidate and retry once
self.invalidate_token().await;
let new_token = self.get_tenant_access_token().await?;
let retry_resp = self
.client
.post(&url)
.header("Authorization", format!("Bearer {new_token}"))
.header("Content-Type", "application/json; charset=utf-8")
.json(&body)
.send()
.await?;
if !retry_resp.status().is_success() {
let err = retry_resp.text().await.unwrap_or_default();
anyhow::bail!("Lark send failed after token refresh: {err}");
}
return Ok(());
}
if !resp.status().is_success() {
let err = resp.text().await.unwrap_or_default();
anyhow::bail!("Lark send failed: {err}");
}
Ok(())
}
async fn listen(&self, tx: tokio::sync::mpsc::Sender<ChannelMessage>) -> anyhow::Result<()> {
use axum::{extract::State, routing::post, Json, Router};
#[derive(Clone)]
struct AppState {
verification_token: String,
channel: Arc<LarkChannel>,
tx: tokio::sync::mpsc::Sender<ChannelMessage>,
}
async fn handle_event(
State(state): State<AppState>,
Json(payload): Json<serde_json::Value>,
) -> axum::response::Response {
use axum::http::StatusCode;
use axum::response::IntoResponse;
// URL verification challenge
if let Some(challenge) = payload.get("challenge").and_then(|c| c.as_str()) {
// Verify token if present
let token_ok = payload
.get("token")
.and_then(|t| t.as_str())
.map_or(true, |t| t == state.verification_token);
if !token_ok {
return (StatusCode::FORBIDDEN, "invalid token").into_response();
}
let resp = serde_json::json!({ "challenge": challenge });
return (StatusCode::OK, Json(resp)).into_response();
}
// Parse event messages
let messages = state.channel.parse_event_payload(&payload);
for msg in messages {
if state.tx.send(msg).await.is_err() {
tracing::warn!("Lark: message channel closed");
break;
}
}
(StatusCode::OK, "ok").into_response()
}
let state = AppState {
verification_token: self.verification_token.clone(),
channel: Arc::new(LarkChannel::new(
self.app_id.clone(),
self.app_secret.clone(),
self.verification_token.clone(),
self.port,
self.allowed_users.clone(),
)),
tx,
};
let app = Router::new()
.route("/lark", post(handle_event))
.with_state(state);
let addr = std::net::SocketAddr::from(([0, 0, 0, 0], self.port));
tracing::info!("Lark event callback server listening on {addr}");
let listener = tokio::net::TcpListener::bind(addr).await?;
axum::serve(listener, app).await?;
Ok(())
}
async fn health_check(&self) -> bool {
self.get_tenant_access_token().await.is_ok()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_channel() -> LarkChannel {
LarkChannel::new(
"cli_test_app_id".into(),
"test_app_secret".into(),
"test_verification_token".into(),
9898,
vec!["ou_testuser123".into()],
)
}
#[test]
fn lark_channel_name() {
let ch = make_channel();
assert_eq!(ch.name(), "lark");
}
#[test]
fn lark_user_allowed_exact() {
let ch = make_channel();
assert!(ch.is_user_allowed("ou_testuser123"));
assert!(!ch.is_user_allowed("ou_other"));
}
#[test]
fn lark_user_allowed_wildcard() {
let ch = LarkChannel::new(
"id".into(),
"secret".into(),
"token".into(),
9898,
vec!["*".into()],
);
assert!(ch.is_user_allowed("ou_anyone"));
}
#[test]
fn lark_user_denied_empty() {
let ch = LarkChannel::new("id".into(), "secret".into(), "token".into(), 9898, vec![]);
assert!(!ch.is_user_allowed("ou_anyone"));
}
#[test]
fn lark_parse_challenge() {
let ch = make_channel();
let payload = serde_json::json!({
"challenge": "abc123",
"token": "test_verification_token",
"type": "url_verification"
});
// Challenge payloads should not produce messages
let msgs = ch.parse_event_payload(&payload);
assert!(msgs.is_empty());
}
#[test]
fn lark_parse_valid_text_message() {
let ch = make_channel();
let payload = serde_json::json!({
"header": {
"event_type": "im.message.receive_v1"
},
"event": {
"sender": {
"sender_id": {
"open_id": "ou_testuser123"
}
},
"message": {
"message_type": "text",
"content": "{\"text\":\"Hello ZeroClaw!\"}",
"chat_id": "oc_chat123",
"create_time": "1699999999000"
}
}
});
let msgs = ch.parse_event_payload(&payload);
assert_eq!(msgs.len(), 1);
assert_eq!(msgs[0].content, "Hello ZeroClaw!");
assert_eq!(msgs[0].sender, "oc_chat123");
assert_eq!(msgs[0].channel, "lark");
assert_eq!(msgs[0].timestamp, 1_699_999_999);
}
#[test]
fn lark_parse_unauthorized_user() {
let ch = make_channel();
let payload = serde_json::json!({
"header": { "event_type": "im.message.receive_v1" },
"event": {
"sender": { "sender_id": { "open_id": "ou_unauthorized" } },
"message": {
"message_type": "text",
"content": "{\"text\":\"spam\"}",
"chat_id": "oc_chat",
"create_time": "1000"
}
}
});
let msgs = ch.parse_event_payload(&payload);
assert!(msgs.is_empty());
}
#[test]
fn lark_parse_non_text_message_skipped() {
let ch = LarkChannel::new(
"id".into(),
"secret".into(),
"token".into(),
9898,
vec!["*".into()],
);
let payload = serde_json::json!({
"header": { "event_type": "im.message.receive_v1" },
"event": {
"sender": { "sender_id": { "open_id": "ou_user" } },
"message": {
"message_type": "image",
"content": "{}",
"chat_id": "oc_chat"
}
}
});
let msgs = ch.parse_event_payload(&payload);
assert!(msgs.is_empty());
}
#[test]
fn lark_parse_empty_text_skipped() {
let ch = LarkChannel::new(
"id".into(),
"secret".into(),
"token".into(),
9898,
vec!["*".into()],
);
let payload = serde_json::json!({
"header": { "event_type": "im.message.receive_v1" },
"event": {
"sender": { "sender_id": { "open_id": "ou_user" } },
"message": {
"message_type": "text",
"content": "{\"text\":\"\"}",
"chat_id": "oc_chat"
}
}
});
let msgs = ch.parse_event_payload(&payload);
assert!(msgs.is_empty());
}
#[test]
fn lark_parse_wrong_event_type() {
let ch = make_channel();
let payload = serde_json::json!({
"header": { "event_type": "im.chat.disbanded_v1" },
"event": {}
});
let msgs = ch.parse_event_payload(&payload);
assert!(msgs.is_empty());
}
#[test]
fn lark_parse_missing_sender() {
let ch = LarkChannel::new(
"id".into(),
"secret".into(),
"token".into(),
9898,
vec!["*".into()],
);
let payload = serde_json::json!({
"header": { "event_type": "im.message.receive_v1" },
"event": {
"message": {
"message_type": "text",
"content": "{\"text\":\"hello\"}",
"chat_id": "oc_chat"
}
}
});
let msgs = ch.parse_event_payload(&payload);
assert!(msgs.is_empty());
}
#[test]
fn lark_parse_unicode_message() {
let ch = LarkChannel::new(
"id".into(),
"secret".into(),
"token".into(),
9898,
vec!["*".into()],
);
let payload = serde_json::json!({
"header": { "event_type": "im.message.receive_v1" },
"event": {
"sender": { "sender_id": { "open_id": "ou_user" } },
"message": {
"message_type": "text",
"content": "{\"text\":\"你好世界 🌍\"}",
"chat_id": "oc_chat",
"create_time": "1000"
}
}
});
let msgs = ch.parse_event_payload(&payload);
assert_eq!(msgs.len(), 1);
assert_eq!(msgs[0].content, "你好世界 🌍");
}
#[test]
fn lark_parse_missing_event() {
let ch = make_channel();
let payload = serde_json::json!({
"header": { "event_type": "im.message.receive_v1" }
});
let msgs = ch.parse_event_payload(&payload);
assert!(msgs.is_empty());
}
#[test]
fn lark_parse_invalid_content_json() {
let ch = LarkChannel::new(
"id".into(),
"secret".into(),
"token".into(),
9898,
vec!["*".into()],
);
let payload = serde_json::json!({
"header": { "event_type": "im.message.receive_v1" },
"event": {
"sender": { "sender_id": { "open_id": "ou_user" } },
"message": {
"message_type": "text",
"content": "not valid json",
"chat_id": "oc_chat"
}
}
});
let msgs = ch.parse_event_payload(&payload);
assert!(msgs.is_empty());
}
#[test]
fn lark_config_serde() {
use crate::config::schema::LarkConfig;
let lc = LarkConfig {
app_id: "cli_app123".into(),
app_secret: "secret456".into(),
encrypt_key: None,
verification_token: Some("vtoken789".into()),
allowed_users: vec!["ou_user1".into(), "ou_user2".into()],
use_feishu: false,
};
let json = serde_json::to_string(&lc).unwrap();
let parsed: LarkConfig = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.app_id, "cli_app123");
assert_eq!(parsed.app_secret, "secret456");
assert_eq!(parsed.verification_token.as_deref(), Some("vtoken789"));
assert_eq!(parsed.allowed_users.len(), 2);
}
#[test]
fn lark_config_toml_roundtrip() {
use crate::config::schema::LarkConfig;
let lc = LarkConfig {
app_id: "app".into(),
app_secret: "secret".into(),
encrypt_key: None,
verification_token: Some("tok".into()),
allowed_users: vec!["*".into()],
use_feishu: false,
};
let toml_str = toml::to_string(&lc).unwrap();
let parsed: LarkConfig = toml::from_str(&toml_str).unwrap();
assert_eq!(parsed.app_id, "app");
assert_eq!(parsed.verification_token.as_deref(), Some("tok"));
assert_eq!(parsed.allowed_users, vec!["*"]);
}
#[test]
fn lark_config_defaults_optional_fields() {
use crate::config::schema::LarkConfig;
let json = r#"{"app_id":"a","app_secret":"s"}"#;
let parsed: LarkConfig = serde_json::from_str(json).unwrap();
assert!(parsed.verification_token.is_none());
assert!(parsed.allowed_users.is_empty());
}
#[test]
fn lark_parse_fallback_sender_to_open_id() {
// When chat_id is missing, sender should fall back to open_id
let ch = LarkChannel::new(
"id".into(),
"secret".into(),
"token".into(),
9898,
vec!["*".into()],
);
let payload = serde_json::json!({
"header": { "event_type": "im.message.receive_v1" },
"event": {
"sender": { "sender_id": { "open_id": "ou_user" } },
"message": {
"message_type": "text",
"content": "{\"text\":\"hello\"}",
"create_time": "1000"
}
}
});
let msgs = ch.parse_event_payload(&payload);
assert_eq!(msgs.len(), 1);
assert_eq!(msgs[0].sender, "ou_user");
}
}