feat: initial release — ZeroClaw v0.1.0

- 22 AI providers (OpenRouter, Anthropic, OpenAI, Mistral, etc.)
- 7 channels (CLI, Telegram, Discord, Slack, iMessage, Matrix, Webhook)
- 5-step onboarding wizard with Project Context personalization
- OpenClaw-aligned system prompt (SOUL.md, IDENTITY.md, USER.md, AGENTS.md, etc.)
- SQLite memory backend with auto-save
- Skills system with on-demand loading
- Security: autonomy levels, command allowlists, cost limits
- 532 tests passing, 0 clippy warnings
This commit is contained in:
argenis de la rosa 2026-02-13 12:19:14 -05:00
commit 05cb353f7f
71 changed files with 15757 additions and 0 deletions

182
src/agent/loop_.rs Normal file
View file

@ -0,0 +1,182 @@
use crate::config::Config;
use crate::memory::{self, Memory, MemoryCategory};
use crate::observability::{self, Observer, ObserverEvent};
use crate::providers::{self, Provider};
use crate::runtime;
use crate::security::SecurityPolicy;
use crate::tools;
use anyhow::Result;
use std::fmt::Write;
use std::sync::Arc;
use std::time::Instant;
/// Build context preamble by searching memory for relevant entries
async fn build_context(mem: &dyn Memory, user_msg: &str) -> String {
let mut context = String::new();
// Pull relevant memories for this message
if let Ok(entries) = mem.recall(user_msg, 5).await {
if !entries.is_empty() {
context.push_str("[Memory context]\n");
for entry in &entries {
let _ = writeln!(context, "- {}: {}", entry.key, entry.content);
}
context.push('\n');
}
}
context
}
#[allow(clippy::too_many_lines)]
pub async fn run(
config: Config,
message: Option<String>,
provider_override: Option<String>,
model_override: Option<String>,
temperature: f64,
) -> Result<()> {
// ── Wire up agnostic subsystems ──────────────────────────────
let observer: Arc<dyn Observer> =
Arc::from(observability::create_observer(&config.observability));
let _runtime = runtime::create_runtime(&config.runtime);
let security = Arc::new(SecurityPolicy::from_config(
&config.autonomy,
&config.workspace_dir,
));
// ── Memory (the brain) ────────────────────────────────────────
let mem: Arc<dyn Memory> =
Arc::from(memory::create_memory(&config.memory, &config.workspace_dir)?);
tracing::info!(backend = mem.name(), "Memory initialized");
// ── Tools (including memory tools) ────────────────────────────
let _tools = tools::all_tools(security, mem.clone());
// ── Resolve provider ─────────────────────────────────────────
let provider_name = provider_override
.as_deref()
.or(config.default_provider.as_deref())
.unwrap_or("openrouter");
let model_name = model_override
.as_deref()
.or(config.default_model.as_deref())
.unwrap_or("anthropic/claude-sonnet-4-20250514");
let provider: Box<dyn Provider> =
providers::create_provider(provider_name, config.api_key.as_deref())?;
observer.record_event(&ObserverEvent::AgentStart {
provider: provider_name.to_string(),
model: model_name.to_string(),
});
// ── Build system prompt from workspace MD files (OpenClaw framework) ──
let skills = crate::skills::load_skills(&config.workspace_dir);
let tool_descs: Vec<(&str, &str)> = vec![
("shell", "Execute terminal commands"),
("file_read", "Read file contents"),
("file_write", "Write file contents"),
("memory_store", "Save to memory"),
("memory_recall", "Search memory"),
("memory_forget", "Delete a memory entry"),
];
let system_prompt = crate::channels::build_system_prompt(
&config.workspace_dir,
model_name,
&tool_descs,
&skills,
);
// ── Execute ──────────────────────────────────────────────────
let start = Instant::now();
if let Some(msg) = message {
// Auto-save user message to memory
if config.memory.auto_save {
let _ = mem
.store("user_msg", &msg, MemoryCategory::Conversation)
.await;
}
// Inject memory context into user message
let context = build_context(mem.as_ref(), &msg).await;
let enriched = if context.is_empty() {
msg.clone()
} else {
format!("{context}{msg}")
};
let response = provider
.chat_with_system(Some(&system_prompt), &enriched, model_name, temperature)
.await?;
println!("{response}");
// Auto-save assistant response to daily log
if config.memory.auto_save {
let summary = if response.len() > 100 {
format!("{}...", &response[..100])
} else {
response.clone()
};
let _ = mem
.store("assistant_resp", &summary, MemoryCategory::Daily)
.await;
}
} else {
println!("🦀 ZeroClaw Interactive Mode");
println!("Type /quit to exit.\n");
let (tx, mut rx) = tokio::sync::mpsc::channel(32);
let cli = crate::channels::CliChannel::new();
// Spawn listener
let listen_handle = tokio::spawn(async move {
let _ = crate::channels::Channel::listen(&cli, tx).await;
});
while let Some(msg) = rx.recv().await {
// Auto-save conversation turns
if config.memory.auto_save {
let _ = mem
.store("user_msg", &msg.content, MemoryCategory::Conversation)
.await;
}
// Inject memory context into user message
let context = build_context(mem.as_ref(), &msg.content).await;
let enriched = if context.is_empty() {
msg.content.clone()
} else {
format!("{context}{}", msg.content)
};
let response = provider
.chat_with_system(Some(&system_prompt), &enriched, model_name, temperature)
.await?;
println!("\n{response}\n");
if config.memory.auto_save {
let summary = if response.len() > 100 {
format!("{}...", &response[..100])
} else {
response.clone()
};
let _ = mem
.store("assistant_resp", &summary, MemoryCategory::Daily)
.await;
}
}
listen_handle.abort();
}
let duration = start.elapsed();
observer.record_event(&ObserverEvent::AgentEnd {
duration,
tokens_used: None,
});
Ok(())
}

3
src/agent/mod.rs Normal file
View file

@ -0,0 +1,3 @@
pub mod loop_;
pub use loop_::run;

117
src/channels/cli.rs Normal file
View file

@ -0,0 +1,117 @@
use super::traits::{Channel, ChannelMessage};
use async_trait::async_trait;
use tokio::io::{self, AsyncBufReadExt, BufReader};
use uuid::Uuid;
/// CLI channel — stdin/stdout, always available, zero deps
pub struct CliChannel;
impl CliChannel {
pub fn new() -> Self {
Self
}
}
#[async_trait]
impl Channel for CliChannel {
fn name(&self) -> &str {
"cli"
}
async fn send(&self, message: &str, _recipient: &str) -> anyhow::Result<()> {
println!("{message}");
Ok(())
}
async fn listen(&self, tx: tokio::sync::mpsc::Sender<ChannelMessage>) -> anyhow::Result<()> {
let stdin = io::stdin();
let reader = BufReader::new(stdin);
let mut lines = reader.lines();
while let Ok(Some(line)) = lines.next_line().await {
let line = line.trim().to_string();
if line.is_empty() {
continue;
}
if line == "/quit" || line == "/exit" {
break;
}
let msg = ChannelMessage {
id: Uuid::new_v4().to_string(),
sender: "user".to_string(),
content: line,
channel: "cli".to_string(),
timestamp: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
};
if tx.send(msg).await.is_err() {
break;
}
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn cli_channel_name() {
assert_eq!(CliChannel::new().name(), "cli");
}
#[tokio::test]
async fn cli_channel_send_does_not_panic() {
let ch = CliChannel::new();
let result = ch.send("hello", "user").await;
assert!(result.is_ok());
}
#[tokio::test]
async fn cli_channel_send_empty_message() {
let ch = CliChannel::new();
let result = ch.send("", "").await;
assert!(result.is_ok());
}
#[tokio::test]
async fn cli_channel_health_check() {
let ch = CliChannel::new();
assert!(ch.health_check().await);
}
#[test]
fn channel_message_struct() {
let msg = ChannelMessage {
id: "test-id".into(),
sender: "user".into(),
content: "hello".into(),
channel: "cli".into(),
timestamp: 1234567890,
};
assert_eq!(msg.id, "test-id");
assert_eq!(msg.sender, "user");
assert_eq!(msg.content, "hello");
assert_eq!(msg.channel, "cli");
assert_eq!(msg.timestamp, 1234567890);
}
#[test]
fn channel_message_clone() {
let msg = ChannelMessage {
id: "id".into(),
sender: "s".into(),
content: "c".into(),
channel: "ch".into(),
timestamp: 0,
};
let cloned = msg.clone();
assert_eq!(cloned.id, msg.id);
assert_eq!(cloned.content, msg.content);
}
}

271
src/channels/discord.rs Normal file
View file

@ -0,0 +1,271 @@
use super::traits::{Channel, ChannelMessage};
use async_trait::async_trait;
use futures_util::{SinkExt, StreamExt};
use serde_json::json;
use tokio_tungstenite::tungstenite::Message;
use uuid::Uuid;
/// Discord channel — connects via Gateway WebSocket for real-time messages
pub struct DiscordChannel {
bot_token: String,
guild_id: Option<String>,
client: reqwest::Client,
}
impl DiscordChannel {
pub fn new(bot_token: String, guild_id: Option<String>) -> Self {
Self {
bot_token,
guild_id,
client: reqwest::Client::new(),
}
}
fn bot_user_id_from_token(token: &str) -> Option<String> {
// Discord bot tokens are base64(bot_user_id).timestamp.hmac
let part = token.split('.').next()?;
base64_decode(part)
}
}
const BASE64_ALPHABET: &[u8] =
b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
/// Minimal base64 decode (no extra dep) — only needs to decode the user ID portion
#[allow(clippy::cast_possible_truncation)]
fn base64_decode(input: &str) -> Option<String> {
let padded = match input.len() % 4 {
2 => format!("{input}=="),
3 => format!("{input}="),
_ => input.to_string(),
};
let mut bytes = Vec::new();
let chars: Vec<u8> = padded.bytes().collect();
for chunk in chars.chunks(4) {
if chunk.len() < 4 {
break;
}
let mut v = [0usize; 4];
for (i, &b) in chunk.iter().enumerate() {
if b == b'=' {
v[i] = 0;
} else {
v[i] = BASE64_ALPHABET.iter().position(|&a| a == b)?;
}
}
bytes.push(((v[0] << 2) | (v[1] >> 4)) as u8);
if chunk[2] != b'=' {
bytes.push((((v[1] & 0xF) << 4) | (v[2] >> 2)) as u8);
}
if chunk[3] != b'=' {
bytes.push((((v[2] & 0x3) << 6) | v[3]) as u8);
}
}
String::from_utf8(bytes).ok()
}
#[async_trait]
impl Channel for DiscordChannel {
fn name(&self) -> &str {
"discord"
}
async fn send(&self, message: &str, channel_id: &str) -> anyhow::Result<()> {
let url = format!("https://discord.com/api/v10/channels/{channel_id}/messages");
let body = json!({ "content": message });
self.client
.post(&url)
.header("Authorization", format!("Bot {}", self.bot_token))
.json(&body)
.send()
.await?;
Ok(())
}
#[allow(clippy::too_many_lines)]
async fn listen(&self, tx: tokio::sync::mpsc::Sender<ChannelMessage>) -> anyhow::Result<()> {
let bot_user_id = Self::bot_user_id_from_token(&self.bot_token).unwrap_or_default();
// Get Gateway URL
let gw_resp: serde_json::Value = self
.client
.get("https://discord.com/api/v10/gateway/bot")
.header("Authorization", format!("Bot {}", self.bot_token))
.send()
.await?
.json()
.await?;
let gw_url = gw_resp
.get("url")
.and_then(|u| u.as_str())
.unwrap_or("wss://gateway.discord.gg");
let ws_url = format!("{gw_url}/?v=10&encoding=json");
tracing::info!("Discord: connecting to gateway...");
let (ws_stream, _) = tokio_tungstenite::connect_async(&ws_url).await?;
let (mut write, mut read) = ws_stream.split();
// Read Hello (opcode 10)
let hello = read.next().await.ok_or(anyhow::anyhow!("No hello"))??;
let hello_data: serde_json::Value = serde_json::from_str(&hello.to_string())?;
let heartbeat_interval = hello_data
.get("d")
.and_then(|d| d.get("heartbeat_interval"))
.and_then(serde_json::Value::as_u64)
.unwrap_or(41250);
// Send Identify (opcode 2)
let identify = json!({
"op": 2,
"d": {
"token": self.bot_token,
"intents": 33281, // GUILDS | GUILD_MESSAGES | MESSAGE_CONTENT | DIRECT_MESSAGES
"properties": {
"os": "linux",
"browser": "zeroclaw",
"device": "zeroclaw"
}
}
});
write.send(Message::Text(identify.to_string())).await?;
tracing::info!("Discord: connected and identified");
// Spawn heartbeat task
let (hb_tx, mut hb_rx) = tokio::sync::mpsc::channel::<()>(1);
let hb_interval = heartbeat_interval;
tokio::spawn(async move {
let mut interval =
tokio::time::interval(std::time::Duration::from_millis(hb_interval));
loop {
interval.tick().await;
if hb_tx.send(()).await.is_err() {
break;
}
}
});
let guild_filter = self.guild_id.clone();
loop {
tokio::select! {
_ = hb_rx.recv() => {
let hb = json!({"op": 1, "d": null});
if write.send(Message::Text(hb.to_string())).await.is_err() {
break;
}
}
msg = read.next() => {
let msg = match msg {
Some(Ok(Message::Text(t))) => t,
Some(Ok(Message::Close(_))) | None => break,
_ => continue,
};
let event: serde_json::Value = match serde_json::from_str(&msg) {
Ok(e) => e,
Err(_) => continue,
};
// Only handle MESSAGE_CREATE (opcode 0, type "MESSAGE_CREATE")
let event_type = event.get("t").and_then(|t| t.as_str()).unwrap_or("");
if event_type != "MESSAGE_CREATE" {
continue;
}
let Some(d) = event.get("d") else {
continue;
};
// Skip messages from the bot itself
let author_id = d.get("author").and_then(|a| a.get("id")).and_then(|i| i.as_str()).unwrap_or("");
if author_id == bot_user_id {
continue;
}
// Skip bot messages
if d.get("author").and_then(|a| a.get("bot")).and_then(serde_json::Value::as_bool).unwrap_or(false) {
continue;
}
// Guild filter
if let Some(ref gid) = guild_filter {
let msg_guild = d.get("guild_id").and_then(serde_json::Value::as_str).unwrap_or("");
if msg_guild != gid {
continue;
}
}
let content = d.get("content").and_then(|c| c.as_str()).unwrap_or("");
if content.is_empty() {
continue;
}
let channel_id = d.get("channel_id").and_then(|c| c.as_str()).unwrap_or("").to_string();
let channel_msg = ChannelMessage {
id: Uuid::new_v4().to_string(),
sender: channel_id,
content: content.to_string(),
channel: "discord".to_string(),
timestamp: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
};
if tx.send(channel_msg).await.is_err() {
break;
}
}
}
}
Ok(())
}
async fn health_check(&self) -> bool {
self.client
.get("https://discord.com/api/v10/users/@me")
.header("Authorization", format!("Bot {}", self.bot_token))
.send()
.await
.map(|r| r.status().is_success())
.unwrap_or(false)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn discord_channel_name() {
let ch = DiscordChannel::new("fake".into(), None);
assert_eq!(ch.name(), "discord");
}
#[test]
fn base64_decode_bot_id() {
// "MTIzNDU2" decodes to "123456"
let decoded = base64_decode("MTIzNDU2");
assert_eq!(decoded, Some("123456".to_string()));
}
#[test]
fn bot_user_id_extraction() {
// Token format: base64(user_id).timestamp.hmac
let token = "MTIzNDU2.fake.hmac";
let id = DiscordChannel::bot_user_id_from_token(token);
assert_eq!(id, Some("123456".to_string()));
}
}

265
src/channels/imessage.rs Normal file
View file

@ -0,0 +1,265 @@
use crate::channels::traits::{Channel, ChannelMessage};
use async_trait::async_trait;
use directories::UserDirs;
use tokio::sync::mpsc;
/// iMessage channel using macOS `AppleScript` bridge.
/// Polls the Messages database for new messages and sends replies via `osascript`.
#[derive(Clone)]
pub struct IMessageChannel {
allowed_contacts: Vec<String>,
poll_interval_secs: u64,
}
impl IMessageChannel {
pub fn new(allowed_contacts: Vec<String>) -> Self {
Self {
allowed_contacts,
poll_interval_secs: 3,
}
}
fn is_contact_allowed(&self, sender: &str) -> bool {
if self.allowed_contacts.iter().any(|u| u == "*") {
return true;
}
self.allowed_contacts.iter().any(|u| {
u.eq_ignore_ascii_case(sender)
})
}
}
#[async_trait]
impl Channel for IMessageChannel {
fn name(&self) -> &str {
"imessage"
}
async fn send(&self, message: &str, target: &str) -> anyhow::Result<()> {
let escaped_msg = message.replace('\\', "\\\\").replace('"', "\\\"");
let script = format!(
r#"tell application "Messages"
set targetService to 1st account whose service type = iMessage
set targetBuddy to participant "{target}" of targetService
send "{escaped_msg}" to targetBuddy
end tell"#
);
let output = tokio::process::Command::new("osascript")
.arg("-e")
.arg(&script)
.output()
.await?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
anyhow::bail!("iMessage send failed: {stderr}");
}
Ok(())
}
async fn listen(&self, tx: mpsc::Sender<ChannelMessage>) -> anyhow::Result<()> {
tracing::info!("iMessage channel listening (AppleScript bridge)...");
// Query the Messages SQLite database for new messages
// The database is at ~/Library/Messages/chat.db
let db_path = UserDirs::new()
.map(|u| u.home_dir().join("Library/Messages/chat.db"))
.ok_or_else(|| anyhow::anyhow!("Cannot find home directory"))?;
if !db_path.exists() {
anyhow::bail!(
"Messages database not found at {}. Ensure Messages.app is set up and Full Disk Access is granted.",
db_path.display()
);
}
// Track the last ROWID we've seen
let mut last_rowid = get_max_rowid(&db_path).await.unwrap_or(0);
loop {
tokio::time::sleep(tokio::time::Duration::from_secs(self.poll_interval_secs)).await;
let new_messages = fetch_new_messages(&db_path, last_rowid).await;
match new_messages {
Ok(messages) => {
for (rowid, sender, text) in messages {
if rowid > last_rowid {
last_rowid = rowid;
}
if !self.is_contact_allowed(&sender) {
continue;
}
if text.trim().is_empty() {
continue;
}
let msg = ChannelMessage {
id: rowid.to_string(),
sender: sender.clone(),
content: text,
channel: "imessage".to_string(),
timestamp: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
};
if tx.send(msg).await.is_err() {
return Ok(());
}
}
}
Err(e) => {
tracing::warn!("iMessage poll error: {e}");
}
}
}
}
async fn health_check(&self) -> bool {
if !cfg!(target_os = "macos") {
return false;
}
let db_path = UserDirs::new()
.map(|u| u.home_dir().join("Library/Messages/chat.db"))
.unwrap_or_default();
db_path.exists()
}
}
/// Get the current max ROWID from the messages table
async fn get_max_rowid(db_path: &std::path::Path) -> anyhow::Result<i64> {
let output = tokio::process::Command::new("sqlite3")
.arg(db_path)
.arg("SELECT MAX(ROWID) FROM message WHERE is_from_me = 0;")
.output()
.await?;
let stdout = String::from_utf8_lossy(&output.stdout);
let rowid = stdout.trim().parse::<i64>().unwrap_or(0);
Ok(rowid)
}
/// Fetch messages newer than `since_rowid`
async fn fetch_new_messages(
db_path: &std::path::Path,
since_rowid: i64,
) -> anyhow::Result<Vec<(i64, String, String)>> {
let query = format!(
"SELECT m.ROWID, h.id, m.text \
FROM message m \
JOIN handle h ON m.handle_id = h.ROWID \
WHERE m.ROWID > {since_rowid} \
AND m.is_from_me = 0 \
AND m.text IS NOT NULL \
ORDER BY m.ROWID ASC \
LIMIT 20;"
);
let output = tokio::process::Command::new("sqlite3")
.arg("-separator")
.arg("|")
.arg(db_path)
.arg(&query)
.output()
.await?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
anyhow::bail!("sqlite3 query failed: {stderr}");
}
let stdout = String::from_utf8_lossy(&output.stdout);
let mut results = Vec::new();
for line in stdout.lines() {
let parts: Vec<&str> = line.splitn(3, '|').collect();
if parts.len() == 3 {
if let Ok(rowid) = parts[0].parse::<i64>() {
results.push((rowid, parts[1].to_string(), parts[2].to_string()));
}
}
}
Ok(results)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn creates_with_contacts() {
let ch = IMessageChannel::new(vec!["+1234567890".into()]);
assert_eq!(ch.allowed_contacts.len(), 1);
assert_eq!(ch.poll_interval_secs, 3);
}
#[test]
fn creates_with_empty_contacts() {
let ch = IMessageChannel::new(vec![]);
assert!(ch.allowed_contacts.is_empty());
}
#[test]
fn wildcard_allows_anyone() {
let ch = IMessageChannel::new(vec!["*".into()]);
assert!(ch.is_contact_allowed("+1234567890"));
assert!(ch.is_contact_allowed("random@icloud.com"));
assert!(ch.is_contact_allowed(""));
}
#[test]
fn specific_contact_allowed() {
let ch = IMessageChannel::new(vec!["+1234567890".into(), "user@icloud.com".into()]);
assert!(ch.is_contact_allowed("+1234567890"));
assert!(ch.is_contact_allowed("user@icloud.com"));
}
#[test]
fn unknown_contact_denied() {
let ch = IMessageChannel::new(vec!["+1234567890".into()]);
assert!(!ch.is_contact_allowed("+9999999999"));
assert!(!ch.is_contact_allowed("hacker@evil.com"));
}
#[test]
fn contact_case_insensitive() {
let ch = IMessageChannel::new(vec!["User@iCloud.com".into()]);
assert!(ch.is_contact_allowed("user@icloud.com"));
assert!(ch.is_contact_allowed("USER@ICLOUD.COM"));
}
#[test]
fn empty_allowlist_denies_all() {
let ch = IMessageChannel::new(vec![]);
assert!(!ch.is_contact_allowed("+1234567890"));
assert!(!ch.is_contact_allowed("anyone"));
}
#[test]
fn name_returns_imessage() {
let ch = IMessageChannel::new(vec![]);
assert_eq!(ch.name(), "imessage");
}
#[test]
fn wildcard_among_others_still_allows_all() {
let ch = IMessageChannel::new(vec!["+111".into(), "*".into(), "+222".into()]);
assert!(ch.is_contact_allowed("totally-unknown"));
}
#[test]
fn contact_with_spaces_exact_match() {
let ch = IMessageChannel::new(vec![" spaced ".into()]);
assert!(ch.is_contact_allowed(" spaced "));
assert!(!ch.is_contact_allowed("spaced"));
}
}

467
src/channels/matrix.rs Normal file
View file

@ -0,0 +1,467 @@
use crate::channels::traits::{Channel, ChannelMessage};
use async_trait::async_trait;
use reqwest::Client;
use serde::Deserialize;
use tokio::sync::mpsc;
/// Matrix channel using the Client-Server API (no SDK needed).
/// Connects to any Matrix homeserver (Element, Synapse, etc.).
#[derive(Clone)]
pub struct MatrixChannel {
homeserver: String,
access_token: String,
room_id: String,
allowed_users: Vec<String>,
client: Client,
}
#[derive(Debug, Deserialize)]
struct SyncResponse {
next_batch: String,
#[serde(default)]
rooms: Rooms,
}
#[derive(Debug, Deserialize, Default)]
struct Rooms {
#[serde(default)]
join: std::collections::HashMap<String, JoinedRoom>,
}
#[derive(Debug, Deserialize)]
struct JoinedRoom {
#[serde(default)]
timeline: Timeline,
}
#[derive(Debug, Deserialize, Default)]
struct Timeline {
#[serde(default)]
events: Vec<TimelineEvent>,
}
#[derive(Debug, Deserialize)]
struct TimelineEvent {
#[serde(rename = "type")]
event_type: String,
sender: String,
#[serde(default)]
content: EventContent,
}
#[derive(Debug, Deserialize, Default)]
struct EventContent {
#[serde(default)]
body: Option<String>,
#[serde(default)]
msgtype: Option<String>,
}
#[derive(Debug, Deserialize)]
struct WhoAmIResponse {
user_id: String,
}
impl MatrixChannel {
pub fn new(
homeserver: String,
access_token: String,
room_id: String,
allowed_users: Vec<String>,
) -> Self {
let homeserver = if homeserver.ends_with('/') {
homeserver[..homeserver.len() - 1].to_string()
} else {
homeserver
};
Self {
homeserver,
access_token,
room_id,
allowed_users,
client: Client::new(),
}
}
fn is_user_allowed(&self, sender: &str) -> bool {
if self.allowed_users.iter().any(|u| u == "*") {
return true;
}
self.allowed_users
.iter()
.any(|u| u.eq_ignore_ascii_case(sender))
}
async fn get_my_user_id(&self) -> anyhow::Result<String> {
let url = format!(
"{}/_matrix/client/v3/account/whoami",
self.homeserver
);
let resp = self
.client
.get(&url)
.header("Authorization", format!("Bearer {}", self.access_token))
.send()
.await?;
if !resp.status().is_success() {
let err = resp.text().await?;
anyhow::bail!("Matrix whoami failed: {err}");
}
let who: WhoAmIResponse = resp.json().await?;
Ok(who.user_id)
}
}
#[async_trait]
impl Channel for MatrixChannel {
fn name(&self) -> &str {
"matrix"
}
async fn send(&self, message: &str, _target: &str) -> anyhow::Result<()> {
let txn_id = format!("zc_{}", chrono::Utc::now().timestamp_millis());
let url = format!(
"{}/_matrix/client/v3/rooms/{}/send/m.room.message/{}",
self.homeserver, self.room_id, txn_id
);
let body = serde_json::json!({
"msgtype": "m.text",
"body": message
});
let resp = self
.client
.put(&url)
.header("Authorization", format!("Bearer {}", self.access_token))
.json(&body)
.send()
.await?;
if !resp.status().is_success() {
let err = resp.text().await?;
anyhow::bail!("Matrix send failed: {err}");
}
Ok(())
}
async fn listen(&self, tx: mpsc::Sender<ChannelMessage>) -> anyhow::Result<()> {
tracing::info!("Matrix channel listening on room {}...", self.room_id);
let my_user_id = self.get_my_user_id().await?;
// Initial sync to get the since token
let url = format!(
"{}/_matrix/client/v3/sync?timeout=30000&filter={{\"room\":{{\"timeline\":{{\"limit\":1}}}}}}",
self.homeserver
);
let resp = self
.client
.get(&url)
.header("Authorization", format!("Bearer {}", self.access_token))
.send()
.await?;
if !resp.status().is_success() {
let err = resp.text().await?;
anyhow::bail!("Matrix initial sync failed: {err}");
}
let sync: SyncResponse = resp.json().await?;
let mut since = sync.next_batch;
// Long-poll loop
loop {
let url = format!(
"{}/_matrix/client/v3/sync?since={}&timeout=30000",
self.homeserver, since
);
let resp = self
.client
.get(&url)
.header("Authorization", format!("Bearer {}", self.access_token))
.send()
.await;
let resp = match resp {
Ok(r) => r,
Err(e) => {
tracing::warn!("Matrix sync error: {e}, retrying...");
tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
continue;
}
};
if !resp.status().is_success() {
tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
continue;
}
let sync: SyncResponse = resp.json().await?;
since = sync.next_batch;
// Process events from our room
if let Some(room) = sync.rooms.join.get(&self.room_id) {
for event in &room.timeline.events {
// Skip our own messages
if event.sender == my_user_id {
continue;
}
// Only process text messages
if event.event_type != "m.room.message" {
continue;
}
if event.content.msgtype.as_deref() != Some("m.text") {
continue;
}
let Some(ref body) = event.content.body else {
continue;
};
if !self.is_user_allowed(&event.sender) {
continue;
}
let msg = ChannelMessage {
id: format!("mx_{}", chrono::Utc::now().timestamp_millis()),
sender: event.sender.clone(),
content: body.clone(),
channel: "matrix".to_string(),
timestamp: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
};
if tx.send(msg).await.is_err() {
return Ok(());
}
}
}
}
}
async fn health_check(&self) -> bool {
let url = format!(
"{}/_matrix/client/v3/account/whoami",
self.homeserver
);
let Ok(resp) = self
.client
.get(&url)
.header("Authorization", format!("Bearer {}", self.access_token))
.send()
.await
else {
return false;
};
resp.status().is_success()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_channel() -> MatrixChannel {
MatrixChannel::new(
"https://matrix.org".to_string(),
"syt_test_token".to_string(),
"!room:matrix.org".to_string(),
vec!["@user:matrix.org".to_string()],
)
}
#[test]
fn creates_with_correct_fields() {
let ch = make_channel();
assert_eq!(ch.homeserver, "https://matrix.org");
assert_eq!(ch.access_token, "syt_test_token");
assert_eq!(ch.room_id, "!room:matrix.org");
assert_eq!(ch.allowed_users.len(), 1);
}
#[test]
fn strips_trailing_slash() {
let ch = MatrixChannel::new(
"https://matrix.org/".to_string(),
"tok".to_string(),
"!r:m".to_string(),
vec![],
);
assert_eq!(ch.homeserver, "https://matrix.org");
}
#[test]
fn no_trailing_slash_unchanged() {
let ch = MatrixChannel::new(
"https://matrix.org".to_string(),
"tok".to_string(),
"!r:m".to_string(),
vec![],
);
assert_eq!(ch.homeserver, "https://matrix.org");
}
#[test]
fn multiple_trailing_slashes_strips_one() {
let ch = MatrixChannel::new(
"https://matrix.org//".to_string(),
"tok".to_string(),
"!r:m".to_string(),
vec![],
);
assert_eq!(ch.homeserver, "https://matrix.org/");
}
#[test]
fn wildcard_allows_anyone() {
let ch = MatrixChannel::new(
"https://m.org".to_string(),
"tok".to_string(),
"!r:m".to_string(),
vec!["*".to_string()],
);
assert!(ch.is_user_allowed("@anyone:matrix.org"));
assert!(ch.is_user_allowed("@hacker:evil.org"));
}
#[test]
fn specific_user_allowed() {
let ch = make_channel();
assert!(ch.is_user_allowed("@user:matrix.org"));
}
#[test]
fn unknown_user_denied() {
let ch = make_channel();
assert!(!ch.is_user_allowed("@stranger:matrix.org"));
assert!(!ch.is_user_allowed("@evil:hacker.org"));
}
#[test]
fn user_case_insensitive() {
let ch = MatrixChannel::new(
"https://m.org".to_string(),
"tok".to_string(),
"!r:m".to_string(),
vec!["@User:Matrix.org".to_string()],
);
assert!(ch.is_user_allowed("@user:matrix.org"));
assert!(ch.is_user_allowed("@USER:MATRIX.ORG"));
}
#[test]
fn empty_allowlist_denies_all() {
let ch = MatrixChannel::new(
"https://m.org".to_string(),
"tok".to_string(),
"!r:m".to_string(),
vec![],
);
assert!(!ch.is_user_allowed("@anyone:matrix.org"));
}
#[test]
fn name_returns_matrix() {
let ch = make_channel();
assert_eq!(ch.name(), "matrix");
}
#[test]
fn sync_response_deserializes_empty() {
let json = r#"{"next_batch":"s123","rooms":{"join":{}}}"#;
let resp: SyncResponse = serde_json::from_str(json).unwrap();
assert_eq!(resp.next_batch, "s123");
assert!(resp.rooms.join.is_empty());
}
#[test]
fn sync_response_deserializes_with_events() {
let json = r#"{
"next_batch": "s456",
"rooms": {
"join": {
"!room:matrix.org": {
"timeline": {
"events": [
{
"type": "m.room.message",
"sender": "@user:matrix.org",
"content": {
"msgtype": "m.text",
"body": "Hello!"
}
}
]
}
}
}
}
}"#;
let resp: SyncResponse = serde_json::from_str(json).unwrap();
assert_eq!(resp.next_batch, "s456");
let room = resp.rooms.join.get("!room:matrix.org").unwrap();
assert_eq!(room.timeline.events.len(), 1);
assert_eq!(room.timeline.events[0].sender, "@user:matrix.org");
assert_eq!(room.timeline.events[0].content.body.as_deref(), Some("Hello!"));
assert_eq!(room.timeline.events[0].content.msgtype.as_deref(), Some("m.text"));
}
#[test]
fn sync_response_ignores_non_text_events() {
let json = r#"{
"next_batch": "s789",
"rooms": {
"join": {
"!room:m": {
"timeline": {
"events": [
{
"type": "m.room.member",
"sender": "@user:m",
"content": {}
}
]
}
}
}
}
}"#;
let resp: SyncResponse = serde_json::from_str(json).unwrap();
let room = resp.rooms.join.get("!room:m").unwrap();
assert_eq!(room.timeline.events[0].event_type, "m.room.member");
assert!(room.timeline.events[0].content.body.is_none());
}
#[test]
fn whoami_response_deserializes() {
let json = r#"{"user_id":"@bot:matrix.org"}"#;
let resp: WhoAmIResponse = serde_json::from_str(json).unwrap();
assert_eq!(resp.user_id, "@bot:matrix.org");
}
#[test]
fn event_content_defaults() {
let json = r#"{"type":"m.room.message","sender":"@u:m","content":{}}"#;
let event: TimelineEvent = serde_json::from_str(json).unwrap();
assert!(event.content.body.is_none());
assert!(event.content.msgtype.is_none());
}
#[test]
fn sync_response_missing_rooms_defaults() {
let json = r#"{"next_batch":"s0"}"#;
let resp: SyncResponse = serde_json::from_str(json).unwrap();
assert!(resp.rooms.join.is_empty());
}
}

