merge: resolve conflicts between feat/whatsapp-email-channels and main

- Keep main's WhatsApp implementation (webhook-based, simpler)
- Preserve email channel fixes from our branch
- Merge all main branch updates (daemon, cron, health, etc.)
- Resolve Cargo.lock conflicts
This commit is contained in:
argenis de la rosa 2026-02-14 14:59:16 -05:00
commit 4e6da51924
40 changed files with 6925 additions and 780 deletions

66
.dockerignore Normal file
View file

@ -0,0 +1,66 @@
# Git history (may contain old secrets)
.git
.gitignore
.githooks
# Rust build artifacts (can be multiple GB)
target
# Documentation and examples (not needed for runtime)
docs
examples
tests
# Markdown files (README, CHANGELOG, etc.)
*.md
# Images (unnecessary for build)
*.png
*.svg
*.jpg
*.jpeg
*.gif
# SQLite databases (conversation history, cron jobs)
*.db
*.db-journal
# macOS artifacts
.DS_Store
.AppleDouble
.LSOverride
# CI/CD configs (not needed in image)
.github
# Cargo deny config (lint tool, not runtime)
deny.toml
# License file (not needed for runtime)
LICENSE
# Temporary files
.tmp_*
*.tmp
*.bak
*.swp
*~
# IDE and editor configs
.idea
.vscode
*.iml
# Windsurf workflows
.windsurf
# Environment files (may contain secrets)
.env
.env.*
!.env.example
# Coverage and profiling
*.profraw
*.profdata
coverage
lcov.info

View file

@ -63,3 +63,40 @@ jobs:
with: with:
name: zeroclaw-${{ matrix.target }} name: zeroclaw-${{ matrix.target }}
path: target/${{ matrix.target }}/release/zeroclaw* path: target/${{ matrix.target }}/release/zeroclaw*
docker:
name: Docker Security
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- name: Build Docker image
run: docker build -t zeroclaw:test .
- name: Verify non-root user (UID != 0)
run: |
USER_ID=$(docker inspect --format='{{.Config.User}}' zeroclaw:test)
echo "Container user: $USER_ID"
if [ "$USER_ID" = "0" ] || [ "$USER_ID" = "root" ] || [ -z "$USER_ID" ]; then
echo "❌ FAIL: Container runs as root (UID 0)"
exit 1
fi
echo "✅ PASS: Container runs as non-root user ($USER_ID)"
- name: Verify distroless nonroot base image
run: |
BASE_IMAGE=$(grep -E '^FROM.*runtime|^FROM gcr.io/distroless' Dockerfile | tail -1)
echo "Base image line: $BASE_IMAGE"
if ! echo "$BASE_IMAGE" | grep -q ':nonroot'; then
echo "❌ FAIL: Runtime stage does not use :nonroot variant"
exit 1
fi
echo "✅ PASS: Using distroless :nonroot variant"
- name: Verify USER directive exists
run: |
if ! grep -qE '^USER\s+[0-9]+' Dockerfile; then
echo "❌ FAIL: No explicit USER directive with numeric UID"
exit 1
fi
echo "✅ PASS: Explicit USER directive found"

0
.tmp_todo_probe Normal file
View file

View file

@ -5,6 +5,24 @@ All notable changes to ZeroClaw will be documented in this file.
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/), The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
## [Unreleased]
### Security
- **Legacy XOR cipher migration**: The `enc:` prefix (XOR cipher) is now deprecated.
Secrets using this format will be automatically migrated to `enc2:` (ChaCha20-Poly1305 AEAD)
when decrypted via `decrypt_and_migrate()`. A `tracing::warn!` is emitted when legacy
values are encountered. The XOR cipher will be removed in a future release.
### Added
- `SecretStore::decrypt_and_migrate()` — Decrypts secrets and returns a migrated `enc2:`
value if the input used the legacy `enc:` format
- `SecretStore::needs_migration()` — Check if a value uses the legacy `enc:` format
- `SecretStore::is_secure_encrypted()` — Check if a value uses the secure `enc2:` format
### Deprecated
- `enc:` prefix for encrypted secrets — Use `enc2:` (ChaCha20-Poly1305) instead.
Legacy values are still decrypted for backward compatibility but should be migrated.
## [0.1.0] - 2025-02-13 ## [0.1.0] - 2025-02-13
### Added ### Added

View file

@ -15,7 +15,7 @@ categories = ["command-line-utilities", "api-bindings"]
clap = { version = "4.5", features = ["derive"] } clap = { version = "4.5", features = ["derive"] }
# Async runtime - feature-optimized for size # Async runtime - feature-optimized for size
tokio = { version = "1.42", default-features = false, features = ["rt-multi-thread", "macros", "time", "net", "io-util", "sync", "process", "io-std", "fs"] } tokio = { version = "1.42", default-features = false, features = ["rt-multi-thread", "macros", "time", "net", "io-util", "sync", "process", "io-std", "fs", "signal"] }
# HTTP client - minimal features # HTTP client - minimal features
reqwest = { version = "0.12", default-features = false, features = ["json", "rustls-tls", "blocking"] } reqwest = { version = "0.12", default-features = false, features = ["json", "rustls-tls", "blocking"] }
@ -49,6 +49,7 @@ async-trait = "0.1"
# Memory / persistence # Memory / persistence
rusqlite = { version = "0.32", features = ["bundled"] } rusqlite = { version = "0.32", features = ["bundled"] }
chrono = { version = "0.4", default-features = false, features = ["clock", "std"] } chrono = { version = "0.4", default-features = false, features = ["clock", "std"] }
cron = "0.12"
# Interactive CLI prompts # Interactive CLI prompts
dialoguer = { version = "0.11", features = ["fuzzy-select"] } dialoguer = { version = "0.11", features = ["fuzzy-select"] }
@ -64,6 +65,12 @@ rustls-pki-types = "1.14.0"
tokio-rustls = "0.26.4" tokio-rustls = "0.26.4"
webpki-roots = "1.0.6" webpki-roots = "1.0.6"
# HTTP server (gateway) — replaces raw TCP for proper HTTP/1.1 compliance
axum = { version = "0.7", default-features = false, features = ["http1", "json", "tokio", "query"] }
tower = { version = "0.5", default-features = false }
tower-http = { version = "0.6", default-features = false, features = ["limit", "timeout"] }
http-body-util = "0.1"
[profile.release] [profile.release]
opt-level = "z" # Optimize for size opt-level = "z" # Optimize for size
lto = true # Link-time optimization lto = true # Link-time optimization

View file

@ -8,14 +8,17 @@ COPY src/ src/
RUN cargo build --release --locked && \ RUN cargo build --release --locked && \
strip target/release/zeroclaw strip target/release/zeroclaw
# ── Stage 2: Runtime (distroless — no shell, no OS, tiny) ──── # ── Stage 2: Runtime (distroless nonroot — no shell, no OS, tiny, UID 65534) ──
FROM gcr.io/distroless/cc-debian12 FROM gcr.io/distroless/cc-debian12:nonroot
COPY --from=builder /app/target/release/zeroclaw /usr/local/bin/zeroclaw COPY --from=builder /app/target/release/zeroclaw /usr/local/bin/zeroclaw
# Default workspace # Default workspace (owned by nonroot user)
VOLUME ["/workspace"] VOLUME ["/workspace"]
ENV ZEROCLAW_WORKSPACE=/workspace ENV ZEROCLAW_WORKSPACE=/workspace
# Explicitly set non-root user (distroless:nonroot defaults to 65534, but be explicit)
USER 65534:65534
ENTRYPOINT ["zeroclaw"] ENTRYPOINT ["zeroclaw"]
CMD ["gateway"] CMD ["gateway"]

212
README.md
View file

