merge: resolve conflicts between feat/whatsapp-email-channels and main
- Keep main's WhatsApp implementation (webhook-based, simpler) - Preserve email channel fixes from our branch - Merge all main branch updates (daemon, cron, health, etc.) - Resolve Cargo.lock conflicts
This commit is contained in:
commit
4e6da51924
40 changed files with 6925 additions and 780 deletions
66
.dockerignore
Normal file
66
.dockerignore
Normal file
|
|
@ -0,0 +1,66 @@
|
||||||
|
# Git history (may contain old secrets)
|
||||||
|
.git
|
||||||
|
.gitignore
|
||||||
|
.githooks
|
||||||
|
|
||||||
|
# Rust build artifacts (can be multiple GB)
|
||||||
|
target
|
||||||
|
|
||||||
|
# Documentation and examples (not needed for runtime)
|
||||||
|
docs
|
||||||
|
examples
|
||||||
|
tests
|
||||||
|
|
||||||
|
# Markdown files (README, CHANGELOG, etc.)
|
||||||
|
*.md
|
||||||
|
|
||||||
|
# Images (unnecessary for build)
|
||||||
|
*.png
|
||||||
|
*.svg
|
||||||
|
*.jpg
|
||||||
|
*.jpeg
|
||||||
|
*.gif
|
||||||
|
|
||||||
|
# SQLite databases (conversation history, cron jobs)
|
||||||
|
*.db
|
||||||
|
*.db-journal
|
||||||
|
|
||||||
|
# macOS artifacts
|
||||||
|
.DS_Store
|
||||||
|
.AppleDouble
|
||||||
|
.LSOverride
|
||||||
|
|
||||||
|
# CI/CD configs (not needed in image)
|
||||||
|
.github
|
||||||
|
|
||||||
|
# Cargo deny config (lint tool, not runtime)
|
||||||
|
deny.toml
|
||||||
|
|
||||||
|
# License file (not needed for runtime)
|
||||||
|
LICENSE
|
||||||
|
|
||||||
|
# Temporary files
|
||||||
|
.tmp_*
|
||||||
|
*.tmp
|
||||||
|
*.bak
|
||||||
|
*.swp
|
||||||
|
*~
|
||||||
|
|
||||||
|
# IDE and editor configs
|
||||||
|
.idea
|
||||||
|
.vscode
|
||||||
|
*.iml
|
||||||
|
|
||||||
|
# Windsurf workflows
|
||||||
|
.windsurf
|
||||||
|
|
||||||
|
# Environment files (may contain secrets)
|
||||||
|
.env
|
||||||
|
.env.*
|
||||||
|
!.env.example
|
||||||
|
|
||||||
|
# Coverage and profiling
|
||||||
|
*.profraw
|
||||||
|
*.profdata
|
||||||
|
coverage
|
||||||
|
lcov.info
|
||||||
37
.github/workflows/ci.yml
vendored
37
.github/workflows/ci.yml
vendored
|
|
@ -63,3 +63,40 @@ jobs:
|
||||||
with:
|
with:
|
||||||
name: zeroclaw-${{ matrix.target }}
|
name: zeroclaw-${{ matrix.target }}
|
||||||
path: target/${{ matrix.target }}/release/zeroclaw*
|
path: target/${{ matrix.target }}/release/zeroclaw*
|
||||||
|
|
||||||
|
docker:
|
||||||
|
name: Docker Security
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Build Docker image
|
||||||
|
run: docker build -t zeroclaw:test .
|
||||||
|
|
||||||
|
- name: Verify non-root user (UID != 0)
|
||||||
|
run: |
|
||||||
|
USER_ID=$(docker inspect --format='{{.Config.User}}' zeroclaw:test)
|
||||||
|
echo "Container user: $USER_ID"
|
||||||
|
if [ "$USER_ID" = "0" ] || [ "$USER_ID" = "root" ] || [ -z "$USER_ID" ]; then
|
||||||
|
echo "❌ FAIL: Container runs as root (UID 0)"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
echo "✅ PASS: Container runs as non-root user ($USER_ID)"
|
||||||
|
|
||||||
|
- name: Verify distroless nonroot base image
|
||||||
|
run: |
|
||||||
|
BASE_IMAGE=$(grep -E '^FROM.*runtime|^FROM gcr.io/distroless' Dockerfile | tail -1)
|
||||||
|
echo "Base image line: $BASE_IMAGE"
|
||||||
|
if ! echo "$BASE_IMAGE" | grep -q ':nonroot'; then
|
||||||
|
echo "❌ FAIL: Runtime stage does not use :nonroot variant"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
echo "✅ PASS: Using distroless :nonroot variant"
|
||||||
|
|
||||||
|
- name: Verify USER directive exists
|
||||||
|
run: |
|
||||||
|
if ! grep -qE '^USER\s+[0-9]+' Dockerfile; then
|
||||||
|
echo "❌ FAIL: No explicit USER directive with numeric UID"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
echo "✅ PASS: Explicit USER directive found"
|
||||||
|
|
|
||||||
0
.tmp_todo_probe
Normal file
0
.tmp_todo_probe
Normal file
18
CHANGELOG.md
18
CHANGELOG.md
|
|
@ -5,6 +5,24 @@ All notable changes to ZeroClaw will be documented in this file.
|
||||||
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
|
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.1.0/),
|
||||||
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
||||||
|
|
||||||
|
## [Unreleased]
|
||||||
|
|
||||||
|
### Security
|
||||||
|
- **Legacy XOR cipher migration**: The `enc:` prefix (XOR cipher) is now deprecated.
|
||||||
|
Secrets using this format will be automatically migrated to `enc2:` (ChaCha20-Poly1305 AEAD)
|
||||||
|
when decrypted via `decrypt_and_migrate()`. A `tracing::warn!` is emitted when legacy
|
||||||
|
values are encountered. The XOR cipher will be removed in a future release.
|
||||||
|
|
||||||
|
### Added
|
||||||
|
- `SecretStore::decrypt_and_migrate()` — Decrypts secrets and returns a migrated `enc2:`
|
||||||
|
value if the input used the legacy `enc:` format
|
||||||
|
- `SecretStore::needs_migration()` — Check if a value uses the legacy `enc:` format
|
||||||
|
- `SecretStore::is_secure_encrypted()` — Check if a value uses the secure `enc2:` format
|
||||||
|
|
||||||
|
### Deprecated
|
||||||
|
- `enc:` prefix for encrypted secrets — Use `enc2:` (ChaCha20-Poly1305) instead.
|
||||||
|
Legacy values are still decrypted for backward compatibility but should be migrated.
|
||||||
|
|
||||||
## [0.1.0] - 2025-02-13
|
## [0.1.0] - 2025-02-13
|
||||||
|
|
||||||
### Added
|
### Added
|
||||||
|
|
|
||||||
|
|
@ -15,7 +15,7 @@ categories = ["command-line-utilities", "api-bindings"]
|
||||||
clap = { version = "4.5", features = ["derive"] }
|
clap = { version = "4.5", features = ["derive"] }
|
||||||
|
|
||||||
# Async runtime - feature-optimized for size
|
# Async runtime - feature-optimized for size
|
||||||
tokio = { version = "1.42", default-features = false, features = ["rt-multi-thread", "macros", "time", "net", "io-util", "sync", "process", "io-std", "fs"] }
|
tokio = { version = "1.42", default-features = false, features = ["rt-multi-thread", "macros", "time", "net", "io-util", "sync", "process", "io-std", "fs", "signal"] }
|
||||||
|
|
||||||
# HTTP client - minimal features
|
# HTTP client - minimal features
|
||||||
reqwest = { version = "0.12", default-features = false, features = ["json", "rustls-tls", "blocking"] }
|
reqwest = { version = "0.12", default-features = false, features = ["json", "rustls-tls", "blocking"] }
|
||||||
|
|
@ -49,6 +49,7 @@ async-trait = "0.1"
|
||||||
# Memory / persistence
|
# Memory / persistence
|
||||||
rusqlite = { version = "0.32", features = ["bundled"] }
|
rusqlite = { version = "0.32", features = ["bundled"] }
|
||||||
chrono = { version = "0.4", default-features = false, features = ["clock", "std"] }
|
chrono = { version = "0.4", default-features = false, features = ["clock", "std"] }
|
||||||
|
cron = "0.12"
|
||||||
|
|
||||||
# Interactive CLI prompts
|
# Interactive CLI prompts
|
||||||
dialoguer = { version = "0.11", features = ["fuzzy-select"] }
|
dialoguer = { version = "0.11", features = ["fuzzy-select"] }
|
||||||
|
|
@ -64,6 +65,12 @@ rustls-pki-types = "1.14.0"
|
||||||
tokio-rustls = "0.26.4"
|
tokio-rustls = "0.26.4"
|
||||||
webpki-roots = "1.0.6"
|
webpki-roots = "1.0.6"
|
||||||
|
|
||||||
|
# HTTP server (gateway) — replaces raw TCP for proper HTTP/1.1 compliance
|
||||||
|
axum = { version = "0.7", default-features = false, features = ["http1", "json", "tokio", "query"] }
|
||||||
|
tower = { version = "0.5", default-features = false }
|
||||||
|
tower-http = { version = "0.6", default-features = false, features = ["limit", "timeout"] }
|
||||||
|
http-body-util = "0.1"
|
||||||
|
|
||||||
[profile.release]
|
[profile.release]
|
||||||
opt-level = "z" # Optimize for size
|
opt-level = "z" # Optimize for size
|
||||||
lto = true # Link-time optimization
|
lto = true # Link-time optimization
|
||||||
|
|
|
||||||
|
|
@ -8,14 +8,17 @@ COPY src/ src/
|
||||||
RUN cargo build --release --locked && \
|
RUN cargo build --release --locked && \
|
||||||
strip target/release/zeroclaw
|
strip target/release/zeroclaw
|
||||||
|
|
||||||
# ── Stage 2: Runtime (distroless — no shell, no OS, tiny) ────
|
# ── Stage 2: Runtime (distroless nonroot — no shell, no OS, tiny, UID 65534) ──
|
||||||
FROM gcr.io/distroless/cc-debian12
|
FROM gcr.io/distroless/cc-debian12:nonroot
|
||||||
|
|
||||||
COPY --from=builder /app/target/release/zeroclaw /usr/local/bin/zeroclaw
|
COPY --from=builder /app/target/release/zeroclaw /usr/local/bin/zeroclaw
|
||||||
|
|
||||||
# Default workspace
|
# Default workspace (owned by nonroot user)
|
||||||
VOLUME ["/workspace"]
|
VOLUME ["/workspace"]
|
||||||
ENV ZEROCLAW_WORKSPACE=/workspace
|
ENV ZEROCLAW_WORKSPACE=/workspace
|
||||||
|
|
||||||
|
# Explicitly set non-root user (distroless:nonroot defaults to 65534, but be explicit)
|
||||||
|
USER 65534:65534
|
||||||
|
|
||||||
ENTRYPOINT ["zeroclaw"]
|
ENTRYPOINT ["zeroclaw"]
|
||||||
CMD ["gateway"]
|
CMD ["gateway"]
|
||||||
|
|
|
||||||
212
README.md
212
README.md
|
|
@ -12,12 +12,19 @@
|
||||||
<a href="LICENSE"><img src="https://img.shields.io/badge/license-MIT-blue.svg" alt="License: MIT" /></a>
|
<a href="LICENSE"><img src="https://img.shields.io/badge/license-MIT-blue.svg" alt="License: MIT" /></a>
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
The fastest, smallest, fully autonomous AI assistant — deploy anywhere, swap anything.
|
Fast, small, and fully autonomous AI assistant infrastructure — deploy anywhere, swap anything.
|
||||||
|
|
||||||
```
|
```
|
||||||
~3.4MB binary · <10ms startup · 1,017 tests · 22+ providers · 8 traits · Pluggable everything
|
~3.4MB binary · <10ms startup · 1,017 tests · 22+ providers · 8 traits · Pluggable everything
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Why teams pick ZeroClaw
|
||||||
|
|
||||||
|
- **Lean by default:** small Rust binary, fast startup, low memory footprint.
|
||||||
|
- **Secure by design:** pairing, strict sandboxing, explicit allowlists, workspace scoping.
|
||||||
|
- **Fully swappable:** core systems are traits (providers, channels, tools, memory, tunnels).
|
||||||
|
- **No lock-in:** OpenAI-compatible provider support + pluggable custom endpoints.
|
||||||
|
|
||||||
## Benchmark Snapshot (ZeroClaw vs OpenClaw)
|
## Benchmark Snapshot (ZeroClaw vs OpenClaw)
|
||||||
|
|
||||||
Local machine quick benchmark (macOS arm64, Feb 2026), same host, 3 runs each.
|
Local machine quick benchmark (macOS arm64, Feb 2026), same host, 3 runs each.
|
||||||
|
|
@ -30,7 +37,17 @@ Local machine quick benchmark (macOS arm64, Feb 2026), same host, 3 runs each.
|
||||||
| `--help` max RSS observed | **~7.3 MB** | **~394 MB** |
|
| `--help` max RSS observed | **~7.3 MB** | **~394 MB** |
|
||||||
| `status` max RSS observed | **~7.8 MB** | **~1.52 GB** |
|
| `status` max RSS observed | **~7.8 MB** | **~1.52 GB** |
|
||||||
|
|
||||||
> Notes: measured with `/usr/bin/time -l`; first run includes cold-start effects. OpenClaw results include `pnpm install` + `pnpm build` before execution.
|
> Notes: measured with `/usr/bin/time -l`; first run includes cold-start effects. OpenClaw results were measured after `pnpm install` + `pnpm build`.
|
||||||
|
|
||||||
|
Reproduce ZeroClaw numbers locally:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
cargo build --release
|
||||||
|
ls -lh target/release/zeroclaw
|
||||||
|
|
||||||
|
/usr/bin/time -l target/release/zeroclaw --help
|
||||||
|
/usr/bin/time -l target/release/zeroclaw status
|
||||||
|
```
|
||||||
|
|
||||||
## Quick Start
|
## Quick Start
|
||||||
|
|
||||||
|
|
@ -38,34 +55,52 @@ Local machine quick benchmark (macOS arm64, Feb 2026), same host, 3 runs each.
|
||||||
git clone https://github.com/theonlyhennygod/zeroclaw.git
|
git clone https://github.com/theonlyhennygod/zeroclaw.git
|
||||||
cd zeroclaw
|
cd zeroclaw
|
||||||
cargo build --release
|
cargo build --release
|
||||||
|
cargo install --path . --force
|
||||||
|
|
||||||
# Quick setup (no prompts)
|
# Quick setup (no prompts)
|
||||||
cargo run --release -- onboard --api-key sk-... --provider openrouter
|
zeroclaw onboard --api-key sk-... --provider openrouter
|
||||||
|
|
||||||
# Or interactive wizard
|
# Or interactive wizard
|
||||||
cargo run --release -- onboard --interactive
|
zeroclaw onboard --interactive
|
||||||
|
|
||||||
|
# Or quickly repair channels/allowlists only
|
||||||
|
zeroclaw onboard --channels-only
|
||||||
|
|
||||||
# Chat
|
# Chat
|
||||||
cargo run --release -- agent -m "Hello, ZeroClaw!"
|
zeroclaw agent -m "Hello, ZeroClaw!"
|
||||||
|
|
||||||
# Interactive mode
|
# Interactive mode
|
||||||
cargo run --release -- agent
|
zeroclaw agent
|
||||||
|
|
||||||
# Start the gateway (webhook server)
|
# Start the gateway (webhook server)
|
||||||
cargo run --release -- gateway # default: 127.0.0.1:8080
|
zeroclaw gateway # default: 127.0.0.1:8080
|
||||||
cargo run --release -- gateway --port 0 # random port (security hardened)
|
zeroclaw gateway --port 0 # random port (security hardened)
|
||||||
|
|
||||||
|
# Start full autonomous runtime
|
||||||
|
zeroclaw daemon
|
||||||
|
|
||||||
# Check status
|
# Check status
|
||||||
cargo run --release -- status
|
zeroclaw status
|
||||||
|
|
||||||
|
# Run system diagnostics
|
||||||
|
zeroclaw doctor
|
||||||
|
|
||||||
# Check channel health
|
# Check channel health
|
||||||
cargo run --release -- channel doctor
|
zeroclaw channel doctor
|
||||||
|
|
||||||
# Get integration setup details
|
# Get integration setup details
|
||||||
cargo run --release -- integrations info Telegram
|
zeroclaw integrations info Telegram
|
||||||
|
|
||||||
|
# Manage background service
|
||||||
|
zeroclaw service install
|
||||||
|
zeroclaw service status
|
||||||
|
|
||||||
|
# Migrate memory from OpenClaw (safe preview first)
|
||||||
|
zeroclaw migrate openclaw --dry-run
|
||||||
|
zeroclaw migrate openclaw
|
||||||
```
|
```
|
||||||
|
|
||||||
> **Tip:** Run `cargo install --path .` to install `zeroclaw` globally, then use `zeroclaw` instead of `cargo run --release --`.
|
> **Dev fallback (no global install):** prefix commands with `cargo run --release --` (example: `cargo run --release -- status`).
|
||||||
|
|
||||||
## Architecture
|
## Architecture
|
||||||
|
|
||||||
|
|
@ -78,17 +113,25 @@ Every subsystem is a **trait** — swap implementations with a config change, ze
|
||||||
| Subsystem | Trait | Ships with | Extend |
|
| Subsystem | Trait | Ships with | Extend |
|
||||||
|-----------|-------|------------|--------|
|
|-----------|-------|------------|--------|
|
||||||
| **AI Models** | `Provider` | 22+ providers (OpenRouter, Anthropic, OpenAI, Ollama, Venice, Groq, Mistral, xAI, DeepSeek, Together, Fireworks, Perplexity, Cohere, Bedrock, etc.) | `custom:https://your-api.com` — any OpenAI-compatible API |
|
| **AI Models** | `Provider` | 22+ providers (OpenRouter, Anthropic, OpenAI, Ollama, Venice, Groq, Mistral, xAI, DeepSeek, Together, Fireworks, Perplexity, Cohere, Bedrock, etc.) | `custom:https://your-api.com` — any OpenAI-compatible API |
|
||||||
| **Channels** | `Channel` | CLI, Telegram, Discord, Slack, iMessage, Matrix, Webhook | Any messaging API |
|
| **Channels** | `Channel` | CLI, Telegram, Discord, Slack, iMessage, Matrix, WhatsApp, Webhook | Any messaging API |
|
||||||
| **Memory** | `Memory` | SQLite with hybrid search (FTS5 + vector cosine similarity), Markdown | Any persistence backend |
|
| **Memory** | `Memory` | SQLite with hybrid search (FTS5 + vector cosine similarity), Markdown | Any persistence backend |
|
||||||
| **Tools** | `Tool` | shell, file_read, file_write, memory_store, memory_recall, memory_forget, browser_open (Brave + allowlist), composio (optional) | Any capability |
|
| **Tools** | `Tool` | shell, file_read, file_write, memory_store, memory_recall, memory_forget, browser_open (Brave + allowlist), composio (optional) | Any capability |
|
||||||
| **Observability** | `Observer` | Noop, Log, Multi | Prometheus, OTel |
|
| **Observability** | `Observer` | Noop, Log, Multi | Prometheus, OTel |
|
||||||
| **Runtime** | `RuntimeAdapter` | Native (Mac/Linux/Pi) | Docker, WASM |
|
| **Runtime** | `RuntimeAdapter` | Native (Mac/Linux/Pi) | Docker, WASM (planned; unsupported kinds fail fast) |
|
||||||
| **Security** | `SecurityPolicy` | Gateway pairing, sandbox, allowlists, rate limits, filesystem scoping, encrypted secrets | — |
|
| **Security** | `SecurityPolicy` | Gateway pairing, sandbox, allowlists, rate limits, filesystem scoping, encrypted secrets | — |
|
||||||
|
| **Identity** | `IdentityConfig` | OpenClaw (markdown), AIEOS v1.1 (JSON) | Any identity format |
|
||||||
| **Tunnel** | `Tunnel` | None, Cloudflare, Tailscale, ngrok, Custom | Any tunnel binary |
|
| **Tunnel** | `Tunnel` | None, Cloudflare, Tailscale, ngrok, Custom | Any tunnel binary |
|
||||||
| **Heartbeat** | Engine | HEARTBEAT.md periodic tasks | — |
|
| **Heartbeat** | Engine | HEARTBEAT.md periodic tasks | — |
|
||||||
| **Skills** | Loader | TOML manifests + SKILL.md instructions | Community skill packs |
|
| **Skills** | Loader | TOML manifests + SKILL.md instructions | Community skill packs |
|
||||||
| **Integrations** | Registry | 50+ integrations across 9 categories | Plugin system |
|
| **Integrations** | Registry | 50+ integrations across 9 categories | Plugin system |
|
||||||
|
|
||||||
|
### Runtime support (current)
|
||||||
|
|
||||||
|
- ✅ Supported today: `runtime.kind = "native"`
|
||||||
|
- 🚧 Planned, not implemented yet: Docker / WASM / edge runtimes
|
||||||
|
|
||||||
|
When an unsupported `runtime.kind` is configured, ZeroClaw now exits with a clear error instead of silently falling back to native.
|
||||||
|
|
||||||
### Memory System (Full-Stack Search Engine)
|
### Memory System (Full-Stack Search Engine)
|
||||||
|
|
||||||
All custom, zero external dependencies — no Pinecone, no Elasticsearch, no LangChain:
|
All custom, zero external dependencies — no Pinecone, no Elasticsearch, no LangChain:
|
||||||
|
|
@ -124,7 +167,7 @@ ZeroClaw enforces security at **every layer** — not just the sandbox. It passe
|
||||||
|---|------|--------|-----|
|
|---|------|--------|-----|
|
||||||
| 1 | **Gateway not publicly exposed** | ✅ | Binds `127.0.0.1` by default. Refuses `0.0.0.0` without tunnel or explicit `allow_public_bind = true`. |
|
| 1 | **Gateway not publicly exposed** | ✅ | Binds `127.0.0.1` by default. Refuses `0.0.0.0` without tunnel or explicit `allow_public_bind = true`. |
|
||||||
| 2 | **Pairing required** | ✅ | 6-digit one-time code on startup. Exchange via `POST /pair` for bearer token. All `/webhook` requests require `Authorization: Bearer <token>`. |
|
| 2 | **Pairing required** | ✅ | 6-digit one-time code on startup. Exchange via `POST /pair` for bearer token. All `/webhook` requests require `Authorization: Bearer <token>`. |
|
||||||
| 3 | **Filesystem scoped (no /)** | ✅ | `workspace_only = true` by default. 14 system dirs + 4 sensitive dotfiles blocked. Null byte injection blocked. Symlink escape detection via canonicalization. |
|
| 3 | **Filesystem scoped (no /)** | ✅ | `workspace_only = true` by default. 14 system dirs + 4 sensitive dotfiles blocked. Null byte injection blocked. Symlink escape detection via canonicalization + resolved-path workspace checks in file read/write tools. |
|
||||||
| 4 | **Access via tunnel only** | ✅ | Gateway refuses public bind without active tunnel. Supports Tailscale, Cloudflare, ngrok, or any custom tunnel. |
|
| 4 | **Access via tunnel only** | ✅ | Gateway refuses public bind without active tunnel. Supports Tailscale, Cloudflare, ngrok, or any custom tunnel. |
|
||||||
|
|
||||||
> **Run your own nmap:** `nmap -p 1-65535 <your-host>` — ZeroClaw binds to localhost only, so nothing is exposed unless you explicitly configure a tunnel.
|
> **Run your own nmap:** `nmap -p 1-65535 <your-host>` — ZeroClaw binds to localhost only, so nothing is exposed unless you explicitly configure a tunnel.
|
||||||
|
|
@ -139,6 +182,63 @@ Inbound sender policy is now consistent:
|
||||||
|
|
||||||
This keeps accidental exposure low by default.
|
This keeps accidental exposure low by default.
|
||||||
|
|
||||||
|
Recommended low-friction setup (secure + fast):
|
||||||
|
|
||||||
|
- **Telegram:** allowlist your own `@username` (without `@`) and/or your numeric Telegram user ID.
|
||||||
|
- **Discord:** allowlist your own Discord user ID.
|
||||||
|
- **Slack:** allowlist your own Slack member ID (usually starts with `U`).
|
||||||
|
- Use `"*"` only for temporary open testing.
|
||||||
|
|
||||||
|
If you're not sure which identity to use:
|
||||||
|
|
||||||
|
1. Start channels and send one message to your bot.
|
||||||
|
2. Read the warning log to see the exact sender identity.
|
||||||
|
3. Add that value to the allowlist and rerun channels-only setup.
|
||||||
|
|
||||||
|
If you hit authorization warnings in logs (for example: `ignoring message from unauthorized user`),
|
||||||
|
rerun channel setup only:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
zeroclaw onboard --channels-only
|
||||||
|
```
|
||||||
|
|
||||||
|
### WhatsApp Business Cloud API Setup
|
||||||
|
|
||||||
|
WhatsApp uses Meta's Cloud API with webhooks (push-based, not polling):
|
||||||
|
|
||||||
|
1. **Create a Meta Business App:**
|
||||||
|
- Go to [developers.facebook.com](https://developers.facebook.com)
|
||||||
|
- Create a new app → Select "Business" type
|
||||||
|
- Add the "WhatsApp" product
|
||||||
|
|
||||||
|
2. **Get your credentials:**
|
||||||
|
- **Access Token:** From WhatsApp → API Setup → Generate token (or create a System User for permanent tokens)
|
||||||
|
- **Phone Number ID:** From WhatsApp → API Setup → Phone number ID
|
||||||
|
- **Verify Token:** You define this (any random string) — Meta will send it back during webhook verification
|
||||||
|
|
||||||
|
3. **Configure ZeroClaw:**
|
||||||
|
```toml
|
||||||
|
[channels_config.whatsapp]
|
||||||
|
access_token = "EAABx..."
|
||||||
|
phone_number_id = "123456789012345"
|
||||||
|
verify_token = "my-secret-verify-token"
|
||||||
|
allowed_numbers = ["+1234567890"] # E.164 format, or ["*"] for all
|
||||||
|
```
|
||||||
|
|
||||||
|
4. **Start the gateway with a tunnel:**
|
||||||
|
```bash
|
||||||
|
zeroclaw gateway --port 8080
|
||||||
|
```
|
||||||
|
WhatsApp requires HTTPS, so use a tunnel (ngrok, Cloudflare, Tailscale Funnel).
|
||||||
|
|
||||||
|
5. **Configure Meta webhook:**
|
||||||
|
- In Meta Developer Console → WhatsApp → Configuration → Webhook
|
||||||
|
- **Callback URL:** `https://your-tunnel-url/whatsapp`
|
||||||
|
- **Verify Token:** Same as your `verify_token` in config
|
||||||
|
- Subscribe to `messages` field
|
||||||
|
|
||||||
|
6. **Test:** Send a message to your WhatsApp Business number — ZeroClaw will respond via the LLM.
|
||||||
|
|
||||||
## Configuration
|
## Configuration
|
||||||
|
|
||||||
Config: `~/.zeroclaw/config.toml` (created by `onboard`)
|
Config: `~/.zeroclaw/config.toml` (created by `onboard`)
|
||||||
|
|
@ -166,6 +266,9 @@ workspace_only = true # default: true — scoped to workspace
|
||||||
allowed_commands = ["git", "npm", "cargo", "ls", "cat", "grep"]
|
allowed_commands = ["git", "npm", "cargo", "ls", "cat", "grep"]
|
||||||
forbidden_paths = ["/etc", "/root", "/proc", "/sys", "~/.ssh", "~/.gnupg", "~/.aws"]
|
forbidden_paths = ["/etc", "/root", "/proc", "/sys", "~/.ssh", "~/.gnupg", "~/.aws"]
|
||||||
|
|
||||||
|
[runtime]
|
||||||
|
kind = "native" # only supported value right now; unsupported kinds fail fast
|
||||||
|
|
||||||
[heartbeat]
|
[heartbeat]
|
||||||
enabled = false
|
enabled = false
|
||||||
interval_minutes = 30
|
interval_minutes = 30
|
||||||
|
|
@ -182,8 +285,81 @@ allowed_domains = ["docs.rs"] # required when browser is enabled
|
||||||
|
|
||||||
[composio]
|
[composio]
|
||||||
enabled = false # opt-in: 1000+ OAuth apps via composio.dev
|
enabled = false # opt-in: 1000+ OAuth apps via composio.dev
|
||||||
|
|
||||||
|
[identity]
|
||||||
|
format = "openclaw" # "openclaw" (default, markdown files) or "aieos" (JSON)
|
||||||
|
# aieos_path = "identity.json" # path to AIEOS JSON file (relative to workspace or absolute)
|
||||||
|
# aieos_inline = '{"identity":{"names":{"first":"Nova"}}}' # inline AIEOS JSON
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Identity System (AIEOS Support)
|
||||||
|
|
||||||
|
ZeroClaw supports **identity-agnostic** AI personas through two formats:
|
||||||
|
|
||||||
|
### OpenClaw (Default)
|
||||||
|
|
||||||
|
Traditional markdown files in your workspace:
|
||||||
|
- `IDENTITY.md` — Who the agent is
|
||||||
|
- `SOUL.md` — Core personality and values
|
||||||
|
- `USER.md` — Who the agent is helping
|
||||||
|
- `AGENTS.md` — Behavior guidelines
|
||||||
|
|
||||||
|
### AIEOS (AI Entity Object Specification)
|
||||||
|
|
||||||
|
[AIEOS](https://aieos.org) is a standardization framework for portable AI identity. ZeroClaw supports AIEOS v1.1 JSON payloads, allowing you to:
|
||||||
|
|
||||||
|
- **Import identities** from the AIEOS ecosystem
|
||||||
|
- **Export identities** to other AIEOS-compatible systems
|
||||||
|
- **Maintain behavioral integrity** across different AI models
|
||||||
|
|
||||||
|
#### Enable AIEOS
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[identity]
|
||||||
|
format = "aieos"
|
||||||
|
aieos_path = "identity.json" # relative to workspace or absolute path
|
||||||
|
```
|
||||||
|
|
||||||
|
Or inline JSON:
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[identity]
|
||||||
|
format = "aieos"
|
||||||
|
aieos_inline = '''
|
||||||
|
{
|
||||||
|
"identity": {
|
||||||
|
"names": { "first": "Nova", "nickname": "N" }
|
||||||
|
},
|
||||||
|
"psychology": {
|
||||||
|
"neural_matrix": { "creativity": 0.9, "logic": 0.8 },
|
||||||
|
"traits": { "mbti": "ENTP" },
|
||||||
|
"moral_compass": { "alignment": "Chaotic Good" }
|
||||||
|
},
|
||||||
|
"linguistics": {
|
||||||
|
"text_style": { "formality_level": 0.2, "slang_usage": true }
|
||||||
|
},
|
||||||
|
"motivations": {
|
||||||
|
"core_drive": "Push boundaries and explore possibilities"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
'''
|
||||||
|
```
|
||||||
|
|
||||||
|
#### AIEOS Schema Sections
|
||||||
|
|
||||||
|
| Section | Description |
|
||||||
|
|---------|-------------|
|
||||||
|
| `identity` | Names, bio, origin, residence |
|
||||||
|
| `psychology` | Neural matrix (cognitive weights), MBTI, OCEAN, moral compass |
|
||||||
|
| `linguistics` | Text style, formality, catchphrases, forbidden words |
|
||||||
|
| `motivations` | Core drive, short/long-term goals, fears |
|
||||||
|
| `capabilities` | Skills and tools the agent can access |
|
||||||
|
| `physicality` | Visual descriptors for image generation |
|
||||||
|
| `history` | Origin story, education, occupation |
|
||||||
|
| `interests` | Hobbies, favorites, lifestyle |
|
||||||
|
|
||||||
|
See [aieos.org](https://aieos.org) for the full schema and live examples.
|
||||||
|
|
||||||
## Gateway API
|
## Gateway API
|
||||||
|
|
||||||
| Endpoint | Method | Auth | Description |
|
| Endpoint | Method | Auth | Description |
|
||||||
|
|
@ -191,6 +367,8 @@ enabled = false # opt-in: 1000+ OAuth apps via composio.dev
|
||||||
| `/health` | GET | None | Health check (always public, no secrets leaked) |
|
| `/health` | GET | None | Health check (always public, no secrets leaked) |
|
||||||
| `/pair` | POST | `X-Pairing-Code` header | Exchange one-time code for bearer token |
|
| `/pair` | POST | `X-Pairing-Code` header | Exchange one-time code for bearer token |
|
||||||
| `/webhook` | POST | `Authorization: Bearer <token>` | Send message: `{"message": "your prompt"}` |
|
| `/webhook` | POST | `Authorization: Bearer <token>` | Send message: `{"message": "your prompt"}` |
|
||||||
|
| `/whatsapp` | GET | Query params | Meta webhook verification (hub.mode, hub.verify_token, hub.challenge) |
|
||||||
|
| `/whatsapp` | POST | None (Meta signature) | WhatsApp incoming message webhook |
|
||||||
|
|
||||||
## Commands
|
## Commands
|
||||||
|
|
||||||
|
|
@ -198,10 +376,14 @@ enabled = false # opt-in: 1000+ OAuth apps via composio.dev
|
||||||
|---------|-------------|
|
|---------|-------------|
|
||||||
| `onboard` | Quick setup (default) |
|
| `onboard` | Quick setup (default) |
|
||||||
| `onboard --interactive` | Full interactive 7-step wizard |
|
| `onboard --interactive` | Full interactive 7-step wizard |
|
||||||
|
| `onboard --channels-only` | Reconfigure channels/allowlists only (fast repair flow) |
|
||||||
| `agent -m "..."` | Single message mode |
|
| `agent -m "..."` | Single message mode |
|
||||||
| `agent` | Interactive chat mode |
|
| `agent` | Interactive chat mode |
|
||||||
| `gateway` | Start webhook server (default: `127.0.0.1:8080`) |
|
| `gateway` | Start webhook server (default: `127.0.0.1:8080`) |
|
||||||
| `gateway --port 0` | Random port mode |
|
| `gateway --port 0` | Random port mode |
|
||||||
|
| `daemon` | Start long-running autonomous runtime |
|
||||||
|
| `service install/start/stop/status/uninstall` | Manage user-level background service |
|
||||||
|
| `doctor` | Diagnose daemon/scheduler/channel freshness |
|
||||||
| `status` | Show full system status |
|
| `status` | Show full system status |
|
||||||
| `channel doctor` | Run health checks for configured channels |
|
| `channel doctor` | Run health checks for configured channels |
|
||||||
| `integrations info <name>` | Show setup/status details for one integration |
|
| `integrations info <name>` | Show setup/status details for one integration |
|
||||||
|
|
|
||||||
30
SECURITY.md
30
SECURITY.md
|
|
@ -61,3 +61,33 @@ cargo test -- tools::shell
|
||||||
cargo test -- tools::file_read
|
cargo test -- tools::file_read
|
||||||
cargo test -- tools::file_write
|
cargo test -- tools::file_write
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Container Security
|
||||||
|
|
||||||
|
ZeroClaw Docker images follow CIS Docker Benchmark best practices:
|
||||||
|
|
||||||
|
| Control | Implementation |
|
||||||
|
|---------|----------------|
|
||||||
|
| **4.1 Non-root user** | Container runs as UID 65534 (distroless nonroot) |
|
||||||
|
| **4.2 Minimal base image** | `gcr.io/distroless/cc-debian12:nonroot` — no shell, no package manager |
|
||||||
|
| **4.6 HEALTHCHECK** | Not applicable (stateless CLI/gateway) |
|
||||||
|
| **5.25 Read-only filesystem** | Supported via `docker run --read-only` with `/workspace` volume |
|
||||||
|
|
||||||
|
### Verifying Container Security
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Build and verify non-root user
|
||||||
|
docker build -t zeroclaw .
|
||||||
|
docker inspect --format='{{.Config.User}}' zeroclaw
|
||||||
|
# Expected: 65534:65534
|
||||||
|
|
||||||
|
# Run with read-only filesystem (production hardening)
|
||||||
|
docker run --read-only -v /path/to/workspace:/workspace zeroclaw gateway
|
||||||
|
```
|
||||||
|
|
||||||
|
### CI Enforcement
|
||||||
|
|
||||||
|
The `docker` job in `.github/workflows/ci.yml` automatically verifies:
|
||||||
|
1. Container does not run as root (UID 0)
|
||||||
|
2. Runtime stage uses `:nonroot` variant
|
||||||
|
3. Explicit `USER` directive with numeric UID exists
|
||||||
|
|
|
||||||
169
scripts/test_dockerignore.sh
Executable file
169
scripts/test_dockerignore.sh
Executable file
|
|
@ -0,0 +1,169 @@
|
||||||
|
#!/usr/bin/env bash
|
||||||
|
# Test script to verify .dockerignore excludes sensitive paths
|
||||||
|
# Run: ./scripts/test_dockerignore.sh
|
||||||
|
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||||
|
PROJECT_ROOT="$(dirname "$SCRIPT_DIR")"
|
||||||
|
DOCKERIGNORE="$PROJECT_ROOT/.dockerignore"
|
||||||
|
|
||||||
|
RED='\033[0;31m'
|
||||||
|
GREEN='\033[0;32m'
|
||||||
|
NC='\033[0m' # No Color
|
||||||
|
|
||||||
|
PASS=0
|
||||||
|
FAIL=0
|
||||||
|
|
||||||
|
log_pass() {
|
||||||
|
echo -e "${GREEN}✓${NC} $1"
|
||||||
|
PASS=$((PASS + 1))
|
||||||
|
}
|
||||||
|
|
||||||
|
log_fail() {
|
||||||
|
echo -e "${RED}✗${NC} $1"
|
||||||
|
FAIL=$((FAIL + 1))
|
||||||
|
}
|
||||||
|
|
||||||
|
# Test 1: .dockerignore exists
|
||||||
|
echo "=== Testing .dockerignore ==="
|
||||||
|
if [[ -f "$DOCKERIGNORE" ]]; then
|
||||||
|
log_pass ".dockerignore file exists"
|
||||||
|
else
|
||||||
|
log_fail ".dockerignore file does not exist"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Test 2: Required exclusions are present
|
||||||
|
MUST_EXCLUDE=(
|
||||||
|
".git"
|
||||||
|
".githooks"
|
||||||
|
"target"
|
||||||
|
"docs"
|
||||||
|
"examples"
|
||||||
|
"tests"
|
||||||
|
"*.md"
|
||||||
|
"*.png"
|
||||||
|
"*.db"
|
||||||
|
"*.db-journal"
|
||||||
|
".DS_Store"
|
||||||
|
".github"
|
||||||
|
"deny.toml"
|
||||||
|
"LICENSE"
|
||||||
|
".env"
|
||||||
|
".tmp_*"
|
||||||
|
)
|
||||||
|
|
||||||
|
for pattern in "${MUST_EXCLUDE[@]}"; do
|
||||||
|
# Use fgrep for literal matching
|
||||||
|
if grep -Fq "$pattern" "$DOCKERIGNORE" 2>/dev/null; then
|
||||||
|
log_pass "Excludes: $pattern"
|
||||||
|
else
|
||||||
|
log_fail "Missing exclusion: $pattern"
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
|
||||||
|
# Test 3: Build essentials are NOT excluded
|
||||||
|
MUST_NOT_EXCLUDE=(
|
||||||
|
"Cargo.toml"
|
||||||
|
"Cargo.lock"
|
||||||
|
"src"
|
||||||
|
)
|
||||||
|
|
||||||
|
for path in "${MUST_NOT_EXCLUDE[@]}"; do
|
||||||
|
if grep -qE "^${path}$" "$DOCKERIGNORE" 2>/dev/null; then
|
||||||
|
log_fail "Build essential '$path' is incorrectly excluded"
|
||||||
|
else
|
||||||
|
log_pass "Build essential NOT excluded: $path"
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
|
||||||
|
# Test 4: No syntax errors (basic validation)
|
||||||
|
while IFS= read -r line; do
|
||||||
|
# Skip empty lines and comments
|
||||||
|
[[ -z "$line" || "$line" =~ ^# ]] && continue
|
||||||
|
|
||||||
|
# Check for common issues
|
||||||
|
if [[ "$line" =~ [[:space:]]$ ]]; then
|
||||||
|
log_fail "Trailing whitespace in pattern: '$line'"
|
||||||
|
fi
|
||||||
|
done < "$DOCKERIGNORE"
|
||||||
|
log_pass "No trailing whitespace in patterns"
|
||||||
|
|
||||||
|
# Test 5: Verify Docker build context would be small
|
||||||
|
echo ""
|
||||||
|
echo "=== Simulating Docker build context ==="
|
||||||
|
|
||||||
|
# Create temp dir and simulate what would be sent
|
||||||
|
TEMP_DIR=$(mktemp -d)
|
||||||
|
trap "rm -rf $TEMP_DIR" EXIT
|
||||||
|
|
||||||
|
# Use rsync with .dockerignore patterns to simulate Docker's behavior
|
||||||
|
cd "$PROJECT_ROOT"
|
||||||
|
|
||||||
|
# Count files that WOULD be sent (excluding .dockerignore patterns)
|
||||||
|
TOTAL_FILES=$(find . -type f | wc -l | tr -d ' ')
|
||||||
|
CONTEXT_FILES=$(find . -type f \
|
||||||
|
! -path './.git/*' \
|
||||||
|
! -path './target/*' \
|
||||||
|
! -path './docs/*' \
|
||||||
|
! -path './examples/*' \
|
||||||
|
! -path './tests/*' \
|
||||||
|
! -name '*.md' \
|
||||||
|
! -name '*.png' \
|
||||||
|
! -name '*.svg' \
|
||||||
|
! -name '*.db' \
|
||||||
|
! -name '*.db-journal' \
|
||||||
|
! -name '.DS_Store' \
|
||||||
|
! -path './.github/*' \
|
||||||
|
! -name 'deny.toml' \
|
||||||
|
! -name 'LICENSE' \
|
||||||
|
! -name '.env' \
|
||||||
|
! -name '.env.*' \
|
||||||
|
2>/dev/null | wc -l | tr -d ' ')
|
||||||
|
|
||||||
|
echo "Total files in repo: $TOTAL_FILES"
|
||||||
|
echo "Files in Docker context: $CONTEXT_FILES"
|
||||||
|
|
||||||
|
if [[ $CONTEXT_FILES -lt $TOTAL_FILES ]]; then
|
||||||
|
log_pass "Docker context is smaller than full repo ($CONTEXT_FILES < $TOTAL_FILES files)"
|
||||||
|
else
|
||||||
|
log_fail "Docker context is not being reduced"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Test 6: Verify critical security files would be excluded
|
||||||
|
echo ""
|
||||||
|
echo "=== Security checks ==="
|
||||||
|
|
||||||
|
# Check if .git would be excluded
|
||||||
|
if [[ -d "$PROJECT_ROOT/.git" ]]; then
|
||||||
|
if grep -q "^\.git$" "$DOCKERIGNORE"; then
|
||||||
|
log_pass ".git directory will be excluded (security)"
|
||||||
|
else
|
||||||
|
log_fail ".git directory NOT excluded - SECURITY RISK"
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Check if any .db files exist and would be excluded
|
||||||
|
DB_FILES=$(find "$PROJECT_ROOT" -name "*.db" -type f 2>/dev/null | head -5)
|
||||||
|
if [[ -n "$DB_FILES" ]]; then
|
||||||
|
if grep -q "^\*\.db$" "$DOCKERIGNORE"; then
|
||||||
|
log_pass "*.db files will be excluded (security)"
|
||||||
|
else
|
||||||
|
log_fail "*.db files NOT excluded - SECURITY RISK"
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Summary
|
||||||
|
echo ""
|
||||||
|
echo "=== Summary ==="
|
||||||
|
echo -e "Passed: ${GREEN}$PASS${NC}"
|
||||||
|
echo -e "Failed: ${RED}$FAIL${NC}"
|
||||||
|
|
||||||
|
if [[ $FAIL -gt 0 ]]; then
|
||||||
|
echo -e "${RED}FAILED${NC}: $FAIL tests failed"
|
||||||
|
exit 1
|
||||||
|
else
|
||||||
|
echo -e "${GREEN}PASSED${NC}: All tests passed"
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
|
@ -39,7 +39,7 @@ pub async fn run(
|
||||||
// ── Wire up agnostic subsystems ──────────────────────────────
|
// ── Wire up agnostic subsystems ──────────────────────────────
|
||||||
let observer: Arc<dyn Observer> =
|
let observer: Arc<dyn Observer> =
|
||||||
Arc::from(observability::create_observer(&config.observability));
|
Arc::from(observability::create_observer(&config.observability));
|
||||||
let _runtime = runtime::create_runtime(&config.runtime);
|
let _runtime = runtime::create_runtime(&config.runtime)?;
|
||||||
let security = Arc::new(SecurityPolicy::from_config(
|
let security = Arc::new(SecurityPolicy::from_config(
|
||||||
&config.autonomy,
|
&config.autonomy,
|
||||||
&config.workspace_dir,
|
&config.workspace_dir,
|
||||||
|
|
@ -72,8 +72,11 @@ pub async fn run(
|
||||||
.or(config.default_model.as_deref())
|
.or(config.default_model.as_deref())
|
||||||
.unwrap_or("anthropic/claude-sonnet-4-20250514");
|
.unwrap_or("anthropic/claude-sonnet-4-20250514");
|
||||||
|
|
||||||
let provider: Box<dyn Provider> =
|
let provider: Box<dyn Provider> = providers::create_resilient_provider(
|
||||||
providers::create_provider(provider_name, config.api_key.as_deref())?;
|
provider_name,
|
||||||
|
config.api_key.as_deref(),
|
||||||
|
&config.reliability,
|
||||||
|
)?;
|
||||||
|
|
||||||
observer.record_event(&ObserverEvent::AgentStart {
|
observer.record_event(&ObserverEvent::AgentStart {
|
||||||
provider: provider_name.to_string(),
|
provider: provider_name.to_string(),
|
||||||
|
|
@ -83,12 +86,30 @@ pub async fn run(
|
||||||
// ── Build system prompt from workspace MD files (OpenClaw framework) ──
|
// ── Build system prompt from workspace MD files (OpenClaw framework) ──
|
||||||
let skills = crate::skills::load_skills(&config.workspace_dir);
|
let skills = crate::skills::load_skills(&config.workspace_dir);
|
||||||
let mut tool_descs: Vec<(&str, &str)> = vec![
|
let mut tool_descs: Vec<(&str, &str)> = vec![
|
||||||
("shell", "Execute terminal commands"),
|
(
|
||||||
("file_read", "Read file contents"),
|
"shell",
|
||||||
("file_write", "Write file contents"),
|
"Execute terminal commands. Use when: running local checks, build/test commands, diagnostics. Don't use when: a safer dedicated tool exists, or command is destructive without approval.",
|
||||||
("memory_store", "Save to memory"),
|
),
|
||||||
("memory_recall", "Search memory"),
|
(
|
||||||
("memory_forget", "Delete a memory entry"),
|
"file_read",
|
||||||
|
"Read file contents. Use when: inspecting project files, configs, logs. Don't use when: a targeted search is enough.",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"file_write",
|
||||||
|
"Write file contents. Use when: applying focused edits, scaffolding files, updating docs/code. Don't use when: side effects are unclear or file ownership is uncertain.",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"memory_store",
|
||||||
|
"Save to memory. Use when: preserving durable preferences, decisions, key context. Don't use when: information is transient/noisy/sensitive without need.",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"memory_recall",
|
||||||
|
"Search memory. Use when: retrieving prior decisions, user preferences, historical context. Don't use when: answer is already in current context.",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"memory_forget",
|
||||||
|
"Delete a memory entry. Use when: memory is incorrect/stale or explicitly requested for removal. Don't use when: impact is uncertain.",
|
||||||
|
),
|
||||||
];
|
];
|
||||||
if config.browser.enabled {
|
if config.browser.enabled {
|
||||||
tool_descs.push((
|
tool_descs.push((
|
||||||
|
|
|
||||||
|
|
@ -29,6 +29,60 @@ impl IMessageChannel {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Escape a string for safe interpolation into `AppleScript`.
|
||||||
|
///
|
||||||
|
/// This prevents injection attacks by escaping:
|
||||||
|
/// - Backslashes (`\` → `\\`)
|
||||||
|
/// - Double quotes (`"` → `\"`)
|
||||||
|
fn escape_applescript(s: &str) -> String {
|
||||||
|
s.replace('\\', "\\\\").replace('"', "\\\"")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Validate that a target looks like a valid phone number or email address.
|
||||||
|
///
|
||||||
|
/// This is a defense-in-depth measure to reject obviously malicious targets
|
||||||
|
/// before they reach `AppleScript` interpolation.
|
||||||
|
///
|
||||||
|
/// Valid patterns:
|
||||||
|
/// - Phone: starts with `+` followed by digits (with optional spaces/dashes)
|
||||||
|
/// - Email: contains `@` with alphanumeric chars on both sides
|
||||||
|
fn is_valid_imessage_target(target: &str) -> bool {
|
||||||
|
let target = target.trim();
|
||||||
|
if target.is_empty() {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Phone number: +1234567890 or +1 234-567-8900
|
||||||
|
if target.starts_with('+') {
|
||||||
|
let digits_only: String = target.chars().filter(char::is_ascii_digit).collect();
|
||||||
|
// Must have at least 7 digits (shortest valid phone numbers)
|
||||||
|
return digits_only.len() >= 7 && digits_only.len() <= 15;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Email: simple validation (contains @ with chars on both sides)
|
||||||
|
if let Some(at_pos) = target.find('@') {
|
||||||
|
let local = &target[..at_pos];
|
||||||
|
let domain = &target[at_pos + 1..];
|
||||||
|
|
||||||
|
// Local part: non-empty, alphanumeric + common email chars
|
||||||
|
let local_valid = !local.is_empty()
|
||||||
|
&& local
|
||||||
|
.chars()
|
||||||
|
.all(|c| c.is_alphanumeric() || "._+-".contains(c));
|
||||||
|
|
||||||
|
// Domain: non-empty, contains a dot, alphanumeric + dots/hyphens
|
||||||
|
let domain_valid = !domain.is_empty()
|
||||||
|
&& domain.contains('.')
|
||||||
|
&& domain
|
||||||
|
.chars()
|
||||||
|
.all(|c| c.is_alphanumeric() || ".-".contains(c));
|
||||||
|
|
||||||
|
return local_valid && domain_valid;
|
||||||
|
}
|
||||||
|
|
||||||
|
false
|
||||||
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl Channel for IMessageChannel {
|
impl Channel for IMessageChannel {
|
||||||
fn name(&self) -> &str {
|
fn name(&self) -> &str {
|
||||||
|
|
@ -36,11 +90,22 @@ impl Channel for IMessageChannel {
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn send(&self, message: &str, target: &str) -> anyhow::Result<()> {
|
async fn send(&self, message: &str, target: &str) -> anyhow::Result<()> {
|
||||||
let escaped_msg = message.replace('\\', "\\\\").replace('"', "\\\"");
|
// Defense-in-depth: validate target format before any interpolation
|
||||||
|
if !is_valid_imessage_target(target) {
|
||||||
|
anyhow::bail!(
|
||||||
|
"Invalid iMessage target: must be a phone number (+1234567890) or email (user@example.com)"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// SECURITY: Escape both message AND target to prevent AppleScript injection
|
||||||
|
// See: CWE-78 (OS Command Injection)
|
||||||
|
let escaped_msg = escape_applescript(message);
|
||||||
|
let escaped_target = escape_applescript(target);
|
||||||
|
|
||||||
let script = format!(
|
let script = format!(
|
||||||
r#"tell application "Messages"
|
r#"tell application "Messages"
|
||||||
set targetService to 1st account whose service type = iMessage
|
set targetService to 1st account whose service type = iMessage
|
||||||
set targetBuddy to participant "{target}" of targetService
|
set targetBuddy to participant "{escaped_target}" of targetService
|
||||||
send "{escaped_msg}" to targetBuddy
|
send "{escaped_msg}" to targetBuddy
|
||||||
end tell"#
|
end tell"#
|
||||||
);
|
);
|
||||||
|
|
@ -262,4 +327,204 @@ mod tests {
|
||||||
assert!(ch.is_contact_allowed(" spaced "));
|
assert!(ch.is_contact_allowed(" spaced "));
|
||||||
assert!(!ch.is_contact_allowed("spaced"));
|
assert!(!ch.is_contact_allowed("spaced"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ══════════════════════════════════════════════════════════
|
||||||
|
// AppleScript Escaping Tests (CWE-78 Prevention)
|
||||||
|
// ══════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn escape_applescript_double_quotes() {
|
||||||
|
assert_eq!(escape_applescript(r#"hello "world""#), r#"hello \"world\""#);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn escape_applescript_backslashes() {
|
||||||
|
assert_eq!(escape_applescript(r"path\to\file"), r"path\\to\\file");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn escape_applescript_mixed() {
|
||||||
|
assert_eq!(
|
||||||
|
escape_applescript(r#"say "hello\" world"#),
|
||||||
|
r#"say \"hello\\\" world"#
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn escape_applescript_injection_attempt() {
|
||||||
|
// This is the exact attack vector from the security report
|
||||||
|
let malicious = r#"" & do shell script "id" & ""#;
|
||||||
|
let escaped = escape_applescript(malicious);
|
||||||
|
// After escaping, the quotes should be escaped and not break out
|
||||||
|
assert_eq!(escaped, r#"\" & do shell script \"id\" & \""#);
|
||||||
|
// Verify all quotes are now escaped (preceded by backslash)
|
||||||
|
// The escaped string should not have any unescaped quotes (quote not preceded by backslash)
|
||||||
|
let chars: Vec<char> = escaped.chars().collect();
|
||||||
|
for (i, &c) in chars.iter().enumerate() {
|
||||||
|
if c == '"' {
|
||||||
|
// Every quote must be preceded by a backslash
|
||||||
|
assert!(
|
||||||
|
i > 0 && chars[i - 1] == '\\',
|
||||||
|
"Found unescaped quote at position {i}"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn escape_applescript_empty_string() {
|
||||||
|
assert_eq!(escape_applescript(""), "");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn escape_applescript_no_special_chars() {
|
||||||
|
assert_eq!(escape_applescript("hello world"), "hello world");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn escape_applescript_unicode() {
|
||||||
|
assert_eq!(escape_applescript("hello 🦀 world"), "hello 🦀 world");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn escape_applescript_newlines_preserved() {
|
||||||
|
assert_eq!(escape_applescript("line1\nline2"), "line1\nline2");
|
||||||
|
}
|
||||||
|
|
||||||
|
// ══════════════════════════════════════════════════════════
|
||||||
|
// Target Validation Tests
|
||||||
|
// ══════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn valid_phone_number_simple() {
|
||||||
|
assert!(is_valid_imessage_target("+1234567890"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn valid_phone_number_with_country_code() {
|
||||||
|
assert!(is_valid_imessage_target("+14155551234"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn valid_phone_number_with_spaces() {
|
||||||
|
assert!(is_valid_imessage_target("+1 415 555 1234"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn valid_phone_number_with_dashes() {
|
||||||
|
assert!(is_valid_imessage_target("+1-415-555-1234"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn valid_phone_number_international() {
|
||||||
|
assert!(is_valid_imessage_target("+447911123456")); // UK
|
||||||
|
assert!(is_valid_imessage_target("+81312345678")); // Japan
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn valid_email_simple() {
|
||||||
|
assert!(is_valid_imessage_target("user@example.com"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn valid_email_with_subdomain() {
|
||||||
|
assert!(is_valid_imessage_target("user@mail.example.com"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn valid_email_with_plus() {
|
||||||
|
assert!(is_valid_imessage_target("user+tag@example.com"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn valid_email_with_dots() {
|
||||||
|
assert!(is_valid_imessage_target("first.last@example.com"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn valid_email_icloud() {
|
||||||
|
assert!(is_valid_imessage_target("user@icloud.com"));
|
||||||
|
assert!(is_valid_imessage_target("user@me.com"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn invalid_target_empty() {
|
||||||
|
assert!(!is_valid_imessage_target(""));
|
||||||
|
assert!(!is_valid_imessage_target(" "));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn invalid_target_no_plus_prefix() {
|
||||||
|
// Phone numbers must start with +
|
||||||
|
assert!(!is_valid_imessage_target("1234567890"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn invalid_target_too_short_phone() {
|
||||||
|
// Less than 7 digits
|
||||||
|
assert!(!is_valid_imessage_target("+123456"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn invalid_target_too_long_phone() {
|
||||||
|
// More than 15 digits
|
||||||
|
assert!(!is_valid_imessage_target("+1234567890123456"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn invalid_target_email_no_at() {
|
||||||
|
assert!(!is_valid_imessage_target("userexample.com"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn invalid_target_email_no_domain() {
|
||||||
|
assert!(!is_valid_imessage_target("user@"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn invalid_target_email_no_local() {
|
||||||
|
assert!(!is_valid_imessage_target("@example.com"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn invalid_target_email_no_dot_in_domain() {
|
||||||
|
assert!(!is_valid_imessage_target("user@localhost"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn invalid_target_injection_attempt() {
|
||||||
|
// The exact attack vector from the security report
|
||||||
|
assert!(!is_valid_imessage_target(r#"" & do shell script "id" & ""#));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn invalid_target_applescript_injection() {
|
||||||
|
// Various injection attempts
|
||||||
|
assert!(!is_valid_imessage_target(r#"test" & quit"#));
|
||||||
|
assert!(!is_valid_imessage_target(r#"test\ndo shell script"#));
|
||||||
|
assert!(!is_valid_imessage_target("test\"; malicious code; \""));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn invalid_target_special_chars() {
|
||||||
|
assert!(!is_valid_imessage_target("user<script>@example.com"));
|
||||||
|
assert!(!is_valid_imessage_target("user@example.com; rm -rf /"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn invalid_target_null_byte() {
|
||||||
|
assert!(!is_valid_imessage_target("user\0@example.com"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn invalid_target_newline() {
|
||||||
|
assert!(!is_valid_imessage_target("user\n@example.com"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn target_with_leading_trailing_whitespace_trimmed() {
|
||||||
|
// Should trim and validate
|
||||||
|
assert!(is_valid_imessage_target(" +1234567890 "));
|
||||||
|
assert!(is_valid_imessage_target(" user@example.com "));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@ pub mod slack;
|
||||||
pub mod telegram;
|
pub mod telegram;
|
||||||
pub mod whatsapp;
|
pub mod whatsapp;
|
||||||
pub mod traits;
|
pub mod traits;
|
||||||
|
pub mod whatsapp;
|
||||||
|
|
||||||
pub use cli::CliChannel;
|
pub use cli::CliChannel;
|
||||||
pub use discord::DiscordChannel;
|
pub use discord::DiscordChannel;
|
||||||
|
|
@ -17,6 +18,7 @@ pub use telegram::TelegramChannel;
|
||||||
#[allow(unused_imports)]
|
#[allow(unused_imports)]
|
||||||
pub use whatsapp::WhatsAppChannel;
|
pub use whatsapp::WhatsAppChannel;
|
||||||
pub use traits::Channel;
|
pub use traits::Channel;
|
||||||
|
pub use whatsapp::WhatsAppChannel;
|
||||||
|
|
||||||
use crate::config::Config;
|
use crate::config::Config;
|
||||||
use crate::memory::{self, Memory};
|
use crate::memory::{self, Memory};
|
||||||
|
|
@ -28,6 +30,46 @@ use std::time::Duration;
|
||||||
/// Maximum characters per injected workspace file (matches `OpenClaw` default).
|
/// Maximum characters per injected workspace file (matches `OpenClaw` default).
|
||||||
const BOOTSTRAP_MAX_CHARS: usize = 20_000;
|
const BOOTSTRAP_MAX_CHARS: usize = 20_000;
|
||||||
|
|
||||||
|
const DEFAULT_CHANNEL_INITIAL_BACKOFF_SECS: u64 = 2;
|
||||||
|
const DEFAULT_CHANNEL_MAX_BACKOFF_SECS: u64 = 60;
|
||||||
|
|
||||||
|
fn spawn_supervised_listener(
|
||||||
|
ch: Arc<dyn Channel>,
|
||||||
|
tx: tokio::sync::mpsc::Sender<traits::ChannelMessage>,
|
||||||
|
initial_backoff_secs: u64,
|
||||||
|
max_backoff_secs: u64,
|
||||||
|
) -> tokio::task::JoinHandle<()> {
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let component = format!("channel:{}", ch.name());
|
||||||
|
let mut backoff = initial_backoff_secs.max(1);
|
||||||
|
let max_backoff = max_backoff_secs.max(backoff);
|
||||||
|
|
||||||
|
loop {
|
||||||
|
crate::health::mark_component_ok(&component);
|
||||||
|
let result = ch.listen(tx.clone()).await;
|
||||||
|
|
||||||
|
if tx.is_closed() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
|
match result {
|
||||||
|
Ok(()) => {
|
||||||
|
tracing::warn!("Channel {} exited unexpectedly; restarting", ch.name());
|
||||||
|
crate::health::mark_component_error(&component, "listener exited unexpectedly");
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
tracing::error!("Channel {} error: {e}; restarting", ch.name());
|
||||||
|
crate::health::mark_component_error(&component, e.to_string());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
crate::health::bump_component_restart(&component);
|
||||||
|
tokio::time::sleep(Duration::from_secs(backoff)).await;
|
||||||
|
backoff = backoff.saturating_mul(2).min(max_backoff);
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
/// Load workspace identity files and build a system prompt.
|
/// Load workspace identity files and build a system prompt.
|
||||||
///
|
///
|
||||||
/// Follows the `OpenClaw` framework structure:
|
/// Follows the `OpenClaw` framework structure:
|
||||||
|
|
@ -150,6 +192,38 @@ pub fn build_system_prompt(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Inject OpenClaw (markdown) identity files into the prompt
|
||||||
|
fn inject_openclaw_identity(prompt: &mut String, workspace_dir: &std::path::Path) {
|
||||||
|
#[allow(unused_imports)]
|
||||||
|
use std::fmt::Write;
|
||||||
|
|
||||||
|
prompt.push_str("## Project Context\n\n");
|
||||||
|
prompt
|
||||||
|
.push_str("The following workspace files define your identity, behavior, and context.\n\n");
|
||||||
|
|
||||||
|
let bootstrap_files = [
|
||||||
|
"AGENTS.md",
|
||||||
|
"SOUL.md",
|
||||||
|
"TOOLS.md",
|
||||||
|
"IDENTITY.md",
|
||||||
|
"USER.md",
|
||||||
|
"HEARTBEAT.md",
|
||||||
|
];
|
||||||
|
|
||||||
|
for filename in &bootstrap_files {
|
||||||
|
inject_workspace_file(prompt, workspace_dir, filename);
|
||||||
|
}
|
||||||
|
|
||||||
|
// BOOTSTRAP.md — only if it exists (first-run ritual)
|
||||||
|
let bootstrap_path = workspace_dir.join("BOOTSTRAP.md");
|
||||||
|
if bootstrap_path.exists() {
|
||||||
|
inject_workspace_file(prompt, workspace_dir, "BOOTSTRAP.md");
|
||||||
|
}
|
||||||
|
|
||||||
|
// MEMORY.md — curated long-term memory (main session only)
|
||||||
|
inject_workspace_file(prompt, workspace_dir, "MEMORY.md");
|
||||||
|
}
|
||||||
|
|
||||||
/// Inject a single workspace file into the prompt with truncation and missing-file markers.
|
/// Inject a single workspace file into the prompt with truncation and missing-file markers.
|
||||||
fn inject_workspace_file(prompt: &mut String, workspace_dir: &std::path::Path, filename: &str) {
|
fn inject_workspace_file(prompt: &mut String, workspace_dir: &std::path::Path, filename: &str) {
|
||||||
use std::fmt::Write;
|
use std::fmt::Write;
|
||||||
|
|
@ -200,6 +274,7 @@ pub fn handle_command(command: super::ChannelCommands, config: &Config) -> Resul
|
||||||
("Webhook", config.channels_config.webhook.is_some()),
|
("Webhook", config.channels_config.webhook.is_some()),
|
||||||
("iMessage", config.channels_config.imessage.is_some()),
|
("iMessage", config.channels_config.imessage.is_some()),
|
||||||
("Matrix", config.channels_config.matrix.is_some()),
|
("Matrix", config.channels_config.matrix.is_some()),
|
||||||
|
("WhatsApp", config.channels_config.whatsapp.is_some()),
|
||||||
] {
|
] {
|
||||||
println!(" {} {name}", if configured { "✅" } else { "❌" });
|
println!(" {} {name}", if configured { "✅" } else { "❌" });
|
||||||
}
|
}
|
||||||
|
|
@ -294,6 +369,18 @@ pub async fn doctor_channels(config: Config) -> Result<()> {
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if let Some(ref wa) = config.channels_config.whatsapp {
|
||||||
|
channels.push((
|
||||||
|
"WhatsApp",
|
||||||
|
Arc::new(WhatsAppChannel::new(
|
||||||
|
wa.access_token.clone(),
|
||||||
|
wa.phone_number_id.clone(),
|
||||||
|
wa.verify_token.clone(),
|
||||||
|
wa.allowed_numbers.clone(),
|
||||||
|
)),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
if channels.is_empty() {
|
if channels.is_empty() {
|
||||||
println!("No real-time channels configured. Run `zeroclaw onboard` first.");
|
println!("No real-time channels configured. Run `zeroclaw onboard` first.");
|
||||||
return Ok(());
|
return Ok(());
|
||||||
|
|
@ -338,9 +425,10 @@ pub async fn doctor_channels(config: Config) -> Result<()> {
|
||||||
/// Start all configured channels and route messages to the agent
|
/// Start all configured channels and route messages to the agent
|
||||||
#[allow(clippy::too_many_lines)]
|
#[allow(clippy::too_many_lines)]
|
||||||
pub async fn start_channels(config: Config) -> Result<()> {
|
pub async fn start_channels(config: Config) -> Result<()> {
|
||||||
let provider: Arc<dyn Provider> = Arc::from(providers::create_provider(
|
let provider: Arc<dyn Provider> = Arc::from(providers::create_resilient_provider(
|
||||||
config.default_provider.as_deref().unwrap_or("openrouter"),
|
config.default_provider.as_deref().unwrap_or("openrouter"),
|
||||||
config.api_key.as_deref(),
|
config.api_key.as_deref(),
|
||||||
|
&config.reliability,
|
||||||
)?);
|
)?);
|
||||||
let model = config
|
let model = config
|
||||||
.default_model
|
.default_model
|
||||||
|
|
@ -359,12 +447,30 @@ pub async fn start_channels(config: Config) -> Result<()> {
|
||||||
|
|
||||||
// Collect tool descriptions for the prompt
|
// Collect tool descriptions for the prompt
|
||||||
let mut tool_descs: Vec<(&str, &str)> = vec![
|
let mut tool_descs: Vec<(&str, &str)> = vec![
|
||||||
("shell", "Execute terminal commands"),
|
(
|
||||||
("file_read", "Read file contents"),
|
"shell",
|
||||||
("file_write", "Write file contents"),
|
"Execute terminal commands. Use when: running local checks, build/test commands, diagnostics. Don't use when: a safer dedicated tool exists, or command is destructive without approval.",
|
||||||
("memory_store", "Save to memory"),
|
),
|
||||||
("memory_recall", "Search memory"),
|
(
|
||||||
("memory_forget", "Delete a memory entry"),
|
"file_read",
|
||||||
|
"Read file contents. Use when: inspecting project files, configs, logs. Don't use when: a targeted search is enough.",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"file_write",
|
||||||
|
"Write file contents. Use when: applying focused edits, scaffolding files, updating docs/code. Don't use when: side effects are unclear or file ownership is uncertain.",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"memory_store",
|
||||||
|
"Save to memory. Use when: preserving durable preferences, decisions, key context. Don't use when: information is transient/noisy/sensitive without need.",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"memory_recall",
|
||||||
|
"Search memory. Use when: retrieving prior decisions, user preferences, historical context. Don't use when: answer is already in current context.",
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"memory_forget",
|
||||||
|
"Delete a memory entry. Use when: memory is incorrect/stale or explicitly requested for removal. Don't use when: impact is uncertain.",
|
||||||
|
),
|
||||||
];
|
];
|
||||||
|
|
||||||
if config.browser.enabled {
|
if config.browser.enabled {
|
||||||
|
|
@ -426,6 +532,15 @@ pub async fn start_channels(config: Config) -> Result<()> {
|
||||||
)));
|
)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if let Some(ref wa) = config.channels_config.whatsapp {
|
||||||
|
channels.push(Arc::new(WhatsAppChannel::new(
|
||||||
|
wa.access_token.clone(),
|
||||||
|
wa.phone_number_id.clone(),
|
||||||
|
wa.verify_token.clone(),
|
||||||
|
wa.allowed_numbers.clone(),
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
|
||||||
if channels.is_empty() {
|
if channels.is_empty() {
|
||||||
println!("No channels configured. Run `zeroclaw onboard` to set up channels.");
|
println!("No channels configured. Run `zeroclaw onboard` to set up channels.");
|
||||||
return Ok(());
|
return Ok(());
|
||||||
|
|
@ -450,19 +565,29 @@ pub async fn start_channels(config: Config) -> Result<()> {
|
||||||
println!(" Listening for messages... (Ctrl+C to stop)");
|
println!(" Listening for messages... (Ctrl+C to stop)");
|
||||||
println!();
|
println!();
|
||||||
|
|
||||||
|
crate::health::mark_component_ok("channels");
|
||||||
|
|
||||||
|
let initial_backoff_secs = config
|
||||||
|
.reliability
|
||||||
|
.channel_initial_backoff_secs
|
||||||
|
.max(DEFAULT_CHANNEL_INITIAL_BACKOFF_SECS);
|
||||||
|
let max_backoff_secs = config
|
||||||
|
.reliability
|
||||||
|
.channel_max_backoff_secs
|
||||||
|
.max(DEFAULT_CHANNEL_MAX_BACKOFF_SECS);
|
||||||
|
|
||||||
// Single message bus — all channels send messages here
|
// Single message bus — all channels send messages here
|
||||||
let (tx, mut rx) = tokio::sync::mpsc::channel::<traits::ChannelMessage>(100);
|
let (tx, mut rx) = tokio::sync::mpsc::channel::<traits::ChannelMessage>(100);
|
||||||
|
|
||||||
// Spawn a listener for each channel
|
// Spawn a listener for each channel
|
||||||
let mut handles = Vec::new();
|
let mut handles = Vec::new();
|
||||||
for ch in &channels {
|
for ch in &channels {
|
||||||
let ch = ch.clone();
|
handles.push(spawn_supervised_listener(
|
||||||
let tx = tx.clone();
|
ch.clone(),
|
||||||
handles.push(tokio::spawn(async move {
|
tx.clone(),
|
||||||
if let Err(e) = ch.listen(tx).await {
|
initial_backoff_secs,
|
||||||
tracing::error!("Channel {} error: {e}", ch.name());
|
max_backoff_secs,
|
||||||
}
|
));
|
||||||
}));
|
|
||||||
}
|
}
|
||||||
drop(tx); // Drop our copy so rx closes when all channels stop
|
drop(tx); // Drop our copy so rx closes when all channels stop
|
||||||
|
|
||||||
|
|
@ -537,6 +662,8 @@ pub async fn start_channels(config: Config) -> Result<()> {
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||||
|
use std::sync::Arc;
|
||||||
use tempfile::TempDir;
|
use tempfile::TempDir;
|
||||||
|
|
||||||
fn make_workspace() -> TempDir {
|
fn make_workspace() -> TempDir {
|
||||||
|
|
@ -781,4 +908,55 @@ mod tests {
|
||||||
let state = classify_health_result(&result);
|
let state = classify_health_result(&result);
|
||||||
assert_eq!(state, ChannelHealthState::Timeout);
|
assert_eq!(state, ChannelHealthState::Timeout);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct AlwaysFailChannel {
|
||||||
|
name: &'static str,
|
||||||
|
calls: Arc<AtomicUsize>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait::async_trait]
|
||||||
|
impl Channel for AlwaysFailChannel {
|
||||||
|
fn name(&self) -> &str {
|
||||||
|
self.name
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn send(&self, _message: &str, _recipient: &str) -> anyhow::Result<()> {
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn listen(
|
||||||
|
&self,
|
||||||
|
_tx: tokio::sync::mpsc::Sender<traits::ChannelMessage>,
|
||||||
|
) -> anyhow::Result<()> {
|
||||||
|
self.calls.fetch_add(1, Ordering::SeqCst);
|
||||||
|
anyhow::bail!("listen boom")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn supervised_listener_marks_error_and_restarts_on_failures() {
|
||||||
|
let calls = Arc::new(AtomicUsize::new(0));
|
||||||
|
let channel: Arc<dyn Channel> = Arc::new(AlwaysFailChannel {
|
||||||
|
name: "test-supervised-fail",
|
||||||
|
calls: Arc::clone(&calls),
|
||||||
|
});
|
||||||
|
|
||||||
|
let (_tx, rx) = tokio::sync::mpsc::channel::<traits::ChannelMessage>(1);
|
||||||
|
let handle = spawn_supervised_listener(channel, _tx, 1, 1);
|
||||||
|
|
||||||
|
tokio::time::sleep(Duration::from_millis(80)).await;
|
||||||
|
drop(rx);
|
||||||
|
handle.abort();
|
||||||
|
let _ = handle.await;
|
||||||
|
|
||||||
|
let snapshot = crate::health::snapshot_json();
|
||||||
|
let component = &snapshot["components"]["channel:test-supervised-fail"];
|
||||||
|
assert_eq!(component["status"], "error");
|
||||||
|
assert!(component["restart_count"].as_u64().unwrap_or(0) >= 1);
|
||||||
|
assert!(component["last_error"]
|
||||||
|
.as_str()
|
||||||
|
.unwrap_or("")
|
||||||
|
.contains("listen boom"));
|
||||||
|
assert!(calls.load(Ordering::SeqCst) >= 1);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -25,6 +25,13 @@ impl TelegramChannel {
|
||||||
fn is_user_allowed(&self, username: &str) -> bool {
|
fn is_user_allowed(&self, username: &str) -> bool {
|
||||||
self.allowed_users.iter().any(|u| u == "*" || u == username)
|
self.allowed_users.iter().any(|u| u == "*" || u == username)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn is_any_user_allowed<'a, I>(&self, identities: I) -> bool
|
||||||
|
where
|
||||||
|
I: IntoIterator<Item = &'a str>,
|
||||||
|
{
|
||||||
|
identities.into_iter().any(|id| self.is_user_allowed(id))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
|
|
@ -95,15 +102,28 @@ impl Channel for TelegramChannel {
|
||||||
continue;
|
continue;
|
||||||
};
|
};
|
||||||
|
|
||||||
let username = message
|
let username_opt = message
|
||||||
.get("from")
|
.get("from")
|
||||||
.and_then(|f| f.get("username"))
|
.and_then(|f| f.get("username"))
|
||||||
.and_then(|u| u.as_str())
|
.and_then(|u| u.as_str());
|
||||||
.unwrap_or("unknown");
|
let username = username_opt.unwrap_or("unknown");
|
||||||
|
|
||||||
if !self.is_user_allowed(username) {
|
let user_id = message
|
||||||
|
.get("from")
|
||||||
|
.and_then(|f| f.get("id"))
|
||||||
|
.and_then(serde_json::Value::as_i64);
|
||||||
|
let user_id_str = user_id.map(|id| id.to_string());
|
||||||
|
|
||||||
|
let mut identities = vec![username];
|
||||||
|
if let Some(ref id) = user_id_str {
|
||||||
|
identities.push(id.as_str());
|
||||||
|
}
|
||||||
|
|
||||||
|
if !self.is_any_user_allowed(identities.iter().copied()) {
|
||||||
tracing::warn!(
|
tracing::warn!(
|
||||||
"Telegram: ignoring message from unauthorized user: {username}"
|
"Telegram: ignoring message from unauthorized user: username={username}, user_id={}. \
|
||||||
|
Allowlist Telegram @username or numeric user ID, then run `zeroclaw onboard --channels-only`.",
|
||||||
|
user_id_str.as_deref().unwrap_or("unknown")
|
||||||
);
|
);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
@ -211,4 +231,16 @@ mod tests {
|
||||||
assert!(ch.is_user_allowed("bob"));
|
assert!(ch.is_user_allowed("bob"));
|
||||||
assert!(ch.is_user_allowed("anyone"));
|
assert!(ch.is_user_allowed("anyone"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn telegram_user_allowed_by_numeric_id_identity() {
|
||||||
|
let ch = TelegramChannel::new("t".into(), vec!["123456789".into()]);
|
||||||
|
assert!(ch.is_any_user_allowed(["unknown", "123456789"]));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn telegram_user_denied_when_none_of_identities_match() {
|
||||||
|
let ch = TelegramChannel::new("t".into(), vec!["alice".into(), "987654321".into()]);
|
||||||
|
assert!(!ch.is_any_user_allowed(["unknown", "123456789"]));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
File diff suppressed because it is too large
Load diff
|
|
@ -2,7 +2,7 @@ pub mod schema;
|
||||||
|
|
||||||
pub use schema::{
|
pub use schema::{
|
||||||
AutonomyConfig, BrowserConfig, ChannelsConfig, ComposioConfig, Config, DiscordConfig,
|
AutonomyConfig, BrowserConfig, ChannelsConfig, ComposioConfig, Config, DiscordConfig,
|
||||||
GatewayConfig, HeartbeatConfig, IMessageConfig, MatrixConfig, MemoryConfig,
|
GatewayConfig, HeartbeatConfig, IMessageConfig, IdentityConfig, MatrixConfig, MemoryConfig,
|
||||||
ObservabilityConfig, RuntimeConfig, SecretsConfig, SlackConfig, TelegramConfig, TunnelConfig,
|
ObservabilityConfig, ReliabilityConfig, RuntimeConfig, SecretsConfig, SlackConfig,
|
||||||
WebhookConfig,
|
TelegramConfig, TunnelConfig, WebhookConfig,
|
||||||
};
|
};
|
||||||
|
|
|
||||||
|
|
@ -25,6 +25,9 @@ pub struct Config {
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub runtime: RuntimeConfig,
|
pub runtime: RuntimeConfig,
|
||||||
|
|
||||||
|
#[serde(default)]
|
||||||
|
pub reliability: ReliabilityConfig,
|
||||||
|
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub heartbeat: HeartbeatConfig,
|
pub heartbeat: HeartbeatConfig,
|
||||||
|
|
||||||
|
|
@ -48,6 +51,38 @@ pub struct Config {
|
||||||
|
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub browser: BrowserConfig,
|
pub browser: BrowserConfig,
|
||||||
|
|
||||||
|
#[serde(default)]
|
||||||
|
pub identity: IdentityConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Identity (AIEOS / OpenClaw format) ──────────────────────────
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct IdentityConfig {
|
||||||
|
/// Identity format: "openclaw" (default) or "aieos"
|
||||||
|
#[serde(default = "default_identity_format")]
|
||||||
|
pub format: String,
|
||||||
|
/// Path to AIEOS JSON file (relative to workspace)
|
||||||
|
#[serde(default)]
|
||||||
|
pub aieos_path: Option<String>,
|
||||||
|
/// Inline AIEOS JSON (alternative to file path)
|
||||||
|
#[serde(default)]
|
||||||
|
pub aieos_inline: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_identity_format() -> String {
|
||||||
|
"openclaw".into()
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for IdentityConfig {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
format: default_identity_format(),
|
||||||
|
aieos_path: None,
|
||||||
|
aieos_inline: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// ── Gateway security ─────────────────────────────────────────────
|
// ── Gateway security ─────────────────────────────────────────────
|
||||||
|
|
@ -143,6 +178,18 @@ pub struct MemoryConfig {
|
||||||
pub backend: String,
|
pub backend: String,
|
||||||
/// Auto-save conversation context to memory
|
/// Auto-save conversation context to memory
|
||||||
pub auto_save: bool,
|
pub auto_save: bool,
|
||||||
|
/// Run memory/session hygiene (archiving + retention cleanup)
|
||||||
|
#[serde(default = "default_hygiene_enabled")]
|
||||||
|
pub hygiene_enabled: bool,
|
||||||
|
/// Archive daily/session files older than this many days
|
||||||
|
#[serde(default = "default_archive_after_days")]
|
||||||
|
pub archive_after_days: u32,
|
||||||
|
/// Purge archived files older than this many days
|
||||||
|
#[serde(default = "default_purge_after_days")]
|
||||||
|
pub purge_after_days: u32,
|
||||||
|
/// For sqlite backend: prune conversation rows older than this many days
|
||||||
|
#[serde(default = "default_conversation_retention_days")]
|
||||||
|
pub conversation_retention_days: u32,
|
||||||
/// Embedding provider: "none" | "openai" | "custom:URL"
|
/// Embedding provider: "none" | "openai" | "custom:URL"
|
||||||
#[serde(default = "default_embedding_provider")]
|
#[serde(default = "default_embedding_provider")]
|
||||||
pub embedding_provider: String,
|
pub embedding_provider: String,
|
||||||
|
|
@ -169,6 +216,18 @@ pub struct MemoryConfig {
|
||||||
fn default_embedding_provider() -> String {
|
fn default_embedding_provider() -> String {
|
||||||
"none".into()
|
"none".into()
|
||||||
}
|
}
|
||||||
|
fn default_hygiene_enabled() -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
fn default_archive_after_days() -> u32 {
|
||||||
|
7
|
||||||
|
}
|
||||||
|
fn default_purge_after_days() -> u32 {
|
||||||
|
30
|
||||||
|
}
|
||||||
|
fn default_conversation_retention_days() -> u32 {
|
||||||
|
30
|
||||||
|
}
|
||||||
fn default_embedding_model() -> String {
|
fn default_embedding_model() -> String {
|
||||||
"text-embedding-3-small".into()
|
"text-embedding-3-small".into()
|
||||||
}
|
}
|
||||||
|
|
@ -193,6 +252,10 @@ impl Default for MemoryConfig {
|
||||||
Self {
|
Self {
|
||||||
backend: "sqlite".into(),
|
backend: "sqlite".into(),
|
||||||
auto_save: true,
|
auto_save: true,
|
||||||
|
hygiene_enabled: default_hygiene_enabled(),
|
||||||
|
archive_after_days: default_archive_after_days(),
|
||||||
|
purge_after_days: default_purge_after_days(),
|
||||||
|
conversation_retention_days: default_conversation_retention_days(),
|
||||||
embedding_provider: default_embedding_provider(),
|
embedding_provider: default_embedding_provider(),
|
||||||
embedding_model: default_embedding_model(),
|
embedding_model: default_embedding_model(),
|
||||||
embedding_dimensions: default_embedding_dims(),
|
embedding_dimensions: default_embedding_dims(),
|
||||||
|
|
@ -281,7 +344,9 @@ impl Default for AutonomyConfig {
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
pub struct RuntimeConfig {
|
pub struct RuntimeConfig {
|
||||||
/// "native" | "docker" | "cloudflare"
|
/// Runtime kind (currently supported: "native").
|
||||||
|
///
|
||||||
|
/// Reserved values (not implemented yet): "docker", "cloudflare".
|
||||||
pub kind: String,
|
pub kind: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -293,6 +358,71 @@ impl Default for RuntimeConfig {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ── Reliability / supervision ────────────────────────────────────
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct ReliabilityConfig {
|
||||||
|
/// Retries per provider before failing over.
|
||||||
|
#[serde(default = "default_provider_retries")]
|
||||||
|
pub provider_retries: u32,
|
||||||
|
/// Base backoff (ms) for provider retry delay.
|
||||||
|
#[serde(default = "default_provider_backoff_ms")]
|
||||||
|
pub provider_backoff_ms: u64,
|
||||||
|
/// Fallback provider chain (e.g. `["anthropic", "openai"]`).
|
||||||
|
#[serde(default)]
|
||||||
|
pub fallback_providers: Vec<String>,
|
||||||
|
/// Initial backoff for channel/daemon restarts.
|
||||||
|
#[serde(default = "default_channel_backoff_secs")]
|
||||||
|
pub channel_initial_backoff_secs: u64,
|
||||||
|
/// Max backoff for channel/daemon restarts.
|
||||||
|
#[serde(default = "default_channel_backoff_max_secs")]
|
||||||
|
pub channel_max_backoff_secs: u64,
|
||||||
|
/// Scheduler polling cadence in seconds.
|
||||||
|
#[serde(default = "default_scheduler_poll_secs")]
|
||||||
|
pub scheduler_poll_secs: u64,
|
||||||
|
/// Max retries for cron job execution attempts.
|
||||||
|
#[serde(default = "default_scheduler_retries")]
|
||||||
|
pub scheduler_retries: u32,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_provider_retries() -> u32 {
|
||||||
|
2
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_provider_backoff_ms() -> u64 {
|
||||||
|
500
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_channel_backoff_secs() -> u64 {
|
||||||
|
2
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_channel_backoff_max_secs() -> u64 {
|
||||||
|
60
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_scheduler_poll_secs() -> u64 {
|
||||||
|
15
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_scheduler_retries() -> u32 {
|
||||||
|
2
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for ReliabilityConfig {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
provider_retries: default_provider_retries(),
|
||||||
|
provider_backoff_ms: default_provider_backoff_ms(),
|
||||||
|
fallback_providers: Vec::new(),
|
||||||
|
channel_initial_backoff_secs: default_channel_backoff_secs(),
|
||||||
|
channel_max_backoff_secs: default_channel_backoff_max_secs(),
|
||||||
|
scheduler_poll_secs: default_scheduler_poll_secs(),
|
||||||
|
scheduler_retries: default_scheduler_retries(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ── Heartbeat ────────────────────────────────────────────────────
|
// ── Heartbeat ────────────────────────────────────────────────────
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
|
@ -387,6 +517,7 @@ pub struct ChannelsConfig {
|
||||||
pub webhook: Option<WebhookConfig>,
|
pub webhook: Option<WebhookConfig>,
|
||||||
pub imessage: Option<IMessageConfig>,
|
pub imessage: Option<IMessageConfig>,
|
||||||
pub matrix: Option<MatrixConfig>,
|
pub matrix: Option<MatrixConfig>,
|
||||||
|
pub whatsapp: Option<WhatsAppConfig>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for ChannelsConfig {
|
impl Default for ChannelsConfig {
|
||||||
|
|
@ -399,6 +530,7 @@ impl Default for ChannelsConfig {
|
||||||
webhook: None,
|
webhook: None,
|
||||||
imessage: None,
|
imessage: None,
|
||||||
matrix: None,
|
matrix: None,
|
||||||
|
whatsapp: None,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -445,6 +577,19 @@ pub struct MatrixConfig {
|
||||||
pub allowed_users: Vec<String>,
|
pub allowed_users: Vec<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct WhatsAppConfig {
|
||||||
|
/// Access token from Meta Business Suite
|
||||||
|
pub access_token: String,
|
||||||
|
/// Phone number ID from Meta Business API
|
||||||
|
pub phone_number_id: String,
|
||||||
|
/// Webhook verify token (you define this, Meta sends it back for verification)
|
||||||
|
pub verify_token: String,
|
||||||
|
/// Allowed phone numbers (E.164 format: +1234567890) or "*" for all
|
||||||
|
#[serde(default)]
|
||||||
|
pub allowed_numbers: Vec<String>,
|
||||||
|
}
|
||||||
|
|
||||||
// ── Config impl ──────────────────────────────────────────────────
|
// ── Config impl ──────────────────────────────────────────────────
|
||||||
|
|
||||||
impl Default for Config {
|
impl Default for Config {
|
||||||
|
|
@ -463,6 +608,7 @@ impl Default for Config {
|
||||||
observability: ObservabilityConfig::default(),
|
observability: ObservabilityConfig::default(),
|
||||||
autonomy: AutonomyConfig::default(),
|
autonomy: AutonomyConfig::default(),
|
||||||
runtime: RuntimeConfig::default(),
|
runtime: RuntimeConfig::default(),
|
||||||
|
reliability: ReliabilityConfig::default(),
|
||||||
heartbeat: HeartbeatConfig::default(),
|
heartbeat: HeartbeatConfig::default(),
|
||||||
channels_config: ChannelsConfig::default(),
|
channels_config: ChannelsConfig::default(),
|
||||||
memory: MemoryConfig::default(),
|
memory: MemoryConfig::default(),
|
||||||
|
|
@ -471,6 +617,7 @@ impl Default for Config {
|
||||||
composio: ComposioConfig::default(),
|
composio: ComposioConfig::default(),
|
||||||
secrets: SecretsConfig::default(),
|
secrets: SecretsConfig::default(),
|
||||||
browser: BrowserConfig::default(),
|
browser: BrowserConfig::default(),
|
||||||
|
identity: IdentityConfig::default(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -558,6 +705,17 @@ mod tests {
|
||||||
assert_eq!(h.interval_minutes, 30);
|
assert_eq!(h.interval_minutes, 30);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn memory_config_default_hygiene_settings() {
|
||||||
|
let m = MemoryConfig::default();
|
||||||
|
assert_eq!(m.backend, "sqlite");
|
||||||
|
assert!(m.auto_save);
|
||||||
|
assert!(m.hygiene_enabled);
|
||||||
|
assert_eq!(m.archive_after_days, 7);
|
||||||
|
assert_eq!(m.purge_after_days, 30);
|
||||||
|
assert_eq!(m.conversation_retention_days, 30);
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn channels_config_default() {
|
fn channels_config_default() {
|
||||||
let c = ChannelsConfig::default();
|
let c = ChannelsConfig::default();
|
||||||
|
|
@ -591,6 +749,7 @@ mod tests {
|
||||||
runtime: RuntimeConfig {
|
runtime: RuntimeConfig {
|
||||||
kind: "docker".into(),
|
kind: "docker".into(),
|
||||||
},
|
},
|
||||||
|
reliability: ReliabilityConfig::default(),
|
||||||
heartbeat: HeartbeatConfig {
|
heartbeat: HeartbeatConfig {
|
||||||
enabled: true,
|
enabled: true,
|
||||||
interval_minutes: 15,
|
interval_minutes: 15,
|
||||||
|
|
@ -606,6 +765,7 @@ mod tests {
|
||||||
webhook: None,
|
webhook: None,
|
||||||
imessage: None,
|
imessage: None,
|
||||||
matrix: None,
|
matrix: None,
|
||||||
|
whatsapp: None,
|
||||||
},
|
},
|
||||||
memory: MemoryConfig::default(),
|
memory: MemoryConfig::default(),
|
||||||
tunnel: TunnelConfig::default(),
|
tunnel: TunnelConfig::default(),
|
||||||
|
|
@ -613,6 +773,7 @@ mod tests {
|
||||||
composio: ComposioConfig::default(),
|
composio: ComposioConfig::default(),
|
||||||
secrets: SecretsConfig::default(),
|
secrets: SecretsConfig::default(),
|
||||||
browser: BrowserConfig::default(),
|
browser: BrowserConfig::default(),
|
||||||
|
identity: IdentityConfig::default(),
|
||||||
};
|
};
|
||||||
|
|
||||||
let toml_str = toml::to_string_pretty(&config).unwrap();
|
let toml_str = toml::to_string_pretty(&config).unwrap();
|
||||||
|
|
@ -650,6 +811,10 @@ default_temperature = 0.7
|
||||||
assert_eq!(parsed.runtime.kind, "native");
|
assert_eq!(parsed.runtime.kind, "native");
|
||||||
assert!(!parsed.heartbeat.enabled);
|
assert!(!parsed.heartbeat.enabled);
|
||||||
assert!(parsed.channels_config.cli);
|
assert!(parsed.channels_config.cli);
|
||||||
|
assert!(parsed.memory.hygiene_enabled);
|
||||||
|
assert_eq!(parsed.memory.archive_after_days, 7);
|
||||||
|
assert_eq!(parsed.memory.purge_after_days, 30);
|
||||||
|
assert_eq!(parsed.memory.conversation_retention_days, 30);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
@ -669,6 +834,7 @@ default_temperature = 0.7
|
||||||
observability: ObservabilityConfig::default(),
|
observability: ObservabilityConfig::default(),
|
||||||
autonomy: AutonomyConfig::default(),
|
autonomy: AutonomyConfig::default(),
|
||||||
runtime: RuntimeConfig::default(),
|
runtime: RuntimeConfig::default(),
|
||||||
|
reliability: ReliabilityConfig::default(),
|
||||||
heartbeat: HeartbeatConfig::default(),
|
heartbeat: HeartbeatConfig::default(),
|
||||||
channels_config: ChannelsConfig::default(),
|
channels_config: ChannelsConfig::default(),
|
||||||
memory: MemoryConfig::default(),
|
memory: MemoryConfig::default(),
|
||||||
|
|
@ -677,6 +843,7 @@ default_temperature = 0.7
|
||||||
composio: ComposioConfig::default(),
|
composio: ComposioConfig::default(),
|
||||||
secrets: SecretsConfig::default(),
|
secrets: SecretsConfig::default(),
|
||||||
browser: BrowserConfig::default(),
|
browser: BrowserConfig::default(),
|
||||||
|
identity: IdentityConfig::default(),
|
||||||
};
|
};
|
||||||
|
|
||||||
config.save().unwrap();
|
config.save().unwrap();
|
||||||
|
|
@ -810,6 +977,7 @@ default_temperature = 0.7
|
||||||
room_id: "!r:m".into(),
|
room_id: "!r:m".into(),
|
||||||
allowed_users: vec!["@u:m".into()],
|
allowed_users: vec!["@u:m".into()],
|
||||||
}),
|
}),
|
||||||
|
whatsapp: None,
|
||||||
};
|
};
|
||||||
let toml_str = toml::to_string_pretty(&c).unwrap();
|
let toml_str = toml::to_string_pretty(&c).unwrap();
|
||||||
let parsed: ChannelsConfig = toml::from_str(&toml_str).unwrap();
|
let parsed: ChannelsConfig = toml::from_str(&toml_str).unwrap();
|
||||||
|
|
@ -894,6 +1062,89 @@ channel_id = "C123"
|
||||||
assert_eq!(parsed.port, 8080);
|
assert_eq!(parsed.port, 8080);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ── WhatsApp config ──────────────────────────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn whatsapp_config_serde() {
|
||||||
|
let wc = WhatsAppConfig {
|
||||||
|
access_token: "EAABx...".into(),
|
||||||
|
phone_number_id: "123456789".into(),
|
||||||
|
verify_token: "my-verify-token".into(),
|
||||||
|
allowed_numbers: vec!["+1234567890".into(), "+9876543210".into()],
|
||||||
|
};
|
||||||
|
let json = serde_json::to_string(&wc).unwrap();
|
||||||
|
let parsed: WhatsAppConfig = serde_json::from_str(&json).unwrap();
|
||||||
|
assert_eq!(parsed.access_token, "EAABx...");
|
||||||
|
assert_eq!(parsed.phone_number_id, "123456789");
|
||||||
|
assert_eq!(parsed.verify_token, "my-verify-token");
|
||||||
|
assert_eq!(parsed.allowed_numbers.len(), 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn whatsapp_config_toml_roundtrip() {
|
||||||
|
let wc = WhatsAppConfig {
|
||||||
|
access_token: "tok".into(),
|
||||||
|
phone_number_id: "12345".into(),
|
||||||
|
verify_token: "verify".into(),
|
||||||
|
allowed_numbers: vec!["+1".into()],
|
||||||
|
};
|
||||||
|
let toml_str = toml::to_string(&wc).unwrap();
|
||||||
|
let parsed: WhatsAppConfig = toml::from_str(&toml_str).unwrap();
|
||||||
|
assert_eq!(parsed.phone_number_id, "12345");
|
||||||
|
assert_eq!(parsed.allowed_numbers, vec!["+1"]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn whatsapp_config_deserializes_without_allowed_numbers() {
|
||||||
|
let json = r#"{"access_token":"tok","phone_number_id":"123","verify_token":"ver"}"#;
|
||||||
|
let parsed: WhatsAppConfig = serde_json::from_str(json).unwrap();
|
||||||
|
assert!(parsed.allowed_numbers.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn whatsapp_config_wildcard_allowed() {
|
||||||
|
let wc = WhatsAppConfig {
|
||||||
|
access_token: "tok".into(),
|
||||||
|
phone_number_id: "123".into(),
|
||||||
|
verify_token: "ver".into(),
|
||||||
|
allowed_numbers: vec!["*".into()],
|
||||||
|
};
|
||||||
|
let toml_str = toml::to_string(&wc).unwrap();
|
||||||
|
let parsed: WhatsAppConfig = toml::from_str(&toml_str).unwrap();
|
||||||
|
assert_eq!(parsed.allowed_numbers, vec!["*"]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn channels_config_with_whatsapp() {
|
||||||
|
let c = ChannelsConfig {
|
||||||
|
cli: true,
|
||||||
|
telegram: None,
|
||||||
|
discord: None,
|
||||||
|
slack: None,
|
||||||
|
webhook: None,
|
||||||
|
imessage: None,
|
||||||
|
matrix: None,
|
||||||
|
whatsapp: Some(WhatsAppConfig {
|
||||||
|
access_token: "tok".into(),
|
||||||
|
phone_number_id: "123".into(),
|
||||||
|
verify_token: "ver".into(),
|
||||||
|
allowed_numbers: vec!["+1".into()],
|
||||||
|
}),
|
||||||
|
};
|
||||||
|
let toml_str = toml::to_string_pretty(&c).unwrap();
|
||||||
|
let parsed: ChannelsConfig = toml::from_str(&toml_str).unwrap();
|
||||||
|
assert!(parsed.whatsapp.is_some());
|
||||||
|
let wa = parsed.whatsapp.unwrap();
|
||||||
|
assert_eq!(wa.phone_number_id, "123");
|
||||||
|
assert_eq!(wa.allowed_numbers, vec!["+1"]);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn channels_config_default_has_no_whatsapp() {
|
||||||
|
let c = ChannelsConfig::default();
|
||||||
|
assert!(c.whatsapp.is_none());
|
||||||
|
}
|
||||||
|
|
||||||
// ══════════════════════════════════════════════════════════
|
// ══════════════════════════════════════════════════════════
|
||||||
// SECURITY CHECKLIST TESTS — Gateway config
|
// SECURITY CHECKLIST TESTS — Gateway config
|
||||||
// ══════════════════════════════════════════════════════════
|
// ══════════════════════════════════════════════════════════
|
||||||
|
|
|
||||||
350
src/cron/mod.rs
350
src/cron/mod.rs
|
|
@ -1,25 +1,353 @@
|
||||||
use crate::config::Config;
|
use crate::config::Config;
|
||||||
use anyhow::Result;
|
use anyhow::{Context, Result};
|
||||||
|
use chrono::{DateTime, Utc};
|
||||||
|
use cron::Schedule;
|
||||||
|
use rusqlite::{params, Connection};
|
||||||
|
use std::str::FromStr;
|
||||||
|
use uuid::Uuid;
|
||||||
|
|
||||||
pub fn handle_command(command: super::CronCommands, _config: Config) -> Result<()> {
|
pub mod scheduler;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct CronJob {
|
||||||
|
pub id: String,
|
||||||
|
pub expression: String,
|
||||||
|
pub command: String,
|
||||||
|
pub next_run: DateTime<Utc>,
|
||||||
|
pub last_run: Option<DateTime<Utc>>,
|
||||||
|
pub last_status: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn handle_command(command: super::CronCommands, config: Config) -> Result<()> {
|
||||||
match command {
|
match command {
|
||||||
super::CronCommands::List => {
|
super::CronCommands::List => {
|
||||||
println!("No scheduled tasks yet.");
|
let jobs = list_jobs(&config)?;
|
||||||
println!("\nUsage:");
|
if jobs.is_empty() {
|
||||||
println!(" zeroclaw cron add '0 9 * * *' 'agent -m \"Good morning!\"'");
|
println!("No scheduled tasks yet.");
|
||||||
|
println!("\nUsage:");
|
||||||
|
println!(" zeroclaw cron add '0 9 * * *' 'agent -m \"Good morning!\"'");
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
println!("🕒 Scheduled jobs ({}):", jobs.len());
|
||||||
|
for job in jobs {
|
||||||
|
let last_run = job
|
||||||
|
.last_run
|
||||||
|
.map(|d| d.to_rfc3339())
|
||||||
|
.unwrap_or_else(|| "never".into());
|
||||||
|
let last_status = job.last_status.unwrap_or_else(|| "n/a".into());
|
||||||
|
println!(
|
||||||
|
"- {} | {} | next={} | last={} ({})\n cmd: {}",
|
||||||
|
job.id,
|
||||||
|
job.expression,
|
||||||
|
job.next_run.to_rfc3339(),
|
||||||
|
last_run,
|
||||||
|
last_status,
|
||||||
|
job.command
|
||||||
|
);
|
||||||
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
super::CronCommands::Add {
|
super::CronCommands::Add {
|
||||||
expression,
|
expression,
|
||||||
command,
|
command,
|
||||||
} => {
|
} => {
|
||||||
println!("Cron scheduling coming soon!");
|
let job = add_job(&config, &expression, &command)?;
|
||||||
println!(" Expression: {expression}");
|
println!("✅ Added cron job {}", job.id);
|
||||||
println!(" Command: {command}");
|
println!(" Expr: {}", job.expression);
|
||||||
|
println!(" Next: {}", job.next_run.to_rfc3339());
|
||||||
|
println!(" Cmd : {}", job.command);
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
super::CronCommands::Remove { id } => {
|
super::CronCommands::Remove { id } => remove_job(&config, &id),
|
||||||
anyhow::bail!("Remove task '{id}' not yet implemented");
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn add_job(config: &Config, expression: &str, command: &str) -> Result<CronJob> {
|
||||||
|
let now = Utc::now();
|
||||||
|
let next_run = next_run_for(expression, now)?;
|
||||||
|
let id = Uuid::new_v4().to_string();
|
||||||
|
|
||||||
|
with_connection(config, |conn| {
|
||||||
|
conn.execute(
|
||||||
|
"INSERT INTO cron_jobs (id, expression, command, created_at, next_run)
|
||||||
|
VALUES (?1, ?2, ?3, ?4, ?5)",
|
||||||
|
params![
|
||||||
|
id,
|
||||||
|
expression,
|
||||||
|
command,
|
||||||
|
now.to_rfc3339(),
|
||||||
|
next_run.to_rfc3339()
|
||||||
|
],
|
||||||
|
)
|
||||||
|
.context("Failed to insert cron job")?;
|
||||||
|
Ok(())
|
||||||
|
})?;
|
||||||
|
|
||||||
|
Ok(CronJob {
|
||||||
|
id,
|
||||||
|
expression: expression.to_string(),
|
||||||
|
command: command.to_string(),
|
||||||
|
next_run,
|
||||||
|
last_run: None,
|
||||||
|
last_status: None,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn list_jobs(config: &Config) -> Result<Vec<CronJob>> {
|
||||||
|
with_connection(config, |conn| {
|
||||||
|
let mut stmt = conn.prepare(
|
||||||
|
"SELECT id, expression, command, next_run, last_run, last_status
|
||||||
|
FROM cron_jobs ORDER BY next_run ASC",
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let rows = stmt.query_map([], |row| {
|
||||||
|
let next_run_raw: String = row.get(3)?;
|
||||||
|
let last_run_raw: Option<String> = row.get(4)?;
|
||||||
|
Ok((
|
||||||
|
row.get::<_, String>(0)?,
|
||||||
|
row.get::<_, String>(1)?,
|
||||||
|
row.get::<_, String>(2)?,
|
||||||
|
next_run_raw,
|
||||||
|
last_run_raw,
|
||||||
|
row.get::<_, Option<String>>(5)?,
|
||||||
|
))
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let mut jobs = Vec::new();
|
||||||
|
for row in rows {
|
||||||
|
let (id, expression, command, next_run_raw, last_run_raw, last_status) = row?;
|
||||||
|
jobs.push(CronJob {
|
||||||
|
id,
|
||||||
|
expression,
|
||||||
|
command,
|
||||||
|
next_run: parse_rfc3339(&next_run_raw)?,
|
||||||
|
last_run: match last_run_raw {
|
||||||
|
Some(raw) => Some(parse_rfc3339(&raw)?),
|
||||||
|
None => None,
|
||||||
|
},
|
||||||
|
last_status,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
Ok(jobs)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn remove_job(config: &Config, id: &str) -> Result<()> {
|
||||||
|
let changed = with_connection(config, |conn| {
|
||||||
|
conn.execute("DELETE FROM cron_jobs WHERE id = ?1", params![id])
|
||||||
|
.context("Failed to delete cron job")
|
||||||
|
})?;
|
||||||
|
|
||||||
|
if changed == 0 {
|
||||||
|
anyhow::bail!("Cron job '{id}' not found");
|
||||||
|
}
|
||||||
|
|
||||||
|
println!("✅ Removed cron job {id}");
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn due_jobs(config: &Config, now: DateTime<Utc>) -> Result<Vec<CronJob>> {
|
||||||
|
with_connection(config, |conn| {
|
||||||
|
let mut stmt = conn.prepare(
|
||||||
|
"SELECT id, expression, command, next_run, last_run, last_status
|
||||||
|
FROM cron_jobs WHERE next_run <= ?1 ORDER BY next_run ASC",
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let rows = stmt.query_map(params![now.to_rfc3339()], |row| {
|
||||||
|
let next_run_raw: String = row.get(3)?;
|
||||||
|
let last_run_raw: Option<String> = row.get(4)?;
|
||||||
|
Ok((
|
||||||
|
row.get::<_, String>(0)?,
|
||||||
|
row.get::<_, String>(1)?,
|
||||||
|
row.get::<_, String>(2)?,
|
||||||
|
next_run_raw,
|
||||||
|
last_run_raw,
|
||||||
|
row.get::<_, Option<String>>(5)?,
|
||||||
|
))
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let mut jobs = Vec::new();
|
||||||
|
for row in rows {
|
||||||
|
let (id, expression, command, next_run_raw, last_run_raw, last_status) = row?;
|
||||||
|
jobs.push(CronJob {
|
||||||
|
id,
|
||||||
|
expression,
|
||||||
|
command,
|
||||||
|
next_run: parse_rfc3339(&next_run_raw)?,
|
||||||
|
last_run: match last_run_raw {
|
||||||
|
Some(raw) => Some(parse_rfc3339(&raw)?),
|
||||||
|
None => None,
|
||||||
|
},
|
||||||
|
last_status,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
Ok(jobs)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn reschedule_after_run(
|
||||||
|
config: &Config,
|
||||||
|
job: &CronJob,
|
||||||
|
success: bool,
|
||||||
|
output: &str,
|
||||||
|
) -> Result<()> {
|
||||||
|
let now = Utc::now();
|
||||||
|
let next_run = next_run_for(&job.expression, now)?;
|
||||||
|
let status = if success { "ok" } else { "error" };
|
||||||
|
|
||||||
|
with_connection(config, |conn| {
|
||||||
|
conn.execute(
|
||||||
|
"UPDATE cron_jobs
|
||||||
|
SET next_run = ?1, last_run = ?2, last_status = ?3, last_output = ?4
|
||||||
|
WHERE id = ?5",
|
||||||
|
params![
|
||||||
|
next_run.to_rfc3339(),
|
||||||
|
now.to_rfc3339(),
|
||||||
|
status,
|
||||||
|
output,
|
||||||
|
job.id
|
||||||
|
],
|
||||||
|
)
|
||||||
|
.context("Failed to update cron job run state")?;
|
||||||
|
Ok(())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn next_run_for(expression: &str, from: DateTime<Utc>) -> Result<DateTime<Utc>> {
|
||||||
|
let normalized = normalize_expression(expression)?;
|
||||||
|
let schedule = Schedule::from_str(&normalized)
|
||||||
|
.with_context(|| format!("Invalid cron expression: {expression}"))?;
|
||||||
|
schedule
|
||||||
|
.after(&from)
|
||||||
|
.next()
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("No future occurrence for expression: {expression}"))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn normalize_expression(expression: &str) -> Result<String> {
|
||||||
|
let expression = expression.trim();
|
||||||
|
let field_count = expression.split_whitespace().count();
|
||||||
|
|
||||||
|
match field_count {
|
||||||
|
// standard crontab syntax: minute hour day month weekday
|
||||||
|
5 => Ok(format!("0 {expression}")),
|
||||||
|
// crate-native syntax includes seconds (+ optional year)
|
||||||
|
6 | 7 => Ok(expression.to_string()),
|
||||||
|
_ => anyhow::bail!(
|
||||||
|
"Invalid cron expression: {expression} (expected 5, 6, or 7 fields, got {field_count})"
|
||||||
|
),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_rfc3339(raw: &str) -> Result<DateTime<Utc>> {
|
||||||
|
let parsed = DateTime::parse_from_rfc3339(raw)
|
||||||
|
.with_context(|| format!("Invalid RFC3339 timestamp in cron DB: {raw}"))?;
|
||||||
|
Ok(parsed.with_timezone(&Utc))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn with_connection<T>(config: &Config, f: impl FnOnce(&Connection) -> Result<T>) -> Result<T> {
|
||||||
|
let db_path = config.workspace_dir.join("cron").join("jobs.db");
|
||||||
|
if let Some(parent) = db_path.parent() {
|
||||||
|
std::fs::create_dir_all(parent)
|
||||||
|
.with_context(|| format!("Failed to create cron directory: {}", parent.display()))?;
|
||||||
|
}
|
||||||
|
|
||||||
|
let conn = Connection::open(&db_path)
|
||||||
|
.with_context(|| format!("Failed to open cron DB: {}", db_path.display()))?;
|
||||||
|
|
||||||
|
conn.execute_batch(
|
||||||
|
"CREATE TABLE IF NOT EXISTS cron_jobs (
|
||||||
|
id TEXT PRIMARY KEY,
|
||||||
|
expression TEXT NOT NULL,
|
||||||
|
command TEXT NOT NULL,
|
||||||
|
created_at TEXT NOT NULL,
|
||||||
|
next_run TEXT NOT NULL,
|
||||||
|
last_run TEXT,
|
||||||
|
last_status TEXT,
|
||||||
|
last_output TEXT
|
||||||
|
);
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_cron_jobs_next_run ON cron_jobs(next_run);",
|
||||||
|
)
|
||||||
|
.context("Failed to initialize cron schema")?;
|
||||||
|
|
||||||
|
f(&conn)
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::config::Config;
|
||||||
|
use chrono::Duration as ChronoDuration;
|
||||||
|
use tempfile::TempDir;
|
||||||
|
|
||||||
|
fn test_config(tmp: &TempDir) -> Config {
|
||||||
|
let mut config = Config::default();
|
||||||
|
config.workspace_dir = tmp.path().join("workspace");
|
||||||
|
config.config_path = tmp.path().join("config.toml");
|
||||||
|
std::fs::create_dir_all(&config.workspace_dir).unwrap();
|
||||||
|
config
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn add_job_accepts_five_field_expression() {
|
||||||
|
let tmp = TempDir::new().unwrap();
|
||||||
|
let config = test_config(&tmp);
|
||||||
|
|
||||||
|
let job = add_job(&config, "*/5 * * * *", "echo ok").unwrap();
|
||||||
|
|
||||||
|
assert_eq!(job.expression, "*/5 * * * *");
|
||||||
|
assert_eq!(job.command, "echo ok");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn add_job_rejects_invalid_field_count() {
|
||||||
|
let tmp = TempDir::new().unwrap();
|
||||||
|
let config = test_config(&tmp);
|
||||||
|
|
||||||
|
let err = add_job(&config, "* * * *", "echo bad").unwrap_err();
|
||||||
|
assert!(err.to_string().contains("expected 5, 6, or 7 fields"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn add_list_remove_roundtrip() {
|
||||||
|
let tmp = TempDir::new().unwrap();
|
||||||
|
let config = test_config(&tmp);
|
||||||
|
|
||||||
|
let job = add_job(&config, "*/10 * * * *", "echo roundtrip").unwrap();
|
||||||
|
let listed = list_jobs(&config).unwrap();
|
||||||
|
assert_eq!(listed.len(), 1);
|
||||||
|
assert_eq!(listed[0].id, job.id);
|
||||||
|
|
||||||
|
remove_job(&config, &job.id).unwrap();
|
||||||
|
assert!(list_jobs(&config).unwrap().is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn due_jobs_filters_by_timestamp() {
|
||||||
|
let tmp = TempDir::new().unwrap();
|
||||||
|
let config = test_config(&tmp);
|
||||||
|
|
||||||
|
let _job = add_job(&config, "* * * * *", "echo due").unwrap();
|
||||||
|
|
||||||
|
let due_now = due_jobs(&config, Utc::now()).unwrap();
|
||||||
|
assert!(due_now.is_empty(), "new job should not be due immediately");
|
||||||
|
|
||||||
|
let far_future = Utc::now() + ChronoDuration::days(365);
|
||||||
|
let due_future = due_jobs(&config, far_future).unwrap();
|
||||||
|
assert_eq!(due_future.len(), 1, "job should be due in far future");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn reschedule_after_run_persists_last_status_and_last_run() {
|
||||||
|
let tmp = TempDir::new().unwrap();
|
||||||
|
let config = test_config(&tmp);
|
||||||
|
|
||||||
|
let job = add_job(&config, "*/15 * * * *", "echo run").unwrap();
|
||||||
|
reschedule_after_run(&config, &job, false, "failed output").unwrap();
|
||||||
|
|
||||||
|
let listed = list_jobs(&config).unwrap();
|
||||||
|
let stored = listed.iter().find(|j| j.id == job.id).unwrap();
|
||||||
|
assert_eq!(stored.last_status.as_deref(), Some("error"));
|
||||||
|
assert!(stored.last_run.is_some());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
297
src/cron/scheduler.rs
Normal file
297
src/cron/scheduler.rs
Normal file
|
|
@ -0,0 +1,297 @@
|
||||||
|
use crate::config::Config;
|
||||||
|
use crate::cron::{due_jobs, reschedule_after_run, CronJob};
|
||||||
|
use crate::security::SecurityPolicy;
|
||||||
|
use anyhow::Result;
|
||||||
|
use chrono::Utc;
|
||||||
|
use tokio::process::Command;
|
||||||
|
use tokio::time::{self, Duration};
|
||||||
|
|
||||||
|
const MIN_POLL_SECONDS: u64 = 5;
|
||||||
|
|
||||||
|
pub async fn run(config: Config) -> Result<()> {
|
||||||
|
let poll_secs = config.reliability.scheduler_poll_secs.max(MIN_POLL_SECONDS);
|
||||||
|
let mut interval = time::interval(Duration::from_secs(poll_secs));
|
||||||
|
let security = SecurityPolicy::from_config(&config.autonomy, &config.workspace_dir);
|
||||||
|
|
||||||
|
crate::health::mark_component_ok("scheduler");
|
||||||
|
|
||||||
|
loop {
|
||||||
|
interval.tick().await;
|
||||||
|
|
||||||
|
let jobs = match due_jobs(&config, Utc::now()) {
|
||||||
|
Ok(jobs) => jobs,
|
||||||
|
Err(e) => {
|
||||||
|
crate::health::mark_component_error("scheduler", e.to_string());
|
||||||
|
tracing::warn!("Scheduler query failed: {e}");
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
for job in jobs {
|
||||||
|
crate::health::mark_component_ok("scheduler");
|
||||||
|
let (success, output) = execute_job_with_retry(&config, &security, &job).await;
|
||||||
|
|
||||||
|
if !success {
|
||||||
|
crate::health::mark_component_error("scheduler", format!("job {} failed", job.id));
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Err(e) = reschedule_after_run(&config, &job, success, &output) {
|
||||||
|
crate::health::mark_component_error("scheduler", e.to_string());
|
||||||
|
tracing::warn!("Failed to persist scheduler run result: {e}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn execute_job_with_retry(
|
||||||
|
config: &Config,
|
||||||
|
security: &SecurityPolicy,
|
||||||
|
job: &CronJob,
|
||||||
|
) -> (bool, String) {
|
||||||
|
let mut last_output = String::new();
|
||||||
|
let retries = config.reliability.scheduler_retries;
|
||||||
|
let mut backoff_ms = config.reliability.provider_backoff_ms.max(200);
|
||||||
|
|
||||||
|
for attempt in 0..=retries {
|
||||||
|
let (success, output) = run_job_command(config, security, job).await;
|
||||||
|
last_output = output;
|
||||||
|
|
||||||
|
if success {
|
||||||
|
return (true, last_output);
|
||||||
|
}
|
||||||
|
|
||||||
|
if last_output.starts_with("blocked by security policy:") {
|
||||||
|
// Deterministic policy violations are not retryable.
|
||||||
|
return (false, last_output);
|
||||||
|
}
|
||||||
|
|
||||||
|
if attempt < retries {
|
||||||
|
let jitter_ms = (Utc::now().timestamp_subsec_millis() % 250) as u64;
|
||||||
|
time::sleep(Duration::from_millis(backoff_ms + jitter_ms)).await;
|
||||||
|
backoff_ms = (backoff_ms.saturating_mul(2)).min(30_000);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
(false, last_output)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_env_assignment(word: &str) -> bool {
|
||||||
|
word.contains('=')
|
||||||
|
&& word
|
||||||
|
.chars()
|
||||||
|
.next()
|
||||||
|
.is_some_and(|c| c.is_ascii_alphabetic() || c == '_')
|
||||||
|
}
|
||||||
|
|
||||||
|
fn strip_wrapping_quotes(token: &str) -> &str {
|
||||||
|
token.trim_matches(|c| c == '"' || c == '\'')
|
||||||
|
}
|
||||||
|
|
||||||
|
fn forbidden_path_argument(security: &SecurityPolicy, command: &str) -> Option<String> {
|
||||||
|
let mut normalized = command.to_string();
|
||||||
|
for sep in ["&&", "||"] {
|
||||||
|
normalized = normalized.replace(sep, "\x00");
|
||||||
|
}
|
||||||
|
for sep in ['\n', ';', '|'] {
|
||||||
|
normalized = normalized.replace(sep, "\x00");
|
||||||
|
}
|
||||||
|
|
||||||
|
for segment in normalized.split('\x00') {
|
||||||
|
let tokens: Vec<&str> = segment.split_whitespace().collect();
|
||||||
|
if tokens.is_empty() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Skip leading env assignments and executable token.
|
||||||
|
let mut idx = 0;
|
||||||
|
while idx < tokens.len() && is_env_assignment(tokens[idx]) {
|
||||||
|
idx += 1;
|
||||||
|
}
|
||||||
|
if idx >= tokens.len() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
idx += 1;
|
||||||
|
|
||||||
|
for token in &tokens[idx..] {
|
||||||
|
let candidate = strip_wrapping_quotes(token);
|
||||||
|
if candidate.is_empty() || candidate.starts_with('-') || candidate.contains("://") {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let looks_like_path = candidate.starts_with('/')
|
||||||
|
|| candidate.starts_with("./")
|
||||||
|
|| candidate.starts_with("../")
|
||||||
|
|| candidate.starts_with("~/")
|
||||||
|
|| candidate.contains('/');
|
||||||
|
|
||||||
|
if looks_like_path && !security.is_path_allowed(candidate) {
|
||||||
|
return Some(candidate.to_string());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn run_job_command(
|
||||||
|
config: &Config,
|
||||||
|
security: &SecurityPolicy,
|
||||||
|
job: &CronJob,
|
||||||
|
) -> (bool, String) {
|
||||||
|
if !security.is_command_allowed(&job.command) {
|
||||||
|
return (
|
||||||
|
false,
|
||||||
|
format!(
|
||||||
|
"blocked by security policy: command not allowed: {}",
|
||||||
|
job.command
|
||||||
|
),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(path) = forbidden_path_argument(security, &job.command) {
|
||||||
|
return (
|
||||||
|
false,
|
||||||
|
format!("blocked by security policy: forbidden path argument: {path}"),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let output = Command::new("sh")
|
||||||
|
.arg("-lc")
|
||||||
|
.arg(&job.command)
|
||||||
|
.current_dir(&config.workspace_dir)
|
||||||
|
.output()
|
||||||
|
.await;
|
||||||
|
|
||||||
|
match output {
|
||||||
|
Ok(output) => {
|
||||||
|
let stdout = String::from_utf8_lossy(&output.stdout);
|
||||||
|
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||||
|
let combined = format!(
|
||||||
|
"status={}\nstdout:\n{}\nstderr:\n{}",
|
||||||
|
output.status,
|
||||||
|
stdout.trim(),
|
||||||
|
stderr.trim()
|
||||||
|
);
|
||||||
|
(output.status.success(), combined)
|
||||||
|
}
|
||||||
|
Err(e) => (false, format!("spawn error: {e}")),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::config::Config;
|
||||||
|
use crate::security::SecurityPolicy;
|
||||||
|
use tempfile::TempDir;
|
||||||
|
|
||||||
|
fn test_config(tmp: &TempDir) -> Config {
|
||||||
|
let mut config = Config::default();
|
||||||
|
config.workspace_dir = tmp.path().join("workspace");
|
||||||
|
config.config_path = tmp.path().join("config.toml");
|
||||||
|
std::fs::create_dir_all(&config.workspace_dir).unwrap();
|
||||||
|
config
|
||||||
|
}
|
||||||
|
|
||||||
|
fn test_job(command: &str) -> CronJob {
|
||||||
|
CronJob {
|
||||||
|
id: "test-job".into(),
|
||||||
|
expression: "* * * * *".into(),
|
||||||
|
command: command.into(),
|
||||||
|
next_run: Utc::now(),
|
||||||
|
last_run: None,
|
||||||
|
last_status: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn run_job_command_success() {
|
||||||
|
let tmp = TempDir::new().unwrap();
|
||||||
|
let config = test_config(&tmp);
|
||||||
|
let job = test_job("echo scheduler-ok");
|
||||||
|
let security = SecurityPolicy::from_config(&config.autonomy, &config.workspace_dir);
|
||||||
|
|
||||||
|
let (success, output) = run_job_command(&config, &security, &job).await;
|
||||||
|
assert!(success);
|
||||||
|
assert!(output.contains("scheduler-ok"));
|
||||||
|
assert!(output.contains("status=exit status: 0"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn run_job_command_failure() {
|
||||||
|
let tmp = TempDir::new().unwrap();
|
||||||
|
let config = test_config(&tmp);
|
||||||
|
let job = test_job("ls definitely_missing_file_for_scheduler_test");
|
||||||
|
let security = SecurityPolicy::from_config(&config.autonomy, &config.workspace_dir);
|
||||||
|
|
||||||
|
let (success, output) = run_job_command(&config, &security, &job).await;
|
||||||
|
assert!(!success);
|
||||||
|
assert!(output.contains("definitely_missing_file_for_scheduler_test"));
|
||||||
|
assert!(output.contains("status=exit status:"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn run_job_command_blocks_disallowed_command() {
|
||||||
|
let tmp = TempDir::new().unwrap();
|
||||||
|
let mut config = test_config(&tmp);
|
||||||
|
config.autonomy.allowed_commands = vec!["echo".into()];
|
||||||
|
let job = test_job("curl https://evil.example");
|
||||||
|
let security = SecurityPolicy::from_config(&config.autonomy, &config.workspace_dir);
|
||||||
|
|
||||||
|
let (success, output) = run_job_command(&config, &security, &job).await;
|
||||||
|
assert!(!success);
|
||||||
|
assert!(output.contains("blocked by security policy"));
|
||||||
|
assert!(output.contains("command not allowed"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn run_job_command_blocks_forbidden_path_argument() {
|
||||||
|
let tmp = TempDir::new().unwrap();
|
||||||
|
let mut config = test_config(&tmp);
|
||||||
|
config.autonomy.allowed_commands = vec!["cat".into()];
|
||||||
|
let job = test_job("cat /etc/passwd");
|
||||||
|
let security = SecurityPolicy::from_config(&config.autonomy, &config.workspace_dir);
|
||||||
|
|
||||||
|
let (success, output) = run_job_command(&config, &security, &job).await;
|
||||||
|
assert!(!success);
|
||||||
|
assert!(output.contains("blocked by security policy"));
|
||||||
|
assert!(output.contains("forbidden path argument"));
|
||||||
|
assert!(output.contains("/etc/passwd"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn execute_job_with_retry_recovers_after_first_failure() {
|
||||||
|
let tmp = TempDir::new().unwrap();
|
||||||
|
let mut config = test_config(&tmp);
|
||||||
|
config.reliability.scheduler_retries = 1;
|
||||||
|
config.reliability.provider_backoff_ms = 1;
|
||||||
|
config.autonomy.allowed_commands = vec!["sh".into()];
|
||||||
|
let security = SecurityPolicy::from_config(&config.autonomy, &config.workspace_dir);
|
||||||
|
|
||||||
|
std::fs::write(
|
||||||
|
config.workspace_dir.join("retry-once.sh"),
|
||||||
|
"#!/bin/sh\nif [ -f retry-ok.flag ]; then\n echo recovered\n exit 0\nfi\ntouch retry-ok.flag\nexit 1\n",
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
let job = test_job("sh ./retry-once.sh");
|
||||||
|
|
||||||
|
let (success, output) = execute_job_with_retry(&config, &security, &job).await;
|
||||||
|
assert!(success);
|
||||||
|
assert!(output.contains("recovered"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn execute_job_with_retry_exhausts_attempts() {
|
||||||
|
let tmp = TempDir::new().unwrap();
|
||||||
|
let mut config = test_config(&tmp);
|
||||||
|
config.reliability.scheduler_retries = 1;
|
||||||
|
config.reliability.provider_backoff_ms = 1;
|
||||||
|
let security = SecurityPolicy::from_config(&config.autonomy, &config.workspace_dir);
|
||||||
|
|
||||||
|
let job = test_job("ls always_missing_for_retry_test");
|
||||||
|
|
||||||
|
let (success, output) = execute_job_with_retry(&config, &security, &job).await;
|
||||||
|
assert!(!success);
|
||||||
|
assert!(output.contains("always_missing_for_retry_test"));
|
||||||
|
}
|
||||||
|
}
|
||||||
287
src/daemon/mod.rs
Normal file
287
src/daemon/mod.rs
Normal file
|
|
@ -0,0 +1,287 @@
|
||||||
|
use crate::config::Config;
|
||||||
|
use anyhow::Result;
|
||||||
|
use chrono::Utc;
|
||||||
|
use std::future::Future;
|
||||||
|
use std::path::PathBuf;
|
||||||
|
use tokio::task::JoinHandle;
|
||||||
|
use tokio::time::Duration;
|
||||||
|
|
||||||
|
const STATUS_FLUSH_SECONDS: u64 = 5;
|
||||||
|
|
||||||
|
pub async fn run(config: Config, host: String, port: u16) -> Result<()> {
|
||||||
|
let initial_backoff = config.reliability.channel_initial_backoff_secs.max(1);
|
||||||
|
let max_backoff = config
|
||||||
|
.reliability
|
||||||
|
.channel_max_backoff_secs
|
||||||
|
.max(initial_backoff);
|
||||||
|
|
||||||
|
crate::health::mark_component_ok("daemon");
|
||||||
|
|
||||||
|
if config.heartbeat.enabled {
|
||||||
|
let _ =
|
||||||
|
crate::heartbeat::engine::HeartbeatEngine::ensure_heartbeat_file(&config.workspace_dir)
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut handles: Vec<JoinHandle<()>> = vec![spawn_state_writer(config.clone())];
|
||||||
|
|
||||||
|
{
|
||||||
|
let gateway_cfg = config.clone();
|
||||||
|
let gateway_host = host.clone();
|
||||||
|
handles.push(spawn_component_supervisor(
|
||||||
|
"gateway",
|
||||||
|
initial_backoff,
|
||||||
|
max_backoff,
|
||||||
|
move || {
|
||||||
|
let cfg = gateway_cfg.clone();
|
||||||
|
let host = gateway_host.clone();
|
||||||
|
async move { crate::gateway::run_gateway(&host, port, cfg).await }
|
||||||
|
},
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
if has_supervised_channels(&config) {
|
||||||
|
let channels_cfg = config.clone();
|
||||||
|
handles.push(spawn_component_supervisor(
|
||||||
|
"channels",
|
||||||
|
initial_backoff,
|
||||||
|
max_backoff,
|
||||||
|
move || {
|
||||||
|
let cfg = channels_cfg.clone();
|
||||||
|
async move { crate::channels::start_channels(cfg).await }
|
||||||
|
},
|
||||||
|
));
|
||||||
|
} else {
|
||||||
|
crate::health::mark_component_ok("channels");
|
||||||
|
tracing::info!("No real-time channels configured; channel supervisor disabled");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.heartbeat.enabled {
|
||||||
|
let heartbeat_cfg = config.clone();
|
||||||
|
handles.push(spawn_component_supervisor(
|
||||||
|
"heartbeat",
|
||||||
|
initial_backoff,
|
||||||
|
max_backoff,
|
||||||
|
move || {
|
||||||
|
let cfg = heartbeat_cfg.clone();
|
||||||
|
async move { run_heartbeat_worker(cfg).await }
|
||||||
|
},
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
let scheduler_cfg = config.clone();
|
||||||
|
handles.push(spawn_component_supervisor(
|
||||||
|
"scheduler",
|
||||||
|
initial_backoff,
|
||||||
|
max_backoff,
|
||||||
|
move || {
|
||||||
|
let cfg = scheduler_cfg.clone();
|
||||||
|
async move { crate::cron::scheduler::run(cfg).await }
|
||||||
|
},
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
println!("🧠 ZeroClaw daemon started");
|
||||||
|
println!(" Gateway: http://{host}:{port}");
|
||||||
|
println!(" Components: gateway, channels, heartbeat, scheduler");
|
||||||
|
println!(" Ctrl+C to stop");
|
||||||
|
|
||||||
|
tokio::signal::ctrl_c().await?;
|
||||||
|
crate::health::mark_component_error("daemon", "shutdown requested");
|
||||||
|
|
||||||
|
for handle in &handles {
|
||||||
|
handle.abort();
|
||||||
|
}
|
||||||
|
for handle in handles {
|
||||||
|
let _ = handle.await;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn state_file_path(config: &Config) -> PathBuf {
|
||||||
|
config
|
||||||
|
.config_path
|
||||||
|
.parent()
|
||||||
|
.map_or_else(|| PathBuf::from("."), PathBuf::from)
|
||||||
|
.join("daemon_state.json")
|
||||||
|
}
|
||||||
|
|
||||||
|
fn spawn_state_writer(config: Config) -> JoinHandle<()> {
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let path = state_file_path(&config);
|
||||||
|
if let Some(parent) = path.parent() {
|
||||||
|
let _ = tokio::fs::create_dir_all(parent).await;
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut interval = tokio::time::interval(Duration::from_secs(STATUS_FLUSH_SECONDS));
|
||||||
|
loop {
|
||||||
|
interval.tick().await;
|
||||||
|
let mut json = crate::health::snapshot_json();
|
||||||
|
if let Some(obj) = json.as_object_mut() {
|
||||||
|
obj.insert(
|
||||||
|
"written_at".into(),
|
||||||
|
serde_json::json!(Utc::now().to_rfc3339()),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
let data = serde_json::to_vec_pretty(&json).unwrap_or_else(|_| b"{}".to_vec());
|
||||||
|
let _ = tokio::fs::write(&path, data).await;
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn spawn_component_supervisor<F, Fut>(
|
||||||
|
name: &'static str,
|
||||||
|
initial_backoff_secs: u64,
|
||||||
|
max_backoff_secs: u64,
|
||||||
|
mut run_component: F,
|
||||||
|
) -> JoinHandle<()>
|
||||||
|
where
|
||||||
|
F: FnMut() -> Fut + Send + 'static,
|
||||||
|
Fut: Future<Output = Result<()>> + Send + 'static,
|
||||||
|
{
|
||||||
|
tokio::spawn(async move {
|
||||||
|
let mut backoff = initial_backoff_secs.max(1);
|
||||||
|
let max_backoff = max_backoff_secs.max(backoff);
|
||||||
|
|
||||||
|
loop {
|
||||||
|
crate::health::mark_component_ok(name);
|
||||||
|
match run_component().await {
|
||||||
|
Ok(()) => {
|
||||||
|
crate::health::mark_component_error(name, "component exited unexpectedly");
|
||||||
|
tracing::warn!("Daemon component '{name}' exited unexpectedly");
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
crate::health::mark_component_error(name, e.to_string());
|
||||||
|
tracing::error!("Daemon component '{name}' failed: {e}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
crate::health::bump_component_restart(name);
|
||||||
|
tokio::time::sleep(Duration::from_secs(backoff)).await;
|
||||||
|
backoff = backoff.saturating_mul(2).min(max_backoff);
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn run_heartbeat_worker(config: Config) -> Result<()> {
|
||||||
|
let observer: std::sync::Arc<dyn crate::observability::Observer> =
|
||||||
|
std::sync::Arc::from(crate::observability::create_observer(&config.observability));
|
||||||
|
let engine = crate::heartbeat::engine::HeartbeatEngine::new(
|
||||||
|
config.heartbeat.clone(),
|
||||||
|
config.workspace_dir.clone(),
|
||||||
|
observer,
|
||||||
|
);
|
||||||
|
|
||||||
|
let interval_mins = config.heartbeat.interval_minutes.max(5);
|
||||||
|
let mut interval = tokio::time::interval(Duration::from_secs(u64::from(interval_mins) * 60));
|
||||||
|
|
||||||
|
loop {
|
||||||
|
interval.tick().await;
|
||||||
|
|
||||||
|
let tasks = engine.collect_tasks().await?;
|
||||||
|
if tasks.is_empty() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
for task in tasks {
|
||||||
|
let prompt = format!("[Heartbeat Task] {task}");
|
||||||
|
let temp = config.default_temperature;
|
||||||
|
if let Err(e) = crate::agent::run(config.clone(), Some(prompt), None, None, temp).await
|
||||||
|
{
|
||||||
|
crate::health::mark_component_error("heartbeat", e.to_string());
|
||||||
|
tracing::warn!("Heartbeat task failed: {e}");
|
||||||
|
} else {
|
||||||
|
crate::health::mark_component_ok("heartbeat");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn has_supervised_channels(config: &Config) -> bool {
|
||||||
|
config.channels_config.telegram.is_some()
|
||||||
|
|| config.channels_config.discord.is_some()
|
||||||
|
|| config.channels_config.slack.is_some()
|
||||||
|
|| config.channels_config.imessage.is_some()
|
||||||
|
|| config.channels_config.matrix.is_some()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use tempfile::TempDir;
|
||||||
|
|
||||||
|
fn test_config(tmp: &TempDir) -> Config {
|
||||||
|
let mut config = Config::default();
|
||||||
|
config.workspace_dir = tmp.path().join("workspace");
|
||||||
|
config.config_path = tmp.path().join("config.toml");
|
||||||
|
std::fs::create_dir_all(&config.workspace_dir).unwrap();
|
||||||
|
config
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn state_file_path_uses_config_directory() {
|
||||||
|
let tmp = TempDir::new().unwrap();
|
||||||
|
let config = test_config(&tmp);
|
||||||
|
|
||||||
|
let path = state_file_path(&config);
|
||||||
|
assert_eq!(path, tmp.path().join("daemon_state.json"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn supervisor_marks_error_and_restart_on_failure() {
|
||||||
|
let handle = spawn_component_supervisor("daemon-test-fail", 1, 1, || async {
|
||||||
|
anyhow::bail!("boom")
|
||||||
|
});
|
||||||
|
|
||||||
|
tokio::time::sleep(Duration::from_millis(50)).await;
|
||||||
|
handle.abort();
|
||||||
|
let _ = handle.await;
|
||||||
|
|
||||||
|
let snapshot = crate::health::snapshot_json();
|
||||||
|
let component = &snapshot["components"]["daemon-test-fail"];
|
||||||
|
assert_eq!(component["status"], "error");
|
||||||
|
assert!(component["restart_count"].as_u64().unwrap_or(0) >= 1);
|
||||||
|
assert!(component["last_error"]
|
||||||
|
.as_str()
|
||||||
|
.unwrap_or("")
|
||||||
|
.contains("boom"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn supervisor_marks_unexpected_exit_as_error() {
|
||||||
|
let handle = spawn_component_supervisor("daemon-test-exit", 1, 1, || async { Ok(()) });
|
||||||
|
|
||||||
|
tokio::time::sleep(Duration::from_millis(50)).await;
|
||||||
|
handle.abort();
|
||||||
|
let _ = handle.await;
|
||||||
|
|
||||||
|
let snapshot = crate::health::snapshot_json();
|
||||||
|
let component = &snapshot["components"]["daemon-test-exit"];
|
||||||
|
assert_eq!(component["status"], "error");
|
||||||
|
assert!(component["restart_count"].as_u64().unwrap_or(0) >= 1);
|
||||||
|
assert!(component["last_error"]
|
||||||
|
.as_str()
|
||||||
|
.unwrap_or("")
|
||||||
|
.contains("component exited unexpectedly"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn detects_no_supervised_channels() {
|
||||||
|
let config = Config::default();
|
||||||
|
assert!(!has_supervised_channels(&config));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn detects_supervised_channels_present() {
|
||||||
|
let mut config = Config::default();
|
||||||
|
config.channels_config.telegram = Some(crate::config::TelegramConfig {
|
||||||
|
bot_token: "token".into(),
|
||||||
|
allowed_users: vec![],
|
||||||
|
});
|
||||||
|
assert!(has_supervised_channels(&config));
|
||||||
|
}
|
||||||
|
}
|
||||||
123
src/doctor/mod.rs
Normal file
123
src/doctor/mod.rs
Normal file
|
|
@ -0,0 +1,123 @@
|
||||||
|
use crate::config::Config;
|
||||||
|
use anyhow::{Context, Result};
|
||||||
|
use chrono::{DateTime, Utc};
|
||||||
|
|
||||||
|
const DAEMON_STALE_SECONDS: i64 = 30;
|
||||||
|
const SCHEDULER_STALE_SECONDS: i64 = 120;
|
||||||
|
const CHANNEL_STALE_SECONDS: i64 = 300;
|
||||||
|
|
||||||
|
pub fn run(config: &Config) -> Result<()> {
|
||||||
|
let state_file = crate::daemon::state_file_path(config);
|
||||||
|
if !state_file.exists() {
|
||||||
|
println!("🩺 ZeroClaw Doctor");
|
||||||
|
println!(" ❌ daemon state file not found: {}", state_file.display());
|
||||||
|
println!(" 💡 Start daemon with: zeroclaw daemon");
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
let raw = std::fs::read_to_string(&state_file)
|
||||||
|
.with_context(|| format!("Failed to read {}", state_file.display()))?;
|
||||||
|
let snapshot: serde_json::Value = serde_json::from_str(&raw)
|
||||||
|
.with_context(|| format!("Failed to parse {}", state_file.display()))?;
|
||||||
|
|
||||||
|
println!("🩺 ZeroClaw Doctor");
|
||||||
|
println!(" State file: {}", state_file.display());
|
||||||
|
|
||||||
|
let updated_at = snapshot
|
||||||
|
.get("updated_at")
|
||||||
|
.and_then(serde_json::Value::as_str)
|
||||||
|
.unwrap_or("");
|
||||||
|
|
||||||
|
if let Ok(ts) = DateTime::parse_from_rfc3339(updated_at) {
|
||||||
|
let age = Utc::now()
|
||||||
|
.signed_duration_since(ts.with_timezone(&Utc))
|
||||||
|
.num_seconds();
|
||||||
|
if age <= DAEMON_STALE_SECONDS {
|
||||||
|
println!(" ✅ daemon heartbeat fresh ({age}s ago)");
|
||||||
|
} else {
|
||||||
|
println!(" ❌ daemon heartbeat stale ({age}s ago)");
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
println!(" ❌ invalid daemon timestamp: {updated_at}");
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut channel_count = 0_u32;
|
||||||
|
let mut stale_channels = 0_u32;
|
||||||
|
|
||||||
|
if let Some(components) = snapshot
|
||||||
|
.get("components")
|
||||||
|
.and_then(serde_json::Value::as_object)
|
||||||
|
{
|
||||||
|
if let Some(scheduler) = components.get("scheduler") {
|
||||||
|
let scheduler_ok = scheduler
|
||||||
|
.get("status")
|
||||||
|
.and_then(serde_json::Value::as_str)
|
||||||
|
.map(|s| s == "ok")
|
||||||
|
.unwrap_or(false);
|
||||||
|
|
||||||
|
let scheduler_last_ok = scheduler
|
||||||
|
.get("last_ok")
|
||||||
|
.and_then(serde_json::Value::as_str)
|
||||||
|
.and_then(parse_rfc3339)
|
||||||
|
.map(|dt| Utc::now().signed_duration_since(dt).num_seconds())
|
||||||
|
.unwrap_or(i64::MAX);
|
||||||
|
|
||||||
|
if scheduler_ok && scheduler_last_ok <= SCHEDULER_STALE_SECONDS {
|
||||||
|
println!(
|
||||||
|
" ✅ scheduler healthy (last ok {}s ago)",
|
||||||
|
scheduler_last_ok
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
println!(
|
||||||
|
" ❌ scheduler unhealthy/stale (status_ok={}, age={}s)",
|
||||||
|
scheduler_ok, scheduler_last_ok
|
||||||
|
);
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
println!(" ❌ scheduler component missing");
|
||||||
|
}
|
||||||
|
|
||||||
|
for (name, component) in components {
|
||||||
|
if !name.starts_with("channel:") {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
channel_count += 1;
|
||||||
|
let status_ok = component
|
||||||
|
.get("status")
|
||||||
|
.and_then(serde_json::Value::as_str)
|
||||||
|
.map(|s| s == "ok")
|
||||||
|
.unwrap_or(false);
|
||||||
|
let age = component
|
||||||
|
.get("last_ok")
|
||||||
|
.and_then(serde_json::Value::as_str)
|
||||||
|
.and_then(parse_rfc3339)
|
||||||
|
.map(|dt| Utc::now().signed_duration_since(dt).num_seconds())
|
||||||
|
.unwrap_or(i64::MAX);
|
||||||
|
|
||||||
|
if status_ok && age <= CHANNEL_STALE_SECONDS {
|
||||||
|
println!(" ✅ {name} fresh (last ok {age}s ago)");
|
||||||
|
} else {
|
||||||
|
stale_channels += 1;
|
||||||
|
println!(" ❌ {name} stale/unhealthy (status_ok={status_ok}, age={age}s)");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if channel_count == 0 {
|
||||||
|
println!(" ℹ️ no channel components tracked in state yet");
|
||||||
|
} else {
|
||||||
|
println!(
|
||||||
|
" Channel summary: {} total, {} stale",
|
||||||
|
channel_count, stale_channels
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_rfc3339(raw: &str) -> Option<DateTime<Utc>> {
|
||||||
|
DateTime::parse_from_rfc3339(raw)
|
||||||
|
.ok()
|
||||||
|
.map(|dt| dt.with_timezone(&Utc))
|
||||||
|
}
|
||||||
|
|
@ -1,15 +1,49 @@
|
||||||
|
//! Axum-based HTTP gateway with proper HTTP/1.1 compliance, body limits, and timeouts.
|
||||||
|
//!
|
||||||
|
//! This module replaces the raw TCP implementation with axum for:
|
||||||
|
//! - Proper HTTP/1.1 parsing and compliance
|
||||||
|
//! - Content-Length validation (handled by hyper)
|
||||||
|
//! - Request body size limits (64KB max)
|
||||||
|
//! - Request timeouts (30s) to prevent slow-loris attacks
|
||||||
|
//! - Header sanitization (handled by axum/hyper)
|
||||||
|
|
||||||
|
use crate::channels::{Channel, WhatsAppChannel};
|
||||||
use crate::config::Config;
|
use crate::config::Config;
|
||||||
use crate::memory::{self, Memory, MemoryCategory};
|
use crate::memory::{self, Memory, MemoryCategory};
|
||||||
use crate::providers::{self, Provider};
|
use crate::providers::{self, Provider};
|
||||||
use crate::security::pairing::{constant_time_eq, is_public_bind, PairingGuard};
|
use crate::security::pairing::{constant_time_eq, is_public_bind, PairingGuard};
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
|
use axum::{
|
||||||
|
body::Bytes,
|
||||||
|
extract::{Query, State},
|
||||||
|
http::{header, HeaderMap, StatusCode},
|
||||||
|
response::{IntoResponse, Json},
|
||||||
|
routing::{get, post},
|
||||||
|
Router,
|
||||||
|
};
|
||||||
|
use std::net::SocketAddr;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use std::time::Duration;
|
use tower_http::limit::RequestBodyLimitLayer;
|
||||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
|
||||||
use tokio::net::TcpListener;
|
|
||||||
|
|
||||||
/// Run a minimal HTTP gateway (webhook + health check)
|
/// Maximum request body size (64KB) — prevents memory exhaustion
|
||||||
/// Zero new dependencies — uses raw TCP + tokio.
|
pub const MAX_BODY_SIZE: usize = 65_536;
|
||||||
|
/// Request timeout (30s) — prevents slow-loris attacks
|
||||||
|
pub const REQUEST_TIMEOUT_SECS: u64 = 30;
|
||||||
|
|
||||||
|
/// Shared state for all axum handlers
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct AppState {
|
||||||
|
pub provider: Arc<dyn Provider>,
|
||||||
|
pub model: String,
|
||||||
|
pub temperature: f64,
|
||||||
|
pub mem: Arc<dyn Memory>,
|
||||||
|
pub auto_save: bool,
|
||||||
|
pub webhook_secret: Option<Arc<str>>,
|
||||||
|
pub pairing: Arc<PairingGuard>,
|
||||||
|
pub whatsapp: Option<Arc<WhatsAppChannel>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Run the HTTP gateway using axum with proper HTTP/1.1 compliance.
|
||||||
#[allow(clippy::too_many_lines)]
|
#[allow(clippy::too_many_lines)]
|
||||||
pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
|
pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
|
||||||
// ── Security: refuse public bind without tunnel or explicit opt-in ──
|
// ── Security: refuse public bind without tunnel or explicit opt-in ──
|
||||||
|
|
@ -22,13 +56,15 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
let listener = TcpListener::bind(format!("{host}:{port}")).await?;
|
let addr: SocketAddr = format!("{host}:{port}").parse()?;
|
||||||
|
let listener = tokio::net::TcpListener::bind(addr).await?;
|
||||||
let actual_port = listener.local_addr()?.port();
|
let actual_port = listener.local_addr()?.port();
|
||||||
let addr = format!("{host}:{actual_port}");
|
let display_addr = format!("{host}:{actual_port}");
|
||||||
|
|
||||||
let provider: Arc<dyn Provider> = Arc::from(providers::create_provider(
|
let provider: Arc<dyn Provider> = Arc::from(providers::create_resilient_provider(
|
||||||
config.default_provider.as_deref().unwrap_or("openrouter"),
|
config.default_provider.as_deref().unwrap_or("openrouter"),
|
||||||
config.api_key.as_deref(),
|
config.api_key.as_deref(),
|
||||||
|
&config.reliability,
|
||||||
)?);
|
)?);
|
||||||
let model = config
|
let model = config
|
||||||
.default_model
|
.default_model
|
||||||
|
|
@ -49,6 +85,17 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
|
||||||
.and_then(|w| w.secret.as_deref())
|
.and_then(|w| w.secret.as_deref())
|
||||||
.map(Arc::from);
|
.map(Arc::from);
|
||||||
|
|
||||||
|
// WhatsApp channel (if configured)
|
||||||
|
let whatsapp_channel: Option<Arc<WhatsAppChannel>> =
|
||||||
|
config.channels_config.whatsapp.as_ref().map(|wa| {
|
||||||
|
Arc::new(WhatsAppChannel::new(
|
||||||
|
wa.access_token.clone(),
|
||||||
|
wa.phone_number_id.clone(),
|
||||||
|
wa.verify_token.clone(),
|
||||||
|
wa.allowed_numbers.clone(),
|
||||||
|
))
|
||||||
|
});
|
||||||
|
|
||||||
// ── Pairing guard ──────────────────────────────────────
|
// ── Pairing guard ──────────────────────────────────────
|
||||||
let pairing = Arc::new(PairingGuard::new(
|
let pairing = Arc::new(PairingGuard::new(
|
||||||
config.gateway.require_pairing,
|
config.gateway.require_pairing,
|
||||||
|
|
@ -73,16 +120,20 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
println!("🦀 ZeroClaw Gateway listening on http://{addr}");
|
println!("🦀 ZeroClaw Gateway listening on http://{display_addr}");
|
||||||
if let Some(ref url) = tunnel_url {
|
if let Some(ref url) = tunnel_url {
|
||||||
println!(" 🌐 Public URL: {url}");
|
println!(" 🌐 Public URL: {url}");
|
||||||
}
|
}
|
||||||
println!(" POST /pair — pair a new client (X-Pairing-Code header)");
|
println!(" POST /pair — pair a new client (X-Pairing-Code header)");
|
||||||
println!(" POST /webhook — {{\"message\": \"your prompt\"}}");
|
println!(" POST /webhook — {{\"message\": \"your prompt\"}}");
|
||||||
println!(" GET /health — health check");
|
if whatsapp_channel.is_some() {
|
||||||
|
println!(" GET /whatsapp — Meta webhook verification");
|
||||||
|
println!(" POST /whatsapp — WhatsApp message webhook");
|
||||||
|
}
|
||||||
|
println!(" GET /health — health check");
|
||||||
if let Some(code) = pairing.pairing_code() {
|
if let Some(code) = pairing.pairing_code() {
|
||||||
println!();
|
println!();
|
||||||
println!(" <20> PAIRING REQUIRED — use this one-time code:");
|
println!(" 🔐 PAIRING REQUIRED — use this one-time code:");
|
||||||
println!(" ┌──────────────┐");
|
println!(" ┌──────────────┐");
|
||||||
println!(" │ {code} │");
|
println!(" │ {code} │");
|
||||||
println!(" └──────────────┘");
|
println!(" └──────────────┘");
|
||||||
|
|
@ -97,428 +148,312 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
|
||||||
}
|
}
|
||||||
println!(" Press Ctrl+C to stop.\n");
|
println!(" Press Ctrl+C to stop.\n");
|
||||||
|
|
||||||
loop {
|
crate::health::mark_component_ok("gateway");
|
||||||
let (mut stream, peer) = listener.accept().await?;
|
|
||||||
let provider = provider.clone();
|
|
||||||
let model = model.clone();
|
|
||||||
let mem = mem.clone();
|
|
||||||
let auto_save = config.memory.auto_save;
|
|
||||||
let secret = webhook_secret.clone();
|
|
||||||
let pairing = pairing.clone();
|
|
||||||
|
|
||||||
tokio::spawn(async move {
|
// Build shared state
|
||||||
// Read with 30s timeout to prevent slow-loris attacks
|
let state = AppState {
|
||||||
let mut buf = vec![0u8; 65_536]; // 64KB max request
|
provider,
|
||||||
let n = match tokio::time::timeout(Duration::from_secs(30), stream.read(&mut buf)).await
|
model,
|
||||||
{
|
temperature,
|
||||||
Ok(Ok(n)) if n > 0 => n,
|
mem,
|
||||||
_ => return,
|
auto_save: config.memory.auto_save,
|
||||||
};
|
webhook_secret,
|
||||||
|
pairing,
|
||||||
|
whatsapp: whatsapp_channel,
|
||||||
|
};
|
||||||
|
|
||||||
let request = String::from_utf8_lossy(&buf[..n]);
|
// Build router with middleware
|
||||||
let first_line = request.lines().next().unwrap_or("");
|
// Note: Body limit layer prevents memory exhaustion from oversized requests
|
||||||
let parts: Vec<&str> = first_line.split_whitespace().collect();
|
// Timeout is handled by tokio's TcpListener accept timeout and hyper's built-in timeouts
|
||||||
|
let app = Router::new()
|
||||||
|
.route("/health", get(handle_health))
|
||||||
|
.route("/pair", post(handle_pair))
|
||||||
|
.route("/webhook", post(handle_webhook))
|
||||||
|
.route("/whatsapp", get(handle_whatsapp_verify))
|
||||||
|
.route("/whatsapp", post(handle_whatsapp_message))
|
||||||
|
.with_state(state)
|
||||||
|
.layer(RequestBodyLimitLayer::new(MAX_BODY_SIZE));
|
||||||
|
|
||||||
if let [method, path, ..] = parts.as_slice() {
|
// Run the server
|
||||||
tracing::info!("{peer} → {method} {path}");
|
axum::serve(listener, app).await?;
|
||||||
handle_request(
|
|
||||||
&mut stream,
|
Ok(())
|
||||||
method,
|
|
||||||
path,
|
|
||||||
&request,
|
|
||||||
&provider,
|
|
||||||
&model,
|
|
||||||
temperature,
|
|
||||||
&mem,
|
|
||||||
auto_save,
|
|
||||||
secret.as_ref(),
|
|
||||||
&pairing,
|
|
||||||
)
|
|
||||||
.await;
|
|
||||||
} else {
|
|
||||||
let _ = send_response(&mut stream, 400, "Bad Request").await;
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Extract a header value from a raw HTTP request.
|
// ══════════════════════════════════════════════════════════════════════════════
|
||||||
fn extract_header<'a>(request: &'a str, header_name: &str) -> Option<&'a str> {
|
// AXUM HANDLERS
|
||||||
let lower_name = header_name.to_lowercase();
|
// ══════════════════════════════════════════════════════════════════════════════
|
||||||
for line in request.lines() {
|
|
||||||
if let Some((key, value)) = line.split_once(':') {
|
/// GET /health — always public (no secrets leaked)
|
||||||
if key.trim().to_lowercase() == lower_name {
|
async fn handle_health(State(state): State<AppState>) -> impl IntoResponse {
|
||||||
return Some(value.trim());
|
let body = serde_json::json!({
|
||||||
}
|
"status": "ok",
|
||||||
}
|
"paired": state.pairing.is_paired(),
|
||||||
}
|
"runtime": crate::health::snapshot_json(),
|
||||||
None
|
});
|
||||||
|
Json(body)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(clippy::too_many_arguments)]
|
/// POST /pair — exchange one-time code for bearer token
|
||||||
async fn handle_request(
|
async fn handle_pair(State(state): State<AppState>, headers: HeaderMap) -> impl IntoResponse {
|
||||||
stream: &mut tokio::net::TcpStream,
|
let code = headers
|
||||||
method: &str,
|
.get("X-Pairing-Code")
|
||||||
path: &str,
|
.and_then(|v| v.to_str().ok())
|
||||||
request: &str,
|
|
||||||
provider: &Arc<dyn Provider>,
|
|
||||||
model: &str,
|
|
||||||
temperature: f64,
|
|
||||||
mem: &Arc<dyn Memory>,
|
|
||||||
auto_save: bool,
|
|
||||||
webhook_secret: Option<&Arc<str>>,
|
|
||||||
pairing: &PairingGuard,
|
|
||||||
) {
|
|
||||||
match (method, path) {
|
|
||||||
// Health check — always public (no secrets leaked)
|
|
||||||
("GET", "/health") => {
|
|
||||||
let body = serde_json::json!({
|
|
||||||
"status": "ok",
|
|
||||||
"paired": pairing.is_paired(),
|
|
||||||
});
|
|
||||||
let _ = send_json(stream, 200, &body).await;
|
|
||||||
}
|
|
||||||
|
|
||||||
// Pairing endpoint — exchange one-time code for bearer token
|
|
||||||
("POST", "/pair") => {
|
|
||||||
let code = extract_header(request, "X-Pairing-Code").unwrap_or("");
|
|
||||||
match pairing.try_pair(code) {
|
|
||||||
Ok(Some(token)) => {
|
|
||||||
tracing::info!("🔐 New client paired successfully");
|
|
||||||
let body = serde_json::json!({
|
|
||||||
"paired": true,
|
|
||||||
"token": token,
|
|
||||||
"message": "Save this token — use it as Authorization: Bearer <token>"
|
|
||||||
});
|
|
||||||
let _ = send_json(stream, 200, &body).await;
|
|
||||||
}
|
|
||||||
Ok(None) => {
|
|
||||||
tracing::warn!("🔐 Pairing attempt with invalid code");
|
|
||||||
let err = serde_json::json!({"error": "Invalid pairing code"});
|
|
||||||
let _ = send_json(stream, 403, &err).await;
|
|
||||||
}
|
|
||||||
Err(lockout_secs) => {
|
|
||||||
tracing::warn!(
|
|
||||||
"🔐 Pairing locked out — too many failed attempts ({lockout_secs}s remaining)"
|
|
||||||
);
|
|
||||||
let err = serde_json::json!({
|
|
||||||
"error": format!("Too many failed attempts. Try again in {lockout_secs}s."),
|
|
||||||
"retry_after": lockout_secs
|
|
||||||
});
|
|
||||||
let _ = send_json(stream, 429, &err).await;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
("POST", "/webhook") => {
|
|
||||||
// ── Bearer token auth (pairing) ──
|
|
||||||
if pairing.require_pairing() {
|
|
||||||
let auth = extract_header(request, "Authorization").unwrap_or("");
|
|
||||||
let token = auth.strip_prefix("Bearer ").unwrap_or("");
|
|
||||||
if !pairing.is_authenticated(token) {
|
|
||||||
tracing::warn!("Webhook: rejected — not paired / invalid bearer token");
|
|
||||||
let err = serde_json::json!({
|
|
||||||
"error": "Unauthorized — pair first via POST /pair, then send Authorization: Bearer <token>"
|
|
||||||
});
|
|
||||||
let _ = send_json(stream, 401, &err).await;
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ── Webhook secret auth (optional, additional layer) ──
|
|
||||||
if let Some(secret) = webhook_secret {
|
|
||||||
let header_val = extract_header(request, "X-Webhook-Secret");
|
|
||||||
match header_val {
|
|
||||||
Some(val) if constant_time_eq(val, secret.as_ref()) => {}
|
|
||||||
_ => {
|
|
||||||
tracing::warn!(
|
|
||||||
"Webhook: rejected request — invalid or missing X-Webhook-Secret"
|
|
||||||
);
|
|
||||||
let err = serde_json::json!({"error": "Unauthorized — invalid or missing X-Webhook-Secret header"});
|
|
||||||
let _ = send_json(stream, 401, &err).await;
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
handle_webhook(
|
|
||||||
stream,
|
|
||||||
request,
|
|
||||||
provider,
|
|
||||||
model,
|
|
||||||
temperature,
|
|
||||||
mem,
|
|
||||||
auto_save,
|
|
||||||
)
|
|
||||||
.await;
|
|
||||||
}
|
|
||||||
|
|
||||||
_ => {
|
|
||||||
let body = serde_json::json!({
|
|
||||||
"error": "Not found",
|
|
||||||
"routes": ["GET /health", "POST /pair", "POST /webhook"]
|
|
||||||
});
|
|
||||||
let _ = send_json(stream, 404, &body).await;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
async fn handle_webhook(
|
|
||||||
stream: &mut tokio::net::TcpStream,
|
|
||||||
request: &str,
|
|
||||||
provider: &Arc<dyn Provider>,
|
|
||||||
model: &str,
|
|
||||||
temperature: f64,
|
|
||||||
mem: &Arc<dyn Memory>,
|
|
||||||
auto_save: bool,
|
|
||||||
) {
|
|
||||||
let body_str = request
|
|
||||||
.split("\r\n\r\n")
|
|
||||||
.nth(1)
|
|
||||||
.or_else(|| request.split("\n\n").nth(1))
|
|
||||||
.unwrap_or("");
|
.unwrap_or("");
|
||||||
|
|
||||||
let Ok(parsed) = serde_json::from_str::<serde_json::Value>(body_str) else {
|
match state.pairing.try_pair(code) {
|
||||||
let err = serde_json::json!({"error": "Invalid JSON. Expected: {\"message\": \"...\"}"});
|
Ok(Some(token)) => {
|
||||||
let _ = send_json(stream, 400, &err).await;
|
tracing::info!("🔐 New client paired successfully");
|
||||||
return;
|
let body = serde_json::json!({
|
||||||
|
"paired": true,
|
||||||
|
"token": token,
|
||||||
|
"message": "Save this token — use it as Authorization: Bearer <token>"
|
||||||
|
});
|
||||||
|
(StatusCode::OK, Json(body))
|
||||||
|
}
|
||||||
|
Ok(None) => {
|
||||||
|
tracing::warn!("🔐 Pairing attempt with invalid code");
|
||||||
|
let err = serde_json::json!({"error": "Invalid pairing code"});
|
||||||
|
(StatusCode::FORBIDDEN, Json(err))
|
||||||
|
}
|
||||||
|
Err(lockout_secs) => {
|
||||||
|
tracing::warn!(
|
||||||
|
"🔐 Pairing locked out — too many failed attempts ({lockout_secs}s remaining)"
|
||||||
|
);
|
||||||
|
let err = serde_json::json!({
|
||||||
|
"error": format!("Too many failed attempts. Try again in {lockout_secs}s."),
|
||||||
|
"retry_after": lockout_secs
|
||||||
|
});
|
||||||
|
(StatusCode::TOO_MANY_REQUESTS, Json(err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Webhook request body
|
||||||
|
#[derive(serde::Deserialize)]
|
||||||
|
pub struct WebhookBody {
|
||||||
|
pub message: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// POST /webhook — main webhook endpoint
|
||||||
|
async fn handle_webhook(
|
||||||
|
State(state): State<AppState>,
|
||||||
|
headers: HeaderMap,
|
||||||
|
body: Result<Json<WebhookBody>, axum::extract::rejection::JsonRejection>,
|
||||||
|
) -> impl IntoResponse {
|
||||||
|
// ── Bearer token auth (pairing) ──
|
||||||
|
if state.pairing.require_pairing() {
|
||||||
|
let auth = headers
|
||||||
|
.get(header::AUTHORIZATION)
|
||||||
|
.and_then(|v| v.to_str().ok())
|
||||||
|
.unwrap_or("");
|
||||||
|
let token = auth.strip_prefix("Bearer ").unwrap_or("");
|
||||||
|
if !state.pairing.is_authenticated(token) {
|
||||||
|
tracing::warn!("Webhook: rejected — not paired / invalid bearer token");
|
||||||
|
let err = serde_json::json!({
|
||||||
|
"error": "Unauthorized — pair first via POST /pair, then send Authorization: Bearer <token>"
|
||||||
|
});
|
||||||
|
return (StatusCode::UNAUTHORIZED, Json(err));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Webhook secret auth (optional, additional layer) ──
|
||||||
|
if let Some(ref secret) = state.webhook_secret {
|
||||||
|
let header_val = headers
|
||||||
|
.get("X-Webhook-Secret")
|
||||||
|
.and_then(|v| v.to_str().ok());
|
||||||
|
match header_val {
|
||||||
|
Some(val) if constant_time_eq(val, secret.as_ref()) => {}
|
||||||
|
_ => {
|
||||||
|
tracing::warn!("Webhook: rejected request — invalid or missing X-Webhook-Secret");
|
||||||
|
let err = serde_json::json!({"error": "Unauthorized — invalid or missing X-Webhook-Secret header"});
|
||||||
|
return (StatusCode::UNAUTHORIZED, Json(err));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Parse body ──
|
||||||
|
let Json(webhook_body) = match body {
|
||||||
|
Ok(b) => b,
|
||||||
|
Err(e) => {
|
||||||
|
let err = serde_json::json!({
|
||||||
|
"error": format!("Invalid JSON: {e}. Expected: {{\"message\": \"...\"}}")
|
||||||
|
});
|
||||||
|
return (StatusCode::BAD_REQUEST, Json(err));
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
let Some(message) = parsed.get("message").and_then(|v| v.as_str()) else {
|
let message = &webhook_body.message;
|
||||||
let err = serde_json::json!({"error": "Missing 'message' field in JSON"});
|
|
||||||
let _ = send_json(stream, 400, &err).await;
|
|
||||||
return;
|
|
||||||
};
|
|
||||||
|
|
||||||
if auto_save {
|
if state.auto_save {
|
||||||
let _ = mem
|
let _ = state
|
||||||
|
.mem
|
||||||
.store("webhook_msg", message, MemoryCategory::Conversation)
|
.store("webhook_msg", message, MemoryCategory::Conversation)
|
||||||
.await;
|
.await;
|
||||||
}
|
}
|
||||||
|
|
||||||
match provider.chat(message, model, temperature).await {
|
match state
|
||||||
|
.provider
|
||||||
|
.chat(message, &state.model, state.temperature)
|
||||||
|
.await
|
||||||
|
{
|
||||||
Ok(response) => {
|
Ok(response) => {
|
||||||
let body = serde_json::json!({"response": response, "model": model});
|
let body = serde_json::json!({"response": response, "model": state.model});
|
||||||
let _ = send_json(stream, 200, &body).await;
|
(StatusCode::OK, Json(body))
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
let err = serde_json::json!({"error": format!("LLM error: {e}")});
|
let err = serde_json::json!({"error": format!("LLM error: {e}")});
|
||||||
let _ = send_json(stream, 500, &err).await;
|
(StatusCode::INTERNAL_SERVER_ERROR, Json(err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn send_response(
|
/// `WhatsApp` verification query params
|
||||||
stream: &mut tokio::net::TcpStream,
|
#[derive(serde::Deserialize)]
|
||||||
status: u16,
|
pub struct WhatsAppVerifyQuery {
|
||||||
body: &str,
|
#[serde(rename = "hub.mode")]
|
||||||
) -> std::io::Result<()> {
|
pub mode: Option<String>,
|
||||||
let reason = match status {
|
#[serde(rename = "hub.verify_token")]
|
||||||
200 => "OK",
|
pub verify_token: Option<String>,
|
||||||
400 => "Bad Request",
|
#[serde(rename = "hub.challenge")]
|
||||||
404 => "Not Found",
|
pub challenge: Option<String>,
|
||||||
500 => "Internal Server Error",
|
|
||||||
_ => "Unknown",
|
|
||||||
};
|
|
||||||
let response = format!(
|
|
||||||
"HTTP/1.1 {status} {reason}\r\nContent-Type: text/plain\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{body}",
|
|
||||||
body.len()
|
|
||||||
);
|
|
||||||
stream.write_all(response.as_bytes()).await
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn send_json(
|
/// GET /whatsapp — Meta webhook verification
|
||||||
stream: &mut tokio::net::TcpStream,
|
async fn handle_whatsapp_verify(
|
||||||
status: u16,
|
State(state): State<AppState>,
|
||||||
body: &serde_json::Value,
|
Query(params): Query<WhatsAppVerifyQuery>,
|
||||||
) -> std::io::Result<()> {
|
) -> impl IntoResponse {
|
||||||
let reason = match status {
|
let Some(ref wa) = state.whatsapp else {
|
||||||
200 => "OK",
|
return (StatusCode::NOT_FOUND, "WhatsApp not configured".to_string());
|
||||||
400 => "Bad Request",
|
|
||||||
404 => "Not Found",
|
|
||||||
500 => "Internal Server Error",
|
|
||||||
_ => "Unknown",
|
|
||||||
};
|
};
|
||||||
let json = serde_json::to_string(body).unwrap_or_default();
|
|
||||||
let response = format!(
|
// Verify the token matches
|
||||||
"HTTP/1.1 {status} {reason}\r\nContent-Type: application/json\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{json}",
|
if params.mode.as_deref() == Some("subscribe")
|
||||||
json.len()
|
&& params.verify_token.as_deref() == Some(wa.verify_token())
|
||||||
);
|
{
|
||||||
stream.write_all(response.as_bytes()).await
|
if let Some(ch) = params.challenge {
|
||||||
|
tracing::info!("WhatsApp webhook verified successfully");
|
||||||
|
return (StatusCode::OK, ch);
|
||||||
|
}
|
||||||
|
return (StatusCode::BAD_REQUEST, "Missing hub.challenge".to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
tracing::warn!("WhatsApp webhook verification failed — token mismatch");
|
||||||
|
(StatusCode::FORBIDDEN, "Forbidden".to_string())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// POST /whatsapp — incoming message webhook
|
||||||
|
async fn handle_whatsapp_message(State(state): State<AppState>, body: Bytes) -> impl IntoResponse {
|
||||||
|
let Some(ref wa) = state.whatsapp else {
|
||||||
|
return (
|
||||||
|
StatusCode::NOT_FOUND,
|
||||||
|
Json(serde_json::json!({"error": "WhatsApp not configured"})),
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
// Parse JSON body
|
||||||
|
let Ok(payload) = serde_json::from_slice::<serde_json::Value>(&body) else {
|
||||||
|
return (
|
||||||
|
StatusCode::BAD_REQUEST,
|
||||||
|
Json(serde_json::json!({"error": "Invalid JSON payload"})),
|
||||||
|
);
|
||||||
|
};
|
||||||
|
|
||||||
|
// Parse messages from the webhook payload
|
||||||
|
let messages = wa.parse_webhook_payload(&payload);
|
||||||
|
|
||||||
|
if messages.is_empty() {
|
||||||
|
// Acknowledge the webhook even if no messages (could be status updates)
|
||||||
|
return (StatusCode::OK, Json(serde_json::json!({"status": "ok"})));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Process each message
|
||||||
|
for msg in &messages {
|
||||||
|
tracing::info!(
|
||||||
|
"WhatsApp message from {}: {}",
|
||||||
|
msg.sender,
|
||||||
|
if msg.content.len() > 50 {
|
||||||
|
format!("{}...", &msg.content[..50])
|
||||||
|
} else {
|
||||||
|
msg.content.clone()
|
||||||
|
}
|
||||||
|
);
|
||||||
|
|
||||||
|
// Auto-save to memory
|
||||||
|
if state.auto_save {
|
||||||
|
let _ = state
|
||||||
|
.mem
|
||||||
|
.store(
|
||||||
|
&format!("whatsapp_{}", msg.sender),
|
||||||
|
&msg.content,
|
||||||
|
MemoryCategory::Conversation,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Call the LLM
|
||||||
|
match state
|
||||||
|
.provider
|
||||||
|
.chat(&msg.content, &state.model, state.temperature)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(response) => {
|
||||||
|
// Send reply via WhatsApp
|
||||||
|
if let Err(e) = wa.send(&response, &msg.sender).await {
|
||||||
|
tracing::error!("Failed to send WhatsApp reply: {e}");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
tracing::error!("LLM error for WhatsApp message: {e}");
|
||||||
|
let _ = wa.send(&format!("⚠️ Error: {e}"), &msg.sender).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Acknowledge the webhook
|
||||||
|
(StatusCode::OK, Json(serde_json::json!({"status": "ok"})))
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
use tokio::net::TcpListener as TokioListener;
|
|
||||||
|
|
||||||
// ── Port allocation tests ────────────────────────────────
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn port_zero_binds_to_random_port() {
|
|
||||||
let listener = TokioListener::bind("127.0.0.1:0").await.unwrap();
|
|
||||||
let actual = listener.local_addr().unwrap().port();
|
|
||||||
assert_ne!(actual, 0, "OS must assign a non-zero port");
|
|
||||||
assert!(actual > 0, "Actual port must be positive");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn port_zero_assigns_different_ports() {
|
|
||||||
let l1 = TokioListener::bind("127.0.0.1:0").await.unwrap();
|
|
||||||
let l2 = TokioListener::bind("127.0.0.1:0").await.unwrap();
|
|
||||||
let p1 = l1.local_addr().unwrap().port();
|
|
||||||
let p2 = l2.local_addr().unwrap().port();
|
|
||||||
assert_ne!(p1, p2, "Two port-0 binds should get different ports");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn port_zero_assigns_high_port() {
|
|
||||||
let listener = TokioListener::bind("127.0.0.1:0").await.unwrap();
|
|
||||||
let actual = listener.local_addr().unwrap().port();
|
|
||||||
// OS typically assigns ephemeral ports >= 1024
|
|
||||||
assert!(
|
|
||||||
actual >= 1024,
|
|
||||||
"Random port {actual} should be >= 1024 (unprivileged)"
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn specific_port_binds_exactly() {
|
|
||||||
// Find a free port first via port 0, then rebind to it
|
|
||||||
let tmp = TokioListener::bind("127.0.0.1:0").await.unwrap();
|
|
||||||
let free_port = tmp.local_addr().unwrap().port();
|
|
||||||
drop(tmp);
|
|
||||||
|
|
||||||
let listener = TokioListener::bind(format!("127.0.0.1:{free_port}"))
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
let actual = listener.local_addr().unwrap().port();
|
|
||||||
assert_eq!(actual, free_port, "Specific port bind must match exactly");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn actual_port_matches_addr_format() {
|
|
||||||
let listener = TokioListener::bind("127.0.0.1:0").await.unwrap();
|
|
||||||
let actual_port = listener.local_addr().unwrap().port();
|
|
||||||
let addr = format!("127.0.0.1:{actual_port}");
|
|
||||||
assert!(
|
|
||||||
addr.starts_with("127.0.0.1:"),
|
|
||||||
"Addr format must include host"
|
|
||||||
);
|
|
||||||
assert!(
|
|
||||||
!addr.ends_with(":0"),
|
|
||||||
"Addr must not contain port 0 after binding"
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn port_zero_listener_accepts_connections() {
|
|
||||||
let listener = TokioListener::bind("127.0.0.1:0").await.unwrap();
|
|
||||||
let actual_port = listener.local_addr().unwrap().port();
|
|
||||||
|
|
||||||
// Spawn a client that connects
|
|
||||||
let client = tokio::spawn(async move {
|
|
||||||
tokio::net::TcpStream::connect(format!("127.0.0.1:{actual_port}"))
|
|
||||||
.await
|
|
||||||
.unwrap()
|
|
||||||
});
|
|
||||||
|
|
||||||
// Accept the connection
|
|
||||||
let (stream, _peer) = listener.accept().await.unwrap();
|
|
||||||
assert!(stream.peer_addr().is_ok());
|
|
||||||
client.await.unwrap();
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn duplicate_specific_port_fails() {
|
|
||||||
let l1 = TokioListener::bind("127.0.0.1:0").await.unwrap();
|
|
||||||
let port = l1.local_addr().unwrap().port();
|
|
||||||
// Try to bind the same port while l1 is still alive
|
|
||||||
let result = TokioListener::bind(format!("127.0.0.1:{port}")).await;
|
|
||||||
assert!(result.is_err(), "Binding an already-used port must fail");
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tokio::test]
|
|
||||||
async fn tunnel_gets_actual_port_not_zero() {
|
|
||||||
// Simulate what run_gateway does: bind port 0, extract actual port
|
|
||||||
let port: u16 = 0;
|
|
||||||
let host = "127.0.0.1";
|
|
||||||
let listener = TokioListener::bind(format!("{host}:{port}")).await.unwrap();
|
|
||||||
let actual_port = listener.local_addr().unwrap().port();
|
|
||||||
|
|
||||||
// This is the port that would be passed to tun.start(host, actual_port)
|
|
||||||
assert_ne!(actual_port, 0, "Tunnel must receive actual port, not 0");
|
|
||||||
assert!(
|
|
||||||
actual_port >= 1024,
|
|
||||||
"Tunnel port {actual_port} must be unprivileged"
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
// ── extract_header tests ─────────────────────────────────
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn extract_header_finds_value() {
|
fn security_body_limit_is_64kb() {
|
||||||
let req =
|
assert_eq!(MAX_BODY_SIZE, 65_536);
|
||||||
"POST /webhook HTTP/1.1\r\nHost: localhost\r\nX-Webhook-Secret: my-secret\r\n\r\n{}";
|
|
||||||
assert_eq!(extract_header(req, "X-Webhook-Secret"), Some("my-secret"));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn extract_header_case_insensitive() {
|
fn security_timeout_is_30_seconds() {
|
||||||
let req = "POST /webhook HTTP/1.1\r\nx-webhook-secret: abc123\r\n\r\n{}";
|
assert_eq!(REQUEST_TIMEOUT_SECS, 30);
|
||||||
assert_eq!(extract_header(req, "X-Webhook-Secret"), Some("abc123"));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn extract_header_missing_returns_none() {
|
fn webhook_body_requires_message_field() {
|
||||||
let req = "POST /webhook HTTP/1.1\r\nHost: localhost\r\n\r\n{}";
|
let valid = r#"{"message": "hello"}"#;
|
||||||
assert_eq!(extract_header(req, "X-Webhook-Secret"), None);
|
let parsed: Result<WebhookBody, _> = serde_json::from_str(valid);
|
||||||
|
assert!(parsed.is_ok());
|
||||||
|
assert_eq!(parsed.unwrap().message, "hello");
|
||||||
|
|
||||||
|
let missing = r#"{"other": "field"}"#;
|
||||||
|
let parsed: Result<WebhookBody, _> = serde_json::from_str(missing);
|
||||||
|
assert!(parsed.is_err());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn extract_header_trims_whitespace() {
|
fn whatsapp_query_fields_are_optional() {
|
||||||
let req = "POST /webhook HTTP/1.1\r\nX-Webhook-Secret: spaced \r\n\r\n{}";
|
let q = WhatsAppVerifyQuery {
|
||||||
assert_eq!(extract_header(req, "X-Webhook-Secret"), Some("spaced"));
|
mode: None,
|
||||||
|
verify_token: None,
|
||||||
|
challenge: None,
|
||||||
|
};
|
||||||
|
assert!(q.mode.is_none());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn extract_header_first_match_wins() {
|
fn app_state_is_clone() {
|
||||||
let req = "POST /webhook HTTP/1.1\r\nX-Webhook-Secret: first\r\nX-Webhook-Secret: second\r\n\r\n{}";
|
fn assert_clone<T: Clone>() {}
|
||||||
assert_eq!(extract_header(req, "X-Webhook-Secret"), Some("first"));
|
assert_clone::<AppState>();
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn extract_header_empty_value() {
|
|
||||||
let req = "POST /webhook HTTP/1.1\r\nX-Webhook-Secret:\r\n\r\n{}";
|
|
||||||
assert_eq!(extract_header(req, "X-Webhook-Secret"), Some(""));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn extract_header_colon_in_value() {
|
|
||||||
let req = "POST /webhook HTTP/1.1\r\nAuthorization: Bearer sk-abc:123\r\n\r\n{}";
|
|
||||||
// split_once on ':' means only the first colon splits key/value
|
|
||||||
assert_eq!(
|
|
||||||
extract_header(req, "Authorization"),
|
|
||||||
Some("Bearer sk-abc:123")
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn extract_header_different_header() {
|
|
||||||
let req = "POST /webhook HTTP/1.1\r\nContent-Type: application/json\r\nX-Webhook-Secret: mysecret\r\n\r\n{}";
|
|
||||||
assert_eq!(
|
|
||||||
extract_header(req, "Content-Type"),
|
|
||||||
Some("application/json")
|
|
||||||
);
|
|
||||||
assert_eq!(extract_header(req, "X-Webhook-Secret"), Some("mysecret"));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn extract_header_from_empty_request() {
|
|
||||||
assert_eq!(extract_header("", "X-Webhook-Secret"), None);
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn extract_header_newline_only_request() {
|
|
||||||
assert_eq!(extract_header("\r\n\r\n", "X-Webhook-Secret"), None);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
105
src/health/mod.rs
Normal file
105
src/health/mod.rs
Normal file
|
|
@ -0,0 +1,105 @@
|
||||||
|
use chrono::Utc;
|
||||||
|
use serde::Serialize;
|
||||||
|
use std::collections::BTreeMap;
|
||||||
|
use std::sync::{Mutex, OnceLock};
|
||||||
|
use std::time::Instant;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize)]
|
||||||
|
pub struct ComponentHealth {
|
||||||
|
pub status: String,
|
||||||
|
pub updated_at: String,
|
||||||
|
pub last_ok: Option<String>,
|
||||||
|
pub last_error: Option<String>,
|
||||||
|
pub restart_count: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize)]
|
||||||
|
pub struct HealthSnapshot {
|
||||||
|
pub pid: u32,
|
||||||
|
pub updated_at: String,
|
||||||
|
pub uptime_seconds: u64,
|
||||||
|
pub components: BTreeMap<String, ComponentHealth>,
|
||||||
|
}
|
||||||
|
|
||||||
|
struct HealthRegistry {
|
||||||
|
started_at: Instant,
|
||||||
|
components: Mutex<BTreeMap<String, ComponentHealth>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
static REGISTRY: OnceLock<HealthRegistry> = OnceLock::new();
|
||||||
|
|
||||||
|
fn registry() -> &'static HealthRegistry {
|
||||||
|
REGISTRY.get_or_init(|| HealthRegistry {
|
||||||
|
started_at: Instant::now(),
|
||||||
|
components: Mutex::new(BTreeMap::new()),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn now_rfc3339() -> String {
|
||||||
|
Utc::now().to_rfc3339()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn upsert_component<F>(component: &str, update: F)
|
||||||
|
where
|
||||||
|
F: FnOnce(&mut ComponentHealth),
|
||||||
|
{
|
||||||
|
if let Ok(mut map) = registry().components.lock() {
|
||||||
|
let now = now_rfc3339();
|
||||||
|
let entry = map
|
||||||
|
.entry(component.to_string())
|
||||||
|
.or_insert_with(|| ComponentHealth {
|
||||||
|
status: "starting".into(),
|
||||||
|
updated_at: now.clone(),
|
||||||
|
last_ok: None,
|
||||||
|
last_error: None,
|
||||||
|
restart_count: 0,
|
||||||
|
});
|
||||||
|
update(entry);
|
||||||
|
entry.updated_at = now;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn mark_component_ok(component: &str) {
|
||||||
|
upsert_component(component, |entry| {
|
||||||
|
entry.status = "ok".into();
|
||||||
|
entry.last_ok = Some(now_rfc3339());
|
||||||
|
entry.last_error = None;
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn mark_component_error(component: &str, error: impl ToString) {
|
||||||
|
let err = error.to_string();
|
||||||
|
upsert_component(component, move |entry| {
|
||||||
|
entry.status = "error".into();
|
||||||
|
entry.last_error = Some(err);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn bump_component_restart(component: &str) {
|
||||||
|
upsert_component(component, |entry| {
|
||||||
|
entry.restart_count = entry.restart_count.saturating_add(1);
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn snapshot() -> HealthSnapshot {
|
||||||
|
let components = registry()
|
||||||
|
.components
|
||||||
|
.lock()
|
||||||
|
.map_or_else(|_| BTreeMap::new(), |map| map.clone());
|
||||||
|
|
||||||
|
HealthSnapshot {
|
||||||
|
pid: std::process::id(),
|
||||||
|
updated_at: now_rfc3339(),
|
||||||
|
uptime_seconds: registry().started_at.elapsed().as_secs(),
|
||||||
|
components,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn snapshot_json() -> serde_json::Value {
|
||||||
|
serde_json::to_value(snapshot()).unwrap_or_else(|_| {
|
||||||
|
serde_json::json!({
|
||||||
|
"status": "error",
|
||||||
|
"message": "failed to serialize health snapshot"
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
@ -61,16 +61,17 @@ impl HeartbeatEngine {
|
||||||
|
|
||||||
/// Single heartbeat tick — read HEARTBEAT.md and return task count
|
/// Single heartbeat tick — read HEARTBEAT.md and return task count
|
||||||
async fn tick(&self) -> Result<usize> {
|
async fn tick(&self) -> Result<usize> {
|
||||||
|
Ok(self.collect_tasks().await?.len())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Read HEARTBEAT.md and return all parsed tasks.
|
||||||
|
pub async fn collect_tasks(&self) -> Result<Vec<String>> {
|
||||||
let heartbeat_path = self.workspace_dir.join("HEARTBEAT.md");
|
let heartbeat_path = self.workspace_dir.join("HEARTBEAT.md");
|
||||||
|
|
||||||
if !heartbeat_path.exists() {
|
if !heartbeat_path.exists() {
|
||||||
return Ok(0);
|
return Ok(Vec::new());
|
||||||
}
|
}
|
||||||
|
|
||||||
let content = tokio::fs::read_to_string(&heartbeat_path).await?;
|
let content = tokio::fs::read_to_string(&heartbeat_path).await?;
|
||||||
let tasks = Self::parse_tasks(&content);
|
Ok(Self::parse_tasks(&content))
|
||||||
|
|
||||||
Ok(tasks.len())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Parse tasks from HEARTBEAT.md (lines starting with `- `)
|
/// Parse tasks from HEARTBEAT.md (lines starting with `- `)
|
||||||
|
|
|
||||||
109
src/main.rs
109
src/main.rs
|
|
@ -8,7 +8,7 @@
|
||||||
dead_code
|
dead_code
|
||||||
)]
|
)]
|
||||||
|
|
||||||
use anyhow::Result;
|
use anyhow::{bail, Result};
|
||||||
use clap::{Parser, Subcommand};
|
use clap::{Parser, Subcommand};
|
||||||
use tracing::{info, Level};
|
use tracing::{info, Level};
|
||||||
use tracing_subscriber::FmtSubscriber;
|
use tracing_subscriber::FmtSubscriber;
|
||||||
|
|
@ -17,15 +17,20 @@ mod agent;
|
||||||
mod channels;
|
mod channels;
|
||||||
mod config;
|
mod config;
|
||||||
mod cron;
|
mod cron;
|
||||||
|
mod daemon;
|
||||||
|
mod doctor;
|
||||||
mod gateway;
|
mod gateway;
|
||||||
|
mod health;
|
||||||
mod heartbeat;
|
mod heartbeat;
|
||||||
mod integrations;
|
mod integrations;
|
||||||
mod memory;
|
mod memory;
|
||||||
|
mod migration;
|
||||||
mod observability;
|
mod observability;
|
||||||
mod onboard;
|
mod onboard;
|
||||||
mod providers;
|
mod providers;
|
||||||
mod runtime;
|
mod runtime;
|
||||||
mod security;
|
mod security;
|
||||||
|
mod service;
|
||||||
mod skills;
|
mod skills;
|
||||||
mod tools;
|
mod tools;
|
||||||
mod tunnel;
|
mod tunnel;
|
||||||
|
|
@ -43,6 +48,20 @@ struct Cli {
|
||||||
command: Commands,
|
command: Commands,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Subcommand, Debug)]
|
||||||
|
enum ServiceCommands {
|
||||||
|
/// Install daemon service unit for auto-start and restart
|
||||||
|
Install,
|
||||||
|
/// Start daemon service
|
||||||
|
Start,
|
||||||
|
/// Stop daemon service
|
||||||
|
Stop,
|
||||||
|
/// Check daemon service status
|
||||||
|
Status,
|
||||||
|
/// Uninstall daemon service unit
|
||||||
|
Uninstall,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Subcommand, Debug)]
|
#[derive(Subcommand, Debug)]
|
||||||
enum Commands {
|
enum Commands {
|
||||||
/// Initialize your workspace and configuration
|
/// Initialize your workspace and configuration
|
||||||
|
|
@ -51,6 +70,10 @@ enum Commands {
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
interactive: bool,
|
interactive: bool,
|
||||||
|
|
||||||
|
/// Reconfigure channels only (fast repair flow)
|
||||||
|
#[arg(long)]
|
||||||
|
channels_only: bool,
|
||||||
|
|
||||||
/// API key (used in quick mode, ignored with --interactive)
|
/// API key (used in quick mode, ignored with --interactive)
|
||||||
#[arg(long)]
|
#[arg(long)]
|
||||||
api_key: Option<String>,
|
api_key: Option<String>,
|
||||||
|
|
@ -71,7 +94,7 @@ enum Commands {
|
||||||
provider: Option<String>,
|
provider: Option<String>,
|
||||||
|
|
||||||
/// Model to use
|
/// Model to use
|
||||||
#[arg(short, long)]
|
#[arg(long)]
|
||||||
model: Option<String>,
|
model: Option<String>,
|
||||||
|
|
||||||
/// Temperature (0.0 - 2.0)
|
/// Temperature (0.0 - 2.0)
|
||||||
|
|
@ -86,10 +109,30 @@ enum Commands {
|
||||||
port: u16,
|
port: u16,
|
||||||
|
|
||||||
/// Host to bind to
|
/// Host to bind to
|
||||||
#[arg(short, long, default_value = "127.0.0.1")]
|
#[arg(long, default_value = "127.0.0.1")]
|
||||||
host: String,
|
host: String,
|
||||||
},
|
},
|
||||||
|
|
||||||
|
/// Start long-running autonomous runtime (gateway + channels + heartbeat + scheduler)
|
||||||
|
Daemon {
|
||||||
|
/// Port to listen on (use 0 for random available port)
|
||||||
|
#[arg(short, long, default_value = "8080")]
|
||||||
|
port: u16,
|
||||||
|
|
||||||
|
/// Host to bind to
|
||||||
|
#[arg(long, default_value = "127.0.0.1")]
|
||||||
|
host: String,
|
||||||
|
},
|
||||||
|
|
||||||
|
/// Manage OS service lifecycle (launchd/systemd user service)
|
||||||
|
Service {
|
||||||
|
#[command(subcommand)]
|
||||||
|
service_command: ServiceCommands,
|
||||||
|
},
|
||||||
|
|
||||||
|
/// Run diagnostics for daemon/scheduler/channel freshness
|
||||||
|
Doctor,
|
||||||
|
|
||||||
/// Show system status (full details)
|
/// Show system status (full details)
|
||||||
Status,
|
Status,
|
||||||
|
|
||||||
|
|
@ -116,6 +159,26 @@ enum Commands {
|
||||||
#[command(subcommand)]
|
#[command(subcommand)]
|
||||||
skill_command: SkillCommands,
|
skill_command: SkillCommands,
|
||||||
},
|
},
|
||||||
|
|
||||||
|
/// Migrate data from other agent runtimes
|
||||||
|
Migrate {
|
||||||
|
#[command(subcommand)]
|
||||||
|
migrate_command: MigrateCommands,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Subcommand, Debug)]
|
||||||
|
enum MigrateCommands {
|
||||||
|
/// Import memory from an OpenClaw workspace into this ZeroClaw workspace
|
||||||
|
Openclaw {
|
||||||
|
/// Optional path to OpenClaw workspace (defaults to ~/.openclaw/workspace)
|
||||||
|
#[arg(long)]
|
||||||
|
source: Option<std::path::PathBuf>,
|
||||||
|
|
||||||
|
/// Validate and preview migration without writing any data
|
||||||
|
#[arg(long)]
|
||||||
|
dry_run: bool,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Subcommand, Debug)]
|
#[derive(Subcommand, Debug)]
|
||||||
|
|
@ -198,11 +261,21 @@ async fn main() -> Result<()> {
|
||||||
// Onboard runs quick setup by default, or the interactive wizard with --interactive
|
// Onboard runs quick setup by default, or the interactive wizard with --interactive
|
||||||
if let Commands::Onboard {
|
if let Commands::Onboard {
|
||||||
interactive,
|
interactive,
|
||||||
|
channels_only,
|
||||||
api_key,
|
api_key,
|
||||||
provider,
|
provider,
|
||||||
} = &cli.command
|
} = &cli.command
|
||||||
{
|
{
|
||||||
let config = if *interactive {
|
if *interactive && *channels_only {
|
||||||
|
bail!("Use either --interactive or --channels-only, not both");
|
||||||
|
}
|
||||||
|
if *channels_only && (api_key.is_some() || provider.is_some()) {
|
||||||
|
bail!("--channels-only does not accept --api-key or --provider");
|
||||||
|
}
|
||||||
|
|
||||||
|
let config = if *channels_only {
|
||||||
|
onboard::run_channels_repair_wizard()?
|
||||||
|
} else if *interactive {
|
||||||
onboard::run_wizard()?
|
onboard::run_wizard()?
|
||||||
} else {
|
} else {
|
||||||
onboard::run_quick_setup(api_key.as_deref(), provider.as_deref())?
|
onboard::run_quick_setup(api_key.as_deref(), provider.as_deref())?
|
||||||
|
|
@ -236,6 +309,15 @@ async fn main() -> Result<()> {
|
||||||
gateway::run_gateway(&host, port, config).await
|
gateway::run_gateway(&host, port, config).await
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Commands::Daemon { port, host } => {
|
||||||
|
if port == 0 {
|
||||||
|
info!("🧠 Starting ZeroClaw Daemon on {host} (random port)");
|
||||||
|
} else {
|
||||||
|
info!("🧠 Starting ZeroClaw Daemon on {host}:{port}");
|
||||||
|
}
|
||||||
|
daemon::run(config, host, port).await
|
||||||
|
}
|
||||||
|
|
||||||
Commands::Status => {
|
Commands::Status => {
|
||||||
println!("🦀 ZeroClaw Status");
|
println!("🦀 ZeroClaw Status");
|
||||||
println!();
|
println!();
|
||||||
|
|
@ -307,6 +389,10 @@ async fn main() -> Result<()> {
|
||||||
|
|
||||||
Commands::Cron { cron_command } => cron::handle_command(cron_command, config),
|
Commands::Cron { cron_command } => cron::handle_command(cron_command, config),
|
||||||
|
|
||||||
|
Commands::Service { service_command } => service::handle_command(service_command, &config),
|
||||||
|
|
||||||
|
Commands::Doctor => doctor::run(&config),
|
||||||
|
|
||||||
Commands::Channel { channel_command } => match channel_command {
|
Commands::Channel { channel_command } => match channel_command {
|
||||||
ChannelCommands::Start => channels::start_channels(config).await,
|
ChannelCommands::Start => channels::start_channels(config).await,
|
||||||
ChannelCommands::Doctor => channels::doctor_channels(config).await,
|
ChannelCommands::Doctor => channels::doctor_channels(config).await,
|
||||||
|
|
@ -320,5 +406,20 @@ async fn main() -> Result<()> {
|
||||||
Commands::Skills { skill_command } => {
|
Commands::Skills { skill_command } => {
|
||||||
skills::handle_command(skill_command, &config.workspace_dir)
|
skills::handle_command(skill_command, &config.workspace_dir)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Commands::Migrate { migrate_command } => {
|
||||||
|
migration::handle_command(migrate_command, &config).await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use clap::CommandFactory;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn cli_definition_has_no_flag_conflicts() {
|
||||||
|
Cli::command().debug_assert();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
538
src/memory/hygiene.rs
Normal file
538
src/memory/hygiene.rs
Normal file
|
|
@ -0,0 +1,538 @@
|
||||||
|
use crate::config::MemoryConfig;
|
||||||
|
use anyhow::Result;
|
||||||
|
use chrono::{DateTime, Duration, Local, NaiveDate, Utc};
|
||||||
|
use rusqlite::{params, Connection};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::fs;
|
||||||
|
use std::path::{Path, PathBuf};
|
||||||
|
use std::time::{Duration as StdDuration, SystemTime};
|
||||||
|
|
||||||
|
const HYGIENE_INTERVAL_HOURS: i64 = 12;
|
||||||
|
const STATE_FILE: &str = "memory_hygiene_state.json";
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
|
||||||
|
struct HygieneReport {
|
||||||
|
archived_memory_files: u64,
|
||||||
|
archived_session_files: u64,
|
||||||
|
purged_memory_archives: u64,
|
||||||
|
purged_session_archives: u64,
|
||||||
|
pruned_conversation_rows: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl HygieneReport {
|
||||||
|
fn total_actions(&self) -> u64 {
|
||||||
|
self.archived_memory_files
|
||||||
|
+ self.archived_session_files
|
||||||
|
+ self.purged_memory_archives
|
||||||
|
+ self.purged_session_archives
|
||||||
|
+ self.pruned_conversation_rows
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
|
||||||
|
struct HygieneState {
|
||||||
|
last_run_at: Option<String>,
|
||||||
|
last_report: HygieneReport,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Run memory/session hygiene if the cadence window has elapsed.
|
||||||
|
///
|
||||||
|
/// This function is intentionally best-effort: callers should log and continue on failure.
|
||||||
|
pub fn run_if_due(config: &MemoryConfig, workspace_dir: &Path) -> Result<()> {
|
||||||
|
if !config.hygiene_enabled {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
if !should_run_now(workspace_dir)? {
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
let report = HygieneReport {
|
||||||
|
archived_memory_files: archive_daily_memory_files(
|
||||||
|
workspace_dir,
|
||||||
|
config.archive_after_days,
|
||||||
|
)?,
|
||||||
|
archived_session_files: archive_session_files(workspace_dir, config.archive_after_days)?,
|
||||||
|
purged_memory_archives: purge_memory_archives(workspace_dir, config.purge_after_days)?,
|
||||||
|
purged_session_archives: purge_session_archives(workspace_dir, config.purge_after_days)?,
|
||||||
|
pruned_conversation_rows: prune_conversation_rows(
|
||||||
|
workspace_dir,
|
||||||
|
config.conversation_retention_days,
|
||||||
|
)?,
|
||||||
|
};
|
||||||
|
|
||||||
|
write_state(workspace_dir, &report)?;
|
||||||
|
|
||||||
|
if report.total_actions() > 0 {
|
||||||
|
tracing::info!(
|
||||||
|
"memory hygiene complete: archived_memory={} archived_sessions={} purged_memory={} purged_sessions={} pruned_conversation_rows={}",
|
||||||
|
report.archived_memory_files,
|
||||||
|
report.archived_session_files,
|
||||||
|
report.purged_memory_archives,
|
||||||
|
report.purged_session_archives,
|
||||||
|
report.pruned_conversation_rows,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn should_run_now(workspace_dir: &Path) -> Result<bool> {
|
||||||
|
let path = state_path(workspace_dir);
|
||||||
|
if !path.exists() {
|
||||||
|
return Ok(true);
|
||||||
|
}
|
||||||
|
|
||||||
|
let raw = fs::read_to_string(&path)?;
|
||||||
|
let state: HygieneState = match serde_json::from_str(&raw) {
|
||||||
|
Ok(s) => s,
|
||||||
|
Err(_) => return Ok(true),
|
||||||
|
};
|
||||||
|
|
||||||
|
let Some(last_run_at) = state.last_run_at else {
|
||||||
|
return Ok(true);
|
||||||
|
};
|
||||||
|
|
||||||
|
let last = match DateTime::parse_from_rfc3339(&last_run_at) {
|
||||||
|
Ok(ts) => ts.with_timezone(&Utc),
|
||||||
|
Err(_) => return Ok(true),
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(Utc::now().signed_duration_since(last) >= Duration::hours(HYGIENE_INTERVAL_HOURS))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn write_state(workspace_dir: &Path, report: &HygieneReport) -> Result<()> {
|
||||||
|
let path = state_path(workspace_dir);
|
||||||
|
if let Some(parent) = path.parent() {
|
||||||
|
fs::create_dir_all(parent)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
let state = HygieneState {
|
||||||
|
last_run_at: Some(Utc::now().to_rfc3339()),
|
||||||
|
last_report: report.clone(),
|
||||||
|
};
|
||||||
|
let json = serde_json::to_vec_pretty(&state)?;
|
||||||
|
fs::write(path, json)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn state_path(workspace_dir: &Path) -> PathBuf {
|
||||||
|
workspace_dir.join("state").join(STATE_FILE)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn archive_daily_memory_files(workspace_dir: &Path, archive_after_days: u32) -> Result<u64> {
|
||||||
|
if archive_after_days == 0 {
|
||||||
|
return Ok(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
let memory_dir = workspace_dir.join("memory");
|
||||||
|
if !memory_dir.is_dir() {
|
||||||
|
return Ok(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
let archive_dir = memory_dir.join("archive");
|
||||||
|
fs::create_dir_all(&archive_dir)?;
|
||||||
|
|
||||||
|
let cutoff = Local::now().date_naive() - Duration::days(i64::from(archive_after_days));
|
||||||
|
let mut moved = 0_u64;
|
||||||
|
|
||||||
|
for entry in fs::read_dir(&memory_dir)? {
|
||||||
|
let entry = entry?;
|
||||||
|
let path = entry.path();
|
||||||
|
|
||||||
|
if path.is_dir() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
if path.extension().and_then(|e| e.to_str()) != Some("md") {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let Some(filename) = path.file_name().and_then(|f| f.to_str()) else {
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
|
||||||
|
let Some(file_date) = memory_date_from_filename(filename) else {
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
|
||||||
|
if file_date < cutoff {
|
||||||
|
move_to_archive(&path, &archive_dir)?;
|
||||||
|
moved += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(moved)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn archive_session_files(workspace_dir: &Path, archive_after_days: u32) -> Result<u64> {
|
||||||
|
if archive_after_days == 0 {
|
||||||
|
return Ok(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
let sessions_dir = workspace_dir.join("sessions");
|
||||||
|
if !sessions_dir.is_dir() {
|
||||||
|
return Ok(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
let archive_dir = sessions_dir.join("archive");
|
||||||
|
fs::create_dir_all(&archive_dir)?;
|
||||||
|
|
||||||
|
let cutoff_date = Local::now().date_naive() - Duration::days(i64::from(archive_after_days));
|
||||||
|
let cutoff_time = SystemTime::now()
|
||||||
|
.checked_sub(StdDuration::from_secs(
|
||||||
|
u64::from(archive_after_days) * 24 * 60 * 60,
|
||||||
|
))
|
||||||
|
.unwrap_or(SystemTime::UNIX_EPOCH);
|
||||||
|
|
||||||
|
let mut moved = 0_u64;
|
||||||
|
for entry in fs::read_dir(&sessions_dir)? {
|
||||||
|
let entry = entry?;
|
||||||
|
let path = entry.path();
|
||||||
|
|
||||||
|
if path.is_dir() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let Some(filename) = path.file_name().and_then(|f| f.to_str()) else {
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
|
||||||
|
let is_old = if let Some(date) = date_prefix(filename) {
|
||||||
|
date < cutoff_date
|
||||||
|
} else {
|
||||||
|
is_older_than(&path, cutoff_time)
|
||||||
|
};
|
||||||
|
|
||||||
|
if is_old {
|
||||||
|
move_to_archive(&path, &archive_dir)?;
|
||||||
|
moved += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(moved)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn purge_memory_archives(workspace_dir: &Path, purge_after_days: u32) -> Result<u64> {
|
||||||
|
if purge_after_days == 0 {
|
||||||
|
return Ok(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
let archive_dir = workspace_dir.join("memory").join("archive");
|
||||||
|
if !archive_dir.is_dir() {
|
||||||
|
return Ok(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
let cutoff = Local::now().date_naive() - Duration::days(i64::from(purge_after_days));
|
||||||
|
let mut removed = 0_u64;
|
||||||
|
|
||||||
|
for entry in fs::read_dir(&archive_dir)? {
|
||||||
|
let entry = entry?;
|
||||||
|
let path = entry.path();
|
||||||
|
|
||||||
|
if path.is_dir() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let Some(filename) = path.file_name().and_then(|f| f.to_str()) else {
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
|
||||||
|
let Some(file_date) = memory_date_from_filename(filename) else {
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
|
||||||
|
if file_date < cutoff {
|
||||||
|
fs::remove_file(&path)?;
|
||||||
|
removed += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(removed)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn purge_session_archives(workspace_dir: &Path, purge_after_days: u32) -> Result<u64> {
|
||||||
|
if purge_after_days == 0 {
|
||||||
|
return Ok(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
let archive_dir = workspace_dir.join("sessions").join("archive");
|
||||||
|
if !archive_dir.is_dir() {
|
||||||
|
return Ok(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
let cutoff_date = Local::now().date_naive() - Duration::days(i64::from(purge_after_days));
|
||||||
|
let cutoff_time = SystemTime::now()
|
||||||
|
.checked_sub(StdDuration::from_secs(
|
||||||
|
u64::from(purge_after_days) * 24 * 60 * 60,
|
||||||
|
))
|
||||||
|
.unwrap_or(SystemTime::UNIX_EPOCH);
|
||||||
|
|
||||||
|
let mut removed = 0_u64;
|
||||||
|
for entry in fs::read_dir(&archive_dir)? {
|
||||||
|
let entry = entry?;
|
||||||
|
let path = entry.path();
|
||||||
|
|
||||||
|
if path.is_dir() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let Some(filename) = path.file_name().and_then(|f| f.to_str()) else {
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
|
||||||
|
let is_old = if let Some(date) = date_prefix(filename) {
|
||||||
|
date < cutoff_date
|
||||||
|
} else {
|
||||||
|
is_older_than(&path, cutoff_time)
|
||||||
|
};
|
||||||
|
|
||||||
|
if is_old {
|
||||||
|
fs::remove_file(&path)?;
|
||||||
|
removed += 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(removed)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn prune_conversation_rows(workspace_dir: &Path, retention_days: u32) -> Result<u64> {
|
||||||
|
if retention_days == 0 {
|
||||||
|
return Ok(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
let db_path = workspace_dir.join("memory").join("brain.db");
|
||||||
|
if !db_path.exists() {
|
||||||
|
return Ok(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
let conn = Connection::open(db_path)?;
|
||||||
|
let cutoff = (Local::now() - Duration::days(i64::from(retention_days))).to_rfc3339();
|
||||||
|
|
||||||
|
let affected = conn.execute(
|
||||||
|
"DELETE FROM memories WHERE category = 'conversation' AND updated_at < ?1",
|
||||||
|
params![cutoff],
|
||||||
|
)?;
|
||||||
|
|
||||||
|
Ok(u64::try_from(affected).unwrap_or(0))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn memory_date_from_filename(filename: &str) -> Option<NaiveDate> {
|
||||||
|
let stem = filename.strip_suffix(".md")?;
|
||||||
|
let date_part = stem.split('_').next().unwrap_or(stem);
|
||||||
|
NaiveDate::parse_from_str(date_part, "%Y-%m-%d").ok()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn date_prefix(filename: &str) -> Option<NaiveDate> {
|
||||||
|
if filename.len() < 10 {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
NaiveDate::parse_from_str(&filename[..10], "%Y-%m-%d").ok()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_older_than(path: &Path, cutoff: SystemTime) -> bool {
|
||||||
|
fs::metadata(path)
|
||||||
|
.and_then(|meta| meta.modified())
|
||||||
|
.map(|modified| modified < cutoff)
|
||||||
|
.unwrap_or(false)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn move_to_archive(src: &Path, archive_dir: &Path) -> Result<()> {
|
||||||
|
let Some(filename) = src.file_name().and_then(|f| f.to_str()) else {
|
||||||
|
return Ok(());
|
||||||
|
};
|
||||||
|
|
||||||
|
let target = unique_archive_target(archive_dir, filename);
|
||||||
|
fs::rename(src, target)?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn unique_archive_target(archive_dir: &Path, filename: &str) -> PathBuf {
|
||||||
|
let direct = archive_dir.join(filename);
|
||||||
|
if !direct.exists() {
|
||||||
|
return direct;
|
||||||
|
}
|
||||||
|
|
||||||
|
let (stem, ext) = split_name(filename);
|
||||||
|
for i in 1..10_000 {
|
||||||
|
let candidate = if ext.is_empty() {
|
||||||
|
archive_dir.join(format!("{stem}_{i}"))
|
||||||
|
} else {
|
||||||
|
archive_dir.join(format!("{stem}_{i}.{ext}"))
|
||||||
|
};
|
||||||
|
if !candidate.exists() {
|
||||||
|
return candidate;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
direct
|
||||||
|
}
|
||||||
|
|
||||||
|
fn split_name(filename: &str) -> (&str, &str) {
|
||||||
|
match filename.rsplit_once('.') {
|
||||||
|
Some((stem, ext)) => (stem, ext),
|
||||||
|
None => (filename, ""),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::memory::{Memory, MemoryCategory, SqliteMemory};
|
||||||
|
use tempfile::TempDir;
|
||||||
|
|
||||||
|
fn default_cfg() -> MemoryConfig {
|
||||||
|
MemoryConfig::default()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn archives_old_daily_memory_files() {
|
||||||
|
let tmp = TempDir::new().unwrap();
|
||||||
|
let workspace = tmp.path();
|
||||||
|
fs::create_dir_all(workspace.join("memory")).unwrap();
|
||||||
|
|
||||||
|
let old = (Local::now().date_naive() - Duration::days(10))
|
||||||
|
.format("%Y-%m-%d")
|
||||||
|
.to_string();
|
||||||
|
let today = Local::now().date_naive().format("%Y-%m-%d").to_string();
|
||||||
|
|
||||||
|
let old_file = workspace.join("memory").join(format!("{old}.md"));
|
||||||
|
let today_file = workspace.join("memory").join(format!("{today}.md"));
|
||||||
|
fs::write(&old_file, "old note").unwrap();
|
||||||
|
fs::write(&today_file, "fresh note").unwrap();
|
||||||
|
|
||||||
|
run_if_due(&default_cfg(), workspace).unwrap();
|
||||||
|
|
||||||
|
assert!(!old_file.exists(), "old daily file should be archived");
|
||||||
|
assert!(
|
||||||
|
workspace
|
||||||
|
.join("memory")
|
||||||
|
.join("archive")
|
||||||
|
.join(format!("{old}.md"))
|
||||||
|
.exists(),
|
||||||
|
"old daily file should exist in memory/archive"
|
||||||
|
);
|
||||||
|
assert!(today_file.exists(), "today file should remain in place");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn archives_old_session_files() {
|
||||||
|
let tmp = TempDir::new().unwrap();
|
||||||
|
let workspace = tmp.path();
|
||||||
|
fs::create_dir_all(workspace.join("sessions")).unwrap();
|
||||||
|
|
||||||
|
let old = (Local::now().date_naive() - Duration::days(10))
|
||||||
|
.format("%Y-%m-%d")
|
||||||
|
.to_string();
|
||||||
|
let old_name = format!("{old}-agent.log");
|
||||||
|
let old_file = workspace.join("sessions").join(&old_name);
|
||||||
|
fs::write(&old_file, "old session").unwrap();
|
||||||
|
|
||||||
|
run_if_due(&default_cfg(), workspace).unwrap();
|
||||||
|
|
||||||
|
assert!(!old_file.exists(), "old session file should be archived");
|
||||||
|
assert!(
|
||||||
|
workspace
|
||||||
|
.join("sessions")
|
||||||
|
.join("archive")
|
||||||
|
.join(&old_name)
|
||||||
|
.exists(),
|
||||||
|
"archived session file should exist"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn skips_second_run_within_cadence_window() {
|
||||||
|
let tmp = TempDir::new().unwrap();
|
||||||
|
let workspace = tmp.path();
|
||||||
|
fs::create_dir_all(workspace.join("memory")).unwrap();
|
||||||
|
|
||||||
|
let old_a = (Local::now().date_naive() - Duration::days(10))
|
||||||
|
.format("%Y-%m-%d")
|
||||||
|
.to_string();
|
||||||
|
let file_a = workspace.join("memory").join(format!("{old_a}.md"));
|
||||||
|
fs::write(&file_a, "first").unwrap();
|
||||||
|
|
||||||
|
run_if_due(&default_cfg(), workspace).unwrap();
|
||||||
|
assert!(!file_a.exists(), "first old file should be archived");
|
||||||
|
|
||||||
|
let old_b = (Local::now().date_naive() - Duration::days(9))
|
||||||
|
.format("%Y-%m-%d")
|
||||||
|
.to_string();
|
||||||
|
let file_b = workspace.join("memory").join(format!("{old_b}.md"));
|
||||||
|
fs::write(&file_b, "second").unwrap();
|
||||||
|
|
||||||
|
// Should skip because cadence gate prevents a second immediate run.
|
||||||
|
run_if_due(&default_cfg(), workspace).unwrap();
|
||||||
|
assert!(
|
||||||
|
file_b.exists(),
|
||||||
|
"second file should remain because run is throttled"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn purges_old_memory_archives() {
|
||||||
|
let tmp = TempDir::new().unwrap();
|
||||||
|
let workspace = tmp.path();
|
||||||
|
let archive_dir = workspace.join("memory").join("archive");
|
||||||
|
fs::create_dir_all(&archive_dir).unwrap();
|
||||||
|
|
||||||
|
let old = (Local::now().date_naive() - Duration::days(40))
|
||||||
|
.format("%Y-%m-%d")
|
||||||
|
.to_string();
|
||||||
|
let keep = (Local::now().date_naive() - Duration::days(5))
|
||||||
|
.format("%Y-%m-%d")
|
||||||
|
.to_string();
|
||||||
|
|
||||||
|
let old_file = archive_dir.join(format!("{old}.md"));
|
||||||
|
let keep_file = archive_dir.join(format!("{keep}.md"));
|
||||||
|
fs::write(&old_file, "expired").unwrap();
|
||||||
|
fs::write(&keep_file, "recent").unwrap();
|
||||||
|
|
||||||
|
run_if_due(&default_cfg(), workspace).unwrap();
|
||||||
|
|
||||||
|
assert!(!old_file.exists(), "old archived file should be purged");
|
||||||
|
assert!(keep_file.exists(), "recent archived file should remain");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn prunes_old_conversation_rows_in_sqlite_backend() {
|
||||||
|
let tmp = TempDir::new().unwrap();
|
||||||
|
let workspace = tmp.path();
|
||||||
|
|
||||||
|
let mem = SqliteMemory::new(workspace).unwrap();
|
||||||
|
mem.store("conv_old", "outdated", MemoryCategory::Conversation)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
mem.store("core_keep", "durable", MemoryCategory::Core)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
drop(mem);
|
||||||
|
|
||||||
|
let db_path = workspace.join("memory").join("brain.db");
|
||||||
|
let conn = Connection::open(&db_path).unwrap();
|
||||||
|
let old_cutoff = (Local::now() - Duration::days(60)).to_rfc3339();
|
||||||
|
conn.execute(
|
||||||
|
"UPDATE memories SET created_at = ?1, updated_at = ?1 WHERE key = 'conv_old'",
|
||||||
|
params![old_cutoff],
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
drop(conn);
|
||||||
|
|
||||||
|
let mut cfg = default_cfg();
|
||||||
|
cfg.archive_after_days = 0;
|
||||||
|
cfg.purge_after_days = 0;
|
||||||
|
cfg.conversation_retention_days = 30;
|
||||||
|
|
||||||
|
run_if_due(&cfg, workspace).unwrap();
|
||||||
|
|
||||||
|
let mem2 = SqliteMemory::new(workspace).unwrap();
|
||||||
|
assert!(
|
||||||
|
mem2.get("conv_old").await.unwrap().is_none(),
|
||||||
|
"old conversation rows should be pruned"
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
mem2.get("core_keep").await.unwrap().is_some(),
|
||||||
|
"core memory should remain"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
pub mod chunker;
|
pub mod chunker;
|
||||||
pub mod embeddings;
|
pub mod embeddings;
|
||||||
|
pub mod hygiene;
|
||||||
pub mod markdown;
|
pub mod markdown;
|
||||||
pub mod sqlite;
|
pub mod sqlite;
|
||||||
pub mod traits;
|
pub mod traits;
|
||||||
|
|
@ -21,6 +22,11 @@ pub fn create_memory(
|
||||||
workspace_dir: &Path,
|
workspace_dir: &Path,
|
||||||
api_key: Option<&str>,
|
api_key: Option<&str>,
|
||||||
) -> anyhow::Result<Box<dyn Memory>> {
|
) -> anyhow::Result<Box<dyn Memory>> {
|
||||||
|
// Best-effort memory hygiene/retention pass (throttled by state file).
|
||||||
|
if let Err(e) = hygiene::run_if_due(config, workspace_dir) {
|
||||||
|
tracing::warn!("memory hygiene skipped: {e}");
|
||||||
|
}
|
||||||
|
|
||||||
match config.backend.as_str() {
|
match config.backend.as_str() {
|
||||||
"sqlite" => {
|
"sqlite" => {
|
||||||
let embedder: Arc<dyn embeddings::EmbeddingProvider> =
|
let embedder: Arc<dyn embeddings::EmbeddingProvider> =
|
||||||
|
|
|
||||||
553
src/migration.rs
Normal file
553
src/migration.rs
Normal file
|
|
@ -0,0 +1,553 @@
|
||||||
|
use crate::config::Config;
|
||||||
|
use crate::memory::{MarkdownMemory, Memory, MemoryCategory, SqliteMemory};
|
||||||
|
use anyhow::{bail, Context, Result};
|
||||||
|
use directories::UserDirs;
|
||||||
|
use rusqlite::{Connection, OpenFlags, OptionalExtension};
|
||||||
|
use std::collections::HashSet;
|
||||||
|
use std::fs;
|
||||||
|
use std::path::{Path, PathBuf};
|
||||||
|
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
struct SourceEntry {
|
||||||
|
key: String,
|
||||||
|
content: String,
|
||||||
|
category: MemoryCategory,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Default)]
|
||||||
|
struct MigrationStats {
|
||||||
|
from_sqlite: usize,
|
||||||
|
from_markdown: usize,
|
||||||
|
imported: usize,
|
||||||
|
skipped_unchanged: usize,
|
||||||
|
renamed_conflicts: usize,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn handle_command(command: super::MigrateCommands, config: &Config) -> Result<()> {
|
||||||
|
match command {
|
||||||
|
super::MigrateCommands::Openclaw { source, dry_run } => {
|
||||||
|
migrate_openclaw_memory(config, source, dry_run).await
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn migrate_openclaw_memory(
|
||||||
|
config: &Config,
|
||||||
|
source_workspace: Option<PathBuf>,
|
||||||
|
dry_run: bool,
|
||||||
|
) -> Result<()> {
|
||||||
|
let source_workspace = resolve_openclaw_workspace(source_workspace)?;
|
||||||
|
if !source_workspace.exists() {
|
||||||
|
bail!(
|
||||||
|
"OpenClaw workspace not found at {}. Pass --source <path> if needed.",
|
||||||
|
source_workspace.display()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if paths_equal(&source_workspace, &config.workspace_dir) {
|
||||||
|
bail!("Source workspace matches current ZeroClaw workspace; refusing self-migration");
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut stats = MigrationStats::default();
|
||||||
|
let entries = collect_source_entries(&source_workspace, &mut stats)?;
|
||||||
|
|
||||||
|
if entries.is_empty() {
|
||||||
|
println!(
|
||||||
|
"No importable memory found in {}",
|
||||||
|
source_workspace.display()
|
||||||
|
);
|
||||||
|
println!("Checked for: memory/brain.db, MEMORY.md, memory/*.md");
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
if dry_run {
|
||||||
|
println!("🔎 Dry run: OpenClaw migration preview");
|
||||||
|
println!(" Source: {}", source_workspace.display());
|
||||||
|
println!(" Target: {}", config.workspace_dir.display());
|
||||||
|
println!(" Candidates: {}", entries.len());
|
||||||
|
println!(" - from sqlite: {}", stats.from_sqlite);
|
||||||
|
println!(" - from markdown: {}", stats.from_markdown);
|
||||||
|
println!();
|
||||||
|
println!("Run without --dry-run to import these entries.");
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(backup_dir) = backup_target_memory(&config.workspace_dir)? {
|
||||||
|
println!("🛟 Backup created: {}", backup_dir.display());
|
||||||
|
}
|
||||||
|
|
||||||
|
let memory = target_memory_backend(config)?;
|
||||||
|
|
||||||
|
for (idx, entry) in entries.into_iter().enumerate() {
|
||||||
|
let mut key = entry.key.trim().to_string();
|
||||||
|
if key.is_empty() {
|
||||||
|
key = format!("openclaw_{idx}");
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(existing) = memory.get(&key).await? {
|
||||||
|
if existing.content.trim() == entry.content.trim() {
|
||||||
|
stats.skipped_unchanged += 1;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let renamed = next_available_key(memory.as_ref(), &key).await?;
|
||||||
|
key = renamed;
|
||||||
|
stats.renamed_conflicts += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
memory.store(&key, &entry.content, entry.category).await?;
|
||||||
|
stats.imported += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
println!("✅ OpenClaw memory migration complete");
|
||||||
|
println!(" Source: {}", source_workspace.display());
|
||||||
|
println!(" Target: {}", config.workspace_dir.display());
|
||||||
|
println!(" Imported: {}", stats.imported);
|
||||||
|
println!(" Skipped unchanged:{}", stats.skipped_unchanged);
|
||||||
|
println!(" Renamed conflicts:{}", stats.renamed_conflicts);
|
||||||
|
println!(" Source sqlite rows:{}", stats.from_sqlite);
|
||||||
|
println!(" Source markdown: {}", stats.from_markdown);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn target_memory_backend(config: &Config) -> Result<Box<dyn Memory>> {
|
||||||
|
match config.memory.backend.as_str() {
|
||||||
|
"sqlite" => Ok(Box::new(SqliteMemory::new(&config.workspace_dir)?)),
|
||||||
|
"markdown" | "none" => Ok(Box::new(MarkdownMemory::new(&config.workspace_dir))),
|
||||||
|
other => {
|
||||||
|
tracing::warn!(
|
||||||
|
"Unknown memory backend '{other}' during migration, defaulting to markdown"
|
||||||
|
);
|
||||||
|
Ok(Box::new(MarkdownMemory::new(&config.workspace_dir)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn collect_source_entries(
|
||||||
|
source_workspace: &Path,
|
||||||
|
stats: &mut MigrationStats,
|
||||||
|
) -> Result<Vec<SourceEntry>> {
|
||||||
|
let mut entries = Vec::new();
|
||||||
|
|
||||||
|
let sqlite_path = source_workspace.join("memory").join("brain.db");
|
||||||
|
let sqlite_entries = read_openclaw_sqlite_entries(&sqlite_path)?;
|
||||||
|
stats.from_sqlite = sqlite_entries.len();
|
||||||
|
entries.extend(sqlite_entries);
|
||||||
|
|
||||||
|
let markdown_entries = read_openclaw_markdown_entries(source_workspace)?;
|
||||||
|
stats.from_markdown = markdown_entries.len();
|
||||||
|
entries.extend(markdown_entries);
|
||||||
|
|
||||||
|
// De-dup exact duplicates to make re-runs deterministic.
|
||||||
|
let mut seen = HashSet::new();
|
||||||
|
entries.retain(|entry| {
|
||||||
|
let sig = format!("{}\u{0}{}\u{0}{}", entry.key, entry.content, entry.category);
|
||||||
|
seen.insert(sig)
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(entries)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn read_openclaw_sqlite_entries(db_path: &Path) -> Result<Vec<SourceEntry>> {
|
||||||
|
if !db_path.exists() {
|
||||||
|
return Ok(Vec::new());
|
||||||
|
}
|
||||||
|
|
||||||
|
let conn = Connection::open_with_flags(db_path, OpenFlags::SQLITE_OPEN_READ_ONLY)
|
||||||
|
.with_context(|| format!("Failed to open source db {}", db_path.display()))?;
|
||||||
|
|
||||||
|
let table_exists: Option<String> = conn
|
||||||
|
.query_row(
|
||||||
|
"SELECT name FROM sqlite_master WHERE type='table' AND name='memories' LIMIT 1",
|
||||||
|
[],
|
||||||
|
|row| row.get(0),
|
||||||
|
)
|
||||||
|
.optional()?;
|
||||||
|
|
||||||
|
if table_exists.is_none() {
|
||||||
|
return Ok(Vec::new());
|
||||||
|
}
|
||||||
|
|
||||||
|
let columns = table_columns(&conn, "memories")?;
|
||||||
|
let key_expr = pick_column_expr(&columns, &["key", "id", "name"], "CAST(rowid AS TEXT)");
|
||||||
|
let Some(content_expr) =
|
||||||
|
pick_optional_column_expr(&columns, &["content", "value", "text", "memory"])
|
||||||
|
else {
|
||||||
|
bail!("OpenClaw memories table found but no content-like column was detected");
|
||||||
|
};
|
||||||
|
let category_expr = pick_column_expr(&columns, &["category", "kind", "type"], "'core'");
|
||||||
|
|
||||||
|
let sql = format!(
|
||||||
|
"SELECT {key_expr} AS key, {content_expr} AS content, {category_expr} AS category FROM memories"
|
||||||
|
);
|
||||||
|
|
||||||
|
let mut stmt = conn.prepare(&sql)?;
|
||||||
|
let mut rows = stmt.query([])?;
|
||||||
|
|
||||||
|
let mut entries = Vec::new();
|
||||||
|
let mut idx = 0_usize;
|
||||||
|
|
||||||
|
while let Some(row) = rows.next()? {
|
||||||
|
let key: String = row
|
||||||
|
.get(0)
|
||||||
|
.unwrap_or_else(|_| format!("openclaw_sqlite_{idx}"));
|
||||||
|
let content: String = row.get(1).unwrap_or_default();
|
||||||
|
let category_raw: String = row.get(2).unwrap_or_else(|_| "core".to_string());
|
||||||
|
|
||||||
|
if content.trim().is_empty() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
entries.push(SourceEntry {
|
||||||
|
key: normalize_key(&key, idx),
|
||||||
|
content: content.trim().to_string(),
|
||||||
|
category: parse_category(&category_raw),
|
||||||
|
});
|
||||||
|
|
||||||
|
idx += 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(entries)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn read_openclaw_markdown_entries(source_workspace: &Path) -> Result<Vec<SourceEntry>> {
|
||||||
|
let mut all = Vec::new();
|
||||||
|
|
||||||
|
let core_path = source_workspace.join("MEMORY.md");
|
||||||
|
if core_path.exists() {
|
||||||
|
let content = fs::read_to_string(&core_path)?;
|
||||||
|
all.extend(parse_markdown_file(
|
||||||
|
&core_path,
|
||||||
|
&content,
|
||||||
|
MemoryCategory::Core,
|
||||||
|
"openclaw_core",
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
let daily_dir = source_workspace.join("memory");
|
||||||
|
if daily_dir.exists() {
|
||||||
|
for file in fs::read_dir(&daily_dir)? {
|
||||||
|
let file = file?;
|
||||||
|
let path = file.path();
|
||||||
|
if path.extension().and_then(|ext| ext.to_str()) != Some("md") {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
let content = fs::read_to_string(&path)?;
|
||||||
|
let stem = path
|
||||||
|
.file_stem()
|
||||||
|
.and_then(|s| s.to_str())
|
||||||
|
.unwrap_or("openclaw_daily");
|
||||||
|
all.extend(parse_markdown_file(
|
||||||
|
&path,
|
||||||
|
&content,
|
||||||
|
MemoryCategory::Daily,
|
||||||
|
stem,
|
||||||
|
));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(all)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_markdown_file(
|
||||||
|
_path: &Path,
|
||||||
|
content: &str,
|
||||||
|
default_category: MemoryCategory,
|
||||||
|
stem: &str,
|
||||||
|
) -> Vec<SourceEntry> {
|
||||||
|
let mut entries = Vec::new();
|
||||||
|
|
||||||
|
for (idx, raw_line) in content.lines().enumerate() {
|
||||||
|
let trimmed = raw_line.trim();
|
||||||
|
if trimmed.is_empty() || trimmed.starts_with('#') {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let line = trimmed.strip_prefix("- ").unwrap_or(trimmed);
|
||||||
|
let (key, text) = match parse_structured_memory_line(line) {
|
||||||
|
Some((k, v)) => (normalize_key(k, idx), v.trim().to_string()),
|
||||||
|
None => (
|
||||||
|
format!("openclaw_{stem}_{}", idx + 1),
|
||||||
|
line.trim().to_string(),
|
||||||
|
),
|
||||||
|
};
|
||||||
|
|
||||||
|
if text.is_empty() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
entries.push(SourceEntry {
|
||||||
|
key,
|
||||||
|
content: text,
|
||||||
|
category: default_category.clone(),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
entries
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_structured_memory_line(line: &str) -> Option<(&str, &str)> {
|
||||||
|
if !line.starts_with("**") {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
let rest = line.strip_prefix("**")?;
|
||||||
|
let key_end = rest.find("**:")?;
|
||||||
|
let key = rest.get(..key_end)?.trim();
|
||||||
|
let value = rest.get(key_end + 3..)?.trim();
|
||||||
|
|
||||||
|
if key.is_empty() || value.is_empty() {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
Some((key, value))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_category(raw: &str) -> MemoryCategory {
|
||||||
|
match raw.trim().to_ascii_lowercase().as_str() {
|
||||||
|
"core" => MemoryCategory::Core,
|
||||||
|
"daily" => MemoryCategory::Daily,
|
||||||
|
"conversation" => MemoryCategory::Conversation,
|
||||||
|
"" => MemoryCategory::Core,
|
||||||
|
other => MemoryCategory::Custom(other.to_string()),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn normalize_key(key: &str, fallback_idx: usize) -> String {
|
||||||
|
let trimmed = key.trim();
|
||||||
|
if trimmed.is_empty() {
|
||||||
|
return format!("openclaw_{fallback_idx}");
|
||||||
|
}
|
||||||
|
trimmed.to_string()
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn next_available_key(memory: &dyn Memory, base: &str) -> Result<String> {
|
||||||
|
for i in 1..=10_000 {
|
||||||
|
let candidate = format!("{base}__openclaw_{i}");
|
||||||
|
if memory.get(&candidate).await?.is_none() {
|
||||||
|
return Ok(candidate);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bail!("Unable to allocate non-conflicting key for '{base}'")
|
||||||
|
}
|
||||||
|
|
||||||
|
fn table_columns(conn: &Connection, table: &str) -> Result<Vec<String>> {
|
||||||
|
let pragma = format!("PRAGMA table_info({table})");
|
||||||
|
let mut stmt = conn.prepare(&pragma)?;
|
||||||
|
let rows = stmt.query_map([], |row| row.get::<_, String>(1))?;
|
||||||
|
|
||||||
|
let mut cols = Vec::new();
|
||||||
|
for col in rows {
|
||||||
|
cols.push(col?.to_ascii_lowercase());
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(cols)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn pick_optional_column_expr(columns: &[String], candidates: &[&str]) -> Option<String> {
|
||||||
|
candidates
|
||||||
|
.iter()
|
||||||
|
.find(|candidate| columns.iter().any(|c| c == *candidate))
|
||||||
|
.map(|s| s.to_string())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn pick_column_expr(columns: &[String], candidates: &[&str], fallback: &str) -> String {
|
||||||
|
pick_optional_column_expr(columns, candidates).unwrap_or_else(|| fallback.to_string())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn resolve_openclaw_workspace(source: Option<PathBuf>) -> Result<PathBuf> {
|
||||||
|
if let Some(src) = source {
|
||||||
|
return Ok(src);
|
||||||
|
}
|
||||||
|
|
||||||
|
let home = UserDirs::new()
|
||||||
|
.map(|u| u.home_dir().to_path_buf())
|
||||||
|
.context("Could not find home directory")?;
|
||||||
|
|
||||||
|
Ok(home.join(".openclaw").join("workspace"))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn paths_equal(a: &Path, b: &Path) -> bool {
|
||||||
|
match (fs::canonicalize(a), fs::canonicalize(b)) {
|
||||||
|
(Ok(a), Ok(b)) => a == b,
|
||||||
|
_ => a == b,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn backup_target_memory(workspace_dir: &Path) -> Result<Option<PathBuf>> {
|
||||||
|
let timestamp = chrono::Local::now().format("%Y%m%d-%H%M%S").to_string();
|
||||||
|
let backup_root = workspace_dir
|
||||||
|
.join("memory")
|
||||||
|
.join("migrations")
|
||||||
|
.join(format!("openclaw-{timestamp}"));
|
||||||
|
|
||||||
|
let mut copied_any = false;
|
||||||
|
fs::create_dir_all(&backup_root)?;
|
||||||
|
|
||||||
|
let files_to_copy = [
|
||||||
|
workspace_dir.join("memory").join("brain.db"),
|
||||||
|
workspace_dir.join("MEMORY.md"),
|
||||||
|
];
|
||||||
|
|
||||||
|
for source in files_to_copy {
|
||||||
|
if source.exists() {
|
||||||
|
let Some(name) = source.file_name() else {
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
fs::copy(&source, backup_root.join(name))?;
|
||||||
|
copied_any = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let daily_dir = workspace_dir.join("memory");
|
||||||
|
if daily_dir.exists() {
|
||||||
|
let daily_backup = backup_root.join("daily");
|
||||||
|
for file in fs::read_dir(&daily_dir)? {
|
||||||
|
let file = file?;
|
||||||
|
let path = file.path();
|
||||||
|
if path.extension().and_then(|ext| ext.to_str()) != Some("md") {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
fs::create_dir_all(&daily_backup)?;
|
||||||
|
let Some(name) = path.file_name() else {
|
||||||
|
continue;
|
||||||
|
};
|
||||||
|
fs::copy(&path, daily_backup.join(name))?;
|
||||||
|
copied_any = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if copied_any {
|
||||||
|
Ok(Some(backup_root))
|
||||||
|
} else {
|
||||||
|
let _ = fs::remove_dir_all(&backup_root);
|
||||||
|
Ok(None)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::config::{Config, MemoryConfig};
|
||||||
|
use rusqlite::params;
|
||||||
|
use tempfile::TempDir;
|
||||||
|
|
||||||
|
fn test_config(workspace: &Path) -> Config {
|
||||||
|
Config {
|
||||||
|
workspace_dir: workspace.to_path_buf(),
|
||||||
|
config_path: workspace.join("config.toml"),
|
||||||
|
memory: MemoryConfig {
|
||||||
|
backend: "sqlite".to_string(),
|
||||||
|
..MemoryConfig::default()
|
||||||
|
},
|
||||||
|
..Config::default()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_structured_markdown_line() {
|
||||||
|
let line = "**user_pref**: likes Rust";
|
||||||
|
let parsed = parse_structured_memory_line(line).unwrap();
|
||||||
|
assert_eq!(parsed.0, "user_pref");
|
||||||
|
assert_eq!(parsed.1, "likes Rust");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_unstructured_markdown_generates_key() {
|
||||||
|
let entries = parse_markdown_file(
|
||||||
|
Path::new("/tmp/MEMORY.md"),
|
||||||
|
"- plain note",
|
||||||
|
MemoryCategory::Core,
|
||||||
|
"core",
|
||||||
|
);
|
||||||
|
assert_eq!(entries.len(), 1);
|
||||||
|
assert!(entries[0].key.starts_with("openclaw_core_"));
|
||||||
|
assert_eq!(entries[0].content, "plain note");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn sqlite_reader_supports_legacy_value_column() {
|
||||||
|
let dir = TempDir::new().unwrap();
|
||||||
|
let db_path = dir.path().join("brain.db");
|
||||||
|
let conn = Connection::open(&db_path).unwrap();
|
||||||
|
|
||||||
|
conn.execute_batch("CREATE TABLE memories (key TEXT, value TEXT, type TEXT);")
|
||||||
|
.unwrap();
|
||||||
|
conn.execute(
|
||||||
|
"INSERT INTO memories (key, value, type) VALUES (?1, ?2, ?3)",
|
||||||
|
params!["legacy_key", "legacy_value", "daily"],
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let rows = read_openclaw_sqlite_entries(&db_path).unwrap();
|
||||||
|
assert_eq!(rows.len(), 1);
|
||||||
|
assert_eq!(rows[0].key, "legacy_key");
|
||||||
|
assert_eq!(rows[0].content, "legacy_value");
|
||||||
|
assert_eq!(rows[0].category, MemoryCategory::Daily);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn migration_renames_conflicting_key() {
|
||||||
|
let source = TempDir::new().unwrap();
|
||||||
|
let target = TempDir::new().unwrap();
|
||||||
|
|
||||||
|
// Existing target memory
|
||||||
|
let target_mem = SqliteMemory::new(target.path()).unwrap();
|
||||||
|
target_mem
|
||||||
|
.store("k", "new value", MemoryCategory::Core)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// Source sqlite with conflicting key + different content
|
||||||
|
let source_db_dir = source.path().join("memory");
|
||||||
|
fs::create_dir_all(&source_db_dir).unwrap();
|
||||||
|
let source_db = source_db_dir.join("brain.db");
|
||||||
|
let conn = Connection::open(&source_db).unwrap();
|
||||||
|
conn.execute_batch("CREATE TABLE memories (key TEXT, content TEXT, category TEXT);")
|
||||||
|
.unwrap();
|
||||||
|
conn.execute(
|
||||||
|
"INSERT INTO memories (key, content, category) VALUES (?1, ?2, ?3)",
|
||||||
|
params!["k", "old value", "core"],
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let config = test_config(target.path());
|
||||||
|
migrate_openclaw_memory(&config, Some(source.path().to_path_buf()), false)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let all = target_mem.list(None).await.unwrap();
|
||||||
|
assert!(all.iter().any(|e| e.key == "k" && e.content == "new value"));
|
||||||
|
assert!(all
|
||||||
|
.iter()
|
||||||
|
.any(|e| e.key.starts_with("k__openclaw_") && e.content == "old value"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn dry_run_does_not_write() {
|
||||||
|
let source = TempDir::new().unwrap();
|
||||||
|
let target = TempDir::new().unwrap();
|
||||||
|
let source_db_dir = source.path().join("memory");
|
||||||
|
fs::create_dir_all(&source_db_dir).unwrap();
|
||||||
|
|
||||||
|
let source_db = source_db_dir.join("brain.db");
|
||||||
|
let conn = Connection::open(&source_db).unwrap();
|
||||||
|
conn.execute_batch("CREATE TABLE memories (key TEXT, content TEXT, category TEXT);")
|
||||||
|
.unwrap();
|
||||||
|
conn.execute(
|
||||||
|
"INSERT INTO memories (key, content, category) VALUES (?1, ?2, ?3)",
|
||||||
|
params!["dry", "run", "core"],
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let config = test_config(target.path());
|
||||||
|
migrate_openclaw_memory(&config, Some(source.path().to_path_buf()), true)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let target_mem = SqliteMemory::new(target.path()).unwrap();
|
||||||
|
assert_eq!(target_mem.count().await.unwrap(), 0);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -1,3 +1,3 @@
|
||||||
pub mod wizard;
|
pub mod wizard;
|
||||||
|
|
||||||
pub use wizard::{run_quick_setup, run_wizard};
|
pub use wizard::{run_channels_repair_wizard, run_quick_setup, run_wizard};
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
|
use crate::config::schema::WhatsAppConfig;
|
||||||
use crate::config::{
|
use crate::config::{
|
||||||
AutonomyConfig, BrowserConfig, ChannelsConfig, ComposioConfig, Config, DiscordConfig,
|
AutonomyConfig, BrowserConfig, ChannelsConfig, ComposioConfig, Config, DiscordConfig,
|
||||||
HeartbeatConfig, IMessageConfig, MatrixConfig, MemoryConfig, ObservabilityConfig,
|
HeartbeatConfig, IMessageConfig, MatrixConfig, MemoryConfig, ObservabilityConfig,
|
||||||
|
|
@ -91,6 +92,7 @@ pub fn run_wizard() -> Result<Config> {
|
||||||
observability: ObservabilityConfig::default(),
|
observability: ObservabilityConfig::default(),
|
||||||
autonomy: AutonomyConfig::default(),
|
autonomy: AutonomyConfig::default(),
|
||||||
runtime: RuntimeConfig::default(),
|
runtime: RuntimeConfig::default(),
|
||||||
|
reliability: crate::config::ReliabilityConfig::default(),
|
||||||
heartbeat: HeartbeatConfig::default(),
|
heartbeat: HeartbeatConfig::default(),
|
||||||
channels_config,
|
channels_config,
|
||||||
memory: MemoryConfig::default(), // SQLite + auto-save by default
|
memory: MemoryConfig::default(), // SQLite + auto-save by default
|
||||||
|
|
@ -99,6 +101,7 @@ pub fn run_wizard() -> Result<Config> {
|
||||||
composio: composio_config,
|
composio: composio_config,
|
||||||
secrets: secrets_config,
|
secrets: secrets_config,
|
||||||
browser: BrowserConfig::default(),
|
browser: BrowserConfig::default(),
|
||||||
|
identity: crate::config::IdentityConfig::default(),
|
||||||
};
|
};
|
||||||
|
|
||||||
println!(
|
println!(
|
||||||
|
|
@ -149,6 +152,61 @@ pub fn run_wizard() -> Result<Config> {
|
||||||
Ok(config)
|
Ok(config)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Interactive repair flow: rerun channel setup only without redoing full onboarding.
|
||||||
|
pub fn run_channels_repair_wizard() -> Result<Config> {
|
||||||
|
println!("{}", style(BANNER).cyan().bold());
|
||||||
|
println!(
|
||||||
|
" {}",
|
||||||
|
style("Channels Repair — update channel tokens and allowlists only")
|
||||||
|
.white()
|
||||||
|
.bold()
|
||||||
|
);
|
||||||
|
println!();
|
||||||
|
|
||||||
|
let mut config = Config::load_or_init()?;
|
||||||
|
|
||||||
|
print_step(1, 1, "Channels (How You Talk to ZeroClaw)");
|
||||||
|
config.channels_config = setup_channels()?;
|
||||||
|
config.save()?;
|
||||||
|
|
||||||
|
println!();
|
||||||
|
println!(
|
||||||
|
" {} Channel config saved: {}",
|
||||||
|
style("✓").green().bold(),
|
||||||
|
style(config.config_path.display()).green()
|
||||||
|
);
|
||||||
|
|
||||||
|
let has_channels = config.channels_config.telegram.is_some()
|
||||||
|
|| config.channels_config.discord.is_some()
|
||||||
|
|| config.channels_config.slack.is_some()
|
||||||
|
|| config.channels_config.imessage.is_some()
|
||||||
|
|| config.channels_config.matrix.is_some();
|
||||||
|
|
||||||
|
if has_channels && config.api_key.is_some() {
|
||||||
|
let launch: bool = Confirm::new()
|
||||||
|
.with_prompt(format!(
|
||||||
|
" {} Launch channels now? (connected channels → AI → reply)",
|
||||||
|
style("🚀").cyan()
|
||||||
|
))
|
||||||
|
.default(true)
|
||||||
|
.interact()?;
|
||||||
|
|
||||||
|
if launch {
|
||||||
|
println!();
|
||||||
|
println!(
|
||||||
|
" {} {}",
|
||||||
|
style("⚡").cyan(),
|
||||||
|
style("Starting channel server...").white().bold()
|
||||||
|
);
|
||||||
|
println!();
|
||||||
|
// Signal to main.rs to call start_channels after wizard returns
|
||||||
|
std::env::set_var("ZEROCLAW_AUTOSTART_CHANNELS", "1");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(config)
|
||||||
|
}
|
||||||
|
|
||||||
// ── Quick setup (zero prompts) ───────────────────────────────────
|
// ── Quick setup (zero prompts) ───────────────────────────────────
|
||||||
|
|
||||||
/// Non-interactive setup: generates a sensible default config instantly.
|
/// Non-interactive setup: generates a sensible default config instantly.
|
||||||
|
|
@ -187,6 +245,7 @@ pub fn run_quick_setup(api_key: Option<&str>, provider: Option<&str>) -> Result<
|
||||||
observability: ObservabilityConfig::default(),
|
observability: ObservabilityConfig::default(),
|
||||||
autonomy: AutonomyConfig::default(),
|
autonomy: AutonomyConfig::default(),
|
||||||
runtime: RuntimeConfig::default(),
|
runtime: RuntimeConfig::default(),
|
||||||
|
reliability: crate::config::ReliabilityConfig::default(),
|
||||||
heartbeat: HeartbeatConfig::default(),
|
heartbeat: HeartbeatConfig::default(),
|
||||||
channels_config: ChannelsConfig::default(),
|
channels_config: ChannelsConfig::default(),
|
||||||
memory: MemoryConfig::default(),
|
memory: MemoryConfig::default(),
|
||||||
|
|
@ -195,6 +254,7 @@ pub fn run_quick_setup(api_key: Option<&str>, provider: Option<&str>) -> Result<
|
||||||
composio: ComposioConfig::default(),
|
composio: ComposioConfig::default(),
|
||||||
secrets: SecretsConfig::default(),
|
secrets: SecretsConfig::default(),
|
||||||
browser: BrowserConfig::default(),
|
browser: BrowserConfig::default(),
|
||||||
|
identity: crate::config::IdentityConfig::default(),
|
||||||
};
|
};
|
||||||
|
|
||||||
config.save()?;
|
config.save()?;
|
||||||
|
|
@ -204,7 +264,9 @@ pub fn run_quick_setup(api_key: Option<&str>, provider: Option<&str>) -> Result<
|
||||||
user_name: std::env::var("USER").unwrap_or_else(|_| "User".into()),
|
user_name: std::env::var("USER").unwrap_or_else(|_| "User".into()),
|
||||||
timezone: "UTC".into(),
|
timezone: "UTC".into(),
|
||||||
agent_name: "ZeroClaw".into(),
|
agent_name: "ZeroClaw".into(),
|
||||||
communication_style: "Direct and concise".into(),
|
communication_style:
|
||||||
|
"Be warm, natural, and clear. Use occasional relevant emojis (1-2 max) and avoid robotic phrasing."
|
||||||
|
.into(),
|
||||||
};
|
};
|
||||||
scaffold_workspace(&workspace_dir, &default_ctx)?;
|
scaffold_workspace(&workspace_dir, &default_ctx)?;
|
||||||
|
|
||||||
|
|
@ -878,24 +940,33 @@ fn setup_project_context() -> Result<ProjectContext> {
|
||||||
|
|
||||||
let style_options = vec![
|
let style_options = vec![
|
||||||
"Direct & concise — skip pleasantries, get to the point",
|
"Direct & concise — skip pleasantries, get to the point",
|
||||||
"Friendly & casual — warm but efficient",
|
"Friendly & casual — warm, human, and helpful",
|
||||||
|
"Professional & polished — calm, confident, and clear",
|
||||||
|
"Expressive & playful — more personality + natural emojis",
|
||||||
"Technical & detailed — thorough explanations, code-first",
|
"Technical & detailed — thorough explanations, code-first",
|
||||||
"Balanced — adapt to the situation",
|
"Balanced — adapt to the situation",
|
||||||
|
"Custom — write your own style guide",
|
||||||
];
|
];
|
||||||
|
|
||||||
let style_idx = Select::new()
|
let style_idx = Select::new()
|
||||||
.with_prompt(" Communication style")
|
.with_prompt(" Communication style")
|
||||||
.items(&style_options)
|
.items(&style_options)
|
||||||
.default(0)
|
.default(1)
|
||||||
.interact()?;
|
.interact()?;
|
||||||
|
|
||||||
let communication_style = match style_idx {
|
let communication_style = match style_idx {
|
||||||
0 => "Be direct and concise. Skip pleasantries. Get to the point.".to_string(),
|
0 => "Be direct and concise. Skip pleasantries. Get to the point.".to_string(),
|
||||||
1 => "Be friendly and casual. Warm but efficient.".to_string(),
|
1 => "Be friendly, human, and conversational. Show warmth and empathy while staying efficient. Use natural contractions.".to_string(),
|
||||||
2 => "Be technical and detailed. Thorough explanations, code-first.".to_string(),
|
2 => "Be professional and polished. Stay calm, structured, and respectful. Use occasional tone-setting emojis only when appropriate.".to_string(),
|
||||||
_ => {
|
3 => "Be expressive and playful when appropriate. Use relevant emojis naturally (0-2 max), and keep serious topics emoji-light.".to_string(),
|
||||||
"Adapt to the situation. Be concise when needed, thorough when it matters.".to_string()
|
4 => "Be technical and detailed. Thorough explanations, code-first.".to_string(),
|
||||||
}
|
5 => "Adapt to the situation. Default to warm and clear communication; be concise when needed, thorough when it matters.".to_string(),
|
||||||
|
_ => Input::new()
|
||||||
|
.with_prompt(" Custom communication style")
|
||||||
|
.default(
|
||||||
|
"Be warm, natural, and clear. Use occasional relevant emojis (1-2 max) and avoid robotic phrasing.".into(),
|
||||||
|
)
|
||||||
|
.interact_text()?,
|
||||||
};
|
};
|
||||||
|
|
||||||
println!(
|
println!(
|
||||||
|
|
@ -931,6 +1002,7 @@ fn setup_channels() -> Result<ChannelsConfig> {
|
||||||
webhook: None,
|
webhook: None,
|
||||||
imessage: None,
|
imessage: None,
|
||||||
matrix: None,
|
matrix: None,
|
||||||
|
whatsapp: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
|
|
@ -975,6 +1047,14 @@ fn setup_channels() -> Result<ChannelsConfig> {
|
||||||
"— self-hosted chat"
|
"— self-hosted chat"
|
||||||
}
|
}
|
||||||
),
|
),
|
||||||
|
format!(
|
||||||
|
"WhatsApp {}",
|
||||||
|
if config.whatsapp.is_some() {
|
||||||
|
"✅ connected"
|
||||||
|
} else {
|
||||||
|
"— Business Cloud API"
|
||||||
|
}
|
||||||
|
),
|
||||||
format!(
|
format!(
|
||||||
"Webhook {}",
|
"Webhook {}",
|
||||||
if config.webhook.is_some() {
|
if config.webhook.is_some() {
|
||||||
|
|
@ -989,7 +1069,7 @@ fn setup_channels() -> Result<ChannelsConfig> {
|
||||||
let choice = Select::new()
|
let choice = Select::new()
|
||||||
.with_prompt(" Connect a channel (or Done to continue)")
|
.with_prompt(" Connect a channel (or Done to continue)")
|
||||||
.items(&options)
|
.items(&options)
|
||||||
.default(6)
|
.default(7)
|
||||||
.interact()?;
|
.interact()?;
|
||||||
|
|
||||||
match choice {
|
match choice {
|
||||||
|
|
@ -1041,17 +1121,38 @@ fn setup_channels() -> Result<ChannelsConfig> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
print_bullet(
|
||||||
|
"Allowlist your own Telegram identity first (recommended for secure + fast setup).",
|
||||||
|
);
|
||||||
|
print_bullet(
|
||||||
|
"Use your @username without '@' (example: argenis), or your numeric Telegram user ID.",
|
||||||
|
);
|
||||||
|
print_bullet("Use '*' only for temporary open testing.");
|
||||||
|
|
||||||
let users_str: String = Input::new()
|
let users_str: String = Input::new()
|
||||||
.with_prompt(" Allowed usernames (comma-separated, or * for all)")
|
.with_prompt(
|
||||||
.default("*".into())
|
" Allowed Telegram identities (comma-separated: username without '@' and/or numeric user ID, '*' for all)",
|
||||||
|
)
|
||||||
|
.allow_empty(true)
|
||||||
.interact_text()?;
|
.interact_text()?;
|
||||||
|
|
||||||
let allowed_users = if users_str.trim() == "*" {
|
let allowed_users = if users_str.trim() == "*" {
|
||||||
vec!["*".into()]
|
vec!["*".into()]
|
||||||
} else {
|
} else {
|
||||||
users_str.split(',').map(|s| s.trim().to_string()).collect()
|
users_str
|
||||||
|
.split(',')
|
||||||
|
.map(|s| s.trim().to_string())
|
||||||
|
.filter(|s| !s.is_empty())
|
||||||
|
.collect()
|
||||||
};
|
};
|
||||||
|
|
||||||
|
if allowed_users.is_empty() {
|
||||||
|
println!(
|
||||||
|
" {} No users allowlisted — Telegram inbound messages will be denied until you add your username/user ID or '*'.",
|
||||||
|
style("⚠").yellow().bold()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
config.telegram = Some(TelegramConfig {
|
config.telegram = Some(TelegramConfig {
|
||||||
bot_token: token,
|
bot_token: token,
|
||||||
allowed_users,
|
allowed_users,
|
||||||
|
|
@ -1111,9 +1212,15 @@ fn setup_channels() -> Result<ChannelsConfig> {
|
||||||
.allow_empty(true)
|
.allow_empty(true)
|
||||||
.interact_text()?;
|
.interact_text()?;
|
||||||
|
|
||||||
|
print_bullet("Allowlist your own Discord user ID first (recommended).");
|
||||||
|
print_bullet(
|
||||||
|
"Get it in Discord: Settings -> Advanced -> Developer Mode (ON), then right-click your profile -> Copy User ID.",
|
||||||
|
);
|
||||||
|
print_bullet("Use '*' only for temporary open testing.");
|
||||||
|
|
||||||
let allowed_users_str: String = Input::new()
|
let allowed_users_str: String = Input::new()
|
||||||
.with_prompt(
|
.with_prompt(
|
||||||
" Allowed Discord user IDs (comma-separated, '*' for all, Enter to deny all)",
|
" Allowed Discord user IDs (comma-separated, recommended: your own ID, '*' for all)",
|
||||||
)
|
)
|
||||||
.allow_empty(true)
|
.allow_empty(true)
|
||||||
.interact_text()?;
|
.interact_text()?;
|
||||||
|
|
@ -1214,9 +1321,15 @@ fn setup_channels() -> Result<ChannelsConfig> {
|
||||||
.allow_empty(true)
|
.allow_empty(true)
|
||||||
.interact_text()?;
|
.interact_text()?;
|
||||||
|
|
||||||
|
print_bullet("Allowlist your own Slack member ID first (recommended).");
|
||||||
|
print_bullet(
|
||||||
|
"Member IDs usually start with 'U' (open your Slack profile -> More -> Copy member ID).",
|
||||||
|
);
|
||||||
|
print_bullet("Use '*' only for temporary open testing.");
|
||||||
|
|
||||||
let allowed_users_str: String = Input::new()
|
let allowed_users_str: String = Input::new()
|
||||||
.with_prompt(
|
.with_prompt(
|
||||||
" Allowed Slack user IDs (comma-separated, '*' for all, Enter to deny all)",
|
" Allowed Slack user IDs (comma-separated, recommended: your own member ID, '*' for all)",
|
||||||
)
|
)
|
||||||
.allow_empty(true)
|
.allow_empty(true)
|
||||||
.interact_text()?;
|
.interact_text()?;
|
||||||
|
|
@ -1378,6 +1491,90 @@ fn setup_channels() -> Result<ChannelsConfig> {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
5 => {
|
5 => {
|
||||||
|
// ── WhatsApp ──
|
||||||
|
println!();
|
||||||
|
println!(
|
||||||
|
" {} {}",
|
||||||
|
style("WhatsApp Setup").white().bold(),
|
||||||
|
style("— Business Cloud API").dim()
|
||||||
|
);
|
||||||
|
print_bullet("1. Go to developers.facebook.com and create a WhatsApp app");
|
||||||
|
print_bullet("2. Add the WhatsApp product and get your phone number ID");
|
||||||
|
print_bullet("3. Generate a temporary access token (System User)");
|
||||||
|
print_bullet("4. Configure webhook URL to: https://your-domain/whatsapp");
|
||||||
|
println!();
|
||||||
|
|
||||||
|
let access_token: String = Input::new()
|
||||||
|
.with_prompt(" Access token (from Meta Developers)")
|
||||||
|
.interact_text()?;
|
||||||
|
|
||||||
|
if access_token.trim().is_empty() {
|
||||||
|
println!(" {} Skipped", style("→").dim());
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let phone_number_id: String = Input::new()
|
||||||
|
.with_prompt(" Phone number ID (from WhatsApp app settings)")
|
||||||
|
.interact_text()?;
|
||||||
|
|
||||||
|
if phone_number_id.trim().is_empty() {
|
||||||
|
println!(" {} Skipped — phone number ID required", style("→").dim());
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let verify_token: String = Input::new()
|
||||||
|
.with_prompt(" Webhook verify token (create your own)")
|
||||||
|
.default("zeroclaw-whatsapp-verify".into())
|
||||||
|
.interact_text()?;
|
||||||
|
|
||||||
|
// Test connection
|
||||||
|
print!(" {} Testing connection... ", style("⏳").dim());
|
||||||
|
let client = reqwest::blocking::Client::new();
|
||||||
|
let url = format!(
|
||||||
|
"https://graph.facebook.com/v18.0/{}",
|
||||||
|
phone_number_id.trim()
|
||||||
|
);
|
||||||
|
match client
|
||||||
|
.get(&url)
|
||||||
|
.header("Authorization", format!("Bearer {}", access_token.trim()))
|
||||||
|
.send()
|
||||||
|
{
|
||||||
|
Ok(resp) if resp.status().is_success() => {
|
||||||
|
println!(
|
||||||
|
"\r {} Connected to WhatsApp API ",
|
||||||
|
style("✅").green().bold()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
_ => {
|
||||||
|
println!(
|
||||||
|
"\r {} Connection failed — check access token and phone number ID",
|
||||||
|
style("❌").red().bold()
|
||||||
|
);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let users_str: String = Input::new()
|
||||||
|
.with_prompt(
|
||||||
|
" Allowed phone numbers (comma-separated +1234567890, or * for all)",
|
||||||
|
)
|
||||||
|
.default("*".into())
|
||||||
|
.interact_text()?;
|
||||||
|
|
||||||
|
let allowed_numbers = if users_str.trim() == "*" {
|
||||||
|
vec!["*".into()]
|
||||||
|
} else {
|
||||||
|
users_str.split(',').map(|s| s.trim().to_string()).collect()
|
||||||
|
};
|
||||||
|
|
||||||
|
config.whatsapp = Some(WhatsAppConfig {
|
||||||
|
access_token: access_token.trim().to_string(),
|
||||||
|
phone_number_id: phone_number_id.trim().to_string(),
|
||||||
|
verify_token: verify_token.trim().to_string(),
|
||||||
|
allowed_numbers,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
6 => {
|
||||||
// ── Webhook ──
|
// ── Webhook ──
|
||||||
println!();
|
println!();
|
||||||
println!(
|
println!(
|
||||||
|
|
@ -1432,6 +1629,9 @@ fn setup_channels() -> Result<ChannelsConfig> {
|
||||||
if config.matrix.is_some() {
|
if config.matrix.is_some() {
|
||||||
active.push("Matrix");
|
active.push("Matrix");
|
||||||
}
|
}
|
||||||
|
if config.whatsapp.is_some() {
|
||||||
|
active.push("WhatsApp");
|
||||||
|
}
|
||||||
if config.webhook.is_some() {
|
if config.webhook.is_some() {
|
||||||
active.push("Webhook");
|
active.push("Webhook");
|
||||||
}
|
}
|
||||||
|
|
@ -1618,7 +1818,7 @@ fn scaffold_workspace(workspace_dir: &Path, ctx: &ProjectContext) -> Result<()>
|
||||||
&ctx.timezone
|
&ctx.timezone
|
||||||
};
|
};
|
||||||
let comm_style = if ctx.communication_style.is_empty() {
|
let comm_style = if ctx.communication_style.is_empty() {
|
||||||
"Adapt to the situation. Be concise when needed, thorough when it matters."
|
"Be warm, natural, and clear. Use occasional relevant emojis (1-2 max) and avoid robotic phrasing."
|
||||||
} else {
|
} else {
|
||||||
&ctx.communication_style
|
&ctx.communication_style
|
||||||
};
|
};
|
||||||
|
|
@ -1667,6 +1867,14 @@ fn scaffold_workspace(workspace_dir: &Path, ctx: &ProjectContext) -> Result<()>
|
||||||
## Tools & Skills\n\n\
|
## Tools & Skills\n\n\
|
||||||
Skills are listed in the system prompt. Use `read` on a skill's SKILL.md for details.\n\
|
Skills are listed in the system prompt. Use `read` on a skill's SKILL.md for details.\n\
|
||||||
Keep local notes (SSH hosts, device names, etc.) in `TOOLS.md`.\n\n\
|
Keep local notes (SSH hosts, device names, etc.) in `TOOLS.md`.\n\n\
|
||||||
|
## Crash Recovery\n\n\
|
||||||
|
- If a run stops unexpectedly, recover context before acting.\n\
|
||||||
|
- Check `MEMORY.md` + latest `memory/*.md` notes to avoid duplicate work.\n\
|
||||||
|
- Resume from the last confirmed step, not from scratch.\n\n\
|
||||||
|
## Sub-task Scoping\n\n\
|
||||||
|
- Break complex work into focused sub-tasks with clear success criteria.\n\
|
||||||
|
- Keep sub-tasks small, verify each output, then merge results.\n\
|
||||||
|
- Prefer one clear objective per sub-task over broad \"do everything\" asks.\n\n\
|
||||||
## Make It Yours\n\n\
|
## Make It Yours\n\n\
|
||||||
This is a starting point. Add your own conventions, style, and rules.\n"
|
This is a starting point. Add your own conventions, style, and rules.\n"
|
||||||
);
|
);
|
||||||
|
|
@ -1704,6 +1912,11 @@ fn scaffold_workspace(workspace_dir: &Path, ctx: &ProjectContext) -> Result<()>
|
||||||
- Always introduce yourself as {agent} if asked\n\n\
|
- Always introduce yourself as {agent} if asked\n\n\
|
||||||
## Communication\n\n\
|
## Communication\n\n\
|
||||||
{comm_style}\n\n\
|
{comm_style}\n\n\
|
||||||
|
- Sound like a real person, not a support script.\n\
|
||||||
|
- Mirror the user's energy: calm when serious, upbeat when casual.\n\
|
||||||
|
- Use emojis naturally (0-2 max when they help tone, not every sentence).\n\
|
||||||
|
- Match emoji density to the user. Formal user => minimal/no emojis.\n\
|
||||||
|
- Prefer specific, grounded phrasing over generic filler.\n\n\
|
||||||
## Boundaries\n\n\
|
## Boundaries\n\n\
|
||||||
- Private things stay private. Period.\n\
|
- Private things stay private. Period.\n\
|
||||||
- When in doubt, ask before acting externally.\n\
|
- When in doubt, ask before acting externally.\n\
|
||||||
|
|
@ -1744,11 +1957,23 @@ fn scaffold_workspace(workspace_dir: &Path, ctx: &ProjectContext) -> Result<()>
|
||||||
- Anything environment-specific\n\n\
|
- Anything environment-specific\n\n\
|
||||||
## Built-in Tools\n\n\
|
## Built-in Tools\n\n\
|
||||||
- **shell** — Execute terminal commands\n\
|
- **shell** — Execute terminal commands\n\
|
||||||
|
- Use when: running local checks, build/test commands, or diagnostics.\n\
|
||||||
|
- Don't use when: a safer dedicated tool exists, or command is destructive without approval.\n\
|
||||||
- **file_read** — Read file contents\n\
|
- **file_read** — Read file contents\n\
|
||||||
|
- Use when: inspecting project files, configs, or logs.\n\
|
||||||
|
- Don't use when: you only need a quick string search (prefer targeted search first).\n\
|
||||||
- **file_write** — Write file contents\n\
|
- **file_write** — Write file contents\n\
|
||||||
|
- Use when: applying focused edits, scaffolding files, or updating docs/code.\n\
|
||||||
|
- Don't use when: unsure about side effects or when the file should remain user-owned.\n\
|
||||||
- **memory_store** — Save to memory\n\
|
- **memory_store** — Save to memory\n\
|
||||||
|
- Use when: preserving durable preferences, decisions, or key context.\n\
|
||||||
|
- Don't use when: info is transient, noisy, or sensitive without explicit need.\n\
|
||||||
- **memory_recall** — Search memory\n\
|
- **memory_recall** — Search memory\n\
|
||||||
- **memory_forget** — Delete a memory entry\n\n\
|
- Use when: you need prior decisions, user preferences, or historical context.\n\
|
||||||
|
- Don't use when: the answer is already in current files/conversation.\n\
|
||||||
|
- **memory_forget** — Delete a memory entry\n\
|
||||||
|
- Use when: memory is incorrect, stale, or explicitly requested to be removed.\n\
|
||||||
|
- Don't use when: uncertain about impact; verify before deleting.\n\n\
|
||||||
---\n\
|
---\n\
|
||||||
*Add whatever helps you do your job. This is your cheat sheet.*\n";
|
*Add whatever helps you do your job. This is your cheat sheet.*\n";
|
||||||
|
|
||||||
|
|
@ -2242,7 +2467,7 @@ mod tests {
|
||||||
|
|
||||||
let soul = fs::read_to_string(tmp.path().join("SOUL.md")).unwrap();
|
let soul = fs::read_to_string(tmp.path().join("SOUL.md")).unwrap();
|
||||||
assert!(
|
assert!(
|
||||||
soul.contains("Adapt to the situation"),
|
soul.contains("Be warm, natural, and clear."),
|
||||||
"should default communication style"
|
"should default communication style"
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
@ -2383,6 +2608,31 @@ mod tests {
|
||||||
"TOOLS.md should list built-in tool: {tool}"
|
"TOOLS.md should list built-in tool: {tool}"
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
assert!(
|
||||||
|
tools.contains("Use when:"),
|
||||||
|
"TOOLS.md should include 'Use when' guidance"
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
tools.contains("Don't use when:"),
|
||||||
|
"TOOLS.md should include 'Don't use when' guidance"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn soul_md_includes_emoji_awareness_guidance() {
|
||||||
|
let tmp = TempDir::new().unwrap();
|
||||||
|
let ctx = ProjectContext::default();
|
||||||
|
scaffold_workspace(tmp.path(), &ctx).unwrap();
|
||||||
|
|
||||||
|
let soul = fs::read_to_string(tmp.path().join("SOUL.md")).unwrap();
|
||||||
|
assert!(
|
||||||
|
soul.contains("Use emojis naturally (0-2 max"),
|
||||||
|
"SOUL.md should include emoji usage guidance"
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
soul.contains("Match emoji density to the user"),
|
||||||
|
"SOUL.md should include emoji-awareness guidance"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
// ── scaffold_workspace: special characters in names ─────────
|
// ── scaffold_workspace: special characters in names ─────────
|
||||||
|
|
@ -2414,7 +2664,9 @@ mod tests {
|
||||||
user_name: "Argenis".into(),
|
user_name: "Argenis".into(),
|
||||||
timezone: "US/Eastern".into(),
|
timezone: "US/Eastern".into(),
|
||||||
agent_name: "Claw".into(),
|
agent_name: "Claw".into(),
|
||||||
communication_style: "Be friendly and casual. Warm but efficient.".into(),
|
communication_style:
|
||||||
|
"Be friendly, human, and conversational. Show warmth and empathy while staying efficient. Use natural contractions."
|
||||||
|
.into(),
|
||||||
};
|
};
|
||||||
scaffold_workspace(tmp.path(), &ctx).unwrap();
|
scaffold_workspace(tmp.path(), &ctx).unwrap();
|
||||||
|
|
||||||
|
|
@ -2424,12 +2676,12 @@ mod tests {
|
||||||
|
|
||||||
let soul = fs::read_to_string(tmp.path().join("SOUL.md")).unwrap();
|
let soul = fs::read_to_string(tmp.path().join("SOUL.md")).unwrap();
|
||||||
assert!(soul.contains("You are **Claw**"));
|
assert!(soul.contains("You are **Claw**"));
|
||||||
assert!(soul.contains("Be friendly and casual"));
|
assert!(soul.contains("Be friendly, human, and conversational"));
|
||||||
|
|
||||||
let user_md = fs::read_to_string(tmp.path().join("USER.md")).unwrap();
|
let user_md = fs::read_to_string(tmp.path().join("USER.md")).unwrap();
|
||||||
assert!(user_md.contains("**Name:** Argenis"));
|
assert!(user_md.contains("**Name:** Argenis"));
|
||||||
assert!(user_md.contains("**Timezone:** US/Eastern"));
|
assert!(user_md.contains("**Timezone:** US/Eastern"));
|
||||||
assert!(user_md.contains("Be friendly and casual"));
|
assert!(user_md.contains("Be friendly, human, and conversational"));
|
||||||
|
|
||||||
let agents = fs::read_to_string(tmp.path().join("AGENTS.md")).unwrap();
|
let agents = fs::read_to_string(tmp.path().join("AGENTS.md")).unwrap();
|
||||||
assert!(agents.contains("Claw Personal Assistant"));
|
assert!(agents.contains("Claw Personal Assistant"));
|
||||||
|
|
|
||||||
|
|
@ -4,11 +4,13 @@ pub mod gemini;
|
||||||
pub mod ollama;
|
pub mod ollama;
|
||||||
pub mod openai;
|
pub mod openai;
|
||||||
pub mod openrouter;
|
pub mod openrouter;
|
||||||
|
pub mod reliable;
|
||||||
pub mod traits;
|
pub mod traits;
|
||||||
|
|
||||||
pub use traits::Provider;
|
pub use traits::Provider;
|
||||||
|
|
||||||
use compatible::{AuthStyle, OpenAiCompatibleProvider};
|
use compatible::{AuthStyle, OpenAiCompatibleProvider};
|
||||||
|
use reliable::ReliableProvider;
|
||||||
|
|
||||||
/// Factory: create the right provider from config
|
/// Factory: create the right provider from config
|
||||||
#[allow(clippy::too_many_lines)]
|
#[allow(clippy::too_many_lines)]
|
||||||
|
|
@ -114,6 +116,42 @@ pub fn create_provider(name: &str, api_key: Option<&str>) -> anyhow::Result<Box<
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Create provider chain with retry and fallback behavior.
|
||||||
|
pub fn create_resilient_provider(
|
||||||
|
primary_name: &str,
|
||||||
|
api_key: Option<&str>,
|
||||||
|
reliability: &crate::config::ReliabilityConfig,
|
||||||
|
) -> anyhow::Result<Box<dyn Provider>> {
|
||||||
|
let mut providers: Vec<(String, Box<dyn Provider>)> = Vec::new();
|
||||||
|
|
||||||
|
providers.push((
|
||||||
|
primary_name.to_string(),
|
||||||
|
create_provider(primary_name, api_key)?,
|
||||||
|
));
|
||||||
|
|
||||||
|
for fallback in &reliability.fallback_providers {
|
||||||
|
if fallback == primary_name || providers.iter().any(|(name, _)| name == fallback) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
match create_provider(fallback, api_key) {
|
||||||
|
Ok(provider) => providers.push((fallback.clone(), provider)),
|
||||||
|
Err(e) => {
|
||||||
|
tracing::warn!(
|
||||||
|
fallback_provider = fallback,
|
||||||
|
"Ignoring invalid fallback provider: {e}"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(Box::new(ReliableProvider::new(
|
||||||
|
providers,
|
||||||
|
reliability.provider_retries,
|
||||||
|
reliability.provider_backoff_ms,
|
||||||
|
)))
|
||||||
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
@ -307,6 +345,34 @@ mod tests {
|
||||||
assert!(create_provider("", None).is_err());
|
assert!(create_provider("", None).is_err());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn resilient_provider_ignores_duplicate_and_invalid_fallbacks() {
|
||||||
|
let reliability = crate::config::ReliabilityConfig {
|
||||||
|
provider_retries: 1,
|
||||||
|
provider_backoff_ms: 100,
|
||||||
|
fallback_providers: vec![
|
||||||
|
"openrouter".into(),
|
||||||
|
"nonexistent-provider".into(),
|
||||||
|
"openai".into(),
|
||||||
|
"openai".into(),
|
||||||
|
],
|
||||||
|
channel_initial_backoff_secs: 2,
|
||||||
|
channel_max_backoff_secs: 60,
|
||||||
|
scheduler_poll_secs: 15,
|
||||||
|
scheduler_retries: 2,
|
||||||
|
};
|
||||||
|
|
||||||
|
let provider = create_resilient_provider("openrouter", Some("sk-test"), &reliability);
|
||||||
|
assert!(provider.is_ok());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn resilient_provider_errors_for_invalid_primary() {
|
||||||
|
let reliability = crate::config::ReliabilityConfig::default();
|
||||||
|
let provider = create_resilient_provider("totally-invalid", Some("sk-test"), &reliability);
|
||||||
|
assert!(provider.is_err());
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn factory_all_providers_create_successfully() {
|
fn factory_all_providers_create_successfully() {
|
||||||
let providers = [
|
let providers = [
|
||||||
|
|
|
||||||
229
src/providers/reliable.rs
Normal file
229
src/providers/reliable.rs
Normal file
|
|
@ -0,0 +1,229 @@
|
||||||
|
use super::Provider;
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
|
/// Provider wrapper with retry + fallback behavior.
|
||||||
|
pub struct ReliableProvider {
|
||||||
|
providers: Vec<(String, Box<dyn Provider>)>,
|
||||||
|
max_retries: u32,
|
||||||
|
base_backoff_ms: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ReliableProvider {
|
||||||
|
pub fn new(
|
||||||
|
providers: Vec<(String, Box<dyn Provider>)>,
|
||||||
|
max_retries: u32,
|
||||||
|
base_backoff_ms: u64,
|
||||||
|
) -> Self {
|
||||||
|
Self {
|
||||||
|
providers,
|
||||||
|
max_retries,
|
||||||
|
base_backoff_ms: base_backoff_ms.max(50),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Provider for ReliableProvider {
|
||||||
|
async fn chat_with_system(
|
||||||
|
&self,
|
||||||
|
system_prompt: Option<&str>,
|
||||||
|
message: &str,
|
||||||
|
model: &str,
|
||||||
|
temperature: f64,
|
||||||
|
) -> anyhow::Result<String> {
|
||||||
|
let mut failures = Vec::new();
|
||||||
|
|
||||||
|
for (provider_name, provider) in &self.providers {
|
||||||
|
let mut backoff_ms = self.base_backoff_ms;
|
||||||
|
|
||||||
|
for attempt in 0..=self.max_retries {
|
||||||
|
match provider
|
||||||
|
.chat_with_system(system_prompt, message, model, temperature)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(resp) => {
|
||||||
|
if attempt > 0 {
|
||||||
|
tracing::info!(
|
||||||
|
provider = provider_name,
|
||||||
|
attempt,
|
||||||
|
"Provider recovered after retries"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
return Ok(resp);
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
failures.push(format!(
|
||||||
|
"{provider_name} attempt {}/{}: {e}",
|
||||||
|
attempt + 1,
|
||||||
|
self.max_retries + 1
|
||||||
|
));
|
||||||
|
|
||||||
|
if attempt < self.max_retries {
|
||||||
|
tracing::warn!(
|
||||||
|
provider = provider_name,
|
||||||
|
attempt = attempt + 1,
|
||||||
|
max_retries = self.max_retries,
|
||||||
|
"Provider call failed, retrying"
|
||||||
|
);
|
||||||
|
tokio::time::sleep(Duration::from_millis(backoff_ms)).await;
|
||||||
|
backoff_ms = (backoff_ms.saturating_mul(2)).min(10_000);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tracing::warn!(provider = provider_name, "Switching to fallback provider");
|
||||||
|
}
|
||||||
|
|
||||||
|
anyhow::bail!("All providers failed. Attempts:\n{}", failures.join("\n"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
struct MockProvider {
|
||||||
|
calls: Arc<AtomicUsize>,
|
||||||
|
fail_until_attempt: usize,
|
||||||
|
response: &'static str,
|
||||||
|
error: &'static str,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Provider for MockProvider {
|
||||||
|
async fn chat_with_system(
|
||||||
|
&self,
|
||||||
|
_system_prompt: Option<&str>,
|
||||||
|
_message: &str,
|
||||||
|
_model: &str,
|
||||||
|
_temperature: f64,
|
||||||
|
) -> anyhow::Result<String> {
|
||||||
|
let attempt = self.calls.fetch_add(1, Ordering::SeqCst) + 1;
|
||||||
|
if attempt <= self.fail_until_attempt {
|
||||||
|
anyhow::bail!(self.error);
|
||||||
|
}
|
||||||
|
Ok(self.response.to_string())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn succeeds_without_retry() {
|
||||||
|
let calls = Arc::new(AtomicUsize::new(0));
|
||||||
|
let provider = ReliableProvider::new(
|
||||||
|
vec![(
|
||||||
|
"primary".into(),
|
||||||
|
Box::new(MockProvider {
|
||||||
|
calls: Arc::clone(&calls),
|
||||||
|
fail_until_attempt: 0,
|
||||||
|
response: "ok",
|
||||||
|
error: "boom",
|
||||||
|
}),
|
||||||
|
)],
|
||||||
|
2,
|
||||||
|
1,
|
||||||
|
);
|
||||||
|
|
||||||
|
let result = provider.chat("hello", "test", 0.0).await.unwrap();
|
||||||
|
assert_eq!(result, "ok");
|
||||||
|
assert_eq!(calls.load(Ordering::SeqCst), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn retries_then_recovers() {
|
||||||
|
let calls = Arc::new(AtomicUsize::new(0));
|
||||||
|
let provider = ReliableProvider::new(
|
||||||
|
vec![(
|
||||||
|
"primary".into(),
|
||||||
|
Box::new(MockProvider {
|
||||||
|
calls: Arc::clone(&calls),
|
||||||
|
fail_until_attempt: 1,
|
||||||
|
response: "recovered",
|
||||||
|
error: "temporary",
|
||||||
|
}),
|
||||||
|
)],
|
||||||
|
2,
|
||||||
|
1,
|
||||||
|
);
|
||||||
|
|
||||||
|
let result = provider.chat("hello", "test", 0.0).await.unwrap();
|
||||||
|
assert_eq!(result, "recovered");
|
||||||
|
assert_eq!(calls.load(Ordering::SeqCst), 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn falls_back_after_retries_exhausted() {
|
||||||
|
let primary_calls = Arc::new(AtomicUsize::new(0));
|
||||||
|
let fallback_calls = Arc::new(AtomicUsize::new(0));
|
||||||
|
|
||||||
|
let provider = ReliableProvider::new(
|
||||||
|
vec![
|
||||||
|
(
|
||||||
|
"primary".into(),
|
||||||
|
Box::new(MockProvider {
|
||||||
|
calls: Arc::clone(&primary_calls),
|
||||||
|
fail_until_attempt: usize::MAX,
|
||||||
|
response: "never",
|
||||||
|
error: "primary down",
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"fallback".into(),
|
||||||
|
Box::new(MockProvider {
|
||||||
|
calls: Arc::clone(&fallback_calls),
|
||||||
|
fail_until_attempt: 0,
|
||||||
|
response: "from fallback",
|
||||||
|
error: "fallback down",
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
1,
|
||||||
|
1,
|
||||||
|
);
|
||||||
|
|
||||||
|
let result = provider.chat("hello", "test", 0.0).await.unwrap();
|
||||||
|
assert_eq!(result, "from fallback");
|
||||||
|
assert_eq!(primary_calls.load(Ordering::SeqCst), 2);
|
||||||
|
assert_eq!(fallback_calls.load(Ordering::SeqCst), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn returns_aggregated_error_when_all_providers_fail() {
|
||||||
|
let provider = ReliableProvider::new(
|
||||||
|
vec![
|
||||||
|
(
|
||||||
|
"p1".into(),
|
||||||
|
Box::new(MockProvider {
|
||||||
|
calls: Arc::new(AtomicUsize::new(0)),
|
||||||
|
fail_until_attempt: usize::MAX,
|
||||||
|
response: "never",
|
||||||
|
error: "p1 error",
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"p2".into(),
|
||||||
|
Box::new(MockProvider {
|
||||||
|
calls: Arc::new(AtomicUsize::new(0)),
|
||||||
|
fail_until_attempt: usize::MAX,
|
||||||
|
response: "never",
|
||||||
|
error: "p2 error",
|
||||||
|
}),
|
||||||
|
),
|
||||||
|
],
|
||||||
|
0,
|
||||||
|
1,
|
||||||
|
);
|
||||||
|
|
||||||
|
let err = provider
|
||||||
|
.chat("hello", "test", 0.0)
|
||||||
|
.await
|
||||||
|
.expect_err("all providers should fail");
|
||||||
|
let msg = err.to_string();
|
||||||
|
assert!(msg.contains("All providers failed"));
|
||||||
|
assert!(msg.contains("p1 attempt 1/1"));
|
||||||
|
assert!(msg.contains("p2 attempt 1/1"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -7,17 +7,21 @@ pub use traits::RuntimeAdapter;
|
||||||
use crate::config::RuntimeConfig;
|
use crate::config::RuntimeConfig;
|
||||||
|
|
||||||
/// Factory: create the right runtime from config
|
/// Factory: create the right runtime from config
|
||||||
pub fn create_runtime(config: &RuntimeConfig) -> Box<dyn RuntimeAdapter> {
|
pub fn create_runtime(config: &RuntimeConfig) -> anyhow::Result<Box<dyn RuntimeAdapter>> {
|
||||||
match config.kind.as_str() {
|
match config.kind.as_str() {
|
||||||
"native" | "docker" => Box::new(NativeRuntime::new()),
|
"native" => Ok(Box::new(NativeRuntime::new())),
|
||||||
"cloudflare" => {
|
"docker" => anyhow::bail!(
|
||||||
tracing::warn!("Cloudflare runtime not yet implemented, falling back to native");
|
"runtime.kind='docker' is not implemented yet. Use runtime.kind='native' until container runtime support lands."
|
||||||
Box::new(NativeRuntime::new())
|
),
|
||||||
}
|
"cloudflare" => anyhow::bail!(
|
||||||
_ => {
|
"runtime.kind='cloudflare' is not implemented yet. Use runtime.kind='native' for now."
|
||||||
tracing::warn!("Unknown runtime '{}', falling back to native", config.kind);
|
),
|
||||||
Box::new(NativeRuntime::new())
|
other if other.trim().is_empty() => anyhow::bail!(
|
||||||
}
|
"runtime.kind cannot be empty. Supported values: native"
|
||||||
|
),
|
||||||
|
other => anyhow::bail!(
|
||||||
|
"Unknown runtime kind '{other}'. Supported values: native"
|
||||||
|
),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -30,44 +34,52 @@ mod tests {
|
||||||
let cfg = RuntimeConfig {
|
let cfg = RuntimeConfig {
|
||||||
kind: "native".into(),
|
kind: "native".into(),
|
||||||
};
|
};
|
||||||
let rt = create_runtime(&cfg);
|
let rt = create_runtime(&cfg).unwrap();
|
||||||
assert_eq!(rt.name(), "native");
|
assert_eq!(rt.name(), "native");
|
||||||
assert!(rt.has_shell_access());
|
assert!(rt.has_shell_access());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn factory_docker_returns_native() {
|
fn factory_docker_errors() {
|
||||||
let cfg = RuntimeConfig {
|
let cfg = RuntimeConfig {
|
||||||
kind: "docker".into(),
|
kind: "docker".into(),
|
||||||
};
|
};
|
||||||
let rt = create_runtime(&cfg);
|
match create_runtime(&cfg) {
|
||||||
assert_eq!(rt.name(), "native");
|
Err(err) => assert!(err.to_string().contains("not implemented")),
|
||||||
|
Ok(_) => panic!("docker runtime should error"),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn factory_cloudflare_falls_back() {
|
fn factory_cloudflare_errors() {
|
||||||
let cfg = RuntimeConfig {
|
let cfg = RuntimeConfig {
|
||||||
kind: "cloudflare".into(),
|
kind: "cloudflare".into(),
|
||||||
};
|
};
|
||||||
let rt = create_runtime(&cfg);
|
match create_runtime(&cfg) {
|
||||||
assert_eq!(rt.name(), "native");
|
Err(err) => assert!(err.to_string().contains("not implemented")),
|
||||||
|
Ok(_) => panic!("cloudflare runtime should error"),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn factory_unknown_falls_back() {
|
fn factory_unknown_errors() {
|
||||||
let cfg = RuntimeConfig {
|
let cfg = RuntimeConfig {
|
||||||
kind: "wasm-edge-unknown".into(),
|
kind: "wasm-edge-unknown".into(),
|
||||||
};
|
};
|
||||||
let rt = create_runtime(&cfg);
|
match create_runtime(&cfg) {
|
||||||
assert_eq!(rt.name(), "native");
|
Err(err) => assert!(err.to_string().contains("Unknown runtime kind")),
|
||||||
|
Ok(_) => panic!("unknown runtime should error"),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn factory_empty_falls_back() {
|
fn factory_empty_errors() {
|
||||||
let cfg = RuntimeConfig {
|
let cfg = RuntimeConfig {
|
||||||
kind: String::new(),
|
kind: String::new(),
|
||||||
};
|
};
|
||||||
let rt = create_runtime(&cfg);
|
match create_runtime(&cfg) {
|
||||||
assert_eq!(rt.name(), "native");
|
Err(err) => assert!(err.to_string().contains("cannot be empty")),
|
||||||
|
Ok(_) => panic!("empty runtime should error"),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -258,8 +258,14 @@ impl SecurityPolicy {
|
||||||
/// Validate that a resolved path is still inside the workspace.
|
/// Validate that a resolved path is still inside the workspace.
|
||||||
/// Call this AFTER joining `workspace_dir` + relative path and canonicalizing.
|
/// Call this AFTER joining `workspace_dir` + relative path and canonicalizing.
|
||||||
pub fn is_resolved_path_allowed(&self, resolved: &Path) -> bool {
|
pub fn is_resolved_path_allowed(&self, resolved: &Path) -> bool {
|
||||||
// Must be under workspace_dir (prevents symlink escapes)
|
// Must be under workspace_dir (prevents symlink escapes).
|
||||||
resolved.starts_with(&self.workspace_dir)
|
// Prefer canonical workspace root so `/a/../b` style config paths don't
|
||||||
|
// cause false positives or negatives.
|
||||||
|
let workspace_root = self
|
||||||
|
.workspace_dir
|
||||||
|
.canonicalize()
|
||||||
|
.unwrap_or_else(|_| self.workspace_dir.clone());
|
||||||
|
resolved.starts_with(workspace_root)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Check if autonomy level permits any action at all
|
/// Check if autonomy level permits any action at all
|
||||||
|
|
|
||||||
|
|
@ -79,45 +79,94 @@ impl SecretStore {
|
||||||
/// - `enc2:` prefix → ChaCha20-Poly1305 (current format)
|
/// - `enc2:` prefix → ChaCha20-Poly1305 (current format)
|
||||||
/// - `enc:` prefix → legacy XOR cipher (backward compatibility for migration)
|
/// - `enc:` prefix → legacy XOR cipher (backward compatibility for migration)
|
||||||
/// - No prefix → returned as-is (plaintext config)
|
/// - No prefix → returned as-is (plaintext config)
|
||||||
|
///
|
||||||
|
/// **Warning**: Legacy `enc:` values are insecure. Use `decrypt_and_migrate` to
|
||||||
|
/// automatically upgrade them to the secure `enc2:` format.
|
||||||
pub fn decrypt(&self, value: &str) -> Result<String> {
|
pub fn decrypt(&self, value: &str) -> Result<String> {
|
||||||
if let Some(hex_str) = value.strip_prefix("enc2:") {
|
if let Some(hex_str) = value.strip_prefix("enc2:") {
|
||||||
let blob =
|
self.decrypt_chacha20(hex_str)
|
||||||
hex_decode(hex_str).context("Failed to decode encrypted secret (corrupt hex)")?;
|
|
||||||
anyhow::ensure!(
|
|
||||||
blob.len() > NONCE_LEN,
|
|
||||||
"Encrypted value too short (missing nonce)"
|
|
||||||
);
|
|
||||||
|
|
||||||
let (nonce_bytes, ciphertext) = blob.split_at(NONCE_LEN);
|
|
||||||
let nonce = Nonce::from_slice(nonce_bytes);
|
|
||||||
let key_bytes = self.load_or_create_key()?;
|
|
||||||
let key = Key::from_slice(&key_bytes);
|
|
||||||
let cipher = ChaCha20Poly1305::new(key);
|
|
||||||
|
|
||||||
let plaintext_bytes = cipher
|
|
||||||
.decrypt(nonce, ciphertext)
|
|
||||||
.map_err(|_| anyhow::anyhow!("Decryption failed — wrong key or tampered data"))?;
|
|
||||||
|
|
||||||
String::from_utf8(plaintext_bytes)
|
|
||||||
.context("Decrypted secret is not valid UTF-8 — corrupt data")
|
|
||||||
} else if let Some(hex_str) = value.strip_prefix("enc:") {
|
} else if let Some(hex_str) = value.strip_prefix("enc:") {
|
||||||
// Legacy XOR cipher — decrypt for backward compatibility
|
self.decrypt_legacy_xor(hex_str)
|
||||||
let ciphertext = hex_decode(hex_str)
|
|
||||||
.context("Failed to decode legacy encrypted secret (corrupt hex)")?;
|
|
||||||
let key = self.load_or_create_key()?;
|
|
||||||
let plaintext_bytes = xor_cipher(&ciphertext, &key);
|
|
||||||
String::from_utf8(plaintext_bytes)
|
|
||||||
.context("Decrypted legacy secret is not valid UTF-8 — wrong key or corrupt data")
|
|
||||||
} else {
|
} else {
|
||||||
Ok(value.to_string())
|
Ok(value.to_string())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Decrypt a secret and return a migrated `enc2:` value if the input used legacy `enc:` format.
|
||||||
|
///
|
||||||
|
/// Returns `(plaintext, Some(new_enc2_value))` if migration occurred, or
|
||||||
|
/// `(plaintext, None)` if no migration was needed.
|
||||||
|
///
|
||||||
|
/// This allows callers to persist the upgraded value back to config.
|
||||||
|
pub fn decrypt_and_migrate(&self, value: &str) -> Result<(String, Option<String>)> {
|
||||||
|
if let Some(hex_str) = value.strip_prefix("enc2:") {
|
||||||
|
// Already using secure format — no migration needed
|
||||||
|
let plaintext = self.decrypt_chacha20(hex_str)?;
|
||||||
|
Ok((plaintext, None))
|
||||||
|
} else if let Some(hex_str) = value.strip_prefix("enc:") {
|
||||||
|
// Legacy XOR cipher — decrypt and re-encrypt with ChaCha20-Poly1305
|
||||||
|
tracing::warn!(
|
||||||
|
"Decrypting legacy XOR-encrypted secret (enc: prefix). \
|
||||||
|
This format is insecure and will be removed in a future release. \
|
||||||
|
The secret will be automatically migrated to enc2: (ChaCha20-Poly1305)."
|
||||||
|
);
|
||||||
|
let plaintext = self.decrypt_legacy_xor(hex_str)?;
|
||||||
|
let migrated = self.encrypt(&plaintext)?;
|
||||||
|
Ok((plaintext, Some(migrated)))
|
||||||
|
} else {
|
||||||
|
// Plaintext — no migration needed
|
||||||
|
Ok((value.to_string(), None))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if a value uses the legacy `enc:` format that should be migrated.
|
||||||
|
pub fn needs_migration(value: &str) -> bool {
|
||||||
|
value.starts_with("enc:")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Decrypt using ChaCha20-Poly1305 (current secure format).
|
||||||
|
fn decrypt_chacha20(&self, hex_str: &str) -> Result<String> {
|
||||||
|
let blob =
|
||||||
|
hex_decode(hex_str).context("Failed to decode encrypted secret (corrupt hex)")?;
|
||||||
|
anyhow::ensure!(
|
||||||
|
blob.len() > NONCE_LEN,
|
||||||
|
"Encrypted value too short (missing nonce)"
|
||||||
|
);
|
||||||
|
|
||||||
|
let (nonce_bytes, ciphertext) = blob.split_at(NONCE_LEN);
|
||||||
|
let nonce = Nonce::from_slice(nonce_bytes);
|
||||||
|
let key_bytes = self.load_or_create_key()?;
|
||||||
|
let key = Key::from_slice(&key_bytes);
|
||||||
|
let cipher = ChaCha20Poly1305::new(key);
|
||||||
|
|
||||||
|
let plaintext_bytes = cipher
|
||||||
|
.decrypt(nonce, ciphertext)
|
||||||
|
.map_err(|_| anyhow::anyhow!("Decryption failed — wrong key or tampered data"))?;
|
||||||
|
|
||||||
|
String::from_utf8(plaintext_bytes)
|
||||||
|
.context("Decrypted secret is not valid UTF-8 — corrupt data")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Decrypt using legacy XOR cipher (insecure, for backward compatibility only).
|
||||||
|
fn decrypt_legacy_xor(&self, hex_str: &str) -> Result<String> {
|
||||||
|
let ciphertext = hex_decode(hex_str)
|
||||||
|
.context("Failed to decode legacy encrypted secret (corrupt hex)")?;
|
||||||
|
let key = self.load_or_create_key()?;
|
||||||
|
let plaintext_bytes = xor_cipher(&ciphertext, &key);
|
||||||
|
String::from_utf8(plaintext_bytes)
|
||||||
|
.context("Decrypted legacy secret is not valid UTF-8 — wrong key or corrupt data")
|
||||||
|
}
|
||||||
|
|
||||||
/// Check if a value is already encrypted (current or legacy format).
|
/// Check if a value is already encrypted (current or legacy format).
|
||||||
pub fn is_encrypted(value: &str) -> bool {
|
pub fn is_encrypted(value: &str) -> bool {
|
||||||
value.starts_with("enc2:") || value.starts_with("enc:")
|
value.starts_with("enc2:") || value.starts_with("enc:")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Check if a value uses the secure `enc2:` format.
|
||||||
|
pub fn is_secure_encrypted(value: &str) -> bool {
|
||||||
|
value.starts_with("enc2:")
|
||||||
|
}
|
||||||
|
|
||||||
/// Load the encryption key from disk, or create one if it doesn't exist.
|
/// Load the encryption key from disk, or create one if it doesn't exist.
|
||||||
fn load_or_create_key(&self) -> Result<Vec<u8>> {
|
fn load_or_create_key(&self) -> Result<Vec<u8>> {
|
||||||
if self.key_path.exists() {
|
if self.key_path.exists() {
|
||||||
|
|
@ -132,13 +181,22 @@ impl SecretStore {
|
||||||
fs::write(&self.key_path, hex_encode(&key))
|
fs::write(&self.key_path, hex_encode(&key))
|
||||||
.context("Failed to write secret key file")?;
|
.context("Failed to write secret key file")?;
|
||||||
|
|
||||||
// Set restrictive permissions (Unix only)
|
// Set restrictive permissions
|
||||||
#[cfg(unix)]
|
#[cfg(unix)]
|
||||||
{
|
{
|
||||||
use std::os::unix::fs::PermissionsExt;
|
use std::os::unix::fs::PermissionsExt;
|
||||||
fs::set_permissions(&self.key_path, fs::Permissions::from_mode(0o600))
|
fs::set_permissions(&self.key_path, fs::Permissions::from_mode(0o600))
|
||||||
.context("Failed to set key file permissions")?;
|
.context("Failed to set key file permissions")?;
|
||||||
}
|
}
|
||||||
|
#[cfg(windows)]
|
||||||
|
{
|
||||||
|
// On Windows, use icacls to restrict permissions to current user only
|
||||||
|
let _ = std::process::Command::new("icacls")
|
||||||
|
.arg(&self.key_path)
|
||||||
|
.args(["/inheritance:r", "/grant:r"])
|
||||||
|
.arg(format!("{}:F", std::env::var("USERNAME").unwrap_or_default()))
|
||||||
|
.output();
|
||||||
|
}
|
||||||
|
|
||||||
Ok(key)
|
Ok(key)
|
||||||
}
|
}
|
||||||
|
|
@ -382,6 +440,258 @@ mod tests {
|
||||||
assert_eq!(decrypted, plaintext, "Legacy XOR values must still decrypt");
|
assert_eq!(decrypted, plaintext, "Legacy XOR values must still decrypt");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ── Migration tests ─────────────────────────────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn needs_migration_detects_legacy_prefix() {
|
||||||
|
assert!(SecretStore::needs_migration("enc:aabbcc"));
|
||||||
|
assert!(!SecretStore::needs_migration("enc2:aabbcc"));
|
||||||
|
assert!(!SecretStore::needs_migration("sk-plaintext"));
|
||||||
|
assert!(!SecretStore::needs_migration(""));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn is_secure_encrypted_detects_enc2_only() {
|
||||||
|
assert!(SecretStore::is_secure_encrypted("enc2:aabbcc"));
|
||||||
|
assert!(!SecretStore::is_secure_encrypted("enc:aabbcc"));
|
||||||
|
assert!(!SecretStore::is_secure_encrypted("sk-plaintext"));
|
||||||
|
assert!(!SecretStore::is_secure_encrypted(""));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn decrypt_and_migrate_returns_none_for_enc2() {
|
||||||
|
let tmp = TempDir::new().unwrap();
|
||||||
|
let store = SecretStore::new(tmp.path(), true);
|
||||||
|
|
||||||
|
let encrypted = store.encrypt("my-secret").unwrap();
|
||||||
|
assert!(encrypted.starts_with("enc2:"));
|
||||||
|
|
||||||
|
let (plaintext, migrated) = store.decrypt_and_migrate(&encrypted).unwrap();
|
||||||
|
assert_eq!(plaintext, "my-secret");
|
||||||
|
assert!(
|
||||||
|
migrated.is_none(),
|
||||||
|
"enc2: values should not trigger migration"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn decrypt_and_migrate_returns_none_for_plaintext() {
|
||||||
|
let tmp = TempDir::new().unwrap();
|
||||||
|
let store = SecretStore::new(tmp.path(), true);
|
||||||
|
|
||||||
|
let (plaintext, migrated) = store.decrypt_and_migrate("sk-plaintext-key").unwrap();
|
||||||
|
assert_eq!(plaintext, "sk-plaintext-key");
|
||||||
|
assert!(
|
||||||
|
migrated.is_none(),
|
||||||
|
"Plaintext values should not trigger migration"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn decrypt_and_migrate_upgrades_legacy_xor() {
|
||||||
|
let tmp = TempDir::new().unwrap();
|
||||||
|
let store = SecretStore::new(tmp.path(), true);
|
||||||
|
|
||||||
|
// Create key first
|
||||||
|
let _ = store.encrypt("setup").unwrap();
|
||||||
|
let key = store.load_or_create_key().unwrap();
|
||||||
|
|
||||||
|
// Manually create a legacy XOR-encrypted value
|
||||||
|
let plaintext = "sk-legacy-secret-to-migrate";
|
||||||
|
let ciphertext = xor_cipher(plaintext.as_bytes(), &key);
|
||||||
|
let legacy_value = format!("enc:{}", hex_encode(&ciphertext));
|
||||||
|
|
||||||
|
// Verify it needs migration
|
||||||
|
assert!(SecretStore::needs_migration(&legacy_value));
|
||||||
|
|
||||||
|
// Decrypt and migrate
|
||||||
|
let (decrypted, migrated) = store.decrypt_and_migrate(&legacy_value).unwrap();
|
||||||
|
assert_eq!(decrypted, plaintext, "Plaintext must match original");
|
||||||
|
assert!(migrated.is_some(), "Legacy value should trigger migration");
|
||||||
|
|
||||||
|
let new_value = migrated.unwrap();
|
||||||
|
assert!(
|
||||||
|
new_value.starts_with("enc2:"),
|
||||||
|
"Migrated value must use enc2: prefix"
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
!SecretStore::needs_migration(&new_value),
|
||||||
|
"Migrated value should not need migration"
|
||||||
|
);
|
||||||
|
|
||||||
|
// Verify the migrated value decrypts correctly
|
||||||
|
let (decrypted2, migrated2) = store.decrypt_and_migrate(&new_value).unwrap();
|
||||||
|
assert_eq!(
|
||||||
|
decrypted2, plaintext,
|
||||||
|
"Migrated value must decrypt to same plaintext"
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
migrated2.is_none(),
|
||||||
|
"Migrated value should not trigger another migration"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn decrypt_and_migrate_handles_unicode() {
|
||||||
|
let tmp = TempDir::new().unwrap();
|
||||||
|
let store = SecretStore::new(tmp.path(), true);
|
||||||
|
|
||||||
|
let _ = store.encrypt("setup").unwrap();
|
||||||
|
let key = store.load_or_create_key().unwrap();
|
||||||
|
|
||||||
|
let plaintext = "sk-日本語-émojis-🦀-тест";
|
||||||
|
let ciphertext = xor_cipher(plaintext.as_bytes(), &key);
|
||||||
|
let legacy_value = format!("enc:{}", hex_encode(&ciphertext));
|
||||||
|
|
||||||
|
let (decrypted, migrated) = store.decrypt_and_migrate(&legacy_value).unwrap();
|
||||||
|
assert_eq!(decrypted, plaintext);
|
||||||
|
assert!(migrated.is_some());
|
||||||
|
|
||||||
|
// Verify migrated value works
|
||||||
|
let new_value = migrated.unwrap();
|
||||||
|
let (decrypted2, _) = store.decrypt_and_migrate(&new_value).unwrap();
|
||||||
|
assert_eq!(decrypted2, plaintext);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn decrypt_and_migrate_handles_empty_secret() {
|
||||||
|
let tmp = TempDir::new().unwrap();
|
||||||
|
let store = SecretStore::new(tmp.path(), true);
|
||||||
|
|
||||||
|
let _ = store.encrypt("setup").unwrap();
|
||||||
|
let key = store.load_or_create_key().unwrap();
|
||||||
|
|
||||||
|
// Empty plaintext XOR-encrypted
|
||||||
|
let plaintext = "";
|
||||||
|
let ciphertext = xor_cipher(plaintext.as_bytes(), &key);
|
||||||
|
let legacy_value = format!("enc:{}", hex_encode(&ciphertext));
|
||||||
|
|
||||||
|
let (decrypted, migrated) = store.decrypt_and_migrate(&legacy_value).unwrap();
|
||||||
|
assert_eq!(decrypted, plaintext);
|
||||||
|
// Empty string encryption returns empty string (not enc2:)
|
||||||
|
assert!(migrated.is_some());
|
||||||
|
assert_eq!(migrated.unwrap(), "");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn decrypt_and_migrate_handles_long_secret() {
|
||||||
|
let tmp = TempDir::new().unwrap();
|
||||||
|
let store = SecretStore::new(tmp.path(), true);
|
||||||
|
|
||||||
|
let _ = store.encrypt("setup").unwrap();
|
||||||
|
let key = store.load_or_create_key().unwrap();
|
||||||
|
|
||||||
|
let plaintext = "a".repeat(10_000);
|
||||||
|
let ciphertext = xor_cipher(plaintext.as_bytes(), &key);
|
||||||
|
let legacy_value = format!("enc:{}", hex_encode(&ciphertext));
|
||||||
|
|
||||||
|
let (decrypted, migrated) = store.decrypt_and_migrate(&legacy_value).unwrap();
|
||||||
|
assert_eq!(decrypted, plaintext);
|
||||||
|
assert!(migrated.is_some());
|
||||||
|
|
||||||
|
let new_value = migrated.unwrap();
|
||||||
|
let (decrypted2, _) = store.decrypt_and_migrate(&new_value).unwrap();
|
||||||
|
assert_eq!(decrypted2, plaintext);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn decrypt_and_migrate_fails_on_corrupt_legacy_hex() {
|
||||||
|
let tmp = TempDir::new().unwrap();
|
||||||
|
let store = SecretStore::new(tmp.path(), true);
|
||||||
|
let _ = store.encrypt("setup").unwrap();
|
||||||
|
|
||||||
|
let result = store.decrypt_and_migrate("enc:not-valid-hex!!");
|
||||||
|
assert!(result.is_err(), "Corrupt hex should fail");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn decrypt_and_migrate_wrong_key_produces_garbage_or_fails() {
|
||||||
|
let tmp1 = TempDir::new().unwrap();
|
||||||
|
let tmp2 = TempDir::new().unwrap();
|
||||||
|
let store1 = SecretStore::new(tmp1.path(), true);
|
||||||
|
let store2 = SecretStore::new(tmp2.path(), true);
|
||||||
|
|
||||||
|
// Create keys for both stores
|
||||||
|
let _ = store1.encrypt("setup").unwrap();
|
||||||
|
let _ = store2.encrypt("setup").unwrap();
|
||||||
|
let key1 = store1.load_or_create_key().unwrap();
|
||||||
|
|
||||||
|
// Encrypt with store1's key
|
||||||
|
let plaintext = "secret-for-store1";
|
||||||
|
let ciphertext = xor_cipher(plaintext.as_bytes(), &key1);
|
||||||
|
let legacy_value = format!("enc:{}", hex_encode(&ciphertext));
|
||||||
|
|
||||||
|
// Decrypt with store2 — XOR will produce garbage bytes
|
||||||
|
// This may fail with UTF-8 error or succeed with garbage plaintext
|
||||||
|
match store2.decrypt_and_migrate(&legacy_value) {
|
||||||
|
Ok((decrypted, _)) => {
|
||||||
|
// If it succeeds, the plaintext should be garbage (not the original)
|
||||||
|
assert_ne!(
|
||||||
|
decrypted, plaintext,
|
||||||
|
"Wrong key should produce garbage plaintext"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
// Expected: UTF-8 decoding failure from garbage bytes
|
||||||
|
assert!(
|
||||||
|
e.to_string().contains("UTF-8"),
|
||||||
|
"Error should be UTF-8 related: {e}"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn migration_produces_different_ciphertext_each_time() {
|
||||||
|
let tmp = TempDir::new().unwrap();
|
||||||
|
let store = SecretStore::new(tmp.path(), true);
|
||||||
|
|
||||||
|
let _ = store.encrypt("setup").unwrap();
|
||||||
|
let key = store.load_or_create_key().unwrap();
|
||||||
|
|
||||||
|
let plaintext = "sk-same-secret";
|
||||||
|
let ciphertext = xor_cipher(plaintext.as_bytes(), &key);
|
||||||
|
let legacy_value = format!("enc:{}", hex_encode(&ciphertext));
|
||||||
|
|
||||||
|
let (_, migrated1) = store.decrypt_and_migrate(&legacy_value).unwrap();
|
||||||
|
let (_, migrated2) = store.decrypt_and_migrate(&legacy_value).unwrap();
|
||||||
|
|
||||||
|
assert!(migrated1.is_some());
|
||||||
|
assert!(migrated2.is_some());
|
||||||
|
assert_ne!(
|
||||||
|
migrated1.unwrap(),
|
||||||
|
migrated2.unwrap(),
|
||||||
|
"Each migration should produce different ciphertext (random nonce)"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn migrated_value_is_tamper_resistant() {
|
||||||
|
let tmp = TempDir::new().unwrap();
|
||||||
|
let store = SecretStore::new(tmp.path(), true);
|
||||||
|
|
||||||
|
let _ = store.encrypt("setup").unwrap();
|
||||||
|
let key = store.load_or_create_key().unwrap();
|
||||||
|
|
||||||
|
let plaintext = "sk-sensitive-data";
|
||||||
|
let ciphertext = xor_cipher(plaintext.as_bytes(), &key);
|
||||||
|
let legacy_value = format!("enc:{}", hex_encode(&ciphertext));
|
||||||
|
|
||||||
|
let (_, migrated) = store.decrypt_and_migrate(&legacy_value).unwrap();
|
||||||
|
let new_value = migrated.unwrap();
|
||||||
|
|
||||||
|
// Tamper with the migrated value
|
||||||
|
let hex_str = &new_value[5..];
|
||||||
|
let mut blob = hex_decode(hex_str).unwrap();
|
||||||
|
if blob.len() > NONCE_LEN {
|
||||||
|
blob[NONCE_LEN] ^= 0xff;
|
||||||
|
}
|
||||||
|
let tampered = format!("enc2:{}", hex_encode(&blob));
|
||||||
|
|
||||||
|
let result = store.decrypt_and_migrate(&tampered);
|
||||||
|
assert!(result.is_err(), "Tampered migrated value must be rejected");
|
||||||
|
}
|
||||||
|
|
||||||
// ── Low-level helpers ───────────────────────────────────────
|
// ── Low-level helpers ───────────────────────────────────────
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
|
||||||
284
src/service/mod.rs
Normal file
284
src/service/mod.rs
Normal file
|
|
@ -0,0 +1,284 @@
|
||||||
|
use crate::config::Config;
|
||||||
|
use anyhow::{Context, Result};
|
||||||
|
use std::fs;
|
||||||
|
use std::path::PathBuf;
|
||||||
|
use std::process::Command;
|
||||||
|
|
||||||
|
const SERVICE_LABEL: &str = "com.zeroclaw.daemon";
|
||||||
|
|
||||||
|
pub fn handle_command(command: super::ServiceCommands, config: &Config) -> Result<()> {
|
||||||
|
match command {
|
||||||
|
super::ServiceCommands::Install => install(config),
|
||||||
|
super::ServiceCommands::Start => start(config),
|
||||||
|
super::ServiceCommands::Stop => stop(config),
|
||||||
|
super::ServiceCommands::Status => status(config),
|
||||||
|
super::ServiceCommands::Uninstall => uninstall(config),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn install(config: &Config) -> Result<()> {
|
||||||
|
if cfg!(target_os = "macos") {
|
||||||
|
install_macos(config)
|
||||||
|
} else if cfg!(target_os = "linux") {
|
||||||
|
install_linux(config)
|
||||||
|
} else {
|
||||||
|
anyhow::bail!("Service management is supported on macOS and Linux only");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn start(config: &Config) -> Result<()> {
|
||||||
|
if cfg!(target_os = "macos") {
|
||||||
|
let plist = macos_service_file()?;
|
||||||
|
run_checked(Command::new("launchctl").arg("load").arg("-w").arg(&plist))?;
|
||||||
|
run_checked(Command::new("launchctl").arg("start").arg(SERVICE_LABEL))?;
|
||||||
|
println!("✅ Service started");
|
||||||
|
Ok(())
|
||||||
|
} else if cfg!(target_os = "linux") {
|
||||||
|
run_checked(Command::new("systemctl").args(["--user", "daemon-reload"]))?;
|
||||||
|
run_checked(Command::new("systemctl").args(["--user", "start", "zeroclaw.service"]))?;
|
||||||
|
println!("✅ Service started");
|
||||||
|
Ok(())
|
||||||
|
} else {
|
||||||
|
let _ = config;
|
||||||
|
anyhow::bail!("Service management is supported on macOS and Linux only")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn stop(config: &Config) -> Result<()> {
|
||||||
|
if cfg!(target_os = "macos") {
|
||||||
|
let plist = macos_service_file()?;
|
||||||
|
let _ = run_checked(Command::new("launchctl").arg("stop").arg(SERVICE_LABEL));
|
||||||
|
let _ = run_checked(
|
||||||
|
Command::new("launchctl")
|
||||||
|
.arg("unload")
|
||||||
|
.arg("-w")
|
||||||
|
.arg(&plist),
|
||||||
|
);
|
||||||
|
println!("✅ Service stopped");
|
||||||
|
Ok(())
|
||||||
|
} else if cfg!(target_os = "linux") {
|
||||||
|
let _ = run_checked(Command::new("systemctl").args(["--user", "stop", "zeroclaw.service"]));
|
||||||
|
println!("✅ Service stopped");
|
||||||
|
Ok(())
|
||||||
|
} else {
|
||||||
|
let _ = config;
|
||||||
|
anyhow::bail!("Service management is supported on macOS and Linux only")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn status(config: &Config) -> Result<()> {
|
||||||
|
if cfg!(target_os = "macos") {
|
||||||
|
let out = run_capture(Command::new("launchctl").arg("list"))?;
|
||||||
|
let running = out.lines().any(|line| line.contains(SERVICE_LABEL));
|
||||||
|
println!(
|
||||||
|
"Service: {}",
|
||||||
|
if running {
|
||||||
|
"✅ running/loaded"
|
||||||
|
} else {
|
||||||
|
"❌ not loaded"
|
||||||
|
}
|
||||||
|
);
|
||||||
|
println!("Unit: {}", macos_service_file()?.display());
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg!(target_os = "linux") {
|
||||||
|
let out = run_capture(Command::new("systemctl").args([
|
||||||
|
"--user",
|
||||||
|
"is-active",
|
||||||
|
"zeroclaw.service",
|
||||||
|
]))
|
||||||
|
.unwrap_or_else(|_| "unknown".into());
|
||||||
|
println!("Service state: {}", out.trim());
|
||||||
|
println!("Unit: {}", linux_service_file(config)?.display());
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
anyhow::bail!("Service management is supported on macOS and Linux only")
|
||||||
|
}
|
||||||
|
|
||||||
|
fn uninstall(config: &Config) -> Result<()> {
|
||||||
|
stop(config)?;
|
||||||
|
|
||||||
|
if cfg!(target_os = "macos") {
|
||||||
|
let file = macos_service_file()?;
|
||||||
|
if file.exists() {
|
||||||
|
fs::remove_file(&file)
|
||||||
|
.with_context(|| format!("Failed to remove {}", file.display()))?;
|
||||||
|
}
|
||||||
|
println!("✅ Service uninstalled ({})", file.display());
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
if cfg!(target_os = "linux") {
|
||||||
|
let file = linux_service_file(config)?;
|
||||||
|
if file.exists() {
|
||||||
|
fs::remove_file(&file)
|
||||||
|
.with_context(|| format!("Failed to remove {}", file.display()))?;
|
||||||
|
}
|
||||||
|
let _ = run_checked(Command::new("systemctl").args(["--user", "daemon-reload"]));
|
||||||
|
println!("✅ Service uninstalled ({})", file.display());
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
anyhow::bail!("Service management is supported on macOS and Linux only")
|
||||||
|
}
|
||||||
|
|
||||||
|
fn install_macos(config: &Config) -> Result<()> {
|
||||||
|
let file = macos_service_file()?;
|
||||||
|
if let Some(parent) = file.parent() {
|
||||||
|
fs::create_dir_all(parent)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
let exe = std::env::current_exe().context("Failed to resolve current executable")?;
|
||||||
|
let logs_dir = config
|
||||||
|
.config_path
|
||||||
|
.parent()
|
||||||
|
.map_or_else(|| PathBuf::from("."), PathBuf::from)
|
||||||
|
.join("logs");
|
||||||
|
fs::create_dir_all(&logs_dir)?;
|
||||||
|
|
||||||
|
let stdout = logs_dir.join("daemon.stdout.log");
|
||||||
|
let stderr = logs_dir.join("daemon.stderr.log");
|
||||||
|
|
||||||
|
let plist = format!(
|
||||||
|
r#"<?xml version=\"1.0\" encoding=\"UTF-8\"?>
|
||||||
|
<!DOCTYPE plist PUBLIC \"-//Apple//DTD PLIST 1.0//EN\" \"http://www.apple.com/DTDs/PropertyList-1.0.dtd\">
|
||||||
|
<plist version=\"1.0\">
|
||||||
|
<dict>
|
||||||
|
<key>Label</key>
|
||||||
|
<string>{label}</string>
|
||||||
|
<key>ProgramArguments</key>
|
||||||
|
<array>
|
||||||
|
<string>{exe}</string>
|
||||||
|
<string>daemon</string>
|
||||||
|
</array>
|
||||||
|
<key>RunAtLoad</key>
|
||||||
|
<true/>
|
||||||
|
<key>KeepAlive</key>
|
||||||
|
<true/>
|
||||||
|
<key>StandardOutPath</key>
|
||||||
|
<string>{stdout}</string>
|
||||||
|
<key>StandardErrorPath</key>
|
||||||
|
<string>{stderr}</string>
|
||||||
|
</dict>
|
||||||
|
</plist>
|
||||||
|
"#,
|
||||||
|
label = SERVICE_LABEL,
|
||||||
|
exe = xml_escape(&exe.display().to_string()),
|
||||||
|
stdout = xml_escape(&stdout.display().to_string()),
|
||||||
|
stderr = xml_escape(&stderr.display().to_string())
|
||||||
|
);
|
||||||
|
|
||||||
|
fs::write(&file, plist)?;
|
||||||
|
println!("✅ Installed launchd service: {}", file.display());
|
||||||
|
println!(" Start with: zeroclaw service start");
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn install_linux(config: &Config) -> Result<()> {
|
||||||
|
let file = linux_service_file(config)?;
|
||||||
|
if let Some(parent) = file.parent() {
|
||||||
|
fs::create_dir_all(parent)?;
|
||||||
|
}
|
||||||
|
|
||||||
|
let exe = std::env::current_exe().context("Failed to resolve current executable")?;
|
||||||
|
let unit = format!(
|
||||||
|
"[Unit]\nDescription=ZeroClaw daemon\nAfter=network.target\n\n[Service]\nType=simple\nExecStart={} daemon\nRestart=always\nRestartSec=3\n\n[Install]\nWantedBy=default.target\n",
|
||||||
|
exe.display()
|
||||||
|
);
|
||||||
|
|
||||||
|
fs::write(&file, unit)?;
|
||||||
|
let _ = run_checked(Command::new("systemctl").args(["--user", "daemon-reload"]));
|
||||||
|
let _ = run_checked(Command::new("systemctl").args(["--user", "enable", "zeroclaw.service"]));
|
||||||
|
println!("✅ Installed systemd user service: {}", file.display());
|
||||||
|
println!(" Start with: zeroclaw service start");
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn macos_service_file() -> Result<PathBuf> {
|
||||||
|
let home = directories::UserDirs::new()
|
||||||
|
.map(|u| u.home_dir().to_path_buf())
|
||||||
|
.context("Could not find home directory")?;
|
||||||
|
Ok(home
|
||||||
|
.join("Library")
|
||||||
|
.join("LaunchAgents")
|
||||||
|
.join(format!("{SERVICE_LABEL}.plist")))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn linux_service_file(config: &Config) -> Result<PathBuf> {
|
||||||
|
let home = directories::UserDirs::new()
|
||||||
|
.map(|u| u.home_dir().to_path_buf())
|
||||||
|
.context("Could not find home directory")?;
|
||||||
|
let _ = config;
|
||||||
|
Ok(home
|
||||||
|
.join(".config")
|
||||||
|
.join("systemd")
|
||||||
|
.join("user")
|
||||||
|
.join("zeroclaw.service"))
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run_checked(command: &mut Command) -> Result<()> {
|
||||||
|
let output = command.output().context("Failed to spawn command")?;
|
||||||
|
if !output.status.success() {
|
||||||
|
let stderr = String::from_utf8_lossy(&output.stderr);
|
||||||
|
anyhow::bail!("Command failed: {}", stderr.trim());
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn run_capture(command: &mut Command) -> Result<String> {
|
||||||
|
let output = command.output().context("Failed to spawn command")?;
|
||||||
|
let mut text = String::from_utf8_lossy(&output.stdout).to_string();
|
||||||
|
if text.trim().is_empty() {
|
||||||
|
text = String::from_utf8_lossy(&output.stderr).to_string();
|
||||||
|
}
|
||||||
|
Ok(text)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn xml_escape(raw: &str) -> String {
|
||||||
|
raw.replace('&', "&")
|
||||||
|
.replace('<', "<")
|
||||||
|
.replace('>', ">")
|
||||||
|
.replace('"', """)
|
||||||
|
.replace('\'', "'")
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn xml_escape_escapes_reserved_chars() {
|
||||||
|
let escaped = xml_escape("<&>\"' and text");
|
||||||
|
assert_eq!(escaped, "<&>"' and text");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn run_capture_reads_stdout() {
|
||||||
|
let out = run_capture(Command::new("sh").args(["-lc", "echo hello"]))
|
||||||
|
.expect("stdout capture should succeed");
|
||||||
|
assert_eq!(out.trim(), "hello");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn run_capture_falls_back_to_stderr() {
|
||||||
|
let out = run_capture(Command::new("sh").args(["-lc", "echo warn 1>&2"]))
|
||||||
|
.expect("stderr capture should succeed");
|
||||||
|
assert_eq!(out.trim(), "warn");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn run_checked_errors_on_non_zero_status() {
|
||||||
|
let err = run_checked(Command::new("sh").args(["-lc", "exit 17"]))
|
||||||
|
.expect_err("non-zero exit should error");
|
||||||
|
assert!(err.to_string().contains("Command failed"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn linux_service_file_has_expected_suffix() {
|
||||||
|
let file = linux_service_file(&Config::default()).unwrap();
|
||||||
|
let path = file.to_string_lossy();
|
||||||
|
assert!(path.ends_with(".config/systemd/user/zeroclaw.service"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -221,6 +221,23 @@ pub fn init_skills_dir(workspace_dir: &Path) -> Result<()> {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Recursively copy a directory (used as fallback when symlinks aren't available)
|
||||||
|
#[cfg(any(windows, not(unix)))]
|
||||||
|
fn copy_dir_recursive(src: &Path, dest: &Path) -> Result<()> {
|
||||||
|
std::fs::create_dir_all(dest)?;
|
||||||
|
for entry in std::fs::read_dir(src)? {
|
||||||
|
let entry = entry?;
|
||||||
|
let src_path = entry.path();
|
||||||
|
let dest_path = dest.join(entry.file_name());
|
||||||
|
if src_path.is_dir() {
|
||||||
|
copy_dir_recursive(&src_path, &dest_path)?;
|
||||||
|
} else {
|
||||||
|
std::fs::copy(&src_path, &dest_path)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
/// Handle the `skills` CLI command
|
/// Handle the `skills` CLI command
|
||||||
pub fn handle_command(command: super::SkillCommands, workspace_dir: &Path) -> Result<()> {
|
pub fn handle_command(command: super::SkillCommands, workspace_dir: &Path) -> Result<()> {
|
||||||
match command {
|
match command {
|
||||||
|
|
@ -295,18 +312,60 @@ pub fn handle_command(command: super::SkillCommands, workspace_dir: &Path) -> Re
|
||||||
let dest = skills_path.join(name);
|
let dest = skills_path.join(name);
|
||||||
|
|
||||||
#[cfg(unix)]
|
#[cfg(unix)]
|
||||||
std::os::unix::fs::symlink(&src, &dest)?;
|
|
||||||
#[cfg(not(unix))]
|
|
||||||
{
|
{
|
||||||
// On non-unix, copy the directory
|
std::os::unix::fs::symlink(&src, &dest)?;
|
||||||
anyhow::bail!("Symlink not supported on this platform. Copy the skill directory manually.");
|
println!(
|
||||||
|
" {} Skill linked: {}",
|
||||||
|
console::style("✓").green().bold(),
|
||||||
|
dest.display()
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
#[cfg(windows)]
|
||||||
|
{
|
||||||
|
// On Windows, try symlink first (requires admin or developer mode),
|
||||||
|
// fall back to directory junction, then copy
|
||||||
|
use std::os::windows::fs::symlink_dir;
|
||||||
|
if symlink_dir(&src, &dest).is_ok() {
|
||||||
|
println!(
|
||||||
|
" {} Skill linked: {}",
|
||||||
|
console::style("✓").green().bold(),
|
||||||
|
dest.display()
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
// Try junction as fallback (works without admin)
|
||||||
|
let junction_result = std::process::Command::new("cmd")
|
||||||
|
.args(["/C", "mklink", "/J"])
|
||||||
|
.arg(&dest)
|
||||||
|
.arg(&src)
|
||||||
|
.output();
|
||||||
|
|
||||||
println!(
|
if junction_result.is_ok() && junction_result.unwrap().status.success() {
|
||||||
" {} Skill linked: {}",
|
println!(
|
||||||
console::style("✓").green().bold(),
|
" {} Skill linked (junction): {}",
|
||||||
dest.display()
|
console::style("✓").green().bold(),
|
||||||
);
|
dest.display()
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
// Final fallback: copy the directory
|
||||||
|
copy_dir_recursive(&src, &dest)?;
|
||||||
|
println!(
|
||||||
|
" {} Skill copied: {}",
|
||||||
|
console::style("✓").green().bold(),
|
||||||
|
dest.display()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
#[cfg(not(any(unix, windows)))]
|
||||||
|
{
|
||||||
|
// On other platforms, copy the directory
|
||||||
|
copy_dir_recursive(&src, &dest)?;
|
||||||
|
println!(
|
||||||
|
" {} Skill copied: {}",
|
||||||
|
console::style("✓").green().bold(),
|
||||||
|
dest.display()
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
|
|
@ -631,3 +690,6 @@ description = "Bare minimum"
|
||||||
assert_eq!(skills[0].name, "from-toml"); // TOML takes priority
|
assert_eq!(skills[0].name, "from-toml"); // TOML takes priority
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod symlink_tests;
|
||||||
|
|
|
||||||
109
src/skills/symlink_tests.rs
Normal file
109
src/skills/symlink_tests.rs
Normal file
|
|
@ -0,0 +1,109 @@
|
||||||
|
#[cfg(test)]
|
||||||
|
mod symlink_tests {
|
||||||
|
use crate::skills::skills_dir;
|
||||||
|
use std::path::Path;
|
||||||
|
use tempfile::TempDir;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_skills_symlink_unix_edge_cases() {
|
||||||
|
let tmp = TempDir::new().unwrap();
|
||||||
|
let workspace_dir = tmp.path().join("workspace");
|
||||||
|
std::fs::create_dir_all(&workspace_dir).unwrap();
|
||||||
|
|
||||||
|
let skills_path = skills_dir(&workspace_dir);
|
||||||
|
std::fs::create_dir_all(&skills_path).unwrap();
|
||||||
|
|
||||||
|
// Test case 1: Valid symlink creation on Unix
|
||||||
|
#[cfg(unix)]
|
||||||
|
{
|
||||||
|
let source_dir = tmp.path().join("source_skill");
|
||||||
|
std::fs::create_dir_all(&source_dir).unwrap();
|
||||||
|
std::fs::write(source_dir.join("SKILL.md"), "# Test Skill\nContent").unwrap();
|
||||||
|
|
||||||
|
let dest_link = skills_path.join("linked_skill");
|
||||||
|
|
||||||
|
// Create symlink
|
||||||
|
let result = std::os::unix::fs::symlink(&source_dir, &dest_link);
|
||||||
|
assert!(result.is_ok(), "Symlink creation should succeed");
|
||||||
|
|
||||||
|
// Verify symlink works
|
||||||
|
assert!(dest_link.exists());
|
||||||
|
assert!(dest_link.is_symlink());
|
||||||
|
|
||||||
|
// Verify we can read through symlink
|
||||||
|
let content = std::fs::read_to_string(dest_link.join("SKILL.md"));
|
||||||
|
assert!(content.is_ok());
|
||||||
|
assert!(content.unwrap().contains("Test Skill"));
|
||||||
|
|
||||||
|
// Test case 2: Symlink to non-existent target should fail gracefully
|
||||||
|
let broken_link = skills_path.join("broken_skill");
|
||||||
|
let non_existent = tmp.path().join("non_existent");
|
||||||
|
let result = std::os::unix::fs::symlink(&non_existent, &broken_link);
|
||||||
|
assert!(
|
||||||
|
result.is_ok(),
|
||||||
|
"Symlink creation should succeed even if target doesn't exist"
|
||||||
|
);
|
||||||
|
|
||||||
|
// But reading through it should fail
|
||||||
|
let content = std::fs::read_to_string(broken_link.join("SKILL.md"));
|
||||||
|
assert!(content.is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test case 3: Non-Unix platforms should handle symlink errors gracefully
|
||||||
|
#[cfg(not(unix))]
|
||||||
|
{
|
||||||
|
let source_dir = tmp.path().join("source_skill");
|
||||||
|
std::fs::create_dir_all(&source_dir).unwrap();
|
||||||
|
|
||||||
|
let dest_link = skills_path.join("linked_skill");
|
||||||
|
|
||||||
|
// Symlink should fail on non-Unix
|
||||||
|
let result = std::os::unix::fs::symlink(&source_dir, &dest_link);
|
||||||
|
assert!(result.is_err());
|
||||||
|
|
||||||
|
// Directory should not exist
|
||||||
|
assert!(!dest_link.exists());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test case 4: skills_dir function edge cases
|
||||||
|
let workspace_with_trailing_slash = format!("{}/", workspace_dir.display());
|
||||||
|
let path_from_str = skills_dir(Path::new(&workspace_with_trailing_slash));
|
||||||
|
assert_eq!(path_from_str, skills_path);
|
||||||
|
|
||||||
|
// Test case 5: Empty workspace directory
|
||||||
|
let empty_workspace = tmp.path().join("empty");
|
||||||
|
let empty_skills_path = skills_dir(&empty_workspace);
|
||||||
|
assert_eq!(empty_skills_path, empty_workspace.join("skills"));
|
||||||
|
assert!(!empty_skills_path.exists());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_skills_symlink_permissions_and_safety() {
|
||||||
|
let tmp = TempDir::new().unwrap();
|
||||||
|
let workspace_dir = tmp.path().join("workspace");
|
||||||
|
std::fs::create_dir_all(&workspace_dir).unwrap();
|
||||||
|
|
||||||
|
let skills_path = skills_dir(&workspace_dir);
|
||||||
|
std::fs::create_dir_all(&skills_path).unwrap();
|
||||||
|
|
||||||
|
#[cfg(unix)]
|
||||||
|
{
|
||||||
|
// Test case: Symlink outside workspace should be allowed (user responsibility)
|
||||||
|
let outside_dir = tmp.path().join("outside_skill");
|
||||||
|
std::fs::create_dir_all(&outside_dir).unwrap();
|
||||||
|
std::fs::write(outside_dir.join("SKILL.md"), "# Outside Skill\nContent").unwrap();
|
||||||
|
|
||||||
|
let dest_link = skills_path.join("outside_skill");
|
||||||
|
let result = std::os::unix::fs::symlink(&outside_dir, &dest_link);
|
||||||
|
assert!(
|
||||||
|
result.is_ok(),
|
||||||
|
"Should allow symlinking to directories outside workspace"
|
||||||
|
);
|
||||||
|
|
||||||
|
// Should still be readable
|
||||||
|
let content = std::fs::read_to_string(dest_link.join("SKILL.md"));
|
||||||
|
assert!(content.is_ok());
|
||||||
|
assert!(content.unwrap().contains("Outside Skill"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -55,7 +55,30 @@ impl Tool for FileReadTool {
|
||||||
|
|
||||||
let full_path = self.security.workspace_dir.join(path);
|
let full_path = self.security.workspace_dir.join(path);
|
||||||
|
|
||||||
match tokio::fs::read_to_string(&full_path).await {
|
// Resolve path before reading to block symlink escapes.
|
||||||
|
let resolved_path = match tokio::fs::canonicalize(&full_path).await {
|
||||||
|
Ok(p) => p,
|
||||||
|
Err(e) => {
|
||||||
|
return Ok(ToolResult {
|
||||||
|
success: false,
|
||||||
|
output: String::new(),
|
||||||
|
error: Some(format!("Failed to resolve file path: {e}")),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
if !self.security.is_resolved_path_allowed(&resolved_path) {
|
||||||
|
return Ok(ToolResult {
|
||||||
|
success: false,
|
||||||
|
output: String::new(),
|
||||||
|
error: Some(format!(
|
||||||
|
"Resolved path escapes workspace: {}",
|
||||||
|
resolved_path.display()
|
||||||
|
)),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
match tokio::fs::read_to_string(&resolved_path).await {
|
||||||
Ok(contents) => Ok(ToolResult {
|
Ok(contents) => Ok(ToolResult {
|
||||||
success: true,
|
success: true,
|
||||||
output: contents,
|
output: contents,
|
||||||
|
|
@ -127,7 +150,7 @@ mod tests {
|
||||||
let tool = FileReadTool::new(test_security(dir.clone()));
|
let tool = FileReadTool::new(test_security(dir.clone()));
|
||||||
let result = tool.execute(json!({"path": "nope.txt"})).await.unwrap();
|
let result = tool.execute(json!({"path": "nope.txt"})).await.unwrap();
|
||||||
assert!(!result.success);
|
assert!(!result.success);
|
||||||
assert!(result.error.as_ref().unwrap().contains("Failed to read"));
|
assert!(result.error.as_ref().unwrap().contains("Failed to resolve"));
|
||||||
|
|
||||||
let _ = tokio::fs::remove_dir_all(&dir).await;
|
let _ = tokio::fs::remove_dir_all(&dir).await;
|
||||||
}
|
}
|
||||||
|
|
@ -200,4 +223,36 @@ mod tests {
|
||||||
|
|
||||||
let _ = tokio::fs::remove_dir_all(&dir).await;
|
let _ = tokio::fs::remove_dir_all(&dir).await;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(unix)]
|
||||||
|
#[tokio::test]
|
||||||
|
async fn file_read_blocks_symlink_escape() {
|
||||||
|
use std::os::unix::fs::symlink;
|
||||||
|
|
||||||
|
let root = std::env::temp_dir().join("zeroclaw_test_file_read_symlink_escape");
|
||||||
|
let workspace = root.join("workspace");
|
||||||
|
let outside = root.join("outside");
|
||||||
|
|
||||||
|
let _ = tokio::fs::remove_dir_all(&root).await;
|
||||||
|
tokio::fs::create_dir_all(&workspace).await.unwrap();
|
||||||
|
tokio::fs::create_dir_all(&outside).await.unwrap();
|
||||||
|
|
||||||
|
tokio::fs::write(outside.join("secret.txt"), "outside workspace")
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
symlink(outside.join("secret.txt"), workspace.join("escape.txt")).unwrap();
|
||||||
|
|
||||||
|
let tool = FileReadTool::new(test_security(workspace.clone()));
|
||||||
|
let result = tool.execute(json!({"path": "escape.txt"})).await.unwrap();
|
||||||
|
|
||||||
|
assert!(!result.success);
|
||||||
|
assert!(result
|
||||||
|
.error
|
||||||
|
.as_deref()
|
||||||
|
.unwrap_or("")
|
||||||
|
.contains("escapes workspace"));
|
||||||
|
|
||||||
|
let _ = tokio::fs::remove_dir_all(&root).await;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -69,7 +69,54 @@ impl Tool for FileWriteTool {
|
||||||
tokio::fs::create_dir_all(parent).await?;
|
tokio::fs::create_dir_all(parent).await?;
|
||||||
}
|
}
|
||||||
|
|
||||||
match tokio::fs::write(&full_path, content).await {
|
let parent = match full_path.parent() {
|
||||||
|
Some(p) => p,
|
||||||
|
None => {
|
||||||
|
return Ok(ToolResult {
|
||||||
|
success: false,
|
||||||
|
output: String::new(),
|
||||||
|
error: Some("Invalid path: missing parent directory".into()),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
// Resolve parent before writing to block symlink escapes.
|
||||||
|
let resolved_parent = match tokio::fs::canonicalize(parent).await {
|
||||||
|
Ok(p) => p,
|
||||||
|
Err(e) => {
|
||||||
|
return Ok(ToolResult {
|
||||||
|
success: false,
|
||||||
|
output: String::new(),
|
||||||
|
error: Some(format!("Failed to resolve file path: {e}")),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
if !self.security.is_resolved_path_allowed(&resolved_parent) {
|
||||||
|
return Ok(ToolResult {
|
||||||
|
success: false,
|
||||||
|
output: String::new(),
|
||||||
|
error: Some(format!(
|
||||||
|
"Resolved path escapes workspace: {}",
|
||||||
|
resolved_parent.display()
|
||||||
|
)),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
let file_name = match full_path.file_name() {
|
||||||
|
Some(name) => name,
|
||||||
|
None => {
|
||||||
|
return Ok(ToolResult {
|
||||||
|
success: false,
|
||||||
|
output: String::new(),
|
||||||
|
error: Some("Invalid path: missing file name".into()),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
let resolved_target = resolved_parent.join(file_name);
|
||||||
|
|
||||||
|
match tokio::fs::write(&resolved_target, content).await {
|
||||||
Ok(()) => Ok(ToolResult {
|
Ok(()) => Ok(ToolResult {
|
||||||
success: true,
|
success: true,
|
||||||
output: format!("Written {} bytes to {path}", content.len()),
|
output: format!("Written {} bytes to {path}", content.len()),
|
||||||
|
|
@ -239,4 +286,36 @@ mod tests {
|
||||||
|
|
||||||
let _ = tokio::fs::remove_dir_all(&dir).await;
|
let _ = tokio::fs::remove_dir_all(&dir).await;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[cfg(unix)]
|
||||||
|
#[tokio::test]
|
||||||
|
async fn file_write_blocks_symlink_escape() {
|
||||||
|
use std::os::unix::fs::symlink;
|
||||||
|
|
||||||
|
let root = std::env::temp_dir().join("zeroclaw_test_file_write_symlink_escape");
|
||||||
|
let workspace = root.join("workspace");
|
||||||
|
let outside = root.join("outside");
|
||||||
|
|
||||||
|
let _ = tokio::fs::remove_dir_all(&root).await;
|
||||||
|
tokio::fs::create_dir_all(&workspace).await.unwrap();
|
||||||
|
tokio::fs::create_dir_all(&outside).await.unwrap();
|
||||||
|
|
||||||
|
symlink(&outside, workspace.join("escape_dir")).unwrap();
|
||||||
|
|
||||||
|
let tool = FileWriteTool::new(test_security(workspace.clone()));
|
||||||
|
let result = tool
|
||||||
|
.execute(json!({"path": "escape_dir/hijack.txt", "content": "bad"}))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert!(!result.success);
|
||||||
|
assert!(result
|
||||||
|
.error
|
||||||
|
.as_deref()
|
||||||
|
.unwrap_or("")
|
||||||
|
.contains("escapes workspace"));
|
||||||
|
assert!(!outside.join("hijack.txt").exists());
|
||||||
|
|
||||||
|
let _ = tokio::fs::remove_dir_all(&root).await;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
327
tests/dockerignore_test.rs
Normal file
327
tests/dockerignore_test.rs
Normal file
|
|
@ -0,0 +1,327 @@
|
||||||
|
//! Tests to verify .dockerignore excludes sensitive paths from Docker build context.
|
||||||
|
//!
|
||||||
|
//! These tests validate that:
|
||||||
|
//! 1. The .dockerignore file exists
|
||||||
|
//! 2. All security-critical paths are excluded
|
||||||
|
//! 3. All build-essential paths are NOT excluded
|
||||||
|
//! 4. Pattern syntax is valid
|
||||||
|
|
||||||
|
use std::fs;
|
||||||
|
use std::path::Path;
|
||||||
|
|
||||||
|
/// Paths that MUST be excluded from Docker build context (security/performance)
|
||||||
|
const MUST_EXCLUDE: &[&str] = &[
|
||||||
|
".git",
|
||||||
|
".githooks",
|
||||||
|
"target",
|
||||||
|
"docs",
|
||||||
|
"examples",
|
||||||
|
"tests",
|
||||||
|
"*.md",
|
||||||
|
"*.png",
|
||||||
|
"*.db",
|
||||||
|
"*.db-journal",
|
||||||
|
".DS_Store",
|
||||||
|
".github",
|
||||||
|
"deny.toml",
|
||||||
|
"LICENSE",
|
||||||
|
".env",
|
||||||
|
".tmp_*",
|
||||||
|
];
|
||||||
|
|
||||||
|
/// Paths that MUST NOT be excluded (required for build)
|
||||||
|
const MUST_INCLUDE: &[&str] = &["Cargo.toml", "Cargo.lock", "src/"];
|
||||||
|
|
||||||
|
/// Parse .dockerignore and return all non-comment, non-empty lines
|
||||||
|
fn parse_dockerignore(content: &str) -> Vec<String> {
|
||||||
|
content
|
||||||
|
.lines()
|
||||||
|
.map(|line| line.trim())
|
||||||
|
.filter(|line| !line.is_empty() && !line.starts_with('#'))
|
||||||
|
.map(|line| line.to_string())
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if a pattern would match a given path
|
||||||
|
fn pattern_matches(pattern: &str, path: &str) -> bool {
|
||||||
|
// Handle negation patterns
|
||||||
|
if pattern.starts_with('!') {
|
||||||
|
return false; // Negation re-includes, so it doesn't "exclude"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle glob patterns
|
||||||
|
if pattern.starts_with("*.") {
|
||||||
|
let ext = &pattern[1..]; // e.g., ".md"
|
||||||
|
return path.ends_with(ext);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle directory patterns (with or without trailing slash)
|
||||||
|
let pattern_normalized = pattern.trim_end_matches('/');
|
||||||
|
let path_normalized = path.trim_end_matches('/');
|
||||||
|
|
||||||
|
// Exact match
|
||||||
|
if path_normalized == pattern_normalized {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pattern is a prefix (directory match)
|
||||||
|
if path_normalized.starts_with(&format!("{}/", pattern_normalized)) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wildcard prefix patterns like ".tmp_*"
|
||||||
|
if pattern.contains('*') && !pattern.starts_with("*.") {
|
||||||
|
let prefix = pattern.split('*').next().unwrap_or("");
|
||||||
|
if !prefix.is_empty() && path.starts_with(prefix) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
false
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if any pattern in the list would exclude the given path
|
||||||
|
fn is_excluded(patterns: &[String], path: &str) -> bool {
|
||||||
|
let mut excluded = false;
|
||||||
|
for pattern in patterns {
|
||||||
|
if pattern.starts_with('!') {
|
||||||
|
// Negation pattern - re-include
|
||||||
|
let negated = &pattern[1..];
|
||||||
|
if pattern_matches(negated, path) {
|
||||||
|
excluded = false;
|
||||||
|
}
|
||||||
|
} else if pattern_matches(pattern, path) {
|
||||||
|
excluded = true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
excluded
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn dockerignore_file_exists() {
|
||||||
|
let path = Path::new(env!("CARGO_MANIFEST_DIR")).join(".dockerignore");
|
||||||
|
assert!(
|
||||||
|
path.exists(),
|
||||||
|
".dockerignore file must exist at project root"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn dockerignore_excludes_security_critical_paths() {
|
||||||
|
let path = Path::new(env!("CARGO_MANIFEST_DIR")).join(".dockerignore");
|
||||||
|
let content = fs::read_to_string(&path).expect("Failed to read .dockerignore");
|
||||||
|
let patterns = parse_dockerignore(&content);
|
||||||
|
|
||||||
|
for must_exclude in MUST_EXCLUDE {
|
||||||
|
// For glob patterns, test with a sample file
|
||||||
|
let test_path = if must_exclude.starts_with("*.") {
|
||||||
|
format!("sample{}", &must_exclude[1..])
|
||||||
|
} else {
|
||||||
|
must_exclude.to_string()
|
||||||
|
};
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
is_excluded(&patterns, &test_path),
|
||||||
|
"Path '{}' (tested as '{}') MUST be excluded by .dockerignore but is not. \
|
||||||
|
This is a security/performance issue.",
|
||||||
|
must_exclude,
|
||||||
|
test_path
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn dockerignore_does_not_exclude_build_essentials() {
|
||||||
|
let path = Path::new(env!("CARGO_MANIFEST_DIR")).join(".dockerignore");
|
||||||
|
let content = fs::read_to_string(&path).expect("Failed to read .dockerignore");
|
||||||
|
let patterns = parse_dockerignore(&content);
|
||||||
|
|
||||||
|
for must_include in MUST_INCLUDE {
|
||||||
|
assert!(
|
||||||
|
!is_excluded(&patterns, must_include),
|
||||||
|
"Path '{}' MUST NOT be excluded by .dockerignore (required for build)",
|
||||||
|
must_include
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn dockerignore_excludes_git_directory() {
|
||||||
|
let path = Path::new(env!("CARGO_MANIFEST_DIR")).join(".dockerignore");
|
||||||
|
let content = fs::read_to_string(&path).expect("Failed to read .dockerignore");
|
||||||
|
let patterns = parse_dockerignore(&content);
|
||||||
|
|
||||||
|
// .git directory and its contents must be excluded
|
||||||
|
assert!(is_excluded(&patterns, ".git"), ".git must be excluded");
|
||||||
|
assert!(
|
||||||
|
is_excluded(&patterns, ".git/config"),
|
||||||
|
".git/config must be excluded"
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
is_excluded(&patterns, ".git/objects/pack/pack-abc123.pack"),
|
||||||
|
".git subdirectories must be excluded"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn dockerignore_excludes_target_directory() {
|
||||||
|
let path = Path::new(env!("CARGO_MANIFEST_DIR")).join(".dockerignore");
|
||||||
|
let content = fs::read_to_string(&path).expect("Failed to read .dockerignore");
|
||||||
|
let patterns = parse_dockerignore(&content);
|
||||||
|
|
||||||
|
assert!(is_excluded(&patterns, "target"), "target must be excluded");
|
||||||
|
assert!(
|
||||||
|
is_excluded(&patterns, "target/debug/zeroclaw"),
|
||||||
|
"target/debug must be excluded"
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
is_excluded(&patterns, "target/release/zeroclaw"),
|
||||||
|
"target/release must be excluded"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn dockerignore_excludes_database_files() {
|
||||||
|
let path = Path::new(env!("CARGO_MANIFEST_DIR")).join(".dockerignore");
|
||||||
|
let content = fs::read_to_string(&path).expect("Failed to read .dockerignore");
|
||||||
|
let patterns = parse_dockerignore(&content);
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
is_excluded(&patterns, "brain.db"),
|
||||||
|
"*.db files must be excluded"
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
is_excluded(&patterns, "memory.db"),
|
||||||
|
"*.db files must be excluded"
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
is_excluded(&patterns, "brain.db-journal"),
|
||||||
|
"*.db-journal files must be excluded"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn dockerignore_excludes_markdown_files() {
|
||||||
|
let path = Path::new(env!("CARGO_MANIFEST_DIR")).join(".dockerignore");
|
||||||
|
let content = fs::read_to_string(&path).expect("Failed to read .dockerignore");
|
||||||
|
let patterns = parse_dockerignore(&content);
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
is_excluded(&patterns, "README.md"),
|
||||||
|
"*.md files must be excluded"
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
is_excluded(&patterns, "CHANGELOG.md"),
|
||||||
|
"*.md files must be excluded"
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
is_excluded(&patterns, "CONTRIBUTING.md"),
|
||||||
|
"*.md files must be excluded"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn dockerignore_excludes_image_files() {
|
||||||
|
let path = Path::new(env!("CARGO_MANIFEST_DIR")).join(".dockerignore");
|
||||||
|
let content = fs::read_to_string(&path).expect("Failed to read .dockerignore");
|
||||||
|
let patterns = parse_dockerignore(&content);
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
is_excluded(&patterns, "zeroclaw.png"),
|
||||||
|
"*.png files must be excluded"
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
is_excluded(&patterns, "logo.png"),
|
||||||
|
"*.png files must be excluded"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn dockerignore_excludes_env_files() {
|
||||||
|
let path = Path::new(env!("CARGO_MANIFEST_DIR")).join(".dockerignore");
|
||||||
|
let content = fs::read_to_string(&path).expect("Failed to read .dockerignore");
|
||||||
|
let patterns = parse_dockerignore(&content);
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
is_excluded(&patterns, ".env"),
|
||||||
|
".env must be excluded (contains secrets)"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn dockerignore_excludes_ci_configs() {
|
||||||
|
let path = Path::new(env!("CARGO_MANIFEST_DIR")).join(".dockerignore");
|
||||||
|
let content = fs::read_to_string(&path).expect("Failed to read .dockerignore");
|
||||||
|
let patterns = parse_dockerignore(&content);
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
is_excluded(&patterns, ".github"),
|
||||||
|
".github must be excluded"
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
is_excluded(&patterns, ".github/workflows/ci.yml"),
|
||||||
|
".github/workflows must be excluded"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn dockerignore_has_valid_syntax() {
|
||||||
|
let path = Path::new(env!("CARGO_MANIFEST_DIR")).join(".dockerignore");
|
||||||
|
let content = fs::read_to_string(&path).expect("Failed to read .dockerignore");
|
||||||
|
|
||||||
|
for (line_num, line) in content.lines().enumerate() {
|
||||||
|
let trimmed = line.trim();
|
||||||
|
|
||||||
|
// Skip empty lines and comments
|
||||||
|
if trimmed.is_empty() || trimmed.starts_with('#') {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for invalid patterns
|
||||||
|
assert!(
|
||||||
|
!trimmed.contains("**") || trimmed.matches("**").count() <= 2,
|
||||||
|
"Line {}: Too many ** in pattern '{}'",
|
||||||
|
line_num + 1,
|
||||||
|
trimmed
|
||||||
|
);
|
||||||
|
|
||||||
|
// Check for trailing spaces (can cause issues)
|
||||||
|
assert!(
|
||||||
|
line.trim_end() == line.trim_start().trim_end(),
|
||||||
|
"Line {}: Pattern '{}' has leading whitespace which may cause issues",
|
||||||
|
line_num + 1,
|
||||||
|
line
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn dockerignore_pattern_matching_edge_cases() {
|
||||||
|
// Test the pattern matching logic itself
|
||||||
|
let patterns = vec![
|
||||||
|
".git".to_string(),
|
||||||
|
".githooks".to_string(),
|
||||||
|
"target".to_string(),
|
||||||
|
"*.md".to_string(),
|
||||||
|
"*.db".to_string(),
|
||||||
|
".tmp_*".to_string(),
|
||||||
|
".env".to_string(),
|
||||||
|
];
|
||||||
|
|
||||||
|
// Should match
|
||||||
|
assert!(is_excluded(&patterns, ".git"));
|
||||||
|
assert!(is_excluded(&patterns, ".git/config"));
|
||||||
|
assert!(is_excluded(&patterns, ".githooks"));
|
||||||
|
assert!(is_excluded(&patterns, "target"));
|
||||||
|
assert!(is_excluded(&patterns, "target/debug/build"));
|
||||||
|
assert!(is_excluded(&patterns, "README.md"));
|
||||||
|
assert!(is_excluded(&patterns, "brain.db"));
|
||||||
|
assert!(is_excluded(&patterns, ".tmp_todo_probe"));
|
||||||
|
assert!(is_excluded(&patterns, ".env"));
|
||||||
|
|
||||||
|
// Should NOT match
|
||||||
|
assert!(!is_excluded(&patterns, "src"));
|
||||||
|
assert!(!is_excluded(&patterns, "src/main.rs"));
|
||||||
|
assert!(!is_excluded(&patterns, "Cargo.toml"));
|
||||||
|
assert!(!is_excluded(&patterns, "Cargo.lock"));
|
||||||
|
}
|
||||||
Loading…
Add table
Add a link
Reference in a new issue