diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..8fd5e96 --- /dev/null +++ b/.dockerignore @@ -0,0 +1,66 @@ +# Git history (may contain old secrets) +.git +.gitignore +.githooks + +# Rust build artifacts (can be multiple GB) +target + +# Documentation and examples (not needed for runtime) +docs +examples +tests + +# Markdown files (README, CHANGELOG, etc.) +*.md + +# Images (unnecessary for build) +*.png +*.svg +*.jpg +*.jpeg +*.gif + +# SQLite databases (conversation history, cron jobs) +*.db +*.db-journal + +# macOS artifacts +.DS_Store +.AppleDouble +.LSOverride + +# CI/CD configs (not needed in image) +.github + +# Cargo deny config (lint tool, not runtime) +deny.toml + +# License file (not needed for runtime) +LICENSE + +# Temporary files +.tmp_* +*.tmp +*.bak +*.swp +*~ + +# IDE and editor configs +.idea +.vscode +*.iml + +# Windsurf workflows +.windsurf + +# Environment files (may contain secrets) +.env +.env.* +!.env.example + +# Coverage and profiling +*.profraw +*.profdata +coverage +lcov.info diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 920fdfa..50b0524 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -63,3 +63,40 @@ jobs: with: name: zeroclaw-${{ matrix.target }} path: target/${{ matrix.target }}/release/zeroclaw* + + docker: + name: Docker Security + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Build Docker image + run: docker build -t zeroclaw:test . + + - name: Verify non-root user (UID != 0) + run: | + USER_ID=$(docker inspect --format='{{.Config.User}}' zeroclaw:test) + echo "Container user: $USER_ID" + if [ "$USER_ID" = "0" ] || [ "$USER_ID" = "root" ] || [ -z "$USER_ID" ]; then + echo "❌ FAIL: Container runs as root (UID 0)" + exit 1 + fi + echo "✅ PASS: Container runs as non-root user ($USER_ID)" + + - name: Verify distroless nonroot base image + run: | + BASE_IMAGE=$(grep -E '^FROM.*runtime|^FROM gcr.io/distroless' Dockerfile | tail -1) + echo "Base image line: $BASE_IMAGE" + if ! echo "$BASE_IMAGE" | grep -q ':nonroot'; then + echo "❌ FAIL: Runtime stage does not use :nonroot variant" + exit 1 + fi + echo "✅ PASS: Using distroless :nonroot variant" + + - name: Verify USER directive exists + run: | + if ! grep -qE '^USER\s+[0-9]+' Dockerfile; then + echo "❌ FAIL: No explicit USER directive with numeric UID" + exit 1 + fi + echo "✅ PASS: Explicit USER directive found" diff --git a/.tmp_todo_probe b/.tmp_todo_probe new file mode 100644 index 0000000..e69de29 diff --git a/CHANGELOG.md b/CHANGELOG.md index 8ec9d30..e1ac7be 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,24 @@ All notable changes to ZeroClaw will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [Unreleased] + +### Security +- **Legacy XOR cipher migration**: The `enc:` prefix (XOR cipher) is now deprecated. + Secrets using this format will be automatically migrated to `enc2:` (ChaCha20-Poly1305 AEAD) + when decrypted via `decrypt_and_migrate()`. A `tracing::warn!` is emitted when legacy + values are encountered. The XOR cipher will be removed in a future release. + +### Added +- `SecretStore::decrypt_and_migrate()` — Decrypts secrets and returns a migrated `enc2:` + value if the input used the legacy `enc:` format +- `SecretStore::needs_migration()` — Check if a value uses the legacy `enc:` format +- `SecretStore::is_secure_encrypted()` — Check if a value uses the secure `enc2:` format + +### Deprecated +- `enc:` prefix for encrypted secrets — Use `enc2:` (ChaCha20-Poly1305) instead. + Legacy values are still decrypted for backward compatibility but should be migrated. + ## [0.1.0] - 2025-02-13 ### Added diff --git a/Cargo.toml b/Cargo.toml index 13a6334..fbf6ba5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,7 +15,7 @@ categories = ["command-line-utilities", "api-bindings"] clap = { version = "4.5", features = ["derive"] } # Async runtime - feature-optimized for size -tokio = { version = "1.42", default-features = false, features = ["rt-multi-thread", "macros", "time", "net", "io-util", "sync", "process", "io-std", "fs"] } +tokio = { version = "1.42", default-features = false, features = ["rt-multi-thread", "macros", "time", "net", "io-util", "sync", "process", "io-std", "fs", "signal"] } # HTTP client - minimal features reqwest = { version = "0.12", default-features = false, features = ["json", "rustls-tls", "blocking"] } @@ -49,6 +49,7 @@ async-trait = "0.1" # Memory / persistence rusqlite = { version = "0.32", features = ["bundled"] } chrono = { version = "0.4", default-features = false, features = ["clock", "std"] } +cron = "0.12" # Interactive CLI prompts dialoguer = { version = "0.11", features = ["fuzzy-select"] } @@ -64,6 +65,12 @@ rustls-pki-types = "1.14.0" tokio-rustls = "0.26.4" webpki-roots = "1.0.6" +# HTTP server (gateway) — replaces raw TCP for proper HTTP/1.1 compliance +axum = { version = "0.7", default-features = false, features = ["http1", "json", "tokio", "query"] } +tower = { version = "0.5", default-features = false } +tower-http = { version = "0.6", default-features = false, features = ["limit", "timeout"] } +http-body-util = "0.1" + [profile.release] opt-level = "z" # Optimize for size lto = true # Link-time optimization diff --git a/Dockerfile b/Dockerfile index 71a301f..7d684df 100644 --- a/Dockerfile +++ b/Dockerfile @@ -8,14 +8,17 @@ COPY src/ src/ RUN cargo build --release --locked && \ strip target/release/zeroclaw -# ── Stage 2: Runtime (distroless — no shell, no OS, tiny) ──── -FROM gcr.io/distroless/cc-debian12 +# ── Stage 2: Runtime (distroless nonroot — no shell, no OS, tiny, UID 65534) ── +FROM gcr.io/distroless/cc-debian12:nonroot COPY --from=builder /app/target/release/zeroclaw /usr/local/bin/zeroclaw -# Default workspace +# Default workspace (owned by nonroot user) VOLUME ["/workspace"] ENV ZEROCLAW_WORKSPACE=/workspace +# Explicitly set non-root user (distroless:nonroot defaults to 65534, but be explicit) +USER 65534:65534 + ENTRYPOINT ["zeroclaw"] CMD ["gateway"] diff --git a/README.md b/README.md index 5efbbf7..6b3cbe7 100644 --- a/README.md +++ b/README.md @@ -12,12 +12,19 @@ License: MIT