@ -12,12 +12,19 @@
<a href="LICENSE"><img src="https://img.shields.io/badge/license-MIT-blue.svg" alt="License: MIT" /></a> <a href="LICENSE"><img src="https://img.shields.io/badge/license-MIT-blue.svg" alt="License: MIT" /></a>
</p> </p>
The fastest, smallest, fully autonomous AI assistant — deploy anywhere, swap anything. Fast, small, and fully autonomous AI assistant infrastructure — deploy anywhere, swap anything.
``` ```
~3.4MB binary · <10ms startup · 1,017 tests · 22+ providers · 8 traits · Pluggable everything ~3.4MB binary · <10ms startup · 1,017 tests · 22+ providers · 8 traits · Pluggable everything
``` ```
### Why teams pick ZeroClaw
- **Lean by default:** small Rust binary, fast startup, low memory footprint.
- **Secure by design:** pairing, strict sandboxing, explicit allowlists, workspace scoping.
- **Fully swappable:** core systems are traits (providers, channels, tools, memory, tunnels).
- **No lock-in:** OpenAI-compatible provider support + pluggable custom endpoints.
## Benchmark Snapshot (ZeroClaw vs OpenClaw) ## Benchmark Snapshot (ZeroClaw vs OpenClaw)
Local machine quick benchmark (macOS arm64, Feb 2026), same host, 3 runs each. Local machine quick benchmark (macOS arm64, Feb 2026), same host, 3 runs each.
@ -30,7 +37,17 @@ Local machine quick benchmark (macOS arm64, Feb 2026), same host, 3 runs each.
| `--help` max RSS observed | **~7.3 MB** | **~394 MB** | | `--help` max RSS observed | **~7.3 MB** | **~394 MB** |
| `status` max RSS observed | **~7.8 MB** | **~1.52 GB** | | `status` max RSS observed | **~7.8 MB** | **~1.52 GB** |
> Notes: measured with `/usr/bin/time -l`; first run includes cold-start effects. OpenClaw results include `pnpm install` + `pnpm build` before execution. > Notes: measured with `/usr/bin/time -l`; first run includes cold-start effects. OpenClaw results were measured after `pnpm install` + `pnpm build`.
Reproduce ZeroClaw numbers locally:
```bash
cargo build --release
ls -lh target/release/zeroclaw
/usr/bin/time -l target/release/zeroclaw --help
/usr/bin/time -l target/release/zeroclaw status
```
## Quick Start ## Quick Start
@ -38,34 +55,52 @@ Local machine quick benchmark (macOS arm64, Feb 2026), same host, 3 runs each.
git clone https://github.com/theonlyhennygod/zeroclaw.git git clone https://github.com/theonlyhennygod/zeroclaw.git
cd zeroclaw cd zeroclaw
cargo build --release cargo build --release
cargo install --path . --force
# Quick setup (no prompts) # Quick setup (no prompts)
cargo run --release -- onboard --api-key sk-... --provider openrouter zeroclaw onboard --api-key sk-... --provider openrouter
# Or interactive wizard # Or interactive wizard
cargo run --release -- onboard --interactive zeroclaw onboard --interactive
# Or quickly repair channels/allowlists only
zeroclaw onboard --channels-only
# Chat # Chat
cargo run --release -- agent -m "Hello, ZeroClaw!" zeroclaw agent -m "Hello, ZeroClaw!"
# Interactive mode # Interactive mode
cargo run --release -- agent zeroclaw agent
# Start the gateway (webhook server) # Start the gateway (webhook server)
cargo run --release -- gateway # default: 127.0.0.1:8080 zeroclaw gateway # default: 127.0.0.1:8080
cargo run --release -- gateway --port 0 # random port (security hardened) zeroclaw gateway --port 0 # random port (security hardened)
# Start full autonomous runtime
zeroclaw daemon
# Check status # Check status
cargo run --release -- status zeroclaw status
# Run system diagnostics
zeroclaw doctor
# Check channel health # Check channel health
cargo run --release -- channel doctor zeroclaw channel doctor
# Get integration setup details # Get integration setup details
cargo run --release -- integrations info Telegram zeroclaw integrations info Telegram
# Manage background service
zeroclaw service install
zeroclaw service status
# Migrate memory from OpenClaw (safe preview first)
zeroclaw migrate openclaw --dry-run
zeroclaw migrate openclaw
``` ```
> **Tip:** Run `cargo install --path .` to install `zeroclaw` globally, then use `zeroclaw` instead of `cargo run --release --`. > **Dev fallback (no global install):** prefix commands with `cargo run --release --` (example: `cargo run --release -- status`).
## Architecture ## Architecture
@ -78,17 +113,25 @@ Every subsystem is a **trait** — swap implementations with a config change, ze
| Subsystem | Trait | Ships with | Extend | | Subsystem | Trait | Ships with | Extend |
|-----------|-------|------------|--------| |-----------|-------|------------|--------|
| **AI Models** | `Provider` | 22+ providers (OpenRouter, Anthropic, OpenAI, Ollama, Venice, Groq, Mistral, xAI, DeepSeek, Together, Fireworks, Perplexity, Cohere, Bedrock, etc.) | `custom:https://your-api.com` — any OpenAI-compatible API | | **AI Models** | `Provider` | 22+ providers (OpenRouter, Anthropic, OpenAI, Ollama, Venice, Groq, Mistral, xAI, DeepSeek, Together, Fireworks, Perplexity, Cohere, Bedrock, etc.) | `custom:https://your-api.com` — any OpenAI-compatible API |
| **Channels** | `Channel` | CLI, Telegram, Discord, Slack, iMessage, Matrix, Webhook | Any messaging API | | **Channels** | `Channel` | CLI, Telegram, Discord, Slack, iMessage, Matrix, WhatsApp, Webhook | Any messaging API |
| **Memory** | `Memory` | SQLite with hybrid search (FTS5 + vector cosine similarity), Markdown | Any persistence backend | | **Memory** | `Memory` | SQLite with hybrid search (FTS5 + vector cosine similarity), Markdown | Any persistence backend |
| **Tools** | `Tool` | shell, file_read, file_write, memory_store, memory_recall, memory_forget, browser_open (Brave + allowlist), composio (optional) | Any capability | | **Tools** | `Tool` | shell, file_read, file_write, memory_store, memory_recall, memory_forget, browser_open (Brave + allowlist), composio (optional) | Any capability |
| **Observability** | `Observer` | Noop, Log, Multi | Prometheus, OTel | | **Observability** | `Observer` | Noop, Log, Multi | Prometheus, OTel |
| **Runtime** | `RuntimeAdapter` | Native (Mac/Linux/Pi) | Docker, WASM | | **Runtime** | `RuntimeAdapter` | Native (Mac/Linux/Pi) | Docker, WASM (planned; unsupported kinds fail fast) |
| **Security** | `SecurityPolicy` | Gateway pairing, sandbox, allowlists, rate limits, filesystem scoping, encrypted secrets | — | | **Security** | `SecurityPolicy` | Gateway pairing, sandbox, allowlists, rate limits, filesystem scoping, encrypted secrets | — |
| **Identity** | `IdentityConfig` | OpenClaw (markdown), AIEOS v1.1 (JSON) | Any identity format |
| **Tunnel** | `Tunnel` | None, Cloudflare, Tailscale, ngrok, Custom | Any tunnel binary | | **Tunnel** | `Tunnel` | None, Cloudflare, Tailscale, ngrok, Custom | Any tunnel binary |
| **Heartbeat** | Engine | HEARTBEAT.md periodic tasks | — | | **Heartbeat** | Engine | HEARTBEAT.md periodic tasks | — |
| **Skills** | Loader | TOML manifests + SKILL.md instructions | Community skill packs | | **Skills** | Loader | TOML manifests + SKILL.md instructions | Community skill packs |
| **Integrations** | Registry | 50+ integrations across 9 categories | Plugin system | | **Integrations** | Registry | 50+ integrations across 9 categories | Plugin system |
### Runtime support (current)
- ✅ Supported today: `runtime.kind = "native"`
- 🚧 Planned, not implemented yet: Docker / WASM / edge runtimes
When an unsupported `runtime.kind` is configured, ZeroClaw now exits with a clear error instead of silently falling back to native.
### Memory System (Full-Stack Search Engine) ### Memory System (Full-Stack Search Engine)
All custom, zero external dependencies — no Pinecone, no Elasticsearch, no LangChain: All custom, zero external dependencies — no Pinecone, no Elasticsearch, no LangChain:
@ -124,7 +167,7 @@ ZeroClaw enforces security at **every layer** — not just the sandbox. It passe
|---|------|--------|-----| |---|------|--------|-----|
| 1 | **Gateway not publicly exposed** | ✅ | Binds `127.0.0.1` by default. Refuses `0.0.0.0` without tunnel or explicit `allow_public_bind = true`. | | 1 | **Gateway not publicly exposed** | ✅ | Binds `127.0.0.1` by default. Refuses `0.0.0.0` without tunnel or explicit `allow_public_bind = true`. |
| 2 | **Pairing required** | ✅ | 6-digit one-time code on startup. Exchange via `POST /pair` for bearer token. All `/webhook` requests require `Authorization: Bearer <token>`. | | 2 | **Pairing required** | ✅ | 6-digit one-time code on startup. Exchange via `POST /pair` for bearer token. All `/webhook` requests require `Authorization: Bearer <token>`. |
| 3 | **Filesystem scoped (no /)** | ✅ | `workspace_only = true` by default. 14 system dirs + 4 sensitive dotfiles blocked. Null byte injection blocked. Symlink escape detection via canonicalization. | | 3 | **Filesystem scoped (no /)** | ✅ | `workspace_only = true` by default. 14 system dirs + 4 sensitive dotfiles blocked. Null byte injection blocked. Symlink escape detection via canonicalization + resolved-path workspace checks in file read/write tools. |
| 4 | **Access via tunnel only** | ✅ | Gateway refuses public bind without active tunnel. Supports Tailscale, Cloudflare, ngrok, or any custom tunnel. | | 4 | **Access via tunnel only** | ✅ | Gateway refuses public bind without active tunnel. Supports Tailscale, Cloudflare, ngrok, or any custom tunnel. |
> **Run your own nmap:** `nmap -p 1-65535 <your-host>` — ZeroClaw binds to localhost only, so nothing is exposed unless you explicitly configure a tunnel. > **Run your own nmap:** `nmap -p 1-65535 <your-host>` — ZeroClaw binds to localhost only, so nothing is exposed unless you explicitly configure a tunnel.
@ -139,6 +182,63 @@ Inbound sender policy is now consistent:
This keeps accidental exposure low by default. This keeps accidental exposure low by default.
Recommended low-friction setup (secure + fast):
- **Telegram:** allowlist your own `@username` (without `@`) and/or your numeric Telegram user ID.
- **Discord:** allowlist your own Discord user ID.
- **Slack:** allowlist your own Slack member ID (usually starts with `U`).
- Use `"*"` only for temporary open testing.
If you're not sure which identity to use:
1. Start channels and send one message to your bot.
2. Read the warning log to see the exact sender identity.
3. Add that value to the allowlist and rerun channels-only setup.
If you hit authorization warnings in logs (for example: `ignoring message from unauthorized user`),
rerun channel setup only:
```bash
zeroclaw onboard --channels-only
```
### WhatsApp Business Cloud API Setup
WhatsApp uses Meta's Cloud API with webhooks (push-based, not polling):
1. **Create a Meta Business App:**
- Go to [developers.facebook.com](https://developers.facebook.com)
- Create a new app → Select "Business" type
- Add the "WhatsApp" product
2. **Get your credentials:**
- **Access Token:** From WhatsApp → API Setup → Generate token (or create a System User for permanent tokens)
- **Phone Number ID:** From WhatsApp → API Setup → Phone number ID
- **Verify Token:** You define this (any random string) — Meta will send it back during webhook verification
3. **Configure ZeroClaw:**
```toml
[channels_config.whatsapp]
access_token = "EAABx..."
phone_number_id = "123456789012345"
verify_token = "my-secret-verify-token"
allowed_numbers = ["+1234567890"] # E.164 format, or ["*"] for all
```
4. **Start the gateway with a tunnel:**
```bash
zeroclaw gateway --port 8080
```
WhatsApp requires HTTPS, so use a tunnel (ngrok, Cloudflare, Tailscale Funnel).
5. **Configure Meta webhook:**
- In Meta Developer Console → WhatsApp → Configuration → Webhook
- **Callback URL:** `https://your-tunnel-url/whatsapp`
- **Verify Token:** Same as your `verify_token` in config
- Subscribe to `messages` field
6. **Test:** Send a message to your WhatsApp Business number — ZeroClaw will respond via the LLM.
## Configuration ## Configuration
Config: `~/.zeroclaw/config.toml` (created by `onboard`) Config: `~/.zeroclaw/config.toml` (created by `onboard`)
@ -166,6 +266,9 @@ workspace_only = true # default: true — scoped to workspace
allowed_commands = ["git", "npm", "cargo", "ls", "cat", "grep"] allowed_commands = ["git", "npm", "cargo", "ls", "cat", "grep"]
forbidden_paths = ["/etc", "/root", "/proc", "/sys", "~/.ssh", "~/.gnupg", "~/.aws"] forbidden_paths = ["/etc", "/root", "/proc", "/sys", "~/.ssh", "~/.gnupg", "~/.aws"]
[runtime]
kind = "native" # only supported value right now; unsupported kinds fail fast
[heartbeat] [heartbeat]
enabled = false enabled = false
interval_minutes = 30 interval_minutes = 30
@ -182,8 +285,81 @@ allowed_domains = ["docs.rs"] # required when browser is enabled
[composio] [composio]
enabled = false # opt-in: 1000+ OAuth apps via composio.dev enabled = false # opt-in: 1000+ OAuth apps via composio.dev
[identity]
format = "openclaw" # "openclaw" (default, markdown files) or "aieos" (JSON)
# aieos_path = "identity.json" # path to AIEOS JSON file (relative to workspace or absolute)
# aieos_inline = '{"identity":{"names":{"first":"Nova"}}}' # inline AIEOS JSON
``` ```
## Identity System (AIEOS Support)
ZeroClaw supports **identity-agnostic** AI personas through two formats:
### OpenClaw (Default)
Traditional markdown files in your workspace:
- `IDENTITY.md` — Who the agent is
- `SOUL.md` — Core personality and values
- `USER.md` — Who the agent is helping
- `AGENTS.md` — Behavior guidelines
### AIEOS (AI Entity Object Specification)
[AIEOS](https://aieos.org) is a standardization framework for portable AI identity. ZeroClaw supports AIEOS v1.1 JSON payloads, allowing you to:
- **Import identities** from the AIEOS ecosystem
- **Export identities** to other AIEOS-compatible systems
- **Maintain behavioral integrity** across different AI models
#### Enable AIEOS
```toml
[identity]
format = "aieos"
aieos_path = "identity.json" # relative to workspace or absolute path
```
Or inline JSON:
```toml
[identity]
format = "aieos"
aieos_inline = '''
{
"identity": {
"names": { "first": "Nova", "nickname": "N" }
},
"psychology": {
"neural_matrix": { "creativity": 0.9, "logic": 0.8 },
"traits": { "mbti": "ENTP" },
"moral_compass": { "alignment": "Chaotic Good" }
},
"linguistics": {
"text_style": { "formality_level": 0.2, "slang_usage": true }
},
"motivations": {
"core_drive": "Push boundaries and explore possibilities"
}
}
'''
```
#### AIEOS Schema Sections
| Section | Description |
|---------|-------------|
| `identity` | Names, bio, origin, residence |
| `psychology` | Neural matrix (cognitive weights), MBTI, OCEAN, moral compass |
| `linguistics` | Text style, formality, catchphrases, forbidden words |
| `motivations` | Core drive, short/long-term goals, fears |
| `capabilities` | Skills and tools the agent can access |
| `physicality` | Visual descriptors for image generation |
| `history` | Origin story, education, occupation |
| `interests` | Hobbies, favorites, lifestyle |
See [aieos.org](https://aieos.org) for the full schema and live examples.
## Gateway API ## Gateway API
| Endpoint | Method | Auth | Description | | Endpoint | Method | Auth | Description |
@ -191,6 +367,8 @@ enabled = false # opt-in: 1000+ OAuth apps via composio.dev
| `/health` | GET | None | Health check (always public, no secrets leaked) | | `/health` | GET | None | Health check (always public, no secrets leaked) |
| `/pair` | POST | `X-Pairing-Code` header | Exchange one-time code for bearer token | | `/pair` | POST | `X-Pairing-Code` header | Exchange one-time code for bearer token |
| `/webhook` | POST | `Authorization: Bearer <token>` | Send message: `{"message": "your prompt"}` | | `/webhook` | POST | `Authorization: Bearer <token>` | Send message: `{"message": "your prompt"}` |
| `/whatsapp` | GET | Query params | Meta webhook verification (hub.mode, hub.verify_token, hub.challenge) |
| `/whatsapp` | POST | None (Meta signature) | WhatsApp incoming message webhook |
## Commands ## Commands
@ -198,10 +376,14 @@ enabled = false # opt-in: 1000+ OAuth apps via composio.dev
|---------|-------------| |---------|-------------|
| `onboard` | Quick setup (default) | | `onboard` | Quick setup (default) |
| `onboard --interactive` | Full interactive 7-step wizard | | `onboard --interactive` | Full interactive 7-step wizard |
| `onboard --channels-only` | Reconfigure channels/allowlists only (fast repair flow) |
| `agent -m "..."` | Single message mode | | `agent -m "..."` | Single message mode |
| `agent` | Interactive chat mode | | `agent` | Interactive chat mode |
| `gateway` | Start webhook server (default: `127.0.0.1:8080`) | | `gateway` | Start webhook server (default: `127.0.0.1:8080`) |
| `gateway --port 0` | Random port mode | | `gateway --port 0` | Random port mode |
| `daemon` | Start long-running autonomous runtime |
| `service install/start/stop/status/uninstall` | Manage user-level background service |
| `doctor` | Diagnose daemon/scheduler/channel freshness |
| `status` | Show full system status | | `status` | Show full system status |
| `channel doctor` | Run health checks for configured channels | | `channel doctor` | Run health checks for configured channels |
| `integrations info <name>` | Show setup/status details for one integration | | `integrations info <name>` | Show setup/status details for one integration |

View file

@ -61,3 +61,33 @@ cargo test -- tools::shell
cargo test -- tools::file_read cargo test -- tools::file_read
cargo test -- tools::file_write cargo test -- tools::file_write
``` ```
## Container Security
ZeroClaw Docker images follow CIS Docker Benchmark best practices:
| Control | Implementation |
|---------|----------------|
| **4.1 Non-root user** | Container runs as UID 65534 (distroless nonroot) |
| **4.2 Minimal base image** | `gcr.io/distroless/cc-debian12:nonroot` — no shell, no package manager |
| **4.6 HEALTHCHECK** | Not applicable (stateless CLI/gateway) |
| **5.25 Read-only filesystem** | Supported via `docker run --read-only` with `/workspace` volume |
### Verifying Container Security
```bash
# Build and verify non-root user
docker build -t zeroclaw .
docker inspect --format='{{.Config.User}}' zeroclaw
# Expected: 65534:65534
# Run with read-only filesystem (production hardening)
docker run --read-only -v /path/to/workspace:/workspace zeroclaw gateway
```
### CI Enforcement
The `docker` job in `.github/workflows/ci.yml` automatically verifies:
1. Container does not run as root (UID 0)
2. Runtime stage uses `:nonroot` variant
3. Explicit `USER` directive with numeric UID exists

169
scripts/test_dockerignore.sh Executable file
View file

@ -0,0 +1,169 @@
#!/usr/bin/env bash
# Test script to verify .dockerignore excludes sensitive paths
# Run: ./scripts/test_dockerignore.sh
set -euo pipefail
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
PROJECT_ROOT="$(dirname "$SCRIPT_DIR")"
DOCKERIGNORE="$PROJECT_ROOT/.dockerignore"
RED='\033[0;31m'
GREEN='\033[0;32m'
NC='\033[0m' # No Color
PASS=0
FAIL=0
log_pass() {
echo -e "${GREEN}${NC} $1"
PASS=$((PASS + 1))
}
log_fail() {
echo -e "${RED}${NC} $1"
FAIL=$((FAIL + 1))
}
# Test 1: .dockerignore exists
echo "=== Testing .dockerignore ==="
if [[ -f "$DOCKERIGNORE" ]]; then
log_pass ".dockerignore file exists"
else
log_fail ".dockerignore file does not exist"
exit 1
fi
# Test 2: Required exclusions are present
MUST_EXCLUDE=(
".git"
".githooks"
"target"
"docs"
"examples"
"tests"
"*.md"
"*.png"
"*.db"
"*.db-journal"
".DS_Store"
".github"
"deny.toml"
"LICENSE"
".env"
".tmp_*"
)
for pattern in "${MUST_EXCLUDE[@]}"; do
# Use fgrep for literal matching
if grep -Fq "$pattern" "$DOCKERIGNORE" 2>/dev/null; then
log_pass "Excludes: $pattern"
else
log_fail "Missing exclusion: $pattern"
fi
done
# Test 3: Build essentials are NOT excluded
MUST_NOT_EXCLUDE=(
"Cargo.toml"
"Cargo.lock"
"src"
)
for path in "${MUST_NOT_EXCLUDE[@]}"; do
if grep -qE "^${path}$" "$DOCKERIGNORE" 2>/dev/null; then
log_fail "Build essential '$path' is incorrectly excluded"
else
log_pass "Build essential NOT excluded: $path"
fi
done
# Test 4: No syntax errors (basic validation)
while IFS= read -r line; do
# Skip empty lines and comments
[[ -z "$line" || "$line" =~ ^# ]] && continue
# Check for common issues
if [[ "$line" =~ [[:space:]]$ ]]; then
log_fail "Trailing whitespace in pattern: '$line'"
fi
done < "$DOCKERIGNORE"
log_pass "No trailing whitespace in patterns"
# Test 5: Verify Docker build context would be small
echo ""
echo "=== Simulating Docker build context ==="
# Create temp dir and simulate what would be sent
TEMP_DIR=$(mktemp -d)
trap "rm -rf $TEMP_DIR" EXIT
# Use rsync with .dockerignore patterns to simulate Docker's behavior
cd "$PROJECT_ROOT"
# Count files that WOULD be sent (excluding .dockerignore patterns)
TOTAL_FILES=$(find . -type f | wc -l | tr -d ' ')
CONTEXT_FILES=$(find . -type f \
! -path './.git/*' \
! -path './target/*' \
! -path './docs/*' \
! -path './examples/*' \
! -path './tests/*' \
! -name '*.md' \
! -name '*.png' \
! -name '*.svg' \
! -name '*.db' \
! -name '*.db-journal' \
! -name '.DS_Store' \
! -path './.github/*' \
! -name 'deny.toml' \
! -name 'LICENSE' \
! -name '.env' \
! -name '.env.*' \
2>/dev/null | wc -l | tr -d ' ')
echo "Total files in repo: $TOTAL_FILES"
echo "Files in Docker context: $CONTEXT_FILES"
if [[ $CONTEXT_FILES -lt $TOTAL_FILES ]]; then
log_pass "Docker context is smaller than full repo ($CONTEXT_FILES < $TOTAL_FILES files)"
else
log_fail "Docker context is not being reduced"
fi
# Test 6: Verify critical security files would be excluded
echo ""
echo "=== Security checks ==="
# Check if .git would be excluded
if [[ -d "$PROJECT_ROOT/.git" ]]; then
if grep -q "^\.git$" "$DOCKERIGNORE"; then
log_pass ".git directory will be excluded (security)"
else
log_fail ".git directory NOT excluded - SECURITY RISK"
fi
fi
# Check if any .db files exist and would be excluded
DB_FILES=$(find "$PROJECT_ROOT" -name "*.db" -type f 2>/dev/null | head -5)
if [[ -n "$DB_FILES" ]]; then
if grep -q "^\*\.db$" "$DOCKERIGNORE"; then
log_pass "*.db files will be excluded (security)"
else
log_fail "*.db files NOT excluded - SECURITY RISK"
fi
fi
# Summary
echo ""
echo "=== Summary ==="
echo -e "Passed: ${GREEN}$PASS${NC}"
echo -e "Failed: ${RED}$FAIL${NC}"
if [[ $FAIL -gt 0 ]]; then
echo -e "${RED}FAILED${NC}: $FAIL tests failed"
exit 1
else
echo -e "${GREEN}PASSED${NC}: All tests passed"
exit 0
fi

View file

@ -39,7 +39,7 @@ pub async fn run(
// ── Wire up agnostic subsystems ────────────────────────────── // ── Wire up agnostic subsystems ──────────────────────────────
let observer: Arc<dyn Observer> = let observer: Arc<dyn Observer> =
Arc::from(observability::create_observer(&config.observability)); Arc::from(observability::create_observer(&config.observability));
let _runtime = runtime::create_runtime(&config.runtime); let _runtime = runtime::create_runtime(&config.runtime)?;
let security = Arc::new(SecurityPolicy::from_config( let security = Arc::new(SecurityPolicy::from_config(
&config.autonomy, &config.autonomy,
&config.workspace_dir, &config.workspace_dir,
@ -72,8 +72,11 @@ pub async fn run(
.or(config.default_model.as_deref()) .or(config.default_model.as_deref())
.unwrap_or("anthropic/claude-sonnet-4-20250514"); .unwrap_or("anthropic/claude-sonnet-4-20250514");
let provider: Box<dyn Provider> = let provider: Box<dyn Provider> = providers::create_resilient_provider(
providers::create_provider(provider_name, config.api_key.as_deref())?; provider_name,
config.api_key.as_deref(),
&config.reliability,
)?;
observer.record_event(&ObserverEvent::AgentStart { observer.record_event(&ObserverEvent::AgentStart {
provider: provider_name.to_string(), provider: provider_name.to_string(),
@ -83,12 +86,30 @@ pub async fn run(
// ── Build system prompt from workspace MD files (OpenClaw framework) ── // ── Build system prompt from workspace MD files (OpenClaw framework) ──
let skills = crate::skills::load_skills(&config.workspace_dir); let skills = crate::skills::load_skills(&config.workspace_dir);
let mut tool_descs: Vec<(&str, &str)> = vec![ let mut tool_descs: Vec<(&str, &str)> = vec![
("shell", "Execute terminal commands"), (
("file_read", "Read file contents"), "shell",
("file_write", "Write file contents"), "Execute terminal commands. Use when: running local checks, build/test commands, diagnostics. Don't use when: a safer dedicated tool exists, or command is destructive without approval.",
("memory_store", "Save to memory"), ),
("memory_recall", "Search memory"), (
("memory_forget", "Delete a memory entry"), "file_read",
"Read file contents. Use when: inspecting project files, configs, logs. Don't use when: a targeted search is enough.",
),
(
"file_write",
"Write file contents. Use when: applying focused edits, scaffolding files, updating docs/code. Don't use when: side effects are unclear or file ownership is uncertain.",
),
(
"memory_store",
"Save to memory. Use when: preserving durable preferences, decisions, key context. Don't use when: information is transient/noisy/sensitive without need.",
),
(
"memory_recall",
"Search memory. Use when: retrieving prior decisions, user preferences, historical context. Don't use when: answer is already in current context.",
),
(
"memory_forget",
"Delete a memory entry. Use when: memory is incorrect/stale or explicitly requested for removal. Don't use when: impact is uncertain.",
),
]; ];
if config.browser.enabled { if config.browser.enabled {
tool_descs.push(( tool_descs.push((

View file

@ -29,6 +29,60 @@ impl IMessageChannel {
} }
} }
/// Escape a string for safe interpolation into `AppleScript`.
///
/// This prevents injection attacks by escaping:
/// - Backslashes (`\` → `\\`)
/// - Double quotes (`"` → `\"`)
fn escape_applescript(s: &str) -> String {
s.replace('\\', "\\\\").replace('"', "\\\"")
}
/// Validate that a target looks like a valid phone number or email address.
///
/// This is a defense-in-depth measure to reject obviously malicious targets
/// before they reach `AppleScript` interpolation.
///
/// Valid patterns:
/// - Phone: starts with `+` followed by digits (with optional spaces/dashes)
/// - Email: contains `@` with alphanumeric chars on both sides
fn is_valid_imessage_target(target: &str) -> bool {
let target = target.trim();
if target.is_empty() {
return false;
}
// Phone number: +1234567890 or +1 234-567-8900
if target.starts_with('+') {
let digits_only: String = target.chars().filter(char::is_ascii_digit).collect();
// Must have at least 7 digits (shortest valid phone numbers)
return digits_only.len() >= 7 && digits_only.len() <= 15;
}
// Email: simple validation (contains @ with chars on both sides)
if let Some(at_pos) = target.find('@') {
let local = &target[..at_pos];
let domain = &target[at_pos + 1..];
// Local part: non-empty, alphanumeric + common email chars
let local_valid = !local.is_empty()
&& local
.chars()
.all(|c| c.is_alphanumeric() || "._+-".contains(c));
// Domain: non-empty, contains a dot, alphanumeric + dots/hyphens
let domain_valid = !domain.is_empty()
&& domain.contains('.')
&& domain
.chars()
.all(|c| c.is_alphanumeric() || ".-".contains(c));
return local_valid && domain_valid;
}
false
}
#[async_trait] #[async_trait]
impl Channel for IMessageChannel { impl Channel for IMessageChannel {
fn name(&self) -> &str { fn name(&self) -> &str {
@ -36,11 +90,22 @@ impl Channel for IMessageChannel {
} }
async fn send(&self, message: &str, target: &str) -> anyhow::Result<()> { async fn send(&self, message: &str, target: &str) -> anyhow::Result<()> {
let escaped_msg = message.replace('\\', "\\\\").replace('"', "\\\""); // Defense-in-depth: validate target format before any interpolation
if !is_valid_imessage_target(target) {
anyhow::bail!(
"Invalid iMessage target: must be a phone number (+1234567890) or email (user@example.com)"
);
}
// SECURITY: Escape both message AND target to prevent AppleScript injection
// See: CWE-78 (OS Command Injection)
let escaped_msg = escape_applescript(message);
let escaped_target = escape_applescript(target);
let script = format!( let script = format!(
r#"tell application "Messages" r#"tell application "Messages"
set targetService to 1st account whose service type = iMessage set targetService to 1st account whose service type = iMessage
set targetBuddy to participant "{target}" of targetService set targetBuddy to participant "{escaped_target}" of targetService
send "{escaped_msg}" to targetBuddy send "{escaped_msg}" to targetBuddy
end tell"# end tell"#
); );
@ -262,4 +327,204 @@ mod tests {
assert!(ch.is_contact_allowed(" spaced ")); assert!(ch.is_contact_allowed(" spaced "));
assert!(!ch.is_contact_allowed("spaced")); assert!(!ch.is_contact_allowed("spaced"));
} }
// ══════════════════════════════════════════════════════════
// AppleScript Escaping Tests (CWE-78 Prevention)
// ══════════════════════════════════════════════════════════
#[test]
fn escape_applescript_double_quotes() {
assert_eq!(escape_applescript(r#"hello "world""#), r#"hello \"world\""#);
}
#[test]
fn escape_applescript_backslashes() {
assert_eq!(escape_applescript(r"path\to\file"), r"path\\to\\file");
}
#[test]
fn escape_applescript_mixed() {
assert_eq!(
escape_applescript(r#"say "hello\" world"#),
r#"say \"hello\\\" world"#
);
}
#[test]
fn escape_applescript_injection_attempt() {
// This is the exact attack vector from the security report
let malicious = r#"" & do shell script "id" & ""#;
let escaped = escape_applescript(malicious);
// After escaping, the quotes should be escaped and not break out
assert_eq!(escaped, r#"\" & do shell script \"id\" & \""#);
// Verify all quotes are now escaped (preceded by backslash)
// The escaped string should not have any unescaped quotes (quote not preceded by backslash)
let chars: Vec<char> = escaped.chars().collect();
for (i, &c) in chars.iter().enumerate() {
if c == '"' {
// Every quote must be preceded by a backslash
assert!(
i > 0 && chars[i - 1] == '\\',
"Found unescaped quote at position {i}"
);
}
}
}
#[test]
fn escape_applescript_empty_string() {
assert_eq!(escape_applescript(""), "");
}
#[test]
fn escape_applescript_no_special_chars() {
assert_eq!(escape_applescript("hello world"), "hello world");
}
#[test]
fn escape_applescript_unicode() {
assert_eq!(escape_applescript("hello 🦀 world"), "hello 🦀 world");
}
#[test]
fn escape_applescript_newlines_preserved() {
assert_eq!(escape_applescript("line1\nline2"), "line1\nline2");
}
// ══════════════════════════════════════════════════════════
// Target Validation Tests
// ══════════════════════════════════════════════════════════
#[test]
fn valid_phone_number_simple() {
assert!(is_valid_imessage_target("+1234567890"));
}
#[test]
fn valid_phone_number_with_country_code() {
assert!(is_valid_imessage_target("+14155551234"));
}
#[test]
fn valid_phone_number_with_spaces() {
assert!(is_valid_imessage_target("+1 415 555 1234"));
}
#[test]
fn valid_phone_number_with_dashes() {
assert!(is_valid_imessage_target("+1-415-555-1234"));
}
#[test]
fn valid_phone_number_international() {
assert!(is_valid_imessage_target("+447911123456")); // UK
assert!(is_valid_imessage_target("+81312345678")); // Japan
}
#[test]
fn valid_email_simple() {
assert!(is_valid_imessage_target("user@example.com"));
}
#[test]
fn valid_email_with_subdomain() {
assert!(is_valid_imessage_target("user@mail.example.com"));
}
#[test]
fn valid_email_with_plus() {
assert!(is_valid_imessage_target("user+tag@example.com"));
}
#[test]
fn valid_email_with_dots() {
assert!(is_valid_imessage_target("first.last@example.com"));
}
#[test]
fn valid_email_icloud() {
assert!(is_valid_imessage_target("user@icloud.com"));
assert!(is_valid_imessage_target("user@me.com"));
}
#[test]
fn invalid_target_empty() {
assert!(!is_valid_imessage_target(""));
assert!(!is_valid_imessage_target(" "));
}
#[test]
fn invalid_target_no_plus_prefix() {
// Phone numbers must start with +
assert!(!is_valid_imessage_target("1234567890"));
}
#[test]
fn invalid_target_too_short_phone() {
// Less than 7 digits
assert!(!is_valid_imessage_target("+123456"));
}
#[test]
fn invalid_target_too_long_phone() {
// More than 15 digits
assert!(!is_valid_imessage_target("+1234567890123456"));
}
#[test]
fn invalid_target_email_no_at() {
assert!(!is_valid_imessage_target("userexample.com"));
}
#[test]
fn invalid_target_email_no_domain() {
assert!(!is_valid_imessage_target("user@"));
}
#[test]
fn invalid_target_email_no_local() {
assert!(!is_valid_imessage_target("@example.com"));
}
#[test]
fn invalid_target_email_no_dot_in_domain() {
assert!(!is_valid_imessage_target("user@localhost"));
}
#[test]
fn invalid_target_injection_attempt() {
// The exact attack vector from the security report
assert!(!is_valid_imessage_target(r#"" & do shell script "id" & ""#));
}
#[test]
fn invalid_target_applescript_injection() {
// Various injection attempts
assert!(!is_valid_imessage_target(r#"test" & quit"#));
assert!(!is_valid_imessage_target(r#"test\ndo shell script"#));
assert!(!is_valid_imessage_target("test\"; malicious code; \""));
}
#[test]
fn invalid_target_special_chars() {
assert!(!is_valid_imessage_target("user<script>@example.com"));
assert!(!is_valid_imessage_target("user@example.com; rm -rf /"));
}
#[test]
fn invalid_target_null_byte() {
assert!(!is_valid_imessage_target("user\0@example.com"));
}
#[test]
fn invalid_target_newline() {
assert!(!is_valid_imessage_target("user\n@example.com"));
}
#[test]
fn target_with_leading_trailing_whitespace_trimmed() {
// Should trim and validate
assert!(is_valid_imessage_target(" +1234567890 "));
assert!(is_valid_imessage_target(" user@example.com "));
}
} }

View file

@ -7,6 +7,7 @@ pub mod slack;
pub mod telegram; pub mod telegram;
pub mod whatsapp; pub mod whatsapp;
pub mod traits; pub mod traits;
pub mod whatsapp;
pub use cli::CliChannel; pub use cli::CliChannel;
pub use discord::DiscordChannel; pub use discord::DiscordChannel;
@ -17,6 +18,7 @@ pub use telegram::TelegramChannel;
#[allow(unused_imports)] #[allow(unused_imports)]
pub use whatsapp::WhatsAppChannel; pub use whatsapp::WhatsAppChannel;
pub use traits::Channel; pub use traits::Channel;
pub use whatsapp::WhatsAppChannel;
use crate::config::Config; use crate::config::Config;
use crate::memory::{self, Memory}; use crate::memory::{self, Memory};
@ -28,6 +30,46 @@ use std::time::Duration;
/// Maximum characters per injected workspace file (matches `OpenClaw` default). /// Maximum characters per injected workspace file (matches `OpenClaw` default).
const BOOTSTRAP_MAX_CHARS: usize = 20_000; const BOOTSTRAP_MAX_CHARS: usize = 20_000;
const DEFAULT_CHANNEL_INITIAL_BACKOFF_SECS: u64 = 2;
const DEFAULT_CHANNEL_MAX_BACKOFF_SECS: u64 = 60;
fn spawn_supervised_listener(
ch: Arc<dyn Channel>,
tx: tokio::sync::mpsc::Sender<traits::ChannelMessage>,
initial_backoff_secs: u64,
max_backoff_secs: u64,
) -> tokio::task::JoinHandle<()> {
tokio::spawn(async move {
let component = format!("channel:{}", ch.name());
let mut backoff = initial_backoff_secs.max(1);
let max_backoff = max_backoff_secs.max(backoff);
loop {
crate::health::mark_component_ok(&component);
let result = ch.listen(tx.clone()).await;
if tx.is_closed() {
break;
}
match result {
Ok(()) => {
tracing::warn!("Channel {} exited unexpectedly; restarting", ch.name());
crate::health::mark_component_error(&component, "listener exited unexpectedly");
}
Err(e) => {
tracing::error!("Channel {} error: {e}; restarting", ch.name());
crate::health::mark_component_error(&component, e.to_string());
}
}
crate::health::bump_component_restart(&component);
tokio::time::sleep(Duration::from_secs(backoff)).await;
backoff = backoff.saturating_mul(2).min(max_backoff);
}
})
}
/// Load workspace identity files and build a system prompt. /// Load workspace identity files and build a system prompt.
/// ///
/// Follows the `OpenClaw` framework structure: /// Follows the `OpenClaw` framework structure:
@ -150,6 +192,38 @@ pub fn build_system_prompt(
} }
} }
/// Inject OpenClaw (markdown) identity files into the prompt
fn inject_openclaw_identity(prompt: &mut String, workspace_dir: &std::path::Path) {
#[allow(unused_imports)]
use std::fmt::Write;
prompt.push_str("## Project Context\n\n");
prompt
.push_str("The following workspace files define your identity, behavior, and context.\n\n");
let bootstrap_files = [
"AGENTS.md",
"SOUL.md",
"TOOLS.md",
"IDENTITY.md",
"USER.md",
"HEARTBEAT.md",
];
for filename in &bootstrap_files {
inject_workspace_file(prompt, workspace_dir, filename);
}
// BOOTSTRAP.md — only if it exists (first-run ritual)
let bootstrap_path = workspace_dir.join("BOOTSTRAP.md");
if bootstrap_path.exists() {
inject_workspace_file(prompt, workspace_dir, "BOOTSTRAP.md");
}
// MEMORY.md — curated long-term memory (main session only)
inject_workspace_file(prompt, workspace_dir, "MEMORY.md");
}
/// Inject a single workspace file into the prompt with truncation and missing-file markers. /// Inject a single workspace file into the prompt with truncation and missing-file markers.
fn inject_workspace_file(prompt: &mut String, workspace_dir: &std::path::Path, filename: &str) { fn inject_workspace_file(prompt: &mut String, workspace_dir: &std::path::Path, filename: &str) {
use std::fmt::Write; use std::fmt::Write;
@ -200,6 +274,7 @@ pub fn handle_command(command: super::ChannelCommands, config: &Config) -> Resul
("Webhook", config.channels_config.webhook.is_some()), ("Webhook", config.channels_config.webhook.is_some()),
("iMessage", config.channels_config.imessage.is_some()), ("iMessage", config.channels_config.imessage.is_some()),
("Matrix", config.channels_config.matrix.is_some()), ("Matrix", config.channels_config.matrix.is_some()),
("WhatsApp", config.channels_config.whatsapp.is_some()),
] { ] {
println!(" {} {name}", if configured { "" } else { "" }); println!(" {} {name}", if configured { "" } else { "" });
} }
@ -294,6 +369,18 @@ pub async fn doctor_channels(config: Config) -> Result<()> {
)); ));
} }
if let Some(ref wa) = config.channels_config.whatsapp {
channels.push((
"WhatsApp",
Arc::new(WhatsAppChannel::new(
wa.access_token.clone(),
wa.phone_number_id.clone(),
wa.verify_token.clone(),
wa.allowed_numbers.clone(),
)),
));
}
if channels.is_empty() { if channels.is_empty() {
println!("No real-time channels configured. Run `zeroclaw onboard` first."); println!("No real-time channels configured. Run `zeroclaw onboard` first.");
return Ok(()); return Ok(());
@ -338,9 +425,10 @@ pub async fn doctor_channels(config: Config) -> Result<()> {
/// Start all configured channels and route messages to the agent /// Start all configured channels and route messages to the agent
#[allow(clippy::too_many_lines)] #[allow(clippy::too_many_lines)]
pub async fn start_channels(config: Config) -> Result<()> { pub async fn start_channels(config: Config) -> Result<()> {
let provider: Arc<dyn Provider> = Arc::from(providers::create_provider( let provider: Arc<dyn Provider> = Arc::from(providers::create_resilient_provider(
config.default_provider.as_deref().unwrap_or("openrouter"), config.default_provider.as_deref().unwrap_or("openrouter"),
config.api_key.as_deref(), config.api_key.as_deref(),
&config.reliability,
)?); )?);
let model = config let model = config
.default_model .default_model
@ -359,12 +447,30 @@ pub async fn start_channels(config: Config) -> Result<()> {
// Collect tool descriptions for the prompt // Collect tool descriptions for the prompt
let mut tool_descs: Vec<(&str, &str)> = vec![ let mut tool_descs: Vec<(&str, &str)> = vec![
("shell", "Execute terminal commands"), (
("file_read", "Read file contents"), "shell",
("file_write", "Write file contents"), "Execute terminal commands. Use when: running local checks, build/test commands, diagnostics. Don't use when: a safer dedicated tool exists, or command is destructive without approval.",
("memory_store", "Save to memory"), ),
("memory_recall", "Search memory"), (
("memory_forget", "Delete a memory entry"), "file_read",
"Read file contents. Use when: inspecting project files, configs, logs. Don't use when: a targeted search is enough.",
),
(
"file_write",
"Write file contents. Use when: applying focused edits, scaffolding files, updating docs/code. Don't use when: side effects are unclear or file ownership is uncertain.",
),
(
"memory_store",
"Save to memory. Use when: preserving durable preferences, decisions, key context. Don't use when: information is transient/noisy/sensitive without need.",
),
(
"memory_recall",
"Search memory. Use when: retrieving prior decisions, user preferences, historical context. Don't use when: answer is already in current context.",
),
(
"memory_forget",
"Delete a memory entry. Use when: memory is incorrect/stale or explicitly requested for removal. Don't use when: impact is uncertain.",
),
]; ];
if config.browser.enabled { if config.browser.enabled {
@ -426,6 +532,15 @@ pub async fn start_channels(config: Config) -> Result<()> {
))); )));
} }
if let Some(ref wa) = config.channels_config.whatsapp {
channels.push(Arc::new(WhatsAppChannel::new(
wa.access_token.clone(),
wa.phone_number_id.clone(),
wa.verify_token.clone(),
wa.allowed_numbers.clone(),
)));
}
if channels.is_empty() { if channels.is_empty() {
println!("No channels configured. Run `zeroclaw onboard` to set up channels."); println!("No channels configured. Run `zeroclaw onboard` to set up channels.");
return Ok(()); return Ok(());
@ -450,19 +565,29 @@ pub async fn start_channels(config: Config) -> Result<()> {
println!(" Listening for messages... (Ctrl+C to stop)"); println!(" Listening for messages... (Ctrl+C to stop)");
println!(); println!();
crate::health::mark_component_ok("channels");
let initial_backoff_secs = config
.reliability
.channel_initial_backoff_secs
.max(DEFAULT_CHANNEL_INITIAL_BACKOFF_SECS);
let max_backoff_secs = config
.reliability
.channel_max_backoff_secs
.max(DEFAULT_CHANNEL_MAX_BACKOFF_SECS);
// Single message bus — all channels send messages here // Single message bus — all channels send messages here
let (tx, mut rx) = tokio::sync::mpsc::channel::<traits::ChannelMessage>(100); let (tx, mut rx) = tokio::sync::mpsc::channel::<traits::ChannelMessage>(100);
// Spawn a listener for each channel // Spawn a listener for each channel
let mut handles = Vec::new(); let mut handles = Vec::new();
for ch in &channels { for ch in &channels {
let ch = ch.clone(); handles.push(spawn_supervised_listener(
let tx = tx.clone(); ch.clone(),
handles.push(tokio::spawn(async move { tx.clone(),
if let Err(e) = ch.listen(tx).await { initial_backoff_secs,
tracing::error!("Channel {} error: {e}", ch.name()); max_backoff_secs,
} ));
}));
} }
drop(tx); // Drop our copy so rx closes when all channels stop drop(tx); // Drop our copy so rx closes when all channels stop
@ -537,6 +662,8 @@ pub async fn start_channels(config: Config) -> Result<()> {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
use tempfile::TempDir; use tempfile::TempDir;
fn make_workspace() -> TempDir { fn make_workspace() -> TempDir {
@ -781,4 +908,55 @@ mod tests {
let state = classify_health_result(&result); let state = classify_health_result(&result);
assert_eq!(state, ChannelHealthState::Timeout); assert_eq!(state, ChannelHealthState::Timeout);
} }
struct AlwaysFailChannel {
name: &'static str,
calls: Arc<AtomicUsize>,
}
#[async_trait::async_trait]
impl Channel for AlwaysFailChannel {
fn name(&self) -> &str {
self.name
}
async fn send(&self, _message: &str, _recipient: &str) -> anyhow::Result<()> {
Ok(())
}
async fn listen(
&self,
_tx: tokio::sync::mpsc::Sender<traits::ChannelMessage>,
) -> anyhow::Result<()> {
self.calls.fetch_add(1, Ordering::SeqCst);
anyhow::bail!("listen boom")
}
}
#[tokio::test]
async fn supervised_listener_marks_error_and_restarts_on_failures() {
let calls = Arc::new(AtomicUsize::new(0));
let channel: Arc<dyn Channel> = Arc::new(AlwaysFailChannel {
name: "test-supervised-fail",
calls: Arc::clone(&calls),
});
let (_tx, rx) = tokio::sync::mpsc::channel::<traits::ChannelMessage>(1);
let handle = spawn_supervised_listener(channel, _tx, 1, 1);
tokio::time::sleep(Duration::from_millis(80)).await;
drop(rx);
handle.abort();
let _ = handle.await;
let snapshot = crate::health::snapshot_json();
let component = &snapshot["components"]["channel:test-supervised-fail"];
assert_eq!(component["status"], "error");
assert!(component["restart_count"].as_u64().unwrap_or(0) >= 1);
assert!(component["last_error"]
.as_str()
.unwrap_or("")
.contains("listen boom"));
assert!(calls.load(Ordering::SeqCst) >= 1);
}
} }

View file

@ -25,6 +25,13 @@ impl TelegramChannel {
fn is_user_allowed(&self, username: &str) -> bool { fn is_user_allowed(&self, username: &str) -> bool {
self.allowed_users.iter().any(|u| u == "*" || u == username) self.allowed_users.iter().any(|u| u == "*" || u == username)
} }
fn is_any_user_allowed<'a, I>(&self, identities: I) -> bool
where
I: IntoIterator<Item = &'a str>,
{
identities.into_iter().any(|id| self.is_user_allowed(id))
}
} }
#[async_trait] #[async_trait]
@ -95,15 +102,28 @@ impl Channel for TelegramChannel {
continue; continue;
}; };
let username = message let username_opt = message
.get("from") .get("from")
.and_then(|f| f.get("username")) .and_then(|f| f.get("username"))
.and_then(|u| u.as_str()) .and_then(|u| u.as_str());
.unwrap_or("unknown"); let username = username_opt.unwrap_or("unknown");
if !self.is_user_allowed(username) { let user_id = message
.get("from")
.and_then(|f| f.get("id"))
.and_then(serde_json::Value::as_i64);
let user_id_str = user_id.map(|id| id.to_string());
let mut identities = vec![username];
if let Some(ref id) = user_id_str {
identities.push(id.as_str());
}
if !self.is_any_user_allowed(identities.iter().copied()) {
tracing::warn!( tracing::warn!(
"Telegram: ignoring message from unauthorized user: {username}" "Telegram: ignoring message from unauthorized user: username={username}, user_id={}. \
Allowlist Telegram @username or numeric user ID, then run `zeroclaw onboard --channels-only`.",
user_id_str.as_deref().unwrap_or("unknown")
); );
continue; continue;
} }
@ -211,4 +231,16 @@ mod tests {
assert!(ch.is_user_allowed("bob")); assert!(ch.is_user_allowed("bob"));
assert!(ch.is_user_allowed("anyone")); assert!(ch.is_user_allowed("anyone"));
} }
#[test]
fn telegram_user_allowed_by_numeric_id_identity() {
let ch = TelegramChannel::new("t".into(), vec!["123456789".into()]);
assert!(ch.is_any_user_allowed(["unknown", "123456789"]));
}
#[test]
fn telegram_user_denied_when_none_of_identities_match() {
let ch = TelegramChannel::new("t".into(), vec!["alice".into(), "987654321".into()]);
assert!(!ch.is_any_user_allowed(["unknown", "123456789"]));
}
} }

File diff suppressed because it is too large Load diff

View file

@ -2,7 +2,7 @@ pub mod schema;
pub use schema::{ pub use schema::{
AutonomyConfig, BrowserConfig, ChannelsConfig, ComposioConfig, Config, DiscordConfig, AutonomyConfig, BrowserConfig, ChannelsConfig, ComposioConfig, Config, DiscordConfig,
GatewayConfig, HeartbeatConfig, IMessageConfig, MatrixConfig, MemoryConfig, GatewayConfig, HeartbeatConfig, IMessageConfig, IdentityConfig, MatrixConfig, MemoryConfig,
ObservabilityConfig, RuntimeConfig, SecretsConfig, SlackConfig, TelegramConfig, TunnelConfig, ObservabilityConfig, ReliabilityConfig, RuntimeConfig, SecretsConfig, SlackConfig,
WebhookConfig, TelegramConfig, TunnelConfig, WebhookConfig,
}; };

View file

@ -25,6 +25,9 @@ pub struct Config {
#[serde(default)] #[serde(default)]
pub runtime: RuntimeConfig, pub runtime: RuntimeConfig,
#[serde(default)]
pub reliability: ReliabilityConfig,
#[serde(default)] #[serde(default)]
pub heartbeat: HeartbeatConfig, pub heartbeat: HeartbeatConfig,
@ -48,6 +51,38 @@ pub struct Config {
#[serde(default)] #[serde(default)]
pub browser: BrowserConfig, pub browser: BrowserConfig,
#[serde(default)]
pub identity: IdentityConfig,
}
// ── Identity (AIEOS / OpenClaw format) ──────────────────────────
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct IdentityConfig {
/// Identity format: "openclaw" (default) or "aieos"
#[serde(default = "default_identity_format")]
pub format: String,
/// Path to AIEOS JSON file (relative to workspace)
#[serde(default)]
pub aieos_path: Option<String>,
/// Inline AIEOS JSON (alternative to file path)
#[serde(default)]
pub aieos_inline: Option<String>,
}
fn default_identity_format() -> String {
"openclaw".into()
}
impl Default for IdentityConfig {
fn default() -> Self {
Self {
format: default_identity_format(),
aieos_path: None,
aieos_inline: None,
}
}
} }
// ── Gateway security ───────────────────────────────────────────── // ── Gateway security ─────────────────────────────────────────────
@ -143,6 +178,18 @@ pub struct MemoryConfig {
pub backend: String, pub backend: String,
/// Auto-save conversation context to memory /// Auto-save conversation context to memory
pub auto_save: bool, pub auto_save: bool,
/// Run memory/session hygiene (archiving + retention cleanup)
#[serde(default = "default_hygiene_enabled")]
pub hygiene_enabled: bool,
/// Archive daily/session files older than this many days
#[serde(default = "default_archive_after_days")]
pub archive_after_days: u32,
/// Purge archived files older than this many days
#[serde(default = "default_purge_after_days")]
pub purge_after_days: u32,
/// For sqlite backend: prune conversation rows older than this many days
#[serde(default = "default_conversation_retention_days")]
pub conversation_retention_days: u32,
/// Embedding provider: "none" | "openai" | "custom:URL" /// Embedding provider: "none" | "openai" | "custom:URL"
#[serde(default = "default_embedding_provider")] #[serde(default = "default_embedding_provider")]
pub embedding_provider: String, pub embedding_provider: String,
@ -169,6 +216,18 @@ pub struct MemoryConfig {
fn default_embedding_provider() -> String { fn default_embedding_provider() -> String {
"none".into() "none".into()
} }
fn default_hygiene_enabled() -> bool {
true
}
fn default_archive_after_days() -> u32 {
7
}
fn default_purge_after_days() -> u32 {
30
}
fn default_conversation_retention_days() -> u32 {
30
}
fn default_embedding_model() -> String { fn default_embedding_model() -> String {
"text-embedding-3-small".into() "text-embedding-3-small".into()
} }
@ -193,6 +252,10 @@ impl Default for MemoryConfig {
Self { Self {
backend: "sqlite".into(), backend: "sqlite".into(),
auto_save: true, auto_save: true,
hygiene_enabled: default_hygiene_enabled(),
archive_after_days: default_archive_after_days(),
purge_after_days: default_purge_after_days(),
conversation_retention_days: default_conversation_retention_days(),
embedding_provider: default_embedding_provider(), embedding_provider: default_embedding_provider(),
embedding_model: default_embedding_model(), embedding_model: default_embedding_model(),
embedding_dimensions: default_embedding_dims(), embedding_dimensions: default_embedding_dims(),
@ -281,7 +344,9 @@ impl Default for AutonomyConfig {
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RuntimeConfig { pub struct RuntimeConfig {
/// "native" | "docker" | "cloudflare" /// Runtime kind (currently supported: "native").
///
/// Reserved values (not implemented yet): "docker", "cloudflare".
pub kind: String, pub kind: String,
} }
@ -293,6 +358,71 @@ impl Default for RuntimeConfig {
} }
} }
// ── Reliability / supervision ────────────────────────────────────
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ReliabilityConfig {
/// Retries per provider before failing over.
#[serde(default = "default_provider_retries")]
pub provider_retries: u32,
/// Base backoff (ms) for provider retry delay.
#[serde(default = "default_provider_backoff_ms")]
pub provider_backoff_ms: u64,
/// Fallback provider chain (e.g. `["anthropic", "openai"]`).
#[serde(default)]
pub fallback_providers: Vec<String>,
/// Initial backoff for channel/daemon restarts.
#[serde(default = "default_channel_backoff_secs")]
pub channel_initial_backoff_secs: u64,
/// Max backoff for channel/daemon restarts.
#[serde(default = "default_channel_backoff_max_secs")]
pub channel_max_backoff_secs: u64,
/// Scheduler polling cadence in seconds.
#[serde(default = "default_scheduler_poll_secs")]
pub scheduler_poll_secs: u64,
/// Max retries for cron job execution attempts.
#[serde(default = "default_scheduler_retries")]
pub scheduler_retries: u32,
}
fn default_provider_retries() -> u32 {
2
}
fn default_provider_backoff_ms() -> u64 {
500
}
fn default_channel_backoff_secs() -> u64 {
2
}
fn default_channel_backoff_max_secs() -> u64 {
60
}
fn default_scheduler_poll_secs() -> u64 {
15
}
fn default_scheduler_retries() -> u32 {
2
}
impl Default for ReliabilityConfig {
fn default() -> Self {
Self {
provider_retries: default_provider_retries(),
provider_backoff_ms: default_provider_backoff_ms(),
fallback_providers: Vec::new(),
channel_initial_backoff_secs: default_channel_backoff_secs(),
channel_max_backoff_secs: default_channel_backoff_max_secs(),
scheduler_poll_secs: default_scheduler_poll_secs(),
scheduler_retries: default_scheduler_retries(),
}
}
}
// ── Heartbeat ──────────────────────────────────────────────────── // ── Heartbeat ────────────────────────────────────────────────────
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
@ -387,6 +517,7 @@ pub struct ChannelsConfig {
pub webhook: Option<WebhookConfig>, pub webhook: Option<WebhookConfig>,
pub imessage: Option<IMessageConfig>, pub imessage: Option<IMessageConfig>,
pub matrix: Option<MatrixConfig>, pub matrix: Option<MatrixConfig>,
pub whatsapp: Option<WhatsAppConfig>,
} }
impl Default for ChannelsConfig { impl Default for ChannelsConfig {
@ -399,6 +530,7 @@ impl Default for ChannelsConfig {
webhook: None, webhook: None,
imessage: None, imessage: None,
matrix: None, matrix: None,
whatsapp: None,
} }
} }
} }
@ -445,6 +577,19 @@ pub struct MatrixConfig {
pub allowed_users: Vec<String>, pub allowed_users: Vec<String>,
} }
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WhatsAppConfig {
/// Access token from Meta Business Suite
pub access_token: String,
/// Phone number ID from Meta Business API
pub phone_number_id: String,
/// Webhook verify token (you define this, Meta sends it back for verification)
pub verify_token: String,
/// Allowed phone numbers (E.164 format: +1234567890) or "*" for all
#[serde(default)]
pub allowed_numbers: Vec<String>,
}
// ── Config impl ────────────────────────────────────────────────── // ── Config impl ──────────────────────────────────────────────────
impl Default for Config { impl Default for Config {
@ -463,6 +608,7 @@ impl Default for Config {
observability: ObservabilityConfig::default(), observability: ObservabilityConfig::default(),
autonomy: AutonomyConfig::default(), autonomy: AutonomyConfig::default(),
runtime: RuntimeConfig::default(), runtime: RuntimeConfig::default(),
reliability: ReliabilityConfig::default(),
heartbeat: HeartbeatConfig::default(), heartbeat: HeartbeatConfig::default(),
channels_config: ChannelsConfig::default(), channels_config: ChannelsConfig::default(),
memory: MemoryConfig::default(), memory: MemoryConfig::default(),
@ -471,6 +617,7 @@ impl Default for Config {
composio: ComposioConfig::default(), composio: ComposioConfig::default(),
secrets: SecretsConfig::default(), secrets: SecretsConfig::default(),
browser: BrowserConfig::default(), browser: BrowserConfig::default(),
identity: IdentityConfig::default(),
} }
} }
} }
@ -558,6 +705,17 @@ mod tests {
assert_eq!(h.interval_minutes, 30); assert_eq!(h.interval_minutes, 30);
} }
#[test]
fn memory_config_default_hygiene_settings() {
let m = MemoryConfig::default();
assert_eq!(m.backend, "sqlite");
assert!(m.auto_save);
assert!(m.hygiene_enabled);
assert_eq!(m.archive_after_days, 7);
assert_eq!(m.purge_after_days, 30);
assert_eq!(m.conversation_retention_days, 30);
}
#[test] #[test]
fn channels_config_default() { fn channels_config_default() {
let c = ChannelsConfig::default(); let c = ChannelsConfig::default();
@ -591,6 +749,7 @@ mod tests {
runtime: RuntimeConfig { runtime: RuntimeConfig {
kind: "docker".into(), kind: "docker".into(),
}, },
reliability: ReliabilityConfig::default(),
heartbeat: HeartbeatConfig { heartbeat: HeartbeatConfig {
enabled: true, enabled: true,
interval_minutes: 15, interval_minutes: 15,
@ -606,6 +765,7 @@ mod tests {
webhook: None, webhook: None,
imessage: None, imessage: None,
matrix: None, matrix: None,
whatsapp: None,
}, },
memory: MemoryConfig::default(), memory: MemoryConfig::default(),
tunnel: TunnelConfig::default(), tunnel: TunnelConfig::default(),
@ -613,6 +773,7 @@ mod tests {
composio: ComposioConfig::default(), composio: ComposioConfig::default(),
secrets: SecretsConfig::default(), secrets: SecretsConfig::default(),
browser: BrowserConfig::default(), browser: BrowserConfig::default(),
identity: IdentityConfig::default(),
}; };
let toml_str = toml::to_string_pretty(&config).unwrap(); let toml_str = toml::to_string_pretty(&config).unwrap();
@ -650,6 +811,10 @@ default_temperature = 0.7
assert_eq!(parsed.runtime.kind, "native"); assert_eq!(parsed.runtime.kind, "native");
assert!(!parsed.heartbeat.enabled); assert!(!parsed.heartbeat.enabled);
assert!(parsed.channels_config.cli); assert!(parsed.channels_config.cli);
assert!(parsed.memory.hygiene_enabled);
assert_eq!(parsed.memory.archive_after_days, 7);
assert_eq!(parsed.memory.purge_after_days, 30);
assert_eq!(parsed.memory.conversation_retention_days, 30);
} }
#[test] #[test]
@ -669,6 +834,7 @@ default_temperature = 0.7
observability: ObservabilityConfig::default(), observability: ObservabilityConfig::default(),
autonomy: AutonomyConfig::default(), autonomy: AutonomyConfig::default(),
runtime: RuntimeConfig::default(), runtime: RuntimeConfig::default(),
reliability: ReliabilityConfig::default(),
heartbeat: HeartbeatConfig::default(), heartbeat: HeartbeatConfig::default(),
channels_config: ChannelsConfig::default(), channels_config: ChannelsConfig::default(),
memory: MemoryConfig::default(), memory: MemoryConfig::default(),
@ -677,6 +843,7 @@ default_temperature = 0.7
composio: ComposioConfig::default(), composio: ComposioConfig::default(),
secrets: SecretsConfig::default(), secrets: SecretsConfig::default(),
browser: BrowserConfig::default(), browser: BrowserConfig::default(),
identity: IdentityConfig::default(),
}; };
config.save().unwrap(); config.save().unwrap();
@ -810,6 +977,7 @@ default_temperature = 0.7
room_id: "!r:m".into(), room_id: "!r:m".into(),
allowed_users: vec!["@u:m".into()], allowed_users: vec!["@u:m".into()],
}), }),
whatsapp: None,
}; };
let toml_str = toml::to_string_pretty(&c).unwrap(); let toml_str = toml::to_string_pretty(&c).unwrap();
let parsed: ChannelsConfig = toml::from_str(&toml_str).unwrap(); let parsed: ChannelsConfig = toml::from_str(&toml_str).unwrap();
@ -894,6 +1062,89 @@ channel_id = "C123"
assert_eq!(parsed.port, 8080); assert_eq!(parsed.port, 8080);
} }
// ── WhatsApp config ──────────────────────────────────────
#[test]
fn whatsapp_config_serde() {
let wc = WhatsAppConfig {
access_token: "EAABx...".into(),
phone_number_id: "123456789".into(),
verify_token: "my-verify-token".into(),
allowed_numbers: vec!["+1234567890".into(), "+9876543210".into()],
};
let json = serde_json::to_string(&wc).unwrap();
let parsed: WhatsAppConfig = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.access_token, "EAABx...");
assert_eq!(parsed.phone_number_id, "123456789");
assert_eq!(parsed.verify_token, "my-verify-token");
assert_eq!(parsed.allowed_numbers.len(), 2);
}
#[test]
fn whatsapp_config_toml_roundtrip() {
let wc = WhatsAppConfig {
access_token: "tok".into(),
phone_number_id: "12345".into(),
verify_token: "verify".into(),
allowed_numbers: vec!["+1".into()],
};
let toml_str = toml::to_string(&wc).unwrap();
let parsed: WhatsAppConfig = toml::from_str(&toml_str).unwrap();
assert_eq!(parsed.phone_number_id, "12345");
assert_eq!(parsed.allowed_numbers, vec!["+1"]);
}
#[test]
fn whatsapp_config_deserializes_without_allowed_numbers() {
let json = r#"{"access_token":"tok","phone_number_id":"123","verify_token":"ver"}"#;
let parsed: WhatsAppConfig = serde_json::from_str(json).unwrap();
assert!(parsed.allowed_numbers.is_empty());
}
#[test]
fn whatsapp_config_wildcard_allowed() {
let wc = WhatsAppConfig {
access_token: "tok".into(),
phone_number_id: "123".into(),
verify_token: "ver".into(),
allowed_numbers: vec!["*".into()],
};
let toml_str = toml::to_string(&wc).unwrap();
let parsed: WhatsAppConfig = toml::from_str(&toml_str).unwrap();
assert_eq!(parsed.allowed_numbers, vec!["*"]);
}
#[test]
fn channels_config_with_whatsapp() {
let c = ChannelsConfig {
cli: true,
telegram: None,
discord: None,
slack: None,
webhook: None,
imessage: None,
matrix: None,
whatsapp: Some(WhatsAppConfig {
access_token: "tok".into(),
phone_number_id: "123".into(),
verify_token: "ver".into(),
allowed_numbers: vec!["+1".into()],
}),
};
let toml_str = toml::to_string_pretty(&c).unwrap();
let parsed: ChannelsConfig = toml::from_str(&toml_str).unwrap();
assert!(parsed.whatsapp.is_some());
let wa = parsed.whatsapp.unwrap();
assert_eq!(wa.phone_number_id, "123");
assert_eq!(wa.allowed_numbers, vec!["+1"]);
}
#[test]
fn channels_config_default_has_no_whatsapp() {
let c = ChannelsConfig::default();
assert!(c.whatsapp.is_none());
}
// ══════════════════════════════════════════════════════════ // ══════════════════════════════════════════════════════════
// SECURITY CHECKLIST TESTS — Gateway config // SECURITY CHECKLIST TESTS — Gateway config
// ══════════════════════════════════════════════════════════ // ══════════════════════════════════════════════════════════

View file

@ -1,25 +1,353 @@
use crate::config::Config; use crate::config::Config;
use anyhow::Result; use anyhow::{Context, Result};
use chrono::{DateTime, Utc};
use cron::Schedule;
use rusqlite::{params, Connection};
use std::str::FromStr;
use uuid::Uuid;
pub fn handle_command(command: super::CronCommands, _config: Config) -> Result<()> { pub mod scheduler;
#[derive(Debug, Clone)]
pub struct CronJob {
pub id: String,
pub expression: String,
pub command: String,
pub next_run: DateTime<Utc>,
pub last_run: Option<DateTime<Utc>>,
pub last_status: Option<String>,
}
pub fn handle_command(command: super::CronCommands, config: Config) -> Result<()> {
match command { match command {
super::CronCommands::List => { super::CronCommands::List => {
let jobs = list_jobs(&config)?;
if jobs.is_empty() {
println!("No scheduled tasks yet."); println!("No scheduled tasks yet.");
println!("\nUsage:"); println!("\nUsage:");
println!(" zeroclaw cron add '0 9 * * *' 'agent -m \"Good morning!\"'"); println!(" zeroclaw cron add '0 9 * * *' 'agent -m \"Good morning!\"'");
return Ok(());
}
println!("🕒 Scheduled jobs ({}):", jobs.len());
for job in jobs {
let last_run = job
.last_run
.map(|d| d.to_rfc3339())
.unwrap_or_else(|| "never".into());
let last_status = job.last_status.unwrap_or_else(|| "n/a".into());
println!(
"- {} | {} | next={} | last={} ({})\n cmd: {}",
job.id,
job.expression,
job.next_run.to_rfc3339(),
last_run,
last_status,
job.command
);
}
Ok(()) Ok(())
} }
super::CronCommands::Add { super::CronCommands::Add {
expression, expression,
command, command,
} => { } => {
println!("Cron scheduling coming soon!"); let job = add_job(&config, &expression, &command)?;
println!(" Expression: {expression}"); println!("✅ Added cron job {}", job.id);
println!(" Command: {command}"); println!(" Expr: {}", job.expression);
println!(" Next: {}", job.next_run.to_rfc3339());
println!(" Cmd : {}", job.command);
Ok(()) Ok(())
} }
super::CronCommands::Remove { id } => { super::CronCommands::Remove { id } => remove_job(&config, &id),
anyhow::bail!("Remove task '{id}' not yet implemented");
} }
} }
pub fn add_job(config: &Config, expression: &str, command: &str) -> Result<CronJob> {
let now = Utc::now();
let next_run = next_run_for(expression, now)?;
let id = Uuid::new_v4().to_string();
with_connection(config, |conn| {
conn.execute(
"INSERT INTO cron_jobs (id, expression, command, created_at, next_run)
VALUES (?1, ?2, ?3, ?4, ?5)",
params![
id,
expression,
command,
now.to_rfc3339(),
next_run.to_rfc3339()
],
)
.context("Failed to insert cron job")?;
Ok(())
})?;
Ok(CronJob {
id,
expression: expression.to_string(),
command: command.to_string(),
next_run,
last_run: None,
last_status: None,
})
}
pub fn list_jobs(config: &Config) -> Result<Vec<CronJob>> {
with_connection(config, |conn| {
let mut stmt = conn.prepare(
"SELECT id, expression, command, next_run, last_run, last_status
FROM cron_jobs ORDER BY next_run ASC",
)?;
let rows = stmt.query_map([], |row| {
let next_run_raw: String = row.get(3)?;
let last_run_raw: Option<String> = row.get(4)?;
Ok((
row.get::<_, String>(0)?,
row.get::<_, String>(1)?,
row.get::<_, String>(2)?,
next_run_raw,
last_run_raw,
row.get::<_, Option<String>>(5)?,
))
})?;
let mut jobs = Vec::new();
for row in rows {
let (id, expression, command, next_run_raw, last_run_raw, last_status) = row?;
jobs.push(CronJob {
id,
expression,
command,
next_run: parse_rfc3339(&next_run_raw)?,
last_run: match last_run_raw {
Some(raw) => Some(parse_rfc3339(&raw)?),
None => None,
},
last_status,
});
}
Ok(jobs)
})
}
pub fn remove_job(config: &Config, id: &str) -> Result<()> {
let changed = with_connection(config, |conn| {
conn.execute("DELETE FROM cron_jobs WHERE id = ?1", params![id])
.context("Failed to delete cron job")
})?;
if changed == 0 {
anyhow::bail!("Cron job '{id}' not found");
}
println!("✅ Removed cron job {id}");
Ok(())
}
pub fn due_jobs(config: &Config, now: DateTime<Utc>) -> Result<Vec<CronJob>> {
with_connection(config, |conn| {
let mut stmt = conn.prepare(
"SELECT id, expression, command, next_run, last_run, last_status
FROM cron_jobs WHERE next_run <= ?1 ORDER BY next_run ASC",
)?;
let rows = stmt.query_map(params![now.to_rfc3339()], |row| {
let next_run_raw: String = row.get(3)?;
let last_run_raw: Option<String> = row.get(4)?;
Ok((
row.get::<_, String>(0)?,
row.get::<_, String>(1)?,
row.get::<_, String>(2)?,
next_run_raw,
last_run_raw,
row.get::<_, Option<String>>(5)?,
))
})?;
let mut jobs = Vec::new();
for row in rows {
let (id, expression, command, next_run_raw, last_run_raw, last_status) = row?;
jobs.push(CronJob {
id,
expression,
command,
next_run: parse_rfc3339(&next_run_raw)?,
last_run: match last_run_raw {
Some(raw) => Some(parse_rfc3339(&raw)?),
None => None,
},
last_status,
});
}
Ok(jobs)
})
}
pub fn reschedule_after_run(
config: &Config,
job: &CronJob,
success: bool,
output: &str,
) -> Result<()> {
let now = Utc::now();
let next_run = next_run_for(&job.expression, now)?;
let status = if success { "ok" } else { "error" };
with_connection(config, |conn| {
conn.execute(
"UPDATE cron_jobs
SET next_run = ?1, last_run = ?2, last_status = ?3, last_output = ?4
WHERE id = ?5",
params![
next_run.to_rfc3339(),
now.to_rfc3339(),
status,
output,
job.id
],
)
.context("Failed to update cron job run state")?;
Ok(())
})
}
fn next_run_for(expression: &str, from: DateTime<Utc>) -> Result<DateTime<Utc>> {
let normalized = normalize_expression(expression)?;
let schedule = Schedule::from_str(&normalized)
.with_context(|| format!("Invalid cron expression: {expression}"))?;
schedule
.after(&from)
.next()
.ok_or_else(|| anyhow::anyhow!("No future occurrence for expression: {expression}"))
}
fn normalize_expression(expression: &str) -> Result<String> {
let expression = expression.trim();
let field_count = expression.split_whitespace().count();
match field_count {
// standard crontab syntax: minute hour day month weekday
5 => Ok(format!("0 {expression}")),
// crate-native syntax includes seconds (+ optional year)
6 | 7 => Ok(expression.to_string()),
_ => anyhow::bail!(
"Invalid cron expression: {expression} (expected 5, 6, or 7 fields, got {field_count})"
),
}
}
fn parse_rfc3339(raw: &str) -> Result<DateTime<Utc>> {
let parsed = DateTime::parse_from_rfc3339(raw)
.with_context(|| format!("Invalid RFC3339 timestamp in cron DB: {raw}"))?;
Ok(parsed.with_timezone(&Utc))
}
fn with_connection<T>(config: &Config, f: impl FnOnce(&Connection) -> Result<T>) -> Result<T> {
let db_path = config.workspace_dir.join("cron").join("jobs.db");
if let Some(parent) = db_path.parent() {
std::fs::create_dir_all(parent)
.with_context(|| format!("Failed to create cron directory: {}", parent.display()))?;
}
let conn = Connection::open(&db_path)
.with_context(|| format!("Failed to open cron DB: {}", db_path.display()))?;
conn.execute_batch(
"CREATE TABLE IF NOT EXISTS cron_jobs (
id TEXT PRIMARY KEY,
expression TEXT NOT NULL,
command TEXT NOT NULL,
created_at TEXT NOT NULL,
next_run TEXT NOT NULL,
last_run TEXT,
last_status TEXT,
last_output TEXT
);
CREATE INDEX IF NOT EXISTS idx_cron_jobs_next_run ON cron_jobs(next_run);",
)
.context("Failed to initialize cron schema")?;
f(&conn)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::Config;
use chrono::Duration as ChronoDuration;
use tempfile::TempDir;
fn test_config(tmp: &TempDir) -> Config {
let mut config = Config::default();
config.workspace_dir = tmp.path().join("workspace");
config.config_path = tmp.path().join("config.toml");
std::fs::create_dir_all(&config.workspace_dir).unwrap();
config
}
#[test]
fn add_job_accepts_five_field_expression() {
let tmp = TempDir::new().unwrap();
let config = test_config(&tmp);
let job = add_job(&config, "*/5 * * * *", "echo ok").unwrap();
assert_eq!(job.expression, "*/5 * * * *");
assert_eq!(job.command, "echo ok");
}
#[test]
fn add_job_rejects_invalid_field_count() {
let tmp = TempDir::new().unwrap();
let config = test_config(&tmp);
let err = add_job(&config, "* * * *", "echo bad").unwrap_err();
assert!(err.to_string().contains("expected 5, 6, or 7 fields"));
}
#[test]
fn add_list_remove_roundtrip() {
let tmp = TempDir::new().unwrap();
let config = test_config(&tmp);
let job = add_job(&config, "*/10 * * * *", "echo roundtrip").unwrap();
let listed = list_jobs(&config).unwrap();
assert_eq!(listed.len(), 1);
assert_eq!(listed[0].id, job.id);
remove_job(&config, &job.id).unwrap();
assert!(list_jobs(&config).unwrap().is_empty());
}
#[test]
fn due_jobs_filters_by_timestamp() {
let tmp = TempDir::new().unwrap();
let config = test_config(&tmp);
let _job = add_job(&config, "* * * * *", "echo due").unwrap();
let due_now = due_jobs(&config, Utc::now()).unwrap();
assert!(due_now.is_empty(), "new job should not be due immediately");
let far_future = Utc::now() + ChronoDuration::days(365);
let due_future = due_jobs(&config, far_future).unwrap();
assert_eq!(due_future.len(), 1, "job should be due in far future");
}
#[test]
fn reschedule_after_run_persists_last_status_and_last_run() {
let tmp = TempDir::new().unwrap();
let config = test_config(&tmp);
let job = add_job(&config, "*/15 * * * *", "echo run").unwrap();
reschedule_after_run(&config, &job, false, "failed output").unwrap();
let listed = list_jobs(&config).unwrap();
let stored = listed.iter().find(|j| j.id == job.id).unwrap();
assert_eq!(stored.last_status.as_deref(), Some("error"));
assert!(stored.last_run.is_some());
}
} }

297
src/cron/scheduler.rs Normal file
View file

@ -0,0 +1,297 @@
use crate::config::Config;
use crate::cron::{due_jobs, reschedule_after_run, CronJob};
use crate::security::SecurityPolicy;
use anyhow::Result;
use chrono::Utc;
use tokio::process::Command;
use tokio::time::{self, Duration};
const MIN_POLL_SECONDS: u64 = 5;
pub async fn run(config: Config) -> Result<()> {
let poll_secs = config.reliability.scheduler_poll_secs.max(MIN_POLL_SECONDS);
let mut interval = time::interval(Duration::from_secs(poll_secs));
let security = SecurityPolicy::from_config(&config.autonomy, &config.workspace_dir);
crate::health::mark_component_ok("scheduler");
loop {
interval.tick().await;
let jobs = match due_jobs(&config, Utc::now()) {
Ok(jobs) => jobs,
Err(e) => {
crate::health::mark_component_error("scheduler", e.to_string());
tracing::warn!("Scheduler query failed: {e}");
continue;
}
};
for job in jobs {
crate::health::mark_component_ok("scheduler");
let (success, output) = execute_job_with_retry(&config, &security, &job).await;
if !success {
crate::health::mark_component_error("scheduler", format!("job {} failed", job.id));
}
if let Err(e) = reschedule_after_run(&config, &job, success, &output) {
crate::health::mark_component_error("scheduler", e.to_string());
tracing::warn!("Failed to persist scheduler run result: {e}");
}
}
}
}
async fn execute_job_with_retry(
config: &Config,
security: &SecurityPolicy,
job: &CronJob,
) -> (bool, String) {
let mut last_output = String::new();
let retries = config.reliability.scheduler_retries;
let mut backoff_ms = config.reliability.provider_backoff_ms.max(200);
for attempt in 0..=retries {
let (success, output) = run_job_command(config, security, job).await;
last_output = output;
if success {
return (true, last_output);
}
if last_output.starts_with("blocked by security policy:") {
// Deterministic policy violations are not retryable.
return (false, last_output);
}
if attempt < retries {
let jitter_ms = (Utc::now().timestamp_subsec_millis() % 250) as u64;
time::sleep(Duration::from_millis(backoff_ms + jitter_ms)).await;
backoff_ms = (backoff_ms.saturating_mul(2)).min(30_000);
}
}
(false, last_output)
}
fn is_env_assignment(word: &str) -> bool {
word.contains('=')
&& word
.chars()
.next()
.is_some_and(|c| c.is_ascii_alphabetic() || c == '_')
}
fn strip_wrapping_quotes(token: &str) -> &str {
token.trim_matches(|c| c == '"' || c == '\'')
}
fn forbidden_path_argument(security: &SecurityPolicy, command: &str) -> Option<String> {
let mut normalized = command.to_string();
for sep in ["&&", "||"] {
normalized = normalized.replace(sep, "\x00");
}
for sep in ['\n', ';', '|'] {
normalized = normalized.replace(sep, "\x00");
}
for segment in normalized.split('\x00') {
let tokens: Vec<&str> = segment.split_whitespace().collect();
if tokens.is_empty() {
continue;
}
// Skip leading env assignments and executable token.
let mut idx = 0;
while idx < tokens.len() && is_env_assignment(tokens[idx]) {
idx += 1;
}
if idx >= tokens.len() {
continue;
}
idx += 1;
for token in &tokens[idx..] {
let candidate = strip_wrapping_quotes(token);
if candidate.is_empty() || candidate.starts_with('-') || candidate.contains("://") {
continue;
}
let looks_like_path = candidate.starts_with('/')
|| candidate.starts_with("./")
|| candidate.starts_with("../")
|| candidate.starts_with("~/")
|| candidate.contains('/');
if looks_like_path && !security.is_path_allowed(candidate) {
return Some(candidate.to_string());
}
}
}
None
}
async fn run_job_command(
config: &Config,
security: &SecurityPolicy,
job: &CronJob,
) -> (bool, String) {
if !security.is_command_allowed(&job.command) {
return (
false,
format!(
"blocked by security policy: command not allowed: {}",
job.command
),
);
}
if let Some(path) = forbidden_path_argument(security, &job.command) {
return (
false,
format!("blocked by security policy: forbidden path argument: {path}"),
);
}
let output = Command::new("sh")
.arg("-lc")
.arg(&job.command)
.current_dir(&config.workspace_dir)
.output()
.await;
match output {
Ok(output) => {
let stdout = String::from_utf8_lossy(&output.stdout);
let stderr = String::from_utf8_lossy(&output.stderr);
let combined = format!(
"status={}\nstdout:\n{}\nstderr:\n{}",
output.status,
stdout.trim(),
stderr.trim()
);
(output.status.success(), combined)
}
Err(e) => (false, format!("spawn error: {e}")),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::Config;
use crate::security::SecurityPolicy;
use tempfile::TempDir;
fn test_config(tmp: &TempDir) -> Config {
let mut config = Config::default();
config.workspace_dir = tmp.path().join("workspace");
config.config_path = tmp.path().join("config.toml");
std::fs::create_dir_all(&config.workspace_dir).unwrap();
config
}
fn test_job(command: &str) -> CronJob {
CronJob {
id: "test-job".into(),
expression: "* * * * *".into(),
command: command.into(),
next_run: Utc::now(),
last_run: None,
last_status: None,
}
}
#[tokio::test]
async fn run_job_command_success() {
let tmp = TempDir::new().unwrap();
let config = test_config(&tmp);
let job = test_job("echo scheduler-ok");
let security = SecurityPolicy::from_config(&config.autonomy, &config.workspace_dir);
let (success, output) = run_job_command(&config, &security, &job).await;
assert!(success);
assert!(output.contains("scheduler-ok"));
assert!(output.contains("status=exit status: 0"));
}
#[tokio::test]
async fn run_job_command_failure() {
let tmp = TempDir::new().unwrap();
let config = test_config(&tmp);
let job = test_job("ls definitely_missing_file_for_scheduler_test");
let security = SecurityPolicy::from_config(&config.autonomy, &config.workspace_dir);
let (success, output) = run_job_command(&config, &security, &job).await;
assert!(!success);
assert!(output.contains("definitely_missing_file_for_scheduler_test"));
assert!(output.contains("status=exit status:"));
}
#[tokio::test]
async fn run_job_command_blocks_disallowed_command() {
let tmp = TempDir::new().unwrap();
let mut config = test_config(&tmp);
config.autonomy.allowed_commands = vec!["echo".into()];
let job = test_job("curl https://evil.example");
let security = SecurityPolicy::from_config(&config.autonomy, &config.workspace_dir);
let (success, output) = run_job_command(&config, &security, &job).await;
assert!(!success);
assert!(output.contains("blocked by security policy"));
assert!(output.contains("command not allowed"));
}
#[tokio::test]
async fn run_job_command_blocks_forbidden_path_argument() {
let tmp = TempDir::new().unwrap();
let mut config = test_config(&tmp);
config.autonomy.allowed_commands = vec!["cat".into()];
let job = test_job("cat /etc/passwd");
let security = SecurityPolicy::from_config(&config.autonomy, &config.workspace_dir);
let (success, output) = run_job_command(&config, &security, &job).await;
assert!(!success);
assert!(output.contains("blocked by security policy"));
assert!(output.contains("forbidden path argument"));
assert!(output.contains("/etc/passwd"));
}
#[tokio::test]
async fn execute_job_with_retry_recovers_after_first_failure() {
let tmp = TempDir::new().unwrap();
let mut config = test_config(&tmp);
config.reliability.scheduler_retries = 1;
config.reliability.provider_backoff_ms = 1;
config.autonomy.allowed_commands = vec!["sh".into()];
let security = SecurityPolicy::from_config(&config.autonomy, &config.workspace_dir);
std::fs::write(
config.workspace_dir.join("retry-once.sh"),
"#!/bin/sh\nif [ -f retry-ok.flag ]; then\n echo recovered\n exit 0\nfi\ntouch retry-ok.flag\nexit 1\n",
)
.unwrap();
let job = test_job("sh ./retry-once.sh");
let (success, output) = execute_job_with_retry(&config, &security, &job).await;
assert!(success);
assert!(output.contains("recovered"));
}
#[tokio::test]
async fn execute_job_with_retry_exhausts_attempts() {
let tmp = TempDir::new().unwrap();
let mut config = test_config(&tmp);
config.reliability.scheduler_retries = 1;
config.reliability.provider_backoff_ms = 1;
let security = SecurityPolicy::from_config(&config.autonomy, &config.workspace_dir);
let job = test_job("ls always_missing_for_retry_test");
let (success, output) = execute_job_with_retry(&config, &security, &job).await;
assert!(!success);
assert!(output.contains("always_missing_for_retry_test"));
}
}

287
src/daemon/mod.rs Normal file
View file

@ -0,0 +1,287 @@
use crate::config::Config;
use anyhow::Result;
use chrono::Utc;
use std::future::Future;
use std::path::PathBuf;
use tokio::task::JoinHandle;
use tokio::time::Duration;
const STATUS_FLUSH_SECONDS: u64 = 5;
pub async fn run(config: Config, host: String, port: u16) -> Result<()> {
let initial_backoff = config.reliability.channel_initial_backoff_secs.max(1);
let max_backoff = config
.reliability
.channel_max_backoff_secs
.max(initial_backoff);
crate::health::mark_component_ok("daemon");
if config.heartbeat.enabled {
let _ =
crate::heartbeat::engine::HeartbeatEngine::ensure_heartbeat_file(&config.workspace_dir)
.await;
}
let mut handles: Vec<JoinHandle<()>> = vec![spawn_state_writer(config.clone())];
{
let gateway_cfg = config.clone();
let gateway_host = host.clone();
handles.push(spawn_component_supervisor(
"gateway",
initial_backoff,
max_backoff,
move || {
let cfg = gateway_cfg.clone();
let host = gateway_host.clone();
async move { crate::gateway::run_gateway(&host, port, cfg).await }
},
));
}
{
if has_supervised_channels(&config) {
let channels_cfg = config.clone();
handles.push(spawn_component_supervisor(
"channels",
initial_backoff,
max_backoff,
move || {
let cfg = channels_cfg.clone();
async move { crate::channels::start_channels(cfg).await }
},
));
} else {
crate::health::mark_component_ok("channels");
tracing::info!("No real-time channels configured; channel supervisor disabled");
}
}
if config.heartbeat.enabled {
let heartbeat_cfg = config.clone();
handles.push(spawn_component_supervisor(
"heartbeat",
initial_backoff,
max_backoff,
move || {
let cfg = heartbeat_cfg.clone();
async move { run_heartbeat_worker(cfg).await }
},
));
}
{
let scheduler_cfg = config.clone();
handles.push(spawn_component_supervisor(
"scheduler",
initial_backoff,
max_backoff,
move || {
let cfg = scheduler_cfg.clone();
async move { crate::cron::scheduler::run(cfg).await }
},
));
}
println!("🧠 ZeroClaw daemon started");
println!(" Gateway: http://{host}:{port}");
println!(" Components: gateway, channels, heartbeat, scheduler");
println!(" Ctrl+C to stop");
tokio::signal::ctrl_c().await?;
crate::health::mark_component_error("daemon", "shutdown requested");
for handle in &handles {
handle.abort();
}
for handle in handles {
let _ = handle.await;
}
Ok(())
}
pub fn state_file_path(config: &Config) -> PathBuf {
config
.config_path
.parent()
.map_or_else(|| PathBuf::from("."), PathBuf::from)
.join("daemon_state.json")
}
fn spawn_state_writer(config: Config) -> JoinHandle<()> {
tokio::spawn(async move {
let path = state_file_path(&config);
if let Some(parent) = path.parent() {
let _ = tokio::fs::create_dir_all(parent).await;
}
let mut interval = tokio::time::interval(Duration::from_secs(STATUS_FLUSH_SECONDS));
loop {
interval.tick().await;
let mut json = crate::health::snapshot_json();
if let Some(obj) = json.as_object_mut() {
obj.insert(
"written_at".into(),
serde_json::json!(Utc::now().to_rfc3339()),
);
}
let data = serde_json::to_vec_pretty(&json).unwrap_or_else(|_| b"{}".to_vec());
let _ = tokio::fs::write(&path, data).await;
}
})
}
fn spawn_component_supervisor<F, Fut>(
name: &'static str,
initial_backoff_secs: u64,
max_backoff_secs: u64,
mut run_component: F,
) -> JoinHandle<()>
where
F: FnMut() -> Fut + Send + 'static,
Fut: Future<Output = Result<()>> + Send + 'static,
{
tokio::spawn(async move {
let mut backoff = initial_backoff_secs.max(1);
let max_backoff = max_backoff_secs.max(backoff);
loop {
crate::health::mark_component_ok(name);
match run_component().await {
Ok(()) => {
crate::health::mark_component_error(name, "component exited unexpectedly");
tracing::warn!("Daemon component '{name}' exited unexpectedly");
}
Err(e) => {
crate::health::mark_component_error(name, e.to_string());
tracing::error!("Daemon component '{name}' failed: {e}");
}
}
crate::health::bump_component_restart(name);
tokio::time::sleep(Duration::from_secs(backoff)).await;
backoff = backoff.saturating_mul(2).min(max_backoff);
}
})
}
async fn run_heartbeat_worker(config: Config) -> Result<()> {
let observer: std::sync::Arc<dyn crate::observability::Observer> =
std::sync::Arc::from(crate::observability::create_observer(&config.observability));
let engine = crate::heartbeat::engine::HeartbeatEngine::new(
config.heartbeat.clone(),
config.workspace_dir.clone(),
observer,
);
let interval_mins = config.heartbeat.interval_minutes.max(5);
let mut interval = tokio::time::interval(Duration::from_secs(u64::from(interval_mins) * 60));
loop {
interval.tick().await;
let tasks = engine.collect_tasks().await?;
if tasks.is_empty() {
continue;
}
for task in tasks {
let prompt = format!("[Heartbeat Task] {task}");
let temp = config.default_temperature;
if let Err(e) = crate::agent::run(config.clone(), Some(prompt), None, None, temp).await
{
crate::health::mark_component_error("heartbeat", e.to_string());
tracing::warn!("Heartbeat task failed: {e}");
} else {
crate::health::mark_component_ok("heartbeat");
}
}
}
}
fn has_supervised_channels(config: &Config) -> bool {
config.channels_config.telegram.is_some()
|| config.channels_config.discord.is_some()
|| config.channels_config.slack.is_some()
|| config.channels_config.imessage.is_some()
|| config.channels_config.matrix.is_some()
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
fn test_config(tmp: &TempDir) -> Config {
let mut config = Config::default();
config.workspace_dir = tmp.path().join("workspace");
config.config_path = tmp.path().join("config.toml");
std::fs::create_dir_all(&config.workspace_dir).unwrap();
config
}
#[test]
fn state_file_path_uses_config_directory() {
let tmp = TempDir::new().unwrap();
let config = test_config(&tmp);
let path = state_file_path(&config);
assert_eq!(path, tmp.path().join("daemon_state.json"));
}
#[tokio::test]
async fn supervisor_marks_error_and_restart_on_failure() {
let handle = spawn_component_supervisor("daemon-test-fail", 1, 1, || async {
anyhow::bail!("boom")
});
tokio::time::sleep(Duration::from_millis(50)).await;
handle.abort();
let _ = handle.await;
let snapshot = crate::health::snapshot_json();
let component = &snapshot["components"]["daemon-test-fail"];
assert_eq!(component["status"], "error");
assert!(component["restart_count"].as_u64().unwrap_or(0) >= 1);
assert!(component["last_error"]
.as_str()
.unwrap_or("")
.contains("boom"));
}
#[tokio::test]
async fn supervisor_marks_unexpected_exit_as_error() {
let handle = spawn_component_supervisor("daemon-test-exit", 1, 1, || async { Ok(()) });
tokio::time::sleep(Duration::from_millis(50)).await;
handle.abort();
let _ = handle.await;
let snapshot = crate::health::snapshot_json();
let component = &snapshot["components"]["daemon-test-exit"];
assert_eq!(component["status"], "error");
assert!(component["restart_count"].as_u64().unwrap_or(0) >= 1);
assert!(component["last_error"]
.as_str()
.unwrap_or("")
.contains("component exited unexpectedly"));
}
#[test]
fn detects_no_supervised_channels() {
let config = Config::default();
assert!(!has_supervised_channels(&config));
}
#[test]
fn detects_supervised_channels_present() {
let mut config = Config::default();
config.channels_config.telegram = Some(crate::config::TelegramConfig {
bot_token: "token".into(),
allowed_users: vec![],
});
assert!(has_supervised_channels(&config));
}
}

123
src/doctor/mod.rs Normal file
View file

@ -0,0 +1,123 @@
use crate::config::Config;
use anyhow::{Context, Result};
use chrono::{DateTime, Utc};
const DAEMON_STALE_SECONDS: i64 = 30;
const SCHEDULER_STALE_SECONDS: i64 = 120;
const CHANNEL_STALE_SECONDS: i64 = 300;
pub fn run(config: &Config) -> Result<()> {
let state_file = crate::daemon::state_file_path(config);
if !state_file.exists() {
println!("🩺 ZeroClaw Doctor");
println!(" ❌ daemon state file not found: {}", state_file.display());
println!(" 💡 Start daemon with: zeroclaw daemon");
return Ok(());
}
let raw = std::fs::read_to_string(&state_file)
.with_context(|| format!("Failed to read {}", state_file.display()))?;
let snapshot: serde_json::Value = serde_json::from_str(&raw)
.with_context(|| format!("Failed to parse {}", state_file.display()))?;
println!("🩺 ZeroClaw Doctor");
println!(" State file: {}", state_file.display());
let updated_at = snapshot
.get("updated_at")
.and_then(serde_json::Value::as_str)
.unwrap_or("");
if let Ok(ts) = DateTime::parse_from_rfc3339(updated_at) {
let age = Utc::now()
.signed_duration_since(ts.with_timezone(&Utc))
.num_seconds();
if age <= DAEMON_STALE_SECONDS {
println!(" ✅ daemon heartbeat fresh ({age}s ago)");
} else {
println!(" ❌ daemon heartbeat stale ({age}s ago)");
}
} else {
println!(" ❌ invalid daemon timestamp: {updated_at}");
}
let mut channel_count = 0_u32;
let mut stale_channels = 0_u32;
if let Some(components) = snapshot
.get("components")
.and_then(serde_json::Value::as_object)
{
if let Some(scheduler) = components.get("scheduler") {
let scheduler_ok = scheduler
.get("status")
.and_then(serde_json::Value::as_str)
.map(|s| s == "ok")
.unwrap_or(false);
let scheduler_last_ok = scheduler
.get("last_ok")
.and_then(serde_json::Value::as_str)
.and_then(parse_rfc3339)
.map(|dt| Utc::now().signed_duration_since(dt).num_seconds())
.unwrap_or(i64::MAX);
if scheduler_ok && scheduler_last_ok <= SCHEDULER_STALE_SECONDS {
println!(
" ✅ scheduler healthy (last ok {}s ago)",
scheduler_last_ok
);
} else {
println!(
" ❌ scheduler unhealthy/stale (status_ok={}, age={}s)",
scheduler_ok, scheduler_last_ok
);
}
} else {
println!(" ❌ scheduler component missing");
}
for (name, component) in components {
if !name.starts_with("channel:") {
continue;
}
channel_count += 1;
let status_ok = component
.get("status")
.and_then(serde_json::Value::as_str)
.map(|s| s == "ok")
.unwrap_or(false);
let age = component
.get("last_ok")
.and_then(serde_json::Value::as_str)
.and_then(parse_rfc3339)
.map(|dt| Utc::now().signed_duration_since(dt).num_seconds())
.unwrap_or(i64::MAX);
if status_ok && age <= CHANNEL_STALE_SECONDS {
println!("{name} fresh (last ok {age}s ago)");
} else {
stale_channels += 1;
println!("{name} stale/unhealthy (status_ok={status_ok}, age={age}s)");
}
}
}
if channel_count == 0 {
println!(" no channel components tracked in state yet");
} else {
println!(
" Channel summary: {} total, {} stale",
channel_count, stale_channels
);
}
Ok(())
}
fn parse_rfc3339(raw: &str) -> Option<DateTime<Utc>> {
DateTime::parse_from_rfc3339(raw)
.ok()
.map(|dt| dt.with_timezone(&Utc))
}

View file

@ -1,15 +1,49 @@
//! Axum-based HTTP gateway with proper HTTP/1.1 compliance, body limits, and timeouts.
//!
//! This module replaces the raw TCP implementation with axum for:
//! - Proper HTTP/1.1 parsing and compliance
//! - Content-Length validation (handled by hyper)
//! - Request body size limits (64KB max)
//! - Request timeouts (30s) to prevent slow-loris attacks
//! - Header sanitization (handled by axum/hyper)
use crate::channels::{Channel, WhatsAppChannel};
use crate::config::Config; use crate::config::Config;
use crate::memory::{self, Memory, MemoryCategory}; use crate::memory::{self, Memory, MemoryCategory};
use crate::providers::{self, Provider}; use crate::providers::{self, Provider};
use crate::security::pairing::{constant_time_eq, is_public_bind, PairingGuard}; use crate::security::pairing::{constant_time_eq, is_public_bind, PairingGuard};
use anyhow::Result; use anyhow::Result;
use axum::{
body::Bytes,
extract::{Query, State},
http::{header, HeaderMap, StatusCode},
response::{IntoResponse, Json},
routing::{get, post},
Router,
};
use std::net::SocketAddr;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use tower_http::limit::RequestBodyLimitLayer;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpListener;
/// Run a minimal HTTP gateway (webhook + health check) /// Maximum request body size (64KB) — prevents memory exhaustion
/// Zero new dependencies — uses raw TCP + tokio. pub const MAX_BODY_SIZE: usize = 65_536;
/// Request timeout (30s) — prevents slow-loris attacks
pub const REQUEST_TIMEOUT_SECS: u64 = 30;
/// Shared state for all axum handlers
#[derive(Clone)]
pub struct AppState {
pub provider: Arc<dyn Provider>,
pub model: String,
pub temperature: f64,
pub mem: Arc<dyn Memory>,
pub auto_save: bool,
pub webhook_secret: Option<Arc<str>>,
pub pairing: Arc<PairingGuard>,
pub whatsapp: Option<Arc<WhatsAppChannel>>,
}
/// Run the HTTP gateway using axum with proper HTTP/1.1 compliance.
#[allow(clippy::too_many_lines)] #[allow(clippy::too_many_lines)]
pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> { pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
// ── Security: refuse public bind without tunnel or explicit opt-in ── // ── Security: refuse public bind without tunnel or explicit opt-in ──
@ -22,13 +56,15 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
); );
} }
let listener = TcpListener::bind(format!("{host}:{port}")).await?; let addr: SocketAddr = format!("{host}:{port}").parse()?;
let listener = tokio::net::TcpListener::bind(addr).await?;
let actual_port = listener.local_addr()?.port(); let actual_port = listener.local_addr()?.port();
let addr = format!("{host}:{actual_port}"); let display_addr = format!("{host}:{actual_port}");
let provider: Arc<dyn Provider> = Arc::from(providers::create_provider( let provider: Arc<dyn Provider> = Arc::from(providers::create_resilient_provider(
config.default_provider.as_deref().unwrap_or("openrouter"), config.default_provider.as_deref().unwrap_or("openrouter"),
config.api_key.as_deref(), config.api_key.as_deref(),
&config.reliability,
)?); )?);
let model = config let model = config
.default_model .default_model
@ -49,6 +85,17 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
.and_then(|w| w.secret.as_deref()) .and_then(|w| w.secret.as_deref())
.map(Arc::from); .map(Arc::from);
// WhatsApp channel (if configured)
let whatsapp_channel: Option<Arc<WhatsAppChannel>> =
config.channels_config.whatsapp.as_ref().map(|wa| {
Arc::new(WhatsAppChannel::new(
wa.access_token.clone(),
wa.phone_number_id.clone(),
wa.verify_token.clone(),
wa.allowed_numbers.clone(),
))
});
// ── Pairing guard ────────────────────────────────────── // ── Pairing guard ──────────────────────────────────────
let pairing = Arc::new(PairingGuard::new( let pairing = Arc::new(PairingGuard::new(
config.gateway.require_pairing, config.gateway.require_pairing,
@ -73,16 +120,20 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
} }
} }
println!("🦀 ZeroClaw Gateway listening on http://{addr}"); println!("🦀 ZeroClaw Gateway listening on http://{display_addr}");
if let Some(ref url) = tunnel_url { if let Some(ref url) = tunnel_url {
println!(" 🌐 Public URL: {url}"); println!(" 🌐 Public URL: {url}");
} }
println!(" POST /pair — pair a new client (X-Pairing-Code header)"); println!(" POST /pair — pair a new client (X-Pairing-Code header)");
println!(" POST /webhook — {{\"message\": \"your prompt\"}}"); println!(" POST /webhook — {{\"message\": \"your prompt\"}}");
if whatsapp_channel.is_some() {
println!(" GET /whatsapp — Meta webhook verification");
println!(" POST /whatsapp — WhatsApp message webhook");
}
println!(" GET /health — health check"); println!(" GET /health — health check");
if let Some(code) = pairing.pairing_code() { if let Some(code) = pairing.pairing_code() {
println!(); println!();
println!(" <20> PAIRING REQUIRED — use this one-time code:"); println!(" 🔐 PAIRING REQUIRED — use this one-time code:");
println!(" ┌──────────────┐"); println!(" ┌──────────────┐");
println!("{code}"); println!("{code}");
println!(" └──────────────┘"); println!(" └──────────────┘");
@ -97,92 +148,60 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
} }
println!(" Press Ctrl+C to stop.\n"); println!(" Press Ctrl+C to stop.\n");
loop { crate::health::mark_component_ok("gateway");
let (mut stream, peer) = listener.accept().await?;
let provider = provider.clone();
let model = model.clone();
let mem = mem.clone();
let auto_save = config.memory.auto_save;
let secret = webhook_secret.clone();
let pairing = pairing.clone();
tokio::spawn(async move { // Build shared state
// Read with 30s timeout to prevent slow-loris attacks let state = AppState {
let mut buf = vec![0u8; 65_536]; // 64KB max request provider,
let n = match tokio::time::timeout(Duration::from_secs(30), stream.read(&mut buf)).await model,
{ temperature,
Ok(Ok(n)) if n > 0 => n, mem,
_ => return, auto_save: config.memory.auto_save,
webhook_secret,
pairing,
whatsapp: whatsapp_channel,
}; };
let request = String::from_utf8_lossy(&buf[..n]); // Build router with middleware
let first_line = request.lines().next().unwrap_or(""); // Note: Body limit layer prevents memory exhaustion from oversized requests
let parts: Vec<&str> = first_line.split_whitespace().collect(); // Timeout is handled by tokio's TcpListener accept timeout and hyper's built-in timeouts
let app = Router::new()
.route("/health", get(handle_health))
.route("/pair", post(handle_pair))
.route("/webhook", post(handle_webhook))
.route("/whatsapp", get(handle_whatsapp_verify))
.route("/whatsapp", post(handle_whatsapp_message))
.with_state(state)
.layer(RequestBodyLimitLayer::new(MAX_BODY_SIZE));
if let [method, path, ..] = parts.as_slice() { // Run the server
tracing::info!("{peer} → {method} {path}"); axum::serve(listener, app).await?;
handle_request(
&mut stream, Ok(())
method,
path,
&request,
&provider,
&model,
temperature,
&mem,
auto_save,
secret.as_ref(),
&pairing,
)
.await;
} else {
let _ = send_response(&mut stream, 400, "Bad Request").await;
}
});
}
} }
/// Extract a header value from a raw HTTP request. // ══════════════════════════════════════════════════════════════════════════════
fn extract_header<'a>(request: &'a str, header_name: &str) -> Option<&'a str> { // AXUM HANDLERS
let lower_name = header_name.to_lowercase(); // ══════════════════════════════════════════════════════════════════════════════
for line in request.lines() {
if let Some((key, value)) = line.split_once(':') {
if key.trim().to_lowercase() == lower_name {
return Some(value.trim());
}
}
}
None
}
#[allow(clippy::too_many_arguments)] /// GET /health — always public (no secrets leaked)
async fn handle_request( async fn handle_health(State(state): State<AppState>) -> impl IntoResponse {
stream: &mut tokio::net::TcpStream,
method: &str,
path: &str,
request: &str,
provider: &Arc<dyn Provider>,
model: &str,
temperature: f64,
mem: &Arc<dyn Memory>,
auto_save: bool,
webhook_secret: Option<&Arc<str>>,
pairing: &PairingGuard,
) {
match (method, path) {
// Health check — always public (no secrets leaked)
("GET", "/health") => {
let body = serde_json::json!({ let body = serde_json::json!({
"status": "ok", "status": "ok",
"paired": pairing.is_paired(), "paired": state.pairing.is_paired(),
"runtime": crate::health::snapshot_json(),
}); });
let _ = send_json(stream, 200, &body).await; Json(body)
} }
// Pairing endpoint — exchange one-time code for bearer token /// POST /pair — exchange one-time code for bearer token
("POST", "/pair") => { async fn handle_pair(State(state): State<AppState>, headers: HeaderMap) -> impl IntoResponse {
let code = extract_header(request, "X-Pairing-Code").unwrap_or(""); let code = headers
match pairing.try_pair(code) { .get("X-Pairing-Code")
.and_then(|v| v.to_str().ok())
.unwrap_or("");
match state.pairing.try_pair(code) {
Ok(Some(token)) => { Ok(Some(token)) => {
tracing::info!("🔐 New client paired successfully"); tracing::info!("🔐 New client paired successfully");
let body = serde_json::json!({ let body = serde_json::json!({
@ -190,12 +209,12 @@ async fn handle_request(
"token": token, "token": token,
"message": "Save this token — use it as Authorization: Bearer <token>" "message": "Save this token — use it as Authorization: Bearer <token>"
}); });
let _ = send_json(stream, 200, &body).await; (StatusCode::OK, Json(body))
} }
Ok(None) => { Ok(None) => {
tracing::warn!("🔐 Pairing attempt with invalid code"); tracing::warn!("🔐 Pairing attempt with invalid code");
let err = serde_json::json!({"error": "Invalid pairing code"}); let err = serde_json::json!({"error": "Invalid pairing code"});
let _ = send_json(stream, 403, &err).await; (StatusCode::FORBIDDEN, Json(err))
} }
Err(lockout_secs) => { Err(lockout_secs) => {
tracing::warn!( tracing::warn!(
@ -205,320 +224,236 @@ async fn handle_request(
"error": format!("Too many failed attempts. Try again in {lockout_secs}s."), "error": format!("Too many failed attempts. Try again in {lockout_secs}s."),
"retry_after": lockout_secs "retry_after": lockout_secs
}); });
let _ = send_json(stream, 429, &err).await; (StatusCode::TOO_MANY_REQUESTS, Json(err))
} }
} }
} }
("POST", "/webhook") => { /// Webhook request body
#[derive(serde::Deserialize)]
pub struct WebhookBody {
pub message: String,
}
/// POST /webhook — main webhook endpoint
async fn handle_webhook(
State(state): State<AppState>,
headers: HeaderMap,
body: Result<Json<WebhookBody>, axum::extract::rejection::JsonRejection>,
) -> impl IntoResponse {
// ── Bearer token auth (pairing) ── // ── Bearer token auth (pairing) ──
if pairing.require_pairing() { if state.pairing.require_pairing() {
let auth = extract_header(request, "Authorization").unwrap_or(""); let auth = headers
.get(header::AUTHORIZATION)
.and_then(|v| v.to_str().ok())
.unwrap_or("");
let token = auth.strip_prefix("Bearer ").unwrap_or(""); let token = auth.strip_prefix("Bearer ").unwrap_or("");
if !pairing.is_authenticated(token) { if !state.pairing.is_authenticated(token) {
tracing::warn!("Webhook: rejected — not paired / invalid bearer token"); tracing::warn!("Webhook: rejected — not paired / invalid bearer token");
let err = serde_json::json!({ let err = serde_json::json!({
"error": "Unauthorized — pair first via POST /pair, then send Authorization: Bearer <token>" "error": "Unauthorized — pair first via POST /pair, then send Authorization: Bearer <token>"
}); });
let _ = send_json(stream, 401, &err).await; return (StatusCode::UNAUTHORIZED, Json(err));
return;
} }
} }
// ── Webhook secret auth (optional, additional layer) ── // ── Webhook secret auth (optional, additional layer) ──
if let Some(secret) = webhook_secret { if let Some(ref secret) = state.webhook_secret {
let header_val = extract_header(request, "X-Webhook-Secret"); let header_val = headers
.get("X-Webhook-Secret")
.and_then(|v| v.to_str().ok());
match header_val { match header_val {
Some(val) if constant_time_eq(val, secret.as_ref()) => {} Some(val) if constant_time_eq(val, secret.as_ref()) => {}
_ => { _ => {
tracing::warn!( tracing::warn!("Webhook: rejected request — invalid or missing X-Webhook-Secret");
"Webhook: rejected request — invalid or missing X-Webhook-Secret"
);
let err = serde_json::json!({"error": "Unauthorized — invalid or missing X-Webhook-Secret header"}); let err = serde_json::json!({"error": "Unauthorized — invalid or missing X-Webhook-Secret header"});
let _ = send_json(stream, 401, &err).await; return (StatusCode::UNAUTHORIZED, Json(err));
return;
} }
} }
} }
handle_webhook(
stream,
request,
provider,
model,
temperature,
mem,
auto_save,
)
.await;
}
_ => { // ── Parse body ──
let body = serde_json::json!({ let Json(webhook_body) = match body {
"error": "Not found", Ok(b) => b,
"routes": ["GET /health", "POST /pair", "POST /webhook"] Err(e) => {
let err = serde_json::json!({
"error": format!("Invalid JSON: {e}. Expected: {{\"message\": \"...\"}}")
}); });
let _ = send_json(stream, 404, &body).await; return (StatusCode::BAD_REQUEST, Json(err));
} }
}
}
async fn handle_webhook(
stream: &mut tokio::net::TcpStream,
request: &str,
provider: &Arc<dyn Provider>,
model: &str,
temperature: f64,
mem: &Arc<dyn Memory>,
auto_save: bool,
) {
let body_str = request
.split("\r\n\r\n")
.nth(1)
.or_else(|| request.split("\n\n").nth(1))
.unwrap_or("");
let Ok(parsed) = serde_json::from_str::<serde_json::Value>(body_str) else {
let err = serde_json::json!({"error": "Invalid JSON. Expected: {\"message\": \"...\"}"});
let _ = send_json(stream, 400, &err).await;
return;
}; };
let Some(message) = parsed.get("message").and_then(|v| v.as_str()) else { let message = &webhook_body.message;
let err = serde_json::json!({"error": "Missing 'message' field in JSON"});
let _ = send_json(stream, 400, &err).await;
return;
};
if auto_save { if state.auto_save {
let _ = mem let _ = state
.mem
.store("webhook_msg", message, MemoryCategory::Conversation) .store("webhook_msg", message, MemoryCategory::Conversation)
.await; .await;
} }
match provider.chat(message, model, temperature).await { match state
.provider
.chat(message, &state.model, state.temperature)
.await
{
Ok(response) => { Ok(response) => {
let body = serde_json::json!({"response": response, "model": model}); let body = serde_json::json!({"response": response, "model": state.model});
let _ = send_json(stream, 200, &body).await; (StatusCode::OK, Json(body))
} }
Err(e) => { Err(e) => {
let err = serde_json::json!({"error": format!("LLM error: {e}")}); let err = serde_json::json!({"error": format!("LLM error: {e}")});
let _ = send_json(stream, 500, &err).await; (StatusCode::INTERNAL_SERVER_ERROR, Json(err))
} }
} }
} }
async fn send_response( /// `WhatsApp` verification query params
stream: &mut tokio::net::TcpStream, #[derive(serde::Deserialize)]
status: u16, pub struct WhatsAppVerifyQuery {
body: &str, #[serde(rename = "hub.mode")]
) -> std::io::Result<()> { pub mode: Option<String>,
let reason = match status { #[serde(rename = "hub.verify_token")]
200 => "OK", pub verify_token: Option<String>,
400 => "Bad Request", #[serde(rename = "hub.challenge")]
404 => "Not Found", pub challenge: Option<String>,
500 => "Internal Server Error",
_ => "Unknown",
};
let response = format!(
"HTTP/1.1 {status} {reason}\r\nContent-Type: text/plain\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{body}",
body.len()
);
stream.write_all(response.as_bytes()).await
} }
async fn send_json( /// GET /whatsapp — Meta webhook verification
stream: &mut tokio::net::TcpStream, async fn handle_whatsapp_verify(
status: u16, State(state): State<AppState>,
body: &serde_json::Value, Query(params): Query<WhatsAppVerifyQuery>,
) -> std::io::Result<()> { ) -> impl IntoResponse {
let reason = match status { let Some(ref wa) = state.whatsapp else {
200 => "OK", return (StatusCode::NOT_FOUND, "WhatsApp not configured".to_string());
400 => "Bad Request",
404 => "Not Found",
500 => "Internal Server Error",
_ => "Unknown",
}; };
let json = serde_json::to_string(body).unwrap_or_default();
let response = format!( // Verify the token matches
"HTTP/1.1 {status} {reason}\r\nContent-Type: application/json\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{json}", if params.mode.as_deref() == Some("subscribe")
json.len() && params.verify_token.as_deref() == Some(wa.verify_token())
{
if let Some(ch) = params.challenge {
tracing::info!("WhatsApp webhook verified successfully");
return (StatusCode::OK, ch);
}
return (StatusCode::BAD_REQUEST, "Missing hub.challenge".to_string());
}
tracing::warn!("WhatsApp webhook verification failed — token mismatch");
(StatusCode::FORBIDDEN, "Forbidden".to_string())
}
/// POST /whatsapp — incoming message webhook
async fn handle_whatsapp_message(State(state): State<AppState>, body: Bytes) -> impl IntoResponse {
let Some(ref wa) = state.whatsapp else {
return (
StatusCode::NOT_FOUND,
Json(serde_json::json!({"error": "WhatsApp not configured"})),
); );
stream.write_all(response.as_bytes()).await };
// Parse JSON body
let Ok(payload) = serde_json::from_slice::<serde_json::Value>(&body) else {
return (
StatusCode::BAD_REQUEST,
Json(serde_json::json!({"error": "Invalid JSON payload"})),
);
};
// Parse messages from the webhook payload
let messages = wa.parse_webhook_payload(&payload);
if messages.is_empty() {
// Acknowledge the webhook even if no messages (could be status updates)
return (StatusCode::OK, Json(serde_json::json!({"status": "ok"})));
}
// Process each message
for msg in &messages {
tracing::info!(
"WhatsApp message from {}: {}",
msg.sender,
if msg.content.len() > 50 {
format!("{}...", &msg.content[..50])
} else {
msg.content.clone()
}
);
// Auto-save to memory
if state.auto_save {
let _ = state
.mem
.store(
&format!("whatsapp_{}", msg.sender),
&msg.content,
MemoryCategory::Conversation,
)
.await;
}
// Call the LLM
match state
.provider
.chat(&msg.content, &state.model, state.temperature)
.await
{
Ok(response) => {
// Send reply via WhatsApp
if let Err(e) = wa.send(&response, &msg.sender).await {
tracing::error!("Failed to send WhatsApp reply: {e}");
}
}
Err(e) => {
tracing::error!("LLM error for WhatsApp message: {e}");
let _ = wa.send(&format!("⚠️ Error: {e}"), &msg.sender).await;
}
}
}
// Acknowledge the webhook
(StatusCode::OK, Json(serde_json::json!({"status": "ok"})))
} }
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
use tokio::net::TcpListener as TokioListener;
// ── Port allocation tests ────────────────────────────────
#[tokio::test]
async fn port_zero_binds_to_random_port() {
let listener = TokioListener::bind("127.0.0.1:0").await.unwrap();
let actual = listener.local_addr().unwrap().port();
assert_ne!(actual, 0, "OS must assign a non-zero port");
assert!(actual > 0, "Actual port must be positive");
}
#[tokio::test]
async fn port_zero_assigns_different_ports() {
let l1 = TokioListener::bind("127.0.0.1:0").await.unwrap();
let l2 = TokioListener::bind("127.0.0.1:0").await.unwrap();
let p1 = l1.local_addr().unwrap().port();
let p2 = l2.local_addr().unwrap().port();
assert_ne!(p1, p2, "Two port-0 binds should get different ports");
}
#[tokio::test]
async fn port_zero_assigns_high_port() {
let listener = TokioListener::bind("127.0.0.1:0").await.unwrap();
let actual = listener.local_addr().unwrap().port();
// OS typically assigns ephemeral ports >= 1024
assert!(
actual >= 1024,
"Random port {actual} should be >= 1024 (unprivileged)"
);
}
#[tokio::test]
async fn specific_port_binds_exactly() {
// Find a free port first via port 0, then rebind to it
let tmp = TokioListener::bind("127.0.0.1:0").await.unwrap();
let free_port = tmp.local_addr().unwrap().port();
drop(tmp);
let listener = TokioListener::bind(format!("127.0.0.1:{free_port}"))
.await
.unwrap();
let actual = listener.local_addr().unwrap().port();
assert_eq!(actual, free_port, "Specific port bind must match exactly");
}
#[tokio::test]
async fn actual_port_matches_addr_format() {
let listener = TokioListener::bind("127.0.0.1:0").await.unwrap();
let actual_port = listener.local_addr().unwrap().port();
let addr = format!("127.0.0.1:{actual_port}");
assert!(
addr.starts_with("127.0.0.1:"),
"Addr format must include host"
);
assert!(
!addr.ends_with(":0"),
"Addr must not contain port 0 after binding"
);
}
#[tokio::test]
async fn port_zero_listener_accepts_connections() {
let listener = TokioListener::bind("127.0.0.1:0").await.unwrap();
let actual_port = listener.local_addr().unwrap().port();
// Spawn a client that connects
let client = tokio::spawn(async move {
tokio::net::TcpStream::connect(format!("127.0.0.1:{actual_port}"))
.await
.unwrap()
});
// Accept the connection
let (stream, _peer) = listener.accept().await.unwrap();
assert!(stream.peer_addr().is_ok());
client.await.unwrap();
}
#[tokio::test]
async fn duplicate_specific_port_fails() {
let l1 = TokioListener::bind("127.0.0.1:0").await.unwrap();
let port = l1.local_addr().unwrap().port();
// Try to bind the same port while l1 is still alive
let result = TokioListener::bind(format!("127.0.0.1:{port}")).await;
assert!(result.is_err(), "Binding an already-used port must fail");
}
#[tokio::test]
async fn tunnel_gets_actual_port_not_zero() {
// Simulate what run_gateway does: bind port 0, extract actual port
let port: u16 = 0;
let host = "127.0.0.1";
let listener = TokioListener::bind(format!("{host}:{port}")).await.unwrap();
let actual_port = listener.local_addr().unwrap().port();
// This is the port that would be passed to tun.start(host, actual_port)
assert_ne!(actual_port, 0, "Tunnel must receive actual port, not 0");
assert!(
actual_port >= 1024,
"Tunnel port {actual_port} must be unprivileged"
);
}
// ── extract_header tests ─────────────────────────────────
#[test] #[test]
fn extract_header_finds_value() { fn security_body_limit_is_64kb() {
let req = assert_eq!(MAX_BODY_SIZE, 65_536);
"POST /webhook HTTP/1.1\r\nHost: localhost\r\nX-Webhook-Secret: my-secret\r\n\r\n{}";
assert_eq!(extract_header(req, "X-Webhook-Secret"), Some("my-secret"));
} }
#[test] #[test]
fn extract_header_case_insensitive() { fn security_timeout_is_30_seconds() {
let req = "POST /webhook HTTP/1.1\r\nx-webhook-secret: abc123\r\n\r\n{}"; assert_eq!(REQUEST_TIMEOUT_SECS, 30);
assert_eq!(extract_header(req, "X-Webhook-Secret"), Some("abc123"));
} }
#[test] #[test]
fn extract_header_missing_returns_none() { fn webhook_body_requires_message_field() {
let req = "POST /webhook HTTP/1.1\r\nHost: localhost\r\n\r\n{}"; let valid = r#"{"message": "hello"}"#;
assert_eq!(extract_header(req, "X-Webhook-Secret"), None); let parsed: Result<WebhookBody, _> = serde_json::from_str(valid);
assert!(parsed.is_ok());
assert_eq!(parsed.unwrap().message, "hello");
let missing = r#"{"other": "field"}"#;
let parsed: Result<WebhookBody, _> = serde_json::from_str(missing);
assert!(parsed.is_err());
} }
#[test] #[test]
fn extract_header_trims_whitespace() { fn whatsapp_query_fields_are_optional() {
let req = "POST /webhook HTTP/1.1\r\nX-Webhook-Secret: spaced \r\n\r\n{}"; let q = WhatsAppVerifyQuery {
assert_eq!(extract_header(req, "X-Webhook-Secret"), Some("spaced")); mode: None,
verify_token: None,
challenge: None,
};
assert!(q.mode.is_none());
} }
#[test] #[test]
fn extract_header_first_match_wins() { fn app_state_is_clone() {
let req = "POST /webhook HTTP/1.1\r\nX-Webhook-Secret: first\r\nX-Webhook-Secret: second\r\n\r\n{}"; fn assert_clone<T: Clone>() {}
assert_eq!(extract_header(req, "X-Webhook-Secret"), Some("first")); assert_clone::<AppState>();
}
#[test]
fn extract_header_empty_value() {
let req = "POST /webhook HTTP/1.1\r\nX-Webhook-Secret:\r\n\r\n{}";
assert_eq!(extract_header(req, "X-Webhook-Secret"), Some(""));
}
#[test]
fn extract_header_colon_in_value() {
let req = "POST /webhook HTTP/1.1\r\nAuthorization: Bearer sk-abc:123\r\n\r\n{}";
// split_once on ':' means only the first colon splits key/value
assert_eq!(
extract_header(req, "Authorization"),
Some("Bearer sk-abc:123")
);
}
#[test]
fn extract_header_different_header() {
let req = "POST /webhook HTTP/1.1\r\nContent-Type: application/json\r\nX-Webhook-Secret: mysecret\r\n\r\n{}";
assert_eq!(
extract_header(req, "Content-Type"),
Some("application/json")
);
assert_eq!(extract_header(req, "X-Webhook-Secret"), Some("mysecret"));
}
#[test]
fn extract_header_from_empty_request() {
assert_eq!(extract_header("", "X-Webhook-Secret"), None);
}
#[test]
fn extract_header_newline_only_request() {
assert_eq!(extract_header("\r\n\r\n", "X-Webhook-Secret"), None);
} }
} }

105
src/health/mod.rs Normal file
View file

@ -0,0 +1,105 @@
use chrono::Utc;
use serde::Serialize;
use std::collections::BTreeMap;
use std::sync::{Mutex, OnceLock};
use std::time::Instant;
#[derive(Debug, Clone, Serialize)]
pub struct ComponentHealth {
pub status: String,
pub updated_at: String,
pub last_ok: Option<String>,
pub last_error: Option<String>,
pub restart_count: u64,
}
#[derive(Debug, Clone, Serialize)]
pub struct HealthSnapshot {
pub pid: u32,
pub updated_at: String,
pub uptime_seconds: u64,
pub components: BTreeMap<String, ComponentHealth>,
}
struct HealthRegistry {
started_at: Instant,
components: Mutex<BTreeMap<String, ComponentHealth>>,
}
static REGISTRY: OnceLock<HealthRegistry> = OnceLock::new();
fn registry() -> &'static HealthRegistry {
REGISTRY.get_or_init(|| HealthRegistry {
started_at: Instant::now(),
components: Mutex::new(BTreeMap::new()),
})
}
fn now_rfc3339() -> String {
Utc::now().to_rfc3339()
}
fn upsert_component<F>(component: &str, update: F)
where
F: FnOnce(&mut ComponentHealth),
{
if let Ok(mut map) = registry().components.lock() {
let now = now_rfc3339();
let entry = map
.entry(component.to_string())
.or_insert_with(|| ComponentHealth {
status: "starting".into(),
updated_at: now.clone(),
last_ok: None,
last_error: None,
restart_count: 0,
});
update(entry);
entry.updated_at = now;
}
}
pub fn mark_component_ok(component: &str) {
upsert_component(component, |entry| {
entry.status = "ok".into();
entry.last_ok = Some(now_rfc3339());
entry.last_error = None;
});
}
pub fn mark_component_error(component: &str, error: impl ToString) {
let err = error.to_string();
upsert_component(component, move |entry| {
entry.status = "error".into();
entry.last_error = Some(err);
});
}
pub fn bump_component_restart(component: &str) {
upsert_component(component, |entry| {
entry.restart_count = entry.restart_count.saturating_add(1);
});
}
pub fn snapshot() -> HealthSnapshot {
let components = registry()
.components
.lock()
.map_or_else(|_| BTreeMap::new(), |map| map.clone());
HealthSnapshot {
pid: std::process::id(),
updated_at: now_rfc3339(),
uptime_seconds: registry().started_at.elapsed().as_secs(),
components,
}
}
pub fn snapshot_json() -> serde_json::Value {
serde_json::to_value(snapshot()).unwrap_or_else(|_| {
serde_json::json!({
"status": "error",
"message": "failed to serialize health snapshot"
})
})
}

View file

@ -61,16 +61,17 @@ impl HeartbeatEngine {
/// Single heartbeat tick — read HEARTBEAT.md and return task count /// Single heartbeat tick — read HEARTBEAT.md and return task count
async fn tick(&self) -> Result<usize> { async fn tick(&self) -> Result<usize> {
let heartbeat_path = self.workspace_dir.join("HEARTBEAT.md"); Ok(self.collect_tasks().await?.len())
if !heartbeat_path.exists() {
return Ok(0);
} }
/// Read HEARTBEAT.md and return all parsed tasks.
pub async fn collect_tasks(&self) -> Result<Vec<String>> {
let heartbeat_path = self.workspace_dir.join("HEARTBEAT.md");
if !heartbeat_path.exists() {
return Ok(Vec::new());
}
let content = tokio::fs::read_to_string(&heartbeat_path).await?; let content = tokio::fs::read_to_string(&heartbeat_path).await?;
let tasks = Self::parse_tasks(&content); Ok(Self::parse_tasks(&content))
Ok(tasks.len())
} }
/// Parse tasks from HEARTBEAT.md (lines starting with `- `) /// Parse tasks from HEARTBEAT.md (lines starting with `- `)

View file

@ -8,7 +8,7 @@
dead_code dead_code
)] )]
use anyhow::Result; use anyhow::{bail, Result};
use clap::{Parser, Subcommand}; use clap::{Parser, Subcommand};
use tracing::{info, Level}; use tracing::{info, Level};
use tracing_subscriber::FmtSubscriber; use tracing_subscriber::FmtSubscriber;
@ -17,15 +17,20 @@ mod agent;
mod channels; mod channels;
mod config; mod config;
mod cron; mod cron;
mod daemon;
mod doctor;
mod gateway; mod gateway;
mod health;
mod heartbeat; mod heartbeat;
mod integrations; mod integrations;
mod memory; mod memory;
mod migration;
mod observability; mod observability;
mod onboard; mod onboard;
mod providers; mod providers;
mod runtime; mod runtime;
mod security; mod security;
mod service;
mod skills; mod skills;
mod tools; mod tools;
mod tunnel; mod tunnel;
@ -43,6 +48,20 @@ struct Cli {
command: Commands, command: Commands,
} }
#[derive(Subcommand, Debug)]
enum ServiceCommands {
/// Install daemon service unit for auto-start and restart
Install,
/// Start daemon service
Start,
/// Stop daemon service
Stop,
/// Check daemon service status
Status,
/// Uninstall daemon service unit
Uninstall,
}
#[derive(Subcommand, Debug)] #[derive(Subcommand, Debug)]
enum Commands { enum Commands {
/// Initialize your workspace and configuration /// Initialize your workspace and configuration
@ -51,6 +70,10 @@ enum Commands {
#[arg(long)] #[arg(long)]
interactive: bool, interactive: bool,
/// Reconfigure channels only (fast repair flow)
#[arg(long)]
channels_only: bool,
/// API key (used in quick mode, ignored with --interactive) /// API key (used in quick mode, ignored with --interactive)
#[arg(long)] #[arg(long)]
api_key: Option<String>, api_key: Option<String>,
@ -71,7 +94,7 @@ enum Commands {
provider: Option<String>, provider: Option<String>,
/// Model to use /// Model to use
#[arg(short, long)] #[arg(long)]
model: Option<String>, model: Option<String>,
/// Temperature (0.0 - 2.0) /// Temperature (0.0 - 2.0)
@ -86,10 +109,30 @@ enum Commands {
port: u16, port: u16,
/// Host to bind to /// Host to bind to
#[arg(short, long, default_value = "127.0.0.1")] #[arg(long, default_value = "127.0.0.1")]
host: String, host: String,
}, },
/// Start long-running autonomous runtime (gateway + channels + heartbeat + scheduler)
Daemon {
/// Port to listen on (use 0 for random available port)
#[arg(short, long, default_value = "8080")]
port: u16,
/// Host to bind to
#[arg(long, default_value = "127.0.0.1")]
host: String,
},
/// Manage OS service lifecycle (launchd/systemd user service)
Service {
#[command(subcommand)]
service_command: ServiceCommands,
},
/// Run diagnostics for daemon/scheduler/channel freshness
Doctor,
/// Show system status (full details) /// Show system status (full details)
Status, Status,
@ -116,6 +159,26 @@ enum Commands {
#[command(subcommand)] #[command(subcommand)]
skill_command: SkillCommands, skill_command: SkillCommands,
}, },
/// Migrate data from other agent runtimes
Migrate {
#[command(subcommand)]
migrate_command: MigrateCommands,
},
}
#[derive(Subcommand, Debug)]
enum MigrateCommands {
/// Import memory from an OpenClaw workspace into this ZeroClaw workspace
Openclaw {
/// Optional path to OpenClaw workspace (defaults to ~/.openclaw/workspace)
#[arg(long)]
source: Option<std::path::PathBuf>,
/// Validate and preview migration without writing any data
#[arg(long)]
dry_run: bool,
},
} }
#[derive(Subcommand, Debug)] #[derive(Subcommand, Debug)]
@ -198,11 +261,21 @@ async fn main() -> Result<()> {
// Onboard runs quick setup by default, or the interactive wizard with --interactive // Onboard runs quick setup by default, or the interactive wizard with --interactive
if let Commands::Onboard { if let Commands::Onboard {
interactive, interactive,
channels_only,
api_key, api_key,
provider, provider,
} = &cli.command } = &cli.command
{ {
let config = if *interactive { if *interactive && *channels_only {
bail!("Use either --interactive or --channels-only, not both");
}
if *channels_only && (api_key.is_some() || provider.is_some()) {
bail!("--channels-only does not accept --api-key or --provider");
}
let config = if *channels_only {
onboard::run_channels_repair_wizard()?
} else if *interactive {
onboard::run_wizard()? onboard::run_wizard()?
} else { } else {
onboard::run_quick_setup(api_key.as_deref(), provider.as_deref())? onboard::run_quick_setup(api_key.as_deref(), provider.as_deref())?
@ -236,6 +309,15 @@ async fn main() -> Result<()> {
gateway::run_gateway(&host, port, config).await gateway::run_gateway(&host, port, config).await
} }
Commands::Daemon { port, host } => {
if port == 0 {
info!("🧠 Starting ZeroClaw Daemon on {host} (random port)");
} else {
info!("🧠 Starting ZeroClaw Daemon on {host}:{port}");
}
daemon::run(config, host, port).await
}
Commands::Status => { Commands::Status => {
println!("🦀 ZeroClaw Status"); println!("🦀 ZeroClaw Status");
println!(); println!();
@ -307,6 +389,10 @@ async fn main() -> Result<()> {
Commands::Cron { cron_command } => cron::handle_command(cron_command, config), Commands::Cron { cron_command } => cron::handle_command(cron_command, config),
Commands::Service { service_command } => service::handle_command(service_command, &config),
Commands::Doctor => doctor::run(&config),
Commands::Channel { channel_command } => match channel_command { Commands::Channel { channel_command } => match channel_command {
ChannelCommands::Start => channels::start_channels(config).await, ChannelCommands::Start => channels::start_channels(config).await,
ChannelCommands::Doctor => channels::doctor_channels(config).await, ChannelCommands::Doctor => channels::doctor_channels(config).await,
@ -320,5 +406,20 @@ async fn main() -> Result<()> {
Commands::Skills { skill_command } => { Commands::Skills { skill_command } => {
skills::handle_command(skill_command, &config.workspace_dir) skills::handle_command(skill_command, &config.workspace_dir)
} }
Commands::Migrate { migrate_command } => {
migration::handle_command(migrate_command, &config).await
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use clap::CommandFactory;
#[test]
fn cli_definition_has_no_flag_conflicts() {
Cli::command().debug_assert();
} }
} }

538
src/memory/hygiene.rs Normal file
View file

@ -0,0 +1,538 @@
use crate::config::MemoryConfig;
use anyhow::Result;
use chrono::{DateTime, Duration, Local, NaiveDate, Utc};
use rusqlite::{params, Connection};
use serde::{Deserialize, Serialize};
use std::fs;
use std::path::{Path, PathBuf};
use std::time::{Duration as StdDuration, SystemTime};
const HYGIENE_INTERVAL_HOURS: i64 = 12;
const STATE_FILE: &str = "memory_hygiene_state.json";
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
struct HygieneReport {
archived_memory_files: u64,
archived_session_files: u64,
purged_memory_archives: u64,
purged_session_archives: u64,
pruned_conversation_rows: u64,
}
impl HygieneReport {
fn total_actions(&self) -> u64 {
self.archived_memory_files
+ self.archived_session_files
+ self.purged_memory_archives
+ self.purged_session_archives
+ self.pruned_conversation_rows
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
struct HygieneState {
last_run_at: Option<String>,
last_report: HygieneReport,
}
/// Run memory/session hygiene if the cadence window has elapsed.
///
/// This function is intentionally best-effort: callers should log and continue on failure.
pub fn run_if_due(config: &MemoryConfig, workspace_dir: &Path) -> Result<()> {
if !config.hygiene_enabled {
return Ok(());
}
if !should_run_now(workspace_dir)? {
return Ok(());
}
let report = HygieneReport {
archived_memory_files: archive_daily_memory_files(
workspace_dir,
config.archive_after_days,
)?,
archived_session_files: archive_session_files(workspace_dir, config.archive_after_days)?,
purged_memory_archives: purge_memory_archives(workspace_dir, config.purge_after_days)?,
purged_session_archives: purge_session_archives(workspace_dir, config.purge_after_days)?,
pruned_conversation_rows: prune_conversation_rows(
workspace_dir,
config.conversation_retention_days,
)?,
};
write_state(workspace_dir, &report)?;
if report.total_actions() > 0 {
tracing::info!(
"memory hygiene complete: archived_memory={} archived_sessions={} purged_memory={} purged_sessions={} pruned_conversation_rows={}",
report.archived_memory_files,
report.archived_session_files,
report.purged_memory_archives,
report.purged_session_archives,
report.pruned_conversation_rows,
);
}
Ok(())
}
fn should_run_now(workspace_dir: &Path) -> Result<bool> {
let path = state_path(workspace_dir);
if !path.exists() {
return Ok(true);
}
let raw = fs::read_to_string(&path)?;
let state: HygieneState = match serde_json::from_str(&raw) {
Ok(s) => s,
Err(_) => return Ok(true),
};
let Some(last_run_at) = state.last_run_at else {
return Ok(true);
};
let last = match DateTime::parse_from_rfc3339(&last_run_at) {
Ok(ts) => ts.with_timezone(&Utc),
Err(_) => return Ok(true),
};
Ok(Utc::now().signed_duration_since(last) >= Duration::hours(HYGIENE_INTERVAL_HOURS))
}
fn write_state(workspace_dir: &Path, report: &HygieneReport) -> Result<()> {
let path = state_path(workspace_dir);
if let Some(parent) = path.parent() {
fs::create_dir_all(parent)?;
}
let state = HygieneState {
last_run_at: Some(Utc::now().to_rfc3339()),
last_report: report.clone(),
};
let json = serde_json::to_vec_pretty(&state)?;
fs::write(path, json)?;
Ok(())
}
fn state_path(workspace_dir: &Path) -> PathBuf {
workspace_dir.join("state").join(STATE_FILE)
}
fn archive_daily_memory_files(workspace_dir: &Path, archive_after_days: u32) -> Result<u64> {
if archive_after_days == 0 {
return Ok(0);
}
let memory_dir = workspace_dir.join("memory");
if !memory_dir.is_dir() {
return Ok(0);
}
let archive_dir = memory_dir.join("archive");
fs::create_dir_all(&archive_dir)?;
let cutoff = Local::now().date_naive() - Duration::days(i64::from(archive_after_days));
let mut moved = 0_u64;
for entry in fs::read_dir(&memory_dir)? {
let entry = entry?;
let path = entry.path();
if path.is_dir() {
continue;
}
if path.extension().and_then(|e| e.to_str()) != Some("md") {
continue;
}
let Some(filename) = path.file_name().and_then(|f| f.to_str()) else {
continue;
};
let Some(file_date) = memory_date_from_filename(filename) else {
continue;
};
if file_date < cutoff {
move_to_archive(&path, &archive_dir)?;
moved += 1;
}
}
Ok(moved)
}
fn archive_session_files(workspace_dir: &Path, archive_after_days: u32) -> Result<u64> {
if archive_after_days == 0 {
return Ok(0);
}
let sessions_dir = workspace_dir.join("sessions");
if !sessions_dir.is_dir() {
return Ok(0);
}
let archive_dir = sessions_dir.join("archive");
fs::create_dir_all(&archive_dir)?;
let cutoff_date = Local::now().date_naive() - Duration::days(i64::from(archive_after_days));
let cutoff_time = SystemTime::now()
.checked_sub(StdDuration::from_secs(
u64::from(archive_after_days) * 24 * 60 * 60,
))
.unwrap_or(SystemTime::UNIX_EPOCH);
let mut moved = 0_u64;
for entry in fs::read_dir(&sessions_dir)? {
let entry = entry?;
let path = entry.path();
if path.is_dir() {
continue;
}
let Some(filename) = path.file_name().and_then(|f| f.to_str()) else {
continue;
};
let is_old = if let Some(date) = date_prefix(filename) {
date < cutoff_date
} else {
is_older_than(&path, cutoff_time)
};
if is_old {
move_to_archive(&path, &archive_dir)?;
moved += 1;
}
}
Ok(moved)
}
fn purge_memory_archives(workspace_dir: &Path, purge_after_days: u32) -> Result<u64> {
if purge_after_days == 0 {
return Ok(0);
}
let archive_dir = workspace_dir.join("memory").join("archive");
if !archive_dir.is_dir() {
return Ok(0);
}
let cutoff = Local::now().date_naive() - Duration::days(i64::from(purge_after_days));
let mut removed = 0_u64;
for entry in fs::read_dir(&archive_dir)? {
let entry = entry?;
let path = entry.path();
if path.is_dir() {
continue;
}
let Some(filename) = path.file_name().and_then(|f| f.to_str()) else {
continue;
};
let Some(file_date) = memory_date_from_filename(filename) else {
continue;
};
if file_date < cutoff {
fs::remove_file(&path)?;
removed += 1;
}
}
Ok(removed)
}
fn purge_session_archives(workspace_dir: &Path, purge_after_days: u32) -> Result<u64> {
if purge_after_days == 0 {
return Ok(0);
}
let archive_dir = workspace_dir.join("sessions").join("archive");
if !archive_dir.is_dir() {
return Ok(0);
}
let cutoff_date = Local::now().date_naive() - Duration::days(i64::from(purge_after_days));
let cutoff_time = SystemTime::now()
.checked_sub(StdDuration::from_secs(
u64::from(purge_after_days) * 24 * 60 * 60,
))
.unwrap_or(SystemTime::UNIX_EPOCH);
let mut removed = 0_u64;
for entry in fs::read_dir(&archive_dir)? {
let entry = entry?;
let path = entry.path();
if path.is_dir() {
continue;
}
let Some(filename) = path.file_name().and_then(|f| f.to_str()) else {
continue;
};
let is_old = if let Some(date) = date_prefix(filename) {
date < cutoff_date
} else {
is_older_than(&path, cutoff_time)
};
if is_old {
fs::remove_file(&path)?;
removed += 1;
}
}
Ok(removed)
}
fn prune_conversation_rows(workspace_dir: &Path, retention_days: u32) -> Result<u64> {
if retention_days == 0 {
return Ok(0);
}
let db_path = workspace_dir.join("memory").join("brain.db");
if !db_path.exists() {
return Ok(0);
}
let conn = Connection::open(db_path)?;
let cutoff = (Local::now() - Duration::days(i64::from(retention_days))).to_rfc3339();
let affected = conn.execute(
"DELETE FROM memories WHERE category = 'conversation' AND updated_at < ?1",
params![cutoff],
)?;
Ok(u64::try_from(affected).unwrap_or(0))
}
fn memory_date_from_filename(filename: &str) -> Option<NaiveDate> {
let stem = filename.strip_suffix(".md")?;
let date_part = stem.split('_').next().unwrap_or(stem);
NaiveDate::parse_from_str(date_part, "%Y-%m-%d").ok()
}
fn date_prefix(filename: &str) -> Option<NaiveDate> {
if filename.len() < 10 {
return None;
}
NaiveDate::parse_from_str(&filename[..10], "%Y-%m-%d").ok()
}
fn is_older_than(path: &Path, cutoff: SystemTime) -> bool {
fs::metadata(path)
.and_then(|meta| meta.modified())
.map(|modified| modified < cutoff)
.unwrap_or(false)
}
fn move_to_archive(src: &Path, archive_dir: &Path) -> Result<()> {
let Some(filename) = src.file_name().and_then(|f| f.to_str()) else {
return Ok(());
};
let target = unique_archive_target(archive_dir, filename);
fs::rename(src, target)?;
Ok(())
}
fn unique_archive_target(archive_dir: &Path, filename: &str) -> PathBuf {
let direct = archive_dir.join(filename);
if !direct.exists() {
return direct;
}
let (stem, ext) = split_name(filename);
for i in 1..10_000 {
let candidate = if ext.is_empty() {
archive_dir.join(format!("{stem}_{i}"))
} else {
archive_dir.join(format!("{stem}_{i}.{ext}"))
};
if !candidate.exists() {
return candidate;
}
}
direct
}
fn split_name(filename: &str) -> (&str, &str) {
match filename.rsplit_once('.') {
Some((stem, ext)) => (stem, ext),
None => (filename, ""),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::memory::{Memory, MemoryCategory, SqliteMemory};
use tempfile::TempDir;
fn default_cfg() -> MemoryConfig {
MemoryConfig::default()
}
#[test]
fn archives_old_daily_memory_files() {
let tmp = TempDir::new().unwrap();
let workspace = tmp.path();
fs::create_dir_all(workspace.join("memory")).unwrap();
let old = (Local::now().date_naive() - Duration::days(10))
.format("%Y-%m-%d")
.to_string();
let today = Local::now().date_naive().format("%Y-%m-%d").to_string();
let old_file = workspace.join("memory").join(format!("{old}.md"));
let today_file = workspace.join("memory").join(format!("{today}.md"));
fs::write(&old_file, "old note").unwrap();
fs::write(&today_file, "fresh note").unwrap();
run_if_due(&default_cfg(), workspace).unwrap();
assert!(!old_file.exists(), "old daily file should be archived");
assert!(
workspace
.join("memory")
.join("archive")
.join(format!("{old}.md"))
.exists(),
"old daily file should exist in memory/archive"
);
assert!(today_file.exists(), "today file should remain in place");
}
#[test]
fn archives_old_session_files() {
let tmp = TempDir::new().unwrap();
let workspace = tmp.path();
fs::create_dir_all(workspace.join("sessions")).unwrap();
let old = (Local::now().date_naive() - Duration::days(10))
.format("%Y-%m-%d")
.to_string();
let old_name = format!("{old}-agent.log");
let old_file = workspace.join("sessions").join(&old_name);
fs::write(&old_file, "old session").unwrap();
run_if_due(&default_cfg(), workspace).unwrap();
assert!(!old_file.exists(), "old session file should be archived");
assert!(
workspace
.join("sessions")
.join("archive")
.join(&old_name)
.exists(),
"archived session file should exist"
);
}
#[test]
fn skips_second_run_within_cadence_window() {
let tmp = TempDir::new().unwrap();
let workspace = tmp.path();
fs::create_dir_all(workspace.join("memory")).unwrap();
let old_a = (Local::now().date_naive() - Duration::days(10))
.format("%Y-%m-%d")
.to_string();
let file_a = workspace.join("memory").join(format!("{old_a}.md"));
fs::write(&file_a, "first").unwrap();
run_if_due(&default_cfg(), workspace).unwrap();
assert!(!file_a.exists(), "first old file should be archived");
let old_b = (Local::now().date_naive() - Duration::days(9))
.format("%Y-%m-%d")
.to_string();
let file_b = workspace.join("memory").join(format!("{old_b}.md"));
fs::write(&file_b, "second").unwrap();
// Should skip because cadence gate prevents a second immediate run.
run_if_due(&default_cfg(), workspace).unwrap();
assert!(
file_b.exists(),
"second file should remain because run is throttled"
);
}
#[test]
fn purges_old_memory_archives() {
let tmp = TempDir::new().unwrap();
let workspace = tmp.path();
let archive_dir = workspace.join("memory").join("archive");
fs::create_dir_all(&archive_dir).unwrap();
let old = (Local::now().date_naive() - Duration::days(40))
.format("%Y-%m-%d")
.to_string();
let keep = (Local::now().date_naive() - Duration::days(5))
.format("%Y-%m-%d")
.to_string();
let old_file = archive_dir.join(format!("{old}.md"));
let keep_file = archive_dir.join(format!("{keep}.md"));
fs::write(&old_file, "expired").unwrap();
fs::write(&keep_file, "recent").unwrap();
run_if_due(&default_cfg(), workspace).unwrap();
assert!(!old_file.exists(), "old archived file should be purged");
assert!(keep_file.exists(), "recent archived file should remain");
}
#[tokio::test]
async fn prunes_old_conversation_rows_in_sqlite_backend() {
let tmp = TempDir::new().unwrap();
let workspace = tmp.path();
let mem = SqliteMemory::new(workspace).unwrap();
mem.store("conv_old", "outdated", MemoryCategory::Conversation)
.await
.unwrap();
mem.store("core_keep", "durable", MemoryCategory::Core)
.await
.unwrap();
drop(mem);
let db_path = workspace.join("memory").join("brain.db");
let conn = Connection::open(&db_path).unwrap();
let old_cutoff = (Local::now() - Duration::days(60)).to_rfc3339();
conn.execute(
"UPDATE memories SET created_at = ?1, updated_at = ?1 WHERE key = 'conv_old'",
params![old_cutoff],
)
.unwrap();
drop(conn);
let mut cfg = default_cfg();
cfg.archive_after_days = 0;
cfg.purge_after_days = 0;
cfg.conversation_retention_days = 30;
run_if_due(&cfg, workspace).unwrap();
let mem2 = SqliteMemory::new(workspace).unwrap();
assert!(
mem2.get("conv_old").await.unwrap().is_none(),
"old conversation rows should be pruned"
);
assert!(
mem2.get("core_keep").await.unwrap().is_some(),
"core memory should remain"
);
}
}

View file

@ -1,5 +1,6 @@
pub mod chunker; pub mod chunker;
pub mod embeddings; pub mod embeddings;
pub mod hygiene;
pub mod markdown; pub mod markdown;
pub mod sqlite; pub mod sqlite;
pub mod traits; pub mod traits;
@ -21,6 +22,11 @@ pub fn create_memory(
workspace_dir: &Path, workspace_dir: &Path,
api_key: Option<&str>, api_key: Option<&str>,
) -> anyhow::Result<Box<dyn Memory>> { ) -> anyhow::Result<Box<dyn Memory>> {
// Best-effort memory hygiene/retention pass (throttled by state file).
if let Err(e) = hygiene::run_if_due(config, workspace_dir) {
tracing::warn!("memory hygiene skipped: {e}");
}
match config.backend.as_str() { match config.backend.as_str() {
"sqlite" => { "sqlite" => {
let embedder: Arc<dyn embeddings::EmbeddingProvider> = let embedder: Arc<dyn embeddings::EmbeddingProvider> =

553
src/migration.rs Normal file
View file

@ -0,0 +1,553 @@
use crate::config::Config;
use crate::memory::{MarkdownMemory, Memory, MemoryCategory, SqliteMemory};
use anyhow::{bail, Context, Result};
use directories::UserDirs;
use rusqlite::{Connection, OpenFlags, OptionalExtension};
use std::collections::HashSet;
use std::fs;
use std::path::{Path, PathBuf};
#[derive(Debug, Clone)]
struct SourceEntry {
key: String,
content: String,
category: MemoryCategory,
}
#[derive(Debug, Default)]
struct MigrationStats {
from_sqlite: usize,
from_markdown: usize,
imported: usize,
skipped_unchanged: usize,
renamed_conflicts: usize,
}
pub async fn handle_command(command: super::MigrateCommands, config: &Config) -> Result<()> {
match command {
super::MigrateCommands::Openclaw { source, dry_run } => {
migrate_openclaw_memory(config, source, dry_run).await
}
}
}
async fn migrate_openclaw_memory(
config: &Config,
source_workspace: Option<PathBuf>,
dry_run: bool,
) -> Result<()> {
let source_workspace = resolve_openclaw_workspace(source_workspace)?;
if !source_workspace.exists() {
bail!(
"OpenClaw workspace not found at {}. Pass --source <path> if needed.",
source_workspace.display()
);
}
if paths_equal(&source_workspace, &config.workspace_dir) {
bail!("Source workspace matches current ZeroClaw workspace; refusing self-migration");
}
let mut stats = MigrationStats::default();
let entries = collect_source_entries(&source_workspace, &mut stats)?;
if entries.is_empty() {
println!(
"No importable memory found in {}",
source_workspace.display()
);
println!("Checked for: memory/brain.db, MEMORY.md, memory/*.md");
return Ok(());
}
if dry_run {
println!("🔎 Dry run: OpenClaw migration preview");
println!(" Source: {}", source_workspace.display());
println!(" Target: {}", config.workspace_dir.display());
println!(" Candidates: {}", entries.len());
println!(" - from sqlite: {}", stats.from_sqlite);
println!(" - from markdown: {}", stats.from_markdown);
println!();
println!("Run without --dry-run to import these entries.");
return Ok(());
}
if let Some(backup_dir) = backup_target_memory(&config.workspace_dir)? {
println!("🛟 Backup created: {}", backup_dir.display());
}
let memory = target_memory_backend(config)?;
for (idx, entry) in entries.into_iter().enumerate() {
let mut key = entry.key.trim().to_string();
if key.is_empty() {
key = format!("openclaw_{idx}");
}
if let Some(existing) = memory.get(&key).await? {
if existing.content.trim() == entry.content.trim() {
stats.skipped_unchanged += 1;
continue;
}
let renamed = next_available_key(memory.as_ref(), &key).await?;
key = renamed;
stats.renamed_conflicts += 1;
}
memory.store(&key, &entry.content, entry.category).await?;
stats.imported += 1;
}
println!("✅ OpenClaw memory migration complete");
println!(" Source: {}", source_workspace.display());
println!(" Target: {}", config.workspace_dir.display());
println!(" Imported: {}", stats.imported);
println!(" Skipped unchanged:{}", stats.skipped_unchanged);
println!(" Renamed conflicts:{}", stats.renamed_conflicts);
println!(" Source sqlite rows:{}", stats.from_sqlite);
println!(" Source markdown: {}", stats.from_markdown);
Ok(())
}
fn target_memory_backend(config: &Config) -> Result<Box<dyn Memory>> {
match config.memory.backend.as_str() {
"sqlite" => Ok(Box::new(SqliteMemory::new(&config.workspace_dir)?)),
"markdown" | "none" => Ok(Box::new(MarkdownMemory::new(&config.workspace_dir))),
other => {
tracing::warn!(
"Unknown memory backend '{other}' during migration, defaulting to markdown"
);
Ok(Box::new(MarkdownMemory::new(&config.workspace_dir)))
}
}
}
fn collect_source_entries(
source_workspace: &Path,
stats: &mut MigrationStats,
) -> Result<Vec<SourceEntry>> {
let mut entries = Vec::new();
let sqlite_path = source_workspace.join("memory").join("brain.db");
let sqlite_entries = read_openclaw_sqlite_entries(&sqlite_path)?;
stats.from_sqlite = sqlite_entries.len();
entries.extend(sqlite_entries);
let markdown_entries = read_openclaw_markdown_entries(source_workspace)?;
stats.from_markdown = markdown_entries.len();
entries.extend(markdown_entries);
// De-dup exact duplicates to make re-runs deterministic.
let mut seen = HashSet::new();
entries.retain(|entry| {
let sig = format!("{}\u{0}{}\u{0}{}", entry.key, entry.content, entry.category);
seen.insert(sig)
});
Ok(entries)
}
fn read_openclaw_sqlite_entries(db_path: &Path) -> Result<Vec<SourceEntry>> {
if !db_path.exists() {
return Ok(Vec::new());
}
let conn = Connection::open_with_flags(db_path, OpenFlags::SQLITE_OPEN_READ_ONLY)
.with_context(|| format!("Failed to open source db {}", db_path.display()))?;
let table_exists: Option<String> = conn
.query_row(
"SELECT name FROM sqlite_master WHERE type='table' AND name='memories' LIMIT 1",
[],
|row| row.get(0),
)
.optional()?;
if table_exists.is_none() {
return Ok(Vec::new());
}
let columns = table_columns(&conn, "memories")?;
let key_expr = pick_column_expr(&columns, &["key", "id", "name"], "CAST(rowid AS TEXT)");
let Some(content_expr) =
pick_optional_column_expr(&columns, &["content", "value", "text", "memory"])
else {
bail!("OpenClaw memories table found but no content-like column was detected");
};
let category_expr = pick_column_expr(&columns, &["category", "kind", "type"], "'core'");
let sql = format!(
"SELECT {key_expr} AS key, {content_expr} AS content, {category_expr} AS category FROM memories"
);
let mut stmt = conn.prepare(&sql)?;
let mut rows = stmt.query([])?;
let mut entries = Vec::new();
let mut idx = 0_usize;
while let Some(row) = rows.next()? {
let key: String = row
.get(0)
.unwrap_or_else(|_| format!("openclaw_sqlite_{idx}"));
let content: String = row.get(1).unwrap_or_default();
let category_raw: String = row.get(2).unwrap_or_else(|_| "core".to_string());
if content.trim().is_empty() {
continue;
}
entries.push(SourceEntry {
key: normalize_key(&key, idx),
content: content.trim().to_string(),
category: parse_category(&category_raw),
});
idx += 1;
}
Ok(entries)
}
fn read_openclaw_markdown_entries(source_workspace: &Path) -> Result<Vec<SourceEntry>> {
let mut all = Vec::new();
let core_path = source_workspace.join("MEMORY.md");
if core_path.exists() {
let content = fs::read_to_string(&core_path)?;
all.extend(parse_markdown_file(
&core_path,
&content,
MemoryCategory::Core,
"openclaw_core",
));
}
let daily_dir = source_workspace.join("memory");
if daily_dir.exists() {
for file in fs::read_dir(&daily_dir)? {
let file = file?;
let path = file.path();
if path.extension().and_then(|ext| ext.to_str()) != Some("md") {
continue;
}
let content = fs::read_to_string(&path)?;
let stem = path
.file_stem()
.and_then(|s| s.to_str())
.unwrap_or("openclaw_daily");
all.extend(parse_markdown_file(
&path,
&content,
MemoryCategory::Daily,
stem,
));
}
}
Ok(all)
}
fn parse_markdown_file(
_path: &Path,
content: &str,
default_category: MemoryCategory,
stem: &str,
) -> Vec<SourceEntry> {
let mut entries = Vec::new();
for (idx, raw_line) in content.lines().enumerate() {
let trimmed = raw_line.trim();
if trimmed.is_empty() || trimmed.starts_with('#') {
continue;
}
let line = trimmed.strip_prefix("- ").unwrap_or(trimmed);
let (key, text) = match parse_structured_memory_line(line) {
Some((k, v)) => (normalize_key(k, idx), v.trim().to_string()),
None => (
format!("openclaw_{stem}_{}", idx + 1),
line.trim().to_string(),
),
};
if text.is_empty() {
continue;
}
entries.push(SourceEntry {
key,
content: text,
category: default_category.clone(),
});
}
entries
}
fn parse_structured_memory_line(line: &str) -> Option<(&str, &str)> {
if !line.starts_with("**") {
return None;
}
let rest = line.strip_prefix("**")?;
let key_end = rest.find("**:")?;
let key = rest.get(..key_end)?.trim();
let value = rest.get(key_end + 3..)?.trim();
if key.is_empty() || value.is_empty() {
return None;
}
Some((key, value))
}
fn parse_category(raw: &str) -> MemoryCategory {
match raw.trim().to_ascii_lowercase().as_str() {
"core" => MemoryCategory::Core,
"daily" => MemoryCategory::Daily,
"conversation" => MemoryCategory::Conversation,
"" => MemoryCategory::Core,
other => MemoryCategory::Custom(other.to_string()),
}
}
fn normalize_key(key: &str, fallback_idx: usize) -> String {
let trimmed = key.trim();
if trimmed.is_empty() {
return format!("openclaw_{fallback_idx}");
}
trimmed.to_string()
}
async fn next_available_key(memory: &dyn Memory, base: &str) -> Result<String> {
for i in 1..=10_000 {
let candidate = format!("{base}__openclaw_{i}");
if memory.get(&candidate).await?.is_none() {
return Ok(candidate);
}
}
bail!("Unable to allocate non-conflicting key for '{base}'")
}
fn table_columns(conn: &Connection, table: &str) -> Result<Vec<String>> {
let pragma = format!("PRAGMA table_info({table})");
let mut stmt = conn.prepare(&pragma)?;
let rows = stmt.query_map([], |row| row.get::<_, String>(1))?;
let mut cols = Vec::new();
for col in rows {
cols.push(col?.to_ascii_lowercase());
}
Ok(cols)
}
fn pick_optional_column_expr(columns: &[String], candidates: &[&str]) -> Option<String> {
candidates
.iter()
.find(|candidate| columns.iter().any(|c| c == *candidate))
.map(|s| s.to_string())
}
fn pick_column_expr(columns: &[String], candidates: &[&str], fallback: &str) -> String {
pick_optional_column_expr(columns, candidates).unwrap_or_else(|| fallback.to_string())
}
fn resolve_openclaw_workspace(source: Option<PathBuf>) -> Result<PathBuf> {
if let Some(src) = source {
return Ok(src);
}
let home = UserDirs::new()
.map(|u| u.home_dir().to_path_buf())
.context("Could not find home directory")?;
Ok(home.join(".openclaw").join("workspace"))
}
fn paths_equal(a: &Path, b: &Path) -> bool {
match (fs::canonicalize(a), fs::canonicalize(b)) {
(Ok(a), Ok(b)) => a == b,
_ => a == b,
}
}
fn backup_target_memory(workspace_dir: &Path) -> Result<Option<PathBuf>> {
let timestamp = chrono::Local::now().format("%Y%m%d-%H%M%S").to_string();
let backup_root = workspace_dir
.join("memory")
.join("migrations")
.join(format!("openclaw-{timestamp}"));
let mut copied_any = false;
fs::create_dir_all(&backup_root)?;
let files_to_copy = [
workspace_dir.join("memory").join("brain.db"),
workspace_dir.join("MEMORY.md"),
];
for source in files_to_copy {
if source.exists() {
let Some(name) = source.file_name() else {
continue;
};
fs::copy(&source, backup_root.join(name))?;
copied_any = true;
}
}
let daily_dir = workspace_dir.join("memory");
if daily_dir.exists() {
let daily_backup = backup_root.join("daily");
for file in fs::read_dir(&daily_dir)? {
let file = file?;
let path = file.path();
if path.extension().and_then(|ext| ext.to_str()) != Some("md") {
continue;
}
fs::create_dir_all(&daily_backup)?;
let Some(name) = path.file_name() else {
continue;
};
fs::copy(&path, daily_backup.join(name))?;
copied_any = true;
}
}
if copied_any {
Ok(Some(backup_root))
} else {
let _ = fs::remove_dir_all(&backup_root);
Ok(None)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::{Config, MemoryConfig};
use rusqlite::params;
use tempfile::TempDir;
fn test_config(workspace: &Path) -> Config {
Config {
workspace_dir: workspace.to_path_buf(),
config_path: workspace.join("config.toml"),
memory: MemoryConfig {
backend: "sqlite".to_string(),
..MemoryConfig::default()
},
..Config::default()
}
}
#[test]
fn parse_structured_markdown_line() {
let line = "**user_pref**: likes Rust";
let parsed = parse_structured_memory_line(line).unwrap();
assert_eq!(parsed.0, "user_pref");
assert_eq!(parsed.1, "likes Rust");
}
#[test]
fn parse_unstructured_markdown_generates_key() {
let entries = parse_markdown_file(
Path::new("/tmp/MEMORY.md"),
"- plain note",
MemoryCategory::Core,
"core",
);
assert_eq!(entries.len(), 1);
assert!(entries[0].key.starts_with("openclaw_core_"));
assert_eq!(entries[0].content, "plain note");
}
#[test]
fn sqlite_reader_supports_legacy_value_column() {
let dir = TempDir::new().unwrap();
let db_path = dir.path().join("brain.db");
let conn = Connection::open(&db_path).unwrap();
conn.execute_batch("CREATE TABLE memories (key TEXT, value TEXT, type TEXT);")
.unwrap();
conn.execute(
"INSERT INTO memories (key, value, type) VALUES (?1, ?2, ?3)",
params!["legacy_key", "legacy_value", "daily"],
)
.unwrap();
let rows = read_openclaw_sqlite_entries(&db_path).unwrap();
assert_eq!(rows.len(), 1);
assert_eq!(rows[0].key, "legacy_key");
assert_eq!(rows[0].content, "legacy_value");
assert_eq!(rows[0].category, MemoryCategory::Daily);
}
#[tokio::test]
async fn migration_renames_conflicting_key() {
let source = TempDir::new().unwrap();
let target = TempDir::new().unwrap();
// Existing target memory
let target_mem = SqliteMemory::new(target.path()).unwrap();
target_mem
.store("k", "new value", MemoryCategory::Core)
.await
.unwrap();
// Source sqlite with conflicting key + different content
let source_db_dir = source.path().join("memory");
fs::create_dir_all(&source_db_dir).unwrap();
let source_db = source_db_dir.join("brain.db");
let conn = Connection::open(&source_db).unwrap();
conn.execute_batch("CREATE TABLE memories (key TEXT, content TEXT, category TEXT);")
.unwrap();
conn.execute(
"INSERT INTO memories (key, content, category) VALUES (?1, ?2, ?3)",
params!["k", "old value", "core"],
)
.unwrap();
let config = test_config(target.path());
migrate_openclaw_memory(&config, Some(source.path().to_path_buf()), false)
.await
.unwrap();
let all = target_mem.list(None).await.unwrap();
assert!(all.iter().any(|e| e.key == "k" && e.content == "new value"));
assert!(all
.iter()
.any(|e| e.key.starts_with("k__openclaw_") && e.content == "old value"));
}
#[tokio::test]
async fn dry_run_does_not_write() {
let source = TempDir::new().unwrap();
let target = TempDir::new().unwrap();
let source_db_dir = source.path().join("memory");
fs::create_dir_all(&source_db_dir).unwrap();
let source_db = source_db_dir.join("brain.db");
let conn = Connection::open(&source_db).unwrap();
conn.execute_batch("CREATE TABLE memories (key TEXT, content TEXT, category TEXT);")
.unwrap();
conn.execute(
"INSERT INTO memories (key, content, category) VALUES (?1, ?2, ?3)",
params!["dry", "run", "core"],
)
.unwrap();
let config = test_config(target.path());
migrate_openclaw_memory(&config, Some(source.path().to_path_buf()), true)
.await
.unwrap();
let target_mem = SqliteMemory::new(target.path()).unwrap();
assert_eq!(target_mem.count().await.unwrap(), 0);
}
}

View file

@ -1,3 +1,3 @@
pub mod wizard; pub mod wizard;
pub use wizard::{run_quick_setup, run_wizard}; pub use wizard::{run_channels_repair_wizard, run_quick_setup, run_wizard};

View file

@ -1,3 +1,4 @@
use crate::config::schema::WhatsAppConfig;
use crate::config::{ use crate::config::{
AutonomyConfig, BrowserConfig, ChannelsConfig, ComposioConfig, Config, DiscordConfig, AutonomyConfig, BrowserConfig, ChannelsConfig, ComposioConfig, Config, DiscordConfig,
HeartbeatConfig, IMessageConfig, MatrixConfig, MemoryConfig, ObservabilityConfig, HeartbeatConfig, IMessageConfig, MatrixConfig, MemoryConfig, ObservabilityConfig,
@ -91,6 +92,7 @@ pub fn run_wizard() -> Result<Config> {
observability: ObservabilityConfig::default(), observability: ObservabilityConfig::default(),
autonomy: AutonomyConfig::default(), autonomy: AutonomyConfig::default(),
runtime: RuntimeConfig::default(), runtime: RuntimeConfig::default(),
reliability: crate::config::ReliabilityConfig::default(),
heartbeat: HeartbeatConfig::default(), heartbeat: HeartbeatConfig::default(),
channels_config, channels_config,
memory: MemoryConfig::default(), // SQLite + auto-save by default memory: MemoryConfig::default(), // SQLite + auto-save by default
@ -99,6 +101,7 @@ pub fn run_wizard() -> Result<Config> {
composio: composio_config, composio: composio_config,
secrets: secrets_config, secrets: secrets_config,
browser: BrowserConfig::default(), browser: BrowserConfig::default(),
identity: crate::config::IdentityConfig::default(),
}; };
println!( println!(
@ -149,6 +152,61 @@ pub fn run_wizard() -> Result<Config> {
Ok(config) Ok(config)
} }
/// Interactive repair flow: rerun channel setup only without redoing full onboarding.
pub fn run_channels_repair_wizard() -> Result<Config> {
println!("{}", style(BANNER).cyan().bold());
println!(
" {}",
style("Channels Repair — update channel tokens and allowlists only")
.white()
.bold()
);
println!();
let mut config = Config::load_or_init()?;
print_step(1, 1, "Channels (How You Talk to ZeroClaw)");
config.channels_config = setup_channels()?;
config.save()?;
println!();
println!(
" {} Channel config saved: {}",
style("").green().bold(),
style(config.config_path.display()).green()
);
let has_channels = config.channels_config.telegram.is_some()
|| config.channels_config.discord.is_some()
|| config.channels_config.slack.is_some()
|| config.channels_config.imessage.is_some()
|| config.channels_config.matrix.is_some();
if has_channels && config.api_key.is_some() {
let launch: bool = Confirm::new()
.with_prompt(format!(
" {} Launch channels now? (connected channels → AI → reply)",
style("🚀").cyan()
))
.default(true)
.interact()?;
if launch {
println!();
println!(
" {} {}",
style("").cyan(),
style("Starting channel server...").white().bold()
);
println!();
// Signal to main.rs to call start_channels after wizard returns
std::env::set_var("ZEROCLAW_AUTOSTART_CHANNELS", "1");
}
}
Ok(config)
}
// ── Quick setup (zero prompts) ─────────────────────────────────── // ── Quick setup (zero prompts) ───────────────────────────────────
/// Non-interactive setup: generates a sensible default config instantly. /// Non-interactive setup: generates a sensible default config instantly.
@ -187,6 +245,7 @@ pub fn run_quick_setup(api_key: Option<&str>, provider: Option<&str>) -> Result<
observability: ObservabilityConfig::default(), observability: ObservabilityConfig::default(),
autonomy: AutonomyConfig::default(), autonomy: AutonomyConfig::default(),
runtime: RuntimeConfig::default(), runtime: RuntimeConfig::default(),
reliability: crate::config::ReliabilityConfig::default(),
heartbeat: HeartbeatConfig::default(), heartbeat: HeartbeatConfig::default(),
channels_config: ChannelsConfig::default(), channels_config: ChannelsConfig::default(),
memory: MemoryConfig::default(), memory: MemoryConfig::default(),
@ -195,6 +254,7 @@ pub fn run_quick_setup(api_key: Option<&str>, provider: Option<&str>) -> Result<
composio: ComposioConfig::default(), composio: ComposioConfig::default(),
secrets: SecretsConfig::default(), secrets: SecretsConfig::default(),
browser: BrowserConfig::default(), browser: BrowserConfig::default(),
identity: crate::config::IdentityConfig::default(),
}; };
config.save()?; config.save()?;
@ -204,7 +264,9 @@ pub fn run_quick_setup(api_key: Option<&str>, provider: Option<&str>) -> Result<
user_name: std::env::var("USER").unwrap_or_else(|_| "User".into()), user_name: std::env::var("USER").unwrap_or_else(|_| "User".into()),
timezone: "UTC".into(), timezone: "UTC".into(),
agent_name: "ZeroClaw".into(), agent_name: "ZeroClaw".into(),
communication_style: "Direct and concise".into(), communication_style:
"Be warm, natural, and clear. Use occasional relevant emojis (1-2 max) and avoid robotic phrasing."
.into(),
}; };
scaffold_workspace(&workspace_dir, &default_ctx)?; scaffold_workspace(&workspace_dir, &default_ctx)?;
@ -878,24 +940,33 @@ fn setup_project_context() -> Result<ProjectContext> {
let style_options = vec![ let style_options = vec![
"Direct & concise — skip pleasantries, get to the point", "Direct & concise — skip pleasantries, get to the point",
"Friendly & casual — warm but efficient", "Friendly & casual — warm, human, and helpful",
"Professional & polished — calm, confident, and clear",
"Expressive & playful — more personality + natural emojis",
"Technical & detailed — thorough explanations, code-first", "Technical & detailed — thorough explanations, code-first",
"Balanced — adapt to the situation", "Balanced — adapt to the situation",
"Custom — write your own style guide",
]; ];
let style_idx = Select::new() let style_idx = Select::new()
.with_prompt(" Communication style") .with_prompt(" Communication style")
.items(&style_options) .items(&style_options)
.default(0) .default(1)
.interact()?; .interact()?;
let communication_style = match style_idx { let communication_style = match style_idx {
0 => "Be direct and concise. Skip pleasantries. Get to the point.".to_string(), 0 => "Be direct and concise. Skip pleasantries. Get to the point.".to_string(),
1 => "Be friendly and casual. Warm but efficient.".to_string(), 1 => "Be friendly, human, and conversational. Show warmth and empathy while staying efficient. Use natural contractions.".to_string(),
2 => "Be technical and detailed. Thorough explanations, code-first.".to_string(), 2 => "Be professional and polished. Stay calm, structured, and respectful. Use occasional tone-setting emojis only when appropriate.".to_string(),
_ => { 3 => "Be expressive and playful when appropriate. Use relevant emojis naturally (0-2 max), and keep serious topics emoji-light.".to_string(),
"Adapt to the situation. Be concise when needed, thorough when it matters.".to_string() 4 => "Be technical and detailed. Thorough explanations, code-first.".to_string(),
} 5 => "Adapt to the situation. Default to warm and clear communication; be concise when needed, thorough when it matters.".to_string(),
_ => Input::new()
.with_prompt(" Custom communication style")
.default(
"Be warm, natural, and clear. Use occasional relevant emojis (1-2 max) and avoid robotic phrasing.".into(),
)
.interact_text()?,
}; };
println!( println!(
@ -931,6 +1002,7 @@ fn setup_channels() -> Result<ChannelsConfig> {
webhook: None, webhook: None,
imessage: None, imessage: None,
matrix: None, matrix: None,
whatsapp: None,
}; };
loop { loop {
@ -975,6 +1047,14 @@ fn setup_channels() -> Result<ChannelsConfig> {
"— self-hosted chat" "— self-hosted chat"
} }
), ),
format!(
"WhatsApp {}",
if config.whatsapp.is_some() {
"✅ connected"
} else {
"— Business Cloud API"
}
),
format!( format!(
"Webhook {}", "Webhook {}",
if config.webhook.is_some() { if config.webhook.is_some() {
@ -989,7 +1069,7 @@ fn setup_channels() -> Result<ChannelsConfig> {
let choice = Select::new() let choice = Select::new()
.with_prompt(" Connect a channel (or Done to continue)") .with_prompt(" Connect a channel (or Done to continue)")
.items(&options) .items(&options)
.default(6) .default(7)
.interact()?; .interact()?;
match choice { match choice {
@ -1041,17 +1121,38 @@ fn setup_channels() -> Result<ChannelsConfig> {
} }
} }
print_bullet(
"Allowlist your own Telegram identity first (recommended for secure + fast setup).",
);
print_bullet(
"Use your @username without '@' (example: argenis), or your numeric Telegram user ID.",
);
print_bullet("Use '*' only for temporary open testing.");
let users_str: String = Input::new() let users_str: String = Input::new()
.with_prompt(" Allowed usernames (comma-separated, or * for all)") .with_prompt(
.default("*".into()) " Allowed Telegram identities (comma-separated: username without '@' and/or numeric user ID, '*' for all)",
)
.allow_empty(true)
.interact_text()?; .interact_text()?;
let allowed_users = if users_str.trim() == "*" { let allowed_users = if users_str.trim() == "*" {
vec!["*".into()] vec!["*".into()]
} else { } else {
users_str.split(',').map(|s| s.trim().to_string()).collect() users_str
.split(',')
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty())
.collect()
}; };
if allowed_users.is_empty() {
println!(
" {} No users allowlisted — Telegram inbound messages will be denied until you add your username/user ID or '*'.",
style("").yellow().bold()
);
}
config.telegram = Some(TelegramConfig { config.telegram = Some(TelegramConfig {
bot_token: token, bot_token: token,
allowed_users, allowed_users,
@ -1111,9 +1212,15 @@ fn setup_channels() -> Result<ChannelsConfig> {
.allow_empty(true) .allow_empty(true)
.interact_text()?; .interact_text()?;
print_bullet("Allowlist your own Discord user ID first (recommended).");
print_bullet(
"Get it in Discord: Settings -> Advanced -> Developer Mode (ON), then right-click your profile -> Copy User ID.",
);
print_bullet("Use '*' only for temporary open testing.");
let allowed_users_str: String = Input::new() let allowed_users_str: String = Input::new()
.with_prompt( .with_prompt(
" Allowed Discord user IDs (comma-separated, '*' for all, Enter to deny all)", " Allowed Discord user IDs (comma-separated, recommended: your own ID, '*' for all)",
) )
.allow_empty(true) .allow_empty(true)
.interact_text()?; .interact_text()?;
@ -1214,9 +1321,15 @@ fn setup_channels() -> Result<ChannelsConfig> {
.allow_empty(true) .allow_empty(true)
.interact_text()?; .interact_text()?;
print_bullet("Allowlist your own Slack member ID first (recommended).");
print_bullet(
"Member IDs usually start with 'U' (open your Slack profile -> More -> Copy member ID).",
);
print_bullet("Use '*' only for temporary open testing.");
let allowed_users_str: String = Input::new() let allowed_users_str: String = Input::new()
.with_prompt( .with_prompt(
" Allowed Slack user IDs (comma-separated, '*' for all, Enter to deny all)", " Allowed Slack user IDs (comma-separated, recommended: your own member ID, '*' for all)",
) )
.allow_empty(true) .allow_empty(true)
.interact_text()?; .interact_text()?;
@ -1378,6 +1491,90 @@ fn setup_channels() -> Result<ChannelsConfig> {
}); });
} }
5 => { 5 => {
// ── WhatsApp ──
println!();
println!(
" {} {}",
style("WhatsApp Setup").white().bold(),
style("— Business Cloud API").dim()
);
print_bullet("1. Go to developers.facebook.com and create a WhatsApp app");
print_bullet("2. Add the WhatsApp product and get your phone number ID");
print_bullet("3. Generate a temporary access token (System User)");
print_bullet("4. Configure webhook URL to: https://your-domain/whatsapp");
println!();
let access_token: String = Input::new()
.with_prompt(" Access token (from Meta Developers)")
.interact_text()?;
if access_token.trim().is_empty() {
println!(" {} Skipped", style("").dim());
continue;
}
let phone_number_id: String = Input::new()
.with_prompt(" Phone number ID (from WhatsApp app settings)")
.interact_text()?;
if phone_number_id.trim().is_empty() {
println!(" {} Skipped — phone number ID required", style("").dim());
continue;
}
let verify_token: String = Input::new()
.with_prompt(" Webhook verify token (create your own)")
.default("zeroclaw-whatsapp-verify".into())
.interact_text()?;
// Test connection
print!(" {} Testing connection... ", style("").dim());
let client = reqwest::blocking::Client::new();
let url = format!(
"https://graph.facebook.com/v18.0/{}",
phone_number_id.trim()
);
match client
.get(&url)
.header("Authorization", format!("Bearer {}", access_token.trim()))
.send()
{
Ok(resp) if resp.status().is_success() => {
println!(
"\r {} Connected to WhatsApp API ",
style("").green().bold()
);
}
_ => {
println!(
"\r {} Connection failed — check access token and phone number ID",
style("").red().bold()
);
continue;
}
}
let users_str: String = Input::new()
.with_prompt(
" Allowed phone numbers (comma-separated +1234567890, or * for all)",
)
.default("*".into())
.interact_text()?;
let allowed_numbers = if users_str.trim() == "*" {
vec!["*".into()]
} else {
users_str.split(',').map(|s| s.trim().to_string()).collect()
};
config.whatsapp = Some(WhatsAppConfig {
access_token: access_token.trim().to_string(),
phone_number_id: phone_number_id.trim().to_string(),
verify_token: verify_token.trim().to_string(),
allowed_numbers,
});
}
6 => {
// ── Webhook ── // ── Webhook ──
println!(); println!();
println!( println!(
@ -1432,6 +1629,9 @@ fn setup_channels() -> Result<ChannelsConfig> {
if config.matrix.is_some() { if config.matrix.is_some() {
active.push("Matrix"); active.push("Matrix");
} }
if config.whatsapp.is_some() {
active.push("WhatsApp");
}
if config.webhook.is_some() { if config.webhook.is_some() {
active.push("Webhook"); active.push("Webhook");
} }
@ -1618,7 +1818,7 @@ fn scaffold_workspace(workspace_dir: &Path, ctx: &ProjectContext) -> Result<()>
&ctx.timezone &ctx.timezone
}; };
let comm_style = if ctx.communication_style.is_empty() { let comm_style = if ctx.communication_style.is_empty() {
"Adapt to the situation. Be concise when needed, thorough when it matters." "Be warm, natural, and clear. Use occasional relevant emojis (1-2 max) and avoid robotic phrasing."
} else { } else {
&ctx.communication_style &ctx.communication_style
}; };
@ -1667,6 +1867,14 @@ fn scaffold_workspace(workspace_dir: &Path, ctx: &ProjectContext) -> Result<()>
## Tools & Skills\n\n\ ## Tools & Skills\n\n\
Skills are listed in the system prompt. Use `read` on a skill's SKILL.md for details.\n\ Skills are listed in the system prompt. Use `read` on a skill's SKILL.md for details.\n\
Keep local notes (SSH hosts, device names, etc.) in `TOOLS.md`.\n\n\ Keep local notes (SSH hosts, device names, etc.) in `TOOLS.md`.\n\n\
## Crash Recovery\n\n\
- If a run stops unexpectedly, recover context before acting.\n\
- Check `MEMORY.md` + latest `memory/*.md` notes to avoid duplicate work.\n\
- Resume from the last confirmed step, not from scratch.\n\n\
## Sub-task Scoping\n\n\
- Break complex work into focused sub-tasks with clear success criteria.\n\
- Keep sub-tasks small, verify each output, then merge results.\n\
- Prefer one clear objective per sub-task over broad \"do everything\" asks.\n\n\
## Make It Yours\n\n\ ## Make It Yours\n\n\
This is a starting point. Add your own conventions, style, and rules.\n" This is a starting point. Add your own conventions, style, and rules.\n"
); );
@ -1704,6 +1912,11 @@ fn scaffold_workspace(workspace_dir: &Path, ctx: &ProjectContext) -> Result<()>
- Always introduce yourself as {agent} if asked\n\n\ - Always introduce yourself as {agent} if asked\n\n\
## Communication\n\n\ ## Communication\n\n\
{comm_style}\n\n\ {comm_style}\n\n\
- Sound like a real person, not a support script.\n\
- Mirror the user's energy: calm when serious, upbeat when casual.\n\
- Use emojis naturally (0-2 max when they help tone, not every sentence).\n\
- Match emoji density to the user. Formal user => minimal/no emojis.\n\
- Prefer specific, grounded phrasing over generic filler.\n\n\
## Boundaries\n\n\ ## Boundaries\n\n\
- Private things stay private. Period.\n\ - Private things stay private. Period.\n\
- When in doubt, ask before acting externally.\n\ - When in doubt, ask before acting externally.\n\
@ -1744,11 +1957,23 @@ fn scaffold_workspace(workspace_dir: &Path, ctx: &ProjectContext) -> Result<()>
- Anything environment-specific\n\n\ - Anything environment-specific\n\n\
## Built-in Tools\n\n\ ## Built-in Tools\n\n\
- **shell** Execute terminal commands\n\ - **shell** Execute terminal commands\n\
- Use when: running local checks, build/test commands, or diagnostics.\n\
- Don't use when: a safer dedicated tool exists, or command is destructive without approval.\n\
- **file_read** Read file contents\n\ - **file_read** Read file contents\n\
- Use when: inspecting project files, configs, or logs.\n\
- Don't use when: you only need a quick string search (prefer targeted search first).\n\
- **file_write** Write file contents\n\ - **file_write** Write file contents\n\
- Use when: applying focused edits, scaffolding files, or updating docs/code.\n\
- Don't use when: unsure about side effects or when the file should remain user-owned.\n\
- **memory_store** Save to memory\n\ - **memory_store** Save to memory\n\
- Use when: preserving durable preferences, decisions, or key context.\n\
- Don't use when: info is transient, noisy, or sensitive without explicit need.\n\
- **memory_recall** Search memory\n\ - **memory_recall** Search memory\n\
- **memory_forget** Delete a memory entry\n\n\ - Use when: you need prior decisions, user preferences, or historical context.\n\
- Don't use when: the answer is already in current files/conversation.\n\
- **memory_forget** Delete a memory entry\n\
- Use when: memory is incorrect, stale, or explicitly requested to be removed.\n\
- Don't use when: uncertain about impact; verify before deleting.\n\n\
---\n\ ---\n\
*Add whatever helps you do your job. This is your cheat sheet.*\n"; *Add whatever helps you do your job. This is your cheat sheet.*\n";
@ -2242,7 +2467,7 @@ mod tests {
let soul = fs::read_to_string(tmp.path().join("SOUL.md")).unwrap(); let soul = fs::read_to_string(tmp.path().join("SOUL.md")).unwrap();
assert!( assert!(
soul.contains("Adapt to the situation"), soul.contains("Be warm, natural, and clear."),
"should default communication style" "should default communication style"
); );
} }
@ -2383,6 +2608,31 @@ mod tests {
"TOOLS.md should list built-in tool: {tool}" "TOOLS.md should list built-in tool: {tool}"
); );
} }
assert!(
tools.contains("Use when:"),
"TOOLS.md should include 'Use when' guidance"
);
assert!(
tools.contains("Don't use when:"),
"TOOLS.md should include 'Don't use when' guidance"
);
}
#[test]
fn soul_md_includes_emoji_awareness_guidance() {
let tmp = TempDir::new().unwrap();
let ctx = ProjectContext::default();
scaffold_workspace(tmp.path(), &ctx).unwrap();
let soul = fs::read_to_string(tmp.path().join("SOUL.md")).unwrap();
assert!(
soul.contains("Use emojis naturally (0-2 max"),
"SOUL.md should include emoji usage guidance"
);
assert!(
soul.contains("Match emoji density to the user"),
"SOUL.md should include emoji-awareness guidance"
);
} }
// ── scaffold_workspace: special characters in names ───────── // ── scaffold_workspace: special characters in names ─────────
@ -2414,7 +2664,9 @@ mod tests {
user_name: "Argenis".into(), user_name: "Argenis".into(),
timezone: "US/Eastern".into(), timezone: "US/Eastern".into(),
agent_name: "Claw".into(), agent_name: "Claw".into(),
communication_style: "Be friendly and casual. Warm but efficient.".into(), communication_style:
"Be friendly, human, and conversational. Show warmth and empathy while staying efficient. Use natural contractions."
.into(),
}; };
scaffold_workspace(tmp.path(), &ctx).unwrap(); scaffold_workspace(tmp.path(), &ctx).unwrap();
@ -2424,12 +2676,12 @@ mod tests {
let soul = fs::read_to_string(tmp.path().join("SOUL.md")).unwrap(); let soul = fs::read_to_string(tmp.path().join("SOUL.md")).unwrap();
assert!(soul.contains("You are **Claw**")); assert!(soul.contains("You are **Claw**"));
assert!(soul.contains("Be friendly and casual")); assert!(soul.contains("Be friendly, human, and conversational"));
let user_md = fs::read_to_string(tmp.path().join("USER.md")).unwrap(); let user_md = fs::read_to_string(tmp.path().join("USER.md")).unwrap();
assert!(user_md.contains("**Name:** Argenis")); assert!(user_md.contains("**Name:** Argenis"));
assert!(user_md.contains("**Timezone:** US/Eastern")); assert!(user_md.contains("**Timezone:** US/Eastern"));
assert!(user_md.contains("Be friendly and casual")); assert!(user_md.contains("Be friendly, human, and conversational"));
let agents = fs::read_to_string(tmp.path().join("AGENTS.md")).unwrap(); let agents = fs::read_to_string(tmp.path().join("AGENTS.md")).unwrap();
assert!(agents.contains("Claw Personal Assistant")); assert!(agents.contains("Claw Personal Assistant"));

View file

@ -4,11 +4,13 @@ pub mod gemini;
pub mod ollama; pub mod ollama;
pub mod openai; pub mod openai;
pub mod openrouter; pub mod openrouter;
pub mod reliable;
pub mod traits; pub mod traits;
pub use traits::Provider; pub use traits::Provider;
use compatible::{AuthStyle, OpenAiCompatibleProvider}; use compatible::{AuthStyle, OpenAiCompatibleProvider};
use reliable::ReliableProvider;
/// Factory: create the right provider from config /// Factory: create the right provider from config
#[allow(clippy::too_many_lines)] #[allow(clippy::too_many_lines)]
@ -114,6 +116,42 @@ pub fn create_provider(name: &str, api_key: Option<&str>) -> anyhow::Result<Box<
} }
} }
/// Create provider chain with retry and fallback behavior.
pub fn create_resilient_provider(
primary_name: &str,
api_key: Option<&str>,
reliability: &crate::config::ReliabilityConfig,
) -> anyhow::Result<Box<dyn Provider>> {
let mut providers: Vec<(String, Box<dyn Provider>)> = Vec::new();
providers.push((
primary_name.to_string(),
create_provider(primary_name, api_key)?,
));
for fallback in &reliability.fallback_providers {
if fallback == primary_name || providers.iter().any(|(name, _)| name == fallback) {
continue;
}
match create_provider(fallback, api_key) {
Ok(provider) => providers.push((fallback.clone(), provider)),
Err(e) => {
tracing::warn!(
fallback_provider = fallback,
"Ignoring invalid fallback provider: {e}"
);
}
}
}
Ok(Box::new(ReliableProvider::new(
providers,
reliability.provider_retries,
reliability.provider_backoff_ms,
)))
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
@ -307,6 +345,34 @@ mod tests {
assert!(create_provider("", None).is_err()); assert!(create_provider("", None).is_err());
} }
#[test]
fn resilient_provider_ignores_duplicate_and_invalid_fallbacks() {
let reliability = crate::config::ReliabilityConfig {
provider_retries: 1,
provider_backoff_ms: 100,
fallback_providers: vec![
"openrouter".into(),
"nonexistent-provider".into(),
"openai".into(),
"openai".into(),
],
channel_initial_backoff_secs: 2,
channel_max_backoff_secs: 60,
scheduler_poll_secs: 15,
scheduler_retries: 2,
};
let provider = create_resilient_provider("openrouter", Some("sk-test"), &reliability);
assert!(provider.is_ok());
}
#[test]
fn resilient_provider_errors_for_invalid_primary() {
let reliability = crate::config::ReliabilityConfig::default();
let provider = create_resilient_provider("totally-invalid", Some("sk-test"), &reliability);
assert!(provider.is_err());
}
#[test] #[test]
fn factory_all_providers_create_successfully() { fn factory_all_providers_create_successfully() {
let providers = [ let providers = [

229
src/providers/reliable.rs Normal file
View file

@ -0,0 +1,229 @@
use super::Provider;
use async_trait::async_trait;
use std::time::Duration;
/// Provider wrapper with retry + fallback behavior.
pub struct ReliableProvider {
providers: Vec<(String, Box<dyn Provider>)>,
max_retries: u32,
base_backoff_ms: u64,
}
impl ReliableProvider {
pub fn new(
providers: Vec<(String, Box<dyn Provider>)>,
max_retries: u32,
base_backoff_ms: u64,
) -> Self {
Self {
providers,
max_retries,
base_backoff_ms: base_backoff_ms.max(50),
}
}
}
#[async_trait]
impl Provider for ReliableProvider {
async fn chat_with_system(
&self,
system_prompt: Option<&str>,
message: &str,
model: &str,
temperature: f64,
) -> anyhow::Result<String> {
let mut failures = Vec::new();
for (provider_name, provider) in &self.providers {
let mut backoff_ms = self.base_backoff_ms;
for attempt in 0..=self.max_retries {
match provider
.chat_with_system(system_prompt, message, model, temperature)
.await
{
Ok(resp) => {
if attempt > 0 {
tracing::info!(
provider = provider_name,
attempt,
"Provider recovered after retries"
);
}
return Ok(resp);
}
Err(e) => {
failures.push(format!(
"{provider_name} attempt {}/{}: {e}",
attempt + 1,
self.max_retries + 1
));
if attempt < self.max_retries {
tracing::warn!(
provider = provider_name,
attempt = attempt + 1,
max_retries = self.max_retries,
"Provider call failed, retrying"
);
tokio::time::sleep(Duration::from_millis(backoff_ms)).await;
backoff_ms = (backoff_ms.saturating_mul(2)).min(10_000);
}
}
}
}
tracing::warn!(provider = provider_name, "Switching to fallback provider");
}
anyhow::bail!("All providers failed. Attempts:\n{}", failures.join("\n"))
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::Arc;
struct MockProvider {
calls: Arc<AtomicUsize>,
fail_until_attempt: usize,
response: &'static str,
error: &'static str,
}
#[async_trait]
impl Provider for MockProvider {
async fn chat_with_system(
&self,
_system_prompt: Option<&str>,
_message: &str,
_model: &str,
_temperature: f64,
) -> anyhow::Result<String> {
let attempt = self.calls.fetch_add(1, Ordering::SeqCst) + 1;
if attempt <= self.fail_until_attempt {
anyhow::bail!(self.error);
}
Ok(self.response.to_string())
}
}
#[tokio::test]
async fn succeeds_without_retry() {
let calls = Arc::new(AtomicUsize::new(0));
let provider = ReliableProvider::new(
vec![(
"primary".into(),
Box::new(MockProvider {
calls: Arc::clone(&calls),
fail_until_attempt: 0,
response: "ok",
error: "boom",
}),
)],
2,
1,
);
let result = provider.chat("hello", "test", 0.0).await.unwrap();
assert_eq!(result, "ok");
assert_eq!(calls.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn retries_then_recovers() {
let calls = Arc::new(AtomicUsize::new(0));
let provider = ReliableProvider::new(
vec![(
"primary".into(),
Box::new(MockProvider {
calls: Arc::clone(&calls),
fail_until_attempt: 1,
response: "recovered",
error: "temporary",
}),
)],
2,
1,
);
let result = provider.chat("hello", "test", 0.0).await.unwrap();
assert_eq!(result, "recovered");
assert_eq!(calls.load(Ordering::SeqCst), 2);
}
#[tokio::test]
async fn falls_back_after_retries_exhausted() {
let primary_calls = Arc::new(AtomicUsize::new(0));
let fallback_calls = Arc::new(AtomicUsize::new(0));
let provider = ReliableProvider::new(
vec![
(
"primary".into(),
Box::new(MockProvider {
calls: Arc::clone(&primary_calls),
fail_until_attempt: usize::MAX,
response: "never",
error: "primary down",
}),
),
(
"fallback".into(),
Box::new(MockProvider {
calls: Arc::clone(&fallback_calls),
fail_until_attempt: 0,
response: "from fallback",
error: "fallback down",
}),
),
],
1,
1,
);
let result = provider.chat("hello", "test", 0.0).await.unwrap();
assert_eq!(result, "from fallback");
assert_eq!(primary_calls.load(Ordering::SeqCst), 2);
assert_eq!(fallback_calls.load(Ordering::SeqCst), 1);
}
#[tokio::test]
async fn returns_aggregated_error_when_all_providers_fail() {
let provider = ReliableProvider::new(
vec![
(
"p1".into(),
Box::new(MockProvider {
calls: Arc::new(AtomicUsize::new(0)),
fail_until_attempt: usize::MAX,
response: "never",
error: "p1 error",
}),
),
(
"p2".into(),
Box::new(MockProvider {
calls: Arc::new(AtomicUsize::new(0)),
fail_until_attempt: usize::MAX,
response: "never",
error: "p2 error",
}),
),
],
0,
1,
);
let err = provider
.chat("hello", "test", 0.0)
.await
.expect_err("all providers should fail");
let msg = err.to_string();
assert!(msg.contains("All providers failed"));
assert!(msg.contains("p1 attempt 1/1"));
assert!(msg.contains("p2 attempt 1/1"));
}
}

View file

@ -7,17 +7,21 @@ pub use traits::RuntimeAdapter;
use crate::config::RuntimeConfig; use crate::config::RuntimeConfig;
/// Factory: create the right runtime from config /// Factory: create the right runtime from config
pub fn create_runtime(config: &RuntimeConfig) -> Box<dyn RuntimeAdapter> { pub fn create_runtime(config: &RuntimeConfig) -> anyhow::Result<Box<dyn RuntimeAdapter>> {
match config.kind.as_str() { match config.kind.as_str() {
"native" | "docker" => Box::new(NativeRuntime::new()), "native" => Ok(Box::new(NativeRuntime::new())),
"cloudflare" => { "docker" => anyhow::bail!(
tracing::warn!("Cloudflare runtime not yet implemented, falling back to native"); "runtime.kind='docker' is not implemented yet. Use runtime.kind='native' until container runtime support lands."
Box::new(NativeRuntime::new()) ),
} "cloudflare" => anyhow::bail!(
_ => { "runtime.kind='cloudflare' is not implemented yet. Use runtime.kind='native' for now."
tracing::warn!("Unknown runtime '{}', falling back to native", config.kind); ),
Box::new(NativeRuntime::new()) other if other.trim().is_empty() => anyhow::bail!(
} "runtime.kind cannot be empty. Supported values: native"
),
other => anyhow::bail!(
"Unknown runtime kind '{other}'. Supported values: native"
),
} }
} }
@ -30,44 +34,52 @@ mod tests {
let cfg = RuntimeConfig { let cfg = RuntimeConfig {
kind: "native".into(), kind: "native".into(),
}; };
let rt = create_runtime(&cfg); let rt = create_runtime(&cfg).unwrap();
assert_eq!(rt.name(), "native"); assert_eq!(rt.name(), "native");
assert!(rt.has_shell_access()); assert!(rt.has_shell_access());
} }
#[test] #[test]
fn factory_docker_returns_native() { fn factory_docker_errors() {
let cfg = RuntimeConfig { let cfg = RuntimeConfig {
kind: "docker".into(), kind: "docker".into(),
}; };
let rt = create_runtime(&cfg); match create_runtime(&cfg) {
assert_eq!(rt.name(), "native"); Err(err) => assert!(err.to_string().contains("not implemented")),
Ok(_) => panic!("docker runtime should error"),
}
} }
#[test] #[test]
fn factory_cloudflare_falls_back() { fn factory_cloudflare_errors() {
let cfg = RuntimeConfig { let cfg = RuntimeConfig {
kind: "cloudflare".into(), kind: "cloudflare".into(),
}; };
let rt = create_runtime(&cfg); match create_runtime(&cfg) {
assert_eq!(rt.name(), "native"); Err(err) => assert!(err.to_string().contains("not implemented")),
Ok(_) => panic!("cloudflare runtime should error"),
}
} }
#[test] #[test]
fn factory_unknown_falls_back() { fn factory_unknown_errors() {
let cfg = RuntimeConfig { let cfg = RuntimeConfig {
kind: "wasm-edge-unknown".into(), kind: "wasm-edge-unknown".into(),
}; };
let rt = create_runtime(&cfg); match create_runtime(&cfg) {
assert_eq!(rt.name(), "native"); Err(err) => assert!(err.to_string().contains("Unknown runtime kind")),
Ok(_) => panic!("unknown runtime should error"),
}
} }
#[test] #[test]
fn factory_empty_falls_back() { fn factory_empty_errors() {
let cfg = RuntimeConfig { let cfg = RuntimeConfig {
kind: String::new(), kind: String::new(),
}; };
let rt = create_runtime(&cfg); match create_runtime(&cfg) {
assert_eq!(rt.name(), "native"); Err(err) => assert!(err.to_string().contains("cannot be empty")),
Ok(_) => panic!("empty runtime should error"),
}
} }
} }

View file

@ -258,8 +258,14 @@ impl SecurityPolicy {
/// Validate that a resolved path is still inside the workspace. /// Validate that a resolved path is still inside the workspace.
/// Call this AFTER joining `workspace_dir` + relative path and canonicalizing. /// Call this AFTER joining `workspace_dir` + relative path and canonicalizing.
pub fn is_resolved_path_allowed(&self, resolved: &Path) -> bool { pub fn is_resolved_path_allowed(&self, resolved: &Path) -> bool {
// Must be under workspace_dir (prevents symlink escapes) // Must be under workspace_dir (prevents symlink escapes).
resolved.starts_with(&self.workspace_dir) // Prefer canonical workspace root so `/a/../b` style config paths don't
// cause false positives or negatives.
let workspace_root = self
.workspace_dir
.canonicalize()
.unwrap_or_else(|_| self.workspace_dir.clone());
resolved.starts_with(workspace_root)
} }
/// Check if autonomy level permits any action at all /// Check if autonomy level permits any action at all

View file

@ -79,8 +79,53 @@ impl SecretStore {
/// - `enc2:` prefix → ChaCha20-Poly1305 (current format) /// - `enc2:` prefix → ChaCha20-Poly1305 (current format)
/// - `enc:` prefix → legacy XOR cipher (backward compatibility for migration) /// - `enc:` prefix → legacy XOR cipher (backward compatibility for migration)
/// - No prefix → returned as-is (plaintext config) /// - No prefix → returned as-is (plaintext config)
///
/// **Warning**: Legacy `enc:` values are insecure. Use `decrypt_and_migrate` to
/// automatically upgrade them to the secure `enc2:` format.
pub fn decrypt(&self, value: &str) -> Result<String> { pub fn decrypt(&self, value: &str) -> Result<String> {
if let Some(hex_str) = value.strip_prefix("enc2:") { if let Some(hex_str) = value.strip_prefix("enc2:") {
self.decrypt_chacha20(hex_str)
} else if let Some(hex_str) = value.strip_prefix("enc:") {
self.decrypt_legacy_xor(hex_str)
} else {
Ok(value.to_string())
}
}
/// Decrypt a secret and return a migrated `enc2:` value if the input used legacy `enc:` format.
///
/// Returns `(plaintext, Some(new_enc2_value))` if migration occurred, or
/// `(plaintext, None)` if no migration was needed.
///
/// This allows callers to persist the upgraded value back to config.
pub fn decrypt_and_migrate(&self, value: &str) -> Result<(String, Option<String>)> {
if let Some(hex_str) = value.strip_prefix("enc2:") {
// Already using secure format — no migration needed
let plaintext = self.decrypt_chacha20(hex_str)?;
Ok((plaintext, None))
} else if let Some(hex_str) = value.strip_prefix("enc:") {
// Legacy XOR cipher — decrypt and re-encrypt with ChaCha20-Poly1305
tracing::warn!(
"Decrypting legacy XOR-encrypted secret (enc: prefix). \
This format is insecure and will be removed in a future release. \
The secret will be automatically migrated to enc2: (ChaCha20-Poly1305)."
);
let plaintext = self.decrypt_legacy_xor(hex_str)?;
let migrated = self.encrypt(&plaintext)?;
Ok((plaintext, Some(migrated)))
} else {
// Plaintext — no migration needed
Ok((value.to_string(), None))
}
}
/// Check if a value uses the legacy `enc:` format that should be migrated.
pub fn needs_migration(value: &str) -> bool {
value.starts_with("enc:")
}
/// Decrypt using ChaCha20-Poly1305 (current secure format).
fn decrypt_chacha20(&self, hex_str: &str) -> Result<String> {
let blob = let blob =
hex_decode(hex_str).context("Failed to decode encrypted secret (corrupt hex)")?; hex_decode(hex_str).context("Failed to decode encrypted secret (corrupt hex)")?;
anyhow::ensure!( anyhow::ensure!(
@ -100,17 +145,16 @@ impl SecretStore {
String::from_utf8(plaintext_bytes) String::from_utf8(plaintext_bytes)
.context("Decrypted secret is not valid UTF-8 — corrupt data") .context("Decrypted secret is not valid UTF-8 — corrupt data")
} else if let Some(hex_str) = value.strip_prefix("enc:") { }
// Legacy XOR cipher — decrypt for backward compatibility
/// Decrypt using legacy XOR cipher (insecure, for backward compatibility only).
fn decrypt_legacy_xor(&self, hex_str: &str) -> Result<String> {
let ciphertext = hex_decode(hex_str) let ciphertext = hex_decode(hex_str)
.context("Failed to decode legacy encrypted secret (corrupt hex)")?; .context("Failed to decode legacy encrypted secret (corrupt hex)")?;
let key = self.load_or_create_key()?; let key = self.load_or_create_key()?;
let plaintext_bytes = xor_cipher(&ciphertext, &key); let plaintext_bytes = xor_cipher(&ciphertext, &key);
String::from_utf8(plaintext_bytes) String::from_utf8(plaintext_bytes)
.context("Decrypted legacy secret is not valid UTF-8 — wrong key or corrupt data") .context("Decrypted legacy secret is not valid UTF-8 — wrong key or corrupt data")
} else {
Ok(value.to_string())
}
} }
/// Check if a value is already encrypted (current or legacy format). /// Check if a value is already encrypted (current or legacy format).
@ -118,6 +162,11 @@ impl SecretStore {
value.starts_with("enc2:") || value.starts_with("enc:") value.starts_with("enc2:") || value.starts_with("enc:")
} }
/// Check if a value uses the secure `enc2:` format.
pub fn is_secure_encrypted(value: &str) -> bool {
value.starts_with("enc2:")
}
/// Load the encryption key from disk, or create one if it doesn't exist. /// Load the encryption key from disk, or create one if it doesn't exist.
fn load_or_create_key(&self) -> Result<Vec<u8>> { fn load_or_create_key(&self) -> Result<Vec<u8>> {
if self.key_path.exists() { if self.key_path.exists() {
@ -132,13 +181,22 @@ impl SecretStore {
fs::write(&self.key_path, hex_encode(&key)) fs::write(&self.key_path, hex_encode(&key))
.context("Failed to write secret key file")?; .context("Failed to write secret key file")?;
// Set restrictive permissions (Unix only) // Set restrictive permissions
#[cfg(unix)] #[cfg(unix)]
{ {
use std::os::unix::fs::PermissionsExt; use std::os::unix::fs::PermissionsExt;
fs::set_permissions(&self.key_path, fs::Permissions::from_mode(0o600)) fs::set_permissions(&self.key_path, fs::Permissions::from_mode(0o600))
.context("Failed to set key file permissions")?; .context("Failed to set key file permissions")?;
} }
#[cfg(windows)]
{
// On Windows, use icacls to restrict permissions to current user only
let _ = std::process::Command::new("icacls")
.arg(&self.key_path)
.args(["/inheritance:r", "/grant:r"])
.arg(format!("{}:F", std::env::var("USERNAME").unwrap_or_default()))
.output();
}
Ok(key) Ok(key)
} }
@ -382,6 +440,258 @@ mod tests {
assert_eq!(decrypted, plaintext, "Legacy XOR values must still decrypt"); assert_eq!(decrypted, plaintext, "Legacy XOR values must still decrypt");
} }
// ── Migration tests ─────────────────────────────────────────
#[test]
fn needs_migration_detects_legacy_prefix() {
assert!(SecretStore::needs_migration("enc:aabbcc"));
assert!(!SecretStore::needs_migration("enc2:aabbcc"));
assert!(!SecretStore::needs_migration("sk-plaintext"));
assert!(!SecretStore::needs_migration(""));
}
#[test]
fn is_secure_encrypted_detects_enc2_only() {
assert!(SecretStore::is_secure_encrypted("enc2:aabbcc"));
assert!(!SecretStore::is_secure_encrypted("enc:aabbcc"));
assert!(!SecretStore::is_secure_encrypted("sk-plaintext"));
assert!(!SecretStore::is_secure_encrypted(""));
}
#[test]
fn decrypt_and_migrate_returns_none_for_enc2() {
let tmp = TempDir::new().unwrap();
let store = SecretStore::new(tmp.path(), true);
let encrypted = store.encrypt("my-secret").unwrap();
assert!(encrypted.starts_with("enc2:"));
let (plaintext, migrated) = store.decrypt_and_migrate(&encrypted).unwrap();
assert_eq!(plaintext, "my-secret");
assert!(
migrated.is_none(),
"enc2: values should not trigger migration"
);
}
#[test]
fn decrypt_and_migrate_returns_none_for_plaintext() {
let tmp = TempDir::new().unwrap();
let store = SecretStore::new(tmp.path(), true);
let (plaintext, migrated) = store.decrypt_and_migrate("sk-plaintext-key").unwrap();
assert_eq!(plaintext, "sk-plaintext-key");
assert!(
migrated.is_none(),
"Plaintext values should not trigger migration"
);
}
#[test]
fn decrypt_and_migrate_upgrades_legacy_xor() {
let tmp = TempDir::new().unwrap();
let store = SecretStore::new(tmp.path(), true);
// Create key first
let _ = store.encrypt("setup").unwrap();
let key = store.load_or_create_key().unwrap();
// Manually create a legacy XOR-encrypted value
let plaintext = "sk-legacy-secret-to-migrate";
let ciphertext = xor_cipher(plaintext.as_bytes(), &key);
let legacy_value = format!("enc:{}", hex_encode(&ciphertext));
// Verify it needs migration
assert!(SecretStore::needs_migration(&legacy_value));
// Decrypt and migrate
let (decrypted, migrated) = store.decrypt_and_migrate(&legacy_value).unwrap();
assert_eq!(decrypted, plaintext, "Plaintext must match original");
assert!(migrated.is_some(), "Legacy value should trigger migration");
let new_value = migrated.unwrap();
assert!(
new_value.starts_with("enc2:"),
"Migrated value must use enc2: prefix"
);
assert!(
!SecretStore::needs_migration(&new_value),
"Migrated value should not need migration"
);
// Verify the migrated value decrypts correctly
let (decrypted2, migrated2) = store.decrypt_and_migrate(&new_value).unwrap();
assert_eq!(
decrypted2, plaintext,
"Migrated value must decrypt to same plaintext"
);
assert!(
migrated2.is_none(),
"Migrated value should not trigger another migration"
);
}
#[test]
fn decrypt_and_migrate_handles_unicode() {
let tmp = TempDir::new().unwrap();
let store = SecretStore::new(tmp.path(), true);
let _ = store.encrypt("setup").unwrap();
let key = store.load_or_create_key().unwrap();
let plaintext = "sk-日本語-émojis-🦀-тест";
let ciphertext = xor_cipher(plaintext.as_bytes(), &key);
let legacy_value = format!("enc:{}", hex_encode(&ciphertext));
let (decrypted, migrated) = store.decrypt_and_migrate(&legacy_value).unwrap();
assert_eq!(decrypted, plaintext);
assert!(migrated.is_some());
// Verify migrated value works
let new_value = migrated.unwrap();
let (decrypted2, _) = store.decrypt_and_migrate(&new_value).unwrap();
assert_eq!(decrypted2, plaintext);
}
#[test]
fn decrypt_and_migrate_handles_empty_secret() {
let tmp = TempDir::new().unwrap();
let store = SecretStore::new(tmp.path(), true);
let _ = store.encrypt("setup").unwrap();
let key = store.load_or_create_key().unwrap();
// Empty plaintext XOR-encrypted
let plaintext = "";
let ciphertext = xor_cipher(plaintext.as_bytes(), &key);
let legacy_value = format!("enc:{}", hex_encode(&ciphertext));
let (decrypted, migrated) = store.decrypt_and_migrate(&legacy_value).unwrap();
assert_eq!(decrypted, plaintext);
// Empty string encryption returns empty string (not enc2:)
assert!(migrated.is_some());
assert_eq!(migrated.unwrap(), "");
}
#[test]
fn decrypt_and_migrate_handles_long_secret() {
let tmp = TempDir::new().unwrap();
let store = SecretStore::new(tmp.path(), true);
let _ = store.encrypt("setup").unwrap();
let key = store.load_or_create_key().unwrap();
let plaintext = "a".repeat(10_000);
let ciphertext = xor_cipher(plaintext.as_bytes(), &key);
let legacy_value = format!("enc:{}", hex_encode(&ciphertext));
let (decrypted, migrated) = store.decrypt_and_migrate(&legacy_value).unwrap();
assert_eq!(decrypted, plaintext);
assert!(migrated.is_some());
let new_value = migrated.unwrap();
let (decrypted2, _) = store.decrypt_and_migrate(&new_value).unwrap();
assert_eq!(decrypted2, plaintext);
}
#[test]
fn decrypt_and_migrate_fails_on_corrupt_legacy_hex() {
let tmp = TempDir::new().unwrap();
let store = SecretStore::new(tmp.path(), true);
let _ = store.encrypt("setup").unwrap();
let result = store.decrypt_and_migrate("enc:not-valid-hex!!");
assert!(result.is_err(), "Corrupt hex should fail");
}
#[test]
fn decrypt_and_migrate_wrong_key_produces_garbage_or_fails() {
let tmp1 = TempDir::new().unwrap();
let tmp2 = TempDir::new().unwrap();
let store1 = SecretStore::new(tmp1.path(), true);
let store2 = SecretStore::new(tmp2.path(), true);
// Create keys for both stores
let _ = store1.encrypt("setup").unwrap();
let _ = store2.encrypt("setup").unwrap();
let key1 = store1.load_or_create_key().unwrap();
// Encrypt with store1's key
let plaintext = "secret-for-store1";
let ciphertext = xor_cipher(plaintext.as_bytes(), &key1);
let legacy_value = format!("enc:{}", hex_encode(&ciphertext));
// Decrypt with store2 — XOR will produce garbage bytes
// This may fail with UTF-8 error or succeed with garbage plaintext
match store2.decrypt_and_migrate(&legacy_value) {
Ok((decrypted, _)) => {
// If it succeeds, the plaintext should be garbage (not the original)
assert_ne!(
decrypted, plaintext,
"Wrong key should produce garbage plaintext"
);
}
Err(e) => {
// Expected: UTF-8 decoding failure from garbage bytes
assert!(
e.to_string().contains("UTF-8"),
"Error should be UTF-8 related: {e}"
);
}
}
}
#[test]
fn migration_produces_different_ciphertext_each_time() {
let tmp = TempDir::new().unwrap();
let store = SecretStore::new(tmp.path(), true);
let _ = store.encrypt("setup").unwrap();
let key = store.load_or_create_key().unwrap();
let plaintext = "sk-same-secret";
let ciphertext = xor_cipher(plaintext.as_bytes(), &key);
let legacy_value = format!("enc:{}", hex_encode(&ciphertext));
let (_, migrated1) = store.decrypt_and_migrate(&legacy_value).unwrap();
let (_, migrated2) = store.decrypt_and_migrate(&legacy_value).unwrap();
assert!(migrated1.is_some());
assert!(migrated2.is_some());
assert_ne!(
migrated1.unwrap(),
migrated2.unwrap(),
"Each migration should produce different ciphertext (random nonce)"
);
}
#[test]
fn migrated_value_is_tamper_resistant() {
let tmp = TempDir::new().unwrap();
let store = SecretStore::new(tmp.path(), true);
let _ = store.encrypt("setup").unwrap();
let key = store.load_or_create_key().unwrap();
let plaintext = "sk-sensitive-data";
let ciphertext = xor_cipher(plaintext.as_bytes(), &key);
let legacy_value = format!("enc:{}", hex_encode(&ciphertext));
let (_, migrated) = store.decrypt_and_migrate(&legacy_value).unwrap();
let new_value = migrated.unwrap();
// Tamper with the migrated value
let hex_str = &new_value[5..];
let mut blob = hex_decode(hex_str).unwrap();
if blob.len() > NONCE_LEN {
blob[NONCE_LEN] ^= 0xff;
}
let tampered = format!("enc2:{}", hex_encode(&blob));
let result = store.decrypt_and_migrate(&tampered);
assert!(result.is_err(), "Tampered migrated value must be rejected");
}
// ── Low-level helpers ─────────────────────────────────────── // ── Low-level helpers ───────────────────────────────────────
#[test] #[test]

284
src/service/mod.rs Normal file
View file

@ -0,0 +1,284 @@
use crate::config::Config;
use anyhow::{Context, Result};
use std::fs;
use std::path::PathBuf;
use std::process::Command;
const SERVICE_LABEL: &str = "com.zeroclaw.daemon";
pub fn handle_command(command: super::ServiceCommands, config: &Config) -> Result<()> {
match command {
super::ServiceCommands::Install => install(config),
super::ServiceCommands::Start => start(config),
super::ServiceCommands::Stop => stop(config),
super::ServiceCommands::Status => status(config),
super::ServiceCommands::Uninstall => uninstall(config),
}
}
fn install(config: &Config) -> Result<()> {
if cfg!(target_os = "macos") {
install_macos(config)
} else if cfg!(target_os = "linux") {
install_linux(config)
} else {
anyhow::bail!("Service management is supported on macOS and Linux only");
}
}
fn start(config: &Config) -> Result<()> {
if cfg!(target_os = "macos") {
let plist = macos_service_file()?;
run_checked(Command::new("launchctl").arg("load").arg("-w").arg(&plist))?;
run_checked(Command::new("launchctl").arg("start").arg(SERVICE_LABEL))?;
println!("✅ Service started");
Ok(())
} else if cfg!(target_os = "linux") {
run_checked(Command::new("systemctl").args(["--user", "daemon-reload"]))?;
run_checked(Command::new("systemctl").args(["--user", "start", "zeroclaw.service"]))?;
println!("✅ Service started");
Ok(())
} else {
let _ = config;
anyhow::bail!("Service management is supported on macOS and Linux only")
}
}
fn stop(config: &Config) -> Result<()> {
if cfg!(target_os = "macos") {
let plist = macos_service_file()?;
let _ = run_checked(Command::new("launchctl").arg("stop").arg(SERVICE_LABEL));
let _ = run_checked(
Command::new("launchctl")
.arg("unload")
.arg("-w")
.arg(&plist),
);
println!("✅ Service stopped");
Ok(())
} else if cfg!(target_os = "linux") {
let _ = run_checked(Command::new("systemctl").args(["--user", "stop", "zeroclaw.service"]));
println!("✅ Service stopped");
Ok(())
} else {
let _ = config;
anyhow::bail!("Service management is supported on macOS and Linux only")
}
}
fn status(config: &Config) -> Result<()> {
if cfg!(target_os = "macos") {
let out = run_capture(Command::new("launchctl").arg("list"))?;
let running = out.lines().any(|line| line.contains(SERVICE_LABEL));
println!(
"Service: {}",
if running {
"✅ running/loaded"
} else {
"❌ not loaded"
}
);
println!("Unit: {}", macos_service_file()?.display());
return Ok(());
}
if cfg!(target_os = "linux") {
let out = run_capture(Command::new("systemctl").args([
"--user",
"is-active",
"zeroclaw.service",
]))
.unwrap_or_else(|_| "unknown".into());
println!("Service state: {}", out.trim());
println!("Unit: {}", linux_service_file(config)?.display());
return Ok(());
}
anyhow::bail!("Service management is supported on macOS and Linux only")
}
fn uninstall(config: &Config) -> Result<()> {
stop(config)?;
if cfg!(target_os = "macos") {
let file = macos_service_file()?;
if file.exists() {
fs::remove_file(&file)
.with_context(|| format!("Failed to remove {}", file.display()))?;
}
println!("✅ Service uninstalled ({})", file.display());
return Ok(());
}
if cfg!(target_os = "linux") {
let file = linux_service_file(config)?;
if file.exists() {
fs::remove_file(&file)
.with_context(|| format!("Failed to remove {}", file.display()))?;
}
let _ = run_checked(Command::new("systemctl").args(["--user", "daemon-reload"]));
println!("✅ Service uninstalled ({})", file.display());
return Ok(());
}
anyhow::bail!("Service management is supported on macOS and Linux only")
}
fn install_macos(config: &Config) -> Result<()> {
let file = macos_service_file()?;
if let Some(parent) = file.parent() {
fs::create_dir_all(parent)?;
}
let exe = std::env::current_exe().context("Failed to resolve current executable")?;
let logs_dir = config
.config_path
.parent()
.map_or_else(|| PathBuf::from("."), PathBuf::from)
.join("logs");
fs::create_dir_all(&logs_dir)?;
let stdout = logs_dir.join("daemon.stdout.log");
let stderr = logs_dir.join("daemon.stderr.log");
let plist = format!(
r#"<?xml version=\"1.0\" encoding=\"UTF-8\"?>
<!DOCTYPE plist PUBLIC \"-//Apple//DTD PLIST 1.0//EN\" \"http://www.apple.com/DTDs/PropertyList-1.0.dtd\">
<plist version=\"1.0\">
<dict>
<key>Label</key>
<string>{label}</string>
<key>ProgramArguments</key>
<array>
<string>{exe}</string>
<string>daemon</string>
</array>
<key>RunAtLoad</key>
<true/>
<key>KeepAlive</key>
<true/>
<key>StandardOutPath</key>
<string>{stdout}</string>
<key>StandardErrorPath</key>
<string>{stderr}</string>
</dict>
</plist>
"#,
label = SERVICE_LABEL,
exe = xml_escape(&exe.display().to_string()),
stdout = xml_escape(&stdout.display().to_string()),
stderr = xml_escape(&stderr.display().to_string())
);
fs::write(&file, plist)?;
println!("✅ Installed launchd service: {}", file.display());
println!(" Start with: zeroclaw service start");
Ok(())
}
fn install_linux(config: &Config) -> Result<()> {
let file = linux_service_file(config)?;
if let Some(parent) = file.parent() {
fs::create_dir_all(parent)?;
}
let exe = std::env::current_exe().context("Failed to resolve current executable")?;
let unit = format!(
"[Unit]\nDescription=ZeroClaw daemon\nAfter=network.target\n\n[Service]\nType=simple\nExecStart={} daemon\nRestart=always\nRestartSec=3\n\n[Install]\nWantedBy=default.target\n",
exe.display()
);
fs::write(&file, unit)?;
let _ = run_checked(Command::new("systemctl").args(["--user", "daemon-reload"]));
let _ = run_checked(Command::new("systemctl").args(["--user", "enable", "zeroclaw.service"]));
println!("✅ Installed systemd user service: {}", file.display());
println!(" Start with: zeroclaw service start");
Ok(())
}
fn macos_service_file() -> Result<PathBuf> {
let home = directories::UserDirs::new()
.map(|u| u.home_dir().to_path_buf())
.context("Could not find home directory")?;
Ok(home
.join("Library")
.join("LaunchAgents")
.join(format!("{SERVICE_LABEL}.plist")))
}
fn linux_service_file(config: &Config) -> Result<PathBuf> {
let home = directories::UserDirs::new()
.map(|u| u.home_dir().to_path_buf())
.context("Could not find home directory")?;
let _ = config;
Ok(home
.join(".config")
.join("systemd")
.join("user")
.join("zeroclaw.service"))
}
fn run_checked(command: &mut Command) -> Result<()> {
let output = command.output().context("Failed to spawn command")?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
anyhow::bail!("Command failed: {}", stderr.trim());
}
Ok(())
}
fn run_capture(command: &mut Command) -> Result<String> {
let output = command.output().context("Failed to spawn command")?;
let mut text = String::from_utf8_lossy(&output.stdout).to_string();
if text.trim().is_empty() {
text = String::from_utf8_lossy(&output.stderr).to_string();
}
Ok(text)
}
fn xml_escape(raw: &str) -> String {
raw.replace('&', "&amp;")
.replace('<', "&lt;")
.replace('>', "&gt;")
.replace('"', "&quot;")
.replace('\'', "&apos;")
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn xml_escape_escapes_reserved_chars() {
let escaped = xml_escape("<&>\"' and text");
assert_eq!(escaped, "&lt;&amp;&gt;&quot;&apos; and text");
}
#[test]
fn run_capture_reads_stdout() {
let out = run_capture(Command::new("sh").args(["-lc", "echo hello"]))
.expect("stdout capture should succeed");
assert_eq!(out.trim(), "hello");
}
#[test]
fn run_capture_falls_back_to_stderr() {
let out = run_capture(Command::new("sh").args(["-lc", "echo warn 1>&2"]))
.expect("stderr capture should succeed");
assert_eq!(out.trim(), "warn");
}
#[test]
fn run_checked_errors_on_non_zero_status() {
let err = run_checked(Command::new("sh").args(["-lc", "exit 17"]))
.expect_err("non-zero exit should error");
assert!(err.to_string().contains("Command failed"));
}
#[test]
fn linux_service_file_has_expected_suffix() {
let file = linux_service_file(&Config::default()).unwrap();
let path = file.to_string_lossy();
assert!(path.ends_with(".config/systemd/user/zeroclaw.service"));
}
}

View file

@ -221,6 +221,23 @@ pub fn init_skills_dir(workspace_dir: &Path) -> Result<()> {
Ok(()) Ok(())
} }
/// Recursively copy a directory (used as fallback when symlinks aren't available)
#[cfg(any(windows, not(unix)))]
fn copy_dir_recursive(src: &Path, dest: &Path) -> Result<()> {
std::fs::create_dir_all(dest)?;
for entry in std::fs::read_dir(src)? {
let entry = entry?;
let src_path = entry.path();
let dest_path = dest.join(entry.file_name());
if src_path.is_dir() {
copy_dir_recursive(&src_path, &dest_path)?;
} else {
std::fs::copy(&src_path, &dest_path)?;
}
}
Ok(())
}
/// Handle the `skills` CLI command /// Handle the `skills` CLI command
pub fn handle_command(command: super::SkillCommands, workspace_dir: &Path) -> Result<()> { pub fn handle_command(command: super::SkillCommands, workspace_dir: &Path) -> Result<()> {
match command { match command {
@ -295,19 +312,61 @@ pub fn handle_command(command: super::SkillCommands, workspace_dir: &Path) -> Re
let dest = skills_path.join(name); let dest = skills_path.join(name);
#[cfg(unix)] #[cfg(unix)]
std::os::unix::fs::symlink(&src, &dest)?;
#[cfg(not(unix))]
{ {
// On non-unix, copy the directory std::os::unix::fs::symlink(&src, &dest)?;
anyhow::bail!("Symlink not supported on this platform. Copy the skill directory manually.");
}
println!( println!(
" {} Skill linked: {}", " {} Skill linked: {}",
console::style("").green().bold(), console::style("").green().bold(),
dest.display() dest.display()
); );
} }
#[cfg(windows)]
{
// On Windows, try symlink first (requires admin or developer mode),
// fall back to directory junction, then copy
use std::os::windows::fs::symlink_dir;
if symlink_dir(&src, &dest).is_ok() {
println!(
" {} Skill linked: {}",
console::style("").green().bold(),
dest.display()
);
} else {
// Try junction as fallback (works without admin)
let junction_result = std::process::Command::new("cmd")
.args(["/C", "mklink", "/J"])
.arg(&dest)
.arg(&src)
.output();
if junction_result.is_ok() && junction_result.unwrap().status.success() {
println!(
" {} Skill linked (junction): {}",
console::style("").green().bold(),
dest.display()
);
} else {
// Final fallback: copy the directory
copy_dir_recursive(&src, &dest)?;
println!(
" {} Skill copied: {}",
console::style("").green().bold(),
dest.display()
);
}
}
}
#[cfg(not(any(unix, windows)))]
{
// On other platforms, copy the directory
copy_dir_recursive(&src, &dest)?;
println!(
" {} Skill copied: {}",
console::style("").green().bold(),
dest.display()
);
}
}
Ok(()) Ok(())
} }
@ -631,3 +690,6 @@ description = "Bare minimum"
assert_eq!(skills[0].name, "from-toml"); // TOML takes priority assert_eq!(skills[0].name, "from-toml"); // TOML takes priority
} }
} }
#[cfg(test)]
mod symlink_tests;

109
src/skills/symlink_tests.rs Normal file
View file

@ -0,0 +1,109 @@
#[cfg(test)]
mod symlink_tests {
use crate::skills::skills_dir;
use std::path::Path;
use tempfile::TempDir;
#[test]
fn test_skills_symlink_unix_edge_cases() {
let tmp = TempDir::new().unwrap();
let workspace_dir = tmp.path().join("workspace");
std::fs::create_dir_all(&workspace_dir).unwrap();
let skills_path = skills_dir(&workspace_dir);
std::fs::create_dir_all(&skills_path).unwrap();
// Test case 1: Valid symlink creation on Unix
#[cfg(unix)]
{
let source_dir = tmp.path().join("source_skill");
std::fs::create_dir_all(&source_dir).unwrap();
std::fs::write(source_dir.join("SKILL.md"), "# Test Skill\nContent").unwrap();
let dest_link = skills_path.join("linked_skill");
// Create symlink
let result = std::os::unix::fs::symlink(&source_dir, &dest_link);
assert!(result.is_ok(), "Symlink creation should succeed");
// Verify symlink works
assert!(dest_link.exists());
assert!(dest_link.is_symlink());
// Verify we can read through symlink
let content = std::fs::read_to_string(dest_link.join("SKILL.md"));
assert!(content.is_ok());
assert!(content.unwrap().contains("Test Skill"));
// Test case 2: Symlink to non-existent target should fail gracefully
let broken_link = skills_path.join("broken_skill");
let non_existent = tmp.path().join("non_existent");
let result = std::os::unix::fs::symlink(&non_existent, &broken_link);
assert!(
result.is_ok(),
"Symlink creation should succeed even if target doesn't exist"
);
// But reading through it should fail
let content = std::fs::read_to_string(broken_link.join("SKILL.md"));
assert!(content.is_err());
}
// Test case 3: Non-Unix platforms should handle symlink errors gracefully
#[cfg(not(unix))]
{
let source_dir = tmp.path().join("source_skill");
std::fs::create_dir_all(&source_dir).unwrap();
let dest_link = skills_path.join("linked_skill");
// Symlink should fail on non-Unix
let result = std::os::unix::fs::symlink(&source_dir, &dest_link);
assert!(result.is_err());
// Directory should not exist
assert!(!dest_link.exists());
}
// Test case 4: skills_dir function edge cases
let workspace_with_trailing_slash = format!("{}/", workspace_dir.display());
let path_from_str = skills_dir(Path::new(&workspace_with_trailing_slash));
assert_eq!(path_from_str, skills_path);
// Test case 5: Empty workspace directory
let empty_workspace = tmp.path().join("empty");
let empty_skills_path = skills_dir(&empty_workspace);
assert_eq!(empty_skills_path, empty_workspace.join("skills"));
assert!(!empty_skills_path.exists());
}
#[test]
fn test_skills_symlink_permissions_and_safety() {
let tmp = TempDir::new().unwrap();
let workspace_dir = tmp.path().join("workspace");
std::fs::create_dir_all(&workspace_dir).unwrap();
let skills_path = skills_dir(&workspace_dir);
std::fs::create_dir_all(&skills_path).unwrap();
#[cfg(unix)]
{
// Test case: Symlink outside workspace should be allowed (user responsibility)
let outside_dir = tmp.path().join("outside_skill");
std::fs::create_dir_all(&outside_dir).unwrap();
std::fs::write(outside_dir.join("SKILL.md"), "# Outside Skill\nContent").unwrap();
let dest_link = skills_path.join("outside_skill");
let result = std::os::unix::fs::symlink(&outside_dir, &dest_link);
assert!(
result.is_ok(),
"Should allow symlinking to directories outside workspace"
);
// Should still be readable
let content = std::fs::read_to_string(dest_link.join("SKILL.md"));
assert!(content.is_ok());
assert!(content.unwrap().contains("Outside Skill"));
}
}
}

View file

@ -55,7 +55,30 @@ impl Tool for FileReadTool {
let full_path = self.security.workspace_dir.join(path); let full_path = self.security.workspace_dir.join(path);
match tokio::fs::read_to_string(&full_path).await { // Resolve path before reading to block symlink escapes.
let resolved_path = match tokio::fs::canonicalize(&full_path).await {
Ok(p) => p,
Err(e) => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!("Failed to resolve file path: {e}")),
});
}
};
if !self.security.is_resolved_path_allowed(&resolved_path) {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!(
"Resolved path escapes workspace: {}",
resolved_path.display()
)),
});
}
match tokio::fs::read_to_string(&resolved_path).await {
Ok(contents) => Ok(ToolResult { Ok(contents) => Ok(ToolResult {
success: true, success: true,
output: contents, output: contents,
@ -127,7 +150,7 @@ mod tests {
let tool = FileReadTool::new(test_security(dir.clone())); let tool = FileReadTool::new(test_security(dir.clone()));
let result = tool.execute(json!({"path": "nope.txt"})).await.unwrap(); let result = tool.execute(json!({"path": "nope.txt"})).await.unwrap();
assert!(!result.success); assert!(!result.success);
assert!(result.error.as_ref().unwrap().contains("Failed to read")); assert!(result.error.as_ref().unwrap().contains("Failed to resolve"));
let _ = tokio::fs::remove_dir_all(&dir).await; let _ = tokio::fs::remove_dir_all(&dir).await;
} }
@ -200,4 +223,36 @@ mod tests {
let _ = tokio::fs::remove_dir_all(&dir).await; let _ = tokio::fs::remove_dir_all(&dir).await;
} }
#[cfg(unix)]
#[tokio::test]
async fn file_read_blocks_symlink_escape() {
use std::os::unix::fs::symlink;
let root = std::env::temp_dir().join("zeroclaw_test_file_read_symlink_escape");
let workspace = root.join("workspace");
let outside = root.join("outside");
let _ = tokio::fs::remove_dir_all(&root).await;
tokio::fs::create_dir_all(&workspace).await.unwrap();
tokio::fs::create_dir_all(&outside).await.unwrap();
tokio::fs::write(outside.join("secret.txt"), "outside workspace")
.await
.unwrap();
symlink(outside.join("secret.txt"), workspace.join("escape.txt")).unwrap();
let tool = FileReadTool::new(test_security(workspace.clone()));
let result = tool.execute(json!({"path": "escape.txt"})).await.unwrap();
assert!(!result.success);
assert!(result
.error
.as_deref()
.unwrap_or("")
.contains("escapes workspace"));
let _ = tokio::fs::remove_dir_all(&root).await;
}
} }

View file

@ -69,7 +69,54 @@ impl Tool for FileWriteTool {
tokio::fs::create_dir_all(parent).await?; tokio::fs::create_dir_all(parent).await?;
} }
match tokio::fs::write(&full_path, content).await { let parent = match full_path.parent() {
Some(p) => p,
None => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some("Invalid path: missing parent directory".into()),
});
}
};
// Resolve parent before writing to block symlink escapes.
let resolved_parent = match tokio::fs::canonicalize(parent).await {
Ok(p) => p,
Err(e) => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!("Failed to resolve file path: {e}")),
});
}
};
if !self.security.is_resolved_path_allowed(&resolved_parent) {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!(
"Resolved path escapes workspace: {}",
resolved_parent.display()
)),
});
}
let file_name = match full_path.file_name() {
Some(name) => name,
None => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some("Invalid path: missing file name".into()),
});
}
};
let resolved_target = resolved_parent.join(file_name);
match tokio::fs::write(&resolved_target, content).await {
Ok(()) => Ok(ToolResult { Ok(()) => Ok(ToolResult {
success: true, success: true,
output: format!("Written {} bytes to {path}", content.len()), output: format!("Written {} bytes to {path}", content.len()),
@ -239,4 +286,36 @@ mod tests {
let _ = tokio::fs::remove_dir_all(&dir).await; let _ = tokio::fs::remove_dir_all(&dir).await;
} }
#[cfg(unix)]
#[tokio::test]
async fn file_write_blocks_symlink_escape() {
use std::os::unix::fs::symlink;
let root = std::env::temp_dir().join("zeroclaw_test_file_write_symlink_escape");
let workspace = root.join("workspace");
let outside = root.join("outside");
let _ = tokio::fs::remove_dir_all(&root).await;
tokio::fs::create_dir_all(&workspace).await.unwrap();
tokio::fs::create_dir_all(&outside).await.unwrap();
symlink(&outside, workspace.join("escape_dir")).unwrap();
let tool = FileWriteTool::new(test_security(workspace.clone()));
let result = tool
.execute(json!({"path": "escape_dir/hijack.txt", "content": "bad"}))
.await
.unwrap();
assert!(!result.success);
assert!(result
.error
.as_deref()
.unwrap_or("")
.contains("escapes workspace"));
assert!(!outside.join("hijack.txt").exists());
let _ = tokio::fs::remove_dir_all(&root).await;
}
} }

327
tests/dockerignore_test.rs Normal file
View file

@ -0,0 +1,327 @@
//! Tests to verify .dockerignore excludes sensitive paths from Docker build context.
//!
//! These tests validate that:
//! 1. The .dockerignore file exists
//! 2. All security-critical paths are excluded
//! 3. All build-essential paths are NOT excluded
//! 4. Pattern syntax is valid
use std::fs;
use std::path::Path;
/// Paths that MUST be excluded from Docker build context (security/performance)
const MUST_EXCLUDE: &[&str] = &[
".git",
".githooks",
"target",
"docs",
"examples",
"tests",
"*.md",
"*.png",
"*.db",
"*.db-journal",
".DS_Store",
".github",
"deny.toml",
"LICENSE",
".env",
".tmp_*",
];
/// Paths that MUST NOT be excluded (required for build)
const MUST_INCLUDE: &[&str] = &["Cargo.toml", "Cargo.lock", "src/"];
/// Parse .dockerignore and return all non-comment, non-empty lines
fn parse_dockerignore(content: &str) -> Vec<String> {
content
.lines()
.map(|line| line.trim())
.filter(|line| !line.is_empty() && !line.starts_with('#'))
.map(|line| line.to_string())
.collect()
}
/// Check if a pattern would match a given path
fn pattern_matches(pattern: &str, path: &str) -> bool {
// Handle negation patterns
if pattern.starts_with('!') {
return false; // Negation re-includes, so it doesn't "exclude"
}
// Handle glob patterns
if pattern.starts_with("*.") {
let ext = &pattern[1..]; // e.g., ".md"
return path.ends_with(ext);
}
// Handle directory patterns (with or without trailing slash)
let pattern_normalized = pattern.trim_end_matches('/');
let path_normalized = path.trim_end_matches('/');
// Exact match
if path_normalized == pattern_normalized {
return true;
}
// Pattern is a prefix (directory match)
if path_normalized.starts_with(&format!("{}/", pattern_normalized)) {
return true;
}
// Wildcard prefix patterns like ".tmp_*"
if pattern.contains('*') && !pattern.starts_with("*.") {
let prefix = pattern.split('*').next().unwrap_or("");
if !prefix.is_empty() && path.starts_with(prefix) {
return true;
}
}
false
}
/// Check if any pattern in the list would exclude the given path
fn is_excluded(patterns: &[String], path: &str) -> bool {
let mut excluded = false;
for pattern in patterns {
if pattern.starts_with('!') {
// Negation pattern - re-include
let negated = &pattern[1..];
if pattern_matches(negated, path) {
excluded = false;
}
} else if pattern_matches(pattern, path) {
excluded = true;
}
}
excluded
}
#[test]
fn dockerignore_file_exists() {
let path = Path::new(env!("CARGO_MANIFEST_DIR")).join(".dockerignore");
assert!(
path.exists(),
".dockerignore file must exist at project root"
);
}
#[test]
fn dockerignore_excludes_security_critical_paths() {
let path = Path::new(env!("CARGO_MANIFEST_DIR")).join(".dockerignore");
let content = fs::read_to_string(&path).expect("Failed to read .dockerignore");
let patterns = parse_dockerignore(&content);
for must_exclude in MUST_EXCLUDE {
// For glob patterns, test with a sample file
let test_path = if must_exclude.starts_with("*.") {
format!("sample{}", &must_exclude[1..])
} else {
must_exclude.to_string()
};
assert!(
is_excluded(&patterns, &test_path),
"Path '{}' (tested as '{}') MUST be excluded by .dockerignore but is not. \
This is a security/performance issue.",
must_exclude,
test_path
);
}
}
#[test]
fn dockerignore_does_not_exclude_build_essentials() {
let path = Path::new(env!("CARGO_MANIFEST_DIR")).join(".dockerignore");
let content = fs::read_to_string(&path).expect("Failed to read .dockerignore");
let patterns = parse_dockerignore(&content);
for must_include in MUST_INCLUDE {
assert!(
!is_excluded(&patterns, must_include),
"Path '{}' MUST NOT be excluded by .dockerignore (required for build)",
must_include
);
}
}
#[test]
fn dockerignore_excludes_git_directory() {
let path = Path::new(env!("CARGO_MANIFEST_DIR")).join(".dockerignore");
let content = fs::read_to_string(&path).expect("Failed to read .dockerignore");
let patterns = parse_dockerignore(&content);
// .git directory and its contents must be excluded
assert!(is_excluded(&patterns, ".git"), ".git must be excluded");
assert!(
is_excluded(&patterns, ".git/config"),
".git/config must be excluded"
);
assert!(
is_excluded(&patterns, ".git/objects/pack/pack-abc123.pack"),
".git subdirectories must be excluded"
);
}
#[test]
fn dockerignore_excludes_target_directory() {
let path = Path::new(env!("CARGO_MANIFEST_DIR")).join(".dockerignore");
let content = fs::read_to_string(&path).expect("Failed to read .dockerignore");
let patterns = parse_dockerignore(&content);
assert!(is_excluded(&patterns, "target"), "target must be excluded");
assert!(
is_excluded(&patterns, "target/debug/zeroclaw"),
"target/debug must be excluded"
);
assert!(
is_excluded(&patterns, "target/release/zeroclaw"),
"target/release must be excluded"
);
}
#[test]
fn dockerignore_excludes_database_files() {
let path = Path::new(env!("CARGO_MANIFEST_DIR")).join(".dockerignore");
let content = fs::read_to_string(&path).expect("Failed to read .dockerignore");
let patterns = parse_dockerignore(&content);
assert!(
is_excluded(&patterns, "brain.db"),
"*.db files must be excluded"
);
assert!(
is_excluded(&patterns, "memory.db"),
"*.db files must be excluded"
);
assert!(
is_excluded(&patterns, "brain.db-journal"),
"*.db-journal files must be excluded"
);
}
#[test]
fn dockerignore_excludes_markdown_files() {
let path = Path::new(env!("CARGO_MANIFEST_DIR")).join(".dockerignore");
let content = fs::read_to_string(&path).expect("Failed to read .dockerignore");
let patterns = parse_dockerignore(&content);
assert!(
is_excluded(&patterns, "README.md"),
"*.md files must be excluded"
);
assert!(
is_excluded(&patterns, "CHANGELOG.md"),
"*.md files must be excluded"
);
assert!(
is_excluded(&patterns, "CONTRIBUTING.md"),
"*.md files must be excluded"
);
}
#[test]
fn dockerignore_excludes_image_files() {
let path = Path::new(env!("CARGO_MANIFEST_DIR")).join(".dockerignore");
let content = fs::read_to_string(&path).expect("Failed to read .dockerignore");
let patterns = parse_dockerignore(&content);
assert!(
is_excluded(&patterns, "zeroclaw.png"),
"*.png files must be excluded"
);
assert!(
is_excluded(&patterns, "logo.png"),
"*.png files must be excluded"
);
}
#[test]
fn dockerignore_excludes_env_files() {
let path = Path::new(env!("CARGO_MANIFEST_DIR")).join(".dockerignore");
let content = fs::read_to_string(&path).expect("Failed to read .dockerignore");
let patterns = parse_dockerignore(&content);
assert!(
is_excluded(&patterns, ".env"),
".env must be excluded (contains secrets)"
);
}
#[test]
fn dockerignore_excludes_ci_configs() {
let path = Path::new(env!("CARGO_MANIFEST_DIR")).join(".dockerignore");
let content = fs::read_to_string(&path).expect("Failed to read .dockerignore");
let patterns = parse_dockerignore(&content);
assert!(
is_excluded(&patterns, ".github"),
".github must be excluded"
);
assert!(
is_excluded(&patterns, ".github/workflows/ci.yml"),
".github/workflows must be excluded"
);
}
#[test]
fn dockerignore_has_valid_syntax() {
let path = Path::new(env!("CARGO_MANIFEST_DIR")).join(".dockerignore");
let content = fs::read_to_string(&path).expect("Failed to read .dockerignore");
for (line_num, line) in content.lines().enumerate() {
let trimmed = line.trim();
// Skip empty lines and comments
if trimmed.is_empty() || trimmed.starts_with('#') {
continue;
}
// Check for invalid patterns
assert!(
!trimmed.contains("**") || trimmed.matches("**").count() <= 2,
"Line {}: Too many ** in pattern '{}'",
line_num + 1,
trimmed
);
// Check for trailing spaces (can cause issues)
assert!(
line.trim_end() == line.trim_start().trim_end(),
"Line {}: Pattern '{}' has leading whitespace which may cause issues",
line_num + 1,
line
);
}
}
#[test]
fn dockerignore_pattern_matching_edge_cases() {
// Test the pattern matching logic itself
let patterns = vec![
".git".to_string(),
".githooks".to_string(),
"target".to_string(),
"*.md".to_string(),
"*.db".to_string(),
".tmp_*".to_string(),
".env".to_string(),
];
// Should match
assert!(is_excluded(&patterns, ".git"));
assert!(is_excluded(&patterns, ".git/config"));
assert!(is_excluded(&patterns, ".githooks"));
assert!(is_excluded(&patterns, "target"));
assert!(is_excluded(&patterns, "target/debug/build"));
assert!(is_excluded(&patterns, "README.md"));
assert!(is_excluded(&patterns, "brain.db"));
assert!(is_excluded(&patterns, ".tmp_todo_probe"));
assert!(is_excluded(&patterns, ".env"));
// Should NOT match
assert!(!is_excluded(&patterns, "src"));
assert!(!is_excluded(&patterns, "src/main.rs"));
assert!(!is_excluded(&patterns, "Cargo.toml"));
assert!(!is_excluded(&patterns, "Cargo.lock"));
}