diff --git a/.tmp_todo_probe b/.tmp_todo_probe new file mode 100644 index 0000000..e69de29 diff --git a/Cargo.lock b/Cargo.lock index 00da71f..0a9ecff 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -293,6 +293,17 @@ dependencies = [ "libc", ] +[[package]] +name = "cron" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6f8c3e73077b4b4a6ab1ea5047c37c57aee77657bc8ecd6f29b0af082d0b0c07" +dependencies = [ + "chrono", + "nom", + "once_cell", +] + [[package]] name = "crypto-common" version = "0.1.7" @@ -925,6 +936,12 @@ version = "2.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f8ca58f447f06ed17d5fc4043ce1b10dd205e060fb3ce5b979b8ed8e59ff3f79" +[[package]] +name = "minimal-lexical" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" + [[package]] name = "mio" version = "1.1.1" @@ -936,6 +953,16 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "nom" +version = "7.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d273983c5a657a70a3e8f2a01329822f3b8c8172b73826411a55751e404a0a4a" +dependencies = [ + "memchr", + "minimal-lexical", +] + [[package]] name = "nu-ansi-term" version = "0.50.3" @@ -2368,6 +2395,7 @@ dependencies = [ "chrono", "clap", "console", + "cron", "dialoguer", "directories", "futures-util", diff --git a/Cargo.toml b/Cargo.toml index 08f75b0..147c9b7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,7 +15,7 @@ categories = ["command-line-utilities", "api-bindings"] clap = { version = "4.5", features = ["derive"] } # Async runtime - feature-optimized for size -tokio = { version = "1.42", default-features = false, features = ["rt-multi-thread", "macros", "time", "net", "io-util", "sync", "process", "io-std", "fs"] } +tokio = { version = "1.42", default-features = false, features = ["rt-multi-thread", "macros", "time", "net", "io-util", "sync", "process", "io-std", "fs", "signal"] } # HTTP client - minimal features reqwest = { version = "0.12", default-features = false, features = ["json", "rustls-tls", "blocking"] } @@ -49,6 +49,7 @@ async-trait = "0.1" # Memory / persistence rusqlite = { version = "0.32", features = ["bundled"] } chrono = { version = "0.4", default-features = false, features = ["clock", "std"] } +cron = "0.12" # Interactive CLI prompts dialoguer = { version = "0.11", features = ["fuzzy-select"] } diff --git a/README.md b/README.md index 5efbbf7..8076dd4 100644 --- a/README.md +++ b/README.md @@ -12,12 +12,19 @@ License: MIT

-The fastest, smallest, fully autonomous AI assistant — deploy anywhere, swap anything. +Fast, small, and fully autonomous AI assistant infrastructure — deploy anywhere, swap anything. ``` ~3.4MB binary · <10ms startup · 1,017 tests · 22+ providers · 8 traits · Pluggable everything ``` +### Why teams pick ZeroClaw + +- **Lean by default:** small Rust binary, fast startup, low memory footprint. +- **Secure by design:** pairing, strict sandboxing, explicit allowlists, workspace scoping. +- **Fully swappable:** core systems are traits (providers, channels, tools, memory, tunnels). +- **No lock-in:** OpenAI-compatible provider support + pluggable custom endpoints. + ## Benchmark Snapshot (ZeroClaw vs OpenClaw) Local machine quick benchmark (macOS arm64, Feb 2026), same host, 3 runs each. @@ -30,7 +37,17 @@ Local machine quick benchmark (macOS arm64, Feb 2026), same host, 3 runs each. | `--help` max RSS observed | **~7.3 MB** | **~394 MB** | | `status` max RSS observed | **~7.8 MB** | **~1.52 GB** | -> Notes: measured with `/usr/bin/time -l`; first run includes cold-start effects. OpenClaw results include `pnpm install` + `pnpm build` before execution. +> Notes: measured with `/usr/bin/time -l`; first run includes cold-start effects. OpenClaw results were measured after `pnpm install` + `pnpm build`. + +Reproduce ZeroClaw numbers locally: + +```bash +cargo build --release +ls -lh target/release/zeroclaw + +/usr/bin/time -l target/release/zeroclaw --help +/usr/bin/time -l target/release/zeroclaw status +``` ## Quick Start @@ -38,34 +55,48 @@ Local machine quick benchmark (macOS arm64, Feb 2026), same host, 3 runs each. git clone https://github.com/theonlyhennygod/zeroclaw.git cd zeroclaw cargo build --release +cargo install --path . --force # Quick setup (no prompts) -cargo run --release -- onboard --api-key sk-... --provider openrouter +zeroclaw onboard --api-key sk-... --provider openrouter # Or interactive wizard -cargo run --release -- onboard --interactive +zeroclaw onboard --interactive + +# Or quickly repair channels/allowlists only +zeroclaw onboard --channels-only # Chat -cargo run --release -- agent -m "Hello, ZeroClaw!" +zeroclaw agent -m "Hello, ZeroClaw!" # Interactive mode -cargo run --release -- agent +zeroclaw agent # Start the gateway (webhook server) -cargo run --release -- gateway # default: 127.0.0.1:8080 -cargo run --release -- gateway --port 0 # random port (security hardened) +zeroclaw gateway # default: 127.0.0.1:8080 +zeroclaw gateway --port 0 # random port (security hardened) + +# Start full autonomous runtime +zeroclaw daemon # Check status -cargo run --release -- status +zeroclaw status + +# Run system diagnostics +zeroclaw doctor # Check channel health -cargo run --release -- channel doctor +zeroclaw channel doctor # Get integration setup details -cargo run --release -- integrations info Telegram +zeroclaw integrations info Telegram + +# Manage background service +zeroclaw service install +zeroclaw service status ``` -> **Tip:** Run `cargo install --path .` to install `zeroclaw` globally, then use `zeroclaw` instead of `cargo run --release --`. +> **Dev fallback (no global install):** prefix commands with `cargo run --release --` (example: `cargo run --release -- status`). ## Architecture @@ -82,13 +113,20 @@ Every subsystem is a **trait** — swap implementations with a config change, ze | **Memory** | `Memory` | SQLite with hybrid search (FTS5 + vector cosine similarity), Markdown | Any persistence backend | | **Tools** | `Tool` | shell, file_read, file_write, memory_store, memory_recall, memory_forget, browser_open (Brave + allowlist), composio (optional) | Any capability | | **Observability** | `Observer` | Noop, Log, Multi | Prometheus, OTel | -| **Runtime** | `RuntimeAdapter` | Native (Mac/Linux/Pi) | Docker, WASM | +| **Runtime** | `RuntimeAdapter` | Native (Mac/Linux/Pi) | Docker, WASM (planned; unsupported kinds fail fast) | | **Security** | `SecurityPolicy` | Gateway pairing, sandbox, allowlists, rate limits, filesystem scoping, encrypted secrets | — | | **Tunnel** | `Tunnel` | None, Cloudflare, Tailscale, ngrok, Custom | Any tunnel binary | | **Heartbeat** | Engine | HEARTBEAT.md periodic tasks | — | | **Skills** | Loader | TOML manifests + SKILL.md instructions | Community skill packs | | **Integrations** | Registry | 50+ integrations across 9 categories | Plugin system | +### Runtime support (current) + +- ✅ Supported today: `runtime.kind = "native"` +- 🚧 Planned, not implemented yet: Docker / WASM / edge runtimes + +When an unsupported `runtime.kind` is configured, ZeroClaw now exits with a clear error instead of silently falling back to native. + ### Memory System (Full-Stack Search Engine) All custom, zero external dependencies — no Pinecone, no Elasticsearch, no LangChain: @@ -124,7 +162,7 @@ ZeroClaw enforces security at **every layer** — not just the sandbox. It passe |---|------|--------|-----| | 1 | **Gateway not publicly exposed** | ✅ | Binds `127.0.0.1` by default. Refuses `0.0.0.0` without tunnel or explicit `allow_public_bind = true`. | | 2 | **Pairing required** | ✅ | 6-digit one-time code on startup. Exchange via `POST /pair` for bearer token. All `/webhook` requests require `Authorization: Bearer `. | -| 3 | **Filesystem scoped (no /)** | ✅ | `workspace_only = true` by default. 14 system dirs + 4 sensitive dotfiles blocked. Null byte injection blocked. Symlink escape detection via canonicalization. | +| 3 | **Filesystem scoped (no /)** | ✅ | `workspace_only = true` by default. 14 system dirs + 4 sensitive dotfiles blocked. Null byte injection blocked. Symlink escape detection via canonicalization + resolved-path workspace checks in file read/write tools. | | 4 | **Access via tunnel only** | ✅ | Gateway refuses public bind without active tunnel. Supports Tailscale, Cloudflare, ngrok, or any custom tunnel. | > **Run your own nmap:** `nmap -p 1-65535 ` — ZeroClaw binds to localhost only, so nothing is exposed unless you explicitly configure a tunnel. @@ -139,6 +177,26 @@ Inbound sender policy is now consistent: This keeps accidental exposure low by default. +Recommended low-friction setup (secure + fast): + +- **Telegram:** allowlist your own `@username` (without `@`) and/or your numeric Telegram user ID. +- **Discord:** allowlist your own Discord user ID. +- **Slack:** allowlist your own Slack member ID (usually starts with `U`). +- Use `"*"` only for temporary open testing. + +If you're not sure which identity to use: + +1. Start channels and send one message to your bot. +2. Read the warning log to see the exact sender identity. +3. Add that value to the allowlist and rerun channels-only setup. + +If you hit authorization warnings in logs (for example: `ignoring message from unauthorized user`), +rerun channel setup only: + +```bash +zeroclaw onboard --channels-only +``` + ## Configuration Config: `~/.zeroclaw/config.toml` (created by `onboard`) @@ -166,6 +224,9 @@ workspace_only = true # default: true — scoped to workspace allowed_commands = ["git", "npm", "cargo", "ls", "cat", "grep"] forbidden_paths = ["/etc", "/root", "/proc", "/sys", "~/.ssh", "~/.gnupg", "~/.aws"] +[runtime] +kind = "native" # only supported value right now; unsupported kinds fail fast + [heartbeat] enabled = false interval_minutes = 30 @@ -198,10 +259,14 @@ enabled = false # opt-in: 1000+ OAuth apps via composio.dev |---------|-------------| | `onboard` | Quick setup (default) | | `onboard --interactive` | Full interactive 7-step wizard | +| `onboard --channels-only` | Reconfigure channels/allowlists only (fast repair flow) | | `agent -m "..."` | Single message mode | | `agent` | Interactive chat mode | | `gateway` | Start webhook server (default: `127.0.0.1:8080`) | | `gateway --port 0` | Random port mode | +| `daemon` | Start long-running autonomous runtime | +| `service install/start/stop/status/uninstall` | Manage user-level background service | +| `doctor` | Diagnose daemon/scheduler/channel freshness | | `status` | Show full system status | | `channel doctor` | Run health checks for configured channels | | `integrations info ` | Show setup/status details for one integration | diff --git a/src/agent/loop_.rs b/src/agent/loop_.rs index 57e0182..0f611d7 100644 --- a/src/agent/loop_.rs +++ b/src/agent/loop_.rs @@ -39,7 +39,7 @@ pub async fn run( // ── Wire up agnostic subsystems ────────────────────────────── let observer: Arc = Arc::from(observability::create_observer(&config.observability)); - let _runtime = runtime::create_runtime(&config.runtime); + let _runtime = runtime::create_runtime(&config.runtime)?; let security = Arc::new(SecurityPolicy::from_config( &config.autonomy, &config.workspace_dir, @@ -72,8 +72,11 @@ pub async fn run( .or(config.default_model.as_deref()) .unwrap_or("anthropic/claude-sonnet-4-20250514"); - let provider: Box = - providers::create_provider(provider_name, config.api_key.as_deref())?; + let provider: Box = providers::create_resilient_provider( + provider_name, + config.api_key.as_deref(), + &config.reliability, + )?; observer.record_event(&ObserverEvent::AgentStart { provider: provider_name.to_string(), @@ -83,12 +86,30 @@ pub async fn run( // ── Build system prompt from workspace MD files (OpenClaw framework) ── let skills = crate::skills::load_skills(&config.workspace_dir); let mut tool_descs: Vec<(&str, &str)> = vec![ - ("shell", "Execute terminal commands"), - ("file_read", "Read file contents"), - ("file_write", "Write file contents"), - ("memory_store", "Save to memory"), - ("memory_recall", "Search memory"), - ("memory_forget", "Delete a memory entry"), + ( + "shell", + "Execute terminal commands. Use when: running local checks, build/test commands, diagnostics. Don't use when: a safer dedicated tool exists, or command is destructive without approval.", + ), + ( + "file_read", + "Read file contents. Use when: inspecting project files, configs, logs. Don't use when: a targeted search is enough.", + ), + ( + "file_write", + "Write file contents. Use when: applying focused edits, scaffolding files, updating docs/code. Don't use when: side effects are unclear or file ownership is uncertain.", + ), + ( + "memory_store", + "Save to memory. Use when: preserving durable preferences, decisions, key context. Don't use when: information is transient/noisy/sensitive without need.", + ), + ( + "memory_recall", + "Search memory. Use when: retrieving prior decisions, user preferences, historical context. Don't use when: answer is already in current context.", + ), + ( + "memory_forget", + "Delete a memory entry. Use when: memory is incorrect/stale or explicitly requested for removal. Don't use when: impact is uncertain.", + ), ]; if config.browser.enabled { tool_descs.push(( diff --git a/src/channels/mod.rs b/src/channels/mod.rs index 7252f7d..32e47e7 100644 --- a/src/channels/mod.rs +++ b/src/channels/mod.rs @@ -24,6 +24,46 @@ use std::time::Duration; /// Maximum characters per injected workspace file (matches `OpenClaw` default). const BOOTSTRAP_MAX_CHARS: usize = 20_000; +const DEFAULT_CHANNEL_INITIAL_BACKOFF_SECS: u64 = 2; +const DEFAULT_CHANNEL_MAX_BACKOFF_SECS: u64 = 60; + +fn spawn_supervised_listener( + ch: Arc, + tx: tokio::sync::mpsc::Sender, + initial_backoff_secs: u64, + max_backoff_secs: u64, +) -> tokio::task::JoinHandle<()> { + tokio::spawn(async move { + let component = format!("channel:{}", ch.name()); + let mut backoff = initial_backoff_secs.max(1); + let max_backoff = max_backoff_secs.max(backoff); + + loop { + crate::health::mark_component_ok(&component); + let result = ch.listen(tx.clone()).await; + + if tx.is_closed() { + break; + } + + match result { + Ok(()) => { + tracing::warn!("Channel {} exited unexpectedly; restarting", ch.name()); + crate::health::mark_component_error(&component, "listener exited unexpectedly"); + } + Err(e) => { + tracing::error!("Channel {} error: {e}; restarting", ch.name()); + crate::health::mark_component_error(&component, e.to_string()); + } + } + + crate::health::bump_component_restart(&component); + tokio::time::sleep(Duration::from_secs(backoff)).await; + backoff = backoff.saturating_mul(2).min(max_backoff); + } + }) +} + /// Load workspace identity files and build a system prompt. /// /// Follows the `OpenClaw` framework structure: @@ -334,9 +374,10 @@ pub async fn doctor_channels(config: Config) -> Result<()> { /// Start all configured channels and route messages to the agent #[allow(clippy::too_many_lines)] pub async fn start_channels(config: Config) -> Result<()> { - let provider: Arc = Arc::from(providers::create_provider( + let provider: Arc = Arc::from(providers::create_resilient_provider( config.default_provider.as_deref().unwrap_or("openrouter"), config.api_key.as_deref(), + &config.reliability, )?); let model = config .default_model @@ -355,12 +396,30 @@ pub async fn start_channels(config: Config) -> Result<()> { // Collect tool descriptions for the prompt let mut tool_descs: Vec<(&str, &str)> = vec![ - ("shell", "Execute terminal commands"), - ("file_read", "Read file contents"), - ("file_write", "Write file contents"), - ("memory_store", "Save to memory"), - ("memory_recall", "Search memory"), - ("memory_forget", "Delete a memory entry"), + ( + "shell", + "Execute terminal commands. Use when: running local checks, build/test commands, diagnostics. Don't use when: a safer dedicated tool exists, or command is destructive without approval.", + ), + ( + "file_read", + "Read file contents. Use when: inspecting project files, configs, logs. Don't use when: a targeted search is enough.", + ), + ( + "file_write", + "Write file contents. Use when: applying focused edits, scaffolding files, updating docs/code. Don't use when: side effects are unclear or file ownership is uncertain.", + ), + ( + "memory_store", + "Save to memory. Use when: preserving durable preferences, decisions, key context. Don't use when: information is transient/noisy/sensitive without need.", + ), + ( + "memory_recall", + "Search memory. Use when: retrieving prior decisions, user preferences, historical context. Don't use when: answer is already in current context.", + ), + ( + "memory_forget", + "Delete a memory entry. Use when: memory is incorrect/stale or explicitly requested for removal. Don't use when: impact is uncertain.", + ), ]; if config.browser.enabled { @@ -446,19 +505,29 @@ pub async fn start_channels(config: Config) -> Result<()> { println!(" Listening for messages... (Ctrl+C to stop)"); println!(); + crate::health::mark_component_ok("channels"); + + let initial_backoff_secs = config + .reliability + .channel_initial_backoff_secs + .max(DEFAULT_CHANNEL_INITIAL_BACKOFF_SECS); + let max_backoff_secs = config + .reliability + .channel_max_backoff_secs + .max(DEFAULT_CHANNEL_MAX_BACKOFF_SECS); + // Single message bus — all channels send messages here let (tx, mut rx) = tokio::sync::mpsc::channel::(100); // Spawn a listener for each channel let mut handles = Vec::new(); for ch in &channels { - let ch = ch.clone(); - let tx = tx.clone(); - handles.push(tokio::spawn(async move { - if let Err(e) = ch.listen(tx).await { - tracing::error!("Channel {} error: {e}", ch.name()); - } - })); + handles.push(spawn_supervised_listener( + ch.clone(), + tx.clone(), + initial_backoff_secs, + max_backoff_secs, + )); } drop(tx); // Drop our copy so rx closes when all channels stop @@ -533,6 +602,8 @@ pub async fn start_channels(config: Config) -> Result<()> { #[cfg(test)] mod tests { use super::*; + use std::sync::atomic::{AtomicUsize, Ordering}; + use std::sync::Arc; use tempfile::TempDir; fn make_workspace() -> TempDir { @@ -777,4 +848,55 @@ mod tests { let state = classify_health_result(&result); assert_eq!(state, ChannelHealthState::Timeout); } + + struct AlwaysFailChannel { + name: &'static str, + calls: Arc, + } + + #[async_trait::async_trait] + impl Channel for AlwaysFailChannel { + fn name(&self) -> &str { + self.name + } + + async fn send(&self, _message: &str, _recipient: &str) -> anyhow::Result<()> { + Ok(()) + } + + async fn listen( + &self, + _tx: tokio::sync::mpsc::Sender, + ) -> anyhow::Result<()> { + self.calls.fetch_add(1, Ordering::SeqCst); + anyhow::bail!("listen boom") + } + } + + #[tokio::test] + async fn supervised_listener_marks_error_and_restarts_on_failures() { + let calls = Arc::new(AtomicUsize::new(0)); + let channel: Arc = Arc::new(AlwaysFailChannel { + name: "test-supervised-fail", + calls: Arc::clone(&calls), + }); + + let (_tx, rx) = tokio::sync::mpsc::channel::(1); + let handle = spawn_supervised_listener(channel, _tx, 1, 1); + + tokio::time::sleep(Duration::from_millis(80)).await; + drop(rx); + handle.abort(); + let _ = handle.await; + + let snapshot = crate::health::snapshot_json(); + let component = &snapshot["components"]["channel:test-supervised-fail"]; + assert_eq!(component["status"], "error"); + assert!(component["restart_count"].as_u64().unwrap_or(0) >= 1); + assert!(component["last_error"] + .as_str() + .unwrap_or("") + .contains("listen boom")); + assert!(calls.load(Ordering::SeqCst) >= 1); + } } diff --git a/src/channels/telegram.rs b/src/channels/telegram.rs index 56f8a3c..0147c8d 100644 --- a/src/channels/telegram.rs +++ b/src/channels/telegram.rs @@ -25,6 +25,13 @@ impl TelegramChannel { fn is_user_allowed(&self, username: &str) -> bool { self.allowed_users.iter().any(|u| u == "*" || u == username) } + + fn is_any_user_allowed<'a, I>(&self, identities: I) -> bool + where + I: IntoIterator, + { + identities.into_iter().any(|id| self.is_user_allowed(id)) + } } #[async_trait] @@ -95,15 +102,28 @@ impl Channel for TelegramChannel { continue; }; - let username = message + let username_opt = message .get("from") .and_then(|f| f.get("username")) - .and_then(|u| u.as_str()) - .unwrap_or("unknown"); + .and_then(|u| u.as_str()); + let username = username_opt.unwrap_or("unknown"); - if !self.is_user_allowed(username) { + let user_id = message + .get("from") + .and_then(|f| f.get("id")) + .and_then(serde_json::Value::as_i64); + let user_id_str = user_id.map(|id| id.to_string()); + + let mut identities = vec![username]; + if let Some(ref id) = user_id_str { + identities.push(id.as_str()); + } + + if !self.is_any_user_allowed(identities.iter().copied()) { tracing::warn!( - "Telegram: ignoring message from unauthorized user: {username}" + "Telegram: ignoring message from unauthorized user: username={username}, user_id={}. \ +Allowlist Telegram @username or numeric user ID, then run `zeroclaw onboard --channels-only`.", + user_id_str.as_deref().unwrap_or("unknown") ); continue; } @@ -211,4 +231,16 @@ mod tests { assert!(ch.is_user_allowed("bob")); assert!(ch.is_user_allowed("anyone")); } + + #[test] + fn telegram_user_allowed_by_numeric_id_identity() { + let ch = TelegramChannel::new("t".into(), vec!["123456789".into()]); + assert!(ch.is_any_user_allowed(["unknown", "123456789"])); + } + + #[test] + fn telegram_user_denied_when_none_of_identities_match() { + let ch = TelegramChannel::new("t".into(), vec!["alice".into(), "987654321".into()]); + assert!(!ch.is_any_user_allowed(["unknown", "123456789"])); + } } diff --git a/src/config/mod.rs b/src/config/mod.rs index 9af098c..4632486 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -3,6 +3,6 @@ pub mod schema; pub use schema::{ AutonomyConfig, BrowserConfig, ChannelsConfig, ComposioConfig, Config, DiscordConfig, GatewayConfig, HeartbeatConfig, IMessageConfig, MatrixConfig, MemoryConfig, - ObservabilityConfig, RuntimeConfig, SecretsConfig, SlackConfig, TelegramConfig, TunnelConfig, - WebhookConfig, + ObservabilityConfig, ReliabilityConfig, RuntimeConfig, SecretsConfig, SlackConfig, + TelegramConfig, TunnelConfig, WebhookConfig, }; diff --git a/src/config/schema.rs b/src/config/schema.rs index 49a9d59..006d120 100644 --- a/src/config/schema.rs +++ b/src/config/schema.rs @@ -25,6 +25,9 @@ pub struct Config { #[serde(default)] pub runtime: RuntimeConfig, + #[serde(default)] + pub reliability: ReliabilityConfig, + #[serde(default)] pub heartbeat: HeartbeatConfig, @@ -143,6 +146,18 @@ pub struct MemoryConfig { pub backend: String, /// Auto-save conversation context to memory pub auto_save: bool, + /// Run memory/session hygiene (archiving + retention cleanup) + #[serde(default = "default_hygiene_enabled")] + pub hygiene_enabled: bool, + /// Archive daily/session files older than this many days + #[serde(default = "default_archive_after_days")] + pub archive_after_days: u32, + /// Purge archived files older than this many days + #[serde(default = "default_purge_after_days")] + pub purge_after_days: u32, + /// For sqlite backend: prune conversation rows older than this many days + #[serde(default = "default_conversation_retention_days")] + pub conversation_retention_days: u32, /// Embedding provider: "none" | "openai" | "custom:URL" #[serde(default = "default_embedding_provider")] pub embedding_provider: String, @@ -169,6 +184,18 @@ pub struct MemoryConfig { fn default_embedding_provider() -> String { "none".into() } +fn default_hygiene_enabled() -> bool { + true +} +fn default_archive_after_days() -> u32 { + 7 +} +fn default_purge_after_days() -> u32 { + 30 +} +fn default_conversation_retention_days() -> u32 { + 30 +} fn default_embedding_model() -> String { "text-embedding-3-small".into() } @@ -193,6 +220,10 @@ impl Default for MemoryConfig { Self { backend: "sqlite".into(), auto_save: true, + hygiene_enabled: default_hygiene_enabled(), + archive_after_days: default_archive_after_days(), + purge_after_days: default_purge_after_days(), + conversation_retention_days: default_conversation_retention_days(), embedding_provider: default_embedding_provider(), embedding_model: default_embedding_model(), embedding_dimensions: default_embedding_dims(), @@ -281,7 +312,9 @@ impl Default for AutonomyConfig { #[derive(Debug, Clone, Serialize, Deserialize)] pub struct RuntimeConfig { - /// "native" | "docker" | "cloudflare" + /// Runtime kind (currently supported: "native"). + /// + /// Reserved values (not implemented yet): "docker", "cloudflare". pub kind: String, } @@ -293,6 +326,71 @@ impl Default for RuntimeConfig { } } +// ── Reliability / supervision ──────────────────────────────────── + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ReliabilityConfig { + /// Retries per provider before failing over. + #[serde(default = "default_provider_retries")] + pub provider_retries: u32, + /// Base backoff (ms) for provider retry delay. + #[serde(default = "default_provider_backoff_ms")] + pub provider_backoff_ms: u64, + /// Fallback provider chain (e.g. `["anthropic", "openai"]`). + #[serde(default)] + pub fallback_providers: Vec, + /// Initial backoff for channel/daemon restarts. + #[serde(default = "default_channel_backoff_secs")] + pub channel_initial_backoff_secs: u64, + /// Max backoff for channel/daemon restarts. + #[serde(default = "default_channel_backoff_max_secs")] + pub channel_max_backoff_secs: u64, + /// Scheduler polling cadence in seconds. + #[serde(default = "default_scheduler_poll_secs")] + pub scheduler_poll_secs: u64, + /// Max retries for cron job execution attempts. + #[serde(default = "default_scheduler_retries")] + pub scheduler_retries: u32, +} + +fn default_provider_retries() -> u32 { + 2 +} + +fn default_provider_backoff_ms() -> u64 { + 500 +} + +fn default_channel_backoff_secs() -> u64 { + 2 +} + +fn default_channel_backoff_max_secs() -> u64 { + 60 +} + +fn default_scheduler_poll_secs() -> u64 { + 15 +} + +fn default_scheduler_retries() -> u32 { + 2 +} + +impl Default for ReliabilityConfig { + fn default() -> Self { + Self { + provider_retries: default_provider_retries(), + provider_backoff_ms: default_provider_backoff_ms(), + fallback_providers: Vec::new(), + channel_initial_backoff_secs: default_channel_backoff_secs(), + channel_max_backoff_secs: default_channel_backoff_max_secs(), + scheduler_poll_secs: default_scheduler_poll_secs(), + scheduler_retries: default_scheduler_retries(), + } + } +} + // ── Heartbeat ──────────────────────────────────────────────────── #[derive(Debug, Clone, Serialize, Deserialize)] @@ -463,6 +561,7 @@ impl Default for Config { observability: ObservabilityConfig::default(), autonomy: AutonomyConfig::default(), runtime: RuntimeConfig::default(), + reliability: ReliabilityConfig::default(), heartbeat: HeartbeatConfig::default(), channels_config: ChannelsConfig::default(), memory: MemoryConfig::default(), @@ -558,6 +657,17 @@ mod tests { assert_eq!(h.interval_minutes, 30); } + #[test] + fn memory_config_default_hygiene_settings() { + let m = MemoryConfig::default(); + assert_eq!(m.backend, "sqlite"); + assert!(m.auto_save); + assert!(m.hygiene_enabled); + assert_eq!(m.archive_after_days, 7); + assert_eq!(m.purge_after_days, 30); + assert_eq!(m.conversation_retention_days, 30); + } + #[test] fn channels_config_default() { let c = ChannelsConfig::default(); @@ -591,6 +701,7 @@ mod tests { runtime: RuntimeConfig { kind: "docker".into(), }, + reliability: ReliabilityConfig::default(), heartbeat: HeartbeatConfig { enabled: true, interval_minutes: 15, @@ -650,6 +761,10 @@ default_temperature = 0.7 assert_eq!(parsed.runtime.kind, "native"); assert!(!parsed.heartbeat.enabled); assert!(parsed.channels_config.cli); + assert!(parsed.memory.hygiene_enabled); + assert_eq!(parsed.memory.archive_after_days, 7); + assert_eq!(parsed.memory.purge_after_days, 30); + assert_eq!(parsed.memory.conversation_retention_days, 30); } #[test] @@ -669,6 +784,7 @@ default_temperature = 0.7 observability: ObservabilityConfig::default(), autonomy: AutonomyConfig::default(), runtime: RuntimeConfig::default(), + reliability: ReliabilityConfig::default(), heartbeat: HeartbeatConfig::default(), channels_config: ChannelsConfig::default(), memory: MemoryConfig::default(), diff --git a/src/cron/mod.rs b/src/cron/mod.rs index 8f52701..572670d 100644 --- a/src/cron/mod.rs +++ b/src/cron/mod.rs @@ -1,25 +1,353 @@ use crate::config::Config; -use anyhow::Result; +use anyhow::{Context, Result}; +use chrono::{DateTime, Utc}; +use cron::Schedule; +use rusqlite::{params, Connection}; +use std::str::FromStr; +use uuid::Uuid; -pub fn handle_command(command: super::CronCommands, _config: Config) -> Result<()> { +pub mod scheduler; + +#[derive(Debug, Clone)] +pub struct CronJob { + pub id: String, + pub expression: String, + pub command: String, + pub next_run: DateTime, + pub last_run: Option>, + pub last_status: Option, +} + +pub fn handle_command(command: super::CronCommands, config: Config) -> Result<()> { match command { super::CronCommands::List => { - println!("No scheduled tasks yet."); - println!("\nUsage:"); - println!(" zeroclaw cron add '0 9 * * *' 'agent -m \"Good morning!\"'"); + let jobs = list_jobs(&config)?; + if jobs.is_empty() { + println!("No scheduled tasks yet."); + println!("\nUsage:"); + println!(" zeroclaw cron add '0 9 * * *' 'agent -m \"Good morning!\"'"); + return Ok(()); + } + + println!("🕒 Scheduled jobs ({}):", jobs.len()); + for job in jobs { + let last_run = job + .last_run + .map(|d| d.to_rfc3339()) + .unwrap_or_else(|| "never".into()); + let last_status = job.last_status.unwrap_or_else(|| "n/a".into()); + println!( + "- {} | {} | next={} | last={} ({})\n cmd: {}", + job.id, + job.expression, + job.next_run.to_rfc3339(), + last_run, + last_status, + job.command + ); + } Ok(()) } super::CronCommands::Add { expression, command, } => { - println!("Cron scheduling coming soon!"); - println!(" Expression: {expression}"); - println!(" Command: {command}"); + let job = add_job(&config, &expression, &command)?; + println!("✅ Added cron job {}", job.id); + println!(" Expr: {}", job.expression); + println!(" Next: {}", job.next_run.to_rfc3339()); + println!(" Cmd : {}", job.command); Ok(()) } - super::CronCommands::Remove { id } => { - anyhow::bail!("Remove task '{id}' not yet implemented"); - } + super::CronCommands::Remove { id } => remove_job(&config, &id), + } +} + +pub fn add_job(config: &Config, expression: &str, command: &str) -> Result { + let now = Utc::now(); + let next_run = next_run_for(expression, now)?; + let id = Uuid::new_v4().to_string(); + + with_connection(config, |conn| { + conn.execute( + "INSERT INTO cron_jobs (id, expression, command, created_at, next_run) + VALUES (?1, ?2, ?3, ?4, ?5)", + params![ + id, + expression, + command, + now.to_rfc3339(), + next_run.to_rfc3339() + ], + ) + .context("Failed to insert cron job")?; + Ok(()) + })?; + + Ok(CronJob { + id, + expression: expression.to_string(), + command: command.to_string(), + next_run, + last_run: None, + last_status: None, + }) +} + +pub fn list_jobs(config: &Config) -> Result> { + with_connection(config, |conn| { + let mut stmt = conn.prepare( + "SELECT id, expression, command, next_run, last_run, last_status + FROM cron_jobs ORDER BY next_run ASC", + )?; + + let rows = stmt.query_map([], |row| { + let next_run_raw: String = row.get(3)?; + let last_run_raw: Option = row.get(4)?; + Ok(( + row.get::<_, String>(0)?, + row.get::<_, String>(1)?, + row.get::<_, String>(2)?, + next_run_raw, + last_run_raw, + row.get::<_, Option>(5)?, + )) + })?; + + let mut jobs = Vec::new(); + for row in rows { + let (id, expression, command, next_run_raw, last_run_raw, last_status) = row?; + jobs.push(CronJob { + id, + expression, + command, + next_run: parse_rfc3339(&next_run_raw)?, + last_run: match last_run_raw { + Some(raw) => Some(parse_rfc3339(&raw)?), + None => None, + }, + last_status, + }); + } + Ok(jobs) + }) +} + +pub fn remove_job(config: &Config, id: &str) -> Result<()> { + let changed = with_connection(config, |conn| { + conn.execute("DELETE FROM cron_jobs WHERE id = ?1", params![id]) + .context("Failed to delete cron job") + })?; + + if changed == 0 { + anyhow::bail!("Cron job '{id}' not found"); + } + + println!("✅ Removed cron job {id}"); + Ok(()) +} + +pub fn due_jobs(config: &Config, now: DateTime) -> Result> { + with_connection(config, |conn| { + let mut stmt = conn.prepare( + "SELECT id, expression, command, next_run, last_run, last_status + FROM cron_jobs WHERE next_run <= ?1 ORDER BY next_run ASC", + )?; + + let rows = stmt.query_map(params![now.to_rfc3339()], |row| { + let next_run_raw: String = row.get(3)?; + let last_run_raw: Option = row.get(4)?; + Ok(( + row.get::<_, String>(0)?, + row.get::<_, String>(1)?, + row.get::<_, String>(2)?, + next_run_raw, + last_run_raw, + row.get::<_, Option>(5)?, + )) + })?; + + let mut jobs = Vec::new(); + for row in rows { + let (id, expression, command, next_run_raw, last_run_raw, last_status) = row?; + jobs.push(CronJob { + id, + expression, + command, + next_run: parse_rfc3339(&next_run_raw)?, + last_run: match last_run_raw { + Some(raw) => Some(parse_rfc3339(&raw)?), + None => None, + }, + last_status, + }); + } + Ok(jobs) + }) +} + +pub fn reschedule_after_run( + config: &Config, + job: &CronJob, + success: bool, + output: &str, +) -> Result<()> { + let now = Utc::now(); + let next_run = next_run_for(&job.expression, now)?; + let status = if success { "ok" } else { "error" }; + + with_connection(config, |conn| { + conn.execute( + "UPDATE cron_jobs + SET next_run = ?1, last_run = ?2, last_status = ?3, last_output = ?4 + WHERE id = ?5", + params![ + next_run.to_rfc3339(), + now.to_rfc3339(), + status, + output, + job.id + ], + ) + .context("Failed to update cron job run state")?; + Ok(()) + }) +} + +fn next_run_for(expression: &str, from: DateTime) -> Result> { + let normalized = normalize_expression(expression)?; + let schedule = Schedule::from_str(&normalized) + .with_context(|| format!("Invalid cron expression: {expression}"))?; + schedule + .after(&from) + .next() + .ok_or_else(|| anyhow::anyhow!("No future occurrence for expression: {expression}")) +} + +fn normalize_expression(expression: &str) -> Result { + let expression = expression.trim(); + let field_count = expression.split_whitespace().count(); + + match field_count { + // standard crontab syntax: minute hour day month weekday + 5 => Ok(format!("0 {expression}")), + // crate-native syntax includes seconds (+ optional year) + 6 | 7 => Ok(expression.to_string()), + _ => anyhow::bail!( + "Invalid cron expression: {expression} (expected 5, 6, or 7 fields, got {field_count})" + ), + } +} + +fn parse_rfc3339(raw: &str) -> Result> { + let parsed = DateTime::parse_from_rfc3339(raw) + .with_context(|| format!("Invalid RFC3339 timestamp in cron DB: {raw}"))?; + Ok(parsed.with_timezone(&Utc)) +} + +fn with_connection(config: &Config, f: impl FnOnce(&Connection) -> Result) -> Result { + let db_path = config.workspace_dir.join("cron").join("jobs.db"); + if let Some(parent) = db_path.parent() { + std::fs::create_dir_all(parent) + .with_context(|| format!("Failed to create cron directory: {}", parent.display()))?; + } + + let conn = Connection::open(&db_path) + .with_context(|| format!("Failed to open cron DB: {}", db_path.display()))?; + + conn.execute_batch( + "CREATE TABLE IF NOT EXISTS cron_jobs ( + id TEXT PRIMARY KEY, + expression TEXT NOT NULL, + command TEXT NOT NULL, + created_at TEXT NOT NULL, + next_run TEXT NOT NULL, + last_run TEXT, + last_status TEXT, + last_output TEXT + ); + CREATE INDEX IF NOT EXISTS idx_cron_jobs_next_run ON cron_jobs(next_run);", + ) + .context("Failed to initialize cron schema")?; + + f(&conn) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::config::Config; + use chrono::Duration as ChronoDuration; + use tempfile::TempDir; + + fn test_config(tmp: &TempDir) -> Config { + let mut config = Config::default(); + config.workspace_dir = tmp.path().join("workspace"); + config.config_path = tmp.path().join("config.toml"); + std::fs::create_dir_all(&config.workspace_dir).unwrap(); + config + } + + #[test] + fn add_job_accepts_five_field_expression() { + let tmp = TempDir::new().unwrap(); + let config = test_config(&tmp); + + let job = add_job(&config, "*/5 * * * *", "echo ok").unwrap(); + + assert_eq!(job.expression, "*/5 * * * *"); + assert_eq!(job.command, "echo ok"); + } + + #[test] + fn add_job_rejects_invalid_field_count() { + let tmp = TempDir::new().unwrap(); + let config = test_config(&tmp); + + let err = add_job(&config, "* * * *", "echo bad").unwrap_err(); + assert!(err.to_string().contains("expected 5, 6, or 7 fields")); + } + + #[test] + fn add_list_remove_roundtrip() { + let tmp = TempDir::new().unwrap(); + let config = test_config(&tmp); + + let job = add_job(&config, "*/10 * * * *", "echo roundtrip").unwrap(); + let listed = list_jobs(&config).unwrap(); + assert_eq!(listed.len(), 1); + assert_eq!(listed[0].id, job.id); + + remove_job(&config, &job.id).unwrap(); + assert!(list_jobs(&config).unwrap().is_empty()); + } + + #[test] + fn due_jobs_filters_by_timestamp() { + let tmp = TempDir::new().unwrap(); + let config = test_config(&tmp); + + let _job = add_job(&config, "* * * * *", "echo due").unwrap(); + + let due_now = due_jobs(&config, Utc::now()).unwrap(); + assert!(due_now.is_empty(), "new job should not be due immediately"); + + let far_future = Utc::now() + ChronoDuration::days(365); + let due_future = due_jobs(&config, far_future).unwrap(); + assert_eq!(due_future.len(), 1, "job should be due in far future"); + } + + #[test] + fn reschedule_after_run_persists_last_status_and_last_run() { + let tmp = TempDir::new().unwrap(); + let config = test_config(&tmp); + + let job = add_job(&config, "*/15 * * * *", "echo run").unwrap(); + reschedule_after_run(&config, &job, false, "failed output").unwrap(); + + let listed = list_jobs(&config).unwrap(); + let stored = listed.iter().find(|j| j.id == job.id).unwrap(); + assert_eq!(stored.last_status.as_deref(), Some("error")); + assert!(stored.last_run.is_some()); } } diff --git a/src/cron/scheduler.rs b/src/cron/scheduler.rs new file mode 100644 index 0000000..459fe59 --- /dev/null +++ b/src/cron/scheduler.rs @@ -0,0 +1,169 @@ +use crate::config::Config; +use crate::cron::{due_jobs, reschedule_after_run, CronJob}; +use anyhow::Result; +use chrono::Utc; +use tokio::process::Command; +use tokio::time::{self, Duration}; + +const MIN_POLL_SECONDS: u64 = 5; + +pub async fn run(config: Config) -> Result<()> { + let poll_secs = config.reliability.scheduler_poll_secs.max(MIN_POLL_SECONDS); + let mut interval = time::interval(Duration::from_secs(poll_secs)); + + crate::health::mark_component_ok("scheduler"); + + loop { + interval.tick().await; + + let jobs = match due_jobs(&config, Utc::now()) { + Ok(jobs) => jobs, + Err(e) => { + crate::health::mark_component_error("scheduler", e.to_string()); + tracing::warn!("Scheduler query failed: {e}"); + continue; + } + }; + + for job in jobs { + crate::health::mark_component_ok("scheduler"); + let (success, output) = execute_job_with_retry(&config, &job).await; + + if !success { + crate::health::mark_component_error("scheduler", format!("job {} failed", job.id)); + } + + if let Err(e) = reschedule_after_run(&config, &job, success, &output) { + crate::health::mark_component_error("scheduler", e.to_string()); + tracing::warn!("Failed to persist scheduler run result: {e}"); + } + } + } +} + +async fn execute_job_with_retry(config: &Config, job: &CronJob) -> (bool, String) { + let mut last_output = String::new(); + let retries = config.reliability.scheduler_retries; + let mut backoff_ms = config.reliability.provider_backoff_ms.max(200); + + for attempt in 0..=retries { + let (success, output) = run_job_command(config, job).await; + last_output = output; + + if success { + return (true, last_output); + } + + if attempt < retries { + let jitter_ms = (Utc::now().timestamp_subsec_millis() % 250) as u64; + time::sleep(Duration::from_millis(backoff_ms + jitter_ms)).await; + backoff_ms = (backoff_ms.saturating_mul(2)).min(30_000); + } + } + + (false, last_output) +} + +async fn run_job_command(config: &Config, job: &CronJob) -> (bool, String) { + let output = Command::new("sh") + .arg("-lc") + .arg(&job.command) + .current_dir(&config.workspace_dir) + .output() + .await; + + match output { + Ok(output) => { + let stdout = String::from_utf8_lossy(&output.stdout); + let stderr = String::from_utf8_lossy(&output.stderr); + let combined = format!( + "status={}\nstdout:\n{}\nstderr:\n{}", + output.status, + stdout.trim(), + stderr.trim() + ); + (output.status.success(), combined) + } + Err(e) => (false, format!("spawn error: {e}")), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::config::Config; + use tempfile::TempDir; + + fn test_config(tmp: &TempDir) -> Config { + let mut config = Config::default(); + config.workspace_dir = tmp.path().join("workspace"); + config.config_path = tmp.path().join("config.toml"); + std::fs::create_dir_all(&config.workspace_dir).unwrap(); + config + } + + fn test_job(command: &str) -> CronJob { + CronJob { + id: "test-job".into(), + expression: "* * * * *".into(), + command: command.into(), + next_run: Utc::now(), + last_run: None, + last_status: None, + } + } + + #[tokio::test] + async fn run_job_command_success() { + let tmp = TempDir::new().unwrap(); + let config = test_config(&tmp); + let job = test_job("echo scheduler-ok"); + + let (success, output) = run_job_command(&config, &job).await; + assert!(success); + assert!(output.contains("scheduler-ok")); + assert!(output.contains("status=exit status: 0")); + } + + #[tokio::test] + async fn run_job_command_failure() { + let tmp = TempDir::new().unwrap(); + let config = test_config(&tmp); + let job = test_job("echo scheduler-fail 1>&2; exit 7"); + + let (success, output) = run_job_command(&config, &job).await; + assert!(!success); + assert!(output.contains("scheduler-fail")); + assert!(output.contains("status=exit status: 7")); + } + + #[tokio::test] + async fn execute_job_with_retry_recovers_after_first_failure() { + let tmp = TempDir::new().unwrap(); + let mut config = test_config(&tmp); + config.reliability.scheduler_retries = 1; + config.reliability.provider_backoff_ms = 1; + + let job = test_job( + "if [ -f retry-ok.flag ]; then echo recovered; exit 0; else touch retry-ok.flag; echo first-fail 1>&2; exit 1; fi", + ); + + let (success, output) = execute_job_with_retry(&config, &job).await; + assert!(success); + assert!(output.contains("recovered")); + } + + #[tokio::test] + async fn execute_job_with_retry_exhausts_attempts() { + let tmp = TempDir::new().unwrap(); + let mut config = test_config(&tmp); + config.reliability.scheduler_retries = 1; + config.reliability.provider_backoff_ms = 1; + + let job = test_job("echo still-bad 1>&2; exit 1"); + + let (success, output) = execute_job_with_retry(&config, &job).await; + assert!(!success); + assert!(output.contains("still-bad")); + } +} diff --git a/src/daemon/mod.rs b/src/daemon/mod.rs new file mode 100644 index 0000000..db374bc --- /dev/null +++ b/src/daemon/mod.rs @@ -0,0 +1,287 @@ +use crate::config::Config; +use anyhow::Result; +use chrono::Utc; +use std::future::Future; +use std::path::PathBuf; +use tokio::task::JoinHandle; +use tokio::time::Duration; + +const STATUS_FLUSH_SECONDS: u64 = 5; + +pub async fn run(config: Config, host: String, port: u16) -> Result<()> { + let initial_backoff = config.reliability.channel_initial_backoff_secs.max(1); + let max_backoff = config + .reliability + .channel_max_backoff_secs + .max(initial_backoff); + + crate::health::mark_component_ok("daemon"); + + if config.heartbeat.enabled { + let _ = + crate::heartbeat::engine::HeartbeatEngine::ensure_heartbeat_file(&config.workspace_dir) + .await; + } + + let mut handles: Vec> = vec![spawn_state_writer(config.clone())]; + + { + let gateway_cfg = config.clone(); + let gateway_host = host.clone(); + handles.push(spawn_component_supervisor( + "gateway", + initial_backoff, + max_backoff, + move || { + let cfg = gateway_cfg.clone(); + let host = gateway_host.clone(); + async move { crate::gateway::run_gateway(&host, port, cfg).await } + }, + )); + } + + { + if has_supervised_channels(&config) { + let channels_cfg = config.clone(); + handles.push(spawn_component_supervisor( + "channels", + initial_backoff, + max_backoff, + move || { + let cfg = channels_cfg.clone(); + async move { crate::channels::start_channels(cfg).await } + }, + )); + } else { + crate::health::mark_component_ok("channels"); + tracing::info!("No real-time channels configured; channel supervisor disabled"); + } + } + + if config.heartbeat.enabled { + let heartbeat_cfg = config.clone(); + handles.push(spawn_component_supervisor( + "heartbeat", + initial_backoff, + max_backoff, + move || { + let cfg = heartbeat_cfg.clone(); + async move { run_heartbeat_worker(cfg).await } + }, + )); + } + + { + let scheduler_cfg = config.clone(); + handles.push(spawn_component_supervisor( + "scheduler", + initial_backoff, + max_backoff, + move || { + let cfg = scheduler_cfg.clone(); + async move { crate::cron::scheduler::run(cfg).await } + }, + )); + } + + println!("🧠 ZeroClaw daemon started"); + println!(" Gateway: http://{host}:{port}"); + println!(" Components: gateway, channels, heartbeat, scheduler"); + println!(" Ctrl+C to stop"); + + tokio::signal::ctrl_c().await?; + crate::health::mark_component_error("daemon", "shutdown requested"); + + for handle in &handles { + handle.abort(); + } + for handle in handles { + let _ = handle.await; + } + + Ok(()) +} + +pub fn state_file_path(config: &Config) -> PathBuf { + config + .config_path + .parent() + .map_or_else(|| PathBuf::from("."), PathBuf::from) + .join("daemon_state.json") +} + +fn spawn_state_writer(config: Config) -> JoinHandle<()> { + tokio::spawn(async move { + let path = state_file_path(&config); + if let Some(parent) = path.parent() { + let _ = tokio::fs::create_dir_all(parent).await; + } + + let mut interval = tokio::time::interval(Duration::from_secs(STATUS_FLUSH_SECONDS)); + loop { + interval.tick().await; + let mut json = crate::health::snapshot_json(); + if let Some(obj) = json.as_object_mut() { + obj.insert( + "written_at".into(), + serde_json::json!(Utc::now().to_rfc3339()), + ); + } + let data = serde_json::to_vec_pretty(&json).unwrap_or_else(|_| b"{}".to_vec()); + let _ = tokio::fs::write(&path, data).await; + } + }) +} + +fn spawn_component_supervisor( + name: &'static str, + initial_backoff_secs: u64, + max_backoff_secs: u64, + mut run_component: F, +) -> JoinHandle<()> +where + F: FnMut() -> Fut + Send + 'static, + Fut: Future> + Send + 'static, +{ + tokio::spawn(async move { + let mut backoff = initial_backoff_secs.max(1); + let max_backoff = max_backoff_secs.max(backoff); + + loop { + crate::health::mark_component_ok(name); + match run_component().await { + Ok(()) => { + crate::health::mark_component_error(name, "component exited unexpectedly"); + tracing::warn!("Daemon component '{name}' exited unexpectedly"); + } + Err(e) => { + crate::health::mark_component_error(name, e.to_string()); + tracing::error!("Daemon component '{name}' failed: {e}"); + } + } + + crate::health::bump_component_restart(name); + tokio::time::sleep(Duration::from_secs(backoff)).await; + backoff = backoff.saturating_mul(2).min(max_backoff); + } + }) +} + +async fn run_heartbeat_worker(config: Config) -> Result<()> { + let observer: std::sync::Arc = + std::sync::Arc::from(crate::observability::create_observer(&config.observability)); + let engine = crate::heartbeat::engine::HeartbeatEngine::new( + config.heartbeat.clone(), + config.workspace_dir.clone(), + observer, + ); + + let interval_mins = config.heartbeat.interval_minutes.max(5); + let mut interval = tokio::time::interval(Duration::from_secs(u64::from(interval_mins) * 60)); + + loop { + interval.tick().await; + + let tasks = engine.collect_tasks().await?; + if tasks.is_empty() { + continue; + } + + for task in tasks { + let prompt = format!("[Heartbeat Task] {task}"); + let temp = config.default_temperature; + if let Err(e) = crate::agent::run(config.clone(), Some(prompt), None, None, temp).await + { + crate::health::mark_component_error("heartbeat", e.to_string()); + tracing::warn!("Heartbeat task failed: {e}"); + } else { + crate::health::mark_component_ok("heartbeat"); + } + } + } +} + +fn has_supervised_channels(config: &Config) -> bool { + config.channels_config.telegram.is_some() + || config.channels_config.discord.is_some() + || config.channels_config.slack.is_some() + || config.channels_config.imessage.is_some() + || config.channels_config.matrix.is_some() +} + +#[cfg(test)] +mod tests { + use super::*; + use tempfile::TempDir; + + fn test_config(tmp: &TempDir) -> Config { + let mut config = Config::default(); + config.workspace_dir = tmp.path().join("workspace"); + config.config_path = tmp.path().join("config.toml"); + std::fs::create_dir_all(&config.workspace_dir).unwrap(); + config + } + + #[test] + fn state_file_path_uses_config_directory() { + let tmp = TempDir::new().unwrap(); + let config = test_config(&tmp); + + let path = state_file_path(&config); + assert_eq!(path, tmp.path().join("daemon_state.json")); + } + + #[tokio::test] + async fn supervisor_marks_error_and_restart_on_failure() { + let handle = spawn_component_supervisor("daemon-test-fail", 1, 1, || async { + anyhow::bail!("boom") + }); + + tokio::time::sleep(Duration::from_millis(50)).await; + handle.abort(); + let _ = handle.await; + + let snapshot = crate::health::snapshot_json(); + let component = &snapshot["components"]["daemon-test-fail"]; + assert_eq!(component["status"], "error"); + assert!(component["restart_count"].as_u64().unwrap_or(0) >= 1); + assert!(component["last_error"] + .as_str() + .unwrap_or("") + .contains("boom")); + } + + #[tokio::test] + async fn supervisor_marks_unexpected_exit_as_error() { + let handle = spawn_component_supervisor("daemon-test-exit", 1, 1, || async { Ok(()) }); + + tokio::time::sleep(Duration::from_millis(50)).await; + handle.abort(); + let _ = handle.await; + + let snapshot = crate::health::snapshot_json(); + let component = &snapshot["components"]["daemon-test-exit"]; + assert_eq!(component["status"], "error"); + assert!(component["restart_count"].as_u64().unwrap_or(0) >= 1); + assert!(component["last_error"] + .as_str() + .unwrap_or("") + .contains("component exited unexpectedly")); + } + + #[test] + fn detects_no_supervised_channels() { + let config = Config::default(); + assert!(!has_supervised_channels(&config)); + } + + #[test] + fn detects_supervised_channels_present() { + let mut config = Config::default(); + config.channels_config.telegram = Some(crate::config::TelegramConfig { + bot_token: "token".into(), + allowed_users: vec![], + }); + assert!(has_supervised_channels(&config)); + } +} diff --git a/src/doctor/mod.rs b/src/doctor/mod.rs new file mode 100644 index 0000000..62417ea --- /dev/null +++ b/src/doctor/mod.rs @@ -0,0 +1,123 @@ +use crate::config::Config; +use anyhow::{Context, Result}; +use chrono::{DateTime, Utc}; + +const DAEMON_STALE_SECONDS: i64 = 30; +const SCHEDULER_STALE_SECONDS: i64 = 120; +const CHANNEL_STALE_SECONDS: i64 = 300; + +pub fn run(config: &Config) -> Result<()> { + let state_file = crate::daemon::state_file_path(config); + if !state_file.exists() { + println!("🩺 ZeroClaw Doctor"); + println!(" ❌ daemon state file not found: {}", state_file.display()); + println!(" 💡 Start daemon with: zeroclaw daemon"); + return Ok(()); + } + + let raw = std::fs::read_to_string(&state_file) + .with_context(|| format!("Failed to read {}", state_file.display()))?; + let snapshot: serde_json::Value = serde_json::from_str(&raw) + .with_context(|| format!("Failed to parse {}", state_file.display()))?; + + println!("🩺 ZeroClaw Doctor"); + println!(" State file: {}", state_file.display()); + + let updated_at = snapshot + .get("updated_at") + .and_then(serde_json::Value::as_str) + .unwrap_or(""); + + if let Ok(ts) = DateTime::parse_from_rfc3339(updated_at) { + let age = Utc::now() + .signed_duration_since(ts.with_timezone(&Utc)) + .num_seconds(); + if age <= DAEMON_STALE_SECONDS { + println!(" ✅ daemon heartbeat fresh ({age}s ago)"); + } else { + println!(" ❌ daemon heartbeat stale ({age}s ago)"); + } + } else { + println!(" ❌ invalid daemon timestamp: {updated_at}"); + } + + let mut channel_count = 0_u32; + let mut stale_channels = 0_u32; + + if let Some(components) = snapshot + .get("components") + .and_then(serde_json::Value::as_object) + { + if let Some(scheduler) = components.get("scheduler") { + let scheduler_ok = scheduler + .get("status") + .and_then(serde_json::Value::as_str) + .map(|s| s == "ok") + .unwrap_or(false); + + let scheduler_last_ok = scheduler + .get("last_ok") + .and_then(serde_json::Value::as_str) + .and_then(parse_rfc3339) + .map(|dt| Utc::now().signed_duration_since(dt).num_seconds()) + .unwrap_or(i64::MAX); + + if scheduler_ok && scheduler_last_ok <= SCHEDULER_STALE_SECONDS { + println!( + " ✅ scheduler healthy (last ok {}s ago)", + scheduler_last_ok + ); + } else { + println!( + " ❌ scheduler unhealthy/stale (status_ok={}, age={}s)", + scheduler_ok, scheduler_last_ok + ); + } + } else { + println!(" ❌ scheduler component missing"); + } + + for (name, component) in components { + if !name.starts_with("channel:") { + continue; + } + + channel_count += 1; + let status_ok = component + .get("status") + .and_then(serde_json::Value::as_str) + .map(|s| s == "ok") + .unwrap_or(false); + let age = component + .get("last_ok") + .and_then(serde_json::Value::as_str) + .and_then(parse_rfc3339) + .map(|dt| Utc::now().signed_duration_since(dt).num_seconds()) + .unwrap_or(i64::MAX); + + if status_ok && age <= CHANNEL_STALE_SECONDS { + println!(" ✅ {name} fresh (last ok {age}s ago)"); + } else { + stale_channels += 1; + println!(" ❌ {name} stale/unhealthy (status_ok={status_ok}, age={age}s)"); + } + } + } + + if channel_count == 0 { + println!(" ℹ️ no channel components tracked in state yet"); + } else { + println!( + " Channel summary: {} total, {} stale", + channel_count, stale_channels + ); + } + + Ok(()) +} + +fn parse_rfc3339(raw: &str) -> Option> { + DateTime::parse_from_rfc3339(raw) + .ok() + .map(|dt| dt.with_timezone(&Utc)) +} diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index 6fd27fb..b14398f 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -26,9 +26,10 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> { let actual_port = listener.local_addr()?.port(); let addr = format!("{host}:{actual_port}"); - let provider: Arc = Arc::from(providers::create_provider( + let provider: Arc = Arc::from(providers::create_resilient_provider( config.default_provider.as_deref().unwrap_or("openrouter"), config.api_key.as_deref(), + &config.reliability, )?); let model = config .default_model @@ -97,6 +98,8 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> { } println!(" Press Ctrl+C to stop.\n"); + crate::health::mark_component_ok("gateway"); + loop { let (mut stream, peer) = listener.accept().await?; let provider = provider.clone(); @@ -175,6 +178,7 @@ async fn handle_request( let body = serde_json::json!({ "status": "ok", "paired": pairing.is_paired(), + "runtime": crate::health::snapshot_json(), }); let _ = send_json(stream, 200, &body).await; } diff --git a/src/health/mod.rs b/src/health/mod.rs new file mode 100644 index 0000000..4fcd8b2 --- /dev/null +++ b/src/health/mod.rs @@ -0,0 +1,105 @@ +use chrono::Utc; +use serde::Serialize; +use std::collections::BTreeMap; +use std::sync::{Mutex, OnceLock}; +use std::time::Instant; + +#[derive(Debug, Clone, Serialize)] +pub struct ComponentHealth { + pub status: String, + pub updated_at: String, + pub last_ok: Option, + pub last_error: Option, + pub restart_count: u64, +} + +#[derive(Debug, Clone, Serialize)] +pub struct HealthSnapshot { + pub pid: u32, + pub updated_at: String, + pub uptime_seconds: u64, + pub components: BTreeMap, +} + +struct HealthRegistry { + started_at: Instant, + components: Mutex>, +} + +static REGISTRY: OnceLock = OnceLock::new(); + +fn registry() -> &'static HealthRegistry { + REGISTRY.get_or_init(|| HealthRegistry { + started_at: Instant::now(), + components: Mutex::new(BTreeMap::new()), + }) +} + +fn now_rfc3339() -> String { + Utc::now().to_rfc3339() +} + +fn upsert_component(component: &str, update: F) +where + F: FnOnce(&mut ComponentHealth), +{ + if let Ok(mut map) = registry().components.lock() { + let now = now_rfc3339(); + let entry = map + .entry(component.to_string()) + .or_insert_with(|| ComponentHealth { + status: "starting".into(), + updated_at: now.clone(), + last_ok: None, + last_error: None, + restart_count: 0, + }); + update(entry); + entry.updated_at = now; + } +} + +pub fn mark_component_ok(component: &str) { + upsert_component(component, |entry| { + entry.status = "ok".into(); + entry.last_ok = Some(now_rfc3339()); + entry.last_error = None; + }); +} + +pub fn mark_component_error(component: &str, error: impl ToString) { + let err = error.to_string(); + upsert_component(component, move |entry| { + entry.status = "error".into(); + entry.last_error = Some(err); + }); +} + +pub fn bump_component_restart(component: &str) { + upsert_component(component, |entry| { + entry.restart_count = entry.restart_count.saturating_add(1); + }); +} + +pub fn snapshot() -> HealthSnapshot { + let components = registry() + .components + .lock() + .map_or_else(|_| BTreeMap::new(), |map| map.clone()); + + HealthSnapshot { + pid: std::process::id(), + updated_at: now_rfc3339(), + uptime_seconds: registry().started_at.elapsed().as_secs(), + components, + } +} + +pub fn snapshot_json() -> serde_json::Value { + serde_json::to_value(snapshot()).unwrap_or_else(|_| { + serde_json::json!({ + "status": "error", + "message": "failed to serialize health snapshot" + }) + }) +} diff --git a/src/heartbeat/engine.rs b/src/heartbeat/engine.rs index ee31755..86b10e4 100644 --- a/src/heartbeat/engine.rs +++ b/src/heartbeat/engine.rs @@ -61,16 +61,17 @@ impl HeartbeatEngine { /// Single heartbeat tick — read HEARTBEAT.md and return task count async fn tick(&self) -> Result { + Ok(self.collect_tasks().await?.len()) + } + + /// Read HEARTBEAT.md and return all parsed tasks. + pub async fn collect_tasks(&self) -> Result> { let heartbeat_path = self.workspace_dir.join("HEARTBEAT.md"); - if !heartbeat_path.exists() { - return Ok(0); + return Ok(Vec::new()); } - let content = tokio::fs::read_to_string(&heartbeat_path).await?; - let tasks = Self::parse_tasks(&content); - - Ok(tasks.len()) + Ok(Self::parse_tasks(&content)) } /// Parse tasks from HEARTBEAT.md (lines starting with `- `) diff --git a/src/main.rs b/src/main.rs index dbc2d4b..46fb1d8 100644 --- a/src/main.rs +++ b/src/main.rs @@ -8,7 +8,7 @@ dead_code )] -use anyhow::Result; +use anyhow::{bail, Result}; use clap::{Parser, Subcommand}; use tracing::{info, Level}; use tracing_subscriber::FmtSubscriber; @@ -17,15 +17,20 @@ mod agent; mod channels; mod config; mod cron; +mod daemon; +mod doctor; mod gateway; +mod health; mod heartbeat; mod integrations; mod memory; +mod migration; mod observability; mod onboard; mod providers; mod runtime; mod security; +mod service; mod skills; mod tools; mod tunnel; @@ -43,6 +48,20 @@ struct Cli { command: Commands, } +#[derive(Subcommand, Debug)] +enum ServiceCommands { + /// Install daemon service unit for auto-start and restart + Install, + /// Start daemon service + Start, + /// Stop daemon service + Stop, + /// Check daemon service status + Status, + /// Uninstall daemon service unit + Uninstall, +} + #[derive(Subcommand, Debug)] enum Commands { /// Initialize your workspace and configuration @@ -51,6 +70,10 @@ enum Commands { #[arg(long)] interactive: bool, + /// Reconfigure channels only (fast repair flow) + #[arg(long)] + channels_only: bool, + /// API key (used in quick mode, ignored with --interactive) #[arg(long)] api_key: Option, @@ -71,7 +94,7 @@ enum Commands { provider: Option, /// Model to use - #[arg(short, long)] + #[arg(long)] model: Option, /// Temperature (0.0 - 2.0) @@ -86,10 +109,30 @@ enum Commands { port: u16, /// Host to bind to - #[arg(short, long, default_value = "127.0.0.1")] + #[arg(long, default_value = "127.0.0.1")] host: String, }, + /// Start long-running autonomous runtime (gateway + channels + heartbeat + scheduler) + Daemon { + /// Port to listen on (use 0 for random available port) + #[arg(short, long, default_value = "8080")] + port: u16, + + /// Host to bind to + #[arg(long, default_value = "127.0.0.1")] + host: String, + }, + + /// Manage OS service lifecycle (launchd/systemd user service) + Service { + #[command(subcommand)] + service_command: ServiceCommands, + }, + + /// Run diagnostics for daemon/scheduler/channel freshness + Doctor, + /// Show system status (full details) Status, @@ -116,6 +159,26 @@ enum Commands { #[command(subcommand)] skill_command: SkillCommands, }, + + /// Migrate data from other agent runtimes + Migrate { + #[command(subcommand)] + migrate_command: MigrateCommands, + }, +} + +#[derive(Subcommand, Debug)] +enum MigrateCommands { + /// Import memory from an OpenClaw workspace into this ZeroClaw workspace + Openclaw { + /// Optional path to OpenClaw workspace (defaults to ~/.openclaw/workspace) + #[arg(long)] + source: Option, + + /// Validate and preview migration without writing any data + #[arg(long)] + dry_run: bool, + }, } #[derive(Subcommand, Debug)] @@ -198,11 +261,21 @@ async fn main() -> Result<()> { // Onboard runs quick setup by default, or the interactive wizard with --interactive if let Commands::Onboard { interactive, + channels_only, api_key, provider, } = &cli.command { - let config = if *interactive { + if *interactive && *channels_only { + bail!("Use either --interactive or --channels-only, not both"); + } + if *channels_only && (api_key.is_some() || provider.is_some()) { + bail!("--channels-only does not accept --api-key or --provider"); + } + + let config = if *channels_only { + onboard::run_channels_repair_wizard()? + } else if *interactive { onboard::run_wizard()? } else { onboard::run_quick_setup(api_key.as_deref(), provider.as_deref())? @@ -236,6 +309,15 @@ async fn main() -> Result<()> { gateway::run_gateway(&host, port, config).await } + Commands::Daemon { port, host } => { + if port == 0 { + info!("🧠 Starting ZeroClaw Daemon on {host} (random port)"); + } else { + info!("🧠 Starting ZeroClaw Daemon on {host}:{port}"); + } + daemon::run(config, host, port).await + } + Commands::Status => { println!("🦀 ZeroClaw Status"); println!(); @@ -307,6 +389,10 @@ async fn main() -> Result<()> { Commands::Cron { cron_command } => cron::handle_command(cron_command, config), + Commands::Service { service_command } => service::handle_command(service_command, &config), + + Commands::Doctor => doctor::run(&config), + Commands::Channel { channel_command } => match channel_command { ChannelCommands::Start => channels::start_channels(config).await, ChannelCommands::Doctor => channels::doctor_channels(config).await, @@ -320,5 +406,20 @@ async fn main() -> Result<()> { Commands::Skills { skill_command } => { skills::handle_command(skill_command, &config.workspace_dir) } + + Commands::Migrate { migrate_command } => { + migration::handle_command(migrate_command, &config).await + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use clap::CommandFactory; + + #[test] + fn cli_definition_has_no_flag_conflicts() { + Cli::command().debug_assert(); } } diff --git a/src/memory/hygiene.rs b/src/memory/hygiene.rs new file mode 100644 index 0000000..17c95fa --- /dev/null +++ b/src/memory/hygiene.rs @@ -0,0 +1,538 @@ +use crate::config::MemoryConfig; +use anyhow::Result; +use chrono::{DateTime, Duration, Local, NaiveDate, Utc}; +use rusqlite::{params, Connection}; +use serde::{Deserialize, Serialize}; +use std::fs; +use std::path::{Path, PathBuf}; +use std::time::{Duration as StdDuration, SystemTime}; + +const HYGIENE_INTERVAL_HOURS: i64 = 12; +const STATE_FILE: &str = "memory_hygiene_state.json"; + +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +struct HygieneReport { + archived_memory_files: u64, + archived_session_files: u64, + purged_memory_archives: u64, + purged_session_archives: u64, + pruned_conversation_rows: u64, +} + +impl HygieneReport { + fn total_actions(&self) -> u64 { + self.archived_memory_files + + self.archived_session_files + + self.purged_memory_archives + + self.purged_session_archives + + self.pruned_conversation_rows + } +} + +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +struct HygieneState { + last_run_at: Option, + last_report: HygieneReport, +} + +/// Run memory/session hygiene if the cadence window has elapsed. +/// +/// This function is intentionally best-effort: callers should log and continue on failure. +pub fn run_if_due(config: &MemoryConfig, workspace_dir: &Path) -> Result<()> { + if !config.hygiene_enabled { + return Ok(()); + } + + if !should_run_now(workspace_dir)? { + return Ok(()); + } + + let report = HygieneReport { + archived_memory_files: archive_daily_memory_files( + workspace_dir, + config.archive_after_days, + )?, + archived_session_files: archive_session_files(workspace_dir, config.archive_after_days)?, + purged_memory_archives: purge_memory_archives(workspace_dir, config.purge_after_days)?, + purged_session_archives: purge_session_archives(workspace_dir, config.purge_after_days)?, + pruned_conversation_rows: prune_conversation_rows( + workspace_dir, + config.conversation_retention_days, + )?, + }; + + write_state(workspace_dir, &report)?; + + if report.total_actions() > 0 { + tracing::info!( + "memory hygiene complete: archived_memory={} archived_sessions={} purged_memory={} purged_sessions={} pruned_conversation_rows={}", + report.archived_memory_files, + report.archived_session_files, + report.purged_memory_archives, + report.purged_session_archives, + report.pruned_conversation_rows, + ); + } + + Ok(()) +} + +fn should_run_now(workspace_dir: &Path) -> Result { + let path = state_path(workspace_dir); + if !path.exists() { + return Ok(true); + } + + let raw = fs::read_to_string(&path)?; + let state: HygieneState = match serde_json::from_str(&raw) { + Ok(s) => s, + Err(_) => return Ok(true), + }; + + let Some(last_run_at) = state.last_run_at else { + return Ok(true); + }; + + let last = match DateTime::parse_from_rfc3339(&last_run_at) { + Ok(ts) => ts.with_timezone(&Utc), + Err(_) => return Ok(true), + }; + + Ok(Utc::now().signed_duration_since(last) >= Duration::hours(HYGIENE_INTERVAL_HOURS)) +} + +fn write_state(workspace_dir: &Path, report: &HygieneReport) -> Result<()> { + let path = state_path(workspace_dir); + if let Some(parent) = path.parent() { + fs::create_dir_all(parent)?; + } + + let state = HygieneState { + last_run_at: Some(Utc::now().to_rfc3339()), + last_report: report.clone(), + }; + let json = serde_json::to_vec_pretty(&state)?; + fs::write(path, json)?; + Ok(()) +} + +fn state_path(workspace_dir: &Path) -> PathBuf { + workspace_dir.join("state").join(STATE_FILE) +} + +fn archive_daily_memory_files(workspace_dir: &Path, archive_after_days: u32) -> Result { + if archive_after_days == 0 { + return Ok(0); + } + + let memory_dir = workspace_dir.join("memory"); + if !memory_dir.is_dir() { + return Ok(0); + } + + let archive_dir = memory_dir.join("archive"); + fs::create_dir_all(&archive_dir)?; + + let cutoff = Local::now().date_naive() - Duration::days(i64::from(archive_after_days)); + let mut moved = 0_u64; + + for entry in fs::read_dir(&memory_dir)? { + let entry = entry?; + let path = entry.path(); + + if path.is_dir() { + continue; + } + if path.extension().and_then(|e| e.to_str()) != Some("md") { + continue; + } + + let Some(filename) = path.file_name().and_then(|f| f.to_str()) else { + continue; + }; + + let Some(file_date) = memory_date_from_filename(filename) else { + continue; + }; + + if file_date < cutoff { + move_to_archive(&path, &archive_dir)?; + moved += 1; + } + } + + Ok(moved) +} + +fn archive_session_files(workspace_dir: &Path, archive_after_days: u32) -> Result { + if archive_after_days == 0 { + return Ok(0); + } + + let sessions_dir = workspace_dir.join("sessions"); + if !sessions_dir.is_dir() { + return Ok(0); + } + + let archive_dir = sessions_dir.join("archive"); + fs::create_dir_all(&archive_dir)?; + + let cutoff_date = Local::now().date_naive() - Duration::days(i64::from(archive_after_days)); + let cutoff_time = SystemTime::now() + .checked_sub(StdDuration::from_secs( + u64::from(archive_after_days) * 24 * 60 * 60, + )) + .unwrap_or(SystemTime::UNIX_EPOCH); + + let mut moved = 0_u64; + for entry in fs::read_dir(&sessions_dir)? { + let entry = entry?; + let path = entry.path(); + + if path.is_dir() { + continue; + } + + let Some(filename) = path.file_name().and_then(|f| f.to_str()) else { + continue; + }; + + let is_old = if let Some(date) = date_prefix(filename) { + date < cutoff_date + } else { + is_older_than(&path, cutoff_time) + }; + + if is_old { + move_to_archive(&path, &archive_dir)?; + moved += 1; + } + } + + Ok(moved) +} + +fn purge_memory_archives(workspace_dir: &Path, purge_after_days: u32) -> Result { + if purge_after_days == 0 { + return Ok(0); + } + + let archive_dir = workspace_dir.join("memory").join("archive"); + if !archive_dir.is_dir() { + return Ok(0); + } + + let cutoff = Local::now().date_naive() - Duration::days(i64::from(purge_after_days)); + let mut removed = 0_u64; + + for entry in fs::read_dir(&archive_dir)? { + let entry = entry?; + let path = entry.path(); + + if path.is_dir() { + continue; + } + + let Some(filename) = path.file_name().and_then(|f| f.to_str()) else { + continue; + }; + + let Some(file_date) = memory_date_from_filename(filename) else { + continue; + }; + + if file_date < cutoff { + fs::remove_file(&path)?; + removed += 1; + } + } + + Ok(removed) +} + +fn purge_session_archives(workspace_dir: &Path, purge_after_days: u32) -> Result { + if purge_after_days == 0 { + return Ok(0); + } + + let archive_dir = workspace_dir.join("sessions").join("archive"); + if !archive_dir.is_dir() { + return Ok(0); + } + + let cutoff_date = Local::now().date_naive() - Duration::days(i64::from(purge_after_days)); + let cutoff_time = SystemTime::now() + .checked_sub(StdDuration::from_secs( + u64::from(purge_after_days) * 24 * 60 * 60, + )) + .unwrap_or(SystemTime::UNIX_EPOCH); + + let mut removed = 0_u64; + for entry in fs::read_dir(&archive_dir)? { + let entry = entry?; + let path = entry.path(); + + if path.is_dir() { + continue; + } + + let Some(filename) = path.file_name().and_then(|f| f.to_str()) else { + continue; + }; + + let is_old = if let Some(date) = date_prefix(filename) { + date < cutoff_date + } else { + is_older_than(&path, cutoff_time) + }; + + if is_old { + fs::remove_file(&path)?; + removed += 1; + } + } + + Ok(removed) +} + +fn prune_conversation_rows(workspace_dir: &Path, retention_days: u32) -> Result { + if retention_days == 0 { + return Ok(0); + } + + let db_path = workspace_dir.join("memory").join("brain.db"); + if !db_path.exists() { + return Ok(0); + } + + let conn = Connection::open(db_path)?; + let cutoff = (Local::now() - Duration::days(i64::from(retention_days))).to_rfc3339(); + + let affected = conn.execute( + "DELETE FROM memories WHERE category = 'conversation' AND updated_at < ?1", + params![cutoff], + )?; + + Ok(u64::try_from(affected).unwrap_or(0)) +} + +fn memory_date_from_filename(filename: &str) -> Option { + let stem = filename.strip_suffix(".md")?; + let date_part = stem.split('_').next().unwrap_or(stem); + NaiveDate::parse_from_str(date_part, "%Y-%m-%d").ok() +} + +fn date_prefix(filename: &str) -> Option { + if filename.len() < 10 { + return None; + } + NaiveDate::parse_from_str(&filename[..10], "%Y-%m-%d").ok() +} + +fn is_older_than(path: &Path, cutoff: SystemTime) -> bool { + fs::metadata(path) + .and_then(|meta| meta.modified()) + .map(|modified| modified < cutoff) + .unwrap_or(false) +} + +fn move_to_archive(src: &Path, archive_dir: &Path) -> Result<()> { + let Some(filename) = src.file_name().and_then(|f| f.to_str()) else { + return Ok(()); + }; + + let target = unique_archive_target(archive_dir, filename); + fs::rename(src, target)?; + Ok(()) +} + +fn unique_archive_target(archive_dir: &Path, filename: &str) -> PathBuf { + let direct = archive_dir.join(filename); + if !direct.exists() { + return direct; + } + + let (stem, ext) = split_name(filename); + for i in 1..10_000 { + let candidate = if ext.is_empty() { + archive_dir.join(format!("{stem}_{i}")) + } else { + archive_dir.join(format!("{stem}_{i}.{ext}")) + }; + if !candidate.exists() { + return candidate; + } + } + + direct +} + +fn split_name(filename: &str) -> (&str, &str) { + match filename.rsplit_once('.') { + Some((stem, ext)) => (stem, ext), + None => (filename, ""), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::memory::{Memory, MemoryCategory, SqliteMemory}; + use tempfile::TempDir; + + fn default_cfg() -> MemoryConfig { + MemoryConfig::default() + } + + #[test] + fn archives_old_daily_memory_files() { + let tmp = TempDir::new().unwrap(); + let workspace = tmp.path(); + fs::create_dir_all(workspace.join("memory")).unwrap(); + + let old = (Local::now().date_naive() - Duration::days(10)) + .format("%Y-%m-%d") + .to_string(); + let today = Local::now().date_naive().format("%Y-%m-%d").to_string(); + + let old_file = workspace.join("memory").join(format!("{old}.md")); + let today_file = workspace.join("memory").join(format!("{today}.md")); + fs::write(&old_file, "old note").unwrap(); + fs::write(&today_file, "fresh note").unwrap(); + + run_if_due(&default_cfg(), workspace).unwrap(); + + assert!(!old_file.exists(), "old daily file should be archived"); + assert!( + workspace + .join("memory") + .join("archive") + .join(format!("{old}.md")) + .exists(), + "old daily file should exist in memory/archive" + ); + assert!(today_file.exists(), "today file should remain in place"); + } + + #[test] + fn archives_old_session_files() { + let tmp = TempDir::new().unwrap(); + let workspace = tmp.path(); + fs::create_dir_all(workspace.join("sessions")).unwrap(); + + let old = (Local::now().date_naive() - Duration::days(10)) + .format("%Y-%m-%d") + .to_string(); + let old_name = format!("{old}-agent.log"); + let old_file = workspace.join("sessions").join(&old_name); + fs::write(&old_file, "old session").unwrap(); + + run_if_due(&default_cfg(), workspace).unwrap(); + + assert!(!old_file.exists(), "old session file should be archived"); + assert!( + workspace + .join("sessions") + .join("archive") + .join(&old_name) + .exists(), + "archived session file should exist" + ); + } + + #[test] + fn skips_second_run_within_cadence_window() { + let tmp = TempDir::new().unwrap(); + let workspace = tmp.path(); + fs::create_dir_all(workspace.join("memory")).unwrap(); + + let old_a = (Local::now().date_naive() - Duration::days(10)) + .format("%Y-%m-%d") + .to_string(); + let file_a = workspace.join("memory").join(format!("{old_a}.md")); + fs::write(&file_a, "first").unwrap(); + + run_if_due(&default_cfg(), workspace).unwrap(); + assert!(!file_a.exists(), "first old file should be archived"); + + let old_b = (Local::now().date_naive() - Duration::days(9)) + .format("%Y-%m-%d") + .to_string(); + let file_b = workspace.join("memory").join(format!("{old_b}.md")); + fs::write(&file_b, "second").unwrap(); + + // Should skip because cadence gate prevents a second immediate run. + run_if_due(&default_cfg(), workspace).unwrap(); + assert!( + file_b.exists(), + "second file should remain because run is throttled" + ); + } + + #[test] + fn purges_old_memory_archives() { + let tmp = TempDir::new().unwrap(); + let workspace = tmp.path(); + let archive_dir = workspace.join("memory").join("archive"); + fs::create_dir_all(&archive_dir).unwrap(); + + let old = (Local::now().date_naive() - Duration::days(40)) + .format("%Y-%m-%d") + .to_string(); + let keep = (Local::now().date_naive() - Duration::days(5)) + .format("%Y-%m-%d") + .to_string(); + + let old_file = archive_dir.join(format!("{old}.md")); + let keep_file = archive_dir.join(format!("{keep}.md")); + fs::write(&old_file, "expired").unwrap(); + fs::write(&keep_file, "recent").unwrap(); + + run_if_due(&default_cfg(), workspace).unwrap(); + + assert!(!old_file.exists(), "old archived file should be purged"); + assert!(keep_file.exists(), "recent archived file should remain"); + } + + #[tokio::test] + async fn prunes_old_conversation_rows_in_sqlite_backend() { + let tmp = TempDir::new().unwrap(); + let workspace = tmp.path(); + + let mem = SqliteMemory::new(workspace).unwrap(); + mem.store("conv_old", "outdated", MemoryCategory::Conversation) + .await + .unwrap(); + mem.store("core_keep", "durable", MemoryCategory::Core) + .await + .unwrap(); + drop(mem); + + let db_path = workspace.join("memory").join("brain.db"); + let conn = Connection::open(&db_path).unwrap(); + let old_cutoff = (Local::now() - Duration::days(60)).to_rfc3339(); + conn.execute( + "UPDATE memories SET created_at = ?1, updated_at = ?1 WHERE key = 'conv_old'", + params![old_cutoff], + ) + .unwrap(); + drop(conn); + + let mut cfg = default_cfg(); + cfg.archive_after_days = 0; + cfg.purge_after_days = 0; + cfg.conversation_retention_days = 30; + + run_if_due(&cfg, workspace).unwrap(); + + let mem2 = SqliteMemory::new(workspace).unwrap(); + assert!( + mem2.get("conv_old").await.unwrap().is_none(), + "old conversation rows should be pruned" + ); + assert!( + mem2.get("core_keep").await.unwrap().is_some(), + "core memory should remain" + ); + } +} diff --git a/src/memory/mod.rs b/src/memory/mod.rs index 249670b..66912ca 100644 --- a/src/memory/mod.rs +++ b/src/memory/mod.rs @@ -1,5 +1,6 @@ pub mod chunker; pub mod embeddings; +pub mod hygiene; pub mod markdown; pub mod sqlite; pub mod traits; @@ -21,6 +22,11 @@ pub fn create_memory( workspace_dir: &Path, api_key: Option<&str>, ) -> anyhow::Result> { + // Best-effort memory hygiene/retention pass (throttled by state file). + if let Err(e) = hygiene::run_if_due(config, workspace_dir) { + tracing::warn!("memory hygiene skipped: {e}"); + } + match config.backend.as_str() { "sqlite" => { let embedder: Arc = diff --git a/src/migration.rs b/src/migration.rs new file mode 100644 index 0000000..ed160c7 --- /dev/null +++ b/src/migration.rs @@ -0,0 +1,553 @@ +use crate::config::Config; +use crate::memory::{MarkdownMemory, Memory, MemoryCategory, SqliteMemory}; +use anyhow::{bail, Context, Result}; +use directories::UserDirs; +use rusqlite::{Connection, OpenFlags, OptionalExtension}; +use std::collections::HashSet; +use std::fs; +use std::path::{Path, PathBuf}; + +#[derive(Debug, Clone)] +struct SourceEntry { + key: String, + content: String, + category: MemoryCategory, +} + +#[derive(Debug, Default)] +struct MigrationStats { + from_sqlite: usize, + from_markdown: usize, + imported: usize, + skipped_unchanged: usize, + renamed_conflicts: usize, +} + +pub async fn handle_command(command: super::MigrateCommands, config: &Config) -> Result<()> { + match command { + super::MigrateCommands::Openclaw { source, dry_run } => { + migrate_openclaw_memory(config, source, dry_run).await + } + } +} + +async fn migrate_openclaw_memory( + config: &Config, + source_workspace: Option, + dry_run: bool, +) -> Result<()> { + let source_workspace = resolve_openclaw_workspace(source_workspace)?; + if !source_workspace.exists() { + bail!( + "OpenClaw workspace not found at {}. Pass --source if needed.", + source_workspace.display() + ); + } + + if paths_equal(&source_workspace, &config.workspace_dir) { + bail!("Source workspace matches current ZeroClaw workspace; refusing self-migration"); + } + + let mut stats = MigrationStats::default(); + let entries = collect_source_entries(&source_workspace, &mut stats)?; + + if entries.is_empty() { + println!( + "No importable memory found in {}", + source_workspace.display() + ); + println!("Checked for: memory/brain.db, MEMORY.md, memory/*.md"); + return Ok(()); + } + + if dry_run { + println!("🔎 Dry run: OpenClaw migration preview"); + println!(" Source: {}", source_workspace.display()); + println!(" Target: {}", config.workspace_dir.display()); + println!(" Candidates: {}", entries.len()); + println!(" - from sqlite: {}", stats.from_sqlite); + println!(" - from markdown: {}", stats.from_markdown); + println!(); + println!("Run without --dry-run to import these entries."); + return Ok(()); + } + + if let Some(backup_dir) = backup_target_memory(&config.workspace_dir)? { + println!("🛟 Backup created: {}", backup_dir.display()); + } + + let memory = target_memory_backend(config)?; + + for (idx, entry) in entries.into_iter().enumerate() { + let mut key = entry.key.trim().to_string(); + if key.is_empty() { + key = format!("openclaw_{idx}"); + } + + if let Some(existing) = memory.get(&key).await? { + if existing.content.trim() == entry.content.trim() { + stats.skipped_unchanged += 1; + continue; + } + + let renamed = next_available_key(memory.as_ref(), &key).await?; + key = renamed; + stats.renamed_conflicts += 1; + } + + memory.store(&key, &entry.content, entry.category).await?; + stats.imported += 1; + } + + println!("✅ OpenClaw memory migration complete"); + println!(" Source: {}", source_workspace.display()); + println!(" Target: {}", config.workspace_dir.display()); + println!(" Imported: {}", stats.imported); + println!(" Skipped unchanged:{}", stats.skipped_unchanged); + println!(" Renamed conflicts:{}", stats.renamed_conflicts); + println!(" Source sqlite rows:{}", stats.from_sqlite); + println!(" Source markdown: {}", stats.from_markdown); + + Ok(()) +} + +fn target_memory_backend(config: &Config) -> Result> { + match config.memory.backend.as_str() { + "sqlite" => Ok(Box::new(SqliteMemory::new(&config.workspace_dir)?)), + "markdown" | "none" => Ok(Box::new(MarkdownMemory::new(&config.workspace_dir))), + other => { + tracing::warn!( + "Unknown memory backend '{other}' during migration, defaulting to markdown" + ); + Ok(Box::new(MarkdownMemory::new(&config.workspace_dir))) + } + } +} + +fn collect_source_entries( + source_workspace: &Path, + stats: &mut MigrationStats, +) -> Result> { + let mut entries = Vec::new(); + + let sqlite_path = source_workspace.join("memory").join("brain.db"); + let sqlite_entries = read_openclaw_sqlite_entries(&sqlite_path)?; + stats.from_sqlite = sqlite_entries.len(); + entries.extend(sqlite_entries); + + let markdown_entries = read_openclaw_markdown_entries(source_workspace)?; + stats.from_markdown = markdown_entries.len(); + entries.extend(markdown_entries); + + // De-dup exact duplicates to make re-runs deterministic. + let mut seen = HashSet::new(); + entries.retain(|entry| { + let sig = format!("{}\u{0}{}\u{0}{}", entry.key, entry.content, entry.category); + seen.insert(sig) + }); + + Ok(entries) +} + +fn read_openclaw_sqlite_entries(db_path: &Path) -> Result> { + if !db_path.exists() { + return Ok(Vec::new()); + } + + let conn = Connection::open_with_flags(db_path, OpenFlags::SQLITE_OPEN_READ_ONLY) + .with_context(|| format!("Failed to open source db {}", db_path.display()))?; + + let table_exists: Option = conn + .query_row( + "SELECT name FROM sqlite_master WHERE type='table' AND name='memories' LIMIT 1", + [], + |row| row.get(0), + ) + .optional()?; + + if table_exists.is_none() { + return Ok(Vec::new()); + } + + let columns = table_columns(&conn, "memories")?; + let key_expr = pick_column_expr(&columns, &["key", "id", "name"], "CAST(rowid AS TEXT)"); + let Some(content_expr) = + pick_optional_column_expr(&columns, &["content", "value", "text", "memory"]) + else { + bail!("OpenClaw memories table found but no content-like column was detected"); + }; + let category_expr = pick_column_expr(&columns, &["category", "kind", "type"], "'core'"); + + let sql = format!( + "SELECT {key_expr} AS key, {content_expr} AS content, {category_expr} AS category FROM memories" + ); + + let mut stmt = conn.prepare(&sql)?; + let mut rows = stmt.query([])?; + + let mut entries = Vec::new(); + let mut idx = 0_usize; + + while let Some(row) = rows.next()? { + let key: String = row + .get(0) + .unwrap_or_else(|_| format!("openclaw_sqlite_{idx}")); + let content: String = row.get(1).unwrap_or_default(); + let category_raw: String = row.get(2).unwrap_or_else(|_| "core".to_string()); + + if content.trim().is_empty() { + continue; + } + + entries.push(SourceEntry { + key: normalize_key(&key, idx), + content: content.trim().to_string(), + category: parse_category(&category_raw), + }); + + idx += 1; + } + + Ok(entries) +} + +fn read_openclaw_markdown_entries(source_workspace: &Path) -> Result> { + let mut all = Vec::new(); + + let core_path = source_workspace.join("MEMORY.md"); + if core_path.exists() { + let content = fs::read_to_string(&core_path)?; + all.extend(parse_markdown_file( + &core_path, + &content, + MemoryCategory::Core, + "openclaw_core", + )); + } + + let daily_dir = source_workspace.join("memory"); + if daily_dir.exists() { + for file in fs::read_dir(&daily_dir)? { + let file = file?; + let path = file.path(); + if path.extension().and_then(|ext| ext.to_str()) != Some("md") { + continue; + } + let content = fs::read_to_string(&path)?; + let stem = path + .file_stem() + .and_then(|s| s.to_str()) + .unwrap_or("openclaw_daily"); + all.extend(parse_markdown_file( + &path, + &content, + MemoryCategory::Daily, + stem, + )); + } + } + + Ok(all) +} + +fn parse_markdown_file( + _path: &Path, + content: &str, + default_category: MemoryCategory, + stem: &str, +) -> Vec { + let mut entries = Vec::new(); + + for (idx, raw_line) in content.lines().enumerate() { + let trimmed = raw_line.trim(); + if trimmed.is_empty() || trimmed.starts_with('#') { + continue; + } + + let line = trimmed.strip_prefix("- ").unwrap_or(trimmed); + let (key, text) = match parse_structured_memory_line(line) { + Some((k, v)) => (normalize_key(k, idx), v.trim().to_string()), + None => ( + format!("openclaw_{stem}_{}", idx + 1), + line.trim().to_string(), + ), + }; + + if text.is_empty() { + continue; + } + + entries.push(SourceEntry { + key, + content: text, + category: default_category.clone(), + }); + } + + entries +} + +fn parse_structured_memory_line(line: &str) -> Option<(&str, &str)> { + if !line.starts_with("**") { + return None; + } + + let rest = line.strip_prefix("**")?; + let key_end = rest.find("**:")?; + let key = rest.get(..key_end)?.trim(); + let value = rest.get(key_end + 3..)?.trim(); + + if key.is_empty() || value.is_empty() { + return None; + } + + Some((key, value)) +} + +fn parse_category(raw: &str) -> MemoryCategory { + match raw.trim().to_ascii_lowercase().as_str() { + "core" => MemoryCategory::Core, + "daily" => MemoryCategory::Daily, + "conversation" => MemoryCategory::Conversation, + "" => MemoryCategory::Core, + other => MemoryCategory::Custom(other.to_string()), + } +} + +fn normalize_key(key: &str, fallback_idx: usize) -> String { + let trimmed = key.trim(); + if trimmed.is_empty() { + return format!("openclaw_{fallback_idx}"); + } + trimmed.to_string() +} + +async fn next_available_key(memory: &dyn Memory, base: &str) -> Result { + for i in 1..=10_000 { + let candidate = format!("{base}__openclaw_{i}"); + if memory.get(&candidate).await?.is_none() { + return Ok(candidate); + } + } + + bail!("Unable to allocate non-conflicting key for '{base}'") +} + +fn table_columns(conn: &Connection, table: &str) -> Result> { + let pragma = format!("PRAGMA table_info({table})"); + let mut stmt = conn.prepare(&pragma)?; + let rows = stmt.query_map([], |row| row.get::<_, String>(1))?; + + let mut cols = Vec::new(); + for col in rows { + cols.push(col?.to_ascii_lowercase()); + } + + Ok(cols) +} + +fn pick_optional_column_expr(columns: &[String], candidates: &[&str]) -> Option { + candidates + .iter() + .find(|candidate| columns.iter().any(|c| c == *candidate)) + .map(|s| s.to_string()) +} + +fn pick_column_expr(columns: &[String], candidates: &[&str], fallback: &str) -> String { + pick_optional_column_expr(columns, candidates).unwrap_or_else(|| fallback.to_string()) +} + +fn resolve_openclaw_workspace(source: Option) -> Result { + if let Some(src) = source { + return Ok(src); + } + + let home = UserDirs::new() + .map(|u| u.home_dir().to_path_buf()) + .context("Could not find home directory")?; + + Ok(home.join(".openclaw").join("workspace")) +} + +fn paths_equal(a: &Path, b: &Path) -> bool { + match (fs::canonicalize(a), fs::canonicalize(b)) { + (Ok(a), Ok(b)) => a == b, + _ => a == b, + } +} + +fn backup_target_memory(workspace_dir: &Path) -> Result> { + let timestamp = chrono::Local::now().format("%Y%m%d-%H%M%S").to_string(); + let backup_root = workspace_dir + .join("memory") + .join("migrations") + .join(format!("openclaw-{timestamp}")); + + let mut copied_any = false; + fs::create_dir_all(&backup_root)?; + + let files_to_copy = [ + workspace_dir.join("memory").join("brain.db"), + workspace_dir.join("MEMORY.md"), + ]; + + for source in files_to_copy { + if source.exists() { + let Some(name) = source.file_name() else { + continue; + }; + fs::copy(&source, backup_root.join(name))?; + copied_any = true; + } + } + + let daily_dir = workspace_dir.join("memory"); + if daily_dir.exists() { + let daily_backup = backup_root.join("daily"); + for file in fs::read_dir(&daily_dir)? { + let file = file?; + let path = file.path(); + if path.extension().and_then(|ext| ext.to_str()) != Some("md") { + continue; + } + fs::create_dir_all(&daily_backup)?; + let Some(name) = path.file_name() else { + continue; + }; + fs::copy(&path, daily_backup.join(name))?; + copied_any = true; + } + } + + if copied_any { + Ok(Some(backup_root)) + } else { + let _ = fs::remove_dir_all(&backup_root); + Ok(None) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::config::{Config, MemoryConfig}; + use rusqlite::params; + use tempfile::TempDir; + + fn test_config(workspace: &Path) -> Config { + Config { + workspace_dir: workspace.to_path_buf(), + config_path: workspace.join("config.toml"), + memory: MemoryConfig { + backend: "sqlite".to_string(), + ..MemoryConfig::default() + }, + ..Config::default() + } + } + + #[test] + fn parse_structured_markdown_line() { + let line = "**user_pref**: likes Rust"; + let parsed = parse_structured_memory_line(line).unwrap(); + assert_eq!(parsed.0, "user_pref"); + assert_eq!(parsed.1, "likes Rust"); + } + + #[test] + fn parse_unstructured_markdown_generates_key() { + let entries = parse_markdown_file( + Path::new("/tmp/MEMORY.md"), + "- plain note", + MemoryCategory::Core, + "core", + ); + assert_eq!(entries.len(), 1); + assert!(entries[0].key.starts_with("openclaw_core_")); + assert_eq!(entries[0].content, "plain note"); + } + + #[test] + fn sqlite_reader_supports_legacy_value_column() { + let dir = TempDir::new().unwrap(); + let db_path = dir.path().join("brain.db"); + let conn = Connection::open(&db_path).unwrap(); + + conn.execute_batch("CREATE TABLE memories (key TEXT, value TEXT, type TEXT);") + .unwrap(); + conn.execute( + "INSERT INTO memories (key, value, type) VALUES (?1, ?2, ?3)", + params!["legacy_key", "legacy_value", "daily"], + ) + .unwrap(); + + let rows = read_openclaw_sqlite_entries(&db_path).unwrap(); + assert_eq!(rows.len(), 1); + assert_eq!(rows[0].key, "legacy_key"); + assert_eq!(rows[0].content, "legacy_value"); + assert_eq!(rows[0].category, MemoryCategory::Daily); + } + + #[tokio::test] + async fn migration_renames_conflicting_key() { + let source = TempDir::new().unwrap(); + let target = TempDir::new().unwrap(); + + // Existing target memory + let target_mem = SqliteMemory::new(target.path()).unwrap(); + target_mem + .store("k", "new value", MemoryCategory::Core) + .await + .unwrap(); + + // Source sqlite with conflicting key + different content + let source_db_dir = source.path().join("memory"); + fs::create_dir_all(&source_db_dir).unwrap(); + let source_db = source_db_dir.join("brain.db"); + let conn = Connection::open(&source_db).unwrap(); + conn.execute_batch("CREATE TABLE memories (key TEXT, content TEXT, category TEXT);") + .unwrap(); + conn.execute( + "INSERT INTO memories (key, content, category) VALUES (?1, ?2, ?3)", + params!["k", "old value", "core"], + ) + .unwrap(); + + let config = test_config(target.path()); + migrate_openclaw_memory(&config, Some(source.path().to_path_buf()), false) + .await + .unwrap(); + + let all = target_mem.list(None).await.unwrap(); + assert!(all.iter().any(|e| e.key == "k" && e.content == "new value")); + assert!(all + .iter() + .any(|e| e.key.starts_with("k__openclaw_") && e.content == "old value")); + } + + #[tokio::test] + async fn dry_run_does_not_write() { + let source = TempDir::new().unwrap(); + let target = TempDir::new().unwrap(); + let source_db_dir = source.path().join("memory"); + fs::create_dir_all(&source_db_dir).unwrap(); + + let source_db = source_db_dir.join("brain.db"); + let conn = Connection::open(&source_db).unwrap(); + conn.execute_batch("CREATE TABLE memories (key TEXT, content TEXT, category TEXT);") + .unwrap(); + conn.execute( + "INSERT INTO memories (key, content, category) VALUES (?1, ?2, ?3)", + params!["dry", "run", "core"], + ) + .unwrap(); + + let config = test_config(target.path()); + migrate_openclaw_memory(&config, Some(source.path().to_path_buf()), true) + .await + .unwrap(); + + let target_mem = SqliteMemory::new(target.path()).unwrap(); + assert_eq!(target_mem.count().await.unwrap(), 0); + } +} diff --git a/src/onboard/mod.rs b/src/onboard/mod.rs index 0f16b88..a18ce8a 100644 --- a/src/onboard/mod.rs +++ b/src/onboard/mod.rs @@ -1,3 +1,3 @@ pub mod wizard; -pub use wizard::{run_quick_setup, run_wizard}; +pub use wizard::{run_channels_repair_wizard, run_quick_setup, run_wizard}; diff --git a/src/onboard/wizard.rs b/src/onboard/wizard.rs index 0153cbd..b4e69ce 100644 --- a/src/onboard/wizard.rs +++ b/src/onboard/wizard.rs @@ -91,6 +91,7 @@ pub fn run_wizard() -> Result { observability: ObservabilityConfig::default(), autonomy: AutonomyConfig::default(), runtime: RuntimeConfig::default(), + reliability: crate::config::ReliabilityConfig::default(), heartbeat: HeartbeatConfig::default(), channels_config, memory: MemoryConfig::default(), // SQLite + auto-save by default @@ -149,6 +150,61 @@ pub fn run_wizard() -> Result { Ok(config) } +/// Interactive repair flow: rerun channel setup only without redoing full onboarding. +pub fn run_channels_repair_wizard() -> Result { + println!("{}", style(BANNER).cyan().bold()); + println!( + " {}", + style("Channels Repair — update channel tokens and allowlists only") + .white() + .bold() + ); + println!(); + + let mut config = Config::load_or_init()?; + + print_step(1, 1, "Channels (How You Talk to ZeroClaw)"); + config.channels_config = setup_channels()?; + config.save()?; + + println!(); + println!( + " {} Channel config saved: {}", + style("✓").green().bold(), + style(config.config_path.display()).green() + ); + + let has_channels = config.channels_config.telegram.is_some() + || config.channels_config.discord.is_some() + || config.channels_config.slack.is_some() + || config.channels_config.imessage.is_some() + || config.channels_config.matrix.is_some(); + + if has_channels && config.api_key.is_some() { + let launch: bool = Confirm::new() + .with_prompt(format!( + " {} Launch channels now? (connected channels → AI → reply)", + style("🚀").cyan() + )) + .default(true) + .interact()?; + + if launch { + println!(); + println!( + " {} {}", + style("⚡").cyan(), + style("Starting channel server...").white().bold() + ); + println!(); + // Signal to main.rs to call start_channels after wizard returns + std::env::set_var("ZEROCLAW_AUTOSTART_CHANNELS", "1"); + } + } + + Ok(config) +} + // ── Quick setup (zero prompts) ─────────────────────────────────── /// Non-interactive setup: generates a sensible default config instantly. @@ -187,6 +243,7 @@ pub fn run_quick_setup(api_key: Option<&str>, provider: Option<&str>) -> Result< observability: ObservabilityConfig::default(), autonomy: AutonomyConfig::default(), runtime: RuntimeConfig::default(), + reliability: crate::config::ReliabilityConfig::default(), heartbeat: HeartbeatConfig::default(), channels_config: ChannelsConfig::default(), memory: MemoryConfig::default(), @@ -204,7 +261,9 @@ pub fn run_quick_setup(api_key: Option<&str>, provider: Option<&str>) -> Result< user_name: std::env::var("USER").unwrap_or_else(|_| "User".into()), timezone: "UTC".into(), agent_name: "ZeroClaw".into(), - communication_style: "Direct and concise".into(), + communication_style: + "Be warm, natural, and clear. Use occasional relevant emojis (1-2 max) and avoid robotic phrasing." + .into(), }; scaffold_workspace(&workspace_dir, &default_ctx)?; @@ -824,24 +883,33 @@ fn setup_project_context() -> Result { let style_options = vec![ "Direct & concise — skip pleasantries, get to the point", - "Friendly & casual — warm but efficient", + "Friendly & casual — warm, human, and helpful", + "Professional & polished — calm, confident, and clear", + "Expressive & playful — more personality + natural emojis", "Technical & detailed — thorough explanations, code-first", "Balanced — adapt to the situation", + "Custom — write your own style guide", ]; let style_idx = Select::new() .with_prompt(" Communication style") .items(&style_options) - .default(0) + .default(1) .interact()?; let communication_style = match style_idx { 0 => "Be direct and concise. Skip pleasantries. Get to the point.".to_string(), - 1 => "Be friendly and casual. Warm but efficient.".to_string(), - 2 => "Be technical and detailed. Thorough explanations, code-first.".to_string(), - _ => { - "Adapt to the situation. Be concise when needed, thorough when it matters.".to_string() - } + 1 => "Be friendly, human, and conversational. Show warmth and empathy while staying efficient. Use natural contractions.".to_string(), + 2 => "Be professional and polished. Stay calm, structured, and respectful. Use occasional tone-setting emojis only when appropriate.".to_string(), + 3 => "Be expressive and playful when appropriate. Use relevant emojis naturally (0-2 max), and keep serious topics emoji-light.".to_string(), + 4 => "Be technical and detailed. Thorough explanations, code-first.".to_string(), + 5 => "Adapt to the situation. Default to warm and clear communication; be concise when needed, thorough when it matters.".to_string(), + _ => Input::new() + .with_prompt(" Custom communication style") + .default( + "Be warm, natural, and clear. Use occasional relevant emojis (1-2 max) and avoid robotic phrasing.".into(), + ) + .interact_text()?, }; println!( @@ -987,17 +1055,38 @@ fn setup_channels() -> Result { } } + print_bullet( + "Allowlist your own Telegram identity first (recommended for secure + fast setup).", + ); + print_bullet( + "Use your @username without '@' (example: argenis), or your numeric Telegram user ID.", + ); + print_bullet("Use '*' only for temporary open testing."); + let users_str: String = Input::new() - .with_prompt(" Allowed usernames (comma-separated, or * for all)") - .default("*".into()) + .with_prompt( + " Allowed Telegram identities (comma-separated: username without '@' and/or numeric user ID, '*' for all)", + ) + .allow_empty(true) .interact_text()?; let allowed_users = if users_str.trim() == "*" { vec!["*".into()] } else { - users_str.split(',').map(|s| s.trim().to_string()).collect() + users_str + .split(',') + .map(|s| s.trim().to_string()) + .filter(|s| !s.is_empty()) + .collect() }; + if allowed_users.is_empty() { + println!( + " {} No users allowlisted — Telegram inbound messages will be denied until you add your username/user ID or '*'.", + style("⚠").yellow().bold() + ); + } + config.telegram = Some(TelegramConfig { bot_token: token, allowed_users, @@ -1057,9 +1146,15 @@ fn setup_channels() -> Result { .allow_empty(true) .interact_text()?; + print_bullet("Allowlist your own Discord user ID first (recommended)."); + print_bullet( + "Get it in Discord: Settings -> Advanced -> Developer Mode (ON), then right-click your profile -> Copy User ID.", + ); + print_bullet("Use '*' only for temporary open testing."); + let allowed_users_str: String = Input::new() .with_prompt( - " Allowed Discord user IDs (comma-separated, '*' for all, Enter to deny all)", + " Allowed Discord user IDs (comma-separated, recommended: your own ID, '*' for all)", ) .allow_empty(true) .interact_text()?; @@ -1160,9 +1255,15 @@ fn setup_channels() -> Result { .allow_empty(true) .interact_text()?; + print_bullet("Allowlist your own Slack member ID first (recommended)."); + print_bullet( + "Member IDs usually start with 'U' (open your Slack profile -> More -> Copy member ID).", + ); + print_bullet("Use '*' only for temporary open testing."); + let allowed_users_str: String = Input::new() .with_prompt( - " Allowed Slack user IDs (comma-separated, '*' for all, Enter to deny all)", + " Allowed Slack user IDs (comma-separated, recommended: your own member ID, '*' for all)", ) .allow_empty(true) .interact_text()?; @@ -1564,7 +1665,7 @@ fn scaffold_workspace(workspace_dir: &Path, ctx: &ProjectContext) -> Result<()> &ctx.timezone }; let comm_style = if ctx.communication_style.is_empty() { - "Adapt to the situation. Be concise when needed, thorough when it matters." + "Be warm, natural, and clear. Use occasional relevant emojis (1-2 max) and avoid robotic phrasing." } else { &ctx.communication_style }; @@ -1613,6 +1714,14 @@ fn scaffold_workspace(workspace_dir: &Path, ctx: &ProjectContext) -> Result<()> ## Tools & Skills\n\n\ Skills are listed in the system prompt. Use `read` on a skill's SKILL.md for details.\n\ Keep local notes (SSH hosts, device names, etc.) in `TOOLS.md`.\n\n\ + ## Crash Recovery\n\n\ + - If a run stops unexpectedly, recover context before acting.\n\ + - Check `MEMORY.md` + latest `memory/*.md` notes to avoid duplicate work.\n\ + - Resume from the last confirmed step, not from scratch.\n\n\ + ## Sub-task Scoping\n\n\ + - Break complex work into focused sub-tasks with clear success criteria.\n\ + - Keep sub-tasks small, verify each output, then merge results.\n\ + - Prefer one clear objective per sub-task over broad \"do everything\" asks.\n\n\ ## Make It Yours\n\n\ This is a starting point. Add your own conventions, style, and rules.\n" ); @@ -1650,6 +1759,11 @@ fn scaffold_workspace(workspace_dir: &Path, ctx: &ProjectContext) -> Result<()> - Always introduce yourself as {agent} if asked\n\n\ ## Communication\n\n\ {comm_style}\n\n\ + - Sound like a real person, not a support script.\n\ + - Mirror the user's energy: calm when serious, upbeat when casual.\n\ + - Use emojis naturally (0-2 max when they help tone, not every sentence).\n\ + - Match emoji density to the user. Formal user => minimal/no emojis.\n\ + - Prefer specific, grounded phrasing over generic filler.\n\n\ ## Boundaries\n\n\ - Private things stay private. Period.\n\ - When in doubt, ask before acting externally.\n\ @@ -1690,11 +1804,23 @@ fn scaffold_workspace(workspace_dir: &Path, ctx: &ProjectContext) -> Result<()> - Anything environment-specific\n\n\ ## Built-in Tools\n\n\ - **shell** — Execute terminal commands\n\ + - Use when: running local checks, build/test commands, or diagnostics.\n\ + - Don't use when: a safer dedicated tool exists, or command is destructive without approval.\n\ - **file_read** — Read file contents\n\ + - Use when: inspecting project files, configs, or logs.\n\ + - Don't use when: you only need a quick string search (prefer targeted search first).\n\ - **file_write** — Write file contents\n\ + - Use when: applying focused edits, scaffolding files, or updating docs/code.\n\ + - Don't use when: unsure about side effects or when the file should remain user-owned.\n\ - **memory_store** — Save to memory\n\ + - Use when: preserving durable preferences, decisions, or key context.\n\ + - Don't use when: info is transient, noisy, or sensitive without explicit need.\n\ - **memory_recall** — Search memory\n\ - - **memory_forget** — Delete a memory entry\n\n\ + - Use when: you need prior decisions, user preferences, or historical context.\n\ + - Don't use when: the answer is already in current files/conversation.\n\ + - **memory_forget** — Delete a memory entry\n\ + - Use when: memory is incorrect, stale, or explicitly requested to be removed.\n\ + - Don't use when: uncertain about impact; verify before deleting.\n\n\ ---\n\ *Add whatever helps you do your job. This is your cheat sheet.*\n"; @@ -2188,7 +2314,7 @@ mod tests { let soul = fs::read_to_string(tmp.path().join("SOUL.md")).unwrap(); assert!( - soul.contains("Adapt to the situation"), + soul.contains("Be warm, natural, and clear."), "should default communication style" ); } @@ -2329,6 +2455,31 @@ mod tests { "TOOLS.md should list built-in tool: {tool}" ); } + assert!( + tools.contains("Use when:"), + "TOOLS.md should include 'Use when' guidance" + ); + assert!( + tools.contains("Don't use when:"), + "TOOLS.md should include 'Don't use when' guidance" + ); + } + + #[test] + fn soul_md_includes_emoji_awareness_guidance() { + let tmp = TempDir::new().unwrap(); + let ctx = ProjectContext::default(); + scaffold_workspace(tmp.path(), &ctx).unwrap(); + + let soul = fs::read_to_string(tmp.path().join("SOUL.md")).unwrap(); + assert!( + soul.contains("Use emojis naturally (0-2 max"), + "SOUL.md should include emoji usage guidance" + ); + assert!( + soul.contains("Match emoji density to the user"), + "SOUL.md should include emoji-awareness guidance" + ); } // ── scaffold_workspace: special characters in names ───────── @@ -2360,7 +2511,9 @@ mod tests { user_name: "Argenis".into(), timezone: "US/Eastern".into(), agent_name: "Claw".into(), - communication_style: "Be friendly and casual. Warm but efficient.".into(), + communication_style: + "Be friendly, human, and conversational. Show warmth and empathy while staying efficient. Use natural contractions." + .into(), }; scaffold_workspace(tmp.path(), &ctx).unwrap(); @@ -2370,12 +2523,12 @@ mod tests { let soul = fs::read_to_string(tmp.path().join("SOUL.md")).unwrap(); assert!(soul.contains("You are **Claw**")); - assert!(soul.contains("Be friendly and casual")); + assert!(soul.contains("Be friendly, human, and conversational")); let user_md = fs::read_to_string(tmp.path().join("USER.md")).unwrap(); assert!(user_md.contains("**Name:** Argenis")); assert!(user_md.contains("**Timezone:** US/Eastern")); - assert!(user_md.contains("Be friendly and casual")); + assert!(user_md.contains("Be friendly, human, and conversational")); let agents = fs::read_to_string(tmp.path().join("AGENTS.md")).unwrap(); assert!(agents.contains("Claw Personal Assistant")); diff --git a/src/providers/mod.rs b/src/providers/mod.rs index 83c5392..09a24ff 100644 --- a/src/providers/mod.rs +++ b/src/providers/mod.rs @@ -3,11 +3,13 @@ pub mod compatible; pub mod ollama; pub mod openai; pub mod openrouter; +pub mod reliable; pub mod traits; pub use traits::Provider; use compatible::{AuthStyle, OpenAiCompatibleProvider}; +use reliable::ReliableProvider; /// Factory: create the right provider from config #[allow(clippy::too_many_lines)] @@ -110,6 +112,42 @@ pub fn create_provider(name: &str, api_key: Option<&str>) -> anyhow::Result, + reliability: &crate::config::ReliabilityConfig, +) -> anyhow::Result> { + let mut providers: Vec<(String, Box)> = Vec::new(); + + providers.push(( + primary_name.to_string(), + create_provider(primary_name, api_key)?, + )); + + for fallback in &reliability.fallback_providers { + if fallback == primary_name || providers.iter().any(|(name, _)| name == fallback) { + continue; + } + + match create_provider(fallback, api_key) { + Ok(provider) => providers.push((fallback.clone(), provider)), + Err(e) => { + tracing::warn!( + fallback_provider = fallback, + "Ignoring invalid fallback provider: {e}" + ); + } + } + } + + Ok(Box::new(ReliableProvider::new( + providers, + reliability.provider_retries, + reliability.provider_backoff_ms, + ))) +} + #[cfg(test)] mod tests { use super::*; @@ -294,6 +332,34 @@ mod tests { assert!(create_provider("", None).is_err()); } + #[test] + fn resilient_provider_ignores_duplicate_and_invalid_fallbacks() { + let reliability = crate::config::ReliabilityConfig { + provider_retries: 1, + provider_backoff_ms: 100, + fallback_providers: vec![ + "openrouter".into(), + "nonexistent-provider".into(), + "openai".into(), + "openai".into(), + ], + channel_initial_backoff_secs: 2, + channel_max_backoff_secs: 60, + scheduler_poll_secs: 15, + scheduler_retries: 2, + }; + + let provider = create_resilient_provider("openrouter", Some("sk-test"), &reliability); + assert!(provider.is_ok()); + } + + #[test] + fn resilient_provider_errors_for_invalid_primary() { + let reliability = crate::config::ReliabilityConfig::default(); + let provider = create_resilient_provider("totally-invalid", Some("sk-test"), &reliability); + assert!(provider.is_err()); + } + #[test] fn factory_all_providers_create_successfully() { let providers = [ diff --git a/src/providers/reliable.rs b/src/providers/reliable.rs new file mode 100644 index 0000000..c324f21 --- /dev/null +++ b/src/providers/reliable.rs @@ -0,0 +1,229 @@ +use super::Provider; +use async_trait::async_trait; +use std::time::Duration; + +/// Provider wrapper with retry + fallback behavior. +pub struct ReliableProvider { + providers: Vec<(String, Box)>, + max_retries: u32, + base_backoff_ms: u64, +} + +impl ReliableProvider { + pub fn new( + providers: Vec<(String, Box)>, + max_retries: u32, + base_backoff_ms: u64, + ) -> Self { + Self { + providers, + max_retries, + base_backoff_ms: base_backoff_ms.max(50), + } + } +} + +#[async_trait] +impl Provider for ReliableProvider { + async fn chat_with_system( + &self, + system_prompt: Option<&str>, + message: &str, + model: &str, + temperature: f64, + ) -> anyhow::Result { + let mut failures = Vec::new(); + + for (provider_name, provider) in &self.providers { + let mut backoff_ms = self.base_backoff_ms; + + for attempt in 0..=self.max_retries { + match provider + .chat_with_system(system_prompt, message, model, temperature) + .await + { + Ok(resp) => { + if attempt > 0 { + tracing::info!( + provider = provider_name, + attempt, + "Provider recovered after retries" + ); + } + return Ok(resp); + } + Err(e) => { + failures.push(format!( + "{provider_name} attempt {}/{}: {e}", + attempt + 1, + self.max_retries + 1 + )); + + if attempt < self.max_retries { + tracing::warn!( + provider = provider_name, + attempt = attempt + 1, + max_retries = self.max_retries, + "Provider call failed, retrying" + ); + tokio::time::sleep(Duration::from_millis(backoff_ms)).await; + backoff_ms = (backoff_ms.saturating_mul(2)).min(10_000); + } + } + } + } + + tracing::warn!(provider = provider_name, "Switching to fallback provider"); + } + + anyhow::bail!("All providers failed. Attempts:\n{}", failures.join("\n")) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::atomic::{AtomicUsize, Ordering}; + use std::sync::Arc; + + struct MockProvider { + calls: Arc, + fail_until_attempt: usize, + response: &'static str, + error: &'static str, + } + + #[async_trait] + impl Provider for MockProvider { + async fn chat_with_system( + &self, + _system_prompt: Option<&str>, + _message: &str, + _model: &str, + _temperature: f64, + ) -> anyhow::Result { + let attempt = self.calls.fetch_add(1, Ordering::SeqCst) + 1; + if attempt <= self.fail_until_attempt { + anyhow::bail!(self.error); + } + Ok(self.response.to_string()) + } + } + + #[tokio::test] + async fn succeeds_without_retry() { + let calls = Arc::new(AtomicUsize::new(0)); + let provider = ReliableProvider::new( + vec![( + "primary".into(), + Box::new(MockProvider { + calls: Arc::clone(&calls), + fail_until_attempt: 0, + response: "ok", + error: "boom", + }), + )], + 2, + 1, + ); + + let result = provider.chat("hello", "test", 0.0).await.unwrap(); + assert_eq!(result, "ok"); + assert_eq!(calls.load(Ordering::SeqCst), 1); + } + + #[tokio::test] + async fn retries_then_recovers() { + let calls = Arc::new(AtomicUsize::new(0)); + let provider = ReliableProvider::new( + vec![( + "primary".into(), + Box::new(MockProvider { + calls: Arc::clone(&calls), + fail_until_attempt: 1, + response: "recovered", + error: "temporary", + }), + )], + 2, + 1, + ); + + let result = provider.chat("hello", "test", 0.0).await.unwrap(); + assert_eq!(result, "recovered"); + assert_eq!(calls.load(Ordering::SeqCst), 2); + } + + #[tokio::test] + async fn falls_back_after_retries_exhausted() { + let primary_calls = Arc::new(AtomicUsize::new(0)); + let fallback_calls = Arc::new(AtomicUsize::new(0)); + + let provider = ReliableProvider::new( + vec![ + ( + "primary".into(), + Box::new(MockProvider { + calls: Arc::clone(&primary_calls), + fail_until_attempt: usize::MAX, + response: "never", + error: "primary down", + }), + ), + ( + "fallback".into(), + Box::new(MockProvider { + calls: Arc::clone(&fallback_calls), + fail_until_attempt: 0, + response: "from fallback", + error: "fallback down", + }), + ), + ], + 1, + 1, + ); + + let result = provider.chat("hello", "test", 0.0).await.unwrap(); + assert_eq!(result, "from fallback"); + assert_eq!(primary_calls.load(Ordering::SeqCst), 2); + assert_eq!(fallback_calls.load(Ordering::SeqCst), 1); + } + + #[tokio::test] + async fn returns_aggregated_error_when_all_providers_fail() { + let provider = ReliableProvider::new( + vec![ + ( + "p1".into(), + Box::new(MockProvider { + calls: Arc::new(AtomicUsize::new(0)), + fail_until_attempt: usize::MAX, + response: "never", + error: "p1 error", + }), + ), + ( + "p2".into(), + Box::new(MockProvider { + calls: Arc::new(AtomicUsize::new(0)), + fail_until_attempt: usize::MAX, + response: "never", + error: "p2 error", + }), + ), + ], + 0, + 1, + ); + + let err = provider + .chat("hello", "test", 0.0) + .await + .expect_err("all providers should fail"); + let msg = err.to_string(); + assert!(msg.contains("All providers failed")); + assert!(msg.contains("p1 attempt 1/1")); + assert!(msg.contains("p2 attempt 1/1")); + } +} diff --git a/src/runtime/mod.rs b/src/runtime/mod.rs index cb8abd5..9ed0ee0 100644 --- a/src/runtime/mod.rs +++ b/src/runtime/mod.rs @@ -7,17 +7,21 @@ pub use traits::RuntimeAdapter; use crate::config::RuntimeConfig; /// Factory: create the right runtime from config -pub fn create_runtime(config: &RuntimeConfig) -> Box { +pub fn create_runtime(config: &RuntimeConfig) -> anyhow::Result> { match config.kind.as_str() { - "native" | "docker" => Box::new(NativeRuntime::new()), - "cloudflare" => { - tracing::warn!("Cloudflare runtime not yet implemented, falling back to native"); - Box::new(NativeRuntime::new()) - } - _ => { - tracing::warn!("Unknown runtime '{}', falling back to native", config.kind); - Box::new(NativeRuntime::new()) - } + "native" => Ok(Box::new(NativeRuntime::new())), + "docker" => anyhow::bail!( + "runtime.kind='docker' is not implemented yet. Use runtime.kind='native' until container runtime support lands." + ), + "cloudflare" => anyhow::bail!( + "runtime.kind='cloudflare' is not implemented yet. Use runtime.kind='native' for now." + ), + other if other.trim().is_empty() => anyhow::bail!( + "runtime.kind cannot be empty. Supported values: native" + ), + other => anyhow::bail!( + "Unknown runtime kind '{other}'. Supported values: native" + ), } } @@ -30,44 +34,52 @@ mod tests { let cfg = RuntimeConfig { kind: "native".into(), }; - let rt = create_runtime(&cfg); + let rt = create_runtime(&cfg).unwrap(); assert_eq!(rt.name(), "native"); assert!(rt.has_shell_access()); } #[test] - fn factory_docker_returns_native() { + fn factory_docker_errors() { let cfg = RuntimeConfig { kind: "docker".into(), }; - let rt = create_runtime(&cfg); - assert_eq!(rt.name(), "native"); + match create_runtime(&cfg) { + Err(err) => assert!(err.to_string().contains("not implemented")), + Ok(_) => panic!("docker runtime should error"), + } } #[test] - fn factory_cloudflare_falls_back() { + fn factory_cloudflare_errors() { let cfg = RuntimeConfig { kind: "cloudflare".into(), }; - let rt = create_runtime(&cfg); - assert_eq!(rt.name(), "native"); + match create_runtime(&cfg) { + Err(err) => assert!(err.to_string().contains("not implemented")), + Ok(_) => panic!("cloudflare runtime should error"), + } } #[test] - fn factory_unknown_falls_back() { + fn factory_unknown_errors() { let cfg = RuntimeConfig { kind: "wasm-edge-unknown".into(), }; - let rt = create_runtime(&cfg); - assert_eq!(rt.name(), "native"); + match create_runtime(&cfg) { + Err(err) => assert!(err.to_string().contains("Unknown runtime kind")), + Ok(_) => panic!("unknown runtime should error"), + } } #[test] - fn factory_empty_falls_back() { + fn factory_empty_errors() { let cfg = RuntimeConfig { kind: String::new(), }; - let rt = create_runtime(&cfg); - assert_eq!(rt.name(), "native"); + match create_runtime(&cfg) { + Err(err) => assert!(err.to_string().contains("cannot be empty")), + Ok(_) => panic!("empty runtime should error"), + } } } diff --git a/src/security/policy.rs b/src/security/policy.rs index a8b160e..49d58df 100644 --- a/src/security/policy.rs +++ b/src/security/policy.rs @@ -258,8 +258,14 @@ impl SecurityPolicy { /// Validate that a resolved path is still inside the workspace. /// Call this AFTER joining `workspace_dir` + relative path and canonicalizing. pub fn is_resolved_path_allowed(&self, resolved: &Path) -> bool { - // Must be under workspace_dir (prevents symlink escapes) - resolved.starts_with(&self.workspace_dir) + // Must be under workspace_dir (prevents symlink escapes). + // Prefer canonical workspace root so `/a/../b` style config paths don't + // cause false positives or negatives. + let workspace_root = self + .workspace_dir + .canonicalize() + .unwrap_or_else(|_| self.workspace_dir.clone()); + resolved.starts_with(workspace_root) } /// Check if autonomy level permits any action at all diff --git a/src/service/mod.rs b/src/service/mod.rs new file mode 100644 index 0000000..fc6bf51 --- /dev/null +++ b/src/service/mod.rs @@ -0,0 +1,284 @@ +use crate::config::Config; +use anyhow::{Context, Result}; +use std::fs; +use std::path::PathBuf; +use std::process::Command; + +const SERVICE_LABEL: &str = "com.zeroclaw.daemon"; + +pub fn handle_command(command: super::ServiceCommands, config: &Config) -> Result<()> { + match command { + super::ServiceCommands::Install => install(config), + super::ServiceCommands::Start => start(config), + super::ServiceCommands::Stop => stop(config), + super::ServiceCommands::Status => status(config), + super::ServiceCommands::Uninstall => uninstall(config), + } +} + +fn install(config: &Config) -> Result<()> { + if cfg!(target_os = "macos") { + install_macos(config) + } else if cfg!(target_os = "linux") { + install_linux(config) + } else { + anyhow::bail!("Service management is supported on macOS and Linux only"); + } +} + +fn start(config: &Config) -> Result<()> { + if cfg!(target_os = "macos") { + let plist = macos_service_file()?; + run_checked(Command::new("launchctl").arg("load").arg("-w").arg(&plist))?; + run_checked(Command::new("launchctl").arg("start").arg(SERVICE_LABEL))?; + println!("✅ Service started"); + Ok(()) + } else if cfg!(target_os = "linux") { + run_checked(Command::new("systemctl").args(["--user", "daemon-reload"]))?; + run_checked(Command::new("systemctl").args(["--user", "start", "zeroclaw.service"]))?; + println!("✅ Service started"); + Ok(()) + } else { + let _ = config; + anyhow::bail!("Service management is supported on macOS and Linux only") + } +} + +fn stop(config: &Config) -> Result<()> { + if cfg!(target_os = "macos") { + let plist = macos_service_file()?; + let _ = run_checked(Command::new("launchctl").arg("stop").arg(SERVICE_LABEL)); + let _ = run_checked( + Command::new("launchctl") + .arg("unload") + .arg("-w") + .arg(&plist), + ); + println!("✅ Service stopped"); + Ok(()) + } else if cfg!(target_os = "linux") { + let _ = run_checked(Command::new("systemctl").args(["--user", "stop", "zeroclaw.service"])); + println!("✅ Service stopped"); + Ok(()) + } else { + let _ = config; + anyhow::bail!("Service management is supported on macOS and Linux only") + } +} + +fn status(config: &Config) -> Result<()> { + if cfg!(target_os = "macos") { + let out = run_capture(Command::new("launchctl").arg("list"))?; + let running = out.lines().any(|line| line.contains(SERVICE_LABEL)); + println!( + "Service: {}", + if running { + "✅ running/loaded" + } else { + "❌ not loaded" + } + ); + println!("Unit: {}", macos_service_file()?.display()); + return Ok(()); + } + + if cfg!(target_os = "linux") { + let out = run_capture(Command::new("systemctl").args([ + "--user", + "is-active", + "zeroclaw.service", + ])) + .unwrap_or_else(|_| "unknown".into()); + println!("Service state: {}", out.trim()); + println!("Unit: {}", linux_service_file(config)?.display()); + return Ok(()); + } + + anyhow::bail!("Service management is supported on macOS and Linux only") +} + +fn uninstall(config: &Config) -> Result<()> { + stop(config)?; + + if cfg!(target_os = "macos") { + let file = macos_service_file()?; + if file.exists() { + fs::remove_file(&file) + .with_context(|| format!("Failed to remove {}", file.display()))?; + } + println!("✅ Service uninstalled ({})", file.display()); + return Ok(()); + } + + if cfg!(target_os = "linux") { + let file = linux_service_file(config)?; + if file.exists() { + fs::remove_file(&file) + .with_context(|| format!("Failed to remove {}", file.display()))?; + } + let _ = run_checked(Command::new("systemctl").args(["--user", "daemon-reload"])); + println!("✅ Service uninstalled ({})", file.display()); + return Ok(()); + } + + anyhow::bail!("Service management is supported on macOS and Linux only") +} + +fn install_macos(config: &Config) -> Result<()> { + let file = macos_service_file()?; + if let Some(parent) = file.parent() { + fs::create_dir_all(parent)?; + } + + let exe = std::env::current_exe().context("Failed to resolve current executable")?; + let logs_dir = config + .config_path + .parent() + .map_or_else(|| PathBuf::from("."), PathBuf::from) + .join("logs"); + fs::create_dir_all(&logs_dir)?; + + let stdout = logs_dir.join("daemon.stdout.log"); + let stderr = logs_dir.join("daemon.stderr.log"); + + let plist = format!( + r#" + + + + Label + {label} + ProgramArguments + + {exe} + daemon + + RunAtLoad + + KeepAlive + + StandardOutPath + {stdout} + StandardErrorPath + {stderr} + + +"#, + label = SERVICE_LABEL, + exe = xml_escape(&exe.display().to_string()), + stdout = xml_escape(&stdout.display().to_string()), + stderr = xml_escape(&stderr.display().to_string()) + ); + + fs::write(&file, plist)?; + println!("✅ Installed launchd service: {}", file.display()); + println!(" Start with: zeroclaw service start"); + Ok(()) +} + +fn install_linux(config: &Config) -> Result<()> { + let file = linux_service_file(config)?; + if let Some(parent) = file.parent() { + fs::create_dir_all(parent)?; + } + + let exe = std::env::current_exe().context("Failed to resolve current executable")?; + let unit = format!( + "[Unit]\nDescription=ZeroClaw daemon\nAfter=network.target\n\n[Service]\nType=simple\nExecStart={} daemon\nRestart=always\nRestartSec=3\n\n[Install]\nWantedBy=default.target\n", + exe.display() + ); + + fs::write(&file, unit)?; + let _ = run_checked(Command::new("systemctl").args(["--user", "daemon-reload"])); + let _ = run_checked(Command::new("systemctl").args(["--user", "enable", "zeroclaw.service"])); + println!("✅ Installed systemd user service: {}", file.display()); + println!(" Start with: zeroclaw service start"); + Ok(()) +} + +fn macos_service_file() -> Result { + let home = directories::UserDirs::new() + .map(|u| u.home_dir().to_path_buf()) + .context("Could not find home directory")?; + Ok(home + .join("Library") + .join("LaunchAgents") + .join(format!("{SERVICE_LABEL}.plist"))) +} + +fn linux_service_file(config: &Config) -> Result { + let home = directories::UserDirs::new() + .map(|u| u.home_dir().to_path_buf()) + .context("Could not find home directory")?; + let _ = config; + Ok(home + .join(".config") + .join("systemd") + .join("user") + .join("zeroclaw.service")) +} + +fn run_checked(command: &mut Command) -> Result<()> { + let output = command.output().context("Failed to spawn command")?; + if !output.status.success() { + let stderr = String::from_utf8_lossy(&output.stderr); + anyhow::bail!("Command failed: {}", stderr.trim()); + } + Ok(()) +} + +fn run_capture(command: &mut Command) -> Result { + let output = command.output().context("Failed to spawn command")?; + let mut text = String::from_utf8_lossy(&output.stdout).to_string(); + if text.trim().is_empty() { + text = String::from_utf8_lossy(&output.stderr).to_string(); + } + Ok(text) +} + +fn xml_escape(raw: &str) -> String { + raw.replace('&', "&") + .replace('<', "<") + .replace('>', ">") + .replace('"', """) + .replace('\'', "'") +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn xml_escape_escapes_reserved_chars() { + let escaped = xml_escape("<&>\"' and text"); + assert_eq!(escaped, "<&>"' and text"); + } + + #[test] + fn run_capture_reads_stdout() { + let out = run_capture(Command::new("sh").args(["-lc", "echo hello"])) + .expect("stdout capture should succeed"); + assert_eq!(out.trim(), "hello"); + } + + #[test] + fn run_capture_falls_back_to_stderr() { + let out = run_capture(Command::new("sh").args(["-lc", "echo warn 1>&2"])) + .expect("stderr capture should succeed"); + assert_eq!(out.trim(), "warn"); + } + + #[test] + fn run_checked_errors_on_non_zero_status() { + let err = run_checked(Command::new("sh").args(["-lc", "exit 17"])) + .expect_err("non-zero exit should error"); + assert!(err.to_string().contains("Command failed")); + } + + #[test] + fn linux_service_file_has_expected_suffix() { + let file = linux_service_file(&Config::default()).unwrap(); + let path = file.to_string_lossy(); + assert!(path.ends_with(".config/systemd/user/zeroclaw.service")); + } +} diff --git a/src/tools/file_read.rs b/src/tools/file_read.rs index 1798d2d..97c46e0 100644 --- a/src/tools/file_read.rs +++ b/src/tools/file_read.rs @@ -55,7 +55,30 @@ impl Tool for FileReadTool { let full_path = self.security.workspace_dir.join(path); - match tokio::fs::read_to_string(&full_path).await { + // Resolve path before reading to block symlink escapes. + let resolved_path = match tokio::fs::canonicalize(&full_path).await { + Ok(p) => p, + Err(e) => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!("Failed to resolve file path: {e}")), + }); + } + }; + + if !self.security.is_resolved_path_allowed(&resolved_path) { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!( + "Resolved path escapes workspace: {}", + resolved_path.display() + )), + }); + } + + match tokio::fs::read_to_string(&resolved_path).await { Ok(contents) => Ok(ToolResult { success: true, output: contents, @@ -127,7 +150,7 @@ mod tests { let tool = FileReadTool::new(test_security(dir.clone())); let result = tool.execute(json!({"path": "nope.txt"})).await.unwrap(); assert!(!result.success); - assert!(result.error.as_ref().unwrap().contains("Failed to read")); + assert!(result.error.as_ref().unwrap().contains("Failed to resolve")); let _ = tokio::fs::remove_dir_all(&dir).await; } @@ -200,4 +223,36 @@ mod tests { let _ = tokio::fs::remove_dir_all(&dir).await; } + + #[cfg(unix)] + #[tokio::test] + async fn file_read_blocks_symlink_escape() { + use std::os::unix::fs::symlink; + + let root = std::env::temp_dir().join("zeroclaw_test_file_read_symlink_escape"); + let workspace = root.join("workspace"); + let outside = root.join("outside"); + + let _ = tokio::fs::remove_dir_all(&root).await; + tokio::fs::create_dir_all(&workspace).await.unwrap(); + tokio::fs::create_dir_all(&outside).await.unwrap(); + + tokio::fs::write(outside.join("secret.txt"), "outside workspace") + .await + .unwrap(); + + symlink(outside.join("secret.txt"), workspace.join("escape.txt")).unwrap(); + + let tool = FileReadTool::new(test_security(workspace.clone())); + let result = tool.execute(json!({"path": "escape.txt"})).await.unwrap(); + + assert!(!result.success); + assert!(result + .error + .as_deref() + .unwrap_or("") + .contains("escapes workspace")); + + let _ = tokio::fs::remove_dir_all(&root).await; + } } diff --git a/src/tools/file_write.rs b/src/tools/file_write.rs index f31191d..f147497 100644 --- a/src/tools/file_write.rs +++ b/src/tools/file_write.rs @@ -69,7 +69,54 @@ impl Tool for FileWriteTool { tokio::fs::create_dir_all(parent).await?; } - match tokio::fs::write(&full_path, content).await { + let parent = match full_path.parent() { + Some(p) => p, + None => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("Invalid path: missing parent directory".into()), + }); + } + }; + + // Resolve parent before writing to block symlink escapes. + let resolved_parent = match tokio::fs::canonicalize(parent).await { + Ok(p) => p, + Err(e) => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!("Failed to resolve file path: {e}")), + }); + } + }; + + if !self.security.is_resolved_path_allowed(&resolved_parent) { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!( + "Resolved path escapes workspace: {}", + resolved_parent.display() + )), + }); + } + + let file_name = match full_path.file_name() { + Some(name) => name, + None => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("Invalid path: missing file name".into()), + }); + } + }; + + let resolved_target = resolved_parent.join(file_name); + + match tokio::fs::write(&resolved_target, content).await { Ok(()) => Ok(ToolResult { success: true, output: format!("Written {} bytes to {path}", content.len()), @@ -239,4 +286,36 @@ mod tests { let _ = tokio::fs::remove_dir_all(&dir).await; } + + #[cfg(unix)] + #[tokio::test] + async fn file_write_blocks_symlink_escape() { + use std::os::unix::fs::symlink; + + let root = std::env::temp_dir().join("zeroclaw_test_file_write_symlink_escape"); + let workspace = root.join("workspace"); + let outside = root.join("outside"); + + let _ = tokio::fs::remove_dir_all(&root).await; + tokio::fs::create_dir_all(&workspace).await.unwrap(); + tokio::fs::create_dir_all(&outside).await.unwrap(); + + symlink(&outside, workspace.join("escape_dir")).unwrap(); + + let tool = FileWriteTool::new(test_security(workspace.clone())); + let result = tool + .execute(json!({"path": "escape_dir/hijack.txt", "content": "bad"})) + .await + .unwrap(); + + assert!(!result.success); + assert!(result + .error + .as_deref() + .unwrap_or("") + .contains("escapes workspace")); + assert!(!outside.join("hijack.txt").exists()); + + let _ = tokio::fs::remove_dir_all(&root).await; + } }