550
src/channels/mod.rs Normal file
View file

@ -0,0 +1,550 @@
pub mod cli;
pub mod discord;
pub mod imessage;
pub mod matrix;
pub mod slack;
pub mod telegram;
pub mod traits;
pub use cli::CliChannel;
pub use discord::DiscordChannel;
pub use imessage::IMessageChannel;
pub use matrix::MatrixChannel;
pub use slack::SlackChannel;
pub use telegram::TelegramChannel;
pub use traits::Channel;
use crate::config::Config;
use crate::memory::{self, Memory};
use crate::providers::{self, Provider};
use anyhow::Result;
use std::sync::Arc;
/// Maximum characters per injected workspace file (matches `OpenClaw` default).
const BOOTSTRAP_MAX_CHARS: usize = 20_000;
/// Load workspace identity files and build a system prompt.
///
/// Follows the `OpenClaw` framework structure:
/// 1. Tooling — tool list + descriptions
/// 2. Safety — guardrail reminder
/// 3. Skills — compact list with paths (loaded on-demand)
/// 4. Workspace — working directory
/// 5. Bootstrap files — AGENTS, SOUL, TOOLS, IDENTITY, USER, HEARTBEAT, BOOTSTRAP, MEMORY
/// 6. Date & Time — timezone for cache stability
/// 7. Runtime — host, OS, model
///
/// Daily memory files (`memory/*.md`) are NOT injected — they are accessed
/// on-demand via `memory_recall` / `memory_search` tools.
pub fn build_system_prompt(
workspace_dir: &std::path::Path,
model_name: &str,
tools: &[(&str, &str)],
skills: &[crate::skills::Skill],
) -> String {
use std::fmt::Write;
let mut prompt = String::with_capacity(8192);
// ── 1. Tooling ──────────────────────────────────────────────
if !tools.is_empty() {
prompt.push_str("## Tools\n\n");
prompt.push_str("You have access to the following tools:\n\n");
for (name, desc) in tools {
let _ = writeln!(prompt, "- **{name}**: {desc}");
}
prompt.push('\n');
}
// ── 2. Safety ───────────────────────────────────────────────
prompt.push_str("## Safety\n\n");
prompt.push_str(
"- Do not exfiltrate private data.\n\
- Do not run destructive commands without asking.\n\
- Do not bypass oversight or approval mechanisms.\n\
- Prefer `trash` over `rm` (recoverable beats gone forever).\n\
- When in doubt, ask before acting externally.\n\n",
);
// ── 3. Skills (compact list — load on-demand) ───────────────
if !skills.is_empty() {
prompt.push_str("## Available Skills\n\n");
prompt.push_str(
"Skills are loaded on demand. Use `read` on the skill path to get full instructions.\n\n",
);
prompt.push_str("<available_skills>\n");
for skill in skills {
let _ = writeln!(prompt, " <skill>");
let _ = writeln!(prompt, " <name>{}</name>", skill.name);
let _ = writeln!(prompt, " <description>{}</description>", skill.description);
let location = workspace_dir.join("skills").join(&skill.name).join("SKILL.md");
let _ = writeln!(prompt, " <location>{}</location>", location.display());
let _ = writeln!(prompt, " </skill>");
}
prompt.push_str("</available_skills>\n\n");
}
// ── 4. Workspace ────────────────────────────────────────────
let _ = writeln!(prompt, "## Workspace\n\nWorking directory: `{}`\n", workspace_dir.display());
// ── 5. Bootstrap files (injected into context) ──────────────
prompt.push_str("## Project Context\n\n");
prompt.push_str("The following workspace files define your identity, behavior, and context.\n\n");
let bootstrap_files = [
"AGENTS.md",
"SOUL.md",
"TOOLS.md",
"IDENTITY.md",
"USER.md",
"HEARTBEAT.md",
];
for filename in &bootstrap_files {
inject_workspace_file(&mut prompt, workspace_dir, filename);
}
// BOOTSTRAP.md — only if it exists (first-run ritual)
let bootstrap_path = workspace_dir.join("BOOTSTRAP.md");
if bootstrap_path.exists() {
inject_workspace_file(&mut prompt, workspace_dir, "BOOTSTRAP.md");
}
// MEMORY.md — curated long-term memory (main session only)
inject_workspace_file(&mut prompt, workspace_dir, "MEMORY.md");
// ── 6. Date & Time ──────────────────────────────────────────
let now = chrono::Local::now();
let tz = now.format("%Z").to_string();
let _ = writeln!(prompt, "## Current Date & Time\n\nTimezone: {tz}\n");
// ── 7. Runtime ──────────────────────────────────────────────
let host = hostname::get()
.map_or_else(|_| "unknown".into(), |h| h.to_string_lossy().to_string());
let _ = writeln!(
prompt,
"## Runtime\n\nHost: {host} | OS: {} | Model: {model_name}\n",
std::env::consts::OS,
);
if prompt.is_empty() {
"You are ZeroClaw, a fast and efficient AI assistant built in Rust. Be helpful, concise, and direct.".to_string()
} else {
prompt
}
}
/// Inject a single workspace file into the prompt with truncation and missing-file markers.
fn inject_workspace_file(prompt: &mut String, workspace_dir: &std::path::Path, filename: &str) {
use std::fmt::Write;
let path = workspace_dir.join(filename);
match std::fs::read_to_string(&path) {
Ok(content) => {
let trimmed = content.trim();
if trimmed.is_empty() {
return;
}
let _ = writeln!(prompt, "### {filename}\n");
if trimmed.len() > BOOTSTRAP_MAX_CHARS {
prompt.push_str(&trimmed[..BOOTSTRAP_MAX_CHARS]);
let _ = writeln!(
prompt,
"\n\n[... truncated at {BOOTSTRAP_MAX_CHARS} chars — use `read` for full file]\n"
);
} else {
prompt.push_str(trimmed);
prompt.push_str("\n\n");
}
}
Err(_) => {
// Missing-file marker (matches OpenClaw behavior)
let _ = writeln!(prompt, "### {filename}\n\n[File not found: {filename}]\n");
}
}
}
pub fn handle_command(command: super::ChannelCommands, config: &Config) -> Result<()> {
match command {
super::ChannelCommands::Start => {
// Handled in main.rs (needs async), this is unreachable
unreachable!("Start is handled in main.rs")
}
super::ChannelCommands::List => {
println!("Channels:");
println!(" ✅ CLI (always available)");
for (name, configured) in [
("Telegram", config.channels_config.telegram.is_some()),
("Discord", config.channels_config.discord.is_some()),
("Slack", config.channels_config.slack.is_some()),
("Webhook", config.channels_config.webhook.is_some()),
("iMessage", config.channels_config.imessage.is_some()),
("Matrix", config.channels_config.matrix.is_some()),
] {
println!(
" {} {name}",
if configured { "" } else { "" }
);
}
println!("\nTo start channels: zeroclaw channel start");
println!("To configure: zeroclaw onboard");
Ok(())
}
super::ChannelCommands::Add {
channel_type,
config: _,
} => {
anyhow::bail!("Channel type '{channel_type}' — use `zeroclaw onboard` to configure channels");
}
super::ChannelCommands::Remove { name } => {
anyhow::bail!("Remove channel '{name}' — edit ~/.zeroclaw/config.toml directly");
}
}
}
/// Start all configured channels and route messages to the agent
#[allow(clippy::too_many_lines)]
pub async fn start_channels(config: Config) -> Result<()> {
let provider: Arc<dyn Provider> = Arc::from(providers::create_provider(
config.default_provider.as_deref().unwrap_or("openrouter"),
config.api_key.as_deref(),
)?);
let model = config
.default_model
.clone()
.unwrap_or_else(|| "anthropic/claude-sonnet-4-20250514".into());
let temperature = config.default_temperature;
let mem: Arc<dyn Memory> =
Arc::from(memory::create_memory(&config.memory, &config.workspace_dir)?);
// Build system prompt from workspace identity files + skills
let workspace = config.workspace_dir.clone();
let skills = crate::skills::load_skills(&workspace);
// Collect tool descriptions for the prompt
let tool_descs: Vec<(&str, &str)> = vec![
("shell", "Execute terminal commands"),
("file_read", "Read file contents"),
("file_write", "Write file contents"),
("memory_store", "Save to memory"),
("memory_recall", "Search memory"),
("memory_forget", "Delete a memory entry"),
];
let system_prompt = build_system_prompt(&workspace, &model, &tool_descs, &skills);
if !skills.is_empty() {
println!(" 🧩 Skills: {}", skills.iter().map(|s| s.name.as_str()).collect::<Vec<_>>().join(", "));
}
// Collect active channels
let mut channels: Vec<Arc<dyn Channel>> = Vec::new();
if let Some(ref tg) = config.channels_config.telegram {
channels.push(Arc::new(TelegramChannel::new(
tg.bot_token.clone(),
tg.allowed_users.clone(),
)));
}
if let Some(ref dc) = config.channels_config.discord {
channels.push(Arc::new(DiscordChannel::new(
dc.bot_token.clone(),
dc.guild_id.clone(),
)));
}
if let Some(ref sl) = config.channels_config.slack {
channels.push(Arc::new(SlackChannel::new(
sl.bot_token.clone(),
sl.channel_id.clone(),
)));
}
if let Some(ref im) = config.channels_config.imessage {
channels.push(Arc::new(IMessageChannel::new(
im.allowed_contacts.clone(),
)));
}
if let Some(ref mx) = config.channels_config.matrix {
channels.push(Arc::new(MatrixChannel::new(
mx.homeserver.clone(),
mx.access_token.clone(),
mx.room_id.clone(),
mx.allowed_users.clone(),
)));
}
if channels.is_empty() {
println!("No channels configured. Run `zeroclaw onboard` to set up channels.");
return Ok(());
}
println!("🦀 ZeroClaw Channel Server");
println!(" 🤖 Model: {model}");
println!(" 🧠 Memory: {} (auto-save: {})", config.memory.backend, if config.memory.auto_save { "on" } else { "off" });
println!(" 📡 Channels: {}", channels.iter().map(|c| c.name()).collect::<Vec<_>>().join(", "));
println!();
println!(" Listening for messages... (Ctrl+C to stop)");
println!();
// Single message bus — all channels send messages here
let (tx, mut rx) = tokio::sync::mpsc::channel::<traits::ChannelMessage>(100);
// Spawn a listener for each channel
let mut handles = Vec::new();
for ch in &channels {
let ch = ch.clone();
let tx = tx.clone();
handles.push(tokio::spawn(async move {
if let Err(e) = ch.listen(tx).await {
tracing::error!("Channel {} error: {e}", ch.name());
}
}));
}
drop(tx); // Drop our copy so rx closes when all channels stop
// Process incoming messages — call the LLM and reply
while let Some(msg) = rx.recv().await {
println!(
" 💬 [{}] from {}: {}",
msg.channel,
msg.sender,
if msg.content.len() > 80 {
format!("{}...", &msg.content[..80])
} else {
msg.content.clone()
}
);
// Auto-save to memory
if config.memory.auto_save {
let _ = mem
.store(
&format!("{}_{}", msg.channel, msg.sender),
&msg.content,
crate::memory::MemoryCategory::Conversation,
)
.await;
}
// Call the LLM with system prompt (identity + soul + tools)
match provider.chat_with_system(Some(&system_prompt), &msg.content, &model, temperature).await {
Ok(response) => {
println!(
" 🤖 Reply: {}",
if response.len() > 80 {
format!("{}...", &response[..80])
} else {
response.clone()
}
);
// Find the channel that sent this message and reply
for ch in &channels {
if ch.name() == msg.channel {
if let Err(e) = ch.send(&response, &msg.sender).await {
eprintln!(" ❌ Failed to reply on {}: {e}", ch.name());
}
break;
}
}
}
Err(e) => {
eprintln!(" ❌ LLM error: {e}");
for ch in &channels {
if ch.name() == msg.channel {
let _ = ch
.send(&format!("⚠️ Error: {e}"), &msg.sender)
.await;
break;
}
}
}
}
}
// Wait for all channel tasks
for h in handles {
let _ = h.await;
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
fn make_workspace() -> TempDir {
let tmp = TempDir::new().unwrap();
// Create minimal workspace files
std::fs::write(tmp.path().join("SOUL.md"), "# Soul\nBe helpful.").unwrap();
std::fs::write(tmp.path().join("IDENTITY.md"), "# Identity\nName: ZeroClaw").unwrap();
std::fs::write(tmp.path().join("USER.md"), "# User\nName: Test User").unwrap();
std::fs::write(tmp.path().join("AGENTS.md"), "# Agents\nFollow instructions.").unwrap();
std::fs::write(tmp.path().join("TOOLS.md"), "# Tools\nUse shell carefully.").unwrap();
std::fs::write(tmp.path().join("HEARTBEAT.md"), "# Heartbeat\nCheck status.").unwrap();
std::fs::write(tmp.path().join("MEMORY.md"), "# Memory\nUser likes Rust.").unwrap();
tmp
}
#[test]
fn prompt_contains_all_sections() {
let ws = make_workspace();
let tools = vec![("shell", "Run commands"), ("file_read", "Read files")];
let prompt = build_system_prompt(ws.path(), "test-model", &tools, &[]);
// Section headers
assert!(prompt.contains("## Tools"), "missing Tools section");
assert!(prompt.contains("## Safety"), "missing Safety section");
assert!(prompt.contains("## Workspace"), "missing Workspace section");
assert!(prompt.contains("## Project Context"), "missing Project Context");
assert!(prompt.contains("## Current Date & Time"), "missing Date/Time");
assert!(prompt.contains("## Runtime"), "missing Runtime section");
}
#[test]
fn prompt_injects_tools() {
let ws = make_workspace();
let tools = vec![("shell", "Run commands"), ("memory_recall", "Search memory")];
let prompt = build_system_prompt(ws.path(), "gpt-4o", &tools, &[]);
assert!(prompt.contains("**shell**"));
assert!(prompt.contains("Run commands"));
assert!(prompt.contains("**memory_recall**"));
}
#[test]
fn prompt_injects_safety() {
let ws = make_workspace();
let prompt = build_system_prompt(ws.path(), "model", &[], &[]);
assert!(prompt.contains("Do not exfiltrate private data"));
assert!(prompt.contains("Do not run destructive commands"));
assert!(prompt.contains("Prefer `trash` over `rm`"));
}
#[test]
fn prompt_injects_workspace_files() {
let ws = make_workspace();
let prompt = build_system_prompt(ws.path(), "model", &[], &[]);
assert!(prompt.contains("### SOUL.md"), "missing SOUL.md header");
assert!(prompt.contains("Be helpful"), "missing SOUL content");
assert!(prompt.contains("### IDENTITY.md"), "missing IDENTITY.md");
assert!(prompt.contains("Name: ZeroClaw"), "missing IDENTITY content");
assert!(prompt.contains("### USER.md"), "missing USER.md");
assert!(prompt.contains("### AGENTS.md"), "missing AGENTS.md");
assert!(prompt.contains("### TOOLS.md"), "missing TOOLS.md");
assert!(prompt.contains("### HEARTBEAT.md"), "missing HEARTBEAT.md");
assert!(prompt.contains("### MEMORY.md"), "missing MEMORY.md");
assert!(prompt.contains("User likes Rust"), "missing MEMORY content");
}
#[test]
fn prompt_missing_file_markers() {
let tmp = TempDir::new().unwrap();
// Empty workspace — no files at all
let prompt = build_system_prompt(tmp.path(), "model", &[], &[]);
assert!(prompt.contains("[File not found: SOUL.md]"));
assert!(prompt.contains("[File not found: AGENTS.md]"));
assert!(prompt.contains("[File not found: IDENTITY.md]"));
}
#[test]
fn prompt_bootstrap_only_if_exists() {
let ws = make_workspace();
// No BOOTSTRAP.md — should not appear
let prompt = build_system_prompt(ws.path(), "model", &[], &[]);
assert!(!prompt.contains("### BOOTSTRAP.md"), "BOOTSTRAP.md should not appear when missing");
// Create BOOTSTRAP.md — should appear
std::fs::write(ws.path().join("BOOTSTRAP.md"), "# Bootstrap\nFirst run.").unwrap();
let prompt2 = build_system_prompt(ws.path(), "model", &[], &[]);
assert!(prompt2.contains("### BOOTSTRAP.md"), "BOOTSTRAP.md should appear when present");
assert!(prompt2.contains("First run"));
}
#[test]
fn prompt_no_daily_memory_injection() {
let ws = make_workspace();
let memory_dir = ws.path().join("memory");
std::fs::create_dir_all(&memory_dir).unwrap();
let today = chrono::Local::now().format("%Y-%m-%d").to_string();
std::fs::write(memory_dir.join(format!("{today}.md")), "# Daily\nSome note.").unwrap();
let prompt = build_system_prompt(ws.path(), "model", &[], &[]);
// Daily notes should NOT be in the system prompt (on-demand via tools)
assert!(!prompt.contains("Daily Notes"), "daily notes should not be auto-injected");
assert!(!prompt.contains("Some note"), "daily content should not be in prompt");
}
#[test]
fn prompt_runtime_metadata() {
let ws = make_workspace();
let prompt = build_system_prompt(ws.path(), "claude-sonnet-4", &[], &[]);
assert!(prompt.contains("Model: claude-sonnet-4"));
assert!(prompt.contains(&format!("OS: {}", std::env::consts::OS)));
assert!(prompt.contains("Host:"));
}
#[test]
fn prompt_skills_compact_list() {
let ws = make_workspace();
let skills = vec![crate::skills::Skill {
name: "code-review".into(),
description: "Review code for bugs".into(),
version: "1.0.0".into(),
author: None,
tags: vec![],
tools: vec![],
prompts: vec!["Long prompt content that should NOT appear in system prompt".into()],
}];
let prompt = build_system_prompt(ws.path(), "model", &[], &skills);
assert!(prompt.contains("<available_skills>"), "missing skills XML");
assert!(prompt.contains("<name>code-review</name>"));
assert!(prompt.contains("<description>Review code for bugs</description>"));
assert!(prompt.contains("SKILL.md</location>"));
assert!(prompt.contains("loaded on demand"), "should mention on-demand loading");
// Full prompt content should NOT be dumped
assert!(!prompt.contains("Long prompt content that should NOT appear"));
}
#[test]
fn prompt_truncation() {
let ws = make_workspace();
// Write a file larger than BOOTSTRAP_MAX_CHARS
let big_content = "x".repeat(BOOTSTRAP_MAX_CHARS + 1000);
std::fs::write(ws.path().join("AGENTS.md"), &big_content).unwrap();
let prompt = build_system_prompt(ws.path(), "model", &[], &[]);
assert!(prompt.contains("truncated at"), "large files should be truncated");
assert!(!prompt.contains(&big_content), "full content should not appear");
}
#[test]
fn prompt_empty_files_skipped() {
let ws = make_workspace();
std::fs::write(ws.path().join("TOOLS.md"), "").unwrap();
let prompt = build_system_prompt(ws.path(), "model", &[], &[]);
// Empty file should not produce a header
assert!(!prompt.contains("### TOOLS.md"), "empty files should be skipped");
}
#[test]
fn prompt_workspace_path() {
let ws = make_workspace();
let prompt = build_system_prompt(ws.path(), "model", &[], &[]);
assert!(prompt.contains(&format!("Working directory: `{}`", ws.path().display())));
}
}

174
src/channels/slack.rs Normal file
View file

@ -0,0 +1,174 @@
use super::traits::{Channel, ChannelMessage};
use async_trait::async_trait;
use uuid::Uuid;
/// Slack channel — polls conversations.history via Web API
pub struct SlackChannel {
bot_token: String,
channel_id: Option<String>,
client: reqwest::Client,
}
impl SlackChannel {
pub fn new(bot_token: String, channel_id: Option<String>) -> Self {
Self {
bot_token,
channel_id,
client: reqwest::Client::new(),
}
}
/// Get the bot's own user ID so we can ignore our own messages
async fn get_bot_user_id(&self) -> Option<String> {
let resp: serde_json::Value = self
.client
.get("https://slack.com/api/auth.test")
.bearer_auth(&self.bot_token)
.send()
.await
.ok()?
.json()
.await
.ok()?;
resp.get("user_id")
.and_then(|u| u.as_str())
.map(String::from)
}
}
#[async_trait]
impl Channel for SlackChannel {
fn name(&self) -> &str {
"slack"
}
async fn send(&self, message: &str, channel: &str) -> anyhow::Result<()> {
let body = serde_json::json!({
"channel": channel,
"text": message
});
self.client
.post("https://slack.com/api/chat.postMessage")
.bearer_auth(&self.bot_token)
.json(&body)
.send()
.await?;
Ok(())
}
async fn listen(&self, tx: tokio::sync::mpsc::Sender<ChannelMessage>) -> anyhow::Result<()> {
let channel_id = self
.channel_id
.clone()
.ok_or_else(|| anyhow::anyhow!("Slack channel_id required for listening"))?;
let bot_user_id = self.get_bot_user_id().await.unwrap_or_default();
let mut last_ts = String::new();
tracing::info!("Slack channel listening on #{channel_id}...");
loop {
tokio::time::sleep(std::time::Duration::from_secs(3)).await;
let mut params = vec![
("channel", channel_id.clone()),
("limit", "10".to_string()),
];
if !last_ts.is_empty() {
params.push(("oldest", last_ts.clone()));
}
let resp = match self
.client
.get("https://slack.com/api/conversations.history")
.bearer_auth(&self.bot_token)
.query(&params)
.send()
.await
{
Ok(r) => r,
Err(e) => {
tracing::warn!("Slack poll error: {e}");
continue;
}
};
let data: serde_json::Value = match resp.json().await {
Ok(d) => d,
Err(e) => {
tracing::warn!("Slack parse error: {e}");
continue;
}
};
if let Some(messages) = data.get("messages").and_then(|m| m.as_array()) {
// Messages come newest-first, reverse to process oldest first
for msg in messages.iter().rev() {
let ts = msg.get("ts").and_then(|t| t.as_str()).unwrap_or("");
let user = msg
.get("user")
.and_then(|u| u.as_str())
.unwrap_or("unknown");
let text = msg.get("text").and_then(|t| t.as_str()).unwrap_or("");
// Skip bot's own messages
if user == bot_user_id {
continue;
}
// Skip empty or already-seen
if text.is_empty() || ts <= last_ts.as_str() {
continue;
}
last_ts = ts.to_string();
let channel_msg = ChannelMessage {
id: Uuid::new_v4().to_string(),
sender: channel_id.clone(),
content: text.to_string(),
channel: "slack".to_string(),
timestamp: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
};
if tx.send(channel_msg).await.is_err() {
return Ok(());
}
}
}
}
}
async fn health_check(&self) -> bool {
self.client
.get("https://slack.com/api/auth.test")
.bearer_auth(&self.bot_token)
.send()
.await
.map(|r| r.status().is_success())
.unwrap_or(false)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn slack_channel_name() {
let ch = SlackChannel::new("xoxb-fake".into(), None);
assert_eq!(ch.name(), "slack");
}
#[test]
fn slack_channel_with_channel_id() {
let ch = SlackChannel::new("xoxb-fake".into(), Some("C12345".into()));
assert_eq!(ch.channel_id, Some("C12345".to_string()));
}
}

182
src/channels/telegram.rs Normal file
View file

@ -0,0 +1,182 @@
use super::traits::{Channel, ChannelMessage};
use async_trait::async_trait;
use uuid::Uuid;
/// Telegram channel — long-polls the Bot API for updates
pub struct TelegramChannel {
bot_token: String,
allowed_users: Vec<String>,
client: reqwest::Client,
}
impl TelegramChannel {
pub fn new(bot_token: String, allowed_users: Vec<String>) -> Self {
Self {
bot_token,
allowed_users,
client: reqwest::Client::new(),
}
}
fn api_url(&self, method: &str) -> String {
format!("https://api.telegram.org/bot{}/{method}", self.bot_token)
}
fn is_user_allowed(&self, username: &str) -> bool {
self.allowed_users.iter().any(|u| u == "*" || u == username)
}
}
#[async_trait]
impl Channel for TelegramChannel {
fn name(&self) -> &str {
"telegram"
}
async fn send(&self, message: &str, chat_id: &str) -> anyhow::Result<()> {
let body = serde_json::json!({
"chat_id": chat_id,
"text": message,
"parse_mode": "Markdown"
});
self.client
.post(self.api_url("sendMessage"))
.json(&body)
.send()
.await?;
Ok(())
}
async fn listen(&self, tx: tokio::sync::mpsc::Sender<ChannelMessage>) -> anyhow::Result<()> {
let mut offset: i64 = 0;
tracing::info!("Telegram channel listening for messages...");
loop {
let url = self.api_url("getUpdates");
let body = serde_json::json!({
"offset": offset,
"timeout": 30,
"allowed_updates": ["message"]
});
let resp = match self.client.post(&url).json(&body).send().await {
Ok(r) => r,
Err(e) => {
tracing::warn!("Telegram poll error: {e}");
tokio::time::sleep(std::time::Duration::from_secs(5)).await;
continue;
}
};
let data: serde_json::Value = match resp.json().await {
Ok(d) => d,
Err(e) => {
tracing::warn!("Telegram parse error: {e}");
tokio::time::sleep(std::time::Duration::from_secs(5)).await;
continue;
}
};
if let Some(results) = data.get("result").and_then(serde_json::Value::as_array) {
for update in results {
// Advance offset past this update
if let Some(uid) = update.get("update_id").and_then(serde_json::Value::as_i64) {
offset = uid + 1;
}
let Some(message) = update.get("message") else {
continue;
};
let Some(text) = message.get("text").and_then(serde_json::Value::as_str) else {
continue;
};
let username = message
.get("from")
.and_then(|f| f.get("username"))
.and_then(|u| u.as_str())
.unwrap_or("unknown");
if !self.is_user_allowed(username) {
tracing::warn!("Telegram: ignoring message from unauthorized user: {username}");
continue;
}
let chat_id = message
.get("chat")
.and_then(|c| c.get("id"))
.and_then(serde_json::Value::as_i64)
.map(|id| id.to_string())
.unwrap_or_default();
let msg = ChannelMessage {
id: Uuid::new_v4().to_string(),
sender: chat_id,
content: text.to_string(),
channel: "telegram".to_string(),
timestamp: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
};
if tx.send(msg).await.is_err() {
return Ok(());
}
}
}
}
}
async fn health_check(&self) -> bool {
self.client
.get(self.api_url("getMe"))
.send()
.await
.map(|r| r.status().is_success())
.unwrap_or(false)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn telegram_channel_name() {
let ch = TelegramChannel::new("fake-token".into(), vec!["*".into()]);
assert_eq!(ch.name(), "telegram");
}
#[test]
fn telegram_api_url() {
let ch = TelegramChannel::new("123:ABC".into(), vec![]);
assert_eq!(
ch.api_url("getMe"),
"https://api.telegram.org/bot123:ABC/getMe"
);
}
#[test]
fn telegram_user_allowed_wildcard() {
let ch = TelegramChannel::new("t".into(), vec!["*".into()]);
assert!(ch.is_user_allowed("anyone"));
}
#[test]
fn telegram_user_allowed_specific() {
let ch = TelegramChannel::new("t".into(), vec!["alice".into(), "bob".into()]);
assert!(ch.is_user_allowed("alice"));
assert!(!ch.is_user_allowed("eve"));
}
#[test]
fn telegram_user_denied_empty() {
let ch = TelegramChannel::new("t".into(), vec![]);
assert!(!ch.is_user_allowed("anyone"));
}
}

29
src/channels/traits.rs Normal file
View file

@ -0,0 +1,29 @@
use async_trait::async_trait;
/// A message received from or sent to a channel
#[derive(Debug, Clone)]
pub struct ChannelMessage {
pub id: String,
pub sender: String,
pub content: String,
pub channel: String,
pub timestamp: u64,
}
/// Core channel trait — implement for any messaging platform
#[async_trait]
pub trait Channel: Send + Sync {
/// Human-readable channel name
fn name(&self) -> &str;
/// Send a message through this channel
async fn send(&self, message: &str, recipient: &str) -> anyhow::Result<()>;
/// Start listening for incoming messages (long-running)
async fn listen(&self, tx: tokio::sync::mpsc::Sender<ChannelMessage>) -> anyhow::Result<()>;
/// Check if channel is healthy
async fn health_check(&self) -> bool {
true
}
}

7
src/config/mod.rs Normal file
View file

@ -0,0 +1,7 @@
pub mod schema;
pub use schema::{
AutonomyConfig, ChannelsConfig, Config, DiscordConfig, HeartbeatConfig, IMessageConfig,
MatrixConfig, MemoryConfig, ObservabilityConfig, RuntimeConfig, SlackConfig, TelegramConfig,
WebhookConfig,
};

580
src/config/schema.rs Normal file
View file

@ -0,0 +1,580 @@
use crate::security::AutonomyLevel;
use anyhow::{Context, Result};
use directories::UserDirs;
use serde::{Deserialize, Serialize};
use std::fs;
use std::path::PathBuf;
// ── Top-level config ──────────────────────────────────────────────
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Config {
pub workspace_dir: PathBuf,
pub config_path: PathBuf,
pub api_key: Option<String>,
pub default_provider: Option<String>,
pub default_model: Option<String>,
pub default_temperature: f64,
#[serde(default)]
pub observability: ObservabilityConfig,
#[serde(default)]
pub autonomy: AutonomyConfig,
#[serde(default)]
pub runtime: RuntimeConfig,
#[serde(default)]
pub heartbeat: HeartbeatConfig,
#[serde(default)]
pub channels_config: ChannelsConfig,
#[serde(default)]
pub memory: MemoryConfig,
}
// ── Memory ───────────────────────────────────────────────────
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MemoryConfig {
/// "sqlite" | "markdown" | "none"
pub backend: String,
/// Auto-save conversation context to memory
pub auto_save: bool,
}
impl Default for MemoryConfig {
fn default() -> Self {
Self {
backend: "sqlite".into(),
auto_save: true,
}
}
}
// ── Observability ─────────────────────────────────────────────────
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ObservabilityConfig {
/// "none" | "log" | "prometheus" | "otel"
pub backend: String,
}
impl Default for ObservabilityConfig {
fn default() -> Self {
Self {
backend: "none".into(),
}
}
}
// ── Autonomy / Security ──────────────────────────────────────────
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AutonomyConfig {
pub level: AutonomyLevel,
pub workspace_only: bool,
pub allowed_commands: Vec<String>,
pub forbidden_paths: Vec<String>,
pub max_actions_per_hour: u32,
pub max_cost_per_day_cents: u32,
}
impl Default for AutonomyConfig {
fn default() -> Self {
Self {
level: AutonomyLevel::Supervised,
workspace_only: true,
allowed_commands: vec![
"git".into(),
"npm".into(),
"cargo".into(),
"ls".into(),
"cat".into(),
"grep".into(),
"find".into(),
"echo".into(),
"pwd".into(),
"wc".into(),
"head".into(),
"tail".into(),
],
forbidden_paths: vec![
"/etc".into(),
"/root".into(),
"~/.ssh".into(),
"~/.gnupg".into(),
],
max_actions_per_hour: 20,
max_cost_per_day_cents: 500,
}
}
}
// ── Runtime ──────────────────────────────────────────────────────
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RuntimeConfig {
/// "native" | "docker" | "cloudflare"
pub kind: String,
}
impl Default for RuntimeConfig {
fn default() -> Self {
Self {
kind: "native".into(),
}
}
}
// ── Heartbeat ────────────────────────────────────────────────────
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HeartbeatConfig {
pub enabled: bool,
pub interval_minutes: u32,
}
impl Default for HeartbeatConfig {
fn default() -> Self {
Self {
enabled: false,
interval_minutes: 30,
}
}
}
// ── Channels ─────────────────────────────────────────────────────
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ChannelsConfig {
pub cli: bool,
pub telegram: Option<TelegramConfig>,
pub discord: Option<DiscordConfig>,
pub slack: Option<SlackConfig>,
pub webhook: Option<WebhookConfig>,
pub imessage: Option<IMessageConfig>,
pub matrix: Option<MatrixConfig>,
}
impl Default for ChannelsConfig {
fn default() -> Self {
Self {
cli: true,
telegram: None,
discord: None,
slack: None,
webhook: None,
imessage: None,
matrix: None,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TelegramConfig {
pub bot_token: String,
pub allowed_users: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DiscordConfig {
pub bot_token: String,
pub guild_id: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SlackConfig {
pub bot_token: String,
pub app_token: Option<String>,
pub channel_id: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WebhookConfig {
pub port: u16,
pub secret: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct IMessageConfig {
pub allowed_contacts: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MatrixConfig {
pub homeserver: String,
pub access_token: String,
pub room_id: String,
pub allowed_users: Vec<String>,
}
// ── Config impl ──────────────────────────────────────────────────
impl Default for Config {
fn default() -> Self {
let home =
UserDirs::new().map_or_else(|| PathBuf::from("."), |u| u.home_dir().to_path_buf());
let zeroclaw_dir = home.join(".zeroclaw");
Self {
workspace_dir: zeroclaw_dir.join("workspace"),
config_path: zeroclaw_dir.join("config.toml"),
api_key: None,
default_provider: Some("openrouter".to_string()),
default_model: Some("anthropic/claude-sonnet-4-20250514".to_string()),
default_temperature: 0.7,
observability: ObservabilityConfig::default(),
autonomy: AutonomyConfig::default(),
runtime: RuntimeConfig::default(),
heartbeat: HeartbeatConfig::default(),
channels_config: ChannelsConfig::default(),
memory: MemoryConfig::default(),
}
}
}
impl Config {
pub fn load_or_init() -> Result<Self> {
let home = UserDirs::new()
.map(|u| u.home_dir().to_path_buf())
.context("Could not find home directory")?;
let zeroclaw_dir = home.join(".zeroclaw");
let config_path = zeroclaw_dir.join("config.toml");
if !zeroclaw_dir.exists() {
fs::create_dir_all(&zeroclaw_dir).context("Failed to create .zeroclaw directory")?;
fs::create_dir_all(zeroclaw_dir.join("workspace"))
.context("Failed to create workspace directory")?;
}
if config_path.exists() {
let contents =
fs::read_to_string(&config_path).context("Failed to read config file")?;
let config: Config =
toml::from_str(&contents).context("Failed to parse config file")?;
Ok(config)
} else {
let config = Config::default();
config.save()?;
Ok(config)
}
}
pub fn save(&self) -> Result<()> {
let toml_str = toml::to_string_pretty(self).context("Failed to serialize config")?;
fs::write(&self.config_path, toml_str).context("Failed to write config file")?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::path::PathBuf;
// ── Defaults ─────────────────────────────────────────────
#[test]
fn config_default_has_sane_values() {
let c = Config::default();
assert_eq!(c.default_provider.as_deref(), Some("openrouter"));
assert!(c.default_model.as_deref().unwrap().contains("claude"));
assert!((c.default_temperature - 0.7).abs() < f64::EPSILON);
assert!(c.api_key.is_none());
assert!(c.workspace_dir.to_string_lossy().contains("workspace"));
assert!(c.config_path.to_string_lossy().contains("config.toml"));
}
#[test]
fn observability_config_default() {
let o = ObservabilityConfig::default();
assert_eq!(o.backend, "none");
}
#[test]
fn autonomy_config_default() {
let a = AutonomyConfig::default();
assert_eq!(a.level, AutonomyLevel::Supervised);
assert!(a.workspace_only);
assert!(a.allowed_commands.contains(&"git".to_string()));
assert!(a.allowed_commands.contains(&"cargo".to_string()));
assert!(a.forbidden_paths.contains(&"/etc".to_string()));
assert_eq!(a.max_actions_per_hour, 20);
assert_eq!(a.max_cost_per_day_cents, 500);
}
#[test]
fn runtime_config_default() {
let r = RuntimeConfig::default();
assert_eq!(r.kind, "native");
}
#[test]
fn heartbeat_config_default() {
let h = HeartbeatConfig::default();
assert!(!h.enabled);
assert_eq!(h.interval_minutes, 30);
}
#[test]
fn channels_config_default() {
let c = ChannelsConfig::default();
assert!(c.cli);
assert!(c.telegram.is_none());
assert!(c.discord.is_none());
}
// ── Serde round-trip ─────────────────────────────────────
#[test]
fn config_toml_roundtrip() {
let config = Config {
workspace_dir: PathBuf::from("/tmp/test/workspace"),
config_path: PathBuf::from("/tmp/test/config.toml"),
api_key: Some("sk-test-key".into()),
default_provider: Some("openrouter".into()),
default_model: Some("gpt-4o".into()),
default_temperature: 0.5,
observability: ObservabilityConfig {
backend: "log".into(),
},
autonomy: AutonomyConfig {
level: AutonomyLevel::Full,
workspace_only: false,
allowed_commands: vec!["docker".into()],
forbidden_paths: vec!["/secret".into()],
max_actions_per_hour: 50,
max_cost_per_day_cents: 1000,
},
runtime: RuntimeConfig {
kind: "docker".into(),
},
heartbeat: HeartbeatConfig {
enabled: true,
interval_minutes: 15,
},
channels_config: ChannelsConfig {
cli: true,
telegram: Some(TelegramConfig {
bot_token: "123:ABC".into(),
allowed_users: vec!["user1".into()],
}),
discord: None,
slack: None,
webhook: None,
imessage: None,
matrix: None,
},
memory: MemoryConfig::default(),
};
let toml_str = toml::to_string_pretty(&config).unwrap();
let parsed: Config = toml::from_str(&toml_str).unwrap();
assert_eq!(parsed.api_key, config.api_key);
assert_eq!(parsed.default_provider, config.default_provider);
assert_eq!(parsed.default_model, config.default_model);
assert!((parsed.default_temperature - config.default_temperature).abs() < f64::EPSILON);
assert_eq!(parsed.observability.backend, "log");
assert_eq!(parsed.autonomy.level, AutonomyLevel::Full);
assert!(!parsed.autonomy.workspace_only);
assert_eq!(parsed.runtime.kind, "docker");
assert!(parsed.heartbeat.enabled);
assert_eq!(parsed.heartbeat.interval_minutes, 15);
assert!(parsed.channels_config.telegram.is_some());
assert_eq!(
parsed.channels_config.telegram.unwrap().bot_token,
"123:ABC"
);
}
#[test]
fn config_minimal_toml_uses_defaults() {
let minimal = r#"
workspace_dir = "/tmp/ws"
config_path = "/tmp/config.toml"
default_temperature = 0.7
"#;
let parsed: Config = toml::from_str(minimal).unwrap();
assert!(parsed.api_key.is_none());
assert!(parsed.default_provider.is_none());
assert_eq!(parsed.observability.backend, "none");
assert_eq!(parsed.autonomy.level, AutonomyLevel::Supervised);
assert_eq!(parsed.runtime.kind, "native");
assert!(!parsed.heartbeat.enabled);
assert!(parsed.channels_config.cli);
}
#[test]
fn config_save_and_load_tmpdir() {
let dir = std::env::temp_dir().join("zeroclaw_test_config");
let _ = fs::remove_dir_all(&dir);
fs::create_dir_all(&dir).unwrap();
let config_path = dir.join("config.toml");
let config = Config {
workspace_dir: dir.join("workspace"),
config_path: config_path.clone(),
api_key: Some("sk-roundtrip".into()),
default_provider: Some("openrouter".into()),
default_model: Some("test-model".into()),
default_temperature: 0.9,
observability: ObservabilityConfig::default(),
autonomy: AutonomyConfig::default(),
runtime: RuntimeConfig::default(),
heartbeat: HeartbeatConfig::default(),
channels_config: ChannelsConfig::default(),
memory: MemoryConfig::default(),
};
config.save().unwrap();
assert!(config_path.exists());
let contents = fs::read_to_string(&config_path).unwrap();
let loaded: Config = toml::from_str(&contents).unwrap();
assert_eq!(loaded.api_key.as_deref(), Some("sk-roundtrip"));
assert_eq!(loaded.default_model.as_deref(), Some("test-model"));
assert!((loaded.default_temperature - 0.9).abs() < f64::EPSILON);
let _ = fs::remove_dir_all(&dir);
}
// ── Telegram / Discord config ────────────────────────────
#[test]
fn telegram_config_serde() {
let tc = TelegramConfig {
bot_token: "123:XYZ".into(),
allowed_users: vec!["alice".into(), "bob".into()],
};
let json = serde_json::to_string(&tc).unwrap();
let parsed: TelegramConfig = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.bot_token, "123:XYZ");
assert_eq!(parsed.allowed_users.len(), 2);
}
#[test]
fn discord_config_serde() {
let dc = DiscordConfig {
bot_token: "discord-token".into(),
guild_id: Some("12345".into()),
};
let json = serde_json::to_string(&dc).unwrap();
let parsed: DiscordConfig = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.bot_token, "discord-token");
assert_eq!(parsed.guild_id.as_deref(), Some("12345"));
}
#[test]
fn discord_config_optional_guild() {
let dc = DiscordConfig {
bot_token: "tok".into(),
guild_id: None,
};
let json = serde_json::to_string(&dc).unwrap();
let parsed: DiscordConfig = serde_json::from_str(&json).unwrap();
assert!(parsed.guild_id.is_none());
}
// ── iMessage / Matrix config ────────────────────────────
#[test]
fn imessage_config_serde() {
let ic = IMessageConfig {
allowed_contacts: vec!["+1234567890".into(), "user@icloud.com".into()],
};
let json = serde_json::to_string(&ic).unwrap();
let parsed: IMessageConfig = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.allowed_contacts.len(), 2);
assert_eq!(parsed.allowed_contacts[0], "+1234567890");
}
#[test]
fn imessage_config_empty_contacts() {
let ic = IMessageConfig {
allowed_contacts: vec![],
};
let json = serde_json::to_string(&ic).unwrap();
let parsed: IMessageConfig = serde_json::from_str(&json).unwrap();
assert!(parsed.allowed_contacts.is_empty());
}
#[test]
fn imessage_config_wildcard() {
let ic = IMessageConfig {
allowed_contacts: vec!["*".into()],
};
let toml_str = toml::to_string(&ic).unwrap();
let parsed: IMessageConfig = toml::from_str(&toml_str).unwrap();
assert_eq!(parsed.allowed_contacts, vec!["*"]);
}
#[test]
fn matrix_config_serde() {
let mc = MatrixConfig {
homeserver: "https://matrix.org".into(),
access_token: "syt_token_abc".into(),
room_id: "!room123:matrix.org".into(),
allowed_users: vec!["@user:matrix.org".into()],
};
let json = serde_json::to_string(&mc).unwrap();
let parsed: MatrixConfig = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.homeserver, "https://matrix.org");
assert_eq!(parsed.access_token, "syt_token_abc");
assert_eq!(parsed.room_id, "!room123:matrix.org");
assert_eq!(parsed.allowed_users.len(), 1);
}
#[test]
fn matrix_config_toml_roundtrip() {
let mc = MatrixConfig {
homeserver: "https://synapse.local:8448".into(),
access_token: "tok".into(),
room_id: "!abc:synapse.local".into(),
allowed_users: vec!["@admin:synapse.local".into(), "*".into()],
};
let toml_str = toml::to_string(&mc).unwrap();
let parsed: MatrixConfig = toml::from_str(&toml_str).unwrap();
assert_eq!(parsed.homeserver, "https://synapse.local:8448");
assert_eq!(parsed.allowed_users.len(), 2);
}
#[test]
fn channels_config_with_imessage_and_matrix() {
let c = ChannelsConfig {
cli: true,
telegram: None,
discord: None,
slack: None,
webhook: None,
imessage: Some(IMessageConfig {
allowed_contacts: vec!["+1".into()],
}),
matrix: Some(MatrixConfig {
homeserver: "https://m.org".into(),
access_token: "tok".into(),
room_id: "!r:m".into(),
allowed_users: vec!["@u:m".into()],
}),
};
let toml_str = toml::to_string_pretty(&c).unwrap();
let parsed: ChannelsConfig = toml::from_str(&toml_str).unwrap();
assert!(parsed.imessage.is_some());
assert!(parsed.matrix.is_some());
assert_eq!(
parsed.imessage.unwrap().allowed_contacts,
vec!["+1"]
);
assert_eq!(parsed.matrix.unwrap().homeserver, "https://m.org");
}
#[test]
fn channels_config_default_has_no_imessage_matrix() {
let c = ChannelsConfig::default();
assert!(c.imessage.is_none());
assert!(c.matrix.is_none());
}
}

25
src/cron/mod.rs Normal file
View file

@ -0,0 +1,25 @@
use crate::config::Config;
use anyhow::Result;
pub fn handle_command(command: super::CronCommands, _config: Config) -> Result<()> {
match command {
super::CronCommands::List => {
println!("No scheduled tasks yet.");
println!("\nUsage:");
println!(" zeroclaw cron add '0 9 * * *' 'agent -m \"Good morning!\"'");
Ok(())
}
super::CronCommands::Add {
expression,
command,
} => {
println!("Cron scheduling coming soon!");
println!(" Expression: {expression}");
println!(" Command: {command}");
Ok(())
}
super::CronCommands::Remove { id } => {
anyhow::bail!("Remove task '{id}' not yet implemented");
}
}
}

180
src/gateway/mod.rs Normal file
View file

@ -0,0 +1,180 @@
use crate::config::Config;
use crate::memory::{self, Memory, MemoryCategory};
use crate::providers::{self, Provider};
use anyhow::Result;
use std::sync::Arc;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpListener;
/// Run a minimal HTTP gateway (webhook + health check)
/// Zero new dependencies — uses raw TCP + tokio.
#[allow(clippy::too_many_lines)]
pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
let addr = format!("{host}:{port}");
let listener = TcpListener::bind(&addr).await?;
let provider: Arc<dyn Provider> = Arc::from(providers::create_provider(
config.default_provider.as_deref().unwrap_or("openrouter"),
config.api_key.as_deref(),
)?);
let model = config
.default_model
.clone()
.unwrap_or_else(|| "anthropic/claude-sonnet-4-20250514".into());
let temperature = config.default_temperature;
let mem: Arc<dyn Memory> =
Arc::from(memory::create_memory(&config.memory, &config.workspace_dir)?);
println!("🦀 ZeroClaw Gateway listening on http://{addr}");
println!(" POST /webhook — {{\"message\": \"your prompt\"}}");
println!(" GET /health — health check");
println!(" Press Ctrl+C to stop.\n");
loop {
let (mut stream, peer) = listener.accept().await?;
let provider = provider.clone();
let model = model.clone();
let mem = mem.clone();
let auto_save = config.memory.auto_save;
tokio::spawn(async move {
let mut buf = vec![0u8; 8192];
let n = match stream.read(&mut buf).await {
Ok(n) if n > 0 => n,
_ => return,
};
let request = String::from_utf8_lossy(&buf[..n]);
let first_line = request.lines().next().unwrap_or("");
let parts: Vec<&str> = first_line.split_whitespace().collect();
if let [method, path, ..] = parts.as_slice() {
tracing::info!("{peer} → {method} {path}");
handle_request(&mut stream, method, path, &request, &provider, &model, temperature, &mem, auto_save).await;
} else {
let _ = send_response(&mut stream, 400, "Bad Request").await;
}
});
}
}
#[allow(clippy::too_many_arguments)]
async fn handle_request(
stream: &mut tokio::net::TcpStream,
method: &str,
path: &str,
request: &str,
provider: &Arc<dyn Provider>,
model: &str,
temperature: f64,
mem: &Arc<dyn Memory>,
auto_save: bool,
) {
match (method, path) {
("GET", "/health") => {
let body = serde_json::json!({
"status": "ok",
"version": env!("CARGO_PKG_VERSION"),
"memory": mem.name(),
"memory_healthy": mem.health_check().await,
});
let _ = send_json(stream, 200, &body).await;
}
("POST", "/webhook") => {
handle_webhook(stream, request, provider, model, temperature, mem, auto_save).await;
}
_ => {
let body = serde_json::json!({
"error": "Not found",
"routes": ["GET /health", "POST /webhook"]
});
let _ = send_json(stream, 404, &body).await;
}
}
}
async fn handle_webhook(
stream: &mut tokio::net::TcpStream,
request: &str,
provider: &Arc<dyn Provider>,
model: &str,
temperature: f64,
mem: &Arc<dyn Memory>,
auto_save: bool,
) {
let body_str = request
.split("\r\n\r\n")
.nth(1)
.or_else(|| request.split("\n\n").nth(1))
.unwrap_or("");
let Ok(parsed) = serde_json::from_str::<serde_json::Value>(body_str) else {
let err = serde_json::json!({"error": "Invalid JSON. Expected: {\"message\": \"...\"}"});
let _ = send_json(stream, 400, &err).await;
return;
};
let Some(message) = parsed.get("message").and_then(|v| v.as_str()) else {
let err = serde_json::json!({"error": "Missing 'message' field in JSON"});
let _ = send_json(stream, 400, &err).await;
return;
};
if auto_save {
let _ = mem
.store("webhook_msg", message, MemoryCategory::Conversation)
.await;
}
match provider.chat(message, model, temperature).await {
Ok(response) => {
let body = serde_json::json!({"response": response, "model": model});
let _ = send_json(stream, 200, &body).await;
}
Err(e) => {
let err = serde_json::json!({"error": format!("LLM error: {e}")});
let _ = send_json(stream, 500, &err).await;
}
}
}
async fn send_response(
stream: &mut tokio::net::TcpStream,
status: u16,
body: &str,
) -> std::io::Result<()> {
let reason = match status {
200 => "OK",
400 => "Bad Request",
404 => "Not Found",
500 => "Internal Server Error",
_ => "Unknown",
};
let response = format!(
"HTTP/1.1 {status} {reason}\r\nContent-Type: text/plain\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{body}",
body.len()
);
stream.write_all(response.as_bytes()).await
}
async fn send_json(
stream: &mut tokio::net::TcpStream,
status: u16,
body: &serde_json::Value,
) -> std::io::Result<()> {
let reason = match status {
200 => "OK",
400 => "Bad Request",
404 => "Not Found",
500 => "Internal Server Error",
_ => "Unknown",
};
let json = serde_json::to_string(body).unwrap_or_default();
let response = format!(
"HTTP/1.1 {status} {reason}\r\nContent-Type: application/json\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{json}",
json.len()
);
stream.write_all(response.as_bytes()).await
}

296
src/heartbeat/engine.rs Normal file
View file

@ -0,0 +1,296 @@
use crate::config::HeartbeatConfig;
use crate::observability::{Observer, ObserverEvent};
use anyhow::Result;
use std::path::Path;
use std::sync::Arc;
use tokio::time::{self, Duration};
use tracing::{info, warn};
/// Heartbeat engine — reads HEARTBEAT.md and executes tasks periodically
pub struct HeartbeatEngine {
config: HeartbeatConfig,
workspace_dir: std::path::PathBuf,
observer: Arc<dyn Observer>,
}
impl HeartbeatEngine {
pub fn new(
config: HeartbeatConfig,
workspace_dir: std::path::PathBuf,
observer: Arc<dyn Observer>,
) -> Self {
Self {
config,
workspace_dir,
observer,
}
}
/// Start the heartbeat loop (runs until cancelled)
pub async fn run(&self) -> Result<()> {
if !self.config.enabled {
info!("Heartbeat disabled");
return Ok(());
}
let interval_mins = self.config.interval_minutes.max(5);
info!("💓 Heartbeat started: every {} minutes", interval_mins);
let mut interval = time::interval(Duration::from_secs(u64::from(interval_mins) * 60));
loop {
interval.tick().await;
self.observer.record_event(&ObserverEvent::HeartbeatTick);
match self.tick().await {
Ok(tasks) => {
if tasks > 0 {
info!("💓 Heartbeat: processed {} tasks", tasks);
}
}
Err(e) => {
warn!("💓 Heartbeat error: {}", e);
self.observer.record_event(&ObserverEvent::Error {
component: "heartbeat".into(),
message: e.to_string(),
});
}
}
}
}
/// Single heartbeat tick — read HEARTBEAT.md and return task count
async fn tick(&self) -> Result<usize> {
let heartbeat_path = self.workspace_dir.join("HEARTBEAT.md");
if !heartbeat_path.exists() {
return Ok(0);
}
let content = tokio::fs::read_to_string(&heartbeat_path).await?;
let tasks = Self::parse_tasks(&content);
Ok(tasks.len())
}
/// Parse tasks from HEARTBEAT.md (lines starting with `- `)
fn parse_tasks(content: &str) -> Vec<String> {
content
.lines()
.filter_map(|line| {
let trimmed = line.trim();
trimmed.strip_prefix("- ").map(ToString::to_string)
})
.collect()
}
/// Create a default HEARTBEAT.md if it doesn't exist
pub async fn ensure_heartbeat_file(workspace_dir: &Path) -> Result<()> {
let path = workspace_dir.join("HEARTBEAT.md");
if !path.exists() {
let default = "# Periodic Tasks\n\n\
# Add tasks below (one per line, starting with `- `)\n\
# The agent will check this file on each heartbeat tick.\n\
#\n\
# Examples:\n\
# - Check my email for important messages\n\
# - Review my calendar for upcoming events\n\
# - Check the weather forecast\n";
tokio::fs::write(&path, default).await?;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_tasks_basic() {
let content = "# Tasks\n\n- Check email\n- Review calendar\nNot a task\n- Third task";
let tasks = HeartbeatEngine::parse_tasks(content);
assert_eq!(tasks.len(), 3);
assert_eq!(tasks[0], "Check email");
assert_eq!(tasks[1], "Review calendar");
assert_eq!(tasks[2], "Third task");
}
#[test]
fn parse_tasks_empty_content() {
assert!(HeartbeatEngine::parse_tasks("").is_empty());
}
#[test]
fn parse_tasks_only_comments() {
let tasks = HeartbeatEngine::parse_tasks("# No tasks here\n\nJust comments\n# Another");
assert!(tasks.is_empty());
}
#[test]
fn parse_tasks_with_leading_whitespace() {
let content = " - Indented task\n\t- Tab indented";
let tasks = HeartbeatEngine::parse_tasks(content);
assert_eq!(tasks.len(), 2);
assert_eq!(tasks[0], "Indented task");
assert_eq!(tasks[1], "Tab indented");
}
#[test]
fn parse_tasks_dash_without_space_ignored() {
let content = "- Real task\n-\n- Another";
let tasks = HeartbeatEngine::parse_tasks(content);
// "-" trimmed = "-", does NOT start with "- " => skipped
// "- Real task" => "Real task"
// "- Another" => "Another"
assert_eq!(tasks.len(), 2);
assert_eq!(tasks[0], "Real task");
assert_eq!(tasks[1], "Another");
}
#[test]
fn parse_tasks_trailing_space_bullet_trimmed_to_dash() {
// "- " trimmed becomes "-" (trim removes trailing space)
// "-" does NOT start with "- " => skipped
let content = "- ";
let tasks = HeartbeatEngine::parse_tasks(content);
assert_eq!(tasks.len(), 0);
}
#[test]
fn parse_tasks_bullet_with_content_after_spaces() {
// "- hello " trimmed becomes "- hello" => starts_with "- " => "hello"
let content = "- hello ";
let tasks = HeartbeatEngine::parse_tasks(content);
assert_eq!(tasks.len(), 1);
assert_eq!(tasks[0], "hello");
}
#[test]
fn parse_tasks_unicode() {
let content = "- Check email 📧\n- Review calendar 📅\n- 日本語タスク";
let tasks = HeartbeatEngine::parse_tasks(content);
assert_eq!(tasks.len(), 3);
assert!(tasks[0].contains("📧"));
assert!(tasks[2].contains("日本語"));
}
#[test]
fn parse_tasks_mixed_markdown() {
let content = "# Periodic Tasks\n\n## Quick\n- Task A\n\n## Long\n- Task B\n\n* Not a dash bullet\n1. Not numbered";
let tasks = HeartbeatEngine::parse_tasks(content);
assert_eq!(tasks.len(), 2);
assert_eq!(tasks[0], "Task A");
assert_eq!(tasks[1], "Task B");
}
#[test]
fn parse_tasks_single_task() {
let tasks = HeartbeatEngine::parse_tasks("- Only one");
assert_eq!(tasks.len(), 1);
assert_eq!(tasks[0], "Only one");
}
#[test]
fn parse_tasks_many_tasks() {
let content: String = (0..100).map(|i| format!("- Task {i}\n")).collect();
let tasks = HeartbeatEngine::parse_tasks(&content);
assert_eq!(tasks.len(), 100);
assert_eq!(tasks[99], "Task 99");
}
#[tokio::test]
async fn ensure_heartbeat_file_creates_file() {
let dir = std::env::temp_dir().join("zeroclaw_test_heartbeat");
let _ = tokio::fs::remove_dir_all(&dir).await;
tokio::fs::create_dir_all(&dir).await.unwrap();
HeartbeatEngine::ensure_heartbeat_file(&dir).await.unwrap();
let path = dir.join("HEARTBEAT.md");
assert!(path.exists());
let content = tokio::fs::read_to_string(&path).await.unwrap();
assert!(content.contains("Periodic Tasks"));
let _ = tokio::fs::remove_dir_all(&dir).await;
}
#[tokio::test]
async fn ensure_heartbeat_file_does_not_overwrite() {
let dir = std::env::temp_dir().join("zeroclaw_test_heartbeat_no_overwrite");
let _ = tokio::fs::remove_dir_all(&dir).await;
tokio::fs::create_dir_all(&dir).await.unwrap();
let path = dir.join("HEARTBEAT.md");
tokio::fs::write(&path, "- My custom task").await.unwrap();
HeartbeatEngine::ensure_heartbeat_file(&dir).await.unwrap();
let content = tokio::fs::read_to_string(&path).await.unwrap();
assert_eq!(content, "- My custom task");
let _ = tokio::fs::remove_dir_all(&dir).await;
}
#[tokio::test]
async fn tick_returns_zero_when_no_file() {
let dir = std::env::temp_dir().join("zeroclaw_test_tick_no_file");
let _ = tokio::fs::remove_dir_all(&dir).await;
tokio::fs::create_dir_all(&dir).await.unwrap();
let observer: Arc<dyn Observer> = Arc::new(crate::observability::NoopObserver);
let engine = HeartbeatEngine::new(
HeartbeatConfig {
enabled: true,
interval_minutes: 30,
},
dir.clone(),
observer,
);
let count = engine.tick().await.unwrap();
assert_eq!(count, 0);
let _ = tokio::fs::remove_dir_all(&dir).await;
}
#[tokio::test]
async fn tick_counts_tasks_from_file() {
let dir = std::env::temp_dir().join("zeroclaw_test_tick_count");
let _ = tokio::fs::remove_dir_all(&dir).await;
tokio::fs::create_dir_all(&dir).await.unwrap();
tokio::fs::write(dir.join("HEARTBEAT.md"), "- A\n- B\n- C")
.await
.unwrap();
let observer: Arc<dyn Observer> = Arc::new(crate::observability::NoopObserver);
let engine = HeartbeatEngine::new(
HeartbeatConfig {
enabled: true,
interval_minutes: 30,
},
dir.clone(),
observer,
);
let count = engine.tick().await.unwrap();
assert_eq!(count, 3);
let _ = tokio::fs::remove_dir_all(&dir).await;
}
#[tokio::test]
async fn run_returns_immediately_when_disabled() {
let observer: Arc<dyn Observer> = Arc::new(crate::observability::NoopObserver);
let engine = HeartbeatEngine::new(
HeartbeatConfig {
enabled: false,
interval_minutes: 30,
},
std::env::temp_dir(),
observer,
);
// Should return Ok immediately, not loop forever
let result = engine.run().await;
assert!(result.is_ok());
}
}

1
src/heartbeat/mod.rs Normal file
View file

@ -0,0 +1 @@
pub mod engine;

234
src/integrations/mod.rs Normal file
View file

@ -0,0 +1,234 @@
pub mod registry;
use crate::config::Config;
use anyhow::Result;
/// Integration status
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum IntegrationStatus {
/// Fully implemented and ready to use
Available,
/// Configured and active
Active,
/// Planned but not yet implemented
ComingSoon,
}
/// Integration category
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum IntegrationCategory {
Chat,
AiModel,
Productivity,
MusicAudio,
SmartHome,
ToolsAutomation,
MediaCreative,
Social,
Platform,
}
impl IntegrationCategory {
pub fn label(self) -> &'static str {
match self {
Self::Chat => "Chat Providers",
Self::AiModel => "AI Models",
Self::Productivity => "Productivity",
Self::MusicAudio => "Music & Audio",
Self::SmartHome => "Smart Home",
Self::ToolsAutomation => "Tools & Automation",
Self::MediaCreative => "Media & Creative",
Self::Social => "Social",
Self::Platform => "Platforms",
}
}
pub fn all() -> &'static [Self] {
&[
Self::Chat,
Self::AiModel,
Self::Productivity,
Self::MusicAudio,
Self::SmartHome,
Self::ToolsAutomation,
Self::MediaCreative,
Self::Social,
Self::Platform,
]
}
}
/// A registered integration
pub struct IntegrationEntry {
pub name: &'static str,
pub description: &'static str,
pub category: IntegrationCategory,
pub status_fn: fn(&Config) -> IntegrationStatus,
}
/// Handle the `integrations` CLI command
pub fn handle_command(command: super::IntegrationCommands, config: &Config) -> Result<()> {
match command {
super::IntegrationCommands::List { category } => {
list_integrations(config, category.as_deref())
}
super::IntegrationCommands::Info { name } => show_integration_info(config, &name),
}
}
#[allow(clippy::unnecessary_wraps)]
fn list_integrations(config: &Config, filter_category: Option<&str>) -> Result<()> {
let entries = registry::all_integrations();
let mut available = 0u32;
let mut active = 0u32;
let mut coming = 0u32;
for &cat in IntegrationCategory::all() {
// Filter by category if specified
if let Some(filter) = filter_category {
let filter_lower = filter.to_lowercase();
let cat_lower = cat.label().to_lowercase();
if !cat_lower.contains(&filter_lower) {
continue;
}
}
let cat_entries: Vec<&IntegrationEntry> =
entries.iter().filter(|e| e.category == cat).collect();
if cat_entries.is_empty() {
continue;
}
println!("\n{}", console::style(cat.label()).white().bold());
for entry in &cat_entries {
let status = (entry.status_fn)(config);
let (icon, label) = match status {
IntegrationStatus::Active => {
active += 1;
("", console::style("active").green())
}
IntegrationStatus::Available => {
available += 1;
("", console::style("available").dim())
}
IntegrationStatus::ComingSoon => {
coming += 1;
("🔜", console::style("coming soon").dim())
}
};
println!(
" {icon} {:<22} {:<30} {}",
console::style(entry.name).white().bold(),
entry.description,
label
);
}
}
let total = available + active + coming;
println!();
println!(" {total} integrations: {active} active, {available} available, {coming} coming soon");
println!();
println!(" Configure: zeroclaw onboard");
println!(" Details: zeroclaw integrations info <name>");
println!();
Ok(())
}
fn show_integration_info(config: &Config, name: &str) -> Result<()> {
let entries = registry::all_integrations();
let name_lower = name.to_lowercase();
let Some(entry) = entries.iter().find(|e| e.name.to_lowercase() == name_lower) else {
anyhow::bail!(
"Unknown integration: {name}. Run `zeroclaw integrations list` to see all."
);
};
let status = (entry.status_fn)(config);
let (icon, label) = match status {
IntegrationStatus::Active => ("", "Active"),
IntegrationStatus::Available => ("", "Available"),
IntegrationStatus::ComingSoon => ("🔜", "Coming Soon"),
};
println!();
println!(" {} {}{}", icon, console::style(entry.name).white().bold(), entry.description);
println!(" Category: {}", entry.category.label());
println!(" Status: {label}");
println!();
// Show setup hints based on integration
match entry.name {
"Telegram" => {
println!(" Setup:");
println!(" 1. Message @BotFather on Telegram");
println!(" 2. Create a bot and copy the token");
println!(" 3. Run: zeroclaw onboard");
println!(" 4. Start: zeroclaw channel start");
}
"Discord" => {
println!(" Setup:");
println!(" 1. Go to https://discord.com/developers/applications");
println!(" 2. Create app → Bot → Copy token");
println!(" 3. Enable MESSAGE CONTENT intent");
println!(" 4. Run: zeroclaw onboard");
}
"Slack" => {
println!(" Setup:");
println!(" 1. Go to https://api.slack.com/apps");
println!(" 2. Create app → Bot Token Scopes → Install");
println!(" 3. Run: zeroclaw onboard");
}
"OpenRouter" => {
println!(" Setup:");
println!(" 1. Get API key at https://openrouter.ai/keys");
println!(" 2. Run: zeroclaw onboard");
println!(" Access 200+ models with one key.");
}
"Ollama" => {
println!(" Setup:");
println!(" 1. Install: brew install ollama");
println!(" 2. Pull a model: ollama pull llama3");
println!(" 3. Set provider to 'ollama' in config.toml");
}
"iMessage" => {
println!(" Setup (macOS only):");
println!(" Uses AppleScript bridge to send/receive iMessages.");
println!(" Requires Full Disk Access in System Settings → Privacy.");
}
"GitHub" => {
println!(" Setup:");
println!(" 1. Create a personal access token at https://github.com/settings/tokens");
println!(" 2. Add to config: [integrations.github] token = \"ghp_...\"");
}
"Browser" => {
println!(" Built-in:");
println!(" ZeroClaw can control Chrome/Chromium for web tasks.");
println!(" Uses headless browser automation.");
}
"Cron" => {
println!(" Built-in:");
println!(" Schedule tasks in ~/.zeroclaw/workspace/cron/");
println!(" Run: zeroclaw cron list");
}
"Webhooks" => {
println!(" Built-in:");
println!(" HTTP endpoint for external triggers.");
println!(" Run: zeroclaw gateway");
}
_ => {
if status == IntegrationStatus::ComingSoon {
println!(" This integration is planned. Stay tuned!");
println!(" Track progress: https://github.com/theonlyhennygod/zeroclaw");
}
}
}
println!();
Ok(())
}

View file

@ -0,0 +1,821 @@
use super::{IntegrationCategory, IntegrationEntry, IntegrationStatus};
/// Returns the full catalog of integrations
#[allow(clippy::too_many_lines)]
pub fn all_integrations() -> Vec<IntegrationEntry> {
vec![
// ── Chat Providers ──────────────────────────────────────
IntegrationEntry {
name: "Telegram",
description: "Bot API — long-polling",
category: IntegrationCategory::Chat,
status_fn: |c| {
if c.channels_config.telegram.is_some() {
IntegrationStatus::Active
} else {
IntegrationStatus::Available
}
},
},
IntegrationEntry {
name: "Discord",
description: "Servers, channels & DMs",
category: IntegrationCategory::Chat,
status_fn: |c| {
if c.channels_config.discord.is_some() {
IntegrationStatus::Active
} else {
IntegrationStatus::Available
}
},
},
IntegrationEntry {
name: "Slack",
description: "Workspace apps via Web API",
category: IntegrationCategory::Chat,
status_fn: |c| {
if c.channels_config.slack.is_some() {
IntegrationStatus::Active
} else {
IntegrationStatus::Available
}
},
},
IntegrationEntry {
name: "Webhooks",
description: "HTTP endpoint for triggers",
category: IntegrationCategory::Chat,
status_fn: |c| {
if c.channels_config.webhook.is_some() {
IntegrationStatus::Active
} else {
IntegrationStatus::Available
}
},
},
IntegrationEntry {
name: "WhatsApp",
description: "QR pairing via web bridge",
category: IntegrationCategory::Chat,
status_fn: |_| IntegrationStatus::ComingSoon,
},
IntegrationEntry {
name: "Signal",
description: "Privacy-focused via signal-cli",
category: IntegrationCategory::Chat,
status_fn: |_| IntegrationStatus::ComingSoon,
},
IntegrationEntry {
name: "iMessage",
description: "macOS AppleScript bridge",
category: IntegrationCategory::Chat,
status_fn: |c| {
if c.channels_config.imessage.is_some() {
IntegrationStatus::Active
} else {
IntegrationStatus::Available
}
},
},
IntegrationEntry {
name: "Microsoft Teams",
description: "Enterprise chat support",
category: IntegrationCategory::Chat,
status_fn: |_| IntegrationStatus::ComingSoon,
},
IntegrationEntry {
name: "Matrix",
description: "Matrix protocol (Element)",
category: IntegrationCategory::Chat,
status_fn: |c| {
if c.channels_config.matrix.is_some() {
IntegrationStatus::Active
} else {
IntegrationStatus::Available
}
},
},
IntegrationEntry {
name: "Nostr",
description: "Decentralized DMs (NIP-04)",
category: IntegrationCategory::Chat,
status_fn: |_| IntegrationStatus::ComingSoon,
},
IntegrationEntry {
name: "WebChat",
description: "Browser-based chat UI",
category: IntegrationCategory::Chat,
status_fn: |_| IntegrationStatus::ComingSoon,
},
IntegrationEntry {
name: "Nextcloud Talk",
description: "Self-hosted Nextcloud chat",
category: IntegrationCategory::Chat,
status_fn: |_| IntegrationStatus::ComingSoon,
},
IntegrationEntry {
name: "Zalo",
description: "Zalo Bot API",
category: IntegrationCategory::Chat,
status_fn: |_| IntegrationStatus::ComingSoon,
},
// ── AI Models ───────────────────────────────────────────
IntegrationEntry {
name: "OpenRouter",
description: "200+ models, 1 API key",
category: IntegrationCategory::AiModel,
status_fn: |c| {
if c.default_provider.as_deref() == Some("openrouter") && c.api_key.is_some() {
IntegrationStatus::Active
} else {
IntegrationStatus::Available
}
},
},
IntegrationEntry {
name: "Anthropic",
description: "Claude 3.5/4 Sonnet & Opus",
category: IntegrationCategory::AiModel,
status_fn: |c| {
if c.default_provider.as_deref() == Some("anthropic") {
IntegrationStatus::Active
} else {
IntegrationStatus::Available
}
},
},
IntegrationEntry {
name: "OpenAI",
description: "GPT-4o, GPT-5, o1",
category: IntegrationCategory::AiModel,
status_fn: |c| {
if c.default_provider.as_deref() == Some("openai") {
IntegrationStatus::Active
} else {
IntegrationStatus::Available
}
},
},
IntegrationEntry {
name: "Google",
description: "Gemini 2.5 Pro/Flash",
category: IntegrationCategory::AiModel,
status_fn: |c| {
if c.default_model.as_deref().is_some_and(|m| m.starts_with("google/")) {
IntegrationStatus::Active
} else {
IntegrationStatus::Available
}
},
},
IntegrationEntry {
name: "DeepSeek",
description: "DeepSeek V3 & R1",
category: IntegrationCategory::AiModel,
status_fn: |c| {
if c.default_model.as_deref().is_some_and(|m| m.starts_with("deepseek/")) {
IntegrationStatus::Active
} else {
IntegrationStatus::Available
}
},
},
IntegrationEntry {
name: "xAI",
description: "Grok 3 & 4",
category: IntegrationCategory::AiModel,
status_fn: |c| {
if c.default_model.as_deref().is_some_and(|m| m.starts_with("x-ai/")) {
IntegrationStatus::Active
} else {
IntegrationStatus::Available
}
},
},
IntegrationEntry {
name: "Mistral",
description: "Mistral Large & Codestral",
category: IntegrationCategory::AiModel,
status_fn: |c| {
if c.default_model.as_deref().is_some_and(|m| m.starts_with("mistral")) {
IntegrationStatus::Active
} else {
IntegrationStatus::Available
}
},
},
IntegrationEntry {
name: "Ollama",
description: "Local models (Llama, etc.)",
category: IntegrationCategory::AiModel,
status_fn: |c| {
if c.default_provider.as_deref() == Some("ollama") {
IntegrationStatus::Active
} else {
IntegrationStatus::Available
}
},
},
IntegrationEntry {
name: "Perplexity",
description: "Search-augmented AI",
category: IntegrationCategory::AiModel,
status_fn: |c| {
if c.default_provider.as_deref() == Some("perplexity") {
IntegrationStatus::Active
} else {
IntegrationStatus::Available
}
},
},
IntegrationEntry {
name: "Hugging Face",
description: "Open-source models",
category: IntegrationCategory::AiModel,
status_fn: |_| IntegrationStatus::ComingSoon,
},
IntegrationEntry {
name: "LM Studio",
description: "Local model server",
category: IntegrationCategory::AiModel,
status_fn: |_| IntegrationStatus::ComingSoon,
},
IntegrationEntry {
name: "Venice",
description: "Privacy-first inference (Llama, Opus)",
category: IntegrationCategory::AiModel,
status_fn: |c| {
if c.default_provider.as_deref() == Some("venice") {
IntegrationStatus::Active
} else {
IntegrationStatus::Available
}
},
},
IntegrationEntry {
name: "Vercel AI",
description: "Vercel AI Gateway",
category: IntegrationCategory::AiModel,
status_fn: |c| {
if c.default_provider.as_deref() == Some("vercel") {
IntegrationStatus::Active
} else {
IntegrationStatus::Available
}
},
},
IntegrationEntry {
name: "Cloudflare AI",
description: "Cloudflare AI Gateway",
category: IntegrationCategory::AiModel,
status_fn: |c| {
if c.default_provider.as_deref() == Some("cloudflare") {
IntegrationStatus::Active
} else {
IntegrationStatus::Available
}
},
},
IntegrationEntry {
name: "Moonshot",
description: "Kimi & Kimi Coding",
category: IntegrationCategory::AiModel,
status_fn: |c| {
if c.default_provider.as_deref() == Some("moonshot") {
IntegrationStatus::Active
} else {
IntegrationStatus::Available
}
},
},
IntegrationEntry {
name: "Synthetic",
description: "Synthetic AI models",
category: IntegrationCategory::AiModel,
status_fn: |c| {
if c.default_provider.as_deref() == Some("synthetic") {
IntegrationStatus::Active
} else {
IntegrationStatus::Available
}
},
},
IntegrationEntry {
name: "OpenCode Zen",
description: "Code-focused AI models",
category: IntegrationCategory::AiModel,
status_fn: |c| {
if c.default_provider.as_deref() == Some("opencode") {
IntegrationStatus::Active
} else {
IntegrationStatus::Available
}
},
},
IntegrationEntry {
name: "Z.AI",
description: "Z.AI inference",
category: IntegrationCategory::AiModel,
status_fn: |c| {
if c.default_provider.as_deref() == Some("zai") {
IntegrationStatus::Active
} else {
IntegrationStatus::Available
}
},
},
IntegrationEntry {
name: "GLM",
description: "ChatGLM / Zhipu models",
category: IntegrationCategory::AiModel,
status_fn: |c| {
if c.default_provider.as_deref() == Some("glm") {
IntegrationStatus::Active
} else {
IntegrationStatus::Available
}
},
},
IntegrationEntry {
name: "MiniMax",
description: "MiniMax AI models",
category: IntegrationCategory::AiModel,
status_fn: |c| {
if c.default_provider.as_deref() == Some("minimax") {
IntegrationStatus::Active
} else {
IntegrationStatus::Available
}
},
},
IntegrationEntry {
name: "Amazon Bedrock",
description: "AWS managed model access",
category: IntegrationCategory::AiModel,
status_fn: |c| {
if c.default_provider.as_deref() == Some("bedrock") {
IntegrationStatus::Active
} else {
IntegrationStatus::Available
}
},
},
IntegrationEntry {
name: "Qianfan",
description: "Baidu AI models",
category: IntegrationCategory::AiModel,
status_fn: |c| {
if c.default_provider.as_deref() == Some("qianfan") {
IntegrationStatus::Active
} else {
IntegrationStatus::Available
}
},
},
IntegrationEntry {
name: "Groq",
description: "Ultra-fast LPU inference",
category: IntegrationCategory::AiModel,
status_fn: |c| {
if c.default_provider.as_deref() == Some("groq") {
IntegrationStatus::Active
} else {
IntegrationStatus::Available
}
},
},
IntegrationEntry {
name: "Together AI",
description: "Open-source model hosting",
category: IntegrationCategory::AiModel,
status_fn: |c| {
if c.default_provider.as_deref() == Some("together") {
IntegrationStatus::Active
} else {
IntegrationStatus::Available
}
},
},
IntegrationEntry {
name: "Fireworks AI",
description: "Fast open-source inference",
category: IntegrationCategory::AiModel,
status_fn: |c| {
if c.default_provider.as_deref() == Some("fireworks") {
IntegrationStatus::Active
} else {
IntegrationStatus::Available
}
},
},
IntegrationEntry {
name: "Cohere",
description: "Command R+ & embeddings",
category: IntegrationCategory::AiModel,
status_fn: |c| {
if c.default_provider.as_deref() == Some("cohere") {
IntegrationStatus::Active
} else {
IntegrationStatus::Available
}
},
},
// ── Productivity ────────────────────────────────────────
IntegrationEntry {
name: "GitHub",
description: "Code, issues, PRs",
category: IntegrationCategory::Productivity,
status_fn: |_| IntegrationStatus::ComingSoon,
},
IntegrationEntry {
name: "Notion",
description: "Workspace & databases",
category: IntegrationCategory::Productivity,
status_fn: |_| IntegrationStatus::ComingSoon,
},
IntegrationEntry {
name: "Apple Notes",
description: "Native macOS/iOS notes",
category: IntegrationCategory::Productivity,
status_fn: |_| IntegrationStatus::ComingSoon,
},
IntegrationEntry {
name: "Apple Reminders",
description: "Task management",
category: IntegrationCategory::Productivity,
status_fn: |_| IntegrationStatus::ComingSoon,
},
IntegrationEntry {
name: "Obsidian",
description: "Knowledge graph notes",
category: IntegrationCategory::Productivity,
status_fn: |_| IntegrationStatus::ComingSoon,
},
IntegrationEntry {
name: "Things 3",
description: "GTD task manager",
category: IntegrationCategory::Productivity,
status_fn: |_| IntegrationStatus::ComingSoon,
},
IntegrationEntry {
name: "Bear Notes",
description: "Markdown notes",
category: IntegrationCategory::Productivity,
status_fn: |_| IntegrationStatus::ComingSoon,
},
IntegrationEntry {
name: "Trello",
description: "Kanban boards",
category: IntegrationCategory::Productivity,
status_fn: |_| IntegrationStatus::ComingSoon,
},
IntegrationEntry {
name: "Linear",
description: "Issue tracking",
category: IntegrationCategory::Productivity,
status_fn: |_| IntegrationStatus::ComingSoon,
},
// ── Music & Audio ───────────────────────────────────────
IntegrationEntry {
name: "Spotify",
description: "Music playback control",
category: IntegrationCategory::MusicAudio,
status_fn: |_| IntegrationStatus::ComingSoon,
},
IntegrationEntry {
name: "Sonos",
description: "Multi-room audio",
category: IntegrationCategory::MusicAudio,
status_fn: |_| IntegrationStatus::ComingSoon,
},
IntegrationEntry {
name: "Shazam",
description: "Song recognition",
category: IntegrationCategory::MusicAudio,
status_fn: |_| IntegrationStatus::ComingSoon,
},
// ── Smart Home ──────────────────────────────────────────
IntegrationEntry {
name: "Home Assistant",
description: "Home automation hub",
category: IntegrationCategory::SmartHome,
status_fn: |_| IntegrationStatus::ComingSoon,
},
IntegrationEntry {
name: "Philips Hue",
description: "Smart lighting",
category: IntegrationCategory::SmartHome,
status_fn: |_| IntegrationStatus::ComingSoon,
},
IntegrationEntry {
name: "8Sleep",
description: "Smart mattress",
category: IntegrationCategory::SmartHome,
status_fn: |_| IntegrationStatus::ComingSoon,
},
// ── Tools & Automation ──────────────────────────────────
IntegrationEntry {
name: "Browser",
description: "Chrome/Chromium control",
category: IntegrationCategory::ToolsAutomation,
status_fn: |_| IntegrationStatus::Available,
},
IntegrationEntry {
name: "Shell",
description: "Terminal command execution",
category: IntegrationCategory::ToolsAutomation,
status_fn: |_| IntegrationStatus::Active,
},
IntegrationEntry {
name: "File System",
description: "Read/write files",
category: IntegrationCategory::ToolsAutomation,
status_fn: |_| IntegrationStatus::Active,
},
IntegrationEntry {
name: "Cron",
description: "Scheduled tasks",
category: IntegrationCategory::ToolsAutomation,
status_fn: |_| IntegrationStatus::Available,
},
IntegrationEntry {
name: "Voice",
description: "Voice wake + talk mode",
category: IntegrationCategory::ToolsAutomation,
status_fn: |_| IntegrationStatus::ComingSoon,
},
IntegrationEntry {
name: "Gmail",
description: "Email triggers & send",
category: IntegrationCategory::ToolsAutomation,
status_fn: |_| IntegrationStatus::ComingSoon,
},
IntegrationEntry {
name: "1Password",
description: "Secure credentials",
category: IntegrationCategory::ToolsAutomation,
status_fn: |_| IntegrationStatus::ComingSoon,
},
IntegrationEntry {
name: "Weather",
description: "Forecasts & conditions",
category: IntegrationCategory::ToolsAutomation,
status_fn: |_| IntegrationStatus::ComingSoon,
},
IntegrationEntry {
name: "Canvas",
description: "Visual workspace + A2UI",
category: IntegrationCategory::ToolsAutomation,
status_fn: |_| IntegrationStatus::ComingSoon,
},
// ── Media & Creative ────────────────────────────────────
IntegrationEntry {
name: "Image Gen",
description: "AI image generation",
category: IntegrationCategory::MediaCreative,
status_fn: |_| IntegrationStatus::ComingSoon,
},
IntegrationEntry {
name: "GIF Search",
description: "Find the perfect GIF",
category: IntegrationCategory::MediaCreative,
status_fn: |_| IntegrationStatus::ComingSoon,
},
IntegrationEntry {
name: "Screen Capture",
description: "Screenshot & screen control",
category: IntegrationCategory::MediaCreative,
status_fn: |_| IntegrationStatus::ComingSoon,
},
IntegrationEntry {
name: "Camera",
description: "Photo/video capture",
category: IntegrationCategory::MediaCreative,
status_fn: |_| IntegrationStatus::ComingSoon,
},
// ── Social ──────────────────────────────────────────────
IntegrationEntry {
name: "Twitter/X",
description: "Tweet, reply, search",
category: IntegrationCategory::Social,
status_fn: |_| IntegrationStatus::ComingSoon,
},
IntegrationEntry {
name: "Email",
description: "Send & read emails",
category: IntegrationCategory::Social,
status_fn: |_| IntegrationStatus::ComingSoon,
},
// ── Platforms ───────────────────────────────────────────
IntegrationEntry {
name: "macOS",
description: "Native support + AppleScript",
category: IntegrationCategory::Platform,
status_fn: |_| {
if cfg!(target_os = "macos") {
IntegrationStatus::Active
} else {
IntegrationStatus::Available
}
},
},
IntegrationEntry {
name: "Linux",
description: "Native support",
category: IntegrationCategory::Platform,
status_fn: |_| {
if cfg!(target_os = "linux") {
IntegrationStatus::Active
} else {
IntegrationStatus::Available
}
},
},
IntegrationEntry {
name: "Windows",
description: "WSL2 recommended",
category: IntegrationCategory::Platform,
status_fn: |_| IntegrationStatus::Available,
},
IntegrationEntry {
name: "iOS",
description: "Chat via Telegram/Discord",
category: IntegrationCategory::Platform,
status_fn: |_| IntegrationStatus::Available,
},
IntegrationEntry {
name: "Android",
description: "Chat via Telegram/Discord",
category: IntegrationCategory::Platform,
status_fn: |_| IntegrationStatus::Available,
},
]
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::Config;
use crate::config::schema::{
ChannelsConfig, IMessageConfig, MatrixConfig, TelegramConfig,
};
#[test]
fn registry_has_entries() {
let entries = all_integrations();
assert!(entries.len() >= 50, "Expected 50+ integrations, got {}", entries.len());
}
#[test]
fn all_categories_represented() {
let entries = all_integrations();
for cat in IntegrationCategory::all() {
let count = entries.iter().filter(|e| e.category == *cat).count();
assert!(count > 0, "Category {:?} has no entries", cat);
}
}
#[test]
fn status_functions_dont_panic() {
let config = Config::default();
let entries = all_integrations();
for entry in &entries {
let _ = (entry.status_fn)(&config);
}
}
#[test]
fn no_duplicate_names() {
let entries = all_integrations();
let mut seen = std::collections::HashSet::new();
for entry in &entries {
assert!(
seen.insert(entry.name),
"Duplicate integration name: {}",
entry.name
);
}
}
#[test]
fn no_empty_names_or_descriptions() {
let entries = all_integrations();
for entry in &entries {
assert!(!entry.name.is_empty(), "Found integration with empty name");
assert!(
!entry.description.is_empty(),
"Integration '{}' has empty description",
entry.name
);
}
}
#[test]
fn telegram_active_when_configured() {
let mut config = Config::default();
config.channels_config.telegram = Some(TelegramConfig {
bot_token: "123:ABC".into(),
allowed_users: vec!["user".into()],
});
let entries = all_integrations();
let tg = entries.iter().find(|e| e.name == "Telegram").unwrap();
assert!(matches!((tg.status_fn)(&config), IntegrationStatus::Active));
}
#[test]
fn telegram_available_when_not_configured() {
let config = Config::default();
let entries = all_integrations();
let tg = entries.iter().find(|e| e.name == "Telegram").unwrap();
assert!(matches!((tg.status_fn)(&config), IntegrationStatus::Available));
}
#[test]
fn imessage_active_when_configured() {
let mut config = Config::default();
config.channels_config.imessage = Some(IMessageConfig {
allowed_contacts: vec!["*".into()],
});
let entries = all_integrations();
let im = entries.iter().find(|e| e.name == "iMessage").unwrap();
assert!(matches!((im.status_fn)(&config), IntegrationStatus::Active));
}
#[test]
fn imessage_available_when_not_configured() {
let config = Config::default();
let entries = all_integrations();
let im = entries.iter().find(|e| e.name == "iMessage").unwrap();
assert!(matches!((im.status_fn)(&config), IntegrationStatus::Available));
}
#[test]
fn matrix_active_when_configured() {
let mut config = Config::default();
config.channels_config.matrix = Some(MatrixConfig {
homeserver: "https://m.org".into(),
access_token: "tok".into(),
room_id: "!r:m".into(),
allowed_users: vec![],
});
let entries = all_integrations();
let mx = entries.iter().find(|e| e.name == "Matrix").unwrap();
assert!(matches!((mx.status_fn)(&config), IntegrationStatus::Active));
}
#[test]
fn matrix_available_when_not_configured() {
let config = Config::default();
let entries = all_integrations();
let mx = entries.iter().find(|e| e.name == "Matrix").unwrap();
assert!(matches!((mx.status_fn)(&config), IntegrationStatus::Available));
}
#[test]
fn coming_soon_integrations_stay_coming_soon() {
let config = Config::default();
let entries = all_integrations();
for name in ["WhatsApp", "Signal", "Nostr", "Spotify", "Home Assistant"] {
let entry = entries.iter().find(|e| e.name == name).unwrap();
assert!(
matches!((entry.status_fn)(&config), IntegrationStatus::ComingSoon),
"{name} should be ComingSoon"
);
}
}
#[test]
fn shell_and_filesystem_always_active() {
let config = Config::default();
let entries = all_integrations();
for name in ["Shell", "File System"] {
let entry = entries.iter().find(|e| e.name == name).unwrap();
assert!(
matches!((entry.status_fn)(&config), IntegrationStatus::Active),
"{name} should always be Active"
);
}
}
#[test]
fn macos_active_on_macos() {
let config = Config::default();
let entries = all_integrations();
let macos = entries.iter().find(|e| e.name == "macOS").unwrap();
let status = (macos.status_fn)(&config);
if cfg!(target_os = "macos") {
assert!(matches!(status, IntegrationStatus::Active));
} else {
assert!(matches!(status, IntegrationStatus::Available));
}
}
#[test]
fn category_counts_reasonable() {
let entries = all_integrations();
let chat_count = entries.iter().filter(|e| e.category == IntegrationCategory::Chat).count();
let ai_count = entries.iter().filter(|e| e.category == IntegrationCategory::AiModel).count();
assert!(chat_count >= 5, "Expected 5+ chat integrations, got {chat_count}");
assert!(ai_count >= 5, "Expected 5+ AI model integrations, got {ai_count}");
}
}

20
src/lib.rs Normal file
View file

@ -0,0 +1,20 @@
#![warn(clippy::all, clippy::pedantic)]
#![allow(
clippy::missing_errors_doc,
clippy::missing_panics_doc,
clippy::unnecessary_literal_bound,
clippy::module_name_repetitions,
clippy::struct_field_names,
clippy::must_use_candidate,
clippy::new_without_default,
clippy::return_self_not_must_use,
dead_code
)]
pub mod config;
pub mod heartbeat;
pub mod memory;
pub mod observability;
pub mod providers;
pub mod runtime;
pub mod security;

326
src/main.rs Normal file
View file

@ -0,0 +1,326 @@
#![warn(clippy::all, clippy::pedantic)]
#![allow(
clippy::missing_errors_doc,
clippy::missing_panics_doc,
clippy::unnecessary_literal_bound,
clippy::module_name_repetitions,
clippy::struct_field_names,
dead_code
)]
use anyhow::Result;
use clap::{Parser, Subcommand};
use tracing::{info, Level};
use tracing_subscriber::FmtSubscriber;
mod agent;
mod channels;
mod config;
mod cron;
mod gateway;
mod heartbeat;
mod memory;
mod observability;
mod onboard;
mod providers;
mod runtime;
mod security;
mod integrations;
mod skills;
mod tools;
use config::Config;
/// `ZeroClaw` - Zero overhead. Zero compromise. 100% Rust.
#[derive(Parser, Debug)]
#[command(name = "zeroclaw")]
#[command(author = "theonlyhennygod")]
#[command(version = "0.1.0")]
#[command(about = "The fastest, smallest AI assistant.", long_about = None)]
struct Cli {
#[command(subcommand)]
command: Commands,
}
#[derive(Subcommand, Debug)]
enum Commands {
/// Initialize your workspace and configuration
Onboard,
/// Start the AI agent loop
Agent {
/// Single message mode (don't enter interactive mode)
#[arg(short, long)]
message: Option<String>,
/// Provider to use (openrouter, anthropic, openai)
#[arg(short, long)]
provider: Option<String>,
/// Model to use
#[arg(short, long)]
model: Option<String>,
/// Temperature (0.0 - 2.0)
#[arg(short, long, default_value = "0.7")]
temperature: f64,
},
/// Start the gateway server (webhooks, websockets)
Gateway {
/// Port to listen on
#[arg(short, long, default_value = "8080")]
port: u16,
/// Host to bind to
#[arg(short, long, default_value = "127.0.0.1")]
host: String,
},
/// Show system status
Status {
/// Show detailed status
#[arg(short, long)]
verbose: bool,
},
/// Configure and manage scheduled tasks
Cron {
#[command(subcommand)]
cron_command: CronCommands,
},
/// Manage channels (telegram, discord, slack)
Channel {
#[command(subcommand)]
channel_command: ChannelCommands,
},
/// Tool utilities
Tools {
#[command(subcommand)]
tool_command: ToolCommands,
},
/// Browse 50+ integrations
Integrations {
#[command(subcommand)]
integration_command: IntegrationCommands,
},
/// Manage skills (user-defined capabilities)
Skills {
#[command(subcommand)]
skill_command: SkillCommands,
},
}
#[derive(Subcommand, Debug)]
enum CronCommands {
/// List all scheduled tasks
List,
/// Add a new scheduled task
Add {
/// Cron expression
expression: String,
/// Command to run
command: String,
},
/// Remove a scheduled task
Remove {
/// Task ID
id: String,
},
}
#[derive(Subcommand, Debug)]
enum ChannelCommands {
/// List configured channels
List,
/// Start all configured channels (Telegram, Discord, Slack)
Start,
/// Add a new channel
Add {
/// Channel type
channel_type: String,
/// Configuration JSON
config: String,
},
/// Remove a channel
Remove {
/// Channel name
name: String,
},
}
#[derive(Subcommand, Debug)]
enum SkillCommands {
/// List installed skills
List,
/// Install a skill from a GitHub URL or local path
Install {
/// GitHub URL or local path
source: String,
},
/// Remove an installed skill
Remove {
/// Skill name
name: String,
},
}
#[derive(Subcommand, Debug)]
enum IntegrationCommands {
/// List all integrations and their status
List {
/// Filter by category (e.g. "chat", "ai", "productivity")
#[arg(short, long)]
category: Option<String>,
},
/// Show details about a specific integration
Info {
/// Integration name
name: String,
},
}
#[derive(Subcommand, Debug)]
enum ToolCommands {
/// List available tools
List,
/// Test a tool
Test {
/// Tool name
tool: String,
/// Tool arguments (JSON)
args: String,
},
}
#[tokio::main]
#[allow(clippy::too_many_lines)]
async fn main() -> Result<()> {
let cli = Cli::parse();
// Initialize logging
let subscriber = FmtSubscriber::builder()
.with_max_level(Level::INFO)
.finish();
tracing::subscriber::set_global_default(subscriber).expect("setting default subscriber failed");
// Onboard runs the interactive wizard — no existing config needed
if matches!(cli.command, Commands::Onboard) {
let config = onboard::run_wizard()?;
// Auto-start channels if user said yes during wizard
if std::env::var("ZEROCLAW_AUTOSTART_CHANNELS").as_deref() == Ok("1") {
channels::start_channels(config).await?;
}
return Ok(());
}
// All other commands need config loaded first
let config = Config::load_or_init()?;
match cli.command {
Commands::Onboard => unreachable!(),
Commands::Agent {
message,
provider,
model,
temperature,
} => agent::run(config, message, provider, model, temperature).await,
Commands::Gateway { port, host } => {
info!("🚀 Starting ZeroClaw Gateway on {host}:{port}");
info!("POST http://{host}:{port}/webhook — send JSON messages");
info!("GET http://{host}:{port}/health — health check");
gateway::run_gateway(&host, port, config).await
}
Commands::Status { verbose } => {
println!("🦀 ZeroClaw Status");
println!();
println!("Version: {}", env!("CARGO_PKG_VERSION"));
println!("Workspace: {}", config.workspace_dir.display());
println!("Config: {}", config.config_path.display());
println!();
println!(
"🤖 Provider: {}",
config.default_provider.as_deref().unwrap_or("openrouter")
);
println!(
" Model: {}",
config.default_model.as_deref().unwrap_or("(default)")
);
println!("📊 Observability: {}", config.observability.backend);
println!("🛡️ Autonomy: {:?}", config.autonomy.level);
println!("⚙️ Runtime: {}", config.runtime.kind);
println!(
"💓 Heartbeat: {}",
if config.heartbeat.enabled {
format!("every {}min", config.heartbeat.interval_minutes)
} else {
"disabled".into()
}
);
println!(
"🧠 Memory: {} (auto-save: {})",
config.memory.backend,
if config.memory.auto_save { "on" } else { "off" }
);
if verbose {
println!();
println!("Security:");
println!(" Workspace only: {}", config.autonomy.workspace_only);
println!(
" Allowed commands: {}",
config.autonomy.allowed_commands.join(", ")
);
println!(
" Max actions/hour: {}",
config.autonomy.max_actions_per_hour
);
println!(
" Max cost/day: ${:.2}",
f64::from(config.autonomy.max_cost_per_day_cents) / 100.0
);
println!();
println!("Channels:");
println!(" CLI: ✅ always");
for (name, configured) in [
("Telegram", config.channels_config.telegram.is_some()),
("Discord", config.channels_config.discord.is_some()),
("Slack", config.channels_config.slack.is_some()),
("Webhook", config.channels_config.webhook.is_some()),
] {
println!(
" {name:9} {}",
if configured { "✅ configured" } else { "❌ not configured" }
);
}
}
Ok(())
}
Commands::Cron { cron_command } => cron::handle_command(cron_command, config),
Commands::Channel { channel_command } => match channel_command {
ChannelCommands::Start => channels::start_channels(config).await,
other => channels::handle_command(other, &config),
},
Commands::Tools { tool_command } => tools::handle_command(tool_command, config).await,
Commands::Integrations {
integration_command,
} => integrations::handle_command(integration_command, &config),
Commands::Skills { skill_command } => {
skills::handle_command(skill_command, &config.workspace_dir)
}
}
}

344
src/memory/markdown.rs Normal file
View file

@ -0,0 +1,344 @@
use super::traits::{Memory, MemoryCategory, MemoryEntry};
use async_trait::async_trait;
use chrono::Local;
use std::path::{Path, PathBuf};
use tokio::fs;
/// Markdown-based memory — plain files as source of truth
///
/// Layout:
/// workspace/MEMORY.md — curated long-term memory (core)
/// workspace/memory/YYYY-MM-DD.md — daily logs (append-only)
pub struct MarkdownMemory {
workspace_dir: PathBuf,
}
impl MarkdownMemory {
pub fn new(workspace_dir: &Path) -> Self {
Self {
workspace_dir: workspace_dir.to_path_buf(),
}
}
fn memory_dir(&self) -> PathBuf {
self.workspace_dir.join("memory")
}
fn core_path(&self) -> PathBuf {
self.workspace_dir.join("MEMORY.md")
}
fn daily_path(&self) -> PathBuf {
let date = Local::now().format("%Y-%m-%d").to_string();
self.memory_dir().join(format!("{date}.md"))
}
async fn ensure_dirs(&self) -> anyhow::Result<()> {
fs::create_dir_all(self.memory_dir()).await?;
Ok(())
}
async fn append_to_file(&self, path: &Path, content: &str) -> anyhow::Result<()> {
self.ensure_dirs().await?;
let existing = if path.exists() {
fs::read_to_string(path).await.unwrap_or_default()
} else {
String::new()
};
let updated = if existing.is_empty() {
let header = if path == self.core_path() {
"# Long-Term Memory\n\n"
} else {
let date = Local::now().format("%Y-%m-%d").to_string();
&format!("# Daily Log — {date}\n\n")
};
format!("{header}{content}\n")
} else {
format!("{existing}\n{content}\n")
};
fs::write(path, updated).await?;
Ok(())
}
fn parse_entries_from_file(
path: &Path,
content: &str,
category: &MemoryCategory,
) -> Vec<MemoryEntry> {
let filename = path
.file_stem()
.and_then(|s| s.to_str())
.unwrap_or("unknown");
content
.lines()
.filter(|line| {
let trimmed = line.trim();
!trimmed.is_empty() && !trimmed.starts_with('#')
})
.enumerate()
.map(|(i, line)| {
let trimmed = line.trim();
let clean = trimmed.strip_prefix("- ").unwrap_or(trimmed);
MemoryEntry {
id: format!("{filename}:{i}"),
key: format!("{filename}:{i}"),
content: clean.to_string(),
category: category.clone(),
timestamp: filename.to_string(),
session_id: None,
score: None,
}
})
.collect()
}
async fn read_all_entries(&self) -> anyhow::Result<Vec<MemoryEntry>> {
let mut entries = Vec::new();
// Read MEMORY.md (core)
let core_path = self.core_path();
if core_path.exists() {
let content = fs::read_to_string(&core_path).await?;
entries.extend(Self::parse_entries_from_file(
&core_path,
&content,
&MemoryCategory::Core,
));
}
// Read daily logs
let mem_dir = self.memory_dir();
if mem_dir.exists() {
let mut dir = fs::read_dir(&mem_dir).await?;
while let Some(entry) = dir.next_entry().await? {
let path = entry.path();
if path.extension().and_then(|e| e.to_str()) == Some("md") {
let content = fs::read_to_string(&path).await?;
entries.extend(Self::parse_entries_from_file(
&path,
&content,
&MemoryCategory::Daily,
));
}
}
}
entries.sort_by(|a, b| b.timestamp.cmp(&a.timestamp));
Ok(entries)
}
}
#[async_trait]
impl Memory for MarkdownMemory {
fn name(&self) -> &str {
"markdown"
}
async fn store(
&self,
key: &str,
content: &str,
category: MemoryCategory,
) -> anyhow::Result<()> {
let entry = format!("- **{key}**: {content}");
let path = match category {
MemoryCategory::Core => self.core_path(),
_ => self.daily_path(),
};
self.append_to_file(&path, &entry).await
}
async fn recall(&self, query: &str, limit: usize) -> anyhow::Result<Vec<MemoryEntry>> {
let all = self.read_all_entries().await?;
let query_lower = query.to_lowercase();
let keywords: Vec<&str> = query_lower.split_whitespace().collect();
let mut scored: Vec<MemoryEntry> = all
.into_iter()
.filter_map(|mut entry| {
let content_lower = entry.content.to_lowercase();
let matched = keywords
.iter()
.filter(|kw| content_lower.contains(**kw))
.count();
if matched > 0 {
#[allow(clippy::cast_precision_loss)]
let score = matched as f64 / keywords.len() as f64;
entry.score = Some(score);
Some(entry)
} else {
None
}
})
.collect();
scored.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
scored.truncate(limit);
Ok(scored)
}
async fn get(&self, key: &str) -> anyhow::Result<Option<MemoryEntry>> {
let all = self.read_all_entries().await?;
Ok(all
.into_iter()
.find(|e| e.key == key || e.content.contains(key)))
}
async fn list(&self, category: Option<&MemoryCategory>) -> anyhow::Result<Vec<MemoryEntry>> {
let all = self.read_all_entries().await?;
match category {
Some(cat) => Ok(all.into_iter().filter(|e| &e.category == cat).collect()),
None => Ok(all),
}
}
async fn forget(&self, _key: &str) -> anyhow::Result<bool> {
// Markdown memory is append-only by design (audit trail)
// Return false to indicate the entry wasn't removed
Ok(false)
}
async fn count(&self) -> anyhow::Result<usize> {
let all = self.read_all_entries().await?;
Ok(all.len())
}
async fn health_check(&self) -> bool {
self.workspace_dir.exists()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs as sync_fs;
use tempfile::TempDir;
fn temp_workspace() -> (TempDir, MarkdownMemory) {
let tmp = TempDir::new().unwrap();
let mem = MarkdownMemory::new(tmp.path());
(tmp, mem)
}
#[tokio::test]
async fn markdown_name() {
let (_tmp, mem) = temp_workspace();
assert_eq!(mem.name(), "markdown");
}
#[tokio::test]
async fn markdown_health_check() {
let (_tmp, mem) = temp_workspace();
assert!(mem.health_check().await);
}
#[tokio::test]
async fn markdown_store_core() {
let (_tmp, mem) = temp_workspace();
mem.store("pref", "User likes Rust", MemoryCategory::Core)
.await
.unwrap();
let content = sync_fs::read_to_string(mem.core_path()).unwrap();
assert!(content.contains("User likes Rust"));
}
#[tokio::test]
async fn markdown_store_daily() {
let (_tmp, mem) = temp_workspace();
mem.store("note", "Finished tests", MemoryCategory::Daily)
.await
.unwrap();
let path = mem.daily_path();
let content = sync_fs::read_to_string(path).unwrap();
assert!(content.contains("Finished tests"));
}
#[tokio::test]
async fn markdown_recall_keyword() {
let (_tmp, mem) = temp_workspace();
mem.store("a", "Rust is fast", MemoryCategory::Core)
.await
.unwrap();
mem.store("b", "Python is slow", MemoryCategory::Core)
.await
.unwrap();
mem.store("c", "Rust and safety", MemoryCategory::Core)
.await
.unwrap();
let results = mem.recall("Rust", 10).await.unwrap();
assert!(results.len() >= 2);
assert!(results
.iter()
.all(|r| r.content.to_lowercase().contains("rust")));
}
#[tokio::test]
async fn markdown_recall_no_match() {
let (_tmp, mem) = temp_workspace();
mem.store("a", "Rust is great", MemoryCategory::Core)
.await
.unwrap();
let results = mem.recall("javascript", 10).await.unwrap();
assert!(results.is_empty());
}
#[tokio::test]
async fn markdown_count() {
let (_tmp, mem) = temp_workspace();
mem.store("a", "first", MemoryCategory::Core).await.unwrap();
mem.store("b", "second", MemoryCategory::Core)
.await
.unwrap();
let count = mem.count().await.unwrap();
assert!(count >= 2);
}
#[tokio::test]
async fn markdown_list_by_category() {
let (_tmp, mem) = temp_workspace();
mem.store("a", "core fact", MemoryCategory::Core)
.await
.unwrap();
mem.store("b", "daily note", MemoryCategory::Daily)
.await
.unwrap();
let core = mem.list(Some(&MemoryCategory::Core)).await.unwrap();
assert!(core.iter().all(|e| e.category == MemoryCategory::Core));
let daily = mem.list(Some(&MemoryCategory::Daily)).await.unwrap();
assert!(daily.iter().all(|e| e.category == MemoryCategory::Daily));
}
#[tokio::test]
async fn markdown_forget_is_noop() {
let (_tmp, mem) = temp_workspace();
mem.store("a", "permanent", MemoryCategory::Core)
.await
.unwrap();
let removed = mem.forget("a").await.unwrap();
assert!(!removed, "Markdown memory is append-only");
}
#[tokio::test]
async fn markdown_empty_recall() {
let (_tmp, mem) = temp_workspace();
let results = mem.recall("anything", 10).await.unwrap();
assert!(results.is_empty());
}
#[tokio::test]
async fn markdown_empty_count() {
let (_tmp, mem) = temp_workspace();
assert_eq!(mem.count().await.unwrap(), 0);
}
}

77
src/memory/mod.rs Normal file
View file

@ -0,0 +1,77 @@
pub mod markdown;
pub mod sqlite;
pub mod traits;
pub use markdown::MarkdownMemory;
pub use sqlite::SqliteMemory;
pub use traits::Memory;
#[allow(unused_imports)]
pub use traits::{MemoryCategory, MemoryEntry};
use crate::config::MemoryConfig;
use std::path::Path;
/// Factory: create the right memory backend from config
pub fn create_memory(
config: &MemoryConfig,
workspace_dir: &Path,
) -> anyhow::Result<Box<dyn Memory>> {
match config.backend.as_str() {
"sqlite" => Ok(Box::new(SqliteMemory::new(workspace_dir)?)),
"markdown" | "none" => Ok(Box::new(MarkdownMemory::new(workspace_dir))),
other => {
tracing::warn!("Unknown memory backend '{other}', falling back to markdown");
Ok(Box::new(MarkdownMemory::new(workspace_dir)))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
#[test]
fn factory_sqlite() {
let tmp = TempDir::new().unwrap();
let cfg = MemoryConfig {
backend: "sqlite".into(),
auto_save: true,
};
let mem = create_memory(&cfg, tmp.path()).unwrap();
assert_eq!(mem.name(), "sqlite");
}
#[test]
fn factory_markdown() {
let tmp = TempDir::new().unwrap();
let cfg = MemoryConfig {
backend: "markdown".into(),
auto_save: true,
};
let mem = create_memory(&cfg, tmp.path()).unwrap();
assert_eq!(mem.name(), "markdown");
}
#[test]
fn factory_none_falls_back_to_markdown() {
let tmp = TempDir::new().unwrap();
let cfg = MemoryConfig {
backend: "none".into(),
auto_save: true,
};
let mem = create_memory(&cfg, tmp.path()).unwrap();
assert_eq!(mem.name(), "markdown");
}
#[test]
fn factory_unknown_falls_back_to_markdown() {
let tmp = TempDir::new().unwrap();
let cfg = MemoryConfig {
backend: "redis".into(),
auto_save: true,
};
let mem = create_memory(&cfg, tmp.path()).unwrap();
assert_eq!(mem.name(), "markdown");
}
}

481
src/memory/sqlite.rs Normal file
View file

@ -0,0 +1,481 @@
use super::traits::{Memory, MemoryCategory, MemoryEntry};
use async_trait::async_trait;
use chrono::Local;
use rusqlite::{params, Connection};
use std::path::{Path, PathBuf};
use std::sync::Mutex;
use uuid::Uuid;
/// SQLite-backed persistent memory — the brain
///
/// Stores memories in a local `SQLite` database with keyword search.
/// Zero external dependencies, works offline, survives restarts.
pub struct SqliteMemory {
conn: Mutex<Connection>,
db_path: PathBuf,
}
impl SqliteMemory {
pub fn new(workspace_dir: &Path) -> anyhow::Result<Self> {
let db_path = workspace_dir.join("memory").join("brain.db");
if let Some(parent) = db_path.parent() {
std::fs::create_dir_all(parent)?;
}
let conn = Connection::open(&db_path)?;
conn.execute_batch(
"CREATE TABLE IF NOT EXISTS memories (
id TEXT PRIMARY KEY,
key TEXT NOT NULL UNIQUE,
content TEXT NOT NULL,
category TEXT NOT NULL DEFAULT 'core',
created_at TEXT NOT NULL,
updated_at TEXT NOT NULL
);
CREATE INDEX IF NOT EXISTS idx_memories_category ON memories(category);
CREATE INDEX IF NOT EXISTS idx_memories_key ON memories(key);",
)?;
Ok(Self {
conn: Mutex::new(conn),
db_path,
})
}
fn category_to_str(cat: &MemoryCategory) -> String {
match cat {
MemoryCategory::Core => "core".into(),
MemoryCategory::Daily => "daily".into(),
MemoryCategory::Conversation => "conversation".into(),
MemoryCategory::Custom(name) => name.clone(),
}
}
fn str_to_category(s: &str) -> MemoryCategory {
match s {
"core" => MemoryCategory::Core,
"daily" => MemoryCategory::Daily,
"conversation" => MemoryCategory::Conversation,
other => MemoryCategory::Custom(other.to_string()),
}
}
}
#[async_trait]
impl Memory for SqliteMemory {
fn name(&self) -> &str {
"sqlite"
}
async fn store(
&self,
key: &str,
content: &str,
category: MemoryCategory,
) -> anyhow::Result<()> {
let conn = self
.conn
.lock()
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
let now = Local::now().to_rfc3339();
let cat = Self::category_to_str(&category);
let id = Uuid::new_v4().to_string();
conn.execute(
"INSERT INTO memories (id, key, content, category, created_at, updated_at)
VALUES (?1, ?2, ?3, ?4, ?5, ?6)
ON CONFLICT(key) DO UPDATE SET
content = excluded.content,
category = excluded.category,
updated_at = excluded.updated_at",
params![id, key, content, cat, now, now],
)?;
Ok(())
}
async fn recall(&self, query: &str, limit: usize) -> anyhow::Result<Vec<MemoryEntry>> {
let conn = self
.conn
.lock()
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
// Keyword search: split query into words, match any
let keywords: Vec<String> = query.split_whitespace().map(|w| format!("%{w}%")).collect();
if keywords.is_empty() {
return Ok(Vec::new());
}
// Build dynamic WHERE clause for keyword matching
let conditions: Vec<String> = keywords
.iter()
.enumerate()
.map(|(i, _)| format!("(content LIKE ?{} OR key LIKE ?{})", i * 2 + 1, i * 2 + 2))
.collect();
let where_clause = conditions.join(" OR ");
let sql = format!(
"SELECT id, key, content, category, created_at FROM memories
WHERE {where_clause}
ORDER BY updated_at DESC
LIMIT ?{}",
keywords.len() * 2 + 1
);
let mut stmt = conn.prepare(&sql)?;
// Build params: each keyword appears twice (for content and key)
let mut param_values: Vec<Box<dyn rusqlite::types::ToSql>> = Vec::new();
for kw in &keywords {
param_values.push(Box::new(kw.clone()));
param_values.push(Box::new(kw.clone()));
}
#[allow(clippy::cast_possible_wrap)]
param_values.push(Box::new(limit as i64));
let params_ref: Vec<&dyn rusqlite::types::ToSql> =
param_values.iter().map(AsRef::as_ref).collect();
let rows = stmt.query_map(params_ref.as_slice(), |row| {
Ok(MemoryEntry {
id: row.get(0)?,
key: row.get(1)?,
content: row.get(2)?,
category: Self::str_to_category(&row.get::<_, String>(3)?),
timestamp: row.get(4)?,
session_id: None,
score: Some(1.0),
})
})?;
let mut results = Vec::new();
for row in rows {
results.push(row?);
}
// Score by keyword match count
let query_lower = query.to_lowercase();
let kw_list: Vec<&str> = query_lower.split_whitespace().collect();
for entry in &mut results {
let content_lower = entry.content.to_lowercase();
let matched = kw_list
.iter()
.filter(|kw| content_lower.contains(**kw))
.count();
#[allow(clippy::cast_precision_loss)]
{
entry.score = Some(matched as f64 / kw_list.len().max(1) as f64);
}
}
results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
Ok(results)
}
async fn get(&self, key: &str) -> anyhow::Result<Option<MemoryEntry>> {
let conn = self
.conn
.lock()
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
let mut stmt = conn.prepare(
"SELECT id, key, content, category, created_at FROM memories WHERE key = ?1",
)?;
let mut rows = stmt.query_map(params![key], |row| {
Ok(MemoryEntry {
id: row.get(0)?,
key: row.get(1)?,
content: row.get(2)?,
category: Self::str_to_category(&row.get::<_, String>(3)?),
timestamp: row.get(4)?,
session_id: None,
score: None,
})
})?;
match rows.next() {
Some(Ok(entry)) => Ok(Some(entry)),
_ => Ok(None),
}
}
async fn list(&self, category: Option<&MemoryCategory>) -> anyhow::Result<Vec<MemoryEntry>> {
let conn = self
.conn
.lock()
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
let mut results = Vec::new();
let row_mapper = |row: &rusqlite::Row| -> rusqlite::Result<MemoryEntry> {
Ok(MemoryEntry {
id: row.get(0)?,
key: row.get(1)?,
content: row.get(2)?,
category: Self::str_to_category(&row.get::<_, String>(3)?),
timestamp: row.get(4)?,
session_id: None,
score: None,
})
};
if let Some(cat) = category {
let cat_str = Self::category_to_str(cat);
let mut stmt = conn.prepare(
"SELECT id, key, content, category, created_at FROM memories
WHERE category = ?1 ORDER BY updated_at DESC",
)?;
let rows = stmt.query_map(params![cat_str], row_mapper)?;
for row in rows {
results.push(row?);
}
} else {
let mut stmt = conn.prepare(
"SELECT id, key, content, category, created_at FROM memories
ORDER BY updated_at DESC",
)?;
let rows = stmt.query_map([], row_mapper)?;
for row in rows {
results.push(row?);
}
}
Ok(results)
}
async fn forget(&self, key: &str) -> anyhow::Result<bool> {
let conn = self
.conn
.lock()
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
let affected = conn.execute("DELETE FROM memories WHERE key = ?1", params![key])?;
Ok(affected > 0)
}
async fn count(&self) -> anyhow::Result<usize> {
let conn = self
.conn
.lock()
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
let count: i64 = conn.query_row("SELECT COUNT(*) FROM memories", [], |row| row.get(0))?;
#[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)]
Ok(count as usize)
}
async fn health_check(&self) -> bool {
self.conn
.lock()
.map(|c| c.execute_batch("SELECT 1").is_ok())
.unwrap_or(false)
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
fn temp_sqlite() -> (TempDir, SqliteMemory) {
let tmp = TempDir::new().unwrap();
let mem = SqliteMemory::new(tmp.path()).unwrap();
(tmp, mem)
}
#[tokio::test]
async fn sqlite_name() {
let (_tmp, mem) = temp_sqlite();
assert_eq!(mem.name(), "sqlite");
}
#[tokio::test]
async fn sqlite_health() {
let (_tmp, mem) = temp_sqlite();
assert!(mem.health_check().await);
}
#[tokio::test]
async fn sqlite_store_and_get() {
let (_tmp, mem) = temp_sqlite();
mem.store("user_lang", "Prefers Rust", MemoryCategory::Core)
.await
.unwrap();
let entry = mem.get("user_lang").await.unwrap();
assert!(entry.is_some());
let entry = entry.unwrap();
assert_eq!(entry.key, "user_lang");
assert_eq!(entry.content, "Prefers Rust");
assert_eq!(entry.category, MemoryCategory::Core);
}
#[tokio::test]
async fn sqlite_store_upsert() {
let (_tmp, mem) = temp_sqlite();
mem.store("pref", "likes Rust", MemoryCategory::Core)
.await
.unwrap();
mem.store("pref", "loves Rust", MemoryCategory::Core)
.await
.unwrap();
let entry = mem.get("pref").await.unwrap().unwrap();
assert_eq!(entry.content, "loves Rust");
assert_eq!(mem.count().await.unwrap(), 1);
}
#[tokio::test]
async fn sqlite_recall_keyword() {
let (_tmp, mem) = temp_sqlite();
mem.store("a", "Rust is fast and safe", MemoryCategory::Core)
.await
.unwrap();
mem.store("b", "Python is interpreted", MemoryCategory::Core)
.await
.unwrap();
mem.store("c", "Rust has zero-cost abstractions", MemoryCategory::Core)
.await
.unwrap();
let results = mem.recall("Rust", 10).await.unwrap();
assert_eq!(results.len(), 2);
assert!(results
.iter()
.all(|r| r.content.to_lowercase().contains("rust")));
}
#[tokio::test]
async fn sqlite_recall_multi_keyword() {
let (_tmp, mem) = temp_sqlite();
mem.store("a", "Rust is fast", MemoryCategory::Core)
.await
.unwrap();
mem.store("b", "Rust is safe and fast", MemoryCategory::Core)
.await
.unwrap();
let results = mem.recall("fast safe", 10).await.unwrap();
assert!(!results.is_empty());
// Entry with both keywords should score higher
assert!(results[0].content.contains("safe") && results[0].content.contains("fast"));
}
#[tokio::test]
async fn sqlite_recall_no_match() {
let (_tmp, mem) = temp_sqlite();
mem.store("a", "Rust rocks", MemoryCategory::Core)
.await
.unwrap();
let results = mem.recall("javascript", 10).await.unwrap();
assert!(results.is_empty());
}
#[tokio::test]
async fn sqlite_forget() {
let (_tmp, mem) = temp_sqlite();
mem.store("temp", "temporary data", MemoryCategory::Conversation)
.await
.unwrap();
assert_eq!(mem.count().await.unwrap(), 1);
let removed = mem.forget("temp").await.unwrap();
assert!(removed);
assert_eq!(mem.count().await.unwrap(), 0);
}
#[tokio::test]
async fn sqlite_forget_nonexistent() {
let (_tmp, mem) = temp_sqlite();
let removed = mem.forget("nope").await.unwrap();
assert!(!removed);
}
#[tokio::test]
async fn sqlite_list_all() {
let (_tmp, mem) = temp_sqlite();
mem.store("a", "one", MemoryCategory::Core).await.unwrap();
mem.store("b", "two", MemoryCategory::Daily).await.unwrap();
mem.store("c", "three", MemoryCategory::Conversation)
.await
.unwrap();
let all = mem.list(None).await.unwrap();
assert_eq!(all.len(), 3);
}
#[tokio::test]
async fn sqlite_list_by_category() {
let (_tmp, mem) = temp_sqlite();
mem.store("a", "core1", MemoryCategory::Core).await.unwrap();
mem.store("b", "core2", MemoryCategory::Core).await.unwrap();
mem.store("c", "daily1", MemoryCategory::Daily)
.await
.unwrap();
let core = mem.list(Some(&MemoryCategory::Core)).await.unwrap();
assert_eq!(core.len(), 2);
let daily = mem.list(Some(&MemoryCategory::Daily)).await.unwrap();
assert_eq!(daily.len(), 1);
}
#[tokio::test]
async fn sqlite_count_empty() {
let (_tmp, mem) = temp_sqlite();
assert_eq!(mem.count().await.unwrap(), 0);
}
#[tokio::test]
async fn sqlite_get_nonexistent() {
let (_tmp, mem) = temp_sqlite();
assert!(mem.get("nope").await.unwrap().is_none());
}
#[tokio::test]
async fn sqlite_db_persists() {
let tmp = TempDir::new().unwrap();
{
let mem = SqliteMemory::new(tmp.path()).unwrap();
mem.store("persist", "I survive restarts", MemoryCategory::Core)
.await
.unwrap();
}
// Reopen
let mem2 = SqliteMemory::new(tmp.path()).unwrap();
let entry = mem2.get("persist").await.unwrap();
assert!(entry.is_some());
assert_eq!(entry.unwrap().content, "I survive restarts");
}
#[tokio::test]
async fn sqlite_category_roundtrip() {
let (_tmp, mem) = temp_sqlite();
let categories = vec![
MemoryCategory::Core,
MemoryCategory::Daily,
MemoryCategory::Conversation,
MemoryCategory::Custom("project".into()),
];
for (i, cat) in categories.iter().enumerate() {
mem.store(&format!("k{i}"), &format!("v{i}"), cat.clone())
.await
.unwrap();
}
for (i, cat) in categories.iter().enumerate() {
let entry = mem.get(&format!("k{i}")).await.unwrap().unwrap();
assert_eq!(&entry.category, cat);
}
}
}

68
src/memory/traits.rs Normal file
View file

@ -0,0 +1,68 @@
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
/// A single memory entry
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MemoryEntry {
pub id: String,
pub key: String,
pub content: String,
pub category: MemoryCategory,
pub timestamp: String,
pub session_id: Option<String>,
pub score: Option<f64>,
}
/// Memory categories for organization
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "snake_case")]
pub enum MemoryCategory {
/// Long-term facts, preferences, decisions
Core,
/// Daily session logs
Daily,
/// Conversation context
Conversation,
/// User-defined custom category
Custom(String),
}
impl std::fmt::Display for MemoryCategory {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Core => write!(f, "core"),
Self::Daily => write!(f, "daily"),
Self::Conversation => write!(f, "conversation"),
Self::Custom(name) => write!(f, "{name}"),
}
}
}
/// Core memory trait — implement for any persistence backend
#[async_trait]
pub trait Memory: Send + Sync {
/// Backend name
fn name(&self) -> &str;
/// Store a memory entry
async fn store(&self, key: &str, content: &str, category: MemoryCategory)
-> anyhow::Result<()>;
/// Recall memories matching a query (keyword search)
async fn recall(&self, query: &str, limit: usize) -> anyhow::Result<Vec<MemoryEntry>>;
/// Get a specific memory by key
async fn get(&self, key: &str) -> anyhow::Result<Option<MemoryEntry>>;
/// List all memory keys, optionally filtered by category
async fn list(&self, category: Option<&MemoryCategory>) -> anyhow::Result<Vec<MemoryEntry>>;
/// Remove a memory by key
async fn forget(&self, key: &str) -> anyhow::Result<bool>;
/// Count total memories
async fn count(&self) -> anyhow::Result<usize>;
/// Health check
async fn health_check(&self) -> bool;
}

119
src/observability/log.rs Normal file
View file

@ -0,0 +1,119 @@
use super::traits::{Observer, ObserverEvent, ObserverMetric};
use tracing::info;
/// Log-based observer — uses tracing, zero external deps
pub struct LogObserver;
impl LogObserver {
pub fn new() -> Self {
Self
}
}
impl Observer for LogObserver {
fn record_event(&self, event: &ObserverEvent) {
match event {
ObserverEvent::AgentStart { provider, model } => {
info!(provider = %provider, model = %model, "agent.start");
}
ObserverEvent::AgentEnd {
duration,
tokens_used,
} => {
let ms = u64::try_from(duration.as_millis()).unwrap_or(u64::MAX);
info!(duration_ms = ms, tokens = ?tokens_used, "agent.end");
}
ObserverEvent::ToolCall {
tool,
duration,
success,
} => {
let ms = u64::try_from(duration.as_millis()).unwrap_or(u64::MAX);
info!(tool = %tool, duration_ms = ms, success = success, "tool.call");
}
ObserverEvent::ChannelMessage { channel, direction } => {
info!(channel = %channel, direction = %direction, "channel.message");
}
ObserverEvent::HeartbeatTick => {
info!("heartbeat.tick");
}
ObserverEvent::Error { component, message } => {
info!(component = %component, error = %message, "error");
}
}
}
fn record_metric(&self, metric: &ObserverMetric) {
match metric {
ObserverMetric::RequestLatency(d) => {
let ms = u64::try_from(d.as_millis()).unwrap_or(u64::MAX);
info!(latency_ms = ms, "metric.request_latency");
}
ObserverMetric::TokensUsed(t) => {
info!(tokens = t, "metric.tokens_used");
}
ObserverMetric::ActiveSessions(s) => {
info!(sessions = s, "metric.active_sessions");
}
ObserverMetric::QueueDepth(d) => {
info!(depth = d, "metric.queue_depth");
}
}
}
fn name(&self) -> &str {
"log"
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
#[test]
fn log_observer_name() {
assert_eq!(LogObserver::new().name(), "log");
}
#[test]
fn log_observer_all_events_no_panic() {
let obs = LogObserver::new();
obs.record_event(&ObserverEvent::AgentStart {
provider: "openrouter".into(),
model: "claude-sonnet".into(),
});
obs.record_event(&ObserverEvent::AgentEnd {
duration: Duration::from_millis(500),
tokens_used: Some(100),
});
obs.record_event(&ObserverEvent::AgentEnd {
duration: Duration::ZERO,
tokens_used: None,
});
obs.record_event(&ObserverEvent::ToolCall {
tool: "shell".into(),
duration: Duration::from_millis(10),
success: false,
});
obs.record_event(&ObserverEvent::ChannelMessage {
channel: "telegram".into(),
direction: "outbound".into(),
});
obs.record_event(&ObserverEvent::HeartbeatTick);
obs.record_event(&ObserverEvent::Error {
component: "provider".into(),
message: "timeout".into(),
});
}
#[test]
fn log_observer_all_metrics_no_panic() {
let obs = LogObserver::new();
obs.record_metric(&ObserverMetric::RequestLatency(Duration::from_secs(2)));
obs.record_metric(&ObserverMetric::TokensUsed(0));
obs.record_metric(&ObserverMetric::TokensUsed(u64::MAX));
obs.record_metric(&ObserverMetric::ActiveSessions(1));
obs.record_metric(&ObserverMetric::QueueDepth(999));
}
}

76
src/observability/mod.rs Normal file
View file

@ -0,0 +1,76 @@
pub mod log;
pub mod multi;
pub mod noop;
pub mod traits;
pub use self::log::LogObserver;
pub use noop::NoopObserver;
pub use traits::{Observer, ObserverEvent};
use crate::config::ObservabilityConfig;
/// Factory: create the right observer from config
pub fn create_observer(config: &ObservabilityConfig) -> Box<dyn Observer> {
match config.backend.as_str() {
"log" => Box::new(LogObserver::new()),
"none" | "noop" => Box::new(NoopObserver),
_ => {
tracing::warn!(
"Unknown observability backend '{}', falling back to noop",
config.backend
);
Box::new(NoopObserver)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn factory_none_returns_noop() {
let cfg = ObservabilityConfig {
backend: "none".into(),
};
assert_eq!(create_observer(&cfg).name(), "noop");
}
#[test]
fn factory_noop_returns_noop() {
let cfg = ObservabilityConfig {
backend: "noop".into(),
};
assert_eq!(create_observer(&cfg).name(), "noop");
}
#[test]
fn factory_log_returns_log() {
let cfg = ObservabilityConfig {
backend: "log".into(),
};
assert_eq!(create_observer(&cfg).name(), "log");
}
#[test]
fn factory_unknown_falls_back_to_noop() {
let cfg = ObservabilityConfig {
backend: "prometheus".into(),
};
assert_eq!(create_observer(&cfg).name(), "noop");
}
#[test]
fn factory_empty_string_falls_back_to_noop() {
let cfg = ObservabilityConfig { backend: "".into() };
assert_eq!(create_observer(&cfg).name(), "noop");
}
#[test]
fn factory_garbage_falls_back_to_noop() {
let cfg = ObservabilityConfig {
backend: "xyzzy_garbage_123".into(),
};
assert_eq!(create_observer(&cfg).name(), "noop");
}
}

154
src/observability/multi.rs Normal file
View file

@ -0,0 +1,154 @@
use super::traits::{Observer, ObserverEvent, ObserverMetric};
/// Combine multiple observers — fan-out events to all backends
pub struct MultiObserver {
observers: Vec<Box<dyn Observer>>,
}
impl MultiObserver {
pub fn new(observers: Vec<Box<dyn Observer>>) -> Self {
Self { observers }
}
}
impl Observer for MultiObserver {
fn record_event(&self, event: &ObserverEvent) {
for obs in &self.observers {
obs.record_event(event);
}
}
fn record_metric(&self, metric: &ObserverMetric) {
for obs in &self.observers {
obs.record_metric(metric);
}
}
fn flush(&self) {
for obs in &self.observers {
obs.flush();
}
}
fn name(&self) -> &str {
"multi"
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use std::time::Duration;
/// Test observer that counts calls
struct CountingObserver {
event_count: Arc<AtomicUsize>,
metric_count: Arc<AtomicUsize>,
flush_count: Arc<AtomicUsize>,
}
impl CountingObserver {
fn new(
event_count: Arc<AtomicUsize>,
metric_count: Arc<AtomicUsize>,
flush_count: Arc<AtomicUsize>,
) -> Self {
Self {
event_count,
metric_count,
flush_count,
}
}
}
impl Observer for CountingObserver {
fn record_event(&self, _event: &ObserverEvent) {
self.event_count.fetch_add(1, Ordering::SeqCst);
}
fn record_metric(&self, _metric: &ObserverMetric) {
self.metric_count.fetch_add(1, Ordering::SeqCst);
}
fn flush(&self) {
self.flush_count.fetch_add(1, Ordering::SeqCst);
}
fn name(&self) -> &str {
"counting"
}
}
#[test]
fn multi_name() {
let m = MultiObserver::new(vec![]);
assert_eq!(m.name(), "multi");
}
#[test]
fn multi_empty_no_panic() {
let m = MultiObserver::new(vec![]);
m.record_event(&ObserverEvent::HeartbeatTick);
m.record_metric(&ObserverMetric::TokensUsed(10));
m.flush();
}
#[test]
fn multi_fans_out_events() {
let ec1 = Arc::new(AtomicUsize::new(0));
let mc1 = Arc::new(AtomicUsize::new(0));
let fc1 = Arc::new(AtomicUsize::new(0));
let ec2 = Arc::new(AtomicUsize::new(0));
let mc2 = Arc::new(AtomicUsize::new(0));
let fc2 = Arc::new(AtomicUsize::new(0));
let m = MultiObserver::new(vec![
Box::new(CountingObserver::new(ec1.clone(), mc1.clone(), fc1.clone())),
Box::new(CountingObserver::new(ec2.clone(), mc2.clone(), fc2.clone())),
]);
m.record_event(&ObserverEvent::HeartbeatTick);
m.record_event(&ObserverEvent::HeartbeatTick);
m.record_event(&ObserverEvent::HeartbeatTick);
assert_eq!(ec1.load(Ordering::SeqCst), 3);
assert_eq!(ec2.load(Ordering::SeqCst), 3);
}
#[test]
fn multi_fans_out_metrics() {
let ec1 = Arc::new(AtomicUsize::new(0));
let mc1 = Arc::new(AtomicUsize::new(0));
let fc1 = Arc::new(AtomicUsize::new(0));
let ec2 = Arc::new(AtomicUsize::new(0));
let mc2 = Arc::new(AtomicUsize::new(0));
let fc2 = Arc::new(AtomicUsize::new(0));
let m = MultiObserver::new(vec![
Box::new(CountingObserver::new(ec1.clone(), mc1.clone(), fc1.clone())),
Box::new(CountingObserver::new(ec2.clone(), mc2.clone(), fc2.clone())),
]);
m.record_metric(&ObserverMetric::TokensUsed(100));
m.record_metric(&ObserverMetric::RequestLatency(Duration::from_millis(5)));
assert_eq!(mc1.load(Ordering::SeqCst), 2);
assert_eq!(mc2.load(Ordering::SeqCst), 2);
}
#[test]
fn multi_fans_out_flush() {
let ec = Arc::new(AtomicUsize::new(0));
let mc = Arc::new(AtomicUsize::new(0));
let fc1 = Arc::new(AtomicUsize::new(0));
let fc2 = Arc::new(AtomicUsize::new(0));
let m = MultiObserver::new(vec![
Box::new(CountingObserver::new(ec.clone(), mc.clone(), fc1.clone())),
Box::new(CountingObserver::new(ec.clone(), mc.clone(), fc2.clone())),
]);
m.flush();
assert_eq!(fc1.load(Ordering::SeqCst), 1);
assert_eq!(fc2.load(Ordering::SeqCst), 1);
}
}

72
src/observability/noop.rs Normal file
View file

@ -0,0 +1,72 @@
use super::traits::{Observer, ObserverEvent, ObserverMetric};
/// Zero-overhead observer — all methods compile to nothing
pub struct NoopObserver;
impl Observer for NoopObserver {
#[inline(always)]
fn record_event(&self, _event: &ObserverEvent) {}
#[inline(always)]
fn record_metric(&self, _metric: &ObserverMetric) {}
fn name(&self) -> &str {
"noop"
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
#[test]
fn noop_name() {
assert_eq!(NoopObserver.name(), "noop");
}
#[test]
fn noop_record_event_does_not_panic() {
let obs = NoopObserver;
obs.record_event(&ObserverEvent::HeartbeatTick);
obs.record_event(&ObserverEvent::AgentStart {
provider: "test".into(),
model: "test".into(),
});
obs.record_event(&ObserverEvent::AgentEnd {
duration: Duration::from_millis(100),
tokens_used: Some(42),
});
obs.record_event(&ObserverEvent::AgentEnd {
duration: Duration::ZERO,
tokens_used: None,
});
obs.record_event(&ObserverEvent::ToolCall {
tool: "shell".into(),
duration: Duration::from_secs(1),
success: true,
});
obs.record_event(&ObserverEvent::ChannelMessage {
channel: "cli".into(),
direction: "inbound".into(),
});
obs.record_event(&ObserverEvent::Error {
component: "test".into(),
message: "boom".into(),
});
}
#[test]
fn noop_record_metric_does_not_panic() {
let obs = NoopObserver;
obs.record_metric(&ObserverMetric::RequestLatency(Duration::from_millis(50)));
obs.record_metric(&ObserverMetric::TokensUsed(1000));
obs.record_metric(&ObserverMetric::ActiveSessions(5));
obs.record_metric(&ObserverMetric::QueueDepth(0));
}
#[test]
fn noop_flush_does_not_panic() {
NoopObserver.flush();
}
}

View file

@ -0,0 +1,52 @@
use std::time::Duration;
/// Events the observer can record
#[derive(Debug, Clone)]
pub enum ObserverEvent {
AgentStart {
provider: String,
model: String,
},
AgentEnd {
duration: Duration,
tokens_used: Option<u64>,
},
ToolCall {
tool: String,
duration: Duration,
success: bool,
},
ChannelMessage {
channel: String,
direction: String,
},
HeartbeatTick,
Error {
component: String,
message: String,
},
}
/// Numeric metrics
#[derive(Debug, Clone)]
pub enum ObserverMetric {
RequestLatency(Duration),
TokensUsed(u64),
ActiveSessions(u64),
QueueDepth(u64),
}
/// Core observability trait — implement for any backend
pub trait Observer: Send + Sync {
/// Record a discrete event
fn record_event(&self, event: &ObserverEvent);
/// Record a numeric metric
fn record_metric(&self, metric: &ObserverMetric);
/// Flush any buffered data (no-op for most backends)
fn flush(&self) {}
/// Human-readable name of this observer
fn name(&self) -> &str;
}

3
src/onboard/mod.rs Normal file
View file

@ -0,0 +1,3 @@
pub mod wizard;
pub use wizard::run_wizard;

1804
src/onboard/wizard.rs Normal file

File diff suppressed because it is too large Load diff

212
src/providers/anthropic.rs Normal file
View file

@ -0,0 +1,212 @@
use crate::providers::traits::Provider;
use async_trait::async_trait;
use reqwest::Client;
use serde::{Deserialize, Serialize};
pub struct AnthropicProvider {
api_key: Option<String>,
client: Client,
}
#[derive(Debug, Serialize)]
struct ChatRequest {
model: String,
max_tokens: u32,
#[serde(skip_serializing_if = "Option::is_none")]
system: Option<String>,
messages: Vec<Message>,
temperature: f64,
}
#[derive(Debug, Serialize)]
struct Message {
role: String,
content: String,
}
#[derive(Debug, Deserialize)]
struct ChatResponse {
content: Vec<ContentBlock>,
}
#[derive(Debug, Deserialize)]
struct ContentBlock {
text: String,
}
impl AnthropicProvider {
pub fn new(api_key: Option<&str>) -> Self {
Self {
api_key: api_key.map(ToString::to_string),
client: Client::new(),
}
}
}
#[async_trait]
impl Provider for AnthropicProvider {
async fn chat_with_system(
&self,
system_prompt: Option<&str>,
message: &str,
model: &str,
temperature: f64,
) -> anyhow::Result<String> {
let api_key = self.api_key.as_ref().ok_or_else(|| {
anyhow::anyhow!(
"Anthropic API key not set. Set ANTHROPIC_API_KEY or edit config.toml."
)
})?;
let request = ChatRequest {
model: model.to_string(),
max_tokens: 4096,
system: system_prompt.map(ToString::to_string),
messages: vec![Message {
role: "user".to_string(),
content: message.to_string(),
}],
temperature,
};
let response = self
.client
.post("https://api.anthropic.com/v1/messages")
.header("x-api-key", api_key)
.header("anthropic-version", "2023-06-01")
.header("content-type", "application/json")
.json(&request)
.send()
.await?;
if !response.status().is_success() {
let error = response.text().await?;
anyhow::bail!("Anthropic API error: {error}");
}
let chat_response: ChatResponse = response.json().await?;
chat_response
.content
.into_iter()
.next()
.map(|c| c.text)
.ok_or_else(|| anyhow::anyhow!("No response from Anthropic"))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn creates_with_key() {
let p = AnthropicProvider::new(Some("sk-ant-test123"));
assert!(p.api_key.is_some());
assert_eq!(p.api_key.as_deref(), Some("sk-ant-test123"));
}
#[test]
fn creates_without_key() {
let p = AnthropicProvider::new(None);
assert!(p.api_key.is_none());
}
#[test]
fn creates_with_empty_key() {
let p = AnthropicProvider::new(Some(""));
assert!(p.api_key.is_some());
assert_eq!(p.api_key.as_deref(), Some(""));
}
#[tokio::test]
async fn chat_fails_without_key() {
let p = AnthropicProvider::new(None);
let result = p.chat_with_system(None, "hello", "claude-3-opus", 0.7).await;
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("API key not set"), "Expected key error, got: {err}");
}
#[tokio::test]
async fn chat_with_system_fails_without_key() {
let p = AnthropicProvider::new(None);
let result = p
.chat_with_system(Some("You are ZeroClaw"), "hello", "claude-3-opus", 0.7)
.await;
assert!(result.is_err());
}
#[test]
fn chat_request_serializes_without_system() {
let req = ChatRequest {
model: "claude-3-opus".to_string(),
max_tokens: 4096,
system: None,
messages: vec![Message {
role: "user".to_string(),
content: "hello".to_string(),
}],
temperature: 0.7,
};
let json = serde_json::to_string(&req).unwrap();
assert!(!json.contains("system"), "system field should be skipped when None");
assert!(json.contains("claude-3-opus"));
assert!(json.contains("hello"));
}
#[test]
fn chat_request_serializes_with_system() {
let req = ChatRequest {
model: "claude-3-opus".to_string(),
max_tokens: 4096,
system: Some("You are ZeroClaw".to_string()),
messages: vec![Message {
role: "user".to_string(),
content: "hello".to_string(),
}],
temperature: 0.7,
};
let json = serde_json::to_string(&req).unwrap();
assert!(json.contains("\"system\":\"You are ZeroClaw\""));
}
#[test]
fn chat_response_deserializes() {
let json = r#"{"content":[{"type":"text","text":"Hello there!"}]}"#;
let resp: ChatResponse = serde_json::from_str(json).unwrap();
assert_eq!(resp.content.len(), 1);
assert_eq!(resp.content[0].text, "Hello there!");
}
#[test]
fn chat_response_empty_content() {
let json = r#"{"content":[]}"#;
let resp: ChatResponse = serde_json::from_str(json).unwrap();
assert!(resp.content.is_empty());
}
#[test]
fn chat_response_multiple_blocks() {
let json = r#"{"content":[{"type":"text","text":"First"},{"type":"text","text":"Second"}]}"#;
let resp: ChatResponse = serde_json::from_str(json).unwrap();
assert_eq!(resp.content.len(), 2);
assert_eq!(resp.content[0].text, "First");
assert_eq!(resp.content[1].text, "Second");
}
#[test]
fn temperature_range_serializes() {
for temp in [0.0, 0.5, 1.0, 2.0] {
let req = ChatRequest {
model: "claude-3-opus".to_string(),
max_tokens: 4096,
system: None,
messages: vec![],
temperature: temp,
};
let json = serde_json::to_string(&req).unwrap();
assert!(json.contains(&format!("{temp}")));
}
}
}

245
src/providers/compatible.rs Normal file
View file

@ -0,0 +1,245 @@
//! Generic OpenAI-compatible provider.
//! Most LLM APIs follow the same `/v1/chat/completions` format.
//! This module provides a single implementation that works for all of them.
use crate::providers::traits::Provider;
use async_trait::async_trait;
use reqwest::Client;
use serde::{Deserialize, Serialize};
/// A provider that speaks the OpenAI-compatible chat completions API.
/// Used by: Venice, Vercel AI Gateway, Cloudflare AI Gateway, Moonshot,
/// Synthetic, `OpenCode` Zen, `Z.AI`, `GLM`, `MiniMax`, Bedrock, Qianfan, Groq, Mistral, `xAI`, etc.
pub struct OpenAiCompatibleProvider {
pub(crate) name: String,
pub(crate) base_url: String,
pub(crate) api_key: Option<String>,
pub(crate) auth_header: AuthStyle,
client: Client,
}
/// How the provider expects the API key to be sent.
#[derive(Debug, Clone)]
pub enum AuthStyle {
/// `Authorization: Bearer <key>`
Bearer,
/// `x-api-key: <key>` (used by some Chinese providers)
XApiKey,
/// Custom header name
Custom(String),
}
impl OpenAiCompatibleProvider {
pub fn new(name: &str, base_url: &str, api_key: Option<&str>, auth_style: AuthStyle) -> Self {
Self {
name: name.to_string(),
base_url: base_url.trim_end_matches('/').to_string(),
api_key: api_key.map(ToString::to_string),
auth_header: auth_style,
client: Client::new(),
}
}
}
#[derive(Debug, Serialize)]
struct ChatRequest {
model: String,
messages: Vec<Message>,
temperature: f64,
}
#[derive(Debug, Serialize)]
struct Message {
role: String,
content: String,
}
#[derive(Debug, Deserialize)]
struct ChatResponse {
choices: Vec<Choice>,
}
#[derive(Debug, Deserialize)]
struct Choice {
message: ResponseMessage,
}
#[derive(Debug, Deserialize)]
struct ResponseMessage {
content: String,
}
#[async_trait]
impl Provider for OpenAiCompatibleProvider {
async fn chat_with_system(
&self,
system_prompt: Option<&str>,
message: &str,
model: &str,
temperature: f64,
) -> anyhow::Result<String> {
let api_key = self.api_key.as_ref().ok_or_else(|| {
anyhow::anyhow!(
"{} API key not set. Run `zeroclaw onboard` or set the appropriate env var.",
self.name
)
})?;
let mut messages = Vec::new();
if let Some(sys) = system_prompt {
messages.push(Message {
role: "system".to_string(),
content: sys.to_string(),
});
}
messages.push(Message {
role: "user".to_string(),
content: message.to_string(),
});
let request = ChatRequest {
model: model.to_string(),
messages,
temperature,
};
let url = format!("{}/v1/chat/completions", self.base_url);
let mut req = self.client.post(&url).json(&request);
match &self.auth_header {
AuthStyle::Bearer => {
req = req.header("Authorization", format!("Bearer {api_key}"));
}
AuthStyle::XApiKey => {
req = req.header("x-api-key", api_key.as_str());
}
AuthStyle::Custom(header) => {
req = req.header(header.as_str(), api_key.as_str());
}
}
let response = req.send().await?;
if !response.status().is_success() {
let error = response.text().await?;
anyhow::bail!("{} API error: {error}", self.name);
}
let chat_response: ChatResponse = response.json().await?;
chat_response
.choices
.into_iter()
.next()
.map(|c| c.message.content)
.ok_or_else(|| anyhow::anyhow!("No response from {}", self.name))
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_provider(name: &str, url: &str, key: Option<&str>) -> OpenAiCompatibleProvider {
OpenAiCompatibleProvider::new(name, url, key, AuthStyle::Bearer)
}
#[test]
fn creates_with_key() {
let p = make_provider("venice", "https://api.venice.ai", Some("vn-key"));
assert_eq!(p.name, "venice");
assert_eq!(p.base_url, "https://api.venice.ai");
assert_eq!(p.api_key.as_deref(), Some("vn-key"));
}
#[test]
fn creates_without_key() {
let p = make_provider("test", "https://example.com", None);
assert!(p.api_key.is_none());
}
#[test]
fn strips_trailing_slash() {
let p = make_provider("test", "https://example.com/", None);
assert_eq!(p.base_url, "https://example.com");
}
#[tokio::test]
async fn chat_fails_without_key() {
let p = make_provider("Venice", "https://api.venice.ai", None);
let result = p.chat_with_system(None, "hello", "llama-3.3-70b", 0.7).await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("Venice API key not set"));
}
#[test]
fn request_serializes_correctly() {
let req = ChatRequest {
model: "llama-3.3-70b".to_string(),
messages: vec![
Message { role: "system".to_string(), content: "You are ZeroClaw".to_string() },
Message { role: "user".to_string(), content: "hello".to_string() },
],
temperature: 0.7,
};
let json = serde_json::to_string(&req).unwrap();
assert!(json.contains("llama-3.3-70b"));
assert!(json.contains("system"));
assert!(json.contains("user"));
}
#[test]
fn response_deserializes() {
let json = r#"{"choices":[{"message":{"content":"Hello from Venice!"}}]}"#;
let resp: ChatResponse = serde_json::from_str(json).unwrap();
assert_eq!(resp.choices[0].message.content, "Hello from Venice!");
}
#[test]
fn response_empty_choices() {
let json = r#"{"choices":[]}"#;
let resp: ChatResponse = serde_json::from_str(json).unwrap();
assert!(resp.choices.is_empty());
}
#[test]
fn x_api_key_auth_style() {
let p = OpenAiCompatibleProvider::new(
"moonshot", "https://api.moonshot.cn", Some("ms-key"), AuthStyle::XApiKey,
);
assert!(matches!(p.auth_header, AuthStyle::XApiKey));
}
#[test]
fn custom_auth_style() {
let p = OpenAiCompatibleProvider::new(
"custom", "https://api.example.com", Some("key"), AuthStyle::Custom("X-Custom-Key".into()),
);
assert!(matches!(p.auth_header, AuthStyle::Custom(_)));
}
#[tokio::test]
async fn all_compatible_providers_fail_without_key() {
let providers = vec![
make_provider("Venice", "https://api.venice.ai", None),
make_provider("Moonshot", "https://api.moonshot.cn", None),
make_provider("GLM", "https://open.bigmodel.cn", None),
make_provider("MiniMax", "https://api.minimax.chat", None),
make_provider("Groq", "https://api.groq.com/openai", None),
make_provider("Mistral", "https://api.mistral.ai", None),
make_provider("xAI", "https://api.x.ai", None),
];
for p in providers {
let result = p.chat_with_system(None, "test", "model", 0.7).await;
assert!(result.is_err(), "{} should fail without key", p.name);
assert!(
result.unwrap_err().to_string().contains("API key not set"),
"{} error should mention key", p.name
);
}
}
}

266
src/providers/mod.rs Normal file
View file

@ -0,0 +1,266 @@
pub mod anthropic;
pub mod compatible;
pub mod ollama;
pub mod openai;
pub mod openrouter;
pub mod traits;
pub use traits::Provider;
use compatible::{AuthStyle, OpenAiCompatibleProvider};
/// Factory: create the right provider from config
#[allow(clippy::too_many_lines)]
pub fn create_provider(name: &str, api_key: Option<&str>) -> anyhow::Result<Box<dyn Provider>> {
match name {
// ── Primary providers (custom implementations) ───────
"openrouter" => Ok(Box::new(openrouter::OpenRouterProvider::new(api_key))),
"anthropic" => Ok(Box::new(anthropic::AnthropicProvider::new(api_key))),
"openai" => Ok(Box::new(openai::OpenAiProvider::new(api_key))),
"ollama" => Ok(Box::new(ollama::OllamaProvider::new(
api_key.filter(|k| !k.is_empty()),
))),
// ── OpenAI-compatible providers ──────────────────────
"venice" => Ok(Box::new(OpenAiCompatibleProvider::new(
"Venice", "https://api.venice.ai", api_key, AuthStyle::Bearer,
))),
"vercel" | "vercel-ai" => Ok(Box::new(OpenAiCompatibleProvider::new(
"Vercel AI Gateway", "https://api.vercel.ai", api_key, AuthStyle::Bearer,
))),
"cloudflare" | "cloudflare-ai" => Ok(Box::new(OpenAiCompatibleProvider::new(
"Cloudflare AI Gateway",
"https://gateway.ai.cloudflare.com/v1",
api_key,
AuthStyle::Bearer,
))),
"moonshot" | "kimi" => Ok(Box::new(OpenAiCompatibleProvider::new(
"Moonshot", "https://api.moonshot.cn", api_key, AuthStyle::Bearer,
))),
"synthetic" => Ok(Box::new(OpenAiCompatibleProvider::new(
"Synthetic", "https://api.synthetic.com", api_key, AuthStyle::Bearer,
))),
"opencode" | "opencode-zen" => Ok(Box::new(OpenAiCompatibleProvider::new(
"OpenCode Zen", "https://api.opencode.ai", api_key, AuthStyle::Bearer,
))),
"zai" | "z.ai" => Ok(Box::new(OpenAiCompatibleProvider::new(
"Z.AI", "https://api.z.ai", api_key, AuthStyle::Bearer,
))),
"glm" | "zhipu" => Ok(Box::new(OpenAiCompatibleProvider::new(
"GLM", "https://open.bigmodel.cn/api/paas", api_key, AuthStyle::Bearer,
))),
"minimax" => Ok(Box::new(OpenAiCompatibleProvider::new(
"MiniMax", "https://api.minimax.chat", api_key, AuthStyle::Bearer,
))),
"bedrock" | "aws-bedrock" => Ok(Box::new(OpenAiCompatibleProvider::new(
"Amazon Bedrock",
"https://bedrock-runtime.us-east-1.amazonaws.com",
api_key,
AuthStyle::Bearer,
))),
"qianfan" | "baidu" => Ok(Box::new(OpenAiCompatibleProvider::new(
"Qianfan", "https://aip.baidubce.com", api_key, AuthStyle::Bearer,
))),
// ── Extended ecosystem (community favorites) ─────────
"groq" => Ok(Box::new(OpenAiCompatibleProvider::new(
"Groq", "https://api.groq.com/openai", api_key, AuthStyle::Bearer,
))),
"mistral" => Ok(Box::new(OpenAiCompatibleProvider::new(
"Mistral", "https://api.mistral.ai", api_key, AuthStyle::Bearer,
))),
"xai" | "grok" => Ok(Box::new(OpenAiCompatibleProvider::new(
"xAI", "https://api.x.ai", api_key, AuthStyle::Bearer,
))),
"deepseek" => Ok(Box::new(OpenAiCompatibleProvider::new(
"DeepSeek", "https://api.deepseek.com", api_key, AuthStyle::Bearer,
))),
"together" | "together-ai" => Ok(Box::new(OpenAiCompatibleProvider::new(
"Together AI", "https://api.together.xyz", api_key, AuthStyle::Bearer,
))),
"fireworks" | "fireworks-ai" => Ok(Box::new(OpenAiCompatibleProvider::new(
"Fireworks AI", "https://api.fireworks.ai/inference", api_key, AuthStyle::Bearer,
))),
"perplexity" => Ok(Box::new(OpenAiCompatibleProvider::new(
"Perplexity", "https://api.perplexity.ai", api_key, AuthStyle::Bearer,
))),
"cohere" => Ok(Box::new(OpenAiCompatibleProvider::new(
"Cohere", "https://api.cohere.com/compatibility", api_key, AuthStyle::Bearer,
))),
_ => anyhow::bail!(
"Unknown provider: {name}. Run `zeroclaw integrations list -c ai` to see all available providers."
),
}
}
#[cfg(test)]
mod tests {
use super::*;
// ── Primary providers ────────────────────────────────────
#[test]
fn factory_openrouter() {
assert!(create_provider("openrouter", Some("sk-test")).is_ok());
assert!(create_provider("openrouter", None).is_ok());
}
#[test]
fn factory_anthropic() {
assert!(create_provider("anthropic", Some("sk-test")).is_ok());
}
#[test]
fn factory_openai() {
assert!(create_provider("openai", Some("sk-test")).is_ok());
}
#[test]
fn factory_ollama() {
assert!(create_provider("ollama", None).is_ok());
}
// ── OpenAI-compatible providers ──────────────────────────
#[test]
fn factory_venice() {
assert!(create_provider("venice", Some("vn-key")).is_ok());
}
#[test]
fn factory_vercel() {
assert!(create_provider("vercel", Some("key")).is_ok());
assert!(create_provider("vercel-ai", Some("key")).is_ok());
}
#[test]
fn factory_cloudflare() {
assert!(create_provider("cloudflare", Some("key")).is_ok());
assert!(create_provider("cloudflare-ai", Some("key")).is_ok());
}
#[test]
fn factory_moonshot() {
assert!(create_provider("moonshot", Some("key")).is_ok());
assert!(create_provider("kimi", Some("key")).is_ok());
}
#[test]
fn factory_synthetic() {
assert!(create_provider("synthetic", Some("key")).is_ok());
}
#[test]
fn factory_opencode() {
assert!(create_provider("opencode", Some("key")).is_ok());
assert!(create_provider("opencode-zen", Some("key")).is_ok());
}
#[test]
fn factory_zai() {
assert!(create_provider("zai", Some("key")).is_ok());
assert!(create_provider("z.ai", Some("key")).is_ok());
}
#[test]
fn factory_glm() {
assert!(create_provider("glm", Some("key")).is_ok());
assert!(create_provider("zhipu", Some("key")).is_ok());
}
#[test]
fn factory_minimax() {
assert!(create_provider("minimax", Some("key")).is_ok());
}
#[test]
fn factory_bedrock() {
assert!(create_provider("bedrock", Some("key")).is_ok());
assert!(create_provider("aws-bedrock", Some("key")).is_ok());
}
#[test]
fn factory_qianfan() {
assert!(create_provider("qianfan", Some("key")).is_ok());
assert!(create_provider("baidu", Some("key")).is_ok());
}
// ── Extended ecosystem ───────────────────────────────────
#[test]
fn factory_groq() {
assert!(create_provider("groq", Some("key")).is_ok());
}
#[test]
fn factory_mistral() {
assert!(create_provider("mistral", Some("key")).is_ok());
}
#[test]
fn factory_xai() {
assert!(create_provider("xai", Some("key")).is_ok());
assert!(create_provider("grok", Some("key")).is_ok());
}
#[test]
fn factory_deepseek() {
assert!(create_provider("deepseek", Some("key")).is_ok());
}
#[test]
fn factory_together() {
assert!(create_provider("together", Some("key")).is_ok());
assert!(create_provider("together-ai", Some("key")).is_ok());
}
#[test]
fn factory_fireworks() {
assert!(create_provider("fireworks", Some("key")).is_ok());
assert!(create_provider("fireworks-ai", Some("key")).is_ok());
}
#[test]
fn factory_perplexity() {
assert!(create_provider("perplexity", Some("key")).is_ok());
}
#[test]
fn factory_cohere() {
assert!(create_provider("cohere", Some("key")).is_ok());
}
// ── Error cases ──────────────────────────────────────────
#[test]
fn factory_unknown_provider_errors() {
let p = create_provider("nonexistent", None);
assert!(p.is_err());
let msg = p.err().unwrap().to_string();
assert!(msg.contains("Unknown provider"));
assert!(msg.contains("nonexistent"));
}
#[test]
fn factory_empty_name_errors() {
assert!(create_provider("", None).is_err());
}
#[test]
fn factory_all_providers_create_successfully() {
let providers = [
"openrouter", "anthropic", "openai", "ollama",
"venice", "vercel", "cloudflare", "moonshot", "synthetic",
"opencode", "zai", "glm", "minimax", "bedrock", "qianfan",
"groq", "mistral", "xai", "deepseek", "together",
"fireworks", "perplexity", "cohere",
];
for name in providers {
assert!(
create_provider(name, Some("test-key")).is_ok(),
"Provider '{name}' should create successfully"
);
}
}
}

177
src/providers/ollama.rs Normal file
View file

@ -0,0 +1,177 @@
use crate::providers::traits::Provider;
use async_trait::async_trait;
use reqwest::Client;
use serde::{Deserialize, Serialize};
pub struct OllamaProvider {
base_url: String,
client: Client,
}
#[derive(Debug, Serialize)]
struct ChatRequest {
model: String,
messages: Vec<Message>,
stream: bool,
options: Options,
}
#[derive(Debug, Serialize)]
struct Message {
role: String,
content: String,
}
#[derive(Debug, Serialize)]
struct Options {
temperature: f64,
}
#[derive(Debug, Deserialize)]
struct ChatResponse {
message: ResponseMessage,
}
#[derive(Debug, Deserialize)]
struct ResponseMessage {
content: String,
}
impl OllamaProvider {
pub fn new(base_url: Option<&str>) -> Self {
Self {
base_url: base_url
.unwrap_or("http://localhost:11434")
.trim_end_matches('/')
.to_string(),
client: Client::new(),
}
}
}
#[async_trait]
impl Provider for OllamaProvider {
async fn chat_with_system(
&self,
system_prompt: Option<&str>,
message: &str,
model: &str,
temperature: f64,
) -> anyhow::Result<String> {
let mut messages = Vec::new();
if let Some(sys) = system_prompt {
messages.push(Message {
role: "system".to_string(),
content: sys.to_string(),
});
}
messages.push(Message {
role: "user".to_string(),
content: message.to_string(),
});
let request = ChatRequest {
model: model.to_string(),
messages,
stream: false,
options: Options { temperature },
};
let url = format!("{}/api/chat", self.base_url);
let response = self.client.post(&url).json(&request).send().await?;
if !response.status().is_success() {
let error = response.text().await?;
anyhow::bail!("Ollama error: {error}. Is Ollama running? (brew install ollama && ollama serve)");
}
let chat_response: ChatResponse = response.json().await?;
Ok(chat_response.message.content)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_url() {
let p = OllamaProvider::new(None);
assert_eq!(p.base_url, "http://localhost:11434");
}
#[test]
fn custom_url_trailing_slash() {
let p = OllamaProvider::new(Some("http://192.168.1.100:11434/"));
assert_eq!(p.base_url, "http://192.168.1.100:11434");
}
#[test]
fn custom_url_no_trailing_slash() {
let p = OllamaProvider::new(Some("http://myserver:11434"));
assert_eq!(p.base_url, "http://myserver:11434");
}
#[test]
fn empty_url_uses_empty() {
let p = OllamaProvider::new(Some(""));
assert_eq!(p.base_url, "");
}
#[test]
fn request_serializes_with_system() {
let req = ChatRequest {
model: "llama3".to_string(),
messages: vec![
Message { role: "system".to_string(), content: "You are ZeroClaw".to_string() },
Message { role: "user".to_string(), content: "hello".to_string() },
],
stream: false,
options: Options { temperature: 0.7 },
};
let json = serde_json::to_string(&req).unwrap();
assert!(json.contains("\"stream\":false"));
assert!(json.contains("llama3"));
assert!(json.contains("system"));
assert!(json.contains("\"temperature\":0.7"));
}
#[test]
fn request_serializes_without_system() {
let req = ChatRequest {
model: "mistral".to_string(),
messages: vec![
Message { role: "user".to_string(), content: "test".to_string() },
],
stream: false,
options: Options { temperature: 0.0 },
};
let json = serde_json::to_string(&req).unwrap();
assert!(!json.contains("\"role\":\"system\""));
assert!(json.contains("mistral"));
}
#[test]
fn response_deserializes() {
let json = r#"{"message":{"role":"assistant","content":"Hello from Ollama!"}}"#;
let resp: ChatResponse = serde_json::from_str(json).unwrap();
assert_eq!(resp.message.content, "Hello from Ollama!");
}
#[test]
fn response_with_empty_content() {
let json = r#"{"message":{"role":"assistant","content":""}}"#;
let resp: ChatResponse = serde_json::from_str(json).unwrap();
assert!(resp.message.content.is_empty());
}
#[test]
fn response_with_multiline() {
let json = r#"{"message":{"role":"assistant","content":"line1\nline2\nline3"}}"#;
let resp: ChatResponse = serde_json::from_str(json).unwrap();
assert!(resp.message.content.contains("line1"));
}
}

211
src/providers/openai.rs Normal file
View file

@ -0,0 +1,211 @@
use crate::providers::traits::Provider;
use async_trait::async_trait;
use reqwest::Client;
use serde::{Deserialize, Serialize};
pub struct OpenAiProvider {
api_key: Option<String>,
client: Client,
}
#[derive(Debug, Serialize)]
struct ChatRequest {
model: String,
messages: Vec<Message>,
temperature: f64,
}
#[derive(Debug, Serialize)]
struct Message {
role: String,
content: String,
}
#[derive(Debug, Deserialize)]
struct ChatResponse {
choices: Vec<Choice>,
}
#[derive(Debug, Deserialize)]
struct Choice {
message: ResponseMessage,
}
#[derive(Debug, Deserialize)]
struct ResponseMessage {
content: String,
}
impl OpenAiProvider {
pub fn new(api_key: Option<&str>) -> Self {
Self {
api_key: api_key.map(ToString::to_string),
client: Client::new(),
}
}
}
#[async_trait]
impl Provider for OpenAiProvider {
async fn chat_with_system(
&self,
system_prompt: Option<&str>,
message: &str,
model: &str,
temperature: f64,
) -> anyhow::Result<String> {
let api_key = self.api_key.as_ref().ok_or_else(|| {
anyhow::anyhow!("OpenAI API key not set. Set OPENAI_API_KEY or edit config.toml.")
})?;
let mut messages = Vec::new();
if let Some(sys) = system_prompt {
messages.push(Message {
role: "system".to_string(),
content: sys.to_string(),
});
}
messages.push(Message {
role: "user".to_string(),
content: message.to_string(),
});
let request = ChatRequest {
model: model.to_string(),
messages,
temperature,
};
let response = self
.client
.post("https://api.openai.com/v1/chat/completions")
.header("Authorization", format!("Bearer {api_key}"))
.json(&request)
.send()
.await?;
if !response.status().is_success() {
let error = response.text().await?;
anyhow::bail!("OpenAI API error: {error}");
}
let chat_response: ChatResponse = response.json().await?;
chat_response
.choices
.into_iter()
.next()
.map(|c| c.message.content)
.ok_or_else(|| anyhow::anyhow!("No response from OpenAI"))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn creates_with_key() {
let p = OpenAiProvider::new(Some("sk-proj-abc123"));
assert_eq!(p.api_key.as_deref(), Some("sk-proj-abc123"));
}
#[test]
fn creates_without_key() {
let p = OpenAiProvider::new(None);
assert!(p.api_key.is_none());
}
#[test]
fn creates_with_empty_key() {
let p = OpenAiProvider::new(Some(""));
assert_eq!(p.api_key.as_deref(), Some(""));
}
#[tokio::test]
async fn chat_fails_without_key() {
let p = OpenAiProvider::new(None);
let result = p.chat_with_system(None, "hello", "gpt-4o", 0.7).await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("API key not set"));
}
#[tokio::test]
async fn chat_with_system_fails_without_key() {
let p = OpenAiProvider::new(None);
let result = p
.chat_with_system(Some("You are ZeroClaw"), "test", "gpt-4o", 0.5)
.await;
assert!(result.is_err());
}
#[test]
fn request_serializes_with_system_message() {
let req = ChatRequest {
model: "gpt-4o".to_string(),
messages: vec![
Message { role: "system".to_string(), content: "You are ZeroClaw".to_string() },
Message { role: "user".to_string(), content: "hello".to_string() },
],
temperature: 0.7,
};
let json = serde_json::to_string(&req).unwrap();
assert!(json.contains("\"role\":\"system\""));
assert!(json.contains("\"role\":\"user\""));
assert!(json.contains("gpt-4o"));
}
#[test]
fn request_serializes_without_system() {
let req = ChatRequest {
model: "gpt-4o".to_string(),
messages: vec![
Message { role: "user".to_string(), content: "hello".to_string() },
],
temperature: 0.0,
};
let json = serde_json::to_string(&req).unwrap();
assert!(!json.contains("system"));
assert!(json.contains("\"temperature\":0.0"));
}
#[test]
fn response_deserializes_single_choice() {
let json = r#"{"choices":[{"message":{"content":"Hi!"}}]}"#;
let resp: ChatResponse = serde_json::from_str(json).unwrap();
assert_eq!(resp.choices.len(), 1);
assert_eq!(resp.choices[0].message.content, "Hi!");
}
#[test]
fn response_deserializes_empty_choices() {
let json = r#"{"choices":[]}"#;
let resp: ChatResponse = serde_json::from_str(json).unwrap();
assert!(resp.choices.is_empty());
}
#[test]
fn response_deserializes_multiple_choices() {
let json = r#"{"choices":[{"message":{"content":"A"}},{"message":{"content":"B"}}]}"#;
let resp: ChatResponse = serde_json::from_str(json).unwrap();
assert_eq!(resp.choices.len(), 2);
assert_eq!(resp.choices[0].message.content, "A");
}
#[test]
fn response_with_unicode() {
let json = r#"{"choices":[{"message":{"content":"こんにちは 🦀"}}]}"#;
let resp: ChatResponse = serde_json::from_str(json).unwrap();
assert_eq!(resp.choices[0].message.content, "こんにちは 🦀");
}
#[test]
fn response_with_long_content() {
let long = "x".repeat(100_000);
let json = format!(r#"{{"choices":[{{"message":{{"content":"{long}"}}}}]}}"#);
let resp: ChatResponse = serde_json::from_str(&json).unwrap();
assert_eq!(resp.choices[0].message.content.len(), 100_000);
}
}

107
src/providers/openrouter.rs Normal file
View file

@ -0,0 +1,107 @@
use crate::providers::traits::Provider;
use async_trait::async_trait;
use reqwest::Client;
use serde::{Deserialize, Serialize};
pub struct OpenRouterProvider {
api_key: Option<String>,
client: Client,
}
#[derive(Debug, Serialize)]
struct ChatRequest {
model: String,
messages: Vec<Message>,
temperature: f64,
}
#[derive(Debug, Serialize)]
struct Message {
role: String,
content: String,
}
#[derive(Debug, Deserialize)]
struct ChatResponse {
choices: Vec<Choice>,
}
#[derive(Debug, Deserialize)]
struct Choice {
message: ResponseMessage,
}
#[derive(Debug, Deserialize)]
struct ResponseMessage {
content: String,
}
impl OpenRouterProvider {
pub fn new(api_key: Option<&str>) -> Self {
Self {
api_key: api_key.map(ToString::to_string),
client: Client::new(),
}
}
}
#[async_trait]
impl Provider for OpenRouterProvider {
async fn chat_with_system(
&self,
system_prompt: Option<&str>,
message: &str,
model: &str,
temperature: f64,
) -> anyhow::Result<String> {
let api_key = self.api_key.as_ref()
.ok_or_else(|| anyhow::anyhow!("OpenRouter API key not set. Run `zeroclaw onboard` or set OPENROUTER_API_KEY env var."))?;
let mut messages = Vec::new();
if let Some(sys) = system_prompt {
messages.push(Message {
role: "system".to_string(),
content: sys.to_string(),
});
}
messages.push(Message {
role: "user".to_string(),
content: message.to_string(),
});
let request = ChatRequest {
model: model.to_string(),
messages,
temperature,
};
let response = self
.client
.post("https://openrouter.ai/api/v1/chat/completions")
.header("Authorization", format!("Bearer {api_key}"))
.header(
"HTTP-Referer",
"https://github.com/theonlyhennygod/zeroclaw",
)
.header("X-Title", "ZeroClaw")
.json(&request)
.send()
.await?;
if !response.status().is_success() {
let error = response.text().await?;
anyhow::bail!("OpenRouter API error: {error}");
}
let chat_response: ChatResponse = response.json().await?;
chat_response
.choices
.into_iter()
.next()
.map(|c| c.message.content)
.ok_or_else(|| anyhow::anyhow!("No response from OpenRouter"))
}
}

22
src/providers/traits.rs Normal file
View file

@ -0,0 +1,22 @@
use async_trait::async_trait;
#[async_trait]
pub trait Provider: Send + Sync {
async fn chat(
&self,
message: &str,
model: &str,
temperature: f64,
) -> anyhow::Result<String> {
self.chat_with_system(None, message, model, temperature)
.await
}
async fn chat_with_system(
&self,
system_prompt: Option<&str>,
message: &str,
model: &str,
temperature: f64,
) -> anyhow::Result<String>;
}

71
src/runtime/mod.rs Normal file
View file

@ -0,0 +1,71 @@
pub mod native;
pub mod traits;
pub use native::NativeRuntime;
pub use traits::RuntimeAdapter;
use crate::config::RuntimeConfig;
/// Factory: create the right runtime from config
pub fn create_runtime(config: &RuntimeConfig) -> Box<dyn RuntimeAdapter> {
match config.kind.as_str() {
"native" | "docker" => Box::new(NativeRuntime::new()),
"cloudflare" => {
tracing::warn!("Cloudflare runtime not yet implemented, falling back to native");
Box::new(NativeRuntime::new())
}
_ => {
tracing::warn!("Unknown runtime '{}', falling back to native", config.kind);
Box::new(NativeRuntime::new())
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn factory_native() {
let cfg = RuntimeConfig {
kind: "native".into(),
};
let rt = create_runtime(&cfg);
assert_eq!(rt.name(), "native");
assert!(rt.has_shell_access());
}
#[test]
fn factory_docker_returns_native() {
let cfg = RuntimeConfig {
kind: "docker".into(),
};
let rt = create_runtime(&cfg);
assert_eq!(rt.name(), "native");
}
#[test]
fn factory_cloudflare_falls_back() {
let cfg = RuntimeConfig {
kind: "cloudflare".into(),
};
let rt = create_runtime(&cfg);
assert_eq!(rt.name(), "native");
}
#[test]
fn factory_unknown_falls_back() {
let cfg = RuntimeConfig {
kind: "wasm-edge-unknown".into(),
};
let rt = create_runtime(&cfg);
assert_eq!(rt.name(), "native");
}
#[test]
fn factory_empty_falls_back() {
let cfg = RuntimeConfig { kind: "".into() };
let rt = create_runtime(&cfg);
assert_eq!(rt.name(), "native");
}
}

72
src/runtime/native.rs Normal file
View file

@ -0,0 +1,72 @@
use super::traits::RuntimeAdapter;
use std::path::PathBuf;
/// Native runtime — full access, runs on Mac/Linux/Docker/Raspberry Pi
pub struct NativeRuntime;
impl NativeRuntime {
pub fn new() -> Self {
Self
}
}
impl RuntimeAdapter for NativeRuntime {
fn name(&self) -> &str {
"native"
}
fn has_shell_access(&self) -> bool {
true
}
fn has_filesystem_access(&self) -> bool {
true
}
fn storage_path(&self) -> PathBuf {
directories::UserDirs::new().map_or_else(
|| PathBuf::from(".zeroclaw"),
|u| u.home_dir().join(".zeroclaw"),
)
}
fn supports_long_running(&self) -> bool {
true
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn native_name() {
assert_eq!(NativeRuntime::new().name(), "native");
}
#[test]
fn native_has_shell_access() {
assert!(NativeRuntime::new().has_shell_access());
}
#[test]
fn native_has_filesystem_access() {
assert!(NativeRuntime::new().has_filesystem_access());
}
#[test]
fn native_supports_long_running() {
assert!(NativeRuntime::new().supports_long_running());
}
#[test]
fn native_memory_budget_unlimited() {
assert_eq!(NativeRuntime::new().memory_budget(), 0);
}
#[test]
fn native_storage_path_contains_zeroclaw() {
let path = NativeRuntime::new().storage_path();
assert!(path.to_string_lossy().contains("zeroclaw"));
}
}

25
src/runtime/traits.rs Normal file
View file

@ -0,0 +1,25 @@
use std::path::PathBuf;
/// Runtime adapter — abstracts platform differences so the same agent
/// code runs on native, Docker, Cloudflare Workers, Raspberry Pi, etc.
pub trait RuntimeAdapter: Send + Sync {
/// Human-readable runtime name
fn name(&self) -> &str;
/// Whether this runtime supports shell access
fn has_shell_access(&self) -> bool;
/// Whether this runtime supports filesystem access
fn has_filesystem_access(&self) -> bool;
/// Base storage path for this runtime
fn storage_path(&self) -> PathBuf;
/// Whether long-running processes (gateway, heartbeat) are supported
fn supports_long_running(&self) -> bool;
/// Maximum memory budget in bytes (0 = unlimited)
fn memory_budget(&self) -> u64 {
0
}
}

3
src/security/mod.rs Normal file
View file

@ -0,0 +1,3 @@
pub mod policy;
pub use policy::{AutonomyLevel, SecurityPolicy};

365
src/security/policy.rs Normal file
View file

@ -0,0 +1,365 @@
use serde::{Deserialize, Serialize};
use std::path::{Path, PathBuf};
/// How much autonomy the agent has
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum AutonomyLevel {
/// Read-only: can observe but not act
ReadOnly,
/// Supervised: acts but requires approval for risky operations
Supervised,
/// Full: autonomous execution within policy bounds
Full,
}
impl Default for AutonomyLevel {
fn default() -> Self {
Self::Supervised
}
}
/// Security policy enforced on all tool executions
#[derive(Debug, Clone)]
pub struct SecurityPolicy {
pub autonomy: AutonomyLevel,
pub workspace_dir: PathBuf,
pub workspace_only: bool,
pub allowed_commands: Vec<String>,
pub forbidden_paths: Vec<String>,
pub max_actions_per_hour: u32,
pub max_cost_per_day_cents: u32,
}
impl Default for SecurityPolicy {
fn default() -> Self {
Self {
autonomy: AutonomyLevel::Supervised,
workspace_dir: PathBuf::from("."),
workspace_only: true,
allowed_commands: vec![
"git".into(),
"npm".into(),
"cargo".into(),
"ls".into(),
"cat".into(),
"grep".into(),
"find".into(),
"echo".into(),
"pwd".into(),
"wc".into(),
"head".into(),
"tail".into(),
],
forbidden_paths: vec![
"/etc".into(),
"/root".into(),
"~/.ssh".into(),
"~/.gnupg".into(),
"/var/run".into(),
],
max_actions_per_hour: 20,
max_cost_per_day_cents: 500,
}
}
}
impl SecurityPolicy {
/// Check if a shell command is allowed
pub fn is_command_allowed(&self, command: &str) -> bool {
if self.autonomy == AutonomyLevel::ReadOnly {
return false;
}
// Extract the base command (first word)
let base_cmd = command
.split_whitespace()
.next()
.unwrap_or("")
.rsplit('/')
.next()
.unwrap_or("");
self.allowed_commands
.iter()
.any(|allowed| allowed == base_cmd)
}
/// Check if a file path is allowed (no path traversal, within workspace)
pub fn is_path_allowed(&self, path: &str) -> bool {
// Block obvious traversal attempts
if path.contains("..") {
return false;
}
// Block absolute paths when workspace_only is set
if self.workspace_only && Path::new(path).is_absolute() {
return false;
}
// Block forbidden paths
for forbidden in &self.forbidden_paths {
if path.starts_with(forbidden.as_str()) {
return false;
}
}
true
}
/// Check if autonomy level permits any action at all
pub fn can_act(&self) -> bool {
self.autonomy != AutonomyLevel::ReadOnly
}
/// Build from config sections
pub fn from_config(
autonomy_config: &crate::config::AutonomyConfig,
workspace_dir: &Path,
) -> Self {
Self {
autonomy: autonomy_config.level,
workspace_dir: workspace_dir.to_path_buf(),
workspace_only: autonomy_config.workspace_only,
allowed_commands: autonomy_config.allowed_commands.clone(),
forbidden_paths: autonomy_config.forbidden_paths.clone(),
max_actions_per_hour: autonomy_config.max_actions_per_hour,
max_cost_per_day_cents: autonomy_config.max_cost_per_day_cents,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn default_policy() -> SecurityPolicy {
SecurityPolicy::default()
}
fn readonly_policy() -> SecurityPolicy {
SecurityPolicy {
autonomy: AutonomyLevel::ReadOnly,
..SecurityPolicy::default()
}
}
fn full_policy() -> SecurityPolicy {
SecurityPolicy {
autonomy: AutonomyLevel::Full,
..SecurityPolicy::default()
}
}
// ── AutonomyLevel ────────────────────────────────────────
#[test]
fn autonomy_default_is_supervised() {
assert_eq!(AutonomyLevel::default(), AutonomyLevel::Supervised);
}
#[test]
fn autonomy_serde_roundtrip() {
let json = serde_json::to_string(&AutonomyLevel::Full).unwrap();
assert_eq!(json, "\"full\"");
let parsed: AutonomyLevel = serde_json::from_str("\"readonly\"").unwrap();
assert_eq!(parsed, AutonomyLevel::ReadOnly);
let parsed2: AutonomyLevel = serde_json::from_str("\"supervised\"").unwrap();
assert_eq!(parsed2, AutonomyLevel::Supervised);
}
#[test]
fn can_act_readonly_false() {
assert!(!readonly_policy().can_act());
}
#[test]
fn can_act_supervised_true() {
assert!(default_policy().can_act());
}
#[test]
fn can_act_full_true() {
assert!(full_policy().can_act());
}
// ── is_command_allowed ───────────────────────────────────
#[test]
fn allowed_commands_basic() {
let p = default_policy();
assert!(p.is_command_allowed("ls"));
assert!(p.is_command_allowed("git status"));
assert!(p.is_command_allowed("cargo build --release"));
assert!(p.is_command_allowed("cat file.txt"));
assert!(p.is_command_allowed("grep -r pattern ."));
}
#[test]
fn blocked_commands_basic() {
let p = default_policy();
assert!(!p.is_command_allowed("rm -rf /"));
assert!(!p.is_command_allowed("sudo apt install"));
assert!(!p.is_command_allowed("curl http://evil.com"));
assert!(!p.is_command_allowed("wget http://evil.com"));
assert!(!p.is_command_allowed("python3 exploit.py"));
assert!(!p.is_command_allowed("node malicious.js"));
}
#[test]
fn readonly_blocks_all_commands() {
let p = readonly_policy();
assert!(!p.is_command_allowed("ls"));
assert!(!p.is_command_allowed("cat file.txt"));
assert!(!p.is_command_allowed("echo hello"));
}
#[test]
fn full_autonomy_still_uses_allowlist() {
let p = full_policy();
assert!(p.is_command_allowed("ls"));
assert!(!p.is_command_allowed("rm -rf /"));
}
#[test]
fn command_with_absolute_path_extracts_basename() {
let p = default_policy();
assert!(p.is_command_allowed("/usr/bin/git status"));
assert!(p.is_command_allowed("/bin/ls -la"));
}
#[test]
fn empty_command_blocked() {
let p = default_policy();
assert!(!p.is_command_allowed(""));
assert!(!p.is_command_allowed(" "));
}
#[test]
fn command_with_pipes_uses_first_word() {
let p = default_policy();
assert!(p.is_command_allowed("ls | grep foo"));
assert!(p.is_command_allowed("cat file.txt | wc -l"));
}
#[test]
fn custom_allowlist() {
let p = SecurityPolicy {
allowed_commands: vec!["docker".into(), "kubectl".into()],
..SecurityPolicy::default()
};
assert!(p.is_command_allowed("docker ps"));
assert!(p.is_command_allowed("kubectl get pods"));
assert!(!p.is_command_allowed("ls"));
assert!(!p.is_command_allowed("git status"));
}
#[test]
fn empty_allowlist_blocks_everything() {
let p = SecurityPolicy {
allowed_commands: vec![],
..SecurityPolicy::default()
};
assert!(!p.is_command_allowed("ls"));
assert!(!p.is_command_allowed("echo hello"));
}
// ── is_path_allowed ─────────────────────────────────────
#[test]
fn relative_paths_allowed() {
let p = default_policy();
assert!(p.is_path_allowed("file.txt"));
assert!(p.is_path_allowed("src/main.rs"));
assert!(p.is_path_allowed("deep/nested/dir/file.txt"));
}
#[test]
fn path_traversal_blocked() {
let p = default_policy();
assert!(!p.is_path_allowed("../etc/passwd"));
assert!(!p.is_path_allowed("../../root/.ssh/id_rsa"));
assert!(!p.is_path_allowed("foo/../../../etc/shadow"));
assert!(!p.is_path_allowed(".."));
}
#[test]
fn absolute_paths_blocked_when_workspace_only() {
let p = default_policy();
assert!(!p.is_path_allowed("/etc/passwd"));
assert!(!p.is_path_allowed("/root/.ssh/id_rsa"));
assert!(!p.is_path_allowed("/tmp/file.txt"));
}
#[test]
fn absolute_paths_allowed_when_not_workspace_only() {
let p = SecurityPolicy {
workspace_only: false,
forbidden_paths: vec![],
..SecurityPolicy::default()
};
assert!(p.is_path_allowed("/tmp/file.txt"));
}
#[test]
fn forbidden_paths_blocked() {
let p = SecurityPolicy {
workspace_only: false,
..SecurityPolicy::default()
};
assert!(!p.is_path_allowed("/etc/passwd"));
assert!(!p.is_path_allowed("/root/.bashrc"));
assert!(!p.is_path_allowed("~/.ssh/id_rsa"));
assert!(!p.is_path_allowed("~/.gnupg/pubring.kbx"));
}
#[test]
fn empty_path_allowed() {
let p = default_policy();
assert!(p.is_path_allowed(""));
}
#[test]
fn dotfile_in_workspace_allowed() {
let p = default_policy();
assert!(p.is_path_allowed(".gitignore"));
assert!(p.is_path_allowed(".env"));
}
// ── from_config ─────────────────────────────────────────
#[test]
fn from_config_maps_all_fields() {
let autonomy_config = crate::config::AutonomyConfig {
level: AutonomyLevel::Full,
workspace_only: false,
allowed_commands: vec!["docker".into()],
forbidden_paths: vec!["/secret".into()],
max_actions_per_hour: 100,
max_cost_per_day_cents: 1000,
};
let workspace = PathBuf::from("/tmp/test-workspace");
let policy = SecurityPolicy::from_config(&autonomy_config, &workspace);
assert_eq!(policy.autonomy, AutonomyLevel::Full);
assert!(!policy.workspace_only);
assert_eq!(policy.allowed_commands, vec!["docker"]);
assert_eq!(policy.forbidden_paths, vec!["/secret"]);
assert_eq!(policy.max_actions_per_hour, 100);
assert_eq!(policy.max_cost_per_day_cents, 1000);
assert_eq!(policy.workspace_dir, PathBuf::from("/tmp/test-workspace"));
}
// ── Default policy ──────────────────────────────────────
#[test]
fn default_policy_has_sane_values() {
let p = SecurityPolicy::default();
assert_eq!(p.autonomy, AutonomyLevel::Supervised);
assert!(p.workspace_only);
assert!(!p.allowed_commands.is_empty());
assert!(!p.forbidden_paths.is_empty());
assert!(p.max_actions_per_hour > 0);
assert!(p.max_cost_per_day_cents > 0);
}
}

615
src/skills/mod.rs Normal file
View file

@ -0,0 +1,615 @@
use anyhow::Result;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::{Path, PathBuf};
/// A skill is a user-defined or community-built capability.
/// Skills live in `~/.zeroclaw/workspace/skills/<name>/SKILL.md`
/// and can include tool definitions, prompts, and automation scripts.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Skill {
pub name: String,
pub description: String,
pub version: String,
#[serde(default)]
pub author: Option<String>,
#[serde(default)]
pub tags: Vec<String>,
#[serde(default)]
pub tools: Vec<SkillTool>,
#[serde(default)]
pub prompts: Vec<String>,
}
/// A tool defined by a skill (shell command, HTTP call, etc.)
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SkillTool {
pub name: String,
pub description: String,
/// "shell", "http", "script"
pub kind: String,
/// The command/URL/script to execute
pub command: String,
#[serde(default)]
pub args: HashMap<String, String>,
}
/// Skill manifest parsed from SKILL.toml
#[derive(Debug, Clone, Serialize, Deserialize)]
struct SkillManifest {
skill: SkillMeta,
#[serde(default)]
tools: Vec<SkillTool>,
#[serde(default)]
prompts: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct SkillMeta {
name: String,
description: String,
#[serde(default = "default_version")]
version: String,
#[serde(default)]
author: Option<String>,
#[serde(default)]
tags: Vec<String>,
}
fn default_version() -> String {
"0.1.0".to_string()
}
/// Load all skills from the workspace skills directory
pub fn load_skills(workspace_dir: &Path) -> Vec<Skill> {
let skills_dir = workspace_dir.join("skills");
if !skills_dir.exists() {
return Vec::new();
}
let mut skills = Vec::new();
let Ok(entries) = std::fs::read_dir(&skills_dir) else {
return skills;
};
for entry in entries.flatten() {
let path = entry.path();
if !path.is_dir() {
continue;
}
// Try SKILL.toml first, then SKILL.md
let manifest_path = path.join("SKILL.toml");
let md_path = path.join("SKILL.md");
if manifest_path.exists() {
if let Ok(skill) = load_skill_toml(&manifest_path) {
skills.push(skill);
}
} else if md_path.exists() {
if let Ok(skill) = load_skill_md(&md_path, &path) {
skills.push(skill);
}
}
}
skills
}
/// Load a skill from a SKILL.toml manifest
fn load_skill_toml(path: &Path) -> Result<Skill> {
let content = std::fs::read_to_string(path)?;
let manifest: SkillManifest = toml::from_str(&content)?;
Ok(Skill {
name: manifest.skill.name,
description: manifest.skill.description,
version: manifest.skill.version,
author: manifest.skill.author,
tags: manifest.skill.tags,
tools: manifest.tools,
prompts: manifest.prompts,
})
}
/// Load a skill from a SKILL.md file (simpler format)
fn load_skill_md(path: &Path, dir: &Path) -> Result<Skill> {
let content = std::fs::read_to_string(path)?;
let name = dir
.file_name()
.and_then(|n| n.to_str())
.unwrap_or("unknown")
.to_string();
// Extract description from first non-heading line
let description = content
.lines()
.find(|l| !l.starts_with('#') && !l.trim().is_empty())
.unwrap_or("No description")
.trim()
.to_string();
Ok(Skill {
name,
description,
version: "0.1.0".to_string(),
author: None,
tags: Vec::new(),
tools: Vec::new(),
prompts: vec![content],
})
}
/// Build a system prompt addition from all loaded skills
pub fn skills_to_prompt(skills: &[Skill]) -> String {
use std::fmt::Write;
if skills.is_empty() {
return String::new();
}
let mut prompt = String::from("\n## Active Skills\n\n");
for skill in skills {
let _ = writeln!(prompt, "### {} (v{})", skill.name, skill.version);
let _ = writeln!(prompt, "{}", skill.description);
if !skill.tools.is_empty() {
prompt.push_str("Tools:\n");
for tool in &skill.tools {
let _ = writeln!(prompt, "- **{}**: {} ({})", tool.name, tool.description, tool.kind);
}
}
for p in &skill.prompts {
prompt.push_str(p);
prompt.push('\n');
}
prompt.push('\n');
}
prompt
}
/// Get the skills directory path
pub fn skills_dir(workspace_dir: &Path) -> PathBuf {
workspace_dir.join("skills")
}
/// Initialize the skills directory with a README
pub fn init_skills_dir(workspace_dir: &Path) -> Result<()> {
let dir = skills_dir(workspace_dir);
std::fs::create_dir_all(&dir)?;
let readme = dir.join("README.md");
if !readme.exists() {
std::fs::write(
&readme,
"# ZeroClaw Skills\n\n\
Each subdirectory is a skill. Create a `SKILL.toml` or `SKILL.md` file inside.\n\n\
## SKILL.toml format\n\n\
```toml\n\
[skill]\n\
name = \"my-skill\"\n\
description = \"What this skill does\"\n\
version = \"0.1.0\"\n\
author = \"your-name\"\n\
tags = [\"productivity\", \"automation\"]\n\n\
[[tools]]\n\
name = \"my_tool\"\n\
description = \"What this tool does\"\n\
kind = \"shell\"\n\
command = \"echo hello\"\n\
```\n\n\
## SKILL.md format (simpler)\n\n\
Just write a markdown file with instructions for the agent.\n\
The agent will read it and follow the instructions.\n\n\
## Installing community skills\n\n\
```bash\n\
zeroclaw skills install <github-url>\n\
zeroclaw skills list\n\
```\n",
)?;
}
Ok(())
}
/// Handle the `skills` CLI command
pub fn handle_command(command: super::SkillCommands, workspace_dir: &Path) -> Result<()> {
match command {
super::SkillCommands::List => {
let skills = load_skills(workspace_dir);
if skills.is_empty() {
println!("No skills installed.");
println!();
println!(" Create one: mkdir -p ~/.zeroclaw/workspace/skills/my-skill");
println!(" echo '# My Skill' > ~/.zeroclaw/workspace/skills/my-skill/SKILL.md");
println!();
println!(" Or install: zeroclaw skills install <github-url>");
} else {
println!("Installed skills ({}):", skills.len());
println!();
for skill in &skills {
println!(
" {} {} — {}",
console::style(&skill.name).white().bold(),
console::style(format!("v{}", skill.version)).dim(),
skill.description
);
if !skill.tools.is_empty() {
println!(
" Tools: {}",
skill.tools.iter().map(|t| t.name.as_str()).collect::<Vec<_>>().join(", ")
);
}
if !skill.tags.is_empty() {
println!(
" Tags: {}",
skill.tags.join(", ")
);
}
}
}
println!();
Ok(())
}
super::SkillCommands::Install { source } => {
println!("Installing skill from: {source}");
let skills_path = skills_dir(workspace_dir);
std::fs::create_dir_all(&skills_path)?;
if source.starts_with("http") || source.contains("github.com") {
// Git clone
let output = std::process::Command::new("git")
.args(["clone", "--depth", "1", &source])
.current_dir(&skills_path)
.output()?;
if output.status.success() {
println!(" {} Skill installed successfully!", console::style("").green().bold());
println!(" Restart `zeroclaw channel start` to activate.");
} else {
let stderr = String::from_utf8_lossy(&output.stderr);
anyhow::bail!("Git clone failed: {stderr}");
}
} else {
// Local path — symlink or copy
let src = PathBuf::from(&source);
if !src.exists() {
anyhow::bail!("Source path does not exist: {source}");
}
let name = src.file_name().unwrap_or_default();
let dest = skills_path.join(name);
#[cfg(unix)]
std::os::unix::fs::symlink(&src, &dest)?;
#[cfg(not(unix))]
{
// On non-unix, copy the directory
anyhow::bail!("Symlink not supported on this platform. Copy the skill directory manually.");
}
println!(" {} Skill linked: {}", console::style("").green().bold(), dest.display());
}
Ok(())
}
super::SkillCommands::Remove { name } => {
let skill_path = skills_dir(workspace_dir).join(&name);
if !skill_path.exists() {
anyhow::bail!("Skill not found: {name}");
}
std::fs::remove_dir_all(&skill_path)?;
println!(" {} Skill '{}' removed.", console::style("").green().bold(), name);
Ok(())
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs;
#[test]
fn load_empty_skills_dir() {
let dir = tempfile::tempdir().unwrap();
let skills = load_skills(dir.path());
assert!(skills.is_empty());
}
#[test]
fn load_skill_from_toml() {
let dir = tempfile::tempdir().unwrap();
let skills_dir = dir.path().join("skills");
let skill_dir = skills_dir.join("test-skill");
fs::create_dir_all(&skill_dir).unwrap();
fs::write(
skill_dir.join("SKILL.toml"),
r#"
[skill]
name = "test-skill"
description = "A test skill"
version = "1.0.0"
tags = ["test"]
[[tools]]
name = "hello"
description = "Says hello"
kind = "shell"
command = "echo hello"
"#,
)
.unwrap();
let skills = load_skills(dir.path());
assert_eq!(skills.len(), 1);
assert_eq!(skills[0].name, "test-skill");
assert_eq!(skills[0].tools.len(), 1);
assert_eq!(skills[0].tools[0].name, "hello");
}
#[test]
fn load_skill_from_md() {
let dir = tempfile::tempdir().unwrap();
let skills_dir = dir.path().join("skills");
let skill_dir = skills_dir.join("md-skill");
fs::create_dir_all(&skill_dir).unwrap();
fs::write(
skill_dir.join("SKILL.md"),
"# My Skill\nThis skill does cool things.\n",
)
.unwrap();
let skills = load_skills(dir.path());
assert_eq!(skills.len(), 1);
assert_eq!(skills[0].name, "md-skill");
assert!(skills[0].description.contains("cool things"));
}
#[test]
fn skills_to_prompt_empty() {
let prompt = skills_to_prompt(&[]);
assert!(prompt.is_empty());
}
#[test]
fn skills_to_prompt_with_skills() {
let skills = vec![Skill {
name: "test".to_string(),
description: "A test".to_string(),
version: "1.0.0".to_string(),
author: None,
tags: vec![],
tools: vec![],
prompts: vec!["Do the thing.".to_string()],
}];
let prompt = skills_to_prompt(&skills);
assert!(prompt.contains("test"));
assert!(prompt.contains("Do the thing"));
}
#[test]
fn init_skills_creates_readme() {
let dir = tempfile::tempdir().unwrap();
init_skills_dir(dir.path()).unwrap();
assert!(dir.path().join("skills").join("README.md").exists());
}
#[test]
fn init_skills_idempotent() {
let dir = tempfile::tempdir().unwrap();
init_skills_dir(dir.path()).unwrap();
init_skills_dir(dir.path()).unwrap(); // second call should not fail
assert!(dir.path().join("skills").join("README.md").exists());
}
#[test]
fn load_nonexistent_dir() {
let dir = tempfile::tempdir().unwrap();
let fake = dir.path().join("nonexistent");
let skills = load_skills(&fake);
assert!(skills.is_empty());
}
#[test]
fn load_ignores_files_in_skills_dir() {
let dir = tempfile::tempdir().unwrap();
let skills_dir = dir.path().join("skills");
fs::create_dir_all(&skills_dir).unwrap();
// A file, not a directory — should be ignored
fs::write(skills_dir.join("not-a-skill.txt"), "hello").unwrap();
let skills = load_skills(dir.path());
assert!(skills.is_empty());
}
#[test]
fn load_ignores_dir_without_manifest() {
let dir = tempfile::tempdir().unwrap();
let skills_dir = dir.path().join("skills");
let empty_skill = skills_dir.join("empty-skill");
fs::create_dir_all(&empty_skill).unwrap();
// Directory exists but no SKILL.toml or SKILL.md
let skills = load_skills(dir.path());
assert!(skills.is_empty());
}
#[test]
fn load_multiple_skills() {
let dir = tempfile::tempdir().unwrap();
let skills_dir = dir.path().join("skills");
for name in ["alpha", "beta", "gamma"] {
let skill_dir = skills_dir.join(name);
fs::create_dir_all(&skill_dir).unwrap();
fs::write(
skill_dir.join("SKILL.md"),
format!("# {name}\nSkill {name} description.\n"),
)
.unwrap();
}
let skills = load_skills(dir.path());
assert_eq!(skills.len(), 3);
}
#[test]
fn toml_skill_with_multiple_tools() {
let dir = tempfile::tempdir().unwrap();
let skills_dir = dir.path().join("skills");
let skill_dir = skills_dir.join("multi-tool");
fs::create_dir_all(&skill_dir).unwrap();
fs::write(
skill_dir.join("SKILL.toml"),
r#"
[skill]
name = "multi-tool"
description = "Has many tools"
version = "2.0.0"
author = "tester"
tags = ["automation", "devops"]
[[tools]]
name = "build"
description = "Build the project"
kind = "shell"
command = "cargo build"
[[tools]]
name = "test"
description = "Run tests"
kind = "shell"
command = "cargo test"
[[tools]]
name = "deploy"
description = "Deploy via HTTP"
kind = "http"
command = "https://api.example.com/deploy"
"#,
)
.unwrap();
let skills = load_skills(dir.path());
assert_eq!(skills.len(), 1);
let s = &skills[0];
assert_eq!(s.name, "multi-tool");
assert_eq!(s.version, "2.0.0");
assert_eq!(s.author.as_deref(), Some("tester"));
assert_eq!(s.tags, vec!["automation", "devops"]);
assert_eq!(s.tools.len(), 3);
assert_eq!(s.tools[0].name, "build");
assert_eq!(s.tools[1].kind, "shell");
assert_eq!(s.tools[2].kind, "http");
}
#[test]
fn toml_skill_minimal() {
let dir = tempfile::tempdir().unwrap();
let skills_dir = dir.path().join("skills");
let skill_dir = skills_dir.join("minimal");
fs::create_dir_all(&skill_dir).unwrap();
fs::write(
skill_dir.join("SKILL.toml"),
r#"
[skill]
name = "minimal"
description = "Bare minimum"
"#,
)
.unwrap();
let skills = load_skills(dir.path());
assert_eq!(skills.len(), 1);
assert_eq!(skills[0].version, "0.1.0"); // default version
assert!(skills[0].author.is_none());
assert!(skills[0].tags.is_empty());
assert!(skills[0].tools.is_empty());
}
#[test]
fn toml_skill_invalid_syntax_skipped() {
let dir = tempfile::tempdir().unwrap();
let skills_dir = dir.path().join("skills");
let skill_dir = skills_dir.join("broken");
fs::create_dir_all(&skill_dir).unwrap();
fs::write(skill_dir.join("SKILL.toml"), "this is not valid toml {{{{").unwrap();
let skills = load_skills(dir.path());
assert!(skills.is_empty()); // broken skill is skipped
}
#[test]
fn md_skill_heading_only() {
let dir = tempfile::tempdir().unwrap();
let skills_dir = dir.path().join("skills");
let skill_dir = skills_dir.join("heading-only");
fs::create_dir_all(&skill_dir).unwrap();
fs::write(skill_dir.join("SKILL.md"), "# Just a Heading\n").unwrap();
let skills = load_skills(dir.path());
assert_eq!(skills.len(), 1);
assert_eq!(skills[0].description, "No description");
}
#[test]
fn skills_to_prompt_includes_tools() {
let skills = vec![Skill {
name: "weather".to_string(),
description: "Get weather".to_string(),
version: "1.0.0".to_string(),
author: None,
tags: vec![],
tools: vec![SkillTool {
name: "get_weather".to_string(),
description: "Fetch forecast".to_string(),
kind: "shell".to_string(),
command: "curl wttr.in".to_string(),
args: HashMap::new(),
}],
prompts: vec![],
}];
let prompt = skills_to_prompt(&skills);
assert!(prompt.contains("weather"));
assert!(prompt.contains("get_weather"));
assert!(prompt.contains("Fetch forecast"));
assert!(prompt.contains("shell"));
}
#[test]
fn skills_dir_path() {
let base = std::path::Path::new("/home/user/.zeroclaw");
let dir = skills_dir(base);
assert_eq!(dir, PathBuf::from("/home/user/.zeroclaw/skills"));
}
#[test]
fn toml_prefers_over_md() {
let dir = tempfile::tempdir().unwrap();
let skills_dir = dir.path().join("skills");
let skill_dir = skills_dir.join("dual");
fs::create_dir_all(&skill_dir).unwrap();
fs::write(
skill_dir.join("SKILL.toml"),
"[skill]\nname = \"from-toml\"\ndescription = \"TOML wins\"\n",
)
.unwrap();
fs::write(skill_dir.join("SKILL.md"), "# From MD\nMD description\n").unwrap();
let skills = load_skills(dir.path());
assert_eq!(skills.len(), 1);
assert_eq!(skills[0].name, "from-toml"); // TOML takes priority
}
}

203
src/tools/file_read.rs Normal file
View file

@ -0,0 +1,203 @@
use super::traits::{Tool, ToolResult};
use crate::security::SecurityPolicy;
use async_trait::async_trait;
use serde_json::json;
use std::sync::Arc;
/// Read file contents with path sandboxing
pub struct FileReadTool {
security: Arc<SecurityPolicy>,
}
impl FileReadTool {
pub fn new(security: Arc<SecurityPolicy>) -> Self {
Self { security }
}
}
#[async_trait]
impl Tool for FileReadTool {
fn name(&self) -> &str {
"file_read"
}
fn description(&self) -> &str {
"Read the contents of a file in the workspace"
}
fn parameters_schema(&self) -> serde_json::Value {
json!({
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "Relative path to the file within the workspace"
}
},
"required": ["path"]
})
}
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
let path = args
.get("path")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing 'path' parameter"))?;
// Security check: validate path is within workspace
if !self.security.is_path_allowed(path) {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!("Path not allowed by security policy: {path}")),
});
}
let full_path = self.security.workspace_dir.join(path);
match tokio::fs::read_to_string(&full_path).await {
Ok(contents) => Ok(ToolResult {
success: true,
output: contents,
error: None,
}),
Err(e) => Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!("Failed to read file: {e}")),
}),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::security::{AutonomyLevel, SecurityPolicy};
fn test_security(workspace: std::path::PathBuf) -> Arc<SecurityPolicy> {
Arc::new(SecurityPolicy {
autonomy: AutonomyLevel::Supervised,
workspace_dir: workspace,
..SecurityPolicy::default()
})
}
#[test]
fn file_read_name() {
let tool = FileReadTool::new(test_security(std::env::temp_dir()));
assert_eq!(tool.name(), "file_read");
}
#[test]
fn file_read_schema_has_path() {
let tool = FileReadTool::new(test_security(std::env::temp_dir()));
let schema = tool.parameters_schema();
assert!(schema["properties"]["path"].is_object());
assert!(schema["required"]
.as_array()
.unwrap()
.contains(&json!("path")));
}
#[tokio::test]
async fn file_read_existing_file() {
let dir = std::env::temp_dir().join("zeroclaw_test_file_read");
let _ = tokio::fs::remove_dir_all(&dir).await;
tokio::fs::create_dir_all(&dir).await.unwrap();
tokio::fs::write(dir.join("test.txt"), "hello world")
.await
.unwrap();
let tool = FileReadTool::new(test_security(dir.clone()));
let result = tool.execute(json!({"path": "test.txt"})).await.unwrap();
assert!(result.success);
assert_eq!(result.output, "hello world");
assert!(result.error.is_none());
let _ = tokio::fs::remove_dir_all(&dir).await;
}
#[tokio::test]
async fn file_read_nonexistent_file() {
let dir = std::env::temp_dir().join("zeroclaw_test_file_read_missing");
let _ = tokio::fs::remove_dir_all(&dir).await;
tokio::fs::create_dir_all(&dir).await.unwrap();
let tool = FileReadTool::new(test_security(dir.clone()));
let result = tool.execute(json!({"path": "nope.txt"})).await.unwrap();
assert!(!result.success);
assert!(result.error.as_ref().unwrap().contains("Failed to read"));
let _ = tokio::fs::remove_dir_all(&dir).await;
}
#[tokio::test]
async fn file_read_blocks_path_traversal() {
let dir = std::env::temp_dir().join("zeroclaw_test_file_read_traversal");
let _ = tokio::fs::remove_dir_all(&dir).await;
tokio::fs::create_dir_all(&dir).await.unwrap();
let tool = FileReadTool::new(test_security(dir.clone()));
let result = tool
.execute(json!({"path": "../../../etc/passwd"}))
.await
.unwrap();
assert!(!result.success);
assert!(result.error.as_ref().unwrap().contains("not allowed"));
let _ = tokio::fs::remove_dir_all(&dir).await;
}
#[tokio::test]
async fn file_read_blocks_absolute_path() {
let tool = FileReadTool::new(test_security(std::env::temp_dir()));
let result = tool.execute(json!({"path": "/etc/passwd"})).await.unwrap();
assert!(!result.success);
assert!(result.error.as_ref().unwrap().contains("not allowed"));
}
#[tokio::test]
async fn file_read_missing_path_param() {
let tool = FileReadTool::new(test_security(std::env::temp_dir()));
let result = tool.execute(json!({})).await;
assert!(result.is_err());
}
#[tokio::test]
async fn file_read_empty_file() {
let dir = std::env::temp_dir().join("zeroclaw_test_file_read_empty");
let _ = tokio::fs::remove_dir_all(&dir).await;
tokio::fs::create_dir_all(&dir).await.unwrap();
tokio::fs::write(dir.join("empty.txt"), "").await.unwrap();
let tool = FileReadTool::new(test_security(dir.clone()));
let result = tool.execute(json!({"path": "empty.txt"})).await.unwrap();
assert!(result.success);
assert_eq!(result.output, "");
let _ = tokio::fs::remove_dir_all(&dir).await;
}
#[tokio::test]
async fn file_read_nested_path() {
let dir = std::env::temp_dir().join("zeroclaw_test_file_read_nested");
let _ = tokio::fs::remove_dir_all(&dir).await;
tokio::fs::create_dir_all(dir.join("sub/dir"))
.await
.unwrap();
tokio::fs::write(dir.join("sub/dir/deep.txt"), "deep content")
.await
.unwrap();
let tool = FileReadTool::new(test_security(dir.clone()));
let result = tool
.execute(json!({"path": "sub/dir/deep.txt"}))
.await
.unwrap();
assert!(result.success);
assert_eq!(result.output, "deep content");
let _ = tokio::fs::remove_dir_all(&dir).await;
}
}

242
src/tools/file_write.rs Normal file
View file

@ -0,0 +1,242 @@
use super::traits::{Tool, ToolResult};
use crate::security::SecurityPolicy;
use async_trait::async_trait;
use serde_json::json;
use std::sync::Arc;
/// Write file contents with path sandboxing
pub struct FileWriteTool {
security: Arc<SecurityPolicy>,
}
impl FileWriteTool {
pub fn new(security: Arc<SecurityPolicy>) -> Self {
Self { security }
}
}
#[async_trait]
impl Tool for FileWriteTool {
fn name(&self) -> &str {
"file_write"
}
fn description(&self) -> &str {
"Write contents to a file in the workspace"
}
fn parameters_schema(&self) -> serde_json::Value {
json!({
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "Relative path to the file within the workspace"
},
"content": {
"type": "string",
"description": "Content to write to the file"
}
},
"required": ["path", "content"]
})
}
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
let path = args
.get("path")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing 'path' parameter"))?;
let content = args
.get("content")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing 'content' parameter"))?;
// Security check: validate path is within workspace
if !self.security.is_path_allowed(path) {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!("Path not allowed by security policy: {path}")),
});
}
let full_path = self.security.workspace_dir.join(path);
// Ensure parent directory exists
if let Some(parent) = full_path.parent() {
tokio::fs::create_dir_all(parent).await?;
}
match tokio::fs::write(&full_path, content).await {
Ok(()) => Ok(ToolResult {
success: true,
output: format!("Written {} bytes to {path}", content.len()),
error: None,
}),
Err(e) => Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!("Failed to write file: {e}")),
}),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::security::{AutonomyLevel, SecurityPolicy};
fn test_security(workspace: std::path::PathBuf) -> Arc<SecurityPolicy> {
Arc::new(SecurityPolicy {
autonomy: AutonomyLevel::Supervised,
workspace_dir: workspace,
..SecurityPolicy::default()
})
}
#[test]
fn file_write_name() {
let tool = FileWriteTool::new(test_security(std::env::temp_dir()));
assert_eq!(tool.name(), "file_write");
}
#[test]
fn file_write_schema_has_path_and_content() {
let tool = FileWriteTool::new(test_security(std::env::temp_dir()));
let schema = tool.parameters_schema();
assert!(schema["properties"]["path"].is_object());
assert!(schema["properties"]["content"].is_object());
let required = schema["required"].as_array().unwrap();
assert!(required.contains(&json!("path")));
assert!(required.contains(&json!("content")));
}
#[tokio::test]
async fn file_write_creates_file() {
let dir = std::env::temp_dir().join("zeroclaw_test_file_write");
let _ = tokio::fs::remove_dir_all(&dir).await;
tokio::fs::create_dir_all(&dir).await.unwrap();
let tool = FileWriteTool::new(test_security(dir.clone()));
let result = tool
.execute(json!({"path": "out.txt", "content": "written!"}))
.await
.unwrap();
assert!(result.success);
assert!(result.output.contains("8 bytes"));
let content = tokio::fs::read_to_string(dir.join("out.txt"))
.await
.unwrap();
assert_eq!(content, "written!");
let _ = tokio::fs::remove_dir_all(&dir).await;
}
#[tokio::test]
async fn file_write_creates_parent_dirs() {
let dir = std::env::temp_dir().join("zeroclaw_test_file_write_nested");
let _ = tokio::fs::remove_dir_all(&dir).await;
tokio::fs::create_dir_all(&dir).await.unwrap();
let tool = FileWriteTool::new(test_security(dir.clone()));
let result = tool
.execute(json!({"path": "a/b/c/deep.txt", "content": "deep"}))
.await
.unwrap();
assert!(result.success);
let content = tokio::fs::read_to_string(dir.join("a/b/c/deep.txt"))
.await
.unwrap();
assert_eq!(content, "deep");
let _ = tokio::fs::remove_dir_all(&dir).await;
}
#[tokio::test]
async fn file_write_overwrites_existing() {
let dir = std::env::temp_dir().join("zeroclaw_test_file_write_overwrite");
let _ = tokio::fs::remove_dir_all(&dir).await;
tokio::fs::create_dir_all(&dir).await.unwrap();
tokio::fs::write(dir.join("exist.txt"), "old")
.await
.unwrap();
let tool = FileWriteTool::new(test_security(dir.clone()));
let result = tool
.execute(json!({"path": "exist.txt", "content": "new"}))
.await
.unwrap();
assert!(result.success);
let content = tokio::fs::read_to_string(dir.join("exist.txt"))
.await
.unwrap();
assert_eq!(content, "new");
let _ = tokio::fs::remove_dir_all(&dir).await;
}
#[tokio::test]
async fn file_write_blocks_path_traversal() {
let dir = std::env::temp_dir().join("zeroclaw_test_file_write_traversal");
let _ = tokio::fs::remove_dir_all(&dir).await;
tokio::fs::create_dir_all(&dir).await.unwrap();
let tool = FileWriteTool::new(test_security(dir.clone()));
let result = tool
.execute(json!({"path": "../../etc/evil", "content": "bad"}))
.await
.unwrap();
assert!(!result.success);
assert!(result.error.as_ref().unwrap().contains("not allowed"));
let _ = tokio::fs::remove_dir_all(&dir).await;
}
#[tokio::test]
async fn file_write_blocks_absolute_path() {
let tool = FileWriteTool::new(test_security(std::env::temp_dir()));
let result = tool
.execute(json!({"path": "/etc/evil", "content": "bad"}))
.await
.unwrap();
assert!(!result.success);
assert!(result.error.as_ref().unwrap().contains("not allowed"));
}
#[tokio::test]
async fn file_write_missing_path_param() {
let tool = FileWriteTool::new(test_security(std::env::temp_dir()));
let result = tool.execute(json!({"content": "data"})).await;
assert!(result.is_err());
}
#[tokio::test]
async fn file_write_missing_content_param() {
let tool = FileWriteTool::new(test_security(std::env::temp_dir()));
let result = tool.execute(json!({"path": "file.txt"})).await;
assert!(result.is_err());
}
#[tokio::test]
async fn file_write_empty_content() {
let dir = std::env::temp_dir().join("zeroclaw_test_file_write_empty");
let _ = tokio::fs::remove_dir_all(&dir).await;
tokio::fs::create_dir_all(&dir).await.unwrap();
let tool = FileWriteTool::new(test_security(dir.clone()));
let result = tool
.execute(json!({"path": "empty.txt", "content": ""}))
.await
.unwrap();
assert!(result.success);
assert!(result.output.contains("0 bytes"));
let _ = tokio::fs::remove_dir_all(&dir).await;
}
}

118
src/tools/memory_forget.rs Normal file
View file

@ -0,0 +1,118 @@
use super::traits::{Tool, ToolResult};
use crate::memory::Memory;
use async_trait::async_trait;
use serde_json::json;
use std::sync::Arc;
/// Let the agent forget/delete a memory entry
pub struct MemoryForgetTool {
memory: Arc<dyn Memory>,
}
impl MemoryForgetTool {
pub fn new(memory: Arc<dyn Memory>) -> Self {
Self { memory }
}
}
#[async_trait]
impl Tool for MemoryForgetTool {
fn name(&self) -> &str {
"memory_forget"
}
fn description(&self) -> &str {
"Remove a memory by key. Use to delete outdated facts or sensitive data. Returns whether the memory was found and removed."
}
fn parameters_schema(&self) -> serde_json::Value {
json!({
"type": "object",
"properties": {
"key": {
"type": "string",
"description": "The key of the memory to forget"
}
},
"required": ["key"]
})
}
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
let key = args
.get("key")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing 'key' parameter"))?;
match self.memory.forget(key).await {
Ok(true) => Ok(ToolResult {
success: true,
output: format!("Forgot memory: {key}"),
error: None,
}),
Ok(false) => Ok(ToolResult {
success: true,
output: format!("No memory found with key: {key}"),
error: None,
}),
Err(e) => Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!("Failed to forget memory: {e}")),
}),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::memory::{MemoryCategory, SqliteMemory};
use tempfile::TempDir;
fn test_mem() -> (TempDir, Arc<dyn Memory>) {
let tmp = TempDir::new().unwrap();
let mem = SqliteMemory::new(tmp.path()).unwrap();
(tmp, Arc::new(mem))
}
#[test]
fn name_and_schema() {
let (_tmp, mem) = test_mem();
let tool = MemoryForgetTool::new(mem);
assert_eq!(tool.name(), "memory_forget");
assert!(tool.parameters_schema()["properties"]["key"].is_object());
}
#[tokio::test]
async fn forget_existing() {
let (_tmp, mem) = test_mem();
mem.store("temp", "temporary", MemoryCategory::Conversation)
.await
.unwrap();
let tool = MemoryForgetTool::new(mem.clone());
let result = tool.execute(json!({"key": "temp"})).await.unwrap();
assert!(result.success);
assert!(result.output.contains("Forgot"));
assert!(mem.get("temp").await.unwrap().is_none());
}
#[tokio::test]
async fn forget_nonexistent() {
let (_tmp, mem) = test_mem();
let tool = MemoryForgetTool::new(mem);
let result = tool.execute(json!({"key": "nope"})).await.unwrap();
assert!(result.success);
assert!(result.output.contains("No memory found"));
}
#[tokio::test]
async fn forget_missing_key() {
let (_tmp, mem) = test_mem();
let tool = MemoryForgetTool::new(mem);
let result = tool.execute(json!({})).await;
assert!(result.is_err());
}
}

163
src/tools/memory_recall.rs Normal file
View file

@ -0,0 +1,163 @@
use super::traits::{Tool, ToolResult};
use crate::memory::Memory;
use async_trait::async_trait;
use serde_json::json;
use std::fmt::Write;
use std::sync::Arc;
/// Let the agent search its own memory
pub struct MemoryRecallTool {
memory: Arc<dyn Memory>,
}
impl MemoryRecallTool {
pub fn new(memory: Arc<dyn Memory>) -> Self {
Self { memory }
}
}
#[async_trait]
impl Tool for MemoryRecallTool {
fn name(&self) -> &str {
"memory_recall"
}
fn description(&self) -> &str {
"Search long-term memory for relevant facts, preferences, or context. Returns scored results ranked by relevance."
}
fn parameters_schema(&self) -> serde_json::Value {
json!({
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "Keywords or phrase to search for in memory"
},
"limit": {
"type": "integer",
"description": "Max results to return (default: 5)"
}
},
"required": ["query"]
})
}
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
let query = args
.get("query")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing 'query' parameter"))?;
#[allow(clippy::cast_possible_truncation)]
let limit = args
.get("limit")
.and_then(serde_json::Value::as_u64)
.map_or(5, |v| v as usize);
match self.memory.recall(query, limit).await {
Ok(entries) if entries.is_empty() => Ok(ToolResult {
success: true,
output: "No memories found matching that query.".into(),
error: None,
}),
Ok(entries) => {
let mut output = format!("Found {} memories:\n", entries.len());
for entry in &entries {
let score = entry.score.map_or_else(String::new, |s| format!(" [{s:.0}%]"));
let _ = writeln!(
output,
"- [{}] {}: {}{score}",
entry.category, entry.key, entry.content
);
}
Ok(ToolResult {
success: true,
output,
error: None,
})
}
Err(e) => Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!("Memory recall failed: {e}")),
}),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::memory::{MemoryCategory, SqliteMemory};
use tempfile::TempDir;
fn seeded_mem() -> (TempDir, Arc<dyn Memory>) {
let tmp = TempDir::new().unwrap();
let mem = SqliteMemory::new(tmp.path()).unwrap();
(tmp, Arc::new(mem))
}
#[tokio::test]
async fn recall_empty() {
let (_tmp, mem) = seeded_mem();
let tool = MemoryRecallTool::new(mem);
let result = tool
.execute(json!({"query": "anything"}))
.await
.unwrap();
assert!(result.success);
assert!(result.output.contains("No memories found"));
}
#[tokio::test]
async fn recall_finds_match() {
let (_tmp, mem) = seeded_mem();
mem.store("lang", "User prefers Rust", MemoryCategory::Core)
.await
.unwrap();
mem.store("tz", "Timezone is EST", MemoryCategory::Core)
.await
.unwrap();
let tool = MemoryRecallTool::new(mem);
let result = tool.execute(json!({"query": "Rust"})).await.unwrap();
assert!(result.success);
assert!(result.output.contains("Rust"));
assert!(result.output.contains("Found 1"));
}
#[tokio::test]
async fn recall_respects_limit() {
let (_tmp, mem) = seeded_mem();
for i in 0..10 {
mem.store(&format!("k{i}"), &format!("Rust fact {i}"), MemoryCategory::Core)
.await
.unwrap();
}
let tool = MemoryRecallTool::new(mem);
let result = tool
.execute(json!({"query": "Rust", "limit": 3}))
.await
.unwrap();
assert!(result.success);
assert!(result.output.contains("Found 3"));
}
#[tokio::test]
async fn recall_missing_query() {
let (_tmp, mem) = seeded_mem();
let tool = MemoryRecallTool::new(mem);
let result = tool.execute(json!({})).await;
assert!(result.is_err());
}
#[test]
fn name_and_schema() {
let (_tmp, mem) = seeded_mem();
let tool = MemoryRecallTool::new(mem);
assert_eq!(tool.name(), "memory_recall");
assert!(tool.parameters_schema()["properties"]["query"].is_object());
}
}

146
src/tools/memory_store.rs Normal file
View file

@ -0,0 +1,146 @@
use super::traits::{Tool, ToolResult};
use crate::memory::{Memory, MemoryCategory};
use async_trait::async_trait;
use serde_json::json;
use std::sync::Arc;
/// Let the agent store memories — its own brain writes
pub struct MemoryStoreTool {
memory: Arc<dyn Memory>,
}
impl MemoryStoreTool {
pub fn new(memory: Arc<dyn Memory>) -> Self {
Self { memory }
}
}
#[async_trait]
impl Tool for MemoryStoreTool {
fn name(&self) -> &str {
"memory_store"
}
fn description(&self) -> &str {
"Store a fact, preference, or note in long-term memory. Use category 'core' for permanent facts, 'daily' for session notes, 'conversation' for chat context."
}
fn parameters_schema(&self) -> serde_json::Value {
json!({
"type": "object",
"properties": {
"key": {
"type": "string",
"description": "Unique key for this memory (e.g. 'user_lang', 'project_stack')"
},
"content": {
"type": "string",
"description": "The information to remember"
},
"category": {
"type": "string",
"enum": ["core", "daily", "conversation"],
"description": "Memory category: core (permanent), daily (session), conversation (chat)"
}
},
"required": ["key", "content"]
})
}
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
let key = args
.get("key")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing 'key' parameter"))?;
let content = args
.get("content")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing 'content' parameter"))?;
let category = match args.get("category").and_then(|v| v.as_str()) {
Some("daily") => MemoryCategory::Daily,
Some("conversation") => MemoryCategory::Conversation,
_ => MemoryCategory::Core,
};
match self.memory.store(key, content, category).await {
Ok(()) => Ok(ToolResult {
success: true,
output: format!("Stored memory: {key}"),
error: None,
}),
Err(e) => Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!("Failed to store memory: {e}")),
}),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::memory::SqliteMemory;
use tempfile::TempDir;
fn test_mem() -> (TempDir, Arc<dyn Memory>) {
let tmp = TempDir::new().unwrap();
let mem = SqliteMemory::new(tmp.path()).unwrap();
(tmp, Arc::new(mem))
}
#[test]
fn name_and_schema() {
let (_tmp, mem) = test_mem();
let tool = MemoryStoreTool::new(mem);
assert_eq!(tool.name(), "memory_store");
let schema = tool.parameters_schema();
assert!(schema["properties"]["key"].is_object());
assert!(schema["properties"]["content"].is_object());
}
#[tokio::test]
async fn store_core() {
let (_tmp, mem) = test_mem();
let tool = MemoryStoreTool::new(mem.clone());
let result = tool
.execute(json!({"key": "lang", "content": "Prefers Rust"}))
.await
.unwrap();
assert!(result.success);
assert!(result.output.contains("lang"));
let entry = mem.get("lang").await.unwrap();
assert!(entry.is_some());
assert_eq!(entry.unwrap().content, "Prefers Rust");
}
#[tokio::test]
async fn store_with_category() {
let (_tmp, mem) = test_mem();
let tool = MemoryStoreTool::new(mem.clone());
let result = tool
.execute(json!({"key": "note", "content": "Fixed bug", "category": "daily"}))
.await
.unwrap();
assert!(result.success);
}
#[tokio::test]
async fn store_missing_key() {
let (_tmp, mem) = test_mem();
let tool = MemoryStoreTool::new(mem);
let result = tool.execute(json!({"content": "no key"})).await;
assert!(result.is_err());
}
#[tokio::test]
async fn store_missing_content() {
let (_tmp, mem) = test_mem();
let tool = MemoryStoreTool::new(mem);
let result = tool.execute(json!({"key": "no_content"})).await;
assert!(result.is_err());
}
}

189
src/tools/mod.rs Normal file
View file

@ -0,0 +1,189 @@
pub mod file_read;
pub mod file_write;
pub mod memory_forget;
pub mod memory_recall;
pub mod memory_store;
pub mod shell;
pub mod traits;
pub use file_read::FileReadTool;
pub use file_write::FileWriteTool;
pub use memory_forget::MemoryForgetTool;
pub use memory_recall::MemoryRecallTool;
pub use memory_store::MemoryStoreTool;
pub use shell::ShellTool;
pub use traits::Tool;
#[allow(unused_imports)]
pub use traits::{ToolResult, ToolSpec};
use crate::config::Config;
use crate::memory::Memory;
use crate::security::SecurityPolicy;
use anyhow::Result;
use std::sync::Arc;
/// Create the default tool registry
pub fn default_tools(security: Arc<SecurityPolicy>) -> Vec<Box<dyn Tool>> {
vec![
Box::new(ShellTool::new(security.clone())),
Box::new(FileReadTool::new(security.clone())),
Box::new(FileWriteTool::new(security)),
]
}
/// Create full tool registry including memory tools
pub fn all_tools(
security: Arc<SecurityPolicy>,
memory: Arc<dyn Memory>,
) -> Vec<Box<dyn Tool>> {
vec![
Box::new(ShellTool::new(security.clone())),
Box::new(FileReadTool::new(security.clone())),
Box::new(FileWriteTool::new(security)),
Box::new(MemoryStoreTool::new(memory.clone())),
Box::new(MemoryRecallTool::new(memory.clone())),
Box::new(MemoryForgetTool::new(memory)),
]
}
pub async fn handle_command(command: super::ToolCommands, config: Config) -> Result<()> {
let security = Arc::new(SecurityPolicy {
workspace_dir: config.workspace_dir.clone(),
..SecurityPolicy::default()
});
let mem: Arc<dyn Memory> =
Arc::from(crate::memory::create_memory(&config.memory, &config.workspace_dir)?);
let tools_list = all_tools(security, mem);
match command {
super::ToolCommands::List => {
println!("Available tools ({}):", tools_list.len());
for tool in &tools_list {
println!(" - {}: {}", tool.name(), tool.description());
}
Ok(())
}
super::ToolCommands::Test { tool, args } => {
let matched = tools_list.iter().find(|t| t.name() == tool);
match matched {
Some(t) => {
let parsed: serde_json::Value = serde_json::from_str(&args)?;
let result = t.execute(parsed).await?;
println!("Success: {}", result.success);
println!("Output: {}", result.output);
if let Some(err) = result.error {
println!("Error: {err}");
}
Ok(())
}
None => anyhow::bail!("Unknown tool: {tool}"),
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_tools_has_three() {
let security = Arc::new(SecurityPolicy::default());
let tools = default_tools(security);
assert_eq!(tools.len(), 3);
}
#[test]
fn default_tools_names() {
let security = Arc::new(SecurityPolicy::default());
let tools = default_tools(security);
let names: Vec<&str> = tools.iter().map(|t| t.name()).collect();
assert!(names.contains(&"shell"));
assert!(names.contains(&"file_read"));
assert!(names.contains(&"file_write"));
}
#[test]
fn default_tools_all_have_descriptions() {
let security = Arc::new(SecurityPolicy::default());
let tools = default_tools(security);
for tool in &tools {
assert!(
!tool.description().is_empty(),
"Tool {} has empty description",
tool.name()
);
}
}
#[test]
fn default_tools_all_have_schemas() {
let security = Arc::new(SecurityPolicy::default());
let tools = default_tools(security);
for tool in &tools {
let schema = tool.parameters_schema();
assert!(
schema.is_object(),
"Tool {} schema is not an object",
tool.name()
);
assert!(
schema["properties"].is_object(),
"Tool {} schema has no properties",
tool.name()
);
}
}
#[test]
fn tool_spec_generation() {
let security = Arc::new(SecurityPolicy::default());
let tools = default_tools(security);
for tool in &tools {
let spec = tool.spec();
assert_eq!(spec.name, tool.name());
assert_eq!(spec.description, tool.description());
assert!(spec.parameters.is_object());
}
}
#[test]
fn tool_result_serde() {
let result = ToolResult {
success: true,
output: "hello".into(),
error: None,
};
let json = serde_json::to_string(&result).unwrap();
let parsed: ToolResult = serde_json::from_str(&json).unwrap();
assert!(parsed.success);
assert_eq!(parsed.output, "hello");
assert!(parsed.error.is_none());
}
#[test]
fn tool_result_with_error_serde() {
let result = ToolResult {
success: false,
output: String::new(),
error: Some("boom".into()),
};
let json = serde_json::to_string(&result).unwrap();
let parsed: ToolResult = serde_json::from_str(&json).unwrap();
assert!(!parsed.success);
assert_eq!(parsed.error.as_deref(), Some("boom"));
}
#[test]
fn tool_spec_serde() {
let spec = ToolSpec {
name: "test".into(),
description: "A test tool".into(),
parameters: serde_json::json!({"type": "object"}),
};
let json = serde_json::to_string(&spec).unwrap();
let parsed: ToolSpec = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.name, "test");
assert_eq!(parsed.description, "A test tool");
}
}

166
src/tools/shell.rs Normal file
View file

@ -0,0 +1,166 @@
use super::traits::{Tool, ToolResult};
use crate::security::SecurityPolicy;
use async_trait::async_trait;
use serde_json::json;
use std::sync::Arc;
/// Shell command execution tool with sandboxing
pub struct ShellTool {
security: Arc<SecurityPolicy>,
}
impl ShellTool {
pub fn new(security: Arc<SecurityPolicy>) -> Self {
Self { security }
}
}
#[async_trait]
impl Tool for ShellTool {
fn name(&self) -> &str {
"shell"
}
fn description(&self) -> &str {
"Execute a shell command in the workspace directory"
}
fn parameters_schema(&self) -> serde_json::Value {
json!({
"type": "object",
"properties": {
"command": {
"type": "string",
"description": "The shell command to execute"
}
},
"required": ["command"]
})
}
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
let command = args
.get("command")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing 'command' parameter"))?;
// Security check: validate command against allowlist
if !self.security.is_command_allowed(command) {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!("Command not allowed by security policy: {command}")),
});
}
let output = tokio::process::Command::new("sh")
.arg("-c")
.arg(command)
.current_dir(&self.security.workspace_dir)
.output()
.await?;
let stdout = String::from_utf8_lossy(&output.stdout).to_string();
let stderr = String::from_utf8_lossy(&output.stderr).to_string();
Ok(ToolResult {
success: output.status.success(),
output: stdout,
error: if stderr.is_empty() {
None
} else {
Some(stderr)
},
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::security::{AutonomyLevel, SecurityPolicy};
fn test_security(autonomy: AutonomyLevel) -> Arc<SecurityPolicy> {
Arc::new(SecurityPolicy {
autonomy,
workspace_dir: std::env::temp_dir(),
..SecurityPolicy::default()
})
}
#[test]
fn shell_tool_name() {
let tool = ShellTool::new(test_security(AutonomyLevel::Supervised));
assert_eq!(tool.name(), "shell");
}
#[test]
fn shell_tool_description() {
let tool = ShellTool::new(test_security(AutonomyLevel::Supervised));
assert!(!tool.description().is_empty());
}
#[test]
fn shell_tool_schema_has_command() {
let tool = ShellTool::new(test_security(AutonomyLevel::Supervised));
let schema = tool.parameters_schema();
assert!(schema["properties"]["command"].is_object());
assert!(schema["required"]
.as_array()
.unwrap()
.contains(&json!("command")));
}
#[tokio::test]
async fn shell_executes_allowed_command() {
let tool = ShellTool::new(test_security(AutonomyLevel::Supervised));
let result = tool
.execute(json!({"command": "echo hello"}))
.await
.unwrap();
assert!(result.success);
assert!(result.output.trim().contains("hello"));
assert!(result.error.is_none());
}
#[tokio::test]
async fn shell_blocks_disallowed_command() {
let tool = ShellTool::new(test_security(AutonomyLevel::Supervised));
let result = tool.execute(json!({"command": "rm -rf /"})).await.unwrap();
assert!(!result.success);
assert!(result.error.as_ref().unwrap().contains("not allowed"));
}
#[tokio::test]
async fn shell_blocks_readonly() {
let tool = ShellTool::new(test_security(AutonomyLevel::ReadOnly));
let result = tool.execute(json!({"command": "ls"})).await.unwrap();
assert!(!result.success);
assert!(result.error.as_ref().unwrap().contains("not allowed"));
}
#[tokio::test]
async fn shell_missing_command_param() {
let tool = ShellTool::new(test_security(AutonomyLevel::Supervised));
let result = tool.execute(json!({})).await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("command"));
}
#[tokio::test]
async fn shell_wrong_type_param() {
let tool = ShellTool::new(test_security(AutonomyLevel::Supervised));
let result = tool.execute(json!({"command": 123})).await;
assert!(result.is_err());
}
#[tokio::test]
async fn shell_captures_exit_code() {
let tool = ShellTool::new(test_security(AutonomyLevel::Supervised));
let result = tool
.execute(json!({"command": "ls /nonexistent_dir_xyz"}))
.await
.unwrap();
assert!(!result.success);
}
}

43
src/tools/traits.rs Normal file
View file

@ -0,0 +1,43 @@
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
/// Result of a tool execution
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolResult {
pub success: bool,
pub output: String,
pub error: Option<String>,
}
/// Description of a tool for the LLM
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolSpec {
pub name: String,
pub description: String,
pub parameters: serde_json::Value,
}
/// Core tool trait — implement for any capability
#[async_trait]
pub trait Tool: Send + Sync {
/// Tool name (used in LLM function calling)
fn name(&self) -> &str;
/// Human-readable description
fn description(&self) -> &str;
/// JSON schema for parameters
fn parameters_schema(&self) -> serde_json::Value;
/// Execute the tool with given arguments
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult>;
/// Get the full spec for LLM registration
fn spec(&self) -> ToolSpec {
ToolSpec {
name: self.name().to_string(),
description: self.description().to_string(),
parameters: self.parameters_schema(),
}
}
}