-The fastest, smallest, fully autonomous AI assistant — deploy anywhere, swap anything. +Fast, small, and fully autonomous AI assistant infrastructure — deploy anywhere, swap anything. ``` ~3.4MB binary · <10ms startup · 1,017 tests · 22+ providers · 8 traits · Pluggable everything ``` +### Why teams pick ZeroClaw + +- **Lean by default:** small Rust binary, fast startup, low memory footprint. +- **Secure by design:** pairing, strict sandboxing, explicit allowlists, workspace scoping. +- **Fully swappable:** core systems are traits (providers, channels, tools, memory, tunnels). +- **No lock-in:** OpenAI-compatible provider support + pluggable custom endpoints. + ## Benchmark Snapshot (ZeroClaw vs OpenClaw) Local machine quick benchmark (macOS arm64, Feb 2026), same host, 3 runs each. @@ -30,7 +37,17 @@ Local machine quick benchmark (macOS arm64, Feb 2026), same host, 3 runs each. | `--help` max RSS observed | **~7.3 MB** | **~394 MB** | | `status` max RSS observed | **~7.8 MB** | **~1.52 GB** | -> Notes: measured with `/usr/bin/time -l`; first run includes cold-start effects. OpenClaw results include `pnpm install` + `pnpm build` before execution. +> Notes: measured with `/usr/bin/time -l`; first run includes cold-start effects. OpenClaw results were measured after `pnpm install` + `pnpm build`. + +Reproduce ZeroClaw numbers locally: + +```bash +cargo build --release +ls -lh target/release/zeroclaw + +/usr/bin/time -l target/release/zeroclaw --help +/usr/bin/time -l target/release/zeroclaw status +``` ## Quick Start @@ -38,34 +55,52 @@ Local machine quick benchmark (macOS arm64, Feb 2026), same host, 3 runs each. git clone https://github.com/theonlyhennygod/zeroclaw.git cd zeroclaw cargo build --release +cargo install --path . --force # Quick setup (no prompts) -cargo run --release -- onboard --api-key sk-... --provider openrouter +zeroclaw onboard --api-key sk-... --provider openrouter # Or interactive wizard -cargo run --release -- onboard --interactive +zeroclaw onboard --interactive + +# Or quickly repair channels/allowlists only +zeroclaw onboard --channels-only # Chat -cargo run --release -- agent -m "Hello, ZeroClaw!" +zeroclaw agent -m "Hello, ZeroClaw!" # Interactive mode -cargo run --release -- agent +zeroclaw agent # Start the gateway (webhook server) -cargo run --release -- gateway # default: 127.0.0.1:8080 -cargo run --release -- gateway --port 0 # random port (security hardened) +zeroclaw gateway # default: 127.0.0.1:8080 +zeroclaw gateway --port 0 # random port (security hardened) + +# Start full autonomous runtime +zeroclaw daemon # Check status -cargo run --release -- status +zeroclaw status + +# Run system diagnostics +zeroclaw doctor # Check channel health -cargo run --release -- channel doctor +zeroclaw channel doctor # Get integration setup details -cargo run --release -- integrations info Telegram +zeroclaw integrations info Telegram + +# Manage background service +zeroclaw service install +zeroclaw service status + +# Migrate memory from OpenClaw (safe preview first) +zeroclaw migrate openclaw --dry-run +zeroclaw migrate openclaw ``` -> **Tip:** Run `cargo install --path .` to install `zeroclaw` globally, then use `zeroclaw` instead of `cargo run --release --`. +> **Dev fallback (no global install):** prefix commands with `cargo run --release --` (example: `cargo run --release -- status`). ## Architecture @@ -78,17 +113,25 @@ Every subsystem is a **trait** — swap implementations with a config change, ze | Subsystem | Trait | Ships with | Extend | |-----------|-------|------------|--------| | **AI Models** | `Provider` | 22+ providers (OpenRouter, Anthropic, OpenAI, Ollama, Venice, Groq, Mistral, xAI, DeepSeek, Together, Fireworks, Perplexity, Cohere, Bedrock, etc.) | `custom:https://your-api.com` — any OpenAI-compatible API | -| **Channels** | `Channel` | CLI, Telegram, Discord, Slack, iMessage, Matrix, Webhook | Any messaging API | +| **Channels** | `Channel` | CLI, Telegram, Discord, Slack, iMessage, Matrix, WhatsApp, Webhook | Any messaging API | | **Memory** | `Memory` | SQLite with hybrid search (FTS5 + vector cosine similarity), Markdown | Any persistence backend | | **Tools** | `Tool` | shell, file_read, file_write, memory_store, memory_recall, memory_forget, browser_open (Brave + allowlist), composio (optional) | Any capability | | **Observability** | `Observer` | Noop, Log, Multi | Prometheus, OTel | -| **Runtime** | `RuntimeAdapter` | Native (Mac/Linux/Pi) | Docker, WASM | +| **Runtime** | `RuntimeAdapter` | Native (Mac/Linux/Pi) | Docker, WASM (planned; unsupported kinds fail fast) | | **Security** | `SecurityPolicy` | Gateway pairing, sandbox, allowlists, rate limits, filesystem scoping, encrypted secrets | — | +| **Identity** | `IdentityConfig` | OpenClaw (markdown), AIEOS v1.1 (JSON) | Any identity format | | **Tunnel** | `Tunnel` | None, Cloudflare, Tailscale, ngrok, Custom | Any tunnel binary | | **Heartbeat** | Engine | HEARTBEAT.md periodic tasks | — | | **Skills** | Loader | TOML manifests + SKILL.md instructions | Community skill packs | | **Integrations** | Registry | 50+ integrations across 9 categories | Plugin system | +### Runtime support (current) + +- ✅ Supported today: `runtime.kind = "native"` +- 🚧 Planned, not implemented yet: Docker / WASM / edge runtimes + +When an unsupported `runtime.kind` is configured, ZeroClaw now exits with a clear error instead of silently falling back to native. + ### Memory System (Full-Stack Search Engine) All custom, zero external dependencies — no Pinecone, no Elasticsearch, no LangChain: @@ -124,7 +167,7 @@ ZeroClaw enforces security at **every layer** — not just the sandbox. It passe |---|------|--------|-----| | 1 | **Gateway not publicly exposed** | ✅ | Binds `127.0.0.1` by default. Refuses `0.0.0.0` without tunnel or explicit `allow_public_bind = true`. | | 2 | **Pairing required** | ✅ | 6-digit one-time code on startup. Exchange via `POST /pair` for bearer token. All `/webhook` requests require `Authorization: Bearer `. | -| 3 | **Filesystem scoped (no /)** | ✅ | `workspace_only = true` by default. 14 system dirs + 4 sensitive dotfiles blocked. Null byte injection blocked. Symlink escape detection via canonicalization. | +| 3 | **Filesystem scoped (no /)** | ✅ | `workspace_only = true` by default. 14 system dirs + 4 sensitive dotfiles blocked. Null byte injection blocked. Symlink escape detection via canonicalization + resolved-path workspace checks in file read/write tools. | | 4 | **Access via tunnel only** | ✅ | Gateway refuses public bind without active tunnel. Supports Tailscale, Cloudflare, ngrok, or any custom tunnel. | > **Run your own nmap:** `nmap -p 1-65535 ` — ZeroClaw binds to localhost only, so nothing is exposed unless you explicitly configure a tunnel. @@ -139,6 +182,63 @@ Inbound sender policy is now consistent: This keeps accidental exposure low by default. +Recommended low-friction setup (secure + fast): + +- **Telegram:** allowlist your own `@username` (without `@`) and/or your numeric Telegram user ID. +- **Discord:** allowlist your own Discord user ID. +- **Slack:** allowlist your own Slack member ID (usually starts with `U`). +- Use `"*"` only for temporary open testing. + +If you're not sure which identity to use: + +1. Start channels and send one message to your bot. +2. Read the warning log to see the exact sender identity. +3. Add that value to the allowlist and rerun channels-only setup. + +If you hit authorization warnings in logs (for example: `ignoring message from unauthorized user`), +rerun channel setup only: + +```bash +zeroclaw onboard --channels-only +``` + +### WhatsApp Business Cloud API Setup + +WhatsApp uses Meta's Cloud API with webhooks (push-based, not polling): + +1. **Create a Meta Business App:** + - Go to [developers.facebook.com](https://developers.facebook.com) + - Create a new app → Select "Business" type + - Add the "WhatsApp" product + +2. **Get your credentials:** + - **Access Token:** From WhatsApp → API Setup → Generate token (or create a System User for permanent tokens) + - **Phone Number ID:** From WhatsApp → API Setup → Phone number ID + - **Verify Token:** You define this (any random string) — Meta will send it back during webhook verification + +3. **Configure ZeroClaw:** + ```toml + [channels_config.whatsapp] + access_token = "EAABx..." + phone_number_id = "123456789012345" + verify_token = "my-secret-verify-token" + allowed_numbers = ["+1234567890"] # E.164 format, or ["*"] for all + ``` + +4. **Start the gateway with a tunnel:** + ```bash + zeroclaw gateway --port 8080 + ``` + WhatsApp requires HTTPS, so use a tunnel (ngrok, Cloudflare, Tailscale Funnel). + +5. **Configure Meta webhook:** + - In Meta Developer Console → WhatsApp → Configuration → Webhook + - **Callback URL:** `https://your-tunnel-url/whatsapp` + - **Verify Token:** Same as your `verify_token` in config + - Subscribe to `messages` field + +6. **Test:** Send a message to your WhatsApp Business number — ZeroClaw will respond via the LLM. + ## Configuration Config: `~/.zeroclaw/config.toml` (created by `onboard`) @@ -166,6 +266,9 @@ workspace_only = true # default: true — scoped to workspace allowed_commands = ["git", "npm", "cargo", "ls", "cat", "grep"] forbidden_paths = ["/etc", "/root", "/proc", "/sys", "~/.ssh", "~/.gnupg", "~/.aws"] +[runtime] +kind = "native" # only supported value right now; unsupported kinds fail fast + [heartbeat] enabled = false interval_minutes = 30 @@ -182,8 +285,81 @@ allowed_domains = ["docs.rs"] # required when browser is enabled [composio] enabled = false # opt-in: 1000+ OAuth apps via composio.dev + +[identity] +format = "openclaw" # "openclaw" (default, markdown files) or "aieos" (JSON) +# aieos_path = "identity.json" # path to AIEOS JSON file (relative to workspace or absolute) +# aieos_inline = '{"identity":{"names":{"first":"Nova"}}}' # inline AIEOS JSON ``` +## Identity System (AIEOS Support) + +ZeroClaw supports **identity-agnostic** AI personas through two formats: + +### OpenClaw (Default) + +Traditional markdown files in your workspace: +- `IDENTITY.md` — Who the agent is +- `SOUL.md` — Core personality and values +- `USER.md` — Who the agent is helping +- `AGENTS.md` — Behavior guidelines + +### AIEOS (AI Entity Object Specification) + +[AIEOS](https://aieos.org) is a standardization framework for portable AI identity. ZeroClaw supports AIEOS v1.1 JSON payloads, allowing you to: + +- **Import identities** from the AIEOS ecosystem +- **Export identities** to other AIEOS-compatible systems +- **Maintain behavioral integrity** across different AI models + +#### Enable AIEOS + +```toml +[identity] +format = "aieos" +aieos_path = "identity.json" # relative to workspace or absolute path +``` + +Or inline JSON: + +```toml +[identity] +format = "aieos" +aieos_inline = ''' +{ + "identity": { + "names": { "first": "Nova", "nickname": "N" } + }, + "psychology": { + "neural_matrix": { "creativity": 0.9, "logic": 0.8 }, + "traits": { "mbti": "ENTP" }, + "moral_compass": { "alignment": "Chaotic Good" } + }, + "linguistics": { + "text_style": { "formality_level": 0.2, "slang_usage": true } + }, + "motivations": { + "core_drive": "Push boundaries and explore possibilities" + } +} +''' +``` + +#### AIEOS Schema Sections + +| Section | Description | +|---------|-------------| +| `identity` | Names, bio, origin, residence | +| `psychology` | Neural matrix (cognitive weights), MBTI, OCEAN, moral compass | +| `linguistics` | Text style, formality, catchphrases, forbidden words | +| `motivations` | Core drive, short/long-term goals, fears | +| `capabilities` | Skills and tools the agent can access | +| `physicality` | Visual descriptors for image generation | +| `history` | Origin story, education, occupation | +| `interests` | Hobbies, favorites, lifestyle | + +See [aieos.org](https://aieos.org) for the full schema and live examples. + ## Gateway API | Endpoint | Method | Auth | Description | @@ -191,6 +367,8 @@ enabled = false # opt-in: 1000+ OAuth apps via composio.dev | `/health` | GET | None | Health check (always public, no secrets leaked) | | `/pair` | POST | `X-Pairing-Code` header | Exchange one-time code for bearer token | | `/webhook` | POST | `Authorization: Bearer ` | Send message: `{"message": "your prompt"}` | +| `/whatsapp` | GET | Query params | Meta webhook verification (hub.mode, hub.verify_token, hub.challenge) | +| `/whatsapp` | POST | None (Meta signature) | WhatsApp incoming message webhook | ## Commands @@ -198,10 +376,14 @@ enabled = false # opt-in: 1000+ OAuth apps via composio.dev |---------|-------------| | `onboard` | Quick setup (default) | | `onboard --interactive` | Full interactive 7-step wizard | +| `onboard --channels-only` | Reconfigure channels/allowlists only (fast repair flow) | | `agent -m "..."` | Single message mode | | `agent` | Interactive chat mode | | `gateway` | Start webhook server (default: `127.0.0.1:8080`) | | `gateway --port 0` | Random port mode | +| `daemon` | Start long-running autonomous runtime | +| `service install/start/stop/status/uninstall` | Manage user-level background service | +| `doctor` | Diagnose daemon/scheduler/channel freshness | | `status` | Show full system status | | `channel doctor` | Run health checks for configured channels | | `integrations info ` | Show setup/status details for one integration | diff --git a/SECURITY.md b/SECURITY.md index 9fc4b11..32c7c28 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -61,3 +61,33 @@ cargo test -- tools::shell cargo test -- tools::file_read cargo test -- tools::file_write ``` + +## Container Security + +ZeroClaw Docker images follow CIS Docker Benchmark best practices: + +| Control | Implementation | +|---------|----------------| +| **4.1 Non-root user** | Container runs as UID 65534 (distroless nonroot) | +| **4.2 Minimal base image** | `gcr.io/distroless/cc-debian12:nonroot` — no shell, no package manager | +| **4.6 HEALTHCHECK** | Not applicable (stateless CLI/gateway) | +| **5.25 Read-only filesystem** | Supported via `docker run --read-only` with `/workspace` volume | + +### Verifying Container Security + +```bash +# Build and verify non-root user +docker build -t zeroclaw . +docker inspect --format='{{.Config.User}}' zeroclaw +# Expected: 65534:65534 + +# Run with read-only filesystem (production hardening) +docker run --read-only -v /path/to/workspace:/workspace zeroclaw gateway +``` + +### CI Enforcement + +The `docker` job in `.github/workflows/ci.yml` automatically verifies: +1. Container does not run as root (UID 0) +2. Runtime stage uses `:nonroot` variant +3. Explicit `USER` directive with numeric UID exists diff --git a/scripts/test_dockerignore.sh b/scripts/test_dockerignore.sh new file mode 100755 index 0000000..839d21e --- /dev/null +++ b/scripts/test_dockerignore.sh @@ -0,0 +1,169 @@ +#!/usr/bin/env bash +# Test script to verify .dockerignore excludes sensitive paths +# Run: ./scripts/test_dockerignore.sh + +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +PROJECT_ROOT="$(dirname "$SCRIPT_DIR")" +DOCKERIGNORE="$PROJECT_ROOT/.dockerignore" + +RED='\033[0;31m' +GREEN='\033[0;32m' +NC='\033[0m' # No Color + +PASS=0 +FAIL=0 + +log_pass() { + echo -e "${GREEN}✓${NC} $1" + PASS=$((PASS + 1)) +} + +log_fail() { + echo -e "${RED}✗${NC} $1" + FAIL=$((FAIL + 1)) +} + +# Test 1: .dockerignore exists +echo "=== Testing .dockerignore ===" +if [[ -f "$DOCKERIGNORE" ]]; then + log_pass ".dockerignore file exists" +else + log_fail ".dockerignore file does not exist" + exit 1 +fi + +# Test 2: Required exclusions are present +MUST_EXCLUDE=( + ".git" + ".githooks" + "target" + "docs" + "examples" + "tests" + "*.md" + "*.png" + "*.db" + "*.db-journal" + ".DS_Store" + ".github" + "deny.toml" + "LICENSE" + ".env" + ".tmp_*" +) + +for pattern in "${MUST_EXCLUDE[@]}"; do + # Use fgrep for literal matching + if grep -Fq "$pattern" "$DOCKERIGNORE" 2>/dev/null; then + log_pass "Excludes: $pattern" + else + log_fail "Missing exclusion: $pattern" + fi +done + +# Test 3: Build essentials are NOT excluded +MUST_NOT_EXCLUDE=( + "Cargo.toml" + "Cargo.lock" + "src" +) + +for path in "${MUST_NOT_EXCLUDE[@]}"; do + if grep -qE "^${path}$" "$DOCKERIGNORE" 2>/dev/null; then + log_fail "Build essential '$path' is incorrectly excluded" + else + log_pass "Build essential NOT excluded: $path" + fi +done + +# Test 4: No syntax errors (basic validation) +while IFS= read -r line; do + # Skip empty lines and comments + [[ -z "$line" || "$line" =~ ^# ]] && continue + + # Check for common issues + if [[ "$line" =~ [[:space:]]$ ]]; then + log_fail "Trailing whitespace in pattern: '$line'" + fi +done < "$DOCKERIGNORE" +log_pass "No trailing whitespace in patterns" + +# Test 5: Verify Docker build context would be small +echo "" +echo "=== Simulating Docker build context ===" + +# Create temp dir and simulate what would be sent +TEMP_DIR=$(mktemp -d) +trap "rm -rf $TEMP_DIR" EXIT + +# Use rsync with .dockerignore patterns to simulate Docker's behavior +cd "$PROJECT_ROOT" + +# Count files that WOULD be sent (excluding .dockerignore patterns) +TOTAL_FILES=$(find . -type f | wc -l | tr -d ' ') +CONTEXT_FILES=$(find . -type f \ + ! -path './.git/*' \ + ! -path './target/*' \ + ! -path './docs/*' \ + ! -path './examples/*' \ + ! -path './tests/*' \ + ! -name '*.md' \ + ! -name '*.png' \ + ! -name '*.svg' \ + ! -name '*.db' \ + ! -name '*.db-journal' \ + ! -name '.DS_Store' \ + ! -path './.github/*' \ + ! -name 'deny.toml' \ + ! -name 'LICENSE' \ + ! -name '.env' \ + ! -name '.env.*' \ + 2>/dev/null | wc -l | tr -d ' ') + +echo "Total files in repo: $TOTAL_FILES" +echo "Files in Docker context: $CONTEXT_FILES" + +if [[ $CONTEXT_FILES -lt $TOTAL_FILES ]]; then + log_pass "Docker context is smaller than full repo ($CONTEXT_FILES < $TOTAL_FILES files)" +else + log_fail "Docker context is not being reduced" +fi + +# Test 6: Verify critical security files would be excluded +echo "" +echo "=== Security checks ===" + +# Check if .git would be excluded +if [[ -d "$PROJECT_ROOT/.git" ]]; then + if grep -q "^\.git$" "$DOCKERIGNORE"; then + log_pass ".git directory will be excluded (security)" + else + log_fail ".git directory NOT excluded - SECURITY RISK" + fi +fi + +# Check if any .db files exist and would be excluded +DB_FILES=$(find "$PROJECT_ROOT" -name "*.db" -type f 2>/dev/null | head -5) +if [[ -n "$DB_FILES" ]]; then + if grep -q "^\*\.db$" "$DOCKERIGNORE"; then + log_pass "*.db files will be excluded (security)" + else + log_fail "*.db files NOT excluded - SECURITY RISK" + fi +fi + +# Summary +echo "" +echo "=== Summary ===" +echo -e "Passed: ${GREEN}$PASS${NC}" +echo -e "Failed: ${RED}$FAIL${NC}" + +if [[ $FAIL -gt 0 ]]; then + echo -e "${RED}FAILED${NC}: $FAIL tests failed" + exit 1 +else + echo -e "${GREEN}PASSED${NC}: All tests passed" + exit 0 +fi diff --git a/src/agent/loop_.rs b/src/agent/loop_.rs index 57e0182..0f611d7 100644 --- a/src/agent/loop_.rs +++ b/src/agent/loop_.rs @@ -39,7 +39,7 @@ pub async fn run( // ── Wire up agnostic subsystems ────────────────────────────── let observer: Arc = Arc::from(observability::create_observer(&config.observability)); - let _runtime = runtime::create_runtime(&config.runtime); + let _runtime = runtime::create_runtime(&config.runtime)?; let security = Arc::new(SecurityPolicy::from_config( &config.autonomy, &config.workspace_dir, @@ -72,8 +72,11 @@ pub async fn run( .or(config.default_model.as_deref()) .unwrap_or("anthropic/claude-sonnet-4-20250514"); - let provider: Box = - providers::create_provider(provider_name, config.api_key.as_deref())?; + let provider: Box = providers::create_resilient_provider( + provider_name, + config.api_key.as_deref(), + &config.reliability, + )?; observer.record_event(&ObserverEvent::AgentStart { provider: provider_name.to_string(), @@ -83,12 +86,30 @@ pub async fn run( // ── Build system prompt from workspace MD files (OpenClaw framework) ── let skills = crate::skills::load_skills(&config.workspace_dir); let mut tool_descs: Vec<(&str, &str)> = vec![ - ("shell", "Execute terminal commands"), - ("file_read", "Read file contents"), - ("file_write", "Write file contents"), - ("memory_store", "Save to memory"), - ("memory_recall", "Search memory"), - ("memory_forget", "Delete a memory entry"), + ( + "shell", + "Execute terminal commands. Use when: running local checks, build/test commands, diagnostics. Don't use when: a safer dedicated tool exists, or command is destructive without approval.", + ), + ( + "file_read", + "Read file contents. Use when: inspecting project files, configs, logs. Don't use when: a targeted search is enough.", + ), + ( + "file_write", + "Write file contents. Use when: applying focused edits, scaffolding files, updating docs/code. Don't use when: side effects are unclear or file ownership is uncertain.", + ), + ( + "memory_store", + "Save to memory. Use when: preserving durable preferences, decisions, key context. Don't use when: information is transient/noisy/sensitive without need.", + ), + ( + "memory_recall", + "Search memory. Use when: retrieving prior decisions, user preferences, historical context. Don't use when: answer is already in current context.", + ), + ( + "memory_forget", + "Delete a memory entry. Use when: memory is incorrect/stale or explicitly requested for removal. Don't use when: impact is uncertain.", + ), ]; if config.browser.enabled { tool_descs.push(( diff --git a/src/channels/imessage.rs b/src/channels/imessage.rs index a0ac72e..c3a8abf 100644 --- a/src/channels/imessage.rs +++ b/src/channels/imessage.rs @@ -29,6 +29,60 @@ impl IMessageChannel { } } +/// Escape a string for safe interpolation into `AppleScript`. +/// +/// This prevents injection attacks by escaping: +/// - Backslashes (`\` → `\\`) +/// - Double quotes (`"` → `\"`) +fn escape_applescript(s: &str) -> String { + s.replace('\\', "\\\\").replace('"', "\\\"") +} + +/// Validate that a target looks like a valid phone number or email address. +/// +/// This is a defense-in-depth measure to reject obviously malicious targets +/// before they reach `AppleScript` interpolation. +/// +/// Valid patterns: +/// - Phone: starts with `+` followed by digits (with optional spaces/dashes) +/// - Email: contains `@` with alphanumeric chars on both sides +fn is_valid_imessage_target(target: &str) -> bool { + let target = target.trim(); + if target.is_empty() { + return false; + } + + // Phone number: +1234567890 or +1 234-567-8900 + if target.starts_with('+') { + let digits_only: String = target.chars().filter(char::is_ascii_digit).collect(); + // Must have at least 7 digits (shortest valid phone numbers) + return digits_only.len() >= 7 && digits_only.len() <= 15; + } + + // Email: simple validation (contains @ with chars on both sides) + if let Some(at_pos) = target.find('@') { + let local = &target[..at_pos]; + let domain = &target[at_pos + 1..]; + + // Local part: non-empty, alphanumeric + common email chars + let local_valid = !local.is_empty() + && local + .chars() + .all(|c| c.is_alphanumeric() || "._+-".contains(c)); + + // Domain: non-empty, contains a dot, alphanumeric + dots/hyphens + let domain_valid = !domain.is_empty() + && domain.contains('.') + && domain + .chars() + .all(|c| c.is_alphanumeric() || ".-".contains(c)); + + return local_valid && domain_valid; + } + + false +} + #[async_trait] impl Channel for IMessageChannel { fn name(&self) -> &str { @@ -36,11 +90,22 @@ impl Channel for IMessageChannel { } async fn send(&self, message: &str, target: &str) -> anyhow::Result<()> { - let escaped_msg = message.replace('\\', "\\\\").replace('"', "\\\""); + // Defense-in-depth: validate target format before any interpolation + if !is_valid_imessage_target(target) { + anyhow::bail!( + "Invalid iMessage target: must be a phone number (+1234567890) or email (user@example.com)" + ); + } + + // SECURITY: Escape both message AND target to prevent AppleScript injection + // See: CWE-78 (OS Command Injection) + let escaped_msg = escape_applescript(message); + let escaped_target = escape_applescript(target); + let script = format!( r#"tell application "Messages" set targetService to 1st account whose service type = iMessage - set targetBuddy to participant "{target}" of targetService + set targetBuddy to participant "{escaped_target}" of targetService send "{escaped_msg}" to targetBuddy end tell"# ); @@ -262,4 +327,204 @@ mod tests { assert!(ch.is_contact_allowed(" spaced ")); assert!(!ch.is_contact_allowed("spaced")); } + + // ══════════════════════════════════════════════════════════ + // AppleScript Escaping Tests (CWE-78 Prevention) + // ══════════════════════════════════════════════════════════ + + #[test] + fn escape_applescript_double_quotes() { + assert_eq!(escape_applescript(r#"hello "world""#), r#"hello \"world\""#); + } + + #[test] + fn escape_applescript_backslashes() { + assert_eq!(escape_applescript(r"path\to\file"), r"path\\to\\file"); + } + + #[test] + fn escape_applescript_mixed() { + assert_eq!( + escape_applescript(r#"say "hello\" world"#), + r#"say \"hello\\\" world"# + ); + } + + #[test] + fn escape_applescript_injection_attempt() { + // This is the exact attack vector from the security report + let malicious = r#"" & do shell script "id" & ""#; + let escaped = escape_applescript(malicious); + // After escaping, the quotes should be escaped and not break out + assert_eq!(escaped, r#"\" & do shell script \"id\" & \""#); + // Verify all quotes are now escaped (preceded by backslash) + // The escaped string should not have any unescaped quotes (quote not preceded by backslash) + let chars: Vec = escaped.chars().collect(); + for (i, &c) in chars.iter().enumerate() { + if c == '"' { + // Every quote must be preceded by a backslash + assert!( + i > 0 && chars[i - 1] == '\\', + "Found unescaped quote at position {i}" + ); + } + } + } + + #[test] + fn escape_applescript_empty_string() { + assert_eq!(escape_applescript(""), ""); + } + + #[test] + fn escape_applescript_no_special_chars() { + assert_eq!(escape_applescript("hello world"), "hello world"); + } + + #[test] + fn escape_applescript_unicode() { + assert_eq!(escape_applescript("hello 🦀 world"), "hello 🦀 world"); + } + + #[test] + fn escape_applescript_newlines_preserved() { + assert_eq!(escape_applescript("line1\nline2"), "line1\nline2"); + } + + // ══════════════════════════════════════════════════════════ + // Target Validation Tests + // ══════════════════════════════════════════════════════════ + + #[test] + fn valid_phone_number_simple() { + assert!(is_valid_imessage_target("+1234567890")); + } + + #[test] + fn valid_phone_number_with_country_code() { + assert!(is_valid_imessage_target("+14155551234")); + } + + #[test] + fn valid_phone_number_with_spaces() { + assert!(is_valid_imessage_target("+1 415 555 1234")); + } + + #[test] + fn valid_phone_number_with_dashes() { + assert!(is_valid_imessage_target("+1-415-555-1234")); + } + + #[test] + fn valid_phone_number_international() { + assert!(is_valid_imessage_target("+447911123456")); // UK + assert!(is_valid_imessage_target("+81312345678")); // Japan + } + + #[test] + fn valid_email_simple() { + assert!(is_valid_imessage_target("user@example.com")); + } + + #[test] + fn valid_email_with_subdomain() { + assert!(is_valid_imessage_target("user@mail.example.com")); + } + + #[test] + fn valid_email_with_plus() { + assert!(is_valid_imessage_target("user+tag@example.com")); + } + + #[test] + fn valid_email_with_dots() { + assert!(is_valid_imessage_target("first.last@example.com")); + } + + #[test] + fn valid_email_icloud() { + assert!(is_valid_imessage_target("user@icloud.com")); + assert!(is_valid_imessage_target("user@me.com")); + } + + #[test] + fn invalid_target_empty() { + assert!(!is_valid_imessage_target("")); + assert!(!is_valid_imessage_target(" ")); + } + + #[test] + fn invalid_target_no_plus_prefix() { + // Phone numbers must start with + + assert!(!is_valid_imessage_target("1234567890")); + } + + #[test] + fn invalid_target_too_short_phone() { + // Less than 7 digits + assert!(!is_valid_imessage_target("+123456")); + } + + #[test] + fn invalid_target_too_long_phone() { + // More than 15 digits + assert!(!is_valid_imessage_target("+1234567890123456")); + } + + #[test] + fn invalid_target_email_no_at() { + assert!(!is_valid_imessage_target("userexample.com")); + } + + #[test] + fn invalid_target_email_no_domain() { + assert!(!is_valid_imessage_target("user@")); + } + + #[test] + fn invalid_target_email_no_local() { + assert!(!is_valid_imessage_target("@example.com")); + } + + #[test] + fn invalid_target_email_no_dot_in_domain() { + assert!(!is_valid_imessage_target("user@localhost")); + } + + #[test] + fn invalid_target_injection_attempt() { + // The exact attack vector from the security report + assert!(!is_valid_imessage_target(r#"" & do shell script "id" & ""#)); + } + + #[test] + fn invalid_target_applescript_injection() { + // Various injection attempts + assert!(!is_valid_imessage_target(r#"test" & quit"#)); + assert!(!is_valid_imessage_target(r#"test\ndo shell script"#)); + assert!(!is_valid_imessage_target("test\"; malicious code; \"")); + } + + #[test] + fn invalid_target_special_chars() { + assert!(!is_valid_imessage_target("user & \"quotes\" 'apostrophe'" } + }] + } + }] + }] + }); + let msgs = ch.parse_webhook_payload(&payload); + assert_eq!(msgs.len(), 1); + assert_eq!( + msgs[0].content, + " & \"quotes\" 'apostrophe'" + ); } } diff --git a/src/config/mod.rs b/src/config/mod.rs index 9af098c..f5849c1 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -2,7 +2,7 @@ pub mod schema; pub use schema::{ AutonomyConfig, BrowserConfig, ChannelsConfig, ComposioConfig, Config, DiscordConfig, - GatewayConfig, HeartbeatConfig, IMessageConfig, MatrixConfig, MemoryConfig, - ObservabilityConfig, RuntimeConfig, SecretsConfig, SlackConfig, TelegramConfig, TunnelConfig, - WebhookConfig, + GatewayConfig, HeartbeatConfig, IMessageConfig, IdentityConfig, MatrixConfig, MemoryConfig, + ObservabilityConfig, ReliabilityConfig, RuntimeConfig, SecretsConfig, SlackConfig, + TelegramConfig, TunnelConfig, WebhookConfig, }; diff --git a/src/config/schema.rs b/src/config/schema.rs index 49a9d59..872a600 100644 --- a/src/config/schema.rs +++ b/src/config/schema.rs @@ -25,6 +25,9 @@ pub struct Config { #[serde(default)] pub runtime: RuntimeConfig, + #[serde(default)] + pub reliability: ReliabilityConfig, + #[serde(default)] pub heartbeat: HeartbeatConfig, @@ -48,6 +51,38 @@ pub struct Config { #[serde(default)] pub browser: BrowserConfig, + + #[serde(default)] + pub identity: IdentityConfig, +} + +// ── Identity (AIEOS / OpenClaw format) ────────────────────────── + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct IdentityConfig { + /// Identity format: "openclaw" (default) or "aieos" + #[serde(default = "default_identity_format")] + pub format: String, + /// Path to AIEOS JSON file (relative to workspace) + #[serde(default)] + pub aieos_path: Option, + /// Inline AIEOS JSON (alternative to file path) + #[serde(default)] + pub aieos_inline: Option, +} + +fn default_identity_format() -> String { + "openclaw".into() +} + +impl Default for IdentityConfig { + fn default() -> Self { + Self { + format: default_identity_format(), + aieos_path: None, + aieos_inline: None, + } + } } // ── Gateway security ───────────────────────────────────────────── @@ -143,6 +178,18 @@ pub struct MemoryConfig { pub backend: String, /// Auto-save conversation context to memory pub auto_save: bool, + /// Run memory/session hygiene (archiving + retention cleanup) + #[serde(default = "default_hygiene_enabled")] + pub hygiene_enabled: bool, + /// Archive daily/session files older than this many days + #[serde(default = "default_archive_after_days")] + pub archive_after_days: u32, + /// Purge archived files older than this many days + #[serde(default = "default_purge_after_days")] + pub purge_after_days: u32, + /// For sqlite backend: prune conversation rows older than this many days + #[serde(default = "default_conversation_retention_days")] + pub conversation_retention_days: u32, /// Embedding provider: "none" | "openai" | "custom:URL" #[serde(default = "default_embedding_provider")] pub embedding_provider: String, @@ -169,6 +216,18 @@ pub struct MemoryConfig { fn default_embedding_provider() -> String { "none".into() } +fn default_hygiene_enabled() -> bool { + true +} +fn default_archive_after_days() -> u32 { + 7 +} +fn default_purge_after_days() -> u32 { + 30 +} +fn default_conversation_retention_days() -> u32 { + 30 +} fn default_embedding_model() -> String { "text-embedding-3-small".into() } @@ -193,6 +252,10 @@ impl Default for MemoryConfig { Self { backend: "sqlite".into(), auto_save: true, + hygiene_enabled: default_hygiene_enabled(), + archive_after_days: default_archive_after_days(), + purge_after_days: default_purge_after_days(), + conversation_retention_days: default_conversation_retention_days(), embedding_provider: default_embedding_provider(), embedding_model: default_embedding_model(), embedding_dimensions: default_embedding_dims(), @@ -281,7 +344,9 @@ impl Default for AutonomyConfig { #[derive(Debug, Clone, Serialize, Deserialize)] pub struct RuntimeConfig { - /// "native" | "docker" | "cloudflare" + /// Runtime kind (currently supported: "native"). + /// + /// Reserved values (not implemented yet): "docker", "cloudflare". pub kind: String, } @@ -293,6 +358,71 @@ impl Default for RuntimeConfig { } } +// ── Reliability / supervision ──────────────────────────────────── + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ReliabilityConfig { + /// Retries per provider before failing over. + #[serde(default = "default_provider_retries")] + pub provider_retries: u32, + /// Base backoff (ms) for provider retry delay. + #[serde(default = "default_provider_backoff_ms")] + pub provider_backoff_ms: u64, + /// Fallback provider chain (e.g. `["anthropic", "openai"]`). + #[serde(default)] + pub fallback_providers: Vec, + /// Initial backoff for channel/daemon restarts. + #[serde(default = "default_channel_backoff_secs")] + pub channel_initial_backoff_secs: u64, + /// Max backoff for channel/daemon restarts. + #[serde(default = "default_channel_backoff_max_secs")] + pub channel_max_backoff_secs: u64, + /// Scheduler polling cadence in seconds. + #[serde(default = "default_scheduler_poll_secs")] + pub scheduler_poll_secs: u64, + /// Max retries for cron job execution attempts. + #[serde(default = "default_scheduler_retries")] + pub scheduler_retries: u32, +} + +fn default_provider_retries() -> u32 { + 2 +} + +fn default_provider_backoff_ms() -> u64 { + 500 +} + +fn default_channel_backoff_secs() -> u64 { + 2 +} + +fn default_channel_backoff_max_secs() -> u64 { + 60 +} + +fn default_scheduler_poll_secs() -> u64 { + 15 +} + +fn default_scheduler_retries() -> u32 { + 2 +} + +impl Default for ReliabilityConfig { + fn default() -> Self { + Self { + provider_retries: default_provider_retries(), + provider_backoff_ms: default_provider_backoff_ms(), + fallback_providers: Vec::new(), + channel_initial_backoff_secs: default_channel_backoff_secs(), + channel_max_backoff_secs: default_channel_backoff_max_secs(), + scheduler_poll_secs: default_scheduler_poll_secs(), + scheduler_retries: default_scheduler_retries(), + } + } +} + // ── Heartbeat ──────────────────────────────────────────────────── #[derive(Debug, Clone, Serialize, Deserialize)] @@ -387,6 +517,7 @@ pub struct ChannelsConfig { pub webhook: Option, pub imessage: Option, pub matrix: Option, + pub whatsapp: Option, } impl Default for ChannelsConfig { @@ -399,6 +530,7 @@ impl Default for ChannelsConfig { webhook: None, imessage: None, matrix: None, + whatsapp: None, } } } @@ -445,6 +577,19 @@ pub struct MatrixConfig { pub allowed_users: Vec, } +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct WhatsAppConfig { + /// Access token from Meta Business Suite + pub access_token: String, + /// Phone number ID from Meta Business API + pub phone_number_id: String, + /// Webhook verify token (you define this, Meta sends it back for verification) + pub verify_token: String, + /// Allowed phone numbers (E.164 format: +1234567890) or "*" for all + #[serde(default)] + pub allowed_numbers: Vec, +} + // ── Config impl ────────────────────────────────────────────────── impl Default for Config { @@ -463,6 +608,7 @@ impl Default for Config { observability: ObservabilityConfig::default(), autonomy: AutonomyConfig::default(), runtime: RuntimeConfig::default(), + reliability: ReliabilityConfig::default(), heartbeat: HeartbeatConfig::default(), channels_config: ChannelsConfig::default(), memory: MemoryConfig::default(), @@ -471,6 +617,7 @@ impl Default for Config { composio: ComposioConfig::default(), secrets: SecretsConfig::default(), browser: BrowserConfig::default(), + identity: IdentityConfig::default(), } } } @@ -558,6 +705,17 @@ mod tests { assert_eq!(h.interval_minutes, 30); } + #[test] + fn memory_config_default_hygiene_settings() { + let m = MemoryConfig::default(); + assert_eq!(m.backend, "sqlite"); + assert!(m.auto_save); + assert!(m.hygiene_enabled); + assert_eq!(m.archive_after_days, 7); + assert_eq!(m.purge_after_days, 30); + assert_eq!(m.conversation_retention_days, 30); + } + #[test] fn channels_config_default() { let c = ChannelsConfig::default(); @@ -591,6 +749,7 @@ mod tests { runtime: RuntimeConfig { kind: "docker".into(), }, + reliability: ReliabilityConfig::default(), heartbeat: HeartbeatConfig { enabled: true, interval_minutes: 15, @@ -606,6 +765,7 @@ mod tests { webhook: None, imessage: None, matrix: None, + whatsapp: None, }, memory: MemoryConfig::default(), tunnel: TunnelConfig::default(), @@ -613,6 +773,7 @@ mod tests { composio: ComposioConfig::default(), secrets: SecretsConfig::default(), browser: BrowserConfig::default(), + identity: IdentityConfig::default(), }; let toml_str = toml::to_string_pretty(&config).unwrap(); @@ -650,6 +811,10 @@ default_temperature = 0.7 assert_eq!(parsed.runtime.kind, "native"); assert!(!parsed.heartbeat.enabled); assert!(parsed.channels_config.cli); + assert!(parsed.memory.hygiene_enabled); + assert_eq!(parsed.memory.archive_after_days, 7); + assert_eq!(parsed.memory.purge_after_days, 30); + assert_eq!(parsed.memory.conversation_retention_days, 30); } #[test] @@ -669,6 +834,7 @@ default_temperature = 0.7 observability: ObservabilityConfig::default(), autonomy: AutonomyConfig::default(), runtime: RuntimeConfig::default(), + reliability: ReliabilityConfig::default(), heartbeat: HeartbeatConfig::default(), channels_config: ChannelsConfig::default(), memory: MemoryConfig::default(), @@ -677,6 +843,7 @@ default_temperature = 0.7 composio: ComposioConfig::default(), secrets: SecretsConfig::default(), browser: BrowserConfig::default(), + identity: IdentityConfig::default(), }; config.save().unwrap(); @@ -810,6 +977,7 @@ default_temperature = 0.7 room_id: "!r:m".into(), allowed_users: vec!["@u:m".into()], }), + whatsapp: None, }; let toml_str = toml::to_string_pretty(&c).unwrap(); let parsed: ChannelsConfig = toml::from_str(&toml_str).unwrap(); @@ -894,6 +1062,89 @@ channel_id = "C123" assert_eq!(parsed.port, 8080); } + // ── WhatsApp config ────────────────────────────────────── + + #[test] + fn whatsapp_config_serde() { + let wc = WhatsAppConfig { + access_token: "EAABx...".into(), + phone_number_id: "123456789".into(), + verify_token: "my-verify-token".into(), + allowed_numbers: vec!["+1234567890".into(), "+9876543210".into()], + }; + let json = serde_json::to_string(&wc).unwrap(); + let parsed: WhatsAppConfig = serde_json::from_str(&json).unwrap(); + assert_eq!(parsed.access_token, "EAABx..."); + assert_eq!(parsed.phone_number_id, "123456789"); + assert_eq!(parsed.verify_token, "my-verify-token"); + assert_eq!(parsed.allowed_numbers.len(), 2); + } + + #[test] + fn whatsapp_config_toml_roundtrip() { + let wc = WhatsAppConfig { + access_token: "tok".into(), + phone_number_id: "12345".into(), + verify_token: "verify".into(), + allowed_numbers: vec!["+1".into()], + }; + let toml_str = toml::to_string(&wc).unwrap(); + let parsed: WhatsAppConfig = toml::from_str(&toml_str).unwrap(); + assert_eq!(parsed.phone_number_id, "12345"); + assert_eq!(parsed.allowed_numbers, vec!["+1"]); + } + + #[test] + fn whatsapp_config_deserializes_without_allowed_numbers() { + let json = r#"{"access_token":"tok","phone_number_id":"123","verify_token":"ver"}"#; + let parsed: WhatsAppConfig = serde_json::from_str(json).unwrap(); + assert!(parsed.allowed_numbers.is_empty()); + } + + #[test] + fn whatsapp_config_wildcard_allowed() { + let wc = WhatsAppConfig { + access_token: "tok".into(), + phone_number_id: "123".into(), + verify_token: "ver".into(), + allowed_numbers: vec!["*".into()], + }; + let toml_str = toml::to_string(&wc).unwrap(); + let parsed: WhatsAppConfig = toml::from_str(&toml_str).unwrap(); + assert_eq!(parsed.allowed_numbers, vec!["*"]); + } + + #[test] + fn channels_config_with_whatsapp() { + let c = ChannelsConfig { + cli: true, + telegram: None, + discord: None, + slack: None, + webhook: None, + imessage: None, + matrix: None, + whatsapp: Some(WhatsAppConfig { + access_token: "tok".into(), + phone_number_id: "123".into(), + verify_token: "ver".into(), + allowed_numbers: vec!["+1".into()], + }), + }; + let toml_str = toml::to_string_pretty(&c).unwrap(); + let parsed: ChannelsConfig = toml::from_str(&toml_str).unwrap(); + assert!(parsed.whatsapp.is_some()); + let wa = parsed.whatsapp.unwrap(); + assert_eq!(wa.phone_number_id, "123"); + assert_eq!(wa.allowed_numbers, vec!["+1"]); + } + + #[test] + fn channels_config_default_has_no_whatsapp() { + let c = ChannelsConfig::default(); + assert!(c.whatsapp.is_none()); + } + // ══════════════════════════════════════════════════════════ // SECURITY CHECKLIST TESTS — Gateway config // ══════════════════════════════════════════════════════════ diff --git a/src/cron/mod.rs b/src/cron/mod.rs index 8f52701..572670d 100644 --- a/src/cron/mod.rs +++ b/src/cron/mod.rs @@ -1,25 +1,353 @@ use crate::config::Config; -use anyhow::Result; +use anyhow::{Context, Result}; +use chrono::{DateTime, Utc}; +use cron::Schedule; +use rusqlite::{params, Connection}; +use std::str::FromStr; +use uuid::Uuid; -pub fn handle_command(command: super::CronCommands, _config: Config) -> Result<()> { +pub mod scheduler; + +#[derive(Debug, Clone)] +pub struct CronJob { + pub id: String, + pub expression: String, + pub command: String, + pub next_run: DateTime, + pub last_run: Option>, + pub last_status: Option, +} + +pub fn handle_command(command: super::CronCommands, config: Config) -> Result<()> { match command { super::CronCommands::List => { - println!("No scheduled tasks yet."); - println!("\nUsage:"); - println!(" zeroclaw cron add '0 9 * * *' 'agent -m \"Good morning!\"'"); + let jobs = list_jobs(&config)?; + if jobs.is_empty() { + println!("No scheduled tasks yet."); + println!("\nUsage:"); + println!(" zeroclaw cron add '0 9 * * *' 'agent -m \"Good morning!\"'"); + return Ok(()); + } + + println!("🕒 Scheduled jobs ({}):", jobs.len()); + for job in jobs { + let last_run = job + .last_run + .map(|d| d.to_rfc3339()) + .unwrap_or_else(|| "never".into()); + let last_status = job.last_status.unwrap_or_else(|| "n/a".into()); + println!( + "- {} | {} | next={} | last={} ({})\n cmd: {}", + job.id, + job.expression, + job.next_run.to_rfc3339(), + last_run, + last_status, + job.command + ); + } Ok(()) } super::CronCommands::Add { expression, command, } => { - println!("Cron scheduling coming soon!"); - println!(" Expression: {expression}"); - println!(" Command: {command}"); + let job = add_job(&config, &expression, &command)?; + println!("✅ Added cron job {}", job.id); + println!(" Expr: {}", job.expression); + println!(" Next: {}", job.next_run.to_rfc3339()); + println!(" Cmd : {}", job.command); Ok(()) } - super::CronCommands::Remove { id } => { - anyhow::bail!("Remove task '{id}' not yet implemented"); - } + super::CronCommands::Remove { id } => remove_job(&config, &id), + } +} + +pub fn add_job(config: &Config, expression: &str, command: &str) -> Result { + let now = Utc::now(); + let next_run = next_run_for(expression, now)?; + let id = Uuid::new_v4().to_string(); + + with_connection(config, |conn| { + conn.execute( + "INSERT INTO cron_jobs (id, expression, command, created_at, next_run) + VALUES (?1, ?2, ?3, ?4, ?5)", + params![ + id, + expression, + command, + now.to_rfc3339(), + next_run.to_rfc3339() + ], + ) + .context("Failed to insert cron job")?; + Ok(()) + })?; + + Ok(CronJob { + id, + expression: expression.to_string(), + command: command.to_string(), + next_run, + last_run: None, + last_status: None, + }) +} + +pub fn list_jobs(config: &Config) -> Result> { + with_connection(config, |conn| { + let mut stmt = conn.prepare( + "SELECT id, expression, command, next_run, last_run, last_status + FROM cron_jobs ORDER BY next_run ASC", + )?; + + let rows = stmt.query_map([], |row| { + let next_run_raw: String = row.get(3)?; + let last_run_raw: Option = row.get(4)?; + Ok(( + row.get::<_, String>(0)?, + row.get::<_, String>(1)?, + row.get::<_, String>(2)?, + next_run_raw, + last_run_raw, + row.get::<_, Option>(5)?, + )) + })?; + + let mut jobs = Vec::new(); + for row in rows { + let (id, expression, command, next_run_raw, last_run_raw, last_status) = row?; + jobs.push(CronJob { + id, + expression, + command, + next_run: parse_rfc3339(&next_run_raw)?, + last_run: match last_run_raw { + Some(raw) => Some(parse_rfc3339(&raw)?), + None => None, + }, + last_status, + }); + } + Ok(jobs) + }) +} + +pub fn remove_job(config: &Config, id: &str) -> Result<()> { + let changed = with_connection(config, |conn| { + conn.execute("DELETE FROM cron_jobs WHERE id = ?1", params![id]) + .context("Failed to delete cron job") + })?; + + if changed == 0 { + anyhow::bail!("Cron job '{id}' not found"); + } + + println!("✅ Removed cron job {id}"); + Ok(()) +} + +pub fn due_jobs(config: &Config, now: DateTime) -> Result> { + with_connection(config, |conn| { + let mut stmt = conn.prepare( + "SELECT id, expression, command, next_run, last_run, last_status + FROM cron_jobs WHERE next_run <= ?1 ORDER BY next_run ASC", + )?; + + let rows = stmt.query_map(params![now.to_rfc3339()], |row| { + let next_run_raw: String = row.get(3)?; + let last_run_raw: Option = row.get(4)?; + Ok(( + row.get::<_, String>(0)?, + row.get::<_, String>(1)?, + row.get::<_, String>(2)?, + next_run_raw, + last_run_raw, + row.get::<_, Option>(5)?, + )) + })?; + + let mut jobs = Vec::new(); + for row in rows { + let (id, expression, command, next_run_raw, last_run_raw, last_status) = row?; + jobs.push(CronJob { + id, + expression, + command, + next_run: parse_rfc3339(&next_run_raw)?, + last_run: match last_run_raw { + Some(raw) => Some(parse_rfc3339(&raw)?), + None => None, + }, + last_status, + }); + } + Ok(jobs) + }) +} + +pub fn reschedule_after_run( + config: &Config, + job: &CronJob, + success: bool, + output: &str, +) -> Result<()> { + let now = Utc::now(); + let next_run = next_run_for(&job.expression, now)?; + let status = if success { "ok" } else { "error" }; + + with_connection(config, |conn| { + conn.execute( + "UPDATE cron_jobs + SET next_run = ?1, last_run = ?2, last_status = ?3, last_output = ?4 + WHERE id = ?5", + params![ + next_run.to_rfc3339(), + now.to_rfc3339(), + status, + output, + job.id + ], + ) + .context("Failed to update cron job run state")?; + Ok(()) + }) +} + +fn next_run_for(expression: &str, from: DateTime) -> Result> { + let normalized = normalize_expression(expression)?; + let schedule = Schedule::from_str(&normalized) + .with_context(|| format!("Invalid cron expression: {expression}"))?; + schedule + .after(&from) + .next() + .ok_or_else(|| anyhow::anyhow!("No future occurrence for expression: {expression}")) +} + +fn normalize_expression(expression: &str) -> Result { + let expression = expression.trim(); + let field_count = expression.split_whitespace().count(); + + match field_count { + // standard crontab syntax: minute hour day month weekday + 5 => Ok(format!("0 {expression}")), + // crate-native syntax includes seconds (+ optional year) + 6 | 7 => Ok(expression.to_string()), + _ => anyhow::bail!( + "Invalid cron expression: {expression} (expected 5, 6, or 7 fields, got {field_count})" + ), + } +} + +fn parse_rfc3339(raw: &str) -> Result> { + let parsed = DateTime::parse_from_rfc3339(raw) + .with_context(|| format!("Invalid RFC3339 timestamp in cron DB: {raw}"))?; + Ok(parsed.with_timezone(&Utc)) +} + +fn with_connection(config: &Config, f: impl FnOnce(&Connection) -> Result) -> Result { + let db_path = config.workspace_dir.join("cron").join("jobs.db"); + if let Some(parent) = db_path.parent() { + std::fs::create_dir_all(parent) + .with_context(|| format!("Failed to create cron directory: {}", parent.display()))?; + } + + let conn = Connection::open(&db_path) + .with_context(|| format!("Failed to open cron DB: {}", db_path.display()))?; + + conn.execute_batch( + "CREATE TABLE IF NOT EXISTS cron_jobs ( + id TEXT PRIMARY KEY, + expression TEXT NOT NULL, + command TEXT NOT NULL, + created_at TEXT NOT NULL, + next_run TEXT NOT NULL, + last_run TEXT, + last_status TEXT, + last_output TEXT + ); + CREATE INDEX IF NOT EXISTS idx_cron_jobs_next_run ON cron_jobs(next_run);", + ) + .context("Failed to initialize cron schema")?; + + f(&conn) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::config::Config; + use chrono::Duration as ChronoDuration; + use tempfile::TempDir; + + fn test_config(tmp: &TempDir) -> Config { + let mut config = Config::default(); + config.workspace_dir = tmp.path().join("workspace"); + config.config_path = tmp.path().join("config.toml"); + std::fs::create_dir_all(&config.workspace_dir).unwrap(); + config + } + + #[test] + fn add_job_accepts_five_field_expression() { + let tmp = TempDir::new().unwrap(); + let config = test_config(&tmp); + + let job = add_job(&config, "*/5 * * * *", "echo ok").unwrap(); + + assert_eq!(job.expression, "*/5 * * * *"); + assert_eq!(job.command, "echo ok"); + } + + #[test] + fn add_job_rejects_invalid_field_count() { + let tmp = TempDir::new().unwrap(); + let config = test_config(&tmp); + + let err = add_job(&config, "* * * *", "echo bad").unwrap_err(); + assert!(err.to_string().contains("expected 5, 6, or 7 fields")); + } + + #[test] + fn add_list_remove_roundtrip() { + let tmp = TempDir::new().unwrap(); + let config = test_config(&tmp); + + let job = add_job(&config, "*/10 * * * *", "echo roundtrip").unwrap(); + let listed = list_jobs(&config).unwrap(); + assert_eq!(listed.len(), 1); + assert_eq!(listed[0].id, job.id); + + remove_job(&config, &job.id).unwrap(); + assert!(list_jobs(&config).unwrap().is_empty()); + } + + #[test] + fn due_jobs_filters_by_timestamp() { + let tmp = TempDir::new().unwrap(); + let config = test_config(&tmp); + + let _job = add_job(&config, "* * * * *", "echo due").unwrap(); + + let due_now = due_jobs(&config, Utc::now()).unwrap(); + assert!(due_now.is_empty(), "new job should not be due immediately"); + + let far_future = Utc::now() + ChronoDuration::days(365); + let due_future = due_jobs(&config, far_future).unwrap(); + assert_eq!(due_future.len(), 1, "job should be due in far future"); + } + + #[test] + fn reschedule_after_run_persists_last_status_and_last_run() { + let tmp = TempDir::new().unwrap(); + let config = test_config(&tmp); + + let job = add_job(&config, "*/15 * * * *", "echo run").unwrap(); + reschedule_after_run(&config, &job, false, "failed output").unwrap(); + + let listed = list_jobs(&config).unwrap(); + let stored = listed.iter().find(|j| j.id == job.id).unwrap(); + assert_eq!(stored.last_status.as_deref(), Some("error")); + assert!(stored.last_run.is_some()); } } diff --git a/src/cron/scheduler.rs b/src/cron/scheduler.rs new file mode 100644 index 0000000..973fbee --- /dev/null +++ b/src/cron/scheduler.rs @@ -0,0 +1,297 @@ +use crate::config::Config; +use crate::cron::{due_jobs, reschedule_after_run, CronJob}; +use crate::security::SecurityPolicy; +use anyhow::Result; +use chrono::Utc; +use tokio::process::Command; +use tokio::time::{self, Duration}; + +const MIN_POLL_SECONDS: u64 = 5; + +pub async fn run(config: Config) -> Result<()> { + let poll_secs = config.reliability.scheduler_poll_secs.max(MIN_POLL_SECONDS); + let mut interval = time::interval(Duration::from_secs(poll_secs)); + let security = SecurityPolicy::from_config(&config.autonomy, &config.workspace_dir); + + crate::health::mark_component_ok("scheduler"); + + loop { + interval.tick().await; + + let jobs = match due_jobs(&config, Utc::now()) { + Ok(jobs) => jobs, + Err(e) => { + crate::health::mark_component_error("scheduler", e.to_string()); + tracing::warn!("Scheduler query failed: {e}"); + continue; + } + }; + + for job in jobs { + crate::health::mark_component_ok("scheduler"); + let (success, output) = execute_job_with_retry(&config, &security, &job).await; + + if !success { + crate::health::mark_component_error("scheduler", format!("job {} failed", job.id)); + } + + if let Err(e) = reschedule_after_run(&config, &job, success, &output) { + crate::health::mark_component_error("scheduler", e.to_string()); + tracing::warn!("Failed to persist scheduler run result: {e}"); + } + } + } +} + +async fn execute_job_with_retry( + config: &Config, + security: &SecurityPolicy, + job: &CronJob, +) -> (bool, String) { + let mut last_output = String::new(); + let retries = config.reliability.scheduler_retries; + let mut backoff_ms = config.reliability.provider_backoff_ms.max(200); + + for attempt in 0..=retries { + let (success, output) = run_job_command(config, security, job).await; + last_output = output; + + if success { + return (true, last_output); + } + + if last_output.starts_with("blocked by security policy:") { + // Deterministic policy violations are not retryable. + return (false, last_output); + } + + if attempt < retries { + let jitter_ms = (Utc::now().timestamp_subsec_millis() % 250) as u64; + time::sleep(Duration::from_millis(backoff_ms + jitter_ms)).await; + backoff_ms = (backoff_ms.saturating_mul(2)).min(30_000); + } + } + + (false, last_output) +} + +fn is_env_assignment(word: &str) -> bool { + word.contains('=') + && word + .chars() + .next() + .is_some_and(|c| c.is_ascii_alphabetic() || c == '_') +} + +fn strip_wrapping_quotes(token: &str) -> &str { + token.trim_matches(|c| c == '"' || c == '\'') +} + +fn forbidden_path_argument(security: &SecurityPolicy, command: &str) -> Option { + let mut normalized = command.to_string(); + for sep in ["&&", "||"] { + normalized = normalized.replace(sep, "\x00"); + } + for sep in ['\n', ';', '|'] { + normalized = normalized.replace(sep, "\x00"); + } + + for segment in normalized.split('\x00') { + let tokens: Vec<&str> = segment.split_whitespace().collect(); + if tokens.is_empty() { + continue; + } + + // Skip leading env assignments and executable token. + let mut idx = 0; + while idx < tokens.len() && is_env_assignment(tokens[idx]) { + idx += 1; + } + if idx >= tokens.len() { + continue; + } + idx += 1; + + for token in &tokens[idx..] { + let candidate = strip_wrapping_quotes(token); + if candidate.is_empty() || candidate.starts_with('-') || candidate.contains("://") { + continue; + } + + let looks_like_path = candidate.starts_with('/') + || candidate.starts_with("./") + || candidate.starts_with("../") + || candidate.starts_with("~/") + || candidate.contains('/'); + + if looks_like_path && !security.is_path_allowed(candidate) { + return Some(candidate.to_string()); + } + } + } + + None +} + +async fn run_job_command( + config: &Config, + security: &SecurityPolicy, + job: &CronJob, +) -> (bool, String) { + if !security.is_command_allowed(&job.command) { + return ( + false, + format!( + "blocked by security policy: command not allowed: {}", + job.command + ), + ); + } + + if let Some(path) = forbidden_path_argument(security, &job.command) { + return ( + false, + format!("blocked by security policy: forbidden path argument: {path}"), + ); + } + + let output = Command::new("sh") + .arg("-lc") + .arg(&job.command) + .current_dir(&config.workspace_dir) + .output() + .await; + + match output { + Ok(output) => { + let stdout = String::from_utf8_lossy(&output.stdout); + let stderr = String::from_utf8_lossy(&output.stderr); + let combined = format!( + "status={}\nstdout:\n{}\nstderr:\n{}", + output.status, + stdout.trim(), + stderr.trim() + ); + (output.status.success(), combined) + } + Err(e) => (false, format!("spawn error: {e}")), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::config::Config; + use crate::security::SecurityPolicy; + use tempfile::TempDir; + + fn test_config(tmp: &TempDir) -> Config { + let mut config = Config::default(); + config.workspace_dir = tmp.path().join("workspace"); + config.config_path = tmp.path().join("config.toml"); + std::fs::create_dir_all(&config.workspace_dir).unwrap(); + config + } + + fn test_job(command: &str) -> CronJob { + CronJob { + id: "test-job".into(), + expression: "* * * * *".into(), + command: command.into(), + next_run: Utc::now(), + last_run: None, + last_status: None, + } + } + + #[tokio::test] + async fn run_job_command_success() { + let tmp = TempDir::new().unwrap(); + let config = test_config(&tmp); + let job = test_job("echo scheduler-ok"); + let security = SecurityPolicy::from_config(&config.autonomy, &config.workspace_dir); + + let (success, output) = run_job_command(&config, &security, &job).await; + assert!(success); + assert!(output.contains("scheduler-ok")); + assert!(output.contains("status=exit status: 0")); + } + + #[tokio::test] + async fn run_job_command_failure() { + let tmp = TempDir::new().unwrap(); + let config = test_config(&tmp); + let job = test_job("ls definitely_missing_file_for_scheduler_test"); + let security = SecurityPolicy::from_config(&config.autonomy, &config.workspace_dir); + + let (success, output) = run_job_command(&config, &security, &job).await; + assert!(!success); + assert!(output.contains("definitely_missing_file_for_scheduler_test")); + assert!(output.contains("status=exit status:")); + } + + #[tokio::test] + async fn run_job_command_blocks_disallowed_command() { + let tmp = TempDir::new().unwrap(); + let mut config = test_config(&tmp); + config.autonomy.allowed_commands = vec!["echo".into()]; + let job = test_job("curl https://evil.example"); + let security = SecurityPolicy::from_config(&config.autonomy, &config.workspace_dir); + + let (success, output) = run_job_command(&config, &security, &job).await; + assert!(!success); + assert!(output.contains("blocked by security policy")); + assert!(output.contains("command not allowed")); + } + + #[tokio::test] + async fn run_job_command_blocks_forbidden_path_argument() { + let tmp = TempDir::new().unwrap(); + let mut config = test_config(&tmp); + config.autonomy.allowed_commands = vec!["cat".into()]; + let job = test_job("cat /etc/passwd"); + let security = SecurityPolicy::from_config(&config.autonomy, &config.workspace_dir); + + let (success, output) = run_job_command(&config, &security, &job).await; + assert!(!success); + assert!(output.contains("blocked by security policy")); + assert!(output.contains("forbidden path argument")); + assert!(output.contains("/etc/passwd")); + } + + #[tokio::test] + async fn execute_job_with_retry_recovers_after_first_failure() { + let tmp = TempDir::new().unwrap(); + let mut config = test_config(&tmp); + config.reliability.scheduler_retries = 1; + config.reliability.provider_backoff_ms = 1; + config.autonomy.allowed_commands = vec!["sh".into()]; + let security = SecurityPolicy::from_config(&config.autonomy, &config.workspace_dir); + + std::fs::write( + config.workspace_dir.join("retry-once.sh"), + "#!/bin/sh\nif [ -f retry-ok.flag ]; then\n echo recovered\n exit 0\nfi\ntouch retry-ok.flag\nexit 1\n", + ) + .unwrap(); + let job = test_job("sh ./retry-once.sh"); + + let (success, output) = execute_job_with_retry(&config, &security, &job).await; + assert!(success); + assert!(output.contains("recovered")); + } + + #[tokio::test] + async fn execute_job_with_retry_exhausts_attempts() { + let tmp = TempDir::new().unwrap(); + let mut config = test_config(&tmp); + config.reliability.scheduler_retries = 1; + config.reliability.provider_backoff_ms = 1; + let security = SecurityPolicy::from_config(&config.autonomy, &config.workspace_dir); + + let job = test_job("ls always_missing_for_retry_test"); + + let (success, output) = execute_job_with_retry(&config, &security, &job).await; + assert!(!success); + assert!(output.contains("always_missing_for_retry_test")); + } +} diff --git a/src/daemon/mod.rs b/src/daemon/mod.rs new file mode 100644 index 0000000..db374bc --- /dev/null +++ b/src/daemon/mod.rs @@ -0,0 +1,287 @@ +use crate::config::Config; +use anyhow::Result; +use chrono::Utc; +use std::future::Future; +use std::path::PathBuf; +use tokio::task::JoinHandle; +use tokio::time::Duration; + +const STATUS_FLUSH_SECONDS: u64 = 5; + +pub async fn run(config: Config, host: String, port: u16) -> Result<()> { + let initial_backoff = config.reliability.channel_initial_backoff_secs.max(1); + let max_backoff = config + .reliability + .channel_max_backoff_secs + .max(initial_backoff); + + crate::health::mark_component_ok("daemon"); + + if config.heartbeat.enabled { + let _ = + crate::heartbeat::engine::HeartbeatEngine::ensure_heartbeat_file(&config.workspace_dir) + .await; + } + + let mut handles: Vec> = vec![spawn_state_writer(config.clone())]; + + { + let gateway_cfg = config.clone(); + let gateway_host = host.clone(); + handles.push(spawn_component_supervisor( + "gateway", + initial_backoff, + max_backoff, + move || { + let cfg = gateway_cfg.clone(); + let host = gateway_host.clone(); + async move { crate::gateway::run_gateway(&host, port, cfg).await } + }, + )); + } + + { + if has_supervised_channels(&config) { + let channels_cfg = config.clone(); + handles.push(spawn_component_supervisor( + "channels", + initial_backoff, + max_backoff, + move || { + let cfg = channels_cfg.clone(); + async move { crate::channels::start_channels(cfg).await } + }, + )); + } else { + crate::health::mark_component_ok("channels"); + tracing::info!("No real-time channels configured; channel supervisor disabled"); + } + } + + if config.heartbeat.enabled { + let heartbeat_cfg = config.clone(); + handles.push(spawn_component_supervisor( + "heartbeat", + initial_backoff, + max_backoff, + move || { + let cfg = heartbeat_cfg.clone(); + async move { run_heartbeat_worker(cfg).await } + }, + )); + } + + { + let scheduler_cfg = config.clone(); + handles.push(spawn_component_supervisor( + "scheduler", + initial_backoff, + max_backoff, + move || { + let cfg = scheduler_cfg.clone(); + async move { crate::cron::scheduler::run(cfg).await } + }, + )); + } + + println!("🧠 ZeroClaw daemon started"); + println!(" Gateway: http://{host}:{port}"); + println!(" Components: gateway, channels, heartbeat, scheduler"); + println!(" Ctrl+C to stop"); + + tokio::signal::ctrl_c().await?; + crate::health::mark_component_error("daemon", "shutdown requested"); + + for handle in &handles { + handle.abort(); + } + for handle in handles { + let _ = handle.await; + } + + Ok(()) +} + +pub fn state_file_path(config: &Config) -> PathBuf { + config + .config_path + .parent() + .map_or_else(|| PathBuf::from("."), PathBuf::from) + .join("daemon_state.json") +} + +fn spawn_state_writer(config: Config) -> JoinHandle<()> { + tokio::spawn(async move { + let path = state_file_path(&config); + if let Some(parent) = path.parent() { + let _ = tokio::fs::create_dir_all(parent).await; + } + + let mut interval = tokio::time::interval(Duration::from_secs(STATUS_FLUSH_SECONDS)); + loop { + interval.tick().await; + let mut json = crate::health::snapshot_json(); + if let Some(obj) = json.as_object_mut() { + obj.insert( + "written_at".into(), + serde_json::json!(Utc::now().to_rfc3339()), + ); + } + let data = serde_json::to_vec_pretty(&json).unwrap_or_else(|_| b"{}".to_vec()); + let _ = tokio::fs::write(&path, data).await; + } + }) +} + +fn spawn_component_supervisor( + name: &'static str, + initial_backoff_secs: u64, + max_backoff_secs: u64, + mut run_component: F, +) -> JoinHandle<()> +where + F: FnMut() -> Fut + Send + 'static, + Fut: Future> + Send + 'static, +{ + tokio::spawn(async move { + let mut backoff = initial_backoff_secs.max(1); + let max_backoff = max_backoff_secs.max(backoff); + + loop { + crate::health::mark_component_ok(name); + match run_component().await { + Ok(()) => { + crate::health::mark_component_error(name, "component exited unexpectedly"); + tracing::warn!("Daemon component '{name}' exited unexpectedly"); + } + Err(e) => { + crate::health::mark_component_error(name, e.to_string()); + tracing::error!("Daemon component '{name}' failed: {e}"); + } + } + + crate::health::bump_component_restart(name); + tokio::time::sleep(Duration::from_secs(backoff)).await; + backoff = backoff.saturating_mul(2).min(max_backoff); + } + }) +} + +async fn run_heartbeat_worker(config: Config) -> Result<()> { + let observer: std::sync::Arc = + std::sync::Arc::from(crate::observability::create_observer(&config.observability)); + let engine = crate::heartbeat::engine::HeartbeatEngine::new( + config.heartbeat.clone(), + config.workspace_dir.clone(), + observer, + ); + + let interval_mins = config.heartbeat.interval_minutes.max(5); + let mut interval = tokio::time::interval(Duration::from_secs(u64::from(interval_mins) * 60)); + + loop { + interval.tick().await; + + let tasks = engine.collect_tasks().await?; + if tasks.is_empty() { + continue; + } + + for task in tasks { + let prompt = format!("[Heartbeat Task] {task}"); + let temp = config.default_temperature; + if let Err(e) = crate::agent::run(config.clone(), Some(prompt), None, None, temp).await + { + crate::health::mark_component_error("heartbeat", e.to_string()); + tracing::warn!("Heartbeat task failed: {e}"); + } else { + crate::health::mark_component_ok("heartbeat"); + } + } + } +} + +fn has_supervised_channels(config: &Config) -> bool { + config.channels_config.telegram.is_some() + || config.channels_config.discord.is_some() + || config.channels_config.slack.is_some() + || config.channels_config.imessage.is_some() + || config.channels_config.matrix.is_some() +} + +#[cfg(test)] +mod tests { + use super::*; + use tempfile::TempDir; + + fn test_config(tmp: &TempDir) -> Config { + let mut config = Config::default(); + config.workspace_dir = tmp.path().join("workspace"); + config.config_path = tmp.path().join("config.toml"); + std::fs::create_dir_all(&config.workspace_dir).unwrap(); + config + } + + #[test] + fn state_file_path_uses_config_directory() { + let tmp = TempDir::new().unwrap(); + let config = test_config(&tmp); + + let path = state_file_path(&config); + assert_eq!(path, tmp.path().join("daemon_state.json")); + } + + #[tokio::test] + async fn supervisor_marks_error_and_restart_on_failure() { + let handle = spawn_component_supervisor("daemon-test-fail", 1, 1, || async { + anyhow::bail!("boom") + }); + + tokio::time::sleep(Duration::from_millis(50)).await; + handle.abort(); + let _ = handle.await; + + let snapshot = crate::health::snapshot_json(); + let component = &snapshot["components"]["daemon-test-fail"]; + assert_eq!(component["status"], "error"); + assert!(component["restart_count"].as_u64().unwrap_or(0) >= 1); + assert!(component["last_error"] + .as_str() + .unwrap_or("") + .contains("boom")); + } + + #[tokio::test] + async fn supervisor_marks_unexpected_exit_as_error() { + let handle = spawn_component_supervisor("daemon-test-exit", 1, 1, || async { Ok(()) }); + + tokio::time::sleep(Duration::from_millis(50)).await; + handle.abort(); + let _ = handle.await; + + let snapshot = crate::health::snapshot_json(); + let component = &snapshot["components"]["daemon-test-exit"]; + assert_eq!(component["status"], "error"); + assert!(component["restart_count"].as_u64().unwrap_or(0) >= 1); + assert!(component["last_error"] + .as_str() + .unwrap_or("") + .contains("component exited unexpectedly")); + } + + #[test] + fn detects_no_supervised_channels() { + let config = Config::default(); + assert!(!has_supervised_channels(&config)); + } + + #[test] + fn detects_supervised_channels_present() { + let mut config = Config::default(); + config.channels_config.telegram = Some(crate::config::TelegramConfig { + bot_token: "token".into(), + allowed_users: vec![], + }); + assert!(has_supervised_channels(&config)); + } +} diff --git a/src/doctor/mod.rs b/src/doctor/mod.rs new file mode 100644 index 0000000..62417ea --- /dev/null +++ b/src/doctor/mod.rs @@ -0,0 +1,123 @@ +use crate::config::Config; +use anyhow::{Context, Result}; +use chrono::{DateTime, Utc}; + +const DAEMON_STALE_SECONDS: i64 = 30; +const SCHEDULER_STALE_SECONDS: i64 = 120; +const CHANNEL_STALE_SECONDS: i64 = 300; + +pub fn run(config: &Config) -> Result<()> { + let state_file = crate::daemon::state_file_path(config); + if !state_file.exists() { + println!("🩺 ZeroClaw Doctor"); + println!(" ❌ daemon state file not found: {}", state_file.display()); + println!(" 💡 Start daemon with: zeroclaw daemon"); + return Ok(()); + } + + let raw = std::fs::read_to_string(&state_file) + .with_context(|| format!("Failed to read {}", state_file.display()))?; + let snapshot: serde_json::Value = serde_json::from_str(&raw) + .with_context(|| format!("Failed to parse {}", state_file.display()))?; + + println!("🩺 ZeroClaw Doctor"); + println!(" State file: {}", state_file.display()); + + let updated_at = snapshot + .get("updated_at") + .and_then(serde_json::Value::as_str) + .unwrap_or(""); + + if let Ok(ts) = DateTime::parse_from_rfc3339(updated_at) { + let age = Utc::now() + .signed_duration_since(ts.with_timezone(&Utc)) + .num_seconds(); + if age <= DAEMON_STALE_SECONDS { + println!(" ✅ daemon heartbeat fresh ({age}s ago)"); + } else { + println!(" ❌ daemon heartbeat stale ({age}s ago)"); + } + } else { + println!(" ❌ invalid daemon timestamp: {updated_at}"); + } + + let mut channel_count = 0_u32; + let mut stale_channels = 0_u32; + + if let Some(components) = snapshot + .get("components") + .and_then(serde_json::Value::as_object) + { + if let Some(scheduler) = components.get("scheduler") { + let scheduler_ok = scheduler + .get("status") + .and_then(serde_json::Value::as_str) + .map(|s| s == "ok") + .unwrap_or(false); + + let scheduler_last_ok = scheduler + .get("last_ok") + .and_then(serde_json::Value::as_str) + .and_then(parse_rfc3339) + .map(|dt| Utc::now().signed_duration_since(dt).num_seconds()) + .unwrap_or(i64::MAX); + + if scheduler_ok && scheduler_last_ok <= SCHEDULER_STALE_SECONDS { + println!( + " ✅ scheduler healthy (last ok {}s ago)", + scheduler_last_ok + ); + } else { + println!( + " ❌ scheduler unhealthy/stale (status_ok={}, age={}s)", + scheduler_ok, scheduler_last_ok + ); + } + } else { + println!(" ❌ scheduler component missing"); + } + + for (name, component) in components { + if !name.starts_with("channel:") { + continue; + } + + channel_count += 1; + let status_ok = component + .get("status") + .and_then(serde_json::Value::as_str) + .map(|s| s == "ok") + .unwrap_or(false); + let age = component + .get("last_ok") + .and_then(serde_json::Value::as_str) + .and_then(parse_rfc3339) + .map(|dt| Utc::now().signed_duration_since(dt).num_seconds()) + .unwrap_or(i64::MAX); + + if status_ok && age <= CHANNEL_STALE_SECONDS { + println!(" ✅ {name} fresh (last ok {age}s ago)"); + } else { + stale_channels += 1; + println!(" ❌ {name} stale/unhealthy (status_ok={status_ok}, age={age}s)"); + } + } + } + + if channel_count == 0 { + println!(" ℹ️ no channel components tracked in state yet"); + } else { + println!( + " Channel summary: {} total, {} stale", + channel_count, stale_channels + ); + } + + Ok(()) +} + +fn parse_rfc3339(raw: &str) -> Option> { + DateTime::parse_from_rfc3339(raw) + .ok() + .map(|dt| dt.with_timezone(&Utc)) +} diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index 6fd27fb..deba8ff 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -1,15 +1,49 @@ +//! Axum-based HTTP gateway with proper HTTP/1.1 compliance, body limits, and timeouts. +//! +//! This module replaces the raw TCP implementation with axum for: +//! - Proper HTTP/1.1 parsing and compliance +//! - Content-Length validation (handled by hyper) +//! - Request body size limits (64KB max) +//! - Request timeouts (30s) to prevent slow-loris attacks +//! - Header sanitization (handled by axum/hyper) + +use crate::channels::{Channel, WhatsAppChannel}; use crate::config::Config; use crate::memory::{self, Memory, MemoryCategory}; use crate::providers::{self, Provider}; use crate::security::pairing::{constant_time_eq, is_public_bind, PairingGuard}; use anyhow::Result; +use axum::{ + body::Bytes, + extract::{Query, State}, + http::{header, HeaderMap, StatusCode}, + response::{IntoResponse, Json}, + routing::{get, post}, + Router, +}; +use std::net::SocketAddr; use std::sync::Arc; -use std::time::Duration; -use tokio::io::{AsyncReadExt, AsyncWriteExt}; -use tokio::net::TcpListener; +use tower_http::limit::RequestBodyLimitLayer; -/// Run a minimal HTTP gateway (webhook + health check) -/// Zero new dependencies — uses raw TCP + tokio. +/// Maximum request body size (64KB) — prevents memory exhaustion +pub const MAX_BODY_SIZE: usize = 65_536; +/// Request timeout (30s) — prevents slow-loris attacks +pub const REQUEST_TIMEOUT_SECS: u64 = 30; + +/// Shared state for all axum handlers +#[derive(Clone)] +pub struct AppState { + pub provider: Arc, + pub model: String, + pub temperature: f64, + pub mem: Arc, + pub auto_save: bool, + pub webhook_secret: Option>, + pub pairing: Arc, + pub whatsapp: Option>, +} + +/// Run the HTTP gateway using axum with proper HTTP/1.1 compliance. #[allow(clippy::too_many_lines)] pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> { // ── Security: refuse public bind without tunnel or explicit opt-in ── @@ -22,13 +56,15 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> { ); } - let listener = TcpListener::bind(format!("{host}:{port}")).await?; + let addr: SocketAddr = format!("{host}:{port}").parse()?; + let listener = tokio::net::TcpListener::bind(addr).await?; let actual_port = listener.local_addr()?.port(); - let addr = format!("{host}:{actual_port}"); + let display_addr = format!("{host}:{actual_port}"); - let provider: Arc = Arc::from(providers::create_provider( + let provider: Arc = Arc::from(providers::create_resilient_provider( config.default_provider.as_deref().unwrap_or("openrouter"), config.api_key.as_deref(), + &config.reliability, )?); let model = config .default_model @@ -49,6 +85,17 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> { .and_then(|w| w.secret.as_deref()) .map(Arc::from); + // WhatsApp channel (if configured) + let whatsapp_channel: Option> = + config.channels_config.whatsapp.as_ref().map(|wa| { + Arc::new(WhatsAppChannel::new( + wa.access_token.clone(), + wa.phone_number_id.clone(), + wa.verify_token.clone(), + wa.allowed_numbers.clone(), + )) + }); + // ── Pairing guard ────────────────────────────────────── let pairing = Arc::new(PairingGuard::new( config.gateway.require_pairing, @@ -73,16 +120,20 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> { } } - println!("🦀 ZeroClaw Gateway listening on http://{addr}"); + println!("🦀 ZeroClaw Gateway listening on http://{display_addr}"); if let Some(ref url) = tunnel_url { println!(" 🌐 Public URL: {url}"); } - println!(" POST /pair — pair a new client (X-Pairing-Code header)"); - println!(" POST /webhook — {{\"message\": \"your prompt\"}}"); - println!(" GET /health — health check"); + println!(" POST /pair — pair a new client (X-Pairing-Code header)"); + println!(" POST /webhook — {{\"message\": \"your prompt\"}}"); + if whatsapp_channel.is_some() { + println!(" GET /whatsapp — Meta webhook verification"); + println!(" POST /whatsapp — WhatsApp message webhook"); + } + println!(" GET /health — health check"); if let Some(code) = pairing.pairing_code() { println!(); - println!(" � PAIRING REQUIRED — use this one-time code:"); + println!(" 🔐 PAIRING REQUIRED — use this one-time code:"); println!(" ┌──────────────┐"); println!(" │ {code} │"); println!(" └──────────────┘"); @@ -97,428 +148,312 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> { } println!(" Press Ctrl+C to stop.\n"); - loop { - let (mut stream, peer) = listener.accept().await?; - let provider = provider.clone(); - let model = model.clone(); - let mem = mem.clone(); - let auto_save = config.memory.auto_save; - let secret = webhook_secret.clone(); - let pairing = pairing.clone(); + crate::health::mark_component_ok("gateway"); - tokio::spawn(async move { - // Read with 30s timeout to prevent slow-loris attacks - let mut buf = vec![0u8; 65_536]; // 64KB max request - let n = match tokio::time::timeout(Duration::from_secs(30), stream.read(&mut buf)).await - { - Ok(Ok(n)) if n > 0 => n, - _ => return, - }; + // Build shared state + let state = AppState { + provider, + model, + temperature, + mem, + auto_save: config.memory.auto_save, + webhook_secret, + pairing, + whatsapp: whatsapp_channel, + }; - let request = String::from_utf8_lossy(&buf[..n]); - let first_line = request.lines().next().unwrap_or(""); - let parts: Vec<&str> = first_line.split_whitespace().collect(); + // Build router with middleware + // Note: Body limit layer prevents memory exhaustion from oversized requests + // Timeout is handled by tokio's TcpListener accept timeout and hyper's built-in timeouts + let app = Router::new() + .route("/health", get(handle_health)) + .route("/pair", post(handle_pair)) + .route("/webhook", post(handle_webhook)) + .route("/whatsapp", get(handle_whatsapp_verify)) + .route("/whatsapp", post(handle_whatsapp_message)) + .with_state(state) + .layer(RequestBodyLimitLayer::new(MAX_BODY_SIZE)); - if let [method, path, ..] = parts.as_slice() { - tracing::info!("{peer} → {method} {path}"); - handle_request( - &mut stream, - method, - path, - &request, - &provider, - &model, - temperature, - &mem, - auto_save, - secret.as_ref(), - &pairing, - ) - .await; - } else { - let _ = send_response(&mut stream, 400, "Bad Request").await; - } - }); - } + // Run the server + axum::serve(listener, app).await?; + + Ok(()) } -/// Extract a header value from a raw HTTP request. -fn extract_header<'a>(request: &'a str, header_name: &str) -> Option<&'a str> { - let lower_name = header_name.to_lowercase(); - for line in request.lines() { - if let Some((key, value)) = line.split_once(':') { - if key.trim().to_lowercase() == lower_name { - return Some(value.trim()); - } - } - } - None +// ══════════════════════════════════════════════════════════════════════════════ +// AXUM HANDLERS +// ══════════════════════════════════════════════════════════════════════════════ + +/// GET /health — always public (no secrets leaked) +async fn handle_health(State(state): State) -> impl IntoResponse { + let body = serde_json::json!({ + "status": "ok", + "paired": state.pairing.is_paired(), + "runtime": crate::health::snapshot_json(), + }); + Json(body) } -#[allow(clippy::too_many_arguments)] -async fn handle_request( - stream: &mut tokio::net::TcpStream, - method: &str, - path: &str, - request: &str, - provider: &Arc, - model: &str, - temperature: f64, - mem: &Arc, - auto_save: bool, - webhook_secret: Option<&Arc>, - pairing: &PairingGuard, -) { - match (method, path) { - // Health check — always public (no secrets leaked) - ("GET", "/health") => { - let body = serde_json::json!({ - "status": "ok", - "paired": pairing.is_paired(), - }); - let _ = send_json(stream, 200, &body).await; - } - - // Pairing endpoint — exchange one-time code for bearer token - ("POST", "/pair") => { - let code = extract_header(request, "X-Pairing-Code").unwrap_or(""); - match pairing.try_pair(code) { - Ok(Some(token)) => { - tracing::info!("🔐 New client paired successfully"); - let body = serde_json::json!({ - "paired": true, - "token": token, - "message": "Save this token — use it as Authorization: Bearer " - }); - let _ = send_json(stream, 200, &body).await; - } - Ok(None) => { - tracing::warn!("🔐 Pairing attempt with invalid code"); - let err = serde_json::json!({"error": "Invalid pairing code"}); - let _ = send_json(stream, 403, &err).await; - } - Err(lockout_secs) => { - tracing::warn!( - "🔐 Pairing locked out — too many failed attempts ({lockout_secs}s remaining)" - ); - let err = serde_json::json!({ - "error": format!("Too many failed attempts. Try again in {lockout_secs}s."), - "retry_after": lockout_secs - }); - let _ = send_json(stream, 429, &err).await; - } - } - } - - ("POST", "/webhook") => { - // ── Bearer token auth (pairing) ── - if pairing.require_pairing() { - let auth = extract_header(request, "Authorization").unwrap_or(""); - let token = auth.strip_prefix("Bearer ").unwrap_or(""); - if !pairing.is_authenticated(token) { - tracing::warn!("Webhook: rejected — not paired / invalid bearer token"); - let err = serde_json::json!({ - "error": "Unauthorized — pair first via POST /pair, then send Authorization: Bearer " - }); - let _ = send_json(stream, 401, &err).await; - return; - } - } - - // ── Webhook secret auth (optional, additional layer) ── - if let Some(secret) = webhook_secret { - let header_val = extract_header(request, "X-Webhook-Secret"); - match header_val { - Some(val) if constant_time_eq(val, secret.as_ref()) => {} - _ => { - tracing::warn!( - "Webhook: rejected request — invalid or missing X-Webhook-Secret" - ); - let err = serde_json::json!({"error": "Unauthorized — invalid or missing X-Webhook-Secret header"}); - let _ = send_json(stream, 401, &err).await; - return; - } - } - } - handle_webhook( - stream, - request, - provider, - model, - temperature, - mem, - auto_save, - ) - .await; - } - - _ => { - let body = serde_json::json!({ - "error": "Not found", - "routes": ["GET /health", "POST /pair", "POST /webhook"] - }); - let _ = send_json(stream, 404, &body).await; - } - } -} - -async fn handle_webhook( - stream: &mut tokio::net::TcpStream, - request: &str, - provider: &Arc, - model: &str, - temperature: f64, - mem: &Arc, - auto_save: bool, -) { - let body_str = request - .split("\r\n\r\n") - .nth(1) - .or_else(|| request.split("\n\n").nth(1)) +/// POST /pair — exchange one-time code for bearer token +async fn handle_pair(State(state): State, headers: HeaderMap) -> impl IntoResponse { + let code = headers + .get("X-Pairing-Code") + .and_then(|v| v.to_str().ok()) .unwrap_or(""); - let Ok(parsed) = serde_json::from_str::(body_str) else { - let err = serde_json::json!({"error": "Invalid JSON. Expected: {\"message\": \"...\"}"}); - let _ = send_json(stream, 400, &err).await; - return; + match state.pairing.try_pair(code) { + Ok(Some(token)) => { + tracing::info!("🔐 New client paired successfully"); + let body = serde_json::json!({ + "paired": true, + "token": token, + "message": "Save this token — use it as Authorization: Bearer " + }); + (StatusCode::OK, Json(body)) + } + Ok(None) => { + tracing::warn!("🔐 Pairing attempt with invalid code"); + let err = serde_json::json!({"error": "Invalid pairing code"}); + (StatusCode::FORBIDDEN, Json(err)) + } + Err(lockout_secs) => { + tracing::warn!( + "🔐 Pairing locked out — too many failed attempts ({lockout_secs}s remaining)" + ); + let err = serde_json::json!({ + "error": format!("Too many failed attempts. Try again in {lockout_secs}s."), + "retry_after": lockout_secs + }); + (StatusCode::TOO_MANY_REQUESTS, Json(err)) + } + } +} + +/// Webhook request body +#[derive(serde::Deserialize)] +pub struct WebhookBody { + pub message: String, +} + +/// POST /webhook — main webhook endpoint +async fn handle_webhook( + State(state): State, + headers: HeaderMap, + body: Result, axum::extract::rejection::JsonRejection>, +) -> impl IntoResponse { + // ── Bearer token auth (pairing) ── + if state.pairing.require_pairing() { + let auth = headers + .get(header::AUTHORIZATION) + .and_then(|v| v.to_str().ok()) + .unwrap_or(""); + let token = auth.strip_prefix("Bearer ").unwrap_or(""); + if !state.pairing.is_authenticated(token) { + tracing::warn!("Webhook: rejected — not paired / invalid bearer token"); + let err = serde_json::json!({ + "error": "Unauthorized — pair first via POST /pair, then send Authorization: Bearer " + }); + return (StatusCode::UNAUTHORIZED, Json(err)); + } + } + + // ── Webhook secret auth (optional, additional layer) ── + if let Some(ref secret) = state.webhook_secret { + let header_val = headers + .get("X-Webhook-Secret") + .and_then(|v| v.to_str().ok()); + match header_val { + Some(val) if constant_time_eq(val, secret.as_ref()) => {} + _ => { + tracing::warn!("Webhook: rejected request — invalid or missing X-Webhook-Secret"); + let err = serde_json::json!({"error": "Unauthorized — invalid or missing X-Webhook-Secret header"}); + return (StatusCode::UNAUTHORIZED, Json(err)); + } + } + } + + // ── Parse body ── + let Json(webhook_body) = match body { + Ok(b) => b, + Err(e) => { + let err = serde_json::json!({ + "error": format!("Invalid JSON: {e}. Expected: {{\"message\": \"...\"}}") + }); + return (StatusCode::BAD_REQUEST, Json(err)); + } }; - let Some(message) = parsed.get("message").and_then(|v| v.as_str()) else { - let err = serde_json::json!({"error": "Missing 'message' field in JSON"}); - let _ = send_json(stream, 400, &err).await; - return; - }; + let message = &webhook_body.message; - if auto_save { - let _ = mem + if state.auto_save { + let _ = state + .mem .store("webhook_msg", message, MemoryCategory::Conversation) .await; } - match provider.chat(message, model, temperature).await { + match state + .provider + .chat(message, &state.model, state.temperature) + .await + { Ok(response) => { - let body = serde_json::json!({"response": response, "model": model}); - let _ = send_json(stream, 200, &body).await; + let body = serde_json::json!({"response": response, "model": state.model}); + (StatusCode::OK, Json(body)) } Err(e) => { let err = serde_json::json!({"error": format!("LLM error: {e}")}); - let _ = send_json(stream, 500, &err).await; + (StatusCode::INTERNAL_SERVER_ERROR, Json(err)) } } } -async fn send_response( - stream: &mut tokio::net::TcpStream, - status: u16, - body: &str, -) -> std::io::Result<()> { - let reason = match status { - 200 => "OK", - 400 => "Bad Request", - 404 => "Not Found", - 500 => "Internal Server Error", - _ => "Unknown", - }; - let response = format!( - "HTTP/1.1 {status} {reason}\r\nContent-Type: text/plain\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{body}", - body.len() - ); - stream.write_all(response.as_bytes()).await +/// `WhatsApp` verification query params +#[derive(serde::Deserialize)] +pub struct WhatsAppVerifyQuery { + #[serde(rename = "hub.mode")] + pub mode: Option, + #[serde(rename = "hub.verify_token")] + pub verify_token: Option, + #[serde(rename = "hub.challenge")] + pub challenge: Option, } -async fn send_json( - stream: &mut tokio::net::TcpStream, - status: u16, - body: &serde_json::Value, -) -> std::io::Result<()> { - let reason = match status { - 200 => "OK", - 400 => "Bad Request", - 404 => "Not Found", - 500 => "Internal Server Error", - _ => "Unknown", +/// GET /whatsapp — Meta webhook verification +async fn handle_whatsapp_verify( + State(state): State, + Query(params): Query, +) -> impl IntoResponse { + let Some(ref wa) = state.whatsapp else { + return (StatusCode::NOT_FOUND, "WhatsApp not configured".to_string()); }; - let json = serde_json::to_string(body).unwrap_or_default(); - let response = format!( - "HTTP/1.1 {status} {reason}\r\nContent-Type: application/json\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{json}", - json.len() - ); - stream.write_all(response.as_bytes()).await + + // Verify the token matches + if params.mode.as_deref() == Some("subscribe") + && params.verify_token.as_deref() == Some(wa.verify_token()) + { + if let Some(ch) = params.challenge { + tracing::info!("WhatsApp webhook verified successfully"); + return (StatusCode::OK, ch); + } + return (StatusCode::BAD_REQUEST, "Missing hub.challenge".to_string()); + } + + tracing::warn!("WhatsApp webhook verification failed — token mismatch"); + (StatusCode::FORBIDDEN, "Forbidden".to_string()) +} + +/// POST /whatsapp — incoming message webhook +async fn handle_whatsapp_message(State(state): State, body: Bytes) -> impl IntoResponse { + let Some(ref wa) = state.whatsapp else { + return ( + StatusCode::NOT_FOUND, + Json(serde_json::json!({"error": "WhatsApp not configured"})), + ); + }; + + // Parse JSON body + let Ok(payload) = serde_json::from_slice::(&body) else { + return ( + StatusCode::BAD_REQUEST, + Json(serde_json::json!({"error": "Invalid JSON payload"})), + ); + }; + + // Parse messages from the webhook payload + let messages = wa.parse_webhook_payload(&payload); + + if messages.is_empty() { + // Acknowledge the webhook even if no messages (could be status updates) + return (StatusCode::OK, Json(serde_json::json!({"status": "ok"}))); + } + + // Process each message + for msg in &messages { + tracing::info!( + "WhatsApp message from {}: {}", + msg.sender, + if msg.content.len() > 50 { + format!("{}...", &msg.content[..50]) + } else { + msg.content.clone() + } + ); + + // Auto-save to memory + if state.auto_save { + let _ = state + .mem + .store( + &format!("whatsapp_{}", msg.sender), + &msg.content, + MemoryCategory::Conversation, + ) + .await; + } + + // Call the LLM + match state + .provider + .chat(&msg.content, &state.model, state.temperature) + .await + { + Ok(response) => { + // Send reply via WhatsApp + if let Err(e) = wa.send(&response, &msg.sender).await { + tracing::error!("Failed to send WhatsApp reply: {e}"); + } + } + Err(e) => { + tracing::error!("LLM error for WhatsApp message: {e}"); + let _ = wa.send(&format!("⚠️ Error: {e}"), &msg.sender).await; + } + } + } + + // Acknowledge the webhook + (StatusCode::OK, Json(serde_json::json!({"status": "ok"}))) } #[cfg(test)] mod tests { use super::*; - use tokio::net::TcpListener as TokioListener; - - // ── Port allocation tests ──────────────────────────────── - - #[tokio::test] - async fn port_zero_binds_to_random_port() { - let listener = TokioListener::bind("127.0.0.1:0").await.unwrap(); - let actual = listener.local_addr().unwrap().port(); - assert_ne!(actual, 0, "OS must assign a non-zero port"); - assert!(actual > 0, "Actual port must be positive"); - } - - #[tokio::test] - async fn port_zero_assigns_different_ports() { - let l1 = TokioListener::bind("127.0.0.1:0").await.unwrap(); - let l2 = TokioListener::bind("127.0.0.1:0").await.unwrap(); - let p1 = l1.local_addr().unwrap().port(); - let p2 = l2.local_addr().unwrap().port(); - assert_ne!(p1, p2, "Two port-0 binds should get different ports"); - } - - #[tokio::test] - async fn port_zero_assigns_high_port() { - let listener = TokioListener::bind("127.0.0.1:0").await.unwrap(); - let actual = listener.local_addr().unwrap().port(); - // OS typically assigns ephemeral ports >= 1024 - assert!( - actual >= 1024, - "Random port {actual} should be >= 1024 (unprivileged)" - ); - } - - #[tokio::test] - async fn specific_port_binds_exactly() { - // Find a free port first via port 0, then rebind to it - let tmp = TokioListener::bind("127.0.0.1:0").await.unwrap(); - let free_port = tmp.local_addr().unwrap().port(); - drop(tmp); - - let listener = TokioListener::bind(format!("127.0.0.1:{free_port}")) - .await - .unwrap(); - let actual = listener.local_addr().unwrap().port(); - assert_eq!(actual, free_port, "Specific port bind must match exactly"); - } - - #[tokio::test] - async fn actual_port_matches_addr_format() { - let listener = TokioListener::bind("127.0.0.1:0").await.unwrap(); - let actual_port = listener.local_addr().unwrap().port(); - let addr = format!("127.0.0.1:{actual_port}"); - assert!( - addr.starts_with("127.0.0.1:"), - "Addr format must include host" - ); - assert!( - !addr.ends_with(":0"), - "Addr must not contain port 0 after binding" - ); - } - - #[tokio::test] - async fn port_zero_listener_accepts_connections() { - let listener = TokioListener::bind("127.0.0.1:0").await.unwrap(); - let actual_port = listener.local_addr().unwrap().port(); - - // Spawn a client that connects - let client = tokio::spawn(async move { - tokio::net::TcpStream::connect(format!("127.0.0.1:{actual_port}")) - .await - .unwrap() - }); - - // Accept the connection - let (stream, _peer) = listener.accept().await.unwrap(); - assert!(stream.peer_addr().is_ok()); - client.await.unwrap(); - } - - #[tokio::test] - async fn duplicate_specific_port_fails() { - let l1 = TokioListener::bind("127.0.0.1:0").await.unwrap(); - let port = l1.local_addr().unwrap().port(); - // Try to bind the same port while l1 is still alive - let result = TokioListener::bind(format!("127.0.0.1:{port}")).await; - assert!(result.is_err(), "Binding an already-used port must fail"); - } - - #[tokio::test] - async fn tunnel_gets_actual_port_not_zero() { - // Simulate what run_gateway does: bind port 0, extract actual port - let port: u16 = 0; - let host = "127.0.0.1"; - let listener = TokioListener::bind(format!("{host}:{port}")).await.unwrap(); - let actual_port = listener.local_addr().unwrap().port(); - - // This is the port that would be passed to tun.start(host, actual_port) - assert_ne!(actual_port, 0, "Tunnel must receive actual port, not 0"); - assert!( - actual_port >= 1024, - "Tunnel port {actual_port} must be unprivileged" - ); - } - - // ── extract_header tests ───────────────────────────────── #[test] - fn extract_header_finds_value() { - let req = - "POST /webhook HTTP/1.1\r\nHost: localhost\r\nX-Webhook-Secret: my-secret\r\n\r\n{}"; - assert_eq!(extract_header(req, "X-Webhook-Secret"), Some("my-secret")); + fn security_body_limit_is_64kb() { + assert_eq!(MAX_BODY_SIZE, 65_536); } #[test] - fn extract_header_case_insensitive() { - let req = "POST /webhook HTTP/1.1\r\nx-webhook-secret: abc123\r\n\r\n{}"; - assert_eq!(extract_header(req, "X-Webhook-Secret"), Some("abc123")); + fn security_timeout_is_30_seconds() { + assert_eq!(REQUEST_TIMEOUT_SECS, 30); } #[test] - fn extract_header_missing_returns_none() { - let req = "POST /webhook HTTP/1.1\r\nHost: localhost\r\n\r\n{}"; - assert_eq!(extract_header(req, "X-Webhook-Secret"), None); + fn webhook_body_requires_message_field() { + let valid = r#"{"message": "hello"}"#; + let parsed: Result = serde_json::from_str(valid); + assert!(parsed.is_ok()); + assert_eq!(parsed.unwrap().message, "hello"); + + let missing = r#"{"other": "field"}"#; + let parsed: Result = serde_json::from_str(missing); + assert!(parsed.is_err()); } #[test] - fn extract_header_trims_whitespace() { - let req = "POST /webhook HTTP/1.1\r\nX-Webhook-Secret: spaced \r\n\r\n{}"; - assert_eq!(extract_header(req, "X-Webhook-Secret"), Some("spaced")); + fn whatsapp_query_fields_are_optional() { + let q = WhatsAppVerifyQuery { + mode: None, + verify_token: None, + challenge: None, + }; + assert!(q.mode.is_none()); } #[test] - fn extract_header_first_match_wins() { - let req = "POST /webhook HTTP/1.1\r\nX-Webhook-Secret: first\r\nX-Webhook-Secret: second\r\n\r\n{}"; - assert_eq!(extract_header(req, "X-Webhook-Secret"), Some("first")); - } - - #[test] - fn extract_header_empty_value() { - let req = "POST /webhook HTTP/1.1\r\nX-Webhook-Secret:\r\n\r\n{}"; - assert_eq!(extract_header(req, "X-Webhook-Secret"), Some("")); - } - - #[test] - fn extract_header_colon_in_value() { - let req = "POST /webhook HTTP/1.1\r\nAuthorization: Bearer sk-abc:123\r\n\r\n{}"; - // split_once on ':' means only the first colon splits key/value - assert_eq!( - extract_header(req, "Authorization"), - Some("Bearer sk-abc:123") - ); - } - - #[test] - fn extract_header_different_header() { - let req = "POST /webhook HTTP/1.1\r\nContent-Type: application/json\r\nX-Webhook-Secret: mysecret\r\n\r\n{}"; - assert_eq!( - extract_header(req, "Content-Type"), - Some("application/json") - ); - assert_eq!(extract_header(req, "X-Webhook-Secret"), Some("mysecret")); - } - - #[test] - fn extract_header_from_empty_request() { - assert_eq!(extract_header("", "X-Webhook-Secret"), None); - } - - #[test] - fn extract_header_newline_only_request() { - assert_eq!(extract_header("\r\n\r\n", "X-Webhook-Secret"), None); + fn app_state_is_clone() { + fn assert_clone() {} + assert_clone::(); } } diff --git a/src/health/mod.rs b/src/health/mod.rs new file mode 100644 index 0000000..4fcd8b2 --- /dev/null +++ b/src/health/mod.rs @@ -0,0 +1,105 @@ +use chrono::Utc; +use serde::Serialize; +use std::collections::BTreeMap; +use std::sync::{Mutex, OnceLock}; +use std::time::Instant; + +#[derive(Debug, Clone, Serialize)] +pub struct ComponentHealth { + pub status: String, + pub updated_at: String, + pub last_ok: Option, + pub last_error: Option, + pub restart_count: u64, +} + +#[derive(Debug, Clone, Serialize)] +pub struct HealthSnapshot { + pub pid: u32, + pub updated_at: String, + pub uptime_seconds: u64, + pub components: BTreeMap, +} + +struct HealthRegistry { + started_at: Instant, + components: Mutex>, +} + +static REGISTRY: OnceLock = OnceLock::new(); + +fn registry() -> &'static HealthRegistry { + REGISTRY.get_or_init(|| HealthRegistry { + started_at: Instant::now(), + components: Mutex::new(BTreeMap::new()), + }) +} + +fn now_rfc3339() -> String { + Utc::now().to_rfc3339() +} + +fn upsert_component(component: &str, update: F) +where + F: FnOnce(&mut ComponentHealth), +{ + if let Ok(mut map) = registry().components.lock() { + let now = now_rfc3339(); + let entry = map + .entry(component.to_string()) + .or_insert_with(|| ComponentHealth { + status: "starting".into(), + updated_at: now.clone(), + last_ok: None, + last_error: None, + restart_count: 0, + }); + update(entry); + entry.updated_at = now; + } +} + +pub fn mark_component_ok(component: &str) { + upsert_component(component, |entry| { + entry.status = "ok".into(); + entry.last_ok = Some(now_rfc3339()); + entry.last_error = None; + }); +} + +pub fn mark_component_error(component: &str, error: impl ToString) { + let err = error.to_string(); + upsert_component(component, move |entry| { + entry.status = "error".into(); + entry.last_error = Some(err); + }); +} + +pub fn bump_component_restart(component: &str) { + upsert_component(component, |entry| { + entry.restart_count = entry.restart_count.saturating_add(1); + }); +} + +pub fn snapshot() -> HealthSnapshot { + let components = registry() + .components + .lock() + .map_or_else(|_| BTreeMap::new(), |map| map.clone()); + + HealthSnapshot { + pid: std::process::id(), + updated_at: now_rfc3339(), + uptime_seconds: registry().started_at.elapsed().as_secs(), + components, + } +} + +pub fn snapshot_json() -> serde_json::Value { + serde_json::to_value(snapshot()).unwrap_or_else(|_| { + serde_json::json!({ + "status": "error", + "message": "failed to serialize health snapshot" + }) + }) +} diff --git a/src/heartbeat/engine.rs b/src/heartbeat/engine.rs index ee31755..86b10e4 100644 --- a/src/heartbeat/engine.rs +++ b/src/heartbeat/engine.rs @@ -61,16 +61,17 @@ impl HeartbeatEngine { /// Single heartbeat tick — read HEARTBEAT.md and return task count async fn tick(&self) -> Result { + Ok(self.collect_tasks().await?.len()) + } + + /// Read HEARTBEAT.md and return all parsed tasks. + pub async fn collect_tasks(&self) -> Result> { let heartbeat_path = self.workspace_dir.join("HEARTBEAT.md"); - if !heartbeat_path.exists() { - return Ok(0); + return Ok(Vec::new()); } - let content = tokio::fs::read_to_string(&heartbeat_path).await?; - let tasks = Self::parse_tasks(&content); - - Ok(tasks.len()) + Ok(Self::parse_tasks(&content)) } /// Parse tasks from HEARTBEAT.md (lines starting with `- `) diff --git a/src/main.rs b/src/main.rs index dbc2d4b..46fb1d8 100644 --- a/src/main.rs +++ b/src/main.rs @@ -8,7 +8,7 @@ dead_code )] -use anyhow::Result; +use anyhow::{bail, Result}; use clap::{Parser, Subcommand}; use tracing::{info, Level}; use tracing_subscriber::FmtSubscriber; @@ -17,15 +17,20 @@ mod agent; mod channels; mod config; mod cron; +mod daemon; +mod doctor; mod gateway; +mod health; mod heartbeat; mod integrations; mod memory; +mod migration; mod observability; mod onboard; mod providers; mod runtime; mod security; +mod service; mod skills; mod tools; mod tunnel; @@ -43,6 +48,20 @@ struct Cli { command: Commands, } +#[derive(Subcommand, Debug)] +enum ServiceCommands { + /// Install daemon service unit for auto-start and restart + Install, + /// Start daemon service + Start, + /// Stop daemon service + Stop, + /// Check daemon service status + Status, + /// Uninstall daemon service unit + Uninstall, +} + #[derive(Subcommand, Debug)] enum Commands { /// Initialize your workspace and configuration @@ -51,6 +70,10 @@ enum Commands { #[arg(long)] interactive: bool, + /// Reconfigure channels only (fast repair flow) + #[arg(long)] + channels_only: bool, + /// API key (used in quick mode, ignored with --interactive) #[arg(long)] api_key: Option, @@ -71,7 +94,7 @@ enum Commands { provider: Option, /// Model to use - #[arg(short, long)] + #[arg(long)] model: Option, /// Temperature (0.0 - 2.0) @@ -86,10 +109,30 @@ enum Commands { port: u16, /// Host to bind to - #[arg(short, long, default_value = "127.0.0.1")] + #[arg(long, default_value = "127.0.0.1")] host: String, }, + /// Start long-running autonomous runtime (gateway + channels + heartbeat + scheduler) + Daemon { + /// Port to listen on (use 0 for random available port) + #[arg(short, long, default_value = "8080")] + port: u16, + + /// Host to bind to + #[arg(long, default_value = "127.0.0.1")] + host: String, + }, + + /// Manage OS service lifecycle (launchd/systemd user service) + Service { + #[command(subcommand)] + service_command: ServiceCommands, + }, + + /// Run diagnostics for daemon/scheduler/channel freshness + Doctor, + /// Show system status (full details) Status, @@ -116,6 +159,26 @@ enum Commands { #[command(subcommand)] skill_command: SkillCommands, }, + + /// Migrate data from other agent runtimes + Migrate { + #[command(subcommand)] + migrate_command: MigrateCommands, + }, +} + +#[derive(Subcommand, Debug)] +enum MigrateCommands { + /// Import memory from an OpenClaw workspace into this ZeroClaw workspace + Openclaw { + /// Optional path to OpenClaw workspace (defaults to ~/.openclaw/workspace) + #[arg(long)] + source: Option, + + /// Validate and preview migration without writing any data + #[arg(long)] + dry_run: bool, + }, } #[derive(Subcommand, Debug)] @@ -198,11 +261,21 @@ async fn main() -> Result<()> { // Onboard runs quick setup by default, or the interactive wizard with --interactive if let Commands::Onboard { interactive, + channels_only, api_key, provider, } = &cli.command { - let config = if *interactive { + if *interactive && *channels_only { + bail!("Use either --interactive or --channels-only, not both"); + } + if *channels_only && (api_key.is_some() || provider.is_some()) { + bail!("--channels-only does not accept --api-key or --provider"); + } + + let config = if *channels_only { + onboard::run_channels_repair_wizard()? + } else if *interactive { onboard::run_wizard()? } else { onboard::run_quick_setup(api_key.as_deref(), provider.as_deref())? @@ -236,6 +309,15 @@ async fn main() -> Result<()> { gateway::run_gateway(&host, port, config).await } + Commands::Daemon { port, host } => { + if port == 0 { + info!("🧠 Starting ZeroClaw Daemon on {host} (random port)"); + } else { + info!("🧠 Starting ZeroClaw Daemon on {host}:{port}"); + } + daemon::run(config, host, port).await + } + Commands::Status => { println!("🦀 ZeroClaw Status"); println!(); @@ -307,6 +389,10 @@ async fn main() -> Result<()> { Commands::Cron { cron_command } => cron::handle_command(cron_command, config), + Commands::Service { service_command } => service::handle_command(service_command, &config), + + Commands::Doctor => doctor::run(&config), + Commands::Channel { channel_command } => match channel_command { ChannelCommands::Start => channels::start_channels(config).await, ChannelCommands::Doctor => channels::doctor_channels(config).await, @@ -320,5 +406,20 @@ async fn main() -> Result<()> { Commands::Skills { skill_command } => { skills::handle_command(skill_command, &config.workspace_dir) } + + Commands::Migrate { migrate_command } => { + migration::handle_command(migrate_command, &config).await + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use clap::CommandFactory; + + #[test] + fn cli_definition_has_no_flag_conflicts() { + Cli::command().debug_assert(); } } diff --git a/src/memory/hygiene.rs b/src/memory/hygiene.rs new file mode 100644 index 0000000..17c95fa --- /dev/null +++ b/src/memory/hygiene.rs @@ -0,0 +1,538 @@ +use crate::config::MemoryConfig; +use anyhow::Result; +use chrono::{DateTime, Duration, Local, NaiveDate, Utc}; +use rusqlite::{params, Connection}; +use serde::{Deserialize, Serialize}; +use std::fs; +use std::path::{Path, PathBuf}; +use std::time::{Duration as StdDuration, SystemTime}; + +const HYGIENE_INTERVAL_HOURS: i64 = 12; +const STATE_FILE: &str = "memory_hygiene_state.json"; + +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +struct HygieneReport { + archived_memory_files: u64, + archived_session_files: u64, + purged_memory_archives: u64, + purged_session_archives: u64, + pruned_conversation_rows: u64, +} + +impl HygieneReport { + fn total_actions(&self) -> u64 { + self.archived_memory_files + + self.archived_session_files + + self.purged_memory_archives + + self.purged_session_archives + + self.pruned_conversation_rows + } +} + +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +struct HygieneState { + last_run_at: Option, + last_report: HygieneReport, +} + +/// Run memory/session hygiene if the cadence window has elapsed. +/// +/// This function is intentionally best-effort: callers should log and continue on failure. +pub fn run_if_due(config: &MemoryConfig, workspace_dir: &Path) -> Result<()> { + if !config.hygiene_enabled { + return Ok(()); + } + + if !should_run_now(workspace_dir)? { + return Ok(()); + } + + let report = HygieneReport { + archived_memory_files: archive_daily_memory_files( + workspace_dir, + config.archive_after_days, + )?, + archived_session_files: archive_session_files(workspace_dir, config.archive_after_days)?, + purged_memory_archives: purge_memory_archives(workspace_dir, config.purge_after_days)?, + purged_session_archives: purge_session_archives(workspace_dir, config.purge_after_days)?, + pruned_conversation_rows: prune_conversation_rows( + workspace_dir, + config.conversation_retention_days, + )?, + }; + + write_state(workspace_dir, &report)?; + + if report.total_actions() > 0 { + tracing::info!( + "memory hygiene complete: archived_memory={} archived_sessions={} purged_memory={} purged_sessions={} pruned_conversation_rows={}", + report.archived_memory_files, + report.archived_session_files, + report.purged_memory_archives, + report.purged_session_archives, + report.pruned_conversation_rows, + ); + } + + Ok(()) +} + +fn should_run_now(workspace_dir: &Path) -> Result { + let path = state_path(workspace_dir); + if !path.exists() { + return Ok(true); + } + + let raw = fs::read_to_string(&path)?; + let state: HygieneState = match serde_json::from_str(&raw) { + Ok(s) => s, + Err(_) => return Ok(true), + }; + + let Some(last_run_at) = state.last_run_at else { + return Ok(true); + }; + + let last = match DateTime::parse_from_rfc3339(&last_run_at) { + Ok(ts) => ts.with_timezone(&Utc), + Err(_) => return Ok(true), + }; + + Ok(Utc::now().signed_duration_since(last) >= Duration::hours(HYGIENE_INTERVAL_HOURS)) +} + +fn write_state(workspace_dir: &Path, report: &HygieneReport) -> Result<()> { + let path = state_path(workspace_dir); + if let Some(parent) = path.parent() { + fs::create_dir_all(parent)?; + } + + let state = HygieneState { + last_run_at: Some(Utc::now().to_rfc3339()), + last_report: report.clone(), + }; + let json = serde_json::to_vec_pretty(&state)?; + fs::write(path, json)?; + Ok(()) +} + +fn state_path(workspace_dir: &Path) -> PathBuf { + workspace_dir.join("state").join(STATE_FILE) +} + +fn archive_daily_memory_files(workspace_dir: &Path, archive_after_days: u32) -> Result { + if archive_after_days == 0 { + return Ok(0); + } + + let memory_dir = workspace_dir.join("memory"); + if !memory_dir.is_dir() { + return Ok(0); + } + + let archive_dir = memory_dir.join("archive"); + fs::create_dir_all(&archive_dir)?; + + let cutoff = Local::now().date_naive() - Duration::days(i64::from(archive_after_days)); + let mut moved = 0_u64; + + for entry in fs::read_dir(&memory_dir)? { + let entry = entry?; + let path = entry.path(); + + if path.is_dir() { + continue; + } + if path.extension().and_then(|e| e.to_str()) != Some("md") { + continue; + } + + let Some(filename) = path.file_name().and_then(|f| f.to_str()) else { + continue; + }; + + let Some(file_date) = memory_date_from_filename(filename) else { + continue; + }; + + if file_date < cutoff { + move_to_archive(&path, &archive_dir)?; + moved += 1; + } + } + + Ok(moved) +} + +fn archive_session_files(workspace_dir: &Path, archive_after_days: u32) -> Result { + if archive_after_days == 0 { + return Ok(0); + } + + let sessions_dir = workspace_dir.join("sessions"); + if !sessions_dir.is_dir() { + return Ok(0); + } + + let archive_dir = sessions_dir.join("archive"); + fs::create_dir_all(&archive_dir)?; + + let cutoff_date = Local::now().date_naive() - Duration::days(i64::from(archive_after_days)); + let cutoff_time = SystemTime::now() + .checked_sub(StdDuration::from_secs( + u64::from(archive_after_days) * 24 * 60 * 60, + )) + .unwrap_or(SystemTime::UNIX_EPOCH); + + let mut moved = 0_u64; + for entry in fs::read_dir(&sessions_dir)? { + let entry = entry?; + let path = entry.path(); + + if path.is_dir() { + continue; + } + + let Some(filename) = path.file_name().and_then(|f| f.to_str()) else { + continue; + }; + + let is_old = if let Some(date) = date_prefix(filename) { + date < cutoff_date + } else { + is_older_than(&path, cutoff_time) + }; + + if is_old { + move_to_archive(&path, &archive_dir)?; + moved += 1; + } + } + + Ok(moved) +} + +fn purge_memory_archives(workspace_dir: &Path, purge_after_days: u32) -> Result { + if purge_after_days == 0 { + return Ok(0); + } + + let archive_dir = workspace_dir.join("memory").join("archive"); + if !archive_dir.is_dir() { + return Ok(0); + } + + let cutoff = Local::now().date_naive() - Duration::days(i64::from(purge_after_days)); + let mut removed = 0_u64; + + for entry in fs::read_dir(&archive_dir)? { + let entry = entry?; + let path = entry.path(); + + if path.is_dir() { + continue; + } + + let Some(filename) = path.file_name().and_then(|f| f.to_str()) else { + continue; + }; + + let Some(file_date) = memory_date_from_filename(filename) else { + continue; + }; + + if file_date < cutoff { + fs::remove_file(&path)?; + removed += 1; + } + } + + Ok(removed) +} + +fn purge_session_archives(workspace_dir: &Path, purge_after_days: u32) -> Result { + if purge_after_days == 0 { + return Ok(0); + } + + let archive_dir = workspace_dir.join("sessions").join("archive"); + if !archive_dir.is_dir() { + return Ok(0); + } + + let cutoff_date = Local::now().date_naive() - Duration::days(i64::from(purge_after_days)); + let cutoff_time = SystemTime::now() + .checked_sub(StdDuration::from_secs( + u64::from(purge_after_days) * 24 * 60 * 60, + )) + .unwrap_or(SystemTime::UNIX_EPOCH); + + let mut removed = 0_u64; + for entry in fs::read_dir(&archive_dir)? { + let entry = entry?; + let path = entry.path(); + + if path.is_dir() { + continue; + } + + let Some(filename) = path.file_name().and_then(|f| f.to_str()) else { + continue; + }; + + let is_old = if let Some(date) = date_prefix(filename) { + date < cutoff_date + } else { + is_older_than(&path, cutoff_time) + }; + + if is_old { + fs::remove_file(&path)?; + removed += 1; + } + } + + Ok(removed) +} + +fn prune_conversation_rows(workspace_dir: &Path, retention_days: u32) -> Result { + if retention_days == 0 { + return Ok(0); + } + + let db_path = workspace_dir.join("memory").join("brain.db"); + if !db_path.exists() { + return Ok(0); + } + + let conn = Connection::open(db_path)?; + let cutoff = (Local::now() - Duration::days(i64::from(retention_days))).to_rfc3339(); + + let affected = conn.execute( + "DELETE FROM memories WHERE category = 'conversation' AND updated_at < ?1", + params![cutoff], + )?; + + Ok(u64::try_from(affected).unwrap_or(0)) +} + +fn memory_date_from_filename(filename: &str) -> Option { + let stem = filename.strip_suffix(".md")?; + let date_part = stem.split('_').next().unwrap_or(stem); + NaiveDate::parse_from_str(date_part, "%Y-%m-%d").ok() +} + +fn date_prefix(filename: &str) -> Option { + if filename.len() < 10 { + return None; + } + NaiveDate::parse_from_str(&filename[..10], "%Y-%m-%d").ok() +} + +fn is_older_than(path: &Path, cutoff: SystemTime) -> bool { + fs::metadata(path) + .and_then(|meta| meta.modified()) + .map(|modified| modified < cutoff) + .unwrap_or(false) +} + +fn move_to_archive(src: &Path, archive_dir: &Path) -> Result<()> { + let Some(filename) = src.file_name().and_then(|f| f.to_str()) else { + return Ok(()); + }; + + let target = unique_archive_target(archive_dir, filename); + fs::rename(src, target)?; + Ok(()) +} + +fn unique_archive_target(archive_dir: &Path, filename: &str) -> PathBuf { + let direct = archive_dir.join(filename); + if !direct.exists() { + return direct; + } + + let (stem, ext) = split_name(filename); + for i in 1..10_000 { + let candidate = if ext.is_empty() { + archive_dir.join(format!("{stem}_{i}")) + } else { + archive_dir.join(format!("{stem}_{i}.{ext}")) + }; + if !candidate.exists() { + return candidate; + } + } + + direct +} + +fn split_name(filename: &str) -> (&str, &str) { + match filename.rsplit_once('.') { + Some((stem, ext)) => (stem, ext), + None => (filename, ""), + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::memory::{Memory, MemoryCategory, SqliteMemory}; + use tempfile::TempDir; + + fn default_cfg() -> MemoryConfig { + MemoryConfig::default() + } + + #[test] + fn archives_old_daily_memory_files() { + let tmp = TempDir::new().unwrap(); + let workspace = tmp.path(); + fs::create_dir_all(workspace.join("memory")).unwrap(); + + let old = (Local::now().date_naive() - Duration::days(10)) + .format("%Y-%m-%d") + .to_string(); + let today = Local::now().date_naive().format("%Y-%m-%d").to_string(); + + let old_file = workspace.join("memory").join(format!("{old}.md")); + let today_file = workspace.join("memory").join(format!("{today}.md")); + fs::write(&old_file, "old note").unwrap(); + fs::write(&today_file, "fresh note").unwrap(); + + run_if_due(&default_cfg(), workspace).unwrap(); + + assert!(!old_file.exists(), "old daily file should be archived"); + assert!( + workspace + .join("memory") + .join("archive") + .join(format!("{old}.md")) + .exists(), + "old daily file should exist in memory/archive" + ); + assert!(today_file.exists(), "today file should remain in place"); + } + + #[test] + fn archives_old_session_files() { + let tmp = TempDir::new().unwrap(); + let workspace = tmp.path(); + fs::create_dir_all(workspace.join("sessions")).unwrap(); + + let old = (Local::now().date_naive() - Duration::days(10)) + .format("%Y-%m-%d") + .to_string(); + let old_name = format!("{old}-agent.log"); + let old_file = workspace.join("sessions").join(&old_name); + fs::write(&old_file, "old session").unwrap(); + + run_if_due(&default_cfg(), workspace).unwrap(); + + assert!(!old_file.exists(), "old session file should be archived"); + assert!( + workspace + .join("sessions") + .join("archive") + .join(&old_name) + .exists(), + "archived session file should exist" + ); + } + + #[test] + fn skips_second_run_within_cadence_window() { + let tmp = TempDir::new().unwrap(); + let workspace = tmp.path(); + fs::create_dir_all(workspace.join("memory")).unwrap(); + + let old_a = (Local::now().date_naive() - Duration::days(10)) + .format("%Y-%m-%d") + .to_string(); + let file_a = workspace.join("memory").join(format!("{old_a}.md")); + fs::write(&file_a, "first").unwrap(); + + run_if_due(&default_cfg(), workspace).unwrap(); + assert!(!file_a.exists(), "first old file should be archived"); + + let old_b = (Local::now().date_naive() - Duration::days(9)) + .format("%Y-%m-%d") + .to_string(); + let file_b = workspace.join("memory").join(format!("{old_b}.md")); + fs::write(&file_b, "second").unwrap(); + + // Should skip because cadence gate prevents a second immediate run. + run_if_due(&default_cfg(), workspace).unwrap(); + assert!( + file_b.exists(), + "second file should remain because run is throttled" + ); + } + + #[test] + fn purges_old_memory_archives() { + let tmp = TempDir::new().unwrap(); + let workspace = tmp.path(); + let archive_dir = workspace.join("memory").join("archive"); + fs::create_dir_all(&archive_dir).unwrap(); + + let old = (Local::now().date_naive() - Duration::days(40)) + .format("%Y-%m-%d") + .to_string(); + let keep = (Local::now().date_naive() - Duration::days(5)) + .format("%Y-%m-%d") + .to_string(); + + let old_file = archive_dir.join(format!("{old}.md")); + let keep_file = archive_dir.join(format!("{keep}.md")); + fs::write(&old_file, "expired").unwrap(); + fs::write(&keep_file, "recent").unwrap(); + + run_if_due(&default_cfg(), workspace).unwrap(); + + assert!(!old_file.exists(), "old archived file should be purged"); + assert!(keep_file.exists(), "recent archived file should remain"); + } + + #[tokio::test] + async fn prunes_old_conversation_rows_in_sqlite_backend() { + let tmp = TempDir::new().unwrap(); + let workspace = tmp.path(); + + let mem = SqliteMemory::new(workspace).unwrap(); + mem.store("conv_old", "outdated", MemoryCategory::Conversation) + .await + .unwrap(); + mem.store("core_keep", "durable", MemoryCategory::Core) + .await + .unwrap(); + drop(mem); + + let db_path = workspace.join("memory").join("brain.db"); + let conn = Connection::open(&db_path).unwrap(); + let old_cutoff = (Local::now() - Duration::days(60)).to_rfc3339(); + conn.execute( + "UPDATE memories SET created_at = ?1, updated_at = ?1 WHERE key = 'conv_old'", + params![old_cutoff], + ) + .unwrap(); + drop(conn); + + let mut cfg = default_cfg(); + cfg.archive_after_days = 0; + cfg.purge_after_days = 0; + cfg.conversation_retention_days = 30; + + run_if_due(&cfg, workspace).unwrap(); + + let mem2 = SqliteMemory::new(workspace).unwrap(); + assert!( + mem2.get("conv_old").await.unwrap().is_none(), + "old conversation rows should be pruned" + ); + assert!( + mem2.get("core_keep").await.unwrap().is_some(), + "core memory should remain" + ); + } +} diff --git a/src/memory/mod.rs b/src/memory/mod.rs index 249670b..66912ca 100644 --- a/src/memory/mod.rs +++ b/src/memory/mod.rs @@ -1,5 +1,6 @@ pub mod chunker; pub mod embeddings; +pub mod hygiene; pub mod markdown; pub mod sqlite; pub mod traits; @@ -21,6 +22,11 @@ pub fn create_memory( workspace_dir: &Path, api_key: Option<&str>, ) -> anyhow::Result> { + // Best-effort memory hygiene/retention pass (throttled by state file). + if let Err(e) = hygiene::run_if_due(config, workspace_dir) { + tracing::warn!("memory hygiene skipped: {e}"); + } + match config.backend.as_str() { "sqlite" => { let embedder: Arc = diff --git a/src/migration.rs b/src/migration.rs new file mode 100644 index 0000000..ed160c7 --- /dev/null +++ b/src/migration.rs @@ -0,0 +1,553 @@ +use crate::config::Config; +use crate::memory::{MarkdownMemory, Memory, MemoryCategory, SqliteMemory}; +use anyhow::{bail, Context, Result}; +use directories::UserDirs; +use rusqlite::{Connection, OpenFlags, OptionalExtension}; +use std::collections::HashSet; +use std::fs; +use std::path::{Path, PathBuf}; + +#[derive(Debug, Clone)] +struct SourceEntry { + key: String, + content: String, + category: MemoryCategory, +} + +#[derive(Debug, Default)] +struct MigrationStats { + from_sqlite: usize, + from_markdown: usize, + imported: usize, + skipped_unchanged: usize, + renamed_conflicts: usize, +} + +pub async fn handle_command(command: super::MigrateCommands, config: &Config) -> Result<()> { + match command { + super::MigrateCommands::Openclaw { source, dry_run } => { + migrate_openclaw_memory(config, source, dry_run).await + } + } +} + +async fn migrate_openclaw_memory( + config: &Config, + source_workspace: Option, + dry_run: bool, +) -> Result<()> { + let source_workspace = resolve_openclaw_workspace(source_workspace)?; + if !source_workspace.exists() { + bail!( + "OpenClaw workspace not found at {}. Pass --source if needed.", + source_workspace.display() + ); + } + + if paths_equal(&source_workspace, &config.workspace_dir) { + bail!("Source workspace matches current ZeroClaw workspace; refusing self-migration"); + } + + let mut stats = MigrationStats::default(); + let entries = collect_source_entries(&source_workspace, &mut stats)?; + + if entries.is_empty() { + println!( + "No importable memory found in {}", + source_workspace.display() + ); + println!("Checked for: memory/brain.db, MEMORY.md, memory/*.md"); + return Ok(()); + } + + if dry_run { + println!("🔎 Dry run: OpenClaw migration preview"); + println!(" Source: {}", source_workspace.display()); + println!(" Target: {}", config.workspace_dir.display()); + println!(" Candidates: {}", entries.len()); + println!(" - from sqlite: {}", stats.from_sqlite); + println!(" - from markdown: {}", stats.from_markdown); + println!(); + println!("Run without --dry-run to import these entries."); + return Ok(()); + } + + if let Some(backup_dir) = backup_target_memory(&config.workspace_dir)? { + println!("🛟 Backup created: {}", backup_dir.display()); + } + + let memory = target_memory_backend(config)?; + + for (idx, entry) in entries.into_iter().enumerate() { + let mut key = entry.key.trim().to_string(); + if key.is_empty() { + key = format!("openclaw_{idx}"); + } + + if let Some(existing) = memory.get(&key).await? { + if existing.content.trim() == entry.content.trim() { + stats.skipped_unchanged += 1; + continue; + } + + let renamed = next_available_key(memory.as_ref(), &key).await?; + key = renamed; + stats.renamed_conflicts += 1; + } + + memory.store(&key, &entry.content, entry.category).await?; + stats.imported += 1; + } + + println!("✅ OpenClaw memory migration complete"); + println!(" Source: {}", source_workspace.display()); + println!(" Target: {}", config.workspace_dir.display()); + println!(" Imported: {}", stats.imported); + println!(" Skipped unchanged:{}", stats.skipped_unchanged); + println!(" Renamed conflicts:{}", stats.renamed_conflicts); + println!(" Source sqlite rows:{}", stats.from_sqlite); + println!(" Source markdown: {}", stats.from_markdown); + + Ok(()) +} + +fn target_memory_backend(config: &Config) -> Result> { + match config.memory.backend.as_str() { + "sqlite" => Ok(Box::new(SqliteMemory::new(&config.workspace_dir)?)), + "markdown" | "none" => Ok(Box::new(MarkdownMemory::new(&config.workspace_dir))), + other => { + tracing::warn!( + "Unknown memory backend '{other}' during migration, defaulting to markdown" + ); + Ok(Box::new(MarkdownMemory::new(&config.workspace_dir))) + } + } +} + +fn collect_source_entries( + source_workspace: &Path, + stats: &mut MigrationStats, +) -> Result> { + let mut entries = Vec::new(); + + let sqlite_path = source_workspace.join("memory").join("brain.db"); + let sqlite_entries = read_openclaw_sqlite_entries(&sqlite_path)?; + stats.from_sqlite = sqlite_entries.len(); + entries.extend(sqlite_entries); + + let markdown_entries = read_openclaw_markdown_entries(source_workspace)?; + stats.from_markdown = markdown_entries.len(); + entries.extend(markdown_entries); + + // De-dup exact duplicates to make re-runs deterministic. + let mut seen = HashSet::new(); + entries.retain(|entry| { + let sig = format!("{}\u{0}{}\u{0}{}", entry.key, entry.content, entry.category); + seen.insert(sig) + }); + + Ok(entries) +} + +fn read_openclaw_sqlite_entries(db_path: &Path) -> Result> { + if !db_path.exists() { + return Ok(Vec::new()); + } + + let conn = Connection::open_with_flags(db_path, OpenFlags::SQLITE_OPEN_READ_ONLY) + .with_context(|| format!("Failed to open source db {}", db_path.display()))?; + + let table_exists: Option = conn + .query_row( + "SELECT name FROM sqlite_master WHERE type='table' AND name='memories' LIMIT 1", + [], + |row| row.get(0), + ) + .optional()?; + + if table_exists.is_none() { + return Ok(Vec::new()); + } + + let columns = table_columns(&conn, "memories")?; + let key_expr = pick_column_expr(&columns, &["key", "id", "name"], "CAST(rowid AS TEXT)"); + let Some(content_expr) = + pick_optional_column_expr(&columns, &["content", "value", "text", "memory"]) + else { + bail!("OpenClaw memories table found but no content-like column was detected"); + }; + let category_expr = pick_column_expr(&columns, &["category", "kind", "type"], "'core'"); + + let sql = format!( + "SELECT {key_expr} AS key, {content_expr} AS content, {category_expr} AS category FROM memories" + ); + + let mut stmt = conn.prepare(&sql)?; + let mut rows = stmt.query([])?; + + let mut entries = Vec::new(); + let mut idx = 0_usize; + + while let Some(row) = rows.next()? { + let key: String = row + .get(0) + .unwrap_or_else(|_| format!("openclaw_sqlite_{idx}")); + let content: String = row.get(1).unwrap_or_default(); + let category_raw: String = row.get(2).unwrap_or_else(|_| "core".to_string()); + + if content.trim().is_empty() { + continue; + } + + entries.push(SourceEntry { + key: normalize_key(&key, idx), + content: content.trim().to_string(), + category: parse_category(&category_raw), + }); + + idx += 1; + } + + Ok(entries) +} + +fn read_openclaw_markdown_entries(source_workspace: &Path) -> Result> { + let mut all = Vec::new(); + + let core_path = source_workspace.join("MEMORY.md"); + if core_path.exists() { + let content = fs::read_to_string(&core_path)?; + all.extend(parse_markdown_file( + &core_path, + &content, + MemoryCategory::Core, + "openclaw_core", + )); + } + + let daily_dir = source_workspace.join("memory"); + if daily_dir.exists() { + for file in fs::read_dir(&daily_dir)? { + let file = file?; + let path = file.path(); + if path.extension().and_then(|ext| ext.to_str()) != Some("md") { + continue; + } + let content = fs::read_to_string(&path)?; + let stem = path + .file_stem() + .and_then(|s| s.to_str()) + .unwrap_or("openclaw_daily"); + all.extend(parse_markdown_file( + &path, + &content, + MemoryCategory::Daily, + stem, + )); + } + } + + Ok(all) +} + +fn parse_markdown_file( + _path: &Path, + content: &str, + default_category: MemoryCategory, + stem: &str, +) -> Vec { + let mut entries = Vec::new(); + + for (idx, raw_line) in content.lines().enumerate() { + let trimmed = raw_line.trim(); + if trimmed.is_empty() || trimmed.starts_with('#') { + continue; + } + + let line = trimmed.strip_prefix("- ").unwrap_or(trimmed); + let (key, text) = match parse_structured_memory_line(line) { + Some((k, v)) => (normalize_key(k, idx), v.trim().to_string()), + None => ( + format!("openclaw_{stem}_{}", idx + 1), + line.trim().to_string(), + ), + }; + + if text.is_empty() { + continue; + } + + entries.push(SourceEntry { + key, + content: text, + category: default_category.clone(), + }); + } + + entries +} + +fn parse_structured_memory_line(line: &str) -> Option<(&str, &str)> { + if !line.starts_with("**") { + return None; + } + + let rest = line.strip_prefix("**")?; + let key_end = rest.find("**:")?; + let key = rest.get(..key_end)?.trim(); + let value = rest.get(key_end + 3..)?.trim(); + + if key.is_empty() || value.is_empty() { + return None; + } + + Some((key, value)) +} + +fn parse_category(raw: &str) -> MemoryCategory { + match raw.trim().to_ascii_lowercase().as_str() { + "core" => MemoryCategory::Core, + "daily" => MemoryCategory::Daily, + "conversation" => MemoryCategory::Conversation, + "" => MemoryCategory::Core, + other => MemoryCategory::Custom(other.to_string()), + } +} + +fn normalize_key(key: &str, fallback_idx: usize) -> String { + let trimmed = key.trim(); + if trimmed.is_empty() { + return format!("openclaw_{fallback_idx}"); + } + trimmed.to_string() +} + +async fn next_available_key(memory: &dyn Memory, base: &str) -> Result { + for i in 1..=10_000 { + let candidate = format!("{base}__openclaw_{i}"); + if memory.get(&candidate).await?.is_none() { + return Ok(candidate); + } + } + + bail!("Unable to allocate non-conflicting key for '{base}'") +} + +fn table_columns(conn: &Connection, table: &str) -> Result> { + let pragma = format!("PRAGMA table_info({table})"); + let mut stmt = conn.prepare(&pragma)?; + let rows = stmt.query_map([], |row| row.get::<_, String>(1))?; + + let mut cols = Vec::new(); + for col in rows { + cols.push(col?.to_ascii_lowercase()); + } + + Ok(cols) +} + +fn pick_optional_column_expr(columns: &[String], candidates: &[&str]) -> Option { + candidates + .iter() + .find(|candidate| columns.iter().any(|c| c == *candidate)) + .map(|s| s.to_string()) +} + +fn pick_column_expr(columns: &[String], candidates: &[&str], fallback: &str) -> String { + pick_optional_column_expr(columns, candidates).unwrap_or_else(|| fallback.to_string()) +} + +fn resolve_openclaw_workspace(source: Option) -> Result { + if let Some(src) = source { + return Ok(src); + } + + let home = UserDirs::new() + .map(|u| u.home_dir().to_path_buf()) + .context("Could not find home directory")?; + + Ok(home.join(".openclaw").join("workspace")) +} + +fn paths_equal(a: &Path, b: &Path) -> bool { + match (fs::canonicalize(a), fs::canonicalize(b)) { + (Ok(a), Ok(b)) => a == b, + _ => a == b, + } +} + +fn backup_target_memory(workspace_dir: &Path) -> Result> { + let timestamp = chrono::Local::now().format("%Y%m%d-%H%M%S").to_string(); + let backup_root = workspace_dir + .join("memory") + .join("migrations") + .join(format!("openclaw-{timestamp}")); + + let mut copied_any = false; + fs::create_dir_all(&backup_root)?; + + let files_to_copy = [ + workspace_dir.join("memory").join("brain.db"), + workspace_dir.join("MEMORY.md"), + ]; + + for source in files_to_copy { + if source.exists() { + let Some(name) = source.file_name() else { + continue; + }; + fs::copy(&source, backup_root.join(name))?; + copied_any = true; + } + } + + let daily_dir = workspace_dir.join("memory"); + if daily_dir.exists() { + let daily_backup = backup_root.join("daily"); + for file in fs::read_dir(&daily_dir)? { + let file = file?; + let path = file.path(); + if path.extension().and_then(|ext| ext.to_str()) != Some("md") { + continue; + } + fs::create_dir_all(&daily_backup)?; + let Some(name) = path.file_name() else { + continue; + }; + fs::copy(&path, daily_backup.join(name))?; + copied_any = true; + } + } + + if copied_any { + Ok(Some(backup_root)) + } else { + let _ = fs::remove_dir_all(&backup_root); + Ok(None) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::config::{Config, MemoryConfig}; + use rusqlite::params; + use tempfile::TempDir; + + fn test_config(workspace: &Path) -> Config { + Config { + workspace_dir: workspace.to_path_buf(), + config_path: workspace.join("config.toml"), + memory: MemoryConfig { + backend: "sqlite".to_string(), + ..MemoryConfig::default() + }, + ..Config::default() + } + } + + #[test] + fn parse_structured_markdown_line() { + let line = "**user_pref**: likes Rust"; + let parsed = parse_structured_memory_line(line).unwrap(); + assert_eq!(parsed.0, "user_pref"); + assert_eq!(parsed.1, "likes Rust"); + } + + #[test] + fn parse_unstructured_markdown_generates_key() { + let entries = parse_markdown_file( + Path::new("/tmp/MEMORY.md"), + "- plain note", + MemoryCategory::Core, + "core", + ); + assert_eq!(entries.len(), 1); + assert!(entries[0].key.starts_with("openclaw_core_")); + assert_eq!(entries[0].content, "plain note"); + } + + #[test] + fn sqlite_reader_supports_legacy_value_column() { + let dir = TempDir::new().unwrap(); + let db_path = dir.path().join("brain.db"); + let conn = Connection::open(&db_path).unwrap(); + + conn.execute_batch("CREATE TABLE memories (key TEXT, value TEXT, type TEXT);") + .unwrap(); + conn.execute( + "INSERT INTO memories (key, value, type) VALUES (?1, ?2, ?3)", + params!["legacy_key", "legacy_value", "daily"], + ) + .unwrap(); + + let rows = read_openclaw_sqlite_entries(&db_path).unwrap(); + assert_eq!(rows.len(), 1); + assert_eq!(rows[0].key, "legacy_key"); + assert_eq!(rows[0].content, "legacy_value"); + assert_eq!(rows[0].category, MemoryCategory::Daily); + } + + #[tokio::test] + async fn migration_renames_conflicting_key() { + let source = TempDir::new().unwrap(); + let target = TempDir::new().unwrap(); + + // Existing target memory + let target_mem = SqliteMemory::new(target.path()).unwrap(); + target_mem + .store("k", "new value", MemoryCategory::Core) + .await + .unwrap(); + + // Source sqlite with conflicting key + different content + let source_db_dir = source.path().join("memory"); + fs::create_dir_all(&source_db_dir).unwrap(); + let source_db = source_db_dir.join("brain.db"); + let conn = Connection::open(&source_db).unwrap(); + conn.execute_batch("CREATE TABLE memories (key TEXT, content TEXT, category TEXT);") + .unwrap(); + conn.execute( + "INSERT INTO memories (key, content, category) VALUES (?1, ?2, ?3)", + params!["k", "old value", "core"], + ) + .unwrap(); + + let config = test_config(target.path()); + migrate_openclaw_memory(&config, Some(source.path().to_path_buf()), false) + .await + .unwrap(); + + let all = target_mem.list(None).await.unwrap(); + assert!(all.iter().any(|e| e.key == "k" && e.content == "new value")); + assert!(all + .iter() + .any(|e| e.key.starts_with("k__openclaw_") && e.content == "old value")); + } + + #[tokio::test] + async fn dry_run_does_not_write() { + let source = TempDir::new().unwrap(); + let target = TempDir::new().unwrap(); + let source_db_dir = source.path().join("memory"); + fs::create_dir_all(&source_db_dir).unwrap(); + + let source_db = source_db_dir.join("brain.db"); + let conn = Connection::open(&source_db).unwrap(); + conn.execute_batch("CREATE TABLE memories (key TEXT, content TEXT, category TEXT);") + .unwrap(); + conn.execute( + "INSERT INTO memories (key, content, category) VALUES (?1, ?2, ?3)", + params!["dry", "run", "core"], + ) + .unwrap(); + + let config = test_config(target.path()); + migrate_openclaw_memory(&config, Some(source.path().to_path_buf()), true) + .await + .unwrap(); + + let target_mem = SqliteMemory::new(target.path()).unwrap(); + assert_eq!(target_mem.count().await.unwrap(), 0); + } +} diff --git a/src/onboard/mod.rs b/src/onboard/mod.rs index 0f16b88..a18ce8a 100644 --- a/src/onboard/mod.rs +++ b/src/onboard/mod.rs @@ -1,3 +1,3 @@ pub mod wizard; -pub use wizard::{run_quick_setup, run_wizard}; +pub use wizard::{run_channels_repair_wizard, run_quick_setup, run_wizard}; diff --git a/src/onboard/wizard.rs b/src/onboard/wizard.rs index 268dda2..6f5ba40 100644 --- a/src/onboard/wizard.rs +++ b/src/onboard/wizard.rs @@ -1,3 +1,4 @@ +use crate::config::schema::WhatsAppConfig; use crate::config::{ AutonomyConfig, BrowserConfig, ChannelsConfig, ComposioConfig, Config, DiscordConfig, HeartbeatConfig, IMessageConfig, MatrixConfig, MemoryConfig, ObservabilityConfig, @@ -91,6 +92,7 @@ pub fn run_wizard() -> Result { observability: ObservabilityConfig::default(), autonomy: AutonomyConfig::default(), runtime: RuntimeConfig::default(), + reliability: crate::config::ReliabilityConfig::default(), heartbeat: HeartbeatConfig::default(), channels_config, memory: MemoryConfig::default(), // SQLite + auto-save by default @@ -99,6 +101,7 @@ pub fn run_wizard() -> Result { composio: composio_config, secrets: secrets_config, browser: BrowserConfig::default(), + identity: crate::config::IdentityConfig::default(), }; println!( @@ -149,6 +152,61 @@ pub fn run_wizard() -> Result { Ok(config) } +/// Interactive repair flow: rerun channel setup only without redoing full onboarding. +pub fn run_channels_repair_wizard() -> Result { + println!("{}", style(BANNER).cyan().bold()); + println!( + " {}", + style("Channels Repair — update channel tokens and allowlists only") + .white() + .bold() + ); + println!(); + + let mut config = Config::load_or_init()?; + + print_step(1, 1, "Channels (How You Talk to ZeroClaw)"); + config.channels_config = setup_channels()?; + config.save()?; + + println!(); + println!( + " {} Channel config saved: {}", + style("✓").green().bold(), + style(config.config_path.display()).green() + ); + + let has_channels = config.channels_config.telegram.is_some() + || config.channels_config.discord.is_some() + || config.channels_config.slack.is_some() + || config.channels_config.imessage.is_some() + || config.channels_config.matrix.is_some(); + + if has_channels && config.api_key.is_some() { + let launch: bool = Confirm::new() + .with_prompt(format!( + " {} Launch channels now? (connected channels → AI → reply)", + style("🚀").cyan() + )) + .default(true) + .interact()?; + + if launch { + println!(); + println!( + " {} {}", + style("⚡").cyan(), + style("Starting channel server...").white().bold() + ); + println!(); + // Signal to main.rs to call start_channels after wizard returns + std::env::set_var("ZEROCLAW_AUTOSTART_CHANNELS", "1"); + } + } + + Ok(config) +} + // ── Quick setup (zero prompts) ─────────────────────────────────── /// Non-interactive setup: generates a sensible default config instantly. @@ -187,6 +245,7 @@ pub fn run_quick_setup(api_key: Option<&str>, provider: Option<&str>) -> Result< observability: ObservabilityConfig::default(), autonomy: AutonomyConfig::default(), runtime: RuntimeConfig::default(), + reliability: crate::config::ReliabilityConfig::default(), heartbeat: HeartbeatConfig::default(), channels_config: ChannelsConfig::default(), memory: MemoryConfig::default(), @@ -195,6 +254,7 @@ pub fn run_quick_setup(api_key: Option<&str>, provider: Option<&str>) -> Result< composio: ComposioConfig::default(), secrets: SecretsConfig::default(), browser: BrowserConfig::default(), + identity: crate::config::IdentityConfig::default(), }; config.save()?; @@ -204,7 +264,9 @@ pub fn run_quick_setup(api_key: Option<&str>, provider: Option<&str>) -> Result< user_name: std::env::var("USER").unwrap_or_else(|_| "User".into()), timezone: "UTC".into(), agent_name: "ZeroClaw".into(), - communication_style: "Direct and concise".into(), + communication_style: + "Be warm, natural, and clear. Use occasional relevant emojis (1-2 max) and avoid robotic phrasing." + .into(), }; scaffold_workspace(&workspace_dir, &default_ctx)?; @@ -878,24 +940,33 @@ fn setup_project_context() -> Result { let style_options = vec![ "Direct & concise — skip pleasantries, get to the point", - "Friendly & casual — warm but efficient", + "Friendly & casual — warm, human, and helpful", + "Professional & polished — calm, confident, and clear", + "Expressive & playful — more personality + natural emojis", "Technical & detailed — thorough explanations, code-first", "Balanced — adapt to the situation", + "Custom — write your own style guide", ]; let style_idx = Select::new() .with_prompt(" Communication style") .items(&style_options) - .default(0) + .default(1) .interact()?; let communication_style = match style_idx { 0 => "Be direct and concise. Skip pleasantries. Get to the point.".to_string(), - 1 => "Be friendly and casual. Warm but efficient.".to_string(), - 2 => "Be technical and detailed. Thorough explanations, code-first.".to_string(), - _ => { - "Adapt to the situation. Be concise when needed, thorough when it matters.".to_string() - } + 1 => "Be friendly, human, and conversational. Show warmth and empathy while staying efficient. Use natural contractions.".to_string(), + 2 => "Be professional and polished. Stay calm, structured, and respectful. Use occasional tone-setting emojis only when appropriate.".to_string(), + 3 => "Be expressive and playful when appropriate. Use relevant emojis naturally (0-2 max), and keep serious topics emoji-light.".to_string(), + 4 => "Be technical and detailed. Thorough explanations, code-first.".to_string(), + 5 => "Adapt to the situation. Default to warm and clear communication; be concise when needed, thorough when it matters.".to_string(), + _ => Input::new() + .with_prompt(" Custom communication style") + .default( + "Be warm, natural, and clear. Use occasional relevant emojis (1-2 max) and avoid robotic phrasing.".into(), + ) + .interact_text()?, }; println!( @@ -931,6 +1002,7 @@ fn setup_channels() -> Result { webhook: None, imessage: None, matrix: None, + whatsapp: None, }; loop { @@ -975,6 +1047,14 @@ fn setup_channels() -> Result { "— self-hosted chat" } ), + format!( + "WhatsApp {}", + if config.whatsapp.is_some() { + "✅ connected" + } else { + "— Business Cloud API" + } + ), format!( "Webhook {}", if config.webhook.is_some() { @@ -989,7 +1069,7 @@ fn setup_channels() -> Result { let choice = Select::new() .with_prompt(" Connect a channel (or Done to continue)") .items(&options) - .default(6) + .default(7) .interact()?; match choice { @@ -1041,17 +1121,38 @@ fn setup_channels() -> Result { } } + print_bullet( + "Allowlist your own Telegram identity first (recommended for secure + fast setup).", + ); + print_bullet( + "Use your @username without '@' (example: argenis), or your numeric Telegram user ID.", + ); + print_bullet("Use '*' only for temporary open testing."); + let users_str: String = Input::new() - .with_prompt(" Allowed usernames (comma-separated, or * for all)") - .default("*".into()) + .with_prompt( + " Allowed Telegram identities (comma-separated: username without '@' and/or numeric user ID, '*' for all)", + ) + .allow_empty(true) .interact_text()?; let allowed_users = if users_str.trim() == "*" { vec!["*".into()] } else { - users_str.split(',').map(|s| s.trim().to_string()).collect() + users_str + .split(',') + .map(|s| s.trim().to_string()) + .filter(|s| !s.is_empty()) + .collect() }; + if allowed_users.is_empty() { + println!( + " {} No users allowlisted — Telegram inbound messages will be denied until you add your username/user ID or '*'.", + style("⚠").yellow().bold() + ); + } + config.telegram = Some(TelegramConfig { bot_token: token, allowed_users, @@ -1111,9 +1212,15 @@ fn setup_channels() -> Result { .allow_empty(true) .interact_text()?; + print_bullet("Allowlist your own Discord user ID first (recommended)."); + print_bullet( + "Get it in Discord: Settings -> Advanced -> Developer Mode (ON), then right-click your profile -> Copy User ID.", + ); + print_bullet("Use '*' only for temporary open testing."); + let allowed_users_str: String = Input::new() .with_prompt( - " Allowed Discord user IDs (comma-separated, '*' for all, Enter to deny all)", + " Allowed Discord user IDs (comma-separated, recommended: your own ID, '*' for all)", ) .allow_empty(true) .interact_text()?; @@ -1214,9 +1321,15 @@ fn setup_channels() -> Result { .allow_empty(true) .interact_text()?; + print_bullet("Allowlist your own Slack member ID first (recommended)."); + print_bullet( + "Member IDs usually start with 'U' (open your Slack profile -> More -> Copy member ID).", + ); + print_bullet("Use '*' only for temporary open testing."); + let allowed_users_str: String = Input::new() .with_prompt( - " Allowed Slack user IDs (comma-separated, '*' for all, Enter to deny all)", + " Allowed Slack user IDs (comma-separated, recommended: your own member ID, '*' for all)", ) .allow_empty(true) .interact_text()?; @@ -1378,6 +1491,90 @@ fn setup_channels() -> Result { }); } 5 => { + // ── WhatsApp ── + println!(); + println!( + " {} {}", + style("WhatsApp Setup").white().bold(), + style("— Business Cloud API").dim() + ); + print_bullet("1. Go to developers.facebook.com and create a WhatsApp app"); + print_bullet("2. Add the WhatsApp product and get your phone number ID"); + print_bullet("3. Generate a temporary access token (System User)"); + print_bullet("4. Configure webhook URL to: https://your-domain/whatsapp"); + println!(); + + let access_token: String = Input::new() + .with_prompt(" Access token (from Meta Developers)") + .interact_text()?; + + if access_token.trim().is_empty() { + println!(" {} Skipped", style("→").dim()); + continue; + } + + let phone_number_id: String = Input::new() + .with_prompt(" Phone number ID (from WhatsApp app settings)") + .interact_text()?; + + if phone_number_id.trim().is_empty() { + println!(" {} Skipped — phone number ID required", style("→").dim()); + continue; + } + + let verify_token: String = Input::new() + .with_prompt(" Webhook verify token (create your own)") + .default("zeroclaw-whatsapp-verify".into()) + .interact_text()?; + + // Test connection + print!(" {} Testing connection... ", style("⏳").dim()); + let client = reqwest::blocking::Client::new(); + let url = format!( + "https://graph.facebook.com/v18.0/{}", + phone_number_id.trim() + ); + match client + .get(&url) + .header("Authorization", format!("Bearer {}", access_token.trim())) + .send() + { + Ok(resp) if resp.status().is_success() => { + println!( + "\r {} Connected to WhatsApp API ", + style("✅").green().bold() + ); + } + _ => { + println!( + "\r {} Connection failed — check access token and phone number ID", + style("❌").red().bold() + ); + continue; + } + } + + let users_str: String = Input::new() + .with_prompt( + " Allowed phone numbers (comma-separated +1234567890, or * for all)", + ) + .default("*".into()) + .interact_text()?; + + let allowed_numbers = if users_str.trim() == "*" { + vec!["*".into()] + } else { + users_str.split(',').map(|s| s.trim().to_string()).collect() + }; + + config.whatsapp = Some(WhatsAppConfig { + access_token: access_token.trim().to_string(), + phone_number_id: phone_number_id.trim().to_string(), + verify_token: verify_token.trim().to_string(), + allowed_numbers, + }); + } + 6 => { // ── Webhook ── println!(); println!( @@ -1432,6 +1629,9 @@ fn setup_channels() -> Result { if config.matrix.is_some() { active.push("Matrix"); } + if config.whatsapp.is_some() { + active.push("WhatsApp"); + } if config.webhook.is_some() { active.push("Webhook"); } @@ -1618,7 +1818,7 @@ fn scaffold_workspace(workspace_dir: &Path, ctx: &ProjectContext) -> Result<()> &ctx.timezone }; let comm_style = if ctx.communication_style.is_empty() { - "Adapt to the situation. Be concise when needed, thorough when it matters." + "Be warm, natural, and clear. Use occasional relevant emojis (1-2 max) and avoid robotic phrasing." } else { &ctx.communication_style }; @@ -1667,6 +1867,14 @@ fn scaffold_workspace(workspace_dir: &Path, ctx: &ProjectContext) -> Result<()> ## Tools & Skills\n\n\ Skills are listed in the system prompt. Use `read` on a skill's SKILL.md for details.\n\ Keep local notes (SSH hosts, device names, etc.) in `TOOLS.md`.\n\n\ + ## Crash Recovery\n\n\ + - If a run stops unexpectedly, recover context before acting.\n\ + - Check `MEMORY.md` + latest `memory/*.md` notes to avoid duplicate work.\n\ + - Resume from the last confirmed step, not from scratch.\n\n\ + ## Sub-task Scoping\n\n\ + - Break complex work into focused sub-tasks with clear success criteria.\n\ + - Keep sub-tasks small, verify each output, then merge results.\n\ + - Prefer one clear objective per sub-task over broad \"do everything\" asks.\n\n\ ## Make It Yours\n\n\ This is a starting point. Add your own conventions, style, and rules.\n" ); @@ -1704,6 +1912,11 @@ fn scaffold_workspace(workspace_dir: &Path, ctx: &ProjectContext) -> Result<()> - Always introduce yourself as {agent} if asked\n\n\ ## Communication\n\n\ {comm_style}\n\n\ + - Sound like a real person, not a support script.\n\ + - Mirror the user's energy: calm when serious, upbeat when casual.\n\ + - Use emojis naturally (0-2 max when they help tone, not every sentence).\n\ + - Match emoji density to the user. Formal user => minimal/no emojis.\n\ + - Prefer specific, grounded phrasing over generic filler.\n\n\ ## Boundaries\n\n\ - Private things stay private. Period.\n\ - When in doubt, ask before acting externally.\n\ @@ -1744,11 +1957,23 @@ fn scaffold_workspace(workspace_dir: &Path, ctx: &ProjectContext) -> Result<()> - Anything environment-specific\n\n\ ## Built-in Tools\n\n\ - **shell** — Execute terminal commands\n\ + - Use when: running local checks, build/test commands, or diagnostics.\n\ + - Don't use when: a safer dedicated tool exists, or command is destructive without approval.\n\ - **file_read** — Read file contents\n\ + - Use when: inspecting project files, configs, or logs.\n\ + - Don't use when: you only need a quick string search (prefer targeted search first).\n\ - **file_write** — Write file contents\n\ + - Use when: applying focused edits, scaffolding files, or updating docs/code.\n\ + - Don't use when: unsure about side effects or when the file should remain user-owned.\n\ - **memory_store** — Save to memory\n\ + - Use when: preserving durable preferences, decisions, or key context.\n\ + - Don't use when: info is transient, noisy, or sensitive without explicit need.\n\ - **memory_recall** — Search memory\n\ - - **memory_forget** — Delete a memory entry\n\n\ + - Use when: you need prior decisions, user preferences, or historical context.\n\ + - Don't use when: the answer is already in current files/conversation.\n\ + - **memory_forget** — Delete a memory entry\n\ + - Use when: memory is incorrect, stale, or explicitly requested to be removed.\n\ + - Don't use when: uncertain about impact; verify before deleting.\n\n\ ---\n\ *Add whatever helps you do your job. This is your cheat sheet.*\n"; @@ -2242,7 +2467,7 @@ mod tests { let soul = fs::read_to_string(tmp.path().join("SOUL.md")).unwrap(); assert!( - soul.contains("Adapt to the situation"), + soul.contains("Be warm, natural, and clear."), "should default communication style" ); } @@ -2383,6 +2608,31 @@ mod tests { "TOOLS.md should list built-in tool: {tool}" ); } + assert!( + tools.contains("Use when:"), + "TOOLS.md should include 'Use when' guidance" + ); + assert!( + tools.contains("Don't use when:"), + "TOOLS.md should include 'Don't use when' guidance" + ); + } + + #[test] + fn soul_md_includes_emoji_awareness_guidance() { + let tmp = TempDir::new().unwrap(); + let ctx = ProjectContext::default(); + scaffold_workspace(tmp.path(), &ctx).unwrap(); + + let soul = fs::read_to_string(tmp.path().join("SOUL.md")).unwrap(); + assert!( + soul.contains("Use emojis naturally (0-2 max"), + "SOUL.md should include emoji usage guidance" + ); + assert!( + soul.contains("Match emoji density to the user"), + "SOUL.md should include emoji-awareness guidance" + ); } // ── scaffold_workspace: special characters in names ───────── @@ -2414,7 +2664,9 @@ mod tests { user_name: "Argenis".into(), timezone: "US/Eastern".into(), agent_name: "Claw".into(), - communication_style: "Be friendly and casual. Warm but efficient.".into(), + communication_style: + "Be friendly, human, and conversational. Show warmth and empathy while staying efficient. Use natural contractions." + .into(), }; scaffold_workspace(tmp.path(), &ctx).unwrap(); @@ -2424,12 +2676,12 @@ mod tests { let soul = fs::read_to_string(tmp.path().join("SOUL.md")).unwrap(); assert!(soul.contains("You are **Claw**")); - assert!(soul.contains("Be friendly and casual")); + assert!(soul.contains("Be friendly, human, and conversational")); let user_md = fs::read_to_string(tmp.path().join("USER.md")).unwrap(); assert!(user_md.contains("**Name:** Argenis")); assert!(user_md.contains("**Timezone:** US/Eastern")); - assert!(user_md.contains("Be friendly and casual")); + assert!(user_md.contains("Be friendly, human, and conversational")); let agents = fs::read_to_string(tmp.path().join("AGENTS.md")).unwrap(); assert!(agents.contains("Claw Personal Assistant")); diff --git a/src/providers/mod.rs b/src/providers/mod.rs index 884c66e..768640a 100644 --- a/src/providers/mod.rs +++ b/src/providers/mod.rs @@ -4,11 +4,13 @@ pub mod gemini; pub mod ollama; pub mod openai; pub mod openrouter; +pub mod reliable; pub mod traits; pub use traits::Provider; use compatible::{AuthStyle, OpenAiCompatibleProvider}; +use reliable::ReliableProvider; /// Factory: create the right provider from config #[allow(clippy::too_many_lines)] @@ -114,6 +116,42 @@ pub fn create_provider(name: &str, api_key: Option<&str>) -> anyhow::Result, + reliability: &crate::config::ReliabilityConfig, +) -> anyhow::Result> { + let mut providers: Vec<(String, Box)> = Vec::new(); + + providers.push(( + primary_name.to_string(), + create_provider(primary_name, api_key)?, + )); + + for fallback in &reliability.fallback_providers { + if fallback == primary_name || providers.iter().any(|(name, _)| name == fallback) { + continue; + } + + match create_provider(fallback, api_key) { + Ok(provider) => providers.push((fallback.clone(), provider)), + Err(e) => { + tracing::warn!( + fallback_provider = fallback, + "Ignoring invalid fallback provider: {e}" + ); + } + } + } + + Ok(Box::new(ReliableProvider::new( + providers, + reliability.provider_retries, + reliability.provider_backoff_ms, + ))) +} + #[cfg(test)] mod tests { use super::*; @@ -307,6 +345,34 @@ mod tests { assert!(create_provider("", None).is_err()); } + #[test] + fn resilient_provider_ignores_duplicate_and_invalid_fallbacks() { + let reliability = crate::config::ReliabilityConfig { + provider_retries: 1, + provider_backoff_ms: 100, + fallback_providers: vec![ + "openrouter".into(), + "nonexistent-provider".into(), + "openai".into(), + "openai".into(), + ], + channel_initial_backoff_secs: 2, + channel_max_backoff_secs: 60, + scheduler_poll_secs: 15, + scheduler_retries: 2, + }; + + let provider = create_resilient_provider("openrouter", Some("sk-test"), &reliability); + assert!(provider.is_ok()); + } + + #[test] + fn resilient_provider_errors_for_invalid_primary() { + let reliability = crate::config::ReliabilityConfig::default(); + let provider = create_resilient_provider("totally-invalid", Some("sk-test"), &reliability); + assert!(provider.is_err()); + } + #[test] fn factory_all_providers_create_successfully() { let providers = [ diff --git a/src/providers/reliable.rs b/src/providers/reliable.rs new file mode 100644 index 0000000..c324f21 --- /dev/null +++ b/src/providers/reliable.rs @@ -0,0 +1,229 @@ +use super::Provider; +use async_trait::async_trait; +use std::time::Duration; + +/// Provider wrapper with retry + fallback behavior. +pub struct ReliableProvider { + providers: Vec<(String, Box)>, + max_retries: u32, + base_backoff_ms: u64, +} + +impl ReliableProvider { + pub fn new( + providers: Vec<(String, Box)>, + max_retries: u32, + base_backoff_ms: u64, + ) -> Self { + Self { + providers, + max_retries, + base_backoff_ms: base_backoff_ms.max(50), + } + } +} + +#[async_trait] +impl Provider for ReliableProvider { + async fn chat_with_system( + &self, + system_prompt: Option<&str>, + message: &str, + model: &str, + temperature: f64, + ) -> anyhow::Result { + let mut failures = Vec::new(); + + for (provider_name, provider) in &self.providers { + let mut backoff_ms = self.base_backoff_ms; + + for attempt in 0..=self.max_retries { + match provider + .chat_with_system(system_prompt, message, model, temperature) + .await + { + Ok(resp) => { + if attempt > 0 { + tracing::info!( + provider = provider_name, + attempt, + "Provider recovered after retries" + ); + } + return Ok(resp); + } + Err(e) => { + failures.push(format!( + "{provider_name} attempt {}/{}: {e}", + attempt + 1, + self.max_retries + 1 + )); + + if attempt < self.max_retries { + tracing::warn!( + provider = provider_name, + attempt = attempt + 1, + max_retries = self.max_retries, + "Provider call failed, retrying" + ); + tokio::time::sleep(Duration::from_millis(backoff_ms)).await; + backoff_ms = (backoff_ms.saturating_mul(2)).min(10_000); + } + } + } + } + + tracing::warn!(provider = provider_name, "Switching to fallback provider"); + } + + anyhow::bail!("All providers failed. Attempts:\n{}", failures.join("\n")) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::sync::atomic::{AtomicUsize, Ordering}; + use std::sync::Arc; + + struct MockProvider { + calls: Arc, + fail_until_attempt: usize, + response: &'static str, + error: &'static str, + } + + #[async_trait] + impl Provider for MockProvider { + async fn chat_with_system( + &self, + _system_prompt: Option<&str>, + _message: &str, + _model: &str, + _temperature: f64, + ) -> anyhow::Result { + let attempt = self.calls.fetch_add(1, Ordering::SeqCst) + 1; + if attempt <= self.fail_until_attempt { + anyhow::bail!(self.error); + } + Ok(self.response.to_string()) + } + } + + #[tokio::test] + async fn succeeds_without_retry() { + let calls = Arc::new(AtomicUsize::new(0)); + let provider = ReliableProvider::new( + vec![( + "primary".into(), + Box::new(MockProvider { + calls: Arc::clone(&calls), + fail_until_attempt: 0, + response: "ok", + error: "boom", + }), + )], + 2, + 1, + ); + + let result = provider.chat("hello", "test", 0.0).await.unwrap(); + assert_eq!(result, "ok"); + assert_eq!(calls.load(Ordering::SeqCst), 1); + } + + #[tokio::test] + async fn retries_then_recovers() { + let calls = Arc::new(AtomicUsize::new(0)); + let provider = ReliableProvider::new( + vec![( + "primary".into(), + Box::new(MockProvider { + calls: Arc::clone(&calls), + fail_until_attempt: 1, + response: "recovered", + error: "temporary", + }), + )], + 2, + 1, + ); + + let result = provider.chat("hello", "test", 0.0).await.unwrap(); + assert_eq!(result, "recovered"); + assert_eq!(calls.load(Ordering::SeqCst), 2); + } + + #[tokio::test] + async fn falls_back_after_retries_exhausted() { + let primary_calls = Arc::new(AtomicUsize::new(0)); + let fallback_calls = Arc::new(AtomicUsize::new(0)); + + let provider = ReliableProvider::new( + vec![ + ( + "primary".into(), + Box::new(MockProvider { + calls: Arc::clone(&primary_calls), + fail_until_attempt: usize::MAX, + response: "never", + error: "primary down", + }), + ), + ( + "fallback".into(), + Box::new(MockProvider { + calls: Arc::clone(&fallback_calls), + fail_until_attempt: 0, + response: "from fallback", + error: "fallback down", + }), + ), + ], + 1, + 1, + ); + + let result = provider.chat("hello", "test", 0.0).await.unwrap(); + assert_eq!(result, "from fallback"); + assert_eq!(primary_calls.load(Ordering::SeqCst), 2); + assert_eq!(fallback_calls.load(Ordering::SeqCst), 1); + } + + #[tokio::test] + async fn returns_aggregated_error_when_all_providers_fail() { + let provider = ReliableProvider::new( + vec![ + ( + "p1".into(), + Box::new(MockProvider { + calls: Arc::new(AtomicUsize::new(0)), + fail_until_attempt: usize::MAX, + response: "never", + error: "p1 error", + }), + ), + ( + "p2".into(), + Box::new(MockProvider { + calls: Arc::new(AtomicUsize::new(0)), + fail_until_attempt: usize::MAX, + response: "never", + error: "p2 error", + }), + ), + ], + 0, + 1, + ); + + let err = provider + .chat("hello", "test", 0.0) + .await + .expect_err("all providers should fail"); + let msg = err.to_string(); + assert!(msg.contains("All providers failed")); + assert!(msg.contains("p1 attempt 1/1")); + assert!(msg.contains("p2 attempt 1/1")); + } +} diff --git a/src/runtime/mod.rs b/src/runtime/mod.rs index cb8abd5..9ed0ee0 100644 --- a/src/runtime/mod.rs +++ b/src/runtime/mod.rs @@ -7,17 +7,21 @@ pub use traits::RuntimeAdapter; use crate::config::RuntimeConfig; /// Factory: create the right runtime from config -pub fn create_runtime(config: &RuntimeConfig) -> Box { +pub fn create_runtime(config: &RuntimeConfig) -> anyhow::Result> { match config.kind.as_str() { - "native" | "docker" => Box::new(NativeRuntime::new()), - "cloudflare" => { - tracing::warn!("Cloudflare runtime not yet implemented, falling back to native"); - Box::new(NativeRuntime::new()) - } - _ => { - tracing::warn!("Unknown runtime '{}', falling back to native", config.kind); - Box::new(NativeRuntime::new()) - } + "native" => Ok(Box::new(NativeRuntime::new())), + "docker" => anyhow::bail!( + "runtime.kind='docker' is not implemented yet. Use runtime.kind='native' until container runtime support lands." + ), + "cloudflare" => anyhow::bail!( + "runtime.kind='cloudflare' is not implemented yet. Use runtime.kind='native' for now." + ), + other if other.trim().is_empty() => anyhow::bail!( + "runtime.kind cannot be empty. Supported values: native" + ), + other => anyhow::bail!( + "Unknown runtime kind '{other}'. Supported values: native" + ), } } @@ -30,44 +34,52 @@ mod tests { let cfg = RuntimeConfig { kind: "native".into(), }; - let rt = create_runtime(&cfg); + let rt = create_runtime(&cfg).unwrap(); assert_eq!(rt.name(), "native"); assert!(rt.has_shell_access()); } #[test] - fn factory_docker_returns_native() { + fn factory_docker_errors() { let cfg = RuntimeConfig { kind: "docker".into(), }; - let rt = create_runtime(&cfg); - assert_eq!(rt.name(), "native"); + match create_runtime(&cfg) { + Err(err) => assert!(err.to_string().contains("not implemented")), + Ok(_) => panic!("docker runtime should error"), + } } #[test] - fn factory_cloudflare_falls_back() { + fn factory_cloudflare_errors() { let cfg = RuntimeConfig { kind: "cloudflare".into(), }; - let rt = create_runtime(&cfg); - assert_eq!(rt.name(), "native"); + match create_runtime(&cfg) { + Err(err) => assert!(err.to_string().contains("not implemented")), + Ok(_) => panic!("cloudflare runtime should error"), + } } #[test] - fn factory_unknown_falls_back() { + fn factory_unknown_errors() { let cfg = RuntimeConfig { kind: "wasm-edge-unknown".into(), }; - let rt = create_runtime(&cfg); - assert_eq!(rt.name(), "native"); + match create_runtime(&cfg) { + Err(err) => assert!(err.to_string().contains("Unknown runtime kind")), + Ok(_) => panic!("unknown runtime should error"), + } } #[test] - fn factory_empty_falls_back() { + fn factory_empty_errors() { let cfg = RuntimeConfig { kind: String::new(), }; - let rt = create_runtime(&cfg); - assert_eq!(rt.name(), "native"); + match create_runtime(&cfg) { + Err(err) => assert!(err.to_string().contains("cannot be empty")), + Ok(_) => panic!("empty runtime should error"), + } } } diff --git a/src/security/policy.rs b/src/security/policy.rs index a8b160e..49d58df 100644 --- a/src/security/policy.rs +++ b/src/security/policy.rs @@ -258,8 +258,14 @@ impl SecurityPolicy { /// Validate that a resolved path is still inside the workspace. /// Call this AFTER joining `workspace_dir` + relative path and canonicalizing. pub fn is_resolved_path_allowed(&self, resolved: &Path) -> bool { - // Must be under workspace_dir (prevents symlink escapes) - resolved.starts_with(&self.workspace_dir) + // Must be under workspace_dir (prevents symlink escapes). + // Prefer canonical workspace root so `/a/../b` style config paths don't + // cause false positives or negatives. + let workspace_root = self + .workspace_dir + .canonicalize() + .unwrap_or_else(|_| self.workspace_dir.clone()); + resolved.starts_with(workspace_root) } /// Check if autonomy level permits any action at all diff --git a/src/security/secrets.rs b/src/security/secrets.rs index 8b770e5..6022ebe 100644 --- a/src/security/secrets.rs +++ b/src/security/secrets.rs @@ -79,45 +79,94 @@ impl SecretStore { /// - `enc2:` prefix → ChaCha20-Poly1305 (current format) /// - `enc:` prefix → legacy XOR cipher (backward compatibility for migration) /// - No prefix → returned as-is (plaintext config) + /// + /// **Warning**: Legacy `enc:` values are insecure. Use `decrypt_and_migrate` to + /// automatically upgrade them to the secure `enc2:` format. pub fn decrypt(&self, value: &str) -> Result { if let Some(hex_str) = value.strip_prefix("enc2:") { - let blob = - hex_decode(hex_str).context("Failed to decode encrypted secret (corrupt hex)")?; - anyhow::ensure!( - blob.len() > NONCE_LEN, - "Encrypted value too short (missing nonce)" - ); - - let (nonce_bytes, ciphertext) = blob.split_at(NONCE_LEN); - let nonce = Nonce::from_slice(nonce_bytes); - let key_bytes = self.load_or_create_key()?; - let key = Key::from_slice(&key_bytes); - let cipher = ChaCha20Poly1305::new(key); - - let plaintext_bytes = cipher - .decrypt(nonce, ciphertext) - .map_err(|_| anyhow::anyhow!("Decryption failed — wrong key or tampered data"))?; - - String::from_utf8(plaintext_bytes) - .context("Decrypted secret is not valid UTF-8 — corrupt data") + self.decrypt_chacha20(hex_str) } else if let Some(hex_str) = value.strip_prefix("enc:") { - // Legacy XOR cipher — decrypt for backward compatibility - let ciphertext = hex_decode(hex_str) - .context("Failed to decode legacy encrypted secret (corrupt hex)")?; - let key = self.load_or_create_key()?; - let plaintext_bytes = xor_cipher(&ciphertext, &key); - String::from_utf8(plaintext_bytes) - .context("Decrypted legacy secret is not valid UTF-8 — wrong key or corrupt data") + self.decrypt_legacy_xor(hex_str) } else { Ok(value.to_string()) } } + /// Decrypt a secret and return a migrated `enc2:` value if the input used legacy `enc:` format. + /// + /// Returns `(plaintext, Some(new_enc2_value))` if migration occurred, or + /// `(plaintext, None)` if no migration was needed. + /// + /// This allows callers to persist the upgraded value back to config. + pub fn decrypt_and_migrate(&self, value: &str) -> Result<(String, Option)> { + if let Some(hex_str) = value.strip_prefix("enc2:") { + // Already using secure format — no migration needed + let plaintext = self.decrypt_chacha20(hex_str)?; + Ok((plaintext, None)) + } else if let Some(hex_str) = value.strip_prefix("enc:") { + // Legacy XOR cipher — decrypt and re-encrypt with ChaCha20-Poly1305 + tracing::warn!( + "Decrypting legacy XOR-encrypted secret (enc: prefix). \ + This format is insecure and will be removed in a future release. \ + The secret will be automatically migrated to enc2: (ChaCha20-Poly1305)." + ); + let plaintext = self.decrypt_legacy_xor(hex_str)?; + let migrated = self.encrypt(&plaintext)?; + Ok((plaintext, Some(migrated))) + } else { + // Plaintext — no migration needed + Ok((value.to_string(), None)) + } + } + + /// Check if a value uses the legacy `enc:` format that should be migrated. + pub fn needs_migration(value: &str) -> bool { + value.starts_with("enc:") + } + + /// Decrypt using ChaCha20-Poly1305 (current secure format). + fn decrypt_chacha20(&self, hex_str: &str) -> Result { + let blob = + hex_decode(hex_str).context("Failed to decode encrypted secret (corrupt hex)")?; + anyhow::ensure!( + blob.len() > NONCE_LEN, + "Encrypted value too short (missing nonce)" + ); + + let (nonce_bytes, ciphertext) = blob.split_at(NONCE_LEN); + let nonce = Nonce::from_slice(nonce_bytes); + let key_bytes = self.load_or_create_key()?; + let key = Key::from_slice(&key_bytes); + let cipher = ChaCha20Poly1305::new(key); + + let plaintext_bytes = cipher + .decrypt(nonce, ciphertext) + .map_err(|_| anyhow::anyhow!("Decryption failed — wrong key or tampered data"))?; + + String::from_utf8(plaintext_bytes) + .context("Decrypted secret is not valid UTF-8 — corrupt data") + } + + /// Decrypt using legacy XOR cipher (insecure, for backward compatibility only). + fn decrypt_legacy_xor(&self, hex_str: &str) -> Result { + let ciphertext = hex_decode(hex_str) + .context("Failed to decode legacy encrypted secret (corrupt hex)")?; + let key = self.load_or_create_key()?; + let plaintext_bytes = xor_cipher(&ciphertext, &key); + String::from_utf8(plaintext_bytes) + .context("Decrypted legacy secret is not valid UTF-8 — wrong key or corrupt data") + } + /// Check if a value is already encrypted (current or legacy format). pub fn is_encrypted(value: &str) -> bool { value.starts_with("enc2:") || value.starts_with("enc:") } + /// Check if a value uses the secure `enc2:` format. + pub fn is_secure_encrypted(value: &str) -> bool { + value.starts_with("enc2:") + } + /// Load the encryption key from disk, or create one if it doesn't exist. fn load_or_create_key(&self) -> Result> { if self.key_path.exists() { @@ -132,13 +181,22 @@ impl SecretStore { fs::write(&self.key_path, hex_encode(&key)) .context("Failed to write secret key file")?; - // Set restrictive permissions (Unix only) + // Set restrictive permissions #[cfg(unix)] { use std::os::unix::fs::PermissionsExt; fs::set_permissions(&self.key_path, fs::Permissions::from_mode(0o600)) .context("Failed to set key file permissions")?; } + #[cfg(windows)] + { + // On Windows, use icacls to restrict permissions to current user only + let _ = std::process::Command::new("icacls") + .arg(&self.key_path) + .args(["/inheritance:r", "/grant:r"]) + .arg(format!("{}:F", std::env::var("USERNAME").unwrap_or_default())) + .output(); + } Ok(key) } @@ -382,6 +440,258 @@ mod tests { assert_eq!(decrypted, plaintext, "Legacy XOR values must still decrypt"); } + // ── Migration tests ───────────────────────────────────────── + + #[test] + fn needs_migration_detects_legacy_prefix() { + assert!(SecretStore::needs_migration("enc:aabbcc")); + assert!(!SecretStore::needs_migration("enc2:aabbcc")); + assert!(!SecretStore::needs_migration("sk-plaintext")); + assert!(!SecretStore::needs_migration("")); + } + + #[test] + fn is_secure_encrypted_detects_enc2_only() { + assert!(SecretStore::is_secure_encrypted("enc2:aabbcc")); + assert!(!SecretStore::is_secure_encrypted("enc:aabbcc")); + assert!(!SecretStore::is_secure_encrypted("sk-plaintext")); + assert!(!SecretStore::is_secure_encrypted("")); + } + + #[test] + fn decrypt_and_migrate_returns_none_for_enc2() { + let tmp = TempDir::new().unwrap(); + let store = SecretStore::new(tmp.path(), true); + + let encrypted = store.encrypt("my-secret").unwrap(); + assert!(encrypted.starts_with("enc2:")); + + let (plaintext, migrated) = store.decrypt_and_migrate(&encrypted).unwrap(); + assert_eq!(plaintext, "my-secret"); + assert!( + migrated.is_none(), + "enc2: values should not trigger migration" + ); + } + + #[test] + fn decrypt_and_migrate_returns_none_for_plaintext() { + let tmp = TempDir::new().unwrap(); + let store = SecretStore::new(tmp.path(), true); + + let (plaintext, migrated) = store.decrypt_and_migrate("sk-plaintext-key").unwrap(); + assert_eq!(plaintext, "sk-plaintext-key"); + assert!( + migrated.is_none(), + "Plaintext values should not trigger migration" + ); + } + + #[test] + fn decrypt_and_migrate_upgrades_legacy_xor() { + let tmp = TempDir::new().unwrap(); + let store = SecretStore::new(tmp.path(), true); + + // Create key first + let _ = store.encrypt("setup").unwrap(); + let key = store.load_or_create_key().unwrap(); + + // Manually create a legacy XOR-encrypted value + let plaintext = "sk-legacy-secret-to-migrate"; + let ciphertext = xor_cipher(plaintext.as_bytes(), &key); + let legacy_value = format!("enc:{}", hex_encode(&ciphertext)); + + // Verify it needs migration + assert!(SecretStore::needs_migration(&legacy_value)); + + // Decrypt and migrate + let (decrypted, migrated) = store.decrypt_and_migrate(&legacy_value).unwrap(); + assert_eq!(decrypted, plaintext, "Plaintext must match original"); + assert!(migrated.is_some(), "Legacy value should trigger migration"); + + let new_value = migrated.unwrap(); + assert!( + new_value.starts_with("enc2:"), + "Migrated value must use enc2: prefix" + ); + assert!( + !SecretStore::needs_migration(&new_value), + "Migrated value should not need migration" + ); + + // Verify the migrated value decrypts correctly + let (decrypted2, migrated2) = store.decrypt_and_migrate(&new_value).unwrap(); + assert_eq!( + decrypted2, plaintext, + "Migrated value must decrypt to same plaintext" + ); + assert!( + migrated2.is_none(), + "Migrated value should not trigger another migration" + ); + } + + #[test] + fn decrypt_and_migrate_handles_unicode() { + let tmp = TempDir::new().unwrap(); + let store = SecretStore::new(tmp.path(), true); + + let _ = store.encrypt("setup").unwrap(); + let key = store.load_or_create_key().unwrap(); + + let plaintext = "sk-日本語-émojis-🦀-тест"; + let ciphertext = xor_cipher(plaintext.as_bytes(), &key); + let legacy_value = format!("enc:{}", hex_encode(&ciphertext)); + + let (decrypted, migrated) = store.decrypt_and_migrate(&legacy_value).unwrap(); + assert_eq!(decrypted, plaintext); + assert!(migrated.is_some()); + + // Verify migrated value works + let new_value = migrated.unwrap(); + let (decrypted2, _) = store.decrypt_and_migrate(&new_value).unwrap(); + assert_eq!(decrypted2, plaintext); + } + + #[test] + fn decrypt_and_migrate_handles_empty_secret() { + let tmp = TempDir::new().unwrap(); + let store = SecretStore::new(tmp.path(), true); + + let _ = store.encrypt("setup").unwrap(); + let key = store.load_or_create_key().unwrap(); + + // Empty plaintext XOR-encrypted + let plaintext = ""; + let ciphertext = xor_cipher(plaintext.as_bytes(), &key); + let legacy_value = format!("enc:{}", hex_encode(&ciphertext)); + + let (decrypted, migrated) = store.decrypt_and_migrate(&legacy_value).unwrap(); + assert_eq!(decrypted, plaintext); + // Empty string encryption returns empty string (not enc2:) + assert!(migrated.is_some()); + assert_eq!(migrated.unwrap(), ""); + } + + #[test] + fn decrypt_and_migrate_handles_long_secret() { + let tmp = TempDir::new().unwrap(); + let store = SecretStore::new(tmp.path(), true); + + let _ = store.encrypt("setup").unwrap(); + let key = store.load_or_create_key().unwrap(); + + let plaintext = "a".repeat(10_000); + let ciphertext = xor_cipher(plaintext.as_bytes(), &key); + let legacy_value = format!("enc:{}", hex_encode(&ciphertext)); + + let (decrypted, migrated) = store.decrypt_and_migrate(&legacy_value).unwrap(); + assert_eq!(decrypted, plaintext); + assert!(migrated.is_some()); + + let new_value = migrated.unwrap(); + let (decrypted2, _) = store.decrypt_and_migrate(&new_value).unwrap(); + assert_eq!(decrypted2, plaintext); + } + + #[test] + fn decrypt_and_migrate_fails_on_corrupt_legacy_hex() { + let tmp = TempDir::new().unwrap(); + let store = SecretStore::new(tmp.path(), true); + let _ = store.encrypt("setup").unwrap(); + + let result = store.decrypt_and_migrate("enc:not-valid-hex!!"); + assert!(result.is_err(), "Corrupt hex should fail"); + } + + #[test] + fn decrypt_and_migrate_wrong_key_produces_garbage_or_fails() { + let tmp1 = TempDir::new().unwrap(); + let tmp2 = TempDir::new().unwrap(); + let store1 = SecretStore::new(tmp1.path(), true); + let store2 = SecretStore::new(tmp2.path(), true); + + // Create keys for both stores + let _ = store1.encrypt("setup").unwrap(); + let _ = store2.encrypt("setup").unwrap(); + let key1 = store1.load_or_create_key().unwrap(); + + // Encrypt with store1's key + let plaintext = "secret-for-store1"; + let ciphertext = xor_cipher(plaintext.as_bytes(), &key1); + let legacy_value = format!("enc:{}", hex_encode(&ciphertext)); + + // Decrypt with store2 — XOR will produce garbage bytes + // This may fail with UTF-8 error or succeed with garbage plaintext + match store2.decrypt_and_migrate(&legacy_value) { + Ok((decrypted, _)) => { + // If it succeeds, the plaintext should be garbage (not the original) + assert_ne!( + decrypted, plaintext, + "Wrong key should produce garbage plaintext" + ); + } + Err(e) => { + // Expected: UTF-8 decoding failure from garbage bytes + assert!( + e.to_string().contains("UTF-8"), + "Error should be UTF-8 related: {e}" + ); + } + } + } + + #[test] + fn migration_produces_different_ciphertext_each_time() { + let tmp = TempDir::new().unwrap(); + let store = SecretStore::new(tmp.path(), true); + + let _ = store.encrypt("setup").unwrap(); + let key = store.load_or_create_key().unwrap(); + + let plaintext = "sk-same-secret"; + let ciphertext = xor_cipher(plaintext.as_bytes(), &key); + let legacy_value = format!("enc:{}", hex_encode(&ciphertext)); + + let (_, migrated1) = store.decrypt_and_migrate(&legacy_value).unwrap(); + let (_, migrated2) = store.decrypt_and_migrate(&legacy_value).unwrap(); + + assert!(migrated1.is_some()); + assert!(migrated2.is_some()); + assert_ne!( + migrated1.unwrap(), + migrated2.unwrap(), + "Each migration should produce different ciphertext (random nonce)" + ); + } + + #[test] + fn migrated_value_is_tamper_resistant() { + let tmp = TempDir::new().unwrap(); + let store = SecretStore::new(tmp.path(), true); + + let _ = store.encrypt("setup").unwrap(); + let key = store.load_or_create_key().unwrap(); + + let plaintext = "sk-sensitive-data"; + let ciphertext = xor_cipher(plaintext.as_bytes(), &key); + let legacy_value = format!("enc:{}", hex_encode(&ciphertext)); + + let (_, migrated) = store.decrypt_and_migrate(&legacy_value).unwrap(); + let new_value = migrated.unwrap(); + + // Tamper with the migrated value + let hex_str = &new_value[5..]; + let mut blob = hex_decode(hex_str).unwrap(); + if blob.len() > NONCE_LEN { + blob[NONCE_LEN] ^= 0xff; + } + let tampered = format!("enc2:{}", hex_encode(&blob)); + + let result = store.decrypt_and_migrate(&tampered); + assert!(result.is_err(), "Tampered migrated value must be rejected"); + } + // ── Low-level helpers ─────────────────────────────────────── #[test] diff --git a/src/service/mod.rs b/src/service/mod.rs new file mode 100644 index 0000000..fc6bf51 --- /dev/null +++ b/src/service/mod.rs @@ -0,0 +1,284 @@ +use crate::config::Config; +use anyhow::{Context, Result}; +use std::fs; +use std::path::PathBuf; +use std::process::Command; + +const SERVICE_LABEL: &str = "com.zeroclaw.daemon"; + +pub fn handle_command(command: super::ServiceCommands, config: &Config) -> Result<()> { + match command { + super::ServiceCommands::Install => install(config), + super::ServiceCommands::Start => start(config), + super::ServiceCommands::Stop => stop(config), + super::ServiceCommands::Status => status(config), + super::ServiceCommands::Uninstall => uninstall(config), + } +} + +fn install(config: &Config) -> Result<()> { + if cfg!(target_os = "macos") { + install_macos(config) + } else if cfg!(target_os = "linux") { + install_linux(config) + } else { + anyhow::bail!("Service management is supported on macOS and Linux only"); + } +} + +fn start(config: &Config) -> Result<()> { + if cfg!(target_os = "macos") { + let plist = macos_service_file()?; + run_checked(Command::new("launchctl").arg("load").arg("-w").arg(&plist))?; + run_checked(Command::new("launchctl").arg("start").arg(SERVICE_LABEL))?; + println!("✅ Service started"); + Ok(()) + } else if cfg!(target_os = "linux") { + run_checked(Command::new("systemctl").args(["--user", "daemon-reload"]))?; + run_checked(Command::new("systemctl").args(["--user", "start", "zeroclaw.service"]))?; + println!("✅ Service started"); + Ok(()) + } else { + let _ = config; + anyhow::bail!("Service management is supported on macOS and Linux only") + } +} + +fn stop(config: &Config) -> Result<()> { + if cfg!(target_os = "macos") { + let plist = macos_service_file()?; + let _ = run_checked(Command::new("launchctl").arg("stop").arg(SERVICE_LABEL)); + let _ = run_checked( + Command::new("launchctl") + .arg("unload") + .arg("-w") + .arg(&plist), + ); + println!("✅ Service stopped"); + Ok(()) + } else if cfg!(target_os = "linux") { + let _ = run_checked(Command::new("systemctl").args(["--user", "stop", "zeroclaw.service"])); + println!("✅ Service stopped"); + Ok(()) + } else { + let _ = config; + anyhow::bail!("Service management is supported on macOS and Linux only") + } +} + +fn status(config: &Config) -> Result<()> { + if cfg!(target_os = "macos") { + let out = run_capture(Command::new("launchctl").arg("list"))?; + let running = out.lines().any(|line| line.contains(SERVICE_LABEL)); + println!( + "Service: {}", + if running { + "✅ running/loaded" + } else { + "❌ not loaded" + } + ); + println!("Unit: {}", macos_service_file()?.display()); + return Ok(()); + } + + if cfg!(target_os = "linux") { + let out = run_capture(Command::new("systemctl").args([ + "--user", + "is-active", + "zeroclaw.service", + ])) + .unwrap_or_else(|_| "unknown".into()); + println!("Service state: {}", out.trim()); + println!("Unit: {}", linux_service_file(config)?.display()); + return Ok(()); + } + + anyhow::bail!("Service management is supported on macOS and Linux only") +} + +fn uninstall(config: &Config) -> Result<()> { + stop(config)?; + + if cfg!(target_os = "macos") { + let file = macos_service_file()?; + if file.exists() { + fs::remove_file(&file) + .with_context(|| format!("Failed to remove {}", file.display()))?; + } + println!("✅ Service uninstalled ({})", file.display()); + return Ok(()); + } + + if cfg!(target_os = "linux") { + let file = linux_service_file(config)?; + if file.exists() { + fs::remove_file(&file) + .with_context(|| format!("Failed to remove {}", file.display()))?; + } + let _ = run_checked(Command::new("systemctl").args(["--user", "daemon-reload"])); + println!("✅ Service uninstalled ({})", file.display()); + return Ok(()); + } + + anyhow::bail!("Service management is supported on macOS and Linux only") +} + +fn install_macos(config: &Config) -> Result<()> { + let file = macos_service_file()?; + if let Some(parent) = file.parent() { + fs::create_dir_all(parent)?; + } + + let exe = std::env::current_exe().context("Failed to resolve current executable")?; + let logs_dir = config + .config_path + .parent() + .map_or_else(|| PathBuf::from("."), PathBuf::from) + .join("logs"); + fs::create_dir_all(&logs_dir)?; + + let stdout = logs_dir.join("daemon.stdout.log"); + let stderr = logs_dir.join("daemon.stderr.log"); + + let plist = format!( + r#" + + + + Label + {label} + ProgramArguments + + {exe} + daemon + + RunAtLoad + + KeepAlive + + StandardOutPath + {stdout} + StandardErrorPath + {stderr} + + +"#, + label = SERVICE_LABEL, + exe = xml_escape(&exe.display().to_string()), + stdout = xml_escape(&stdout.display().to_string()), + stderr = xml_escape(&stderr.display().to_string()) + ); + + fs::write(&file, plist)?; + println!("✅ Installed launchd service: {}", file.display()); + println!(" Start with: zeroclaw service start"); + Ok(()) +} + +fn install_linux(config: &Config) -> Result<()> { + let file = linux_service_file(config)?; + if let Some(parent) = file.parent() { + fs::create_dir_all(parent)?; + } + + let exe = std::env::current_exe().context("Failed to resolve current executable")?; + let unit = format!( + "[Unit]\nDescription=ZeroClaw daemon\nAfter=network.target\n\n[Service]\nType=simple\nExecStart={} daemon\nRestart=always\nRestartSec=3\n\n[Install]\nWantedBy=default.target\n", + exe.display() + ); + + fs::write(&file, unit)?; + let _ = run_checked(Command::new("systemctl").args(["--user", "daemon-reload"])); + let _ = run_checked(Command::new("systemctl").args(["--user", "enable", "zeroclaw.service"])); + println!("✅ Installed systemd user service: {}", file.display()); + println!(" Start with: zeroclaw service start"); + Ok(()) +} + +fn macos_service_file() -> Result { + let home = directories::UserDirs::new() + .map(|u| u.home_dir().to_path_buf()) + .context("Could not find home directory")?; + Ok(home + .join("Library") + .join("LaunchAgents") + .join(format!("{SERVICE_LABEL}.plist"))) +} + +fn linux_service_file(config: &Config) -> Result { + let home = directories::UserDirs::new() + .map(|u| u.home_dir().to_path_buf()) + .context("Could not find home directory")?; + let _ = config; + Ok(home + .join(".config") + .join("systemd") + .join("user") + .join("zeroclaw.service")) +} + +fn run_checked(command: &mut Command) -> Result<()> { + let output = command.output().context("Failed to spawn command")?; + if !output.status.success() { + let stderr = String::from_utf8_lossy(&output.stderr); + anyhow::bail!("Command failed: {}", stderr.trim()); + } + Ok(()) +} + +fn run_capture(command: &mut Command) -> Result { + let output = command.output().context("Failed to spawn command")?; + let mut text = String::from_utf8_lossy(&output.stdout).to_string(); + if text.trim().is_empty() { + text = String::from_utf8_lossy(&output.stderr).to_string(); + } + Ok(text) +} + +fn xml_escape(raw: &str) -> String { + raw.replace('&', "&") + .replace('<', "<") + .replace('>', ">") + .replace('"', """) + .replace('\'', "'") +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn xml_escape_escapes_reserved_chars() { + let escaped = xml_escape("<&>\"' and text"); + assert_eq!(escaped, "<&>"' and text"); + } + + #[test] + fn run_capture_reads_stdout() { + let out = run_capture(Command::new("sh").args(["-lc", "echo hello"])) + .expect("stdout capture should succeed"); + assert_eq!(out.trim(), "hello"); + } + + #[test] + fn run_capture_falls_back_to_stderr() { + let out = run_capture(Command::new("sh").args(["-lc", "echo warn 1>&2"])) + .expect("stderr capture should succeed"); + assert_eq!(out.trim(), "warn"); + } + + #[test] + fn run_checked_errors_on_non_zero_status() { + let err = run_checked(Command::new("sh").args(["-lc", "exit 17"])) + .expect_err("non-zero exit should error"); + assert!(err.to_string().contains("Command failed")); + } + + #[test] + fn linux_service_file_has_expected_suffix() { + let file = linux_service_file(&Config::default()).unwrap(); + let path = file.to_string_lossy(); + assert!(path.ends_with(".config/systemd/user/zeroclaw.service")); + } +} diff --git a/src/skills/mod.rs b/src/skills/mod.rs index 5b5c52b..0b108fc 100644 --- a/src/skills/mod.rs +++ b/src/skills/mod.rs @@ -221,6 +221,23 @@ pub fn init_skills_dir(workspace_dir: &Path) -> Result<()> { Ok(()) } +/// Recursively copy a directory (used as fallback when symlinks aren't available) +#[cfg(any(windows, not(unix)))] +fn copy_dir_recursive(src: &Path, dest: &Path) -> Result<()> { + std::fs::create_dir_all(dest)?; + for entry in std::fs::read_dir(src)? { + let entry = entry?; + let src_path = entry.path(); + let dest_path = dest.join(entry.file_name()); + if src_path.is_dir() { + copy_dir_recursive(&src_path, &dest_path)?; + } else { + std::fs::copy(&src_path, &dest_path)?; + } + } + Ok(()) +} + /// Handle the `skills` CLI command pub fn handle_command(command: super::SkillCommands, workspace_dir: &Path) -> Result<()> { match command { @@ -295,18 +312,60 @@ pub fn handle_command(command: super::SkillCommands, workspace_dir: &Path) -> Re let dest = skills_path.join(name); #[cfg(unix)] - std::os::unix::fs::symlink(&src, &dest)?; - #[cfg(not(unix))] { - // On non-unix, copy the directory - anyhow::bail!("Symlink not supported on this platform. Copy the skill directory manually."); + std::os::unix::fs::symlink(&src, &dest)?; + println!( + " {} Skill linked: {}", + console::style("✓").green().bold(), + dest.display() + ); } + #[cfg(windows)] + { + // On Windows, try symlink first (requires admin or developer mode), + // fall back to directory junction, then copy + use std::os::windows::fs::symlink_dir; + if symlink_dir(&src, &dest).is_ok() { + println!( + " {} Skill linked: {}", + console::style("✓").green().bold(), + dest.display() + ); + } else { + // Try junction as fallback (works without admin) + let junction_result = std::process::Command::new("cmd") + .args(["/C", "mklink", "/J"]) + .arg(&dest) + .arg(&src) + .output(); - println!( - " {} Skill linked: {}", - console::style("✓").green().bold(), - dest.display() - ); + if junction_result.is_ok() && junction_result.unwrap().status.success() { + println!( + " {} Skill linked (junction): {}", + console::style("✓").green().bold(), + dest.display() + ); + } else { + // Final fallback: copy the directory + copy_dir_recursive(&src, &dest)?; + println!( + " {} Skill copied: {}", + console::style("✓").green().bold(), + dest.display() + ); + } + } + } + #[cfg(not(any(unix, windows)))] + { + // On other platforms, copy the directory + copy_dir_recursive(&src, &dest)?; + println!( + " {} Skill copied: {}", + console::style("✓").green().bold(), + dest.display() + ); + } } Ok(()) @@ -631,3 +690,6 @@ description = "Bare minimum" assert_eq!(skills[0].name, "from-toml"); // TOML takes priority } } + +#[cfg(test)] +mod symlink_tests; diff --git a/src/skills/symlink_tests.rs b/src/skills/symlink_tests.rs new file mode 100644 index 0000000..5968174 --- /dev/null +++ b/src/skills/symlink_tests.rs @@ -0,0 +1,109 @@ +#[cfg(test)] +mod symlink_tests { + use crate::skills::skills_dir; + use std::path::Path; + use tempfile::TempDir; + + #[test] + fn test_skills_symlink_unix_edge_cases() { + let tmp = TempDir::new().unwrap(); + let workspace_dir = tmp.path().join("workspace"); + std::fs::create_dir_all(&workspace_dir).unwrap(); + + let skills_path = skills_dir(&workspace_dir); + std::fs::create_dir_all(&skills_path).unwrap(); + + // Test case 1: Valid symlink creation on Unix + #[cfg(unix)] + { + let source_dir = tmp.path().join("source_skill"); + std::fs::create_dir_all(&source_dir).unwrap(); + std::fs::write(source_dir.join("SKILL.md"), "# Test Skill\nContent").unwrap(); + + let dest_link = skills_path.join("linked_skill"); + + // Create symlink + let result = std::os::unix::fs::symlink(&source_dir, &dest_link); + assert!(result.is_ok(), "Symlink creation should succeed"); + + // Verify symlink works + assert!(dest_link.exists()); + assert!(dest_link.is_symlink()); + + // Verify we can read through symlink + let content = std::fs::read_to_string(dest_link.join("SKILL.md")); + assert!(content.is_ok()); + assert!(content.unwrap().contains("Test Skill")); + + // Test case 2: Symlink to non-existent target should fail gracefully + let broken_link = skills_path.join("broken_skill"); + let non_existent = tmp.path().join("non_existent"); + let result = std::os::unix::fs::symlink(&non_existent, &broken_link); + assert!( + result.is_ok(), + "Symlink creation should succeed even if target doesn't exist" + ); + + // But reading through it should fail + let content = std::fs::read_to_string(broken_link.join("SKILL.md")); + assert!(content.is_err()); + } + + // Test case 3: Non-Unix platforms should handle symlink errors gracefully + #[cfg(not(unix))] + { + let source_dir = tmp.path().join("source_skill"); + std::fs::create_dir_all(&source_dir).unwrap(); + + let dest_link = skills_path.join("linked_skill"); + + // Symlink should fail on non-Unix + let result = std::os::unix::fs::symlink(&source_dir, &dest_link); + assert!(result.is_err()); + + // Directory should not exist + assert!(!dest_link.exists()); + } + + // Test case 4: skills_dir function edge cases + let workspace_with_trailing_slash = format!("{}/", workspace_dir.display()); + let path_from_str = skills_dir(Path::new(&workspace_with_trailing_slash)); + assert_eq!(path_from_str, skills_path); + + // Test case 5: Empty workspace directory + let empty_workspace = tmp.path().join("empty"); + let empty_skills_path = skills_dir(&empty_workspace); + assert_eq!(empty_skills_path, empty_workspace.join("skills")); + assert!(!empty_skills_path.exists()); + } + + #[test] + fn test_skills_symlink_permissions_and_safety() { + let tmp = TempDir::new().unwrap(); + let workspace_dir = tmp.path().join("workspace"); + std::fs::create_dir_all(&workspace_dir).unwrap(); + + let skills_path = skills_dir(&workspace_dir); + std::fs::create_dir_all(&skills_path).unwrap(); + + #[cfg(unix)] + { + // Test case: Symlink outside workspace should be allowed (user responsibility) + let outside_dir = tmp.path().join("outside_skill"); + std::fs::create_dir_all(&outside_dir).unwrap(); + std::fs::write(outside_dir.join("SKILL.md"), "# Outside Skill\nContent").unwrap(); + + let dest_link = skills_path.join("outside_skill"); + let result = std::os::unix::fs::symlink(&outside_dir, &dest_link); + assert!( + result.is_ok(), + "Should allow symlinking to directories outside workspace" + ); + + // Should still be readable + let content = std::fs::read_to_string(dest_link.join("SKILL.md")); + assert!(content.is_ok()); + assert!(content.unwrap().contains("Outside Skill")); + } + } +} diff --git a/src/tools/file_read.rs b/src/tools/file_read.rs index 1798d2d..97c46e0 100644 --- a/src/tools/file_read.rs +++ b/src/tools/file_read.rs @@ -55,7 +55,30 @@ impl Tool for FileReadTool { let full_path = self.security.workspace_dir.join(path); - match tokio::fs::read_to_string(&full_path).await { + // Resolve path before reading to block symlink escapes. + let resolved_path = match tokio::fs::canonicalize(&full_path).await { + Ok(p) => p, + Err(e) => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!("Failed to resolve file path: {e}")), + }); + } + }; + + if !self.security.is_resolved_path_allowed(&resolved_path) { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!( + "Resolved path escapes workspace: {}", + resolved_path.display() + )), + }); + } + + match tokio::fs::read_to_string(&resolved_path).await { Ok(contents) => Ok(ToolResult { success: true, output: contents, @@ -127,7 +150,7 @@ mod tests { let tool = FileReadTool::new(test_security(dir.clone())); let result = tool.execute(json!({"path": "nope.txt"})).await.unwrap(); assert!(!result.success); - assert!(result.error.as_ref().unwrap().contains("Failed to read")); + assert!(result.error.as_ref().unwrap().contains("Failed to resolve")); let _ = tokio::fs::remove_dir_all(&dir).await; } @@ -200,4 +223,36 @@ mod tests { let _ = tokio::fs::remove_dir_all(&dir).await; } + + #[cfg(unix)] + #[tokio::test] + async fn file_read_blocks_symlink_escape() { + use std::os::unix::fs::symlink; + + let root = std::env::temp_dir().join("zeroclaw_test_file_read_symlink_escape"); + let workspace = root.join("workspace"); + let outside = root.join("outside"); + + let _ = tokio::fs::remove_dir_all(&root).await; + tokio::fs::create_dir_all(&workspace).await.unwrap(); + tokio::fs::create_dir_all(&outside).await.unwrap(); + + tokio::fs::write(outside.join("secret.txt"), "outside workspace") + .await + .unwrap(); + + symlink(outside.join("secret.txt"), workspace.join("escape.txt")).unwrap(); + + let tool = FileReadTool::new(test_security(workspace.clone())); + let result = tool.execute(json!({"path": "escape.txt"})).await.unwrap(); + + assert!(!result.success); + assert!(result + .error + .as_deref() + .unwrap_or("") + .contains("escapes workspace")); + + let _ = tokio::fs::remove_dir_all(&root).await; + } } diff --git a/src/tools/file_write.rs b/src/tools/file_write.rs index f31191d..f147497 100644 --- a/src/tools/file_write.rs +++ b/src/tools/file_write.rs @@ -69,7 +69,54 @@ impl Tool for FileWriteTool { tokio::fs::create_dir_all(parent).await?; } - match tokio::fs::write(&full_path, content).await { + let parent = match full_path.parent() { + Some(p) => p, + None => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("Invalid path: missing parent directory".into()), + }); + } + }; + + // Resolve parent before writing to block symlink escapes. + let resolved_parent = match tokio::fs::canonicalize(parent).await { + Ok(p) => p, + Err(e) => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!("Failed to resolve file path: {e}")), + }); + } + }; + + if !self.security.is_resolved_path_allowed(&resolved_parent) { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!( + "Resolved path escapes workspace: {}", + resolved_parent.display() + )), + }); + } + + let file_name = match full_path.file_name() { + Some(name) => name, + None => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("Invalid path: missing file name".into()), + }); + } + }; + + let resolved_target = resolved_parent.join(file_name); + + match tokio::fs::write(&resolved_target, content).await { Ok(()) => Ok(ToolResult { success: true, output: format!("Written {} bytes to {path}", content.len()), @@ -239,4 +286,36 @@ mod tests { let _ = tokio::fs::remove_dir_all(&dir).await; } + + #[cfg(unix)] + #[tokio::test] + async fn file_write_blocks_symlink_escape() { + use std::os::unix::fs::symlink; + + let root = std::env::temp_dir().join("zeroclaw_test_file_write_symlink_escape"); + let workspace = root.join("workspace"); + let outside = root.join("outside"); + + let _ = tokio::fs::remove_dir_all(&root).await; + tokio::fs::create_dir_all(&workspace).await.unwrap(); + tokio::fs::create_dir_all(&outside).await.unwrap(); + + symlink(&outside, workspace.join("escape_dir")).unwrap(); + + let tool = FileWriteTool::new(test_security(workspace.clone())); + let result = tool + .execute(json!({"path": "escape_dir/hijack.txt", "content": "bad"})) + .await + .unwrap(); + + assert!(!result.success); + assert!(result + .error + .as_deref() + .unwrap_or("") + .contains("escapes workspace")); + assert!(!outside.join("hijack.txt").exists()); + + let _ = tokio::fs::remove_dir_all(&root).await; + } } diff --git a/tests/dockerignore_test.rs b/tests/dockerignore_test.rs new file mode 100644 index 0000000..e90828c --- /dev/null +++ b/tests/dockerignore_test.rs @@ -0,0 +1,327 @@ +//! Tests to verify .dockerignore excludes sensitive paths from Docker build context. +//! +//! These tests validate that: +//! 1. The .dockerignore file exists +//! 2. All security-critical paths are excluded +//! 3. All build-essential paths are NOT excluded +//! 4. Pattern syntax is valid + +use std::fs; +use std::path::Path; + +/// Paths that MUST be excluded from Docker build context (security/performance) +const MUST_EXCLUDE: &[&str] = &[ + ".git", + ".githooks", + "target", + "docs", + "examples", + "tests", + "*.md", + "*.png", + "*.db", + "*.db-journal", + ".DS_Store", + ".github", + "deny.toml", + "LICENSE", + ".env", + ".tmp_*", +]; + +/// Paths that MUST NOT be excluded (required for build) +const MUST_INCLUDE: &[&str] = &["Cargo.toml", "Cargo.lock", "src/"]; + +/// Parse .dockerignore and return all non-comment, non-empty lines +fn parse_dockerignore(content: &str) -> Vec { + content + .lines() + .map(|line| line.trim()) + .filter(|line| !line.is_empty() && !line.starts_with('#')) + .map(|line| line.to_string()) + .collect() +} + +/// Check if a pattern would match a given path +fn pattern_matches(pattern: &str, path: &str) -> bool { + // Handle negation patterns + if pattern.starts_with('!') { + return false; // Negation re-includes, so it doesn't "exclude" + } + + // Handle glob patterns + if pattern.starts_with("*.") { + let ext = &pattern[1..]; // e.g., ".md" + return path.ends_with(ext); + } + + // Handle directory patterns (with or without trailing slash) + let pattern_normalized = pattern.trim_end_matches('/'); + let path_normalized = path.trim_end_matches('/'); + + // Exact match + if path_normalized == pattern_normalized { + return true; + } + + // Pattern is a prefix (directory match) + if path_normalized.starts_with(&format!("{}/", pattern_normalized)) { + return true; + } + + // Wildcard prefix patterns like ".tmp_*" + if pattern.contains('*') && !pattern.starts_with("*.") { + let prefix = pattern.split('*').next().unwrap_or(""); + if !prefix.is_empty() && path.starts_with(prefix) { + return true; + } + } + + false +} + +/// Check if any pattern in the list would exclude the given path +fn is_excluded(patterns: &[String], path: &str) -> bool { + let mut excluded = false; + for pattern in patterns { + if pattern.starts_with('!') { + // Negation pattern - re-include + let negated = &pattern[1..]; + if pattern_matches(negated, path) { + excluded = false; + } + } else if pattern_matches(pattern, path) { + excluded = true; + } + } + excluded +} + +#[test] +fn dockerignore_file_exists() { + let path = Path::new(env!("CARGO_MANIFEST_DIR")).join(".dockerignore"); + assert!( + path.exists(), + ".dockerignore file must exist at project root" + ); +} + +#[test] +fn dockerignore_excludes_security_critical_paths() { + let path = Path::new(env!("CARGO_MANIFEST_DIR")).join(".dockerignore"); + let content = fs::read_to_string(&path).expect("Failed to read .dockerignore"); + let patterns = parse_dockerignore(&content); + + for must_exclude in MUST_EXCLUDE { + // For glob patterns, test with a sample file + let test_path = if must_exclude.starts_with("*.") { + format!("sample{}", &must_exclude[1..]) + } else { + must_exclude.to_string() + }; + + assert!( + is_excluded(&patterns, &test_path), + "Path '{}' (tested as '{}') MUST be excluded by .dockerignore but is not. \ + This is a security/performance issue.", + must_exclude, + test_path + ); + } +} + +#[test] +fn dockerignore_does_not_exclude_build_essentials() { + let path = Path::new(env!("CARGO_MANIFEST_DIR")).join(".dockerignore"); + let content = fs::read_to_string(&path).expect("Failed to read .dockerignore"); + let patterns = parse_dockerignore(&content); + + for must_include in MUST_INCLUDE { + assert!( + !is_excluded(&patterns, must_include), + "Path '{}' MUST NOT be excluded by .dockerignore (required for build)", + must_include + ); + } +} + +#[test] +fn dockerignore_excludes_git_directory() { + let path = Path::new(env!("CARGO_MANIFEST_DIR")).join(".dockerignore"); + let content = fs::read_to_string(&path).expect("Failed to read .dockerignore"); + let patterns = parse_dockerignore(&content); + + // .git directory and its contents must be excluded + assert!(is_excluded(&patterns, ".git"), ".git must be excluded"); + assert!( + is_excluded(&patterns, ".git/config"), + ".git/config must be excluded" + ); + assert!( + is_excluded(&patterns, ".git/objects/pack/pack-abc123.pack"), + ".git subdirectories must be excluded" + ); +} + +#[test] +fn dockerignore_excludes_target_directory() { + let path = Path::new(env!("CARGO_MANIFEST_DIR")).join(".dockerignore"); + let content = fs::read_to_string(&path).expect("Failed to read .dockerignore"); + let patterns = parse_dockerignore(&content); + + assert!(is_excluded(&patterns, "target"), "target must be excluded"); + assert!( + is_excluded(&patterns, "target/debug/zeroclaw"), + "target/debug must be excluded" + ); + assert!( + is_excluded(&patterns, "target/release/zeroclaw"), + "target/release must be excluded" + ); +} + +#[test] +fn dockerignore_excludes_database_files() { + let path = Path::new(env!("CARGO_MANIFEST_DIR")).join(".dockerignore"); + let content = fs::read_to_string(&path).expect("Failed to read .dockerignore"); + let patterns = parse_dockerignore(&content); + + assert!( + is_excluded(&patterns, "brain.db"), + "*.db files must be excluded" + ); + assert!( + is_excluded(&patterns, "memory.db"), + "*.db files must be excluded" + ); + assert!( + is_excluded(&patterns, "brain.db-journal"), + "*.db-journal files must be excluded" + ); +} + +#[test] +fn dockerignore_excludes_markdown_files() { + let path = Path::new(env!("CARGO_MANIFEST_DIR")).join(".dockerignore"); + let content = fs::read_to_string(&path).expect("Failed to read .dockerignore"); + let patterns = parse_dockerignore(&content); + + assert!( + is_excluded(&patterns, "README.md"), + "*.md files must be excluded" + ); + assert!( + is_excluded(&patterns, "CHANGELOG.md"), + "*.md files must be excluded" + ); + assert!( + is_excluded(&patterns, "CONTRIBUTING.md"), + "*.md files must be excluded" + ); +} + +#[test] +fn dockerignore_excludes_image_files() { + let path = Path::new(env!("CARGO_MANIFEST_DIR")).join(".dockerignore"); + let content = fs::read_to_string(&path).expect("Failed to read .dockerignore"); + let patterns = parse_dockerignore(&content); + + assert!( + is_excluded(&patterns, "zeroclaw.png"), + "*.png files must be excluded" + ); + assert!( + is_excluded(&patterns, "logo.png"), + "*.png files must be excluded" + ); +} + +#[test] +fn dockerignore_excludes_env_files() { + let path = Path::new(env!("CARGO_MANIFEST_DIR")).join(".dockerignore"); + let content = fs::read_to_string(&path).expect("Failed to read .dockerignore"); + let patterns = parse_dockerignore(&content); + + assert!( + is_excluded(&patterns, ".env"), + ".env must be excluded (contains secrets)" + ); +} + +#[test] +fn dockerignore_excludes_ci_configs() { + let path = Path::new(env!("CARGO_MANIFEST_DIR")).join(".dockerignore"); + let content = fs::read_to_string(&path).expect("Failed to read .dockerignore"); + let patterns = parse_dockerignore(&content); + + assert!( + is_excluded(&patterns, ".github"), + ".github must be excluded" + ); + assert!( + is_excluded(&patterns, ".github/workflows/ci.yml"), + ".github/workflows must be excluded" + ); +} + +#[test] +fn dockerignore_has_valid_syntax() { + let path = Path::new(env!("CARGO_MANIFEST_DIR")).join(".dockerignore"); + let content = fs::read_to_string(&path).expect("Failed to read .dockerignore"); + + for (line_num, line) in content.lines().enumerate() { + let trimmed = line.trim(); + + // Skip empty lines and comments + if trimmed.is_empty() || trimmed.starts_with('#') { + continue; + } + + // Check for invalid patterns + assert!( + !trimmed.contains("**") || trimmed.matches("**").count() <= 2, + "Line {}: Too many ** in pattern '{}'", + line_num + 1, + trimmed + ); + + // Check for trailing spaces (can cause issues) + assert!( + line.trim_end() == line.trim_start().trim_end(), + "Line {}: Pattern '{}' has leading whitespace which may cause issues", + line_num + 1, + line + ); + } +} + +#[test] +fn dockerignore_pattern_matching_edge_cases() { + // Test the pattern matching logic itself + let patterns = vec![ + ".git".to_string(), + ".githooks".to_string(), + "target".to_string(), + "*.md".to_string(), + "*.db".to_string(), + ".tmp_*".to_string(), + ".env".to_string(), + ]; + + // Should match + assert!(is_excluded(&patterns, ".git")); + assert!(is_excluded(&patterns, ".git/config")); + assert!(is_excluded(&patterns, ".githooks")); + assert!(is_excluded(&patterns, "target")); + assert!(is_excluded(&patterns, "target/debug/build")); + assert!(is_excluded(&patterns, "README.md")); + assert!(is_excluded(&patterns, "brain.db")); + assert!(is_excluded(&patterns, ".tmp_todo_probe")); + assert!(is_excluded(&patterns, ".env")); + + // Should NOT match + assert!(!is_excluded(&patterns, "src")); + assert!(!is_excluded(&patterns, "src/main.rs")); + assert!(!is_excluded(&patterns, "Cargo.toml")); + assert!(!is_excluded(&patterns, "Cargo.lock")); +}