Merge branch 'main' into pr-484-clean
This commit is contained in:
commit
ee05d62ce4
90 changed files with 6937 additions and 1403 deletions
|
|
@ -251,6 +251,7 @@ impl Agent {
|
|||
let provider: Box<dyn Provider> = providers::create_routed_provider(
|
||||
provider_name,
|
||||
config.api_key.as_deref(),
|
||||
config.api_url.as_deref(),
|
||||
&config.reliability,
|
||||
&config.model_routes,
|
||||
&model_name,
|
||||
|
|
@ -388,7 +389,7 @@ impl Agent {
|
|||
if self.auto_save {
|
||||
let _ = self
|
||||
.memory
|
||||
.store("user_msg", user_message, MemoryCategory::Conversation)
|
||||
.store("user_msg", user_message, MemoryCategory::Conversation, None)
|
||||
.await;
|
||||
}
|
||||
|
||||
|
|
@ -447,7 +448,7 @@ impl Agent {
|
|||
let summary = truncate_with_ellipsis(&final_text, 100);
|
||||
let _ = self
|
||||
.memory
|
||||
.store("assistant_resp", &summary, MemoryCategory::Daily)
|
||||
.store("assistant_resp", &summary, MemoryCategory::Daily, None)
|
||||
.await;
|
||||
}
|
||||
|
||||
|
|
@ -557,6 +558,7 @@ pub async fn run(
|
|||
agent.observer.record_event(&ObserverEvent::AgentEnd {
|
||||
duration: start.elapsed(),
|
||||
tokens_used: None,
|
||||
cost_usd: None,
|
||||
});
|
||||
|
||||
Ok(())
|
||||
|
|
|
|||
|
|
@ -7,14 +7,70 @@ use crate::security::SecurityPolicy;
|
|||
use crate::tools::{self, Tool};
|
||||
use crate::util::truncate_with_ellipsis;
|
||||
use anyhow::Result;
|
||||
use regex::{Regex, RegexSet};
|
||||
use std::fmt::Write;
|
||||
use std::io::Write as _;
|
||||
use std::sync::Arc;
|
||||
use std::sync::{Arc, LazyLock};
|
||||
use std::time::Instant;
|
||||
use uuid::Uuid;
|
||||
|
||||
/// Maximum agentic tool-use iterations per user message to prevent runaway loops.
|
||||
const MAX_TOOL_ITERATIONS: usize = 10;
|
||||
|
||||
static SENSITIVE_KEY_PATTERNS: LazyLock<RegexSet> = LazyLock::new(|| {
|
||||
RegexSet::new([
|
||||
r"(?i)token",
|
||||
r"(?i)api[_-]?key",
|
||||
r"(?i)password",
|
||||
r"(?i)secret",
|
||||
r"(?i)user[_-]?key",
|
||||
r"(?i)bearer",
|
||||
r"(?i)credential",
|
||||
])
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
static SENSITIVE_KV_REGEX: LazyLock<Regex> = LazyLock::new(|| {
|
||||
Regex::new(r#"(?i)(token|api[_-]?key|password|secret|user[_-]?key|bearer|credential)["']?\s*[:=]\s*(?:"([^"]{8,})"|'([^']{8,})'|([a-zA-Z0-9_\-\.]{8,}))"#).unwrap()
|
||||
});
|
||||
|
||||
/// Scrub credentials from tool output to prevent accidental exfiltration.
|
||||
/// Replaces known credential patterns with a redacted placeholder while preserving
|
||||
/// a small prefix for context.
|
||||
fn scrub_credentials(input: &str) -> String {
|
||||
SENSITIVE_KV_REGEX
|
||||
.replace_all(input, |caps: ®ex::Captures| {
|
||||
let full_match = &caps[0];
|
||||
let key = &caps[1];
|
||||
let val = caps
|
||||
.get(2)
|
||||
.or(caps.get(3))
|
||||
.or(caps.get(4))
|
||||
.map(|m| m.as_str())
|
||||
.unwrap_or("");
|
||||
|
||||
// Preserve first 4 chars for context, then redact
|
||||
let prefix = if val.len() > 4 { &val[..4] } else { "" };
|
||||
|
||||
if full_match.contains(':') {
|
||||
if full_match.contains('"') {
|
||||
format!("\"{}\": \"{}*[REDACTED]\"", key, prefix)
|
||||
} else {
|
||||
format!("{}: {}*[REDACTED]", key, prefix)
|
||||
}
|
||||
} else if full_match.contains('=') {
|
||||
if full_match.contains('"') {
|
||||
format!("{}=\"{}*[REDACTED]\"", key, prefix)
|
||||
} else {
|
||||
format!("{}={}*[REDACTED]", key, prefix)
|
||||
}
|
||||
} else {
|
||||
format!("{}: {}*[REDACTED]", key, prefix)
|
||||
}
|
||||
})
|
||||
.to_string()
|
||||
}
|
||||
|
||||
/// Trigger auto-compaction when non-system message count exceeds this threshold.
|
||||
const MAX_HISTORY_MESSAGES: usize = 50;
|
||||
|
||||
|
|
@ -145,7 +201,7 @@ async fn build_context(mem: &dyn Memory, user_msg: &str) -> String {
|
|||
let mut context = String::new();
|
||||
|
||||
// Pull relevant memories for this message
|
||||
if let Ok(entries) = mem.recall(user_msg, 5).await {
|
||||
if let Ok(entries) = mem.recall(user_msg, 5, None).await {
|
||||
if !entries.is_empty() {
|
||||
context.push_str("[Memory context]\n");
|
||||
for entry in &entries {
|
||||
|
|
@ -436,6 +492,7 @@ struct ParsedToolCall {
|
|||
/// Execute a single turn of the agent loop: send messages, parse tool calls,
|
||||
/// execute tools, and loop until the LLM produces a final text response.
|
||||
/// When `silent` is true, suppresses stdout (for channel use).
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) async fn agent_turn(
|
||||
provider: &dyn Provider,
|
||||
history: &mut Vec<ChatMessage>,
|
||||
|
|
@ -461,6 +518,7 @@ pub(crate) async fn agent_turn(
|
|||
|
||||
/// Execute a single turn of the agent loop: send messages, parse tool calls,
|
||||
/// execute tools, and loop until the LLM produces a final text response.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) async fn run_tool_call_loop(
|
||||
provider: &dyn Provider,
|
||||
history: &mut Vec<ChatMessage>,
|
||||
|
|
@ -606,7 +664,7 @@ pub(crate) async fn run_tool_call_loop(
|
|||
success: r.success,
|
||||
});
|
||||
if r.success {
|
||||
r.output
|
||||
scrub_credentials(&r.output)
|
||||
} else {
|
||||
format!("Error: {}", r.error.unwrap_or_else(|| r.output))
|
||||
}
|
||||
|
|
@ -749,6 +807,7 @@ pub async fn run(
|
|||
let provider: Box<dyn Provider> = providers::create_routed_provider(
|
||||
provider_name,
|
||||
config.api_key.as_deref(),
|
||||
config.api_url.as_deref(),
|
||||
&config.reliability,
|
||||
&config.model_routes,
|
||||
model_name,
|
||||
|
|
@ -912,7 +971,7 @@ pub async fn run(
|
|||
if config.memory.auto_save {
|
||||
let user_key = autosave_memory_key("user_msg");
|
||||
let _ = mem
|
||||
.store(&user_key, &msg, MemoryCategory::Conversation)
|
||||
.store(&user_key, &msg, MemoryCategory::Conversation, None)
|
||||
.await;
|
||||
}
|
||||
|
||||
|
|
@ -955,7 +1014,7 @@ pub async fn run(
|
|||
let summary = truncate_with_ellipsis(&response, 100);
|
||||
let response_key = autosave_memory_key("assistant_resp");
|
||||
let _ = mem
|
||||
.store(&response_key, &summary, MemoryCategory::Daily)
|
||||
.store(&response_key, &summary, MemoryCategory::Daily, None)
|
||||
.await;
|
||||
}
|
||||
} else {
|
||||
|
|
@ -978,7 +1037,7 @@ pub async fn run(
|
|||
if config.memory.auto_save {
|
||||
let user_key = autosave_memory_key("user_msg");
|
||||
let _ = mem
|
||||
.store(&user_key, &msg.content, MemoryCategory::Conversation)
|
||||
.store(&user_key, &msg.content, MemoryCategory::Conversation, None)
|
||||
.await;
|
||||
}
|
||||
|
||||
|
|
@ -1036,7 +1095,7 @@ pub async fn run(
|
|||
let summary = truncate_with_ellipsis(&response, 100);
|
||||
let response_key = autosave_memory_key("assistant_resp");
|
||||
let _ = mem
|
||||
.store(&response_key, &summary, MemoryCategory::Daily)
|
||||
.store(&response_key, &summary, MemoryCategory::Daily, None)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
|
@ -1048,6 +1107,7 @@ pub async fn run(
|
|||
observer.record_event(&ObserverEvent::AgentEnd {
|
||||
duration,
|
||||
tokens_used: None,
|
||||
cost_usd: None,
|
||||
});
|
||||
|
||||
Ok(final_output)
|
||||
|
|
@ -1104,6 +1164,7 @@ pub async fn process_message(config: Config, message: &str) -> Result<String> {
|
|||
let provider: Box<dyn Provider> = providers::create_routed_provider(
|
||||
provider_name,
|
||||
config.api_key.as_deref(),
|
||||
config.api_url.as_deref(),
|
||||
&config.reliability,
|
||||
&config.model_routes,
|
||||
&model_name,
|
||||
|
|
@ -1217,6 +1278,25 @@ pub async fn process_message(config: Config, message: &str) -> Result<String> {
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_scrub_credentials() {
|
||||
let input = "API_KEY=sk-1234567890abcdef; token: 1234567890; password=\"secret123456\"";
|
||||
let scrubbed = scrub_credentials(input);
|
||||
assert!(scrubbed.contains("API_KEY=sk-1*[REDACTED]"));
|
||||
assert!(scrubbed.contains("token: 1234*[REDACTED]"));
|
||||
assert!(scrubbed.contains("password=\"secr*[REDACTED]\""));
|
||||
assert!(!scrubbed.contains("abcdef"));
|
||||
assert!(!scrubbed.contains("secret123456"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scrub_credentials_json() {
|
||||
let input = r#"{"api_key": "sk-1234567890", "other": "public"}"#;
|
||||
let scrubbed = scrub_credentials(input);
|
||||
assert!(scrubbed.contains("\"api_key\": \"sk-1*[REDACTED]\""));
|
||||
assert!(scrubbed.contains("public"));
|
||||
}
|
||||
use crate::memory::{Memory, MemoryCategory, SqliteMemory};
|
||||
use tempfile::TempDir;
|
||||
|
||||
|
|
@ -1496,16 +1576,16 @@ I will now call the tool with this payload:
|
|||
let key1 = autosave_memory_key("user_msg");
|
||||
let key2 = autosave_memory_key("user_msg");
|
||||
|
||||
mem.store(&key1, "I'm Paul", MemoryCategory::Conversation)
|
||||
mem.store(&key1, "I'm Paul", MemoryCategory::Conversation, None)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store(&key2, "I'm 45", MemoryCategory::Conversation)
|
||||
mem.store(&key2, "I'm 45", MemoryCategory::Conversation, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(mem.count().await.unwrap(), 2);
|
||||
|
||||
let recalled = mem.recall("45", 5).await.unwrap();
|
||||
let recalled = mem.recall("45", 5, None).await.unwrap();
|
||||
assert!(recalled.iter().any(|entry| entry.content.contains("45")));
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ impl MemoryLoader for DefaultMemoryLoader {
|
|||
memory: &dyn Memory,
|
||||
user_message: &str,
|
||||
) -> anyhow::Result<String> {
|
||||
let entries = memory.recall(user_message, self.limit).await?;
|
||||
let entries = memory.recall(user_message, self.limit, None).await?;
|
||||
if entries.is_empty() {
|
||||
return Ok(String::new());
|
||||
}
|
||||
|
|
@ -61,11 +61,17 @@ mod tests {
|
|||
_key: &str,
|
||||
_content: &str,
|
||||
_category: MemoryCategory,
|
||||
_session_id: Option<&str>,
|
||||
) -> anyhow::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn recall(&self, _query: &str, limit: usize) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
async fn recall(
|
||||
&self,
|
||||
_query: &str,
|
||||
limit: usize,
|
||||
_session_id: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
if limit == 0 {
|
||||
return Ok(vec![]);
|
||||
}
|
||||
|
|
@ -87,6 +93,7 @@ mod tests {
|
|||
async fn list(
|
||||
&self,
|
||||
_category: Option<&MemoryCategory>,
|
||||
_session_id: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
Ok(vec![])
|
||||
}
|
||||
|
|
|
|||
|
|
@ -40,6 +40,7 @@ impl Channel for CliChannel {
|
|||
let msg = ChannelMessage {
|
||||
id: Uuid::new_v4().to_string(),
|
||||
sender: "user".to_string(),
|
||||
reply_target: "user".to_string(),
|
||||
content: line,
|
||||
channel: "cli".to_string(),
|
||||
timestamp: std::time::SystemTime::now()
|
||||
|
|
@ -90,12 +91,14 @@ mod tests {
|
|||
let msg = ChannelMessage {
|
||||
id: "test-id".into(),
|
||||
sender: "user".into(),
|
||||
reply_target: "user".into(),
|
||||
content: "hello".into(),
|
||||
channel: "cli".into(),
|
||||
timestamp: 1_234_567_890,
|
||||
};
|
||||
assert_eq!(msg.id, "test-id");
|
||||
assert_eq!(msg.sender, "user");
|
||||
assert_eq!(msg.reply_target, "user");
|
||||
assert_eq!(msg.content, "hello");
|
||||
assert_eq!(msg.channel, "cli");
|
||||
assert_eq!(msg.timestamp, 1_234_567_890);
|
||||
|
|
@ -106,6 +109,7 @@ mod tests {
|
|||
let msg = ChannelMessage {
|
||||
id: "id".into(),
|
||||
sender: "s".into(),
|
||||
reply_target: "s".into(),
|
||||
content: "c".into(),
|
||||
channel: "ch".into(),
|
||||
timestamp: 0,
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ use tokio::sync::RwLock;
|
|||
use tokio_tungstenite::tungstenite::Message;
|
||||
use uuid::Uuid;
|
||||
|
||||
/// DingTalk (钉钉) channel — connects via Stream Mode WebSocket for real-time messages.
|
||||
/// DingTalk channel — connects via Stream Mode WebSocket for real-time messages.
|
||||
/// Replies are sent through per-message session webhook URLs.
|
||||
pub struct DingTalkChannel {
|
||||
client_id: String,
|
||||
|
|
@ -64,6 +64,18 @@ impl DingTalkChannel {
|
|||
let gw: GatewayResponse = resp.json().await?;
|
||||
Ok(gw)
|
||||
}
|
||||
|
||||
fn resolve_reply_target(
|
||||
sender_id: &str,
|
||||
conversation_type: &str,
|
||||
conversation_id: Option<&str>,
|
||||
) -> String {
|
||||
if conversation_type == "1" {
|
||||
sender_id.to_string()
|
||||
} else {
|
||||
conversation_id.unwrap_or(sender_id).to_string()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
|
|
@ -193,14 +205,11 @@ impl Channel for DingTalkChannel {
|
|||
.unwrap_or("1");
|
||||
|
||||
// Private chat uses sender ID, group chat uses conversation ID
|
||||
let chat_id = if conversation_type == "1" {
|
||||
sender_id.to_string()
|
||||
} else {
|
||||
data.get("conversationId")
|
||||
.and_then(|c| c.as_str())
|
||||
.unwrap_or(sender_id)
|
||||
.to_string()
|
||||
};
|
||||
let chat_id = Self::resolve_reply_target(
|
||||
sender_id,
|
||||
conversation_type,
|
||||
data.get("conversationId").and_then(|c| c.as_str()),
|
||||
);
|
||||
|
||||
// Store session webhook for later replies
|
||||
if let Some(webhook) = data.get("sessionWebhook").and_then(|w| w.as_str()) {
|
||||
|
|
@ -229,6 +238,7 @@ impl Channel for DingTalkChannel {
|
|||
let channel_msg = ChannelMessage {
|
||||
id: Uuid::new_v4().to_string(),
|
||||
sender: sender_id.to_string(),
|
||||
reply_target: chat_id,
|
||||
content: content.to_string(),
|
||||
channel: "dingtalk".to_string(),
|
||||
timestamp: std::time::SystemTime::now()
|
||||
|
|
@ -305,4 +315,22 @@ client_secret = "secret"
|
|||
let config: crate::config::schema::DingTalkConfig = toml::from_str(toml_str).unwrap();
|
||||
assert!(config.allowed_users.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolve_reply_target_private_chat_uses_sender_id() {
|
||||
let target = DingTalkChannel::resolve_reply_target("staff_1", "1", Some("conv_1"));
|
||||
assert_eq!(target, "staff_1");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolve_reply_target_group_chat_uses_conversation_id() {
|
||||
let target = DingTalkChannel::resolve_reply_target("staff_1", "2", Some("conv_1"));
|
||||
assert_eq!(target, "conv_1");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolve_reply_target_group_chat_falls_back_to_sender_id() {
|
||||
let target = DingTalkChannel::resolve_reply_target("staff_1", "2", None);
|
||||
assert_eq!(target, "staff_1");
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ pub struct DiscordChannel {
|
|||
guild_id: Option<String>,
|
||||
allowed_users: Vec<String>,
|
||||
listen_to_bots: bool,
|
||||
mention_only: bool,
|
||||
client: reqwest::Client,
|
||||
typing_handle: std::sync::Mutex<Option<tokio::task::JoinHandle<()>>>,
|
||||
}
|
||||
|
|
@ -21,12 +22,14 @@ impl DiscordChannel {
|
|||
guild_id: Option<String>,
|
||||
allowed_users: Vec<String>,
|
||||
listen_to_bots: bool,
|
||||
mention_only: bool,
|
||||
) -> Self {
|
||||
Self {
|
||||
bot_token,
|
||||
guild_id,
|
||||
allowed_users,
|
||||
listen_to_bots,
|
||||
mention_only,
|
||||
client: reqwest::Client::new(),
|
||||
typing_handle: std::sync::Mutex::new(None),
|
||||
}
|
||||
|
|
@ -343,6 +346,22 @@ impl Channel for DiscordChannel {
|
|||
continue;
|
||||
}
|
||||
|
||||
// Skip messages that don't @-mention the bot (when mention_only is enabled)
|
||||
if self.mention_only {
|
||||
let mention_tag = format!("<@{bot_user_id}>");
|
||||
if !content.contains(&mention_tag) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// Strip the bot mention from content so the agent sees clean text
|
||||
let clean_content = if self.mention_only {
|
||||
let mention_tag = format!("<@{bot_user_id}>");
|
||||
content.replace(&mention_tag, "").trim().to_string()
|
||||
} else {
|
||||
content.to_string()
|
||||
};
|
||||
|
||||
let message_id = d.get("id").and_then(|i| i.as_str()).unwrap_or("");
|
||||
let channel_id = d.get("channel_id").and_then(|c| c.as_str()).unwrap_or("").to_string();
|
||||
|
||||
|
|
@ -353,6 +372,11 @@ impl Channel for DiscordChannel {
|
|||
format!("discord_{message_id}")
|
||||
},
|
||||
sender: author_id.to_string(),
|
||||
reply_target: if channel_id.is_empty() {
|
||||
author_id.to_string()
|
||||
} else {
|
||||
channel_id
|
||||
},
|
||||
content: content.to_string(),
|
||||
channel: channel_id,
|
||||
timestamp: std::time::SystemTime::now()
|
||||
|
|
@ -423,7 +447,7 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn discord_channel_name() {
|
||||
let ch = DiscordChannel::new("fake".into(), None, vec![], false);
|
||||
let ch = DiscordChannel::new("fake".into(), None, vec![], false, false);
|
||||
assert_eq!(ch.name(), "discord");
|
||||
}
|
||||
|
||||
|
|
@ -444,21 +468,27 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn empty_allowlist_denies_everyone() {
|
||||
let ch = DiscordChannel::new("fake".into(), None, vec![], false);
|
||||
let ch = DiscordChannel::new("fake".into(), None, vec![], false, false);
|
||||
assert!(!ch.is_user_allowed("12345"));
|
||||
assert!(!ch.is_user_allowed("anyone"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn wildcard_allows_everyone() {
|
||||
let ch = DiscordChannel::new("fake".into(), None, vec!["*".into()], false);
|
||||
let ch = DiscordChannel::new("fake".into(), None, vec!["*".into()], false, false);
|
||||
assert!(ch.is_user_allowed("12345"));
|
||||
assert!(ch.is_user_allowed("anyone"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn specific_allowlist_filters() {
|
||||
let ch = DiscordChannel::new("fake".into(), None, vec!["111".into(), "222".into()], false);
|
||||
let ch = DiscordChannel::new(
|
||||
"fake".into(),
|
||||
None,
|
||||
vec!["111".into(), "222".into()],
|
||||
false,
|
||||
false,
|
||||
);
|
||||
assert!(ch.is_user_allowed("111"));
|
||||
assert!(ch.is_user_allowed("222"));
|
||||
assert!(!ch.is_user_allowed("333"));
|
||||
|
|
@ -467,7 +497,7 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn allowlist_is_exact_match_not_substring() {
|
||||
let ch = DiscordChannel::new("fake".into(), None, vec!["111".into()], false);
|
||||
let ch = DiscordChannel::new("fake".into(), None, vec!["111".into()], false, false);
|
||||
assert!(!ch.is_user_allowed("1111"));
|
||||
assert!(!ch.is_user_allowed("11"));
|
||||
assert!(!ch.is_user_allowed("0111"));
|
||||
|
|
@ -475,20 +505,26 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn allowlist_empty_string_user_id() {
|
||||
let ch = DiscordChannel::new("fake".into(), None, vec!["111".into()], false);
|
||||
let ch = DiscordChannel::new("fake".into(), None, vec!["111".into()], false, false);
|
||||
assert!(!ch.is_user_allowed(""));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn allowlist_with_wildcard_and_specific() {
|
||||
let ch = DiscordChannel::new("fake".into(), None, vec!["111".into(), "*".into()], false);
|
||||
let ch = DiscordChannel::new(
|
||||
"fake".into(),
|
||||
None,
|
||||
vec!["111".into(), "*".into()],
|
||||
false,
|
||||
false,
|
||||
);
|
||||
assert!(ch.is_user_allowed("111"));
|
||||
assert!(ch.is_user_allowed("anyone_else"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn allowlist_case_sensitive() {
|
||||
let ch = DiscordChannel::new("fake".into(), None, vec!["ABC".into()], false);
|
||||
let ch = DiscordChannel::new("fake".into(), None, vec!["ABC".into()], false, false);
|
||||
assert!(ch.is_user_allowed("ABC"));
|
||||
assert!(!ch.is_user_allowed("abc"));
|
||||
assert!(!ch.is_user_allowed("Abc"));
|
||||
|
|
@ -663,14 +699,14 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn typing_handle_starts_as_none() {
|
||||
let ch = DiscordChannel::new("fake".into(), None, vec![], false);
|
||||
let ch = DiscordChannel::new("fake".into(), None, vec![], false, false);
|
||||
let guard = ch.typing_handle.lock().unwrap();
|
||||
assert!(guard.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn start_typing_sets_handle() {
|
||||
let ch = DiscordChannel::new("fake".into(), None, vec![], false);
|
||||
let ch = DiscordChannel::new("fake".into(), None, vec![], false, false);
|
||||
let _ = ch.start_typing("123456").await;
|
||||
let guard = ch.typing_handle.lock().unwrap();
|
||||
assert!(guard.is_some());
|
||||
|
|
@ -678,7 +714,7 @@ mod tests {
|
|||
|
||||
#[tokio::test]
|
||||
async fn stop_typing_clears_handle() {
|
||||
let ch = DiscordChannel::new("fake".into(), None, vec![], false);
|
||||
let ch = DiscordChannel::new("fake".into(), None, vec![], false, false);
|
||||
let _ = ch.start_typing("123456").await;
|
||||
let _ = ch.stop_typing("123456").await;
|
||||
let guard = ch.typing_handle.lock().unwrap();
|
||||
|
|
@ -687,14 +723,14 @@ mod tests {
|
|||
|
||||
#[tokio::test]
|
||||
async fn stop_typing_is_idempotent() {
|
||||
let ch = DiscordChannel::new("fake".into(), None, vec![], false);
|
||||
let ch = DiscordChannel::new("fake".into(), None, vec![], false, false);
|
||||
assert!(ch.stop_typing("123456").await.is_ok());
|
||||
assert!(ch.stop_typing("123456").await.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn start_typing_replaces_existing_task() {
|
||||
let ch = DiscordChannel::new("fake".into(), None, vec![], false);
|
||||
let ch = DiscordChannel::new("fake".into(), None, vec![], false, false);
|
||||
let _ = ch.start_typing("111").await;
|
||||
let _ = ch.start_typing("222").await;
|
||||
let guard = ch.typing_handle.lock().unwrap();
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@
|
|||
|
||||
use anyhow::{anyhow, Result};
|
||||
use async_trait::async_trait;
|
||||
use lettre::message::SinglePart;
|
||||
use lettre::transport::smtp::authentication::Credentials;
|
||||
use lettre::{Message, SmtpTransport, Transport};
|
||||
use mail_parser::{MessageParser, MimeHeaders};
|
||||
|
|
@ -39,7 +40,7 @@ pub struct EmailConfig {
|
|||
pub imap_folder: String,
|
||||
/// SMTP server hostname
|
||||
pub smtp_host: String,
|
||||
/// SMTP server port (default: 587 for STARTTLS)
|
||||
/// SMTP server port (default: 465 for TLS)
|
||||
#[serde(default = "default_smtp_port")]
|
||||
pub smtp_port: u16,
|
||||
/// Use TLS for SMTP (default: true)
|
||||
|
|
@ -63,7 +64,7 @@ fn default_imap_port() -> u16 {
|
|||
993
|
||||
}
|
||||
fn default_smtp_port() -> u16 {
|
||||
587
|
||||
465
|
||||
}
|
||||
fn default_imap_folder() -> String {
|
||||
"INBOX".into()
|
||||
|
|
@ -389,7 +390,7 @@ impl Channel for EmailChannel {
|
|||
.from(self.config.from_address.parse()?)
|
||||
.to(recipient.parse()?)
|
||||
.subject(subject)
|
||||
.body(body.to_string())?;
|
||||
.singlepart(SinglePart::plain(body.to_string()))?;
|
||||
|
||||
let transport = self.create_smtp_transport()?;
|
||||
transport.send(&email)?;
|
||||
|
|
@ -427,6 +428,7 @@ impl Channel for EmailChannel {
|
|||
} // MutexGuard dropped before await
|
||||
let msg = ChannelMessage {
|
||||
id,
|
||||
reply_target: sender.clone(),
|
||||
sender,
|
||||
content,
|
||||
channel: "email".to_string(),
|
||||
|
|
@ -464,6 +466,18 @@ impl Channel for EmailChannel {
|
|||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn default_smtp_port_uses_tls_port() {
|
||||
assert_eq!(default_smtp_port(), 465);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn email_config_default_uses_tls_smtp_defaults() {
|
||||
let config = EmailConfig::default();
|
||||
assert_eq!(config.smtp_port, 465);
|
||||
assert!(config.smtp_tls);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_imap_tls_config_succeeds() {
|
||||
let tls_config =
|
||||
|
|
@ -504,7 +518,7 @@ mod tests {
|
|||
assert_eq!(config.imap_port, 993);
|
||||
assert_eq!(config.imap_folder, "INBOX");
|
||||
assert_eq!(config.smtp_host, "");
|
||||
assert_eq!(config.smtp_port, 587);
|
||||
assert_eq!(config.smtp_port, 465);
|
||||
assert!(config.smtp_tls);
|
||||
assert_eq!(config.username, "");
|
||||
assert_eq!(config.password, "");
|
||||
|
|
@ -765,8 +779,8 @@ mod tests {
|
|||
}
|
||||
|
||||
#[test]
|
||||
fn default_smtp_port_returns_587() {
|
||||
assert_eq!(default_smtp_port(), 587);
|
||||
fn default_smtp_port_returns_465() {
|
||||
assert_eq!(default_smtp_port(), 465);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -822,7 +836,7 @@ mod tests {
|
|||
|
||||
let config: EmailConfig = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(config.imap_port, 993); // default
|
||||
assert_eq!(config.smtp_port, 587); // default
|
||||
assert_eq!(config.smtp_port, 465); // default
|
||||
assert!(config.smtp_tls); // default
|
||||
assert_eq!(config.poll_interval_secs, 60); // default
|
||||
}
|
||||
|
|
|
|||
|
|
@ -172,6 +172,7 @@ end tell"#
|
|||
let msg = ChannelMessage {
|
||||
id: rowid.to_string(),
|
||||
sender: sender.clone(),
|
||||
reply_target: sender.clone(),
|
||||
content: text,
|
||||
channel: "imessage".to_string(),
|
||||
timestamp: std::time::SystemTime::now()
|
||||
|
|
|
|||
|
|
@ -220,32 +220,34 @@ fn split_message(message: &str, max_bytes: usize) -> Vec<String> {
|
|||
chunks
|
||||
}
|
||||
|
||||
/// Configuration for constructing an `IrcChannel`.
|
||||
pub struct IrcChannelConfig {
|
||||
pub server: String,
|
||||
pub port: u16,
|
||||
pub nickname: String,
|
||||
pub username: Option<String>,
|
||||
pub channels: Vec<String>,
|
||||
pub allowed_users: Vec<String>,
|
||||
pub server_password: Option<String>,
|
||||
pub nickserv_password: Option<String>,
|
||||
pub sasl_password: Option<String>,
|
||||
pub verify_tls: bool,
|
||||
}
|
||||
|
||||
impl IrcChannel {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn new(
|
||||
server: String,
|
||||
port: u16,
|
||||
nickname: String,
|
||||
username: Option<String>,
|
||||
channels: Vec<String>,
|
||||
allowed_users: Vec<String>,
|
||||
server_password: Option<String>,
|
||||
nickserv_password: Option<String>,
|
||||
sasl_password: Option<String>,
|
||||
verify_tls: bool,
|
||||
) -> Self {
|
||||
let username = username.unwrap_or_else(|| nickname.clone());
|
||||
pub fn new(cfg: IrcChannelConfig) -> Self {
|
||||
let username = cfg.username.unwrap_or_else(|| cfg.nickname.clone());
|
||||
Self {
|
||||
server,
|
||||
port,
|
||||
nickname,
|
||||
server: cfg.server,
|
||||
port: cfg.port,
|
||||
nickname: cfg.nickname,
|
||||
username,
|
||||
channels,
|
||||
allowed_users,
|
||||
server_password,
|
||||
nickserv_password,
|
||||
sasl_password,
|
||||
verify_tls,
|
||||
channels: cfg.channels,
|
||||
allowed_users: cfg.allowed_users,
|
||||
server_password: cfg.server_password,
|
||||
nickserv_password: cfg.nickserv_password,
|
||||
sasl_password: cfg.sasl_password,
|
||||
verify_tls: cfg.verify_tls,
|
||||
writer: Arc::new(Mutex::new(None)),
|
||||
}
|
||||
}
|
||||
|
|
@ -563,7 +565,8 @@ impl Channel for IrcChannel {
|
|||
let seq = MSG_SEQ.fetch_add(1, Ordering::Relaxed);
|
||||
let channel_msg = ChannelMessage {
|
||||
id: format!("irc_{}_{seq}", chrono::Utc::now().timestamp_millis()),
|
||||
sender: reply_to,
|
||||
sender: sender_nick.to_string(),
|
||||
reply_target: reply_to,
|
||||
content,
|
||||
channel: "irc".to_string(),
|
||||
timestamp: std::time::SystemTime::now()
|
||||
|
|
@ -807,18 +810,18 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn specific_user_allowed() {
|
||||
let ch = IrcChannel::new(
|
||||
"irc.test".into(),
|
||||
6697,
|
||||
"bot".into(),
|
||||
None,
|
||||
vec![],
|
||||
vec!["alice".into(), "bob".into()],
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
true,
|
||||
);
|
||||
let ch = IrcChannel::new(IrcChannelConfig {
|
||||
server: "irc.test".into(),
|
||||
port: 6697,
|
||||
nickname: "bot".into(),
|
||||
username: None,
|
||||
channels: vec![],
|
||||
allowed_users: vec!["alice".into(), "bob".into()],
|
||||
server_password: None,
|
||||
nickserv_password: None,
|
||||
sasl_password: None,
|
||||
verify_tls: true,
|
||||
});
|
||||
assert!(ch.is_user_allowed("alice"));
|
||||
assert!(ch.is_user_allowed("bob"));
|
||||
assert!(!ch.is_user_allowed("eve"));
|
||||
|
|
@ -826,18 +829,18 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn allowlist_case_insensitive() {
|
||||
let ch = IrcChannel::new(
|
||||
"irc.test".into(),
|
||||
6697,
|
||||
"bot".into(),
|
||||
None,
|
||||
vec![],
|
||||
vec!["Alice".into()],
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
true,
|
||||
);
|
||||
let ch = IrcChannel::new(IrcChannelConfig {
|
||||
server: "irc.test".into(),
|
||||
port: 6697,
|
||||
nickname: "bot".into(),
|
||||
username: None,
|
||||
channels: vec![],
|
||||
allowed_users: vec!["Alice".into()],
|
||||
server_password: None,
|
||||
nickserv_password: None,
|
||||
sasl_password: None,
|
||||
verify_tls: true,
|
||||
});
|
||||
assert!(ch.is_user_allowed("alice"));
|
||||
assert!(ch.is_user_allowed("ALICE"));
|
||||
assert!(ch.is_user_allowed("Alice"));
|
||||
|
|
@ -845,18 +848,18 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn empty_allowlist_denies_all() {
|
||||
let ch = IrcChannel::new(
|
||||
"irc.test".into(),
|
||||
6697,
|
||||
"bot".into(),
|
||||
None,
|
||||
vec![],
|
||||
vec![],
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
true,
|
||||
);
|
||||
let ch = IrcChannel::new(IrcChannelConfig {
|
||||
server: "irc.test".into(),
|
||||
port: 6697,
|
||||
nickname: "bot".into(),
|
||||
username: None,
|
||||
channels: vec![],
|
||||
allowed_users: vec![],
|
||||
server_password: None,
|
||||
nickserv_password: None,
|
||||
sasl_password: None,
|
||||
verify_tls: true,
|
||||
});
|
||||
assert!(!ch.is_user_allowed("anyone"));
|
||||
}
|
||||
|
||||
|
|
@ -864,35 +867,35 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn new_defaults_username_to_nickname() {
|
||||
let ch = IrcChannel::new(
|
||||
"irc.test".into(),
|
||||
6697,
|
||||
"mybot".into(),
|
||||
None,
|
||||
vec![],
|
||||
vec![],
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
true,
|
||||
);
|
||||
let ch = IrcChannel::new(IrcChannelConfig {
|
||||
server: "irc.test".into(),
|
||||
port: 6697,
|
||||
nickname: "mybot".into(),
|
||||
username: None,
|
||||
channels: vec![],
|
||||
allowed_users: vec![],
|
||||
server_password: None,
|
||||
nickserv_password: None,
|
||||
sasl_password: None,
|
||||
verify_tls: true,
|
||||
});
|
||||
assert_eq!(ch.username, "mybot");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn new_uses_explicit_username() {
|
||||
let ch = IrcChannel::new(
|
||||
"irc.test".into(),
|
||||
6697,
|
||||
"mybot".into(),
|
||||
Some("customuser".into()),
|
||||
vec![],
|
||||
vec![],
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
true,
|
||||
);
|
||||
let ch = IrcChannel::new(IrcChannelConfig {
|
||||
server: "irc.test".into(),
|
||||
port: 6697,
|
||||
nickname: "mybot".into(),
|
||||
username: Some("customuser".into()),
|
||||
channels: vec![],
|
||||
allowed_users: vec![],
|
||||
server_password: None,
|
||||
nickserv_password: None,
|
||||
sasl_password: None,
|
||||
verify_tls: true,
|
||||
});
|
||||
assert_eq!(ch.username, "customuser");
|
||||
assert_eq!(ch.nickname, "mybot");
|
||||
}
|
||||
|
|
@ -905,18 +908,18 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn new_stores_all_fields() {
|
||||
let ch = IrcChannel::new(
|
||||
"irc.example.com".into(),
|
||||
6697,
|
||||
"zcbot".into(),
|
||||
Some("zeroclaw".into()),
|
||||
vec!["#test".into()],
|
||||
vec!["alice".into()],
|
||||
Some("serverpass".into()),
|
||||
Some("nspass".into()),
|
||||
Some("saslpass".into()),
|
||||
false,
|
||||
);
|
||||
let ch = IrcChannel::new(IrcChannelConfig {
|
||||
server: "irc.example.com".into(),
|
||||
port: 6697,
|
||||
nickname: "zcbot".into(),
|
||||
username: Some("zeroclaw".into()),
|
||||
channels: vec!["#test".into()],
|
||||
allowed_users: vec!["alice".into()],
|
||||
server_password: Some("serverpass".into()),
|
||||
nickserv_password: Some("nspass".into()),
|
||||
sasl_password: Some("saslpass".into()),
|
||||
verify_tls: false,
|
||||
});
|
||||
assert_eq!(ch.server, "irc.example.com");
|
||||
assert_eq!(ch.port, 6697);
|
||||
assert_eq!(ch.nickname, "zcbot");
|
||||
|
|
@ -995,17 +998,17 @@ nickname = "bot"
|
|||
// ── Helpers ─────────────────────────────────────────────
|
||||
|
||||
fn make_channel() -> IrcChannel {
|
||||
IrcChannel::new(
|
||||
"irc.example.com".into(),
|
||||
6697,
|
||||
"zcbot".into(),
|
||||
None,
|
||||
vec!["#zeroclaw".into()],
|
||||
vec!["*".into()],
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
true,
|
||||
)
|
||||
IrcChannel::new(IrcChannelConfig {
|
||||
server: "irc.example.com".into(),
|
||||
port: 6697,
|
||||
nickname: "zcbot".into(),
|
||||
username: None,
|
||||
channels: vec!["#zeroclaw".into()],
|
||||
allowed_users: vec!["*".into()],
|
||||
server_password: None,
|
||||
nickserv_password: None,
|
||||
sasl_password: None,
|
||||
verify_tls: true,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,21 +1,152 @@
|
|||
use super::traits::{Channel, ChannelMessage};
|
||||
use async_trait::async_trait;
|
||||
use futures_util::{SinkExt, StreamExt};
|
||||
use prost::Message as ProstMessage;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
use tokio::sync::RwLock;
|
||||
use tokio_tungstenite::tungstenite::Message as WsMsg;
|
||||
use uuid::Uuid;
|
||||
|
||||
const FEISHU_BASE_URL: &str = "https://open.feishu.cn/open-apis";
|
||||
const FEISHU_WS_BASE_URL: &str = "https://open.feishu.cn";
|
||||
const LARK_BASE_URL: &str = "https://open.larksuite.com/open-apis";
|
||||
const LARK_WS_BASE_URL: &str = "https://open.larksuite.com";
|
||||
|
||||
/// Lark/Feishu channel — receives events via HTTP callback, sends via Open API
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// Feishu WebSocket long-connection: pbbp2.proto frame codec
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
#[derive(Clone, PartialEq, prost::Message)]
|
||||
struct PbHeader {
|
||||
#[prost(string, tag = "1")]
|
||||
pub key: String,
|
||||
#[prost(string, tag = "2")]
|
||||
pub value: String,
|
||||
}
|
||||
|
||||
/// Feishu WS frame (pbbp2.proto).
|
||||
/// method=0 → CONTROL (ping/pong) method=1 → DATA (events)
|
||||
#[derive(Clone, PartialEq, prost::Message)]
|
||||
struct PbFrame {
|
||||
#[prost(uint64, tag = "1")]
|
||||
pub seq_id: u64,
|
||||
#[prost(uint64, tag = "2")]
|
||||
pub log_id: u64,
|
||||
#[prost(int32, tag = "3")]
|
||||
pub service: i32,
|
||||
#[prost(int32, tag = "4")]
|
||||
pub method: i32,
|
||||
#[prost(message, repeated, tag = "5")]
|
||||
pub headers: Vec<PbHeader>,
|
||||
#[prost(bytes = "vec", optional, tag = "8")]
|
||||
pub payload: Option<Vec<u8>>,
|
||||
}
|
||||
|
||||
impl PbFrame {
|
||||
fn header_value<'a>(&'a self, key: &str) -> &'a str {
|
||||
self.headers
|
||||
.iter()
|
||||
.find(|h| h.key == key)
|
||||
.map(|h| h.value.as_str())
|
||||
.unwrap_or("")
|
||||
}
|
||||
}
|
||||
|
||||
/// Server-sent client config (parsed from pong payload)
|
||||
#[derive(Debug, serde::Deserialize, Default, Clone)]
|
||||
struct WsClientConfig {
|
||||
#[serde(rename = "PingInterval")]
|
||||
ping_interval: Option<u64>,
|
||||
}
|
||||
|
||||
/// POST /callback/ws/endpoint response
|
||||
#[derive(Debug, serde::Deserialize)]
|
||||
struct WsEndpointResp {
|
||||
code: i32,
|
||||
#[serde(default)]
|
||||
msg: Option<String>,
|
||||
#[serde(default)]
|
||||
data: Option<WsEndpoint>,
|
||||
}
|
||||
|
||||
#[derive(Debug, serde::Deserialize)]
|
||||
struct WsEndpoint {
|
||||
#[serde(rename = "URL")]
|
||||
url: String,
|
||||
#[serde(rename = "ClientConfig")]
|
||||
client_config: Option<WsClientConfig>,
|
||||
}
|
||||
|
||||
/// LarkEvent envelope (method=1 / type=event payload)
|
||||
#[derive(Debug, serde::Deserialize)]
|
||||
struct LarkEvent {
|
||||
header: LarkEventHeader,
|
||||
event: serde_json::Value,
|
||||
}
|
||||
|
||||
#[derive(Debug, serde::Deserialize)]
|
||||
struct LarkEventHeader {
|
||||
event_type: String,
|
||||
#[allow(dead_code)]
|
||||
event_id: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, serde::Deserialize)]
|
||||
struct MsgReceivePayload {
|
||||
sender: LarkSender,
|
||||
message: LarkMessage,
|
||||
}
|
||||
|
||||
#[derive(Debug, serde::Deserialize)]
|
||||
struct LarkSender {
|
||||
sender_id: LarkSenderId,
|
||||
#[serde(default)]
|
||||
sender_type: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, serde::Deserialize, Default)]
|
||||
struct LarkSenderId {
|
||||
open_id: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, serde::Deserialize)]
|
||||
struct LarkMessage {
|
||||
message_id: String,
|
||||
chat_id: String,
|
||||
chat_type: String,
|
||||
message_type: String,
|
||||
#[serde(default)]
|
||||
content: String,
|
||||
#[serde(default)]
|
||||
mentions: Vec<serde_json::Value>,
|
||||
}
|
||||
|
||||
/// Heartbeat timeout for WS connection — must be larger than ping_interval (default 120 s).
|
||||
/// If no binary frame (pong or event) is received within this window, reconnect.
|
||||
const WS_HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(300);
|
||||
|
||||
/// Lark/Feishu channel.
|
||||
///
|
||||
/// Supports two receive modes (configured via `receive_mode` in config):
|
||||
/// - **`websocket`** (default): persistent WSS long-connection; no public URL needed.
|
||||
/// - **`webhook`**: HTTP callback server; requires a public HTTPS endpoint.
|
||||
pub struct LarkChannel {
|
||||
app_id: String,
|
||||
app_secret: String,
|
||||
verification_token: String,
|
||||
port: u16,
|
||||
port: Option<u16>,
|
||||
allowed_users: Vec<String>,
|
||||
/// When true, use Feishu (CN) endpoints; when false, use Lark (international).
|
||||
use_feishu: bool,
|
||||
/// How to receive events: WebSocket long-connection or HTTP webhook.
|
||||
receive_mode: crate::config::schema::LarkReceiveMode,
|
||||
client: reqwest::Client,
|
||||
/// Cached tenant access token
|
||||
tenant_token: Arc<RwLock<Option<String>>>,
|
||||
/// Dedup set: WS message_ids seen in last ~30 min to prevent double-dispatch
|
||||
ws_seen_ids: Arc<RwLock<HashMap<String, Instant>>>,
|
||||
}
|
||||
|
||||
impl LarkChannel {
|
||||
|
|
@ -23,7 +154,7 @@ impl LarkChannel {
|
|||
app_id: String,
|
||||
app_secret: String,
|
||||
verification_token: String,
|
||||
port: u16,
|
||||
port: Option<u16>,
|
||||
allowed_users: Vec<String>,
|
||||
) -> Self {
|
||||
Self {
|
||||
|
|
@ -32,11 +163,310 @@ impl LarkChannel {
|
|||
verification_token,
|
||||
port,
|
||||
allowed_users,
|
||||
use_feishu: true,
|
||||
receive_mode: crate::config::schema::LarkReceiveMode::default(),
|
||||
client: reqwest::Client::new(),
|
||||
tenant_token: Arc::new(RwLock::new(None)),
|
||||
ws_seen_ids: Arc::new(RwLock::new(HashMap::new())),
|
||||
}
|
||||
}
|
||||
|
||||
/// Build from `LarkConfig` (preserves `use_feishu` and `receive_mode`).
|
||||
pub fn from_config(config: &crate::config::schema::LarkConfig) -> Self {
|
||||
let mut ch = Self::new(
|
||||
config.app_id.clone(),
|
||||
config.app_secret.clone(),
|
||||
config.verification_token.clone().unwrap_or_default(),
|
||||
config.port,
|
||||
config.allowed_users.clone(),
|
||||
);
|
||||
ch.use_feishu = config.use_feishu;
|
||||
ch.receive_mode = config.receive_mode.clone();
|
||||
ch
|
||||
}
|
||||
|
||||
fn api_base(&self) -> &'static str {
|
||||
if self.use_feishu {
|
||||
FEISHU_BASE_URL
|
||||
} else {
|
||||
LARK_BASE_URL
|
||||
}
|
||||
}
|
||||
|
||||
fn ws_base(&self) -> &'static str {
|
||||
if self.use_feishu {
|
||||
FEISHU_WS_BASE_URL
|
||||
} else {
|
||||
LARK_WS_BASE_URL
|
||||
}
|
||||
}
|
||||
|
||||
fn tenant_access_token_url(&self) -> String {
|
||||
format!("{}/auth/v3/tenant_access_token/internal", self.api_base())
|
||||
}
|
||||
|
||||
fn send_message_url(&self) -> String {
|
||||
format!("{}/im/v1/messages?receive_id_type=chat_id", self.api_base())
|
||||
}
|
||||
|
||||
/// POST /callback/ws/endpoint → (wss_url, client_config)
|
||||
async fn get_ws_endpoint(&self) -> anyhow::Result<(String, WsClientConfig)> {
|
||||
let resp = self
|
||||
.client
|
||||
.post(format!("{}/callback/ws/endpoint", self.ws_base()))
|
||||
.header("locale", if self.use_feishu { "zh" } else { "en" })
|
||||
.json(&serde_json::json!({
|
||||
"AppID": self.app_id,
|
||||
"AppSecret": self.app_secret,
|
||||
}))
|
||||
.send()
|
||||
.await?
|
||||
.json::<WsEndpointResp>()
|
||||
.await?;
|
||||
if resp.code != 0 {
|
||||
anyhow::bail!(
|
||||
"Lark WS endpoint failed: code={} msg={}",
|
||||
resp.code,
|
||||
resp.msg.as_deref().unwrap_or("(none)")
|
||||
);
|
||||
}
|
||||
let ep = resp
|
||||
.data
|
||||
.ok_or_else(|| anyhow::anyhow!("Lark WS endpoint: empty data"))?;
|
||||
Ok((ep.url, ep.client_config.unwrap_or_default()))
|
||||
}
|
||||
|
||||
/// WS long-connection event loop. Returns Ok(()) when the connection closes
|
||||
/// (the caller reconnects).
|
||||
#[allow(clippy::too_many_lines)]
|
||||
async fn listen_ws(&self, tx: tokio::sync::mpsc::Sender<ChannelMessage>) -> anyhow::Result<()> {
|
||||
let (wss_url, client_config) = self.get_ws_endpoint().await?;
|
||||
let service_id = wss_url
|
||||
.split('?')
|
||||
.nth(1)
|
||||
.and_then(|qs| {
|
||||
qs.split('&')
|
||||
.find(|kv| kv.starts_with("service_id="))
|
||||
.and_then(|kv| kv.split('=').nth(1))
|
||||
.and_then(|v| v.parse::<i32>().ok())
|
||||
})
|
||||
.unwrap_or(0);
|
||||
tracing::info!("Lark: connecting to {wss_url}");
|
||||
|
||||
let (ws_stream, _) = tokio_tungstenite::connect_async(&wss_url).await?;
|
||||
let (mut write, mut read) = ws_stream.split();
|
||||
tracing::info!("Lark: WS connected (service_id={service_id})");
|
||||
|
||||
let mut ping_secs = client_config.ping_interval.unwrap_or(120).max(10);
|
||||
let mut hb_interval = tokio::time::interval(Duration::from_secs(ping_secs));
|
||||
let mut timeout_check = tokio::time::interval(Duration::from_secs(10));
|
||||
hb_interval.tick().await; // consume immediate tick
|
||||
|
||||
let mut seq: u64 = 0;
|
||||
let mut last_recv = Instant::now();
|
||||
|
||||
// Send initial ping immediately (like the official SDK) so the server
|
||||
// starts responding with pongs and we can calibrate the ping_interval.
|
||||
seq = seq.wrapping_add(1);
|
||||
let initial_ping = PbFrame {
|
||||
seq_id: seq,
|
||||
log_id: 0,
|
||||
service: service_id,
|
||||
method: 0,
|
||||
headers: vec![PbHeader {
|
||||
key: "type".into(),
|
||||
value: "ping".into(),
|
||||
}],
|
||||
payload: None,
|
||||
};
|
||||
if write
|
||||
.send(WsMsg::Binary(initial_ping.encode_to_vec()))
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
anyhow::bail!("Lark: initial ping failed");
|
||||
}
|
||||
// message_id → (fragment_slots, created_at) for multi-part reassembly
|
||||
type FragEntry = (Vec<Option<Vec<u8>>>, Instant);
|
||||
let mut frag_cache: HashMap<String, FragEntry> = HashMap::new();
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
biased;
|
||||
|
||||
_ = hb_interval.tick() => {
|
||||
seq = seq.wrapping_add(1);
|
||||
let ping = PbFrame {
|
||||
seq_id: seq, log_id: 0, service: service_id, method: 0,
|
||||
headers: vec![PbHeader { key: "type".into(), value: "ping".into() }],
|
||||
payload: None,
|
||||
};
|
||||
if write.send(WsMsg::Binary(ping.encode_to_vec())).await.is_err() {
|
||||
tracing::warn!("Lark: ping failed, reconnecting");
|
||||
break;
|
||||
}
|
||||
// GC stale fragments > 5 min
|
||||
let cutoff = Instant::now().checked_sub(Duration::from_secs(300)).unwrap_or(Instant::now());
|
||||
frag_cache.retain(|_, (_, ts)| *ts > cutoff);
|
||||
}
|
||||
|
||||
_ = timeout_check.tick() => {
|
||||
if last_recv.elapsed() > WS_HEARTBEAT_TIMEOUT {
|
||||
tracing::warn!("Lark: heartbeat timeout, reconnecting");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
msg = read.next() => {
|
||||
let raw = match msg {
|
||||
Some(Ok(WsMsg::Binary(b))) => { last_recv = Instant::now(); b }
|
||||
Some(Ok(WsMsg::Ping(d))) => { let _ = write.send(WsMsg::Pong(d)).await; continue; }
|
||||
Some(Ok(WsMsg::Close(_))) | None => { tracing::info!("Lark: WS closed — reconnecting"); break; }
|
||||
Some(Err(e)) => { tracing::error!("Lark: WS read error: {e}"); break; }
|
||||
_ => continue,
|
||||
};
|
||||
|
||||
let frame = match PbFrame::decode(&raw[..]) {
|
||||
Ok(f) => f,
|
||||
Err(e) => { tracing::error!("Lark: proto decode: {e}"); continue; }
|
||||
};
|
||||
|
||||
// CONTROL frame
|
||||
if frame.method == 0 {
|
||||
if frame.header_value("type") == "pong" {
|
||||
if let Some(p) = &frame.payload {
|
||||
if let Ok(cfg) = serde_json::from_slice::<WsClientConfig>(p) {
|
||||
if let Some(secs) = cfg.ping_interval {
|
||||
let secs = secs.max(10);
|
||||
if secs != ping_secs {
|
||||
ping_secs = secs;
|
||||
hb_interval = tokio::time::interval(Duration::from_secs(ping_secs));
|
||||
tracing::info!("Lark: ping_interval → {ping_secs}s");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
// DATA frame
|
||||
let msg_type = frame.header_value("type").to_string();
|
||||
let msg_id = frame.header_value("message_id").to_string();
|
||||
let sum = frame.header_value("sum").parse::<usize>().unwrap_or(1);
|
||||
let seq_num = frame.header_value("seq").parse::<usize>().unwrap_or(0);
|
||||
|
||||
// ACK immediately (Feishu requires within 3 s)
|
||||
{
|
||||
let mut ack = frame.clone();
|
||||
ack.payload = Some(br#"{"code":200,"headers":{},"data":[]}"#.to_vec());
|
||||
ack.headers.push(PbHeader { key: "biz_rt".into(), value: "0".into() });
|
||||
let _ = write.send(WsMsg::Binary(ack.encode_to_vec())).await;
|
||||
}
|
||||
|
||||
// Fragment reassembly
|
||||
let sum = if sum == 0 { 1 } else { sum };
|
||||
let payload: Vec<u8> = if sum == 1 || msg_id.is_empty() || seq_num >= sum {
|
||||
frame.payload.clone().unwrap_or_default()
|
||||
} else {
|
||||
let entry = frag_cache.entry(msg_id.clone())
|
||||
.or_insert_with(|| (vec![None; sum], Instant::now()));
|
||||
if entry.0.len() != sum { *entry = (vec![None; sum], Instant::now()); }
|
||||
entry.0[seq_num] = frame.payload.clone();
|
||||
if entry.0.iter().all(|s| s.is_some()) {
|
||||
let full: Vec<u8> = entry.0.iter()
|
||||
.flat_map(|s| s.as_deref().unwrap_or(&[]))
|
||||
.copied().collect();
|
||||
frag_cache.remove(&msg_id);
|
||||
full
|
||||
} else { continue; }
|
||||
};
|
||||
|
||||
if msg_type != "event" { continue; }
|
||||
|
||||
let event: LarkEvent = match serde_json::from_slice(&payload) {
|
||||
Ok(e) => e,
|
||||
Err(e) => { tracing::error!("Lark: event JSON: {e}"); continue; }
|
||||
};
|
||||
if event.header.event_type != "im.message.receive_v1" { continue; }
|
||||
|
||||
let recv: MsgReceivePayload = match serde_json::from_value(event.event) {
|
||||
Ok(r) => r,
|
||||
Err(e) => { tracing::error!("Lark: payload parse: {e}"); continue; }
|
||||
};
|
||||
|
||||
if recv.sender.sender_type == "app" || recv.sender.sender_type == "bot" { continue; }
|
||||
|
||||
let sender_open_id = recv.sender.sender_id.open_id.as_deref().unwrap_or("");
|
||||
if !self.is_user_allowed(sender_open_id) {
|
||||
tracing::warn!("Lark WS: ignoring {sender_open_id} (not in allowed_users)");
|
||||
continue;
|
||||
}
|
||||
|
||||
let lark_msg = &recv.message;
|
||||
|
||||
// Dedup
|
||||
{
|
||||
let now = Instant::now();
|
||||
let mut seen = self.ws_seen_ids.write().await;
|
||||
// GC
|
||||
seen.retain(|_, t| now.duration_since(*t) < Duration::from_secs(30 * 60));
|
||||
if seen.contains_key(&lark_msg.message_id) {
|
||||
tracing::debug!("Lark WS: dup {}", lark_msg.message_id);
|
||||
continue;
|
||||
}
|
||||
seen.insert(lark_msg.message_id.clone(), now);
|
||||
}
|
||||
|
||||
// Decode content by type (mirrors clawdbot-feishu parsing)
|
||||
let text = match lark_msg.message_type.as_str() {
|
||||
"text" => {
|
||||
let v: serde_json::Value = match serde_json::from_str(&lark_msg.content) {
|
||||
Ok(v) => v,
|
||||
Err(_) => continue,
|
||||
};
|
||||
match v.get("text").and_then(|t| t.as_str()).filter(|s| !s.is_empty()) {
|
||||
Some(t) => t.to_string(),
|
||||
None => continue,
|
||||
}
|
||||
}
|
||||
"post" => match parse_post_content(&lark_msg.content) {
|
||||
Some(t) => t,
|
||||
None => continue,
|
||||
},
|
||||
_ => { tracing::debug!("Lark WS: skipping unsupported type '{}'", lark_msg.message_type); continue; }
|
||||
};
|
||||
|
||||
// Strip @_user_N placeholders
|
||||
let text = strip_at_placeholders(&text);
|
||||
let text = text.trim().to_string();
|
||||
if text.is_empty() { continue; }
|
||||
|
||||
// Group-chat: only respond when explicitly @-mentioned
|
||||
if lark_msg.chat_type == "group" && !should_respond_in_group(&lark_msg.mentions) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let channel_msg = ChannelMessage {
|
||||
id: Uuid::new_v4().to_string(),
|
||||
sender: lark_msg.chat_id.clone(),
|
||||
reply_target: lark_msg.chat_id.clone(),
|
||||
content: text,
|
||||
channel: "lark".to_string(),
|
||||
timestamp: std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs(),
|
||||
};
|
||||
|
||||
tracing::debug!("Lark WS: message in {}", lark_msg.chat_id);
|
||||
if tx.send(channel_msg).await.is_err() { break; }
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Check if a user open_id is allowed
|
||||
fn is_user_allowed(&self, open_id: &str) -> bool {
|
||||
self.allowed_users.iter().any(|u| u == "*" || u == open_id)
|
||||
|
|
@ -52,7 +482,7 @@ impl LarkChannel {
|
|||
}
|
||||
}
|
||||
|
||||
let url = format!("{FEISHU_BASE_URL}/auth/v3/tenant_access_token/internal");
|
||||
let url = self.tenant_access_token_url();
|
||||
let body = serde_json::json!({
|
||||
"app_id": self.app_id,
|
||||
"app_secret": self.app_secret,
|
||||
|
|
@ -127,31 +557,41 @@ impl LarkChannel {
|
|||
return messages;
|
||||
}
|
||||
|
||||
// Extract message content (text only)
|
||||
// Extract message content (text and post supported)
|
||||
let msg_type = event
|
||||
.pointer("/message/message_type")
|
||||
.and_then(|t| t.as_str())
|
||||
.unwrap_or("");
|
||||
|
||||
if msg_type != "text" {
|
||||
tracing::debug!("Lark: skipping non-text message type: {msg_type}");
|
||||
return messages;
|
||||
}
|
||||
|
||||
let content_str = event
|
||||
.pointer("/message/content")
|
||||
.and_then(|c| c.as_str())
|
||||
.unwrap_or("");
|
||||
|
||||
// content is a JSON string like "{\"text\":\"hello\"}"
|
||||
let text = serde_json::from_str::<serde_json::Value>(content_str)
|
||||
.ok()
|
||||
.and_then(|v| v.get("text").and_then(|t| t.as_str()).map(String::from))
|
||||
.unwrap_or_default();
|
||||
|
||||
if text.is_empty() {
|
||||
return messages;
|
||||
}
|
||||
let text: String = match msg_type {
|
||||
"text" => {
|
||||
let extracted = serde_json::from_str::<serde_json::Value>(content_str)
|
||||
.ok()
|
||||
.and_then(|v| {
|
||||
v.get("text")
|
||||
.and_then(|t| t.as_str())
|
||||
.filter(|s| !s.is_empty())
|
||||
.map(String::from)
|
||||
});
|
||||
match extracted {
|
||||
Some(t) => t,
|
||||
None => return messages,
|
||||
}
|
||||
}
|
||||
"post" => match parse_post_content(content_str) {
|
||||
Some(t) => t,
|
||||
None => return messages,
|
||||
},
|
||||
_ => {
|
||||
tracing::debug!("Lark: skipping unsupported message type: {msg_type}");
|
||||
return messages;
|
||||
}
|
||||
};
|
||||
|
||||
let timestamp = event
|
||||
.pointer("/message/create_time")
|
||||
|
|
@ -174,6 +614,7 @@ impl LarkChannel {
|
|||
messages.push(ChannelMessage {
|
||||
id: Uuid::new_v4().to_string(),
|
||||
sender: chat_id.to_string(),
|
||||
reply_target: chat_id.to_string(),
|
||||
content: text,
|
||||
channel: "lark".to_string(),
|
||||
timestamp,
|
||||
|
|
@ -191,7 +632,7 @@ impl Channel for LarkChannel {
|
|||
|
||||
async fn send(&self, message: &str, recipient: &str) -> anyhow::Result<()> {
|
||||
let token = self.get_tenant_access_token().await?;
|
||||
let url = format!("{FEISHU_BASE_URL}/im/v1/messages?receive_id_type=chat_id");
|
||||
let url = self.send_message_url();
|
||||
|
||||
let content = serde_json::json!({ "text": message }).to_string();
|
||||
let body = serde_json::json!({
|
||||
|
|
@ -238,6 +679,25 @@ impl Channel for LarkChannel {
|
|||
}
|
||||
|
||||
async fn listen(&self, tx: tokio::sync::mpsc::Sender<ChannelMessage>) -> anyhow::Result<()> {
|
||||
use crate::config::schema::LarkReceiveMode;
|
||||
match self.receive_mode {
|
||||
LarkReceiveMode::Websocket => self.listen_ws(tx).await,
|
||||
LarkReceiveMode::Webhook => self.listen_http(tx).await,
|
||||
}
|
||||
}
|
||||
|
||||
async fn health_check(&self) -> bool {
|
||||
self.get_tenant_access_token().await.is_ok()
|
||||
}
|
||||
}
|
||||
|
||||
impl LarkChannel {
|
||||
/// HTTP callback server (legacy — requires a public endpoint).
|
||||
/// Use `listen()` (WS long-connection) for new deployments.
|
||||
pub async fn listen_http(
|
||||
&self,
|
||||
tx: tokio::sync::mpsc::Sender<ChannelMessage>,
|
||||
) -> anyhow::Result<()> {
|
||||
use axum::{extract::State, routing::post, Json, Router};
|
||||
|
||||
#[derive(Clone)]
|
||||
|
|
@ -282,13 +742,17 @@ impl Channel for LarkChannel {
|
|||
(StatusCode::OK, "ok").into_response()
|
||||
}
|
||||
|
||||
let port = self.port.ok_or_else(|| {
|
||||
anyhow::anyhow!("Lark webhook mode requires `port` to be set in [channels_config.lark]")
|
||||
})?;
|
||||
|
||||
let state = AppState {
|
||||
verification_token: self.verification_token.clone(),
|
||||
channel: Arc::new(LarkChannel::new(
|
||||
self.app_id.clone(),
|
||||
self.app_secret.clone(),
|
||||
self.verification_token.clone(),
|
||||
self.port,
|
||||
None,
|
||||
self.allowed_users.clone(),
|
||||
)),
|
||||
tx,
|
||||
|
|
@ -298,7 +762,7 @@ impl Channel for LarkChannel {
|
|||
.route("/lark", post(handle_event))
|
||||
.with_state(state);
|
||||
|
||||
let addr = std::net::SocketAddr::from(([0, 0, 0, 0], self.port));
|
||||
let addr = std::net::SocketAddr::from(([0, 0, 0, 0], port));
|
||||
tracing::info!("Lark event callback server listening on {addr}");
|
||||
|
||||
let listener = tokio::net::TcpListener::bind(addr).await?;
|
||||
|
|
@ -306,10 +770,110 @@ impl Channel for LarkChannel {
|
|||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
async fn health_check(&self) -> bool {
|
||||
self.get_tenant_access_token().await.is_ok()
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// WS helper functions
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
/// Flatten a Feishu `post` rich-text message to plain text.
|
||||
///
|
||||
/// Returns `None` when the content cannot be parsed or yields no usable text,
|
||||
/// so callers can simply `continue` rather than forwarding a meaningless
|
||||
/// placeholder string to the agent.
|
||||
fn parse_post_content(content: &str) -> Option<String> {
|
||||
let parsed = serde_json::from_str::<serde_json::Value>(content).ok()?;
|
||||
let locale = parsed
|
||||
.get("zh_cn")
|
||||
.or_else(|| parsed.get("en_us"))
|
||||
.or_else(|| {
|
||||
parsed
|
||||
.as_object()
|
||||
.and_then(|m| m.values().find(|v| v.is_object()))
|
||||
})?;
|
||||
|
||||
let mut text = String::new();
|
||||
|
||||
if let Some(title) = locale
|
||||
.get("title")
|
||||
.and_then(|t| t.as_str())
|
||||
.filter(|s| !s.is_empty())
|
||||
{
|
||||
text.push_str(title);
|
||||
text.push_str("\n\n");
|
||||
}
|
||||
|
||||
if let Some(paragraphs) = locale.get("content").and_then(|c| c.as_array()) {
|
||||
for para in paragraphs {
|
||||
if let Some(elements) = para.as_array() {
|
||||
for el in elements {
|
||||
match el.get("tag").and_then(|t| t.as_str()).unwrap_or("") {
|
||||
"text" => {
|
||||
if let Some(t) = el.get("text").and_then(|t| t.as_str()) {
|
||||
text.push_str(t);
|
||||
}
|
||||
}
|
||||
"a" => {
|
||||
text.push_str(
|
||||
el.get("text")
|
||||
.and_then(|t| t.as_str())
|
||||
.filter(|s| !s.is_empty())
|
||||
.or_else(|| el.get("href").and_then(|h| h.as_str()))
|
||||
.unwrap_or(""),
|
||||
);
|
||||
}
|
||||
"at" => {
|
||||
let n = el
|
||||
.get("user_name")
|
||||
.and_then(|n| n.as_str())
|
||||
.or_else(|| el.get("user_id").and_then(|i| i.as_str()))
|
||||
.unwrap_or("user");
|
||||
text.push('@');
|
||||
text.push_str(n);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
text.push('\n');
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let result = text.trim().to_string();
|
||||
if result.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(result)
|
||||
}
|
||||
}
|
||||
|
||||
/// Remove `@_user_N` placeholder tokens injected by Feishu in group chats.
|
||||
fn strip_at_placeholders(text: &str) -> String {
|
||||
let mut result = String::with_capacity(text.len());
|
||||
let mut chars = text.char_indices().peekable();
|
||||
while let Some((_, ch)) = chars.next() {
|
||||
if ch == '@' {
|
||||
let rest: String = chars.clone().map(|(_, c)| c).collect();
|
||||
if let Some(after) = rest.strip_prefix("_user_") {
|
||||
let skip =
|
||||
"_user_".len() + after.chars().take_while(|c| c.is_ascii_digit()).count();
|
||||
for _ in 0..=skip {
|
||||
chars.next();
|
||||
}
|
||||
if chars.peek().map(|(_, c)| *c == ' ').unwrap_or(false) {
|
||||
chars.next();
|
||||
}
|
||||
continue;
|
||||
}
|
||||
}
|
||||
result.push(ch);
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
/// In group chats, only respond when the bot is explicitly @-mentioned.
|
||||
fn should_respond_in_group(mentions: &[serde_json::Value]) -> bool {
|
||||
!mentions.is_empty()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
|
@ -321,7 +885,7 @@ mod tests {
|
|||
"cli_test_app_id".into(),
|
||||
"test_app_secret".into(),
|
||||
"test_verification_token".into(),
|
||||
9898,
|
||||
None,
|
||||
vec!["ou_testuser123".into()],
|
||||
)
|
||||
}
|
||||
|
|
@ -345,7 +909,7 @@ mod tests {
|
|||
"id".into(),
|
||||
"secret".into(),
|
||||
"token".into(),
|
||||
9898,
|
||||
None,
|
||||
vec!["*".into()],
|
||||
);
|
||||
assert!(ch.is_user_allowed("ou_anyone"));
|
||||
|
|
@ -353,7 +917,7 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn lark_user_denied_empty() {
|
||||
let ch = LarkChannel::new("id".into(), "secret".into(), "token".into(), 9898, vec![]);
|
||||
let ch = LarkChannel::new("id".into(), "secret".into(), "token".into(), None, vec![]);
|
||||
assert!(!ch.is_user_allowed("ou_anyone"));
|
||||
}
|
||||
|
||||
|
|
@ -426,7 +990,7 @@ mod tests {
|
|||
"id".into(),
|
||||
"secret".into(),
|
||||
"token".into(),
|
||||
9898,
|
||||
None,
|
||||
vec!["*".into()],
|
||||
);
|
||||
let payload = serde_json::json!({
|
||||
|
|
@ -451,7 +1015,7 @@ mod tests {
|
|||
"id".into(),
|
||||
"secret".into(),
|
||||
"token".into(),
|
||||
9898,
|
||||
None,
|
||||
vec!["*".into()],
|
||||
);
|
||||
let payload = serde_json::json!({
|
||||
|
|
@ -488,7 +1052,7 @@ mod tests {
|
|||
"id".into(),
|
||||
"secret".into(),
|
||||
"token".into(),
|
||||
9898,
|
||||
None,
|
||||
vec!["*".into()],
|
||||
);
|
||||
let payload = serde_json::json!({
|
||||
|
|
@ -512,7 +1076,7 @@ mod tests {
|
|||
"id".into(),
|
||||
"secret".into(),
|
||||
"token".into(),
|
||||
9898,
|
||||
None,
|
||||
vec!["*".into()],
|
||||
);
|
||||
let payload = serde_json::json!({
|
||||
|
|
@ -550,7 +1114,7 @@ mod tests {
|
|||
"id".into(),
|
||||
"secret".into(),
|
||||
"token".into(),
|
||||
9898,
|
||||
None,
|
||||
vec!["*".into()],
|
||||
);
|
||||
let payload = serde_json::json!({
|
||||
|
|
@ -571,7 +1135,7 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn lark_config_serde() {
|
||||
use crate::config::schema::LarkConfig;
|
||||
use crate::config::schema::{LarkConfig, LarkReceiveMode};
|
||||
let lc = LarkConfig {
|
||||
app_id: "cli_app123".into(),
|
||||
app_secret: "secret456".into(),
|
||||
|
|
@ -579,6 +1143,8 @@ mod tests {
|
|||
verification_token: Some("vtoken789".into()),
|
||||
allowed_users: vec!["ou_user1".into(), "ou_user2".into()],
|
||||
use_feishu: false,
|
||||
receive_mode: LarkReceiveMode::default(),
|
||||
port: None,
|
||||
};
|
||||
let json = serde_json::to_string(&lc).unwrap();
|
||||
let parsed: LarkConfig = serde_json::from_str(&json).unwrap();
|
||||
|
|
@ -590,7 +1156,7 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn lark_config_toml_roundtrip() {
|
||||
use crate::config::schema::LarkConfig;
|
||||
use crate::config::schema::{LarkConfig, LarkReceiveMode};
|
||||
let lc = LarkConfig {
|
||||
app_id: "app".into(),
|
||||
app_secret: "secret".into(),
|
||||
|
|
@ -598,6 +1164,8 @@ mod tests {
|
|||
verification_token: Some("tok".into()),
|
||||
allowed_users: vec!["*".into()],
|
||||
use_feishu: false,
|
||||
receive_mode: LarkReceiveMode::Webhook,
|
||||
port: Some(9898),
|
||||
};
|
||||
let toml_str = toml::to_string(&lc).unwrap();
|
||||
let parsed: LarkConfig = toml::from_str(&toml_str).unwrap();
|
||||
|
|
@ -608,11 +1176,36 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn lark_config_defaults_optional_fields() {
|
||||
use crate::config::schema::LarkConfig;
|
||||
use crate::config::schema::{LarkConfig, LarkReceiveMode};
|
||||
let json = r#"{"app_id":"a","app_secret":"s"}"#;
|
||||
let parsed: LarkConfig = serde_json::from_str(json).unwrap();
|
||||
assert!(parsed.verification_token.is_none());
|
||||
assert!(parsed.allowed_users.is_empty());
|
||||
assert_eq!(parsed.receive_mode, LarkReceiveMode::Websocket);
|
||||
assert!(parsed.port.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lark_from_config_preserves_mode_and_region() {
|
||||
use crate::config::schema::{LarkConfig, LarkReceiveMode};
|
||||
|
||||
let cfg = LarkConfig {
|
||||
app_id: "cli_app123".into(),
|
||||
app_secret: "secret456".into(),
|
||||
encrypt_key: None,
|
||||
verification_token: Some("vtoken789".into()),
|
||||
allowed_users: vec!["*".into()],
|
||||
use_feishu: false,
|
||||
receive_mode: LarkReceiveMode::Webhook,
|
||||
port: Some(9898),
|
||||
};
|
||||
|
||||
let ch = LarkChannel::from_config(&cfg);
|
||||
|
||||
assert_eq!(ch.api_base(), LARK_BASE_URL);
|
||||
assert_eq!(ch.ws_base(), LARK_WS_BASE_URL);
|
||||
assert_eq!(ch.receive_mode, LarkReceiveMode::Webhook);
|
||||
assert_eq!(ch.port, Some(9898));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -622,7 +1215,7 @@ mod tests {
|
|||
"id".into(),
|
||||
"secret".into(),
|
||||
"token".into(),
|
||||
9898,
|
||||
None,
|
||||
vec!["*".into()],
|
||||
);
|
||||
let payload = serde_json::json!({
|
||||
|
|
|
|||
|
|
@ -230,6 +230,7 @@ impl Channel for MatrixChannel {
|
|||
let msg = ChannelMessage {
|
||||
id: format!("mx_{}", chrono::Utc::now().timestamp_millis()),
|
||||
sender: event.sender.clone(),
|
||||
reply_target: event.sender.clone(),
|
||||
content: body.clone(),
|
||||
channel: "matrix".to_string(),
|
||||
timestamp: std::time::SystemTime::now()
|
||||
|
|
|
|||
|
|
@ -69,10 +69,19 @@ fn conversation_memory_key(msg: &traits::ChannelMessage) -> String {
|
|||
format!("{}_{}_{}", msg.channel, msg.sender, msg.id)
|
||||
}
|
||||
|
||||
fn channel_delivery_instructions(channel_name: &str) -> Option<&'static str> {
|
||||
match channel_name {
|
||||
"telegram" => Some(
|
||||
"When responding on Telegram, include media markers for files or URLs that should be sent as attachments. Use one marker per attachment with this exact syntax: [IMAGE:<path-or-url>], [DOCUMENT:<path-or-url>], [VIDEO:<path-or-url>], [AUDIO:<path-or-url>], or [VOICE:<path-or-url>]. Keep normal user-facing text outside markers and never wrap markers in code fences.",
|
||||
),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
async fn build_memory_context(mem: &dyn Memory, user_msg: &str) -> String {
|
||||
let mut context = String::new();
|
||||
|
||||
if let Ok(entries) = mem.recall(user_msg, 5).await {
|
||||
if let Ok(entries) = mem.recall(user_msg, 5, None).await {
|
||||
if !entries.is_empty() {
|
||||
context.push_str("[Memory context]\n");
|
||||
for entry in &entries {
|
||||
|
|
@ -158,6 +167,7 @@ async fn process_channel_message(ctx: Arc<ChannelRuntimeContext>, msg: traits::C
|
|||
&autosave_key,
|
||||
&msg.content,
|
||||
crate::memory::MemoryCategory::Conversation,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
|
@ -171,7 +181,7 @@ async fn process_channel_message(ctx: Arc<ChannelRuntimeContext>, msg: traits::C
|
|||
let target_channel = ctx.channels_by_name.get(&msg.channel).cloned();
|
||||
|
||||
if let Some(channel) = target_channel.as_ref() {
|
||||
if let Err(e) = channel.start_typing(&msg.sender).await {
|
||||
if let Err(e) = channel.start_typing(&msg.reply_target).await {
|
||||
tracing::debug!("Failed to start typing on {}: {e}", channel.name());
|
||||
}
|
||||
}
|
||||
|
|
@ -184,6 +194,10 @@ async fn process_channel_message(ctx: Arc<ChannelRuntimeContext>, msg: traits::C
|
|||
ChatMessage::user(&enriched_message),
|
||||
];
|
||||
|
||||
if let Some(instructions) = channel_delivery_instructions(&msg.channel) {
|
||||
history.push(ChatMessage::system(instructions));
|
||||
}
|
||||
|
||||
let llm_result = tokio::time::timeout(
|
||||
Duration::from_secs(CHANNEL_MESSAGE_TIMEOUT_SECS),
|
||||
run_tool_call_loop(
|
||||
|
|
@ -200,7 +214,7 @@ async fn process_channel_message(ctx: Arc<ChannelRuntimeContext>, msg: traits::C
|
|||
.await;
|
||||
|
||||
if let Some(channel) = target_channel.as_ref() {
|
||||
if let Err(e) = channel.stop_typing(&msg.sender).await {
|
||||
if let Err(e) = channel.stop_typing(&msg.reply_target).await {
|
||||
tracing::debug!("Failed to stop typing on {}: {e}", channel.name());
|
||||
}
|
||||
}
|
||||
|
|
@ -224,7 +238,9 @@ async fn process_channel_message(ctx: Arc<ChannelRuntimeContext>, msg: traits::C
|
|||
started_at.elapsed().as_millis()
|
||||
);
|
||||
if let Some(channel) = target_channel.as_ref() {
|
||||
let _ = channel.send(&format!("⚠️ Error: {e}"), &msg.sender).await;
|
||||
let _ = channel
|
||||
.send(&format!("⚠️ Error: {e}"), &msg.reply_target)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
|
|
@ -241,7 +257,7 @@ async fn process_channel_message(ctx: Arc<ChannelRuntimeContext>, msg: traits::C
|
|||
let _ = channel
|
||||
.send(
|
||||
"⚠️ Request timed out while waiting for the model. Please try again.",
|
||||
&msg.sender,
|
||||
&msg.reply_target,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
|
@ -483,6 +499,16 @@ pub fn build_system_prompt(
|
|||
std::env::consts::OS,
|
||||
);
|
||||
|
||||
// ── 8. Channel Capabilities ─────────────────────────────────────
|
||||
prompt.push_str("## Channel Capabilities\n\n");
|
||||
prompt.push_str(
|
||||
"- You are running as a Discord bot. You CAN and do send messages to Discord channels.\n",
|
||||
);
|
||||
prompt.push_str("- When someone messages you on Discord, your response is automatically sent back to Discord.\n");
|
||||
prompt.push_str("- You do NOT need to ask permission to respond — just respond directly.\n");
|
||||
prompt.push_str("- NEVER repeat, describe, or echo credentials, tokens, API keys, or secrets in your responses.\n");
|
||||
prompt.push_str("- If a tool output contains credentials, they have already been redacted — do not mention them.\n\n");
|
||||
|
||||
if prompt.is_empty() {
|
||||
"You are ZeroClaw, a fast and efficient AI assistant built in Rust. Be helpful, concise, and direct.".to_string()
|
||||
} else {
|
||||
|
|
@ -619,6 +645,7 @@ pub async fn doctor_channels(config: Config) -> Result<()> {
|
|||
dc.guild_id.clone(),
|
||||
dc.allowed_users.clone(),
|
||||
dc.listen_to_bots,
|
||||
dc.mention_only,
|
||||
)),
|
||||
));
|
||||
}
|
||||
|
|
@ -672,32 +699,23 @@ pub async fn doctor_channels(config: Config) -> Result<()> {
|
|||
if let Some(ref irc) = config.channels_config.irc {
|
||||
channels.push((
|
||||
"IRC",
|
||||
Arc::new(IrcChannel::new(
|
||||
irc.server.clone(),
|
||||
irc.port,
|
||||
irc.nickname.clone(),
|
||||
irc.username.clone(),
|
||||
irc.channels.clone(),
|
||||
irc.allowed_users.clone(),
|
||||
irc.server_password.clone(),
|
||||
irc.nickserv_password.clone(),
|
||||
irc.sasl_password.clone(),
|
||||
irc.verify_tls.unwrap_or(true),
|
||||
)),
|
||||
Arc::new(IrcChannel::new(irc::IrcChannelConfig {
|
||||
server: irc.server.clone(),
|
||||
port: irc.port,
|
||||
nickname: irc.nickname.clone(),
|
||||
username: irc.username.clone(),
|
||||
channels: irc.channels.clone(),
|
||||
allowed_users: irc.allowed_users.clone(),
|
||||
server_password: irc.server_password.clone(),
|
||||
nickserv_password: irc.nickserv_password.clone(),
|
||||
sasl_password: irc.sasl_password.clone(),
|
||||
verify_tls: irc.verify_tls.unwrap_or(true),
|
||||
})),
|
||||
));
|
||||
}
|
||||
|
||||
if let Some(ref lk) = config.channels_config.lark {
|
||||
channels.push((
|
||||
"Lark",
|
||||
Arc::new(LarkChannel::new(
|
||||
lk.app_id.clone(),
|
||||
lk.app_secret.clone(),
|
||||
lk.verification_token.clone().unwrap_or_default(),
|
||||
9898,
|
||||
lk.allowed_users.clone(),
|
||||
)),
|
||||
));
|
||||
channels.push(("Lark", Arc::new(LarkChannel::from_config(lk))));
|
||||
}
|
||||
|
||||
if let Some(ref dt) = config.channels_config.dingtalk {
|
||||
|
|
@ -762,6 +780,7 @@ pub async fn start_channels(config: Config) -> Result<()> {
|
|||
let provider: Arc<dyn Provider> = Arc::from(providers::create_resilient_provider(
|
||||
&provider_name,
|
||||
config.api_key.as_deref(),
|
||||
config.api_url.as_deref(),
|
||||
&config.reliability,
|
||||
)?);
|
||||
|
||||
|
|
@ -860,6 +879,10 @@ pub async fn start_channels(config: Config) -> Result<()> {
|
|||
"schedule",
|
||||
"Manage scheduled tasks (create/list/get/cancel/pause/resume). Supports recurring cron and one-shot delays.",
|
||||
));
|
||||
tool_descs.push((
|
||||
"pushover",
|
||||
"Send a Pushover notification to your device. Requires PUSHOVER_TOKEN and PUSHOVER_USER_KEY in .env file.",
|
||||
));
|
||||
if !config.agents.is_empty() {
|
||||
tool_descs.push((
|
||||
"delegate",
|
||||
|
|
@ -909,6 +932,7 @@ pub async fn start_channels(config: Config) -> Result<()> {
|
|||
dc.guild_id.clone(),
|
||||
dc.allowed_users.clone(),
|
||||
dc.listen_to_bots,
|
||||
dc.mention_only,
|
||||
)));
|
||||
}
|
||||
|
||||
|
|
@ -947,28 +971,22 @@ pub async fn start_channels(config: Config) -> Result<()> {
|
|||
}
|
||||
|
||||
if let Some(ref irc) = config.channels_config.irc {
|
||||
channels.push(Arc::new(IrcChannel::new(
|
||||
irc.server.clone(),
|
||||
irc.port,
|
||||
irc.nickname.clone(),
|
||||
irc.username.clone(),
|
||||
irc.channels.clone(),
|
||||
irc.allowed_users.clone(),
|
||||
irc.server_password.clone(),
|
||||
irc.nickserv_password.clone(),
|
||||
irc.sasl_password.clone(),
|
||||
irc.verify_tls.unwrap_or(true),
|
||||
)));
|
||||
channels.push(Arc::new(IrcChannel::new(irc::IrcChannelConfig {
|
||||
server: irc.server.clone(),
|
||||
port: irc.port,
|
||||
nickname: irc.nickname.clone(),
|
||||
username: irc.username.clone(),
|
||||
channels: irc.channels.clone(),
|
||||
allowed_users: irc.allowed_users.clone(),
|
||||
server_password: irc.server_password.clone(),
|
||||
nickserv_password: irc.nickserv_password.clone(),
|
||||
sasl_password: irc.sasl_password.clone(),
|
||||
verify_tls: irc.verify_tls.unwrap_or(true),
|
||||
})));
|
||||
}
|
||||
|
||||
if let Some(ref lk) = config.channels_config.lark {
|
||||
channels.push(Arc::new(LarkChannel::new(
|
||||
lk.app_id.clone(),
|
||||
lk.app_secret.clone(),
|
||||
lk.verification_token.clone().unwrap_or_default(),
|
||||
9898,
|
||||
lk.allowed_users.clone(),
|
||||
)));
|
||||
channels.push(Arc::new(LarkChannel::from_config(lk)));
|
||||
}
|
||||
|
||||
if let Some(ref dt) = config.channels_config.dingtalk {
|
||||
|
|
@ -1242,6 +1260,7 @@ mod tests {
|
|||
traits::ChannelMessage {
|
||||
id: "msg-1".to_string(),
|
||||
sender: "alice".to_string(),
|
||||
reply_target: "chat-42".to_string(),
|
||||
content: "What is the BTC price now?".to_string(),
|
||||
channel: "test-channel".to_string(),
|
||||
timestamp: 1,
|
||||
|
|
@ -1251,6 +1270,7 @@ mod tests {
|
|||
|
||||
let sent_messages = channel_impl.sent_messages.lock().await;
|
||||
assert_eq!(sent_messages.len(), 1);
|
||||
assert!(sent_messages[0].starts_with("chat-42:"));
|
||||
assert!(sent_messages[0].contains("BTC is currently around"));
|
||||
assert!(!sent_messages[0].contains("\"tool_calls\""));
|
||||
assert!(!sent_messages[0].contains("mock_price"));
|
||||
|
|
@ -1269,6 +1289,7 @@ mod tests {
|
|||
_key: &str,
|
||||
_content: &str,
|
||||
_category: crate::memory::MemoryCategory,
|
||||
_session_id: Option<&str>,
|
||||
) -> anyhow::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
|
@ -1277,6 +1298,7 @@ mod tests {
|
|||
&self,
|
||||
_query: &str,
|
||||
_limit: usize,
|
||||
_session_id: Option<&str>,
|
||||
) -> anyhow::Result<Vec<crate::memory::MemoryEntry>> {
|
||||
Ok(Vec::new())
|
||||
}
|
||||
|
|
@ -1288,6 +1310,7 @@ mod tests {
|
|||
async fn list(
|
||||
&self,
|
||||
_category: Option<&crate::memory::MemoryCategory>,
|
||||
_session_id: Option<&str>,
|
||||
) -> anyhow::Result<Vec<crate::memory::MemoryEntry>> {
|
||||
Ok(Vec::new())
|
||||
}
|
||||
|
|
@ -1331,6 +1354,7 @@ mod tests {
|
|||
tx.send(traits::ChannelMessage {
|
||||
id: "1".to_string(),
|
||||
sender: "alice".to_string(),
|
||||
reply_target: "alice".to_string(),
|
||||
content: "hello".to_string(),
|
||||
channel: "test-channel".to_string(),
|
||||
timestamp: 1,
|
||||
|
|
@ -1340,6 +1364,7 @@ mod tests {
|
|||
tx.send(traits::ChannelMessage {
|
||||
id: "2".to_string(),
|
||||
sender: "bob".to_string(),
|
||||
reply_target: "bob".to_string(),
|
||||
content: "world".to_string(),
|
||||
channel: "test-channel".to_string(),
|
||||
timestamp: 2,
|
||||
|
|
@ -1570,6 +1595,25 @@ mod tests {
|
|||
assert!(truncated.is_char_boundary(truncated.len()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn prompt_contains_channel_capabilities() {
|
||||
let ws = make_workspace();
|
||||
let prompt = build_system_prompt(ws.path(), "model", &[], &[], None, None);
|
||||
|
||||
assert!(
|
||||
prompt.contains("## Channel Capabilities"),
|
||||
"missing Channel Capabilities section"
|
||||
);
|
||||
assert!(
|
||||
prompt.contains("running as a Discord bot"),
|
||||
"missing Discord context"
|
||||
);
|
||||
assert!(
|
||||
prompt.contains("NEVER repeat, describe, or echo credentials"),
|
||||
"missing security instruction"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn prompt_workspace_path() {
|
||||
let ws = make_workspace();
|
||||
|
|
@ -1583,6 +1627,7 @@ mod tests {
|
|||
let msg = traits::ChannelMessage {
|
||||
id: "msg_abc123".into(),
|
||||
sender: "U123".into(),
|
||||
reply_target: "C456".into(),
|
||||
content: "hello".into(),
|
||||
channel: "slack".into(),
|
||||
timestamp: 1,
|
||||
|
|
@ -1596,6 +1641,7 @@ mod tests {
|
|||
let msg1 = traits::ChannelMessage {
|
||||
id: "msg_1".into(),
|
||||
sender: "U123".into(),
|
||||
reply_target: "C456".into(),
|
||||
content: "first".into(),
|
||||
channel: "slack".into(),
|
||||
timestamp: 1,
|
||||
|
|
@ -1603,6 +1649,7 @@ mod tests {
|
|||
let msg2 = traits::ChannelMessage {
|
||||
id: "msg_2".into(),
|
||||
sender: "U123".into(),
|
||||
reply_target: "C456".into(),
|
||||
content: "second".into(),
|
||||
channel: "slack".into(),
|
||||
timestamp: 2,
|
||||
|
|
@ -1622,6 +1669,7 @@ mod tests {
|
|||
let msg1 = traits::ChannelMessage {
|
||||
id: "msg_1".into(),
|
||||
sender: "U123".into(),
|
||||
reply_target: "C456".into(),
|
||||
content: "I'm Paul".into(),
|
||||
channel: "slack".into(),
|
||||
timestamp: 1,
|
||||
|
|
@ -1629,6 +1677,7 @@ mod tests {
|
|||
let msg2 = traits::ChannelMessage {
|
||||
id: "msg_2".into(),
|
||||
sender: "U123".into(),
|
||||
reply_target: "C456".into(),
|
||||
content: "I'm 45".into(),
|
||||
channel: "slack".into(),
|
||||
timestamp: 2,
|
||||
|
|
@ -1638,6 +1687,7 @@ mod tests {
|
|||
&conversation_memory_key(&msg1),
|
||||
&msg1.content,
|
||||
MemoryCategory::Conversation,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
|
@ -1645,13 +1695,14 @@ mod tests {
|
|||
&conversation_memory_key(&msg2),
|
||||
&msg2.content,
|
||||
MemoryCategory::Conversation,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(mem.count().await.unwrap(), 2);
|
||||
|
||||
let recalled = mem.recall("45", 5).await.unwrap();
|
||||
let recalled = mem.recall("45", 5, None).await.unwrap();
|
||||
assert!(recalled.iter().any(|entry| entry.content.contains("45")));
|
||||
}
|
||||
|
||||
|
|
@ -1659,7 +1710,7 @@ mod tests {
|
|||
async fn build_memory_context_includes_recalled_entries() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let mem = SqliteMemory::new(tmp.path()).unwrap();
|
||||
mem.store("age_fact", "Age is 45", MemoryCategory::Conversation)
|
||||
mem.store("age_fact", "Age is 45", MemoryCategory::Conversation, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
|
|
|
|||
|
|
@ -161,6 +161,7 @@ impl Channel for SlackChannel {
|
|||
let channel_msg = ChannelMessage {
|
||||
id: format!("slack_{channel_id}_{ts}"),
|
||||
sender: user.to_string(),
|
||||
reply_target: channel_id.clone(),
|
||||
content: text.to_string(),
|
||||
channel: "slack".to_string(),
|
||||
timestamp: std::time::SystemTime::now()
|
||||
|
|
|
|||
|
|
@ -51,6 +51,133 @@ fn split_message_for_telegram(message: &str) -> Vec<String> {
|
|||
chunks
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
enum TelegramAttachmentKind {
|
||||
Image,
|
||||
Document,
|
||||
Video,
|
||||
Audio,
|
||||
Voice,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
struct TelegramAttachment {
|
||||
kind: TelegramAttachmentKind,
|
||||
target: String,
|
||||
}
|
||||
|
||||
impl TelegramAttachmentKind {
|
||||
fn from_marker(marker: &str) -> Option<Self> {
|
||||
match marker.trim().to_ascii_uppercase().as_str() {
|
||||
"IMAGE" | "PHOTO" => Some(Self::Image),
|
||||
"DOCUMENT" | "FILE" => Some(Self::Document),
|
||||
"VIDEO" => Some(Self::Video),
|
||||
"AUDIO" => Some(Self::Audio),
|
||||
"VOICE" => Some(Self::Voice),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn is_http_url(target: &str) -> bool {
|
||||
target.starts_with("http://") || target.starts_with("https://")
|
||||
}
|
||||
|
||||
fn infer_attachment_kind_from_target(target: &str) -> Option<TelegramAttachmentKind> {
|
||||
let normalized = target
|
||||
.split('?')
|
||||
.next()
|
||||
.unwrap_or(target)
|
||||
.split('#')
|
||||
.next()
|
||||
.unwrap_or(target);
|
||||
|
||||
let extension = Path::new(normalized)
|
||||
.extension()
|
||||
.and_then(|ext| ext.to_str())?
|
||||
.to_ascii_lowercase();
|
||||
|
||||
match extension.as_str() {
|
||||
"png" | "jpg" | "jpeg" | "gif" | "webp" | "bmp" => Some(TelegramAttachmentKind::Image),
|
||||
"mp4" | "mov" | "mkv" | "avi" | "webm" => Some(TelegramAttachmentKind::Video),
|
||||
"mp3" | "m4a" | "wav" | "flac" => Some(TelegramAttachmentKind::Audio),
|
||||
"ogg" | "oga" | "opus" => Some(TelegramAttachmentKind::Voice),
|
||||
"pdf" | "txt" | "md" | "csv" | "json" | "zip" | "tar" | "gz" | "doc" | "docx" | "xls"
|
||||
| "xlsx" | "ppt" | "pptx" => Some(TelegramAttachmentKind::Document),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_path_only_attachment(message: &str) -> Option<TelegramAttachment> {
|
||||
let trimmed = message.trim();
|
||||
if trimmed.is_empty() || trimmed.contains('\n') {
|
||||
return None;
|
||||
}
|
||||
|
||||
let candidate = trimmed.trim_matches(|c| matches!(c, '`' | '"' | '\''));
|
||||
if candidate.chars().any(char::is_whitespace) {
|
||||
return None;
|
||||
}
|
||||
|
||||
let candidate = candidate.strip_prefix("file://").unwrap_or(candidate);
|
||||
let kind = infer_attachment_kind_from_target(candidate)?;
|
||||
|
||||
if !is_http_url(candidate) && !Path::new(candidate).exists() {
|
||||
return None;
|
||||
}
|
||||
|
||||
Some(TelegramAttachment {
|
||||
kind,
|
||||
target: candidate.to_string(),
|
||||
})
|
||||
}
|
||||
|
||||
fn parse_attachment_markers(message: &str) -> (String, Vec<TelegramAttachment>) {
|
||||
let mut cleaned = String::with_capacity(message.len());
|
||||
let mut attachments = Vec::new();
|
||||
let mut cursor = 0;
|
||||
|
||||
while cursor < message.len() {
|
||||
let Some(open_rel) = message[cursor..].find('[') else {
|
||||
cleaned.push_str(&message[cursor..]);
|
||||
break;
|
||||
};
|
||||
|
||||
let open = cursor + open_rel;
|
||||
cleaned.push_str(&message[cursor..open]);
|
||||
|
||||
let Some(close_rel) = message[open..].find(']') else {
|
||||
cleaned.push_str(&message[open..]);
|
||||
break;
|
||||
};
|
||||
|
||||
let close = open + close_rel;
|
||||
let marker = &message[open + 1..close];
|
||||
|
||||
let parsed = marker.split_once(':').and_then(|(kind, target)| {
|
||||
let kind = TelegramAttachmentKind::from_marker(kind)?;
|
||||
let target = target.trim();
|
||||
if target.is_empty() {
|
||||
return None;
|
||||
}
|
||||
Some(TelegramAttachment {
|
||||
kind,
|
||||
target: target.to_string(),
|
||||
})
|
||||
});
|
||||
|
||||
if let Some(attachment) = parsed {
|
||||
attachments.push(attachment);
|
||||
} else {
|
||||
cleaned.push_str(&message[open..=close]);
|
||||
}
|
||||
|
||||
cursor = close + 1;
|
||||
}
|
||||
|
||||
(cleaned.trim().to_string(), attachments)
|
||||
}
|
||||
|
||||
/// Telegram channel — long-polls the Bot API for updates
|
||||
pub struct TelegramChannel {
|
||||
bot_token: String,
|
||||
|
|
@ -82,6 +209,216 @@ impl TelegramChannel {
|
|||
identities.into_iter().any(|id| self.is_user_allowed(id))
|
||||
}
|
||||
|
||||
fn parse_update_message(&self, update: &serde_json::Value) -> Option<ChannelMessage> {
|
||||
let message = update.get("message")?;
|
||||
|
||||
let text = message.get("text").and_then(serde_json::Value::as_str)?;
|
||||
|
||||
let username = message
|
||||
.get("from")
|
||||
.and_then(|from| from.get("username"))
|
||||
.and_then(serde_json::Value::as_str)
|
||||
.unwrap_or("unknown")
|
||||
.to_string();
|
||||
|
||||
let user_id = message
|
||||
.get("from")
|
||||
.and_then(|from| from.get("id"))
|
||||
.and_then(serde_json::Value::as_i64)
|
||||
.map(|id| id.to_string());
|
||||
|
||||
let sender_identity = if username == "unknown" {
|
||||
user_id.clone().unwrap_or_else(|| "unknown".to_string())
|
||||
} else {
|
||||
username.clone()
|
||||
};
|
||||
|
||||
let mut identities = vec![username.as_str()];
|
||||
if let Some(id) = user_id.as_deref() {
|
||||
identities.push(id);
|
||||
}
|
||||
|
||||
if !self.is_any_user_allowed(identities.iter().copied()) {
|
||||
tracing::warn!(
|
||||
"Telegram: ignoring message from unauthorized user: username={username}, user_id={}. \
|
||||
Allowlist Telegram @username or numeric user ID, then run `zeroclaw onboard --channels-only`.",
|
||||
user_id.as_deref().unwrap_or("unknown")
|
||||
);
|
||||
return None;
|
||||
}
|
||||
|
||||
let chat_id = message
|
||||
.get("chat")
|
||||
.and_then(|chat| chat.get("id"))
|
||||
.and_then(serde_json::Value::as_i64)
|
||||
.map(|id| id.to_string())?;
|
||||
|
||||
let message_id = message
|
||||
.get("message_id")
|
||||
.and_then(serde_json::Value::as_i64)
|
||||
.unwrap_or(0);
|
||||
|
||||
Some(ChannelMessage {
|
||||
id: format!("telegram_{chat_id}_{message_id}"),
|
||||
sender: sender_identity,
|
||||
reply_target: chat_id,
|
||||
content: text.to_string(),
|
||||
channel: "telegram".to_string(),
|
||||
timestamp: std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs(),
|
||||
})
|
||||
}
|
||||
|
||||
async fn send_text_chunks(&self, message: &str, chat_id: &str) -> anyhow::Result<()> {
|
||||
let chunks = split_message_for_telegram(message);
|
||||
|
||||
for (index, chunk) in chunks.iter().enumerate() {
|
||||
let text = if chunks.len() > 1 {
|
||||
if index == 0 {
|
||||
format!("{chunk}\n\n(continues...)")
|
||||
} else if index == chunks.len() - 1 {
|
||||
format!("(continued)\n\n{chunk}")
|
||||
} else {
|
||||
format!("(continued)\n\n{chunk}\n\n(continues...)")
|
||||
}
|
||||
} else {
|
||||
chunk.to_string()
|
||||
};
|
||||
|
||||
let markdown_body = serde_json::json!({
|
||||
"chat_id": chat_id,
|
||||
"text": text,
|
||||
"parse_mode": "Markdown"
|
||||
});
|
||||
|
||||
let markdown_resp = self
|
||||
.client
|
||||
.post(self.api_url("sendMessage"))
|
||||
.json(&markdown_body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if markdown_resp.status().is_success() {
|
||||
if index < chunks.len() - 1 {
|
||||
tokio::time::sleep(Duration::from_millis(100)).await;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
let markdown_status = markdown_resp.status();
|
||||
let markdown_err = markdown_resp.text().await.unwrap_or_default();
|
||||
tracing::warn!(
|
||||
status = ?markdown_status,
|
||||
"Telegram sendMessage with Markdown failed; retrying without parse_mode"
|
||||
);
|
||||
|
||||
let plain_body = serde_json::json!({
|
||||
"chat_id": chat_id,
|
||||
"text": text,
|
||||
});
|
||||
let plain_resp = self
|
||||
.client
|
||||
.post(self.api_url("sendMessage"))
|
||||
.json(&plain_body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !plain_resp.status().is_success() {
|
||||
let plain_status = plain_resp.status();
|
||||
let plain_err = plain_resp.text().await.unwrap_or_default();
|
||||
anyhow::bail!(
|
||||
"Telegram sendMessage failed (markdown {}: {}; plain {}: {})",
|
||||
markdown_status,
|
||||
markdown_err,
|
||||
plain_status,
|
||||
plain_err
|
||||
);
|
||||
}
|
||||
|
||||
if index < chunks.len() - 1 {
|
||||
tokio::time::sleep(Duration::from_millis(100)).await;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn send_media_by_url(
|
||||
&self,
|
||||
method: &str,
|
||||
media_field: &str,
|
||||
chat_id: &str,
|
||||
url: &str,
|
||||
caption: Option<&str>,
|
||||
) -> anyhow::Result<()> {
|
||||
let mut body = serde_json::json!({
|
||||
"chat_id": chat_id,
|
||||
});
|
||||
body[media_field] = serde_json::Value::String(url.to_string());
|
||||
|
||||
if let Some(cap) = caption {
|
||||
body["caption"] = serde_json::Value::String(cap.to_string());
|
||||
}
|
||||
|
||||
let resp = self
|
||||
.client
|
||||
.post(self.api_url(method))
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let err = resp.text().await?;
|
||||
anyhow::bail!("Telegram {method} by URL failed: {err}");
|
||||
}
|
||||
|
||||
tracing::info!("Telegram {method} sent to {chat_id}: {url}");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn send_attachment(
|
||||
&self,
|
||||
chat_id: &str,
|
||||
attachment: &TelegramAttachment,
|
||||
) -> anyhow::Result<()> {
|
||||
let target = attachment.target.trim();
|
||||
|
||||
if is_http_url(target) {
|
||||
return match attachment.kind {
|
||||
TelegramAttachmentKind::Image => {
|
||||
self.send_photo_by_url(chat_id, target, None).await
|
||||
}
|
||||
TelegramAttachmentKind::Document => {
|
||||
self.send_document_by_url(chat_id, target, None).await
|
||||
}
|
||||
TelegramAttachmentKind::Video => {
|
||||
self.send_video_by_url(chat_id, target, None).await
|
||||
}
|
||||
TelegramAttachmentKind::Audio => {
|
||||
self.send_audio_by_url(chat_id, target, None).await
|
||||
}
|
||||
TelegramAttachmentKind::Voice => {
|
||||
self.send_voice_by_url(chat_id, target, None).await
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
let path = Path::new(target);
|
||||
if !path.exists() {
|
||||
anyhow::bail!("Telegram attachment path not found: {target}");
|
||||
}
|
||||
|
||||
match attachment.kind {
|
||||
TelegramAttachmentKind::Image => self.send_photo(chat_id, path, None).await,
|
||||
TelegramAttachmentKind::Document => self.send_document(chat_id, path, None).await,
|
||||
TelegramAttachmentKind::Video => self.send_video(chat_id, path, None).await,
|
||||
TelegramAttachmentKind::Audio => self.send_audio(chat_id, path, None).await,
|
||||
TelegramAttachmentKind::Voice => self.send_voice(chat_id, path, None).await,
|
||||
}
|
||||
}
|
||||
|
||||
/// Send a document/file to a Telegram chat
|
||||
pub async fn send_document(
|
||||
&self,
|
||||
|
|
@ -408,6 +745,39 @@ impl TelegramChannel {
|
|||
tracing::info!("Telegram photo (URL) sent to {chat_id}: {url}");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Send a video by URL (Telegram will download it)
|
||||
pub async fn send_video_by_url(
|
||||
&self,
|
||||
chat_id: &str,
|
||||
url: &str,
|
||||
caption: Option<&str>,
|
||||
) -> anyhow::Result<()> {
|
||||
self.send_media_by_url("sendVideo", "video", chat_id, url, caption)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Send an audio file by URL (Telegram will download it)
|
||||
pub async fn send_audio_by_url(
|
||||
&self,
|
||||
chat_id: &str,
|
||||
url: &str,
|
||||
caption: Option<&str>,
|
||||
) -> anyhow::Result<()> {
|
||||
self.send_media_by_url("sendAudio", "audio", chat_id, url, caption)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Send a voice message by URL (Telegram will download it)
|
||||
pub async fn send_voice_by_url(
|
||||
&self,
|
||||
chat_id: &str,
|
||||
url: &str,
|
||||
caption: Option<&str>,
|
||||
) -> anyhow::Result<()> {
|
||||
self.send_media_by_url("sendVoice", "voice", chat_id, url, caption)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
|
|
@ -417,82 +787,27 @@ impl Channel for TelegramChannel {
|
|||
}
|
||||
|
||||
async fn send(&self, message: &str, chat_id: &str) -> anyhow::Result<()> {
|
||||
// Split message if it exceeds Telegram's 4096 character limit
|
||||
let chunks = split_message_for_telegram(message);
|
||||
let (text_without_markers, attachments) = parse_attachment_markers(message);
|
||||
|
||||
for (i, chunk) in chunks.iter().enumerate() {
|
||||
// Add continuation marker for multi-part messages
|
||||
let text = if chunks.len() > 1 {
|
||||
if i == 0 {
|
||||
format!("{chunk}\n\n(continues...)")
|
||||
} else if i == chunks.len() - 1 {
|
||||
format!("(continued)\n\n{chunk}")
|
||||
} else {
|
||||
format!("(continued)\n\n{chunk}\n\n(continues...)")
|
||||
}
|
||||
} else {
|
||||
chunk.to_string()
|
||||
};
|
||||
|
||||
let markdown_body = serde_json::json!({
|
||||
"chat_id": chat_id,
|
||||
"text": text,
|
||||
"parse_mode": "Markdown"
|
||||
});
|
||||
|
||||
let markdown_resp = self
|
||||
.client
|
||||
.post(self.api_url("sendMessage"))
|
||||
.json(&markdown_body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if markdown_resp.status().is_success() {
|
||||
// Small delay between chunks to avoid rate limiting
|
||||
if i < chunks.len() - 1 {
|
||||
tokio::time::sleep(Duration::from_millis(100)).await;
|
||||
}
|
||||
continue;
|
||||
if !attachments.is_empty() {
|
||||
if !text_without_markers.is_empty() {
|
||||
self.send_text_chunks(&text_without_markers, chat_id)
|
||||
.await?;
|
||||
}
|
||||
|
||||
let markdown_status = markdown_resp.status();
|
||||
let markdown_err = markdown_resp.text().await.unwrap_or_default();
|
||||
tracing::warn!(
|
||||
status = ?markdown_status,
|
||||
"Telegram sendMessage with Markdown failed; retrying without parse_mode"
|
||||
);
|
||||
|
||||
// Retry without parse_mode as a compatibility fallback.
|
||||
let plain_body = serde_json::json!({
|
||||
"chat_id": chat_id,
|
||||
"text": text,
|
||||
});
|
||||
let plain_resp = self
|
||||
.client
|
||||
.post(self.api_url("sendMessage"))
|
||||
.json(&plain_body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !plain_resp.status().is_success() {
|
||||
let plain_status = plain_resp.status();
|
||||
let plain_err = plain_resp.text().await.unwrap_or_default();
|
||||
anyhow::bail!(
|
||||
"Telegram sendMessage failed (markdown {}: {}; plain {}: {})",
|
||||
markdown_status,
|
||||
markdown_err,
|
||||
plain_status,
|
||||
plain_err
|
||||
);
|
||||
for attachment in &attachments {
|
||||
self.send_attachment(chat_id, attachment).await?;
|
||||
}
|
||||
|
||||
// Small delay between chunks to avoid rate limiting
|
||||
if i < chunks.len() - 1 {
|
||||
tokio::time::sleep(Duration::from_millis(100)).await;
|
||||
}
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
Ok(())
|
||||
if let Some(attachment) = parse_path_only_attachment(message) {
|
||||
self.send_attachment(chat_id, &attachment).await?;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
self.send_text_chunks(message, chat_id).await
|
||||
}
|
||||
|
||||
async fn listen(&self, tx: tokio::sync::mpsc::Sender<ChannelMessage>) -> anyhow::Result<()> {
|
||||
|
|
@ -533,59 +848,13 @@ impl Channel for TelegramChannel {
|
|||
offset = uid + 1;
|
||||
}
|
||||
|
||||
let Some(message) = update.get("message") else {
|
||||
let Some(msg) = self.parse_update_message(update) else {
|
||||
continue;
|
||||
};
|
||||
|
||||
let Some(text) = message.get("text").and_then(serde_json::Value::as_str) else {
|
||||
continue;
|
||||
};
|
||||
|
||||
let username_opt = message
|
||||
.get("from")
|
||||
.and_then(|f| f.get("username"))
|
||||
.and_then(|u| u.as_str());
|
||||
let username = username_opt.unwrap_or("unknown");
|
||||
|
||||
let user_id = message
|
||||
.get("from")
|
||||
.and_then(|f| f.get("id"))
|
||||
.and_then(serde_json::Value::as_i64);
|
||||
let user_id_str = user_id.map(|id| id.to_string());
|
||||
|
||||
let mut identities = vec![username];
|
||||
if let Some(ref id) = user_id_str {
|
||||
identities.push(id.as_str());
|
||||
}
|
||||
|
||||
if !self.is_any_user_allowed(identities.iter().copied()) {
|
||||
tracing::warn!(
|
||||
"Telegram: ignoring message from unauthorized user: username={username}, user_id={}. \
|
||||
Allowlist Telegram @username or numeric user ID, then run `zeroclaw onboard --channels-only`.",
|
||||
user_id_str.as_deref().unwrap_or("unknown")
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
let chat_id = message
|
||||
.get("chat")
|
||||
.and_then(|c| c.get("id"))
|
||||
.and_then(serde_json::Value::as_i64)
|
||||
.map(|id| id.to_string());
|
||||
|
||||
let Some(chat_id) = chat_id else {
|
||||
tracing::warn!("Telegram: missing chat_id in message, skipping");
|
||||
continue;
|
||||
};
|
||||
|
||||
let message_id = message
|
||||
.get("message_id")
|
||||
.and_then(|v| v.as_i64())
|
||||
.unwrap_or(0);
|
||||
|
||||
// Send "typing" indicator immediately when we receive a message
|
||||
let typing_body = serde_json::json!({
|
||||
"chat_id": &chat_id,
|
||||
"chat_id": &msg.reply_target,
|
||||
"action": "typing"
|
||||
});
|
||||
let _ = self
|
||||
|
|
@ -595,17 +864,6 @@ Allowlist Telegram @username or numeric user ID, then run `zeroclaw onboard --ch
|
|||
.send()
|
||||
.await; // Ignore errors for typing indicator
|
||||
|
||||
let msg = ChannelMessage {
|
||||
id: format!("telegram_{chat_id}_{message_id}"),
|
||||
sender: username.to_string(),
|
||||
content: text.to_string(),
|
||||
channel: "telegram".to_string(),
|
||||
timestamp: std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs(),
|
||||
};
|
||||
|
||||
if tx.send(msg).await.is_err() {
|
||||
return Ok(());
|
||||
}
|
||||
|
|
@ -716,6 +974,107 @@ mod tests {
|
|||
assert!(!ch.is_any_user_allowed(["unknown", "123456789"]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_attachment_markers_extracts_multiple_types() {
|
||||
let message = "Here are files [IMAGE:/tmp/a.png] and [DOCUMENT:https://example.com/a.pdf]";
|
||||
let (cleaned, attachments) = parse_attachment_markers(message);
|
||||
|
||||
assert_eq!(cleaned, "Here are files and");
|
||||
assert_eq!(attachments.len(), 2);
|
||||
assert_eq!(attachments[0].kind, TelegramAttachmentKind::Image);
|
||||
assert_eq!(attachments[0].target, "/tmp/a.png");
|
||||
assert_eq!(attachments[1].kind, TelegramAttachmentKind::Document);
|
||||
assert_eq!(attachments[1].target, "https://example.com/a.pdf");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_attachment_markers_keeps_invalid_markers_in_text() {
|
||||
let message = "Report [UNKNOWN:/tmp/a.bin]";
|
||||
let (cleaned, attachments) = parse_attachment_markers(message);
|
||||
|
||||
assert_eq!(cleaned, "Report [UNKNOWN:/tmp/a.bin]");
|
||||
assert!(attachments.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_path_only_attachment_detects_existing_file() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let image_path = dir.path().join("snap.png");
|
||||
std::fs::write(&image_path, b"fake-png").unwrap();
|
||||
|
||||
let parsed = parse_path_only_attachment(image_path.to_string_lossy().as_ref())
|
||||
.expect("expected attachment");
|
||||
|
||||
assert_eq!(parsed.kind, TelegramAttachmentKind::Image);
|
||||
assert_eq!(parsed.target, image_path.to_string_lossy());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_path_only_attachment_rejects_sentence_text() {
|
||||
assert!(parse_path_only_attachment("Screenshot saved to /tmp/snap.png").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn infer_attachment_kind_from_target_detects_document_extension() {
|
||||
assert_eq!(
|
||||
infer_attachment_kind_from_target("https://example.com/files/specs.pdf?download=1"),
|
||||
Some(TelegramAttachmentKind::Document)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_update_message_uses_chat_id_as_reply_target() {
|
||||
let ch = TelegramChannel::new("token".into(), vec!["*".into()]);
|
||||
let update = serde_json::json!({
|
||||
"update_id": 1,
|
||||
"message": {
|
||||
"message_id": 33,
|
||||
"text": "hello",
|
||||
"from": {
|
||||
"id": 555,
|
||||
"username": "alice"
|
||||
},
|
||||
"chat": {
|
||||
"id": -100200300
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let msg = ch
|
||||
.parse_update_message(&update)
|
||||
.expect("message should parse");
|
||||
|
||||
assert_eq!(msg.sender, "alice");
|
||||
assert_eq!(msg.reply_target, "-100200300");
|
||||
assert_eq!(msg.content, "hello");
|
||||
assert_eq!(msg.id, "telegram_-100200300_33");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_update_message_allows_numeric_id_without_username() {
|
||||
let ch = TelegramChannel::new("token".into(), vec!["555".into()]);
|
||||
let update = serde_json::json!({
|
||||
"update_id": 2,
|
||||
"message": {
|
||||
"message_id": 9,
|
||||
"text": "ping",
|
||||
"from": {
|
||||
"id": 555
|
||||
},
|
||||
"chat": {
|
||||
"id": 12345
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let msg = ch
|
||||
.parse_update_message(&update)
|
||||
.expect("numeric allowlist should pass");
|
||||
|
||||
assert_eq!(msg.sender, "555");
|
||||
assert_eq!(msg.reply_target, "12345");
|
||||
}
|
||||
|
||||
// ── File sending API URL tests ──────────────────────────────────
|
||||
|
||||
#[test]
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ use async_trait::async_trait;
|
|||
pub struct ChannelMessage {
|
||||
pub id: String,
|
||||
pub sender: String,
|
||||
pub reply_target: String,
|
||||
pub content: String,
|
||||
pub channel: String,
|
||||
pub timestamp: u64,
|
||||
|
|
@ -62,6 +63,7 @@ mod tests {
|
|||
tx.send(ChannelMessage {
|
||||
id: "1".into(),
|
||||
sender: "tester".into(),
|
||||
reply_target: "tester".into(),
|
||||
content: "hello".into(),
|
||||
channel: "dummy".into(),
|
||||
timestamp: 123,
|
||||
|
|
@ -76,6 +78,7 @@ mod tests {
|
|||
let message = ChannelMessage {
|
||||
id: "42".into(),
|
||||
sender: "alice".into(),
|
||||
reply_target: "alice".into(),
|
||||
content: "ping".into(),
|
||||
channel: "dummy".into(),
|
||||
timestamp: 999,
|
||||
|
|
@ -84,6 +87,7 @@ mod tests {
|
|||
let cloned = message.clone();
|
||||
assert_eq!(cloned.id, "42");
|
||||
assert_eq!(cloned.sender, "alice");
|
||||
assert_eq!(cloned.reply_target, "alice");
|
||||
assert_eq!(cloned.content, "ping");
|
||||
assert_eq!(cloned.channel, "dummy");
|
||||
assert_eq!(cloned.timestamp, 999);
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ use uuid::Uuid;
|
|||
/// happens in the gateway when Meta sends webhook events.
|
||||
pub struct WhatsAppChannel {
|
||||
access_token: String,
|
||||
phone_number_id: String,
|
||||
endpoint_id: String,
|
||||
verify_token: String,
|
||||
allowed_numbers: Vec<String>,
|
||||
client: reqwest::Client,
|
||||
|
|
@ -19,13 +19,13 @@ pub struct WhatsAppChannel {
|
|||
impl WhatsAppChannel {
|
||||
pub fn new(
|
||||
access_token: String,
|
||||
phone_number_id: String,
|
||||
endpoint_id: String,
|
||||
verify_token: String,
|
||||
allowed_numbers: Vec<String>,
|
||||
) -> Self {
|
||||
Self {
|
||||
access_token,
|
||||
phone_number_id,
|
||||
endpoint_id,
|
||||
verify_token,
|
||||
allowed_numbers,
|
||||
client: reqwest::Client::new(),
|
||||
|
|
@ -119,6 +119,7 @@ impl WhatsAppChannel {
|
|||
|
||||
messages.push(ChannelMessage {
|
||||
id: Uuid::new_v4().to_string(),
|
||||
reply_target: normalized_from.clone(),
|
||||
sender: normalized_from,
|
||||
content,
|
||||
channel: "whatsapp".to_string(),
|
||||
|
|
@ -142,7 +143,7 @@ impl Channel for WhatsAppChannel {
|
|||
// WhatsApp Cloud API: POST to /v18.0/{phone_number_id}/messages
|
||||
let url = format!(
|
||||
"https://graph.facebook.com/v18.0/{}/messages",
|
||||
self.phone_number_id
|
||||
self.endpoint_id
|
||||
);
|
||||
|
||||
// Normalize recipient (remove leading + if present for API)
|
||||
|
|
@ -162,7 +163,7 @@ impl Channel for WhatsAppChannel {
|
|||
let resp = self
|
||||
.client
|
||||
.post(&url)
|
||||
.header("Authorization", format!("Bearer {}", self.access_token))
|
||||
.bearer_auth(&self.access_token)
|
||||
.header("Content-Type", "application/json")
|
||||
.json(&body)
|
||||
.send()
|
||||
|
|
@ -195,11 +196,11 @@ impl Channel for WhatsAppChannel {
|
|||
|
||||
async fn health_check(&self) -> bool {
|
||||
// Check if we can reach the WhatsApp API
|
||||
let url = format!("https://graph.facebook.com/v18.0/{}", self.phone_number_id);
|
||||
let url = format!("https://graph.facebook.com/v18.0/{}", self.endpoint_id);
|
||||
|
||||
self.client
|
||||
.get(&url)
|
||||
.header("Authorization", format!("Bearer {}", self.access_token))
|
||||
.bearer_auth(&self.access_token)
|
||||
.send()
|
||||
.await
|
||||
.map(|r| r.status().is_success())
|
||||
|
|
|
|||
|
|
@ -37,9 +37,22 @@ mod tests {
|
|||
guild_id: Some("123".into()),
|
||||
allowed_users: vec![],
|
||||
listen_to_bots: false,
|
||||
mention_only: false,
|
||||
};
|
||||
|
||||
let lark = LarkConfig {
|
||||
app_id: "app-id".into(),
|
||||
app_secret: "app-secret".into(),
|
||||
encrypt_key: None,
|
||||
verification_token: None,
|
||||
allowed_users: vec![],
|
||||
use_feishu: false,
|
||||
receive_mode: crate::config::schema::LarkReceiveMode::Websocket,
|
||||
port: None,
|
||||
};
|
||||
|
||||
assert_eq!(telegram.allowed_users.len(), 1);
|
||||
assert_eq!(discord.guild_id.as_deref(), Some("123"));
|
||||
assert_eq!(lark.app_id, "app-id");
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -18,6 +18,8 @@ pub struct Config {
|
|||
#[serde(skip)]
|
||||
pub config_path: PathBuf,
|
||||
pub api_key: Option<String>,
|
||||
/// Base URL override for provider API (e.g. "http://10.0.0.1:11434" for remote Ollama)
|
||||
pub api_url: Option<String>,
|
||||
pub default_provider: Option<String>,
|
||||
pub default_model: Option<String>,
|
||||
pub default_temperature: f64,
|
||||
|
|
@ -1317,6 +1319,10 @@ pub struct DiscordConfig {
|
|||
/// The bot still ignores its own messages to prevent feedback loops.
|
||||
#[serde(default)]
|
||||
pub listen_to_bots: bool,
|
||||
/// When true, only respond to messages that @-mention the bot.
|
||||
/// Other messages in the guild are silently ignored.
|
||||
#[serde(default)]
|
||||
pub mention_only: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
|
|
@ -1395,8 +1401,20 @@ fn default_irc_port() -> u16 {
|
|||
6697
|
||||
}
|
||||
|
||||
/// Lark/Feishu configuration for messaging integration
|
||||
/// Lark is the international version, Feishu is the Chinese version
|
||||
/// How ZeroClaw receives events from Feishu / Lark.
|
||||
///
|
||||
/// - `websocket` (default) — persistent WSS long-connection; no public URL required.
|
||||
/// - `webhook` — HTTP callback server; requires a public HTTPS endpoint.
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum LarkReceiveMode {
|
||||
#[default]
|
||||
Websocket,
|
||||
Webhook,
|
||||
}
|
||||
|
||||
/// Lark/Feishu configuration for messaging integration.
|
||||
/// Lark is the international version; Feishu is the Chinese version.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct LarkConfig {
|
||||
/// App ID from Lark/Feishu developer console
|
||||
|
|
@ -1415,6 +1433,13 @@ pub struct LarkConfig {
|
|||
/// Whether to use the Feishu (Chinese) endpoint instead of Lark (International)
|
||||
#[serde(default)]
|
||||
pub use_feishu: bool,
|
||||
/// Event receive mode: "websocket" (default) or "webhook"
|
||||
#[serde(default)]
|
||||
pub receive_mode: LarkReceiveMode,
|
||||
/// HTTP port for webhook mode only. Must be set when receive_mode = "webhook".
|
||||
/// Not required (and ignored) for websocket mode.
|
||||
#[serde(default)]
|
||||
pub port: Option<u16>,
|
||||
}
|
||||
|
||||
// ── Security Config ─────────────────────────────────────────────────
|
||||
|
|
@ -1594,6 +1619,7 @@ impl Default for Config {
|
|||
workspace_dir: zeroclaw_dir.join("workspace"),
|
||||
config_path: zeroclaw_dir.join("config.toml"),
|
||||
api_key: None,
|
||||
api_url: None,
|
||||
default_provider: Some("openrouter".to_string()),
|
||||
default_model: Some("anthropic/claude-sonnet-4".to_string()),
|
||||
default_temperature: 0.7,
|
||||
|
|
@ -1623,35 +1649,146 @@ impl Default for Config {
|
|||
}
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub fn load_or_init() -> Result<Self> {
|
||||
let home = UserDirs::new()
|
||||
.map(|u| u.home_dir().to_path_buf())
|
||||
.context("Could not find home directory")?;
|
||||
let zeroclaw_dir = home.join(".zeroclaw");
|
||||
let config_path = zeroclaw_dir.join("config.toml");
|
||||
fn default_config_and_workspace_dirs() -> Result<(PathBuf, PathBuf)> {
|
||||
let home = UserDirs::new()
|
||||
.map(|u| u.home_dir().to_path_buf())
|
||||
.context("Could not find home directory")?;
|
||||
let config_dir = home.join(".zeroclaw");
|
||||
Ok((config_dir.clone(), config_dir.join("workspace")))
|
||||
}
|
||||
|
||||
if !zeroclaw_dir.exists() {
|
||||
fs::create_dir_all(&zeroclaw_dir).context("Failed to create .zeroclaw directory")?;
|
||||
fs::create_dir_all(zeroclaw_dir.join("workspace"))
|
||||
.context("Failed to create workspace directory")?;
|
||||
fn resolve_config_dir_for_workspace(workspace_dir: &Path) -> PathBuf {
|
||||
let workspace_config_dir = workspace_dir.to_path_buf();
|
||||
if workspace_config_dir.join("config.toml").exists() {
|
||||
return workspace_config_dir;
|
||||
}
|
||||
|
||||
let legacy_config_dir = workspace_dir
|
||||
.parent()
|
||||
.map(|parent| parent.join(".zeroclaw"));
|
||||
if let Some(legacy_dir) = legacy_config_dir {
|
||||
if legacy_dir.join("config.toml").exists() {
|
||||
return legacy_dir;
|
||||
}
|
||||
|
||||
if workspace_dir
|
||||
.file_name()
|
||||
.is_some_and(|name| name == std::ffi::OsStr::new("workspace"))
|
||||
{
|
||||
return legacy_dir;
|
||||
}
|
||||
}
|
||||
|
||||
workspace_config_dir
|
||||
}
|
||||
|
||||
fn decrypt_optional_secret(
|
||||
store: &crate::security::SecretStore,
|
||||
value: &mut Option<String>,
|
||||
field_name: &str,
|
||||
) -> Result<()> {
|
||||
if let Some(raw) = value.clone() {
|
||||
if crate::security::SecretStore::is_encrypted(&raw) {
|
||||
*value = Some(
|
||||
store
|
||||
.decrypt(&raw)
|
||||
.with_context(|| format!("Failed to decrypt {field_name}"))?,
|
||||
);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn encrypt_optional_secret(
|
||||
store: &crate::security::SecretStore,
|
||||
value: &mut Option<String>,
|
||||
field_name: &str,
|
||||
) -> Result<()> {
|
||||
if let Some(raw) = value.clone() {
|
||||
if !crate::security::SecretStore::is_encrypted(&raw) {
|
||||
*value = Some(
|
||||
store
|
||||
.encrypt(&raw)
|
||||
.with_context(|| format!("Failed to encrypt {field_name}"))?,
|
||||
);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub fn load_or_init() -> Result<Self> {
|
||||
// Resolve workspace first so config loading can follow ZEROCLAW_WORKSPACE.
|
||||
let (zeroclaw_dir, workspace_dir) = match std::env::var("ZEROCLAW_WORKSPACE") {
|
||||
Ok(custom_workspace) if !custom_workspace.is_empty() => {
|
||||
let workspace = PathBuf::from(custom_workspace);
|
||||
(resolve_config_dir_for_workspace(&workspace), workspace)
|
||||
}
|
||||
_ => default_config_and_workspace_dirs()?,
|
||||
};
|
||||
|
||||
let config_path = zeroclaw_dir.join("config.toml");
|
||||
|
||||
fs::create_dir_all(&zeroclaw_dir).context("Failed to create config directory")?;
|
||||
fs::create_dir_all(&workspace_dir).context("Failed to create workspace directory")?;
|
||||
|
||||
if config_path.exists() {
|
||||
// Warn if config file is world-readable (may contain API keys)
|
||||
#[cfg(unix)]
|
||||
{
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
if let Ok(meta) = fs::metadata(&config_path) {
|
||||
if meta.permissions().mode() & 0o004 != 0 {
|
||||
tracing::warn!(
|
||||
"Config file {:?} is world-readable (mode {:o}). \
|
||||
Consider restricting with: chmod 600 {:?}",
|
||||
config_path,
|
||||
meta.permissions().mode() & 0o777,
|
||||
config_path,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let contents =
|
||||
fs::read_to_string(&config_path).context("Failed to read config file")?;
|
||||
let mut config: Config =
|
||||
toml::from_str(&contents).context("Failed to parse config file")?;
|
||||
// Set computed paths that are skipped during serialization
|
||||
config.config_path = config_path.clone();
|
||||
config.workspace_dir = zeroclaw_dir.join("workspace");
|
||||
config.workspace_dir = workspace_dir;
|
||||
let store = crate::security::SecretStore::new(&zeroclaw_dir, config.secrets.encrypt);
|
||||
decrypt_optional_secret(&store, &mut config.api_key, "config.api_key")?;
|
||||
decrypt_optional_secret(
|
||||
&store,
|
||||
&mut config.composio.api_key,
|
||||
"config.composio.api_key",
|
||||
)?;
|
||||
|
||||
decrypt_optional_secret(
|
||||
&store,
|
||||
&mut config.browser.computer_use.api_key,
|
||||
"config.browser.computer_use.api_key",
|
||||
)?;
|
||||
|
||||
for agent in config.agents.values_mut() {
|
||||
decrypt_optional_secret(&store, &mut agent.api_key, "config.agents.*.api_key")?;
|
||||
}
|
||||
config.apply_env_overrides();
|
||||
Ok(config)
|
||||
} else {
|
||||
let mut config = Config::default();
|
||||
config.config_path = config_path.clone();
|
||||
config.workspace_dir = zeroclaw_dir.join("workspace");
|
||||
config.workspace_dir = workspace_dir;
|
||||
config.save()?;
|
||||
|
||||
// Restrict permissions on newly created config file (may contain API keys)
|
||||
#[cfg(unix)]
|
||||
{
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
let _ = fs::set_permissions(&config_path, fs::Permissions::from_mode(0o600));
|
||||
}
|
||||
|
||||
config.apply_env_overrides();
|
||||
Ok(config)
|
||||
}
|
||||
|
|
@ -1732,23 +1869,29 @@ impl Config {
|
|||
}
|
||||
|
||||
pub fn save(&self) -> Result<()> {
|
||||
// Encrypt agent API keys before serialization
|
||||
// Encrypt secrets before serialization
|
||||
let mut config_to_save = self.clone();
|
||||
let zeroclaw_dir = self
|
||||
.config_path
|
||||
.parent()
|
||||
.context("Config path must have a parent directory")?;
|
||||
let store = crate::security::SecretStore::new(zeroclaw_dir, self.secrets.encrypt);
|
||||
|
||||
encrypt_optional_secret(&store, &mut config_to_save.api_key, "config.api_key")?;
|
||||
encrypt_optional_secret(
|
||||
&store,
|
||||
&mut config_to_save.composio.api_key,
|
||||
"config.composio.api_key",
|
||||
)?;
|
||||
|
||||
encrypt_optional_secret(
|
||||
&store,
|
||||
&mut config_to_save.browser.computer_use.api_key,
|
||||
"config.browser.computer_use.api_key",
|
||||
)?;
|
||||
|
||||
for agent in config_to_save.agents.values_mut() {
|
||||
if let Some(ref plaintext_key) = agent.api_key {
|
||||
if !crate::security::SecretStore::is_encrypted(plaintext_key) {
|
||||
agent.api_key = Some(
|
||||
store
|
||||
.encrypt(plaintext_key)
|
||||
.context("Failed to encrypt agent API key")?,
|
||||
);
|
||||
}
|
||||
}
|
||||
encrypt_optional_secret(&store, &mut agent.api_key, "config.agents.*.api_key")?;
|
||||
}
|
||||
|
||||
let toml_str =
|
||||
|
|
@ -1949,6 +2092,7 @@ default_temperature = 0.7
|
|||
workspace_dir: PathBuf::from("/tmp/test/workspace"),
|
||||
config_path: PathBuf::from("/tmp/test/config.toml"),
|
||||
api_key: Some("sk-test-key".into()),
|
||||
api_url: None,
|
||||
default_provider: Some("openrouter".into()),
|
||||
default_model: Some("gpt-4o".into()),
|
||||
default_temperature: 0.5,
|
||||
|
|
@ -2091,6 +2235,7 @@ tool_dispatcher = "xml"
|
|||
workspace_dir: dir.join("workspace"),
|
||||
config_path: config_path.clone(),
|
||||
api_key: Some("sk-roundtrip".into()),
|
||||
api_url: None,
|
||||
default_provider: Some("openrouter".into()),
|
||||
default_model: Some("test-model".into()),
|
||||
default_temperature: 0.9,
|
||||
|
|
@ -2123,13 +2268,82 @@ tool_dispatcher = "xml"
|
|||
|
||||
let contents = fs::read_to_string(&config_path).unwrap();
|
||||
let loaded: Config = toml::from_str(&contents).unwrap();
|
||||
assert_eq!(loaded.api_key.as_deref(), Some("sk-roundtrip"));
|
||||
assert!(loaded
|
||||
.api_key
|
||||
.as_deref()
|
||||
.is_some_and(crate::security::SecretStore::is_encrypted));
|
||||
let store = crate::security::SecretStore::new(&dir, true);
|
||||
let decrypted = store.decrypt(loaded.api_key.as_deref().unwrap()).unwrap();
|
||||
assert_eq!(decrypted, "sk-roundtrip");
|
||||
assert_eq!(loaded.default_model.as_deref(), Some("test-model"));
|
||||
assert!((loaded.default_temperature - 0.9).abs() < f64::EPSILON);
|
||||
|
||||
let _ = fs::remove_dir_all(&dir);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn config_save_encrypts_nested_credentials() {
|
||||
let dir = std::env::temp_dir().join(format!(
|
||||
"zeroclaw_test_nested_credentials_{}",
|
||||
uuid::Uuid::new_v4()
|
||||
));
|
||||
fs::create_dir_all(&dir).unwrap();
|
||||
|
||||
let mut config = Config::default();
|
||||
config.workspace_dir = dir.join("workspace");
|
||||
config.config_path = dir.join("config.toml");
|
||||
config.api_key = Some("root-credential".into());
|
||||
config.composio.api_key = Some("composio-credential".into());
|
||||
config.browser.computer_use.api_key = Some("browser-credential".into());
|
||||
|
||||
config.agents.insert(
|
||||
"worker".into(),
|
||||
DelegateAgentConfig {
|
||||
provider: "openrouter".into(),
|
||||
model: "model-test".into(),
|
||||
system_prompt: None,
|
||||
api_key: Some("agent-credential".into()),
|
||||
temperature: None,
|
||||
max_depth: 3,
|
||||
},
|
||||
);
|
||||
|
||||
config.save().unwrap();
|
||||
|
||||
let contents = fs::read_to_string(config.config_path.clone()).unwrap();
|
||||
let stored: Config = toml::from_str(&contents).unwrap();
|
||||
let store = crate::security::SecretStore::new(&dir, true);
|
||||
|
||||
let root_encrypted = stored.api_key.as_deref().unwrap();
|
||||
assert!(crate::security::SecretStore::is_encrypted(root_encrypted));
|
||||
assert_eq!(store.decrypt(root_encrypted).unwrap(), "root-credential");
|
||||
|
||||
let composio_encrypted = stored.composio.api_key.as_deref().unwrap();
|
||||
assert!(crate::security::SecretStore::is_encrypted(
|
||||
composio_encrypted
|
||||
));
|
||||
assert_eq!(
|
||||
store.decrypt(composio_encrypted).unwrap(),
|
||||
"composio-credential"
|
||||
);
|
||||
|
||||
let browser_encrypted = stored.browser.computer_use.api_key.as_deref().unwrap();
|
||||
assert!(crate::security::SecretStore::is_encrypted(
|
||||
browser_encrypted
|
||||
));
|
||||
assert_eq!(
|
||||
store.decrypt(browser_encrypted).unwrap(),
|
||||
"browser-credential"
|
||||
);
|
||||
|
||||
let worker = stored.agents.get("worker").unwrap();
|
||||
let worker_encrypted = worker.api_key.as_deref().unwrap();
|
||||
assert!(crate::security::SecretStore::is_encrypted(worker_encrypted));
|
||||
assert_eq!(store.decrypt(worker_encrypted).unwrap(), "agent-credential");
|
||||
|
||||
let _ = fs::remove_dir_all(&dir);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn config_save_atomic_cleanup() {
|
||||
let dir =
|
||||
|
|
@ -2182,6 +2396,7 @@ tool_dispatcher = "xml"
|
|||
guild_id: Some("12345".into()),
|
||||
allowed_users: vec![],
|
||||
listen_to_bots: false,
|
||||
mention_only: false,
|
||||
};
|
||||
let json = serde_json::to_string(&dc).unwrap();
|
||||
let parsed: DiscordConfig = serde_json::from_str(&json).unwrap();
|
||||
|
|
@ -2196,6 +2411,7 @@ tool_dispatcher = "xml"
|
|||
guild_id: None,
|
||||
allowed_users: vec![],
|
||||
listen_to_bots: false,
|
||||
mention_only: false,
|
||||
};
|
||||
let json = serde_json::to_string(&dc).unwrap();
|
||||
let parsed: DiscordConfig = serde_json::from_str(&json).unwrap();
|
||||
|
|
@ -2818,6 +3034,96 @@ default_temperature = 0.7
|
|||
std::env::remove_var("ZEROCLAW_WORKSPACE");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_or_init_workspace_override_uses_workspace_root_for_config() {
|
||||
let _env_guard = env_override_test_guard();
|
||||
let temp_home =
|
||||
std::env::temp_dir().join(format!("zeroclaw_test_home_{}", uuid::Uuid::new_v4()));
|
||||
let workspace_dir = temp_home.join("profile-a");
|
||||
|
||||
let original_home = std::env::var("HOME").ok();
|
||||
std::env::set_var("HOME", &temp_home);
|
||||
std::env::set_var("ZEROCLAW_WORKSPACE", &workspace_dir);
|
||||
|
||||
let config = Config::load_or_init().unwrap();
|
||||
|
||||
assert_eq!(config.workspace_dir, workspace_dir);
|
||||
assert_eq!(config.config_path, workspace_dir.join("config.toml"));
|
||||
assert!(workspace_dir.join("config.toml").exists());
|
||||
|
||||
std::env::remove_var("ZEROCLAW_WORKSPACE");
|
||||
if let Some(home) = original_home {
|
||||
std::env::set_var("HOME", home);
|
||||
} else {
|
||||
std::env::remove_var("HOME");
|
||||
}
|
||||
let _ = fs::remove_dir_all(temp_home);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_or_init_workspace_suffix_uses_legacy_config_layout() {
|
||||
let _env_guard = env_override_test_guard();
|
||||
let temp_home =
|
||||
std::env::temp_dir().join(format!("zeroclaw_test_home_{}", uuid::Uuid::new_v4()));
|
||||
let workspace_dir = temp_home.join("workspace");
|
||||
let legacy_config_path = temp_home.join(".zeroclaw").join("config.toml");
|
||||
|
||||
let original_home = std::env::var("HOME").ok();
|
||||
std::env::set_var("HOME", &temp_home);
|
||||
std::env::set_var("ZEROCLAW_WORKSPACE", &workspace_dir);
|
||||
|
||||
let config = Config::load_or_init().unwrap();
|
||||
|
||||
assert_eq!(config.workspace_dir, workspace_dir);
|
||||
assert_eq!(config.config_path, legacy_config_path);
|
||||
assert!(config.config_path.exists());
|
||||
|
||||
std::env::remove_var("ZEROCLAW_WORKSPACE");
|
||||
if let Some(home) = original_home {
|
||||
std::env::set_var("HOME", home);
|
||||
} else {
|
||||
std::env::remove_var("HOME");
|
||||
}
|
||||
let _ = fs::remove_dir_all(temp_home);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_or_init_workspace_override_keeps_existing_legacy_config() {
|
||||
let _env_guard = env_override_test_guard();
|
||||
let temp_home =
|
||||
std::env::temp_dir().join(format!("zeroclaw_test_home_{}", uuid::Uuid::new_v4()));
|
||||
let workspace_dir = temp_home.join("custom-workspace");
|
||||
let legacy_config_dir = temp_home.join(".zeroclaw");
|
||||
let legacy_config_path = legacy_config_dir.join("config.toml");
|
||||
|
||||
fs::create_dir_all(&legacy_config_dir).unwrap();
|
||||
fs::write(
|
||||
&legacy_config_path,
|
||||
r#"default_temperature = 0.7
|
||||
default_model = "legacy-model"
|
||||
"#,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let original_home = std::env::var("HOME").ok();
|
||||
std::env::set_var("HOME", &temp_home);
|
||||
std::env::set_var("ZEROCLAW_WORKSPACE", &workspace_dir);
|
||||
|
||||
let config = Config::load_or_init().unwrap();
|
||||
|
||||
assert_eq!(config.workspace_dir, workspace_dir);
|
||||
assert_eq!(config.config_path, legacy_config_path);
|
||||
assert_eq!(config.default_model.as_deref(), Some("legacy-model"));
|
||||
|
||||
std::env::remove_var("ZEROCLAW_WORKSPACE");
|
||||
if let Some(home) = original_home {
|
||||
std::env::set_var("HOME", home);
|
||||
} else {
|
||||
std::env::remove_var("HOME");
|
||||
}
|
||||
let _ = fs::remove_dir_all(temp_home);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn env_override_empty_values_ignored() {
|
||||
let _env_guard = env_override_test_guard();
|
||||
|
|
@ -2975,4 +3281,118 @@ default_temperature = 0.7
|
|||
assert_eq!(parsed.boards[0].board, "nucleo-f401re");
|
||||
assert_eq!(parsed.boards[0].path.as_deref(), Some("/dev/ttyACM0"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lark_config_serde() {
|
||||
let lc = LarkConfig {
|
||||
app_id: "cli_123456".into(),
|
||||
app_secret: "secret_abc".into(),
|
||||
encrypt_key: Some("encrypt_key".into()),
|
||||
verification_token: Some("verify_token".into()),
|
||||
allowed_users: vec!["user_123".into(), "user_456".into()],
|
||||
use_feishu: true,
|
||||
receive_mode: LarkReceiveMode::Websocket,
|
||||
port: None,
|
||||
};
|
||||
let json = serde_json::to_string(&lc).unwrap();
|
||||
let parsed: LarkConfig = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(parsed.app_id, "cli_123456");
|
||||
assert_eq!(parsed.app_secret, "secret_abc");
|
||||
assert_eq!(parsed.encrypt_key.as_deref(), Some("encrypt_key"));
|
||||
assert_eq!(parsed.verification_token.as_deref(), Some("verify_token"));
|
||||
assert_eq!(parsed.allowed_users.len(), 2);
|
||||
assert!(parsed.use_feishu);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lark_config_toml_roundtrip() {
|
||||
let lc = LarkConfig {
|
||||
app_id: "cli_123456".into(),
|
||||
app_secret: "secret_abc".into(),
|
||||
encrypt_key: Some("encrypt_key".into()),
|
||||
verification_token: Some("verify_token".into()),
|
||||
allowed_users: vec!["*".into()],
|
||||
use_feishu: false,
|
||||
receive_mode: LarkReceiveMode::Webhook,
|
||||
port: Some(9898),
|
||||
};
|
||||
let toml_str = toml::to_string(&lc).unwrap();
|
||||
let parsed: LarkConfig = toml::from_str(&toml_str).unwrap();
|
||||
assert_eq!(parsed.app_id, "cli_123456");
|
||||
assert_eq!(parsed.app_secret, "secret_abc");
|
||||
assert!(!parsed.use_feishu);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lark_config_deserializes_without_optional_fields() {
|
||||
let json = r#"{"app_id":"cli_123","app_secret":"secret"}"#;
|
||||
let parsed: LarkConfig = serde_json::from_str(json).unwrap();
|
||||
assert!(parsed.encrypt_key.is_none());
|
||||
assert!(parsed.verification_token.is_none());
|
||||
assert!(parsed.allowed_users.is_empty());
|
||||
assert!(!parsed.use_feishu);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lark_config_defaults_to_lark_endpoint() {
|
||||
let json = r#"{"app_id":"cli_123","app_secret":"secret"}"#;
|
||||
let parsed: LarkConfig = serde_json::from_str(json).unwrap();
|
||||
assert!(
|
||||
!parsed.use_feishu,
|
||||
"use_feishu should default to false (Lark)"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lark_config_with_wildcard_allowed_users() {
|
||||
let json = r#"{"app_id":"cli_123","app_secret":"secret","allowed_users":["*"]}"#;
|
||||
let parsed: LarkConfig = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(parsed.allowed_users, vec!["*"]);
|
||||
}
|
||||
|
||||
// ── Config file permission hardening (Unix only) ───────────────
|
||||
|
||||
#[cfg(unix)]
|
||||
#[test]
|
||||
fn new_config_file_has_restricted_permissions() {
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
|
||||
let tmp = tempfile::TempDir::new().unwrap();
|
||||
let config_path = tmp.path().join("config.toml");
|
||||
|
||||
// Create a config and save it
|
||||
let mut config = Config::default();
|
||||
config.config_path = config_path.clone();
|
||||
config.save().unwrap();
|
||||
|
||||
// Apply the same permission logic as load_or_init
|
||||
let _ = std::fs::set_permissions(&config_path, std::fs::Permissions::from_mode(0o600));
|
||||
|
||||
let meta = std::fs::metadata(&config_path).unwrap();
|
||||
let mode = meta.permissions().mode() & 0o777;
|
||||
assert_eq!(
|
||||
mode, 0o600,
|
||||
"New config file should be owner-only (0600), got {mode:o}"
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
#[test]
|
||||
fn world_readable_config_is_detectable() {
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
|
||||
let tmp = tempfile::TempDir::new().unwrap();
|
||||
let config_path = tmp.path().join("config.toml");
|
||||
|
||||
// Create a config file with intentionally loose permissions
|
||||
std::fs::write(&config_path, "# test config").unwrap();
|
||||
std::fs::set_permissions(&config_path, std::fs::Permissions::from_mode(0o644)).unwrap();
|
||||
|
||||
let meta = std::fs::metadata(&config_path).unwrap();
|
||||
let mode = meta.permissions().mode();
|
||||
assert!(
|
||||
mode & 0o004 != 0,
|
||||
"Test setup: file should be world-readable (mode {mode:o})"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -245,6 +245,7 @@ async fn deliver_if_configured(config: &Config, job: &CronJob, output: &str) ->
|
|||
dc.guild_id.clone(),
|
||||
dc.allowed_users.clone(),
|
||||
dc.listen_to_bots,
|
||||
dc.mention_only,
|
||||
);
|
||||
channel.send(output, target).await?;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -216,6 +216,7 @@ fn has_supervised_channels(config: &Config) -> bool {
|
|||
|| config.channels_config.matrix.is_some()
|
||||
|| config.channels_config.whatsapp.is_some()
|
||||
|| config.channels_config.email.is_some()
|
||||
|| config.channels_config.lark.is_some()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
|
|
|||
|
|
@ -49,6 +49,13 @@ fn whatsapp_memory_key(msg: &crate::channels::traits::ChannelMessage) -> String
|
|||
format!("whatsapp_{}_{}", msg.sender, msg.id)
|
||||
}
|
||||
|
||||
fn hash_webhook_secret(value: &str) -> String {
|
||||
use sha2::{Digest, Sha256};
|
||||
|
||||
let digest = Sha256::digest(value.as_bytes());
|
||||
hex::encode(digest)
|
||||
}
|
||||
|
||||
/// How often the rate limiter sweeps stale IP entries from its map.
|
||||
const RATE_LIMITER_SWEEP_INTERVAL_SECS: u64 = 300; // 5 minutes
|
||||
|
||||
|
|
@ -178,7 +185,8 @@ pub struct AppState {
|
|||
pub temperature: f64,
|
||||
pub mem: Arc<dyn Memory>,
|
||||
pub auto_save: bool,
|
||||
pub webhook_secret: Option<Arc<str>>,
|
||||
/// SHA-256 hash of `X-Webhook-Secret` (hex-encoded), never plaintext.
|
||||
pub webhook_secret_hash: Option<Arc<str>>,
|
||||
pub pairing: Arc<PairingGuard>,
|
||||
pub rate_limiter: Arc<GatewayRateLimiter>,
|
||||
pub idempotency_store: Arc<IdempotencyStore>,
|
||||
|
|
@ -208,6 +216,7 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
|
|||
let provider: Arc<dyn Provider> = Arc::from(providers::create_resilient_provider(
|
||||
config.default_provider.as_deref().unwrap_or("openrouter"),
|
||||
config.api_key.as_deref(),
|
||||
config.api_url.as_deref(),
|
||||
&config.reliability,
|
||||
)?);
|
||||
let model = config
|
||||
|
|
@ -251,12 +260,14 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
|
|||
&config,
|
||||
));
|
||||
// Extract webhook secret for authentication
|
||||
let webhook_secret: Option<Arc<str>> = config
|
||||
.channels_config
|
||||
.webhook
|
||||
.as_ref()
|
||||
.and_then(|w| w.secret.as_deref())
|
||||
.map(Arc::from);
|
||||
let webhook_secret_hash: Option<Arc<str>> =
|
||||
config.channels_config.webhook.as_ref().and_then(|webhook| {
|
||||
webhook.secret.as_ref().and_then(|raw_secret| {
|
||||
let trimmed_secret = raw_secret.trim();
|
||||
(!trimmed_secret.is_empty())
|
||||
.then(|| Arc::<str>::from(hash_webhook_secret(trimmed_secret)))
|
||||
})
|
||||
});
|
||||
|
||||
// WhatsApp channel (if configured)
|
||||
let whatsapp_channel: Option<Arc<WhatsAppChannel>> =
|
||||
|
|
@ -342,9 +353,6 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
|
|||
} else {
|
||||
println!(" ⚠️ Pairing: DISABLED (all requests accepted)");
|
||||
}
|
||||
if webhook_secret.is_some() {
|
||||
println!(" 🔒 Webhook secret: ENABLED");
|
||||
}
|
||||
println!(" Press Ctrl+C to stop.\n");
|
||||
|
||||
crate::health::mark_component_ok("gateway");
|
||||
|
|
@ -356,7 +364,7 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
|
|||
temperature,
|
||||
mem,
|
||||
auto_save: config.memory.auto_save,
|
||||
webhook_secret,
|
||||
webhook_secret_hash,
|
||||
pairing,
|
||||
rate_limiter,
|
||||
idempotency_store,
|
||||
|
|
@ -482,12 +490,15 @@ async fn handle_webhook(
|
|||
}
|
||||
|
||||
// ── Webhook secret auth (optional, additional layer) ──
|
||||
if let Some(ref secret) = state.webhook_secret {
|
||||
let header_val = headers
|
||||
if let Some(ref secret_hash) = state.webhook_secret_hash {
|
||||
let header_hash = headers
|
||||
.get("X-Webhook-Secret")
|
||||
.and_then(|v| v.to_str().ok());
|
||||
match header_val {
|
||||
Some(val) if constant_time_eq(val, secret.as_ref()) => {}
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(str::trim)
|
||||
.filter(|value| !value.is_empty())
|
||||
.map(hash_webhook_secret);
|
||||
match header_hash {
|
||||
Some(val) if constant_time_eq(&val, secret_hash.as_ref()) => {}
|
||||
_ => {
|
||||
tracing::warn!("Webhook: rejected request — invalid or missing X-Webhook-Secret");
|
||||
let err = serde_json::json!({"error": "Unauthorized — invalid or missing X-Webhook-Secret header"});
|
||||
|
|
@ -532,7 +543,7 @@ async fn handle_webhook(
|
|||
let key = webhook_memory_key();
|
||||
let _ = state
|
||||
.mem
|
||||
.store(&key, message, MemoryCategory::Conversation)
|
||||
.store(&key, message, MemoryCategory::Conversation, None)
|
||||
.await;
|
||||
}
|
||||
|
||||
|
|
@ -685,7 +696,7 @@ async fn handle_whatsapp_message(
|
|||
let key = whatsapp_memory_key(msg);
|
||||
let _ = state
|
||||
.mem
|
||||
.store(&key, &msg.content, MemoryCategory::Conversation)
|
||||
.store(&key, &msg.content, MemoryCategory::Conversation, None)
|
||||
.await;
|
||||
}
|
||||
|
||||
|
|
@ -697,7 +708,7 @@ async fn handle_whatsapp_message(
|
|||
{
|
||||
Ok(response) => {
|
||||
// Send reply via WhatsApp
|
||||
if let Err(e) = wa.send(&response, &msg.sender).await {
|
||||
if let Err(e) = wa.send(&response, &msg.reply_target).await {
|
||||
tracing::error!("Failed to send WhatsApp reply: {e}");
|
||||
}
|
||||
}
|
||||
|
|
@ -706,7 +717,7 @@ async fn handle_whatsapp_message(
|
|||
let _ = wa
|
||||
.send(
|
||||
"Sorry, I couldn't process your message right now.",
|
||||
&msg.sender,
|
||||
&msg.reply_target,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
|
@ -798,7 +809,9 @@ mod tests {
|
|||
.requests
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||
guard.1 = Instant::now() - Duration::from_secs(RATE_LIMITER_SWEEP_INTERVAL_SECS + 1);
|
||||
guard.1 = Instant::now()
|
||||
.checked_sub(Duration::from_secs(RATE_LIMITER_SWEEP_INTERVAL_SECS + 1))
|
||||
.unwrap();
|
||||
// Clear timestamps for ip-2 and ip-3 to simulate stale entries
|
||||
guard.0.get_mut("ip-2").unwrap().clear();
|
||||
guard.0.get_mut("ip-3").unwrap().clear();
|
||||
|
|
@ -848,6 +861,7 @@ mod tests {
|
|||
let msg = ChannelMessage {
|
||||
id: "wamid-123".into(),
|
||||
sender: "+1234567890".into(),
|
||||
reply_target: "+1234567890".into(),
|
||||
content: "hello".into(),
|
||||
channel: "whatsapp".into(),
|
||||
timestamp: 1,
|
||||
|
|
@ -871,11 +885,17 @@ mod tests {
|
|||
_key: &str,
|
||||
_content: &str,
|
||||
_category: MemoryCategory,
|
||||
_session_id: Option<&str>,
|
||||
) -> anyhow::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn recall(&self, _query: &str, _limit: usize) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
async fn recall(
|
||||
&self,
|
||||
_query: &str,
|
||||
_limit: usize,
|
||||
_session_id: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
Ok(Vec::new())
|
||||
}
|
||||
|
||||
|
|
@ -886,6 +906,7 @@ mod tests {
|
|||
async fn list(
|
||||
&self,
|
||||
_category: Option<&MemoryCategory>,
|
||||
_session_id: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
Ok(Vec::new())
|
||||
}
|
||||
|
|
@ -938,6 +959,7 @@ mod tests {
|
|||
key: &str,
|
||||
_content: &str,
|
||||
_category: MemoryCategory,
|
||||
_session_id: Option<&str>,
|
||||
) -> anyhow::Result<()> {
|
||||
self.keys
|
||||
.lock()
|
||||
|
|
@ -946,7 +968,12 @@ mod tests {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
async fn recall(&self, _query: &str, _limit: usize) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
async fn recall(
|
||||
&self,
|
||||
_query: &str,
|
||||
_limit: usize,
|
||||
_session_id: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
Ok(Vec::new())
|
||||
}
|
||||
|
||||
|
|
@ -957,6 +984,7 @@ mod tests {
|
|||
async fn list(
|
||||
&self,
|
||||
_category: Option<&MemoryCategory>,
|
||||
_session_id: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
Ok(Vec::new())
|
||||
}
|
||||
|
|
@ -991,7 +1019,7 @@ mod tests {
|
|||
temperature: 0.0,
|
||||
mem: memory,
|
||||
auto_save: false,
|
||||
webhook_secret: None,
|
||||
webhook_secret_hash: None,
|
||||
pairing: Arc::new(PairingGuard::new(false, &[])),
|
||||
rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100)),
|
||||
idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300))),
|
||||
|
|
@ -1039,7 +1067,7 @@ mod tests {
|
|||
temperature: 0.0,
|
||||
mem: memory,
|
||||
auto_save: true,
|
||||
webhook_secret: None,
|
||||
webhook_secret_hash: None,
|
||||
pairing: Arc::new(PairingGuard::new(false, &[])),
|
||||
rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100)),
|
||||
idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300))),
|
||||
|
|
@ -1077,6 +1105,125 @@ mod tests {
|
|||
assert_eq!(provider_impl.calls.load(Ordering::SeqCst), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn webhook_secret_hash_is_deterministic_and_nonempty() {
|
||||
let one = hash_webhook_secret("secret-value");
|
||||
let two = hash_webhook_secret("secret-value");
|
||||
let other = hash_webhook_secret("other-value");
|
||||
|
||||
assert_eq!(one, two);
|
||||
assert_ne!(one, other);
|
||||
assert_eq!(one.len(), 64);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn webhook_secret_hash_rejects_missing_header() {
|
||||
let provider_impl = Arc::new(MockProvider::default());
|
||||
let provider: Arc<dyn Provider> = provider_impl.clone();
|
||||
let memory: Arc<dyn Memory> = Arc::new(MockMemory);
|
||||
|
||||
let state = AppState {
|
||||
provider,
|
||||
model: "test-model".into(),
|
||||
temperature: 0.0,
|
||||
mem: memory,
|
||||
auto_save: false,
|
||||
webhook_secret_hash: Some(Arc::from(hash_webhook_secret("super-secret"))),
|
||||
pairing: Arc::new(PairingGuard::new(false, &[])),
|
||||
rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100)),
|
||||
idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300))),
|
||||
whatsapp: None,
|
||||
whatsapp_app_secret: None,
|
||||
};
|
||||
|
||||
let response = handle_webhook(
|
||||
State(state),
|
||||
HeaderMap::new(),
|
||||
Ok(Json(WebhookBody {
|
||||
message: "hello".into(),
|
||||
})),
|
||||
)
|
||||
.await
|
||||
.into_response();
|
||||
|
||||
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
|
||||
assert_eq!(provider_impl.calls.load(Ordering::SeqCst), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn webhook_secret_hash_rejects_invalid_header() {
|
||||
let provider_impl = Arc::new(MockProvider::default());
|
||||
let provider: Arc<dyn Provider> = provider_impl.clone();
|
||||
let memory: Arc<dyn Memory> = Arc::new(MockMemory);
|
||||
|
||||
let state = AppState {
|
||||
provider,
|
||||
model: "test-model".into(),
|
||||
temperature: 0.0,
|
||||
mem: memory,
|
||||
auto_save: false,
|
||||
webhook_secret_hash: Some(Arc::from(hash_webhook_secret("super-secret"))),
|
||||
pairing: Arc::new(PairingGuard::new(false, &[])),
|
||||
rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100)),
|
||||
idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300))),
|
||||
whatsapp: None,
|
||||
whatsapp_app_secret: None,
|
||||
};
|
||||
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert("X-Webhook-Secret", HeaderValue::from_static("wrong-secret"));
|
||||
|
||||
let response = handle_webhook(
|
||||
State(state),
|
||||
headers,
|
||||
Ok(Json(WebhookBody {
|
||||
message: "hello".into(),
|
||||
})),
|
||||
)
|
||||
.await
|
||||
.into_response();
|
||||
|
||||
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
|
||||
assert_eq!(provider_impl.calls.load(Ordering::SeqCst), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn webhook_secret_hash_accepts_valid_header() {
|
||||
let provider_impl = Arc::new(MockProvider::default());
|
||||
let provider: Arc<dyn Provider> = provider_impl.clone();
|
||||
let memory: Arc<dyn Memory> = Arc::new(MockMemory);
|
||||
|
||||
let state = AppState {
|
||||
provider,
|
||||
model: "test-model".into(),
|
||||
temperature: 0.0,
|
||||
mem: memory,
|
||||
auto_save: false,
|
||||
webhook_secret_hash: Some(Arc::from(hash_webhook_secret("super-secret"))),
|
||||
pairing: Arc::new(PairingGuard::new(false, &[])),
|
||||
rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100)),
|
||||
idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300))),
|
||||
whatsapp: None,
|
||||
whatsapp_app_secret: None,
|
||||
};
|
||||
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert("X-Webhook-Secret", HeaderValue::from_static("super-secret"));
|
||||
|
||||
let response = handle_webhook(
|
||||
State(state),
|
||||
headers,
|
||||
Ok(Json(WebhookBody {
|
||||
message: "hello".into(),
|
||||
})),
|
||||
)
|
||||
.await
|
||||
.into_response();
|
||||
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
assert_eq!(provider_impl.calls.load(Ordering::SeqCst), 1);
|
||||
}
|
||||
|
||||
// ══════════════════════════════════════════════════════════
|
||||
// WhatsApp Signature Verification Tests (CWE-345 Prevention)
|
||||
// ══════════════════════════════════════════════════════════
|
||||
|
|
|
|||
40
src/main.rs
40
src/main.rs
|
|
@ -34,8 +34,8 @@
|
|||
|
||||
use anyhow::{bail, Result};
|
||||
use clap::{Parser, Subcommand};
|
||||
use tracing::{info, Level};
|
||||
use tracing_subscriber::FmtSubscriber;
|
||||
use tracing::info;
|
||||
use tracing_subscriber::{fmt, EnvFilter};
|
||||
|
||||
mod agent;
|
||||
mod channels;
|
||||
|
|
@ -147,24 +147,24 @@ enum Commands {
|
|||
|
||||
/// Start the gateway server (webhooks, websockets)
|
||||
Gateway {
|
||||
/// Port to listen on (use 0 for random available port)
|
||||
#[arg(short, long, default_value = "8080")]
|
||||
port: u16,
|
||||
/// Port to listen on (use 0 for random available port); defaults to config gateway.port
|
||||
#[arg(short, long)]
|
||||
port: Option<u16>,
|
||||
|
||||
/// Host to bind to
|
||||
#[arg(long, default_value = "127.0.0.1")]
|
||||
host: String,
|
||||
/// Host to bind to; defaults to config gateway.host
|
||||
#[arg(long)]
|
||||
host: Option<String>,
|
||||
},
|
||||
|
||||
/// Start long-running autonomous runtime (gateway + channels + heartbeat + scheduler)
|
||||
Daemon {
|
||||
/// Port to listen on (use 0 for random available port)
|
||||
#[arg(short, long, default_value = "8080")]
|
||||
port: u16,
|
||||
/// Port to listen on (use 0 for random available port); defaults to config gateway.port
|
||||
#[arg(short, long)]
|
||||
port: Option<u16>,
|
||||
|
||||
/// Host to bind to
|
||||
#[arg(long, default_value = "127.0.0.1")]
|
||||
host: String,
|
||||
/// Host to bind to; defaults to config gateway.host
|
||||
#[arg(long)]
|
||||
host: Option<String>,
|
||||
},
|
||||
|
||||
/// Manage OS service lifecycle (launchd/systemd user service)
|
||||
|
|
@ -367,9 +367,11 @@ async fn main() -> Result<()> {
|
|||
|
||||
let cli = Cli::parse();
|
||||
|
||||
// Initialize logging
|
||||
let subscriber = FmtSubscriber::builder()
|
||||
.with_max_level(Level::INFO)
|
||||
// Initialize logging - respects RUST_LOG env var, defaults to INFO
|
||||
let subscriber = fmt::Subscriber::builder()
|
||||
.with_env_filter(
|
||||
EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")),
|
||||
)
|
||||
.finish();
|
||||
|
||||
tracing::subscriber::set_global_default(subscriber).expect("setting default subscriber failed");
|
||||
|
|
@ -434,6 +436,8 @@ async fn main() -> Result<()> {
|
|||
.map(|_| ()),
|
||||
|
||||
Commands::Gateway { port, host } => {
|
||||
let port = port.unwrap_or(config.gateway.port);
|
||||
let host = host.unwrap_or_else(|| config.gateway.host.clone());
|
||||
if port == 0 {
|
||||
info!("🚀 Starting ZeroClaw Gateway on {host} (random port)");
|
||||
} else {
|
||||
|
|
@ -443,6 +447,8 @@ async fn main() -> Result<()> {
|
|||
}
|
||||
|
||||
Commands::Daemon { port, host } => {
|
||||
let port = port.unwrap_or(config.gateway.port);
|
||||
let host = host.unwrap_or_else(|| config.gateway.host.clone());
|
||||
if port == 0 {
|
||||
info!("🧠 Starting ZeroClaw Daemon on {host} (random port)");
|
||||
} else {
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ pub enum MemoryBackendKind {
|
|||
Unknown,
|
||||
}
|
||||
|
||||
#[allow(clippy::struct_excessive_bools)]
|
||||
#[derive(Debug, Clone, Copy, Eq, PartialEq)]
|
||||
pub struct MemoryBackendProfile {
|
||||
pub key: &'static str,
|
||||
|
|
|
|||
|
|
@ -502,10 +502,10 @@ mod tests {
|
|||
let workspace = tmp.path();
|
||||
|
||||
let mem = SqliteMemory::new(workspace).unwrap();
|
||||
mem.store("conv_old", "outdated", MemoryCategory::Conversation)
|
||||
mem.store("conv_old", "outdated", MemoryCategory::Conversation, None)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("core_keep", "durable", MemoryCategory::Core)
|
||||
mem.store("core_keep", "durable", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
drop(mem);
|
||||
|
|
|
|||
|
|
@ -24,7 +24,9 @@ pub struct LucidMemory {
|
|||
impl LucidMemory {
|
||||
const DEFAULT_LUCID_CMD: &'static str = "lucid";
|
||||
const DEFAULT_TOKEN_BUDGET: usize = 200;
|
||||
const DEFAULT_RECALL_TIMEOUT_MS: u64 = 120;
|
||||
// Lucid CLI cold start can exceed 120ms on slower machines, which causes
|
||||
// avoidable fallback to local-only memory and premature cooldown.
|
||||
const DEFAULT_RECALL_TIMEOUT_MS: u64 = 500;
|
||||
const DEFAULT_STORE_TIMEOUT_MS: u64 = 800;
|
||||
const DEFAULT_LOCAL_HIT_THRESHOLD: usize = 3;
|
||||
const DEFAULT_FAILURE_COOLDOWN_MS: u64 = 15_000;
|
||||
|
|
@ -74,6 +76,7 @@ impl LucidMemory {
|
|||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn with_options(
|
||||
workspace_dir: &Path,
|
||||
local: SqliteMemory,
|
||||
|
|
@ -307,14 +310,22 @@ impl Memory for LucidMemory {
|
|||
key: &str,
|
||||
content: &str,
|
||||
category: MemoryCategory,
|
||||
session_id: Option<&str>,
|
||||
) -> anyhow::Result<()> {
|
||||
self.local.store(key, content, category.clone()).await?;
|
||||
self.local
|
||||
.store(key, content, category.clone(), session_id)
|
||||
.await?;
|
||||
self.sync_to_lucid_async(key, content, &category).await;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn recall(&self, query: &str, limit: usize) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
let local_results = self.local.recall(query, limit).await?;
|
||||
async fn recall(
|
||||
&self,
|
||||
query: &str,
|
||||
limit: usize,
|
||||
session_id: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
let local_results = self.local.recall(query, limit, session_id).await?;
|
||||
if limit == 0
|
||||
|| local_results.len() >= limit
|
||||
|| local_results.len() >= self.local_hit_threshold
|
||||
|
|
@ -351,8 +362,12 @@ impl Memory for LucidMemory {
|
|||
self.local.get(key).await
|
||||
}
|
||||
|
||||
async fn list(&self, category: Option<&MemoryCategory>) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
self.local.list(category).await
|
||||
async fn list(
|
||||
&self,
|
||||
category: Option<&MemoryCategory>,
|
||||
session_id: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
self.local.list(category, session_id).await
|
||||
}
|
||||
|
||||
async fn forget(&self, key: &str) -> anyhow::Result<bool> {
|
||||
|
|
@ -396,6 +411,38 @@ EOF
|
|||
exit 0
|
||||
fi
|
||||
|
||||
echo "unsupported command" >&2
|
||||
exit 1
|
||||
"#;
|
||||
|
||||
fs::write(&script_path, script).unwrap();
|
||||
let mut perms = fs::metadata(&script_path).unwrap().permissions();
|
||||
perms.set_mode(0o755);
|
||||
fs::set_permissions(&script_path, perms).unwrap();
|
||||
script_path.display().to_string()
|
||||
}
|
||||
|
||||
fn write_delayed_lucid_script(dir: &Path) -> String {
|
||||
let script_path = dir.join("delayed-lucid.sh");
|
||||
let script = r#"#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
if [[ "${1:-}" == "store" ]]; then
|
||||
echo '{"success":true,"id":"mem_1"}'
|
||||
exit 0
|
||||
fi
|
||||
|
||||
if [[ "${1:-}" == "context" ]]; then
|
||||
# Simulate a cold start that is slower than 120ms but below the 500ms timeout.
|
||||
sleep 0.2
|
||||
cat <<'EOF'
|
||||
<lucid-context>
|
||||
- [decision] Delayed token refresh guidance
|
||||
</lucid-context>
|
||||
EOF
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "unsupported command" >&2
|
||||
exit 1
|
||||
"#;
|
||||
|
|
@ -449,7 +496,7 @@ exit 1
|
|||
cmd,
|
||||
200,
|
||||
3,
|
||||
Duration::from_millis(120),
|
||||
Duration::from_millis(500),
|
||||
Duration::from_millis(400),
|
||||
Duration::from_secs(2),
|
||||
)
|
||||
|
|
@ -468,7 +515,7 @@ exit 1
|
|||
let memory = test_memory(tmp.path(), "nonexistent-lucid-binary".to_string());
|
||||
|
||||
memory
|
||||
.store("lang", "User prefers Rust", MemoryCategory::Core)
|
||||
.store("lang", "User prefers Rust", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
|
|
@ -483,6 +530,30 @@ exit 1
|
|||
let fake_cmd = write_fake_lucid_script(tmp.path());
|
||||
let memory = test_memory(tmp.path(), fake_cmd);
|
||||
|
||||
memory
|
||||
.store(
|
||||
"local_note",
|
||||
"Local sqlite auth fallback note",
|
||||
MemoryCategory::Core,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let entries = memory.recall("auth", 5, None).await.unwrap();
|
||||
|
||||
assert!(entries
|
||||
.iter()
|
||||
.any(|e| e.content.contains("Local sqlite auth fallback note")));
|
||||
assert!(entries.iter().any(|e| e.content.contains("token refresh")));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn recall_handles_lucid_cold_start_delay_within_timeout() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let delayed_cmd = write_delayed_lucid_script(tmp.path());
|
||||
let memory = test_memory(tmp.path(), delayed_cmd);
|
||||
|
||||
memory
|
||||
.store(
|
||||
"local_note",
|
||||
|
|
@ -497,7 +568,9 @@ exit 1
|
|||
assert!(entries
|
||||
.iter()
|
||||
.any(|e| e.content.contains("Local sqlite auth fallback note")));
|
||||
assert!(entries.iter().any(|e| e.content.contains("token refresh")));
|
||||
assert!(entries
|
||||
.iter()
|
||||
.any(|e| e.content.contains("Delayed token refresh guidance")));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
|
@ -513,17 +586,22 @@ exit 1
|
|||
probe_cmd,
|
||||
200,
|
||||
1,
|
||||
Duration::from_millis(120),
|
||||
Duration::from_millis(500),
|
||||
Duration::from_millis(400),
|
||||
Duration::from_secs(2),
|
||||
);
|
||||
|
||||
memory
|
||||
.store("pref", "Rust should stay local-first", MemoryCategory::Core)
|
||||
.store(
|
||||
"pref",
|
||||
"Rust should stay local-first",
|
||||
MemoryCategory::Core,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let entries = memory.recall("rust", 5).await.unwrap();
|
||||
let entries = memory.recall("rust", 5, None).await.unwrap();
|
||||
assert!(entries
|
||||
.iter()
|
||||
.any(|e| e.content.contains("Rust should stay local-first")));
|
||||
|
|
@ -578,13 +656,13 @@ exit 1
|
|||
failing_cmd,
|
||||
200,
|
||||
99,
|
||||
Duration::from_millis(120),
|
||||
Duration::from_millis(500),
|
||||
Duration::from_millis(400),
|
||||
Duration::from_secs(5),
|
||||
);
|
||||
|
||||
let first = memory.recall("auth", 5).await.unwrap();
|
||||
let second = memory.recall("auth", 5).await.unwrap();
|
||||
let first = memory.recall("auth", 5, None).await.unwrap();
|
||||
let second = memory.recall("auth", 5, None).await.unwrap();
|
||||
|
||||
assert!(first.is_empty());
|
||||
assert!(second.is_empty());
|
||||
|
|
|
|||
|
|
@ -143,6 +143,7 @@ impl Memory for MarkdownMemory {
|
|||
key: &str,
|
||||
content: &str,
|
||||
category: MemoryCategory,
|
||||
_session_id: Option<&str>,
|
||||
) -> anyhow::Result<()> {
|
||||
let entry = format!("- **{key}**: {content}");
|
||||
let path = match category {
|
||||
|
|
@ -152,7 +153,12 @@ impl Memory for MarkdownMemory {
|
|||
self.append_to_file(&path, &entry).await
|
||||
}
|
||||
|
||||
async fn recall(&self, query: &str, limit: usize) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
async fn recall(
|
||||
&self,
|
||||
query: &str,
|
||||
limit: usize,
|
||||
_session_id: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
let all = self.read_all_entries().await?;
|
||||
let query_lower = query.to_lowercase();
|
||||
let keywords: Vec<&str> = query_lower.split_whitespace().collect();
|
||||
|
|
@ -192,7 +198,11 @@ impl Memory for MarkdownMemory {
|
|||
.find(|e| e.key == key || e.content.contains(key)))
|
||||
}
|
||||
|
||||
async fn list(&self, category: Option<&MemoryCategory>) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
async fn list(
|
||||
&self,
|
||||
category: Option<&MemoryCategory>,
|
||||
_session_id: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
let all = self.read_all_entries().await?;
|
||||
match category {
|
||||
Some(cat) => Ok(all.into_iter().filter(|e| &e.category == cat).collect()),
|
||||
|
|
@ -243,7 +253,7 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn markdown_store_core() {
|
||||
let (_tmp, mem) = temp_workspace();
|
||||
mem.store("pref", "User likes Rust", MemoryCategory::Core)
|
||||
mem.store("pref", "User likes Rust", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
let content = sync_fs::read_to_string(mem.core_path()).unwrap();
|
||||
|
|
@ -253,7 +263,7 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn markdown_store_daily() {
|
||||
let (_tmp, mem) = temp_workspace();
|
||||
mem.store("note", "Finished tests", MemoryCategory::Daily)
|
||||
mem.store("note", "Finished tests", MemoryCategory::Daily, None)
|
||||
.await
|
||||
.unwrap();
|
||||
let path = mem.daily_path();
|
||||
|
|
@ -264,17 +274,17 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn markdown_recall_keyword() {
|
||||
let (_tmp, mem) = temp_workspace();
|
||||
mem.store("a", "Rust is fast", MemoryCategory::Core)
|
||||
mem.store("a", "Rust is fast", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("b", "Python is slow", MemoryCategory::Core)
|
||||
mem.store("b", "Python is slow", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("c", "Rust and safety", MemoryCategory::Core)
|
||||
mem.store("c", "Rust and safety", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let results = mem.recall("Rust", 10).await.unwrap();
|
||||
let results = mem.recall("Rust", 10, None).await.unwrap();
|
||||
assert!(results.len() >= 2);
|
||||
assert!(results
|
||||
.iter()
|
||||
|
|
@ -284,18 +294,20 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn markdown_recall_no_match() {
|
||||
let (_tmp, mem) = temp_workspace();
|
||||
mem.store("a", "Rust is great", MemoryCategory::Core)
|
||||
mem.store("a", "Rust is great", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
let results = mem.recall("javascript", 10).await.unwrap();
|
||||
let results = mem.recall("javascript", 10, None).await.unwrap();
|
||||
assert!(results.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn markdown_count() {
|
||||
let (_tmp, mem) = temp_workspace();
|
||||
mem.store("a", "first", MemoryCategory::Core).await.unwrap();
|
||||
mem.store("b", "second", MemoryCategory::Core)
|
||||
mem.store("a", "first", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("b", "second", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
let count = mem.count().await.unwrap();
|
||||
|
|
@ -305,24 +317,24 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn markdown_list_by_category() {
|
||||
let (_tmp, mem) = temp_workspace();
|
||||
mem.store("a", "core fact", MemoryCategory::Core)
|
||||
mem.store("a", "core fact", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("b", "daily note", MemoryCategory::Daily)
|
||||
mem.store("b", "daily note", MemoryCategory::Daily, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let core = mem.list(Some(&MemoryCategory::Core)).await.unwrap();
|
||||
let core = mem.list(Some(&MemoryCategory::Core), None).await.unwrap();
|
||||
assert!(core.iter().all(|e| e.category == MemoryCategory::Core));
|
||||
|
||||
let daily = mem.list(Some(&MemoryCategory::Daily)).await.unwrap();
|
||||
let daily = mem.list(Some(&MemoryCategory::Daily), None).await.unwrap();
|
||||
assert!(daily.iter().all(|e| e.category == MemoryCategory::Daily));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn markdown_forget_is_noop() {
|
||||
let (_tmp, mem) = temp_workspace();
|
||||
mem.store("a", "permanent", MemoryCategory::Core)
|
||||
mem.store("a", "permanent", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
let removed = mem.forget("a").await.unwrap();
|
||||
|
|
@ -332,7 +344,7 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn markdown_empty_recall() {
|
||||
let (_tmp, mem) = temp_workspace();
|
||||
let results = mem.recall("anything", 10).await.unwrap();
|
||||
let results = mem.recall("anything", 10, None).await.unwrap();
|
||||
assert!(results.is_empty());
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -25,11 +25,17 @@ impl Memory for NoneMemory {
|
|||
_key: &str,
|
||||
_content: &str,
|
||||
_category: MemoryCategory,
|
||||
_session_id: Option<&str>,
|
||||
) -> anyhow::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn recall(&self, _query: &str, _limit: usize) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
async fn recall(
|
||||
&self,
|
||||
_query: &str,
|
||||
_limit: usize,
|
||||
_session_id: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
Ok(Vec::new())
|
||||
}
|
||||
|
||||
|
|
@ -37,7 +43,11 @@ impl Memory for NoneMemory {
|
|||
Ok(None)
|
||||
}
|
||||
|
||||
async fn list(&self, _category: Option<&MemoryCategory>) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
async fn list(
|
||||
&self,
|
||||
_category: Option<&MemoryCategory>,
|
||||
_session_id: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
Ok(Vec::new())
|
||||
}
|
||||
|
||||
|
|
@ -62,11 +72,14 @@ mod tests {
|
|||
async fn none_memory_is_noop() {
|
||||
let memory = NoneMemory::new();
|
||||
|
||||
memory.store("k", "v", MemoryCategory::Core).await.unwrap();
|
||||
memory
|
||||
.store("k", "v", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(memory.get("k").await.unwrap().is_none());
|
||||
assert!(memory.recall("k", 10).await.unwrap().is_empty());
|
||||
assert!(memory.list(None).await.unwrap().is_empty());
|
||||
assert!(memory.recall("k", 10, None).await.unwrap().is_empty());
|
||||
assert!(memory.list(None, None).await.unwrap().is_empty());
|
||||
assert!(!memory.forget("k").await.unwrap());
|
||||
assert_eq!(memory.count().await.unwrap(), 0);
|
||||
assert!(memory.health_check().await);
|
||||
|
|
|
|||
|
|
@ -157,7 +157,7 @@ impl ResponseCache {
|
|||
|row| row.get(0),
|
||||
)?;
|
||||
|
||||
#[allow(clippy::cast_sign_loss)]
|
||||
#[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)]
|
||||
Ok((count as usize, hits as u64, tokens_saved as u64))
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -124,6 +124,19 @@ impl SqliteMemory {
|
|||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_cache_accessed ON embedding_cache(accessed_at);",
|
||||
)?;
|
||||
|
||||
// Migration: add session_id column if not present (safe to run repeatedly)
|
||||
let has_session_id: bool = conn
|
||||
.prepare("SELECT sql FROM sqlite_master WHERE type='table' AND name='memories'")?
|
||||
.query_row([], |row| row.get::<_, String>(0))?
|
||||
.contains("session_id");
|
||||
if !has_session_id {
|
||||
conn.execute_batch(
|
||||
"ALTER TABLE memories ADD COLUMN session_id TEXT;
|
||||
CREATE INDEX IF NOT EXISTS idx_memories_session ON memories(session_id);",
|
||||
)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
|
@ -361,6 +374,7 @@ impl Memory for SqliteMemory {
|
|||
key: &str,
|
||||
content: &str,
|
||||
category: MemoryCategory,
|
||||
session_id: Option<&str>,
|
||||
) -> anyhow::Result<()> {
|
||||
// Compute embedding (async, before lock)
|
||||
let embedding_bytes = self
|
||||
|
|
@ -377,20 +391,26 @@ impl Memory for SqliteMemory {
|
|||
let id = Uuid::new_v4().to_string();
|
||||
|
||||
conn.execute(
|
||||
"INSERT INTO memories (id, key, content, category, embedding, created_at, updated_at)
|
||||
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)
|
||||
"INSERT INTO memories (id, key, content, category, embedding, created_at, updated_at, session_id)
|
||||
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)
|
||||
ON CONFLICT(key) DO UPDATE SET
|
||||
content = excluded.content,
|
||||
category = excluded.category,
|
||||
embedding = excluded.embedding,
|
||||
updated_at = excluded.updated_at",
|
||||
params![id, key, content, cat, embedding_bytes, now, now],
|
||||
updated_at = excluded.updated_at,
|
||||
session_id = excluded.session_id",
|
||||
params![id, key, content, cat, embedding_bytes, now, now, session_id],
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn recall(&self, query: &str, limit: usize) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
async fn recall(
|
||||
&self,
|
||||
query: &str,
|
||||
limit: usize,
|
||||
session_id: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
if query.trim().is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
|
@ -439,7 +459,7 @@ impl Memory for SqliteMemory {
|
|||
let mut results = Vec::new();
|
||||
for scored in &merged {
|
||||
let mut stmt = conn.prepare(
|
||||
"SELECT id, key, content, category, created_at FROM memories WHERE id = ?1",
|
||||
"SELECT id, key, content, category, created_at, session_id FROM memories WHERE id = ?1",
|
||||
)?;
|
||||
if let Ok(entry) = stmt.query_row(params![scored.id], |row| {
|
||||
Ok(MemoryEntry {
|
||||
|
|
@ -448,10 +468,16 @@ impl Memory for SqliteMemory {
|
|||
content: row.get(2)?,
|
||||
category: Self::str_to_category(&row.get::<_, String>(3)?),
|
||||
timestamp: row.get(4)?,
|
||||
session_id: None,
|
||||
session_id: row.get(5)?,
|
||||
score: Some(f64::from(scored.final_score)),
|
||||
})
|
||||
}) {
|
||||
// Filter by session_id if requested
|
||||
if let Some(sid) = session_id {
|
||||
if entry.session_id.as_deref() != Some(sid) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
results.push(entry);
|
||||
}
|
||||
}
|
||||
|
|
@ -470,7 +496,7 @@ impl Memory for SqliteMemory {
|
|||
.collect();
|
||||
let where_clause = conditions.join(" OR ");
|
||||
let sql = format!(
|
||||
"SELECT id, key, content, category, created_at FROM memories
|
||||
"SELECT id, key, content, category, created_at, session_id FROM memories
|
||||
WHERE {where_clause}
|
||||
ORDER BY updated_at DESC
|
||||
LIMIT ?{}",
|
||||
|
|
@ -493,12 +519,18 @@ impl Memory for SqliteMemory {
|
|||
content: row.get(2)?,
|
||||
category: Self::str_to_category(&row.get::<_, String>(3)?),
|
||||
timestamp: row.get(4)?,
|
||||
session_id: None,
|
||||
session_id: row.get(5)?,
|
||||
score: Some(1.0),
|
||||
})
|
||||
})?;
|
||||
for row in rows {
|
||||
results.push(row?);
|
||||
let entry = row?;
|
||||
if let Some(sid) = session_id {
|
||||
if entry.session_id.as_deref() != Some(sid) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
results.push(entry);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -514,7 +546,7 @@ impl Memory for SqliteMemory {
|
|||
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
|
||||
|
||||
let mut stmt = conn.prepare(
|
||||
"SELECT id, key, content, category, created_at FROM memories WHERE key = ?1",
|
||||
"SELECT id, key, content, category, created_at, session_id FROM memories WHERE key = ?1",
|
||||
)?;
|
||||
|
||||
let mut rows = stmt.query_map(params![key], |row| {
|
||||
|
|
@ -524,7 +556,7 @@ impl Memory for SqliteMemory {
|
|||
content: row.get(2)?,
|
||||
category: Self::str_to_category(&row.get::<_, String>(3)?),
|
||||
timestamp: row.get(4)?,
|
||||
session_id: None,
|
||||
session_id: row.get(5)?,
|
||||
score: None,
|
||||
})
|
||||
})?;
|
||||
|
|
@ -535,7 +567,11 @@ impl Memory for SqliteMemory {
|
|||
}
|
||||
}
|
||||
|
||||
async fn list(&self, category: Option<&MemoryCategory>) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
async fn list(
|
||||
&self,
|
||||
category: Option<&MemoryCategory>,
|
||||
session_id: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
let conn = self
|
||||
.conn
|
||||
.lock()
|
||||
|
|
@ -550,7 +586,7 @@ impl Memory for SqliteMemory {
|
|||
content: row.get(2)?,
|
||||
category: Self::str_to_category(&row.get::<_, String>(3)?),
|
||||
timestamp: row.get(4)?,
|
||||
session_id: None,
|
||||
session_id: row.get(5)?,
|
||||
score: None,
|
||||
})
|
||||
};
|
||||
|
|
@ -558,21 +594,33 @@ impl Memory for SqliteMemory {
|
|||
if let Some(cat) = category {
|
||||
let cat_str = Self::category_to_str(cat);
|
||||
let mut stmt = conn.prepare(
|
||||
"SELECT id, key, content, category, created_at FROM memories
|
||||
"SELECT id, key, content, category, created_at, session_id FROM memories
|
||||
WHERE category = ?1 ORDER BY updated_at DESC",
|
||||
)?;
|
||||
let rows = stmt.query_map(params![cat_str], row_mapper)?;
|
||||
for row in rows {
|
||||
results.push(row?);
|
||||
let entry = row?;
|
||||
if let Some(sid) = session_id {
|
||||
if entry.session_id.as_deref() != Some(sid) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
results.push(entry);
|
||||
}
|
||||
} else {
|
||||
let mut stmt = conn.prepare(
|
||||
"SELECT id, key, content, category, created_at FROM memories
|
||||
"SELECT id, key, content, category, created_at, session_id FROM memories
|
||||
ORDER BY updated_at DESC",
|
||||
)?;
|
||||
let rows = stmt.query_map([], row_mapper)?;
|
||||
for row in rows {
|
||||
results.push(row?);
|
||||
let entry = row?;
|
||||
if let Some(sid) = session_id {
|
||||
if entry.session_id.as_deref() != Some(sid) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
results.push(entry);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -632,7 +680,7 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn sqlite_store_and_get() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("user_lang", "Prefers Rust", MemoryCategory::Core)
|
||||
mem.store("user_lang", "Prefers Rust", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
|
|
@ -647,10 +695,10 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn sqlite_store_upsert() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("pref", "likes Rust", MemoryCategory::Core)
|
||||
mem.store("pref", "likes Rust", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("pref", "loves Rust", MemoryCategory::Core)
|
||||
mem.store("pref", "loves Rust", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
|
|
@ -662,17 +710,22 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn sqlite_recall_keyword() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("a", "Rust is fast and safe", MemoryCategory::Core)
|
||||
mem.store("a", "Rust is fast and safe", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("b", "Python is interpreted", MemoryCategory::Core)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("c", "Rust has zero-cost abstractions", MemoryCategory::Core)
|
||||
mem.store("b", "Python is interpreted", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store(
|
||||
"c",
|
||||
"Rust has zero-cost abstractions",
|
||||
MemoryCategory::Core,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let results = mem.recall("Rust", 10).await.unwrap();
|
||||
let results = mem.recall("Rust", 10, None).await.unwrap();
|
||||
assert_eq!(results.len(), 2);
|
||||
assert!(results
|
||||
.iter()
|
||||
|
|
@ -682,14 +735,14 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn sqlite_recall_multi_keyword() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("a", "Rust is fast", MemoryCategory::Core)
|
||||
mem.store("a", "Rust is fast", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("b", "Rust is safe and fast", MemoryCategory::Core)
|
||||
mem.store("b", "Rust is safe and fast", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let results = mem.recall("fast safe", 10).await.unwrap();
|
||||
let results = mem.recall("fast safe", 10, None).await.unwrap();
|
||||
assert!(!results.is_empty());
|
||||
// Entry with both keywords should score higher
|
||||
assert!(results[0].content.contains("safe") && results[0].content.contains("fast"));
|
||||
|
|
@ -698,17 +751,17 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn sqlite_recall_no_match() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("a", "Rust rocks", MemoryCategory::Core)
|
||||
mem.store("a", "Rust rocks", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
let results = mem.recall("javascript", 10).await.unwrap();
|
||||
let results = mem.recall("javascript", 10, None).await.unwrap();
|
||||
assert!(results.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn sqlite_forget() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("temp", "temporary data", MemoryCategory::Conversation)
|
||||
mem.store("temp", "temporary data", MemoryCategory::Conversation, None)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(mem.count().await.unwrap(), 1);
|
||||
|
|
@ -728,29 +781,37 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn sqlite_list_all() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("a", "one", MemoryCategory::Core).await.unwrap();
|
||||
mem.store("b", "two", MemoryCategory::Daily).await.unwrap();
|
||||
mem.store("c", "three", MemoryCategory::Conversation)
|
||||
mem.store("a", "one", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("b", "two", MemoryCategory::Daily, None)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("c", "three", MemoryCategory::Conversation, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let all = mem.list(None).await.unwrap();
|
||||
let all = mem.list(None, None).await.unwrap();
|
||||
assert_eq!(all.len(), 3);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn sqlite_list_by_category() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("a", "core1", MemoryCategory::Core).await.unwrap();
|
||||
mem.store("b", "core2", MemoryCategory::Core).await.unwrap();
|
||||
mem.store("c", "daily1", MemoryCategory::Daily)
|
||||
mem.store("a", "core1", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("b", "core2", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("c", "daily1", MemoryCategory::Daily, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let core = mem.list(Some(&MemoryCategory::Core)).await.unwrap();
|
||||
let core = mem.list(Some(&MemoryCategory::Core), None).await.unwrap();
|
||||
assert_eq!(core.len(), 2);
|
||||
|
||||
let daily = mem.list(Some(&MemoryCategory::Daily)).await.unwrap();
|
||||
let daily = mem.list(Some(&MemoryCategory::Daily), None).await.unwrap();
|
||||
assert_eq!(daily.len(), 1);
|
||||
}
|
||||
|
||||
|
|
@ -772,7 +833,7 @@ mod tests {
|
|||
|
||||
{
|
||||
let mem = SqliteMemory::new(tmp.path()).unwrap();
|
||||
mem.store("persist", "I survive restarts", MemoryCategory::Core)
|
||||
mem.store("persist", "I survive restarts", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
|
@ -795,7 +856,7 @@ mod tests {
|
|||
];
|
||||
|
||||
for (i, cat) in categories.iter().enumerate() {
|
||||
mem.store(&format!("k{i}"), &format!("v{i}"), cat.clone())
|
||||
mem.store(&format!("k{i}"), &format!("v{i}"), cat.clone(), None)
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
|
@ -815,21 +876,28 @@ mod tests {
|
|||
"a",
|
||||
"Rust is a systems programming language",
|
||||
MemoryCategory::Core,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store(
|
||||
"b",
|
||||
"Python is great for scripting",
|
||||
MemoryCategory::Core,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("b", "Python is great for scripting", MemoryCategory::Core)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store(
|
||||
"c",
|
||||
"Rust and Rust and Rust everywhere",
|
||||
MemoryCategory::Core,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let results = mem.recall("Rust", 10).await.unwrap();
|
||||
let results = mem.recall("Rust", 10, None).await.unwrap();
|
||||
assert!(results.len() >= 2);
|
||||
// All results should contain "Rust"
|
||||
for r in &results {
|
||||
|
|
@ -844,17 +912,17 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn fts5_multi_word_query() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("a", "The quick brown fox jumps", MemoryCategory::Core)
|
||||
mem.store("a", "The quick brown fox jumps", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("b", "A lazy dog sleeps", MemoryCategory::Core)
|
||||
mem.store("b", "A lazy dog sleeps", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("c", "The quick dog runs fast", MemoryCategory::Core)
|
||||
mem.store("c", "The quick dog runs fast", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let results = mem.recall("quick dog", 10).await.unwrap();
|
||||
let results = mem.recall("quick dog", 10, None).await.unwrap();
|
||||
assert!(!results.is_empty());
|
||||
// "The quick dog runs fast" matches both terms
|
||||
assert!(results[0].content.contains("quick"));
|
||||
|
|
@ -863,16 +931,20 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn recall_empty_query_returns_empty() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("a", "data", MemoryCategory::Core).await.unwrap();
|
||||
let results = mem.recall("", 10).await.unwrap();
|
||||
mem.store("a", "data", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
let results = mem.recall("", 10, None).await.unwrap();
|
||||
assert!(results.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn recall_whitespace_query_returns_empty() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("a", "data", MemoryCategory::Core).await.unwrap();
|
||||
let results = mem.recall(" ", 10).await.unwrap();
|
||||
mem.store("a", "data", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
let results = mem.recall(" ", 10, None).await.unwrap();
|
||||
assert!(results.is_empty());
|
||||
}
|
||||
|
||||
|
|
@ -937,9 +1009,14 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn fts5_syncs_on_insert() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("test_key", "unique_searchterm_xyz", MemoryCategory::Core)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store(
|
||||
"test_key",
|
||||
"unique_searchterm_xyz",
|
||||
MemoryCategory::Core,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let conn = mem.conn.lock();
|
||||
let count: i64 = conn
|
||||
|
|
@ -955,9 +1032,14 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn fts5_syncs_on_delete() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("del_key", "deletable_content_abc", MemoryCategory::Core)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store(
|
||||
"del_key",
|
||||
"deletable_content_abc",
|
||||
MemoryCategory::Core,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.forget("del_key").await.unwrap();
|
||||
|
||||
let conn = mem.conn.lock();
|
||||
|
|
@ -974,10 +1056,15 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn fts5_syncs_on_update() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("upd_key", "original_content_111", MemoryCategory::Core)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("upd_key", "updated_content_222", MemoryCategory::Core)
|
||||
mem.store(
|
||||
"upd_key",
|
||||
"original_content_111",
|
||||
MemoryCategory::Core,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("upd_key", "updated_content_222", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
|
|
@ -1019,10 +1106,10 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn reindex_rebuilds_fts() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("r1", "reindex test alpha", MemoryCategory::Core)
|
||||
mem.store("r1", "reindex test alpha", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("r2", "reindex test beta", MemoryCategory::Core)
|
||||
mem.store("r2", "reindex test beta", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
|
|
@ -1031,7 +1118,7 @@ mod tests {
|
|||
assert_eq!(count, 0);
|
||||
|
||||
// FTS should still work after rebuild
|
||||
let results = mem.recall("reindex", 10).await.unwrap();
|
||||
let results = mem.recall("reindex", 10, None).await.unwrap();
|
||||
assert_eq!(results.len(), 2);
|
||||
}
|
||||
|
||||
|
|
@ -1045,12 +1132,13 @@ mod tests {
|
|||
&format!("k{i}"),
|
||||
&format!("common keyword item {i}"),
|
||||
MemoryCategory::Core,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
let results = mem.recall("common keyword", 5).await.unwrap();
|
||||
let results = mem.recall("common keyword", 5, None).await.unwrap();
|
||||
assert!(results.len() <= 5);
|
||||
}
|
||||
|
||||
|
|
@ -1059,11 +1147,11 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn recall_results_have_scores() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("s1", "scored result test", MemoryCategory::Core)
|
||||
mem.store("s1", "scored result test", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let results = mem.recall("scored", 10).await.unwrap();
|
||||
let results = mem.recall("scored", 10, None).await.unwrap();
|
||||
assert!(!results.is_empty());
|
||||
for r in &results {
|
||||
assert!(r.score.is_some(), "Expected score on result: {:?}", r.key);
|
||||
|
|
@ -1075,11 +1163,11 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn recall_with_quotes_in_query() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("q1", "He said hello world", MemoryCategory::Core)
|
||||
mem.store("q1", "He said hello world", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
// Quotes in query should not crash FTS5
|
||||
let results = mem.recall("\"hello\"", 10).await.unwrap();
|
||||
let results = mem.recall("\"hello\"", 10, None).await.unwrap();
|
||||
// May or may not match depending on FTS5 escaping, but must not error
|
||||
assert!(results.len() <= 10);
|
||||
}
|
||||
|
|
@ -1087,31 +1175,34 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn recall_with_asterisk_in_query() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("a1", "wildcard test content", MemoryCategory::Core)
|
||||
mem.store("a1", "wildcard test content", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
let results = mem.recall("wild*", 10).await.unwrap();
|
||||
let results = mem.recall("wild*", 10, None).await.unwrap();
|
||||
assert!(results.len() <= 10);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn recall_with_parentheses_in_query() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("p1", "function call test", MemoryCategory::Core)
|
||||
mem.store("p1", "function call test", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
let results = mem.recall("function()", 10).await.unwrap();
|
||||
let results = mem.recall("function()", 10, None).await.unwrap();
|
||||
assert!(results.len() <= 10);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn recall_with_sql_injection_attempt() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("safe", "normal content", MemoryCategory::Core)
|
||||
mem.store("safe", "normal content", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
// Should not crash or leak data
|
||||
let results = mem.recall("'; DROP TABLE memories; --", 10).await.unwrap();
|
||||
let results = mem
|
||||
.recall("'; DROP TABLE memories; --", 10, None)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(results.len() <= 10);
|
||||
// Table should still exist
|
||||
assert_eq!(mem.count().await.unwrap(), 1);
|
||||
|
|
@ -1122,7 +1213,9 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn store_empty_content() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("empty", "", MemoryCategory::Core).await.unwrap();
|
||||
mem.store("empty", "", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
let entry = mem.get("empty").await.unwrap().unwrap();
|
||||
assert_eq!(entry.content, "");
|
||||
}
|
||||
|
|
@ -1130,7 +1223,7 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn store_empty_key() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("", "content for empty key", MemoryCategory::Core)
|
||||
mem.store("", "content for empty key", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
let entry = mem.get("").await.unwrap().unwrap();
|
||||
|
|
@ -1141,7 +1234,7 @@ mod tests {
|
|||
async fn store_very_long_content() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
let long_content = "x".repeat(100_000);
|
||||
mem.store("long", &long_content, MemoryCategory::Core)
|
||||
mem.store("long", &long_content, MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
let entry = mem.get("long").await.unwrap().unwrap();
|
||||
|
|
@ -1151,9 +1244,14 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn store_unicode_and_emoji() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("emoji_key_🦀", "こんにちは 🚀 Ñoño", MemoryCategory::Core)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store(
|
||||
"emoji_key_🦀",
|
||||
"こんにちは 🚀 Ñoño",
|
||||
MemoryCategory::Core,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
let entry = mem.get("emoji_key_🦀").await.unwrap().unwrap();
|
||||
assert_eq!(entry.content, "こんにちは 🚀 Ñoño");
|
||||
}
|
||||
|
|
@ -1162,7 +1260,7 @@ mod tests {
|
|||
async fn store_content_with_newlines_and_tabs() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
let content = "line1\nline2\ttab\rcarriage\n\nnewparagraph";
|
||||
mem.store("whitespace", content, MemoryCategory::Core)
|
||||
mem.store("whitespace", content, MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
let entry = mem.get("whitespace").await.unwrap().unwrap();
|
||||
|
|
@ -1174,11 +1272,11 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn recall_single_character_query() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("a", "x marks the spot", MemoryCategory::Core)
|
||||
mem.store("a", "x marks the spot", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
// Single char may not match FTS5 but LIKE fallback should work
|
||||
let results = mem.recall("x", 10).await.unwrap();
|
||||
let results = mem.recall("x", 10, None).await.unwrap();
|
||||
// Should not crash; may or may not find results
|
||||
assert!(results.len() <= 10);
|
||||
}
|
||||
|
|
@ -1186,23 +1284,23 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn recall_limit_zero() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("a", "some content", MemoryCategory::Core)
|
||||
mem.store("a", "some content", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
let results = mem.recall("some", 0).await.unwrap();
|
||||
let results = mem.recall("some", 0, None).await.unwrap();
|
||||
assert!(results.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn recall_limit_one() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("a", "matching content alpha", MemoryCategory::Core)
|
||||
mem.store("a", "matching content alpha", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("b", "matching content beta", MemoryCategory::Core)
|
||||
mem.store("b", "matching content beta", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
let results = mem.recall("matching content", 1).await.unwrap();
|
||||
let results = mem.recall("matching content", 1, None).await.unwrap();
|
||||
assert_eq!(results.len(), 1);
|
||||
}
|
||||
|
||||
|
|
@ -1213,21 +1311,22 @@ mod tests {
|
|||
"rust_preferences",
|
||||
"User likes systems programming",
|
||||
MemoryCategory::Core,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
// "rust" appears in key but not content — LIKE fallback checks key too
|
||||
let results = mem.recall("rust", 10).await.unwrap();
|
||||
let results = mem.recall("rust", 10, None).await.unwrap();
|
||||
assert!(!results.is_empty(), "Should match by key");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn recall_unicode_query() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("jp", "日本語のテスト", MemoryCategory::Core)
|
||||
mem.store("jp", "日本語のテスト", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
let results = mem.recall("日本語", 10).await.unwrap();
|
||||
let results = mem.recall("日本語", 10, None).await.unwrap();
|
||||
assert!(!results.is_empty());
|
||||
}
|
||||
|
||||
|
|
@ -1238,7 +1337,9 @@ mod tests {
|
|||
let tmp = TempDir::new().unwrap();
|
||||
{
|
||||
let mem = SqliteMemory::new(tmp.path()).unwrap();
|
||||
mem.store("k1", "v1", MemoryCategory::Core).await.unwrap();
|
||||
mem.store("k1", "v1", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
// Open again — init_schema runs again on existing DB
|
||||
let mem2 = SqliteMemory::new(tmp.path()).unwrap();
|
||||
|
|
@ -1246,7 +1347,9 @@ mod tests {
|
|||
assert!(entry.is_some());
|
||||
assert_eq!(entry.unwrap().content, "v1");
|
||||
// Store more data — should work fine
|
||||
mem2.store("k2", "v2", MemoryCategory::Daily).await.unwrap();
|
||||
mem2.store("k2", "v2", MemoryCategory::Daily, None)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(mem2.count().await.unwrap(), 2);
|
||||
}
|
||||
|
||||
|
|
@ -1264,11 +1367,16 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn forget_then_recall_no_ghost_results() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("ghost", "phantom memory content", MemoryCategory::Core)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store(
|
||||
"ghost",
|
||||
"phantom memory content",
|
||||
MemoryCategory::Core,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.forget("ghost").await.unwrap();
|
||||
let results = mem.recall("phantom memory", 10).await.unwrap();
|
||||
let results = mem.recall("phantom memory", 10, None).await.unwrap();
|
||||
assert!(
|
||||
results.is_empty(),
|
||||
"Deleted memory should not appear in recall"
|
||||
|
|
@ -1278,11 +1386,11 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn forget_and_re_store_same_key() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("cycle", "version 1", MemoryCategory::Core)
|
||||
mem.store("cycle", "version 1", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.forget("cycle").await.unwrap();
|
||||
mem.store("cycle", "version 2", MemoryCategory::Core)
|
||||
mem.store("cycle", "version 2", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
let entry = mem.get("cycle").await.unwrap().unwrap();
|
||||
|
|
@ -1302,14 +1410,14 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn reindex_twice_is_safe() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("r1", "reindex data", MemoryCategory::Core)
|
||||
mem.store("r1", "reindex data", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.reindex().await.unwrap();
|
||||
let count = mem.reindex().await.unwrap();
|
||||
assert_eq!(count, 0); // Noop embedder → nothing to re-embed
|
||||
// Data should still be intact
|
||||
let results = mem.recall("reindex", 10).await.unwrap();
|
||||
let results = mem.recall("reindex", 10, None).await.unwrap();
|
||||
assert_eq!(results.len(), 1);
|
||||
}
|
||||
|
||||
|
|
@ -1363,18 +1471,28 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn list_custom_category() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("c1", "custom1", MemoryCategory::Custom("project".into()))
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("c2", "custom2", MemoryCategory::Custom("project".into()))
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("c3", "other", MemoryCategory::Core)
|
||||
mem.store(
|
||||
"c1",
|
||||
"custom1",
|
||||
MemoryCategory::Custom("project".into()),
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store(
|
||||
"c2",
|
||||
"custom2",
|
||||
MemoryCategory::Custom("project".into()),
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("c3", "other", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let project = mem
|
||||
.list(Some(&MemoryCategory::Custom("project".into())))
|
||||
.list(Some(&MemoryCategory::Custom("project".into())), None)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(project.len(), 2);
|
||||
|
|
@ -1383,7 +1501,122 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn list_empty_db() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
let all = mem.list(None).await.unwrap();
|
||||
let all = mem.list(None, None).await.unwrap();
|
||||
assert!(all.is_empty());
|
||||
}
|
||||
|
||||
// ── Session isolation ─────────────────────────────────────────
|
||||
|
||||
#[tokio::test]
|
||||
async fn store_and_recall_with_session_id() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("k1", "session A fact", MemoryCategory::Core, Some("sess-a"))
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("k2", "session B fact", MemoryCategory::Core, Some("sess-b"))
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("k3", "no session fact", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Recall with session-a filter returns only session-a entry
|
||||
let results = mem.recall("fact", 10, Some("sess-a")).await.unwrap();
|
||||
assert_eq!(results.len(), 1);
|
||||
assert_eq!(results[0].key, "k1");
|
||||
assert_eq!(results[0].session_id.as_deref(), Some("sess-a"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn recall_no_session_filter_returns_all() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("k1", "alpha fact", MemoryCategory::Core, Some("sess-a"))
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("k2", "beta fact", MemoryCategory::Core, Some("sess-b"))
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("k3", "gamma fact", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Recall without session filter returns all matching entries
|
||||
let results = mem.recall("fact", 10, None).await.unwrap();
|
||||
assert_eq!(results.len(), 3);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn cross_session_recall_isolation() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store(
|
||||
"secret",
|
||||
"session A secret data",
|
||||
MemoryCategory::Core,
|
||||
Some("sess-a"),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Session B cannot see session A data
|
||||
let results = mem.recall("secret", 10, Some("sess-b")).await.unwrap();
|
||||
assert!(results.is_empty());
|
||||
|
||||
// Session A can see its own data
|
||||
let results = mem.recall("secret", 10, Some("sess-a")).await.unwrap();
|
||||
assert_eq!(results.len(), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn list_with_session_filter() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("k1", "a1", MemoryCategory::Core, Some("sess-a"))
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("k2", "a2", MemoryCategory::Conversation, Some("sess-a"))
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("k3", "b1", MemoryCategory::Core, Some("sess-b"))
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("k4", "none1", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// List with session-a filter
|
||||
let results = mem.list(None, Some("sess-a")).await.unwrap();
|
||||
assert_eq!(results.len(), 2);
|
||||
assert!(results
|
||||
.iter()
|
||||
.all(|e| e.session_id.as_deref() == Some("sess-a")));
|
||||
|
||||
// List with session-a + category filter
|
||||
let results = mem
|
||||
.list(Some(&MemoryCategory::Core), Some("sess-a"))
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(results.len(), 1);
|
||||
assert_eq!(results[0].key, "k1");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn schema_migration_idempotent_on_reopen() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
|
||||
// First open: creates schema + migration
|
||||
{
|
||||
let mem = SqliteMemory::new(tmp.path()).unwrap();
|
||||
mem.store("k1", "before reopen", MemoryCategory::Core, Some("sess-x"))
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
// Second open: migration runs again but is idempotent
|
||||
{
|
||||
let mem = SqliteMemory::new(tmp.path()).unwrap();
|
||||
let results = mem.recall("reopen", 10, Some("sess-x")).await.unwrap();
|
||||
assert_eq!(results.len(), 1);
|
||||
assert_eq!(results[0].key, "k1");
|
||||
assert_eq!(results[0].session_id.as_deref(), Some("sess-x"));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -44,18 +44,32 @@ pub trait Memory: Send + Sync {
|
|||
/// Backend name
|
||||
fn name(&self) -> &str;
|
||||
|
||||
/// Store a memory entry
|
||||
async fn store(&self, key: &str, content: &str, category: MemoryCategory)
|
||||
-> anyhow::Result<()>;
|
||||
/// Store a memory entry, optionally scoped to a session
|
||||
async fn store(
|
||||
&self,
|
||||
key: &str,
|
||||
content: &str,
|
||||
category: MemoryCategory,
|
||||
session_id: Option<&str>,
|
||||
) -> anyhow::Result<()>;
|
||||
|
||||
/// Recall memories matching a query (keyword search)
|
||||
async fn recall(&self, query: &str, limit: usize) -> anyhow::Result<Vec<MemoryEntry>>;
|
||||
/// Recall memories matching a query (keyword search), optionally scoped to a session
|
||||
async fn recall(
|
||||
&self,
|
||||
query: &str,
|
||||
limit: usize,
|
||||
session_id: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>>;
|
||||
|
||||
/// Get a specific memory by key
|
||||
async fn get(&self, key: &str) -> anyhow::Result<Option<MemoryEntry>>;
|
||||
|
||||
/// List all memory keys, optionally filtered by category
|
||||
async fn list(&self, category: Option<&MemoryCategory>) -> anyhow::Result<Vec<MemoryEntry>>;
|
||||
/// List all memory keys, optionally filtered by category and/or session
|
||||
async fn list(
|
||||
&self,
|
||||
category: Option<&MemoryCategory>,
|
||||
session_id: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>>;
|
||||
|
||||
/// Remove a memory by key
|
||||
async fn forget(&self, key: &str) -> anyhow::Result<bool>;
|
||||
|
|
|
|||
|
|
@ -95,7 +95,9 @@ async fn migrate_openclaw_memory(
|
|||
stats.renamed_conflicts += 1;
|
||||
}
|
||||
|
||||
memory.store(&key, &entry.content, entry.category).await?;
|
||||
memory
|
||||
.store(&key, &entry.content, entry.category, None)
|
||||
.await?;
|
||||
stats.imported += 1;
|
||||
}
|
||||
|
||||
|
|
@ -488,7 +490,7 @@ mod tests {
|
|||
// Existing target memory
|
||||
let target_mem = SqliteMemory::new(target.path()).unwrap();
|
||||
target_mem
|
||||
.store("k", "new value", MemoryCategory::Core)
|
||||
.store("k", "new value", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
|
|
@ -510,7 +512,7 @@ mod tests {
|
|||
.await
|
||||
.unwrap();
|
||||
|
||||
let all = target_mem.list(None).await.unwrap();
|
||||
let all = target_mem.list(None, None).await.unwrap();
|
||||
assert!(all.iter().any(|e| e.key == "k" && e.content == "new value"));
|
||||
assert!(all
|
||||
.iter()
|
||||
|
|
|
|||
|
|
@ -48,9 +48,10 @@ impl Observer for LogObserver {
|
|||
ObserverEvent::AgentEnd {
|
||||
duration,
|
||||
tokens_used,
|
||||
cost_usd,
|
||||
} => {
|
||||
let ms = u64::try_from(duration.as_millis()).unwrap_or(u64::MAX);
|
||||
info!(duration_ms = ms, tokens = ?tokens_used, "agent.end");
|
||||
info!(duration_ms = ms, tokens = ?tokens_used, cost_usd = ?cost_usd, "agent.end");
|
||||
}
|
||||
ObserverEvent::ToolCallStart { tool } => {
|
||||
info!(tool = %tool, "tool.start");
|
||||
|
|
@ -133,10 +134,12 @@ mod tests {
|
|||
obs.record_event(&ObserverEvent::AgentEnd {
|
||||
duration: Duration::from_millis(500),
|
||||
tokens_used: Some(100),
|
||||
cost_usd: Some(0.0015),
|
||||
});
|
||||
obs.record_event(&ObserverEvent::AgentEnd {
|
||||
duration: Duration::ZERO,
|
||||
tokens_used: None,
|
||||
cost_usd: None,
|
||||
});
|
||||
obs.record_event(&ObserverEvent::ToolCallStart {
|
||||
tool: "shell".into(),
|
||||
|
|
|
|||
|
|
@ -48,10 +48,12 @@ mod tests {
|
|||
obs.record_event(&ObserverEvent::AgentEnd {
|
||||
duration: Duration::from_millis(100),
|
||||
tokens_used: Some(42),
|
||||
cost_usd: Some(0.001),
|
||||
});
|
||||
obs.record_event(&ObserverEvent::AgentEnd {
|
||||
duration: Duration::ZERO,
|
||||
tokens_used: None,
|
||||
cost_usd: None,
|
||||
});
|
||||
obs.record_event(&ObserverEvent::ToolCallStart {
|
||||
tool: "shell".into(),
|
||||
|
|
|
|||
|
|
@ -227,6 +227,7 @@ impl Observer for OtelObserver {
|
|||
ObserverEvent::AgentEnd {
|
||||
duration,
|
||||
tokens_used,
|
||||
cost_usd,
|
||||
} => {
|
||||
let secs = duration.as_secs_f64();
|
||||
let start_time = SystemTime::now()
|
||||
|
|
@ -243,6 +244,9 @@ impl Observer for OtelObserver {
|
|||
if let Some(t) = tokens_used {
|
||||
span.set_attribute(KeyValue::new("tokens_used", *t as i64));
|
||||
}
|
||||
if let Some(c) = cost_usd {
|
||||
span.set_attribute(KeyValue::new("cost_usd", *c));
|
||||
}
|
||||
span.end();
|
||||
|
||||
self.agent_duration.record(secs, &[]);
|
||||
|
|
@ -394,10 +398,12 @@ mod tests {
|
|||
obs.record_event(&ObserverEvent::AgentEnd {
|
||||
duration: Duration::from_millis(500),
|
||||
tokens_used: Some(100),
|
||||
cost_usd: Some(0.0015),
|
||||
});
|
||||
obs.record_event(&ObserverEvent::AgentEnd {
|
||||
duration: Duration::ZERO,
|
||||
tokens_used: None,
|
||||
cost_usd: None,
|
||||
});
|
||||
obs.record_event(&ObserverEvent::ToolCallStart {
|
||||
tool: "shell".into(),
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ pub enum ObserverEvent {
|
|||
AgentEnd {
|
||||
duration: Duration,
|
||||
tokens_used: Option<u64>,
|
||||
cost_usd: Option<f64>,
|
||||
},
|
||||
/// A tool call is about to be executed.
|
||||
ToolCallStart {
|
||||
|
|
|
|||
|
|
@ -106,6 +106,7 @@ pub fn run_wizard() -> Result<Config> {
|
|||
} else {
|
||||
Some(api_key)
|
||||
},
|
||||
api_url: None,
|
||||
default_provider: Some(provider),
|
||||
default_model: Some(model),
|
||||
default_temperature: 0.7,
|
||||
|
|
@ -284,7 +285,7 @@ fn memory_config_defaults_for_backend(backend: &str) -> MemoryConfig {
|
|||
|
||||
#[allow(clippy::too_many_lines)]
|
||||
pub fn run_quick_setup(
|
||||
api_key: Option<&str>,
|
||||
credential_override: Option<&str>,
|
||||
provider: Option<&str>,
|
||||
memory_backend: Option<&str>,
|
||||
) -> Result<Config> {
|
||||
|
|
@ -318,7 +319,8 @@ pub fn run_quick_setup(
|
|||
let config = Config {
|
||||
workspace_dir: workspace_dir.clone(),
|
||||
config_path: config_path.clone(),
|
||||
api_key: api_key.map(String::from),
|
||||
api_key: credential_override.map(String::from),
|
||||
api_url: None,
|
||||
default_provider: Some(provider_name.clone()),
|
||||
default_model: Some(model.clone()),
|
||||
default_temperature: 0.7,
|
||||
|
|
@ -377,7 +379,7 @@ pub fn run_quick_setup(
|
|||
println!(
|
||||
" {} API Key: {}",
|
||||
style("✓").green().bold(),
|
||||
if api_key.is_some() {
|
||||
if credential_override.is_some() {
|
||||
style("set").green()
|
||||
} else {
|
||||
style("not set (use --api-key or edit config.toml)").yellow()
|
||||
|
|
@ -426,7 +428,7 @@ pub fn run_quick_setup(
|
|||
);
|
||||
println!();
|
||||
println!(" {}", style("Next steps:").white().bold());
|
||||
if api_key.is_none() {
|
||||
if credential_override.is_none() {
|
||||
println!(" 1. Set your API key: export OPENROUTER_API_KEY=\"sk-...\"");
|
||||
println!(" 2. Or edit: ~/.zeroclaw/config.toml");
|
||||
println!(" 3. Chat: zeroclaw agent -m \"Hello!\"");
|
||||
|
|
@ -2269,14 +2271,11 @@ fn setup_memory() -> Result<MemoryConfig> {
|
|||
let backend = backend_key_from_choice(choice);
|
||||
let profile = memory_backend_profile(backend);
|
||||
|
||||
let auto_save = if !profile.auto_save_default {
|
||||
false
|
||||
} else {
|
||||
Confirm::new()
|
||||
let auto_save = profile.auto_save_default
|
||||
&& Confirm::new()
|
||||
.with_prompt(" Auto-save conversations to memory?")
|
||||
.default(true)
|
||||
.interact()?
|
||||
};
|
||||
.interact()?;
|
||||
|
||||
println!(
|
||||
" {} Memory: {} (auto-save: {})",
|
||||
|
|
@ -2587,6 +2586,7 @@ fn setup_channels() -> Result<ChannelsConfig> {
|
|||
guild_id: if guild.is_empty() { None } else { Some(guild) },
|
||||
allowed_users,
|
||||
listen_to_bots: false,
|
||||
mention_only: false,
|
||||
});
|
||||
}
|
||||
2 => {
|
||||
|
|
@ -2799,22 +2799,14 @@ fn setup_channels() -> Result<ChannelsConfig> {
|
|||
.header("Authorization", format!("Bearer {access_token_clone}"))
|
||||
.send()?;
|
||||
let ok = resp.status().is_success();
|
||||
let data: serde_json::Value = resp.json().unwrap_or_default();
|
||||
let user_id = data
|
||||
.get("user_id")
|
||||
.and_then(serde_json::Value::as_str)
|
||||
.unwrap_or("unknown")
|
||||
.to_string();
|
||||
Ok::<_, reqwest::Error>((ok, user_id))
|
||||
Ok::<_, reqwest::Error>(ok)
|
||||
})
|
||||
.join();
|
||||
match thread_result {
|
||||
Ok(Ok((true, user_id))) => {
|
||||
println!(
|
||||
"\r {} Connected as {user_id} ",
|
||||
style("✅").green().bold()
|
||||
);
|
||||
}
|
||||
Ok(Ok(true)) => println!(
|
||||
"\r {} Connection verified ",
|
||||
style("✅").green().bold()
|
||||
),
|
||||
_ => {
|
||||
println!(
|
||||
"\r {} Connection failed — check homeserver URL and token",
|
||||
|
|
@ -3779,15 +3771,7 @@ fn print_summary(config: &Config) {
|
|||
);
|
||||
|
||||
// Secrets
|
||||
println!(
|
||||
" {} Secrets: {}",
|
||||
style("🔒").cyan(),
|
||||
if config.secrets.encrypt {
|
||||
style("encrypted").green().to_string()
|
||||
} else {
|
||||
style("plaintext").yellow().to_string()
|
||||
}
|
||||
);
|
||||
println!(" {} Secrets: configured", style("🔒").cyan());
|
||||
|
||||
// Gateway
|
||||
println!(
|
||||
|
|
|
|||
|
|
@ -38,6 +38,10 @@ pub fn ensure_arduino_cli() -> Result<()> {
|
|||
anyhow::bail!("brew install arduino-cli failed. Install manually: https://arduino.github.io/arduino-cli/");
|
||||
}
|
||||
println!("arduino-cli installed.");
|
||||
if !arduino_cli_available() {
|
||||
anyhow::bail!("arduino-cli still not found after install. Ensure it's in PATH.");
|
||||
}
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
#[cfg(target_os = "linux")]
|
||||
|
|
@ -54,11 +58,6 @@ pub fn ensure_arduino_cli() -> Result<()> {
|
|||
println!("arduino-cli not found. Install it: https://arduino.github.io/arduino-cli/");
|
||||
anyhow::bail!("arduino-cli not installed.");
|
||||
}
|
||||
|
||||
if !arduino_cli_available() {
|
||||
anyhow::bail!("arduino-cli still not found after install. Ensure it's in PATH.");
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Ensure arduino:avr core is installed.
|
||||
|
|
|
|||
|
|
@ -112,6 +112,7 @@ pub struct SerialPeripheral {
|
|||
|
||||
impl SerialPeripheral {
|
||||
/// Create and connect to a serial peripheral.
|
||||
#[allow(clippy::unused_async)]
|
||||
pub async fn connect(config: &PeripheralBoardConfig) -> anyhow::Result<Self> {
|
||||
let path = config
|
||||
.path
|
||||
|
|
|
|||
|
|
@ -106,17 +106,17 @@ struct NativeContentIn {
|
|||
}
|
||||
|
||||
impl AnthropicProvider {
|
||||
pub fn new(api_key: Option<&str>) -> Self {
|
||||
Self::with_base_url(api_key, None)
|
||||
pub fn new(credential: Option<&str>) -> Self {
|
||||
Self::with_base_url(credential, None)
|
||||
}
|
||||
|
||||
pub fn with_base_url(api_key: Option<&str>, base_url: Option<&str>) -> Self {
|
||||
pub fn with_base_url(credential: Option<&str>, base_url: Option<&str>) -> Self {
|
||||
let base_url = base_url
|
||||
.map(|u| u.trim_end_matches('/'))
|
||||
.unwrap_or("https://api.anthropic.com")
|
||||
.to_string();
|
||||
Self {
|
||||
credential: api_key
|
||||
credential: credential
|
||||
.map(str::trim)
|
||||
.filter(|k| !k.is_empty())
|
||||
.map(ToString::to_string),
|
||||
|
|
@ -410,9 +410,9 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn creates_with_key() {
|
||||
let p = AnthropicProvider::new(Some("sk-ant-test123"));
|
||||
let p = AnthropicProvider::new(Some("anthropic-test-credential"));
|
||||
assert!(p.credential.is_some());
|
||||
assert_eq!(p.credential.as_deref(), Some("sk-ant-test123"));
|
||||
assert_eq!(p.credential.as_deref(), Some("anthropic-test-credential"));
|
||||
assert_eq!(p.base_url, "https://api.anthropic.com");
|
||||
}
|
||||
|
||||
|
|
@ -431,17 +431,19 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn creates_with_whitespace_key() {
|
||||
let p = AnthropicProvider::new(Some(" sk-ant-test123 "));
|
||||
let p = AnthropicProvider::new(Some(" anthropic-test-credential "));
|
||||
assert!(p.credential.is_some());
|
||||
assert_eq!(p.credential.as_deref(), Some("sk-ant-test123"));
|
||||
assert_eq!(p.credential.as_deref(), Some("anthropic-test-credential"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn creates_with_custom_base_url() {
|
||||
let p =
|
||||
AnthropicProvider::with_base_url(Some("sk-ant-test"), Some("https://api.example.com"));
|
||||
let p = AnthropicProvider::with_base_url(
|
||||
Some("anthropic-credential"),
|
||||
Some("https://api.example.com"),
|
||||
);
|
||||
assert_eq!(p.base_url, "https://api.example.com");
|
||||
assert_eq!(p.credential.as_deref(), Some("sk-ant-test"));
|
||||
assert_eq!(p.credential.as_deref(), Some("anthropic-credential"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ use serde::{Deserialize, Serialize};
|
|||
pub struct OpenAiCompatibleProvider {
|
||||
pub(crate) name: String,
|
||||
pub(crate) base_url: String,
|
||||
pub(crate) api_key: Option<String>,
|
||||
pub(crate) credential: Option<String>,
|
||||
pub(crate) auth_header: AuthStyle,
|
||||
/// When false, do not fall back to /v1/responses on chat completions 404.
|
||||
/// GLM/Zhipu does not support the responses API.
|
||||
|
|
@ -37,11 +37,16 @@ pub enum AuthStyle {
|
|||
}
|
||||
|
||||
impl OpenAiCompatibleProvider {
|
||||
pub fn new(name: &str, base_url: &str, api_key: Option<&str>, auth_style: AuthStyle) -> Self {
|
||||
pub fn new(
|
||||
name: &str,
|
||||
base_url: &str,
|
||||
credential: Option<&str>,
|
||||
auth_style: AuthStyle,
|
||||
) -> Self {
|
||||
Self {
|
||||
name: name.to_string(),
|
||||
base_url: base_url.trim_end_matches('/').to_string(),
|
||||
api_key: api_key.map(ToString::to_string),
|
||||
credential: credential.map(ToString::to_string),
|
||||
auth_header: auth_style,
|
||||
supports_responses_fallback: true,
|
||||
client: Client::builder()
|
||||
|
|
@ -57,13 +62,13 @@ impl OpenAiCompatibleProvider {
|
|||
pub fn new_no_responses_fallback(
|
||||
name: &str,
|
||||
base_url: &str,
|
||||
api_key: Option<&str>,
|
||||
credential: Option<&str>,
|
||||
auth_style: AuthStyle,
|
||||
) -> Self {
|
||||
Self {
|
||||
name: name.to_string(),
|
||||
base_url: base_url.trim_end_matches('/').to_string(),
|
||||
api_key: api_key.map(ToString::to_string),
|
||||
credential: credential.map(ToString::to_string),
|
||||
auth_header: auth_style,
|
||||
supports_responses_fallback: false,
|
||||
client: Client::builder()
|
||||
|
|
@ -405,18 +410,18 @@ impl OpenAiCompatibleProvider {
|
|||
fn apply_auth_header(
|
||||
&self,
|
||||
req: reqwest::RequestBuilder,
|
||||
api_key: &str,
|
||||
credential: &str,
|
||||
) -> reqwest::RequestBuilder {
|
||||
match &self.auth_header {
|
||||
AuthStyle::Bearer => req.header("Authorization", format!("Bearer {api_key}")),
|
||||
AuthStyle::XApiKey => req.header("x-api-key", api_key),
|
||||
AuthStyle::Custom(header) => req.header(header, api_key),
|
||||
AuthStyle::Bearer => req.header("Authorization", format!("Bearer {credential}")),
|
||||
AuthStyle::XApiKey => req.header("x-api-key", credential),
|
||||
AuthStyle::Custom(header) => req.header(header, credential),
|
||||
}
|
||||
}
|
||||
|
||||
async fn chat_via_responses(
|
||||
&self,
|
||||
api_key: &str,
|
||||
credential: &str,
|
||||
system_prompt: Option<&str>,
|
||||
message: &str,
|
||||
model: &str,
|
||||
|
|
@ -434,7 +439,7 @@ impl OpenAiCompatibleProvider {
|
|||
let url = self.responses_url();
|
||||
|
||||
let response = self
|
||||
.apply_auth_header(self.client.post(&url).json(&request), api_key)
|
||||
.apply_auth_header(self.client.post(&url).json(&request), credential)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
|
|
@ -459,7 +464,7 @@ impl Provider for OpenAiCompatibleProvider {
|
|||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
let api_key = self.api_key.as_ref().ok_or_else(|| {
|
||||
let credential = self.credential.as_ref().ok_or_else(|| {
|
||||
anyhow::anyhow!(
|
||||
"{} API key not set. Run `zeroclaw onboard` or set the appropriate env var.",
|
||||
self.name
|
||||
|
|
@ -490,7 +495,7 @@ impl Provider for OpenAiCompatibleProvider {
|
|||
let url = self.chat_completions_url();
|
||||
|
||||
let response = self
|
||||
.apply_auth_header(self.client.post(&url).json(&request), api_key)
|
||||
.apply_auth_header(self.client.post(&url).json(&request), credential)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
|
|
@ -501,7 +506,7 @@ impl Provider for OpenAiCompatibleProvider {
|
|||
|
||||
if status == reqwest::StatusCode::NOT_FOUND && self.supports_responses_fallback {
|
||||
return self
|
||||
.chat_via_responses(api_key, system_prompt, message, model)
|
||||
.chat_via_responses(credential, system_prompt, message, model)
|
||||
.await
|
||||
.map_err(|responses_err| {
|
||||
anyhow::anyhow!(
|
||||
|
|
@ -545,7 +550,7 @@ impl Provider for OpenAiCompatibleProvider {
|
|||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
let api_key = self.api_key.as_ref().ok_or_else(|| {
|
||||
let credential = self.credential.as_ref().ok_or_else(|| {
|
||||
anyhow::anyhow!(
|
||||
"{} API key not set. Run `zeroclaw onboard` or set the appropriate env var.",
|
||||
self.name
|
||||
|
|
@ -569,7 +574,7 @@ impl Provider for OpenAiCompatibleProvider {
|
|||
|
||||
let url = self.chat_completions_url();
|
||||
let response = self
|
||||
.apply_auth_header(self.client.post(&url).json(&request), api_key)
|
||||
.apply_auth_header(self.client.post(&url).json(&request), credential)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
|
|
@ -584,7 +589,7 @@ impl Provider for OpenAiCompatibleProvider {
|
|||
if let Some(user_msg) = last_user {
|
||||
return self
|
||||
.chat_via_responses(
|
||||
api_key,
|
||||
credential,
|
||||
system.map(|m| m.content.as_str()),
|
||||
&user_msg.content,
|
||||
model,
|
||||
|
|
@ -791,16 +796,20 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn creates_with_key() {
|
||||
let p = make_provider("venice", "https://api.venice.ai", Some("vn-key"));
|
||||
let p = make_provider(
|
||||
"venice",
|
||||
"https://api.venice.ai",
|
||||
Some("venice-test-credential"),
|
||||
);
|
||||
assert_eq!(p.name, "venice");
|
||||
assert_eq!(p.base_url, "https://api.venice.ai");
|
||||
assert_eq!(p.api_key.as_deref(), Some("vn-key"));
|
||||
assert_eq!(p.credential.as_deref(), Some("venice-test-credential"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn creates_without_key() {
|
||||
let p = make_provider("test", "https://example.com", None);
|
||||
assert!(p.api_key.is_none());
|
||||
assert!(p.credential.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -894,6 +903,7 @@ mod tests {
|
|||
make_provider("Groq", "https://api.groq.com/openai", None),
|
||||
make_provider("Mistral", "https://api.mistral.ai", None),
|
||||
make_provider("xAI", "https://api.x.ai", None),
|
||||
make_provider("Astrai", "https://as-trai.com/v1", None),
|
||||
];
|
||||
|
||||
for p in providers {
|
||||
|
|
|
|||
705
src/providers/copilot.rs
Normal file
705
src/providers/copilot.rs
Normal file
|
|
@ -0,0 +1,705 @@
|
|||
//! GitHub Copilot provider with OAuth device-flow authentication.
|
||||
//!
|
||||
//! Authenticates via GitHub's device code flow (same as VS Code Copilot),
|
||||
//! then exchanges the OAuth token for short-lived Copilot API keys.
|
||||
//! Tokens are cached to disk and auto-refreshed.
|
||||
//!
|
||||
//! **Note:** This uses VS Code's OAuth client ID (`Iv1.b507a08c87ecfe98`) and
|
||||
//! editor headers. This is the same approach used by LiteLLM, Codex CLI,
|
||||
//! and other third-party Copilot integrations. The Copilot token endpoint is
|
||||
//! private; there is no public OAuth scope or app registration for it.
|
||||
//! GitHub could change or revoke this at any time, which would break all
|
||||
//! third-party integrations simultaneously.
|
||||
|
||||
use crate::providers::traits::{
|
||||
ChatMessage, ChatRequest as ProviderChatRequest, ChatResponse as ProviderChatResponse,
|
||||
Provider, ToolCall as ProviderToolCall,
|
||||
};
|
||||
use crate::tools::ToolSpec;
|
||||
use async_trait::async_trait;
|
||||
use reqwest::Client;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::Mutex;
|
||||
use tracing::warn;
|
||||
|
||||
/// GitHub OAuth client ID for Copilot (VS Code extension).
|
||||
const GITHUB_CLIENT_ID: &str = "Iv1.b507a08c87ecfe98";
|
||||
const GITHUB_DEVICE_CODE_URL: &str = "https://github.com/login/device/code";
|
||||
const GITHUB_ACCESS_TOKEN_URL: &str = "https://github.com/login/oauth/access_token";
|
||||
const GITHUB_API_KEY_URL: &str = "https://api.github.com/copilot_internal/v2/token";
|
||||
const DEFAULT_API: &str = "https://api.githubcopilot.com";
|
||||
|
||||
// ── Token types ──────────────────────────────────────────────────
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct DeviceCodeResponse {
|
||||
device_code: String,
|
||||
user_code: String,
|
||||
verification_uri: String,
|
||||
#[serde(default = "default_interval")]
|
||||
interval: u64,
|
||||
#[serde(default = "default_expires_in")]
|
||||
expires_in: u64,
|
||||
}
|
||||
|
||||
fn default_interval() -> u64 {
|
||||
5
|
||||
}
|
||||
|
||||
fn default_expires_in() -> u64 {
|
||||
900
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct AccessTokenResponse {
|
||||
access_token: Option<String>,
|
||||
error: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct ApiKeyInfo {
|
||||
token: String,
|
||||
expires_at: i64,
|
||||
#[serde(default)]
|
||||
endpoints: Option<ApiEndpoints>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct ApiEndpoints {
|
||||
api: Option<String>,
|
||||
}
|
||||
|
||||
struct CachedApiKey {
|
||||
token: String,
|
||||
api_endpoint: String,
|
||||
expires_at: i64,
|
||||
}
|
||||
|
||||
// ── Chat completions types ───────────────────────────────────────
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct ApiChatRequest {
|
||||
model: String,
|
||||
messages: Vec<ApiMessage>,
|
||||
temperature: f64,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
tools: Option<Vec<NativeToolSpec>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
tool_choice: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct ApiMessage {
|
||||
role: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
content: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
tool_call_id: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
tool_calls: Option<Vec<NativeToolCall>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct NativeToolSpec {
|
||||
#[serde(rename = "type")]
|
||||
kind: String,
|
||||
function: NativeToolFunctionSpec,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct NativeToolFunctionSpec {
|
||||
name: String,
|
||||
description: String,
|
||||
parameters: serde_json::Value,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct NativeToolCall {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
id: Option<String>,
|
||||
#[serde(rename = "type", skip_serializing_if = "Option::is_none")]
|
||||
kind: Option<String>,
|
||||
function: NativeFunctionCall,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct NativeFunctionCall {
|
||||
name: String,
|
||||
arguments: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ApiChatResponse {
|
||||
choices: Vec<Choice>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct Choice {
|
||||
message: ResponseMessage,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ResponseMessage {
|
||||
#[serde(default)]
|
||||
content: Option<String>,
|
||||
#[serde(default)]
|
||||
tool_calls: Option<Vec<NativeToolCall>>,
|
||||
}
|
||||
|
||||
// ── Provider ─────────────────────────────────────────────────────
|
||||
|
||||
/// GitHub Copilot provider with automatic OAuth and token refresh.
|
||||
///
|
||||
/// On first use, prompts the user to visit github.com/login/device.
|
||||
/// Tokens are cached to `~/.config/zeroclaw/copilot/` and refreshed
|
||||
/// automatically.
|
||||
pub struct CopilotProvider {
|
||||
github_token: Option<String>,
|
||||
/// Mutex ensures only one caller refreshes tokens at a time,
|
||||
/// preventing duplicate device flow prompts or redundant API calls.
|
||||
refresh_lock: Arc<Mutex<Option<CachedApiKey>>>,
|
||||
http: Client,
|
||||
token_dir: PathBuf,
|
||||
}
|
||||
|
||||
impl CopilotProvider {
|
||||
pub fn new(github_token: Option<&str>) -> Self {
|
||||
let token_dir = directories::ProjectDirs::from("", "", "zeroclaw")
|
||||
.map(|dir| dir.config_dir().join("copilot"))
|
||||
.unwrap_or_else(|| {
|
||||
// Fall back to a user-specific temp directory to avoid
|
||||
// shared-directory symlink attacks.
|
||||
let user = std::env::var("USER")
|
||||
.or_else(|_| std::env::var("USERNAME"))
|
||||
.unwrap_or_else(|_| "unknown".to_string());
|
||||
std::env::temp_dir().join(format!("zeroclaw-copilot-{user}"))
|
||||
});
|
||||
|
||||
if let Err(err) = std::fs::create_dir_all(&token_dir) {
|
||||
warn!(
|
||||
"Failed to create Copilot token directory {:?}: {err}. Token caching is disabled.",
|
||||
token_dir
|
||||
);
|
||||
} else {
|
||||
#[cfg(unix)]
|
||||
{
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
|
||||
if let Err(err) =
|
||||
std::fs::set_permissions(&token_dir, std::fs::Permissions::from_mode(0o700))
|
||||
{
|
||||
warn!(
|
||||
"Failed to set Copilot token directory permissions on {:?}: {err}",
|
||||
token_dir
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Self {
|
||||
github_token: github_token
|
||||
.filter(|token| !token.is_empty())
|
||||
.map(String::from),
|
||||
refresh_lock: Arc::new(Mutex::new(None)),
|
||||
http: Client::builder()
|
||||
.timeout(Duration::from_secs(120))
|
||||
.connect_timeout(Duration::from_secs(10))
|
||||
.build()
|
||||
.unwrap_or_else(|_| Client::new()),
|
||||
token_dir,
|
||||
}
|
||||
}
|
||||
|
||||
/// Required headers for Copilot API requests (editor identification).
|
||||
const COPILOT_HEADERS: [(&str, &str); 4] = [
|
||||
("Editor-Version", "vscode/1.85.1"),
|
||||
("Editor-Plugin-Version", "copilot/1.155.0"),
|
||||
("User-Agent", "GithubCopilot/1.155.0"),
|
||||
("Accept", "application/json"),
|
||||
];
|
||||
|
||||
fn convert_tools(tools: Option<&[ToolSpec]>) -> Option<Vec<NativeToolSpec>> {
|
||||
tools.map(|items| {
|
||||
items
|
||||
.iter()
|
||||
.map(|tool| NativeToolSpec {
|
||||
kind: "function".to_string(),
|
||||
function: NativeToolFunctionSpec {
|
||||
name: tool.name.clone(),
|
||||
description: tool.description.clone(),
|
||||
parameters: tool.parameters.clone(),
|
||||
},
|
||||
})
|
||||
.collect()
|
||||
})
|
||||
}
|
||||
|
||||
fn convert_messages(messages: &[ChatMessage]) -> Vec<ApiMessage> {
|
||||
messages
|
||||
.iter()
|
||||
.map(|message| {
|
||||
if message.role == "assistant" {
|
||||
if let Ok(value) = serde_json::from_str::<serde_json::Value>(&message.content) {
|
||||
if let Some(tool_calls_value) = value.get("tool_calls") {
|
||||
if let Ok(parsed_calls) =
|
||||
serde_json::from_value::<Vec<ProviderToolCall>>(tool_calls_value.clone())
|
||||
{
|
||||
let tool_calls = parsed_calls
|
||||
.into_iter()
|
||||
.map(|tool_call| NativeToolCall {
|
||||
id: Some(tool_call.id),
|
||||
kind: Some("function".to_string()),
|
||||
function: NativeFunctionCall {
|
||||
name: tool_call.name,
|
||||
arguments: tool_call.arguments,
|
||||
},
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let content = value
|
||||
.get("content")
|
||||
.and_then(serde_json::Value::as_str)
|
||||
.map(ToString::to_string);
|
||||
|
||||
return ApiMessage {
|
||||
role: "assistant".to_string(),
|
||||
content,
|
||||
tool_call_id: None,
|
||||
tool_calls: Some(tool_calls),
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if message.role == "tool" {
|
||||
if let Ok(value) = serde_json::from_str::<serde_json::Value>(&message.content) {
|
||||
let tool_call_id = value
|
||||
.get("tool_call_id")
|
||||
.and_then(serde_json::Value::as_str)
|
||||
.map(ToString::to_string);
|
||||
let content = value
|
||||
.get("content")
|
||||
.and_then(serde_json::Value::as_str)
|
||||
.map(ToString::to_string);
|
||||
|
||||
return ApiMessage {
|
||||
role: "tool".to_string(),
|
||||
content,
|
||||
tool_call_id,
|
||||
tool_calls: None,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
ApiMessage {
|
||||
role: message.role.clone(),
|
||||
content: Some(message.content.clone()),
|
||||
tool_call_id: None,
|
||||
tool_calls: None,
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Send a chat completions request with required Copilot headers.
|
||||
async fn send_chat_request(
|
||||
&self,
|
||||
messages: Vec<ApiMessage>,
|
||||
tools: Option<&[ToolSpec]>,
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<ProviderChatResponse> {
|
||||
let (token, endpoint) = self.get_api_key().await?;
|
||||
let url = format!("{}/chat/completions", endpoint.trim_end_matches('/'));
|
||||
|
||||
let native_tools = Self::convert_tools(tools);
|
||||
let request = ApiChatRequest {
|
||||
model: model.to_string(),
|
||||
messages,
|
||||
temperature,
|
||||
tool_choice: native_tools.as_ref().map(|_| "auto".to_string()),
|
||||
tools: native_tools,
|
||||
};
|
||||
|
||||
let mut req = self
|
||||
.http
|
||||
.post(&url)
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.json(&request);
|
||||
|
||||
for (header, value) in &Self::COPILOT_HEADERS {
|
||||
req = req.header(*header, *value);
|
||||
}
|
||||
|
||||
let response = req.send().await?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
return Err(super::api_error("GitHub Copilot", response).await);
|
||||
}
|
||||
|
||||
let api_response: ApiChatResponse = response.json().await?;
|
||||
let choice = api_response
|
||||
.choices
|
||||
.into_iter()
|
||||
.next()
|
||||
.ok_or_else(|| anyhow::anyhow!("No response from GitHub Copilot"))?;
|
||||
|
||||
let tool_calls = choice
|
||||
.message
|
||||
.tool_calls
|
||||
.unwrap_or_default()
|
||||
.into_iter()
|
||||
.map(|tool_call| ProviderToolCall {
|
||||
id: tool_call
|
||||
.id
|
||||
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string()),
|
||||
name: tool_call.function.name,
|
||||
arguments: tool_call.function.arguments,
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(ProviderChatResponse {
|
||||
text: choice.message.content,
|
||||
tool_calls,
|
||||
})
|
||||
}
|
||||
|
||||
/// Get a valid Copilot API key, refreshing or re-authenticating as needed.
|
||||
/// Uses a Mutex to ensure only one caller refreshes at a time.
|
||||
async fn get_api_key(&self) -> anyhow::Result<(String, String)> {
|
||||
let mut cached = self.refresh_lock.lock().await;
|
||||
|
||||
if let Some(cached_key) = cached.as_ref() {
|
||||
if chrono::Utc::now().timestamp() + 120 < cached_key.expires_at {
|
||||
return Ok((cached_key.token.clone(), cached_key.api_endpoint.clone()));
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(info) = self.load_api_key_from_disk().await {
|
||||
if chrono::Utc::now().timestamp() + 120 < info.expires_at {
|
||||
let endpoint = info
|
||||
.endpoints
|
||||
.as_ref()
|
||||
.and_then(|e| e.api.clone())
|
||||
.unwrap_or_else(|| DEFAULT_API.to_string());
|
||||
let token = info.token;
|
||||
|
||||
*cached = Some(CachedApiKey {
|
||||
token: token.clone(),
|
||||
api_endpoint: endpoint.clone(),
|
||||
expires_at: info.expires_at,
|
||||
});
|
||||
return Ok((token, endpoint));
|
||||
}
|
||||
}
|
||||
|
||||
let access_token = self.get_github_access_token().await?;
|
||||
let api_key_info = self.exchange_for_api_key(&access_token).await?;
|
||||
self.save_api_key_to_disk(&api_key_info).await;
|
||||
|
||||
let endpoint = api_key_info
|
||||
.endpoints
|
||||
.as_ref()
|
||||
.and_then(|e| e.api.clone())
|
||||
.unwrap_or_else(|| DEFAULT_API.to_string());
|
||||
|
||||
*cached = Some(CachedApiKey {
|
||||
token: api_key_info.token.clone(),
|
||||
api_endpoint: endpoint.clone(),
|
||||
expires_at: api_key_info.expires_at,
|
||||
});
|
||||
|
||||
Ok((api_key_info.token, endpoint))
|
||||
}
|
||||
|
||||
/// Get a GitHub access token from config, cache, or device flow.
|
||||
async fn get_github_access_token(&self) -> anyhow::Result<String> {
|
||||
if let Some(token) = &self.github_token {
|
||||
return Ok(token.clone());
|
||||
}
|
||||
|
||||
let access_token_path = self.token_dir.join("access-token");
|
||||
if let Ok(cached) = tokio::fs::read_to_string(&access_token_path).await {
|
||||
let token = cached.trim();
|
||||
if !token.is_empty() {
|
||||
return Ok(token.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
let token = self.device_code_login().await?;
|
||||
write_file_secure(&access_token_path, &token).await;
|
||||
Ok(token)
|
||||
}
|
||||
|
||||
/// Run GitHub OAuth device code flow.
|
||||
async fn device_code_login(&self) -> anyhow::Result<String> {
|
||||
let response: DeviceCodeResponse = self
|
||||
.http
|
||||
.post(GITHUB_DEVICE_CODE_URL)
|
||||
.header("Accept", "application/json")
|
||||
.json(&serde_json::json!({
|
||||
"client_id": GITHUB_CLIENT_ID,
|
||||
"scope": "read:user"
|
||||
}))
|
||||
.send()
|
||||
.await?
|
||||
.error_for_status()?
|
||||
.json()
|
||||
.await?;
|
||||
|
||||
let mut poll_interval = Duration::from_secs(response.interval.max(5));
|
||||
let expires_in = response.expires_in.max(1);
|
||||
let expires_at = tokio::time::Instant::now() + Duration::from_secs(expires_in);
|
||||
|
||||
eprintln!(
|
||||
"\nGitHub Copilot authentication is required.\n\
|
||||
Visit: {}\n\
|
||||
Code: {}\n\
|
||||
Waiting for authorization...\n",
|
||||
response.verification_uri, response.user_code
|
||||
);
|
||||
|
||||
while tokio::time::Instant::now() < expires_at {
|
||||
tokio::time::sleep(poll_interval).await;
|
||||
|
||||
let token_response: AccessTokenResponse = self
|
||||
.http
|
||||
.post(GITHUB_ACCESS_TOKEN_URL)
|
||||
.header("Accept", "application/json")
|
||||
.json(&serde_json::json!({
|
||||
"client_id": GITHUB_CLIENT_ID,
|
||||
"device_code": response.device_code,
|
||||
"grant_type": "urn:ietf:params:oauth:grant-type:device_code"
|
||||
}))
|
||||
.send()
|
||||
.await?
|
||||
.json()
|
||||
.await?;
|
||||
|
||||
if let Some(token) = token_response.access_token {
|
||||
eprintln!("Authentication succeeded.\n");
|
||||
return Ok(token);
|
||||
}
|
||||
|
||||
match token_response.error.as_deref() {
|
||||
Some("slow_down") => {
|
||||
poll_interval += Duration::from_secs(5);
|
||||
}
|
||||
Some("authorization_pending") | None => {}
|
||||
Some("expired_token") => {
|
||||
anyhow::bail!("GitHub device authorization expired")
|
||||
}
|
||||
Some(error) => anyhow::bail!("GitHub auth failed: {error}"),
|
||||
}
|
||||
}
|
||||
|
||||
anyhow::bail!("Timed out waiting for GitHub authorization")
|
||||
}
|
||||
|
||||
/// Exchange a GitHub access token for a Copilot API key.
|
||||
async fn exchange_for_api_key(&self, access_token: &str) -> anyhow::Result<ApiKeyInfo> {
|
||||
let mut request = self.http.get(GITHUB_API_KEY_URL);
|
||||
for (header, value) in &Self::COPILOT_HEADERS {
|
||||
request = request.header(*header, *value);
|
||||
}
|
||||
request = request.header("Authorization", format!("token {access_token}"));
|
||||
|
||||
let response = request.send().await?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let body = response.text().await.unwrap_or_default();
|
||||
let sanitized = super::sanitize_api_error(&body);
|
||||
|
||||
if status.as_u16() == 401 || status.as_u16() == 403 {
|
||||
let access_token_path = self.token_dir.join("access-token");
|
||||
tokio::fs::remove_file(&access_token_path).await.ok();
|
||||
}
|
||||
|
||||
anyhow::bail!(
|
||||
"Failed to get Copilot API key ({status}): {sanitized}. \
|
||||
Ensure your GitHub account has an active Copilot subscription."
|
||||
);
|
||||
}
|
||||
|
||||
let info: ApiKeyInfo = response.json().await?;
|
||||
Ok(info)
|
||||
}
|
||||
|
||||
async fn load_api_key_from_disk(&self) -> Option<ApiKeyInfo> {
|
||||
let path = self.token_dir.join("api-key.json");
|
||||
let data = tokio::fs::read_to_string(&path).await.ok()?;
|
||||
serde_json::from_str(&data).ok()
|
||||
}
|
||||
|
||||
async fn save_api_key_to_disk(&self, info: &ApiKeyInfo) {
|
||||
let path = self.token_dir.join("api-key.json");
|
||||
if let Ok(json) = serde_json::to_string_pretty(info) {
|
||||
write_file_secure(&path, &json).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Write a file with 0600 permissions (owner read/write only).
|
||||
/// Uses `spawn_blocking` to avoid blocking the async runtime.
|
||||
async fn write_file_secure(path: &Path, content: &str) {
|
||||
let path = path.to_path_buf();
|
||||
let content = content.to_string();
|
||||
|
||||
let result = tokio::task::spawn_blocking(move || {
|
||||
#[cfg(unix)]
|
||||
{
|
||||
use std::io::Write;
|
||||
use std::os::unix::fs::{OpenOptionsExt, PermissionsExt};
|
||||
|
||||
let mut file = std::fs::OpenOptions::new()
|
||||
.write(true)
|
||||
.create(true)
|
||||
.truncate(true)
|
||||
.mode(0o600)
|
||||
.open(&path)?;
|
||||
file.write_all(content.as_bytes())?;
|
||||
|
||||
std::fs::set_permissions(&path, std::fs::Permissions::from_mode(0o600))?;
|
||||
Ok::<(), std::io::Error>(())
|
||||
}
|
||||
#[cfg(not(unix))]
|
||||
{
|
||||
std::fs::write(&path, &content)?;
|
||||
Ok::<(), std::io::Error>(())
|
||||
}
|
||||
})
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(Ok(())) => {}
|
||||
Ok(Err(err)) => warn!("Failed to write secure file: {err}"),
|
||||
Err(err) => warn!("Failed to spawn blocking write: {err}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Provider for CopilotProvider {
|
||||
async fn chat_with_system(
|
||||
&self,
|
||||
system_prompt: Option<&str>,
|
||||
message: &str,
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
let mut messages = Vec::new();
|
||||
if let Some(system) = system_prompt {
|
||||
messages.push(ApiMessage {
|
||||
role: "system".to_string(),
|
||||
content: Some(system.to_string()),
|
||||
tool_call_id: None,
|
||||
tool_calls: None,
|
||||
});
|
||||
}
|
||||
messages.push(ApiMessage {
|
||||
role: "user".to_string(),
|
||||
content: Some(message.to_string()),
|
||||
tool_call_id: None,
|
||||
tool_calls: None,
|
||||
});
|
||||
|
||||
let response = self
|
||||
.send_chat_request(messages, None, model, temperature)
|
||||
.await?;
|
||||
Ok(response.text.unwrap_or_default())
|
||||
}
|
||||
|
||||
async fn chat_with_history(
|
||||
&self,
|
||||
messages: &[ChatMessage],
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
let response = self
|
||||
.send_chat_request(Self::convert_messages(messages), None, model, temperature)
|
||||
.await?;
|
||||
Ok(response.text.unwrap_or_default())
|
||||
}
|
||||
|
||||
async fn chat(
|
||||
&self,
|
||||
request: ProviderChatRequest<'_>,
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<ProviderChatResponse> {
|
||||
self.send_chat_request(
|
||||
Self::convert_messages(request.messages),
|
||||
request.tools,
|
||||
model,
|
||||
temperature,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
fn supports_native_tools(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
async fn warmup(&self) -> anyhow::Result<()> {
|
||||
let _ = self.get_api_key().await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn new_without_token() {
|
||||
let provider = CopilotProvider::new(None);
|
||||
assert!(provider.github_token.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn new_with_token() {
|
||||
let provider = CopilotProvider::new(Some("ghp_test"));
|
||||
assert_eq!(provider.github_token.as_deref(), Some("ghp_test"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_token_treated_as_none() {
|
||||
let provider = CopilotProvider::new(Some(""));
|
||||
assert!(provider.github_token.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn cache_starts_empty() {
|
||||
let provider = CopilotProvider::new(None);
|
||||
let cached = provider.refresh_lock.lock().await;
|
||||
assert!(cached.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn copilot_headers_include_required_fields() {
|
||||
let headers = CopilotProvider::COPILOT_HEADERS;
|
||||
assert!(headers
|
||||
.iter()
|
||||
.any(|(header, _)| *header == "Editor-Version"));
|
||||
assert!(headers
|
||||
.iter()
|
||||
.any(|(header, _)| *header == "Editor-Plugin-Version"));
|
||||
assert!(headers.iter().any(|(header, _)| *header == "User-Agent"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn default_interval_and_expiry() {
|
||||
assert_eq!(default_interval(), 5);
|
||||
assert_eq!(default_expires_in(), 900);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn supports_native_tools() {
|
||||
let provider = CopilotProvider::new(None);
|
||||
assert!(provider.supports_native_tools());
|
||||
}
|
||||
}
|
||||
|
|
@ -1,5 +1,6 @@
|
|||
pub mod anthropic;
|
||||
pub mod compatible;
|
||||
pub mod copilot;
|
||||
pub mod gemini;
|
||||
pub mod ollama;
|
||||
pub mod openai;
|
||||
|
|
@ -37,9 +38,18 @@ fn token_end(input: &str, from: usize) -> usize {
|
|||
|
||||
/// Scrub known secret-like token prefixes from provider error strings.
|
||||
///
|
||||
/// Redacts tokens with prefixes like `sk-`, `xoxb-`, and `xoxp-`.
|
||||
/// Redacts tokens with prefixes like `sk-`, `xoxb-`, `xoxp-`, `ghp_`, `gho_`,
|
||||
/// `ghu_`, and `github_pat_`.
|
||||
pub fn scrub_secret_patterns(input: &str) -> String {
|
||||
const PREFIXES: [&str; 3] = ["sk-", "xoxb-", "xoxp-"];
|
||||
const PREFIXES: [&str; 7] = [
|
||||
"sk-",
|
||||
"xoxb-",
|
||||
"xoxp-",
|
||||
"ghp_",
|
||||
"gho_",
|
||||
"ghu_",
|
||||
"github_pat_",
|
||||
];
|
||||
|
||||
let mut scrubbed = input.to_string();
|
||||
|
||||
|
|
@ -104,9 +114,12 @@ pub async fn api_error(provider: &str, response: reqwest::Response) -> anyhow::E
|
|||
///
|
||||
/// For Anthropic, the provider-specific env var is `ANTHROPIC_OAUTH_TOKEN` (for setup-tokens)
|
||||
/// followed by `ANTHROPIC_API_KEY` (for regular API keys).
|
||||
fn resolve_api_key(name: &str, api_key: Option<&str>) -> Option<String> {
|
||||
if let Some(key) = api_key.map(str::trim).filter(|k| !k.is_empty()) {
|
||||
return Some(key.to_string());
|
||||
fn resolve_provider_credential(name: &str, credential_override: Option<&str>) -> Option<String> {
|
||||
if let Some(raw_override) = credential_override {
|
||||
let trimmed_override = raw_override.trim();
|
||||
if !trimmed_override.is_empty() {
|
||||
return Some(trimmed_override.to_owned());
|
||||
}
|
||||
}
|
||||
|
||||
let provider_env_candidates: Vec<&str> = match name {
|
||||
|
|
@ -135,6 +148,7 @@ fn resolve_api_key(name: &str, api_key: Option<&str>) -> Option<String> {
|
|||
"opencode" | "opencode-zen" => vec!["OPENCODE_API_KEY"],
|
||||
"vercel" | "vercel-ai" => vec!["VERCEL_API_KEY"],
|
||||
"cloudflare" | "cloudflare-ai" => vec!["CLOUDFLARE_API_KEY"],
|
||||
"astrai" => vec!["ASTRAI_API_KEY"],
|
||||
_ => vec![],
|
||||
};
|
||||
|
||||
|
|
@ -182,19 +196,28 @@ fn parse_custom_provider_url(
|
|||
}
|
||||
}
|
||||
|
||||
/// Factory: create the right provider from config
|
||||
#[allow(clippy::too_many_lines)]
|
||||
/// Factory: create the right provider from config (without custom URL)
|
||||
pub fn create_provider(name: &str, api_key: Option<&str>) -> anyhow::Result<Box<dyn Provider>> {
|
||||
let resolved_key = resolve_api_key(name, api_key);
|
||||
let key = resolved_key.as_deref();
|
||||
create_provider_with_url(name, api_key, None)
|
||||
}
|
||||
|
||||
/// Factory: create the right provider from config with optional custom base URL
|
||||
#[allow(clippy::too_many_lines)]
|
||||
pub fn create_provider_with_url(
|
||||
name: &str,
|
||||
api_key: Option<&str>,
|
||||
api_url: Option<&str>,
|
||||
) -> anyhow::Result<Box<dyn Provider>> {
|
||||
let resolved_credential = resolve_provider_credential(name, api_key);
|
||||
#[allow(clippy::option_as_ref_deref)]
|
||||
let key = resolved_credential.as_ref().map(String::as_str);
|
||||
match name {
|
||||
// ── Primary providers (custom implementations) ───────
|
||||
"openrouter" => Ok(Box::new(openrouter::OpenRouterProvider::new(key))),
|
||||
"anthropic" => Ok(Box::new(anthropic::AnthropicProvider::new(key))),
|
||||
"openai" => Ok(Box::new(openai::OpenAiProvider::new(key))),
|
||||
// Ollama is a local service that doesn't use API keys.
|
||||
// The api_key parameter is ignored to avoid it being misinterpreted as a base_url.
|
||||
"ollama" => Ok(Box::new(ollama::OllamaProvider::new(None))),
|
||||
// Ollama uses api_url for custom base URL (e.g. remote Ollama instance)
|
||||
"ollama" => Ok(Box::new(ollama::OllamaProvider::new(api_url))),
|
||||
"gemini" | "google" | "google-gemini" => {
|
||||
Ok(Box::new(gemini::GeminiProvider::new(key)))
|
||||
}
|
||||
|
|
@ -257,7 +280,7 @@ pub fn create_provider(name: &str, api_key: Option<&str>) -> anyhow::Result<Box<
|
|||
"Groq", "https://api.groq.com/openai", key, AuthStyle::Bearer,
|
||||
))),
|
||||
"mistral" => Ok(Box::new(OpenAiCompatibleProvider::new(
|
||||
"Mistral", "https://api.mistral.ai", key, AuthStyle::Bearer,
|
||||
"Mistral", "https://api.mistral.ai/v1", key, AuthStyle::Bearer,
|
||||
))),
|
||||
"xai" | "grok" => Ok(Box::new(OpenAiCompatibleProvider::new(
|
||||
"xAI", "https://api.x.ai", key, AuthStyle::Bearer,
|
||||
|
|
@ -277,11 +300,33 @@ pub fn create_provider(name: &str, api_key: Option<&str>) -> anyhow::Result<Box<
|
|||
"cohere" => Ok(Box::new(OpenAiCompatibleProvider::new(
|
||||
"Cohere", "https://api.cohere.com/compatibility", key, AuthStyle::Bearer,
|
||||
))),
|
||||
"copilot" | "github-copilot" => Ok(Box::new(OpenAiCompatibleProvider::new(
|
||||
"GitHub Copilot", "https://api.githubcopilot.com", key, AuthStyle::Bearer,
|
||||
))),
|
||||
"nvidia" | "nvidia-nim" | "build.nvidia.com" => Ok(Box::new(OpenAiCompatibleProvider::new(
|
||||
"NVIDIA NIM", "https://integrate.api.nvidia.com/v1", key, AuthStyle::Bearer,
|
||||
"copilot" | "github-copilot" => {
|
||||
Ok(Box::new(copilot::CopilotProvider::new(api_key)))
|
||||
},
|
||||
"lmstudio" | "lm-studio" => {
|
||||
let lm_studio_key = api_key
|
||||
.map(str::trim)
|
||||
.filter(|value| !value.is_empty())
|
||||
.unwrap_or("lm-studio");
|
||||
Ok(Box::new(OpenAiCompatibleProvider::new(
|
||||
"LM Studio",
|
||||
"http://localhost:1234/v1",
|
||||
Some(lm_studio_key),
|
||||
AuthStyle::Bearer,
|
||||
)))
|
||||
}
|
||||
"nvidia" | "nvidia-nim" | "build.nvidia.com" => Ok(Box::new(
|
||||
OpenAiCompatibleProvider::new(
|
||||
"NVIDIA NIM",
|
||||
"https://integrate.api.nvidia.com/v1",
|
||||
key,
|
||||
AuthStyle::Bearer,
|
||||
),
|
||||
)),
|
||||
|
||||
// ── AI inference routers ─────────────────────────────
|
||||
"astrai" => Ok(Box::new(OpenAiCompatibleProvider::new(
|
||||
"Astrai", "https://as-trai.com/v1", key, AuthStyle::Bearer,
|
||||
))),
|
||||
|
||||
// ── Bring Your Own Provider (custom URL) ───────────
|
||||
|
|
@ -326,13 +371,14 @@ pub fn create_provider(name: &str, api_key: Option<&str>) -> anyhow::Result<Box<
|
|||
pub fn create_resilient_provider(
|
||||
primary_name: &str,
|
||||
api_key: Option<&str>,
|
||||
api_url: Option<&str>,
|
||||
reliability: &crate::config::ReliabilityConfig,
|
||||
) -> anyhow::Result<Box<dyn Provider>> {
|
||||
let mut providers: Vec<(String, Box<dyn Provider>)> = Vec::new();
|
||||
|
||||
providers.push((
|
||||
primary_name.to_string(),
|
||||
create_provider(primary_name, api_key)?,
|
||||
create_provider_with_url(primary_name, api_key, api_url)?,
|
||||
));
|
||||
|
||||
for fallback in &reliability.fallback_providers {
|
||||
|
|
@ -340,21 +386,13 @@ pub fn create_resilient_provider(
|
|||
continue;
|
||||
}
|
||||
|
||||
if api_key.is_some() && fallback != "ollama" {
|
||||
tracing::warn!(
|
||||
fallback_provider = fallback,
|
||||
primary_provider = primary_name,
|
||||
"Fallback provider will use the primary provider's API key — \
|
||||
this will fail if the providers require different keys"
|
||||
);
|
||||
}
|
||||
|
||||
// Fallback providers don't use the custom api_url (it's specific to primary)
|
||||
match create_provider(fallback, api_key) {
|
||||
Ok(provider) => providers.push((fallback.clone(), provider)),
|
||||
Err(e) => {
|
||||
Err(_error) => {
|
||||
tracing::warn!(
|
||||
fallback_provider = fallback,
|
||||
"Ignoring invalid fallback provider: {e}"
|
||||
"Ignoring invalid fallback provider during initialization"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
@ -377,12 +415,13 @@ pub fn create_resilient_provider(
|
|||
pub fn create_routed_provider(
|
||||
primary_name: &str,
|
||||
api_key: Option<&str>,
|
||||
api_url: Option<&str>,
|
||||
reliability: &crate::config::ReliabilityConfig,
|
||||
model_routes: &[crate::config::ModelRouteConfig],
|
||||
default_model: &str,
|
||||
) -> anyhow::Result<Box<dyn Provider>> {
|
||||
if model_routes.is_empty() {
|
||||
return create_resilient_provider(primary_name, api_key, reliability);
|
||||
return create_resilient_provider(primary_name, api_key, api_url, reliability);
|
||||
}
|
||||
|
||||
// Collect unique provider names needed
|
||||
|
|
@ -396,12 +435,19 @@ pub fn create_routed_provider(
|
|||
// Create each provider (with its own resilience wrapper)
|
||||
let mut providers: Vec<(String, Box<dyn Provider>)> = Vec::new();
|
||||
for name in &needed {
|
||||
let key = model_routes
|
||||
let routed_credential = model_routes
|
||||
.iter()
|
||||
.find(|r| &r.provider == name)
|
||||
.and_then(|r| r.api_key.as_deref())
|
||||
.or(api_key);
|
||||
match create_resilient_provider(name, key, reliability) {
|
||||
.and_then(|r| {
|
||||
r.api_key.as_ref().and_then(|raw_key| {
|
||||
let trimmed_key = raw_key.trim();
|
||||
(!trimmed_key.is_empty()).then_some(trimmed_key)
|
||||
})
|
||||
});
|
||||
let key = routed_credential.or(api_key);
|
||||
// Only use api_url for the primary provider
|
||||
let url = if name == primary_name { api_url } else { None };
|
||||
match create_resilient_provider(name, key, url, reliability) {
|
||||
Ok(provider) => providers.push((name.clone(), provider)),
|
||||
Err(e) => {
|
||||
if name == primary_name {
|
||||
|
|
@ -409,7 +455,7 @@ pub fn create_routed_provider(
|
|||
}
|
||||
tracing::warn!(
|
||||
provider = name.as_str(),
|
||||
"Ignoring routed provider that failed to create: {e}"
|
||||
"Ignoring routed provider that failed to initialize"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
@ -441,27 +487,27 @@ mod tests {
|
|||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn resolve_api_key_prefers_explicit_argument() {
|
||||
let resolved = resolve_api_key("openrouter", Some(" explicit-key "));
|
||||
assert_eq!(resolved.as_deref(), Some("explicit-key"));
|
||||
fn resolve_provider_credential_prefers_explicit_argument() {
|
||||
let resolved = resolve_provider_credential("openrouter", Some(" explicit-key "));
|
||||
assert_eq!(resolved, Some("explicit-key".to_string()));
|
||||
}
|
||||
|
||||
// ── Primary providers ────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn factory_openrouter() {
|
||||
assert!(create_provider("openrouter", Some("sk-test")).is_ok());
|
||||
assert!(create_provider("openrouter", Some("provider-test-credential")).is_ok());
|
||||
assert!(create_provider("openrouter", None).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn factory_anthropic() {
|
||||
assert!(create_provider("anthropic", Some("sk-test")).is_ok());
|
||||
assert!(create_provider("anthropic", Some("provider-test-credential")).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn factory_openai() {
|
||||
assert!(create_provider("openai", Some("sk-test")).is_ok());
|
||||
assert!(create_provider("openai", Some("provider-test-credential")).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -556,6 +602,13 @@ mod tests {
|
|||
assert!(create_provider("dashscope-us", Some("key")).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn factory_lmstudio() {
|
||||
assert!(create_provider("lmstudio", Some("key")).is_ok());
|
||||
assert!(create_provider("lm-studio", Some("key")).is_ok());
|
||||
assert!(create_provider("lmstudio", None).is_ok());
|
||||
}
|
||||
|
||||
// ── Extended ecosystem ───────────────────────────────────
|
||||
|
||||
#[test]
|
||||
|
|
@ -614,6 +667,13 @@ mod tests {
|
|||
assert!(create_provider("build.nvidia.com", Some("nvapi-test")).is_ok());
|
||||
}
|
||||
|
||||
// ── AI inference routers ─────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn factory_astrai() {
|
||||
assert!(create_provider("astrai", Some("sk-astrai-test")).is_ok());
|
||||
}
|
||||
|
||||
// ── Custom / BYOP provider ─────────────────────────────
|
||||
|
||||
#[test]
|
||||
|
|
@ -761,17 +821,33 @@ mod tests {
|
|||
scheduler_retries: 2,
|
||||
};
|
||||
|
||||
let provider = create_resilient_provider("openrouter", Some("sk-test"), &reliability);
|
||||
let provider = create_resilient_provider(
|
||||
"openrouter",
|
||||
Some("provider-test-credential"),
|
||||
None,
|
||||
&reliability,
|
||||
);
|
||||
assert!(provider.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resilient_provider_errors_for_invalid_primary() {
|
||||
let reliability = crate::config::ReliabilityConfig::default();
|
||||
let provider = create_resilient_provider("totally-invalid", Some("sk-test"), &reliability);
|
||||
let provider = create_resilient_provider(
|
||||
"totally-invalid",
|
||||
Some("provider-test-credential"),
|
||||
None,
|
||||
&reliability,
|
||||
);
|
||||
assert!(provider.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ollama_with_custom_url() {
|
||||
let provider = create_provider_with_url("ollama", None, Some("http://10.100.2.32:11434"));
|
||||
assert!(provider.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn factory_all_providers_create_successfully() {
|
||||
let providers = [
|
||||
|
|
@ -794,6 +870,7 @@ mod tests {
|
|||
"qwen",
|
||||
"qwen-intl",
|
||||
"qwen-us",
|
||||
"lmstudio",
|
||||
"groq",
|
||||
"mistral",
|
||||
"xai",
|
||||
|
|
@ -888,7 +965,7 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn sanitize_preserves_unicode_boundaries() {
|
||||
let input = format!("{} sk-abcdef123", "こんにちは".repeat(80));
|
||||
let input = format!("{} sk-abcdef123", "hello🙂".repeat(80));
|
||||
let result = sanitize_api_error(&input);
|
||||
assert!(std::str::from_utf8(result.as_bytes()).is_ok());
|
||||
assert!(!result.contains("sk-abcdef123"));
|
||||
|
|
@ -900,4 +977,32 @@ mod tests {
|
|||
let result = sanitize_api_error(input);
|
||||
assert_eq!(result, input);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn scrub_github_personal_access_token() {
|
||||
let input = "auth failed with token ghp_abc123def456";
|
||||
let result = scrub_secret_patterns(input);
|
||||
assert_eq!(result, "auth failed with token [REDACTED]");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn scrub_github_oauth_token() {
|
||||
let input = "Bearer gho_1234567890abcdef";
|
||||
let result = scrub_secret_patterns(input);
|
||||
assert_eq!(result, "Bearer [REDACTED]");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn scrub_github_user_token() {
|
||||
let input = "token ghu_sessiontoken123";
|
||||
let result = scrub_secret_patterns(input);
|
||||
assert_eq!(result, "token [REDACTED]");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn scrub_github_fine_grained_pat() {
|
||||
let input = "failed: github_pat_11AABBC_xyzzy789";
|
||||
let result = scrub_secret_patterns(input);
|
||||
assert_eq!(result, "failed: [REDACTED]");
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8,6 +8,8 @@ pub struct OllamaProvider {
|
|||
client: Client,
|
||||
}
|
||||
|
||||
// ─── Request Structures ───────────────────────────────────────────────────────
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct ChatRequest {
|
||||
model: String,
|
||||
|
|
@ -27,6 +29,8 @@ struct Options {
|
|||
temperature: f64,
|
||||
}
|
||||
|
||||
// ─── Response Structures ──────────────────────────────────────────────────────
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ApiChatResponse {
|
||||
message: ResponseMessage,
|
||||
|
|
@ -34,9 +38,30 @@ struct ApiChatResponse {
|
|||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ResponseMessage {
|
||||
#[serde(default)]
|
||||
content: String,
|
||||
#[serde(default)]
|
||||
tool_calls: Vec<OllamaToolCall>,
|
||||
/// Some models return a "thinking" field with internal reasoning
|
||||
#[serde(default)]
|
||||
thinking: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct OllamaToolCall {
|
||||
id: Option<String>,
|
||||
function: OllamaFunction,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct OllamaFunction {
|
||||
name: String,
|
||||
#[serde(default)]
|
||||
arguments: serde_json::Value,
|
||||
}
|
||||
|
||||
// ─── Implementation ───────────────────────────────────────────────────────────
|
||||
|
||||
impl OllamaProvider {
|
||||
pub fn new(base_url: Option<&str>) -> Self {
|
||||
Self {
|
||||
|
|
@ -45,12 +70,145 @@ impl OllamaProvider {
|
|||
.trim_end_matches('/')
|
||||
.to_string(),
|
||||
client: Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(300)) // Ollama runs locally, may be slow
|
||||
.timeout(std::time::Duration::from_secs(300))
|
||||
.connect_timeout(std::time::Duration::from_secs(10))
|
||||
.build()
|
||||
.unwrap_or_else(|_| Client::new()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Send a request to Ollama and get the parsed response
|
||||
async fn send_request(
|
||||
&self,
|
||||
messages: Vec<Message>,
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<ApiChatResponse> {
|
||||
let request = ChatRequest {
|
||||
model: model.to_string(),
|
||||
messages,
|
||||
stream: false,
|
||||
options: Options { temperature },
|
||||
};
|
||||
|
||||
let url = format!("{}/api/chat", self.base_url);
|
||||
|
||||
tracing::debug!(
|
||||
"Ollama request: url={} model={} message_count={} temperature={}",
|
||||
url,
|
||||
model,
|
||||
request.messages.len(),
|
||||
temperature
|
||||
);
|
||||
|
||||
let response = self.client.post(&url).json(&request).send().await?;
|
||||
let status = response.status();
|
||||
tracing::debug!("Ollama response status: {}", status);
|
||||
|
||||
let body = response.bytes().await?;
|
||||
tracing::debug!("Ollama response body length: {} bytes", body.len());
|
||||
|
||||
if !status.is_success() {
|
||||
let raw = String::from_utf8_lossy(&body);
|
||||
let sanitized = super::sanitize_api_error(&raw);
|
||||
tracing::error!(
|
||||
"Ollama error response: status={} body_excerpt={}",
|
||||
status,
|
||||
sanitized
|
||||
);
|
||||
anyhow::bail!(
|
||||
"Ollama API error ({}): {}. Is Ollama running? (brew install ollama && ollama serve)",
|
||||
status,
|
||||
sanitized
|
||||
);
|
||||
}
|
||||
|
||||
let chat_response: ApiChatResponse = match serde_json::from_slice(&body) {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
let raw = String::from_utf8_lossy(&body);
|
||||
let sanitized = super::sanitize_api_error(&raw);
|
||||
tracing::error!(
|
||||
"Ollama response deserialization failed: {e}. body_excerpt={}",
|
||||
sanitized
|
||||
);
|
||||
anyhow::bail!("Failed to parse Ollama response: {e}");
|
||||
}
|
||||
};
|
||||
|
||||
Ok(chat_response)
|
||||
}
|
||||
|
||||
/// Convert Ollama tool calls to the JSON format expected by parse_tool_calls in loop_.rs
|
||||
///
|
||||
/// Handles quirky model behavior where tool calls are wrapped:
|
||||
/// - `{"name": "tool_call", "arguments": {"name": "shell", "arguments": {...}}}`
|
||||
/// - `{"name": "tool.shell", "arguments": {...}}`
|
||||
fn format_tool_calls_for_loop(&self, tool_calls: &[OllamaToolCall]) -> String {
|
||||
let formatted_calls: Vec<serde_json::Value> = tool_calls
|
||||
.iter()
|
||||
.map(|tc| {
|
||||
let (tool_name, tool_args) = self.extract_tool_name_and_args(tc);
|
||||
|
||||
// Arguments must be a JSON string for parse_tool_calls compatibility
|
||||
let args_str =
|
||||
serde_json::to_string(&tool_args).unwrap_or_else(|_| "{}".to_string());
|
||||
|
||||
serde_json::json!({
|
||||
"id": tc.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_name,
|
||||
"arguments": args_str
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
serde_json::json!({
|
||||
"content": "",
|
||||
"tool_calls": formatted_calls
|
||||
})
|
||||
.to_string()
|
||||
}
|
||||
|
||||
/// Extract the actual tool name and arguments from potentially nested structures
|
||||
fn extract_tool_name_and_args(&self, tc: &OllamaToolCall) -> (String, serde_json::Value) {
|
||||
let name = &tc.function.name;
|
||||
let args = &tc.function.arguments;
|
||||
|
||||
// Pattern 1: Nested tool_call wrapper (various malformed versions)
|
||||
// {"name": "tool_call", "arguments": {"name": "shell", "arguments": {"command": "date"}}}
|
||||
// {"name": "tool_call><json", "arguments": {"name": "shell", ...}}
|
||||
// {"name": "tool.call", "arguments": {"name": "shell", ...}}
|
||||
if name == "tool_call"
|
||||
|| name == "tool.call"
|
||||
|| name.starts_with("tool_call>")
|
||||
|| name.starts_with("tool_call<")
|
||||
{
|
||||
if let Some(nested_name) = args.get("name").and_then(|v| v.as_str()) {
|
||||
let nested_args = args
|
||||
.get("arguments")
|
||||
.cloned()
|
||||
.unwrap_or(serde_json::json!({}));
|
||||
tracing::debug!(
|
||||
"Unwrapped nested tool call: {} -> {} with args {:?}",
|
||||
name,
|
||||
nested_name,
|
||||
nested_args
|
||||
);
|
||||
return (nested_name.to_string(), nested_args);
|
||||
}
|
||||
}
|
||||
|
||||
// Pattern 2: Prefixed tool name (tool.shell, tool.file_read, etc.)
|
||||
if let Some(stripped) = name.strip_prefix("tool.") {
|
||||
return (stripped.to_string(), args.clone());
|
||||
}
|
||||
|
||||
// Pattern 3: Normal tool call
|
||||
(name.clone(), args.clone())
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
|
|
@ -76,27 +234,96 @@ impl Provider for OllamaProvider {
|
|||
content: message.to_string(),
|
||||
});
|
||||
|
||||
let request = ChatRequest {
|
||||
model: model.to_string(),
|
||||
messages,
|
||||
stream: false,
|
||||
options: Options { temperature },
|
||||
};
|
||||
let response = self.send_request(messages, model, temperature).await?;
|
||||
|
||||
let url = format!("{}/api/chat", self.base_url);
|
||||
|
||||
let response = self.client.post(&url).json(&request).send().await?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let err = super::api_error("Ollama", response).await;
|
||||
anyhow::bail!("{err}. Is Ollama running? (brew install ollama && ollama serve)");
|
||||
// If model returned tool calls, format them for loop_.rs's parse_tool_calls
|
||||
if !response.message.tool_calls.is_empty() {
|
||||
tracing::debug!(
|
||||
"Ollama returned {} tool call(s), formatting for loop parser",
|
||||
response.message.tool_calls.len()
|
||||
);
|
||||
return Ok(self.format_tool_calls_for_loop(&response.message.tool_calls));
|
||||
}
|
||||
|
||||
let chat_response: ApiChatResponse = response.json().await?;
|
||||
Ok(chat_response.message.content)
|
||||
// Plain text response
|
||||
let content = response.message.content;
|
||||
|
||||
// Handle edge case: model returned only "thinking" with no content or tool calls
|
||||
if content.is_empty() {
|
||||
if let Some(thinking) = &response.message.thinking {
|
||||
tracing::warn!(
|
||||
"Ollama returned empty content with only thinking: '{}'. Model may have stopped prematurely.",
|
||||
if thinking.len() > 100 { &thinking[..100] } else { thinking }
|
||||
);
|
||||
return Ok(format!(
|
||||
"I was thinking about this: {}... but I didn't complete my response. Could you try asking again?",
|
||||
if thinking.len() > 200 { &thinking[..200] } else { thinking }
|
||||
));
|
||||
}
|
||||
tracing::warn!("Ollama returned empty content with no tool calls");
|
||||
}
|
||||
|
||||
Ok(content)
|
||||
}
|
||||
|
||||
async fn chat_with_history(
|
||||
&self,
|
||||
messages: &[crate::providers::ChatMessage],
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
let api_messages: Vec<Message> = messages
|
||||
.iter()
|
||||
.map(|m| Message {
|
||||
role: m.role.clone(),
|
||||
content: m.content.clone(),
|
||||
})
|
||||
.collect();
|
||||
|
||||
let response = self.send_request(api_messages, model, temperature).await?;
|
||||
|
||||
// If model returned tool calls, format them for loop_.rs's parse_tool_calls
|
||||
if !response.message.tool_calls.is_empty() {
|
||||
tracing::debug!(
|
||||
"Ollama returned {} tool call(s), formatting for loop parser",
|
||||
response.message.tool_calls.len()
|
||||
);
|
||||
return Ok(self.format_tool_calls_for_loop(&response.message.tool_calls));
|
||||
}
|
||||
|
||||
// Plain text response
|
||||
let content = response.message.content;
|
||||
|
||||
// Handle edge case: model returned only "thinking" with no content or tool calls
|
||||
// This is a model quirk - it stopped after reasoning without producing output
|
||||
if content.is_empty() {
|
||||
if let Some(thinking) = &response.message.thinking {
|
||||
tracing::warn!(
|
||||
"Ollama returned empty content with only thinking: '{}'. Model may have stopped prematurely.",
|
||||
if thinking.len() > 100 { &thinking[..100] } else { thinking }
|
||||
);
|
||||
// Return a message indicating the model's thought process but no action
|
||||
return Ok(format!(
|
||||
"I was thinking about this: {}... but I didn't complete my response. Could you try asking again?",
|
||||
if thinking.len() > 200 { &thinking[..200] } else { thinking }
|
||||
));
|
||||
}
|
||||
tracing::warn!("Ollama returned empty content with no tool calls");
|
||||
}
|
||||
|
||||
Ok(content)
|
||||
}
|
||||
|
||||
fn supports_native_tools(&self) -> bool {
|
||||
// Return false since loop_.rs uses XML-style tool parsing via system prompt
|
||||
// The model may return native tool_calls but we convert them to JSON format
|
||||
// that parse_tool_calls() understands
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
// ─── Tests ────────────────────────────────────────────────────────────────────
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
|
@ -125,46 +352,6 @@ mod tests {
|
|||
assert_eq!(p.base_url, "");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn request_serializes_with_system() {
|
||||
let req = ChatRequest {
|
||||
model: "llama3".to_string(),
|
||||
messages: vec![
|
||||
Message {
|
||||
role: "system".to_string(),
|
||||
content: "You are ZeroClaw".to_string(),
|
||||
},
|
||||
Message {
|
||||
role: "user".to_string(),
|
||||
content: "hello".to_string(),
|
||||
},
|
||||
],
|
||||
stream: false,
|
||||
options: Options { temperature: 0.7 },
|
||||
};
|
||||
let json = serde_json::to_string(&req).unwrap();
|
||||
assert!(json.contains("\"stream\":false"));
|
||||
assert!(json.contains("llama3"));
|
||||
assert!(json.contains("system"));
|
||||
assert!(json.contains("\"temperature\":0.7"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn request_serializes_without_system() {
|
||||
let req = ChatRequest {
|
||||
model: "mistral".to_string(),
|
||||
messages: vec![Message {
|
||||
role: "user".to_string(),
|
||||
content: "test".to_string(),
|
||||
}],
|
||||
stream: false,
|
||||
options: Options { temperature: 0.0 },
|
||||
};
|
||||
let json = serde_json::to_string(&req).unwrap();
|
||||
assert!(!json.contains("\"role\":\"system\""));
|
||||
assert!(json.contains("mistral"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_deserializes() {
|
||||
let json = r#"{"message":{"role":"assistant","content":"Hello from Ollama!"}}"#;
|
||||
|
|
@ -180,9 +367,98 @@ mod tests {
|
|||
}
|
||||
|
||||
#[test]
|
||||
fn response_with_multiline() {
|
||||
let json = r#"{"message":{"role":"assistant","content":"line1\nline2\nline3"}}"#;
|
||||
fn response_with_missing_content_defaults_to_empty() {
|
||||
let json = r#"{"message":{"role":"assistant"}}"#;
|
||||
let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
|
||||
assert!(resp.message.content.contains("line1"));
|
||||
assert!(resp.message.content.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_with_thinking_field_extracts_content() {
|
||||
let json =
|
||||
r#"{"message":{"role":"assistant","content":"hello","thinking":"internal reasoning"}}"#;
|
||||
let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(resp.message.content, "hello");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_with_tool_calls_parses_correctly() {
|
||||
let json = r#"{"message":{"role":"assistant","content":"","tool_calls":[{"id":"call_123","function":{"name":"shell","arguments":{"command":"date"}}}]}}"#;
|
||||
let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
|
||||
assert!(resp.message.content.is_empty());
|
||||
assert_eq!(resp.message.tool_calls.len(), 1);
|
||||
assert_eq!(resp.message.tool_calls[0].function.name, "shell");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_tool_name_handles_nested_tool_call() {
|
||||
let provider = OllamaProvider::new(None);
|
||||
let tc = OllamaToolCall {
|
||||
id: Some("call_123".into()),
|
||||
function: OllamaFunction {
|
||||
name: "tool_call".into(),
|
||||
arguments: serde_json::json!({
|
||||
"name": "shell",
|
||||
"arguments": {"command": "date"}
|
||||
}),
|
||||
},
|
||||
};
|
||||
let (name, args) = provider.extract_tool_name_and_args(&tc);
|
||||
assert_eq!(name, "shell");
|
||||
assert_eq!(args.get("command").unwrap(), "date");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_tool_name_handles_prefixed_name() {
|
||||
let provider = OllamaProvider::new(None);
|
||||
let tc = OllamaToolCall {
|
||||
id: Some("call_123".into()),
|
||||
function: OllamaFunction {
|
||||
name: "tool.shell".into(),
|
||||
arguments: serde_json::json!({"command": "ls"}),
|
||||
},
|
||||
};
|
||||
let (name, args) = provider.extract_tool_name_and_args(&tc);
|
||||
assert_eq!(name, "shell");
|
||||
assert_eq!(args.get("command").unwrap(), "ls");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_tool_name_handles_normal_call() {
|
||||
let provider = OllamaProvider::new(None);
|
||||
let tc = OllamaToolCall {
|
||||
id: Some("call_123".into()),
|
||||
function: OllamaFunction {
|
||||
name: "file_read".into(),
|
||||
arguments: serde_json::json!({"path": "/tmp/test"}),
|
||||
},
|
||||
};
|
||||
let (name, args) = provider.extract_tool_name_and_args(&tc);
|
||||
assert_eq!(name, "file_read");
|
||||
assert_eq!(args.get("path").unwrap(), "/tmp/test");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn format_tool_calls_produces_valid_json() {
|
||||
let provider = OllamaProvider::new(None);
|
||||
let tool_calls = vec![OllamaToolCall {
|
||||
id: Some("call_abc".into()),
|
||||
function: OllamaFunction {
|
||||
name: "shell".into(),
|
||||
arguments: serde_json::json!({"command": "date"}),
|
||||
},
|
||||
}];
|
||||
|
||||
let formatted = provider.format_tool_calls_for_loop(&tool_calls);
|
||||
let parsed: serde_json::Value = serde_json::from_str(&formatted).unwrap();
|
||||
|
||||
assert!(parsed.get("tool_calls").is_some());
|
||||
let calls = parsed.get("tool_calls").unwrap().as_array().unwrap();
|
||||
assert_eq!(calls.len(), 1);
|
||||
|
||||
let func = calls[0].get("function").unwrap();
|
||||
assert_eq!(func.get("name").unwrap(), "shell");
|
||||
// arguments should be a string (JSON-encoded)
|
||||
assert!(func.get("arguments").unwrap().is_string());
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ use reqwest::Client;
|
|||
use serde::{Deserialize, Serialize};
|
||||
|
||||
pub struct OpenAiProvider {
|
||||
api_key: Option<String>,
|
||||
credential: Option<String>,
|
||||
client: Client,
|
||||
}
|
||||
|
||||
|
|
@ -110,9 +110,9 @@ struct NativeResponseMessage {
|
|||
}
|
||||
|
||||
impl OpenAiProvider {
|
||||
pub fn new(api_key: Option<&str>) -> Self {
|
||||
pub fn new(credential: Option<&str>) -> Self {
|
||||
Self {
|
||||
api_key: api_key.map(ToString::to_string),
|
||||
credential: credential.map(ToString::to_string),
|
||||
client: Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(120))
|
||||
.connect_timeout(std::time::Duration::from_secs(10))
|
||||
|
|
@ -232,7 +232,7 @@ impl Provider for OpenAiProvider {
|
|||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
let api_key = self.api_key.as_ref().ok_or_else(|| {
|
||||
let credential = self.credential.as_ref().ok_or_else(|| {
|
||||
anyhow::anyhow!("OpenAI API key not set. Set OPENAI_API_KEY or edit config.toml.")
|
||||
})?;
|
||||
|
||||
|
|
@ -259,7 +259,7 @@ impl Provider for OpenAiProvider {
|
|||
let response = self
|
||||
.client
|
||||
.post("https://api.openai.com/v1/chat/completions")
|
||||
.header("Authorization", format!("Bearer {api_key}"))
|
||||
.header("Authorization", format!("Bearer {credential}"))
|
||||
.json(&request)
|
||||
.send()
|
||||
.await?;
|
||||
|
|
@ -284,7 +284,7 @@ impl Provider for OpenAiProvider {
|
|||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<ProviderChatResponse> {
|
||||
let api_key = self.api_key.as_ref().ok_or_else(|| {
|
||||
let credential = self.credential.as_ref().ok_or_else(|| {
|
||||
anyhow::anyhow!("OpenAI API key not set. Set OPENAI_API_KEY or edit config.toml.")
|
||||
})?;
|
||||
|
||||
|
|
@ -300,7 +300,7 @@ impl Provider for OpenAiProvider {
|
|||
let response = self
|
||||
.client
|
||||
.post("https://api.openai.com/v1/chat/completions")
|
||||
.header("Authorization", format!("Bearer {api_key}"))
|
||||
.header("Authorization", format!("Bearer {credential}"))
|
||||
.json(&native_request)
|
||||
.send()
|
||||
.await?;
|
||||
|
|
@ -330,20 +330,20 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn creates_with_key() {
|
||||
let p = OpenAiProvider::new(Some("sk-proj-abc123"));
|
||||
assert_eq!(p.api_key.as_deref(), Some("sk-proj-abc123"));
|
||||
let p = OpenAiProvider::new(Some("openai-test-credential"));
|
||||
assert_eq!(p.credential.as_deref(), Some("openai-test-credential"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn creates_without_key() {
|
||||
let p = OpenAiProvider::new(None);
|
||||
assert!(p.api_key.is_none());
|
||||
assert!(p.credential.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn creates_with_empty_key() {
|
||||
let p = OpenAiProvider::new(Some(""));
|
||||
assert_eq!(p.api_key.as_deref(), Some(""));
|
||||
assert_eq!(p.credential.as_deref(), Some(""));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ use reqwest::Client;
|
|||
use serde::{Deserialize, Serialize};
|
||||
|
||||
pub struct OpenRouterProvider {
|
||||
api_key: Option<String>,
|
||||
credential: Option<String>,
|
||||
client: Client,
|
||||
}
|
||||
|
||||
|
|
@ -110,9 +110,9 @@ struct NativeResponseMessage {
|
|||
}
|
||||
|
||||
impl OpenRouterProvider {
|
||||
pub fn new(api_key: Option<&str>) -> Self {
|
||||
pub fn new(credential: Option<&str>) -> Self {
|
||||
Self {
|
||||
api_key: api_key.map(ToString::to_string),
|
||||
credential: credential.map(ToString::to_string),
|
||||
client: Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(120))
|
||||
.connect_timeout(std::time::Duration::from_secs(10))
|
||||
|
|
@ -232,10 +232,10 @@ impl Provider for OpenRouterProvider {
|
|||
async fn warmup(&self) -> anyhow::Result<()> {
|
||||
// Hit a lightweight endpoint to establish TLS + HTTP/2 connection pool.
|
||||
// This prevents the first real chat request from timing out on cold start.
|
||||
if let Some(api_key) = self.api_key.as_ref() {
|
||||
if let Some(credential) = self.credential.as_ref() {
|
||||
self.client
|
||||
.get("https://openrouter.ai/api/v1/auth/key")
|
||||
.header("Authorization", format!("Bearer {api_key}"))
|
||||
.header("Authorization", format!("Bearer {credential}"))
|
||||
.send()
|
||||
.await?
|
||||
.error_for_status()?;
|
||||
|
|
@ -250,7 +250,7 @@ impl Provider for OpenRouterProvider {
|
|||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
let api_key = self.api_key.as_ref()
|
||||
let credential = self.credential.as_ref()
|
||||
.ok_or_else(|| anyhow::anyhow!("OpenRouter API key not set. Run `zeroclaw onboard` or set OPENROUTER_API_KEY env var."))?;
|
||||
|
||||
let mut messages = Vec::new();
|
||||
|
|
@ -276,7 +276,7 @@ impl Provider for OpenRouterProvider {
|
|||
let response = self
|
||||
.client
|
||||
.post("https://openrouter.ai/api/v1/chat/completions")
|
||||
.header("Authorization", format!("Bearer {api_key}"))
|
||||
.header("Authorization", format!("Bearer {credential}"))
|
||||
.header(
|
||||
"HTTP-Referer",
|
||||
"https://github.com/theonlyhennygod/zeroclaw",
|
||||
|
|
@ -306,7 +306,7 @@ impl Provider for OpenRouterProvider {
|
|||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
let api_key = self.api_key.as_ref()
|
||||
let credential = self.credential.as_ref()
|
||||
.ok_or_else(|| anyhow::anyhow!("OpenRouter API key not set. Run `zeroclaw onboard` or set OPENROUTER_API_KEY env var."))?;
|
||||
|
||||
let api_messages: Vec<Message> = messages
|
||||
|
|
@ -326,7 +326,7 @@ impl Provider for OpenRouterProvider {
|
|||
let response = self
|
||||
.client
|
||||
.post("https://openrouter.ai/api/v1/chat/completions")
|
||||
.header("Authorization", format!("Bearer {api_key}"))
|
||||
.header("Authorization", format!("Bearer {credential}"))
|
||||
.header(
|
||||
"HTTP-Referer",
|
||||
"https://github.com/theonlyhennygod/zeroclaw",
|
||||
|
|
@ -356,7 +356,7 @@ impl Provider for OpenRouterProvider {
|
|||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<ProviderChatResponse> {
|
||||
let api_key = self.api_key.as_ref().ok_or_else(|| {
|
||||
let credential = self.credential.as_ref().ok_or_else(|| {
|
||||
anyhow::anyhow!(
|
||||
"OpenRouter API key not set. Run `zeroclaw onboard` or set OPENROUTER_API_KEY env var."
|
||||
)
|
||||
|
|
@ -374,7 +374,7 @@ impl Provider for OpenRouterProvider {
|
|||
let response = self
|
||||
.client
|
||||
.post("https://openrouter.ai/api/v1/chat/completions")
|
||||
.header("Authorization", format!("Bearer {api_key}"))
|
||||
.header("Authorization", format!("Bearer {credential}"))
|
||||
.header(
|
||||
"HTTP-Referer",
|
||||
"https://github.com/theonlyhennygod/zeroclaw",
|
||||
|
|
@ -409,7 +409,7 @@ impl Provider for OpenRouterProvider {
|
|||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<ProviderChatResponse> {
|
||||
let api_key = self.api_key.as_ref().ok_or_else(|| {
|
||||
let credential = self.credential.as_ref().ok_or_else(|| {
|
||||
anyhow::anyhow!(
|
||||
"OpenRouter API key not set. Run `zeroclaw onboard` or set OPENROUTER_API_KEY env var."
|
||||
)
|
||||
|
|
@ -462,7 +462,7 @@ impl Provider for OpenRouterProvider {
|
|||
let response = self
|
||||
.client
|
||||
.post("https://openrouter.ai/api/v1/chat/completions")
|
||||
.header("Authorization", format!("Bearer {api_key}"))
|
||||
.header("Authorization", format!("Bearer {credential}"))
|
||||
.header(
|
||||
"HTTP-Referer",
|
||||
"https://github.com/theonlyhennygod/zeroclaw",
|
||||
|
|
@ -494,14 +494,17 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn creates_with_key() {
|
||||
let provider = OpenRouterProvider::new(Some("sk-or-123"));
|
||||
assert_eq!(provider.api_key.as_deref(), Some("sk-or-123"));
|
||||
let provider = OpenRouterProvider::new(Some("openrouter-test-credential"));
|
||||
assert_eq!(
|
||||
provider.credential.as_deref(),
|
||||
Some("openrouter-test-credential")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn creates_without_key() {
|
||||
let provider = OpenRouterProvider::new(None);
|
||||
assert!(provider.api_key.is_none());
|
||||
assert!(provider.credential.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
|
|
|||
|
|
@ -144,8 +144,8 @@ impl Provider for ReliableProvider {
|
|||
async fn warmup(&self) -> anyhow::Result<()> {
|
||||
for (name, provider) in &self.providers {
|
||||
tracing::info!(provider = name, "Warming up provider connection pool");
|
||||
if let Err(e) = provider.warmup().await {
|
||||
tracing::warn!(provider = name, "Warmup failed (non-fatal): {e}");
|
||||
if provider.warmup().await.is_err() {
|
||||
tracing::warn!(provider = name, "Warmup failed (non-fatal)");
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
|
|
@ -186,8 +186,15 @@ impl Provider for ReliableProvider {
|
|||
let non_retryable = is_non_retryable(&e);
|
||||
let rate_limited = is_rate_limited(&e);
|
||||
|
||||
let failure_reason = if rate_limited {
|
||||
"rate_limited"
|
||||
} else if non_retryable {
|
||||
"non_retryable"
|
||||
} else {
|
||||
"retryable"
|
||||
};
|
||||
failures.push(format!(
|
||||
"{provider_name}/{current_model} attempt {}/{}: {e}",
|
||||
"{provider_name}/{current_model} attempt {}/{}: {failure_reason}",
|
||||
attempt + 1,
|
||||
self.max_retries + 1
|
||||
));
|
||||
|
|
@ -284,8 +291,15 @@ impl Provider for ReliableProvider {
|
|||
let non_retryable = is_non_retryable(&e);
|
||||
let rate_limited = is_rate_limited(&e);
|
||||
|
||||
let failure_reason = if rate_limited {
|
||||
"rate_limited"
|
||||
} else if non_retryable {
|
||||
"non_retryable"
|
||||
} else {
|
||||
"retryable"
|
||||
};
|
||||
failures.push(format!(
|
||||
"{provider_name}/{current_model} attempt {}/{}: {e}",
|
||||
"{provider_name}/{current_model} attempt {}/{}: {failure_reason}",
|
||||
attempt + 1,
|
||||
self.max_retries + 1
|
||||
));
|
||||
|
|
|
|||
|
|
@ -193,6 +193,13 @@ pub enum StreamError {
|
|||
|
||||
#[async_trait]
|
||||
pub trait Provider: Send + Sync {
|
||||
/// Query provider capabilities.
|
||||
///
|
||||
/// Default implementation returns minimal capabilities (no native tool calling).
|
||||
/// Providers should override this to declare their actual capabilities.
|
||||
fn capabilities(&self) -> ProviderCapabilities {
|
||||
ProviderCapabilities::default()
|
||||
}
|
||||
/// Simple one-shot chat (single user message, no explicit system prompt).
|
||||
///
|
||||
/// This is the preferred API for non-agentic direct interactions.
|
||||
|
|
@ -256,7 +263,7 @@ pub trait Provider: Send + Sync {
|
|||
|
||||
/// Whether provider supports native tool calls over API.
|
||||
fn supports_native_tools(&self) -> bool {
|
||||
false
|
||||
self.capabilities().native_tool_calling
|
||||
}
|
||||
|
||||
/// Warm up the HTTP connection pool (TLS handshake, DNS, HTTP/2 setup).
|
||||
|
|
@ -336,6 +343,27 @@ pub trait Provider: Send + Sync {
|
|||
mod tests {
|
||||
use super::*;
|
||||
|
||||
struct CapabilityMockProvider;
|
||||
|
||||
#[async_trait]
|
||||
impl Provider for CapabilityMockProvider {
|
||||
fn capabilities(&self) -> ProviderCapabilities {
|
||||
ProviderCapabilities {
|
||||
native_tool_calling: true,
|
||||
}
|
||||
}
|
||||
|
||||
async fn chat_with_system(
|
||||
&self,
|
||||
_system_prompt: Option<&str>,
|
||||
_message: &str,
|
||||
_model: &str,
|
||||
_temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
Ok("ok".into())
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn chat_message_constructors() {
|
||||
let sys = ChatMessage::system("Be helpful");
|
||||
|
|
@ -398,4 +426,32 @@ mod tests {
|
|||
let json = serde_json::to_string(&tool_result).unwrap();
|
||||
assert!(json.contains("\"type\":\"ToolResults\""));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn provider_capabilities_default() {
|
||||
let caps = ProviderCapabilities::default();
|
||||
assert!(!caps.native_tool_calling);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn provider_capabilities_equality() {
|
||||
let caps1 = ProviderCapabilities {
|
||||
native_tool_calling: true,
|
||||
};
|
||||
let caps2 = ProviderCapabilities {
|
||||
native_tool_calling: true,
|
||||
};
|
||||
let caps3 = ProviderCapabilities {
|
||||
native_tool_calling: false,
|
||||
};
|
||||
|
||||
assert_eq!(caps1, caps2);
|
||||
assert_ne!(caps1, caps3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn supports_native_tools_reflects_capabilities_default_mapping() {
|
||||
let provider = CapabilityMockProvider;
|
||||
assert!(provider.supports_native_tools());
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -81,14 +81,17 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn bubblewrap_sandbox_name() {
|
||||
assert_eq!(BubblewrapSandbox.name(), "bubblewrap");
|
||||
let sandbox = BubblewrapSandbox;
|
||||
assert_eq!(sandbox.name(), "bubblewrap");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bubblewrap_is_available_only_if_installed() {
|
||||
// Result depends on whether bwrap is installed
|
||||
let available = BubblewrapSandbox::is_available();
|
||||
let sandbox = BubblewrapSandbox;
|
||||
let _available = sandbox.is_available();
|
||||
|
||||
// Either way, the name should still work
|
||||
assert_eq!(BubblewrapSandbox.name(), "bubblewrap");
|
||||
assert_eq!(sandbox.name(), "bubblewrap");
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -184,7 +184,7 @@ fn generate_token() -> String {
|
|||
use rand::RngCore;
|
||||
let mut bytes = [0u8; 32];
|
||||
rand::thread_rng().fill_bytes(&mut bytes);
|
||||
format!("zc_{}", hex::encode(&bytes))
|
||||
format!("zc_{}", hex::encode(bytes))
|
||||
}
|
||||
|
||||
/// SHA-256 hash a bearer token for storage. Returns lowercase hex.
|
||||
|
|
|
|||
|
|
@ -343,6 +343,7 @@ impl SecurityPolicy {
|
|||
/// validates each sub-command against the allowlist
|
||||
/// - Blocks single `&` background chaining (`&&` remains supported)
|
||||
/// - Blocks output redirections (`>`, `>>`) that could write outside workspace
|
||||
/// - Blocks dangerous arguments (e.g. `find -exec`, `git config`)
|
||||
pub fn is_command_allowed(&self, command: &str) -> bool {
|
||||
if self.autonomy == AutonomyLevel::ReadOnly {
|
||||
return false;
|
||||
|
|
@ -350,7 +351,12 @@ impl SecurityPolicy {
|
|||
|
||||
// Block subshell/expansion operators — these allow hiding arbitrary
|
||||
// commands inside an allowed command (e.g. `echo $(rm -rf /)`)
|
||||
if command.contains('`') || command.contains("$(") || command.contains("${") {
|
||||
if command.contains('`')
|
||||
|| command.contains("$(")
|
||||
|| command.contains("${")
|
||||
|| command.contains("<(")
|
||||
|| command.contains(">(")
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
@ -359,6 +365,15 @@ impl SecurityPolicy {
|
|||
return false;
|
||||
}
|
||||
|
||||
// Block `tee` — it can write to arbitrary files, bypassing the
|
||||
// redirect check above (e.g. `echo secret | tee /etc/crontab`)
|
||||
if command
|
||||
.split_whitespace()
|
||||
.any(|w| w == "tee" || w.ends_with("/tee"))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// Block background command chaining (`&`), which can hide extra
|
||||
// sub-commands and outlive timeout expectations. Keep `&&` allowed.
|
||||
if contains_single_ampersand(command) {
|
||||
|
|
@ -384,13 +399,9 @@ impl SecurityPolicy {
|
|||
// Strip leading env var assignments (e.g. FOO=bar cmd)
|
||||
let cmd_part = skip_env_assignments(segment);
|
||||
|
||||
let base_cmd = cmd_part
|
||||
.split_whitespace()
|
||||
.next()
|
||||
.unwrap_or("")
|
||||
.rsplit('/')
|
||||
.next()
|
||||
.unwrap_or("");
|
||||
let mut words = cmd_part.split_whitespace();
|
||||
let base_raw = words.next().unwrap_or("");
|
||||
let base_cmd = base_raw.rsplit('/').next().unwrap_or("");
|
||||
|
||||
if base_cmd.is_empty() {
|
||||
continue;
|
||||
|
|
@ -403,6 +414,12 @@ impl SecurityPolicy {
|
|||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// Validate arguments for the command
|
||||
let args: Vec<String> = words.map(|w| w.to_ascii_lowercase()).collect();
|
||||
if !self.is_args_safe(base_cmd, &args) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// At least one command must be present
|
||||
|
|
@ -414,6 +431,29 @@ impl SecurityPolicy {
|
|||
has_cmd
|
||||
}
|
||||
|
||||
/// Check for dangerous arguments that allow sub-command execution.
|
||||
fn is_args_safe(&self, base: &str, args: &[String]) -> bool {
|
||||
let base = base.to_ascii_lowercase();
|
||||
match base.as_str() {
|
||||
"find" => {
|
||||
// find -exec and find -ok allow arbitrary command execution
|
||||
!args.iter().any(|arg| arg == "-exec" || arg == "-ok")
|
||||
}
|
||||
"git" => {
|
||||
// git config, alias, and -c can be used to set dangerous options
|
||||
// (e.g. git config core.editor "rm -rf /")
|
||||
!args.iter().any(|arg| {
|
||||
arg == "config"
|
||||
|| arg.starts_with("config.")
|
||||
|| arg == "alias"
|
||||
|| arg.starts_with("alias.")
|
||||
|| arg == "-c"
|
||||
})
|
||||
}
|
||||
_ => true,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if a file path is allowed (no path traversal, within workspace)
|
||||
pub fn is_path_allowed(&self, path: &str) -> bool {
|
||||
// Block null bytes (can truncate paths in C-backed syscalls)
|
||||
|
|
@ -982,12 +1022,43 @@ mod tests {
|
|||
assert!(!p.is_command_allowed("ls >> /tmp/exfil.txt"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn command_argument_injection_blocked() {
|
||||
let p = default_policy();
|
||||
// find -exec is a common bypass
|
||||
assert!(!p.is_command_allowed("find . -exec rm -rf {} +"));
|
||||
assert!(!p.is_command_allowed("find / -ok cat {} \\;"));
|
||||
// git config/alias can execute commands
|
||||
assert!(!p.is_command_allowed("git config core.editor \"rm -rf /\""));
|
||||
assert!(!p.is_command_allowed("git alias.st status"));
|
||||
assert!(!p.is_command_allowed("git -c core.editor=calc.exe commit"));
|
||||
// Legitimate commands should still work
|
||||
assert!(p.is_command_allowed("find . -name '*.txt'"));
|
||||
assert!(p.is_command_allowed("git status"));
|
||||
assert!(p.is_command_allowed("git add ."));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn command_injection_dollar_brace_blocked() {
|
||||
let p = default_policy();
|
||||
assert!(!p.is_command_allowed("echo ${IFS}cat${IFS}/etc/passwd"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn command_injection_tee_blocked() {
|
||||
let p = default_policy();
|
||||
assert!(!p.is_command_allowed("echo secret | tee /etc/crontab"));
|
||||
assert!(!p.is_command_allowed("ls | /usr/bin/tee outfile"));
|
||||
assert!(!p.is_command_allowed("tee file.txt"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn command_injection_process_substitution_blocked() {
|
||||
let p = default_policy();
|
||||
assert!(!p.is_command_allowed("cat <(echo pwned)"));
|
||||
assert!(!p.is_command_allowed("ls >(cat /etc/passwd)"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn command_env_var_prefix_with_allowed_cmd() {
|
||||
let p = default_policy();
|
||||
|
|
|
|||
|
|
@ -854,7 +854,6 @@ impl BrowserTool {
|
|||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_lines)]
|
||||
#[async_trait]
|
||||
impl Tool for BrowserTool {
|
||||
fn name(&self) -> &str {
|
||||
|
|
@ -1031,165 +1030,21 @@ impl Tool for BrowserTool {
|
|||
return self.execute_computer_use_action(action_str, &args).await;
|
||||
}
|
||||
|
||||
let action = match action_str {
|
||||
"open" => {
|
||||
let url = args
|
||||
.get("url")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'url' for open action"))?;
|
||||
BrowserAction::Open { url: url.into() }
|
||||
}
|
||||
"snapshot" => BrowserAction::Snapshot {
|
||||
interactive_only: args
|
||||
.get("interactive_only")
|
||||
.and_then(serde_json::Value::as_bool)
|
||||
.unwrap_or(true), // Default to interactive for AI
|
||||
compact: args
|
||||
.get("compact")
|
||||
.and_then(serde_json::Value::as_bool)
|
||||
.unwrap_or(true),
|
||||
depth: args
|
||||
.get("depth")
|
||||
.and_then(serde_json::Value::as_u64)
|
||||
.map(|d| u32::try_from(d).unwrap_or(u32::MAX)),
|
||||
},
|
||||
"click" => {
|
||||
let selector = args
|
||||
.get("selector")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'selector' for click"))?;
|
||||
BrowserAction::Click {
|
||||
selector: selector.into(),
|
||||
}
|
||||
}
|
||||
"fill" => {
|
||||
let selector = args
|
||||
.get("selector")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'selector' for fill"))?;
|
||||
let value = args
|
||||
.get("value")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'value' for fill"))?;
|
||||
BrowserAction::Fill {
|
||||
selector: selector.into(),
|
||||
value: value.into(),
|
||||
}
|
||||
}
|
||||
"type" => {
|
||||
let selector = args
|
||||
.get("selector")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'selector' for type"))?;
|
||||
let text = args
|
||||
.get("text")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'text' for type"))?;
|
||||
BrowserAction::Type {
|
||||
selector: selector.into(),
|
||||
text: text.into(),
|
||||
}
|
||||
}
|
||||
"get_text" => {
|
||||
let selector = args
|
||||
.get("selector")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'selector' for get_text"))?;
|
||||
BrowserAction::GetText {
|
||||
selector: selector.into(),
|
||||
}
|
||||
}
|
||||
"get_title" => BrowserAction::GetTitle,
|
||||
"get_url" => BrowserAction::GetUrl,
|
||||
"screenshot" => BrowserAction::Screenshot {
|
||||
path: args.get("path").and_then(|v| v.as_str()).map(String::from),
|
||||
full_page: args
|
||||
.get("full_page")
|
||||
.and_then(serde_json::Value::as_bool)
|
||||
.unwrap_or(false),
|
||||
},
|
||||
"wait" => BrowserAction::Wait {
|
||||
selector: args
|
||||
.get("selector")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(String::from),
|
||||
ms: args.get("ms").and_then(serde_json::Value::as_u64),
|
||||
text: args.get("text").and_then(|v| v.as_str()).map(String::from),
|
||||
},
|
||||
"press" => {
|
||||
let key = args
|
||||
.get("key")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'key' for press"))?;
|
||||
BrowserAction::Press { key: key.into() }
|
||||
}
|
||||
"hover" => {
|
||||
let selector = args
|
||||
.get("selector")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'selector' for hover"))?;
|
||||
BrowserAction::Hover {
|
||||
selector: selector.into(),
|
||||
}
|
||||
}
|
||||
"scroll" => {
|
||||
let direction = args
|
||||
.get("direction")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'direction' for scroll"))?;
|
||||
BrowserAction::Scroll {
|
||||
direction: direction.into(),
|
||||
pixels: args
|
||||
.get("pixels")
|
||||
.and_then(serde_json::Value::as_u64)
|
||||
.map(|p| u32::try_from(p).unwrap_or(u32::MAX)),
|
||||
}
|
||||
}
|
||||
"is_visible" => {
|
||||
let selector = args
|
||||
.get("selector")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'selector' for is_visible"))?;
|
||||
BrowserAction::IsVisible {
|
||||
selector: selector.into(),
|
||||
}
|
||||
}
|
||||
"close" => BrowserAction::Close,
|
||||
"find" => {
|
||||
let by = args
|
||||
.get("by")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'by' for find"))?;
|
||||
let value = args
|
||||
.get("value")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'value' for find"))?;
|
||||
let action = args
|
||||
.get("find_action")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'find_action' for find"))?;
|
||||
BrowserAction::Find {
|
||||
by: by.into(),
|
||||
value: value.into(),
|
||||
action: action.into(),
|
||||
fill_value: args
|
||||
.get("fill_value")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(String::from),
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
if is_computer_use_only_action(action_str) {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(unavailable_action_for_backend_error(action_str, backend)),
|
||||
});
|
||||
}
|
||||
|
||||
let action = match parse_browser_action(action_str, &args) {
|
||||
Ok(a) => a,
|
||||
Err(e) => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"Action '{action_str}' is unavailable for backend '{}'",
|
||||
match backend {
|
||||
ResolvedBackend::AgentBrowser => "agent_browser",
|
||||
ResolvedBackend::RustNative => "rust_native",
|
||||
ResolvedBackend::ComputerUse => "computer_use",
|
||||
}
|
||||
)),
|
||||
error: Some(e.to_string()),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
|
@ -1871,6 +1726,161 @@ mod native_backend {
|
|||
}
|
||||
}
|
||||
|
||||
// ── Action parsing ──────────────────────────────────────────────
|
||||
|
||||
/// Parse a JSON `args` object into a typed `BrowserAction`.
|
||||
fn parse_browser_action(action_str: &str, args: &Value) -> anyhow::Result<BrowserAction> {
|
||||
match action_str {
|
||||
"open" => {
|
||||
let url = args
|
||||
.get("url")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'url' for open action"))?;
|
||||
Ok(BrowserAction::Open { url: url.into() })
|
||||
}
|
||||
"snapshot" => Ok(BrowserAction::Snapshot {
|
||||
interactive_only: args
|
||||
.get("interactive_only")
|
||||
.and_then(serde_json::Value::as_bool)
|
||||
.unwrap_or(true),
|
||||
compact: args
|
||||
.get("compact")
|
||||
.and_then(serde_json::Value::as_bool)
|
||||
.unwrap_or(true),
|
||||
depth: args
|
||||
.get("depth")
|
||||
.and_then(serde_json::Value::as_u64)
|
||||
.map(|d| u32::try_from(d).unwrap_or(u32::MAX)),
|
||||
}),
|
||||
"click" => {
|
||||
let selector = args
|
||||
.get("selector")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'selector' for click"))?;
|
||||
Ok(BrowserAction::Click {
|
||||
selector: selector.into(),
|
||||
})
|
||||
}
|
||||
"fill" => {
|
||||
let selector = args
|
||||
.get("selector")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'selector' for fill"))?;
|
||||
let value = args
|
||||
.get("value")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'value' for fill"))?;
|
||||
Ok(BrowserAction::Fill {
|
||||
selector: selector.into(),
|
||||
value: value.into(),
|
||||
})
|
||||
}
|
||||
"type" => {
|
||||
let selector = args
|
||||
.get("selector")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'selector' for type"))?;
|
||||
let text = args
|
||||
.get("text")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'text' for type"))?;
|
||||
Ok(BrowserAction::Type {
|
||||
selector: selector.into(),
|
||||
text: text.into(),
|
||||
})
|
||||
}
|
||||
"get_text" => {
|
||||
let selector = args
|
||||
.get("selector")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'selector' for get_text"))?;
|
||||
Ok(BrowserAction::GetText {
|
||||
selector: selector.into(),
|
||||
})
|
||||
}
|
||||
"get_title" => Ok(BrowserAction::GetTitle),
|
||||
"get_url" => Ok(BrowserAction::GetUrl),
|
||||
"screenshot" => Ok(BrowserAction::Screenshot {
|
||||
path: args.get("path").and_then(|v| v.as_str()).map(String::from),
|
||||
full_page: args
|
||||
.get("full_page")
|
||||
.and_then(serde_json::Value::as_bool)
|
||||
.unwrap_or(false),
|
||||
}),
|
||||
"wait" => Ok(BrowserAction::Wait {
|
||||
selector: args
|
||||
.get("selector")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(String::from),
|
||||
ms: args.get("ms").and_then(serde_json::Value::as_u64),
|
||||
text: args.get("text").and_then(|v| v.as_str()).map(String::from),
|
||||
}),
|
||||
"press" => {
|
||||
let key = args
|
||||
.get("key")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'key' for press"))?;
|
||||
Ok(BrowserAction::Press { key: key.into() })
|
||||
}
|
||||
"hover" => {
|
||||
let selector = args
|
||||
.get("selector")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'selector' for hover"))?;
|
||||
Ok(BrowserAction::Hover {
|
||||
selector: selector.into(),
|
||||
})
|
||||
}
|
||||
"scroll" => {
|
||||
let direction = args
|
||||
.get("direction")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'direction' for scroll"))?;
|
||||
Ok(BrowserAction::Scroll {
|
||||
direction: direction.into(),
|
||||
pixels: args
|
||||
.get("pixels")
|
||||
.and_then(serde_json::Value::as_u64)
|
||||
.map(|p| u32::try_from(p).unwrap_or(u32::MAX)),
|
||||
})
|
||||
}
|
||||
"is_visible" => {
|
||||
let selector = args
|
||||
.get("selector")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'selector' for is_visible"))?;
|
||||
Ok(BrowserAction::IsVisible {
|
||||
selector: selector.into(),
|
||||
})
|
||||
}
|
||||
"close" => Ok(BrowserAction::Close),
|
||||
"find" => {
|
||||
let by = args
|
||||
.get("by")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'by' for find"))?;
|
||||
let value = args
|
||||
.get("value")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'value' for find"))?;
|
||||
let action = args
|
||||
.get("find_action")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'find_action' for find"))?;
|
||||
Ok(BrowserAction::Find {
|
||||
by: by.into(),
|
||||
value: value.into(),
|
||||
action: action.into(),
|
||||
fill_value: args
|
||||
.get("fill_value")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(String::from),
|
||||
})
|
||||
}
|
||||
other => anyhow::bail!("Unsupported browser action: {other}"),
|
||||
}
|
||||
}
|
||||
|
||||
// ── Helper functions ─────────────────────────────────────────────
|
||||
|
||||
fn is_supported_browser_action(action: &str) -> bool {
|
||||
|
|
@ -1901,6 +1911,28 @@ fn is_supported_browser_action(action: &str) -> bool {
|
|||
)
|
||||
}
|
||||
|
||||
fn is_computer_use_only_action(action: &str) -> bool {
|
||||
matches!(
|
||||
action,
|
||||
"mouse_move" | "mouse_click" | "mouse_drag" | "key_type" | "key_press" | "screen_capture"
|
||||
)
|
||||
}
|
||||
|
||||
fn backend_name(backend: ResolvedBackend) -> &'static str {
|
||||
match backend {
|
||||
ResolvedBackend::AgentBrowser => "agent_browser",
|
||||
ResolvedBackend::RustNative => "rust_native",
|
||||
ResolvedBackend::ComputerUse => "computer_use",
|
||||
}
|
||||
}
|
||||
|
||||
fn unavailable_action_for_backend_error(action: &str, backend: ResolvedBackend) -> String {
|
||||
format!(
|
||||
"Action '{action}' is unavailable for backend '{}'",
|
||||
backend_name(backend)
|
||||
)
|
||||
}
|
||||
|
||||
fn normalize_domains(domains: Vec<String>) -> Vec<String> {
|
||||
domains
|
||||
.into_iter()
|
||||
|
|
@ -2342,4 +2374,28 @@ mod tests {
|
|||
let tool = BrowserTool::new(security, vec![], None);
|
||||
assert!(tool.validate_url("https://example.com").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn computer_use_only_action_detection_is_correct() {
|
||||
assert!(is_computer_use_only_action("mouse_move"));
|
||||
assert!(is_computer_use_only_action("mouse_click"));
|
||||
assert!(is_computer_use_only_action("mouse_drag"));
|
||||
assert!(is_computer_use_only_action("key_type"));
|
||||
assert!(is_computer_use_only_action("key_press"));
|
||||
assert!(is_computer_use_only_action("screen_capture"));
|
||||
assert!(!is_computer_use_only_action("open"));
|
||||
assert!(!is_computer_use_only_action("snapshot"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn unavailable_action_error_preserves_backend_context() {
|
||||
assert_eq!(
|
||||
unavailable_action_for_backend_error("mouse_move", ResolvedBackend::AgentBrowser),
|
||||
"Action 'mouse_move' is unavailable for backend 'agent_browser'"
|
||||
);
|
||||
assert_eq!(
|
||||
unavailable_action_for_backend_error("mouse_move", ResolvedBackend::RustNative),
|
||||
"Action 'mouse_move' is unavailable for backend 'rust_native'"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -112,12 +112,12 @@ impl ComposioTool {
|
|||
action_name: &str,
|
||||
params: serde_json::Value,
|
||||
entity_id: Option<&str>,
|
||||
connected_account_id: Option<&str>,
|
||||
connected_account_ref: Option<&str>,
|
||||
) -> anyhow::Result<serde_json::Value> {
|
||||
let tool_slug = normalize_tool_slug(action_name);
|
||||
|
||||
match self
|
||||
.execute_action_v3(&tool_slug, params.clone(), entity_id, connected_account_id)
|
||||
.execute_action_v3(&tool_slug, params.clone(), entity_id, connected_account_ref)
|
||||
.await
|
||||
{
|
||||
Ok(result) => Ok(result),
|
||||
|
|
@ -130,21 +130,17 @@ impl ComposioTool {
|
|||
}
|
||||
}
|
||||
|
||||
async fn execute_action_v3(
|
||||
&self,
|
||||
fn build_execute_action_v3_request(
|
||||
tool_slug: &str,
|
||||
params: serde_json::Value,
|
||||
entity_id: Option<&str>,
|
||||
connected_account_id: Option<&str>,
|
||||
) -> anyhow::Result<serde_json::Value> {
|
||||
let url = if let Some(connected_account_id) = connected_account_id
|
||||
.map(str::trim)
|
||||
.filter(|id| !id.is_empty())
|
||||
{
|
||||
format!("{COMPOSIO_API_BASE_V3}/tools/{tool_slug}/execute/{connected_account_id}")
|
||||
} else {
|
||||
format!("{COMPOSIO_API_BASE_V3}/tools/{tool_slug}/execute")
|
||||
};
|
||||
connected_account_ref: Option<&str>,
|
||||
) -> (String, serde_json::Value) {
|
||||
let url = format!("{COMPOSIO_API_BASE_V3}/tools/{tool_slug}/execute");
|
||||
let account_ref = connected_account_ref.and_then(|candidate| {
|
||||
let trimmed_candidate = candidate.trim();
|
||||
(!trimmed_candidate.is_empty()).then_some(trimmed_candidate)
|
||||
});
|
||||
|
||||
let mut body = json!({
|
||||
"arguments": params,
|
||||
|
|
@ -153,6 +149,26 @@ impl ComposioTool {
|
|||
if let Some(entity) = entity_id {
|
||||
body["user_id"] = json!(entity);
|
||||
}
|
||||
if let Some(account_ref) = account_ref {
|
||||
body["connected_account_id"] = json!(account_ref);
|
||||
}
|
||||
|
||||
(url, body)
|
||||
}
|
||||
|
||||
async fn execute_action_v3(
|
||||
&self,
|
||||
tool_slug: &str,
|
||||
params: serde_json::Value,
|
||||
entity_id: Option<&str>,
|
||||
connected_account_ref: Option<&str>,
|
||||
) -> anyhow::Result<serde_json::Value> {
|
||||
let (url, body) = Self::build_execute_action_v3_request(
|
||||
tool_slug,
|
||||
params,
|
||||
entity_id,
|
||||
connected_account_ref,
|
||||
);
|
||||
|
||||
let resp = self
|
||||
.client
|
||||
|
|
@ -474,11 +490,11 @@ impl Tool for ComposioTool {
|
|||
})?;
|
||||
|
||||
let params = args.get("params").cloned().unwrap_or(json!({}));
|
||||
let connected_account_id =
|
||||
let connected_account_ref =
|
||||
args.get("connected_account_id").and_then(|v| v.as_str());
|
||||
|
||||
match self
|
||||
.execute_action(action_name, params, Some(entity_id), connected_account_id)
|
||||
.execute_action(action_name, params, Some(entity_id), connected_account_ref)
|
||||
.await
|
||||
{
|
||||
Ok(result) => {
|
||||
|
|
@ -594,9 +610,38 @@ async fn response_error(resp: reqwest::Response) -> String {
|
|||
}
|
||||
|
||||
if let Some(api_error) = extract_api_error_message(&body) {
|
||||
format!("HTTP {}: {api_error}", status.as_u16())
|
||||
return format!(
|
||||
"HTTP {}: {}",
|
||||
status.as_u16(),
|
||||
sanitize_error_message(&api_error)
|
||||
);
|
||||
}
|
||||
|
||||
format!("HTTP {}", status.as_u16())
|
||||
}
|
||||
|
||||
fn sanitize_error_message(message: &str) -> String {
|
||||
let mut sanitized = message.replace('\n', " ");
|
||||
for marker in [
|
||||
"connected_account_id",
|
||||
"connectedAccountId",
|
||||
"entity_id",
|
||||
"entityId",
|
||||
"user_id",
|
||||
"userId",
|
||||
] {
|
||||
sanitized = sanitized.replace(marker, "[redacted]");
|
||||
}
|
||||
|
||||
let max_chars = 240;
|
||||
if sanitized.chars().count() <= max_chars {
|
||||
sanitized
|
||||
} else {
|
||||
format!("HTTP {}: {body}", status.as_u16())
|
||||
let mut end = max_chars;
|
||||
while end > 0 && !sanitized.is_char_boundary(end) {
|
||||
end -= 1;
|
||||
}
|
||||
format!("{}...", &sanitized[..end])
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -948,4 +993,40 @@ mod tests {
|
|||
fn composio_api_base_url_is_v3() {
|
||||
assert_eq!(COMPOSIO_API_BASE_V3, "https://backend.composio.dev/api/v3");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_execute_action_v3_request_uses_fixed_endpoint_and_body_account_id() {
|
||||
let (url, body) = ComposioTool::build_execute_action_v3_request(
|
||||
"gmail-send-email",
|
||||
json!({"to": "test@example.com"}),
|
||||
Some("workspace-user"),
|
||||
Some("account-42"),
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
url,
|
||||
"https://backend.composio.dev/api/v3/tools/gmail-send-email/execute"
|
||||
);
|
||||
assert_eq!(body["arguments"]["to"], json!("test@example.com"));
|
||||
assert_eq!(body["user_id"], json!("workspace-user"));
|
||||
assert_eq!(body["connected_account_id"], json!("account-42"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_execute_action_v3_request_drops_blank_optional_fields() {
|
||||
let (url, body) = ComposioTool::build_execute_action_v3_request(
|
||||
"github-list-repos",
|
||||
json!({}),
|
||||
None,
|
||||
Some(" "),
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
url,
|
||||
"https://backend.composio.dev/api/v3/tools/github-list-repos/execute"
|
||||
);
|
||||
assert_eq!(body["arguments"], json!({}));
|
||||
assert!(body.get("connected_account_id").is_none());
|
||||
assert!(body.get("user_id").is_none());
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -16,8 +16,8 @@ const DELEGATE_TIMEOUT_SECS: u64 = 120;
|
|||
/// summarization) to purpose-built sub-agents.
|
||||
pub struct DelegateTool {
|
||||
agents: Arc<HashMap<String, DelegateAgentConfig>>,
|
||||
/// Global API key fallback (from config.api_key)
|
||||
fallback_api_key: Option<String>,
|
||||
/// Global credential fallback (from config.api_key)
|
||||
fallback_credential: Option<String>,
|
||||
/// Depth at which this tool instance lives in the delegation chain.
|
||||
depth: u32,
|
||||
}
|
||||
|
|
@ -25,11 +25,11 @@ pub struct DelegateTool {
|
|||
impl DelegateTool {
|
||||
pub fn new(
|
||||
agents: HashMap<String, DelegateAgentConfig>,
|
||||
fallback_api_key: Option<String>,
|
||||
fallback_credential: Option<String>,
|
||||
) -> Self {
|
||||
Self {
|
||||
agents: Arc::new(agents),
|
||||
fallback_api_key,
|
||||
fallback_credential,
|
||||
depth: 0,
|
||||
}
|
||||
}
|
||||
|
|
@ -39,12 +39,12 @@ impl DelegateTool {
|
|||
/// their DelegateTool via this method with `depth: parent.depth + 1`.
|
||||
pub fn with_depth(
|
||||
agents: HashMap<String, DelegateAgentConfig>,
|
||||
fallback_api_key: Option<String>,
|
||||
fallback_credential: Option<String>,
|
||||
depth: u32,
|
||||
) -> Self {
|
||||
Self {
|
||||
agents: Arc::new(agents),
|
||||
fallback_api_key,
|
||||
fallback_credential,
|
||||
depth,
|
||||
}
|
||||
}
|
||||
|
|
@ -165,13 +165,15 @@ impl Tool for DelegateTool {
|
|||
}
|
||||
|
||||
// Create provider for this agent
|
||||
let api_key = agent_config
|
||||
let provider_credential_owned = agent_config
|
||||
.api_key
|
||||
.as_deref()
|
||||
.or(self.fallback_api_key.as_deref());
|
||||
.clone()
|
||||
.or_else(|| self.fallback_credential.clone());
|
||||
#[allow(clippy::option_as_ref_deref)]
|
||||
let provider_credential = provider_credential_owned.as_ref().map(String::as_str);
|
||||
|
||||
let provider: Box<dyn Provider> =
|
||||
match providers::create_provider(&agent_config.provider, api_key) {
|
||||
match providers::create_provider(&agent_config.provider, provider_credential) {
|
||||
Ok(p) => p,
|
||||
Err(e) => {
|
||||
return Ok(ToolResult {
|
||||
|
|
@ -268,7 +270,7 @@ mod tests {
|
|||
provider: "openrouter".to_string(),
|
||||
model: "anthropic/claude-sonnet-4-20250514".to_string(),
|
||||
system_prompt: None,
|
||||
api_key: Some("sk-test".to_string()),
|
||||
api_key: Some("delegate-test-credential".to_string()),
|
||||
temperature: None,
|
||||
max_depth: 2,
|
||||
},
|
||||
|
|
|
|||
|
|
@ -28,13 +28,22 @@ impl GitOperationsTool {
|
|||
if arg_lower.starts_with("--exec=")
|
||||
|| arg_lower.starts_with("--upload-pack=")
|
||||
|| arg_lower.starts_with("--receive-pack=")
|
||||
|| arg_lower.starts_with("--pager=")
|
||||
|| arg_lower.starts_with("--editor=")
|
||||
|| arg_lower == "--no-verify"
|
||||
|| arg_lower.contains("$(")
|
||||
|| arg_lower.contains('`')
|
||||
|| arg.contains('|')
|
||||
|| arg.contains(';')
|
||||
|| arg.contains('>')
|
||||
{
|
||||
anyhow::bail!("Blocked potentially dangerous git argument: {arg}");
|
||||
}
|
||||
// Block `-c` config injection (exact match or `-c=...` prefix).
|
||||
// This must not false-positive on `--cached` or `-cached`.
|
||||
if arg_lower == "-c" || arg_lower.starts_with("-c=") {
|
||||
anyhow::bail!("Blocked potentially dangerous git argument: {arg}");
|
||||
}
|
||||
result.push(arg.to_string());
|
||||
}
|
||||
Ok(result)
|
||||
|
|
@ -129,6 +138,9 @@ impl GitOperationsTool {
|
|||
.and_then(|v| v.as_bool())
|
||||
.unwrap_or(false);
|
||||
|
||||
// Validate files argument against injection patterns
|
||||
self.sanitize_git_args(files)?;
|
||||
|
||||
let mut git_args = vec!["diff", "--unified=3"];
|
||||
if cached {
|
||||
git_args.push("--cached");
|
||||
|
|
@ -267,6 +279,14 @@ impl GitOperationsTool {
|
|||
})
|
||||
}
|
||||
|
||||
fn truncate_commit_message(message: &str) -> String {
|
||||
if message.chars().count() > 2000 {
|
||||
format!("{}...", message.chars().take(1997).collect::<String>())
|
||||
} else {
|
||||
message.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
async fn git_commit(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
let message = args
|
||||
.get("message")
|
||||
|
|
@ -286,11 +306,7 @@ impl GitOperationsTool {
|
|||
}
|
||||
|
||||
// Limit message length
|
||||
let message = if sanitized.len() > 2000 {
|
||||
format!("{}...", &sanitized[..1997])
|
||||
} else {
|
||||
sanitized
|
||||
};
|
||||
let message = Self::truncate_commit_message(&sanitized);
|
||||
|
||||
let output = self.run_git_command(&["commit", "-m", &message]).await;
|
||||
|
||||
|
|
@ -314,6 +330,9 @@ impl GitOperationsTool {
|
|||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'paths' parameter"))?;
|
||||
|
||||
// Validate paths against injection patterns
|
||||
self.sanitize_git_args(paths)?;
|
||||
|
||||
let output = self.run_git_command(&["add", "--", paths]).await;
|
||||
|
||||
match output {
|
||||
|
|
@ -574,6 +593,52 @@ mod tests {
|
|||
assert!(tool.sanitize_git_args("arg; rm file").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sanitize_git_blocks_pager_editor_injection() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let tool = test_tool(tmp.path());
|
||||
|
||||
assert!(tool.sanitize_git_args("--pager=less").is_err());
|
||||
assert!(tool.sanitize_git_args("--editor=vim").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sanitize_git_blocks_config_injection() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let tool = test_tool(tmp.path());
|
||||
|
||||
// Exact `-c` flag (config injection)
|
||||
assert!(tool.sanitize_git_args("-c core.sshCommand=evil").is_err());
|
||||
assert!(tool.sanitize_git_args("-c=core.pager=less").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sanitize_git_blocks_no_verify() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let tool = test_tool(tmp.path());
|
||||
|
||||
assert!(tool.sanitize_git_args("--no-verify").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sanitize_git_blocks_redirect_in_args() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let tool = test_tool(tmp.path());
|
||||
|
||||
assert!(tool.sanitize_git_args("file.txt > /tmp/out").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sanitize_git_cached_not_blocked() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let tool = test_tool(tmp.path());
|
||||
|
||||
// --cached must NOT be blocked by the `-c` check
|
||||
assert!(tool.sanitize_git_args("--cached").is_ok());
|
||||
// Other safe flags starting with -c prefix
|
||||
assert!(tool.sanitize_git_args("-cached").is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sanitize_git_allows_safe() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
|
|
@ -583,6 +648,8 @@ mod tests {
|
|||
assert!(tool.sanitize_git_args("main").is_ok());
|
||||
assert!(tool.sanitize_git_args("feature/test-branch").is_ok());
|
||||
assert!(tool.sanitize_git_args("--cached").is_ok());
|
||||
assert!(tool.sanitize_git_args("src/main.rs").is_ok());
|
||||
assert!(tool.sanitize_git_args(".").is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -691,4 +758,12 @@ mod tests {
|
|||
.unwrap_or("")
|
||||
.contains("Unknown operation"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn truncates_multibyte_commit_message_without_panicking() {
|
||||
let long = "🦀".repeat(2500);
|
||||
let truncated = GitOperationsTool::truncate_commit_message(&long);
|
||||
|
||||
assert_eq!(truncated.chars().count(), 2000);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -124,10 +124,11 @@ impl Tool for HardwareBoardInfoTool {
|
|||
});
|
||||
}
|
||||
Err(e) => {
|
||||
output.push_str(&format!(
|
||||
"probe-rs attach failed: {}. Using static info.\n\n",
|
||||
e
|
||||
));
|
||||
use std::fmt::Write;
|
||||
let _ = write!(
|
||||
output,
|
||||
"probe-rs attach failed: {e}. Using static info.\n\n"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -135,13 +136,15 @@ impl Tool for HardwareBoardInfoTool {
|
|||
if let Some(info) = self.static_info_for_board(board) {
|
||||
output.push_str(&info);
|
||||
if let Some(mem) = memory_map_static(board) {
|
||||
output.push_str(&format!("\n\n**Memory map:**\n{}", mem));
|
||||
use std::fmt::Write;
|
||||
let _ = write!(output, "\n\n**Memory map:**\n{mem}");
|
||||
}
|
||||
} else {
|
||||
output.push_str(&format!(
|
||||
"Board '{}' configured. No static info available.",
|
||||
board
|
||||
));
|
||||
use std::fmt::Write;
|
||||
let _ = write!(
|
||||
output,
|
||||
"Board '{board}' configured. No static info available."
|
||||
);
|
||||
}
|
||||
|
||||
Ok(ToolResult {
|
||||
|
|
|
|||
|
|
@ -122,14 +122,16 @@ impl Tool for HardwareMemoryMapTool {
|
|||
|
||||
if !probe_ok {
|
||||
if let Some(map) = self.static_map_for_board(board) {
|
||||
output.push_str(&format!("**{}** (from datasheet):\n{}", board, map));
|
||||
use std::fmt::Write;
|
||||
let _ = write!(output, "**{board}** (from datasheet):\n{map}");
|
||||
} else {
|
||||
use std::fmt::Write;
|
||||
let known: Vec<&str> = MEMORY_MAPS.iter().map(|(b, _)| *b).collect();
|
||||
output.push_str(&format!(
|
||||
"No memory map for board '{}'. Known boards: {}",
|
||||
board,
|
||||
let _ = write!(
|
||||
output,
|
||||
"No memory map for board '{board}'. Known boards: {}",
|
||||
known.join(", ")
|
||||
));
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -94,14 +94,16 @@ impl Tool for HardwareMemoryReadTool {
|
|||
.get("address")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("0x20000000");
|
||||
let address = parse_hex_address(address_str).unwrap_or(NUCLEO_RAM_BASE);
|
||||
let _address = parse_hex_address(address_str).unwrap_or(NUCLEO_RAM_BASE);
|
||||
|
||||
let length = args.get("length").and_then(|v| v.as_u64()).unwrap_or(128) as usize;
|
||||
let length = length.min(256).max(1);
|
||||
let requested_length = args.get("length").and_then(|v| v.as_u64()).unwrap_or(128);
|
||||
let _length = usize::try_from(requested_length)
|
||||
.unwrap_or(256)
|
||||
.clamp(1, 256);
|
||||
|
||||
#[cfg(feature = "probe")]
|
||||
{
|
||||
match probe_read_memory(chip.unwrap(), address, length) {
|
||||
match probe_read_memory(chip.unwrap(), _address, _length) {
|
||||
Ok(output) => {
|
||||
return Ok(ToolResult {
|
||||
success: true,
|
||||
|
|
|
|||
|
|
@ -749,4 +749,54 @@ mod tests {
|
|||
let _ = HttpRequestTool::redact_headers_for_display(&headers);
|
||||
assert_eq!(headers[0].1, "Bearer real-token");
|
||||
}
|
||||
|
||||
// ── SSRF: alternate IP notation bypass defense-in-depth ─────────
|
||||
//
|
||||
// Rust's IpAddr::parse() rejects non-standard notations (octal, hex,
|
||||
// decimal integer, zero-padded). These tests document that property
|
||||
// so regressions are caught if the parsing strategy ever changes.
|
||||
|
||||
#[test]
|
||||
fn ssrf_octal_loopback_not_parsed_as_ip() {
|
||||
// 0177.0.0.1 is octal for 127.0.0.1 in some languages, but
|
||||
// Rust's IpAddr rejects it — it falls through as a hostname.
|
||||
assert!(!is_private_or_local_host("0177.0.0.1"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ssrf_hex_loopback_not_parsed_as_ip() {
|
||||
// 0x7f000001 is hex for 127.0.0.1 in some languages.
|
||||
assert!(!is_private_or_local_host("0x7f000001"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ssrf_decimal_loopback_not_parsed_as_ip() {
|
||||
// 2130706433 is decimal for 127.0.0.1 in some languages.
|
||||
assert!(!is_private_or_local_host("2130706433"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ssrf_zero_padded_loopback_not_parsed_as_ip() {
|
||||
// 127.000.000.001 uses zero-padded octets.
|
||||
assert!(!is_private_or_local_host("127.000.000.001"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ssrf_alternate_notations_rejected_by_validate_url() {
|
||||
// Even if is_private_or_local_host doesn't flag these, they
|
||||
// fail the allowlist because they're treated as hostnames.
|
||||
let tool = test_tool(vec!["example.com"]);
|
||||
for notation in [
|
||||
"http://0177.0.0.1",
|
||||
"http://0x7f000001",
|
||||
"http://2130706433",
|
||||
"http://127.000.000.001",
|
||||
] {
|
||||
let err = tool.validate_url(notation).unwrap_err().to_string();
|
||||
assert!(
|
||||
err.contains("allowed_domains"),
|
||||
"Expected allowlist rejection for {notation}, got: {err}"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -87,7 +87,7 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn forget_existing() {
|
||||
let (_tmp, mem) = test_mem();
|
||||
mem.store("temp", "temporary", MemoryCategory::Conversation)
|
||||
mem.store("temp", "temporary", MemoryCategory::Conversation, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
|
|
|
|||
|
|
@ -55,7 +55,7 @@ impl Tool for MemoryRecallTool {
|
|||
.and_then(serde_json::Value::as_u64)
|
||||
.map_or(5, |v| v as usize);
|
||||
|
||||
match self.memory.recall(query, limit).await {
|
||||
match self.memory.recall(query, limit, None).await {
|
||||
Ok(entries) if entries.is_empty() => Ok(ToolResult {
|
||||
success: true,
|
||||
output: "No memories found matching that query.".into(),
|
||||
|
|
@ -112,10 +112,10 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn recall_finds_match() {
|
||||
let (_tmp, mem) = seeded_mem();
|
||||
mem.store("lang", "User prefers Rust", MemoryCategory::Core)
|
||||
mem.store("lang", "User prefers Rust", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("tz", "Timezone is EST", MemoryCategory::Core)
|
||||
mem.store("tz", "Timezone is EST", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
|
|
@ -134,6 +134,7 @@ mod tests {
|
|||
&format!("k{i}"),
|
||||
&format!("Rust fact {i}"),
|
||||
MemoryCategory::Core,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
|
|
|||
|
|
@ -64,7 +64,7 @@ impl Tool for MemoryStoreTool {
|
|||
_ => MemoryCategory::Core,
|
||||
};
|
||||
|
||||
match self.memory.store(key, content, category).await {
|
||||
match self.memory.store(key, content, category, None).await {
|
||||
Ok(()) => Ok(ToolResult {
|
||||
success: true,
|
||||
output: format!("Stored memory: {key}"),
|
||||
|
|
|
|||
|
|
@ -19,7 +19,9 @@ pub mod image_info;
|
|||
pub mod memory_forget;
|
||||
pub mod memory_recall;
|
||||
pub mod memory_store;
|
||||
pub mod pushover;
|
||||
pub mod schedule;
|
||||
pub mod schema;
|
||||
pub mod screenshot;
|
||||
pub mod shell;
|
||||
pub mod traits;
|
||||
|
|
@ -45,7 +47,9 @@ pub use image_info::ImageInfoTool;
|
|||
pub use memory_forget::MemoryForgetTool;
|
||||
pub use memory_recall::MemoryRecallTool;
|
||||
pub use memory_store::MemoryStoreTool;
|
||||
pub use pushover::PushoverTool;
|
||||
pub use schedule::ScheduleTool;
|
||||
pub use schema::{CleaningStrategy, SchemaCleanr};
|
||||
pub use screenshot::ScreenshotTool;
|
||||
pub use shell::ShellTool;
|
||||
pub use traits::Tool;
|
||||
|
|
@ -141,6 +145,10 @@ pub fn all_tools_with_runtime(
|
|||
security.clone(),
|
||||
workspace_dir.to_path_buf(),
|
||||
)),
|
||||
Box::new(PushoverTool::new(
|
||||
security.clone(),
|
||||
workspace_dir.to_path_buf(),
|
||||
)),
|
||||
];
|
||||
|
||||
if browser_config.enabled {
|
||||
|
|
@ -195,9 +203,13 @@ pub fn all_tools_with_runtime(
|
|||
.iter()
|
||||
.map(|(name, cfg)| (name.clone(), cfg.clone()))
|
||||
.collect();
|
||||
let delegate_fallback_credential = fallback_api_key.and_then(|value| {
|
||||
let trimmed_value = value.trim();
|
||||
(!trimmed_value.is_empty()).then(|| trimmed_value.to_owned())
|
||||
});
|
||||
tools.push(Box::new(DelegateTool::new(
|
||||
delegate_agents,
|
||||
fallback_api_key.map(String::from),
|
||||
delegate_fallback_credential,
|
||||
)));
|
||||
}
|
||||
|
||||
|
|
@ -261,6 +273,7 @@ mod tests {
|
|||
let names: Vec<&str> = tools.iter().map(|t| t.name()).collect();
|
||||
assert!(!names.contains(&"browser_open"));
|
||||
assert!(names.contains(&"schedule"));
|
||||
assert!(names.contains(&"pushover"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -298,6 +311,7 @@ mod tests {
|
|||
);
|
||||
let names: Vec<&str> = tools.iter().map(|t| t.name()).collect();
|
||||
assert!(names.contains(&"browser_open"));
|
||||
assert!(names.contains(&"pushover"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -432,7 +446,7 @@ mod tests {
|
|||
&http,
|
||||
tmp.path(),
|
||||
&agents,
|
||||
Some("sk-test"),
|
||||
Some("delegate-test-credential"),
|
||||
&cfg,
|
||||
);
|
||||
let names: Vec<&str> = tools.iter().map(|t| t.name()).collect();
|
||||
|
|
|
|||
442
src/tools/pushover.rs
Normal file
442
src/tools/pushover.rs
Normal file
|
|
@ -0,0 +1,442 @@
|
|||
use super::traits::{Tool, ToolResult};
|
||||
use crate::security::SecurityPolicy;
|
||||
use async_trait::async_trait;
|
||||
use reqwest::Client;
|
||||
use serde_json::json;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
const PUSHOVER_API_URL: &str = "https://api.pushover.net/1/messages.json";
|
||||
const PUSHOVER_REQUEST_TIMEOUT_SECS: u64 = 15;
|
||||
|
||||
pub struct PushoverTool {
|
||||
client: Client,
|
||||
security: Arc<SecurityPolicy>,
|
||||
workspace_dir: PathBuf,
|
||||
}
|
||||
|
||||
impl PushoverTool {
|
||||
pub fn new(security: Arc<SecurityPolicy>, workspace_dir: PathBuf) -> Self {
|
||||
let client = Client::builder()
|
||||
.timeout(Duration::from_secs(PUSHOVER_REQUEST_TIMEOUT_SECS))
|
||||
.build()
|
||||
.unwrap_or_else(|_| Client::new());
|
||||
|
||||
Self {
|
||||
client,
|
||||
security,
|
||||
workspace_dir,
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_env_value(raw: &str) -> String {
|
||||
let raw = raw.trim();
|
||||
|
||||
let unquoted = if raw.len() >= 2
|
||||
&& ((raw.starts_with('"') && raw.ends_with('"'))
|
||||
|| (raw.starts_with('\'') && raw.ends_with('\'')))
|
||||
{
|
||||
&raw[1..raw.len() - 1]
|
||||
} else {
|
||||
raw
|
||||
};
|
||||
|
||||
// Keep support for inline comments in unquoted values:
|
||||
// KEY=value # comment
|
||||
unquoted.split_once(" #").map_or_else(
|
||||
|| unquoted.trim().to_string(),
|
||||
|(value, _)| value.trim().to_string(),
|
||||
)
|
||||
}
|
||||
|
||||
fn get_credentials(&self) -> anyhow::Result<(String, String)> {
|
||||
let env_path = self.workspace_dir.join(".env");
|
||||
let content = std::fs::read_to_string(&env_path)
|
||||
.map_err(|e| anyhow::anyhow!("Failed to read {}: {}", env_path.display(), e))?;
|
||||
|
||||
let mut token = None;
|
||||
let mut user_key = None;
|
||||
|
||||
for line in content.lines() {
|
||||
let line = line.trim();
|
||||
if line.starts_with('#') || line.is_empty() {
|
||||
continue;
|
||||
}
|
||||
let line = line.strip_prefix("export ").map(str::trim).unwrap_or(line);
|
||||
if let Some((key, value)) = line.split_once('=') {
|
||||
let key = key.trim();
|
||||
let value = Self::parse_env_value(value);
|
||||
|
||||
if key.eq_ignore_ascii_case("PUSHOVER_TOKEN") {
|
||||
token = Some(value);
|
||||
} else if key.eq_ignore_ascii_case("PUSHOVER_USER_KEY") {
|
||||
user_key = Some(value);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let token = token.ok_or_else(|| anyhow::anyhow!("PUSHOVER_TOKEN not found in .env"))?;
|
||||
let user_key =
|
||||
user_key.ok_or_else(|| anyhow::anyhow!("PUSHOVER_USER_KEY not found in .env"))?;
|
||||
|
||||
Ok((token, user_key))
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for PushoverTool {
|
||||
fn name(&self) -> &str {
|
||||
"pushover"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Send a Pushover notification to your device. Requires PUSHOVER_TOKEN and PUSHOVER_USER_KEY in .env file."
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"message": {
|
||||
"type": "string",
|
||||
"description": "The notification message to send"
|
||||
},
|
||||
"title": {
|
||||
"type": "string",
|
||||
"description": "Optional notification title"
|
||||
},
|
||||
"priority": {
|
||||
"type": "integer",
|
||||
"enum": [-2, -1, 0, 1, 2],
|
||||
"description": "Message priority: -2 (lowest/silent), -1 (low/no sound), 0 (normal), 1 (high), 2 (emergency/repeating)"
|
||||
},
|
||||
"sound": {
|
||||
"type": "string",
|
||||
"description": "Notification sound override (e.g., 'pushover', 'bike', 'bugle', 'cashregister', etc.)"
|
||||
}
|
||||
},
|
||||
"required": ["message"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
if !self.security.can_act() {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("Action blocked: autonomy is read-only".into()),
|
||||
});
|
||||
}
|
||||
|
||||
if !self.security.record_action() {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("Action blocked: rate limit exceeded".into()),
|
||||
});
|
||||
}
|
||||
|
||||
let message = args
|
||||
.get("message")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(str::trim)
|
||||
.filter(|v| !v.is_empty())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'message' parameter"))?
|
||||
.to_string();
|
||||
|
||||
let title = args.get("title").and_then(|v| v.as_str()).map(String::from);
|
||||
|
||||
let priority = match args.get("priority").and_then(|v| v.as_i64()) {
|
||||
Some(value) if (-2..=2).contains(&value) => Some(value),
|
||||
Some(value) => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"Invalid 'priority': {value}. Expected integer in range -2..=2"
|
||||
)),
|
||||
})
|
||||
}
|
||||
None => None,
|
||||
};
|
||||
|
||||
let sound = args.get("sound").and_then(|v| v.as_str()).map(String::from);
|
||||
|
||||
let (token, user_key) = self.get_credentials()?;
|
||||
|
||||
let mut form = reqwest::multipart::Form::new()
|
||||
.text("token", token)
|
||||
.text("user", user_key)
|
||||
.text("message", message);
|
||||
|
||||
if let Some(title) = title {
|
||||
form = form.text("title", title);
|
||||
}
|
||||
|
||||
if let Some(priority) = priority {
|
||||
form = form.text("priority", priority.to_string());
|
||||
}
|
||||
|
||||
if let Some(sound) = sound {
|
||||
form = form.text("sound", sound);
|
||||
}
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(PUSHOVER_API_URL)
|
||||
.multipart(form)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let status = response.status();
|
||||
let body = response.text().await.unwrap_or_default();
|
||||
|
||||
if !status.is_success() {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: body,
|
||||
error: Some(format!("Pushover API returned status {}", status)),
|
||||
});
|
||||
}
|
||||
|
||||
let api_status = serde_json::from_str::<serde_json::Value>(&body)
|
||||
.ok()
|
||||
.and_then(|json| json.get("status").and_then(|value| value.as_i64()));
|
||||
|
||||
if api_status == Some(1) {
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: format!(
|
||||
"Pushover notification sent successfully. Response: {}",
|
||||
body
|
||||
),
|
||||
error: None,
|
||||
})
|
||||
} else {
|
||||
Ok(ToolResult {
|
||||
success: false,
|
||||
output: body,
|
||||
error: Some("Pushover API returned an application-level error".into()),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::security::AutonomyLevel;
|
||||
use std::fs;
|
||||
use tempfile::TempDir;
|
||||
|
||||
fn test_security(level: AutonomyLevel, max_actions_per_hour: u32) -> Arc<SecurityPolicy> {
|
||||
Arc::new(SecurityPolicy {
|
||||
autonomy: level,
|
||||
max_actions_per_hour,
|
||||
workspace_dir: std::env::temp_dir(),
|
||||
..SecurityPolicy::default()
|
||||
})
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pushover_tool_name() {
|
||||
let tool = PushoverTool::new(
|
||||
test_security(AutonomyLevel::Full, 100),
|
||||
PathBuf::from("/tmp"),
|
||||
);
|
||||
assert_eq!(tool.name(), "pushover");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pushover_tool_description() {
|
||||
let tool = PushoverTool::new(
|
||||
test_security(AutonomyLevel::Full, 100),
|
||||
PathBuf::from("/tmp"),
|
||||
);
|
||||
assert!(!tool.description().is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pushover_tool_has_parameters_schema() {
|
||||
let tool = PushoverTool::new(
|
||||
test_security(AutonomyLevel::Full, 100),
|
||||
PathBuf::from("/tmp"),
|
||||
);
|
||||
let schema = tool.parameters_schema();
|
||||
assert_eq!(schema["type"], "object");
|
||||
assert!(schema["properties"].get("message").is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pushover_tool_requires_message() {
|
||||
let tool = PushoverTool::new(
|
||||
test_security(AutonomyLevel::Full, 100),
|
||||
PathBuf::from("/tmp"),
|
||||
);
|
||||
let schema = tool.parameters_schema();
|
||||
let required = schema["required"].as_array().unwrap();
|
||||
assert!(required.contains(&serde_json::Value::String("message".to_string())));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn credentials_parsed_from_env_file() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let env_path = tmp.path().join(".env");
|
||||
fs::write(
|
||||
&env_path,
|
||||
"PUSHOVER_TOKEN=testtoken123\nPUSHOVER_USER_KEY=userkey456\n",
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let tool = PushoverTool::new(
|
||||
test_security(AutonomyLevel::Full, 100),
|
||||
tmp.path().to_path_buf(),
|
||||
);
|
||||
let result = tool.get_credentials();
|
||||
|
||||
assert!(result.is_ok());
|
||||
let (token, user_key) = result.unwrap();
|
||||
assert_eq!(token, "testtoken123");
|
||||
assert_eq!(user_key, "userkey456");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn credentials_fail_without_env_file() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let tool = PushoverTool::new(
|
||||
test_security(AutonomyLevel::Full, 100),
|
||||
tmp.path().to_path_buf(),
|
||||
);
|
||||
let result = tool.get_credentials();
|
||||
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn credentials_fail_without_token() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let env_path = tmp.path().join(".env");
|
||||
fs::write(&env_path, "PUSHOVER_USER_KEY=userkey456\n").unwrap();
|
||||
|
||||
let tool = PushoverTool::new(
|
||||
test_security(AutonomyLevel::Full, 100),
|
||||
tmp.path().to_path_buf(),
|
||||
);
|
||||
let result = tool.get_credentials();
|
||||
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn credentials_fail_without_user_key() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let env_path = tmp.path().join(".env");
|
||||
fs::write(&env_path, "PUSHOVER_TOKEN=testtoken123\n").unwrap();
|
||||
|
||||
let tool = PushoverTool::new(
|
||||
test_security(AutonomyLevel::Full, 100),
|
||||
tmp.path().to_path_buf(),
|
||||
);
|
||||
let result = tool.get_credentials();
|
||||
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn credentials_ignore_comments() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let env_path = tmp.path().join(".env");
|
||||
fs::write(&env_path, "# This is a comment\nPUSHOVER_TOKEN=realtoken\n# Another comment\nPUSHOVER_USER_KEY=realuser\n").unwrap();
|
||||
|
||||
let tool = PushoverTool::new(
|
||||
test_security(AutonomyLevel::Full, 100),
|
||||
tmp.path().to_path_buf(),
|
||||
);
|
||||
let result = tool.get_credentials();
|
||||
|
||||
assert!(result.is_ok());
|
||||
let (token, user_key) = result.unwrap();
|
||||
assert_eq!(token, "realtoken");
|
||||
assert_eq!(user_key, "realuser");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pushover_tool_supports_priority() {
|
||||
let tool = PushoverTool::new(
|
||||
test_security(AutonomyLevel::Full, 100),
|
||||
PathBuf::from("/tmp"),
|
||||
);
|
||||
let schema = tool.parameters_schema();
|
||||
assert!(schema["properties"].get("priority").is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pushover_tool_supports_sound() {
|
||||
let tool = PushoverTool::new(
|
||||
test_security(AutonomyLevel::Full, 100),
|
||||
PathBuf::from("/tmp"),
|
||||
);
|
||||
let schema = tool.parameters_schema();
|
||||
assert!(schema["properties"].get("sound").is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn credentials_support_export_and_quoted_values() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let env_path = tmp.path().join(".env");
|
||||
fs::write(
|
||||
&env_path,
|
||||
"export PUSHOVER_TOKEN=\"quotedtoken\"\nPUSHOVER_USER_KEY='quoteduser'\n",
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let tool = PushoverTool::new(
|
||||
test_security(AutonomyLevel::Full, 100),
|
||||
tmp.path().to_path_buf(),
|
||||
);
|
||||
let result = tool.get_credentials();
|
||||
|
||||
assert!(result.is_ok());
|
||||
let (token, user_key) = result.unwrap();
|
||||
assert_eq!(token, "quotedtoken");
|
||||
assert_eq!(user_key, "quoteduser");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn execute_blocks_readonly_mode() {
|
||||
let tool = PushoverTool::new(
|
||||
test_security(AutonomyLevel::ReadOnly, 100),
|
||||
PathBuf::from("/tmp"),
|
||||
);
|
||||
|
||||
let result = tool.execute(json!({"message": "hello"})).await.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.unwrap().contains("read-only"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn execute_blocks_rate_limit() {
|
||||
let tool = PushoverTool::new(test_security(AutonomyLevel::Full, 0), PathBuf::from("/tmp"));
|
||||
|
||||
let result = tool.execute(json!({"message": "hello"})).await.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.unwrap().contains("rate limit"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn execute_rejects_priority_out_of_range() {
|
||||
let tool = PushoverTool::new(
|
||||
test_security(AutonomyLevel::Full, 100),
|
||||
PathBuf::from("/tmp"),
|
||||
);
|
||||
|
||||
let result = tool
|
||||
.execute(json!({"message": "hello", "priority": 5}))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(!result.success);
|
||||
assert!(result.error.unwrap().contains("-2..=2"));
|
||||
}
|
||||
}
|
||||
838
src/tools/schema.rs
Normal file
838
src/tools/schema.rs
Normal file
|
|
@ -0,0 +1,838 @@
|
|||
//! JSON Schema cleaning and validation for LLM tool-calling compatibility.
|
||||
//!
|
||||
//! Different providers support different subsets of JSON Schema. This module
|
||||
//! normalizes tool schemas to improve cross-provider compatibility while
|
||||
//! preserving semantic intent.
|
||||
//!
|
||||
//! ## What this module does
|
||||
//!
|
||||
//! 1. Removes unsupported keywords per provider strategy
|
||||
//! 2. Resolves local `$ref` entries from `$defs` and `definitions`
|
||||
//! 3. Flattens literal `anyOf` / `oneOf` unions into `enum`
|
||||
//! 4. Strips nullable variants from unions and `type` arrays
|
||||
//! 5. Converts `const` to single-value `enum`
|
||||
//! 6. Detects circular references and stops recursion safely
|
||||
//!
|
||||
//! # Example
|
||||
//!
|
||||
//! ```rust
|
||||
//! use serde_json::json;
|
||||
//! use zeroclaw::tools::schema::SchemaCleanr;
|
||||
//!
|
||||
//! let dirty_schema = json!({
|
||||
//! "type": "object",
|
||||
//! "properties": {
|
||||
//! "name": {
|
||||
//! "type": "string",
|
||||
//! "minLength": 1, // Gemini rejects this
|
||||
//! "pattern": "^[a-z]+$" // Gemini rejects this
|
||||
//! },
|
||||
//! "age": {
|
||||
//! "$ref": "#/$defs/Age" // Needs resolution
|
||||
//! }
|
||||
//! },
|
||||
//! "$defs": {
|
||||
//! "Age": {
|
||||
//! "type": "integer",
|
||||
//! "minimum": 0 // Gemini rejects this
|
||||
//! }
|
||||
//! }
|
||||
//! });
|
||||
//!
|
||||
//! let cleaned = SchemaCleanr::clean_for_gemini(dirty_schema);
|
||||
//!
|
||||
//! // Result:
|
||||
//! // {
|
||||
//! // "type": "object",
|
||||
//! // "properties": {
|
||||
//! // "name": { "type": "string" },
|
||||
//! // "age": { "type": "integer" }
|
||||
//! // }
|
||||
//! // }
|
||||
//! ```
|
||||
//!
|
||||
use serde_json::{json, Map, Value};
|
||||
use std::collections::{HashMap, HashSet};
|
||||
|
||||
/// Keywords that Gemini rejects for tool schemas.
|
||||
pub const GEMINI_UNSUPPORTED_KEYWORDS: &[&str] = &[
|
||||
// Schema composition
|
||||
"$ref",
|
||||
"$schema",
|
||||
"$id",
|
||||
"$defs",
|
||||
"definitions",
|
||||
// Property constraints
|
||||
"additionalProperties",
|
||||
"patternProperties",
|
||||
// String constraints
|
||||
"minLength",
|
||||
"maxLength",
|
||||
"pattern",
|
||||
"format",
|
||||
// Number constraints
|
||||
"minimum",
|
||||
"maximum",
|
||||
"multipleOf",
|
||||
// Array constraints
|
||||
"minItems",
|
||||
"maxItems",
|
||||
"uniqueItems",
|
||||
// Object constraints
|
||||
"minProperties",
|
||||
"maxProperties",
|
||||
// Non-standard
|
||||
"examples", // OpenAPI keyword, not JSON Schema
|
||||
];
|
||||
|
||||
/// Keywords that should be preserved during cleaning (metadata).
|
||||
const SCHEMA_META_KEYS: &[&str] = &["description", "title", "default"];
|
||||
|
||||
/// Schema cleaning strategies for different LLM providers.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum CleaningStrategy {
|
||||
/// Gemini (Google AI / Vertex AI) - Most restrictive
|
||||
Gemini,
|
||||
/// Anthropic Claude - Moderately permissive
|
||||
Anthropic,
|
||||
/// OpenAI GPT - Most permissive
|
||||
OpenAI,
|
||||
/// Conservative: Remove only universally unsupported keywords
|
||||
Conservative,
|
||||
}
|
||||
|
||||
impl CleaningStrategy {
|
||||
/// Get the list of unsupported keywords for this strategy.
|
||||
pub fn unsupported_keywords(&self) -> &'static [&'static str] {
|
||||
match self {
|
||||
Self::Gemini => GEMINI_UNSUPPORTED_KEYWORDS,
|
||||
Self::Anthropic => &["$ref", "$defs", "definitions"], // Anthropic doesn't resolve refs
|
||||
Self::OpenAI => &[], // OpenAI is most permissive
|
||||
Self::Conservative => &["$ref", "$defs", "definitions", "additionalProperties"],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// JSON Schema cleaner optimized for LLM tool calling.
|
||||
pub struct SchemaCleanr;
|
||||
|
||||
impl SchemaCleanr {
|
||||
/// Clean schema for Gemini compatibility (strictest).
|
||||
///
|
||||
/// This is the most aggressive cleaning strategy, removing all keywords
|
||||
/// that Gemini's API rejects.
|
||||
pub fn clean_for_gemini(schema: Value) -> Value {
|
||||
Self::clean(schema, CleaningStrategy::Gemini)
|
||||
}
|
||||
|
||||
/// Clean schema for Anthropic compatibility.
|
||||
pub fn clean_for_anthropic(schema: Value) -> Value {
|
||||
Self::clean(schema, CleaningStrategy::Anthropic)
|
||||
}
|
||||
|
||||
/// Clean schema for OpenAI compatibility (most permissive).
|
||||
pub fn clean_for_openai(schema: Value) -> Value {
|
||||
Self::clean(schema, CleaningStrategy::OpenAI)
|
||||
}
|
||||
|
||||
/// Clean schema with specified strategy.
|
||||
pub fn clean(schema: Value, strategy: CleaningStrategy) -> Value {
|
||||
// Extract $defs for reference resolution
|
||||
let defs = if let Some(obj) = schema.as_object() {
|
||||
Self::extract_defs(obj)
|
||||
} else {
|
||||
HashMap::new()
|
||||
};
|
||||
|
||||
Self::clean_with_defs(schema, &defs, strategy, &mut HashSet::new())
|
||||
}
|
||||
|
||||
/// Validate that a schema is suitable for LLM tool calling.
|
||||
///
|
||||
/// Returns an error if the schema is invalid or missing required fields.
|
||||
pub fn validate(schema: &Value) -> anyhow::Result<()> {
|
||||
let obj = schema
|
||||
.as_object()
|
||||
.ok_or_else(|| anyhow::anyhow!("Schema must be an object"))?;
|
||||
|
||||
// Must have 'type' field
|
||||
if !obj.contains_key("type") {
|
||||
anyhow::bail!("Schema missing required 'type' field");
|
||||
}
|
||||
|
||||
// If type is 'object', should have 'properties'
|
||||
if let Some(Value::String(t)) = obj.get("type") {
|
||||
if t == "object" && !obj.contains_key("properties") {
|
||||
tracing::warn!("Object schema without 'properties' field may cause issues");
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------
|
||||
// Internal implementation
|
||||
// --------------------------------------------------------------------
|
||||
|
||||
/// Extract $defs and definitions into a flat map for reference resolution.
|
||||
fn extract_defs(obj: &Map<String, Value>) -> HashMap<String, Value> {
|
||||
let mut defs = HashMap::new();
|
||||
|
||||
// Extract from $defs (JSON Schema 2019-09+)
|
||||
if let Some(Value::Object(defs_obj)) = obj.get("$defs") {
|
||||
for (key, value) in defs_obj {
|
||||
defs.insert(key.clone(), value.clone());
|
||||
}
|
||||
}
|
||||
|
||||
// Extract from definitions (JSON Schema draft-07)
|
||||
if let Some(Value::Object(defs_obj)) = obj.get("definitions") {
|
||||
for (key, value) in defs_obj {
|
||||
defs.insert(key.clone(), value.clone());
|
||||
}
|
||||
}
|
||||
|
||||
defs
|
||||
}
|
||||
|
||||
/// Recursively clean a schema value.
|
||||
fn clean_with_defs(
|
||||
schema: Value,
|
||||
defs: &HashMap<String, Value>,
|
||||
strategy: CleaningStrategy,
|
||||
ref_stack: &mut HashSet<String>,
|
||||
) -> Value {
|
||||
match schema {
|
||||
Value::Object(obj) => Self::clean_object(obj, defs, strategy, ref_stack),
|
||||
Value::Array(arr) => Value::Array(
|
||||
arr.into_iter()
|
||||
.map(|v| Self::clean_with_defs(v, defs, strategy, ref_stack))
|
||||
.collect(),
|
||||
),
|
||||
other => other,
|
||||
}
|
||||
}
|
||||
|
||||
/// Clean an object schema.
|
||||
fn clean_object(
|
||||
obj: Map<String, Value>,
|
||||
defs: &HashMap<String, Value>,
|
||||
strategy: CleaningStrategy,
|
||||
ref_stack: &mut HashSet<String>,
|
||||
) -> Value {
|
||||
// Handle $ref resolution
|
||||
if let Some(Value::String(ref_value)) = obj.get("$ref") {
|
||||
return Self::resolve_ref(ref_value, &obj, defs, strategy, ref_stack);
|
||||
}
|
||||
|
||||
// Handle anyOf/oneOf simplification
|
||||
if obj.contains_key("anyOf") || obj.contains_key("oneOf") {
|
||||
if let Some(simplified) = Self::try_simplify_union(&obj, defs, strategy, ref_stack) {
|
||||
return simplified;
|
||||
}
|
||||
}
|
||||
|
||||
// Build cleaned object
|
||||
let mut cleaned = Map::new();
|
||||
let unsupported: HashSet<&str> = strategy.unsupported_keywords().iter().copied().collect();
|
||||
let has_union = obj.contains_key("anyOf") || obj.contains_key("oneOf");
|
||||
|
||||
for (key, value) in obj {
|
||||
// Skip unsupported keywords
|
||||
if unsupported.contains(key.as_str()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Special handling for specific keys
|
||||
match key.as_str() {
|
||||
// Convert const to enum
|
||||
"const" => {
|
||||
cleaned.insert("enum".to_string(), json!([value]));
|
||||
}
|
||||
// Skip type if we have anyOf/oneOf (they define the type)
|
||||
"type" if has_union => {
|
||||
// Skip
|
||||
}
|
||||
// Handle type arrays (remove null)
|
||||
"type" if matches!(value, Value::Array(_)) => {
|
||||
let cleaned_value = Self::clean_type_array(value);
|
||||
cleaned.insert(key, cleaned_value);
|
||||
}
|
||||
// Recursively clean nested schemas
|
||||
"properties" => {
|
||||
let cleaned_value = Self::clean_properties(value, defs, strategy, ref_stack);
|
||||
cleaned.insert(key, cleaned_value);
|
||||
}
|
||||
"items" => {
|
||||
let cleaned_value = Self::clean_with_defs(value, defs, strategy, ref_stack);
|
||||
cleaned.insert(key, cleaned_value);
|
||||
}
|
||||
"anyOf" | "oneOf" | "allOf" => {
|
||||
let cleaned_value = Self::clean_union(value, defs, strategy, ref_stack);
|
||||
cleaned.insert(key, cleaned_value);
|
||||
}
|
||||
// Keep all other keys, cleaning nested objects/arrays recursively.
|
||||
_ => {
|
||||
let cleaned_value = match value {
|
||||
Value::Object(_) | Value::Array(_) => {
|
||||
Self::clean_with_defs(value, defs, strategy, ref_stack)
|
||||
}
|
||||
other => other,
|
||||
};
|
||||
cleaned.insert(key, cleaned_value);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Value::Object(cleaned)
|
||||
}
|
||||
|
||||
/// Resolve a $ref to its definition.
|
||||
fn resolve_ref(
|
||||
ref_value: &str,
|
||||
obj: &Map<String, Value>,
|
||||
defs: &HashMap<String, Value>,
|
||||
strategy: CleaningStrategy,
|
||||
ref_stack: &mut HashSet<String>,
|
||||
) -> Value {
|
||||
// Prevent circular references
|
||||
if ref_stack.contains(ref_value) {
|
||||
tracing::warn!("Circular $ref detected: {}", ref_value);
|
||||
return Self::preserve_meta(obj, Value::Object(Map::new()));
|
||||
}
|
||||
|
||||
// Try to resolve local ref (#/$defs/Name or #/definitions/Name)
|
||||
if let Some(def_name) = Self::parse_local_ref(ref_value) {
|
||||
if let Some(definition) = defs.get(def_name.as_str()) {
|
||||
ref_stack.insert(ref_value.to_string());
|
||||
let cleaned = Self::clean_with_defs(definition.clone(), defs, strategy, ref_stack);
|
||||
ref_stack.remove(ref_value);
|
||||
return Self::preserve_meta(obj, cleaned);
|
||||
}
|
||||
}
|
||||
|
||||
// Can't resolve: return empty object with metadata
|
||||
tracing::warn!("Cannot resolve $ref: {}", ref_value);
|
||||
Self::preserve_meta(obj, Value::Object(Map::new()))
|
||||
}
|
||||
|
||||
/// Parse a local JSON Pointer ref (#/$defs/Name).
|
||||
fn parse_local_ref(ref_value: &str) -> Option<String> {
|
||||
ref_value
|
||||
.strip_prefix("#/$defs/")
|
||||
.or_else(|| ref_value.strip_prefix("#/definitions/"))
|
||||
.map(Self::decode_json_pointer)
|
||||
}
|
||||
|
||||
/// Decode JSON Pointer escaping (`~0` = `~`, `~1` = `/`).
|
||||
fn decode_json_pointer(segment: &str) -> String {
|
||||
if !segment.contains('~') {
|
||||
return segment.to_string();
|
||||
}
|
||||
|
||||
let mut decoded = String::with_capacity(segment.len());
|
||||
let mut chars = segment.chars().peekable();
|
||||
|
||||
while let Some(ch) = chars.next() {
|
||||
if ch == '~' {
|
||||
match chars.peek().copied() {
|
||||
Some('0') => {
|
||||
chars.next();
|
||||
decoded.push('~');
|
||||
}
|
||||
Some('1') => {
|
||||
chars.next();
|
||||
decoded.push('/');
|
||||
}
|
||||
_ => decoded.push('~'),
|
||||
}
|
||||
} else {
|
||||
decoded.push(ch);
|
||||
}
|
||||
}
|
||||
|
||||
decoded
|
||||
}
|
||||
|
||||
/// Try to simplify anyOf/oneOf to a simpler form.
|
||||
fn try_simplify_union(
|
||||
obj: &Map<String, Value>,
|
||||
defs: &HashMap<String, Value>,
|
||||
strategy: CleaningStrategy,
|
||||
ref_stack: &mut HashSet<String>,
|
||||
) -> Option<Value> {
|
||||
let union_key = if obj.contains_key("anyOf") {
|
||||
"anyOf"
|
||||
} else if obj.contains_key("oneOf") {
|
||||
"oneOf"
|
||||
} else {
|
||||
return None;
|
||||
};
|
||||
|
||||
let variants = obj.get(union_key)?.as_array()?;
|
||||
|
||||
// Clean all variants first
|
||||
let cleaned_variants: Vec<Value> = variants
|
||||
.iter()
|
||||
.map(|v| Self::clean_with_defs(v.clone(), defs, strategy, ref_stack))
|
||||
.collect();
|
||||
|
||||
// Strip null variants
|
||||
let non_null: Vec<Value> = cleaned_variants
|
||||
.into_iter()
|
||||
.filter(|v| !Self::is_null_schema(v))
|
||||
.collect();
|
||||
|
||||
// If only one variant remains after stripping nulls, return it
|
||||
if non_null.len() == 1 {
|
||||
return Some(Self::preserve_meta(obj, non_null[0].clone()));
|
||||
}
|
||||
|
||||
// Try to flatten to enum if all variants are literals
|
||||
if let Some(enum_value) = Self::try_flatten_literal_union(&non_null) {
|
||||
return Some(Self::preserve_meta(obj, enum_value));
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// Check if a schema represents null type.
|
||||
fn is_null_schema(value: &Value) -> bool {
|
||||
if let Some(obj) = value.as_object() {
|
||||
// { const: null }
|
||||
if let Some(Value::Null) = obj.get("const") {
|
||||
return true;
|
||||
}
|
||||
// { enum: [null] }
|
||||
if let Some(Value::Array(arr)) = obj.get("enum") {
|
||||
if arr.len() == 1 && matches!(arr[0], Value::Null) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
// { type: "null" }
|
||||
if let Some(Value::String(t)) = obj.get("type") {
|
||||
if t == "null" {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
/// Try to flatten anyOf/oneOf with only literal values to enum.
|
||||
///
|
||||
/// Example: `anyOf: [{const: "a"}, {const: "b"}]` -> `{type: "string", enum: ["a", "b"]}`
|
||||
fn try_flatten_literal_union(variants: &[Value]) -> Option<Value> {
|
||||
if variants.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut all_values = Vec::new();
|
||||
let mut common_type: Option<String> = None;
|
||||
|
||||
for variant in variants {
|
||||
let obj = variant.as_object()?;
|
||||
|
||||
// Extract literal value from const or single-item enum
|
||||
let literal_value = if let Some(const_val) = obj.get("const") {
|
||||
const_val.clone()
|
||||
} else if let Some(Value::Array(arr)) = obj.get("enum") {
|
||||
if arr.len() == 1 {
|
||||
arr[0].clone()
|
||||
} else {
|
||||
return None;
|
||||
}
|
||||
} else {
|
||||
return None;
|
||||
};
|
||||
|
||||
// Check type consistency
|
||||
let variant_type = obj.get("type")?.as_str()?;
|
||||
match &common_type {
|
||||
None => common_type = Some(variant_type.to_string()),
|
||||
Some(t) if t != variant_type => return None,
|
||||
_ => {}
|
||||
}
|
||||
|
||||
all_values.push(literal_value);
|
||||
}
|
||||
|
||||
common_type.map(|t| {
|
||||
json!({
|
||||
"type": t,
|
||||
"enum": all_values
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
/// Clean type array, removing null.
|
||||
fn clean_type_array(value: Value) -> Value {
|
||||
if let Value::Array(types) = value {
|
||||
let non_null: Vec<Value> = types
|
||||
.into_iter()
|
||||
.filter(|v| v.as_str() != Some("null"))
|
||||
.collect();
|
||||
|
||||
match non_null.len() {
|
||||
0 => Value::String("null".to_string()),
|
||||
1 => non_null
|
||||
.into_iter()
|
||||
.next()
|
||||
.unwrap_or(Value::String("null".to_string())),
|
||||
_ => Value::Array(non_null),
|
||||
}
|
||||
} else {
|
||||
value
|
||||
}
|
||||
}
|
||||
|
||||
/// Clean properties object.
|
||||
fn clean_properties(
|
||||
value: Value,
|
||||
defs: &HashMap<String, Value>,
|
||||
strategy: CleaningStrategy,
|
||||
ref_stack: &mut HashSet<String>,
|
||||
) -> Value {
|
||||
if let Value::Object(props) = value {
|
||||
let cleaned: Map<String, Value> = props
|
||||
.into_iter()
|
||||
.map(|(k, v)| (k, Self::clean_with_defs(v, defs, strategy, ref_stack)))
|
||||
.collect();
|
||||
Value::Object(cleaned)
|
||||
} else {
|
||||
value
|
||||
}
|
||||
}
|
||||
|
||||
/// Clean union (anyOf/oneOf/allOf).
|
||||
fn clean_union(
|
||||
value: Value,
|
||||
defs: &HashMap<String, Value>,
|
||||
strategy: CleaningStrategy,
|
||||
ref_stack: &mut HashSet<String>,
|
||||
) -> Value {
|
||||
if let Value::Array(variants) = value {
|
||||
let cleaned: Vec<Value> = variants
|
||||
.into_iter()
|
||||
.map(|v| Self::clean_with_defs(v, defs, strategy, ref_stack))
|
||||
.collect();
|
||||
Value::Array(cleaned)
|
||||
} else {
|
||||
value
|
||||
}
|
||||
}
|
||||
|
||||
/// Preserve metadata (description, title, default) from source to target.
|
||||
fn preserve_meta(source: &Map<String, Value>, mut target: Value) -> Value {
|
||||
if let Value::Object(target_obj) = &mut target {
|
||||
for &key in SCHEMA_META_KEYS {
|
||||
if let Some(value) = source.get(key) {
|
||||
target_obj.insert(key.to_string(), value.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
target
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_remove_unsupported_keywords() {
|
||||
let schema = json!({
|
||||
"type": "string",
|
||||
"minLength": 1,
|
||||
"maxLength": 100,
|
||||
"pattern": "^[a-z]+$",
|
||||
"description": "A lowercase string"
|
||||
});
|
||||
|
||||
let cleaned = SchemaCleanr::clean_for_gemini(schema);
|
||||
|
||||
assert_eq!(cleaned["type"], "string");
|
||||
assert_eq!(cleaned["description"], "A lowercase string");
|
||||
assert!(cleaned.get("minLength").is_none());
|
||||
assert!(cleaned.get("maxLength").is_none());
|
||||
assert!(cleaned.get("pattern").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolve_ref() {
|
||||
let schema = json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"age": {
|
||||
"$ref": "#/$defs/Age"
|
||||
}
|
||||
},
|
||||
"$defs": {
|
||||
"Age": {
|
||||
"type": "integer",
|
||||
"minimum": 0
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let cleaned = SchemaCleanr::clean_for_gemini(schema);
|
||||
|
||||
assert_eq!(cleaned["properties"]["age"]["type"], "integer");
|
||||
assert!(cleaned["properties"]["age"].get("minimum").is_none()); // Stripped by Gemini strategy
|
||||
assert!(cleaned.get("$defs").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_flatten_literal_union() {
|
||||
let schema = json!({
|
||||
"anyOf": [
|
||||
{ "const": "admin", "type": "string" },
|
||||
{ "const": "user", "type": "string" },
|
||||
{ "const": "guest", "type": "string" }
|
||||
]
|
||||
});
|
||||
|
||||
let cleaned = SchemaCleanr::clean_for_gemini(schema);
|
||||
|
||||
assert_eq!(cleaned["type"], "string");
|
||||
assert!(cleaned["enum"].is_array());
|
||||
let enum_values = cleaned["enum"].as_array().unwrap();
|
||||
assert_eq!(enum_values.len(), 3);
|
||||
assert!(enum_values.contains(&json!("admin")));
|
||||
assert!(enum_values.contains(&json!("user")));
|
||||
assert!(enum_values.contains(&json!("guest")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_strip_null_from_union() {
|
||||
let schema = json!({
|
||||
"oneOf": [
|
||||
{ "type": "string" },
|
||||
{ "type": "null" }
|
||||
]
|
||||
});
|
||||
|
||||
let cleaned = SchemaCleanr::clean_for_gemini(schema);
|
||||
|
||||
// Should simplify to just { type: "string" }
|
||||
assert_eq!(cleaned["type"], "string");
|
||||
assert!(cleaned.get("oneOf").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_const_to_enum() {
|
||||
let schema = json!({
|
||||
"const": "fixed_value",
|
||||
"description": "A constant"
|
||||
});
|
||||
|
||||
let cleaned = SchemaCleanr::clean_for_gemini(schema);
|
||||
|
||||
assert_eq!(cleaned["enum"], json!(["fixed_value"]));
|
||||
assert_eq!(cleaned["description"], "A constant");
|
||||
assert!(cleaned.get("const").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_preserve_metadata() {
|
||||
let schema = json!({
|
||||
"$ref": "#/$defs/Name",
|
||||
"description": "User's name",
|
||||
"title": "Name Field",
|
||||
"default": "Anonymous",
|
||||
"$defs": {
|
||||
"Name": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let cleaned = SchemaCleanr::clean_for_gemini(schema);
|
||||
|
||||
assert_eq!(cleaned["type"], "string");
|
||||
assert_eq!(cleaned["description"], "User's name");
|
||||
assert_eq!(cleaned["title"], "Name Field");
|
||||
assert_eq!(cleaned["default"], "Anonymous");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_circular_ref_prevention() {
|
||||
let schema = json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"parent": {
|
||||
"$ref": "#/$defs/Node"
|
||||
}
|
||||
},
|
||||
"$defs": {
|
||||
"Node": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"child": {
|
||||
"$ref": "#/$defs/Node"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Should not panic on circular reference
|
||||
let cleaned = SchemaCleanr::clean_for_gemini(schema);
|
||||
|
||||
assert_eq!(cleaned["properties"]["parent"]["type"], "object");
|
||||
// Circular reference should be broken
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_schema() {
|
||||
let valid = json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": { "type": "string" }
|
||||
}
|
||||
});
|
||||
|
||||
assert!(SchemaCleanr::validate(&valid).is_ok());
|
||||
|
||||
let invalid = json!({
|
||||
"properties": {
|
||||
"name": { "type": "string" }
|
||||
}
|
||||
});
|
||||
|
||||
assert!(SchemaCleanr::validate(&invalid).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_strategy_differences() {
|
||||
let schema = json!({
|
||||
"type": "string",
|
||||
"minLength": 1,
|
||||
"description": "A string field"
|
||||
});
|
||||
|
||||
// Gemini: Most restrictive (removes minLength)
|
||||
let gemini = SchemaCleanr::clean_for_gemini(schema.clone());
|
||||
assert!(gemini.get("minLength").is_none());
|
||||
assert_eq!(gemini["type"], "string");
|
||||
assert_eq!(gemini["description"], "A string field");
|
||||
|
||||
// OpenAI: Most permissive (keeps minLength)
|
||||
let openai = SchemaCleanr::clean_for_openai(schema.clone());
|
||||
assert_eq!(openai["minLength"], 1); // OpenAI allows validation keywords
|
||||
assert_eq!(openai["type"], "string");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_nested_properties() {
|
||||
let schema = json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"user": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": "string",
|
||||
"minLength": 1
|
||||
}
|
||||
},
|
||||
"additionalProperties": false
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let cleaned = SchemaCleanr::clean_for_gemini(schema);
|
||||
|
||||
assert!(cleaned["properties"]["user"]["properties"]["name"]
|
||||
.get("minLength")
|
||||
.is_none());
|
||||
assert!(cleaned["properties"]["user"]
|
||||
.get("additionalProperties")
|
||||
.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_type_array_null_removal() {
|
||||
let schema = json!({
|
||||
"type": ["string", "null"]
|
||||
});
|
||||
|
||||
let cleaned = SchemaCleanr::clean_for_gemini(schema);
|
||||
|
||||
// Should simplify to just "string"
|
||||
assert_eq!(cleaned["type"], "string");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_type_array_only_null_preserved() {
|
||||
let schema = json!({
|
||||
"type": ["null"]
|
||||
});
|
||||
|
||||
let cleaned = SchemaCleanr::clean_for_gemini(schema);
|
||||
|
||||
assert_eq!(cleaned["type"], "null");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ref_with_json_pointer_escape() {
|
||||
let schema = json!({
|
||||
"$ref": "#/$defs/Foo~1Bar",
|
||||
"$defs": {
|
||||
"Foo/Bar": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let cleaned = SchemaCleanr::clean_for_gemini(schema);
|
||||
|
||||
assert_eq!(cleaned["type"], "string");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_skip_type_when_non_simplifiable_union_exists() {
|
||||
let schema = json!({
|
||||
"type": "object",
|
||||
"oneOf": [
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"a": { "type": "string" }
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"b": { "type": "number" }
|
||||
}
|
||||
}
|
||||
]
|
||||
});
|
||||
|
||||
let cleaned = SchemaCleanr::clean_for_gemini(schema);
|
||||
|
||||
assert!(cleaned.get("type").is_none());
|
||||
assert!(cleaned.get("oneOf").is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_clean_nested_unknown_schema_keyword() {
|
||||
let schema = json!({
|
||||
"not": {
|
||||
"$ref": "#/$defs/Age"
|
||||
},
|
||||
"$defs": {
|
||||
"Age": {
|
||||
"type": "integer",
|
||||
"minimum": 0
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let cleaned = SchemaCleanr::clean_for_gemini(schema);
|
||||
|
||||
assert_eq!(cleaned["not"]["type"], "integer");
|
||||
assert!(cleaned["not"].get("minimum").is_none());
|
||||
}
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue