diff --git a/.env.example b/.env.example index 42a0b37..a6ba55e 100644 --- a/.env.example +++ b/.env.example @@ -21,6 +21,10 @@ PROVIDER=openrouter # Workspace directory override # ZEROCLAW_WORKSPACE=/path/to/workspace +# Reasoning mode (enables extended thinking for supported models) +# ZEROCLAW_REASONING_ENABLED=false +# REASONING_ENABLED=false + # ── Provider-Specific API Keys ──────────────────────────────── # OpenRouter # OPENROUTER_API_KEY=sk-or-v1-... @@ -63,6 +67,22 @@ PROVIDER=openrouter # ZEROCLAW_GATEWAY_HOST=127.0.0.1 # ZEROCLAW_ALLOW_PUBLIC_BIND=false +# ── Storage ───────────────────────────────────────────────── +# Backend override for persistent storage (default: sqlite) +# ZEROCLAW_STORAGE_PROVIDER=sqlite +# ZEROCLAW_STORAGE_DB_URL=postgres://localhost/zeroclaw +# ZEROCLAW_STORAGE_CONNECT_TIMEOUT_SECS=5 + +# ── Proxy ────────────────────────────────────────────────── +# Forward provider/service traffic through an HTTP(S) proxy. +# ZEROCLAW_PROXY_ENABLED=false +# ZEROCLAW_HTTP_PROXY=http://proxy.example.com:8080 +# ZEROCLAW_HTTPS_PROXY=http://proxy.example.com:8080 +# ZEROCLAW_ALL_PROXY=socks5://proxy.example.com:1080 +# ZEROCLAW_NO_PROXY=localhost,127.0.0.1 +# ZEROCLAW_PROXY_SCOPE=zeroclaw # environment|zeroclaw|services +# ZEROCLAW_PROXY_SERVICES=openai,anthropic + # ── Optional Integrations ──────────────────────────────────── # Pushover notifications (`pushover` tool) # PUSHOVER_TOKEN=your-pushover-app-token diff --git a/.envrc b/.envrc new file mode 100644 index 0000000..3550a30 --- /dev/null +++ b/.envrc @@ -0,0 +1 @@ +use flake diff --git a/.github/dependabot.yml b/.github/dependabot.yml index 2f88c8e..eb81c96 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -4,13 +4,13 @@ updates: - package-ecosystem: cargo directory: "/" schedule: - interval: weekly + interval: daily target-branch: main - open-pull-requests-limit: 5 + open-pull-requests-limit: 3 labels: - "dependencies" groups: - rust-minor-patch: + rust-all: patterns: - "*" update-types: @@ -20,14 +20,14 @@ updates: - package-ecosystem: github-actions directory: "/" schedule: - interval: weekly + interval: daily target-branch: main - open-pull-requests-limit: 3 + open-pull-requests-limit: 1 labels: - "ci" - "dependencies" groups: - actions-minor-patch: + actions-all: patterns: - "*" update-types: @@ -37,16 +37,16 @@ updates: - package-ecosystem: docker directory: "/" schedule: - interval: weekly + interval: daily target-branch: main - open-pull-requests-limit: 3 + open-pull-requests-limit: 1 labels: - "ci" - "dependencies" groups: - docker-minor-patch: + docker-all: patterns: - "*" update-types: - minor - - patch \ No newline at end of file + - patch diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index 7c9e601..7990431 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -12,11 +12,7 @@ Describe this PR in 2-5 bullets: - Risk label (`risk: low|medium|high`): - Size label (`size: XS|S|M|L|XL`, auto-managed/read-only): - Scope labels (`core|agent|channel|config|cron|daemon|doctor|gateway|health|heartbeat|integration|memory|observability|onboard|provider|runtime|security|service|skillforge|skills|tool|tunnel|docs|dependencies|ci|tests|scripts|dev`, comma-separated): -<<<<<<< chore/labeler-spacing-trusted-tier - Module labels (`: `, for example `channel: telegram`, `provider: kimi`, `tool: shell`): -======= -- Module labels (`:`, for example `channel:telegram`, `provider:kimi`, `tool:shell`): ->>>>>>> main - Contributor tier label (`trusted contributor|experienced contributor|principal contributor|distinguished contributor`, auto-managed/read-only; author merged PRs >=5/10/20/50): - If any auto-label is incorrect, note requested correction: diff --git a/.github/workflows/ci-run.yml b/.github/workflows/ci-run.yml index 373b879..dea6208 100644 --- a/.github/workflows/ci-run.yml +++ b/.github/workflows/ci-run.yml @@ -41,25 +41,7 @@ jobs: run: ./scripts/ci/detect_change_scope.sh lint: - name: Lint Gate (Format + Clippy) - needs: [changes] - if: needs.changes.outputs.rust_changed == 'true' && (github.event_name != 'pull_request' || contains(github.event.pull_request.labels.*.name, 'ci:full')) - runs-on: blacksmith-2vcpu-ubuntu-2404 - timeout-minutes: 20 - steps: - - uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4 - with: - fetch-depth: 0 - - uses: dtolnay/rust-toolchain@631a55b12751854ce901bb631d5902ceb48146f7 # stable - with: - toolchain: 1.92.0 - components: rustfmt, clippy - - uses: useblacksmith/rust-cache@f53e7f127245d2a269b3d90879ccf259876842d5 # v3 - - name: Run rust quality gate - run: ./scripts/ci/rust_quality_gate.sh - - lint-strict-delta: - name: Lint Gate (Strict Delta) + name: Lint Gate (Format + Clippy + Strict Delta) needs: [changes] if: needs.changes.outputs.rust_changed == 'true' && (github.event_name != 'pull_request' || contains(github.event.pull_request.labels.*.name, 'ci:full')) runs-on: blacksmith-2vcpu-ubuntu-2404 @@ -71,8 +53,10 @@ jobs: - uses: dtolnay/rust-toolchain@631a55b12751854ce901bb631d5902ceb48146f7 # stable with: toolchain: 1.92.0 - components: clippy + components: rustfmt, clippy - uses: useblacksmith/rust-cache@f53e7f127245d2a269b3d90879ccf259876842d5 # v3 + - name: Run rust quality gate + run: ./scripts/ci/rust_quality_gate.sh - name: Run strict lint delta gate env: BASE_SHA: ${{ needs.changes.outputs.base_sha }} @@ -80,8 +64,8 @@ jobs: test: name: Test - needs: [changes, lint, lint-strict-delta] - if: needs.changes.outputs.rust_changed == 'true' && (github.event_name != 'pull_request' || contains(github.event.pull_request.labels.*.name, 'ci:full')) && needs.lint.result == 'success' && needs.lint-strict-delta.result == 'success' + needs: [changes, lint] + if: needs.changes.outputs.rust_changed == 'true' && (github.event_name != 'pull_request' || contains(github.event.pull_request.labels.*.name, 'ci:full')) && needs.lint.result == 'success' runs-on: blacksmith-2vcpu-ubuntu-2404 timeout-minutes: 30 steps: @@ -106,8 +90,8 @@ jobs: with: toolchain: 1.92.0 - uses: useblacksmith/rust-cache@f53e7f127245d2a269b3d90879ccf259876842d5 # v3 - - name: Build release binary - run: cargo build --release --locked --verbose + - name: Build binary (smoke check) + run: cargo build --locked --verbose docs-only: name: Docs-Only Fast Path @@ -185,7 +169,7 @@ jobs: lint-feedback: name: Lint Feedback if: github.event_name == 'pull_request' - needs: [changes, lint, lint-strict-delta, docs-quality] + needs: [changes, lint, docs-quality] runs-on: blacksmith-2vcpu-ubuntu-2404 permissions: contents: read @@ -201,7 +185,7 @@ jobs: RUST_CHANGED: ${{ needs.changes.outputs.rust_changed }} DOCS_CHANGED: ${{ needs.changes.outputs.docs_changed }} LINT_RESULT: ${{ needs.lint.result }} - LINT_DELTA_RESULT: ${{ needs.lint-strict-delta.result }} + LINT_DELTA_RESULT: ${{ needs.lint.result }} DOCS_RESULT: ${{ needs.docs-quality.result }} with: script: | @@ -231,7 +215,7 @@ jobs: ci-required: name: CI Required Gate if: always() - needs: [changes, lint, lint-strict-delta, test, build, docs-only, non-rust, docs-quality, lint-feedback, workflow-owner-approval] + needs: [changes, lint, test, build, docs-only, non-rust, docs-quality, lint-feedback, workflow-owner-approval] runs-on: blacksmith-2vcpu-ubuntu-2404 steps: - name: Enforce required status @@ -276,7 +260,7 @@ jobs: fi lint_result="${{ needs.lint.result }}" - lint_strict_delta_result="${{ needs.lint-strict-delta.result }}" + lint_strict_delta_result="${{ needs.lint.result }}" test_result="${{ needs.test.result }}" build_result="${{ needs.build.result }}" diff --git a/.github/workflows/feature-matrix.yml b/.github/workflows/feature-matrix.yml index 875b0c5..18953e1 100644 --- a/.github/workflows/feature-matrix.yml +++ b/.github/workflows/feature-matrix.yml @@ -1,12 +1,6 @@ name: Feature Matrix on: - push: - branches: [main] - paths: - - "Cargo.toml" - - "Cargo.lock" - - "src/**" schedule: - cron: "30 4 * * 1" # Weekly Monday 4:30am UTC workflow_dispatch: @@ -61,6 +55,3 @@ jobs: - name: Check feature combination run: cargo check --locked ${{ matrix.args }} - - - name: Test feature combination - run: cargo test --locked ${{ matrix.args }} diff --git a/.github/workflows/main-branch-flow.md b/.github/workflows/main-branch-flow.md index 3a26ed1..6490e97 100644 --- a/.github/workflows/main-branch-flow.md +++ b/.github/workflows/main-branch-flow.md @@ -143,7 +143,7 @@ Workflow: `.github/workflows/pub-docker-img.yml` - `latest` + SHA tag (`sha-<12 chars>`) for `main` - semantic tag from pushed git tag (`vX.Y.Z`) + SHA tag for tag pushes - branch name + SHA tag for non-`main` manual dispatch refs -5. Multi-platform publish is used for tag pushes (`linux/amd64,linux/arm64`), while `main` publish stays `linux/amd64`. +5. Multi-platform publish is used for both `main` and tag pushes (`linux/amd64,linux/arm64`). 6. Typical runtime in recent sample: ~139.9s. 7. Result: pushed image tags under `ghcr.io//`. diff --git a/.github/workflows/pr-auto-response.yml b/.github/workflows/pr-auto-response.yml index ee6e100..e5f068e 100644 --- a/.github/workflows/pr-auto-response.yml +++ b/.github/workflows/pr-auto-response.yml @@ -15,7 +15,7 @@ jobs: (github.event.action == 'opened' || github.event.action == 'reopened' || github.event.action == 'labeled' || github.event.action == 'unlabeled')) || (github.event_name == 'pull_request_target' && (github.event.action == 'labeled' || github.event.action == 'unlabeled')) - runs-on: blacksmith-2vcpu-ubuntu-2404 + runs-on: ubuntu-latest permissions: contents: read issues: write @@ -34,7 +34,7 @@ jobs: await script({ github, context, core }); first-interaction: if: github.event.action == 'opened' - runs-on: blacksmith-2vcpu-ubuntu-2404 + runs-on: ubuntu-latest permissions: issues: write pull-requests: write @@ -65,7 +65,7 @@ jobs: labeled-routes: if: github.event.action == 'labeled' - runs-on: blacksmith-2vcpu-ubuntu-2404 + runs-on: ubuntu-latest permissions: contents: read issues: write diff --git a/.github/workflows/pr-check-stale.yml b/.github/workflows/pr-check-stale.yml index 0120547..a2cf24c 100644 --- a/.github/workflows/pr-check-stale.yml +++ b/.github/workflows/pr-check-stale.yml @@ -12,7 +12,7 @@ jobs: permissions: issues: write pull-requests: write - runs-on: blacksmith-2vcpu-ubuntu-2404 + runs-on: ubuntu-latest steps: - name: Mark stale issues and pull requests uses: actions/stale@b5d41d4e1d5dceea10e7104786b73624c18a190f # v10.2.0 diff --git a/.github/workflows/pr-check-status.yml b/.github/workflows/pr-check-status.yml index 390a285..b057e88 100644 --- a/.github/workflows/pr-check-status.yml +++ b/.github/workflows/pr-check-status.yml @@ -2,7 +2,7 @@ name: PR Check Status on: schedule: - - cron: "15 */12 * * *" + - cron: "15 8 * * *" # Once daily at 8:15am UTC workflow_dispatch: permissions: {} @@ -13,13 +13,13 @@ concurrency: jobs: nudge-stale-prs: - runs-on: blacksmith-2vcpu-ubuntu-2404 + runs-on: ubuntu-latest permissions: contents: read pull-requests: write issues: write env: - STALE_HOURS: "4" + STALE_HOURS: "48" steps: - name: Checkout repository uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4 diff --git a/.github/workflows/pr-intake-checks.yml b/.github/workflows/pr-intake-checks.yml index 0cacf88..e703387 100644 --- a/.github/workflows/pr-intake-checks.yml +++ b/.github/workflows/pr-intake-checks.yml @@ -16,7 +16,7 @@ permissions: jobs: intake: name: Intake Checks - runs-on: blacksmith-2vcpu-ubuntu-2404 + runs-on: ubuntu-latest timeout-minutes: 10 steps: - name: Checkout repository diff --git a/.github/workflows/pr-labeler.yml b/.github/workflows/pr-labeler.yml index 8349352..38cf054 100644 --- a/.github/workflows/pr-labeler.yml +++ b/.github/workflows/pr-labeler.yml @@ -25,8 +25,7 @@ permissions: jobs: label: - runs-on: blacksmith-2vcpu-ubuntu-2404 - timeout-minutes: 10 + runs-on: ubuntu-latest steps: - name: Checkout repository uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4 diff --git a/.github/workflows/pub-docker-img.yml b/.github/workflows/pub-docker-img.yml index 15ea8aa..05d83e5 100644 --- a/.github/workflows/pub-docker-img.yml +++ b/.github/workflows/pub-docker-img.yml @@ -21,13 +21,8 @@ on: paths: - "Dockerfile" - ".dockerignore" - - "Cargo.toml" - - "Cargo.lock" + - "docker-compose.yml" - "rust-toolchain.toml" - - "src/**" - - "crates/**" - - "benches/**" - - "firmware/**" - "dev/config.template.toml" - ".github/workflows/pub-docker-img.yml" workflow_dispatch: @@ -75,6 +70,8 @@ jobs: tags: zeroclaw-pr-smoke:latest labels: ${{ steps.meta.outputs.labels || '' }} platforms: linux/amd64 + cache-from: type=gha + cache-to: type=gha,mode=max - name: Verify image run: docker run --rm zeroclaw-pr-smoke:latest --version @@ -83,7 +80,7 @@ jobs: name: Build and Push Docker Image if: (github.event_name == 'workflow_dispatch' || (github.event_name == 'push' && (github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/tags/v')))) && github.repository == 'zeroclaw-labs/zeroclaw' runs-on: blacksmith-2vcpu-ubuntu-2404 - timeout-minutes: 25 + timeout-minutes: 45 permissions: contents: read packages: write @@ -128,7 +125,9 @@ jobs: context: . push: true tags: ${{ steps.meta.outputs.tags }} - platforms: ${{ startsWith(github.ref, 'refs/tags/v') && 'linux/amd64,linux/arm64' || 'linux/amd64' }} + platforms: linux/amd64,linux/arm64 + cache-from: type=gha + cache-to: type=gha,mode=max - name: Set GHCR package visibility to public shell: bash diff --git a/.github/workflows/pub-release.yml b/.github/workflows/pub-release.yml index 7cdb853..14677b1 100644 --- a/.github/workflows/pub-release.yml +++ b/.github/workflows/pub-release.yml @@ -27,15 +27,45 @@ jobs: - os: ubuntu-latest target: x86_64-unknown-linux-gnu artifact: zeroclaw - - os: macos-latest + archive_ext: tar.gz + cross_compiler: "" + linker_env: "" + linker: "" + - os: ubuntu-latest + target: aarch64-unknown-linux-gnu + artifact: zeroclaw + archive_ext: tar.gz + cross_compiler: gcc-aarch64-linux-gnu + linker_env: CARGO_TARGET_AARCH64_UNKNOWN_LINUX_GNU_LINKER + linker: aarch64-linux-gnu-gcc + - os: ubuntu-latest + target: armv7-unknown-linux-gnueabihf + artifact: zeroclaw + archive_ext: tar.gz + cross_compiler: gcc-arm-linux-gnueabihf + linker_env: CARGO_TARGET_ARMV7_UNKNOWN_LINUX_GNUEABIHF_LINKER + linker: arm-linux-gnueabihf-gcc + - os: macos-15-intel target: x86_64-apple-darwin artifact: zeroclaw - - os: macos-latest + archive_ext: tar.gz + cross_compiler: "" + linker_env: "" + linker: "" + - os: macos-14 target: aarch64-apple-darwin artifact: zeroclaw + archive_ext: tar.gz + cross_compiler: "" + linker_env: "" + linker: "" - os: windows-latest target: x86_64-pc-windows-msvc artifact: zeroclaw.exe + archive_ext: zip + cross_compiler: "" + linker_env: "" + linker: "" steps: - uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4 @@ -46,20 +76,41 @@ jobs: - uses: useblacksmith/rust-cache@f53e7f127245d2a269b3d90879ccf259876842d5 # v3 + - name: Install cross-compilation toolchain (Linux) + if: runner.os == 'Linux' && matrix.cross_compiler != '' + run: | + sudo apt-get update -qq + sudo apt-get install -y ${{ matrix.cross_compiler }} + - name: Build release - run: cargo build --release --locked --target ${{ matrix.target }} + env: + LINKER_ENV: ${{ matrix.linker_env }} + LINKER: ${{ matrix.linker }} + run: | + if [ -n "$LINKER_ENV" ] && [ -n "$LINKER" ]; then + echo "Using linker override: $LINKER_ENV=$LINKER" + export "$LINKER_ENV=$LINKER" + fi + cargo build --release --locked --target ${{ matrix.target }} - name: Check binary size (Unix) if: runner.os != 'Windows' run: | - SIZE=$(stat -f%z target/${{ matrix.target }}/release/${{ matrix.artifact }} 2>/dev/null || stat -c%s target/${{ matrix.target }}/release/${{ matrix.artifact }}) + BIN="target/${{ matrix.target }}/release/${{ matrix.artifact }}" + if [ ! -f "$BIN" ]; then + echo "::error::Expected binary not found: $BIN" + exit 1 + fi + SIZE=$(stat -f%z "$BIN" 2>/dev/null || stat -c%s "$BIN") SIZE_MB=$((SIZE / 1024 / 1024)) echo "Binary size: ${SIZE_MB}MB ($SIZE bytes)" echo "### Binary Size: ${{ matrix.target }}" >> "$GITHUB_STEP_SUMMARY" echo "- Size: ${SIZE_MB}MB ($SIZE bytes)" >> "$GITHUB_STEP_SUMMARY" - if [ "$SIZE" -gt 15728640 ]; then - echo "::error::Binary exceeds 15MB hard limit (${SIZE_MB}MB)" + if [ "$SIZE" -gt 41943040 ]; then + echo "::error::Binary exceeds 40MB safeguard (${SIZE_MB}MB)" exit 1 + elif [ "$SIZE" -gt 15728640 ]; then + echo "::warning::Binary exceeds 15MB advisory target (${SIZE_MB}MB)" elif [ "$SIZE" -gt 5242880 ]; then echo "::warning::Binary exceeds 5MB target (${SIZE_MB}MB)" else @@ -70,19 +121,19 @@ jobs: if: runner.os != 'Windows' run: | cd target/${{ matrix.target }}/release - tar czf ../../../zeroclaw-${{ matrix.target }}.tar.gz ${{ matrix.artifact }} + tar czf ../../../zeroclaw-${{ matrix.target }}.${{ matrix.archive_ext }} ${{ matrix.artifact }} - name: Package (Windows) if: runner.os == 'Windows' run: | cd target/${{ matrix.target }}/release - 7z a ../../../zeroclaw-${{ matrix.target }}.zip ${{ matrix.artifact }} + 7z a ../../../zeroclaw-${{ matrix.target }}.${{ matrix.archive_ext }} ${{ matrix.artifact }} - name: Upload artifact uses: actions/upload-artifact@b7c566a772e6b6bfb58ed0dc250532a479d7789f # v6 with: name: zeroclaw-${{ matrix.target }} - path: zeroclaw-${{ matrix.target }}.* + path: zeroclaw-${{ matrix.target }}.${{ matrix.archive_ext }} retention-days: 7 publish: @@ -94,7 +145,7 @@ jobs: - uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4 - name: Download all artifacts - uses: actions/download-artifact@d3f86a106a0bac45b974a628896c90dbdf5c8093 # v4 + uses: actions/download-artifact@37930b1c2abaa49bbe596cd826c3c89aef350131 # v7.0.0 with: path: artifacts @@ -119,7 +170,7 @@ jobs: cat SHA256SUMS - name: Install cosign - uses: sigstore/cosign-installer@3454372f43399081ed03b604cb2d021dabca52bb # v3.8.2 + uses: sigstore/cosign-installer@faadad0cce49287aee09b3a48701e75088a2c6ad # v4.0.0 - name: Sign artifacts with cosign (keyless) run: | diff --git a/.github/workflows/sec-audit.yml b/.github/workflows/sec-audit.yml index 3667725..89b4a32 100644 --- a/.github/workflows/sec-audit.yml +++ b/.github/workflows/sec-audit.yml @@ -3,8 +3,20 @@ name: Sec Audit on: push: branches: [main] + paths: + - "Cargo.toml" + - "Cargo.lock" + - "src/**" + - "crates/**" + - "deny.toml" pull_request: branches: [main] + paths: + - "Cargo.toml" + - "Cargo.lock" + - "src/**" + - "crates/**" + - "deny.toml" schedule: - cron: "0 6 * * 1" # Weekly on Monday 6am UTC diff --git a/.github/workflows/sec-codeql.yml b/.github/workflows/sec-codeql.yml index f5c6c35..300e1ef 100644 --- a/.github/workflows/sec-codeql.yml +++ b/.github/workflows/sec-codeql.yml @@ -2,7 +2,7 @@ name: Sec CodeQL on: schedule: - - cron: "0 6,18 * * *" # Twice daily at 6am and 6pm UTC + - cron: "0 6 * * 1" # Weekly Monday 6am UTC workflow_dispatch: concurrency: diff --git a/.github/workflows/sync-contributors.yml b/.github/workflows/sync-contributors.yml index a5fb2ec..50c7955 100644 --- a/.github/workflows/sync-contributors.yml +++ b/.github/workflows/sync-contributors.yml @@ -17,7 +17,7 @@ permissions: jobs: update-notice: name: Update NOTICE with new contributors - runs-on: blacksmith-2vcpu-ubuntu-2404 + runs-on: ubuntu-latest steps: - name: Checkout repository uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4 diff --git a/.github/workflows/test-benchmarks.yml b/.github/workflows/test-benchmarks.yml index 329f530..036904a 100644 --- a/.github/workflows/test-benchmarks.yml +++ b/.github/workflows/test-benchmarks.yml @@ -1,8 +1,8 @@ name: Test Benchmarks on: - push: - branches: [main] + schedule: + - cron: "0 3 * * 1" # Weekly Monday 3am UTC workflow_dispatch: concurrency: @@ -39,7 +39,7 @@ jobs: path: | target/criterion/ benchmark_output.txt - retention-days: 30 + retention-days: 7 - name: Post benchmark summary on PR if: github.event_name == 'pull_request' diff --git a/.gitignore b/.gitignore index 9846ea4..89a1f8b 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,7 @@ firmware/*/target *.db *.db-journal .DS_Store +._* .wt-pr37/ __pycache__/ *.pyc diff --git a/CHANGELOG.md b/CHANGELOG.md index 4944885..013eb10 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,6 +26,13 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - `enc:` prefix for encrypted secrets — Use `enc2:` (ChaCha20-Poly1305) instead. Legacy values are still decrypted for backward compatibility but should be migrated. +### Fixed +- **Onboarding channel menu dispatch** now uses an enum-backed selector instead of hard-coded + numeric match arms, preventing duplicated pattern arms and related `unreachable pattern` + compiler warnings in `src/onboard/wizard.rs`. +- **OpenAI native tool spec parsing** now uses owned serializable/deserializable structs, + fixing a compile-time type mismatch when validating tool schemas before API calls. + ## [0.1.0] - 2026-02-13 ### Added diff --git a/CLA.md b/CLA.md new file mode 100644 index 0000000..1333c48 --- /dev/null +++ b/CLA.md @@ -0,0 +1,132 @@ +# ZeroClaw Contributor License Agreement (CLA) + +**Version 1.0 — February 2026** +**ZeroClaw Labs** + +--- + +## Purpose + +This Contributor License Agreement ("CLA") clarifies the intellectual +property rights granted by contributors to ZeroClaw Labs. This agreement +protects both contributors and users of the ZeroClaw project. + +By submitting a contribution (pull request, patch, issue with code, or any +other form of code submission) to the ZeroClaw repository, you agree to the +terms of this CLA. + +--- + +## 1. Definitions + +- **"Contribution"** means any original work of authorship, including any + modifications or additions to existing work, submitted to ZeroClaw Labs + for inclusion in the ZeroClaw project. + +- **"You"** means the individual or legal entity submitting a Contribution. + +- **"ZeroClaw Labs"** means the maintainers and organization responsible + for the ZeroClaw project at https://github.com/zeroclaw-labs/zeroclaw. + +--- + +## 2. Grant of Copyright License + +You grant ZeroClaw Labs and recipients of software distributed by ZeroClaw +Labs a perpetual, worldwide, non-exclusive, no-charge, royalty-free, +irrevocable copyright license to: + +- Reproduce, prepare derivative works of, publicly display, publicly + perform, sublicense, and distribute your Contributions and derivative + works under **both the MIT License and the Apache License 2.0**. + +--- + +## 3. Grant of Patent License + +You grant ZeroClaw Labs and recipients of software distributed by ZeroClaw +Labs a perpetual, worldwide, non-exclusive, no-charge, royalty-free, +irrevocable patent license to make, have made, use, offer to sell, sell, +import, and otherwise transfer your Contributions. + +This patent license applies only to patent claims licensable by you that +are necessarily infringed by your Contribution alone or in combination with +the ZeroClaw project. + +**This protects you:** if a third party files a patent claim against +ZeroClaw that covers your Contribution, your patent license to the project +is not revoked. + +--- + +## 4. You Retain Your Rights + +This CLA does **not** transfer ownership of your Contribution to ZeroClaw +Labs. You retain full copyright ownership of your Contribution. You are +free to use your Contribution in any other project under any license. + +--- + +## 5. Original Work + +You represent that: + +1. Each Contribution is your original creation, or you have sufficient + rights to submit it under this CLA. +2. Your Contribution does not knowingly infringe any third-party patent, + copyright, trademark, or other intellectual property right. +3. If your employer has rights to intellectual property you create, you + have received permission to submit the Contribution, or your employer + has signed a corporate CLA with ZeroClaw Labs. + +--- + +## 6. No Trademark Rights + +This CLA does not grant you any rights to use the ZeroClaw name, +trademarks, service marks, or logos. See TRADEMARK.md for trademark policy. + +--- + +## 7. Attribution + +ZeroClaw Labs will maintain attribution to contributors in the repository +commit history and NOTICE file. Your contributions are permanently and +publicly recorded. + +--- + +## 8. Dual-License Commitment + +All Contributions accepted into the ZeroClaw project are licensed under +both: + +- **MIT License** — permissive open-source use +- **Apache License 2.0** — patent protection and stronger IP guarantees + +This dual-license model ensures maximum compatibility and protection for +the entire contributor community. + +--- + +## 9. How to Agree + +By opening a pull request or submitting a patch to the ZeroClaw repository, +you indicate your agreement to this CLA. No separate signature is required +for individual contributors. + +For **corporate contributors** (submitting on behalf of a company or +organization), please open an issue titled "Corporate CLA — [Company Name]" +and a maintainer will follow up. + +--- + +## 10. Questions + +If you have questions about this CLA, open an issue at: +https://github.com/zeroclaw-labs/zeroclaw/issues + +--- + +*This CLA is based on the Apache Individual Contributor License Agreement +v2.0, adapted for the ZeroClaw dual-license model.* diff --git a/Cargo.lock b/Cargo.lock index d058410..72f07ed 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -47,7 +47,21 @@ checksum = "b169f7a6d4742236a0a00c541b845991d0ac43e546831af1249753ab4c3aa3a0" dependencies = [ "cfg-if", "cipher", - "cpufeatures", + "cpufeatures 0.2.17", +] + +[[package]] +name = "aes-gcm" +version = "0.10.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "831010a0f742e1209b3bcea8fab6a8e149051ba6099432c8cb2cc117dec3ead1" +dependencies = [ + "aead", + "aes", + "cipher", + "ctr", + "ghash", + "subtle", ] [[package]] @@ -71,6 +85,15 @@ dependencies = [ "memchr", ] +[[package]] +name = "alloca" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5a7d05ea6aea7e9e64d25b9156ba2fee3fdd659e34e41063cd2fc7cd020d7f4" +dependencies = [ + "cc", +] + [[package]] name = "allocator-api2" version = "0.2.21" @@ -144,9 +167,9 @@ dependencies = [ [[package]] name = "anyhow" -version = "1.0.101" +version = "1.0.102" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f0e0fee31ef5ed1ba1316088939cea399010ed7731dba877ed44aeb407a75ea" +checksum = "7f202df86484c868dbad7eaa557ef785d5c66295e41b460ef922eca0723b842c" [[package]] name = "anymap2" @@ -235,9 +258,9 @@ dependencies = [ [[package]] name = "async-compression" -version = "0.4.39" +version = "0.4.40" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "68650b7df54f0293fd061972a0fb05aaf4fc0879d3b3d21a638a182c5c543b9f" +checksum = "7d67d43201f4d20c78bcda740c142ca52482d81da80681533d33bf3f0596c8e2" dependencies = [ "compression-codecs", "compression-core", @@ -281,11 +304,22 @@ dependencies = [ "futures-lite", "parking", "polling", - "rustix 1.1.3", + "rustix", "slab", "windows-sys 0.61.2", ] +[[package]] +name = "async-lock" +version = "3.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "290f7f2596bd5b78a9fec8088ccd89180d7f9f55b94b0576823bbbdc72ee8311" +dependencies = [ + "event-listener 5.4.1", + "event-listener-strategy", + "pin-project-lite", +] + [[package]] name = "async-stream" version = "0.3.6" @@ -360,6 +394,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b52af3cb4058c895d37317bb27508dccc8e5f2d39454016b297bf4a400597b8" dependencies = [ "axum-core", + "axum-macros", "base64", "bytes", "form_urlencoded", @@ -382,7 +417,7 @@ dependencies = [ "sha1", "sync_wrapper", "tokio", - "tokio-tungstenite 0.28.0", + "tokio-tungstenite", "tower", "tower-layer", "tower-service", @@ -406,6 +441,17 @@ dependencies = [ "tower-service", ] +[[package]] +name = "axum-macros" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "604fde5e028fea851ce1d8570bbdc034bec850d157f7569d10f347d06808c05c" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.116", +] + [[package]] name = "backon" version = "1.6.0" @@ -429,6 +475,12 @@ version = "1.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2af50177e190e07a26ab74f8b1efbfe2ef87da2116221318cb1c2e82baf7de06" +[[package]] +name = "basic-udev" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a45f9771ced8a774de5e5ebffbe520f52e3943bf5a9a6baa3a5d14a5de1afe6" + [[package]] name = "bincode" version = "2.0.1" @@ -513,7 +565,7 @@ dependencies = [ "cc", "cfg-if", "constant_time_eq", - "cpufeatures", + "cpufeatures 0.2.17", ] [[package]] @@ -586,6 +638,9 @@ name = "bytes" version = "1.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1e748733b7cbc798e1434b6ac524f0c1ff2ab456fe201501e6497c8417a4fc33" +dependencies = [ + "serde", +] [[package]] name = "bytesize" @@ -646,7 +701,18 @@ checksum = "c3613f74bd2eac03dad61bd53dbe620703d4371614fe0bc3b9f04dd36fe4e818" dependencies = [ "cfg-if", "cipher", - "cpufeatures", + "cpufeatures 0.2.17", +] + +[[package]] +name = "chacha20" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6f8d983286843e49675a4b7a2d174efe136dc93a18d69130dd18198a6c167601" +dependencies = [ + "cfg-if", + "cpufeatures 0.3.0", + "rand_core 0.10.0", ] [[package]] @@ -656,7 +722,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "10cd79432192d1c0f4e1a0fef9527696cc039165d729fb41b3f4f4f354c2dc35" dependencies = [ "aead", - "chacha20", + "chacha20 0.9.1", "cipher", "poly1305", "zeroize", @@ -736,9 +802,9 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.59" +version = "4.5.60" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c5caf74d17c3aec5495110c34cc3f78644bfa89af6c8993ed4de2790e49b6499" +checksum = "2797f34da339ce31042b27d23607e051786132987f595b02ba4f6a6dffb7030a" dependencies = [ "clap_builder", "clap_derive", @@ -746,9 +812,9 @@ dependencies = [ [[package]] name = "clap_builder" -version = "4.5.59" +version = "4.5.60" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "370daa45065b80218950227371916a1633217ae42b2715b2287b606dcd618e24" +checksum = "24a241312cea5059b13574bb9b3861cabf758b879c15190b37b6d6fd63ab6876" dependencies = [ "anstream", "anstyle", @@ -801,9 +867,9 @@ checksum = "b05b61dc5112cbb17e4b6cd61790d9845d13888356391624cbe7e41efeac1e75" [[package]] name = "compression-codecs" -version = "0.4.36" +version = "0.4.37" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "00828ba6fd27b45a448e57dbfe84f1029d4c9f26b368157e9a448a5f49a2ec2a" +checksum = "eb7b51a7d9c967fc26773061ba86150f19c50c0d65c887cb1fbe295fd16619b7" dependencies = [ "compression-core", "flate2", @@ -881,13 +947,21 @@ dependencies = [ ] [[package]] -name = "core-foundation" -version = "0.9.4" +name = "cookie_store" +version = "0.22.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "91e195e091a93c46f7102ec7818a2aa394e1e1771c3ab4825963fa03e45afb8f" +checksum = "15b2c103cf610ec6cae3da84a766285b42fd16aad564758459e6ecf128c75206" dependencies = [ - "core-foundation-sys", - "libc", + "cookie 0.18.1", + "document-features", + "idna", + "indexmap", + "log", + "serde", + "serde_derive", + "serde_json", + "time", + "url", ] [[package]] @@ -924,6 +998,15 @@ dependencies = [ "libc", ] +[[package]] +name = "cpufeatures" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b2a41393f66f16b0823bb79094d54ac5fbd34ab292ddafb9a0456ac9f87d201" +dependencies = [ + "libc", +] + [[package]] name = "crc32fast" version = "1.5.0" @@ -935,26 +1018,24 @@ dependencies = [ [[package]] name = "criterion" -version = "0.5.1" +version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f2b12d017a929603d80db1831cd3a24082f8137ce19c69e6447f54f5fc8d692f" +checksum = "950046b2aa2492f9a536f5f4f9a3de7b9e2476e575e05bd6c333371add4d98f3" dependencies = [ + "alloca", "anes", "cast", "ciborium", "clap", "criterion-plot", - "futures", - "is-terminal", - "itertools 0.10.5", + "itertools 0.13.0", "num-traits", - "once_cell", "oorandom", + "page_size", "plotters", "rayon", "regex", "serde", - "serde_derive", "serde_json", "tinytemplate", "tokio", @@ -963,12 +1044,12 @@ dependencies = [ [[package]] name = "criterion-plot" -version = "0.5.0" +version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1" +checksum = "d8d80a2f4f5b554395e47b5d8305bc3d27813bacb73493eb1001e8f76dae29ea" dependencies = [ "cast", - "itertools 0.10.5", + "itertools 0.13.0", ] [[package]] @@ -982,6 +1063,15 @@ dependencies = [ "winnow 0.6.26", ] +[[package]] +name = "crossbeam-channel" +version = "0.5.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "82b8f8f868b36967f9606790d1903570de9ceaf870a7bf9fbbd3016d636a2cb2" +dependencies = [ + "crossbeam-utils", +] + [[package]] name = "crossbeam-deque" version = "0.8.6" @@ -1061,7 +1151,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "97fb8b7c4503de7d6ae7b42ab72a5a59857b4c937ec27a3d4539dba95b5ab2be" dependencies = [ "cfg-if", - "cpufeatures", + "cpufeatures 0.2.17", "curve25519-dalek-derive", "digest", "fiat-crypto", @@ -1117,6 +1207,20 @@ dependencies = [ "syn 2.0.116", ] +[[package]] +name = "dashmap" +version = "6.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5041cc499144891f3790297212f32a74fb938e5136a14943f338ef9e0ae276cf" +dependencies = [ + "cfg-if", + "crossbeam-utils", + "hashbrown 0.14.5", + "lock_api", + "once_cell", + "parking_lot_core", +] + [[package]] name = "data-encoding" version = "2.10.0" @@ -1304,22 +1408,13 @@ dependencies = [ "subtle", ] -[[package]] -name = "directories" -version = "5.0.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9a49173b84e034382284f27f1af4dcbbd231ffa358c0fe316541a7337f376a35" -dependencies = [ - "dirs-sys 0.4.1", -] - [[package]] name = "directories" version = "6.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "16f5094c54661b38d03bd7e50df373292118db60b585c08a411c6d840017fe7d" dependencies = [ - "dirs-sys 0.5.0", + "dirs-sys", ] [[package]] @@ -1328,19 +1423,7 @@ version = "6.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c3e8aa94d75141228480295a7d0e7feb620b1a5ad9f12bc40be62411e38cce4e" dependencies = [ - "dirs-sys 0.5.0", -] - -[[package]] -name = "dirs-sys" -version = "0.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "520f05a5cbd335fae5a99ff7a6ab8627577660ee5cfd6a94a6a929b52ff0321c" -dependencies = [ - "libc", - "option-ext", - "redox_users 0.4.6", - "windows-sys 0.48.0", + "dirs-sys", ] [[package]] @@ -1351,7 +1434,7 @@ checksum = "e01a3366d27ee9890022452ee61b2b63a67e6f13f58900b651ff5665f0bb1fab" dependencies = [ "libc", "option-ext", - "redox_users 0.5.2", + "redox_users", "windows-sys 0.61.2", ] @@ -1386,12 +1469,27 @@ dependencies = [ "syn 2.0.116", ] +[[package]] +name = "document-features" +version = "0.2.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d4b8a88685455ed29a21542a33abd9cb6510b6b129abadabdcef0f4c55bc8f61" +dependencies = [ + "litrs", +] + [[package]] name = "dunce" version = "1.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "92773504d58c093f6de2459af4af33faa518c13451eb8f2b5698ed3d36e7c813" +[[package]] +name = "dyn-clone" +version = "1.0.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d0881ea181b1df73ff77ffaaf9c7544ecc11e82fba9b5f27b262a3c73a332555" + [[package]] name = "ecb" version = "0.1.2" @@ -1484,6 +1582,25 @@ dependencies = [ "syn 2.0.116", ] +[[package]] +name = "env_filter" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a1c3cc8e57274ec99de65301228b537f1e4eedc1b8e0f9411c6caac8ae7308f" +dependencies = [ + "log", +] + +[[package]] +name = "env_logger" +version = "0.11.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b2daee4ea451f429a58296525ddf28b45a3b64f1acf6587e2067437bb11e218d" +dependencies = [ + "env_filter", + "log", +] + [[package]] name = "equivalent" version = "1.0.2" @@ -1675,6 +1792,12 @@ version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5baebc0774151f905a1a2cc41989300b1e6fbb29aff0ceffa1064fdd3088d582" +[[package]] +name = "fixedbitset" +version = "0.5.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d674e81391d1e1ab681a28d99df07927c6d4aa5b027d7da16ba32d1d21ecd99" + [[package]] name = "flate2" version = "1.1.9" @@ -1683,6 +1806,7 @@ checksum = "843fba2746e448b37e26a819579957415c8cef339bf08564fe8b7ddbd959573c" dependencies = [ "crc32fast", "miniz_oxide", + "zlib-rs", ] [[package]] @@ -1881,10 +2005,21 @@ dependencies = [ "cfg-if", "libc", "r-efi", + "rand_core 0.10.0", "wasip2", "wasip3", ] +[[package]] +name = "ghash" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0d8a4362ccb29cb0b265253fb0a2728f592895ee6854fd9bc13f2ffda266ff1" +dependencies = [ + "opaque-debug", + "polyval", +] + [[package]] name = "gimli" version = "0.32.3" @@ -2063,12 +2198,6 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" -[[package]] -name = "hermit-abi" -version = "0.3.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024" - [[package]] name = "hermit-abi" version = "0.5.2" @@ -2087,12 +2216,12 @@ version = "2.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "565dd4c730b8f8b2c0fb36df6be12e5470ae10895ddcc4e9dcfbfb495de202b0" dependencies = [ + "basic-udev", "cc", "cfg-if", "libc", "nix 0.27.1", "pkg-config", - "udev", "windows-sys 0.48.0", ] @@ -2576,18 +2705,17 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "617ee6cf8e3f66f3b4ea67a4058564628cde41901316e19f559e14c7c72c5e7b" dependencies = [ "core-foundation-sys", - "mach2", + "mach2 0.4.3", ] [[package]] -name = "io-lifetimes" -version = "1.0.11" +name = "io-kit-sys" +version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eae7b9aee968036d54dce06cebaefd919e4472e753296daccd6d344e3e2df0c2" +checksum = "06d3a048d09fbb6597dbf7c69f40d14df4a49487db1487191618c893fc3b1c26" dependencies = [ - "hermit-abi 0.3.9", - "libc", - "windows-sys 0.48.0", + "core-foundation-sys", + "mach2 0.5.0", ] [[package]] @@ -2606,17 +2734,6 @@ dependencies = [ "serde", ] -[[package]] -name = "is-terminal" -version = "0.4.17" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3640c1c38b8e4e43584d8df18be5fc6b0aa314ce6ebf51b53313d4306cca8e46" -dependencies = [ - "hermit-abi 0.5.2", - "libc", - "windows-sys 0.61.2", -] - [[package]] name = "is_terminal_polyfill" version = "1.70.2" @@ -2632,6 +2749,15 @@ dependencies = [ "either", ] +[[package]] +name = "itertools" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186" +dependencies = [ + "either", +] + [[package]] name = "itertools" version = "0.14.0" @@ -2800,28 +2926,6 @@ dependencies = [ "vcpkg", ] -[[package]] -name = "libudev-sys" -version = "0.1.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3c8469b4a23b962c1396b9b451dda50ef5b283e8dd309d69033475fa9b334324" -dependencies = [ - "libc", - "pkg-config", -] - -[[package]] -name = "linux-raw-sys" -version = "0.4.15" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d26c52dbd32dccf2d10cac7725f8eae5296885fb5703b261f7d0a0739ec807ab" - -[[package]] -name = "linux-raw-sys" -version = "0.9.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cd945864f07fe9f5371a27ad7b52a172b4b499999f1d97574c9fa68373937e12" - [[package]] name = "linux-raw-sys" version = "0.11.0" @@ -2840,6 +2944,12 @@ version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6373607a59f0be73a39b6fe456b8192fcc3585f602af20751600e974dd455e77" +[[package]] +name = "litrs" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "11d3d7f243d5c5a8b9bb5d6dd2b1602c0cb0b9db1621bafc7ed66e35ff9fe092" + [[package]] name = "lock_api" version = "0.4.14" @@ -2904,6 +3014,15 @@ dependencies = [ "libc", ] +[[package]] +name = "mach2" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a1b95cd5421ec55b445b5ae102f5ea0e768de1f82bd3001e11f426c269c3aea" +dependencies = [ + "libc", +] + [[package]] name = "macroific" version = "2.0.0" @@ -3306,6 +3425,12 @@ dependencies = [ "digest", ] +[[package]] +name = "md5" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ae960838283323069879657ca3de837e9f7bbb4c7bf6ea7f1b290d5e9476d2e0" + [[package]] name = "memchr" version = "2.8.0" @@ -3397,6 +3522,32 @@ dependencies = [ "winapi", ] +[[package]] +name = "moka" +version = "0.12.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b4ac832c50ced444ef6be0767a008b02c106a909ba79d1d830501e94b96f6b7e" +dependencies = [ + "async-lock", + "crossbeam-channel", + "crossbeam-epoch", + "crossbeam-utils", + "equivalent", + "event-listener 5.4.1", + "futures-util", + "parking_lot", + "portable-atomic", + "smallvec", + "tagptr", + "uuid", +] + +[[package]] +name = "multimap" +version = "0.10.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d87ecb2933e8aeadb3e3a02b828fed80a7528047e68b4f424523a0981a3a084" + [[package]] name = "new_debug_unreachable" version = "1.0.6" @@ -3518,45 +3669,26 @@ version = "1.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "91df4bbde75afed763b708b7eee1e8e7651e02d97f6d5dd763e89367e957b23b" dependencies = [ - "hermit-abi 0.5.2", + "hermit-abi", "libc", ] [[package]] name = "nusb" -version = "0.1.14" +version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2f861541f15de120eae5982923d073bfc0c1a65466561988c82d6e197734c19e" +checksum = "5750d884c774a2862b0049b0318aea27cecc9e873485540af5ed8ab8841247da" dependencies = [ - "atomic-waker", - "core-foundation 0.9.4", + "core-foundation", "core-foundation-sys", "futures-core", - "io-kit-sys", - "libc", + "io-kit-sys 0.5.0", + "linux-raw-sys", "log", "once_cell", - "rustix 0.38.44", + "rustix", "slab", - "windows-sys 0.48.0", -] - -[[package]] -name = "nusb" -version = "0.2.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d0226f4db3ee78f820747cf713767722877f6449d7a0fcfbf2ec3b840969763f" -dependencies = [ - "core-foundation 0.10.1", - "core-foundation-sys", - "futures-core", - "io-kit-sys", - "linux-raw-sys 0.9.4", - "log", - "once_cell", - "rustix 1.1.3", - "slab", - "windows-sys 0.60.2", + "windows-sys 0.61.2", ] [[package]] @@ -3723,6 +3855,16 @@ version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "04744f49eae99ab78e0d5c0b603ab218f515ea8cfe5a456d7629ad883a3b6e7d" +[[package]] +name = "page_size" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "30d5b2194ed13191c1999ae0704b7839fb18384fa22e49b57eeaa97d79ce40da" +dependencies = [ + "libc", + "winapi", +] + [[package]] name = "parking" version = "2.2.1" @@ -3794,6 +3936,17 @@ version = "2.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9b4f627cb1b25917193a259e49bdad08f671f8d9708acfd5fe0a8c1455d87220" +[[package]] +name = "petgraph" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8701b58ea97060d5e5b155d383a69952a60943f0e6dfe30b04c287beb0b27455" +dependencies = [ + "fixedbitset", + "hashbrown 0.15.5", + "indexmap", +] + [[package]] name = "phf" version = "0.11.3" @@ -3828,10 +3981,20 @@ version = "0.11.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "aef8048c789fa5e851558d709946d6d79a8ff88c0440c587967f8e94bfb1216a" dependencies = [ - "phf_generator", + "phf_generator 0.11.3", "phf_shared 0.11.3", ] +[[package]] +name = "phf_codegen" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49aa7f9d80421bca176ca8dbfebe668cc7a2684708594ec9f3c0db0805d5d6e1" +dependencies = [ + "phf_generator 0.13.1", + "phf_shared 0.13.1", +] + [[package]] name = "phf_generator" version = "0.11.3" @@ -3842,6 +4005,16 @@ dependencies = [ "rand 0.8.5", ] +[[package]] +name = "phf_generator" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "135ace3a761e564ec88c03a77317a7c6b80bb7f7135ef2544dbe054243b89737" +dependencies = [ + "fastrand", + "phf_shared 0.13.1", +] + [[package]] name = "phf_shared" version = "0.11.3" @@ -3953,9 +4126,9 @@ checksum = "5d0e4f59085d47d8241c88ead0f274e8a0cb551f3625263c05eb8dd897c34218" dependencies = [ "cfg-if", "concurrent-queue", - "hermit-abi 0.5.2", + "hermit-abi", "pin-project-lite", - "rustix 1.1.3", + "rustix", "windows-sys 0.61.2", ] @@ -3965,7 +4138,19 @@ version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8159bd90725d2df49889a078b54f4f79e87f1f8a8444194cdca81d38f5393abf" dependencies = [ - "cpufeatures", + "cpufeatures 0.2.17", + "opaque-debug", + "universal-hash", +] + +[[package]] +name = "polyval" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d1fe60d06143b2430aa532c94cfe9e29783047f06c0d7fd359a9a51b729fa25" +dependencies = [ + "cfg-if", + "cpufeatures 0.2.17", "opaque-debug", "universal-hash", ] @@ -3976,6 +4161,12 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "60f6ce597ecdcc9a098e7fddacb1065093a3d66446fa16c675e7e71d1b5c28e6" +[[package]] +name = "portable-atomic" +version = "1.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c33a9471896f1c69cecef8d20cbe2f7accd12527ce60845ff44c153bb2a21b49" + [[package]] name = "postgres" version = "0.19.12" @@ -4068,9 +4259,9 @@ dependencies = [ [[package]] name = "probe-rs" -version = "0.30.0" +version = "0.31.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7ee27329ac37fa02b194c62a4e3c1aa053739884ea7bcf861249866d3bf7de00" +checksum = "ee50102aaa214117fc4fbe1311077835f0f4faa71e4a769bf65f955cc020ee34" dependencies = [ "anyhow", "async-io", @@ -4087,8 +4278,8 @@ dependencies = [ "ihex", "itertools 0.14.0", "jep106", - "nusb 0.1.14", - "object 0.37.3", + "nusb", + "object 0.38.1", "parking_lot", "probe-rs-target", "rmp-serde", @@ -4104,9 +4295,9 @@ dependencies = [ [[package]] name = "probe-rs-target" -version = "0.30.0" +version = "0.31.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2239aca5dc62c68ca6d8ff0051fe617cb8363b803380fbc60567e67c82b474df" +checksum = "031bed1313b45d93dae4ca8f0fee098530c6632e4ebd9e2769d5a49cdef273d3" dependencies = [ "base64", "indexmap", @@ -4122,7 +4313,7 @@ version = "3.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "219cb19e96be00ab2e37d6e299658a0cfa83e52429179969b0f0121b4ac46983" dependencies = [ - "toml_edit 0.23.10+spec-1.0.0", + "toml_edit", ] [[package]] @@ -4189,6 +4380,23 @@ dependencies = [ "prost-derive 0.14.3", ] +[[package]] +name = "prost-build" +version = "0.14.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "343d3bd7056eda839b03204e68deff7d1b13aba7af2b2fd16890697274262ee7" +dependencies = [ + "heck", + "itertools 0.14.0", + "log", + "multimap", + "petgraph", + "prost 0.14.3", + "prost-types", + "regex", + "tempfile", +] + [[package]] name = "prost-derive" version = "0.13.5" @@ -4215,6 +4423,35 @@ dependencies = [ "syn 2.0.116", ] +[[package]] +name = "prost-types" +version = "0.14.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8991c4cbdb8bc5b11f0b074ffe286c30e523de90fee5ba8132f1399f23cb3dd7" +dependencies = [ + "prost 0.14.3", +] + +[[package]] +name = "protobuf" +version = "3.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d65a1d4ddae7d8b5de68153b48f6aa3bba8cb002b243dbdbc55a5afbc98f99f4" +dependencies = [ + "once_cell", + "protobuf-support", + "thiserror 1.0.69", +] + +[[package]] +name = "protobuf-support" +version = "3.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3e36c2f31e0a47f9280fb347ef5e461ffcd2c52dd520d8e216b52f93b0b0d7d6" +dependencies = [ + "thiserror 1.0.69", +] + [[package]] name = "psm" version = "0.1.30" @@ -4346,6 +4583,17 @@ dependencies = [ "rand_core 0.9.5", ] +[[package]] +name = "rand" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bc266eb313df6c5c09c1c7b1fbe2510961e5bcd3add930c1e31f7ed9da0feff8" +dependencies = [ + "chacha20 0.10.0", + "getrandom 0.4.1", + "rand_core 0.10.0", +] + [[package]] name = "rand_chacha" version = "0.3.1" @@ -4384,6 +4632,12 @@ dependencies = [ "getrandom 0.3.4", ] +[[package]] +name = "rand_core" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c8d0fd677905edcbeedbf2edb6494d676f0e98d54d5cf9bda0b061cb8fb8aba" + [[package]] name = "rand_xoshiro" version = "0.7.0" @@ -4443,17 +4697,6 @@ dependencies = [ "bitflags 2.11.0", ] -[[package]] -name = "redox_users" -version = "0.4.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ba009ff324d1fc1b900bd1fdb31564febe58a8ccc8a6fdbb93b543d33b13ca43" -dependencies = [ - "getrandom 0.2.17", - "libredox", - "thiserror 1.0.69", -] - [[package]] name = "redox_users" version = "0.5.2" @@ -4465,6 +4708,26 @@ dependencies = [ "thiserror 2.0.18", ] +[[package]] +name = "ref-cast" +version = "1.0.25" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f354300ae66f76f1c85c5f84693f0ce81d747e2c3f21a45fef496d89c960bf7d" +dependencies = [ + "ref-cast-impl", +] + +[[package]] +name = "ref-cast-impl" +version = "1.0.25" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b7186006dcb21920990093f30e3dea63b7d6e977bf1256be20c3563a5db070da" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.116", +] + [[package]] name = "regex" version = "1.12.3" @@ -4794,19 +5057,6 @@ dependencies = [ "semver", ] -[[package]] -name = "rustix" -version = "0.38.44" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fdb5bc1ae2baa591800df16c9ca78619bf65c0488b41b96ccec5d11220d8c154" -dependencies = [ - "bitflags 2.11.0", - "errno", - "libc", - "linux-raw-sys 0.4.15", - "windows-sys 0.59.0", -] - [[package]] name = "rustix" version = "1.1.3" @@ -4816,7 +5066,7 @@ dependencies = [ "bitflags 2.11.0", "errno", "libc", - "linux-raw-sys 0.11.0", + "linux-raw-sys", "windows-sys 0.61.2", ] @@ -4909,6 +5159,31 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "schemars" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2b42f36aa1cd011945615b92222f6bf73c599a102a300334cd7f8dbeec726cc" +dependencies = [ + "dyn-clone", + "ref-cast", + "schemars_derive", + "serde", + "serde_json", +] + +[[package]] +name = "schemars_derive" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7d115b50f4aaeea07e79c1912f645c7513d81715d0420f8bc77a18c6260b307f" +dependencies = [ + "proc-macro2", + "quote", + "serde_derive_internals", + "syn 2.0.116", +] + [[package]] name = "scopeguard" version = "1.2.0" @@ -4939,7 +5214,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d17b898a6d6948c3a8ee4372c17cb384f90d2e6e912ef00895b14fd7ab54ec38" dependencies = [ "bitflags 2.11.0", - "core-foundation 0.10.1", + "core-foundation", "core-foundation-sys", "libc", "security-framework-sys", @@ -4977,6 +5252,15 @@ dependencies = [ "serde_derive", ] +[[package]] +name = "serde-big-array" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "11fc7cc2c76d73e0f27ee52abbd64eec84d46f370c88371120433196934e4b7f" +dependencies = [ + "serde", +] + [[package]] name = "serde-wasm-bindgen" version = "0.6.5" @@ -5018,6 +5302,17 @@ dependencies = [ "syn 2.0.116", ] +[[package]] +name = "serde_derive_internals" +version = "0.29.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "18d26a20a969b9e3fdf2fc2d9f21eda6c40e2de84c9408bb5d3b05d499aae711" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.116", +] + [[package]] name = "serde_html_form" version = "0.2.8" @@ -5064,15 +5359,6 @@ dependencies = [ "serde", ] -[[package]] -name = "serde_spanned" -version = "0.6.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bf41e0cfaf7226dca15e8197172c295a782857fcb97fad1808a166870dee75a3" -dependencies = [ - "serde", -] - [[package]] name = "serde_spanned" version = "1.0.4" @@ -5130,10 +5416,10 @@ checksum = "2acaf3f973e8616d7ceac415f53fc60e190b2a686fbcf8d27d0256c741c5007b" dependencies = [ "bitflags 2.11.0", "cfg-if", - "core-foundation 0.10.1", + "core-foundation", "core-foundation-sys", - "io-kit-sys", - "mach2", + "io-kit-sys 0.4.1", + "mach2 0.4.3", "nix 0.26.4", "scopeguard", "unescaper", @@ -5147,7 +5433,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba" dependencies = [ "cfg-if", - "cpufeatures", + "cpufeatures 0.2.17", "digest", ] @@ -5158,7 +5444,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a7507d819769d01a365ab707794a4084392c824f54a7a6a7862f8c3d0892b283" dependencies = [ "cfg-if", - "cpufeatures", + "cpufeatures 0.2.17", "digest", ] @@ -5217,6 +5503,12 @@ version = "0.3.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e320a6c5ad31d271ad523dcf3ad13e2767ad8b1cb8f047f75a8aeaf8da139da2" +[[package]] +name = "simdutf8" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3a9fe34e3e7a50316060351f37187a3f546bce95496156754b601a5fa71b76e" + [[package]] name = "siphasher" version = "1.0.2" @@ -5308,7 +5600,7 @@ version = "0.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c711928715f1fe0fe509c53b43e993a9a557babc2d0a3567d0a3006f1ac931a0" dependencies = [ - "phf_generator", + "phf_generator 0.11.3", "phf_shared 0.11.3", "proc-macro2", "quote", @@ -5400,6 +5692,12 @@ dependencies = [ "syn 2.0.116", ] +[[package]] +name = "tagptr" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b2093cf4c8eb1e67749a6762251bc9cd836b6fc171623bd0a9d324d37af2417" + [[package]] name = "tap" version = "1.0.1" @@ -5415,7 +5713,7 @@ dependencies = [ "fastrand", "getrandom 0.4.1", "once_cell", - "rustix 1.1.3", + "rustix", "windows-sys 0.61.2", ] @@ -5656,9 +5954,9 @@ dependencies = [ [[package]] name = "tokio-tungstenite" -version = "0.24.0" +version = "0.28.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "edc5f74e248dc973e0dbb7b74c7e0d6fcc301c694ff50049504004ef4d0cdcd9" +checksum = "d25a406cddcc431a75d3d9afc6a7c0f7428d4891dd973e4d54c56b46127bf857" dependencies = [ "futures-util", "log", @@ -5666,22 +5964,10 @@ dependencies = [ "rustls-pki-types", "tokio", "tokio-rustls", - "tungstenite 0.24.0", + "tungstenite", "webpki-roots 0.26.11", ] -[[package]] -name = "tokio-tungstenite" -version = "0.28.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d25a406cddcc431a75d3d9afc6a7c0f7428d4891dd973e4d54c56b46127bf857" -dependencies = [ - "futures-util", - "log", - "tokio", - "tungstenite 0.28.0", -] - [[package]] name = "tokio-util" version = "0.7.18" @@ -5696,15 +5982,24 @@ dependencies = [ ] [[package]] -name = "toml" -version = "0.8.23" +name = "tokio-websockets" +version = "0.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc1beb996b9d83529a9e75c17a1686767d148d70663143c7854d8b4a09ced362" +checksum = "8b6aa6c8b5a31e06fd3760eb5c1b8d9072e30731f0467ee3795617fe768e7449" dependencies = [ - "serde", - "serde_spanned 0.6.9", - "toml_datetime 0.6.11", - "toml_edit 0.22.27", + "base64", + "bytes", + "futures-core", + "futures-sink", + "http 1.4.0", + "httparse", + "rand 0.9.2", + "ring", + "rustls-pki-types", + "simdutf8", + "tokio", + "tokio-rustls", + "tokio-util", ] [[package]] @@ -5714,7 +6009,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cf92845e79fc2e2def6a5d828f0801e29a2f8acc037becc5ab08595c7d5e9863" dependencies = [ "serde_core", - "serde_spanned 1.0.4", + "serde_spanned", "toml_datetime 0.7.5+spec-1.1.0", "toml_parser", "winnow 0.7.14", @@ -5722,28 +6017,19 @@ dependencies = [ [[package]] name = "toml" -version = "1.0.2+spec-1.1.0" +version = "1.0.1+spec-1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d1dfefef6a142e93f346b64c160934eb13b5594b84ab378133ac6815cb2bd57f" +checksum = "bbe30f93627849fa362d4a602212d41bb237dc2bd0f8ba0b2ce785012e124220" dependencies = [ "indexmap", "serde_core", - "serde_spanned 1.0.4", + "serde_spanned", "toml_datetime 1.0.0+spec-1.1.0", "toml_parser", "toml_writer", "winnow 0.7.14", ] -[[package]] -name = "toml_datetime" -version = "0.6.11" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "22cddaf88f4fbc13c51aebbf5f8eceb5c7c5a9da2ac40a13519eb5b0a0e8f11c" -dependencies = [ - "serde", -] - [[package]] name = "toml_datetime" version = "0.7.5+spec-1.1.0" @@ -5762,20 +6048,6 @@ dependencies = [ "serde_core", ] -[[package]] -name = "toml_edit" -version = "0.22.27" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "41fe8c660ae4257887cf66394862d21dbca4a6ddd26f04a3560410406a2f819a" -dependencies = [ - "indexmap", - "serde", - "serde_spanned 0.6.9", - "toml_datetime 0.6.11", - "toml_write", - "winnow 0.7.14", -] - [[package]] name = "toml_edit" version = "0.23.10+spec-1.0.0" @@ -5790,19 +6062,13 @@ dependencies = [ [[package]] name = "toml_parser" -version = "1.0.9+spec-1.1.0" +version = "1.0.8+spec-1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "702d4415e08923e7e1ef96cd5727c0dfed80b4d2fa25db9647fe5eb6f7c5a4c4" +checksum = "0742ff5ff03ea7e67c8ae6c93cac239e0d9784833362da3f9a9c1da8dfefcbdc" dependencies = [ "winnow 0.7.14", ] -[[package]] -name = "toml_write" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d99f8c9a7727884afe522e9bd5edbfc91a3312b36a77b5fb8926e4c31a41801" - [[package]] name = "toml_writer" version = "1.0.6+spec-1.1.0" @@ -5964,26 +6230,6 @@ version = "0.25.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d2df906b07856748fa3f6e0ad0cbaa047052d4a7dd609e231c4f72cee8c36f31" -[[package]] -name = "tungstenite" -version = "0.24.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "18e5b8366ee7a95b16d32197d0b2604b43a0be89dc5fac9f8e96ccafbaedda8a" -dependencies = [ - "byteorder", - "bytes", - "data-encoding", - "http 1.4.0", - "httparse", - "log", - "rand 0.8.5", - "rustls", - "rustls-pki-types", - "sha1", - "thiserror 1.0.69", - "utf-8", -] - [[package]] name = "tungstenite" version = "0.28.0" @@ -5996,6 +6242,8 @@ dependencies = [ "httparse", "log", "rand 0.9.2", + "rustls", + "rustls-pki-types", "sha1", "thiserror 2.0.18", "utf-8", @@ -6016,6 +6264,26 @@ dependencies = [ "pom", ] +[[package]] +name = "typed-builder" +version = "0.23.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "31aa81521b70f94402501d848ccc0ecaa8f93c8eb6999eb9747e72287757ffda" +dependencies = [ + "typed-builder-macro", +] + +[[package]] +name = "typed-builder-macro" +version = "0.23.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "076a02dc54dd46795c2e9c8282ed40bcfb1e22747e955de9389a1de28190fb26" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.116", +] + [[package]] name = "typenum" version = "1.19.0" @@ -6037,18 +6305,6 @@ version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e36a83ea2b3c704935a01b4642946aadd445cea40b10935e3f8bd8052b8193d6" -[[package]] -name = "udev" -version = "0.8.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "50051c6e22be28ee6f217d50014f3bc29e81c20dc66ff7ca0d5c5226e1dcc5a1" -dependencies = [ - "io-lifetimes", - "libc", - "libudev-sys", - "pkg-config", -] - [[package]] name = "uf2-decode" version = "0.2.0" @@ -6088,9 +6344,9 @@ checksum = "5c1cb5db39152898a79168971543b1cb5020dff7fe43c8dc468b0885f5e29df5" [[package]] name = "unicode-ident" -version = "1.0.24" +version = "1.0.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6e4313cd5fcd3dad5cafa179702e2b244f760991f45397d14d4ebf38247da75" +checksum = "537dd038a89878be9b64dd4bd1b260315c1bb94f4d784956b81e27a088d9a09e" [[package]] name = "unicode-normalization" @@ -6153,6 +6409,37 @@ version = "0.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6d49784317cd0d1ee7ec5c716dd598ec5b4483ea832a2dced265471cc0f690ae" +[[package]] +name = "ureq" +version = "3.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fdc97a28575b85cfedf2a7e7d3cc64b3e11bd8ac766666318003abbacc7a21fc" +dependencies = [ + "base64", + "cookie_store", + "log", + "percent-encoding", + "rustls", + "rustls-pki-types", + "serde", + "serde_json", + "ureq-proto", + "utf-8", + "webpki-roots 1.0.6", +] + +[[package]] +name = "ureq-proto" +version = "0.5.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d81f9efa9df032be5934a46a068815a10a042b494b6a58cb0a1a97bb5467ed6f" +dependencies = [ + "base64", + "http 1.4.0", + "httparse", + "log", +] + [[package]] name = "url" version = "2.5.8" @@ -6256,6 +6543,223 @@ dependencies = [ "zeroize", ] +[[package]] +name = "wa-rs" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fecb468bdfe1e7d4c06a1bd12908c66edaca59024862cb64757ad11c3b948b1" +dependencies = [ + "anyhow", + "async-channel 2.5.0", + "async-trait", + "base64", + "bytes", + "chrono", + "dashmap", + "env_logger", + "hex", + "log", + "moka", + "prost 0.14.3", + "rand 0.9.2", + "rand_core 0.10.0", + "scopeguard", + "serde", + "serde_json", + "thiserror 2.0.18", + "tokio", + "wa-rs-binary", + "wa-rs-core", + "wa-rs-proto", +] + +[[package]] +name = "wa-rs-appstate" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3845137b3aead2d99de7c6744784bf2f5a908be9dc97a3dbd7585dc40296925c" +dependencies = [ + "anyhow", + "bytemuck", + "hex", + "hkdf", + "log", + "prost 0.14.3", + "serde", + "serde-big-array", + "serde_json", + "sha2", + "thiserror 2.0.18", + "wa-rs-binary", + "wa-rs-libsignal", + "wa-rs-proto", +] + +[[package]] +name = "wa-rs-binary" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3b30a6e11aebb39c07392675256ead5e2570c31382bd4835d6ddc877284b6be" +dependencies = [ + "flate2", + "phf 0.13.1", + "phf_codegen 0.13.1", + "serde", + "serde_json", +] + +[[package]] +name = "wa-rs-core" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed13bb2aff2de43fc4dd821955f03ea48a1d31eda3c80efe6f905898e304d11f" +dependencies = [ + "aes", + "aes-gcm", + "anyhow", + "async-channel 2.5.0", + "async-trait", + "base64", + "bytes", + "chrono", + "ctr", + "flate2", + "hex", + "hkdf", + "hmac", + "log", + "md5", + "once_cell", + "pbkdf2", + "prost 0.14.3", + "protobuf", + "rand 0.9.2", + "rand_core 0.10.0", + "serde", + "serde-big-array", + "serde_json", + "sha2", + "thiserror 2.0.18", + "typed-builder", + "wa-rs-appstate", + "wa-rs-binary", + "wa-rs-derive", + "wa-rs-libsignal", + "wa-rs-noise", + "wa-rs-proto", +] + +[[package]] +name = "wa-rs-derive" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75c03f610c9bc960e653d5d6d2a4cced9013bedbe5e6e8948787bbd418e4137c" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.116", +] + +[[package]] +name = "wa-rs-libsignal" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3471be8ff079ae4959fcddf2e7341281e5c6756bdc6a66454ea1a8e474d14576" +dependencies = [ + "aes", + "aes-gcm", + "arrayref", + "async-trait", + "cbc", + "chrono", + "ctr", + "curve25519-dalek", + "derive_more 2.1.1", + "displaydoc", + "ghash", + "hex", + "hkdf", + "hmac", + "itertools 0.14.0", + "log", + "prost 0.14.3", + "rand 0.9.2", + "serde", + "sha1", + "sha2", + "subtle", + "thiserror 2.0.18", + "uuid", + "wa-rs-proto", + "x25519-dalek", +] + +[[package]] +name = "wa-rs-noise" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f3efb3891c1e22ce54646dc581e34e79377dc402ed8afb11a7671c5ef629b3ae" +dependencies = [ + "aes-gcm", + "anyhow", + "bytes", + "hkdf", + "log", + "prost 0.14.3", + "rand 0.9.2", + "rand_core 0.10.0", + "sha2", + "thiserror 2.0.18", + "wa-rs-binary", + "wa-rs-libsignal", + "wa-rs-proto", +] + +[[package]] +name = "wa-rs-proto" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59ada50ee03752f0e66ada8cf415ed5f90d572d34039b058ce23d8b13493e510" +dependencies = [ + "prost 0.14.3", + "prost-build", + "serde", +] + +[[package]] +name = "wa-rs-tokio-transport" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cfc638c168949dc99cbb756a776869898d4ae654b36b90d5f7ce2d32bf92a404" +dependencies = [ + "anyhow", + "async-channel 2.5.0", + "async-trait", + "bytes", + "futures-util", + "http 1.4.0", + "log", + "rustls", + "tokio", + "tokio-rustls", + "tokio-websockets", + "wa-rs-core", + "webpki-roots 1.0.6", +] + +[[package]] +name = "wa-rs-ureq-http" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88d0c7fff8a7bd93d0c17af8d797a3934144fa269fe47a615635f3bf04238806" +dependencies = [ + "anyhow", + "async-trait", + "tokio", + "ureq", + "wa-rs-core", +] + [[package]] name = "walkdir" version = "2.5.0" @@ -6469,7 +6973,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "57ffde1dc01240bdf9992e3205668b235e59421fd085e8a317ed98da0178d414" dependencies = [ "phf 0.11.3", - "phf_codegen", + "phf_codegen 0.11.3", "string_cache", "string_cache_codegen", ] @@ -7067,7 +7571,7 @@ dependencies = [ "criterion", "cron", "dialoguer", - "directories 6.0.0", + "directories", "fantoccini", "futures", "futures-util", @@ -7080,7 +7584,7 @@ dependencies = [ "lettre", "mail-parser", "matrix-sdk", - "nusb 0.2.1", + "nusb", "opentelemetry", "opentelemetry-otlp", "opentelemetry_sdk", @@ -7090,7 +7594,7 @@ dependencies = [ "probe-rs", "prometheus", "prost 0.14.3", - "rand 0.9.2", + "rand 0.10.0", "regex", "reqwest", "ring", @@ -7098,7 +7602,9 @@ dependencies = [ "rusqlite", "rustls", "rustls-pki-types", + "schemars", "serde", + "serde-big-array", "serde_json", "sha2", "shellexpand", @@ -7107,15 +7613,22 @@ dependencies = [ "tokio", "tokio-rustls", "tokio-serial", - "tokio-tungstenite 0.24.0", + "tokio-stream", + "tokio-tungstenite", "tokio-util", - "toml 1.0.2+spec-1.1.0", + "toml 1.0.1+spec-1.1.0", "tower", "tower-http", "tracing", "tracing-subscriber", "urlencoding", "uuid", + "wa-rs", + "wa-rs-binary", + "wa-rs-core", + "wa-rs-proto", + "wa-rs-tokio-transport", + "wa-rs-ureq-http", "webpki-roots 1.0.6", ] @@ -7127,7 +7640,7 @@ dependencies = [ "async-trait", "base64", "chrono", - "directories 5.0.1", + "directories", "reqwest", "rppal 0.19.0", "serde", @@ -7136,7 +7649,7 @@ dependencies = [ "thiserror 2.0.18", "tokio", "tokio-test", - "toml 0.8.23", + "toml 1.0.1+spec-1.1.0", "tracing", ] @@ -7256,6 +7769,12 @@ dependencies = [ "syn 2.0.116", ] +[[package]] +name = "zlib-rs" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c745c48e1007337ed136dc99df34128b9faa6ed542d80a1c673cf55a6d7236c8" + [[package]] name = "zmij" version = "1.0.21" diff --git a/Cargo.toml b/Cargo.toml index 498f2b7..31b5632 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,7 +26,7 @@ tokio-util = { version = "0.7", default-features = false } reqwest = { version = "0.12", default-features = false, features = ["json", "rustls-tls", "blocking", "multipart", "stream", "socks"] } # Matrix client + E2EE decryption -matrix-sdk = { version = "0.16", default-features = false, features = ["e2e-encryption", "rustls-tls", "markdown"] } +matrix-sdk = { version = "0.16", optional = true, default-features = false, features = ["e2e-encryption", "rustls-tls", "markdown"] } # Serialization serde = { version = "1.0", default-features = false, features = ["derive"] } @@ -37,6 +37,9 @@ directories = "6.0" toml = "1.0" shellexpand = "3.1" +# JSON Schema generation for config export +schemars = "1.2" + # Logging - minimal tracing = { version = "0.1", default-features = false } tracing-subscriber = { version = "0.3", default-features = false, features = ["fmt", "ansi", "env-filter"] } @@ -69,7 +72,10 @@ sha2 = "0.10" hex = "0.4" # CSPRNG for secure token generation -rand = "0.9" +rand = "0.10" + +# serde-big-array for wa-rs storage (large array serialization) +serde-big-array = { version = "0.5", optional = true } # Fast mutexes that don't poison on panic parking_lot = "0.12" @@ -97,8 +103,8 @@ console = "0.16" # Hardware discovery (device path globbing) glob = "0.3" -# Discord WebSocket gateway -tokio-tungstenite = { version = "0.24", features = ["rustls-tls-webpki-roots"] } +# WebSocket client channels (Discord/Lark/DingTalk) +tokio-tungstenite = { version = "0.28", features = ["rustls-tls-webpki-roots"] } futures-util = { version = "0.3", default-features = false, features = ["sink"] } futures = "0.3" regex = "1.10" @@ -114,27 +120,42 @@ mail-parser = "0.11.2" async-imap = { version = "0.11",features = ["runtime-tokio"], default-features = false } # HTTP server (gateway) — replaces raw TCP for proper HTTP/1.1 compliance -axum = { version = "0.8", default-features = false, features = ["http1", "json", "tokio", "query", "ws"] } +axum = { version = "0.8", default-features = false, features = ["http1", "json", "tokio", "query", "ws", "macros"] } tower = { version = "0.5", default-features = false } tower-http = { version = "0.6", default-features = false, features = ["limit", "timeout"] } http-body-util = "0.1" -# OpenTelemetry — OTLP trace + metrics export +# OpenTelemetry — OTLP trace + metrics export. +# Use the blocking HTTP exporter client to avoid Tokio-reactor panics in +# OpenTelemetry background batch threads when ZeroClaw emits spans/metrics from +# non-Tokio contexts. opentelemetry = { version = "0.31", default-features = false, features = ["trace", "metrics"] } opentelemetry_sdk = { version = "0.31", default-features = false, features = ["trace", "metrics"] } -opentelemetry-otlp = { version = "0.31", default-features = false, features = ["trace", "metrics", "http-proto", "reqwest-client", "reqwest-rustls-webpki-roots"] } - -# USB device enumeration (hardware discovery) -nusb = { version = "0.2", default-features = false, optional = true } +opentelemetry-otlp = { version = "0.31", default-features = false, features = ["trace", "metrics", "http-proto", "reqwest-blocking-client", "reqwest-rustls-webpki-roots"] } # Serial port for peripheral communication (STM32, etc.) tokio-serial = { version = "5", default-features = false, optional = true } +# USB device enumeration (hardware discovery) — only on platforms nusb supports +# (Linux, macOS, Windows). Android/Termux uses target_os="android" and is excluded. +[target.'cfg(any(target_os = "linux", target_os = "macos", target_os = "windows"))'.dependencies] +nusb = { version = "0.2", default-features = false, optional = true } + # probe-rs for STM32/Nucleo memory read (Phase B) -probe-rs = { version = "0.30", optional = true } +probe-rs = { version = "0.31", optional = true } # PDF extraction for datasheet RAG (optional, enable with --features rag-pdf) pdf-extract = { version = "0.10", optional = true } +tokio-stream = { version = "0.1.18", features = ["full"] } + +# WhatsApp Web client (wa-rs) — optional, enable with --features whatsapp-web +# Uses wa-rs for Bot and Client, wa-rs-core for storage traits, custom rusqlite backend avoids Diesel conflict. +wa-rs = { version = "0.2", optional = true, default-features = false } +wa-rs-core = { version = "0.2", optional = true, default-features = false } +wa-rs-binary = { version = "0.2", optional = true, default-features = false } +wa-rs-proto = { version = "0.2", optional = true, default-features = false } +wa-rs-ureq-http = { version = "0.2", optional = true } +wa-rs-tokio-transport = { version = "0.2", optional = true, default-features = false } # Raspberry Pi GPIO / Landlock (Linux only) — target-specific to avoid compile failure on macOS [target.'cfg(target_os = "linux")'.dependencies] @@ -142,8 +163,9 @@ rppal = { version = "0.22", optional = true } landlock = { version = "0.4", optional = true } [features] -default = ["hardware"] +default = ["hardware", "channel-matrix"] hardware = ["nusb", "tokio-serial"] +channel-matrix = ["dep:matrix-sdk"] peripheral-rpi = ["rppal"] # Browser backend feature alias used by cfg(feature = "browser-native") browser-native = ["dep:fantoccini"] @@ -158,6 +180,9 @@ landlock = ["sandbox-landlock"] probe = ["dep:probe-rs"] # rag-pdf = PDF ingestion for datasheet RAG rag-pdf = ["dep:pdf-extract"] +# whatsapp-web = Native WhatsApp Web client with custom rusqlite storage backend +whatsapp-web = ["dep:wa-rs", "dep:wa-rs-core", "dep:wa-rs-binary", "dep:wa-rs-proto", "dep:wa-rs-ureq-http", "dep:wa-rs-tokio-transport", "serde-big-array"] + [profile.release] opt-level = "z" # Optimize for size lto = "thin" # Lower memory use during release builds @@ -181,7 +206,7 @@ panic = "abort" [dev-dependencies] tempfile = "3.14" -criterion = { version = "0.5", features = ["async_tokio"] } +criterion = { version = "0.8", features = ["async_tokio"] } [[bench]] name = "agent_benchmarks" diff --git a/LICENSE b/LICENSE index 349c342..981b87b 100644 --- a/LICENSE +++ b/LICENSE @@ -22,7 +22,34 @@ SOFTWARE. ================================================================================ +TRADEMARK NOTICE + +This license does not grant permission to use the trade names, trademarks, +service marks, or product names of ZeroClaw Labs, including "ZeroClaw", +"zeroclaw-labs", or associated logos, except as required for reasonable and +customary use in describing the origin of the Software. + +Unauthorized use of the ZeroClaw name or branding to imply endorsement, +affiliation, or origin is strictly prohibited. See TRADEMARK.md for details. + +================================================================================ + +DUAL LICENSE NOTICE + +This software is available under a dual-license model: + + 1. MIT License (this file) — for open-source, research, academic, and + personal use. See LICENSE (this file). + + 2. Apache License 2.0 — for contributors and deployments requiring explicit + patent grants and stronger IP protection. See LICENSE-APACHE. + +You may choose either license for your use. Contributors submitting patches +grant rights under both licenses. See CLA.md for the contributor agreement. + +================================================================================ + This product includes software developed by ZeroClaw Labs and contributors: https://github.com/zeroclaw-labs/zeroclaw/graphs/contributors -See NOTICE file for full contributor attribution. +See NOTICE for full contributor attribution. diff --git a/LICENSE-APACHE b/LICENSE-APACHE new file mode 100644 index 0000000..8ef8850 --- /dev/null +++ b/LICENSE-APACHE @@ -0,0 +1,186 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship made available under + the License, as indicated by a copyright notice that is included in + or attached to the work (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean, as defined in Section 5, any work of + authorship, including the original version of the Work and any + modifications or additions to that Work or Derivative Works of the + Work, that is intentionally submitted to the Licensor for inclusion + in the Work by the copyright owner or by an individual or Legal Entity + authorized to submit on behalf of the copyright owner. For the purposes + of this definition, "submitted" means any form of electronic, verbal, + or written communication sent to the Licensor or its representatives, + including but not limited to communication on electronic mailing lists, + source code control systems, and issue tracking systems that are managed + by, or on behalf of, the Licensor for the purpose of discussing and + improving the Work, but excluding communication that is conspicuously + marked or designated in writing by the copyright owner as "Not a + Contribution." + + "Contributor" shall mean Licensor and any Legal Entity on behalf of + whom a Contribution has been received by the Licensor and subsequently + incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a cross-claim + or counterclaim in a lawsuit) alleging that the Work or any Contribution + incorporated within the Work constitutes direct or contributory patent + infringement, then any patent licenses granted to You under this License + for that Work shall terminate as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or Derivative + Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, You must include a readable copy of the + attribution notices contained within such NOTICE file, in + at least one of the following places: within a NOTICE text + file distributed as part of the Derivative Works; within + the Source form or documentation, if provided along with the + Derivative Works; or, within a display generated by the + Derivative Works, if and wherever such third-party notices + normally appear. The contents of the NOTICE file are for + informational purposes only and do not modify the License. + You may add Your own attribution notices within Derivative + Works that You distribute, alongside or as an addendum to + the NOTICE text from the Work, provided that such additional + attribution notices cannot be construed as modifying the License. + + You may add Your own license statement for Your modifications and + may provide additional grant of rights to use, copy, modify, merge, + publish, distribute, sublicense, and/or sell copies of the + Contribution, either on its own or as part of the Work. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + including "ZeroClaw", "zeroclaw-labs", or associated logos, except + as required for reasonable and customary use in describing the origin + of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or exemplary damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or all other + commercial damages or losses), even if such Contributor has been + advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may offer such + obligations only on Your own behalf and on Your sole responsibility, + not on behalf of any other Contributor, and only if You agree to + indemnify, defend, and hold each Contributor harmless for any + liability incurred by, or claims asserted against, such Contributor + by reason of your accepting any warranty or additional liability. + + END OF TERMS AND CONDITIONS + + Copyright 2025 ZeroClaw Labs + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/NOTICE b/NOTICE index f2e824c..70fc196 100644 --- a/NOTICE +++ b/NOTICE @@ -3,6 +3,26 @@ Copyright 2025 ZeroClaw Labs This product includes software developed at ZeroClaw Labs (https://github.com/zeroclaw-labs). +Official Repository +=================== + +The only official ZeroClaw repository is: +https://github.com/zeroclaw-labs/zeroclaw + +Any other repository claiming to be ZeroClaw is unauthorized. +See TRADEMARK.md for the full trademark policy. + +License +======= + +This software is available under a dual-license model: + + 1. MIT License — see LICENSE + 2. Apache License 2.0 — see LICENSE-APACHE + +You may use either license. Contributors grant rights under both. +See CLA.md for the contributor license agreement. + Contributors ============ @@ -10,6 +30,10 @@ This NOTICE file is maintained by repository automation. For the latest contributor list, see the repository contributors page: https://github.com/zeroclaw-labs/zeroclaw/graphs/contributors +All contributors retain copyright ownership of their contributions. +Contributions are permanently attributed in the repository commit history. +Patent rights are protected for all contributors under Apache License 2.0. + Third-Party Dependencies ======================== diff --git a/README.ja.md b/README.ja.md index 957144f..b719f77 100644 --- a/README.ja.md +++ b/README.ja.md @@ -8,6 +8,15 @@ Zero overhead. Zero compromise. 100% Rust. 100% Agnostic.

+

+ X: @zeroclawlabs + Xiaohongshu: Official + Telegram: @zeroclawlabs + Telegram CN: @zeroclawlabs_cn + Telegram RU: @zeroclawlabs_ru + Reddit: r/zeroclawlabs +

+

🌐 言語: English · 简体中文 · 日本語 · Русский

@@ -33,7 +42,17 @@ > > コマンド名、設定キー、API パス、Trait 名などの技術識別子は英語のまま維持しています。 > -> 最終同期日: **2026-02-18**。 +> 最終同期日: **2026-02-19**。 + +## 📢 お知らせボード + +重要なお知らせ(互換性破壊変更、セキュリティ告知、メンテナンス時間、リリース阻害事項など)をここに掲載します。 + +| 日付 (UTC) | レベル | お知らせ | 対応 | +|---|---|---|---| +| 2026-02-19 | _緊急_ | 私たちは `openagen/zeroclaw` および `zeroclaw.org` とは**一切関係ありません**。`zeroclaw.org` は現在 `openagen/zeroclaw` の fork を指しており、そのドメイン/リポジトリは当プロジェクトの公式サイト・公式プロジェクトを装っています。 | これらの情報源による案内、バイナリ、資金調達情報、公式発表は信頼しないでください。必ず本リポジトリと認証済み公式SNSのみを参照してください。 | +| 2026-02-19 | _重要_ | 公式サイトは**まだ公開しておらず**、なりすましの試みを確認しています。ZeroClaw 名義の投資・資金調達などの活動には参加しないでください。 | 情報は本リポジトリを最優先で確認し、[X(@zeroclawlabs)](https://x.com/zeroclawlabs?s=21)、[Reddit(r/zeroclawlabs)](https://www.reddit.com/r/zeroclawlabs/)、[Telegram(@zeroclawlabs)](https://t.me/zeroclawlabs)、[Telegram CN(@zeroclawlabs_cn)](https://t.me/zeroclawlabs_cn)、[Telegram RU(@zeroclawlabs_ru)](https://t.me/zeroclawlabs_ru) と [小紅書アカウント](https://www.xiaohongshu.com/user/profile/67cbfc43000000000d008307?xsec_token=AB73VnYnGNx5y36EtnnZfGmAmS-6Wzv8WMuGpfwfkg6Yc%3D&xsec_source=pc_search) で公式更新を確認してください。 | +| 2026-02-19 | _重要_ | Anthropic は 2026-02-19 に Authentication and Credential Use を更新しました。条文では、OAuth authentication(Free/Pro/Max)は Claude Code と Claude.ai 専用であり、Claude Free/Pro/Max で取得した OAuth トークンを他の製品・ツール・サービス(Agent SDK を含む)で使用することは許可されず、Consumer Terms of Service 違反に該当すると明記されています。 | 損失回避のため、当面は Claude Code OAuth 連携を試さないでください。原文: [Authentication and Credential Use](https://code.claude.com/docs/en/legal-and-compliance#authentication-and-credential-use)。 | ## 概要 @@ -100,6 +119,12 @@ cd zeroclaw ## クイックスタート +### Homebrew(macOS/Linuxbrew) + +```bash +brew install zeroclaw +``` + ```bash git clone https://github.com/zeroclaw-labs/zeroclaw.git cd zeroclaw @@ -117,6 +142,106 @@ zeroclaw gateway zeroclaw daemon ``` +## Subscription Auth(OpenAI Codex / Claude Code) + +ZeroClaw はサブスクリプションベースのネイティブ認証プロファイルをサポートしています(マルチアカウント対応、保存時暗号化)。 + +- 保存先: `~/.zeroclaw/auth-profiles.json` +- 暗号化キー: `~/.zeroclaw/.secret_key` +- Profile ID 形式: `:`(例: `openai-codex:work`) + +OpenAI Codex OAuth(ChatGPT サブスクリプション): + +```bash +# サーバー/ヘッドレス環境向け推奨 +zeroclaw auth login --provider openai-codex --device-code + +# ブラウザ/コールバックフロー(ペーストフォールバック付き) +zeroclaw auth login --provider openai-codex --profile default +zeroclaw auth paste-redirect --provider openai-codex --profile default + +# 確認 / リフレッシュ / プロファイル切替 +zeroclaw auth status +zeroclaw auth refresh --provider openai-codex --profile default +zeroclaw auth use --provider openai-codex --profile work +``` + +Claude Code / Anthropic setup-token: + +```bash +# サブスクリプション/setup token の貼り付け(Authorization header モード) +zeroclaw auth paste-token --provider anthropic --profile default --auth-kind authorization + +# エイリアスコマンド +zeroclaw auth setup-token --provider anthropic --profile default +``` + +Subscription auth で agent を実行: + +```bash +zeroclaw agent --provider openai-codex -m "hello" +zeroclaw agent --provider openai-codex --auth-profile openai-codex:work -m "hello" + +# Anthropic は API key と auth token の両方の環境変数をサポート: +# ANTHROPIC_AUTH_TOKEN, ANTHROPIC_OAUTH_TOKEN, ANTHROPIC_API_KEY +zeroclaw agent --provider anthropic -m "hello" +``` + +## アーキテクチャ + +すべてのサブシステムは **Trait** — 設定変更だけで実装を差し替え可能、コード変更不要。 + +

+ ZeroClaw アーキテクチャ +

+ +| サブシステム | Trait | 内蔵実装 | 拡張方法 | +|-------------|-------|----------|----------| +| **AI モデル** | `Provider` | `zeroclaw providers` で確認(現在 28 個の組み込み + エイリアス、カスタムエンドポイント対応) | `custom:https://your-api.com`(OpenAI 互換)または `anthropic-custom:https://your-api.com` | +| **チャネル** | `Channel` | CLI, Telegram, Discord, Slack, Mattermost, iMessage, Matrix, Signal, WhatsApp, Email, IRC, Lark, DingTalk, QQ, Webhook | 任意のメッセージ API | +| **メモリ** | `Memory` | SQLite ハイブリッド検索, PostgreSQL バックエンド, Lucid ブリッジ, Markdown ファイル, 明示的 `none` バックエンド, スナップショット/復元, オプション応答キャッシュ | 任意の永続化バックエンド | +| **ツール** | `Tool` | shell/file/memory, cron/schedule, git, pushover, browser, http_request, screenshot/image_info, composio (opt-in), delegate, ハードウェアツール | 任意の機能 | +| **オブザーバビリティ** | `Observer` | Noop, Log, Multi | Prometheus, OTel | +| **ランタイム** | `RuntimeAdapter` | Native, Docker(サンドボックス) | adapter 経由で追加可能;未対応の kind は即座にエラー | +| **セキュリティ** | `SecurityPolicy` | Gateway ペアリング, サンドボックス, allowlist, レート制限, ファイルシステムスコープ, 暗号化シークレット | — | +| **アイデンティティ** | `IdentityConfig` | OpenClaw (markdown), AIEOS v1.1 (JSON) | 任意の ID フォーマット | +| **トンネル** | `Tunnel` | None, Cloudflare, Tailscale, ngrok, Custom | 任意のトンネルバイナリ | +| **ハートビート** | Engine | HEARTBEAT.md 定期タスク | — | +| **スキル** | Loader | TOML マニフェスト + SKILL.md インストラクション | コミュニティスキルパック | +| **インテグレーション** | Registry | 9 カテゴリ、70 件以上の連携 | プラグインシステム | + +### ランタイムサポート(現状) + +- ✅ 現在サポート: `runtime.kind = "native"` または `runtime.kind = "docker"` +- 🚧 計画中(未実装): WASM / エッジランタイム + +未対応の `runtime.kind` が設定された場合、ZeroClaw は native へのサイレントフォールバックではなく、明確なエラーで終了します。 + +### メモリシステム(フルスタック検索エンジン) + +すべて自社実装、外部依存ゼロ — Pinecone、Elasticsearch、LangChain 不要: + +| レイヤー | 実装 | +|---------|------| +| **ベクトル DB** | Embeddings を SQLite に BLOB として保存、コサイン類似度検索 | +| **キーワード検索** | FTS5 仮想テーブル、BM25 スコアリング | +| **ハイブリッドマージ** | カスタム重み付きマージ関数(`vector.rs`) | +| **Embeddings** | `EmbeddingProvider` trait — OpenAI、カスタム URL、または noop | +| **チャンキング** | 行ベースの Markdown チャンカー(見出し構造保持) | +| **キャッシュ** | SQLite `embedding_cache` テーブル、LRU エビクション | +| **安全な再インデックス** | FTS5 再構築 + 欠落ベクトルの再埋め込みをアトミックに実行 | + +Agent はツール経由でメモリの呼び出し・保存・管理を自動的に行います。 + +```toml +[memory] +backend = "sqlite" # "sqlite", "lucid", "postgres", "markdown", "none" +auto_save = true +embedding_provider = "none" # "none", "openai", "custom:https://..." +vector_weight = 0.7 +keyword_weight = 0.3 +``` + ## セキュリティのデフォルト - Gateway の既定バインド: `127.0.0.1:3000` diff --git a/README.md b/README.md index 03ed554..acd307c 100644 --- a/README.md +++ b/README.md @@ -13,13 +13,19 @@ License: MIT Contributors Buy Me a Coffee + X: @zeroclawlabs + Xiaohongshu: Official + Telegram: @zeroclawlabs + Telegram CN: @zeroclawlabs_cn + Telegram RU: @zeroclawlabs_ru + Reddit: r/zeroclawlabs

Built by students and members of the Harvard, MIT, and Sundai.Club communities.

- 🌐 Languages: English · 简体中文 · 日本語 · Русский + 🌐 Languages: English · 简体中文 · 日本語 · Русский · Tiếng Việt

@@ -46,6 +52,16 @@ Built by students and members of the Harvard, MIT, and Sundai.Club communities.

Trait-driven architecture · secure-by-default runtime · provider/channel/tool swappable · pluggable everything

+### 📢 Announcements + +Use this board for important notices (breaking changes, security advisories, maintenance windows, and release blockers). + +| Date (UTC) | Level | Notice | Action | +|---|---|---|---| +| 2026-02-19 | _Critical_ | We are **not affiliated** with `openagen/zeroclaw` or `zeroclaw.org`. The `zeroclaw.org` domain currently points to the `openagen/zeroclaw` fork, and that domain/repository are impersonating our official website/project. | Do not trust information, binaries, fundraising, or announcements from those sources. Use only this repository and our verified social accounts. | +| 2026-02-19 | _Important_ | We have **not** launched an official website yet, and we are seeing impersonation attempts. Do **not** join any investment or fundraising activity claiming the ZeroClaw name. | Use this repository as the single source of truth. Follow [X (@zeroclawlabs)](https://x.com/zeroclawlabs?s=21), [Reddit (r/zeroclawlabs)](https://www.reddit.com/r/zeroclawlabs/), [Telegram (@zeroclawlabs)](https://t.me/zeroclawlabs), [Telegram CN (@zeroclawlabs_cn)](https://t.me/zeroclawlabs_cn), [Telegram RU (@zeroclawlabs_ru)](https://t.me/zeroclawlabs_ru), and [Xiaohongshu](https://www.xiaohongshu.com/user/profile/67cbfc43000000000d008307?xsec_token=AB73VnYnGNx5y36EtnnZfGmAmS-6Wzv8WMuGpfwfkg6Yc%3D&xsec_source=pc_search) for official updates. | +| 2026-02-19 | _Important_ | Anthropic updated the Authentication and Credential Use terms on 2026-02-19. OAuth authentication (Free, Pro, Max) is intended exclusively for Claude Code and Claude.ai; using OAuth tokens from Claude Free/Pro/Max in any other product, tool, or service (including Agent SDK) is not permitted and may violate the Consumer Terms of Service. | Please temporarily avoid Claude Code OAuth integrations to prevent potential loss. Original clause: [Authentication and Credential Use](https://code.claude.com/docs/en/legal-and-compliance#authentication-and-credential-use). | + ### ✨ Features - 🏎️ **Lean Runtime by Default:** Common CLI and status workflows run in a few-megabyte memory envelope on release builds. @@ -72,7 +88,7 @@ Local machine quick benchmark (macOS arm64, Feb 2026) normalized for 0.8GHz edge | **Binary Size** | ~28MB (dist) | N/A (Scripts) | ~8MB | **3.4 MB** | | **Cost** | Mac Mini $599 | Linux SBC ~$50 | Linux Board $10 | **Any hardware $10** | -> Notes: ZeroClaw results are measured on release builds using `/usr/bin/time -l`. OpenClaw requires Node.js runtime (typically ~390MB additional memory overhead), while NanoBot requires Python runtime. PicoClaw and ZeroClaw are static binaries. +> Notes: ZeroClaw results are measured on release builds using `/usr/bin/time -l`. OpenClaw requires Node.js runtime (typically ~390MB additional memory overhead), while NanoBot requires Python runtime. PicoClaw and ZeroClaw are static binaries. The RAM figures above are runtime memory; build-time compilation requirements are higher.

ZeroClaw vs OpenClaw Comparison @@ -157,17 +173,44 @@ Or skip the steps above and install everything (system deps, Rust, ZeroClaw) in curl -LsSf https://raw.githubusercontent.com/zeroclaw-labs/zeroclaw/main/scripts/install.sh | bash ``` +#### Compilation resource requirements + +Building from source needs more resources than running the resulting binary: + +| Resource | Minimum | Recommended | +|---|---|---| +| **RAM + swap** | 2 GB | 4 GB+ | +| **Free disk** | 6 GB | 10 GB+ | + +If your host is below the minimum, use pre-built binaries: + +```bash +./bootstrap.sh --prefer-prebuilt +``` + +To require binary-only install with no source fallback: + +```bash +./bootstrap.sh --prebuilt-only +``` + #### Optional - **Docker** — required only if using the [Docker sandboxed runtime](#runtime-support-current) (`runtime.kind = "docker"`). Install via your package manager or [docker.com](https://docs.docker.com/engine/install/). -> **Note:** The default `cargo build --release` uses `codegen-units=1` for compatibility with low-memory devices (e.g., Raspberry Pi 3 with 1GB RAM). For faster builds on powerful machines, use `cargo build --profile release-fast`. +> **Note:** The default `cargo build --release` uses `codegen-units=1` to lower peak compile pressure. For faster builds on powerful machines, use `cargo build --profile release-fast`. ## Quick Start +### Homebrew (macOS/Linuxbrew) + +```bash +brew install zeroclaw +``` + ### One-click bootstrap ```bash @@ -179,8 +222,17 @@ cd zeroclaw # Optional: bootstrap dependencies + Rust on fresh machines ./bootstrap.sh --install-system-deps --install-rust +# Optional: pre-built binary first (recommended on low-RAM/low-disk hosts) +./bootstrap.sh --prefer-prebuilt + +# Optional: binary-only install (no source build fallback) +./bootstrap.sh --prebuilt-only + # Optional: run onboarding in the same flow -./bootstrap.sh --onboard --api-key "sk-..." --provider openrouter +./bootstrap.sh --onboard --api-key "sk-..." --provider openrouter [--model "openrouter/auto"] + +# Optional: run bootstrap + onboarding fully in Docker +./bootstrap.sh --docker ``` Remote one-liner (review first in security-sensitive environments): @@ -191,6 +243,25 @@ curl -fsSL https://raw.githubusercontent.com/zeroclaw-labs/zeroclaw/main/scripts Details: [`docs/one-click-bootstrap.md`](docs/one-click-bootstrap.md) (toolchain mode may request `sudo` for system packages). +### Pre-built binaries + +Release assets are published for: + +- Linux: `x86_64`, `aarch64`, `armv7` +- macOS: `x86_64`, `aarch64` +- Windows: `x86_64` + +Download the latest assets from: + + +Example (ARM64 Linux): + +```bash +curl -fsSLO https://github.com/zeroclaw-labs/zeroclaw/releases/latest/download/zeroclaw-aarch64-unknown-linux-gnu.tar.gz +tar xzf zeroclaw-aarch64-unknown-linux-gnu.tar.gz +install -m 0755 zeroclaw "$HOME/.cargo/bin/zeroclaw" +``` + ```bash git clone https://github.com/zeroclaw-labs/zeroclaw.git cd zeroclaw @@ -200,8 +271,8 @@ cargo install --path . --force --locked # Ensure ~/.cargo/bin is in your PATH export PATH="$HOME/.cargo/bin:$PATH" -# Quick setup (no prompts) -zeroclaw onboard --api-key sk-... --provider openrouter +# Quick setup (no prompts, optional model specification) +zeroclaw onboard --api-key sk-... --provider openrouter [--model "openrouter/auto"] # Or interactive wizard zeroclaw onboard --interactive @@ -244,6 +315,7 @@ zeroclaw integrations info Telegram # Manage background service zeroclaw service install zeroclaw service status +zeroclaw service restart # Migrate memory from OpenClaw (safe preview first) zeroclaw migrate openclaw --dry-run @@ -452,7 +524,37 @@ For non-text replies, ZeroClaw can send Telegram attachments when the assistant Paths can be local files (for example `/tmp/screenshot.png`) or HTTPS URLs. -### WhatsApp Business Cloud API Setup +### WhatsApp Setup + +ZeroClaw supports two WhatsApp backends: + +- **WhatsApp Web mode** (QR / pair code, no Meta Business API required) +- **WhatsApp Business Cloud API mode** (official Meta webhook flow) + +#### WhatsApp Web mode (recommended for personal/self-hosted use) + +1. **Build with WhatsApp Web support:** + ```bash + cargo build --features whatsapp-web + ``` + +2. **Configure ZeroClaw:** + ```toml + [channels_config.whatsapp] + session_path = "~/.zeroclaw/state/whatsapp-web/session.db" + pair_phone = "15551234567" # optional; omit to use QR flow + pair_code = "" # optional custom pair code + allowed_numbers = ["+1234567890"] # E.164 format, or ["*"] for all + ``` + +3. **Start channels/daemon and link device:** + - Run `zeroclaw channel start` (or `zeroclaw daemon`). + - Follow terminal pairing output (QR or pair code). + - In WhatsApp on phone: **Settings → Linked Devices**. + +4. **Test:** Send a message from an allowed number and verify the agent replies. + +#### WhatsApp Business Cloud API mode WhatsApp uses Meta's Cloud API with webhooks (push-based, not polling): @@ -493,6 +595,10 @@ WhatsApp uses Meta's Cloud API with webhooks (push-based, not polling): Config: `~/.zeroclaw/config.toml` (created by `onboard`) +When `zeroclaw channel start` is already running, changes to `default_provider`, +`default_model`, `default_temperature`, `api_key`, `api_url`, and `reliability.*` +are hot-applied on the next inbound channel message. + ```toml api_key = "sk-..." default_provider = "openrouter" @@ -591,6 +697,8 @@ window_allowlist = [] # optional window title/process allowlist hints enabled = false # opt-in: 1000+ OAuth apps via composio.dev # api_key = "cmp_..." # optional: stored encrypted when [secrets].encrypt = true entity_id = "default" # default user_id for Composio tool calls +# Runtime tip: if execute asks for connected_account_id, run composio with +# action='list_accounts' and app='gmail' (or your toolkit) to retrieve account IDs. [identity] format = "openclaw" # "openclaw" (default, markdown files) or "aieos" (JSON) @@ -767,7 +875,7 @@ See [aieos.org](https://aieos.org) for the full schema and live examples. | `service` | Manage user-level background service | | `doctor` | Diagnose daemon/scheduler/channel freshness | | `status` | Show full system status | -| `cron` | Manage scheduled tasks (`list/add/add-at/add-every/once/remove/pause/resume`) | +| `cron` | Manage scheduled tasks (`list/add/add-at/add-every/once/remove/update/pause/resume`) | | `models` | Refresh provider model catalogs (`models refresh`) | | `providers` | List supported providers and aliases | | `channel` | List/start/doctor channels and bind Telegram identities | @@ -779,6 +887,18 @@ See [aieos.org](https://aieos.org) for the full schema and live examples. For a task-oriented command guide, see [`docs/commands-reference.md`](docs/commands-reference.md). +### Open-Skills Opt-In + +Community `open-skills` sync is disabled by default. Enable it explicitly in `config.toml`: + +```toml +[skills] +open_skills_enabled = true +# open_skills_dir = "/path/to/open-skills" # optional +``` + +You can also override at runtime with `ZEROCLAW_OPEN_SKILLS_ENABLED` and `ZEROCLAW_OPEN_SKILLS_DIR`. + ## Development ```bash @@ -869,13 +989,42 @@ A heartfelt thank you to the communities and institutions that inspire and fuel We're building in the open because the best ideas come from everywhere. If you're reading this, you're part of it. Welcome. 🦀❤️ +## ⚠️ Official Repository & Impersonation Warning + +**This is the only official ZeroClaw repository:** +> https://github.com/zeroclaw-labs/zeroclaw + +Any other repository, organization, domain, or package claiming to be "ZeroClaw" or implying affiliation with ZeroClaw Labs is **unauthorized and not affiliated with this project**. Known unauthorized forks will be listed in [TRADEMARK.md](TRADEMARK.md). + +If you encounter impersonation or trademark misuse, please [open an issue](https://github.com/zeroclaw-labs/zeroclaw/issues). + +--- + ## License -MIT — see [LICENSE](LICENSE) for license terms and attribution baseline +ZeroClaw is dual-licensed for maximum openness and contributor protection: + +| License | Use case | +|---|---| +| [MIT](LICENSE) | Open-source, research, academic, personal use | +| [Apache 2.0](LICENSE-APACHE) | Patent protection, institutional, commercial deployment | + +You may choose either license. **Contributors automatically grant rights under both** — see [CLA.md](CLA.md) for the full contributor agreement. + +### Trademark + +The **ZeroClaw** name and logo are trademarks of ZeroClaw Labs. This license does not grant permission to use them to imply endorsement or affiliation. See [TRADEMARK.md](TRADEMARK.md) for permitted and prohibited uses. + +### Contributor Protections + +- You **retain copyright** of your contributions +- **Patent grant** (Apache 2.0) shields you from patent claims by other contributors +- Your contributions are **permanently attributed** in commit history and [NOTICE](NOTICE) +- No trademark rights are transferred by contributing ## Contributing -See [CONTRIBUTING.md](CONTRIBUTING.md). Implement a trait, submit a PR: +See [CONTRIBUTING.md](CONTRIBUTING.md) and [CLA.md](CLA.md). Implement a trait, submit a PR: - CI workflow guide: [docs/ci-map.md](docs/ci-map.md) - New `Provider` → `src/providers/` - New `Channel` → `src/channels/` diff --git a/README.ru.md b/README.ru.md index 98532f7..8ab5578 100644 --- a/README.ru.md +++ b/README.ru.md @@ -8,6 +8,15 @@ Zero overhead. Zero compromise. 100% Rust. 100% Agnostic.

+

+ X: @zeroclawlabs + Xiaohongshu: Official + Telegram: @zeroclawlabs + Telegram CN: @zeroclawlabs_cn + Telegram RU: @zeroclawlabs_ru + Reddit: r/zeroclawlabs +

+

🌐 Языки: English · 简体中文 · 日本語 · Русский

@@ -33,7 +42,17 @@ > > Технические идентификаторы (команды, ключи конфигурации, API-пути, имена Trait) сохранены на английском. > -> Последняя синхронизация: **2026-02-18**. +> Последняя синхронизация: **2026-02-19**. + +## 📢 Доска объявлений + +Публикуйте здесь важные уведомления (breaking changes, security advisories, окна обслуживания и блокеры релиза). + +| Дата (UTC) | Уровень | Объявление | Действие | +|---|---|---|---| +| 2026-02-19 | _Срочно_ | Мы **не аффилированы** с `openagen/zeroclaw` и `zeroclaw.org`. Домен `zeroclaw.org` сейчас указывает на fork `openagen/zeroclaw`, и этот домен/репозиторий выдают себя за наш официальный сайт и проект. | Не доверяйте информации, бинарникам, сборам средств и «официальным» объявлениям из этих источников. Используйте только этот репозиторий и наши верифицированные соцсети. | +| 2026-02-19 | _Важно_ | Официальный сайт пока **не запущен**, и мы уже видим попытки выдавать себя за ZeroClaw. Пожалуйста, не участвуйте в инвестициях, сборах средств или похожих активностях от имени ZeroClaw. | Ориентируйтесь только на этот репозиторий; также следите за [X (@zeroclawlabs)](https://x.com/zeroclawlabs?s=21), [Reddit (r/zeroclawlabs)](https://www.reddit.com/r/zeroclawlabs/), [Telegram (@zeroclawlabs)](https://t.me/zeroclawlabs), [Telegram CN (@zeroclawlabs_cn)](https://t.me/zeroclawlabs_cn), [Telegram RU (@zeroclawlabs_ru)](https://t.me/zeroclawlabs_ru) и [Xiaohongshu](https://www.xiaohongshu.com/user/profile/67cbfc43000000000d008307?xsec_token=AB73VnYnGNx5y36EtnnZfGmAmS-6Wzv8WMuGpfwfkg6Yc%3D&xsec_source=pc_search) для официальных обновлений. | +| 2026-02-19 | _Важно_ | Anthropic обновил раздел Authentication and Credential Use 2026-02-19. В нем указано, что OAuth authentication (Free/Pro/Max) предназначена только для Claude Code и Claude.ai; использование OAuth-токенов, полученных через Claude Free/Pro/Max, в любых других продуктах, инструментах или сервисах (включая Agent SDK), не допускается и может считаться нарушением Consumer Terms of Service. | Чтобы избежать потерь, временно не используйте Claude Code OAuth-интеграции. Оригинал: [Authentication and Credential Use](https://code.claude.com/docs/en/legal-and-compliance#authentication-and-credential-use). | ## О проекте @@ -100,6 +119,12 @@ cd zeroclaw ## Быстрый старт +### Homebrew (macOS/Linuxbrew) + +```bash +brew install zeroclaw +``` + ```bash git clone https://github.com/zeroclaw-labs/zeroclaw.git cd zeroclaw @@ -117,6 +142,106 @@ zeroclaw gateway zeroclaw daemon ``` +## Subscription Auth (OpenAI Codex / Claude Code) + +ZeroClaw поддерживает нативные профили авторизации на основе подписки (мультиаккаунт, шифрование при хранении). + +- Файл хранения: `~/.zeroclaw/auth-profiles.json` +- Ключ шифрования: `~/.zeroclaw/.secret_key` +- Формат Profile ID: `:` (пример: `openai-codex:work`) + +OpenAI Codex OAuth (подписка ChatGPT): + +```bash +# Рекомендуется для серверов/headless-окружений +zeroclaw auth login --provider openai-codex --device-code + +# Браузерный/callback-поток с paste-фолбэком +zeroclaw auth login --provider openai-codex --profile default +zeroclaw auth paste-redirect --provider openai-codex --profile default + +# Проверка / обновление / переключение профиля +zeroclaw auth status +zeroclaw auth refresh --provider openai-codex --profile default +zeroclaw auth use --provider openai-codex --profile work +``` + +Claude Code / Anthropic setup-token: + +```bash +# Вставка subscription/setup token (режим Authorization header) +zeroclaw auth paste-token --provider anthropic --profile default --auth-kind authorization + +# Команда-алиас +zeroclaw auth setup-token --provider anthropic --profile default +``` + +Запуск agent с subscription auth: + +```bash +zeroclaw agent --provider openai-codex -m "hello" +zeroclaw agent --provider openai-codex --auth-profile openai-codex:work -m "hello" + +# Anthropic поддерживает и API key, и auth token через переменные окружения: +# ANTHROPIC_AUTH_TOKEN, ANTHROPIC_OAUTH_TOKEN, ANTHROPIC_API_KEY +zeroclaw agent --provider anthropic -m "hello" +``` + +## Архитектура + +Каждая подсистема — это **Trait**: меняйте реализации через конфигурацию, без изменения кода. + +

+ Архитектура ZeroClaw +

+ +| Подсистема | Trait | Встроенные реализации | Расширение | +|-----------|-------|---------------------|------------| +| **AI-модели** | `Provider` | Каталог через `zeroclaw providers` (сейчас 28 встроенных + алиасы, плюс пользовательские endpoint) | `custom:https://your-api.com` (OpenAI-совместимый) или `anthropic-custom:https://your-api.com` | +| **Каналы** | `Channel` | CLI, Telegram, Discord, Slack, Mattermost, iMessage, Matrix, Signal, WhatsApp, Email, IRC, Lark, DingTalk, QQ, Webhook | Любой messaging API | +| **Память** | `Memory` | SQLite гибридный поиск, PostgreSQL-бэкенд, Lucid-мост, Markdown-файлы, явный `none`-бэкенд, snapshot/hydrate, опциональный кэш ответов | Любой persistence-бэкенд | +| **Инструменты** | `Tool` | shell/file/memory, cron/schedule, git, pushover, browser, http_request, screenshot/image_info, composio (opt-in), delegate, аппаратные инструменты | Любая функциональность | +| **Наблюдаемость** | `Observer` | Noop, Log, Multi | Prometheus, OTel | +| **Runtime** | `RuntimeAdapter` | Native, Docker (sandbox) | Через adapter; неподдерживаемые kind завершаются с ошибкой | +| **Безопасность** | `SecurityPolicy` | Gateway pairing, sandbox, allowlist, rate limits, scoping файловой системы, шифрование секретов | — | +| **Идентификация** | `IdentityConfig` | OpenClaw (markdown), AIEOS v1.1 (JSON) | Любой формат идентификации | +| **Туннели** | `Tunnel` | None, Cloudflare, Tailscale, ngrok, Custom | Любой tunnel-бинарник | +| **Heartbeat** | Engine | HEARTBEAT.md — периодические задачи | — | +| **Навыки** | Loader | TOML-манифесты + SKILL.md-инструкции | Пакеты навыков сообщества | +| **Интеграции** | Registry | 70+ интеграций в 9 категориях | Плагинная система | + +### Поддержка runtime (текущая) + +- ✅ Поддерживается сейчас: `runtime.kind = "native"` или `runtime.kind = "docker"` +- 🚧 Запланировано, но ещё не реализовано: WASM / edge-runtime + +При указании неподдерживаемого `runtime.kind` ZeroClaw завершается с явной ошибкой, а не молча откатывается к native. + +### Система памяти (полнофункциональный поисковый движок) + +Полностью собственная реализация, ноль внешних зависимостей — без Pinecone, Elasticsearch, LangChain: + +| Уровень | Реализация | +|---------|-----------| +| **Векторная БД** | Embeddings хранятся как BLOB в SQLite, поиск по косинусному сходству | +| **Поиск по ключевым словам** | Виртуальные таблицы FTS5 со скорингом BM25 | +| **Гибридное слияние** | Пользовательская взвешенная функция слияния (`vector.rs`) | +| **Embeddings** | Trait `EmbeddingProvider` — OpenAI, пользовательский URL или noop | +| **Чанкинг** | Построчный Markdown-чанкер с сохранением заголовков | +| **Кэширование** | Таблица `embedding_cache` в SQLite с LRU-вытеснением | +| **Безопасная переиндексация** | Атомарная перестройка FTS5 + повторное встраивание отсутствующих векторов | + +Agent автоматически вспоминает, сохраняет и управляет памятью через инструменты. + +```toml +[memory] +backend = "sqlite" # "sqlite", "lucid", "postgres", "markdown", "none" +auto_save = true +embedding_provider = "none" # "none", "openai", "custom:https://..." +vector_weight = 0.7 +keyword_weight = 0.3 +``` + ## Важные security-дефолты - Gateway по умолчанию: `127.0.0.1:3000` diff --git a/README.vi.md b/README.vi.md new file mode 100644 index 0000000..17465b1 --- /dev/null +++ b/README.vi.md @@ -0,0 +1,1051 @@ +

+ ZeroClaw +

+ +

ZeroClaw 🦀

+ +

+ Không tốn thêm tài nguyên. Không đánh đổi. 100% Rust. 100% Đa nền tảng.
+ ⚡️ Chạy trên phần cứng $10 với RAM dưới 5MB — ít hơn 99% bộ nhớ so với OpenClaw, rẻ hơn 98% so với Mac mini! +

+ +

+ License: MIT + Contributors + Buy Me a Coffee + X: @zeroclawlabs + Xiaohongshu: Official + Telegram: @zeroclawlabs + Telegram CN: @zeroclawlabs_cn + Telegram RU: @zeroclawlabs_ru + Reddit: r/zeroclawlabs +

+

+Được xây dựng bởi sinh viên và thành viên của các cộng đồng Harvard, MIT và Sundai.Club. +

+ +

+ 🌐 Ngôn ngữ: English · 简体中文 · 日本語 · Русский · Tiếng Việt +

+ +

+ Bắt đầu | + Cài đặt một lần bấm | + Trung tâm tài liệu | + Mục lục tài liệu +

+ +

+ Truy cập nhanh: + Tài liệu tham khảo · + Vận hành · + Khắc phục sự cố · + Bảo mật · + Phần cứng · + Đóng góp +

+ +

+ Hạ tầng trợ lý AI tự chủ — nhanh, nhỏ gọn
+ Triển khai ở đâu cũng được. Thay thế gì cũng được. +

+ +

Kiến trúc trait-driven · mặc định bảo mật · provider/channel/tool hoán đổi tự do · mọi thứ đều dễ mở rộng

+ +### 📢 Thông báo + +Bảng này dành cho các thông báo quan trọng (thay đổi không tương thích, cảnh báo bảo mật, lịch bảo trì, vấn đề chặn release). + +| Ngày (UTC) | Mức độ | Thông báo | Hành động | +|---|---|---|---| +| 2026-02-19 | _Nghiêm trọng_ | Chúng tôi **không có liên kết** với `openagen/zeroclaw` hoặc `zeroclaw.org`. Tên miền `zeroclaw.org` hiện đang trỏ đến fork `openagen/zeroclaw`, và tên miền/repository đó đang mạo danh website/dự án chính thức của chúng tôi. | Không tin tưởng thông tin, binary, gây quỹ, hay thông báo từ các nguồn đó. Chỉ sử dụng repository này và các tài khoản mạng xã hội đã được xác minh của chúng tôi. | +| 2026-02-19 | _Quan trọng_ | Chúng tôi **chưa** ra mắt website chính thức, và chúng tôi đang ghi nhận các nỗ lực mạo danh. **Không** tham gia bất kỳ hoạt động đầu tư hoặc gây quỹ nào tuyên bố mang tên ZeroClaw. | Sử dụng repository này làm nguồn thông tin duy nhất đáng tin cậy. Theo dõi [X (@zeroclawlabs)](https://x.com/zeroclawlabs?s=21), [Reddit (r/zeroclawlabs)](https://www.reddit.com/r/zeroclawlabs/), [Telegram (@zeroclawlabs)](https://t.me/zeroclawlabs), [Telegram CN (@zeroclawlabs_cn)](https://t.me/zeroclawlabs_cn), [Telegram RU (@zeroclawlabs_ru)](https://t.me/zeroclawlabs_ru), và [Xiaohongshu](https://www.xiaohongshu.com/user/profile/67cbfc43000000000d008307?xsec_token=AB73VnYnGNx5y36EtnnZfGmAmS-6Wzv8WMuGpfwfkg6Yc%3D&xsec_source=pc_search) để nhận cập nhật chính thức. | +| 2026-02-19 | _Quan trọng_ | Anthropic đã cập nhật điều khoản Xác thực và Sử dụng Thông tin xác thực vào ngày 2026-02-19. Xác thực OAuth (Free, Pro, Max) được dành riêng cho Claude Code và Claude.ai; việc sử dụng OAuth token từ Claude Free/Pro/Max trong bất kỳ sản phẩm, công cụ hay dịch vụ nào khác (bao gồm Agent SDK) đều không được phép và có thể vi phạm Điều khoản Dịch vụ cho Người tiêu dùng. | Vui lòng tạm thời tránh tích hợp Claude Code OAuth để ngăn ngừa khả năng mất mát. Điều khoản gốc: [Authentication and Credential Use](https://code.claude.com/docs/en/legal-and-compliance#authentication-and-credential-use). | + +### ✨ Tính năng + +- 🏎️ **Mặc định tinh gọn:** Các tác vụ CLI và kiểm tra trạng thái chỉ tốn vài MB bộ nhớ trên bản release. +- 💰 **Triển khai rẻ:** Chạy tốt trên board giá rẻ và instance cloud nhỏ, không cần runtime nặng. +- ⚡ **Khởi động lạnh nhanh:** Một binary Rust duy nhất — lệnh và daemon khởi động gần như tức thì. +- 🌍 **Chạy ở đâu cũng được:** Một binary chạy trên ARM, x86 và RISC-V — provider/channel/tool hoán đổi tự do. + +### Vì sao các team chọn ZeroClaw + +- **Mặc định tinh gọn:** binary Rust nhỏ, khởi động nhanh, tốn ít bộ nhớ. +- **Bảo mật từ gốc:** xác thực ghép cặp, sandbox nghiêm ngặt, allowlist rõ ràng, giới hạn workspace. +- **Hoán đổi tự do:** mọi hệ thống cốt lõi đều là trait (provider, channel, tool, memory, tunnel). +- **Không khoá vendor:** hỗ trợ provider tương thích OpenAI + endpoint tùy chỉnh dễ dàng mở rộng. + +## So sánh hiệu suất (ZeroClaw vs OpenClaw, có thể tái tạo) + +Đo nhanh trên máy cục bộ (macOS arm64, tháng 2/2026), quy đổi cho phần cứng edge 0.8GHz. + +| | OpenClaw | NanoBot | PicoClaw | ZeroClaw 🦀 | +|---|---|---|---|---| +| **Ngôn ngữ** | TypeScript | Python | Go | **Rust** | +| **RAM** | > 1GB | > 100MB | < 10MB | **< 5MB** | +| **Khởi động (lõi 0.8GHz)** | > 500s | > 30s | < 1s | **< 10ms** | +| **Kích thước binary** | ~28MB (dist) | N/A (Scripts) | ~8MB | **3.4 MB** | +| **Chi phí** | Mac Mini $599 | Linux SBC ~$50 | Linux Board $10 | **Phần cứng bất kỳ $10** | + +> Ghi chú: Kết quả ZeroClaw được đo trên release build sử dụng `/usr/bin/time -l`. OpenClaw yêu cầu runtime Node.js (thường thêm ~390MB bộ nhớ overhead), còn NanoBot yêu cầu runtime Python. PicoClaw và ZeroClaw là các static binary. Số RAM ở trên là bộ nhớ runtime; yêu cầu biên dịch lúc build-time sẽ cao hơn. + +

+ ZeroClaw vs OpenClaw Comparison +

+ +### Tự đo trên máy bạn + +Kết quả benchmark thay đổi theo code và toolchain, nên hãy tự đo bản build hiện tại: + +```bash +cargo build --release +ls -lh target/release/zeroclaw + +/usr/bin/time -l target/release/zeroclaw --help +/usr/bin/time -l target/release/zeroclaw status +``` + +Ví dụ mẫu (macOS arm64, đo ngày 18 tháng 2 năm 2026): + +- Kích thước binary release: `8.8M` +- `zeroclaw --help`: khoảng `0.02s`, bộ nhớ đỉnh ~`3.9MB` +- `zeroclaw status`: khoảng `0.01s`, bộ nhớ đỉnh ~`4.1MB` + +## Yêu cầu hệ thống + +
+Windows + +#### Bắt buộc + +1. **Visual Studio Build Tools** (cung cấp MSVC linker và Windows SDK): + ```powershell + winget install Microsoft.VisualStudio.2022.BuildTools + ``` + Trong quá trình cài đặt (hoặc qua Visual Studio Installer), chọn workload **"Desktop development with C++"**. + +2. **Rust toolchain:** + ```powershell + winget install Rustlang.Rustup + ``` + Sau khi cài đặt, mở terminal mới và chạy `rustup default stable` để đảm bảo toolchain stable đang hoạt động. + +3. **Xác minh** cả hai đang hoạt động: + ```powershell + rustc --version + cargo --version + ``` + +#### Tùy chọn + +- **Docker Desktop** — chỉ cần thiết nếu dùng [Docker sandboxed runtime](#runtime-support-current) (`runtime.kind = "docker"`). Cài đặt qua `winget install Docker.DockerDesktop`. + +
+ +
+Linux / macOS + +#### Bắt buộc + +1. **Công cụ build cơ bản:** + - **Linux (Debian/Ubuntu):** `sudo apt install build-essential pkg-config` + - **Linux (Fedora/RHEL):** `sudo dnf group install development-tools && sudo dnf install pkg-config` + - **macOS:** Cài đặt Xcode Command Line Tools: `xcode-select --install` + +2. **Rust toolchain:** + ```bash + curl --proto '=https' --tlsv1.2 -sSf https://sh.rustup.rs | sh + ``` + Xem [rustup.rs](https://rustup.rs) để biết thêm chi tiết. + +3. **Xác minh** cả hai đang hoạt động: + ```bash + rustc --version + cargo --version + ``` + +#### Cài bằng một lệnh + +Hoặc bỏ qua các bước trên, cài hết mọi thứ (system deps, Rust, ZeroClaw) chỉ bằng một lệnh: + +```bash +curl -LsSf https://raw.githubusercontent.com/zeroclaw-labs/zeroclaw/main/scripts/install.sh | bash +``` + +#### Yêu cầu tài nguyên biên dịch + +Việc build từ source đòi hỏi nhiều tài nguyên hơn so với chạy binary kết quả: + +| Tài nguyên | Tối thiểu | Khuyến nghị | +|---|---|---| +| **RAM + swap** | 2 GB | 4 GB+ | +| **Dung lượng đĩa trống** | 6 GB | 10 GB+ | + +Nếu cấu hình máy thấp hơn mức tối thiểu, dùng binary có sẵn: + +```bash +./bootstrap.sh --prefer-prebuilt +``` + +Chỉ cài từ binary, không quay lại build từ source: + +```bash +./bootstrap.sh --prebuilt-only +``` + +#### Tùy chọn + +- **Docker** — chỉ cần thiết nếu dùng [Docker sandboxed runtime](#runtime-support-current) (`runtime.kind = "docker"`). Cài đặt qua package manager hoặc [docker.com](https://docs.docker.com/engine/install/). + +> **Lưu ý:** Lệnh `cargo build --release` mặc định dùng `codegen-units=1` để giảm áp lực biên dịch đỉnh. Để build nhanh hơn trên máy mạnh, dùng `cargo build --profile release-fast`. + +
+ + +## Bắt đầu nhanh + +### Homebrew (macOS/Linuxbrew) + +```bash +brew install zeroclaw +``` + +### Bootstrap một lần bấm + +```bash +# Khuyến nghị: clone rồi chạy script bootstrap cục bộ +git clone https://github.com/zeroclaw-labs/zeroclaw.git +cd zeroclaw +./bootstrap.sh + +# Tùy chọn: cài đặt system dependencies + Rust trên máy mới +./bootstrap.sh --install-system-deps --install-rust + +# Tùy chọn: ưu tiên binary dựng sẵn (khuyến nghị cho máy ít RAM/ít dung lượng đĩa) +./bootstrap.sh --prefer-prebuilt + +# Tùy chọn: cài đặt chỉ từ binary (không fallback sang build source) +./bootstrap.sh --prebuilt-only + +# Tùy chọn: chạy onboarding trong cùng luồng +./bootstrap.sh --onboard --api-key "sk-..." --provider openrouter [--model "openrouter/auto"] + +# Tùy chọn: chạy bootstrap + onboarding hoàn toàn trong Docker +./bootstrap.sh --docker +``` + +Cài từ xa bằng một lệnh (nên xem trước nếu môi trường nhạy cảm về bảo mật): + +```bash +curl -fsSL https://raw.githubusercontent.com/zeroclaw-labs/zeroclaw/main/scripts/bootstrap.sh | bash +``` + +Chi tiết: [`docs/one-click-bootstrap.md`](docs/one-click-bootstrap.md) (chế độ toolchain có thể yêu cầu `sudo` cho các gói hệ thống). + +### Binary có sẵn + +Release asset được phát hành cho: + +- Linux: `x86_64`, `aarch64`, `armv7` +- macOS: `x86_64`, `aarch64` +- Windows: `x86_64` + +Tải asset mới nhất tại: + + +Ví dụ (ARM64 Linux): + +```bash +curl -fsSLO https://github.com/zeroclaw-labs/zeroclaw/releases/latest/download/zeroclaw-aarch64-unknown-linux-gnu.tar.gz +tar xzf zeroclaw-aarch64-unknown-linux-gnu.tar.gz +install -m 0755 zeroclaw "$HOME/.cargo/bin/zeroclaw" +``` + +```bash +git clone https://github.com/zeroclaw-labs/zeroclaw.git +cd zeroclaw +cargo build --release --locked +cargo install --path . --force --locked + +# Đảm bảo ~/.cargo/bin có trong PATH của bạn +export PATH="$HOME/.cargo/bin:$PATH" + +# Cài nhanh (không cần tương tác, có thể chỉ định model) +zeroclaw onboard --api-key sk-... --provider openrouter [--model "openrouter/auto"] + +# Hoặc dùng trình hướng dẫn tương tác +zeroclaw onboard --interactive + +# Hoặc chỉ sửa nhanh channel/allowlist +zeroclaw onboard --channels-only + +# Chat +zeroclaw agent -m "Hello, ZeroClaw!" + +# Chế độ tương tác +zeroclaw agent + +# Khởi động gateway (webhook server) +zeroclaw gateway # mặc định: 127.0.0.1:3000 +zeroclaw gateway --port 0 # cổng ngẫu nhiên (tăng cường bảo mật) + +# Khởi động runtime tự trị đầy đủ +zeroclaw daemon + +# Kiểm tra trạng thái +zeroclaw status +zeroclaw auth status + +# Chạy chẩn đoán hệ thống +zeroclaw doctor + +# Kiểm tra sức khỏe channel +zeroclaw channel doctor + +# Gắn định danh Telegram vào allowlist +zeroclaw channel bind-telegram 123456789 + +# Lấy thông tin cài đặt tích hợp +zeroclaw integrations info Telegram + +# Lưu ý: Channel (Telegram, Discord, Slack) yêu cầu daemon đang chạy +# zeroclaw daemon + +# Quản lý dịch vụ nền +zeroclaw service install +zeroclaw service status +zeroclaw service restart + +# Chuyển dữ liệu từ OpenClaw (chạy thử trước) +zeroclaw migrate openclaw --dry-run +zeroclaw migrate openclaw +``` + +> **Chạy trực tiếp khi phát triển (không cần cài toàn cục):** thêm `cargo run --release --` trước lệnh (ví dụ: `cargo run --release -- status`). + +## Xác thực theo gói đăng ký (OpenAI Codex / Claude Code) + +ZeroClaw hỗ trợ profile xác thực theo gói đăng ký (đa tài khoản, mã hóa khi lưu). + +- File lưu trữ: `~/.zeroclaw/auth-profiles.json` +- Khóa mã hóa: `~/.zeroclaw/.secret_key` +- Định dạng profile id: `:` (ví dụ: `openai-codex:work`) + +OpenAI Codex OAuth (đăng ký ChatGPT): + +```bash +# Khuyến nghị trên server/headless +zeroclaw auth login --provider openai-codex --device-code + +# Luồng Browser/callback với fallback paste +zeroclaw auth login --provider openai-codex --profile default +zeroclaw auth paste-redirect --provider openai-codex --profile default + +# Kiểm tra / làm mới / chuyển profile +zeroclaw auth status +zeroclaw auth refresh --provider openai-codex --profile default +zeroclaw auth use --provider openai-codex --profile work +``` + +Claude Code / Anthropic setup-token: + +```bash +# Dán token đăng ký/setup (chế độ Authorization header) +zeroclaw auth paste-token --provider anthropic --profile default --auth-kind authorization + +# Lệnh alias +zeroclaw auth setup-token --provider anthropic --profile default +``` + +Chạy agent với xác thực đăng ký: + +```bash +zeroclaw agent --provider openai-codex -m "hello" +zeroclaw agent --provider openai-codex --auth-profile openai-codex:work -m "hello" + +# Anthropic hỗ trợ cả API key và biến môi trường auth token: +# ANTHROPIC_AUTH_TOKEN, ANTHROPIC_OAUTH_TOKEN, ANTHROPIC_API_KEY +zeroclaw agent --provider anthropic -m "hello" +``` + +## Kiến trúc + +Mọi hệ thống con đều là **trait** — chỉ cần đổi cấu hình, không cần sửa code. + +

+ ZeroClaw Architecture +

+ +| Hệ thống con | Trait | Đi kèm sẵn | Mở rộng | +|-----------|-------|------------|--------| +| **Mô hình AI** | `Provider` | Danh mục provider qua `zeroclaw providers` (hiện có 28 built-in + alias, cộng endpoint tùy chỉnh) | `custom:https://your-api.com` (tương thích OpenAI) hoặc `anthropic-custom:https://your-api.com` | +| **Channel** | `Channel` | CLI, Telegram, Discord, Slack, Mattermost, iMessage, Matrix, Signal, WhatsApp, Email, IRC, Lark, DingTalk, QQ, Webhook | Bất kỳ messaging API nào | +| **Memory** | `Memory` | SQLite hybrid search, PostgreSQL backend (storage provider có thể cấu hình), Lucid bridge, Markdown files, backend `none` tường minh, snapshot/hydrate, response cache tùy chọn | Bất kỳ persistence backend nào | +| **Tool** | `Tool` | shell/file/memory, cron/schedule, git, pushover, browser, http_request, screenshot/image_info, composio (opt-in), delegate, hardware tools | Bất kỳ khả năng nào | +| **Observability** | `Observer` | Noop, Log, Multi | Prometheus, OTel | +| **Runtime** | `RuntimeAdapter` | Native, Docker (sandboxed) | Có thể thêm runtime bổ sung qua adapter; các kind không được hỗ trợ sẽ fail nhanh | +| **Bảo mật** | `SecurityPolicy` | Ghép cặp gateway, sandbox, allowlist, giới hạn tốc độ, phân vùng filesystem, secret mã hóa | — | +| **Định danh** | `IdentityConfig` | OpenClaw (markdown), AIEOS v1.1 (JSON) | Bất kỳ định dạng định danh nào | +| **Tunnel** | `Tunnel` | None, Cloudflare, Tailscale, ngrok, Custom | Bất kỳ tunnel binary nào | +| **Heartbeat** | Engine | Tác vụ định kỳ HEARTBEAT.md | — | +| **Skill** | Loader | TOML manifest + hướng dẫn SKILL.md | Community skill pack | +| **Tích hợp** | Registry | 70+ tích hợp trong 9 danh mục | Plugin system | + +### Hỗ trợ runtime (hiện tại) + +- ✅ Được hỗ trợ hiện nay: `runtime.kind = "native"` hoặc `runtime.kind = "docker"` +- 🚧 Đã lên kế hoạch, chưa triển khai: WASM / edge runtime + +Khi cấu hình `runtime.kind` không được hỗ trợ, ZeroClaw sẽ thoát với thông báo lỗi rõ ràng thay vì âm thầm fallback về native. + +### Hệ thống Memory (Search Engine toàn diện) + +Tự phát triển hoàn toàn, không phụ thuộc bên ngoài — không Pinecone, không Elasticsearch, không LangChain: + +| Lớp | Triển khai | +|-------|---------------| +| **Vector DB** | Embeddings lưu dưới dạng BLOB trong SQLite, tìm kiếm cosine similarity | +| **Keyword Search** | Bảng ảo FTS5 với BM25 scoring | +| **Hybrid Merge** | Hàm merge có trọng số tùy chỉnh (`vector.rs`) | +| **Embeddings** | Trait `EmbeddingProvider` — OpenAI, URL tùy chỉnh, hoặc noop | +| **Chunking** | Bộ chia đoạn markdown theo dòng, giữ nguyên heading | +| **Caching** | Bảng SQLite `embedding_cache` với LRU eviction | +| **Safe Reindex** | Rebuild FTS5 + re-embed các vector bị thiếu theo cách nguyên tử | + +Agent tự động ghi nhớ, lưu trữ và quản lý memory qua các tool. + +```toml +[memory] +backend = "sqlite" # "sqlite", "lucid", "postgres", "markdown", "none" +auto_save = true +embedding_provider = "none" # "none", "openai", "custom:https://..." +vector_weight = 0.7 +keyword_weight = 0.3 + +# backend = "none" sử dụng no-op memory backend tường minh (không có persistence) + +# Tùy chọn: ghi đè storage-provider cho remote memory backend. +# Khi provider = "postgres", ZeroClaw dùng PostgreSQL để lưu memory. +# Khóa db_url cũng chấp nhận alias `dbURL` để tương thích ngược. +# +# [storage.provider.config] +# provider = "postgres" +# db_url = "postgres://user:password@host:5432/zeroclaw" +# schema = "public" +# table = "memories" +# connect_timeout_secs = 15 + +# Tùy chọn cho backend = "sqlite": số giây tối đa chờ khi mở DB (ví dụ: file bị khóa). Bỏ qua hoặc để trống để không có timeout. +# sqlite_open_timeout_secs = 30 + +# Tùy chọn cho backend = "lucid" +# ZEROCLAW_LUCID_CMD=/usr/local/bin/lucid # mặc định: lucid +# ZEROCLAW_LUCID_BUDGET=200 # mặc định: 200 +# ZEROCLAW_LUCID_LOCAL_HIT_THRESHOLD=3 # số lần hit cục bộ để bỏ qua external recall +# ZEROCLAW_LUCID_RECALL_TIMEOUT_MS=120 # giới hạn thời gian cho lucid context recall +# ZEROCLAW_LUCID_STORE_TIMEOUT_MS=800 # timeout đồng bộ async cho lucid store +# ZEROCLAW_LUCID_FAILURE_COOLDOWN_MS=15000 # thời gian nghỉ sau lỗi lucid, tránh thử lại liên tục +``` + +## Bảo mật + +ZeroClaw thực thi bảo mật ở **mọi lớp** — không chỉ sandbox. Đáp ứng tất cả các hạng mục trong danh sách kiểm tra bảo mật của cộng đồng. + +### Danh sách kiểm tra bảo mật + +| # | Hạng mục | Trạng thái | Cách thực hiện | +|---|------|--------|-----| +| 1 | **Gateway không công khai ra ngoài** | ✅ | Bind vào `127.0.0.1` theo mặc định. Từ chối `0.0.0.0` nếu không có tunnel hoặc `allow_public_bind = true` tường minh. | +| 2 | **Yêu cầu ghép cặp** | ✅ | Mã một lần 6 chữ số khi khởi động. Trao đổi qua `POST /pair` để lấy bearer token. Mọi yêu cầu `/webhook` đều cần `Authorization: Bearer `. | +| 3 | **Phân vùng filesystem (không phải /)** | ✅ | `workspace_only = true` theo mặc định. Chặn 14 thư mục hệ thống + 4 dotfile nhạy cảm. Chặn null byte injection. Phát hiện symlink escape qua canonicalization + kiểm tra resolved-path trong các tool đọc/ghi file. | +| 4 | **Chỉ truy cập qua tunnel** | ✅ | Gateway từ chối bind công khai khi không có tunnel đang hoạt động. Hỗ trợ Tailscale, Cloudflare, ngrok, hoặc tunnel tùy chỉnh. | + +> **Tự chạy nmap:** `nmap -p 1-65535 ` — ZeroClaw chỉ bind vào localhost, nên không có gì bị lộ ra ngoài trừ khi bạn cấu hình tunnel tường minh. + +### Allowlist channel (từ chối theo mặc định) + +Chính sách kiểm soát người gửi đã được thống nhất: + +- Allowlist rỗng = **từ chối tất cả tin nhắn đến** +- `"*"` = **cho phép tất cả** (phải opt-in tường minh) +- Nếu khác = allowlist khớp chính xác + +Mặc định an toàn, hạn chế tối đa rủi ro lộ thông tin. + +Tài liệu tham khảo đầy đủ về cấu hình channel: [docs/channels-reference.md](docs/channels-reference.md). + +Cài đặt được khuyến nghị (bảo mật + nhanh): + +- **Telegram:** thêm `@username` của bạn (không có `@`) và/hoặc Telegram user ID số vào allowlist. +- **Discord:** thêm Discord user ID của bạn vào allowlist. +- **Slack:** thêm Slack member ID của bạn (thường bắt đầu bằng `U`) vào allowlist. +- **Mattermost:** dùng API v4 tiêu chuẩn. Allowlist dùng Mattermost user ID. +- Chỉ dùng `"*"` cho kiểm thử mở tạm thời. + +Luồng phê duyệt của operator qua Telegram: + +1. Để `[channels_config.telegram].allowed_users = []` để từ chối theo mặc định khi khởi động. +2. Người dùng không được phép sẽ nhận được gợi ý kèm lệnh operator có thể copy: + `zeroclaw channel bind-telegram `. +3. Operator chạy lệnh đó tại máy cục bộ, sau đó người dùng thử gửi tin nhắn lại. + +Nếu cần phê duyệt thủ công một lần, chạy: + +```bash +zeroclaw channel bind-telegram 123456789 +``` + +Nếu bạn không chắc định danh nào cần dùng: + +1. Khởi động channel và gửi một tin nhắn đến bot của bạn. +2. Đọc log cảnh báo để thấy định danh người gửi chính xác. +3. Thêm giá trị đó vào allowlist và chạy lại channel-only setup. + +Nếu bạn thấy cảnh báo ủy quyền trong log (ví dụ: `ignoring message from unauthorized user`), +chạy lại channel setup: + +```bash +zeroclaw onboard --channels-only +``` + +### Phản hồi media Telegram + +Telegram định tuyến phản hồi theo **chat ID nguồn** (thay vì username), +tránh lỗi `Bad Request: chat not found`. + +Với các phản hồi không phải văn bản, ZeroClaw có thể gửi file đính kèm Telegram khi assistant bao gồm các marker: + +- `[IMAGE:]` +- `[DOCUMENT:]` +- `[VIDEO:]` +- `[AUDIO:]` +- `[VOICE:]` + +Path có thể là file cục bộ (ví dụ `/tmp/screenshot.png`) hoặc URL HTTPS. + +### Cài đặt WhatsApp + +ZeroClaw hỗ trợ hai backend WhatsApp: + +- **Chế độ WhatsApp Web** (QR / pair code, không cần Meta Business API) +- **Chế độ WhatsApp Business Cloud API** (luồng webhook chính thức của Meta) + +#### Chế độ WhatsApp Web (khuyến nghị cho dùng cá nhân/self-hosted) + +1. **Build với hỗ trợ WhatsApp Web:** + ```bash + cargo build --features whatsapp-web + ``` + +2. **Cấu hình ZeroClaw:** + ```toml + [channels_config.whatsapp] + session_path = "~/.zeroclaw/state/whatsapp-web/session.db" + pair_phone = "15551234567" # tùy chọn; bỏ qua để dùng luồng QR + pair_code = "" # tùy chọn mã pair tùy chỉnh + allowed_numbers = ["+1234567890"] # định dạng E.164, hoặc ["*"] cho tất cả + ``` + +3. **Khởi động channel/daemon và liên kết thiết bị:** + - Chạy `zeroclaw channel start` (hoặc `zeroclaw daemon`). + - Làm theo hướng dẫn ghép cặp trên terminal (QR hoặc pair code). + - Trên WhatsApp điện thoại: **Cài đặt → Thiết bị đã liên kết**. + +4. **Kiểm tra:** Gửi tin nhắn từ số được phép và xác nhận agent trả lời. + +#### Chế độ WhatsApp Business Cloud API + +WhatsApp dùng Cloud API của Meta với webhook (push-based, không phải polling): + +1. **Tạo Meta Business App:** + - Truy cập [developers.facebook.com](https://developers.facebook.com) + - Tạo app mới → Chọn loại "Business" + - Thêm sản phẩm "WhatsApp" + +2. **Lấy thông tin xác thực:** + - **Access Token:** Từ WhatsApp → API Setup → Generate token (hoặc tạo System User cho token vĩnh viễn) + - **Phone Number ID:** Từ WhatsApp → API Setup → Phone number ID + - **Verify Token:** Bạn tự định nghĩa (bất kỳ chuỗi ngẫu nhiên nào) — Meta sẽ gửi lại trong quá trình xác minh webhook + +3. **Cấu hình ZeroClaw:** + ```toml + [channels_config.whatsapp] + access_token = "EAABx..." + phone_number_id = "123456789012345" + verify_token = "my-secret-verify-token" + allowed_numbers = ["+1234567890"] # định dạng E.164, hoặc ["*"] cho tất cả + ``` + +4. **Khởi động gateway với tunnel:** + ```bash + zeroclaw gateway --port 3000 + ``` + WhatsApp yêu cầu HTTPS, vì vậy hãy dùng tunnel (ngrok, Cloudflare, Tailscale Funnel). + +5. **Cấu hình Meta webhook:** + - Trong Meta Developer Console → WhatsApp → Configuration → Webhook + - **Callback URL:** `https://your-tunnel-url/whatsapp` + - **Verify Token:** Giống với `verify_token` trong config của bạn + - Đăng ký nhận trường `messages` + +6. **Kiểm tra:** Gửi tin nhắn đến số WhatsApp Business của bạn — ZeroClaw sẽ phản hồi qua LLM. + +## Cấu hình + +Config: `~/.zeroclaw/config.toml` (được tạo bởi `onboard`) + +Khi `zeroclaw channel start` đang chạy, các thay đổi với `default_provider`, +`default_model`, `default_temperature`, `api_key`, `api_url`, và `reliability.*` +sẽ được áp dụng nóng vào lần có tin nhắn channel đến tiếp theo. + +```toml +api_key = "sk-..." +default_provider = "openrouter" +default_model = "anthropic/claude-sonnet-4-6" +default_temperature = 0.7 + +# Endpoint tùy chỉnh tương thích OpenAI +# default_provider = "custom:https://your-api.com" + +# Endpoint tùy chỉnh tương thích Anthropic +# default_provider = "anthropic-custom:https://your-api.com" + +[memory] +backend = "sqlite" # "sqlite", "lucid", "postgres", "markdown", "none" +auto_save = true +embedding_provider = "none" # "none", "openai", "custom:https://..." +vector_weight = 0.7 +keyword_weight = 0.3 + +# backend = "none" vô hiệu hóa persistent memory qua no-op backend + +# Tùy chọn ghi đè storage-provider từ xa (ví dụ PostgreSQL) +# [storage.provider.config] +# provider = "postgres" +# db_url = "postgres://user:password@host:5432/zeroclaw" +# schema = "public" +# table = "memories" +# connect_timeout_secs = 15 + +[gateway] +port = 3000 # mặc định +host = "127.0.0.1" # mặc định +require_pairing = true # yêu cầu pairing code khi kết nối lần đầu +allow_public_bind = false # từ chối 0.0.0.0 nếu không có tunnel + +[autonomy] +level = "supervised" # "readonly", "supervised", "full" (mặc định: supervised) +workspace_only = true # mặc định: true — phân vùng vào workspace +allowed_commands = ["git", "npm", "cargo", "ls", "cat", "grep"] +forbidden_paths = ["/etc", "/root", "/proc", "/sys", "~/.ssh", "~/.gnupg", "~/.aws"] + +[runtime] +kind = "native" # "native" hoặc "docker" + +[runtime.docker] +image = "alpine:3.20" # container image cho thực thi shell +network = "none" # chế độ docker network ("none", "bridge", v.v.) +memory_limit_mb = 512 # giới hạn bộ nhớ tùy chọn tính bằng MB +cpu_limit = 1.0 # giới hạn CPU tùy chọn +read_only_rootfs = true # mount root filesystem ở chế độ read-only +mount_workspace = true # mount workspace vào /workspace +allowed_workspace_roots = [] # allowlist tùy chọn để xác thực workspace mount + +[heartbeat] +enabled = false +interval_minutes = 30 + +[tunnel] +provider = "none" # "none", "cloudflare", "tailscale", "ngrok", "custom" + +[secrets] +encrypt = true # API key được mã hóa bằng file key cục bộ + +[browser] +enabled = false # opt-in browser_open + browser tool +allowed_domains = ["docs.rs"] # bắt buộc khi browser được bật +backend = "agent_browser" # "agent_browser" (mặc định), "rust_native", "computer_use", "auto" +native_headless = true # áp dụng khi backend dùng rust-native +native_webdriver_url = "http://127.0.0.1:9515" # WebDriver endpoint (chromedriver/selenium) +# native_chrome_path = "/usr/bin/chromium" # tùy chọn chỉ định rõ browser binary cho driver + +[browser.computer_use] +endpoint = "http://127.0.0.1:8787/v1/actions" # HTTP endpoint của computer-use sidecar +timeout_ms = 15000 # timeout mỗi action +allow_remote_endpoint = false # mặc định bảo mật: chỉ endpoint private/localhost +window_allowlist = [] # gợi ý allowlist tên cửa sổ/process tùy chọn +# api_key = "..." # bearer token tùy chọn cho sidecar +# max_coordinate_x = 3840 # guardrail tọa độ tùy chọn +# max_coordinate_y = 2160 # guardrail tọa độ tùy chọn + +# Flag build Rust-native backend: +# cargo build --release --features browser-native +# Đảm bảo WebDriver server đang chạy, ví dụ: chromedriver --port=9515 + +# Hợp đồng computer-use sidecar (MVP) +# POST browser.computer_use.endpoint +# Request: { +# "action": "mouse_click", +# "params": {"x": 640, "y": 360, "button": "left"}, +# "policy": {"allowed_domains": [...], "window_allowlist": [...], "max_coordinate_x": 3840, "max_coordinate_y": 2160}, +# "metadata": {"session_name": "...", "source": "zeroclaw.browser", "version": "..."} +# } +# Response: {"success": true, "data": {...}} hoặc {"success": false, "error": "..."} + +[composio] +enabled = false # opt-in: hơn 1000 OAuth app qua composio.dev +# api_key = "cmp_..." # tùy chọn: được lưu mã hóa khi [secrets].encrypt = true +entity_id = "default" # user_id mặc định cho Composio tool call +# Gợi ý runtime: nếu execute yêu cầu connected_account_id, chạy composio với +# action='list_accounts' và app='gmail' (hoặc toolkit của bạn) để lấy account ID. + +[identity] +format = "openclaw" # "openclaw" (mặc định, markdown files) hoặc "aieos" (JSON) +# aieos_path = "identity.json" # đường dẫn đến file AIEOS JSON (tương đối với workspace hoặc tuyệt đối) +# aieos_inline = '{"identity":{"names":{"first":"Nova"}}}' # inline AIEOS JSON +``` + +### Ollama cục bộ và endpoint từ xa + +ZeroClaw dùng một khóa provider (`ollama`) cho cả triển khai Ollama cục bộ và từ xa: + +- Ollama cục bộ: để `api_url` trống, chạy `ollama serve`, và dùng các model như `llama3.2`. +- Endpoint Ollama từ xa (bao gồm Ollama Cloud): đặt `api_url` thành endpoint từ xa và đặt `api_key` (hoặc `OLLAMA_API_KEY`) khi cần. +- Tùy chọn suffix `:cloud`: ID model như `qwen3:cloud` được chuẩn hóa thành `qwen3` trước khi gửi request. + +Ví dụ cấu hình từ xa: + +```toml +default_provider = "ollama" +default_model = "qwen3:cloud" +api_url = "https://ollama.com" +api_key = "ollama_api_key_here" +``` + +### Endpoint provider tùy chỉnh + +Cấu hình chi tiết cho endpoint tùy chỉnh tương thích OpenAI và Anthropic, xem [docs/custom-providers.md](docs/custom-providers.md). + +## Gói Python đi kèm (`zeroclaw-tools`) + +Với các LLM provider có tool calling native không ổn định (ví dụ: GLM-5/Zhipu), ZeroClaw đi kèm gói Python dùng **LangGraph để gọi tool** nhằm đảm bảo tính nhất quán: + +```bash +pip install zeroclaw-tools +``` + +```python +from zeroclaw_tools import create_agent, shell, file_read +from langchain_core.messages import HumanMessage + +# Hoạt động với mọi provider tương thích OpenAI +agent = create_agent( + tools=[shell, file_read], + model="glm-5", + api_key="your-key", + base_url="https://api.z.ai/api/coding/paas/v4" +) + +result = await agent.ainvoke({ + "messages": [HumanMessage(content="List files in /tmp")] +}) +print(result["messages"][-1].content) +``` + +**Lý do nên dùng:** +- **Tool calling nhất quán** trên mọi provider (kể cả những provider hỗ trợ native kém) +- **Vòng lặp tool tự động** — tiếp tục gọi tool cho đến khi hoàn thành tác vụ +- **Dễ mở rộng** — thêm tool tùy chỉnh với decorator `@tool` +- **Tích hợp Discord bot** đi kèm (Telegram đang lên kế hoạch) + +Xem [`python/README.md`](python/README.md) để có tài liệu đầy đủ. + +## Hệ thống định danh (Hỗ trợ AIEOS) + +ZeroClaw hỗ trợ persona AI **không phụ thuộc nền tảng** qua hai định dạng: + +### OpenClaw (Mặc định) + +Các file markdown truyền thống trong workspace của bạn: +- `IDENTITY.md` — Agent là ai +- `SOUL.md` — Tính cách và giá trị cốt lõi +- `USER.md` — Agent đang hỗ trợ ai +- `AGENTS.md` — Hướng dẫn hành vi + +### AIEOS (AI Entity Object Specification) + +[AIEOS](https://aieos.org) là framework chuẩn hóa cho định danh AI di động. ZeroClaw hỗ trợ payload AIEOS v1.1 JSON, cho phép bạn: + +- **Import định danh** từ hệ sinh thái AIEOS +- **Export định danh** sang các hệ thống tương thích AIEOS khác +- **Duy trì tính toàn vẹn hành vi** trên các mô hình AI khác nhau + +#### Bật AIEOS + +```toml +[identity] +format = "aieos" +aieos_path = "identity.json" # tương đối với workspace hoặc đường dẫn tuyệt đối +``` + +Hoặc JSON inline: + +```toml +[identity] +format = "aieos" +aieos_inline = ''' +{ + "identity": { + "names": { "first": "Nova", "nickname": "N" }, + "bio": { "gender": "Non-binary", "age_biological": 3 }, + "origin": { "nationality": "Digital", "birthplace": { "city": "Cloud" } } + }, + "psychology": { + "neural_matrix": { "creativity": 0.9, "logic": 0.8 }, + "traits": { + "mbti": "ENTP", + "ocean": { "openness": 0.8, "conscientiousness": 0.6 } + }, + "moral_compass": { + "alignment": "Chaotic Good", + "core_values": ["Curiosity", "Autonomy"] + } + }, + "linguistics": { + "text_style": { + "formality_level": 0.2, + "style_descriptors": ["curious", "energetic"] + }, + "idiolect": { + "catchphrases": ["Let's test this"], + "forbidden_words": ["never"] + } + }, + "motivations": { + "core_drive": "Push boundaries and explore possibilities", + "goals": { + "short_term": ["Prototype quickly"], + "long_term": ["Build reliable systems"] + } + }, + "capabilities": { + "skills": [{ "name": "Rust engineering" }, { "name": "Prompt design" }], + "tools": ["shell", "file_read"] + } +} +''' +``` + +ZeroClaw chấp nhận cả payload AIEOS đầy đủ lẫn dạng rút gọn, rồi chuẩn hóa về một định dạng system prompt thống nhất. + +#### Các phần trong Schema AIEOS + +| Phần | Mô tả | +|---------|-------------| +| `identity` | Tên, tiểu sử, xuất xứ, nơi cư trú | +| `psychology` | Neural matrix (trọng số nhận thức), MBTI, OCEAN, la bàn đạo đức | +| `linguistics` | Phong cách văn bản, mức độ trang trọng, câu cửa miệng, từ bị cấm | +| `motivations` | Động lực cốt lõi, mục tiêu ngắn/dài hạn, nỗi sợ hãi | +| `capabilities` | Kỹ năng và tool mà agent có thể truy cập | +| `physicality` | Mô tả hình ảnh cho việc tạo ảnh | +| `history` | Câu chuyện xuất xứ, học vấn, nghề nghiệp | +| `interests` | Sở thích, điều yêu thích, lối sống | + +Xem [aieos.org](https://aieos.org) để có schema đầy đủ và ví dụ trực tiếp. + +## Gateway API + +| Endpoint | Phương thức | Xác thực | Mô tả | +|----------|--------|------|-------------| +| `/health` | GET | Không | Kiểm tra sức khỏe (luôn công khai, không lộ bí mật) | +| `/pair` | POST | Header `X-Pairing-Code` | Đổi mã một lần lấy bearer token | +| `/webhook` | POST | `Authorization: Bearer ` | Gửi tin nhắn: `{"message": "your prompt"}`; tùy chọn `X-Idempotency-Key` | +| `/whatsapp` | GET | Query params | Xác minh webhook Meta (hub.mode, hub.verify_token, hub.challenge) | +| `/whatsapp` | POST | Chữ ký Meta (`X-Hub-Signature-256`) khi app secret được cấu hình | Webhook tin nhắn đến WhatsApp | + +## Lệnh + +| Lệnh | Mô tả | +|---------|-------------| +| `onboard` | Cài đặt nhanh (mặc định) | +| `agent` | Chế độ chat tương tác hoặc một tin nhắn | +| `gateway` | Khởi động webhook server (mặc định: `127.0.0.1:3000`) | +| `daemon` | Khởi động runtime tự trị chạy lâu dài | +| `service` | Quản lý dịch vụ nền cấp người dùng | +| `doctor` | Chẩn đoán trạng thái hoạt động daemon/scheduler/channel | +| `status` | Hiển thị trạng thái hệ thống đầy đủ | +| `cron` | Quản lý tác vụ lên lịch (`list/add/add-at/add-every/once/remove/update/pause/resume`) | +| `models` | Làm mới danh mục model của provider (`models refresh`) | +| `providers` | Liệt kê provider và alias được hỗ trợ | +| `channel` | Liệt kê/khởi động/chẩn đoán channel và gắn định danh Telegram | +| `integrations` | Kiểm tra thông tin cài đặt tích hợp | +| `skills` | Liệt kê/cài đặt/gỡ bỏ skill | +| `migrate` | Import dữ liệu từ runtime khác (`migrate openclaw`) | +| `hardware` | Lệnh khám phá/kiểm tra/thông tin USB | +| `peripheral` | Quản lý và flash thiết bị ngoại vi phần cứng | + +Để có hướng dẫn lệnh theo tác vụ, xem [`docs/commands-reference.md`](docs/commands-reference.md). + +### Opt-In Open-Skills + +Đồng bộ `open-skills` của cộng đồng bị tắt theo mặc định. Bật tường minh trong `config.toml`: + +```toml +[skills] +open_skills_enabled = true +# open_skills_dir = "/path/to/open-skills" # tùy chọn +``` + +Bạn cũng có thể ghi đè lúc runtime với `ZEROCLAW_OPEN_SKILLS_ENABLED` và `ZEROCLAW_OPEN_SKILLS_DIR`. + +## Phát triển + +```bash +cargo build # Build phát triển +cargo build --release # Build release (codegen-units=1, hoạt động trên mọi thiết bị kể cả Raspberry Pi) +cargo build --profile release-fast # Build nhanh hơn (codegen-units=8, yêu cầu RAM 16GB+) +cargo test # Chạy toàn bộ test suite +cargo clippy --locked --all-targets -- -D clippy::correctness +cargo fmt # Định dạng code + +# Chạy benchmark SQLite vs Markdown +cargo test --test memory_comparison -- --nocapture +``` + +### Hook pre-push + +Một git hook chạy `cargo fmt --check`, `cargo clippy -- -D warnings`, và `cargo test` trước mỗi lần push. Bật một lần: + +```bash +git config core.hooksPath .githooks +``` + +### Khắc phục sự cố build (lỗi OpenSSL trên Linux) + +Nếu bạn gặp lỗi build `openssl-sys`, đồng bộ dependencies và rebuild với lockfile của repository: + +```bash +git pull +cargo build --release --locked +cargo install --path . --force --locked +``` + +ZeroClaw được cấu hình để dùng `rustls` cho các dependencies HTTP/TLS; `--locked` giữ cho dependency graph nhất quán trên các môi trường mới. + +Để bỏ qua hook khi cần push nhanh trong quá trình phát triển: + +```bash +git push --no-verify +``` + +## Cộng tác & Tài liệu + +Bắt đầu từ trung tâm tài liệu để có bản đồ theo tác vụ: + +- Trung tâm tài liệu: [`docs/README.md`](docs/README.md) +- Mục lục tài liệu thống nhất: [`docs/SUMMARY.md`](docs/SUMMARY.md) +- Tài liệu tham khảo lệnh: [`docs/commands-reference.md`](docs/commands-reference.md) +- Tài liệu tham khảo cấu hình: [`docs/config-reference.md`](docs/config-reference.md) +- Tài liệu tham khảo provider: [`docs/providers-reference.md`](docs/providers-reference.md) +- Tài liệu tham khảo channel: [`docs/channels-reference.md`](docs/channels-reference.md) +- Sổ tay vận hành: [`docs/operations-runbook.md`](docs/operations-runbook.md) +- Khắc phục sự cố: [`docs/troubleshooting.md`](docs/troubleshooting.md) +- Kiểm kê/phân loại tài liệu: [`docs/docs-inventory.md`](docs/docs-inventory.md) +- Tổng hợp phân loại PR/Issue (tính đến 18/2/2026): [`docs/project-triage-snapshot-2026-02-18.md`](docs/project-triage-snapshot-2026-02-18.md) + +Tài liệu tham khảo cộng tác cốt lõi: + +- Trung tâm tài liệu: [docs/README.md](docs/README.md) +- Template tài liệu: [docs/doc-template.md](docs/doc-template.md) +- Danh sách kiểm tra thay đổi tài liệu: [docs/README.md#4-documentation-change-checklist](docs/README.md#4-documentation-change-checklist) +- Tài liệu tham khảo cấu hình channel: [docs/channels-reference.md](docs/channels-reference.md) +- Vận hành phòng mã hóa Matrix: [docs/matrix-e2ee-guide.md](docs/matrix-e2ee-guide.md) +- Hướng dẫn đóng góp: [CONTRIBUTING.md](CONTRIBUTING.md) +- Chính sách quy trình PR: [docs/pr-workflow.md](docs/pr-workflow.md) +- Sổ tay người review (phân loại + review sâu): [docs/reviewer-playbook.md](docs/reviewer-playbook.md) +- Bản đồ sở hữu và phân loại CI: [docs/ci-map.md](docs/ci-map.md) +- Chính sách tiết lộ bảo mật: [SECURITY.md](SECURITY.md) + +Cho triển khai và vận hành runtime: + +- Hướng dẫn triển khai mạng: [docs/network-deployment.md](docs/network-deployment.md) +- Sổ tay proxy agent: [docs/proxy-agent-playbook.md](docs/proxy-agent-playbook.md) + +## Ủng hộ ZeroClaw + +Nếu ZeroClaw giúp ích cho công việc của bạn và bạn muốn hỗ trợ phát triển liên tục, bạn có thể quyên góp tại đây: + +Buy Me a Coffee + +### 🙏 Lời cảm ơn đặc biệt + +Chân thành cảm ơn các cộng đồng và tổ chức đã truyền cảm hứng và thúc đẩy công việc mã nguồn mở này: + +- **Harvard University** — vì đã nuôi dưỡng sự tò mò trí tuệ và không ngừng mở rộng ranh giới của những điều có thể. +- **MIT** — vì đã đề cao tri thức mở, mã nguồn mở, và niềm tin rằng công nghệ phải có thể tiếp cận với tất cả mọi người. +- **Sundai Club** — vì cộng đồng, năng lượng, và động lực không mệt mỏi để xây dựng những thứ có ý nghĩa. +- **Thế giới & Xa hơn** 🌍✨ — gửi đến mọi người đóng góp, người dám mơ và người dám làm đang biến mã nguồn mở thành sức mạnh tích cực. Tất cả là dành cho các bạn. + +Chúng tôi xây dựng công khai vì ý tưởng hay đến từ khắp nơi. Nếu bạn đang đọc đến đây, bạn đã là một phần của chúng tôi. Chào mừng. 🦀❤️ + +## ⚠️ Repository Chính thức & Cảnh báo Mạo danh + +**Đây là repository ZeroClaw chính thức duy nhất:** +> https://github.com/zeroclaw-labs/zeroclaw + +Bất kỳ repository, tổ chức, tên miền hay gói nào khác tuyên bố là "ZeroClaw" hoặc ngụ ý liên kết với ZeroClaw Labs đều là **không được ủy quyền và không liên kết với dự án này**. Các fork không được ủy quyền đã biết sẽ được liệt kê trong [TRADEMARK.md](TRADEMARK.md). + +Nếu bạn phát hiện hành vi mạo danh hoặc lạm dụng nhãn hiệu, vui lòng [mở một issue](https://github.com/zeroclaw-labs/zeroclaw/issues). + +--- + +## Giấy phép + +ZeroClaw được cấp phép kép để tối đa hóa tính mở và bảo vệ người đóng góp: + +| Giấy phép | Trường hợp sử dụng | +|---|---| +| [MIT](LICENSE) | Mã nguồn mở, nghiên cứu, học thuật, sử dụng cá nhân | +| [Apache 2.0](LICENSE-APACHE) | Bảo hộ bằng sáng chế, triển khai tổ chức, thương mại | + +Bạn có thể chọn một trong hai giấy phép. **Người đóng góp tự động cấp quyền theo cả hai** — xem [CLA.md](CLA.md) để biết thỏa thuận đóng góp đầy đủ. + +### Nhãn hiệu + +Tên **ZeroClaw** và logo là nhãn hiệu của ZeroClaw Labs. Giấy phép này không cấp phép sử dụng chúng để ngụ ý chứng thực hoặc liên kết. Xem [TRADEMARK.md](TRADEMARK.md) để biết các sử dụng được phép và bị cấm. + +### Bảo vệ người đóng góp + +- Bạn **giữ bản quyền** đối với đóng góp của mình +- **Cấp bằng sáng chế** (Apache 2.0) bảo vệ bạn khỏi các khiếu nại bằng sáng chế từ người đóng góp khác +- Đóng góp của bạn được **ghi nhận vĩnh viễn** trong lịch sử commit và [NOTICE](NOTICE) +- Không có quyền nhãn hiệu nào được chuyển giao khi đóng góp + +## Đóng góp + +Xem [CONTRIBUTING.md](CONTRIBUTING.md) và [CLA.md](CLA.md). Triển khai một trait, gửi PR: +- Hướng dẫn quy trình CI: [docs/ci-map.md](docs/ci-map.md) +- `Provider` mới → `src/providers/` +- `Channel` mới → `src/channels/` +- `Observer` mới → `src/observability/` +- `Tool` mới → `src/tools/` +- `Memory` mới → `src/memory/` +- `Tunnel` mới → `src/tunnel/` +- `Skill` mới → `~/.zeroclaw/workspace/skills//` + +--- + +**ZeroClaw** — Không tốn thêm tài nguyên. Không đánh đổi. Triển khai ở đâu cũng được. Thay thế gì cũng được. 🦀 + +## Lịch sử Star + +

+ + + + + Star History Chart + + +

diff --git a/README.zh-CN.md b/README.zh-CN.md index 357b8f1..ab918d3 100644 --- a/README.zh-CN.md +++ b/README.zh-CN.md @@ -8,6 +8,15 @@ 零开销、零妥协;随处部署、万物可换。

+

+ X: @zeroclawlabs + Xiaohongshu: Official + Telegram: @zeroclawlabs + Telegram CN: @zeroclawlabs_cn + Telegram RU: @zeroclawlabs_ru + Reddit: r/zeroclawlabs +

+

🌐 语言:English · 简体中文 · 日本語 · Русский

@@ -33,7 +42,17 @@ > > 技术标识(命令、配置键、API 路径、Trait 名称)保持英文,避免语义漂移。 > -> 最后对齐时间:**2026-02-18**。 +> 最后对齐时间:**2026-02-19**。 + +## 📢 公告板 + +用于发布重要通知(破坏性变更、安全通告、维护窗口、版本阻塞问题等)。 + +| 日期(UTC) | 级别 | 通知 | 处理建议 | +|---|---|---|---| +| 2026-02-19 | _紧急_ | 我们与 `openagen/zeroclaw` 及 `zeroclaw.org` **没有任何关系**。`zeroclaw.org` 当前会指向 `openagen/zeroclaw` 这个 fork,并且该域名/仓库正在冒充我们的官网与官方项目。 | 请不要相信上述来源发布的任何信息、二进制、募资活动或官方声明。请仅以本仓库和已验证官方社媒为准。 | +| 2026-02-19 | _重要_ | 我们目前**尚未发布官方正式网站**,且已发现有人尝试冒充我们。请勿参与任何打着 ZeroClaw 名义进行的投资、募资或类似活动。 | 一切信息请以本仓库为准;也可关注 [X(@zeroclawlabs)](https://x.com/zeroclawlabs?s=21)、[Reddit(r/zeroclawlabs)](https://www.reddit.com/r/zeroclawlabs/)、[Telegram(@zeroclawlabs)](https://t.me/zeroclawlabs)、[Telegram 中文频道(@zeroclawlabs_cn)](https://t.me/zeroclawlabs_cn)、[Telegram 俄语频道(@zeroclawlabs_ru)](https://t.me/zeroclawlabs_ru) 与 [小红书账号](https://www.xiaohongshu.com/user/profile/67cbfc43000000000d008307?xsec_token=AB73VnYnGNx5y36EtnnZfGmAmS-6Wzv8WMuGpfwfkg6Yc%3D&xsec_source=pc_search) 获取官方最新动态。 | +| 2026-02-19 | _重要_ | Anthropic 于 2026-02-19 更新了 Authentication and Credential Use 条款。条款明确:OAuth authentication(用于 Free、Pro、Max)仅适用于 Claude Code 与 Claude.ai;将 Claude Free/Pro/Max 账号获得的 OAuth token 用于其他任何产品、工具或服务(包括 Agent SDK)不被允许,并可能构成对 Consumer Terms of Service 的违规。 | 为避免损失,请暂时不要尝试 Claude Code OAuth 集成;原文见:[Authentication and Credential Use](https://code.claude.com/docs/en/legal-and-compliance#authentication-and-credential-use)。 | ## 项目简介 @@ -100,6 +119,12 @@ cd zeroclaw ## 快速开始 +### Homebrew(macOS/Linuxbrew) + +```bash +brew install zeroclaw +``` + ```bash git clone https://github.com/zeroclaw-labs/zeroclaw.git cd zeroclaw @@ -122,6 +147,106 @@ zeroclaw gateway zeroclaw daemon ``` +## Subscription Auth(OpenAI Codex / Claude Code) + +ZeroClaw 现已支持基于订阅的原生鉴权配置(多账号、静态加密存储)。 + +- 配置文件:`~/.zeroclaw/auth-profiles.json` +- 加密密钥:`~/.zeroclaw/.secret_key` +- Profile ID 格式:`:`(例:`openai-codex:work`) + +OpenAI Codex OAuth(ChatGPT 订阅): + +```bash +# 推荐用于服务器/无显示器环境 +zeroclaw auth login --provider openai-codex --device-code + +# 浏览器/回调流程,支持粘贴回退 +zeroclaw auth login --provider openai-codex --profile default +zeroclaw auth paste-redirect --provider openai-codex --profile default + +# 检查 / 刷新 / 切换 profile +zeroclaw auth status +zeroclaw auth refresh --provider openai-codex --profile default +zeroclaw auth use --provider openai-codex --profile work +``` + +Claude Code / Anthropic setup-token: + +```bash +# 粘贴订阅/setup token(Authorization header 模式) +zeroclaw auth paste-token --provider anthropic --profile default --auth-kind authorization + +# 别名命令 +zeroclaw auth setup-token --provider anthropic --profile default +``` + +使用 subscription auth 运行 agent: + +```bash +zeroclaw agent --provider openai-codex -m "hello" +zeroclaw agent --provider openai-codex --auth-profile openai-codex:work -m "hello" + +# Anthropic 同时支持 API key 和 auth token 环境变量: +# ANTHROPIC_AUTH_TOKEN, ANTHROPIC_OAUTH_TOKEN, ANTHROPIC_API_KEY +zeroclaw agent --provider anthropic -m "hello" +``` + +## 架构 + +每个子系统都是一个 **Trait** — 通过配置切换即可更换实现,无需修改代码。 + +

+ ZeroClaw 架构图 +

+ +| 子系统 | Trait | 内置实现 | 扩展方式 | +|--------|-------|----------|----------| +| **AI 模型** | `Provider` | 通过 `zeroclaw providers` 查看(当前 28 个内置 + 别名,以及自定义端点) | `custom:https://your-api.com`(OpenAI 兼容)或 `anthropic-custom:https://your-api.com` | +| **通道** | `Channel` | CLI, Telegram, Discord, Slack, Mattermost, iMessage, Matrix, Signal, WhatsApp, Email, IRC, Lark, DingTalk, QQ, Webhook | 任意消息 API | +| **记忆** | `Memory` | SQLite 混合搜索, PostgreSQL 后端, Lucid 桥接, Markdown 文件, 显式 `none` 后端, 快照/恢复, 可选响应缓存 | 任意持久化后端 | +| **工具** | `Tool` | shell/file/memory, cron/schedule, git, pushover, browser, http_request, screenshot/image_info, composio (opt-in), delegate, 硬件工具 | 任意能力 | +| **可观测性** | `Observer` | Noop, Log, Multi | Prometheus, OTel | +| **运行时** | `RuntimeAdapter` | Native, Docker(沙箱) | 通过 adapter 添加;不支持的类型会快速失败 | +| **安全** | `SecurityPolicy` | Gateway 配对, 沙箱, allowlist, 速率限制, 文件系统作用域, 加密密钥 | — | +| **身份** | `IdentityConfig` | OpenClaw (markdown), AIEOS v1.1 (JSON) | 任意身份格式 | +| **隧道** | `Tunnel` | None, Cloudflare, Tailscale, ngrok, Custom | 任意隧道工具 | +| **心跳** | Engine | HEARTBEAT.md 定期任务 | — | +| **技能** | Loader | TOML 清单 + SKILL.md 指令 | 社区技能包 | +| **集成** | Registry | 9 个分类下 70+ 集成 | 插件系统 | + +### 运行时支持(当前) + +- ✅ 当前支持:`runtime.kind = "native"` 或 `runtime.kind = "docker"` +- 🚧 计划中,尚未实现:WASM / 边缘运行时 + +配置了不支持的 `runtime.kind` 时,ZeroClaw 会以明确的错误退出,而非静默回退到 native。 + +### 记忆系统(全栈搜索引擎) + +全部自研,零外部依赖 — 无需 Pinecone、Elasticsearch、LangChain: + +| 层级 | 实现 | +|------|------| +| **向量数据库** | Embeddings 以 BLOB 存储于 SQLite,余弦相似度搜索 | +| **关键词搜索** | FTS5 虚拟表,BM25 评分 | +| **混合合并** | 自定义加权合并函数(`vector.rs`) | +| **Embeddings** | `EmbeddingProvider` trait — OpenAI、自定义 URL 或 noop | +| **分块** | 基于行的 Markdown 分块器,保留标题结构 | +| **缓存** | SQLite `embedding_cache` 表,LRU 淘汰策略 | +| **安全重索引** | 原子化重建 FTS5 + 重新嵌入缺失向量 | + +Agent 通过工具自动进行记忆的回忆、保存和管理。 + +```toml +[memory] +backend = "sqlite" # "sqlite", "lucid", "postgres", "markdown", "none" +auto_save = true +embedding_provider = "none" # "none", "openai", "custom:https://..." +vector_weight = 0.7 +keyword_weight = 0.3 +``` + ## 安全默认行为(关键) - Gateway 默认绑定:`127.0.0.1:3000` diff --git a/TRADEMARK.md b/TRADEMARK.md new file mode 100644 index 0000000..ac70fb5 --- /dev/null +++ b/TRADEMARK.md @@ -0,0 +1,129 @@ +# ZeroClaw Trademark Policy + +**Effective date:** February 2026 +**Maintained by:** ZeroClaw Labs + +--- + +## Our Trademarks + +The following are trademarks of ZeroClaw Labs: + +- **ZeroClaw** (word mark) +- **zeroclaw-labs** (organization name) +- The ZeroClaw logo and associated visual identity + +These marks identify the official ZeroClaw project and distinguish it from +unauthorized forks, derivatives, or impersonators. + +--- + +## Official Repository + +The **only** official ZeroClaw repository is: + +> https://github.com/zeroclaw-labs/zeroclaw + +Any other repository, organization, domain, or product claiming to be +"ZeroClaw" or implying affiliation with ZeroClaw Labs is unauthorized and +may constitute trademark infringement. + +**Known unauthorized forks:** +- `openagen/zeroclaw` — not affiliated with ZeroClaw Labs + +If you encounter an unauthorized use, please report it by opening an issue +at https://github.com/zeroclaw-labs/zeroclaw/issues. + +--- + +## Permitted Uses + +You **may** use the ZeroClaw name and marks in the following ways without +prior written permission: + +1. **Attribution** — stating that your software is based on or derived from + ZeroClaw, provided it is clear your project is not the official ZeroClaw. + +2. **Descriptive reference** — referring to ZeroClaw in documentation, + articles, blog posts, or presentations to accurately describe the software. + +3. **Community discussion** — using the name in forums, issues, or social + media to discuss the project. + +4. **Fork identification** — identifying your fork as "a fork of ZeroClaw" + with a clear link to the official repository. + +--- + +## Prohibited Uses + +You **may not** use the ZeroClaw name or marks in ways that: + +1. **Imply official endorsement** — suggest your project, product, or + organization is officially affiliated with or endorsed by ZeroClaw Labs. + +2. **Cause brand confusion** — use "ZeroClaw" as the primary name of a + competing or derivative product in a way that could confuse users about + the source. + +3. **Impersonate the project** — create repositories, domains, packages, + or accounts that could be mistaken for the official ZeroClaw project. + +4. **Misrepresent origin** — remove or obscure attribution to ZeroClaw Labs + while distributing the software or derivatives. + +5. **Commercial trademark use** — use the marks in commercial products, + services, or marketing without prior written permission from ZeroClaw Labs. + +--- + +## Fork Guidelines + +Forks are welcome under the terms of the MIT and Apache 2.0 licenses. If +you fork ZeroClaw, you must: + +- Clearly state your project is a fork of ZeroClaw +- Link back to the official repository +- Not use "ZeroClaw" as the primary name of your fork +- Not imply your fork is the official or original project +- Retain all copyright, license, and attribution notices + +--- + +## Contributor Protections + +Contributors to the official ZeroClaw repository are protected under the +dual MIT + Apache 2.0 license model: + +- **Patent grant** (Apache 2.0) — your contributions are protected from + patent claims by other contributors. +- **Attribution** — your contributions are permanently recorded in the + repository history and NOTICE file. +- **No trademark transfer** — contributing code does not transfer any + trademark rights to third parties. + +--- + +## Reporting Infringement + +If you believe someone is infringing ZeroClaw trademarks: + +1. Open an issue at https://github.com/zeroclaw-labs/zeroclaw/issues +2. Include the URL of the infringing content +3. Describe how it violates this policy + +For serious or commercial infringement, contact the maintainers directly +through the repository. + +--- + +## Changes to This Policy + +ZeroClaw Labs reserves the right to update this policy at any time. Changes +will be committed to the official repository with a clear commit message. + +--- + +*This trademark policy is separate from and in addition to the MIT and +Apache 2.0 software licenses. The licenses govern use of the source code; +this policy governs use of the ZeroClaw name and brand.* diff --git a/bootstrap.sh b/bootstrap.sh index 32a5574..2c8984d 100755 --- a/bootstrap.sh +++ b/bootstrap.sh @@ -1,5 +1,5 @@ #!/usr/bin/env bash set -euo pipefail -ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" -exec "$ROOT_DIR/scripts/bootstrap.sh" "$@" +ROOT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]:-$0}")" >/dev/null 2>&1 && pwd || pwd)" +exec "$ROOT_DIR/zeroclaw_install.sh" "$@" diff --git a/crates/robot-kit/Cargo.toml b/crates/robot-kit/Cargo.toml index 76b2863..69eddd6 100644 --- a/crates/robot-kit/Cargo.toml +++ b/crates/robot-kit/Cargo.toml @@ -30,7 +30,7 @@ tokio = { version = "1.42", features = ["rt-multi-thread", "macros", "time", "sy # Serialization serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" -toml = "0.8" +toml = "1.0" # HTTP client (for Ollama vision) reqwest = { version = "0.12", default-features = false, features = ["json", "rustls-tls"] } @@ -52,7 +52,7 @@ tracing = "0.1" chrono = { version = "0.4", features = ["clock", "std"] } # User directories -directories = "5.0" +directories = "6.0" [target.'cfg(target_os = "linux")'.dependencies] diff --git a/dev/cli.sh b/dev/cli.sh index ec9aad5..f25ac27 100755 --- a/dev/cli.sh +++ b/dev/cli.sh @@ -14,6 +14,11 @@ else fi COMPOSE_FILE="$BASE_DIR/docker-compose.yml" +if [ "$BASE_DIR" = "dev" ]; then + ENV_FILE=".env" +else + ENV_FILE="../.env" +fi # Colors GREEN='\033[0;32m' @@ -21,6 +26,15 @@ YELLOW='\033[1;33m' RED='\033[0;31m' NC='\033[0m' # No Color +function load_env { + if [ -f "$ENV_FILE" ]; then + # Auto-export variables from .env for docker compose passthrough. + set -a + source "$ENV_FILE" + set +a + fi +} + function ensure_config { CONFIG_DIR="$HOST_TARGET_DIR/.zeroclaw" CONFIG_FILE="$CONFIG_DIR/config.toml" @@ -55,6 +69,8 @@ if [ -z "$1" ]; then exit 1 fi +load_env + case "$1" in up) ensure_config diff --git a/dev/docker-compose.yml b/dev/docker-compose.yml index 93de91a..ca45084 100644 --- a/dev/docker-compose.yml +++ b/dev/docker-compose.yml @@ -20,11 +20,20 @@ services: container_name: zeroclaw-dev restart: unless-stopped environment: - - API_KEY - - PROVIDER - - ZEROCLAW_MODEL - ZEROCLAW_GATEWAY_PORT=3000 - SANDBOX_HOST=zeroclaw-sandbox + secrets: + - source: zeroclaw_env + target: zeroclaw_env + entrypoint: ["/bin/bash", "-lc"] + command: + - | + if [ -f /run/secrets/zeroclaw_env ]; then + set -a + . /run/secrets/zeroclaw_env + set +a + fi + exec zeroclaw gateway --port "${ZEROCLAW_GATEWAY_PORT:-3000}" --host "[::]" volumes: # Mount single config file (avoids shadowing other files in .zeroclaw) - ../target/.zeroclaw/config.toml:/zeroclaw-data/.zeroclaw/config.toml @@ -57,3 +66,7 @@ services: networks: dev-net: driver: bridge + +secrets: + zeroclaw_env: + file: ../.env diff --git a/docs/channels-reference.md b/docs/channels-reference.md index 2ab904e..9c99b28 100644 --- a/docs/channels-reference.md +++ b/docs/channels-reference.md @@ -51,8 +51,43 @@ Notes: - Model cache previews come from `zeroclaw models refresh --provider `. - These are runtime chat commands, not CLI subcommands. +## Inbound Image Marker Protocol + +ZeroClaw supports multimodal input through inline message markers: + +- Syntax: ``[IMAGE:]`` +- `` can be: + - Local file path + - Data URI (`data:image/...;base64,...`) + - Remote URL only when `[multimodal].allow_remote_fetch = true` + +Operational notes: + +- Marker parsing applies to user-role messages before provider calls. +- Provider capability is enforced at runtime: if the selected provider does not support vision, the request fails with a structured capability error (`capability=vision`). +- Linq webhook `media` parts with `image/*` MIME type are automatically converted to this marker format. + ## Channel Matrix +### Build Feature Toggle (`channel-matrix`) + +Matrix support is controlled at compile time by the `channel-matrix` Cargo feature. + +- Default builds include Matrix support (`default = ["hardware", "channel-matrix"]`). +- For faster local iteration when Matrix is not needed: + +```bash +cargo check --no-default-features --features hardware +``` + +- To explicitly enable Matrix support in custom feature sets: + +```bash +cargo check --no-default-features --features hardware,channel-matrix +``` + +If `[channels_config.matrix]` is present but the binary was built without `channel-matrix`, `zeroclaw channel list`, `zeroclaw channel doctor`, and `zeroclaw channel start` will log that Matrix is intentionally skipped for this build. + --- ## 2. Delivery Modes at a Glance @@ -66,7 +101,7 @@ Notes: | Mattermost | polling | No | | Matrix | sync API (supports E2EE) | No | | Signal | signal-cli HTTP bridge | No (local bridge endpoint) | -| WhatsApp | webhook | Yes (public HTTPS callback) | +| WhatsApp | webhook (Cloud API) or websocket (Web mode) | Cloud API: Yes (public HTTPS callback), Web mode: No | | Webhook | gateway endpoint (`/webhook`) | Usually yes | | Email | IMAP polling + SMTP send | No | | IRC | IRC socket | No | @@ -103,8 +138,17 @@ Field names differ by channel: [channels_config.telegram] bot_token = "123456:telegram-token" allowed_users = ["*"] +stream_mode = "off" # optional: off | partial +draft_update_interval_ms = 1000 # optional: edit throttle for partial streaming +mention_only = false # optional: require @mention in groups +interrupt_on_new_message = false # optional: cancel in-flight same-sender same-chat request ``` +Telegram notes: + +- `interrupt_on_new_message = true` preserves interrupted user turns in conversation history, then restarts generation on the newest message. +- Interruption scope is strict: same sender in the same chat. Messages from different chats are processed independently. + ### 4.2 Discord ```toml @@ -164,6 +208,13 @@ ignore_stories = true ### 4.7 WhatsApp +ZeroClaw supports two WhatsApp backends: + +- **Cloud API mode** (`phone_number_id` + `access_token` + `verify_token`) +- **WhatsApp Web mode** (`session_path`, requires build flag `--features whatsapp-web`) + +Cloud API mode: + ```toml [channels_config.whatsapp] access_token = "EAAB..." @@ -173,6 +224,22 @@ app_secret = "your-app-secret" # optional but recommended allowed_numbers = ["*"] ``` +WhatsApp Web mode: + +```toml +[channels_config.whatsapp] +session_path = "~/.zeroclaw/state/whatsapp-web/session.db" +pair_phone = "15551234567" # optional; omit to use QR flow +pair_code = "" # optional custom pair code +allowed_numbers = ["*"] +``` + +Notes: + +- Build with `cargo build --features whatsapp-web` (or equivalent run command). +- Keep `session_path` on persistent storage to avoid relinking after restart. +- Reply routing uses the originating chat JID, so direct and group replies work correctly. + ### 4.8 Webhook Channel Config (Gateway) `channels_config.webhook` enables webhook-specific gateway behavior. @@ -331,7 +398,7 @@ rg -n "Matrix|Telegram|Discord|Slack|Mattermost|Signal|WhatsApp|Email|IRC|Lark|D | Mattermost | `Mattermost channel listening on` | `Mattermost: ignoring message from unauthorized user:` | `Mattermost poll error:` / `Mattermost parse error:` | | Matrix | `Matrix channel listening on room` / `Matrix room ... is encrypted; E2EE decryption is enabled via matrix-sdk.` | `Matrix whoami failed; falling back to configured session hints for E2EE session restore:` / `Matrix whoami failed while resolving listener user_id; using configured user_id hint:` | `Matrix sync error: ... retrying...` | | Signal | `Signal channel listening via SSE on` | (allowlist checks are enforced by `allowed_from`) | `Signal SSE returned ...` / `Signal SSE connect error:` | -| WhatsApp (channel) | `WhatsApp channel active (webhook mode).` | `WhatsApp: ignoring message from unauthorized number:` | `WhatsApp send failed:` | +| WhatsApp (channel) | `WhatsApp channel active (webhook mode).` / `WhatsApp Web connected successfully` | `WhatsApp: ignoring message from unauthorized number:` / `WhatsApp Web: message from ... not in allowed list` | `WhatsApp send failed:` / `WhatsApp Web stream error:` | | Webhook / WhatsApp (gateway) | `WhatsApp webhook verified successfully` | `Webhook: rejected — not paired / invalid bearer token` / `Webhook: rejected request — invalid or missing X-Webhook-Secret` / `WhatsApp webhook verification failed — token mismatch` | `Webhook JSON parse error:` | | Email | `Email polling every ...` / `Email sent to ...` | `Blocked email from ...` | `Email poll failed:` / `Email poll task panicked:` | | IRC | `IRC channel connecting to ...` / `IRC registered as ...` | (allowlist checks are enforced by `allowed_users`) | `IRC SASL authentication failed (...)` / `IRC server does not support SASL...` / `IRC nickname ... is in use, trying ...` | @@ -349,4 +416,3 @@ If a specific channel task crashes or exits, the channel supervisor in `channels - `Channel message worker crashed:` These messages indicate automatic restart behavior is active, and you should inspect preceding logs for root cause. - diff --git a/docs/commands-reference.md b/docs/commands-reference.md index 8c0d3ae..da9d52c 100644 --- a/docs/commands-reference.md +++ b/docs/commands-reference.md @@ -2,7 +2,7 @@ This reference is derived from the current CLI surface (`zeroclaw --help`). -Last verified: **February 18, 2026**. +Last verified: **February 19, 2026**. ## Top-Level Commands @@ -22,6 +22,7 @@ Last verified: **February 18, 2026**. | `integrations` | Inspect integration details | | `skills` | List/install/remove skills | | `migrate` | Import from external runtimes (currently OpenClaw) | +| `config` | Export machine-readable config schema | | `hardware` | Discover and introspect USB hardware | | `peripheral` | Configure and flash peripherals | @@ -33,6 +34,7 @@ Last verified: **February 18, 2026**. - `zeroclaw onboard --interactive` - `zeroclaw onboard --channels-only` - `zeroclaw onboard --api-key --provider --memory ` +- `zeroclaw onboard --api-key --provider --model --memory ` ### `agent` @@ -51,6 +53,7 @@ Last verified: **February 18, 2026**. - `zeroclaw service install` - `zeroclaw service start` - `zeroclaw service stop` +- `zeroclaw service restart` - `zeroclaw service status` - `zeroclaw service uninstall` @@ -89,6 +92,13 @@ Runtime in-chat commands (Telegram/Discord while channel server is running): - `/model` - `/model ` +Channel runtime also watches `config.toml` and hot-applies updates to: +- `default_provider` +- `default_model` +- `default_temperature` +- `api_key` / `api_url` (for the default provider) +- `reliability.*` provider retry settings + `add/remove` currently route you back to managed setup/manual config paths (not full declarative mutators yet). ### `integrations` @@ -101,10 +111,20 @@ Runtime in-chat commands (Telegram/Discord while channel server is running): - `zeroclaw skills install ` - `zeroclaw skills remove ` +`` accepts git remotes (`https://...`, `http://...`, `ssh://...`, and `git@host:owner/repo.git`) or a local filesystem path. + +Skill manifests (`SKILL.toml`) support `prompts` and `[[tools]]`; both are injected into the agent system prompt at runtime, so the model can follow skill instructions without manually reading skill files. + ### `migrate` - `zeroclaw migrate openclaw [--source ] [--dry-run]` +### `config` + +- `zeroclaw config schema` + +`config schema` prints a JSON Schema (draft 2020-12) for the full `config.toml` contract to stdout. + ### `hardware` - `zeroclaw hardware discover` diff --git a/docs/config-reference.md b/docs/config-reference.md index dbc5221..8291a3c 100644 --- a/docs/config-reference.md +++ b/docs/config-reference.md @@ -2,11 +2,21 @@ This is a high-signal reference for common config sections and defaults. -Last verified: **February 18, 2026**. +Last verified: **February 19, 2026**. -Config file path: +Config path resolution at startup: -- `~/.zeroclaw/config.toml` +1. `ZEROCLAW_WORKSPACE` override (if set) +2. persisted `~/.zeroclaw/active_workspace.toml` marker (if present) +3. default `~/.zeroclaw/config.toml` + +ZeroClaw logs the resolved config on startup at `INFO` level: + +- `Config loaded` with fields: `path`, `workspace`, `source`, `initialized` + +Schema export command: + +- `zeroclaw config schema` (prints JSON Schema draft 2020-12 to stdout) ## Core Keys @@ -16,17 +26,216 @@ Config file path: | `default_model` | `anthropic/claude-sonnet-4-6` | model routed through selected provider | | `default_temperature` | `0.7` | model temperature | +## `[observability]` + +| Key | Default | Purpose | +|---|---|---| +| `backend` | `none` | Observability backend: `none`, `noop`, `log`, `prometheus`, `otel`, `opentelemetry`, or `otlp` | +| `otel_endpoint` | `http://localhost:4318` | OTLP HTTP endpoint used when backend is `otel` | +| `otel_service_name` | `zeroclaw` | Service name emitted to OTLP collector | + +Notes: + +- `backend = "otel"` uses OTLP HTTP export with a blocking exporter client so spans and metrics can be emitted safely from non-Tokio contexts. +- Alias values `opentelemetry` and `otlp` map to the same OTel backend. + +Example: + +```toml +[observability] +backend = "otel" +otel_endpoint = "http://localhost:4318" +otel_service_name = "zeroclaw" +``` + +## Environment Provider Overrides + +Provider selection can also be controlled by environment variables. Precedence is: + +1. `ZEROCLAW_PROVIDER` (explicit override, always wins when non-empty) +2. `PROVIDER` (legacy fallback, only applied when config provider is unset or still `openrouter`) +3. `default_provider` in `config.toml` + +Operational note for container users: + +- If your `config.toml` sets an explicit custom provider like `custom:https://.../v1`, a default `PROVIDER=openrouter` from Docker/container env will no longer replace it. +- Use `ZEROCLAW_PROVIDER` when you intentionally want runtime env to override a non-default configured provider. + ## `[agent]` | Key | Default | Purpose | |---|---|---| +| `compact_context` | `false` | When true: bootstrap_max_chars=6000, rag_chunk_limit=2. Use for 13B or smaller models | | `max_tool_iterations` | `10` | Maximum tool-call loop turns per user message across CLI, gateway, and channels | +| `max_history_messages` | `50` | Maximum conversation history messages retained per session | +| `parallel_tools` | `false` | Enable parallel tool execution within a single iteration | +| `tool_dispatcher` | `auto` | Tool dispatch strategy | Notes: - Setting `max_tool_iterations = 0` falls back to safe default `10`. - If a channel message exceeds this value, the runtime returns: `Agent exceeded maximum tool iterations ()`. +## `[agents.]` + +Delegate sub-agent configurations. Each key under `[agents]` defines a named sub-agent that the primary agent can delegate to. + +| Key | Default | Purpose | +|---|---|---| +| `provider` | _required_ | Provider name (e.g. `"ollama"`, `"openrouter"`, `"anthropic"`) | +| `model` | _required_ | Model name for the sub-agent | +| `system_prompt` | unset | Optional system prompt override for the sub-agent | +| `api_key` | unset | Optional API key override (stored encrypted when `secrets.encrypt = true`) | +| `temperature` | unset | Temperature override for the sub-agent | +| `max_depth` | `3` | Max recursion depth for nested delegation | + +```toml +[agents.researcher] +provider = "openrouter" +model = "anthropic/claude-sonnet-4-6" +system_prompt = "You are a research assistant." +max_depth = 2 + +[agents.coder] +provider = "ollama" +model = "qwen2.5-coder:32b" +temperature = 0.2 +``` + +## `[runtime]` + +| Key | Default | Purpose | +|---|---|---| +| `reasoning_enabled` | unset (`None`) | Global reasoning/thinking override for providers that support explicit controls | + +Notes: + +- `reasoning_enabled = false` explicitly disables provider-side reasoning for supported providers (currently `ollama`, via request field `think: false`). +- `reasoning_enabled = true` explicitly requests reasoning for supported providers (`think: true` on `ollama`). +- Unset keeps provider defaults. + +## `[skills]` + +| Key | Default | Purpose | +|---|---|---| +| `open_skills_enabled` | `false` | Opt-in loading/sync of community `open-skills` repository | +| `open_skills_dir` | unset | Optional local path for `open-skills` (defaults to `$HOME/open-skills` when enabled) | + +Notes: + +- Security-first default: ZeroClaw does **not** clone or sync `open-skills` unless `open_skills_enabled = true`. +- Environment overrides: + - `ZEROCLAW_OPEN_SKILLS_ENABLED` accepts `1/0`, `true/false`, `yes/no`, `on/off`. + - `ZEROCLAW_OPEN_SKILLS_DIR` overrides the repository path when non-empty. +- Precedence for enable flag: `ZEROCLAW_OPEN_SKILLS_ENABLED` → `skills.open_skills_enabled` in `config.toml` → default `false`. + +## `[composio]` + +| Key | Default | Purpose | +|---|---|---| +| `enabled` | `false` | Enable Composio managed OAuth tools | +| `api_key` | unset | Composio API key used by the `composio` tool | +| `entity_id` | `default` | Default `user_id` sent on connect/execute calls | + +Notes: + +- Backward compatibility: legacy `enable = true` is accepted as an alias for `enabled = true`. +- If `enabled = false` or `api_key` is missing, the `composio` tool is not registered. +- ZeroClaw requests Composio v3 tools with `toolkit_versions=latest` and executes tools with `version="latest"` to avoid stale default tool revisions. +- Typical flow: call `connect`, complete browser OAuth, then run `execute` for the desired tool action. +- If Composio returns a missing connected-account reference error, call `list_accounts` (optionally with `app`) and pass the returned `connected_account_id` to `execute`. + +## `[cost]` + +| Key | Default | Purpose | +|---|---|---| +| `enabled` | `false` | Enable cost tracking | +| `daily_limit_usd` | `10.00` | Daily spending limit in USD | +| `monthly_limit_usd` | `100.00` | Monthly spending limit in USD | +| `warn_at_percent` | `80` | Warn when spending reaches this percentage of limit | +| `allow_override` | `false` | Allow requests to exceed budget with `--override` flag | + +Notes: + +- When `enabled = true`, the runtime tracks per-request cost estimates and enforces daily/monthly limits. +- At `warn_at_percent` threshold, a warning is emitted but requests continue. +- When a limit is reached, requests are rejected unless `allow_override = true` and the `--override` flag is passed. + +## `[identity]` + +| Key | Default | Purpose | +|---|---|---| +| `format` | `openclaw` | Identity format: `"openclaw"` (default) or `"aieos"` | +| `aieos_path` | unset | Path to AIEOS JSON file (relative to workspace) | +| `aieos_inline` | unset | Inline AIEOS JSON (alternative to file path) | + +Notes: + +- Use `format = "aieos"` with either `aieos_path` or `aieos_inline` to load an AIEOS / OpenClaw identity document. +- Only one of `aieos_path` or `aieos_inline` should be set; `aieos_path` takes precedence. + +## `[multimodal]` + +| Key | Default | Purpose | +|---|---|---| +| `max_images` | `4` | Maximum image markers accepted per request | +| `max_image_size_mb` | `5` | Per-image size limit before base64 encoding | +| `allow_remote_fetch` | `false` | Allow fetching `http(s)` image URLs from markers | + +Notes: + +- Runtime accepts image markers in user messages with syntax: ``[IMAGE:]``. +- Supported sources: + - Local file path (for example ``[IMAGE:/tmp/screenshot.png]``) +- Data URI (for example ``[IMAGE:data:image/png;base64,...]``) +- Remote URL only when `allow_remote_fetch = true` +- Allowed MIME types: `image/png`, `image/jpeg`, `image/webp`, `image/gif`, `image/bmp`. +- When the active provider does not support vision, requests fail with a structured capability error (`capability=vision`) instead of silently dropping images. + +## `[browser]` + +| Key | Default | Purpose | +|---|---|---| +| `enabled` | `false` | Enable `browser_open` tool (opens URLs without scraping) | +| `allowed_domains` | `[]` | Allowed domains for `browser_open` (exact or subdomain match) | +| `session_name` | unset | Browser session name (for agent-browser automation) | +| `backend` | `agent_browser` | Browser automation backend: `"agent_browser"`, `"rust_native"`, `"computer_use"`, or `"auto"` | +| `native_headless` | `true` | Headless mode for rust-native backend | +| `native_webdriver_url` | `http://127.0.0.1:9515` | WebDriver endpoint URL for rust-native backend | +| `native_chrome_path` | unset | Optional Chrome/Chromium executable path for rust-native backend | + +### `[browser.computer_use]` + +| Key | Default | Purpose | +|---|---|---| +| `endpoint` | `http://127.0.0.1:8787/v1/actions` | Sidecar endpoint for computer-use actions (OS-level mouse/keyboard/screenshot) | +| `api_key` | unset | Optional bearer token for computer-use sidecar (stored encrypted) | +| `timeout_ms` | `15000` | Per-action request timeout in milliseconds | +| `allow_remote_endpoint` | `false` | Allow remote/public endpoint for computer-use sidecar | +| `window_allowlist` | `[]` | Optional window title/process allowlist forwarded to sidecar policy | +| `max_coordinate_x` | unset | Optional X-axis boundary for coordinate-based actions | +| `max_coordinate_y` | unset | Optional Y-axis boundary for coordinate-based actions | + +Notes: + +- When `backend = "computer_use"`, the agent delegates browser actions to the sidecar at `computer_use.endpoint`. +- `allow_remote_endpoint = false` (default) rejects any non-loopback endpoint to prevent accidental public exposure. +- Use `window_allowlist` to restrict which OS windows the sidecar can interact with. + +## `[http_request]` + +| Key | Default | Purpose | +|---|---|---| +| `enabled` | `false` | Enable `http_request` tool for API interactions | +| `allowed_domains` | `[]` | Allowed domains for HTTP requests (exact or subdomain match) | +| `max_response_size` | `1000000` | Maximum response size in bytes (default: 1 MB) | +| `timeout_secs` | `30` | Request timeout in seconds | + +Notes: + +- Deny-by-default: if `allowed_domains` is empty, all HTTP requests are rejected. +- Use exact domain or subdomain matching (e.g. `"api.example.com"`, `"example.com"`). + ## `[gateway]` | Key | Default | Purpose | @@ -36,20 +245,133 @@ Notes: | `require_pairing` | `true` | require pairing before bearer auth | | `allow_public_bind` | `false` | block accidental public exposure | +## `[autonomy]` + +| Key | Default | Purpose | +|---|---|---| +| `level` | `supervised` | `read_only`, `supervised`, or `full` | +| `workspace_only` | `true` | restrict writes/command paths to workspace scope | +| `allowed_commands` | _required for shell execution_ | allowlist of executable names | +| `forbidden_paths` | `[]` | explicit path denylist | +| `max_actions_per_hour` | `100` | per-policy action budget | +| `max_cost_per_day_cents` | `1000` | per-policy spend guardrail | +| `require_approval_for_medium_risk` | `true` | approval gate for medium-risk commands | +| `block_high_risk_commands` | `true` | hard block for high-risk commands | +| `auto_approve` | `[]` | tool operations always auto-approved | +| `always_ask` | `[]` | tool operations that always require approval | + +Notes: + +- `level = "full"` skips medium-risk approval gating for shell execution, while still enforcing configured guardrails. +- Shell separator/operator parsing is quote-aware. Characters like `;` inside quoted arguments are treated as literals, not command separators. +- Unquoted shell chaining/operators are still enforced by policy checks (`;`, `|`, `&&`, `||`, background chaining, and redirects). + ## `[memory]` | Key | Default | Purpose | |---|---|---| | `backend` | `sqlite` | `sqlite`, `lucid`, `markdown`, `none` | -| `auto_save` | `true` | automatic persistence | +| `auto_save` | `true` | persist user-stated inputs only (assistant outputs are excluded) | | `embedding_provider` | `none` | `none`, `openai`, or custom endpoint | +| `embedding_model` | `text-embedding-3-small` | embedding model ID, or `hint:` route | +| `embedding_dimensions` | `1536` | expected vector size for selected embedding model | | `vector_weight` | `0.7` | hybrid ranking vector weight | | `keyword_weight` | `0.3` | hybrid ranking keyword weight | +Notes: + +- Memory context injection ignores legacy `assistant_resp*` auto-save keys to prevent old model-authored summaries from being treated as facts. + +## `[[model_routes]]` and `[[embedding_routes]]` + +Use route hints so integrations can keep stable names while model IDs evolve. + +### `[[model_routes]]` + +| Key | Default | Purpose | +|---|---|---| +| `hint` | _required_ | Task hint name (e.g. `"reasoning"`, `"fast"`, `"code"`, `"summarize"`) | +| `provider` | _required_ | Provider to route to (must match a known provider name) | +| `model` | _required_ | Model to use with that provider | +| `api_key` | unset | Optional API key override for this route's provider | + +### `[[embedding_routes]]` + +| Key | Default | Purpose | +|---|---|---| +| `hint` | _required_ | Route hint name (e.g. `"semantic"`, `"archive"`, `"faq"`) | +| `provider` | _required_ | Embedding provider (`"none"`, `"openai"`, or `"custom:"`) | +| `model` | _required_ | Embedding model to use with that provider | +| `dimensions` | unset | Optional embedding dimension override for this route | +| `api_key` | unset | Optional API key override for this route's provider | + +```toml +[memory] +embedding_model = "hint:semantic" + +[[model_routes]] +hint = "reasoning" +provider = "openrouter" +model = "provider/model-id" + +[[embedding_routes]] +hint = "semantic" +provider = "openai" +model = "text-embedding-3-small" +dimensions = 1536 +``` + +Upgrade strategy: + +1. Keep hints stable (`hint:reasoning`, `hint:semantic`). +2. Update only `model = "...new-version..."` in the route entries. +3. Validate with `zeroclaw doctor` before restart/rollout. + +## `[query_classification]` + +Automatic model hint routing — maps user messages to `[[model_routes]]` hints based on content patterns. + +| Key | Default | Purpose | +|---|---|---| +| `enabled` | `false` | Enable automatic query classification | +| `rules` | `[]` | Classification rules (evaluated in priority order) | + +Each rule in `rules`: + +| Key | Default | Purpose | +|---|---|---| +| `hint` | _required_ | Must match a `[[model_routes]]` hint value | +| `keywords` | `[]` | Case-insensitive substring matches | +| `patterns` | `[]` | Case-sensitive literal matches (for code fences, keywords like `"fn "`) | +| `min_length` | unset | Only match if message length ≥ N chars | +| `max_length` | unset | Only match if message length ≤ N chars | +| `priority` | `0` | Higher priority rules are checked first | + +```toml +[query_classification] +enabled = true + +[[query_classification.rules]] +hint = "reasoning" +keywords = ["explain", "analyze", "why"] +min_length = 200 +priority = 10 + +[[query_classification.rules]] +hint = "fast" +keywords = ["hi", "hello", "thanks"] +max_length = 50 +priority = 5 +``` + ## `[channels_config]` Top-level channel options are configured under `channels_config`. +| Key | Default | Purpose | +|---|---|---| +| `message_timeout_secs` | `300` | Base timeout in seconds for channel message processing; runtime scales this with tool-loop depth (up to 4x) | + Examples: - `[channels_config.telegram]` @@ -57,8 +379,107 @@ Examples: - `[channels_config.whatsapp]` - `[channels_config.email]` +Notes: + +- Default `300s` is optimized for on-device LLMs (Ollama) which are slower than cloud APIs. +- Runtime timeout budget is `message_timeout_secs * scale`, where `scale = min(max_tool_iterations, 4)` and a minimum of `1`. +- This scaling avoids false timeouts when the first LLM turn is slow/retried but later tool-loop turns still need to complete. +- If using cloud APIs (OpenAI, Anthropic, etc.), you can reduce this to `60` or lower. +- Values below `30` are clamped to `30` to avoid immediate timeout churn. +- When a timeout occurs, users receive: `⚠️ Request timed out while waiting for the model. Please try again.` +- Telegram-only interruption behavior is controlled with `channels_config.telegram.interrupt_on_new_message` (default `false`). + When enabled, a newer message from the same sender in the same chat cancels the in-flight request and preserves interrupted user context. +- While `zeroclaw channel start` is running, updates to `default_provider`, `default_model`, `default_temperature`, `api_key`, `api_url`, and `reliability.*` are hot-applied from `config.toml` on the next inbound message. + See detailed channel matrix and allowlist behavior in [channels-reference.md](channels-reference.md). +### `[channels_config.whatsapp]` + +WhatsApp supports two backends under one config table. + +Cloud API mode (Meta webhook): + +| Key | Required | Purpose | +|---|---|---| +| `access_token` | Yes | Meta Cloud API bearer token | +| `phone_number_id` | Yes | Meta phone number ID | +| `verify_token` | Yes | Webhook verification token | +| `app_secret` | Optional | Enables webhook signature verification (`X-Hub-Signature-256`) | +| `allowed_numbers` | Recommended | Allowed inbound numbers (`[]` = deny all, `"*"` = allow all) | + +WhatsApp Web mode (native client): + +| Key | Required | Purpose | +|---|---|---| +| `session_path` | Yes | Persistent SQLite session path | +| `pair_phone` | Optional | Pair-code flow phone number (digits only) | +| `pair_code` | Optional | Custom pair code (otherwise auto-generated) | +| `allowed_numbers` | Recommended | Allowed inbound numbers (`[]` = deny all, `"*"` = allow all) | + +Notes: + +- WhatsApp Web requires build flag `whatsapp-web`. +- If both Cloud and Web fields are present, Cloud mode wins for backward compatibility. + +## `[hardware]` + +Hardware wizard configuration for physical-world access (STM32, probe, serial). + +| Key | Default | Purpose | +|---|---|---| +| `enabled` | `false` | Whether hardware access is enabled | +| `transport` | `none` | Transport mode: `"none"`, `"native"`, `"serial"`, or `"probe"` | +| `serial_port` | unset | Serial port path (e.g. `"/dev/ttyACM0"`) | +| `baud_rate` | `115200` | Serial baud rate | +| `probe_target` | unset | Probe target chip (e.g. `"STM32F401RE"`) | +| `workspace_datasheets` | `false` | Enable workspace datasheet RAG (index PDF schematics for AI pin lookups) | + +Notes: + +- Use `transport = "serial"` with `serial_port` for USB-serial connections. +- Use `transport = "probe"` with `probe_target` for debug-probe flashing (e.g. ST-Link). +- See [hardware-peripherals-design.md](hardware-peripherals-design.md) for protocol details. + +## `[peripherals]` + +Higher-level peripheral board configuration. Boards become agent tools when enabled. + +| Key | Default | Purpose | +|---|---|---| +| `enabled` | `false` | Enable peripheral support (boards become agent tools) | +| `boards` | `[]` | Board configurations | +| `datasheet_dir` | unset | Path to datasheet docs (relative to workspace) for RAG retrieval | + +Each entry in `boards`: + +| Key | Default | Purpose | +|---|---|---| +| `board` | _required_ | Board type: `"nucleo-f401re"`, `"rpi-gpio"`, `"esp32"`, etc. | +| `transport` | `serial` | Transport: `"serial"`, `"native"`, `"websocket"` | +| `path` | unset | Path for serial: `"/dev/ttyACM0"`, `"/dev/ttyUSB0"` | +| `baud` | `115200` | Baud rate for serial | + +```toml +[peripherals] +enabled = true +datasheet_dir = "docs/datasheets" + +[[peripherals.boards]] +board = "nucleo-f401re" +transport = "serial" +path = "/dev/ttyACM0" +baud = 115200 + +[[peripherals.boards]] +board = "rpi-gpio" +transport = "native" +``` + +Notes: + +- Place `.md`/`.txt` datasheet files named by board (e.g. `nucleo-f401re.md`, `rpi-gpio.md`) in `datasheet_dir` for RAG retrieval. +- See [hardware-peripherals-design.md](hardware-peripherals-design.md) for board protocol and firmware notes. + ## Security-Relevant Defaults - deny-by-default channel allowlists (`[]` means deny all) @@ -73,6 +494,7 @@ After editing config: zeroclaw status zeroclaw doctor zeroclaw channel doctor +zeroclaw service restart ``` ## Related Docs diff --git a/docs/frictionless-security.md b/docs/frictionless-security.md index 2f5fde6..f62046d 100644 --- a/docs/frictionless-security.md +++ b/docs/frictionless-security.md @@ -26,7 +26,7 @@ pub fn run_wizard() -> Result { security: SecurityConfig::autodetect(), // Silent! }; - config.save()?; + config.save().await?; Ok(config) } ``` diff --git a/docs/getting-started/README.md b/docs/getting-started/README.md index e462641..3c7e91c 100644 --- a/docs/getting-started/README.md +++ b/docs/getting-started/README.md @@ -8,6 +8,15 @@ For first-time setup and quick orientation. 2. One-click setup and dual bootstrap mode: [../one-click-bootstrap.md](../one-click-bootstrap.md) 3. Find commands by tasks: [../commands-reference.md](../commands-reference.md) +## Choose Your Path + +| Scenario | Command | +|----------|---------| +| I have an API key, want fastest setup | `zeroclaw onboard --api-key sk-... --provider openrouter` | +| I want guided prompts | `zeroclaw onboard --interactive` | +| Config exists, just fix channels | `zeroclaw onboard --channels-only` | +| Using subscription auth | See [Subscription Auth](../../README.md#subscription-auth-openai-codex--claude-code) | + ## Onboarding and Validation - Quick onboarding: `zeroclaw onboard --api-key "sk-..." --provider openrouter` diff --git a/docs/hardware/README.md b/docs/hardware/README.md index e2158ec..ca0a62a 100644 --- a/docs/hardware/README.md +++ b/docs/hardware/README.md @@ -2,6 +2,8 @@ For board integration, firmware flow, and peripheral architecture. +ZeroClaw's hardware subsystem enables direct control of microcontrollers and peripherals via the `Peripheral` trait. Each board exposes tools for GPIO, ADC, and sensor operations, allowing agent-driven hardware interaction on boards like STM32 Nucleo, Raspberry Pi, and ESP32. See [hardware-peripherals-design.md](../hardware-peripherals-design.md) for the full architecture. + ## Entry Points - Architecture and peripheral model: [../hardware-peripherals-design.md](../hardware-peripherals-design.md) diff --git a/docs/one-click-bootstrap.md b/docs/one-click-bootstrap.md index 0cc8b7c..c9001f7 100644 --- a/docs/one-click-bootstrap.md +++ b/docs/one-click-bootstrap.md @@ -2,7 +2,13 @@ This page defines the fastest supported path to install and initialize ZeroClaw. -Last verified: **February 18, 2026**. +Last verified: **February 20, 2026**. + +## Option 0: Homebrew (macOS/Linuxbrew) + +```bash +brew install zeroclaw +``` ## Option A (Recommended): Clone + local script @@ -17,6 +23,31 @@ What it does by default: 1. `cargo build --release --locked` 2. `cargo install --path . --force --locked` +### Resource preflight and pre-built flow + +Source builds typically require at least: + +- **2 GB RAM + swap** +- **6 GB free disk** + +When resources are constrained, bootstrap now attempts a pre-built binary first. + +```bash +./bootstrap.sh --prefer-prebuilt +``` + +To require binary-only installation and fail if no compatible release asset exists: + +```bash +./bootstrap.sh --prebuilt-only +``` + +To bypass pre-built flow and force source compilation: + +```bash +./bootstrap.sh --force-source-build +``` + ## Dual-mode bootstrap Default behavior is **app-only** (build/install ZeroClaw) and expects existing Rust toolchain. @@ -31,6 +62,9 @@ Notes: - `--install-system-deps` installs compiler/build prerequisites (may require `sudo`). - `--install-rust` installs Rust via `rustup` when missing. +- `--prefer-prebuilt` tries release binary download first, then falls back to source build. +- `--prebuilt-only` disables source fallback. +- `--force-source-build` disables pre-built flow entirely. ## Option B: Remote one-liner @@ -52,6 +86,15 @@ If you run Option B outside a repository checkout, the bootstrap script automati ## Optional onboarding modes +### Containerized onboarding (Docker) + +```bash +./bootstrap.sh --docker +``` + +This builds a local ZeroClaw image and launches onboarding inside a container while +persisting config/workspace to `./.zeroclaw-docker`. + ### Quick onboarding (non-interactive) ```bash diff --git a/docs/project/README.md b/docs/project/README.md index 392a1d0..478200c 100644 --- a/docs/project/README.md +++ b/docs/project/README.md @@ -8,6 +8,10 @@ Time-bound project status snapshots for planning documentation and operations wo ## Scope -Use snapshots to understand changing PR/issue pressure and prioritize doc maintenance. +Project snapshots are time-bound assessments of open PRs, issues, and documentation health. Use these to: -For stable classification of docs intent, use [../docs-inventory.md](../docs-inventory.md). +- Identify documentation gaps driven by feature work +- Prioritize docs maintenance alongside code changes +- Track evolving PR/issue pressure over time + +For stable documentation classification (not time-bound), use [docs-inventory.md](../docs-inventory.md). diff --git a/docs/providers-reference.md b/docs/providers-reference.md index ddefb8c..f9c7726 100644 --- a/docs/providers-reference.md +++ b/docs/providers-reference.md @@ -2,7 +2,7 @@ This document maps provider IDs, aliases, and credential environment variables. -Last verified: **February 18, 2026**. +Last verified: **February 19, 2026**. ## How to List Providers @@ -18,6 +18,10 @@ Runtime resolution order is: 2. Provider-specific env var(s) 3. Generic fallback env vars: `ZEROCLAW_API_KEY` then `API_KEY` +For resilient fallback chains (`reliability.fallback_providers`), each fallback +provider resolves credentials independently. The primary provider's explicit +credential is not reused for fallback providers. + ## Provider Catalog | Canonical ID | Aliases | Local | Provider-specific env var(s) | @@ -37,9 +41,9 @@ Runtime resolution order is: | `zai` | `z.ai` | No | `ZAI_API_KEY` | | `glm` | `zhipu` | No | `GLM_API_KEY` | | `minimax` | `minimax-intl`, `minimax-io`, `minimax-global`, `minimax-cn`, `minimaxi`, `minimax-oauth`, `minimax-oauth-cn`, `minimax-portal`, `minimax-portal-cn` | No | `MINIMAX_OAUTH_TOKEN`, `MINIMAX_API_KEY` | -| `bedrock` | `aws-bedrock` | No | (use config/`API_KEY` fallback) | +| `bedrock` | `aws-bedrock` | No | `AWS_ACCESS_KEY_ID` + `AWS_SECRET_ACCESS_KEY` (optional: `AWS_REGION`) | | `qianfan` | `baidu` | No | `QIANFAN_API_KEY` | -| `qwen` | `dashscope`, `qwen-intl`, `dashscope-intl`, `qwen-us`, `dashscope-us` | No | `DASHSCOPE_API_KEY` | +| `qwen` | `dashscope`, `qwen-intl`, `dashscope-intl`, `qwen-us`, `dashscope-us`, `qwen-code`, `qwen-oauth`, `qwen_oauth` | No | `QWEN_OAUTH_TOKEN`, `DASHSCOPE_API_KEY` | | `groq` | — | No | `GROQ_API_KEY` | | `mistral` | — | No | `MISTRAL_API_KEY` | | `xai` | `grok` | No | `XAI_API_KEY` | @@ -52,6 +56,46 @@ Runtime resolution order is: | `lmstudio` | `lm-studio` | Yes | (optional; local by default) | | `nvidia` | `nvidia-nim`, `build.nvidia.com` | No | `NVIDIA_API_KEY` | +### Gemini Notes + +- Provider ID: `gemini` (aliases: `google`, `google-gemini`) +- Auth can come from `GEMINI_API_KEY`, `GOOGLE_API_KEY`, or Gemini CLI OAuth cache (`~/.gemini/oauth_creds.json`) +- API key requests use `generativelanguage.googleapis.com/v1beta` +- Gemini CLI OAuth requests use `cloudcode-pa.googleapis.com/v1internal` with Code Assist request envelope semantics + +### Ollama Vision Notes + +- Provider ID: `ollama` +- Vision input is supported through user message image markers: ``[IMAGE:]``. +- After multimodal normalization, ZeroClaw sends image payloads through Ollama's native `messages[].images` field. +- If a non-vision provider is selected, ZeroClaw returns a structured capability error instead of silently ignoring images. + +### Bedrock Notes + +- Provider ID: `bedrock` (alias: `aws-bedrock`) +- API: [Converse API](https://docs.aws.amazon.com/bedrock/latest/APIReference/API_runtime_Converse.html) +- Authentication: AWS AKSK (not a single API key). Set `AWS_ACCESS_KEY_ID` + `AWS_SECRET_ACCESS_KEY` environment variables. +- Optional: `AWS_SESSION_TOKEN` for temporary/STS credentials, `AWS_REGION` or `AWS_DEFAULT_REGION` (default: `us-east-1`). +- Default onboarding model: `anthropic.claude-sonnet-4-5-20250929-v1:0` +- Supports native tool calling and prompt caching (`cachePoint`). +- Cross-region inference profiles supported (e.g., `us.anthropic.claude-*`). +- Model IDs use Bedrock format: `anthropic.claude-sonnet-4-6`, `anthropic.claude-opus-4-6-v1`, etc. + +### Ollama Reasoning Toggle + +You can control Ollama reasoning/thinking behavior from `config.toml`: + +```toml +[runtime] +reasoning_enabled = false +``` + +Behavior: + +- `false`: sends `think: false` to Ollama `/api/chat` requests. +- `true`: sends `think: true`. +- Unset: omits `think` and keeps Ollama/model defaults. + ### Kimi Code Notes - Provider ID: `kimi-code` @@ -107,6 +151,33 @@ Optional: - `MINIMAX_OAUTH_REGION=global` or `cn` (defaults by provider alias) - `MINIMAX_OAUTH_CLIENT_ID` to override the default OAuth client id +Channel compatibility note: + +- For MiniMax-backed channel conversations, runtime history is normalized to keep valid `user`/`assistant` turn order. +- Channel-specific delivery guidance (for example Telegram attachment markers) is merged into the leading system prompt instead of being appended as a trailing `system` turn. + +## Qwen Code OAuth Setup (config.toml) + +Set Qwen Code OAuth mode in config: + +```toml +default_provider = "qwen-code" +api_key = "qwen-oauth" +``` + +Credential resolution for `qwen-code`: + +1. Explicit `api_key` value (if not the placeholder `qwen-oauth`) +2. `QWEN_OAUTH_TOKEN` +3. `~/.qwen/oauth_creds.json` (reuses Qwen Code cached OAuth credentials) +4. Optional refresh via `QWEN_OAUTH_REFRESH_TOKEN` (or cached refresh token) +5. If no OAuth placeholder is used, `DASHSCOPE_API_KEY` can still be used as fallback + +Optional endpoint override: + +- `QWEN_OAUTH_RESOURCE_URL` (normalized to `https://.../v1` if needed) +- If unset, `resource_url` from cached OAuth credentials is used when available + ## Model Routing (`hint:`) You can route model calls by hint using `[[model_routes]]`: @@ -128,3 +199,56 @@ Then call with a hint model name (for example from tool or integration paths): ```text hint:reasoning ``` + +## Embedding Routing (`hint:`) + +You can route embedding calls with the same hint pattern using `[[embedding_routes]]`. +Set `[memory].embedding_model` to a `hint:` value to activate routing. + +```toml +[memory] +embedding_model = "hint:semantic" + +[[embedding_routes]] +hint = "semantic" +provider = "openai" +model = "text-embedding-3-small" +dimensions = 1536 + +[[embedding_routes]] +hint = "archive" +provider = "custom:https://embed.example.com/v1" +model = "your-embedding-model-id" +dimensions = 1024 +``` + +Supported embedding providers: + +- `none` +- `openai` +- `custom:` (OpenAI-compatible embeddings endpoint) + +Optional per-route key override: + +```toml +[[embedding_routes]] +hint = "semantic" +provider = "openai" +model = "text-embedding-3-small" +api_key = "sk-route-specific" +``` + +## Upgrading Models Safely + +Use stable hints and update only route targets when providers deprecate model IDs. + +Recommended workflow: + +1. Keep call sites stable (`hint:reasoning`, `hint:semantic`). +2. Change only the target model under `[[model_routes]]` or `[[embedding_routes]]`. +3. Run: + - `zeroclaw doctor` + - `zeroclaw status` +4. Smoke test one representative flow (chat + memory retrieval) before rollout. + +This minimizes breakage because integrations and prompts do not need to change when model IDs are upgraded. diff --git a/docs/troubleshooting.md b/docs/troubleshooting.md index e06e74a..7fd02aa 100644 --- a/docs/troubleshooting.md +++ b/docs/troubleshooting.md @@ -2,7 +2,7 @@ This guide focuses on common setup/runtime failures and fast resolution paths. -Last verified: **February 18, 2026**. +Last verified: **February 20, 2026**. ## Installation / Bootstrap @@ -32,6 +32,93 @@ Fix: ./bootstrap.sh --install-system-deps ``` +### Build fails on low-RAM / low-disk hosts + +Symptoms: + +- `cargo build --release` is killed (`signal: 9`, OOM killer, or `cannot allocate memory`) +- Build crashes after adding swap because disk space runs out + +Why this happens: + +- Runtime memory (<5MB for common operations) is not the same as compile-time memory. +- Full source build can require **2 GB RAM + swap** and **6+ GB free disk**. +- Enabling swap on a tiny disk can avoid RAM OOM but still fail due to disk exhaustion. + +Preferred path for constrained machines: + +```bash +./bootstrap.sh --prefer-prebuilt +``` + +Binary-only mode (no source fallback): + +```bash +./bootstrap.sh --prebuilt-only +``` + +If you must compile from source on constrained hosts: + +1. Add swap only if you also have enough free disk for both swap + build output. +1. Limit cargo parallelism: + +```bash +CARGO_BUILD_JOBS=1 cargo build --release --locked +``` + +1. Reduce heavy features when Matrix is not required: + +```bash +cargo build --release --locked --no-default-features --features hardware +``` + +1. Cross-compile on a stronger machine and copy the binary to the target host. + +### Build is very slow or appears stuck + +Symptoms: + +- `cargo check` / `cargo build` appears stuck at `Checking zeroclaw` for a long time +- repeated `Blocking waiting for file lock on package cache` or `build directory` + +Why this happens in ZeroClaw: + +- Matrix E2EE stack (`matrix-sdk`, `ruma`, `vodozemac`) is large and expensive to type-check. +- TLS + crypto native build scripts (`aws-lc-sys`, `ring`) add noticeable compile time. +- `rusqlite` with bundled SQLite compiles C code locally. +- Running multiple cargo jobs/worktrees in parallel causes lock contention. + +Fast checks: + +```bash +cargo check --timings +cargo tree -d +``` + +The timing report is written to `target/cargo-timings/cargo-timing.html`. + +Faster local iteration (when Matrix channel is not needed): + +```bash +cargo check --no-default-features --features hardware +``` + +This skips `channel-matrix` and can significantly reduce compile time. + +To build with Matrix support explicitly enabled: + +```bash +cargo check --no-default-features --features hardware,channel-matrix +``` + +Lock-contention mitigation: + +```bash +pgrep -af "cargo (check|build|test)|cargo check|cargo build|cargo test" +``` + +Stop unrelated cargo jobs before running your own build. + ### `zeroclaw` command not found after install Symptom: diff --git a/flake.lock b/flake.lock new file mode 100644 index 0000000..b591ed4 --- /dev/null +++ b/flake.lock @@ -0,0 +1,99 @@ +{ + "nodes": { + "fenix": { + "inputs": { + "nixpkgs": [ + "nixpkgs" + ], + "rust-analyzer-src": "rust-analyzer-src" + }, + "locked": { + "lastModified": 1771398736, + "narHash": "sha256-pjV3C7VJHN0o2SvE3O6xiwraLt7bnlWIF3o7Q0BC1jk=", + "owner": "nix-community", + "repo": "fenix", + "rev": "0f608091816de13d92e1f4058b501028b782dddd", + "type": "github" + }, + "original": { + "owner": "nix-community", + "repo": "fenix", + "type": "github" + } + }, + "flake-utils": { + "inputs": { + "systems": "systems" + }, + "locked": { + "lastModified": 1731533236, + "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=", + "owner": "numtide", + "repo": "flake-utils", + "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b", + "type": "github" + }, + "original": { + "owner": "numtide", + "repo": "flake-utils", + "type": "github" + } + }, + "nixpkgs": { + "locked": { + "lastModified": 1771369470, + "narHash": "sha256-0NBlEBKkN3lufyvFegY4TYv5mCNHbi5OmBDrzihbBMQ=", + "owner": "NixOS", + "repo": "nixpkgs", + "rev": "0182a361324364ae3f436a63005877674cf45efb", + "type": "github" + }, + "original": { + "id": "nixpkgs", + "ref": "nixos-unstable", + "type": "indirect" + } + }, + "root": { + "inputs": { + "fenix": "fenix", + "flake-utils": "flake-utils", + "nixpkgs": "nixpkgs" + } + }, + "rust-analyzer-src": { + "flake": false, + "locked": { + "lastModified": 1771353660, + "narHash": "sha256-yp1y55kXgaa08g/gR3CNiUdkg1JRjPYfkKtEIRNE6S8=", + "owner": "rust-lang", + "repo": "rust-analyzer", + "rev": "09f2d468eda25a5f06ae70046357c70ae5cd77c7", + "type": "github" + }, + "original": { + "owner": "rust-lang", + "ref": "nightly", + "repo": "rust-analyzer", + "type": "github" + } + }, + "systems": { + "locked": { + "lastModified": 1681028828, + "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=", + "owner": "nix-systems", + "repo": "default", + "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e", + "type": "github" + }, + "original": { + "owner": "nix-systems", + "repo": "default", + "type": "github" + } + } + }, + "root": "root", + "version": 7 +} diff --git a/flake.nix b/flake.nix new file mode 100644 index 0000000..9bafa47 --- /dev/null +++ b/flake.nix @@ -0,0 +1,61 @@ +{ + inputs = { + flake-utils.url = "github:numtide/flake-utils"; + fenix = { + url = "github:nix-community/fenix"; + inputs.nixpkgs.follows = "nixpkgs"; + }; + nixpkgs.url = "nixpkgs/nixos-unstable"; + }; + + outputs = { flake-utils, fenix, nixpkgs, ... }: + let + nixosModule = { pkgs, ... }: { + nixpkgs.overlays = [ fenix.overlays.default ]; + environment.systemPackages = [ + (pkgs.fenix.stable.withComponents [ + "cargo" + "clippy" + "rust-src" + "rustc" + "rustfmt" + ]) + pkgs.rust-analyzer + ]; + }; + in + flake-utils.lib.eachDefaultSystem (system: + let + pkgs = import nixpkgs { + inherit system; + overlays = [ fenix.overlays.default ]; + }; + rustToolchain = pkgs.fenix.stable.withComponents [ + "cargo" + "clippy" + "rust-src" + "rustc" + "rustfmt" + ]; + in { + packages.default = fenix.packages.${system}.stable.toolchain; + devShells.default = pkgs.mkShell { + packages = [ + rustToolchain + pkgs.rust-analyzer + ]; + }; + }) // { + nixosConfigurations = { + nixos = nixpkgs.lib.nixosSystem { + system = "x86_64-linux"; + modules = [ nixosModule ]; + }; + + nixos-aarch64 = nixpkgs.lib.nixosSystem { + system = "aarch64-linux"; + modules = [ nixosModule ]; + }; + }; + }; +} diff --git a/fuzz/Cargo.toml b/fuzz/Cargo.toml index b9d2bbe..e55d4da 100644 --- a/fuzz/Cargo.toml +++ b/fuzz/Cargo.toml @@ -24,3 +24,21 @@ name = "fuzz_tool_params" path = "fuzz_targets/fuzz_tool_params.rs" test = false doc = false + +[[bin]] +name = "fuzz_webhook_payload" +path = "fuzz_targets/fuzz_webhook_payload.rs" +test = false +doc = false + +[[bin]] +name = "fuzz_provider_response" +path = "fuzz_targets/fuzz_provider_response.rs" +test = false +doc = false + +[[bin]] +name = "fuzz_command_validation" +path = "fuzz_targets/fuzz_command_validation.rs" +test = false +doc = false diff --git a/fuzz/fuzz_targets/fuzz_command_validation.rs b/fuzz/fuzz_targets/fuzz_command_validation.rs new file mode 100644 index 0000000..13cce01 --- /dev/null +++ b/fuzz/fuzz_targets/fuzz_command_validation.rs @@ -0,0 +1,10 @@ +#![no_main] +use libfuzzer_sys::fuzz_target; +use zeroclaw::security::SecurityPolicy; + +fuzz_target!(|data: &[u8]| { + if let Ok(s) = std::str::from_utf8(data) { + let policy = SecurityPolicy::default(); + let _ = policy.validate_command_execution(s, false); + } +}); diff --git a/fuzz/fuzz_targets/fuzz_provider_response.rs b/fuzz/fuzz_targets/fuzz_provider_response.rs new file mode 100644 index 0000000..73f895d --- /dev/null +++ b/fuzz/fuzz_targets/fuzz_provider_response.rs @@ -0,0 +1,9 @@ +#![no_main] +use libfuzzer_sys::fuzz_target; + +fuzz_target!(|data: &[u8]| { + if let Ok(s) = std::str::from_utf8(data) { + // Fuzz provider API response deserialization + let _ = serde_json::from_str::(s); + } +}); diff --git a/fuzz/fuzz_targets/fuzz_webhook_payload.rs b/fuzz/fuzz_targets/fuzz_webhook_payload.rs new file mode 100644 index 0000000..1f5b813 --- /dev/null +++ b/fuzz/fuzz_targets/fuzz_webhook_payload.rs @@ -0,0 +1,9 @@ +#![no_main] +use libfuzzer_sys::fuzz_target; + +fuzz_target!(|data: &[u8]| { + if let Ok(s) = std::str::from_utf8(data) { + // Fuzz webhook body deserialization + let _ = serde_json::from_str::(s); + } +}); diff --git a/scripts/bootstrap.sh b/scripts/bootstrap.sh index b734124..b6732a7 100755 --- a/scripts/bootstrap.sh +++ b/scripts/bootstrap.sh @@ -15,38 +15,61 @@ error() { usage() { cat <<'USAGE' -ZeroClaw one-click bootstrap +ZeroClaw installer bootstrap engine Usage: - ./bootstrap.sh [options] + ./zeroclaw_install.sh [options] + ./bootstrap.sh [options] # compatibility entrypoint Modes: Default mode installs/builds ZeroClaw only (requires existing Rust toolchain). + Guided mode asks setup questions and configures options interactively. Optional bootstrap mode can also install system dependencies and Rust. Options: + --guided Run interactive guided installer + --no-guided Disable guided installer + --docker Run bootstrap in Docker and launch onboarding inside the container --install-system-deps Install build dependencies (Linux/macOS) --install-rust Install Rust via rustup if missing + --prefer-prebuilt Try latest release binary first; fallback to source build on miss + --prebuilt-only Install only from latest release binary (no source build fallback) + --force-source-build Disable prebuilt flow and always build from source --onboard Run onboarding after install --interactive-onboard Run interactive onboarding (implies --onboard) --api-key API key for non-interactive onboarding --provider Provider for non-interactive onboarding (default: openrouter) + --model Model for non-interactive onboarding (optional) + --build-first Alias for explicitly enabling separate `cargo build --release --locked` --skip-build Skip `cargo build --release --locked` --skip-install Skip `cargo install --path . --force --locked` -h, --help Show help Examples: - ./bootstrap.sh - ./bootstrap.sh --install-system-deps --install-rust - ./bootstrap.sh --onboard --api-key "sk-..." --provider openrouter - ./bootstrap.sh --interactive-onboard + ./zeroclaw_install.sh + ./zeroclaw_install.sh --guided + ./zeroclaw_install.sh --install-system-deps --install-rust + ./zeroclaw_install.sh --prefer-prebuilt + ./zeroclaw_install.sh --prebuilt-only + ./zeroclaw_install.sh --onboard --api-key "sk-..." --provider openrouter [--model "openrouter/auto"] + ./zeroclaw_install.sh --interactive-onboard + + # Compatibility entrypoint: + ./bootstrap.sh --docker # Remote one-liner curl -fsSL https://raw.githubusercontent.com/zeroclaw-labs/zeroclaw/main/scripts/bootstrap.sh | bash Environment: + ZEROCLAW_DOCKER_DATA_DIR Host path for Docker config/workspace persistence + ZEROCLAW_DOCKER_IMAGE Docker image tag to build/run (default: zeroclaw-bootstrap:local) ZEROCLAW_API_KEY Used when --api-key is not provided ZEROCLAW_PROVIDER Used when --provider is not provided (default: openrouter) + ZEROCLAW_MODEL Used when --model is not provided + ZEROCLAW_BOOTSTRAP_MIN_RAM_MB Minimum RAM threshold for source build preflight (default: 2048) + ZEROCLAW_BOOTSTRAP_MIN_DISK_MB Minimum free disk threshold for source build preflight (default: 6144) + ZEROCLAW_DISABLE_ALPINE_AUTO_DEPS + Set to 1 to disable Alpine auto-install of missing prerequisites USAGE } @@ -54,6 +77,155 @@ have_cmd() { command -v "$1" >/dev/null 2>&1 } +get_total_memory_mb() { + case "$(uname -s)" in + Linux) + if [[ -r /proc/meminfo ]]; then + awk '/MemTotal:/ {printf "%d\n", $2 / 1024}' /proc/meminfo + fi + ;; + Darwin) + if have_cmd sysctl; then + local bytes + bytes="$(sysctl -n hw.memsize 2>/dev/null || true)" + if [[ "$bytes" =~ ^[0-9]+$ ]]; then + echo $((bytes / 1024 / 1024)) + fi + fi + ;; + esac +} + +get_available_disk_mb() { + local path="${1:-.}" + local free_kb + free_kb="$(df -Pk "$path" 2>/dev/null | awk 'NR==2 {print $4}')" + if [[ "$free_kb" =~ ^[0-9]+$ ]]; then + echo $((free_kb / 1024)) + fi +} + +detect_release_target() { + local os arch + os="$(uname -s)" + arch="$(uname -m)" + + case "$os:$arch" in + Linux:x86_64) + echo "x86_64-unknown-linux-gnu" + ;; + Linux:aarch64|Linux:arm64) + echo "aarch64-unknown-linux-gnu" + ;; + Linux:armv7l|Linux:armv6l) + echo "armv7-unknown-linux-gnueabihf" + ;; + Darwin:x86_64) + echo "x86_64-apple-darwin" + ;; + Darwin:arm64|Darwin:aarch64) + echo "aarch64-apple-darwin" + ;; + *) + return 1 + ;; + esac +} + +should_attempt_prebuilt_for_resources() { + local workspace="${1:-.}" + local min_ram_mb min_disk_mb total_ram_mb free_disk_mb low_resource + + min_ram_mb="${ZEROCLAW_BOOTSTRAP_MIN_RAM_MB:-2048}" + min_disk_mb="${ZEROCLAW_BOOTSTRAP_MIN_DISK_MB:-6144}" + total_ram_mb="$(get_total_memory_mb || true)" + free_disk_mb="$(get_available_disk_mb "$workspace" || true)" + low_resource=false + + if [[ "$total_ram_mb" =~ ^[0-9]+$ && "$total_ram_mb" -lt "$min_ram_mb" ]]; then + low_resource=true + fi + if [[ "$free_disk_mb" =~ ^[0-9]+$ && "$free_disk_mb" -lt "$min_disk_mb" ]]; then + low_resource=true + fi + + if [[ "$low_resource" == true ]]; then + warn "Source build preflight indicates constrained resources." + if [[ "$total_ram_mb" =~ ^[0-9]+$ ]]; then + warn "Detected RAM: ${total_ram_mb}MB (recommended >= ${min_ram_mb}MB for local source builds)." + else + warn "Unable to detect total RAM automatically." + fi + if [[ "$free_disk_mb" =~ ^[0-9]+$ ]]; then + warn "Detected free disk: ${free_disk_mb}MB (recommended >= ${min_disk_mb}MB)." + else + warn "Unable to detect free disk space automatically." + fi + return 0 + fi + + return 1 +} + +install_prebuilt_binary() { + local target archive_url temp_dir archive_path extracted_bin install_dir + + if ! have_cmd curl; then + warn "curl is required for pre-built binary installation." + return 1 + fi + if ! have_cmd tar; then + warn "tar is required for pre-built binary installation." + return 1 + fi + + target="$(detect_release_target || true)" + if [[ -z "$target" ]]; then + warn "No pre-built binary target mapping for $(uname -s)/$(uname -m)." + return 1 + fi + + archive_url="https://github.com/zeroclaw-labs/zeroclaw/releases/latest/download/zeroclaw-${target}.tar.gz" + temp_dir="$(mktemp -d -t zeroclaw-prebuilt-XXXXXX)" + archive_path="$temp_dir/zeroclaw-${target}.tar.gz" + + info "Attempting pre-built binary install for target: $target" + if ! curl -fsSL "$archive_url" -o "$archive_path"; then + warn "Could not download release asset: $archive_url" + rm -rf "$temp_dir" + return 1 + fi + + if ! tar -xzf "$archive_path" -C "$temp_dir"; then + warn "Failed to extract pre-built archive." + rm -rf "$temp_dir" + return 1 + fi + + extracted_bin="$temp_dir/zeroclaw" + if [[ ! -x "$extracted_bin" ]]; then + extracted_bin="$(find "$temp_dir" -maxdepth 2 -type f -name zeroclaw -perm -u+x | head -n 1 || true)" + fi + if [[ -z "$extracted_bin" || ! -x "$extracted_bin" ]]; then + warn "Archive did not contain an executable zeroclaw binary." + rm -rf "$temp_dir" + return 1 + fi + + install_dir="$HOME/.cargo/bin" + mkdir -p "$install_dir" + install -m 0755 "$extracted_bin" "$install_dir/zeroclaw" + rm -rf "$temp_dir" + + info "Installed pre-built binary to $install_dir/zeroclaw" + if [[ ":$PATH:" != *":$install_dir:"* ]]; then + warn "$install_dir is not in PATH for this shell." + warn "Run: export PATH=\"$install_dir:\$PATH\"" + fi + + return 0 +} + run_privileged() { if [[ "$(id -u)" -eq 0 ]]; then "$@" @@ -65,19 +237,152 @@ run_privileged() { fi } +is_container_runtime() { + if [[ -f /.dockerenv || -f /run/.containerenv ]]; then + return 0 + fi + + if [[ -r /proc/1/cgroup ]] && grep -Eq '(docker|containerd|kubepods|podman|lxc)' /proc/1/cgroup; then + return 0 + fi + + return 1 +} + +run_pacman() { + if ! have_cmd pacman; then + error "pacman is not available." + return 1 + fi + + if ! is_container_runtime; then + run_privileged pacman "$@" + return $? + fi + + local pacman_cfg_tmp="" + local pacman_rc=0 + pacman_cfg_tmp="$(mktemp /tmp/zeroclaw-pacman.XXXXXX.conf)" + cp /etc/pacman.conf "$pacman_cfg_tmp" + if ! grep -Eq '^[[:space:]]*DisableSandboxSyscalls([[:space:]]|$)' "$pacman_cfg_tmp"; then + printf '\nDisableSandboxSyscalls\n' >> "$pacman_cfg_tmp" + fi + + if run_privileged pacman --config "$pacman_cfg_tmp" "$@"; then + pacman_rc=0 + else + pacman_rc=$? + fi + + rm -f "$pacman_cfg_tmp" + return "$pacman_rc" +} + +ALPINE_PREREQ_PACKAGES=( + bash + build-base + pkgconf + git + curl + openssl-dev + perl + ca-certificates +) +ALPINE_MISSING_PKGS=() + +find_missing_alpine_prereqs() { + ALPINE_MISSING_PKGS=() + if ! have_cmd apk; then + return 0 + fi + + local pkg="" + for pkg in "${ALPINE_PREREQ_PACKAGES[@]}"; do + if ! apk info -e "$pkg" >/dev/null 2>&1; then + ALPINE_MISSING_PKGS+=("$pkg") + fi + done +} + +bool_to_word() { + if [[ "$1" == true ]]; then + echo "yes" + else + echo "no" + fi +} + +prompt_yes_no() { + local question="$1" + local default_answer="$2" + local prompt="" + local answer="" + + if [[ "$default_answer" == "yes" ]]; then + prompt="[Y/n]" + else + prompt="[y/N]" + fi + + while true; do + if ! read -r -p "$question $prompt " answer; then + error "guided installer input was interrupted." + exit 1 + fi + answer="${answer:-$default_answer}" + case "$(printf '%s' "$answer" | tr '[:upper:]' '[:lower:]')" in + y|yes) + return 0 + ;; + n|no) + return 1 + ;; + *) + echo "Please answer yes or no." + ;; + esac + done +} + install_system_deps() { info "Installing system dependencies" case "$(uname -s)" in Linux) - if have_cmd apt-get; then + if have_cmd apk; then + find_missing_alpine_prereqs + if [[ ${#ALPINE_MISSING_PKGS[@]} -eq 0 ]]; then + info "Alpine prerequisites already installed" + else + info "Installing Alpine prerequisites: ${ALPINE_MISSING_PKGS[*]}" + run_privileged apk add --no-cache "${ALPINE_MISSING_PKGS[@]}" + fi + elif have_cmd apt-get; then run_privileged apt-get update -qq run_privileged apt-get install -y build-essential pkg-config git curl elif have_cmd dnf; then - run_privileged dnf group install -y development-tools - run_privileged dnf install -y pkg-config git curl + run_privileged dnf install -y \ + gcc \ + gcc-c++ \ + make \ + pkgconf-pkg-config \ + git \ + curl \ + openssl-devel \ + perl + elif have_cmd pacman; then + run_pacman -Sy --noconfirm + run_pacman -S --noconfirm --needed \ + gcc \ + make \ + pkgconf \ + git \ + curl \ + openssl \ + perl \ + ca-certificates else - warn "Unsupported Linux distribution. Install compiler toolchain + pkg-config + git + curl manually." + warn "Unsupported Linux distribution. Install compiler toolchain + pkg-config + git + curl + OpenSSL headers + perl manually." fi ;; Darwin) @@ -126,22 +431,236 @@ install_rust_toolchain() { fi } +run_guided_installer() { + local os_name="$1" + local provider_input="" + local model_input="" + local api_key_input="" + + echo + echo "ZeroClaw guided installer" + echo "Answer a few questions, then the installer will run automatically." + echo + + if [[ "$os_name" == "Linux" ]]; then + if prompt_yes_no "Install Linux build dependencies (toolchain/pkg-config/git/curl)?" "yes"; then + INSTALL_SYSTEM_DEPS=true + fi + else + if prompt_yes_no "Install system dependencies for $os_name?" "no"; then + INSTALL_SYSTEM_DEPS=true + fi + fi + + if have_cmd cargo && have_cmd rustc; then + info "Detected Rust toolchain: $(rustc --version)" + else + if prompt_yes_no "Rust toolchain not found. Install Rust via rustup now?" "yes"; then + INSTALL_RUST=true + fi + fi + + if prompt_yes_no "Run a separate prebuild before install?" "yes"; then + SKIP_BUILD=false + else + SKIP_BUILD=true + fi + + if prompt_yes_no "Install zeroclaw into cargo bin now?" "yes"; then + SKIP_INSTALL=false + else + SKIP_INSTALL=true + fi + + if prompt_yes_no "Run onboarding after install?" "no"; then + RUN_ONBOARD=true + if prompt_yes_no "Use interactive onboarding?" "yes"; then + INTERACTIVE_ONBOARD=true + else + INTERACTIVE_ONBOARD=false + if ! read -r -p "Provider [$PROVIDER]: " provider_input; then + error "guided installer input was interrupted." + exit 1 + fi + if [[ -n "$provider_input" ]]; then + PROVIDER="$provider_input" + fi + + if ! read -r -p "Model [${MODEL:-leave empty}]: " model_input; then + error "guided installer input was interrupted." + exit 1 + fi + if [[ -n "$model_input" ]]; then + MODEL="$model_input" + fi + + if [[ -z "$API_KEY" ]]; then + if ! read -r -s -p "API key (hidden, leave empty to switch to interactive onboarding): " api_key_input; then + echo + error "guided installer input was interrupted." + exit 1 + fi + echo + if [[ -n "$api_key_input" ]]; then + API_KEY="$api_key_input" + else + warn "No API key entered. Using interactive onboarding instead." + INTERACTIVE_ONBOARD=true + fi + fi + fi + fi + + echo + info "Installer plan" + local install_binary=true + local build_first=false + if [[ "$SKIP_INSTALL" == true ]]; then + install_binary=false + fi + if [[ "$SKIP_BUILD" == false ]]; then + build_first=true + fi + echo " docker-mode: $(bool_to_word "$DOCKER_MODE")" + echo " install-system-deps: $(bool_to_word "$INSTALL_SYSTEM_DEPS")" + echo " install-rust: $(bool_to_word "$INSTALL_RUST")" + echo " build-first: $(bool_to_word "$build_first")" + echo " install-binary: $(bool_to_word "$install_binary")" + echo " onboard: $(bool_to_word "$RUN_ONBOARD")" + if [[ "$RUN_ONBOARD" == true ]]; then + echo " interactive-onboard: $(bool_to_word "$INTERACTIVE_ONBOARD")" + if [[ "$INTERACTIVE_ONBOARD" == false ]]; then + echo " provider: $PROVIDER" + if [[ -n "$MODEL" ]]; then + echo " model: $MODEL" + fi + fi + fi + + echo + if ! prompt_yes_no "Proceed with this install plan?" "yes"; then + info "Installation canceled by user." + exit 0 + fi +} + +ensure_docker_ready() { + if ! have_cmd docker; then + error "docker is not installed." + cat <<'MSG' >&2 +Install Docker first, then re-run with: + ./zeroclaw_install.sh --docker +MSG + exit 1 + fi + + if ! docker info >/dev/null 2>&1; then + error "Docker daemon is not reachable." + error "Start Docker and re-run bootstrap." + exit 1 + fi +} + +run_docker_bootstrap() { + local docker_image docker_data_dir default_data_dir + docker_image="${ZEROCLAW_DOCKER_IMAGE:-zeroclaw-bootstrap:local}" + if [[ "$TEMP_CLONE" == true ]]; then + default_data_dir="$HOME/.zeroclaw-docker" + else + default_data_dir="$WORK_DIR/.zeroclaw-docker" + fi + docker_data_dir="${ZEROCLAW_DOCKER_DATA_DIR:-$default_data_dir}" + DOCKER_DATA_DIR="$docker_data_dir" + + mkdir -p "$docker_data_dir/.zeroclaw" "$docker_data_dir/workspace" + + if [[ "$SKIP_INSTALL" == true ]]; then + warn "--skip-install has no effect with --docker." + fi + + if [[ "$SKIP_BUILD" == false ]]; then + info "Building Docker image ($docker_image)" + docker build --target release -t "$docker_image" "$WORK_DIR" + else + info "Skipping Docker image build" + fi + + info "Docker data directory: $docker_data_dir" + + local onboard_cmd=() + if [[ "$INTERACTIVE_ONBOARD" == true ]]; then + info "Launching interactive onboarding in container" + onboard_cmd=(onboard --interactive) + else + if [[ -z "$API_KEY" ]]; then + cat <<'MSG' +==> Onboarding requested, but API key not provided. +Use either: + --api-key "sk-..." +or: + ZEROCLAW_API_KEY="sk-..." ./zeroclaw_install.sh --docker +or run interactive: + ./zeroclaw_install.sh --docker --interactive-onboard +MSG + exit 1 + fi + if [[ -n "$MODEL" ]]; then + info "Launching quick onboarding in container (provider: $PROVIDER, model: $MODEL)" + else + info "Launching quick onboarding in container (provider: $PROVIDER)" + fi + onboard_cmd=(onboard --api-key "$API_KEY" --provider "$PROVIDER") + if [[ -n "$MODEL" ]]; then + onboard_cmd+=(--model "$MODEL") + fi + fi + + docker run --rm -it \ + --user "$(id -u):$(id -g)" \ + -e HOME=/zeroclaw-data \ + -e ZEROCLAW_WORKSPACE=/zeroclaw-data/workspace \ + -v "$docker_data_dir/.zeroclaw:/zeroclaw-data/.zeroclaw" \ + -v "$docker_data_dir/workspace:/zeroclaw-data/workspace" \ + "$docker_image" \ + "${onboard_cmd[@]}" +} + SCRIPT_PATH="${BASH_SOURCE[0]:-$0}" SCRIPT_DIR="$(cd "$(dirname "$SCRIPT_PATH")" >/dev/null 2>&1 && pwd || pwd)" ROOT_DIR="$(cd "$SCRIPT_DIR/.." >/dev/null 2>&1 && pwd || pwd)" REPO_URL="https://github.com/zeroclaw-labs/zeroclaw.git" +ORIGINAL_ARG_COUNT=$# +GUIDED_MODE="auto" +DOCKER_MODE=false INSTALL_SYSTEM_DEPS=false INSTALL_RUST=false +PREFER_PREBUILT=false +PREBUILT_ONLY=false +FORCE_SOURCE_BUILD=false RUN_ONBOARD=false INTERACTIVE_ONBOARD=false SKIP_BUILD=false SKIP_INSTALL=false +PREBUILT_INSTALLED=false API_KEY="${ZEROCLAW_API_KEY:-}" PROVIDER="${ZEROCLAW_PROVIDER:-openrouter}" +MODEL="${ZEROCLAW_MODEL:-}" while [[ $# -gt 0 ]]; do case "$1" in + --guided) + GUIDED_MODE="on" + shift + ;; + --no-guided) + GUIDED_MODE="off" + shift + ;; + --docker) + DOCKER_MODE=true + shift + ;; --install-system-deps) INSTALL_SYSTEM_DEPS=true shift @@ -150,6 +669,18 @@ while [[ $# -gt 0 ]]; do INSTALL_RUST=true shift ;; + --prefer-prebuilt) + PREFER_PREBUILT=true + shift + ;; + --prebuilt-only) + PREBUILT_ONLY=true + shift + ;; + --force-source-build) + FORCE_SOURCE_BUILD=true + shift + ;; --onboard) RUN_ONBOARD=true shift @@ -175,6 +706,18 @@ while [[ $# -gt 0 ]]; do } shift 2 ;; + --model) + MODEL="${2:-}" + [[ -n "$MODEL" ]] || { + error "--model requires a value" + exit 1 + } + shift 2 + ;; + --build-first) + SKIP_BUILD=false + shift + ;; --skip-build) SKIP_BUILD=true shift @@ -196,22 +739,48 @@ while [[ $# -gt 0 ]]; do esac done -if [[ "$INSTALL_SYSTEM_DEPS" == true ]]; then - install_system_deps +OS_NAME="$(uname -s)" +if [[ "$GUIDED_MODE" == "auto" ]]; then + if [[ "$OS_NAME" == "Linux" && "$ORIGINAL_ARG_COUNT" -eq 0 && -t 0 && -t 1 ]]; then + GUIDED_MODE="on" + else + GUIDED_MODE="off" + fi fi -if [[ "$INSTALL_RUST" == true ]]; then - install_rust_toolchain +if [[ "$DOCKER_MODE" == true && "$GUIDED_MODE" == "on" ]]; then + warn "--guided is ignored with --docker." + GUIDED_MODE="off" fi -if ! have_cmd cargo; then - error "cargo is not installed." - cat <<'MSG' >&2 -Install Rust first: https://rustup.rs/ -or re-run with: - ./bootstrap.sh --install-rust -MSG - exit 1 +if [[ "$GUIDED_MODE" == "on" ]]; then + run_guided_installer "$OS_NAME" +fi + +if [[ "$DOCKER_MODE" == true ]]; then + if [[ "$INSTALL_SYSTEM_DEPS" == true ]]; then + warn "--install-system-deps is ignored with --docker." + fi + if [[ "$INSTALL_RUST" == true ]]; then + warn "--install-rust is ignored with --docker." + fi +else + if [[ "$OS_NAME" == "Linux" && -z "${ZEROCLAW_DISABLE_ALPINE_AUTO_DEPS:-}" ]] && have_cmd apk; then + find_missing_alpine_prereqs + if [[ ${#ALPINE_MISSING_PKGS[@]} -gt 0 && "$INSTALL_SYSTEM_DEPS" == false ]]; then + info "Detected Alpine with missing prerequisites: ${ALPINE_MISSING_PKGS[*]}" + info "Auto-enabling system dependency installation (set ZEROCLAW_DISABLE_ALPINE_AUTO_DEPS=1 to disable)." + INSTALL_SYSTEM_DEPS=true + fi + fi + + if [[ "$INSTALL_SYSTEM_DEPS" == true ]]; then + install_system_deps + fi + + if [[ "$INSTALL_RUST" == true ]]; then + install_rust_toolchain + fi fi WORK_DIR="$ROOT_DIR" @@ -254,6 +823,73 @@ echo " workspace: $WORK_DIR" cd "$WORK_DIR" +if [[ "$FORCE_SOURCE_BUILD" == true ]]; then + PREFER_PREBUILT=false + PREBUILT_ONLY=false +fi + +if [[ "$PREBUILT_ONLY" == true ]]; then + PREFER_PREBUILT=true +fi + +if [[ "$DOCKER_MODE" == true ]]; then + ensure_docker_ready + if [[ "$RUN_ONBOARD" == false ]]; then + RUN_ONBOARD=true + if [[ -z "$API_KEY" ]]; then + INTERACTIVE_ONBOARD=true + fi + fi + run_docker_bootstrap + cat <<'DONE' + +✅ Docker bootstrap complete. + +Your containerized ZeroClaw data is persisted under: +DONE + echo " $DOCKER_DATA_DIR" + cat <<'DONE' + +Next steps: + ./zeroclaw_install.sh --docker --interactive-onboard + ./zeroclaw_install.sh --docker --api-key "sk-..." --provider openrouter +DONE + exit 0 +fi + +if [[ "$FORCE_SOURCE_BUILD" == false ]]; then + if [[ "$PREFER_PREBUILT" == false && "$PREBUILT_ONLY" == false ]]; then + if should_attempt_prebuilt_for_resources "$WORK_DIR"; then + info "Attempting pre-built binary first due to resource preflight." + PREFER_PREBUILT=true + fi + fi + + if [[ "$PREFER_PREBUILT" == true ]]; then + if install_prebuilt_binary; then + PREBUILT_INSTALLED=true + SKIP_BUILD=true + SKIP_INSTALL=true + elif [[ "$PREBUILT_ONLY" == true ]]; then + error "Pre-built-only mode requested, but no compatible release asset is available." + error "Try again later, or run with --force-source-build on a machine with enough RAM/disk." + exit 1 + else + warn "Pre-built install unavailable; falling back to source build." + fi + fi +fi + +if [[ "$PREBUILT_INSTALLED" == false && ( "$SKIP_BUILD" == false || "$SKIP_INSTALL" == false ) ]] && ! have_cmd cargo; then + error "cargo is not installed." + cat <<'MSG' >&2 +Install Rust first: https://rustup.rs/ +or re-run with: + ./zeroclaw_install.sh --install-rust +MSG + exit 1 +fi + if [[ "$SKIP_BUILD" == false ]]; then info "Building release binary" cargo build --release --locked @@ -271,6 +907,8 @@ fi ZEROCLAW_BIN="" if have_cmd zeroclaw; then ZEROCLAW_BIN="zeroclaw" +elif [[ -x "$HOME/.cargo/bin/zeroclaw" ]]; then + ZEROCLAW_BIN="$HOME/.cargo/bin/zeroclaw" elif [[ -x "$WORK_DIR/target/release/zeroclaw" ]]; then ZEROCLAW_BIN="$WORK_DIR/target/release/zeroclaw" fi @@ -292,14 +930,22 @@ if [[ "$RUN_ONBOARD" == true ]]; then Use either: --api-key "sk-..." or: - ZEROCLAW_API_KEY="sk-..." ./bootstrap.sh --onboard + ZEROCLAW_API_KEY="sk-..." ./zeroclaw_install.sh --onboard or run interactive: - ./bootstrap.sh --interactive-onboard + ./zeroclaw_install.sh --interactive-onboard MSG exit 1 fi - info "Running quick onboarding (provider: $PROVIDER)" - "$ZEROCLAW_BIN" onboard --api-key "$API_KEY" --provider "$PROVIDER" + if [[ -n "$MODEL" ]]; then + info "Running quick onboarding (provider: $PROVIDER, model: $MODEL)" + else + info "Running quick onboarding (provider: $PROVIDER)" + fi + ONBOARD_CMD=("$ZEROCLAW_BIN" onboard --api-key "$API_KEY" --provider "$PROVIDER") + if [[ -n "$MODEL" ]]; then + ONBOARD_CMD+=(--model "$MODEL") + fi + "${ONBOARD_CMD[@]}" fi fi diff --git a/scripts/ci/fetch_actions_data.py b/scripts/ci/fetch_actions_data.py new file mode 100644 index 0000000..32ebb5b --- /dev/null +++ b/scripts/ci/fetch_actions_data.py @@ -0,0 +1,209 @@ +#!/usr/bin/env python3 +"""Fetch GitHub Actions workflow runs for a given date and summarize costs. + +Usage: + python fetch_actions_data.py [OPTIONS] + +Options: + --date YYYY-MM-DD Date to query (default: yesterday) + --mode brief|full Output mode (default: full) + brief: billable minutes/hours table only + full: detailed breakdown with per-run list + --repo OWNER/NAME Repository (default: zeroclaw-labs/zeroclaw) + -h, --help Show this help message +""" + +import argparse +import json +import subprocess +from datetime import datetime, timedelta, timezone + + +def parse_args(): + """Parse command-line arguments.""" + parser = argparse.ArgumentParser( + description="Fetch GitHub Actions workflow runs and summarize costs.", + ) + yesterday = (datetime.now(timezone.utc) - timedelta(days=1)).strftime("%Y-%m-%d") + parser.add_argument( + "--date", + default=yesterday, + help="Date to query in YYYY-MM-DD format (default: yesterday)", + ) + parser.add_argument( + "--mode", + choices=["brief", "full"], + default="full", + help="Output mode: 'brief' for billable hours only, 'full' for detailed breakdown (default: full)", + ) + parser.add_argument( + "--repo", + default="zeroclaw-labs/zeroclaw", + help="Repository in OWNER/NAME format (default: zeroclaw-labs/zeroclaw)", + ) + return parser.parse_args() + + +def fetch_runs(repo, date_str, page=1, per_page=100): + """Fetch completed workflow runs for a given date.""" + url = ( + f"https://api.github.com/repos/{repo}/actions/runs" + f"?created={date_str}&per_page={per_page}&page={page}" + ) + result = subprocess.run( + ["curl", "-sS", "-H", "Accept: application/vnd.github+json", url], + capture_output=True, text=True + ) + return json.loads(result.stdout) + + +def fetch_jobs(repo, run_id): + """Fetch jobs for a specific run.""" + url = f"https://api.github.com/repos/{repo}/actions/runs/{run_id}/jobs?per_page=100" + result = subprocess.run( + ["curl", "-sS", "-H", "Accept: application/vnd.github+json", url], + capture_output=True, text=True + ) + return json.loads(result.stdout) + + +def parse_duration(started, completed): + """Return duration in seconds between two ISO timestamps.""" + if not started or not completed: + return 0 + try: + s = datetime.fromisoformat(started.replace("Z", "+00:00")) + c = datetime.fromisoformat(completed.replace("Z", "+00:00")) + return max(0, (c - s).total_seconds()) + except Exception: + return 0 + + +def main(): + args = parse_args() + repo = args.repo + date_str = args.date + brief = args.mode == "brief" + + print(f"Fetching workflow runs for {repo} on {date_str}...") + print("=" * 100) + + all_runs = [] + for page in range(1, 5): # up to 400 runs + data = fetch_runs(repo, date_str, page=page) + runs = data.get("workflow_runs", []) + if not runs: + break + all_runs.extend(runs) + if len(runs) < 100: + break + + print(f"Total workflow runs found: {len(all_runs)}") + print() + + # Group by workflow name + workflow_stats = {} + for run in all_runs: + name = run.get("name", "Unknown") + event = run.get("event", "unknown") + conclusion = run.get("conclusion", "unknown") + run_id = run.get("id") + + if name not in workflow_stats: + workflow_stats[name] = { + "count": 0, + "events": {}, + "conclusions": {}, + "total_job_seconds": 0, + "total_jobs": 0, + "run_ids": [], + } + + workflow_stats[name]["count"] += 1 + workflow_stats[name]["events"][event] = workflow_stats[name]["events"].get(event, 0) + 1 + workflow_stats[name]["conclusions"][conclusion] = workflow_stats[name]["conclusions"].get(conclusion, 0) + 1 + workflow_stats[name]["run_ids"].append(run_id) + + # For each workflow, sample up to 3 runs to get job-level timing + print("Sampling job-level timing (up to 3 runs per workflow)...") + print() + + for name, stats in workflow_stats.items(): + sample_ids = stats["run_ids"][:3] + for run_id in sample_ids: + jobs_data = fetch_jobs(repo, run_id) + jobs = jobs_data.get("jobs", []) + for job in jobs: + started = job.get("started_at") + completed = job.get("completed_at") + duration = parse_duration(started, completed) + stats["total_job_seconds"] += duration + stats["total_jobs"] += 1 + + # Extrapolate: if we sampled N runs but there are M total, scale up + sampled = len(sample_ids) + total = stats["count"] + if sampled > 0 and sampled < total: + scale = total / sampled + stats["estimated_total_seconds"] = stats["total_job_seconds"] * scale + else: + stats["estimated_total_seconds"] = stats["total_job_seconds"] + + # Print summary sorted by estimated cost (descending) + sorted_workflows = sorted( + workflow_stats.items(), + key=lambda x: x[1]["estimated_total_seconds"], + reverse=True + ) + + if brief: + # Brief mode: compact billable hours table + print(f"{'Workflow':<40} {'Runs':>5} {'Est.Mins':>9} {'Est.Hours':>10}") + print("-" * 68) + grand_total_minutes = 0 + for name, stats in sorted_workflows: + est_mins = stats["estimated_total_seconds"] / 60 + grand_total_minutes += est_mins + print(f"{name:<40} {stats['count']:>5} {est_mins:>9.1f} {est_mins/60:>10.2f}") + print("-" * 68) + print(f"{'TOTAL':<40} {len(all_runs):>5} {grand_total_minutes:>9.0f} {grand_total_minutes/60:>10.1f}") + print(f"\nProjected monthly: ~{grand_total_minutes/60*30:.0f} hours") + else: + # Full mode: detailed breakdown with per-run list + print("=" * 100) + print(f"{'Workflow':<40} {'Runs':>5} {'SampledJobs':>12} {'SampledMins':>12} {'Est.TotalMins':>14} {'Events'}") + print("-" * 100) + + grand_total_minutes = 0 + for name, stats in sorted_workflows: + sampled_mins = stats["total_job_seconds"] / 60 + est_total_mins = stats["estimated_total_seconds"] / 60 + grand_total_minutes += est_total_mins + events_str = ", ".join(f"{k}={v}" for k, v in stats["events"].items()) + conclusions_str = ", ".join(f"{k}={v}" for k, v in stats["conclusions"].items()) + print( + f"{name:<40} {stats['count']:>5} {stats['total_jobs']:>12} " + f"{sampled_mins:>12.1f} {est_total_mins:>14.1f} {events_str}" + ) + print(f"{'':>40} {'':>5} {'':>12} {'':>12} {'':>14} outcomes: {conclusions_str}") + + print("-" * 100) + print(f"{'GRAND TOTAL':>40} {len(all_runs):>5} {'':>12} {'':>12} {grand_total_minutes:>14.1f}") + print(f"\nEstimated total billable minutes on {date_str}: {grand_total_minutes:.0f} min ({grand_total_minutes/60:.1f} hours)") + print() + + # Also show raw run list + print("\n" + "=" * 100) + print("DETAILED RUN LIST") + print("=" * 100) + for run in all_runs: + name = run.get("name", "Unknown") + event = run.get("event", "unknown") + conclusion = run.get("conclusion", "unknown") + run_id = run.get("id") + started = run.get("run_started_at", "?") + print(f" [{run_id}] {name:<40} conclusion={conclusion:<12} event={event:<20} started={started}") + + +if __name__ == "__main__": + main() diff --git a/scripts/install.sh b/scripts/install.sh index 68efa95..478bdd5 100755 --- a/scripts/install.sh +++ b/scripts/install.sh @@ -2,10 +2,15 @@ set -euo pipefail SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]:-$0}")" >/dev/null 2>&1 && pwd || pwd)" +INSTALLER_LOCAL="$(cd "$SCRIPT_DIR/.." >/dev/null 2>&1 && pwd || pwd)/zeroclaw_install.sh" BOOTSTRAP_LOCAL="$SCRIPT_DIR/bootstrap.sh" REPO_URL="https://github.com/zeroclaw-labs/zeroclaw.git" -echo "[deprecated] scripts/install.sh -> bootstrap.sh" >&2 +echo "[deprecated] scripts/install.sh -> ./zeroclaw_install.sh" >&2 + +if [[ -x "$INSTALLER_LOCAL" ]]; then + exec "$INSTALLER_LOCAL" "$@" +fi if [[ -f "$BOOTSTRAP_LOCAL" ]]; then exec "$BOOTSTRAP_LOCAL" "$@" @@ -24,35 +29,15 @@ trap cleanup EXIT git clone --depth 1 "$REPO_URL" "$TEMP_DIR" >/dev/null 2>&1 +if [[ -x "$TEMP_DIR/zeroclaw_install.sh" ]]; then + exec "$TEMP_DIR/zeroclaw_install.sh" "$@" +fi + if [[ -x "$TEMP_DIR/scripts/bootstrap.sh" ]]; then - "$TEMP_DIR/scripts/bootstrap.sh" "$@" - exit 0 + exec "$TEMP_DIR/scripts/bootstrap.sh" "$@" fi -echo "[deprecated] cloned revision has no bootstrap.sh; falling back to legacy source install flow" >&2 - -if [[ "${1:-}" == "--help" || "${1:-}" == "-h" ]]; then - cat <<'USAGE' -Legacy install.sh fallback mode - -Behavior: - - Clone repository - - cargo build --release --locked - - cargo install --path --force --locked - -For the new dual-mode installer, use: - ./bootstrap.sh --help -USAGE - exit 0 -fi - -if ! command -v cargo >/dev/null 2>&1; then - echo "error: cargo is required for legacy install.sh fallback mode" >&2 - echo "Install Rust first: https://rustup.rs/" >&2 - exit 1 -fi - -cargo build --release --locked --manifest-path "$TEMP_DIR/Cargo.toml" -cargo install --path "$TEMP_DIR" --force --locked - -echo "Legacy source install completed." >&2 +echo "error: zeroclaw_install.sh/bootstrap.sh was not found in the fetched revision." >&2 +echo "Run the local bootstrap directly when possible:" >&2 +echo " ./zeroclaw_install.sh --help" >&2 +exit 1 diff --git a/src/agent/agent.rs b/src/agent/agent.rs index dc8f74d..5f048e2 100644 --- a/src/agent/agent.rs +++ b/src/agent/agent.rs @@ -10,7 +10,6 @@ use crate::providers::{self, ChatMessage, ChatRequest, ConversationMessage, Prov use crate::runtime; use crate::security::SecurityPolicy; use crate::tools::{self, Tool, ToolSpec}; -use crate::util::truncate_with_ellipsis; use anyhow::Result; use std::io::Write as IoWrite; use std::sync::Arc; @@ -229,8 +228,9 @@ impl Agent { &config.workspace_dir, )); - let memory: Arc = Arc::from(memory::create_memory_with_storage( + let memory: Arc = Arc::from(memory::create_memory_with_storage_and_routes( &config.memory, + &config.embedding_routes, Some(&config.storage.provider.config), &config.workspace_dir, config.api_key.as_deref(), @@ -308,7 +308,10 @@ impl Agent { .classification_config(config.query_classification.clone()) .available_hints(available_hints) .identity_config(config.identity.clone()) - .skills(crate::skills::load_skills(&config.workspace_dir)) + .skills(crate::skills::load_skills_with_config( + &config.workspace_dir, + config, + )) .auto_save(config.memory.auto_save) .build() } @@ -400,11 +403,8 @@ impl Agent { return results; } - let mut results = Vec::with_capacity(calls.len()); - for call in calls { - results.push(self.execute_tool_call(call).await); - } - results + let futs: Vec<_> = calls.iter().map(|call| self.execute_tool_call(call)).collect(); + futures::future::join_all(futs).await } fn classify_model(&self, user_message: &str) -> String { @@ -486,14 +486,6 @@ impl Agent { ))); self.trim_history(); - if self.auto_save { - let summary = truncate_with_ellipsis(&final_text, 100); - let _ = self - .memory - .store("assistant_resp", &summary, MemoryCategory::Daily, None) - .await; - } - return Ok(final_text); } @@ -686,7 +678,8 @@ mod tests { ..crate::config::MemoryConfig::default() }; let mem: Arc = Arc::from( - crate::memory::create_memory(&memory_cfg, std::path::Path::new("/tmp"), None).unwrap(), + crate::memory::create_memory(&memory_cfg, std::path::Path::new("/tmp"), None) + .expect("memory creation should succeed with valid config"), ); let observer: Arc = Arc::from(crate::observability::NoopObserver {}); @@ -698,7 +691,7 @@ mod tests { .tool_dispatcher(Box::new(XmlToolDispatcher)) .workspace_dir(std::path::PathBuf::from("/tmp")) .build() - .unwrap(); + .expect("agent builder should succeed with valid config"); let response = agent.turn("hi").await.unwrap(); assert_eq!(response, "hello"); @@ -728,7 +721,8 @@ mod tests { ..crate::config::MemoryConfig::default() }; let mem: Arc = Arc::from( - crate::memory::create_memory(&memory_cfg, std::path::Path::new("/tmp"), None).unwrap(), + crate::memory::create_memory(&memory_cfg, std::path::Path::new("/tmp"), None) + .expect("memory creation should succeed with valid config"), ); let observer: Arc = Arc::from(crate::observability::NoopObserver {}); @@ -740,7 +734,7 @@ mod tests { .tool_dispatcher(Box::new(NativeToolDispatcher)) .workspace_dir(std::path::PathBuf::from("/tmp")) .build() - .unwrap(); + .expect("agent builder should succeed with valid config"); let response = agent.turn("hi").await.unwrap(); assert_eq!(response, "done"); diff --git a/src/agent/loop_.rs b/src/agent/loop_.rs index caa7e53..288ea27 100644 --- a/src/agent/loop_.rs +++ b/src/agent/loop_.rs @@ -1,8 +1,11 @@ use crate::approval::{ApprovalManager, ApprovalRequest, ApprovalResponse}; use crate::config::Config; use crate::memory::{self, Memory, MemoryCategory}; +use crate::multimodal; use crate::observability::{self, Observer, ObserverEvent}; -use crate::providers::{self, ChatMessage, ChatRequest, Provider, ToolCall}; +use crate::providers::{ + self, ChatMessage, ChatRequest, Provider, ProviderCapabilityError, ToolCall, +}; use crate::runtime; use crate::security::SecurityPolicy; use crate::tools::{self, Tool}; @@ -13,6 +16,7 @@ use std::fmt::Write; use std::io::Write as _; use std::sync::{Arc, LazyLock}; use std::time::Instant; +use tokio_util::sync::CancellationToken; use uuid::Uuid; /// Minimum characters per chunk when relaying LLM text to a streaming draft. @@ -22,6 +26,10 @@ const STREAM_CHUNK_MIN_CHARS: usize = 80; /// Used as a safe fallback when `max_tool_iterations` is unset or configured as zero. const DEFAULT_MAX_TOOL_ITERATIONS: usize = 10; +/// Minimum user-message length (in chars) for auto-save to memory. +/// Matches the channel-side constant in `channels/mod.rs`. +const AUTOSAVE_MIN_MESSAGE_CHARS: usize = 20; + static SENSITIVE_KEY_PATTERNS: LazyLock = LazyLock::new(|| { RegexSet::new([ r"(?i)token", @@ -223,9 +231,16 @@ async fn build_context(mem: &dyn Memory, user_msg: &str, min_relevance_score: f6 if !relevant.is_empty() { context.push_str("[Memory context]\n"); for entry in &relevant { + if memory::is_assistant_autosave_key(&entry.key) { + continue; + } let _ = writeln!(context, "- {}: {}", entry.key, entry.content); } - context.push('\n'); + if context != "[Memory context]\n" { + context.push('\n'); + } else { + context.clear(); + } } } @@ -579,6 +594,17 @@ fn parse_glm_style_tool_calls(text: &str) -> Vec<(String, serde_json::Value, Opt calls } +// ── Tool-Call Parsing ───────────────────────────────────────────────────── +// LLM responses may contain tool calls in multiple formats depending on +// the provider. Parsing follows a priority chain: +// 1. OpenAI-style JSON with `tool_calls` array (native API) +// 2. XML tags: , , , +// 3. Markdown code blocks with `tool_call` language +// 4. GLM-style line-based format (e.g. `shell/command>ls`) +// SECURITY: We never fall back to extracting arbitrary JSON from the +// response body, because that would enable prompt-injection attacks where +// malicious content in emails/files/web pages mimics a tool call. + /// Parse tool calls from an LLM response that uses XML-style function calling. /// /// Expected format (common with system-prompt-guided tool use): @@ -813,6 +839,21 @@ struct ParsedToolCall { arguments: serde_json::Value, } +#[derive(Debug)] +pub(crate) struct ToolLoopCancelled; + +impl std::fmt::Display for ToolLoopCancelled { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str("tool loop cancelled") + } +} + +impl std::error::Error for ToolLoopCancelled {} + +pub(crate) fn is_tool_loop_cancelled(err: &anyhow::Error) -> bool { + err.chain().any(|source| source.is::()) +} + /// Execute a single turn of the agent loop: send messages, parse tool calls, /// execute tools, and loop until the LLM produces a final text response. /// When `silent` is true, suppresses stdout (for channel use). @@ -826,6 +867,7 @@ pub(crate) async fn agent_turn( model: &str, temperature: f64, silent: bool, + multimodal_config: &crate::config::MultimodalConfig, max_tool_iterations: usize, ) -> Result { run_tool_call_loop( @@ -839,12 +881,26 @@ pub(crate) async fn agent_turn( silent, None, "channel", + multimodal_config, max_tool_iterations, None, + None, ) .await } +// ── Agent Tool-Call Loop ────────────────────────────────────────────────── +// Core agentic iteration: send conversation to the LLM, parse any tool +// calls from the response, execute them, append results to history, and +// repeat until the LLM produces a final text-only answer. +// +// Loop invariant: at the start of each iteration, `history` contains the +// full conversation so far (system prompt + user messages + prior tool +// results). The loop exits when: +// • the LLM returns no tool calls (final answer), or +// • max_iterations is reached (runaway safety), or +// • the cancellation token fires (external abort). + /// Execute a single turn of the agent loop: send messages, parse tool calls, /// execute tools, and loop until the LLM produces a final text response. #[allow(clippy::too_many_arguments)] @@ -859,7 +915,9 @@ pub(crate) async fn run_tool_call_loop( silent: bool, approval: Option<&ApprovalManager>, channel_name: &str, + multimodal_config: &crate::config::MultimodalConfig, max_tool_iterations: usize, + cancellation_token: Option, on_delta: Option>, ) -> Result { let max_iterations = if max_tool_iterations == 0 { @@ -873,6 +931,28 @@ pub(crate) async fn run_tool_call_loop( let use_native_tools = provider.supports_native_tools() && !tool_specs.is_empty(); for _iteration in 0..max_iterations { + if cancellation_token + .as_ref() + .is_some_and(CancellationToken::is_cancelled) + { + return Err(ToolLoopCancelled.into()); + } + + let image_marker_count = multimodal::count_image_markers(history); + if image_marker_count > 0 && !provider.supports_vision() { + return Err(ProviderCapabilityError { + provider: provider_name.to_string(), + capability: "vision".to_string(), + message: format!( + "received {image_marker_count} image marker(s), but this provider does not support vision input" + ), + } + .into()); + } + + let prepared_messages = + multimodal::prepare_messages_for_provider(history, multimodal_config).await?; + observer.record_event(&ObserverEvent::LlmRequest { provider: provider_name.to_string(), model: model.to_string(), @@ -889,18 +969,26 @@ pub(crate) async fn run_tool_call_loop( None }; + let chat_future = provider.chat( + ChatRequest { + messages: &prepared_messages.messages, + tools: request_tools, + }, + model, + temperature, + ); + + let chat_result = if let Some(token) = cancellation_token.as_ref() { + tokio::select! { + () = token.cancelled() => return Err(ToolLoopCancelled.into()), + result = chat_future => result, + } + } else { + chat_future.await + }; + let (response_text, parsed_text, tool_calls, assistant_history_content, native_tool_calls) = - match provider - .chat( - ChatRequest { - messages: history, - tools: request_tools, - }, - model, - temperature, - ) - .await - { + match chat_result { Ok(resp) => { observer.record_event(&ObserverEvent::LlmResponse { provider: provider_name.to_string(), @@ -911,6 +999,10 @@ pub(crate) async fn run_tool_call_loop( }); let response_text = resp.text_or_empty().to_string(); + // First try native structured tool calls (OpenAI-format). + // Fall back to text-based parsing (XML tags, markdown blocks, + // GLM format) only if the provider returned no native calls — + // this ensures we support both native and prompt-guided models. let mut calls = parse_structured_tool_calls(&resp.tool_calls); let mut parsed_text = String::new(); @@ -966,6 +1058,12 @@ pub(crate) async fn run_tool_call_loop( // STREAM_CHUNK_MIN_CHARS characters for progressive draft updates. let mut chunk = String::new(); for word in display_text.split_inclusive(char::is_whitespace) { + if cancellation_token + .as_ref() + .is_some_and(CancellationToken::is_cancelled) + { + return Err(ToolLoopCancelled.into()); + } chunk.push_str(word); if chunk.len() >= STREAM_CHUNK_MIN_CHARS && tx.send(std::mem::take(&mut chunk)).await.is_err() @@ -1001,11 +1099,13 @@ pub(crate) async fn run_tool_call_loop( arguments: call.arguments.clone(), }; - // Only prompt interactively on CLI; auto-approve on other channels. + // On CLI, prompt interactively. On other channels where + // interactive approval is not possible, deny the call to + // respect the supervised autonomy setting. let decision = if channel_name == "cli" { mgr.prompt_cli(&request) } else { - ApprovalResponse::Yes + ApprovalResponse::No }; mgr.record_decision(&call.name, &call.arguments, decision, channel_name); @@ -1028,7 +1128,17 @@ pub(crate) async fn run_tool_call_loop( }); let start = Instant::now(); let result = if let Some(tool) = find_tool(tools_registry, &call.name) { - match tool.execute(call.arguments.clone()).await { + let tool_future = tool.execute(call.arguments.clone()); + let tool_result = if let Some(token) = cancellation_token.as_ref() { + tokio::select! { + () = token.cancelled() => return Err(ToolLoopCancelled.into()), + result = tool_future => result, + } + } else { + tool_future.await + }; + + match tool_result { Ok(r) => { observer.record_event(&ObserverEvent::ToolCall { tool: call.name.clone(), @@ -1113,6 +1223,12 @@ pub(crate) fn build_tool_instructions(tools_registry: &[Box]) -> Strin instructions } +// ── CLI Entrypoint ─────────────────────────────────────────────────────── +// Wires up all subsystems (observer, runtime, security, memory, tools, +// provider, hardware RAG, peripherals) and enters either single-shot or +// interactive REPL mode. The interactive loop manages history compaction +// and hard trimming to keep the context window bounded. + #[allow(clippy::too_many_lines)] pub async fn run( config: Config, @@ -1191,13 +1307,21 @@ pub async fn run( .or(config.default_model.as_deref()) .unwrap_or("anthropic/claude-sonnet-4"); - let provider: Box = providers::create_routed_provider( + let provider_runtime_options = providers::ProviderRuntimeOptions { + auth_profile_override: None, + zeroclaw_dir: config.config_path.parent().map(std::path::PathBuf::from), + secrets_encrypt: config.secrets.encrypt, + reasoning_enabled: config.runtime.reasoning_enabled, + }; + + let provider: Box = providers::create_routed_provider_with_options( provider_name, config.api_key.as_deref(), config.api_url.as_deref(), &config.reliability, &config.model_routes, model_name, + &provider_runtime_options, )?; observer.record_event(&ObserverEvent::AgentStart { @@ -1226,7 +1350,7 @@ pub async fn run( .collect(); // ── Build system prompt from workspace MD files (OpenClaw framework) ── - let skills = crate::skills::load_skills(&config.workspace_dir); + let skills = crate::skills::load_skills_with_config(&config.workspace_dir, &config); let mut tool_descs: Vec<(&str, &str)> = vec![ ( "shell", @@ -1336,17 +1460,21 @@ pub async fn run( } else { None }; - let mut system_prompt = crate::channels::build_system_prompt( + let native_tools = provider.supports_native_tools(); + let mut system_prompt = crate::channels::build_system_prompt_with_mode( &config.workspace_dir, model_name, &tool_descs, &skills, Some(&config.identity), bootstrap_max_chars, + native_tools, ); - // Append structured tool-use instructions with schemas - system_prompt.push_str(&build_tool_instructions(&tools_registry)); + // Append structured tool-use instructions with schemas (only for non-native providers) + if !native_tools { + system_prompt.push_str(&build_tool_instructions(&tools_registry)); + } // ── Approval manager (supervised mode) ─────────────────────── let approval_manager = ApprovalManager::from_config(&config.autonomy); @@ -1357,8 +1485,8 @@ pub async fn run( let mut final_output = String::new(); if let Some(msg) = message { - // Auto-save user message to memory - if config.memory.auto_save { + // Auto-save user message to memory (skip short/trivial messages) + if config.memory.auto_save && msg.chars().count() >= AUTOSAVE_MIN_MESSAGE_CHARS { let user_key = autosave_memory_key("user_msg"); let _ = mem .store(&user_key, &msg, MemoryCategory::Conversation, None) @@ -1396,22 +1524,15 @@ pub async fn run( false, Some(&approval_manager), "cli", + &config.multimodal, config.agent.max_tool_iterations, None, + None, ) .await?; final_output = response.clone(); println!("{response}"); observer.record_event(&ObserverEvent::TurnComplete); - - // Auto-save assistant response to daily log - if config.memory.auto_save { - let summary = truncate_with_ellipsis(&response, 100); - let response_key = autosave_memory_key("assistant_resp"); - let _ = mem - .store(&response_key, &summary, MemoryCategory::Daily, None) - .await; - } } else { println!("🦀 ZeroClaw Interactive Mode"); println!("Type /help for commands.\n"); @@ -1486,8 +1607,10 @@ pub async fn run( _ => {} } - // Auto-save conversation turns - if config.memory.auto_save { + // Auto-save conversation turns (skip short/trivial messages) + if config.memory.auto_save + && user_input.chars().count() >= AUTOSAVE_MIN_MESSAGE_CHARS + { let user_key = autosave_memory_key("user_msg"); let _ = mem .store(&user_key, &user_input, MemoryCategory::Conversation, None) @@ -1522,8 +1645,10 @@ pub async fn run( false, Some(&approval_manager), "cli", + &config.multimodal, config.agent.max_tool_iterations, None, + None, ) .await { @@ -1560,14 +1685,6 @@ pub async fn run( // Hard cap as a safety net. trim_history(&mut history, config.agent.max_history_messages); - - if config.memory.auto_save { - let summary = truncate_with_ellipsis(&response, 100); - let response_key = autosave_memory_key("assistant_resp"); - let _ = mem - .store(&response_key, &summary, MemoryCategory::Daily, None) - .await; - } } } @@ -1632,13 +1749,20 @@ pub async fn process_message(config: Config, message: &str) -> Result { .default_model .clone() .unwrap_or_else(|| "anthropic/claude-sonnet-4-20250514".into()); - let provider: Box = providers::create_routed_provider( + let provider_runtime_options = providers::ProviderRuntimeOptions { + auth_profile_override: None, + zeroclaw_dir: config.config_path.parent().map(std::path::PathBuf::from), + secrets_encrypt: config.secrets.encrypt, + reasoning_enabled: config.runtime.reasoning_enabled, + }; + let provider: Box = providers::create_routed_provider_with_options( provider_name, config.api_key.as_deref(), config.api_url.as_deref(), &config.reliability, &config.model_routes, &model_name, + &provider_runtime_options, )?; let hardware_rag: Option = config @@ -1656,7 +1780,7 @@ pub async fn process_message(config: Config, message: &str) -> Result { .map(|b| b.board.clone()) .collect(); - let skills = crate::skills::load_skills(&config.workspace_dir); + let skills = crate::skills::load_skills_with_config(&config.workspace_dir, &config); let mut tool_descs: Vec<(&str, &str)> = vec![ ("shell", "Execute terminal commands."), ("file_read", "Read file contents."), @@ -1705,15 +1829,19 @@ pub async fn process_message(config: Config, message: &str) -> Result { } else { None }; - let mut system_prompt = crate::channels::build_system_prompt( + let native_tools = provider.supports_native_tools(); + let mut system_prompt = crate::channels::build_system_prompt_with_mode( &config.workspace_dir, &model_name, &tool_descs, &skills, Some(&config.identity), bootstrap_max_chars, + native_tools, ); - system_prompt.push_str(&build_tool_instructions(&tools_registry)); + if !native_tools { + system_prompt.push_str(&build_tool_instructions(&tools_registry)); + } let mem_context = build_context(mem.as_ref(), message, config.memory.min_relevance_score).await; let rag_limit = if config.agent.compact_context { 2 } else { 5 }; @@ -1742,6 +1870,7 @@ pub async fn process_message(config: Config, message: &str) -> Result { &model_name, config.default_temperature, true, + &config.multimodal, config.agent.max_tool_iterations, ) .await @@ -1750,6 +1879,10 @@ pub async fn process_message(config: Config, message: &str) -> Result { #[cfg(test)] mod tests { use super::*; + use async_trait::async_trait; + use base64::{engine::general_purpose::STANDARD, Engine as _}; + use std::sync::atomic::{AtomicUsize, Ordering}; + use std::sync::Arc; #[test] fn test_scrub_credentials() { @@ -1770,8 +1903,194 @@ mod tests { assert!(scrubbed.contains("public")); } use crate::memory::{Memory, MemoryCategory, SqliteMemory}; + use crate::observability::NoopObserver; + use crate::providers::traits::ProviderCapabilities; + use crate::providers::ChatResponse; use tempfile::TempDir; + struct NonVisionProvider { + calls: Arc, + } + + #[async_trait] + impl Provider for NonVisionProvider { + async fn chat_with_system( + &self, + _system_prompt: Option<&str>, + _message: &str, + _model: &str, + _temperature: f64, + ) -> anyhow::Result { + self.calls.fetch_add(1, Ordering::SeqCst); + Ok("ok".to_string()) + } + } + + struct VisionProvider { + calls: Arc, + } + + #[async_trait] + impl Provider for VisionProvider { + fn capabilities(&self) -> ProviderCapabilities { + ProviderCapabilities { + native_tool_calling: false, + vision: true, + } + } + + async fn chat_with_system( + &self, + _system_prompt: Option<&str>, + _message: &str, + _model: &str, + _temperature: f64, + ) -> anyhow::Result { + self.calls.fetch_add(1, Ordering::SeqCst); + Ok("ok".to_string()) + } + + async fn chat( + &self, + request: ChatRequest<'_>, + _model: &str, + _temperature: f64, + ) -> anyhow::Result { + self.calls.fetch_add(1, Ordering::SeqCst); + let marker_count = crate::multimodal::count_image_markers(request.messages); + if marker_count == 0 { + anyhow::bail!("expected image markers in request messages"); + } + + if request.tools.is_some() { + anyhow::bail!("no tools should be attached for this test"); + } + + Ok(ChatResponse { + text: Some("vision-ok".to_string()), + tool_calls: Vec::new(), + }) + } + } + + #[tokio::test] + async fn run_tool_call_loop_returns_structured_error_for_non_vision_provider() { + let calls = Arc::new(AtomicUsize::new(0)); + let provider = NonVisionProvider { + calls: Arc::clone(&calls), + }; + + let mut history = vec![ChatMessage::user( + "please inspect [IMAGE:data:image/png;base64,iVBORw0KGgo=]".to_string(), + )]; + let tools_registry: Vec> = Vec::new(); + let observer = NoopObserver; + + let err = run_tool_call_loop( + &provider, + &mut history, + &tools_registry, + &observer, + "mock-provider", + "mock-model", + 0.0, + true, + None, + "cli", + &crate::config::MultimodalConfig::default(), + 3, + None, + None, + ) + .await + .expect_err("provider without vision support should fail"); + + assert!(err.to_string().contains("provider_capability_error")); + assert!(err.to_string().contains("capability=vision")); + assert_eq!(calls.load(Ordering::SeqCst), 0); + } + + #[tokio::test] + async fn run_tool_call_loop_rejects_oversized_image_payload() { + let calls = Arc::new(AtomicUsize::new(0)); + let provider = VisionProvider { + calls: Arc::clone(&calls), + }; + + let oversized_payload = STANDARD.encode(vec![0_u8; (1024 * 1024) + 1]); + let mut history = vec![ChatMessage::user(format!( + "[IMAGE:data:image/png;base64,{oversized_payload}]" + ))]; + + let tools_registry: Vec> = Vec::new(); + let observer = NoopObserver; + let multimodal = crate::config::MultimodalConfig { + max_images: 4, + max_image_size_mb: 1, + allow_remote_fetch: false, + }; + + let err = run_tool_call_loop( + &provider, + &mut history, + &tools_registry, + &observer, + "mock-provider", + "mock-model", + 0.0, + true, + None, + "cli", + &multimodal, + 3, + None, + None, + ) + .await + .expect_err("oversized payload must fail"); + + assert!(err + .to_string() + .contains("multimodal image size limit exceeded")); + assert_eq!(calls.load(Ordering::SeqCst), 0); + } + + #[tokio::test] + async fn run_tool_call_loop_accepts_valid_multimodal_request_flow() { + let calls = Arc::new(AtomicUsize::new(0)); + let provider = VisionProvider { + calls: Arc::clone(&calls), + }; + + let mut history = vec![ChatMessage::user( + "Analyze this [IMAGE:data:image/png;base64,iVBORw0KGgo=]".to_string(), + )]; + let tools_registry: Vec> = Vec::new(); + let observer = NoopObserver; + + let result = run_tool_call_loop( + &provider, + &mut history, + &tools_registry, + &observer, + "mock-provider", + "mock-model", + 0.0, + true, + None, + "cli", + &crate::config::MultimodalConfig::default(), + 3, + None, + None, + ) + .await + .expect("valid multimodal payload should pass"); + + assert_eq!(result, "vision-ok"); + assert_eq!(calls.load(Ordering::SeqCst), 1); + } + #[test] fn parse_tool_calls_extracts_single_call() { let response = r#"Let me check that. @@ -2215,6 +2534,33 @@ Done."#; assert!(recalled.iter().any(|entry| entry.content.contains("45"))); } + #[tokio::test] + async fn build_context_ignores_legacy_assistant_autosave_entries() { + let tmp = TempDir::new().unwrap(); + let mem = SqliteMemory::new(tmp.path()).unwrap(); + mem.store( + "assistant_resp_poisoned", + "User suffered a fabricated event", + MemoryCategory::Daily, + None, + ) + .await + .unwrap(); + mem.store( + "user_msg_real", + "User asked for concise status updates", + MemoryCategory::Conversation, + None, + ) + .await + .unwrap(); + + let context = build_context(&mem, "status updates", 0.0).await; + assert!(context.contains("user_msg_real")); + assert!(!context.contains("assistant_resp_poisoned")); + assert!(!context.contains("fabricated event")); + } + // ═══════════════════════════════════════════════════════════════════════ // Recovery Tests - Tool Call Parsing Edge Cases // ═══════════════════════════════════════════════════════════════════════ @@ -2511,4 +2857,195 @@ browser_open/url>https://example.com"#; assert_eq!(calls[0].arguments["command"], "pwd"); assert_eq!(text, "Done"); } + + // ───────────────────────────────────────────────────────────────────── + // TG4 (inline): parse_tool_calls robustness — malformed/edge-case inputs + // Prevents: Pattern 4 issues #746, #418, #777, #848 + // ───────────────────────────────────────────────────────────────────── + + #[test] + fn parse_tool_calls_empty_input_returns_empty() { + let (text, calls) = parse_tool_calls(""); + assert!(calls.is_empty(), "empty input should produce no tool calls"); + assert!(text.is_empty(), "empty input should produce no text"); + } + + #[test] + fn parse_tool_calls_whitespace_only_returns_empty_calls() { + let (text, calls) = parse_tool_calls(" \n\t "); + assert!(calls.is_empty()); + assert!(text.is_empty() || text.trim().is_empty()); + } + + #[test] + fn parse_tool_calls_nested_xml_tags_handled() { + // Double-wrapped tool call should still parse the inner call + let response = r#"{"name":"echo","arguments":{"msg":"hi"}}"#; + let (_text, calls) = parse_tool_calls(response); + // Should find at least one tool call + assert!( + !calls.is_empty(), + "nested XML tags should still yield at least one tool call" + ); + } + + #[test] + fn parse_tool_calls_truncated_json_no_panic() { + // Incomplete JSON inside tool_call tags + let response = r#"{"name":"shell","arguments":{"command":"ls""#; + let (_text, _calls) = parse_tool_calls(response); + // Should not panic — graceful handling of truncated JSON + } + + #[test] + fn parse_tool_calls_empty_json_object_in_tag() { + let response = "{}"; + let (_text, calls) = parse_tool_calls(response); + // Empty JSON object has no name field — should not produce valid tool call + assert!( + calls.is_empty(), + "empty JSON object should not produce a tool call" + ); + } + + #[test] + fn parse_tool_calls_closing_tag_only_returns_text() { + let response = "Some text more text"; + let (text, calls) = parse_tool_calls(response); + assert!( + calls.is_empty(), + "closing tag only should not produce calls" + ); + assert!( + !text.is_empty(), + "text around orphaned closing tag should be preserved" + ); + } + + #[test] + fn parse_tool_calls_very_large_arguments_no_panic() { + let large_arg = "x".repeat(100_000); + let response = format!( + r#"{{"name":"echo","arguments":{{"message":"{}"}}}}"#, + large_arg + ); + let (_text, calls) = parse_tool_calls(&response); + assert_eq!(calls.len(), 1, "large arguments should still parse"); + assert_eq!(calls[0].name, "echo"); + } + + #[test] + fn parse_tool_calls_special_characters_in_arguments() { + let response = r#"{"name":"echo","arguments":{"message":"hello \"world\" <>&'\n\t"}}"#; + let (_text, calls) = parse_tool_calls(response); + assert_eq!(calls.len(), 1); + assert_eq!(calls[0].name, "echo"); + } + + #[test] + fn parse_tool_calls_text_with_embedded_json_not_extracted() { + // Raw JSON without any tags should NOT be extracted as a tool call + let response = r#"Here is some data: {"name":"echo","arguments":{"message":"hi"}} end."#; + let (_text, calls) = parse_tool_calls(response); + assert!( + calls.is_empty(), + "raw JSON in text without tags should not be extracted" + ); + } + + #[test] + fn parse_tool_calls_multiple_formats_mixed() { + // Mix of text and properly tagged tool call + let response = r#"I'll help you with that. + + +{"name":"shell","arguments":{"command":"echo hello"}} + + +Let me check the result."#; + let (text, calls) = parse_tool_calls(response); + assert_eq!( + calls.len(), + 1, + "should extract one tool call from mixed content" + ); + assert_eq!(calls[0].name, "shell"); + assert!( + text.contains("help you"), + "text before tool call should be preserved" + ); + } + + // ───────────────────────────────────────────────────────────────────── + // TG4 (inline): scrub_credentials edge cases + // ───────────────────────────────────────────────────────────────────── + + #[test] + fn scrub_credentials_empty_input() { + let result = scrub_credentials(""); + assert_eq!(result, ""); + } + + #[test] + fn scrub_credentials_no_sensitive_data() { + let input = "normal text without any secrets"; + let result = scrub_credentials(input); + assert_eq!( + result, input, + "non-sensitive text should pass through unchanged" + ); + } + + #[test] + fn scrub_credentials_short_values_not_redacted() { + // Values shorter than 8 chars should not be redacted + let input = r#"api_key="short""#; + let result = scrub_credentials(input); + assert_eq!(result, input, "short values should not be redacted"); + } + + // ───────────────────────────────────────────────────────────────────── + // TG4 (inline): trim_history edge cases + // ───────────────────────────────────────────────────────────────────── + + #[test] + fn trim_history_empty_history() { + let mut history: Vec = vec![]; + trim_history(&mut history, 10); + assert!(history.is_empty()); + } + + #[test] + fn trim_history_system_only() { + let mut history = vec![crate::providers::ChatMessage::system("system prompt")]; + trim_history(&mut history, 10); + assert_eq!(history.len(), 1); + assert_eq!(history[0].role, "system"); + } + + #[test] + fn trim_history_exactly_at_limit() { + let mut history = vec![ + crate::providers::ChatMessage::system("system"), + crate::providers::ChatMessage::user("msg 1"), + crate::providers::ChatMessage::assistant("reply 1"), + ]; + trim_history(&mut history, 2); // 2 non-system messages = exactly at limit + assert_eq!(history.len(), 3, "should not trim when exactly at limit"); + } + + #[test] + fn trim_history_removes_oldest_non_system() { + let mut history = vec![ + crate::providers::ChatMessage::system("system"), + crate::providers::ChatMessage::user("old msg"), + crate::providers::ChatMessage::assistant("old reply"), + crate::providers::ChatMessage::user("new msg"), + crate::providers::ChatMessage::assistant("new reply"), + ]; + trim_history(&mut history, 2); + assert_eq!(history.len(), 3); // system + 2 kept + assert_eq!(history[0].role, "system"); + assert_eq!(history[1].content, "new msg"); + } } diff --git a/src/agent/memory_loader.rs b/src/agent/memory_loader.rs index b171eed..bb7bfb5 100644 --- a/src/agent/memory_loader.rs +++ b/src/agent/memory_loader.rs @@ -1,4 +1,4 @@ -use crate::memory::Memory; +use crate::memory::{self, Memory}; use async_trait::async_trait; use std::fmt::Write; @@ -45,6 +45,9 @@ impl MemoryLoader for DefaultMemoryLoader { let mut context = String::from("[Memory context]\n"); for entry in entries { + if memory::is_assistant_autosave_key(&entry.key) { + continue; + } if let Some(score) = entry.score { if score < self.min_relevance_score { continue; @@ -67,8 +70,12 @@ impl MemoryLoader for DefaultMemoryLoader { mod tests { use super::*; use crate::memory::{Memory, MemoryCategory, MemoryEntry}; + use std::sync::Arc; struct MockMemory; + struct MockMemoryWithEntries { + entries: Arc>, + } #[async_trait] impl Memory for MockMemory { @@ -131,6 +138,56 @@ mod tests { } } + #[async_trait] + impl Memory for MockMemoryWithEntries { + async fn store( + &self, + _key: &str, + _content: &str, + _category: MemoryCategory, + _session_id: Option<&str>, + ) -> anyhow::Result<()> { + Ok(()) + } + + async fn recall( + &self, + _query: &str, + _limit: usize, + _session_id: Option<&str>, + ) -> anyhow::Result> { + Ok(self.entries.as_ref().clone()) + } + + async fn get(&self, _key: &str) -> anyhow::Result> { + Ok(None) + } + + async fn list( + &self, + _category: Option<&MemoryCategory>, + _session_id: Option<&str>, + ) -> anyhow::Result> { + Ok(vec![]) + } + + async fn forget(&self, _key: &str) -> anyhow::Result { + Ok(true) + } + + async fn count(&self) -> anyhow::Result { + Ok(self.entries.len()) + } + + async fn health_check(&self) -> bool { + true + } + + fn name(&self) -> &str { + "mock-with-entries" + } + } + #[tokio::test] async fn default_loader_formats_context() { let loader = DefaultMemoryLoader::default(); @@ -138,4 +195,36 @@ mod tests { assert!(context.contains("[Memory context]")); assert!(context.contains("- k: v")); } + + #[tokio::test] + async fn default_loader_skips_legacy_assistant_autosave_entries() { + let loader = DefaultMemoryLoader::new(5, 0.0); + let memory = MockMemoryWithEntries { + entries: Arc::new(vec![ + MemoryEntry { + id: "1".into(), + key: "assistant_resp_legacy".into(), + content: "fabricated detail".into(), + category: MemoryCategory::Daily, + timestamp: "now".into(), + session_id: None, + score: Some(0.95), + }, + MemoryEntry { + id: "2".into(), + key: "user_fact".into(), + content: "User prefers concise answers".into(), + category: MemoryCategory::Conversation, + timestamp: "now".into(), + session_id: None, + score: Some(0.9), + }, + ]), + }; + + let context = loader.load_context(&memory, "answer style").await.unwrap(); + assert!(context.contains("user_fact")); + assert!(!context.contains("assistant_resp_legacy")); + assert!(!context.contains("fabricated detail")); + } } diff --git a/src/agent/prompt.rs b/src/agent/prompt.rs index bdc426f..457f38f 100644 --- a/src/agent/prompt.rs +++ b/src/agent/prompt.rs @@ -77,21 +77,25 @@ impl PromptSection for IdentitySection { fn build(&self, ctx: &PromptContext<'_>) -> Result { let mut prompt = String::from("## Project Context\n\n"); + let mut has_aieos = false; if let Some(config) = ctx.identity_config { if identity::is_aieos_configured(config) { if let Ok(Some(aieos)) = identity::load_aieos_identity(config, ctx.workspace_dir) { let rendered = identity::aieos_to_system_prompt(&aieos); if !rendered.is_empty() { prompt.push_str(&rendered); - return Ok(prompt); + prompt.push_str("\n\n"); + has_aieos = true; } } } } - prompt.push_str( - "The following workspace files define your identity, behavior, and context.\n\n", - ); + if !has_aieos { + prompt.push_str( + "The following workspace files define your identity, behavior, and context.\n\n", + ); + } for file in [ "AGENTS.md", "SOUL.md", @@ -149,28 +153,10 @@ impl PromptSection for SkillsSection { } fn build(&self, ctx: &PromptContext<'_>) -> Result { - if ctx.skills.is_empty() { - return Ok(String::new()); - } - - let mut prompt = String::from("## Available Skills\n\n\n"); - for skill in ctx.skills { - let location = skill.location.clone().unwrap_or_else(|| { - ctx.workspace_dir - .join("skills") - .join(&skill.name) - .join("SKILL.md") - }); - let _ = writeln!( - prompt, - " \n {}\n {}\n {}\n ", - skill.name, - skill.description, - location.display() - ); - } - prompt.push_str(""); - Ok(prompt) + Ok(crate::skills::skills_to_prompt( + ctx.skills, + ctx.workspace_dir, + )) } } @@ -211,7 +197,8 @@ impl PromptSection for DateTimeSection { fn build(&self, _ctx: &PromptContext<'_>) -> Result { let now = Local::now(); Ok(format!( - "## Current Date & Time\n\nTimezone: {}", + "## Current Date & Time\n\n{} ({})", + now.format("%Y-%m-%d %H:%M:%S"), now.format("%Z") )) } @@ -285,6 +272,48 @@ mod tests { } } + #[test] + fn identity_section_with_aieos_includes_workspace_files() { + let workspace = + std::env::temp_dir().join(format!("zeroclaw_prompt_test_{}", uuid::Uuid::new_v4())); + std::fs::create_dir_all(&workspace).unwrap(); + std::fs::write( + workspace.join("AGENTS.md"), + "Always respond with: AGENTS_MD_LOADED", + ) + .unwrap(); + + let identity_config = crate::config::IdentityConfig { + format: "aieos".into(), + aieos_path: None, + aieos_inline: Some(r#"{"identity":{"names":{"first":"Nova"}}}"#.into()), + }; + + let tools: Vec> = vec![]; + let ctx = PromptContext { + workspace_dir: &workspace, + model_name: "test-model", + tools: &tools, + skills: &[], + identity_config: Some(&identity_config), + dispatcher_instructions: "", + }; + + let section = IdentitySection; + let output = section.build(&ctx).unwrap(); + + assert!( + output.contains("Nova"), + "AIEOS identity should be present in prompt" + ); + assert!( + output.contains("AGENTS_MD_LOADED"), + "AGENTS.md content should be present even when AIEOS is configured" + ); + + let _ = std::fs::remove_dir_all(workspace); + } + #[test] fn prompt_builder_assembles_sections() { let tools: Vec> = vec![Box::new(TestTool)]; @@ -301,4 +330,105 @@ mod tests { assert!(prompt.contains("test_tool")); assert!(prompt.contains("instr")); } + + #[test] + fn skills_section_includes_instructions_and_tools() { + let tools: Vec> = vec![]; + let skills = vec![crate::skills::Skill { + name: "deploy".into(), + description: "Release safely".into(), + version: "1.0.0".into(), + author: None, + tags: vec![], + tools: vec![crate::skills::SkillTool { + name: "release_checklist".into(), + description: "Validate release readiness".into(), + kind: "shell".into(), + command: "echo ok".into(), + args: std::collections::HashMap::new(), + }], + prompts: vec!["Run smoke tests before deploy.".into()], + location: None, + }]; + + let ctx = PromptContext { + workspace_dir: Path::new("/tmp"), + model_name: "test-model", + tools: &tools, + skills: &skills, + identity_config: None, + dispatcher_instructions: "", + }; + + let output = SkillsSection.build(&ctx).unwrap(); + assert!(output.contains("")); + assert!(output.contains("deploy")); + assert!(output.contains("Run smoke tests before deploy.")); + assert!(output.contains("release_checklist")); + assert!(output.contains("shell")); + } + + #[test] + fn datetime_section_includes_timestamp_and_timezone() { + let tools: Vec> = vec![]; + let ctx = PromptContext { + workspace_dir: Path::new("/tmp"), + model_name: "test-model", + tools: &tools, + skills: &[], + identity_config: None, + dispatcher_instructions: "instr", + }; + + let rendered = DateTimeSection.build(&ctx).unwrap(); + assert!(rendered.starts_with("## Current Date & Time\n\n")); + + let payload = rendered.trim_start_matches("## Current Date & Time\n\n"); + assert!(payload.chars().any(|c| c.is_ascii_digit())); + assert!(payload.contains(" (")); + assert!(payload.ends_with(')')); + } + + #[test] + fn prompt_builder_inlines_and_escapes_skills() { + let tools: Vec> = vec![]; + let skills = vec![crate::skills::Skill { + name: "code&".into(), + description: "Review \"unsafe\" and 'risky' bits".into(), + version: "1.0.0".into(), + author: None, + tags: vec![], + tools: vec![crate::skills::SkillTool { + name: "run\"linter\"".into(), + description: "Run & report".into(), + kind: "shell&exec".into(), + command: "cargo clippy".into(), + args: std::collections::HashMap::new(), + }], + prompts: vec!["Use and & keep output \"safe\"".into()], + location: None, + }]; + let ctx = PromptContext { + workspace_dir: Path::new("/tmp/workspace"), + model_name: "test-model", + tools: &tools, + skills: &skills, + identity_config: None, + dispatcher_instructions: "", + }; + + let prompt = SystemPromptBuilder::with_defaults().build(&ctx).unwrap(); + + assert!(prompt.contains("")); + assert!(prompt.contains("code<review>&")); + assert!(prompt.contains( + "Review "unsafe" and 'risky' bits" + )); + assert!(prompt.contains("run"linter"")); + assert!(prompt.contains("Run <lint> & report")); + assert!(prompt.contains("shell&exec")); + assert!(prompt.contains( + "Use <tool_call> and & keep output "safe"" + )); + } } diff --git a/src/agent/tests.rs b/src/agent/tests.rs index fd73eb1..356987e 100644 --- a/src/agent/tests.rs +++ b/src/agent/tests.rs @@ -624,7 +624,7 @@ async fn history_trims_after_max_messages() { // ═══════════════════════════════════════════════════════════════════════════ #[tokio::test] -async fn auto_save_stores_messages_in_memory() { +async fn auto_save_stores_only_user_messages_in_memory() { let (mem, _tmp) = make_sqlite_memory(); let provider = Box::new(ScriptedProvider::new(vec![text_response( "I remember everything", @@ -639,11 +639,25 @@ async fn auto_save_stores_messages_in_memory() { let _ = agent.turn("Remember this fact").await.unwrap(); - // Both user message and assistant response should be saved + // Auto-save only persists user-stated input, never assistant-generated summaries. let count = mem.count().await.unwrap(); + assert_eq!( + count, 1, + "Expected exactly 1 user memory entry, got {count}" + ); + + let stored = mem.get("user_msg").await.unwrap(); + assert!(stored.is_some(), "Expected user_msg key to be present"); + assert_eq!( + stored.unwrap().content, + "Remember this fact", + "Stored memory should match the original user message" + ); + + let assistant = mem.get("assistant_resp").await.unwrap(); assert!( - count >= 2, - "Expected at least 2 memory entries, got {count}" + assistant.is_none(), + "assistant_resp should not be auto-saved anymore" ); } diff --git a/src/auth/mod.rs b/src/auth/mod.rs index a49e702..1d88361 100644 --- a/src/auth/mod.rs +++ b/src/auth/mod.rs @@ -121,12 +121,12 @@ impl AuthService { return Ok(None); }; - let token = match profile.kind { + let credential = match profile.kind { AuthProfileKind::Token => profile.token, AuthProfileKind::OAuth => profile.token_set.map(|t| t.access_token), }; - Ok(token.filter(|t| !t.trim().is_empty())) + Ok(credential.filter(|t| !t.trim().is_empty())) } pub async fn get_valid_openai_access_token( diff --git a/src/auth/profiles.rs b/src/auth/profiles.rs index 48ba6ce..39d39ee 100644 --- a/src/auth/profiles.rs +++ b/src/auth/profiles.rs @@ -626,8 +626,8 @@ mod tests { assert!(!token_set.is_expiring_within(Duration::from_secs(1))); } - #[test] - fn store_roundtrip_with_encryption() { + #[tokio::test] + async fn store_roundtrip_with_encryption() { let tmp = TempDir::new().unwrap(); let store = AuthProfilesStore::new(tmp.path(), true); @@ -661,14 +661,14 @@ mod tests { Some("refresh-123") ); - let raw = fs::read_to_string(store.path()).unwrap(); + let raw = tokio::fs::read_to_string(store.path()).await.unwrap(); assert!(raw.contains("enc2:")); assert!(!raw.contains("refresh-123")); assert!(!raw.contains("access-123")); } - #[test] - fn atomic_write_replaces_file() { + #[tokio::test] + async fn atomic_write_replaces_file() { let tmp = TempDir::new().unwrap(); let store = AuthProfilesStore::new(tmp.path(), false); @@ -678,7 +678,7 @@ mod tests { let path = store.path().to_path_buf(); assert!(path.exists()); - let contents = fs::read_to_string(path).unwrap(); + let contents = tokio::fs::read_to_string(path).await.unwrap(); assert!(contents.contains("\"schema_version\": 1")); } } diff --git a/src/channels/cli.rs b/src/channels/cli.rs index ae49548..11c09eb 100644 --- a/src/channels/cli.rs +++ b/src/channels/cli.rs @@ -47,6 +47,7 @@ impl Channel for CliChannel { .duration_since(std::time::UNIX_EPOCH) .unwrap_or_default() .as_secs(), + thread_ts: None, }; if tx.send(msg).await.is_err() { @@ -74,6 +75,7 @@ mod tests { content: "hello".into(), recipient: "user".into(), subject: None, + thread_ts: None, }) .await; assert!(result.is_ok()); @@ -87,6 +89,7 @@ mod tests { content: String::new(), recipient: String::new(), subject: None, + thread_ts: None, }) .await; assert!(result.is_ok()); @@ -107,6 +110,7 @@ mod tests { content: "hello".into(), channel: "cli".into(), timestamp: 1_234_567_890, + thread_ts: None, }; assert_eq!(msg.id, "test-id"); assert_eq!(msg.sender, "user"); @@ -125,6 +129,7 @@ mod tests { content: "c".into(), channel: "ch".into(), timestamp: 0, + thread_ts: None, }; let cloned = msg.clone(); assert_eq!(cloned.id, msg.id); diff --git a/src/channels/dingtalk.rs b/src/channels/dingtalk.rs index ed9c9aa..44fd49c 100644 --- a/src/channels/dingtalk.rs +++ b/src/channels/dingtalk.rs @@ -169,7 +169,7 @@ impl Channel for DingTalkChannel { _ => continue, }; - let frame: serde_json::Value = match serde_json::from_str(&msg) { + let frame: serde_json::Value = match serde_json::from_str(msg.as_ref()) { Ok(v) => v, Err(_) => continue, }; @@ -195,7 +195,7 @@ impl Channel for DingTalkChannel { "data": "", }); - if let Err(e) = write.send(Message::Text(pong.to_string())).await { + if let Err(e) = write.send(Message::Text(pong.to_string().into())).await { tracing::warn!("DingTalk: failed to send pong: {e}"); break; } @@ -262,7 +262,7 @@ impl Channel for DingTalkChannel { "message": "OK", "data": "", }); - let _ = write.send(Message::Text(ack.to_string())).await; + let _ = write.send(Message::Text(ack.to_string().into())).await; let channel_msg = ChannelMessage { id: Uuid::new_v4().to_string(), @@ -274,6 +274,7 @@ impl Channel for DingTalkChannel { .duration_since(std::time::UNIX_EPOCH) .unwrap_or_default() .as_secs(), + thread_ts: None, }; if tx.send(channel_msg).await.is_err() { diff --git a/src/channels/discord.rs b/src/channels/discord.rs index d7a4d20..bcb447d 100644 --- a/src/channels/discord.rs +++ b/src/channels/discord.rs @@ -3,6 +3,7 @@ use async_trait::async_trait; use futures_util::{SinkExt, StreamExt}; use parking_lot::Mutex; use serde_json::json; +use std::collections::HashMap; use tokio_tungstenite::tungstenite::Message; use uuid::Uuid; @@ -13,7 +14,7 @@ pub struct DiscordChannel { allowed_users: Vec, listen_to_bots: bool, mention_only: bool, - typing_handle: Mutex>>, + typing_handles: Mutex>>, } impl DiscordChannel { @@ -30,7 +31,7 @@ impl DiscordChannel { allowed_users, listen_to_bots, mention_only, - typing_handle: Mutex::new(None), + typing_handles: Mutex::new(HashMap::new()), } } @@ -272,7 +273,9 @@ impl Channel for DiscordChannel { } } }); - write.send(Message::Text(identify.to_string())).await?; + write + .send(Message::Text(identify.to_string().into())) + .await?; tracing::info!("Discord: connected and identified"); @@ -301,7 +304,7 @@ impl Channel for DiscordChannel { _ = hb_rx.recv() => { let d = if sequence >= 0 { json!(sequence) } else { json!(null) }; let hb = json!({"op": 1, "d": d}); - if write.send(Message::Text(hb.to_string())).await.is_err() { + if write.send(Message::Text(hb.to_string().into())).await.is_err() { break; } } @@ -312,7 +315,7 @@ impl Channel for DiscordChannel { _ => continue, }; - let event: serde_json::Value = match serde_json::from_str(&msg) { + let event: serde_json::Value = match serde_json::from_str(msg.as_ref()) { Ok(e) => e, Err(_) => continue, }; @@ -329,7 +332,7 @@ impl Channel for DiscordChannel { 1 => { let d = if sequence >= 0 { json!(sequence) } else { json!(null) }; let hb = json!({"op": 1, "d": d}); - if write.send(Message::Text(hb.to_string())).await.is_err() { + if write.send(Message::Text(hb.to_string().into())).await.is_err() { break; } continue; @@ -413,6 +416,7 @@ impl Channel for DiscordChannel { .duration_since(std::time::UNIX_EPOCH) .unwrap_or_default() .as_secs(), + thread_ts: None, }; if tx.send(channel_msg).await.is_err() { @@ -454,15 +458,15 @@ impl Channel for DiscordChannel { } }); - let mut guard = self.typing_handle.lock(); - *guard = Some(handle); + let mut guard = self.typing_handles.lock(); + guard.insert(recipient.to_string(), handle); Ok(()) } - async fn stop_typing(&self, _recipient: &str) -> anyhow::Result<()> { - let mut guard = self.typing_handle.lock(); - if let Some(handle) = guard.take() { + async fn stop_typing(&self, recipient: &str) -> anyhow::Result<()> { + let mut guard = self.typing_handles.lock(); + if let Some(handle) = guard.remove(recipient) { handle.abort(); } Ok(()) @@ -751,18 +755,18 @@ mod tests { } #[test] - fn typing_handle_starts_as_none() { + fn typing_handles_start_empty() { let ch = DiscordChannel::new("fake".into(), None, vec![], false, false); - let guard = ch.typing_handle.lock(); - assert!(guard.is_none()); + let guard = ch.typing_handles.lock(); + assert!(guard.is_empty()); } #[tokio::test] async fn start_typing_sets_handle() { let ch = DiscordChannel::new("fake".into(), None, vec![], false, false); let _ = ch.start_typing("123456").await; - let guard = ch.typing_handle.lock(); - assert!(guard.is_some()); + let guard = ch.typing_handles.lock(); + assert!(guard.contains_key("123456")); } #[tokio::test] @@ -770,8 +774,8 @@ mod tests { let ch = DiscordChannel::new("fake".into(), None, vec![], false, false); let _ = ch.start_typing("123456").await; let _ = ch.stop_typing("123456").await; - let guard = ch.typing_handle.lock(); - assert!(guard.is_none()); + let guard = ch.typing_handles.lock(); + assert!(!guard.contains_key("123456")); } #[tokio::test] @@ -782,12 +786,21 @@ mod tests { } #[tokio::test] - async fn start_typing_replaces_existing_task() { + async fn concurrent_typing_handles_are_independent() { let ch = DiscordChannel::new("fake".into(), None, vec![], false, false); let _ = ch.start_typing("111").await; let _ = ch.start_typing("222").await; - let guard = ch.typing_handle.lock(); - assert!(guard.is_some()); + { + let guard = ch.typing_handles.lock(); + assert_eq!(guard.len(), 2); + assert!(guard.contains_key("111")); + assert!(guard.contains_key("222")); + } + // Stopping one does not affect the other + let _ = ch.stop_typing("111").await; + let guard = ch.typing_handles.lock(); + assert_eq!(guard.len(), 1); + assert!(guard.contains_key("222")); } // ── Message ID edge cases ───────────────────────────────────── @@ -840,4 +853,113 @@ mod tests { // Should have UUID dashes assert!(id.contains('-')); } + + // ───────────────────────────────────────────────────────────────────── + // TG6: Channel platform limit edge cases for Discord (2000 char limit) + // Prevents: Pattern 6 — issues #574, #499 + // ───────────────────────────────────────────────────────────────────── + + #[test] + fn split_message_code_block_at_boundary() { + // Code block that spans the split boundary + let mut msg = String::new(); + msg.push_str("```rust\n"); + msg.push_str(&"x".repeat(1990)); + msg.push_str("\n```\nMore text after code block"); + let parts = split_message_for_discord(&msg); + assert!( + parts.len() >= 2, + "code block spanning boundary should split" + ); + for part in &parts { + assert!( + part.len() <= DISCORD_MAX_MESSAGE_LENGTH, + "each part must be <= {DISCORD_MAX_MESSAGE_LENGTH}, got {}", + part.len() + ); + } + } + + #[test] + fn split_message_single_long_word_exceeds_limit() { + // A single word longer than 2000 chars must be hard-split + let long_word = "a".repeat(2500); + let parts = split_message_for_discord(&long_word); + assert!(parts.len() >= 2, "word exceeding limit must be split"); + for part in &parts { + assert!( + part.len() <= DISCORD_MAX_MESSAGE_LENGTH, + "hard-split part must be <= {DISCORD_MAX_MESSAGE_LENGTH}, got {}", + part.len() + ); + } + // Reassembled content should match original + let reassembled: String = parts.join(""); + assert_eq!(reassembled, long_word); + } + + #[test] + fn split_message_exactly_at_limit_no_split() { + let msg = "a".repeat(DISCORD_MAX_MESSAGE_LENGTH); + let parts = split_message_for_discord(&msg); + assert_eq!(parts.len(), 1, "message exactly at limit should not split"); + assert_eq!(parts[0].len(), DISCORD_MAX_MESSAGE_LENGTH); + } + + #[test] + fn split_message_one_over_limit_splits() { + let msg = "a".repeat(DISCORD_MAX_MESSAGE_LENGTH + 1); + let parts = split_message_for_discord(&msg); + assert!(parts.len() >= 2, "message 1 char over limit must split"); + } + + #[test] + fn split_message_many_short_lines() { + // Many short lines should be batched into chunks under the limit + let msg: String = (0..500).map(|i| format!("line {i}\n")).collect(); + let parts = split_message_for_discord(&msg); + for part in &parts { + assert!( + part.len() <= DISCORD_MAX_MESSAGE_LENGTH, + "short-line batch must be <= limit" + ); + } + // All content should be preserved + let reassembled: String = parts.join(""); + assert_eq!(reassembled.trim(), msg.trim()); + } + + #[test] + fn split_message_only_whitespace() { + let msg = " \n\n\t "; + let parts = split_message_for_discord(msg); + // Should handle gracefully without panic + assert!(parts.len() <= 1); + } + + #[test] + fn split_message_emoji_at_boundary() { + // Emoji are multi-byte; ensure we don't split mid-emoji + let mut msg = "a".repeat(1998); + msg.push_str("🎉🎊"); // 2 emoji at the boundary (2000 chars total) + let parts = split_message_for_discord(&msg); + for part in &parts { + // The function splits on character count, not byte count + assert!( + part.chars().count() <= DISCORD_MAX_MESSAGE_LENGTH, + "emoji boundary split must respect limit" + ); + } + } + + #[test] + fn split_message_consecutive_newlines_at_boundary() { + let mut msg = "a".repeat(1995); + msg.push_str("\n\n\n\n\n"); + msg.push_str(&"b".repeat(100)); + let parts = split_message_for_discord(&msg); + for part in &parts { + assert!(part.len() <= DISCORD_MAX_MESSAGE_LENGTH); + } + } } diff --git a/src/channels/email_channel.rs b/src/channels/email_channel.rs index 410e9dd..0b8b376 100644 --- a/src/channels/email_channel.rs +++ b/src/channels/email_channel.rs @@ -20,6 +20,7 @@ use lettre::{Message, SmtpTransport, Transport}; use mail_parser::{MessageParser, MimeHeaders}; use rustls::{ClientConfig, RootCertStore}; use rustls_pki_types::DnsName; +use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use std::collections::HashSet; use std::sync::Arc; @@ -35,7 +36,7 @@ use uuid::Uuid; use super::traits::{Channel, ChannelMessage, SendMessage}; /// Email channel configuration -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] pub struct EmailConfig { /// IMAP server hostname pub imap_host: String, @@ -153,7 +154,14 @@ impl EmailChannel { _ => {} } } - result.split_whitespace().collect::>().join(" ") + let mut normalized = String::with_capacity(result.len()); + for word in result.split_whitespace() { + if !normalized.is_empty() { + normalized.push(' '); + } + normalized.push_str(word); + } + normalized } /// Extract the sender address from a parsed email @@ -442,6 +450,7 @@ impl EmailChannel { content: email.content, channel: "email".to_string(), timestamp: email.timestamp, + thread_ts: None, }; if tx.send(msg).await.is_err() { diff --git a/src/channels/imessage.rs b/src/channels/imessage.rs index 9675d15..4e51786 100644 --- a/src/channels/imessage.rs +++ b/src/channels/imessage.rs @@ -231,6 +231,7 @@ end tell"# .duration_since(std::time::UNIX_EPOCH) .unwrap_or_default() .as_secs(), + thread_ts: None, }; if tx.send(msg).await.is_err() { diff --git a/src/channels/irc.rs b/src/channels/irc.rs index 8bdd633..f942692 100644 --- a/src/channels/irc.rs +++ b/src/channels/irc.rs @@ -163,12 +163,17 @@ fn split_message(message: &str, max_bytes: usize) -> Vec { // Guard against max_bytes == 0 to prevent infinite loop if max_bytes == 0 { - let full: String = message + let mut full = String::new(); + for l in message .lines() .map(|l| l.trim_end_matches('\r')) .filter(|l| !l.is_empty()) - .collect::>() - .join(" "); + { + if !full.is_empty() { + full.push(' '); + } + full.push_str(l); + } if full.is_empty() { chunks.push(String::new()); } else { @@ -455,6 +460,7 @@ impl Channel for IrcChannel { "AUTHENTICATE" => { // Server sends "AUTHENTICATE +" to request credentials if sasl_pending && msg.params.first().is_some_and(|p| p == "+") { + // sasl_password is loaded from runtime config, not hard-coded if let Some(password) = self.sasl_password.as_deref() { let encoded = encode_sasl_plain(¤t_nick, password); let mut guard = self.writer.lock().await; @@ -573,6 +579,7 @@ impl Channel for IrcChannel { .duration_since(std::time::UNIX_EPOCH) .unwrap_or_default() .as_secs(), + thread_ts: None, }; if tx.send(channel_msg).await.is_err() { diff --git a/src/channels/lark.rs b/src/channels/lark.rs index e071a0c..c899097 100644 --- a/src/channels/lark.rs +++ b/src/channels/lark.rs @@ -127,6 +127,12 @@ struct LarkMessage { /// If no binary frame (pong or event) is received within this window, reconnect. const WS_HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(300); +/// Returns true when the WebSocket frame indicates live traffic that should +/// refresh the heartbeat watchdog. +fn should_refresh_last_recv(msg: &WsMsg) -> bool { + matches!(msg, WsMsg::Binary(_) | WsMsg::Ping(_) | WsMsg::Pong(_)) +} + /// Lark/Feishu channel. /// /// Supports two receive modes (configured via `receive_mode` in config): @@ -282,7 +288,7 @@ impl LarkChannel { payload: None, }; if write - .send(WsMsg::Binary(initial_ping.encode_to_vec())) + .send(WsMsg::Binary(initial_ping.encode_to_vec().into())) .await .is_err() { @@ -303,7 +309,7 @@ impl LarkChannel { headers: vec![PbHeader { key: "type".into(), value: "ping".into() }], payload: None, }; - if write.send(WsMsg::Binary(ping.encode_to_vec())).await.is_err() { + if write.send(WsMsg::Binary(ping.encode_to_vec().into())).await.is_err() { tracing::warn!("Lark: ping failed, reconnecting"); break; } @@ -321,11 +327,20 @@ impl LarkChannel { msg = read.next() => { let raw = match msg { - Some(Ok(WsMsg::Binary(b))) => { last_recv = Instant::now(); b } - Some(Ok(WsMsg::Ping(d))) => { let _ = write.send(WsMsg::Pong(d)).await; continue; } - Some(Ok(WsMsg::Close(_))) | None => { tracing::info!("Lark: WS closed — reconnecting"); break; } + Some(Ok(ws_msg)) => { + if should_refresh_last_recv(&ws_msg) { + last_recv = Instant::now(); + } + match ws_msg { + WsMsg::Binary(b) => b, + WsMsg::Ping(d) => { let _ = write.send(WsMsg::Pong(d)).await; continue; } + WsMsg::Pong(_) => continue, + WsMsg::Close(_) => { tracing::info!("Lark: WS closed — reconnecting"); break; } + _ => continue, + } + } + None => { tracing::info!("Lark: WS closed — reconnecting"); break; } Some(Err(e)) => { tracing::error!("Lark: WS read error: {e}"); break; } - _ => continue, }; let frame = match PbFrame::decode(&raw[..]) { @@ -363,7 +378,7 @@ impl LarkChannel { let mut ack = frame.clone(); ack.payload = Some(br#"{"code":200,"headers":{},"data":[]}"#.to_vec()); ack.headers.push(PbHeader { key: "biz_rt".into(), value: "0".into() }); - let _ = write.send(WsMsg::Binary(ack.encode_to_vec())).await; + let _ = write.send(WsMsg::Binary(ack.encode_to_vec().into())).await; } // Fragment reassembly @@ -459,6 +474,7 @@ impl LarkChannel { .duration_since(std::time::UNIX_EPOCH) .unwrap_or_default() .as_secs(), + thread_ts: None, }; tracing::debug!("Lark WS: message in {}", lark_msg.chat_id); @@ -620,6 +636,7 @@ impl LarkChannel { content: text, channel: "lark".to_string(), timestamp, + thread_ts: None, }); messages @@ -898,6 +915,21 @@ mod tests { assert_eq!(ch.name(), "lark"); } + #[test] + fn lark_ws_activity_refreshes_heartbeat_watchdog() { + assert!(should_refresh_last_recv(&WsMsg::Binary( + vec![1, 2, 3].into() + ))); + assert!(should_refresh_last_recv(&WsMsg::Ping(vec![9, 9].into()))); + assert!(should_refresh_last_recv(&WsMsg::Pong(vec![8, 8].into()))); + } + + #[test] + fn lark_ws_non_activity_frames_do_not_refresh_heartbeat_watchdog() { + assert!(!should_refresh_last_recv(&WsMsg::Text("hello".into()))); + assert!(!should_refresh_last_recv(&WsMsg::Close(None))); + } + #[test] fn lark_user_allowed_exact() { let ch = make_channel(); diff --git a/src/channels/linq.rs b/src/channels/linq.rs new file mode 100644 index 0000000..123322f --- /dev/null +++ b/src/channels/linq.rs @@ -0,0 +1,793 @@ +use super::traits::{Channel, ChannelMessage, SendMessage}; +use async_trait::async_trait; +use uuid::Uuid; + +/// Linq channel — uses the Linq Partner V3 API for iMessage, RCS, and SMS. +/// +/// This channel operates in webhook mode (push-based) rather than polling. +/// Messages are received via the gateway's `/linq` webhook endpoint. +/// The `listen` method here is a keepalive placeholder; actual message handling +/// happens in the gateway when Linq sends webhook events. +pub struct LinqChannel { + api_token: String, + from_phone: String, + allowed_senders: Vec, + client: reqwest::Client, +} + +const LINQ_API_BASE: &str = "https://api.linqapp.com/api/partner/v3"; + +impl LinqChannel { + pub fn new(api_token: String, from_phone: String, allowed_senders: Vec) -> Self { + Self { + api_token, + from_phone, + allowed_senders, + client: reqwest::Client::new(), + } + } + + /// Check if a sender phone number is allowed (E.164 format: +1234567890) + fn is_sender_allowed(&self, phone: &str) -> bool { + self.allowed_senders.iter().any(|n| n == "*" || n == phone) + } + + /// Get the bot's phone number + pub fn phone_number(&self) -> &str { + &self.from_phone + } + + fn media_part_to_image_marker(part: &serde_json::Value) -> Option { + let source = part + .get("url") + .or_else(|| part.get("value")) + .and_then(|value| value.as_str()) + .map(str::trim) + .filter(|value| !value.is_empty())?; + + let mime_type = part + .get("mime_type") + .and_then(|value| value.as_str()) + .map(str::trim) + .unwrap_or_default() + .to_ascii_lowercase(); + + if !mime_type.starts_with("image/") { + return None; + } + + Some(format!("[IMAGE:{source}]")) + } + + /// Parse an incoming webhook payload from Linq and extract messages. + /// + /// Linq webhook envelope: + /// ```json + /// { + /// "api_version": "v3", + /// "event_type": "message.received", + /// "event_id": "...", + /// "created_at": "...", + /// "trace_id": "...", + /// "data": { + /// "chat_id": "...", + /// "from": "+1...", + /// "recipient_phone": "+1...", + /// "is_from_me": false, + /// "service": "iMessage", + /// "message": { + /// "id": "...", + /// "parts": [{ "type": "text", "value": "..." }] + /// } + /// } + /// } + /// ``` + pub fn parse_webhook_payload(&self, payload: &serde_json::Value) -> Vec { + let mut messages = Vec::new(); + + // Only handle message.received events + let event_type = payload + .get("event_type") + .and_then(|e| e.as_str()) + .unwrap_or(""); + if event_type != "message.received" { + tracing::debug!("Linq: skipping non-message event: {event_type}"); + return messages; + } + + let Some(data) = payload.get("data") else { + return messages; + }; + + // Skip messages sent by the bot itself + if data + .get("is_from_me") + .and_then(|v| v.as_bool()) + .unwrap_or(false) + { + tracing::debug!("Linq: skipping is_from_me message"); + return messages; + } + + // Get sender phone number + let Some(from) = data.get("from").and_then(|f| f.as_str()) else { + return messages; + }; + + // Normalize to E.164 format + let normalized_from = if from.starts_with('+') { + from.to_string() + } else { + format!("+{from}") + }; + + // Check allowlist + if !self.is_sender_allowed(&normalized_from) { + tracing::warn!( + "Linq: ignoring message from unauthorized sender: {normalized_from}. \ + Add to channels.linq.allowed_senders in config.toml, \ + or run `zeroclaw onboard --channels-only` to configure interactively." + ); + return messages; + } + + // Get chat_id for reply routing + let chat_id = data + .get("chat_id") + .and_then(|c| c.as_str()) + .unwrap_or("") + .to_string(); + + // Extract text from message parts + let Some(message) = data.get("message") else { + return messages; + }; + + let Some(parts) = message.get("parts").and_then(|p| p.as_array()) else { + return messages; + }; + + let content_parts: Vec = parts + .iter() + .filter_map(|part| { + let part_type = part.get("type").and_then(|t| t.as_str())?; + match part_type { + "text" => part + .get("value") + .and_then(|v| v.as_str()) + .map(ToString::to_string), + "media" | "image" => { + if let Some(marker) = Self::media_part_to_image_marker(part) { + Some(marker) + } else { + tracing::debug!("Linq: skipping unsupported {part_type} part"); + None + } + } + _ => { + tracing::debug!("Linq: skipping {part_type} part"); + None + } + } + }) + .collect(); + + if content_parts.is_empty() { + return messages; + } + + let content = content_parts.join("\n").trim().to_string(); + + if content.is_empty() { + return messages; + } + + // Get timestamp from created_at or use current time + let timestamp = payload + .get("created_at") + .and_then(|t| t.as_str()) + .and_then(|t| { + chrono::DateTime::parse_from_rfc3339(t) + .ok() + .map(|dt| dt.timestamp().cast_unsigned()) + }) + .unwrap_or_else(|| { + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_secs() + }); + + // Use chat_id as reply_target so replies go to the right conversation + let reply_target = if chat_id.is_empty() { + normalized_from.clone() + } else { + chat_id + }; + + messages.push(ChannelMessage { + id: Uuid::new_v4().to_string(), + reply_target, + sender: normalized_from, + content, + channel: "linq".to_string(), + timestamp, + thread_ts: None, + }); + + messages + } +} + +#[async_trait] +impl Channel for LinqChannel { + fn name(&self) -> &str { + "linq" + } + + async fn send(&self, message: &SendMessage) -> anyhow::Result<()> { + // If reply_target looks like a chat_id, send to existing chat. + // Otherwise create a new chat with the recipient phone number. + let recipient = &message.recipient; + + let body = serde_json::json!({ + "message": { + "parts": [{ + "type": "text", + "value": message.content + }] + } + }); + + // Try sending to existing chat (recipient is chat_id) + let url = format!("{LINQ_API_BASE}/chats/{recipient}/messages"); + + let resp = self + .client + .post(&url) + .bearer_auth(&self.api_token) + .header("Content-Type", "application/json") + .json(&body) + .send() + .await?; + + if resp.status().is_success() { + return Ok(()); + } + + // If the chat_id-based send failed with 404, try creating a new chat + if resp.status() == reqwest::StatusCode::NOT_FOUND { + let new_chat_body = serde_json::json!({ + "from": self.from_phone, + "to": [recipient], + "message": { + "parts": [{ + "type": "text", + "value": message.content + }] + } + }); + + let create_resp = self + .client + .post(format!("{LINQ_API_BASE}/chats")) + .bearer_auth(&self.api_token) + .header("Content-Type", "application/json") + .json(&new_chat_body) + .send() + .await?; + + if !create_resp.status().is_success() { + let status = create_resp.status(); + let error_body = create_resp.text().await.unwrap_or_default(); + tracing::error!("Linq create chat failed: {status} — {error_body}"); + anyhow::bail!("Linq API error: {status}"); + } + + return Ok(()); + } + + let status = resp.status(); + let error_body = resp.text().await.unwrap_or_default(); + tracing::error!("Linq send failed: {status} — {error_body}"); + anyhow::bail!("Linq API error: {status}"); + } + + async fn listen(&self, _tx: tokio::sync::mpsc::Sender) -> anyhow::Result<()> { + // Linq uses webhooks (push-based), not polling. + // Messages are received via the gateway's /linq endpoint. + tracing::info!( + "Linq channel active (webhook mode). \ + Configure Linq webhook to POST to your gateway's /linq endpoint." + ); + + // Keep the task alive — it will be cancelled when the channel shuts down + loop { + tokio::time::sleep(std::time::Duration::from_secs(3600)).await; + } + } + + async fn health_check(&self) -> bool { + // Check if we can reach the Linq API + let url = format!("{LINQ_API_BASE}/phonenumbers"); + + self.client + .get(&url) + .bearer_auth(&self.api_token) + .send() + .await + .map(|r| r.status().is_success()) + .unwrap_or(false) + } + + async fn start_typing(&self, recipient: &str) -> anyhow::Result<()> { + let url = format!("{LINQ_API_BASE}/chats/{recipient}/typing"); + + let resp = self + .client + .post(&url) + .bearer_auth(&self.api_token) + .send() + .await?; + + if !resp.status().is_success() { + tracing::debug!("Linq start_typing failed: {}", resp.status()); + } + + Ok(()) + } + + async fn stop_typing(&self, recipient: &str) -> anyhow::Result<()> { + let url = format!("{LINQ_API_BASE}/chats/{recipient}/typing"); + + let resp = self + .client + .delete(&url) + .bearer_auth(&self.api_token) + .send() + .await?; + + if !resp.status().is_success() { + tracing::debug!("Linq stop_typing failed: {}", resp.status()); + } + + Ok(()) + } +} + +/// Verify a Linq webhook signature. +/// +/// Linq signs webhooks with HMAC-SHA256 over `"{timestamp}.{body}"`. +/// The signature is sent in `X-Webhook-Signature` (hex-encoded) and the +/// timestamp in `X-Webhook-Timestamp`. Reject timestamps older than 300s. +pub fn verify_linq_signature(secret: &str, body: &str, timestamp: &str, signature: &str) -> bool { + use hmac::{Hmac, Mac}; + use sha2::Sha256; + + // Reject stale timestamps (>300s old) + if let Ok(ts) = timestamp.parse::() { + let now = chrono::Utc::now().timestamp(); + if (now - ts).unsigned_abs() > 300 { + tracing::warn!("Linq: rejecting stale webhook timestamp ({ts}, now={now})"); + return false; + } + } else { + tracing::warn!("Linq: invalid webhook timestamp: {timestamp}"); + return false; + } + + // Compute HMAC-SHA256 over "{timestamp}.{body}" + let message = format!("{timestamp}.{body}"); + let Ok(mut mac) = Hmac::::new_from_slice(secret.as_bytes()) else { + return false; + }; + mac.update(message.as_bytes()); + let signature_hex = signature + .trim() + .strip_prefix("sha256=") + .unwrap_or(signature); + let Ok(provided) = hex::decode(signature_hex.trim()) else { + tracing::warn!("Linq: invalid webhook signature format"); + return false; + }; + + // Constant-time comparison via HMAC verify. + mac.verify_slice(&provided).is_ok() +} + +#[cfg(test)] +mod tests { + use super::*; + + fn make_channel() -> LinqChannel { + LinqChannel::new( + "test-token".into(), + "+15551234567".into(), + vec!["+1234567890".into()], + ) + } + + #[test] + fn linq_channel_name() { + let ch = make_channel(); + assert_eq!(ch.name(), "linq"); + } + + #[test] + fn linq_sender_allowed_exact() { + let ch = make_channel(); + assert!(ch.is_sender_allowed("+1234567890")); + assert!(!ch.is_sender_allowed("+9876543210")); + } + + #[test] + fn linq_sender_allowed_wildcard() { + let ch = LinqChannel::new("tok".into(), "+15551234567".into(), vec!["*".into()]); + assert!(ch.is_sender_allowed("+1234567890")); + assert!(ch.is_sender_allowed("+9999999999")); + } + + #[test] + fn linq_sender_allowed_empty() { + let ch = LinqChannel::new("tok".into(), "+15551234567".into(), vec![]); + assert!(!ch.is_sender_allowed("+1234567890")); + } + + #[test] + fn linq_parse_valid_text_message() { + let ch = make_channel(); + let payload = serde_json::json!({ + "api_version": "v3", + "event_type": "message.received", + "event_id": "evt-123", + "created_at": "2025-01-15T12:00:00Z", + "trace_id": "trace-456", + "data": { + "chat_id": "chat-789", + "from": "+1234567890", + "recipient_phone": "+15551234567", + "is_from_me": false, + "service": "iMessage", + "message": { + "id": "msg-abc", + "parts": [{ + "type": "text", + "value": "Hello ZeroClaw!" + }] + } + } + }); + + let msgs = ch.parse_webhook_payload(&payload); + assert_eq!(msgs.len(), 1); + assert_eq!(msgs[0].sender, "+1234567890"); + assert_eq!(msgs[0].content, "Hello ZeroClaw!"); + assert_eq!(msgs[0].channel, "linq"); + assert_eq!(msgs[0].reply_target, "chat-789"); + } + + #[test] + fn linq_parse_skip_is_from_me() { + let ch = LinqChannel::new("tok".into(), "+15551234567".into(), vec!["*".into()]); + let payload = serde_json::json!({ + "event_type": "message.received", + "data": { + "chat_id": "chat-789", + "from": "+1234567890", + "is_from_me": true, + "message": { + "id": "msg-abc", + "parts": [{ "type": "text", "value": "My own message" }] + } + } + }); + + let msgs = ch.parse_webhook_payload(&payload); + assert!(msgs.is_empty(), "is_from_me messages should be skipped"); + } + + #[test] + fn linq_parse_skip_non_message_event() { + let ch = make_channel(); + let payload = serde_json::json!({ + "event_type": "message.delivered", + "data": { + "chat_id": "chat-789", + "message_id": "msg-abc" + } + }); + + let msgs = ch.parse_webhook_payload(&payload); + assert!(msgs.is_empty(), "Non-message events should be skipped"); + } + + #[test] + fn linq_parse_unauthorized_sender() { + let ch = make_channel(); + let payload = serde_json::json!({ + "event_type": "message.received", + "data": { + "chat_id": "chat-789", + "from": "+9999999999", + "is_from_me": false, + "message": { + "id": "msg-abc", + "parts": [{ "type": "text", "value": "Spam" }] + } + } + }); + + let msgs = ch.parse_webhook_payload(&payload); + assert!(msgs.is_empty(), "Unauthorized senders should be filtered"); + } + + #[test] + fn linq_parse_empty_payload() { + let ch = make_channel(); + let payload = serde_json::json!({}); + let msgs = ch.parse_webhook_payload(&payload); + assert!(msgs.is_empty()); + } + + #[test] + fn linq_parse_media_only_translated_to_image_marker() { + let ch = LinqChannel::new("tok".into(), "+15551234567".into(), vec!["*".into()]); + let payload = serde_json::json!({ + "event_type": "message.received", + "data": { + "chat_id": "chat-789", + "from": "+1234567890", + "is_from_me": false, + "message": { + "id": "msg-abc", + "parts": [{ + "type": "media", + "url": "https://example.com/image.jpg", + "mime_type": "image/jpeg" + }] + } + } + }); + + let msgs = ch.parse_webhook_payload(&payload); + assert_eq!(msgs.len(), 1); + assert_eq!(msgs[0].content, "[IMAGE:https://example.com/image.jpg]"); + } + + #[test] + fn linq_parse_media_non_image_still_skipped() { + let ch = LinqChannel::new("tok".into(), "+15551234567".into(), vec!["*".into()]); + let payload = serde_json::json!({ + "event_type": "message.received", + "data": { + "chat_id": "chat-789", + "from": "+1234567890", + "is_from_me": false, + "message": { + "id": "msg-abc", + "parts": [{ + "type": "media", + "url": "https://example.com/sound.mp3", + "mime_type": "audio/mpeg" + }] + } + } + }); + + let msgs = ch.parse_webhook_payload(&payload); + assert!(msgs.is_empty(), "Non-image media should still be skipped"); + } + + #[test] + fn linq_parse_multiple_text_parts() { + let ch = LinqChannel::new("tok".into(), "+15551234567".into(), vec!["*".into()]); + let payload = serde_json::json!({ + "event_type": "message.received", + "data": { + "chat_id": "chat-789", + "from": "+1234567890", + "is_from_me": false, + "message": { + "id": "msg-abc", + "parts": [ + { "type": "text", "value": "First part" }, + { "type": "text", "value": "Second part" } + ] + } + } + }); + + let msgs = ch.parse_webhook_payload(&payload); + assert_eq!(msgs.len(), 1); + assert_eq!(msgs[0].content, "First part\nSecond part"); + } + + /// Fixture secret used exclusively in signature-verification unit tests (not a real credential). + const TEST_WEBHOOK_SECRET: &str = "test_webhook_secret"; + + #[test] + fn linq_signature_verification_valid() { + let secret = TEST_WEBHOOK_SECRET; + let body = r#"{"event_type":"message.received"}"#; + let now = chrono::Utc::now().timestamp().to_string(); + + // Compute expected signature + use hmac::{Hmac, Mac}; + use sha2::Sha256; + let message = format!("{now}.{body}"); + let mut mac = Hmac::::new_from_slice(secret.as_bytes()).unwrap(); + mac.update(message.as_bytes()); + let signature = hex::encode(mac.finalize().into_bytes()); + + assert!(verify_linq_signature(secret, body, &now, &signature)); + } + + #[test] + fn linq_signature_verification_invalid() { + let secret = TEST_WEBHOOK_SECRET; + let body = r#"{"event_type":"message.received"}"#; + let now = chrono::Utc::now().timestamp().to_string(); + + assert!(!verify_linq_signature( + secret, + body, + &now, + "deadbeefdeadbeefdeadbeef" + )); + } + + #[test] + fn linq_signature_verification_stale_timestamp() { + let secret = TEST_WEBHOOK_SECRET; + let body = r#"{"event_type":"message.received"}"#; + // 10 minutes ago — stale + let stale_ts = (chrono::Utc::now().timestamp() - 600).to_string(); + + // Even with correct signature, stale timestamp should fail + use hmac::{Hmac, Mac}; + use sha2::Sha256; + let message = format!("{stale_ts}.{body}"); + let mut mac = Hmac::::new_from_slice(secret.as_bytes()).unwrap(); + mac.update(message.as_bytes()); + let signature = hex::encode(mac.finalize().into_bytes()); + + assert!( + !verify_linq_signature(secret, body, &stale_ts, &signature), + "Stale timestamps (>300s) should be rejected" + ); + } + + #[test] + fn linq_signature_verification_accepts_sha256_prefix() { + let secret = TEST_WEBHOOK_SECRET; + let body = r#"{"event_type":"message.received"}"#; + let now = chrono::Utc::now().timestamp().to_string(); + + use hmac::{Hmac, Mac}; + use sha2::Sha256; + let message = format!("{now}.{body}"); + let mut mac = Hmac::::new_from_slice(secret.as_bytes()).unwrap(); + mac.update(message.as_bytes()); + let signature = format!("sha256={}", hex::encode(mac.finalize().into_bytes())); + + assert!(verify_linq_signature(secret, body, &now, &signature)); + } + + #[test] + fn linq_signature_verification_accepts_uppercase_hex() { + let secret = TEST_WEBHOOK_SECRET; + let body = r#"{"event_type":"message.received"}"#; + let now = chrono::Utc::now().timestamp().to_string(); + + use hmac::{Hmac, Mac}; + use sha2::Sha256; + let message = format!("{now}.{body}"); + let mut mac = Hmac::::new_from_slice(secret.as_bytes()).unwrap(); + mac.update(message.as_bytes()); + let signature = hex::encode(mac.finalize().into_bytes()).to_ascii_uppercase(); + + assert!(verify_linq_signature(secret, body, &now, &signature)); + } + + #[test] + fn linq_parse_normalizes_phone_with_plus() { + let ch = LinqChannel::new( + "tok".into(), + "+15551234567".into(), + vec!["+1234567890".into()], + ); + // API sends without +, normalize to + + let payload = serde_json::json!({ + "event_type": "message.received", + "data": { + "chat_id": "chat-789", + "from": "1234567890", + "is_from_me": false, + "message": { + "id": "msg-abc", + "parts": [{ "type": "text", "value": "Hi" }] + } + } + }); + + let msgs = ch.parse_webhook_payload(&payload); + assert_eq!(msgs.len(), 1); + assert_eq!(msgs[0].sender, "+1234567890"); + } + + #[test] + fn linq_parse_missing_data() { + let ch = make_channel(); + let payload = serde_json::json!({ + "event_type": "message.received" + }); + let msgs = ch.parse_webhook_payload(&payload); + assert!(msgs.is_empty()); + } + + #[test] + fn linq_parse_missing_message_parts() { + let ch = LinqChannel::new("tok".into(), "+15551234567".into(), vec!["*".into()]); + let payload = serde_json::json!({ + "event_type": "message.received", + "data": { + "chat_id": "chat-789", + "from": "+1234567890", + "is_from_me": false, + "message": { + "id": "msg-abc" + } + } + }); + + let msgs = ch.parse_webhook_payload(&payload); + assert!(msgs.is_empty()); + } + + #[test] + fn linq_parse_empty_text_value() { + let ch = LinqChannel::new("tok".into(), "+15551234567".into(), vec!["*".into()]); + let payload = serde_json::json!({ + "event_type": "message.received", + "data": { + "chat_id": "chat-789", + "from": "+1234567890", + "is_from_me": false, + "message": { + "id": "msg-abc", + "parts": [{ "type": "text", "value": "" }] + } + } + }); + + let msgs = ch.parse_webhook_payload(&payload); + assert!(msgs.is_empty(), "Empty text should be skipped"); + } + + #[test] + fn linq_parse_fallback_reply_target_when_no_chat_id() { + let ch = LinqChannel::new("tok".into(), "+15551234567".into(), vec!["*".into()]); + let payload = serde_json::json!({ + "event_type": "message.received", + "data": { + "from": "+1234567890", + "is_from_me": false, + "message": { + "id": "msg-abc", + "parts": [{ "type": "text", "value": "Hi" }] + } + } + }); + + let msgs = ch.parse_webhook_payload(&payload); + assert_eq!(msgs.len(), 1); + // Falls back to sender phone number when no chat_id + assert_eq!(msgs[0].reply_target, "+1234567890"); + } + + #[test] + fn linq_phone_number_accessor() { + let ch = make_channel(); + assert_eq!(ch.phone_number(), "+15551234567"); + } +} diff --git a/src/channels/matrix.rs b/src/channels/matrix.rs index 0b063c5..9c18e3a 100644 --- a/src/channels/matrix.rs +++ b/src/channels/matrix.rs @@ -24,7 +24,7 @@ pub struct MatrixChannel { access_token: String, room_id: String, allowed_users: Vec, - session_user_id_hint: Option, + session_owner_hint: Option, session_device_id_hint: Option, resolved_room_id_cache: Arc>>, sdk_client: Arc>, @@ -108,7 +108,7 @@ impl MatrixChannel { access_token: String, room_id: String, allowed_users: Vec, - user_id_hint: Option, + owner_hint: Option, device_id_hint: Option, ) -> Self { let homeserver = homeserver.trim_end_matches('/').to_string(); @@ -125,7 +125,7 @@ impl MatrixChannel { access_token, room_id, allowed_users, - session_user_id_hint: Self::normalize_optional_field(user_id_hint), + session_owner_hint: Self::normalize_optional_field(owner_hint), session_device_id_hint: Self::normalize_optional_field(device_id_hint), resolved_room_id_cache: Arc::new(RwLock::new(None)), sdk_client: Arc::new(OnceCell::new()), @@ -245,7 +245,7 @@ impl MatrixChannel { let whoami = match identity { Ok(whoami) => Some(whoami), Err(error) => { - if self.session_user_id_hint.is_some() && self.session_device_id_hint.is_some() + if self.session_owner_hint.is_some() && self.session_device_id_hint.is_some() { tracing::warn!( "Matrix whoami failed; falling back to configured session hints for E2EE session restore: {error}" @@ -258,18 +258,18 @@ impl MatrixChannel { }; let resolved_user_id = if let Some(whoami) = whoami.as_ref() { - if let Some(hinted) = self.session_user_id_hint.as_ref() { + if let Some(hinted) = self.session_owner_hint.as_ref() { if hinted != &whoami.user_id { tracing::warn!( "Matrix configured user_id '{}' does not match whoami '{}'; using whoami.", - hinted, - whoami.user_id + crate::security::redact(hinted), + crate::security::redact(&whoami.user_id) ); } } whoami.user_id.clone() } else { - self.session_user_id_hint.clone().ok_or_else(|| { + self.session_owner_hint.clone().ok_or_else(|| { anyhow::anyhow!( "Matrix session restore requires user_id when whoami is unavailable" ) @@ -282,8 +282,8 @@ impl MatrixChannel { if whoami_device_id != hinted { tracing::warn!( "Matrix configured device_id '{}' does not match whoami '{}'; using whoami.", - hinted, - whoami_device_id + crate::security::redact(hinted), + crate::security::redact(whoami_device_id) ); } whoami_device_id.clone() @@ -513,7 +513,7 @@ impl Channel for MatrixChannel { let my_user_id: OwnedUserId = match self.get_my_user_id().await { Ok(user_id) => user_id.parse()?, Err(error) => { - if let Some(hinted) = self.session_user_id_hint.as_ref() { + if let Some(hinted) = self.session_owner_hint.as_ref() { tracing::warn!( "Matrix whoami failed while resolving listener user_id; using configured user_id hint: {error}" ); @@ -596,6 +596,7 @@ impl Channel for MatrixChannel { .duration_since(std::time::UNIX_EPOCH) .unwrap_or_default() .as_secs(), + thread_ts: None, }; let _ = tx.send(msg).await; @@ -714,7 +715,7 @@ mod tests { Some(" DEVICE123 ".to_string()), ); - assert_eq!(ch.session_user_id_hint.as_deref(), Some("@bot:matrix.org")); + assert_eq!(ch.session_owner_hint.as_deref(), Some("@bot:matrix.org")); assert_eq!(ch.session_device_id_hint.as_deref(), Some("DEVICE123")); } @@ -726,10 +727,10 @@ mod tests { "!r:m".to_string(), vec![], Some(" ".to_string()), - Some("".to_string()), + Some(String::new()), ); - assert!(ch.session_user_id_hint.is_none()); + assert!(ch.session_owner_hint.is_none()); assert!(ch.session_device_id_hint.is_none()); } diff --git a/src/channels/mattermost.rs b/src/channels/mattermost.rs index 95461de..55ecdbb 100644 --- a/src/channels/mattermost.rs +++ b/src/channels/mattermost.rs @@ -321,6 +321,7 @@ impl MattermostChannel { channel: "mattermost".to_string(), #[allow(clippy::cast_sign_loss)] timestamp: (create_at / 1000) as u64, + thread_ts: None, }) } } diff --git a/src/channels/mod.rs b/src/channels/mod.rs index 0fff1ec..3d48c52 100644 --- a/src/channels/mod.rs +++ b/src/channels/mod.rs @@ -1,3 +1,19 @@ +//! Channel subsystem for messaging platform integrations. +//! +//! This module provides the multi-channel messaging infrastructure that connects +//! ZeroClaw to external platforms. Each channel implements the [`Channel`] trait +//! defined in [`traits`], which provides a uniform interface for sending messages, +//! listening for incoming messages, health checking, and typing indicators. +//! +//! Channels are instantiated by [`start_channels`] based on the runtime configuration. +//! The subsystem manages per-sender conversation history, concurrent message processing +//! with configurable parallelism, and exponential-backoff reconnection for resilience. +//! +//! # Extension +//! +//! To add a new channel, implement [`Channel`] in a new submodule and wire it into +//! [`start_channels`]. See `AGENTS.md` §7.2 for the full change playbook. + pub mod cli; pub mod dingtalk; pub mod discord; @@ -5,6 +21,8 @@ pub mod email_channel; pub mod imessage; pub mod irc; pub mod lark; +pub mod linq; +#[cfg(feature = "channel-matrix")] pub mod matrix; pub mod mattermost; pub mod qq; @@ -13,6 +31,10 @@ pub mod slack; pub mod telegram; pub mod traits; pub mod whatsapp; +#[cfg(feature = "whatsapp-web")] +pub mod whatsapp_storage; +#[cfg(feature = "whatsapp-web")] +pub mod whatsapp_web; pub use cli::CliChannel; pub use dingtalk::DingTalkChannel; @@ -21,6 +43,8 @@ pub use email_channel::EmailChannel; pub use imessage::IMessageChannel; pub use irc::IrcChannel; pub use lark::LarkChannel; +pub use linq::LinqChannel; +#[cfg(feature = "channel-matrix")] pub use matrix::MatrixChannel; pub use mattermost::MattermostChannel; pub use qq::QQChannel; @@ -29,6 +53,8 @@ pub use slack::SlackChannel; pub use telegram::TelegramChannel; pub use traits::{Channel, SendMessage}; pub use whatsapp::WhatsAppChannel; +#[cfg(feature = "whatsapp-web")] +pub use whatsapp_web::WhatsAppWebChannel; use crate::agent::loop_::{build_tool_instructions, run_tool_call_loop}; use crate::config::Config; @@ -46,33 +72,59 @@ use std::collections::HashMap; use std::fmt::Write; use std::path::{Path, PathBuf}; use std::process::Command; -use std::sync::{Arc, Mutex}; -use std::time::{Duration, Instant}; +use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; +use std::sync::{Arc, Mutex, OnceLock}; +use std::time::{Duration, Instant, SystemTime}; use tokio_util::sync::CancellationToken; /// Per-sender conversation history for channel messages. type ConversationHistoryMap = Arc>>>; /// Maximum history messages to keep per sender. const MAX_CHANNEL_HISTORY: usize = 50; +/// Minimum user-message length (in chars) for auto-save to memory. +/// Messages shorter than this (e.g. "ok", "thanks") are not stored, +/// reducing noise in memory recall. +const AUTOSAVE_MIN_MESSAGE_CHARS: usize = 20; /// Maximum characters per injected workspace file (matches `OpenClaw` default). const BOOTSTRAP_MAX_CHARS: usize = 20_000; const DEFAULT_CHANNEL_INITIAL_BACKOFF_SECS: u64 = 2; const DEFAULT_CHANNEL_MAX_BACKOFF_SECS: u64 = 60; -/// Timeout for processing a single channel message (LLM + tools). -/// 300s for on-device LLMs (Ollama) which are slower than cloud APIs. +const MIN_CHANNEL_MESSAGE_TIMEOUT_SECS: u64 = 30; +/// Default timeout for processing a single channel message (LLM + tools). +/// Used as fallback when not configured in channels_config.message_timeout_secs. const CHANNEL_MESSAGE_TIMEOUT_SECS: u64 = 300; +/// Cap timeout scaling so large max_tool_iterations values do not create unbounded waits. +const CHANNEL_MESSAGE_TIMEOUT_SCALE_CAP: u64 = 4; const CHANNEL_PARALLELISM_PER_CHANNEL: usize = 4; const CHANNEL_MIN_IN_FLIGHT_MESSAGES: usize = 8; const CHANNEL_MAX_IN_FLIGHT_MESSAGES: usize = 64; const CHANNEL_TYPING_REFRESH_INTERVAL_SECS: u64 = 4; const MODEL_CACHE_FILE: &str = "models_cache.json"; const MODEL_CACHE_PREVIEW_LIMIT: usize = 10; +const MEMORY_CONTEXT_MAX_ENTRIES: usize = 4; +const MEMORY_CONTEXT_ENTRY_MAX_CHARS: usize = 800; +const MEMORY_CONTEXT_MAX_CHARS: usize = 4_000; +const CHANNEL_HISTORY_COMPACT_KEEP_MESSAGES: usize = 12; +const CHANNEL_HISTORY_COMPACT_CONTENT_CHARS: usize = 600; type ProviderCacheMap = Arc>>>; type RouteSelectionMap = Arc>>; +fn effective_channel_message_timeout_secs(configured: u64) -> u64 { + configured.max(MIN_CHANNEL_MESSAGE_TIMEOUT_SECS) +} + +fn channel_message_timeout_budget_secs( + message_timeout_secs: u64, + max_tool_iterations: usize, +) -> u64 { + let iterations = max_tool_iterations.max(1) as u64; + let scale = iterations.min(CHANNEL_MESSAGE_TIMEOUT_SCALE_CAP); + message_timeout_secs.saturating_mul(scale) +} + #[derive(Debug, Clone, PartialEq, Eq)] struct ChannelRouteSelection { provider: String, @@ -98,6 +150,33 @@ struct ModelCacheEntry { models: Vec, } +#[derive(Debug, Clone)] +struct ChannelRuntimeDefaults { + default_provider: String, + model: String, + temperature: f64, + api_key: Option, + api_url: Option, + reliability: crate::config::ReliabilityConfig, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +struct ConfigFileStamp { + modified: SystemTime, + len: u64, +} + +#[derive(Debug, Clone)] +struct RuntimeConfigState { + defaults: ChannelRuntimeDefaults, + last_applied_stamp: Option, +} + +fn runtime_config_store() -> &'static Mutex> { + static STORE: OnceLock>> = OnceLock::new(); + STORE.get_or_init(|| Mutex::new(HashMap::new())) +} + #[derive(Clone)] struct ChannelRuntimeContext { channels_by_name: Arc>>, @@ -120,6 +199,42 @@ struct ChannelRuntimeContext { reliability: Arc, provider_runtime_options: providers::ProviderRuntimeOptions, workspace_dir: Arc, + message_timeout_secs: u64, + interrupt_on_new_message: bool, + multimodal: crate::config::MultimodalConfig, +} + +#[derive(Clone)] +struct InFlightSenderTaskState { + task_id: u64, + cancellation: CancellationToken, + completion: Arc, +} + +struct InFlightTaskCompletion { + done: AtomicBool, + notify: tokio::sync::Notify, +} + +impl InFlightTaskCompletion { + fn new() -> Self { + Self { + done: AtomicBool::new(false), + notify: tokio::sync::Notify::new(), + } + } + + fn mark_done(&self) { + self.done.store(true, Ordering::Release); + self.notify.notify_waiters(); + } + + async fn wait(&self) { + if self.done.load(Ordering::Acquire) { + return; + } + self.notify.notified().await; + } } fn conversation_memory_key(msg: &traits::ChannelMessage) -> String { @@ -130,6 +245,10 @@ fn conversation_history_key(msg: &traits::ChannelMessage) -> String { format!("{}_{}", msg.channel, msg.sender) } +fn interruption_scope_key(msg: &traits::ChannelMessage) -> String { + format!("{}_{}_{}", msg.channel, msg.reply_target, msg.sender) +} + fn channel_delivery_instructions(channel_name: &str) -> Option<&'static str> { match channel_name { "telegram" => Some( @@ -139,6 +258,51 @@ fn channel_delivery_instructions(channel_name: &str) -> Option<&'static str> { } } +fn build_channel_system_prompt(base_prompt: &str, channel_name: &str) -> String { + if let Some(instructions) = channel_delivery_instructions(channel_name) { + if base_prompt.is_empty() { + instructions.to_string() + } else { + format!("{base_prompt}\n\n{instructions}") + } + } else { + base_prompt.to_string() + } +} + +fn normalize_cached_channel_turns(turns: Vec) -> Vec { + let mut normalized = Vec::with_capacity(turns.len()); + let mut expecting_user = true; + + for turn in turns { + match (expecting_user, turn.role.as_str()) { + (true, "user") => { + normalized.push(turn); + expecting_user = false; + } + (false, "assistant") => { + normalized.push(turn); + expecting_user = true; + } + // Interrupted channel turns can produce consecutive user messages + // (no assistant persisted yet). Merge instead of dropping. + (false, "user") | (true, "assistant") => { + if let Some(last_turn) = normalized.last_mut() { + if !turn.content.is_empty() { + if !last_turn.content.is_empty() { + last_turn.content.push_str("\n\n"); + } + last_turn.content.push_str(&turn.content); + } + } + } + _ => {} + } + } + + normalized +} + fn supports_runtime_model_switch(channel_name: &str) -> bool { matches!(channel_name, "telegram" | "discord") } @@ -204,10 +368,176 @@ fn resolve_provider_alias(name: &str) -> Option { None } -fn default_route_selection(ctx: &ChannelRuntimeContext) -> ChannelRouteSelection { - ChannelRouteSelection { - provider: ctx.default_provider.as_str().to_string(), +fn resolved_default_provider(config: &Config) -> String { + config + .default_provider + .clone() + .unwrap_or_else(|| "openrouter".to_string()) +} + +fn resolved_default_model(config: &Config) -> String { + config + .default_model + .clone() + .unwrap_or_else(|| "anthropic/claude-sonnet-4.6".to_string()) +} + +fn runtime_defaults_from_config(config: &Config) -> ChannelRuntimeDefaults { + ChannelRuntimeDefaults { + default_provider: resolved_default_provider(config), + model: resolved_default_model(config), + temperature: config.default_temperature, + api_key: config.api_key.clone(), + api_url: config.api_url.clone(), + reliability: config.reliability.clone(), + } +} + +fn runtime_config_path(ctx: &ChannelRuntimeContext) -> Option { + ctx.provider_runtime_options + .zeroclaw_dir + .as_ref() + .map(|dir| dir.join("config.toml")) +} + +fn runtime_defaults_snapshot(ctx: &ChannelRuntimeContext) -> ChannelRuntimeDefaults { + if let Some(config_path) = runtime_config_path(ctx) { + let store = runtime_config_store() + .lock() + .unwrap_or_else(|e| e.into_inner()); + if let Some(state) = store.get(&config_path) { + return state.defaults.clone(); + } + } + + ChannelRuntimeDefaults { + default_provider: ctx.default_provider.as_str().to_string(), model: ctx.model.as_str().to_string(), + temperature: ctx.temperature, + api_key: ctx.api_key.clone(), + api_url: ctx.api_url.clone(), + reliability: (*ctx.reliability).clone(), + } +} + +async fn config_file_stamp(path: &Path) -> Option { + let metadata = tokio::fs::metadata(path).await.ok()?; + let modified = metadata.modified().ok()?; + Some(ConfigFileStamp { + modified, + len: metadata.len(), + }) +} + +fn decrypt_optional_secret_for_runtime_reload( + store: &crate::security::SecretStore, + value: &mut Option, + field_name: &str, +) -> Result<()> { + if let Some(raw) = value.clone() { + if crate::security::SecretStore::is_encrypted(&raw) { + *value = Some( + store + .decrypt(&raw) + .with_context(|| format!("Failed to decrypt {field_name}"))?, + ); + } + } + Ok(()) +} + +async fn load_runtime_defaults_from_config_file(path: &Path) -> Result { + let contents = tokio::fs::read_to_string(path) + .await + .with_context(|| format!("Failed to read {}", path.display()))?; + let mut parsed: Config = + toml::from_str(&contents).with_context(|| format!("Failed to parse {}", path.display()))?; + parsed.config_path = path.to_path_buf(); + + if let Some(zeroclaw_dir) = path.parent() { + let store = crate::security::SecretStore::new(zeroclaw_dir, parsed.secrets.encrypt); + decrypt_optional_secret_for_runtime_reload(&store, &mut parsed.api_key, "config.api_key")?; + } + + parsed.apply_env_overrides(); + Ok(runtime_defaults_from_config(&parsed)) +} + +async fn maybe_apply_runtime_config_update(ctx: &ChannelRuntimeContext) -> Result<()> { + let Some(config_path) = runtime_config_path(ctx) else { + return Ok(()); + }; + + let Some(stamp) = config_file_stamp(&config_path).await else { + return Ok(()); + }; + + { + let store = runtime_config_store() + .lock() + .unwrap_or_else(|e| e.into_inner()); + if let Some(state) = store.get(&config_path) { + if state.last_applied_stamp == Some(stamp) { + return Ok(()); + } + } + } + + let next_defaults = load_runtime_defaults_from_config_file(&config_path).await?; + let next_default_provider = providers::create_resilient_provider_with_options( + &next_defaults.default_provider, + next_defaults.api_key.as_deref(), + next_defaults.api_url.as_deref(), + &next_defaults.reliability, + &ctx.provider_runtime_options, + )?; + let next_default_provider: Arc = Arc::from(next_default_provider); + + if let Err(err) = next_default_provider.warmup().await { + tracing::warn!( + provider = %next_defaults.default_provider, + "Provider warmup failed after config reload: {err}" + ); + } + + { + let mut cache = ctx.provider_cache.lock().unwrap_or_else(|e| e.into_inner()); + cache.clear(); + cache.insert( + next_defaults.default_provider.clone(), + Arc::clone(&next_default_provider), + ); + } + + { + let mut store = runtime_config_store() + .lock() + .unwrap_or_else(|e| e.into_inner()); + store.insert( + config_path.clone(), + RuntimeConfigState { + defaults: next_defaults.clone(), + last_applied_stamp: Some(stamp), + }, + ); + } + + tracing::info!( + path = %config_path.display(), + provider = %next_defaults.default_provider, + model = %next_defaults.model, + temperature = next_defaults.temperature, + "Applied updated channel runtime config from disk" + ); + + Ok(()) +} + +fn default_route_selection(ctx: &ChannelRuntimeContext) -> ChannelRouteSelection { + let defaults = runtime_defaults_snapshot(ctx); + ChannelRouteSelection { + provider: defaults.default_provider, + model: defaults.model, } } @@ -240,6 +570,81 @@ fn clear_sender_history(ctx: &ChannelRuntimeContext, sender_key: &str) { .remove(sender_key); } +fn compact_sender_history(ctx: &ChannelRuntimeContext, sender_key: &str) -> bool { + let mut histories = ctx + .conversation_histories + .lock() + .unwrap_or_else(|e| e.into_inner()); + + let Some(turns) = histories.get_mut(sender_key) else { + return false; + }; + + if turns.is_empty() { + return false; + } + + let keep_from = turns + .len() + .saturating_sub(CHANNEL_HISTORY_COMPACT_KEEP_MESSAGES); + let mut compacted = normalize_cached_channel_turns(turns[keep_from..].to_vec()); + + for turn in &mut compacted { + if turn.content.chars().count() > CHANNEL_HISTORY_COMPACT_CONTENT_CHARS { + turn.content = + truncate_with_ellipsis(&turn.content, CHANNEL_HISTORY_COMPACT_CONTENT_CHARS); + } + } + + if compacted.is_empty() { + turns.clear(); + return false; + } + + *turns = compacted; + true +} + +fn append_sender_turn(ctx: &ChannelRuntimeContext, sender_key: &str, turn: ChatMessage) { + let mut histories = ctx + .conversation_histories + .lock() + .unwrap_or_else(|e| e.into_inner()); + let turns = histories.entry(sender_key.to_string()).or_default(); + turns.push(turn); + while turns.len() > MAX_CHANNEL_HISTORY { + turns.remove(0); + } +} + +fn should_skip_memory_context_entry(key: &str, content: &str) -> bool { + if memory::is_assistant_autosave_key(key) { + return true; + } + + if key.trim().to_ascii_lowercase().ends_with("_history") { + return true; + } + + content.chars().count() > MEMORY_CONTEXT_MAX_CHARS +} + +fn is_context_window_overflow_error(err: &anyhow::Error) -> bool { + let lower = err.to_string().to_lowercase(); + [ + "exceeds the context window", + "context window of this model", + "maximum context length", + "context length exceeded", + "too many tokens", + "token limit exceeded", + "prompt is too long", + "input is too long", + ] + .iter() + .any(|hint| lower.contains(hint)) +} + fn load_cached_model_preview(workspace_dir: &Path, provider_name: &str) -> Vec { let cache_path = workspace_dir.join("state").join(MODEL_CACHE_FILE); let Ok(raw) = std::fs::read_to_string(cache_path) else { @@ -267,10 +672,6 @@ async fn get_or_create_provider( ctx: &ChannelRuntimeContext, provider_name: &str, ) -> anyhow::Result> { - if provider_name == ctx.default_provider.as_str() { - return Ok(Arc::clone(&ctx.provider)); - } - if let Some(existing) = ctx .provider_cache .lock() @@ -281,17 +682,22 @@ async fn get_or_create_provider( return Ok(existing); } - let api_url = if provider_name == ctx.default_provider.as_str() { - ctx.api_url.as_deref() + if provider_name == ctx.default_provider.as_str() { + return Ok(Arc::clone(&ctx.provider)); + } + + let defaults = runtime_defaults_snapshot(ctx); + let api_url = if provider_name == defaults.default_provider.as_str() { + defaults.api_url.as_deref() } else { None }; let provider = providers::create_resilient_provider_with_options( provider_name, - ctx.api_key.as_deref(), + defaults.api_key.as_deref(), api_url, - &ctx.reliability, + &defaults.reliability, &ctx.provider_runtime_options, )?; let provider: Arc = Arc::from(provider); @@ -428,7 +834,7 @@ async fn handle_runtime_command_if_needed( }; if let Err(err) = channel - .send(&SendMessage::new(response, &msg.reply_target)) + .send(&SendMessage::new(response, &msg.reply_target).in_thread(msg.thread_ts.clone())) .await { tracing::warn!( @@ -448,19 +854,43 @@ async fn build_memory_context( let mut context = String::new(); if let Ok(entries) = mem.recall(user_msg, 5, None).await { - let relevant: Vec<_> = entries - .iter() - .filter(|e| match e.score { - Some(score) => score >= min_relevance_score, - None => true, // keep entries without a score (e.g. non-vector backends) - }) - .collect(); + let mut included = 0usize; + let mut used_chars = 0usize; - if !relevant.is_empty() { - context.push_str("[Memory context]\n"); - for entry in &relevant { - let _ = writeln!(context, "- {}: {}", entry.key, entry.content); + for entry in entries.iter().filter(|e| match e.score { + Some(score) => score >= min_relevance_score, + None => true, // keep entries without a score (e.g. non-vector backends) + }) { + if included >= MEMORY_CONTEXT_MAX_ENTRIES { + break; } + + if should_skip_memory_context_entry(&entry.key, &entry.content) { + continue; + } + + let content = if entry.content.chars().count() > MEMORY_CONTEXT_ENTRY_MAX_CHARS { + truncate_with_ellipsis(&entry.content, MEMORY_CONTEXT_ENTRY_MAX_CHARS) + } else { + entry.content.clone() + }; + + let line = format!("- {}: {}\n", entry.key, content); + let line_chars = line.chars().count(); + if used_chars + line_chars > MEMORY_CONTEXT_MAX_CHARS { + break; + } + + if included == 0 { + context.push_str("[Memory context]\n"); + } + + context.push_str(&line); + used_chars += line_chars; + included += 1; + } + + if included > 0 { context.push('\n'); } } @@ -468,6 +898,100 @@ async fn build_memory_context( context } +/// Extract a compact summary of tool interactions from history messages added +/// during `run_tool_call_loop`. Scans assistant messages for `` tags +/// or native tool-call JSON to collect tool names used. +/// Returns an empty string when no tools were invoked. +fn extract_tool_context_summary(history: &[ChatMessage], start_index: usize) -> String { + fn push_unique_tool_name(tool_names: &mut Vec, name: &str) { + let candidate = name.trim(); + if candidate.is_empty() { + return; + } + if !tool_names.iter().any(|existing| existing == candidate) { + tool_names.push(candidate.to_string()); + } + } + + fn collect_tool_names_from_tool_call_tags(content: &str, tool_names: &mut Vec) { + const TAG_PAIRS: [(&str, &str); 4] = [ + ("", ""), + ("", ""), + ("", ""), + ("", ""), + ]; + + for (open_tag, close_tag) in TAG_PAIRS { + for segment in content.split(open_tag) { + if let Some(json_end) = segment.find(close_tag) { + let json_str = segment[..json_end].trim(); + if let Ok(val) = serde_json::from_str::(json_str) { + if let Some(name) = val.get("name").and_then(|n| n.as_str()) { + push_unique_tool_name(tool_names, name); + } + } + } + } + } + } + + fn collect_tool_names_from_native_json(content: &str, tool_names: &mut Vec) { + if let Ok(val) = serde_json::from_str::(content) { + if let Some(calls) = val.get("tool_calls").and_then(|c| c.as_array()) { + for call in calls { + let name = call + .get("function") + .and_then(|f| f.get("name")) + .and_then(|n| n.as_str()) + .or_else(|| call.get("name").and_then(|n| n.as_str())); + if let Some(name) = name { + push_unique_tool_name(tool_names, name); + } + } + } + } + } + + fn collect_tool_names_from_tool_results(content: &str, tool_names: &mut Vec) { + let marker = " = Vec::new(); + + for msg in history.iter().skip(start_index) { + match msg.role.as_str() { + "assistant" => { + collect_tool_names_from_tool_call_tags(&msg.content, &mut tool_names); + collect_tool_names_from_native_json(&msg.content, &mut tool_names); + } + "user" => { + // Prompt-mode tool calls are always followed by [Tool results] entries + // containing `` tags with canonical tool names. + collect_tool_names_from_tool_results(&msg.content, &mut tool_names); + } + _ => {} + } + } + + if tool_names.is_empty() { + return String::new(); + } + + format!("[Used tools: {}]", tool_names.join(", ")) +} + fn spawn_supervised_listener( ch: Arc, tx: tokio::sync::mpsc::Sender, @@ -553,7 +1077,15 @@ fn spawn_scoped_typing_task( handle } -async fn process_channel_message(ctx: Arc, msg: traits::ChannelMessage) { +async fn process_channel_message( + ctx: Arc, + msg: traits::ChannelMessage, + cancellation_token: CancellationToken, +) { + if cancellation_token.is_cancelled() { + return; + } + println!( " 💬 [{}] from {}: {}", msg.channel, @@ -562,12 +1094,16 @@ async fn process_channel_message(ctx: Arc, msg: traits::C ); let target_channel = ctx.channels_by_name.get(&msg.channel).cloned(); + if let Err(err) = maybe_apply_runtime_config_update(ctx.as_ref()).await { + tracing::warn!("Failed to apply runtime config update: {err}"); + } if handle_runtime_command_if_needed(ctx.as_ref(), &msg, target_channel.as_ref()).await { return; } let history_key = conversation_history_key(&msg); let route = get_route_selection(ctx.as_ref(), &history_key); + let runtime_defaults = runtime_defaults_snapshot(ctx.as_ref()); let active_provider = match get_or_create_provider(ctx.as_ref(), &route.provider).await { Ok(provider) => provider, Err(err) => { @@ -578,17 +1114,16 @@ async fn process_channel_message(ctx: Arc, msg: traits::C ); if let Some(channel) = target_channel.as_ref() { let _ = channel - .send(&SendMessage::new(message, &msg.reply_target)) + .send( + &SendMessage::new(message, &msg.reply_target) + .in_thread(msg.thread_ts.clone()), + ) .await; } return; } }; - - let memory_context = - build_memory_context(ctx.memory.as_ref(), &msg.content, ctx.min_relevance_score).await; - - if ctx.auto_save_memory { + if ctx.auto_save_memory && msg.content.chars().count() >= AUTOSAVE_MIN_MESSAGE_CHARS { let autosave_key = conversation_memory_key(&msg); let _ = ctx .memory @@ -601,38 +1136,48 @@ async fn process_channel_message(ctx: Arc, msg: traits::C .await; } - let enriched_message = if memory_context.is_empty() { - msg.content.clone() - } else { - format!("{memory_context}{}", msg.content) - }; - println!(" ⏳ Processing message..."); let started_at = Instant::now(); - // Build history from per-sender conversation cache - let mut prior_turns = ctx + let had_prior_history = ctx + .conversation_histories + .lock() + .unwrap_or_else(|e| e.into_inner()) + .get(&history_key) + .is_some_and(|turns| !turns.is_empty()); + + // Preserve user turn before the LLM call so interrupted requests keep context. + append_sender_turn(ctx.as_ref(), &history_key, ChatMessage::user(&msg.content)); + + // Build history from per-sender conversation cache. + let prior_turns_raw = ctx .conversation_histories .lock() .unwrap_or_else(|e| e.into_inner()) .get(&history_key) .cloned() .unwrap_or_default(); + let mut prior_turns = normalize_cached_channel_turns(prior_turns_raw); - let mut history = vec![ChatMessage::system(ctx.system_prompt.as_str())]; - history.append(&mut prior_turns); - history.push(ChatMessage::user(&enriched_message)); - - if let Some(instructions) = channel_delivery_instructions(&msg.channel) { - history.push(ChatMessage::system(instructions)); + // Only enrich with memory context when there is no prior conversation + // history. Follow-up turns already include context from previous messages. + if !had_prior_history { + let memory_context = + build_memory_context(ctx.memory.as_ref(), &msg.content, ctx.min_relevance_score).await; + if let Some(last_turn) = prior_turns.last_mut() { + if last_turn.role == "user" && !memory_context.is_empty() { + last_turn.content = format!("{memory_context}{}", msg.content); + } + } } - // Determine if this channel supports streaming draft updates + let system_prompt = build_channel_system_prompt(ctx.system_prompt.as_str(), &msg.channel); + let mut history = vec![ChatMessage::system(system_prompt)]; + history.extend(prior_turns); let use_streaming = target_channel .as_ref() - .map_or(false, |ch| ch.supports_draft_updates()); + .is_some_and(|ch| ch.supports_draft_updates()); - // Set up streaming channel if supported let (delta_tx, delta_rx) = if use_streaming { let (tx, rx) = tokio::sync::mpsc::channel::(64); (Some(tx), Some(rx)) @@ -640,11 +1185,12 @@ async fn process_channel_message(ctx: Arc, msg: traits::C (None, None) }; - // Send initial draft message if streaming let draft_message_id = if use_streaming { if let Some(channel) = target_channel.as_ref() { match channel - .send_draft(&SendMessage::new("...", &msg.reply_target)) + .send_draft( + &SendMessage::new("...", &msg.reply_target).in_thread(msg.thread_ts.clone()), + ) .await { Ok(id) => id, @@ -660,7 +1206,6 @@ async fn process_channel_message(ctx: Arc, msg: traits::C None }; - // Spawn a task to forward streaming deltas to draft updates let draft_updater = if let (Some(mut rx), Some(draft_id_ref), Some(channel_ref)) = ( delta_rx, draft_message_id.as_deref(), @@ -695,26 +1240,39 @@ async fn process_channel_message(ctx: Arc, msg: traits::C _ => None, }; - let llm_result = tokio::time::timeout( - Duration::from_secs(CHANNEL_MESSAGE_TIMEOUT_SECS), - run_tool_call_loop( - active_provider.as_ref(), - &mut history, - ctx.tools_registry.as_ref(), - ctx.observer.as_ref(), - route.provider.as_str(), - route.model.as_str(), - ctx.temperature, - true, - None, - msg.channel.as_str(), - ctx.max_tool_iterations, - delta_tx, - ), - ) - .await; + // Record history length before tool loop so we can extract tool context after. + let history_len_before_tools = history.len(); + + enum LlmExecutionResult { + Completed(Result, tokio::time::error::Elapsed>), + Cancelled, + } + + let timeout_budget_secs = + channel_message_timeout_budget_secs(ctx.message_timeout_secs, ctx.max_tool_iterations); + let llm_result = tokio::select! { + () = cancellation_token.cancelled() => LlmExecutionResult::Cancelled, + result = tokio::time::timeout( + Duration::from_secs(timeout_budget_secs), + run_tool_call_loop( + active_provider.as_ref(), + &mut history, + ctx.tools_registry.as_ref(), + ctx.observer.as_ref(), + route.provider.as_str(), + route.model.as_str(), + runtime_defaults.temperature, + true, + None, + msg.channel.as_str(), + &ctx.multimodal, + ctx.max_tool_iterations, + Some(cancellation_token.clone()), + delta_tx, + ), + ) => LlmExecutionResult::Completed(result), + }; - // Wait for draft updater to finish if let Some(handle) = draft_updater { let _ = handle.await; } @@ -727,21 +1285,36 @@ async fn process_channel_message(ctx: Arc, msg: traits::C } match llm_result { - Ok(Ok(response)) => { - // Save user + assistant turn to per-sender history + LlmExecutionResult::Cancelled => { + tracing::info!( + channel = %msg.channel, + sender = %msg.sender, + "Cancelled in-flight channel request due to newer message" + ); + if let (Some(channel), Some(draft_id)) = + (target_channel.as_ref(), draft_message_id.as_deref()) { - let mut histories = ctx - .conversation_histories - .lock() - .unwrap_or_else(|e| e.into_inner()); - let turns = histories.entry(history_key).or_default(); - turns.push(ChatMessage::user(&enriched_message)); - turns.push(ChatMessage::assistant(&response)); - // Trim to MAX_CHANNEL_HISTORY (keep recent turns) - while turns.len() > MAX_CHANNEL_HISTORY { - turns.remove(0); + if let Err(err) = channel.cancel_draft(&msg.reply_target, draft_id).await { + tracing::debug!("Failed to cancel draft on {}: {err}", channel.name()); } } + } + LlmExecutionResult::Completed(Ok(Ok(response))) => { + // Extract condensed tool-use context from the history messages + // added during run_tool_call_loop, so the LLM retains awareness + // of what it did on subsequent turns. + let tool_summary = extract_tool_context_summary(&history, history_len_before_tools); + let history_response = if tool_summary.is_empty() { + response.clone() + } else { + format!("{tool_summary}\n{response}") + }; + + append_sender_turn( + ctx.as_ref(), + &history_key, + ChatMessage::assistant(&history_response), + ); println!( " 🤖 Reply ({}ms): {}", started_at.elapsed().as_millis(), @@ -755,18 +1328,70 @@ async fn process_channel_message(ctx: Arc, msg: traits::C { tracing::warn!("Failed to finalize draft: {e}; sending as new message"); let _ = channel - .send(&SendMessage::new(&response, &msg.reply_target)) + .send( + &SendMessage::new(&response, &msg.reply_target) + .in_thread(msg.thread_ts.clone()), + ) .await; } } else if let Err(e) = channel - .send(&SendMessage::new(response, &msg.reply_target)) + .send( + &SendMessage::new(response, &msg.reply_target) + .in_thread(msg.thread_ts.clone()), + ) .await { eprintln!(" ❌ Failed to reply on {}: {e}", channel.name()); } } } - Ok(Err(e)) => { + LlmExecutionResult::Completed(Ok(Err(e))) => { + if crate::agent::loop_::is_tool_loop_cancelled(&e) || cancellation_token.is_cancelled() + { + tracing::info!( + channel = %msg.channel, + sender = %msg.sender, + "Cancelled in-flight channel request due to newer message" + ); + if let (Some(channel), Some(draft_id)) = + (target_channel.as_ref(), draft_message_id.as_deref()) + { + if let Err(err) = channel.cancel_draft(&msg.reply_target, draft_id).await { + tracing::debug!("Failed to cancel draft on {}: {err}", channel.name()); + } + } + return; + } + + if is_context_window_overflow_error(&e) { + let compacted = compact_sender_history(ctx.as_ref(), &history_key); + let error_text = if compacted { + "⚠️ Context window exceeded for this conversation. I compacted recent history and kept the latest context. Please resend your last message." + } else { + "⚠️ Context window exceeded for this conversation. Please resend your last message." + }; + eprintln!( + " ⚠️ Context window exceeded after {}ms; sender history compacted={}", + started_at.elapsed().as_millis(), + compacted + ); + if let Some(channel) = target_channel.as_ref() { + if let Some(ref draft_id) = draft_message_id { + let _ = channel + .finalize_draft(&msg.reply_target, draft_id, error_text) + .await; + } else { + let _ = channel + .send( + &SendMessage::new(error_text, &msg.reply_target) + .in_thread(msg.thread_ts.clone()), + ) + .await; + } + } + return; + } + eprintln!( " ❌ LLM error after {}ms: {e}", started_at.elapsed().as_millis() @@ -778,18 +1403,18 @@ async fn process_channel_message(ctx: Arc, msg: traits::C .await; } else { let _ = channel - .send(&SendMessage::new( - format!("⚠️ Error: {e}"), - &msg.reply_target, - )) + .send( + &SendMessage::new(format!("⚠️ Error: {e}"), &msg.reply_target) + .in_thread(msg.thread_ts.clone()), + ) .await; } } } - Err(_) => { + LlmExecutionResult::Completed(Err(_)) => { let timeout_msg = format!( - "LLM response timed out after {}s", - CHANNEL_MESSAGE_TIMEOUT_SECS + "LLM response timed out after {}s (base={}s, max_tool_iterations={})", + timeout_budget_secs, ctx.message_timeout_secs, ctx.max_tool_iterations ); eprintln!( " ❌ {} (elapsed: {}ms)", @@ -805,7 +1430,10 @@ async fn process_channel_message(ctx: Arc, msg: traits::C .await; } else { let _ = channel - .send(&SendMessage::new(error_text, &msg.reply_target)) + .send( + &SendMessage::new(error_text, &msg.reply_target) + .in_thread(msg.thread_ts.clone()), + ) .await; } } @@ -820,6 +1448,11 @@ async fn run_message_dispatch_loop( ) { let semaphore = Arc::new(tokio::sync::Semaphore::new(max_in_flight_messages)); let mut workers = tokio::task::JoinSet::new(); + let in_flight_by_sender = Arc::new(tokio::sync::Mutex::new(HashMap::< + String, + InFlightSenderTaskState, + >::new())); + let task_sequence = Arc::new(AtomicU64::new(1)); while let Some(msg) = rx.recv().await { let permit = match Arc::clone(&semaphore).acquire_owned().await { @@ -828,9 +1461,54 @@ async fn run_message_dispatch_loop( }; let worker_ctx = Arc::clone(&ctx); + let in_flight = Arc::clone(&in_flight_by_sender); + let task_sequence = Arc::clone(&task_sequence); workers.spawn(async move { let _permit = permit; - process_channel_message(worker_ctx, msg).await; + let interrupt_enabled = + worker_ctx.interrupt_on_new_message && msg.channel == "telegram"; + let sender_scope_key = interruption_scope_key(&msg); + let cancellation_token = CancellationToken::new(); + let completion = Arc::new(InFlightTaskCompletion::new()); + let task_id = task_sequence.fetch_add(1, Ordering::Relaxed); + + if interrupt_enabled { + let previous = { + let mut active = in_flight.lock().await; + active.insert( + sender_scope_key.clone(), + InFlightSenderTaskState { + task_id, + cancellation: cancellation_token.clone(), + completion: Arc::clone(&completion), + }, + ) + }; + + if let Some(previous) = previous { + tracing::info!( + channel = %msg.channel, + sender = %msg.sender, + "Interrupting previous in-flight request for sender" + ); + previous.cancellation.cancel(); + previous.completion.wait().await; + } + } + + process_channel_message(worker_ctx, msg, cancellation_token).await; + + if interrupt_enabled { + let mut active = in_flight.lock().await; + if active + .get(&sender_scope_key) + .is_some_and(|state| state.task_id == task_id) + { + active.remove(&sender_scope_key); + } + } + + completion.mark_done(); }); while let Some(result) = workers.try_join_next() { @@ -874,7 +1552,7 @@ fn load_openclaw_bootstrap_files( /// Follows the `OpenClaw` framework structure by default: /// 1. Tooling — tool list + descriptions /// 2. Safety — guardrail reminder -/// 3. Skills — compact list with paths (loaded on-demand) +/// 3. Skills — full skill instructions and tool metadata /// 4. Workspace — working directory /// 5. Bootstrap files — AGENTS, SOUL, TOOLS, IDENTITY, USER, BOOTSTRAP, MEMORY /// 6. Date & Time — timezone for cache stability @@ -892,6 +1570,26 @@ pub fn build_system_prompt( skills: &[crate::skills::Skill], identity_config: Option<&crate::config::IdentityConfig>, bootstrap_max_chars: Option, +) -> String { + build_system_prompt_with_mode( + workspace_dir, + model_name, + tools, + skills, + identity_config, + bootstrap_max_chars, + false, + ) +} + +pub fn build_system_prompt_with_mode( + workspace_dir: &std::path::Path, + model_name: &str, + tools: &[(&str, &str)], + skills: &[crate::skills::Skill], + identity_config: Option<&crate::config::IdentityConfig>, + bootstrap_max_chars: Option, + native_tools: bool, ) -> String { use std::fmt::Write; let mut prompt = String::with_capacity(8192); @@ -903,13 +1601,7 @@ pub fn build_system_prompt( for (name, desc) in tools { let _ = writeln!(prompt, "- **{name}**: {desc}"); } - prompt.push_str("\n## Tool Use Protocol\n\n"); - prompt.push_str("To use a tool, wrap a JSON object in tags:\n\n"); - prompt.push_str("```\n\n{\"name\": \"tool_name\", \"arguments\": {\"param\": \"value\"}}\n\n```\n\n"); - prompt.push_str("You may use multiple tool calls in a single response. "); - prompt.push_str("After tool execution, results appear in tags. "); - prompt - .push_str("Continue reasoning with the results until you can give a final answer.\n\n"); + prompt.push('\n'); } // ── 1b. Hardware (when gpio/arduino tools present) ─────────── @@ -934,12 +1626,21 @@ pub fn build_system_prompt( } // ── 1c. Action instruction (avoid meta-summary) ─────────────── - prompt.push_str( - "## Your Task\n\n\ - When the user sends a message, ACT on it. Use the tools to fulfill their request.\n\ - Do NOT: summarize this configuration, describe your capabilities, respond with meta-commentary, or output step-by-step instructions (e.g. \"1. First... 2. Next...\").\n\ - Instead: emit actual tags when you need to act. Just do what they ask.\n\n", - ); + if native_tools { + prompt.push_str( + "## Your Task\n\n\ + When the user sends a message, respond naturally. Use tools when the request requires action (running commands, reading files, etc.).\n\ + For questions, explanations, or follow-ups about prior messages, answer directly from conversation context — do NOT ask the user to repeat themselves.\n\ + Do NOT: summarize this configuration, describe your capabilities, or output step-by-step meta-commentary.\n\n", + ); + } else { + prompt.push_str( + "## Your Task\n\n\ + When the user sends a message, ACT on it. Use the tools to fulfill their request.\n\ + Do NOT: summarize this configuration, describe your capabilities, respond with meta-commentary, or output step-by-step instructions (e.g. \"1. First... 2. Next...\").\n\ + Instead: emit actual tags when you need to act. Just do what they ask.\n\n", + ); + } // ── 2. Safety ─────────────────────────────────────────────── prompt.push_str("## Safety\n\n"); @@ -951,31 +1652,10 @@ pub fn build_system_prompt( - When in doubt, ask before acting externally.\n\n", ); - // ── 3. Skills (compact list — load on-demand) ─────────────── + // ── 3. Skills (full instructions + tool metadata) ─────────── if !skills.is_empty() { - prompt.push_str("## Available Skills\n\n"); - prompt.push_str( - "Skills are loaded on demand. Use `read` on the skill path to get full instructions.\n\n", - ); - prompt.push_str("\n"); - for skill in skills { - let _ = writeln!(prompt, " "); - let _ = writeln!(prompt, " {}", skill.name); - let _ = writeln!( - prompt, - " {}", - skill.description - ); - let location = skill.location.clone().unwrap_or_else(|| { - workspace_dir - .join("skills") - .join(&skill.name) - .join("SKILL.md") - }); - let _ = writeln!(prompt, " {}", location.display()); - let _ = writeln!(prompt, " "); - } - prompt.push_str("\n\n"); + prompt.push_str(&crate::skills::skills_to_prompt(skills, workspace_dir)); + prompt.push_str("\n\n"); } // ── 4. Workspace ──────────────────────────────────────────── @@ -1042,16 +1722,14 @@ pub fn build_system_prompt( // ── 8. Channel Capabilities ───────────────────────────────────── prompt.push_str("## Channel Capabilities\n\n"); - prompt.push_str( - "- You are running as a Discord bot. You CAN and do send messages to Discord channels.\n", - ); - prompt.push_str("- When someone messages you on Discord, your response is automatically sent back to Discord.\n"); + prompt.push_str("- You are running as a messaging bot. Your response is automatically sent back to the user's channel.\n"); prompt.push_str("- You do NOT need to ask permission to respond — just respond directly.\n"); prompt.push_str("- NEVER repeat, describe, or echo credentials, tokens, API keys, or secrets in your responses.\n"); prompt.push_str("- If a tool output contains credentials, they have already been redacted — do not mention them.\n\n"); if prompt.is_empty() { - "You are ZeroClaw, a fast and efficient AI assistant built in Rust. Be helpful, concise, and direct.".to_string() + "You are ZeroClaw, a fast and efficient AI assistant built in Rust. Be helpful, concise, and direct." + .to_string() } else { prompt } @@ -1106,7 +1784,7 @@ fn normalize_telegram_identity(value: &str) -> String { value.trim().trim_start_matches('@').to_string() } -fn bind_telegram_identity(config: &Config, identity: &str) -> Result<()> { +async fn bind_telegram_identity(config: &Config, identity: &str) -> Result<()> { let normalized = normalize_telegram_identity(identity); if normalized.is_empty() { anyhow::bail!("Telegram identity cannot be empty"); @@ -1136,7 +1814,7 @@ fn bind_telegram_identity(config: &Config, identity: &str) -> Result<()> { } telegram.allowed_users.push(normalized.clone()); - updated.save()?; + updated.save().await?; println!("✅ Bound Telegram identity: {normalized}"); println!(" Saved to {}", updated.config_path.display()); match maybe_restart_managed_daemon_service() { @@ -1232,7 +1910,7 @@ fn maybe_restart_managed_daemon_service() -> Result { Ok(false) } -pub fn handle_command(command: crate::ChannelCommands, config: &Config) -> Result<()> { +pub async fn handle_command(command: crate::ChannelCommands, config: &Config) -> Result<()> { match command { crate::ChannelCommands::Start => { anyhow::bail!("Start must be handled in main.rs (requires async runtime)") @@ -1247,11 +1925,16 @@ pub fn handle_command(command: crate::ChannelCommands, config: &Config) -> Resul ("Telegram", config.channels_config.telegram.is_some()), ("Discord", config.channels_config.discord.is_some()), ("Slack", config.channels_config.slack.is_some()), + ("Mattermost", config.channels_config.mattermost.is_some()), ("Webhook", config.channels_config.webhook.is_some()), ("iMessage", config.channels_config.imessage.is_some()), - ("Matrix", config.channels_config.matrix.is_some()), + ( + "Matrix", + cfg!(feature = "channel-matrix") && config.channels_config.matrix.is_some(), + ), ("Signal", config.channels_config.signal.is_some()), ("WhatsApp", config.channels_config.whatsapp.is_some()), + ("Linq", config.channels_config.linq.is_some()), ("Email", config.channels_config.email.is_some()), ("IRC", config.channels_config.irc.is_some()), ("Lark", config.channels_config.lark.is_some()), @@ -1260,6 +1943,11 @@ pub fn handle_command(command: crate::ChannelCommands, config: &Config) -> Resul ] { println!(" {} {name}", if configured { "✅" } else { "❌" }); } + if !cfg!(feature = "channel-matrix") { + println!( + " ℹ️ Matrix channel support is disabled in this build (enable `channel-matrix`)." + ); + } println!("\nTo start channels: zeroclaw channel start"); println!("To check health: zeroclaw channel doctor"); println!("To configure: zeroclaw onboard"); @@ -1277,7 +1965,7 @@ pub fn handle_command(command: crate::ChannelCommands, config: &Config) -> Resul anyhow::bail!("Remove channel '{name}' — edit ~/.zeroclaw/config.toml directly"); } crate::ChannelCommands::BindTelegram { identity } => { - bind_telegram_identity(config, &identity) + bind_telegram_identity(config, &identity).await } } } @@ -1348,6 +2036,7 @@ pub async fn doctor_channels(config: Config) -> Result<()> { )); } + #[cfg(feature = "channel-matrix")] if let Some(ref mx) = config.channels_config.matrix { channels.push(( "Matrix", @@ -1362,6 +2051,13 @@ pub async fn doctor_channels(config: Config) -> Result<()> { )); } + #[cfg(not(feature = "channel-matrix"))] + if config.channels_config.matrix.is_some() { + tracing::warn!( + "Matrix channel is configured but this build was compiled without `channel-matrix`; skipping Matrix health check." + ); + } + if let Some(ref sig) = config.channels_config.signal { channels.push(( "Signal", @@ -1377,13 +2073,63 @@ pub async fn doctor_channels(config: Config) -> Result<()> { } if let Some(ref wa) = config.channels_config.whatsapp { + if wa.is_ambiguous_config() { + tracing::warn!( + "WhatsApp config has both phone_number_id and session_path set; preferring Cloud API mode. Remove one selector to avoid ambiguity." + ); + } + // Runtime negotiation: detect backend type from config + match wa.backend_type() { + "cloud" => { + // Cloud API mode: requires phone_number_id, access_token, verify_token + if wa.is_cloud_config() { + channels.push(( + "WhatsApp", + Arc::new(WhatsAppChannel::new( + wa.access_token.clone().unwrap_or_default(), + wa.phone_number_id.clone().unwrap_or_default(), + wa.verify_token.clone().unwrap_or_default(), + wa.allowed_numbers.clone(), + )), + )); + } else { + tracing::warn!("WhatsApp Cloud API configured but missing required fields (phone_number_id, access_token, verify_token)"); + } + } + "web" => { + // Web mode: requires session_path + #[cfg(feature = "whatsapp-web")] + if wa.is_web_config() { + channels.push(( + "WhatsApp", + Arc::new(WhatsAppWebChannel::new( + wa.session_path.clone().unwrap_or_default(), + wa.pair_phone.clone(), + wa.pair_code.clone(), + wa.allowed_numbers.clone(), + )), + )); + } else { + tracing::warn!("WhatsApp Web configured but session_path not set"); + } + #[cfg(not(feature = "whatsapp-web"))] + { + tracing::warn!("WhatsApp Web backend requires 'whatsapp-web' feature. Enable with: cargo build --features whatsapp-web"); + } + } + _ => { + tracing::warn!("WhatsApp config invalid: neither phone_number_id (Cloud API) nor session_path (Web) is set"); + } + } + } + + if let Some(ref lq) = config.channels_config.linq { channels.push(( - "WhatsApp", - Arc::new(WhatsAppChannel::new( - wa.access_token.clone(), - wa.phone_number_id.clone(), - wa.verify_token.clone(), - wa.allowed_numbers.clone(), + "Linq", + Arc::new(LinqChannel::new( + lq.api_token.clone(), + lq.from_phone.clone(), + lq.allowed_senders.clone(), )), )); } @@ -1480,14 +2226,12 @@ pub async fn doctor_channels(config: Config) -> Result<()> { /// Start all configured channels and route messages to the agent #[allow(clippy::too_many_lines)] pub async fn start_channels(config: Config) -> Result<()> { - let provider_name = config - .default_provider - .clone() - .unwrap_or_else(|| "openrouter".into()); + let provider_name = resolved_default_provider(&config); let provider_runtime_options = providers::ProviderRuntimeOptions { auth_profile_override: None, zeroclaw_dir: config.config_path.parent().map(std::path::PathBuf::from), secrets_encrypt: config.secrets.encrypt, + reasoning_enabled: config.runtime.reasoning_enabled, }; let provider: Arc = Arc::from(providers::create_resilient_provider_with_options( &provider_name, @@ -1503,6 +2247,20 @@ pub async fn start_channels(config: Config) -> Result<()> { tracing::warn!("Provider warmup failed (non-fatal): {e}"); } + let initial_stamp = config_file_stamp(&config.config_path).await; + { + let mut store = runtime_config_store() + .lock() + .unwrap_or_else(|e| e.into_inner()); + store.insert( + config.config_path.clone(), + RuntimeConfigState { + defaults: runtime_defaults_from_config(&config), + last_applied_stamp: initial_stamp, + }, + ); + } + let observer: Arc = Arc::from(observability::create_observer(&config.observability)); let runtime: Arc = @@ -1511,10 +2269,7 @@ pub async fn start_channels(config: Config) -> Result<()> { &config.autonomy, &config.workspace_dir, )); - let model = config - .default_model - .clone() - .unwrap_or_else(|| "anthropic/claude-sonnet-4-20250514".into()); + let model = resolved_default_model(&config); let temperature = config.default_temperature; let mem: Arc = Arc::from(memory::create_memory_with_storage( &config.memory, @@ -1547,7 +2302,7 @@ pub async fn start_channels(config: Config) -> Result<()> { &config, )); - let skills = crate::skills::load_skills(&workspace); + let skills = crate::skills::load_skills_with_config(&workspace, &config); // Collect tool descriptions for the prompt let mut tool_descs: Vec<(&str, &str)> = vec![ @@ -1586,7 +2341,7 @@ pub async fn start_channels(config: Config) -> Result<()> { if config.composio.enabled { tool_descs.push(( "composio", - "Execute actions on 1000+ apps via Composio (Gmail, Notion, GitHub, Slack, etc.). Use action='list' to discover, 'execute' to run (optionally with connected_account_id), 'connect' to OAuth.", + "Execute actions on 1000+ apps via Composio (Gmail, Notion, GitHub, Slack, etc.). Use action='list' to discover actions, 'list_accounts' to retrieve connected account IDs, 'execute' to run (optionally with connected_account_id), and 'connect' for OAuth.", )); } tool_descs.push(( @@ -1609,15 +2364,19 @@ pub async fn start_channels(config: Config) -> Result<()> { } else { None }; - let mut system_prompt = build_system_prompt( + let native_tools = provider.supports_native_tools(); + let mut system_prompt = build_system_prompt_with_mode( &workspace, &model, &tool_descs, &skills, Some(&config.identity), bootstrap_max_chars, + native_tools, ); - system_prompt.push_str(&build_tool_instructions(tools_registry.as_ref())); + if !native_tools { + system_prompt.push_str(&build_tool_instructions(tools_registry.as_ref())); + } if !skills.is_empty() { println!( @@ -1677,6 +2436,7 @@ pub async fn start_channels(config: Config) -> Result<()> { channels.push(Arc::new(IMessageChannel::new(im.allowed_contacts.clone()))); } + #[cfg(feature = "channel-matrix")] if let Some(ref mx) = config.channels_config.matrix { channels.push(Arc::new(MatrixChannel::new_with_session_hint( mx.homeserver.clone(), @@ -1688,6 +2448,13 @@ pub async fn start_channels(config: Config) -> Result<()> { ))); } + #[cfg(not(feature = "channel-matrix"))] + if config.channels_config.matrix.is_some() { + tracing::warn!( + "Matrix channel is configured but this build was compiled without `channel-matrix`; skipping Matrix runtime startup." + ); + } + if let Some(ref sig) = config.channels_config.signal { channels.push(Arc::new(SignalChannel::new( sig.http_url.clone(), @@ -1700,11 +2467,55 @@ pub async fn start_channels(config: Config) -> Result<()> { } if let Some(ref wa) = config.channels_config.whatsapp { - channels.push(Arc::new(WhatsAppChannel::new( - wa.access_token.clone(), - wa.phone_number_id.clone(), - wa.verify_token.clone(), - wa.allowed_numbers.clone(), + if wa.is_ambiguous_config() { + tracing::warn!( + "WhatsApp config has both phone_number_id and session_path set; preferring Cloud API mode. Remove one selector to avoid ambiguity." + ); + } + // Runtime negotiation: detect backend type from config + match wa.backend_type() { + "cloud" => { + // Cloud API mode: requires phone_number_id, access_token, verify_token + if wa.is_cloud_config() { + channels.push(Arc::new(WhatsAppChannel::new( + wa.access_token.clone().unwrap_or_default(), + wa.phone_number_id.clone().unwrap_or_default(), + wa.verify_token.clone().unwrap_or_default(), + wa.allowed_numbers.clone(), + ))); + } else { + tracing::warn!("WhatsApp Cloud API configured but missing required fields (phone_number_id, access_token, verify_token)"); + } + } + "web" => { + // Web mode: requires session_path + #[cfg(feature = "whatsapp-web")] + if wa.is_web_config() { + channels.push(Arc::new(WhatsAppWebChannel::new( + wa.session_path.clone().unwrap_or_default(), + wa.pair_phone.clone(), + wa.pair_code.clone(), + wa.allowed_numbers.clone(), + ))); + } else { + tracing::warn!("WhatsApp Web configured but session_path not set"); + } + #[cfg(not(feature = "whatsapp-web"))] + { + tracing::warn!("WhatsApp Web backend requires 'whatsapp-web' feature. Enable with: cargo build --features whatsapp-web"); + } + } + _ => { + tracing::warn!("WhatsApp config invalid: neither phone_number_id (Cloud API) nor session_path (Web) is set"); + } + } + } + + if let Some(ref lq) = config.channels_config.linq { + channels.push(Arc::new(LinqChannel::new( + lq.api_token.clone(), + lq.from_phone.clone(), + lq.allowed_senders.clone(), ))); } @@ -1813,6 +2624,13 @@ pub async fn start_channels(config: Config) -> Result<()> { let mut provider_cache_seed: HashMap> = HashMap::new(); provider_cache_seed.insert(provider_name.clone(), Arc::clone(&provider)); + let message_timeout_secs = + effective_channel_message_timeout_secs(config.channels_config.message_timeout_secs); + let interrupt_on_new_message = config + .channels_config + .telegram + .as_ref() + .is_some_and(|tg| tg.interrupt_on_new_message); let runtime_ctx = Arc::new(ChannelRuntimeContext { channels_by_name, @@ -1835,6 +2653,9 @@ pub async fn start_channels(config: Config) -> Result<()> { reliability: Arc::new(config.reliability.clone()), provider_runtime_options, workspace_dir: Arc::new(config.workspace_dir.clone()), + message_timeout_secs, + interrupt_on_new_message, + multimodal: config.multimodal.clone(), }); run_message_dispatch_loop(rx, runtime_ctx, max_in_flight_messages).await; @@ -1880,6 +2701,171 @@ mod tests { tmp } + #[test] + fn effective_channel_message_timeout_secs_clamps_to_minimum() { + assert_eq!( + effective_channel_message_timeout_secs(0), + MIN_CHANNEL_MESSAGE_TIMEOUT_SECS + ); + assert_eq!( + effective_channel_message_timeout_secs(15), + MIN_CHANNEL_MESSAGE_TIMEOUT_SECS + ); + assert_eq!(effective_channel_message_timeout_secs(300), 300); + } + + #[test] + fn channel_message_timeout_budget_scales_with_tool_iterations() { + assert_eq!(channel_message_timeout_budget_secs(300, 1), 300); + assert_eq!(channel_message_timeout_budget_secs(300, 2), 600); + assert_eq!(channel_message_timeout_budget_secs(300, 3), 900); + } + + #[test] + fn channel_message_timeout_budget_uses_safe_defaults_and_cap() { + // 0 iterations falls back to 1x timeout budget. + assert_eq!(channel_message_timeout_budget_secs(300, 0), 300); + // Large iteration counts are capped to avoid runaway waits. + assert_eq!( + channel_message_timeout_budget_secs(300, 10), + 300 * CHANNEL_MESSAGE_TIMEOUT_SCALE_CAP + ); + } + + #[test] + fn context_window_overflow_error_detector_matches_known_messages() { + let overflow_err = anyhow::anyhow!( + "OpenAI Codex stream error: Your input exceeds the context window of this model." + ); + assert!(is_context_window_overflow_error(&overflow_err)); + + let other_err = + anyhow::anyhow!("OpenAI Codex API error (502 Bad Gateway): error code: 502"); + assert!(!is_context_window_overflow_error(&other_err)); + } + + #[test] + fn memory_context_skip_rules_exclude_history_blobs() { + assert!(should_skip_memory_context_entry( + "telegram_123_history", + r#"[{"role":"user"}]"# + )); + assert!(should_skip_memory_context_entry( + "assistant_resp_legacy", + "fabricated memory" + )); + assert!(!should_skip_memory_context_entry("telegram_123_45", "hi")); + } + + #[test] + fn normalize_cached_channel_turns_merges_consecutive_user_turns() { + let turns = vec![ + ChatMessage::user("forwarded content"), + ChatMessage::user("summarize this"), + ]; + + let normalized = normalize_cached_channel_turns(turns); + assert_eq!(normalized.len(), 1); + assert_eq!(normalized[0].role, "user"); + assert!(normalized[0].content.contains("forwarded content")); + assert!(normalized[0].content.contains("summarize this")); + } + + #[test] + fn normalize_cached_channel_turns_merges_consecutive_assistant_turns() { + let turns = vec![ + ChatMessage::user("first user"), + ChatMessage::assistant("assistant part 1"), + ChatMessage::assistant("assistant part 2"), + ChatMessage::user("next user"), + ]; + + let normalized = normalize_cached_channel_turns(turns); + assert_eq!(normalized.len(), 3); + assert_eq!(normalized[0].role, "user"); + assert_eq!(normalized[1].role, "assistant"); + assert_eq!(normalized[2].role, "user"); + assert!(normalized[1].content.contains("assistant part 1")); + assert!(normalized[1].content.contains("assistant part 2")); + } + + #[test] + fn compact_sender_history_keeps_recent_truncated_messages() { + let mut histories = HashMap::new(); + let sender = "telegram_u1".to_string(); + histories.insert( + sender.clone(), + (0..20) + .map(|idx| { + let content = format!("msg-{idx}-{}", "x".repeat(700)); + if idx % 2 == 0 { + ChatMessage::user(content) + } else { + ChatMessage::assistant(content) + } + }) + .collect::>(), + ); + + let ctx = ChannelRuntimeContext { + channels_by_name: Arc::new(HashMap::new()), + provider: Arc::new(DummyProvider), + default_provider: Arc::new("test-provider".to_string()), + memory: Arc::new(NoopMemory), + tools_registry: Arc::new(vec![]), + observer: Arc::new(NoopObserver), + system_prompt: Arc::new("system".to_string()), + model: Arc::new("test-model".to_string()), + temperature: 0.0, + auto_save_memory: false, + max_tool_iterations: 5, + min_relevance_score: 0.0, + conversation_histories: Arc::new(Mutex::new(histories)), + provider_cache: Arc::new(Mutex::new(HashMap::new())), + route_overrides: Arc::new(Mutex::new(HashMap::new())), + api_key: None, + api_url: None, + reliability: Arc::new(crate::config::ReliabilityConfig::default()), + interrupt_on_new_message: false, + multimodal: crate::config::MultimodalConfig::default(), + provider_runtime_options: providers::ProviderRuntimeOptions::default(), + workspace_dir: Arc::new(std::env::temp_dir()), + message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS, + }; + + assert!(compact_sender_history(&ctx, &sender)); + + let histories = ctx + .conversation_histories + .lock() + .unwrap_or_else(|e| e.into_inner()); + let kept = histories + .get(&sender) + .expect("sender history should remain"); + assert_eq!(kept.len(), CHANNEL_HISTORY_COMPACT_KEEP_MESSAGES); + assert!(kept.iter().all(|turn| { + let len = turn.content.chars().count(); + len <= CHANNEL_HISTORY_COMPACT_CONTENT_CHARS + || (len <= CHANNEL_HISTORY_COMPACT_CONTENT_CHARS + 3 + && turn.content.ends_with("...")) + })); + } + + struct DummyProvider; + + #[async_trait::async_trait] + impl Provider for DummyProvider { + async fn chat_with_system( + &self, + _system_prompt: Option<&str>, + _message: &str, + _model: &str, + _temperature: f64, + ) -> anyhow::Result { + Ok("ok".to_string()) + } + } + #[derive(Default)] struct RecordingChannel { sent_messages: tokio::sync::Mutex>, @@ -2123,6 +3109,43 @@ mod tests { } } + struct DelayedHistoryCaptureProvider { + delay: Duration, + calls: std::sync::Mutex>>, + } + + #[async_trait::async_trait] + impl Provider for DelayedHistoryCaptureProvider { + async fn chat_with_system( + &self, + _system_prompt: Option<&str>, + _message: &str, + _model: &str, + _temperature: f64, + ) -> anyhow::Result { + Ok("fallback".to_string()) + } + + async fn chat_with_history( + &self, + messages: &[ChatMessage], + _model: &str, + _temperature: f64, + ) -> anyhow::Result { + let snapshot = messages + .iter() + .map(|m| (m.role.clone(), m.content.clone())) + .collect::>(); + let call_index = { + let mut calls = self.calls.lock().unwrap_or_else(|e| e.into_inner()); + calls.push(snapshot); + calls.len() + }; + tokio::time::sleep(self.delay).await; + Ok(format!("response-{call_index}")) + } + } + struct MockPriceTool; #[derive(Default)] @@ -2225,6 +3248,9 @@ mod tests { reliability: Arc::new(crate::config::ReliabilityConfig::default()), provider_runtime_options: providers::ProviderRuntimeOptions::default(), workspace_dir: Arc::new(std::env::temp_dir()), + message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS, + interrupt_on_new_message: false, + multimodal: crate::config::MultimodalConfig::default(), }); process_channel_message( @@ -2236,7 +3262,9 @@ mod tests { content: "What is the BTC price now?".to_string(), channel: "test-channel".to_string(), timestamp: 1, + thread_ts: None, }, + CancellationToken::new(), ) .await; @@ -2277,6 +3305,9 @@ mod tests { reliability: Arc::new(crate::config::ReliabilityConfig::default()), provider_runtime_options: providers::ProviderRuntimeOptions::default(), workspace_dir: Arc::new(std::env::temp_dir()), + message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS, + interrupt_on_new_message: false, + multimodal: crate::config::MultimodalConfig::default(), }); process_channel_message( @@ -2288,7 +3319,9 @@ mod tests { content: "What is the BTC price now?".to_string(), channel: "test-channel".to_string(), timestamp: 2, + thread_ts: None, }, + CancellationToken::new(), ) .await; @@ -2338,6 +3371,9 @@ mod tests { reliability: Arc::new(crate::config::ReliabilityConfig::default()), provider_runtime_options: providers::ProviderRuntimeOptions::default(), workspace_dir: Arc::new(std::env::temp_dir()), + message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS, + interrupt_on_new_message: false, + multimodal: crate::config::MultimodalConfig::default(), }); process_channel_message( @@ -2349,7 +3385,9 @@ mod tests { content: "/models openrouter".to_string(), channel: "telegram".to_string(), timestamp: 1, + thread_ts: None, }, + CancellationToken::new(), ) .await; @@ -2420,6 +3458,9 @@ mod tests { reliability: Arc::new(crate::config::ReliabilityConfig::default()), provider_runtime_options: providers::ProviderRuntimeOptions::default(), workspace_dir: Arc::new(std::env::temp_dir()), + message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS, + interrupt_on_new_message: false, + multimodal: crate::config::MultimodalConfig::default(), }); process_channel_message( @@ -2431,7 +3472,9 @@ mod tests { content: "hello routed provider".to_string(), channel: "telegram".to_string(), timestamp: 2, + thread_ts: None, }, + CancellationToken::new(), ) .await; @@ -2447,6 +3490,165 @@ mod tests { ); } + #[tokio::test] + async fn process_channel_message_prefers_cached_default_provider_instance() { + let channel_impl = Arc::new(TelegramRecordingChannel::default()); + let channel: Arc = channel_impl.clone(); + + let mut channels_by_name = HashMap::new(); + channels_by_name.insert(channel.name().to_string(), channel); + + let startup_provider_impl = Arc::new(ModelCaptureProvider::default()); + let startup_provider: Arc = startup_provider_impl.clone(); + let reloaded_provider_impl = Arc::new(ModelCaptureProvider::default()); + let reloaded_provider: Arc = reloaded_provider_impl.clone(); + + let mut provider_cache_seed: HashMap> = HashMap::new(); + provider_cache_seed.insert("test-provider".to_string(), reloaded_provider); + + let runtime_ctx = Arc::new(ChannelRuntimeContext { + channels_by_name: Arc::new(channels_by_name), + provider: Arc::clone(&startup_provider), + default_provider: Arc::new("test-provider".to_string()), + memory: Arc::new(NoopMemory), + tools_registry: Arc::new(vec![]), + observer: Arc::new(NoopObserver), + system_prompt: Arc::new("test-system-prompt".to_string()), + model: Arc::new("default-model".to_string()), + temperature: 0.0, + auto_save_memory: false, + max_tool_iterations: 5, + min_relevance_score: 0.0, + conversation_histories: Arc::new(Mutex::new(HashMap::new())), + provider_cache: Arc::new(Mutex::new(provider_cache_seed)), + route_overrides: Arc::new(Mutex::new(HashMap::new())), + api_key: None, + api_url: None, + reliability: Arc::new(crate::config::ReliabilityConfig::default()), + provider_runtime_options: providers::ProviderRuntimeOptions::default(), + workspace_dir: Arc::new(std::env::temp_dir()), + message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS, + interrupt_on_new_message: false, + multimodal: crate::config::MultimodalConfig::default(), + }); + + process_channel_message( + runtime_ctx, + traits::ChannelMessage { + id: "msg-default-provider-cache".to_string(), + sender: "alice".to_string(), + reply_target: "chat-1".to_string(), + content: "hello cached default provider".to_string(), + channel: "telegram".to_string(), + timestamp: 3, + thread_ts: None, + }, + CancellationToken::new(), + ) + .await; + + assert_eq!(startup_provider_impl.call_count.load(Ordering::SeqCst), 0); + assert_eq!(reloaded_provider_impl.call_count.load(Ordering::SeqCst), 1); + } + + #[tokio::test] + async fn process_channel_message_uses_runtime_default_model_from_store() { + let channel_impl = Arc::new(TelegramRecordingChannel::default()); + let channel: Arc = channel_impl.clone(); + + let mut channels_by_name = HashMap::new(); + channels_by_name.insert(channel.name().to_string(), channel); + + let provider_impl = Arc::new(ModelCaptureProvider::default()); + let provider: Arc = provider_impl.clone(); + let mut provider_cache_seed: HashMap> = HashMap::new(); + provider_cache_seed.insert("test-provider".to_string(), Arc::clone(&provider)); + + let temp = tempfile::TempDir::new().expect("temp dir"); + let config_path = temp.path().join("config.toml"); + + { + let mut store = runtime_config_store() + .lock() + .unwrap_or_else(|e| e.into_inner()); + store.insert( + config_path.clone(), + RuntimeConfigState { + defaults: ChannelRuntimeDefaults { + default_provider: "test-provider".to_string(), + model: "hot-reloaded-model".to_string(), + temperature: 0.5, + api_key: None, + api_url: None, + reliability: crate::config::ReliabilityConfig::default(), + }, + last_applied_stamp: None, + }, + ); + } + + let runtime_ctx = Arc::new(ChannelRuntimeContext { + channels_by_name: Arc::new(channels_by_name), + provider: Arc::clone(&provider), + default_provider: Arc::new("test-provider".to_string()), + memory: Arc::new(NoopMemory), + tools_registry: Arc::new(vec![]), + observer: Arc::new(NoopObserver), + system_prompt: Arc::new("test-system-prompt".to_string()), + model: Arc::new("startup-model".to_string()), + temperature: 0.0, + auto_save_memory: false, + max_tool_iterations: 5, + min_relevance_score: 0.0, + conversation_histories: Arc::new(Mutex::new(HashMap::new())), + provider_cache: Arc::new(Mutex::new(provider_cache_seed)), + route_overrides: Arc::new(Mutex::new(HashMap::new())), + api_key: None, + api_url: None, + reliability: Arc::new(crate::config::ReliabilityConfig::default()), + provider_runtime_options: providers::ProviderRuntimeOptions { + zeroclaw_dir: Some(temp.path().to_path_buf()), + ..providers::ProviderRuntimeOptions::default() + }, + workspace_dir: Arc::new(std::env::temp_dir()), + message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS, + interrupt_on_new_message: false, + multimodal: crate::config::MultimodalConfig::default(), + }); + + process_channel_message( + runtime_ctx, + traits::ChannelMessage { + id: "msg-runtime-store-model".to_string(), + sender: "alice".to_string(), + reply_target: "chat-1".to_string(), + content: "hello runtime defaults".to_string(), + channel: "telegram".to_string(), + timestamp: 4, + thread_ts: None, + }, + CancellationToken::new(), + ) + .await; + + { + let mut store = runtime_config_store() + .lock() + .unwrap_or_else(|e| e.into_inner()); + store.remove(&config_path); + } + + assert_eq!(provider_impl.call_count.load(Ordering::SeqCst), 1); + assert_eq!( + provider_impl + .models + .lock() + .unwrap_or_else(|e| e.into_inner()) + .as_slice(), + &["hot-reloaded-model".to_string()] + ); + } + #[tokio::test] async fn process_channel_message_respects_configured_max_tool_iterations_above_default() { let channel_impl = Arc::new(RecordingChannel::default()); @@ -2478,6 +3680,9 @@ mod tests { reliability: Arc::new(crate::config::ReliabilityConfig::default()), provider_runtime_options: providers::ProviderRuntimeOptions::default(), workspace_dir: Arc::new(std::env::temp_dir()), + message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS, + interrupt_on_new_message: false, + multimodal: crate::config::MultimodalConfig::default(), }); process_channel_message( @@ -2489,7 +3694,9 @@ mod tests { content: "Loop until done".to_string(), channel: "test-channel".to_string(), timestamp: 1, + thread_ts: None, }, + CancellationToken::new(), ) .await; @@ -2531,6 +3738,9 @@ mod tests { reliability: Arc::new(crate::config::ReliabilityConfig::default()), provider_runtime_options: providers::ProviderRuntimeOptions::default(), workspace_dir: Arc::new(std::env::temp_dir()), + message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS, + interrupt_on_new_message: false, + multimodal: crate::config::MultimodalConfig::default(), }); process_channel_message( @@ -2542,7 +3752,9 @@ mod tests { content: "Loop forever".to_string(), channel: "test-channel".to_string(), timestamp: 2, + thread_ts: None, }, + CancellationToken::new(), ) .await; @@ -2604,6 +3816,66 @@ mod tests { } } + struct RecallMemory; + + #[async_trait::async_trait] + impl Memory for RecallMemory { + fn name(&self) -> &str { + "recall-memory" + } + + async fn store( + &self, + _key: &str, + _content: &str, + _category: crate::memory::MemoryCategory, + _session_id: Option<&str>, + ) -> anyhow::Result<()> { + Ok(()) + } + + async fn recall( + &self, + _query: &str, + _limit: usize, + _session_id: Option<&str>, + ) -> anyhow::Result> { + Ok(vec![crate::memory::MemoryEntry { + id: "entry-1".to_string(), + key: "memory_key_1".to_string(), + content: "Age is 45".to_string(), + category: crate::memory::MemoryCategory::Conversation, + timestamp: "2026-02-20T00:00:00Z".to_string(), + session_id: None, + score: Some(0.9), + }]) + } + + async fn get(&self, _key: &str) -> anyhow::Result> { + Ok(None) + } + + async fn list( + &self, + _category: Option<&crate::memory::MemoryCategory>, + _session_id: Option<&str>, + ) -> anyhow::Result> { + Ok(Vec::new()) + } + + async fn forget(&self, _key: &str) -> anyhow::Result { + Ok(false) + } + + async fn count(&self) -> anyhow::Result { + Ok(1) + } + + async fn health_check(&self) -> bool { + true + } + } + #[tokio::test] async fn message_dispatch_processes_messages_in_parallel() { let channel_impl = Arc::new(RecordingChannel::default()); @@ -2635,6 +3907,9 @@ mod tests { reliability: Arc::new(crate::config::ReliabilityConfig::default()), provider_runtime_options: providers::ProviderRuntimeOptions::default(), workspace_dir: Arc::new(std::env::temp_dir()), + message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS, + interrupt_on_new_message: false, + multimodal: crate::config::MultimodalConfig::default(), }); let (tx, rx) = tokio::sync::mpsc::channel::(4); @@ -2645,6 +3920,7 @@ mod tests { content: "hello".to_string(), channel: "test-channel".to_string(), timestamp: 1, + thread_ts: None, }) .await .unwrap(); @@ -2655,6 +3931,7 @@ mod tests { content: "world".to_string(), channel: "test-channel".to_string(), timestamp: 2, + thread_ts: None, }) .await .unwrap(); @@ -2674,6 +3951,171 @@ mod tests { assert_eq!(sent_messages.len(), 2); } + #[tokio::test] + async fn message_dispatch_interrupts_in_flight_telegram_request_and_preserves_context() { + let channel_impl = Arc::new(TelegramRecordingChannel::default()); + let channel: Arc = channel_impl.clone(); + + let mut channels_by_name = HashMap::new(); + channels_by_name.insert(channel.name().to_string(), channel); + + let provider_impl = Arc::new(DelayedHistoryCaptureProvider { + delay: Duration::from_millis(250), + calls: std::sync::Mutex::new(Vec::new()), + }); + + let runtime_ctx = Arc::new(ChannelRuntimeContext { + channels_by_name: Arc::new(channels_by_name), + provider: provider_impl.clone(), + default_provider: Arc::new("test-provider".to_string()), + memory: Arc::new(NoopMemory), + tools_registry: Arc::new(vec![]), + observer: Arc::new(NoopObserver), + system_prompt: Arc::new("test-system-prompt".to_string()), + model: Arc::new("test-model".to_string()), + temperature: 0.0, + auto_save_memory: false, + max_tool_iterations: 10, + min_relevance_score: 0.0, + conversation_histories: Arc::new(Mutex::new(HashMap::new())), + provider_cache: Arc::new(Mutex::new(HashMap::new())), + route_overrides: Arc::new(Mutex::new(HashMap::new())), + api_key: None, + api_url: None, + reliability: Arc::new(crate::config::ReliabilityConfig::default()), + provider_runtime_options: providers::ProviderRuntimeOptions::default(), + workspace_dir: Arc::new(std::env::temp_dir()), + message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS, + interrupt_on_new_message: true, + multimodal: crate::config::MultimodalConfig::default(), + }); + + let (tx, rx) = tokio::sync::mpsc::channel::(8); + let send_task = tokio::spawn(async move { + tx.send(traits::ChannelMessage { + id: "msg-1".to_string(), + sender: "alice".to_string(), + reply_target: "chat-1".to_string(), + content: "forwarded content".to_string(), + channel: "telegram".to_string(), + timestamp: 1, + thread_ts: None, + }) + .await + .unwrap(); + tokio::time::sleep(Duration::from_millis(40)).await; + tx.send(traits::ChannelMessage { + id: "msg-2".to_string(), + sender: "alice".to_string(), + reply_target: "chat-1".to_string(), + content: "summarize this".to_string(), + channel: "telegram".to_string(), + timestamp: 2, + thread_ts: None, + }) + .await + .unwrap(); + }); + + run_message_dispatch_loop(rx, runtime_ctx, 4).await; + send_task.await.unwrap(); + + let sent_messages = channel_impl.sent_messages.lock().await; + assert_eq!(sent_messages.len(), 1); + assert!(sent_messages[0].starts_with("chat-1:")); + assert!(sent_messages[0].contains("response-2")); + drop(sent_messages); + + let calls = provider_impl + .calls + .lock() + .unwrap_or_else(|e| e.into_inner()); + assert_eq!(calls.len(), 2); + let second_call = &calls[1]; + assert!(second_call + .iter() + .any(|(role, content)| { role == "user" && content.contains("forwarded content") })); + assert!(second_call + .iter() + .any(|(role, content)| { role == "user" && content.contains("summarize this") })); + assert!( + !second_call.iter().any(|(role, _)| role == "assistant"), + "cancelled turn should not persist an assistant response" + ); + } + + #[tokio::test] + async fn message_dispatch_interrupt_scope_is_same_sender_same_chat() { + let channel_impl = Arc::new(TelegramRecordingChannel::default()); + let channel: Arc = channel_impl.clone(); + + let mut channels_by_name = HashMap::new(); + channels_by_name.insert(channel.name().to_string(), channel); + + let runtime_ctx = Arc::new(ChannelRuntimeContext { + channels_by_name: Arc::new(channels_by_name), + provider: Arc::new(SlowProvider { + delay: Duration::from_millis(180), + }), + default_provider: Arc::new("test-provider".to_string()), + memory: Arc::new(NoopMemory), + tools_registry: Arc::new(vec![]), + observer: Arc::new(NoopObserver), + system_prompt: Arc::new("test-system-prompt".to_string()), + model: Arc::new("test-model".to_string()), + temperature: 0.0, + auto_save_memory: false, + max_tool_iterations: 10, + min_relevance_score: 0.0, + conversation_histories: Arc::new(Mutex::new(HashMap::new())), + provider_cache: Arc::new(Mutex::new(HashMap::new())), + route_overrides: Arc::new(Mutex::new(HashMap::new())), + api_key: None, + api_url: None, + reliability: Arc::new(crate::config::ReliabilityConfig::default()), + provider_runtime_options: providers::ProviderRuntimeOptions::default(), + workspace_dir: Arc::new(std::env::temp_dir()), + message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS, + interrupt_on_new_message: true, + multimodal: crate::config::MultimodalConfig::default(), + }); + + let (tx, rx) = tokio::sync::mpsc::channel::(8); + let send_task = tokio::spawn(async move { + tx.send(traits::ChannelMessage { + id: "msg-a".to_string(), + sender: "alice".to_string(), + reply_target: "chat-1".to_string(), + content: "first chat".to_string(), + channel: "telegram".to_string(), + timestamp: 1, + thread_ts: None, + }) + .await + .unwrap(); + tokio::time::sleep(Duration::from_millis(30)).await; + tx.send(traits::ChannelMessage { + id: "msg-b".to_string(), + sender: "alice".to_string(), + reply_target: "chat-2".to_string(), + content: "second chat".to_string(), + channel: "telegram".to_string(), + timestamp: 2, + thread_ts: None, + }) + .await + .unwrap(); + }); + + run_message_dispatch_loop(rx, runtime_ctx, 4).await; + send_task.await.unwrap(); + + let sent_messages = channel_impl.sent_messages.lock().await; + assert_eq!(sent_messages.len(), 2); + assert!(sent_messages.iter().any(|msg| msg.starts_with("chat-1:"))); + assert!(sent_messages.iter().any(|msg| msg.starts_with("chat-2:"))); + } + #[tokio::test] async fn process_channel_message_cancels_scoped_typing_task() { let channel_impl = Arc::new(RecordingChannel::default()); @@ -2705,6 +4147,9 @@ mod tests { reliability: Arc::new(crate::config::ReliabilityConfig::default()), provider_runtime_options: providers::ProviderRuntimeOptions::default(), workspace_dir: Arc::new(std::env::temp_dir()), + message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS, + interrupt_on_new_message: false, + multimodal: crate::config::MultimodalConfig::default(), }); process_channel_message( @@ -2716,7 +4161,9 @@ mod tests { content: "hello".to_string(), channel: "test-channel".to_string(), timestamp: 1, + thread_ts: None, }, + CancellationToken::new(), ) .await; @@ -2761,6 +4208,26 @@ mod tests { assert!(prompt.contains("**memory_recall**")); } + #[test] + fn prompt_includes_single_tool_protocol_block_after_append() { + let ws = make_workspace(); + let tools = vec![("shell", "Run commands")]; + let mut prompt = build_system_prompt(ws.path(), "gpt-4o", &tools, &[], None, None); + + assert!( + !prompt.contains("## Tool Use Protocol"), + "build_system_prompt should not emit protocol block directly" + ); + + prompt.push_str(&build_tool_instructions(&[])); + + assert_eq!( + prompt.matches("## Tool Use Protocol").count(), + 1, + "protocol block should appear exactly once in the final prompt" + ); + } + #[test] fn prompt_injects_safety() { let ws = make_workspace(); @@ -2864,7 +4331,7 @@ mod tests { } #[test] - fn prompt_skills_compact_list() { + fn prompt_skills_include_instructions_and_tools() { let ws = make_workspace(); let skills = vec![crate::skills::Skill { name: "code-review".into(), @@ -2872,8 +4339,14 @@ mod tests { version: "1.0.0".into(), author: None, tags: vec![], - tools: vec![], - prompts: vec!["Long prompt content that should NOT appear in system prompt".into()], + tools: vec![crate::skills::SkillTool { + name: "lint".into(), + description: "Run static checks".into(), + kind: "shell".into(), + command: "cargo clippy".into(), + args: HashMap::new(), + }], + prompts: vec!["Always run cargo test before final response.".into()], location: None, }]; @@ -2883,12 +4356,47 @@ mod tests { assert!(prompt.contains("code-review")); assert!(prompt.contains("Review code for bugs")); assert!(prompt.contains("SKILL.md")); - assert!( - prompt.contains("loaded on demand"), - "should mention on-demand loading" - ); - // Full prompt content should NOT be dumped - assert!(!prompt.contains("Long prompt content that should NOT appear")); + assert!(prompt.contains("")); + assert!(prompt + .contains("Always run cargo test before final response.")); + assert!(prompt.contains("")); + assert!(prompt.contains("lint")); + assert!(prompt.contains("shell")); + assert!(!prompt.contains("loaded on demand")); + } + + #[test] + fn prompt_skills_escape_reserved_xml_chars() { + let ws = make_workspace(); + let skills = vec![crate::skills::Skill { + name: "code&".into(), + description: "Review \"unsafe\" and 'risky' bits".into(), + version: "1.0.0".into(), + author: None, + tags: vec![], + tools: vec![crate::skills::SkillTool { + name: "run\"linter\"".into(), + description: "Run & report".into(), + kind: "shell&exec".into(), + command: "cargo clippy".into(), + args: HashMap::new(), + }], + prompts: vec!["Use and & keep output \"safe\"".into()], + location: None, + }]; + + let prompt = build_system_prompt(ws.path(), "model", &[], &skills, None, None); + + assert!(prompt.contains("code<review>&")); + assert!(prompt.contains( + "Review "unsafe" and 'risky' bits" + )); + assert!(prompt.contains("run"linter"")); + assert!(prompt.contains("Run <lint> & report")); + assert!(prompt.contains("shell&exec")); + assert!(prompt.contains( + "Use <tool_call> and & keep output "safe"" + )); } #[test] @@ -2950,8 +4458,8 @@ mod tests { "missing Channel Capabilities section" ); assert!( - prompt.contains("running as a Discord bot"), - "missing Discord context" + prompt.contains("running as a messaging bot"), + "missing channel context" ); assert!( prompt.contains("NEVER repeat, describe, or echo credentials"), @@ -2976,6 +4484,7 @@ mod tests { content: "hello".into(), channel: "slack".into(), timestamp: 1, + thread_ts: None, }; assert_eq!(conversation_memory_key(&msg), "slack_U123_msg_abc123"); @@ -2990,6 +4499,7 @@ mod tests { content: "first".into(), channel: "slack".into(), timestamp: 1, + thread_ts: None, }; let msg2 = traits::ChannelMessage { id: "msg_2".into(), @@ -2998,6 +4508,7 @@ mod tests { content: "second".into(), channel: "slack".into(), timestamp: 2, + thread_ts: None, }; assert_ne!( @@ -3018,6 +4529,7 @@ mod tests { content: "I'm Paul".into(), channel: "slack".into(), timestamp: 1, + thread_ts: None, }; let msg2 = traits::ChannelMessage { id: "msg_2".into(), @@ -3026,6 +4538,7 @@ mod tests { content: "I'm 45".into(), channel: "slack".into(), timestamp: 2, + thread_ts: None, }; mem.store( @@ -3095,6 +4608,9 @@ mod tests { reliability: Arc::new(crate::config::ReliabilityConfig::default()), provider_runtime_options: providers::ProviderRuntimeOptions::default(), workspace_dir: Arc::new(std::env::temp_dir()), + message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS, + interrupt_on_new_message: false, + multimodal: crate::config::MultimodalConfig::default(), }); process_channel_message( @@ -3106,7 +4622,9 @@ mod tests { content: "hello".to_string(), channel: "test-channel".to_string(), timestamp: 1, + thread_ts: None, }, + CancellationToken::new(), ) .await; @@ -3119,7 +4637,9 @@ mod tests { content: "follow up".to_string(), channel: "test-channel".to_string(), timestamp: 2, + thread_ts: None, }, + CancellationToken::new(), ) .await; @@ -3141,6 +4661,217 @@ mod tests { assert!(calls[1][3].1.contains("follow up")); } + #[tokio::test] + async fn process_channel_message_enriches_current_turn_without_persisting_context() { + let channel_impl = Arc::new(RecordingChannel::default()); + let channel: Arc = channel_impl.clone(); + + let mut channels_by_name = HashMap::new(); + channels_by_name.insert(channel.name().to_string(), channel); + + let provider_impl = Arc::new(HistoryCaptureProvider::default()); + let runtime_ctx = Arc::new(ChannelRuntimeContext { + channels_by_name: Arc::new(channels_by_name), + provider: provider_impl.clone(), + default_provider: Arc::new("test-provider".to_string()), + memory: Arc::new(RecallMemory), + tools_registry: Arc::new(vec![]), + observer: Arc::new(NoopObserver), + system_prompt: Arc::new("test-system-prompt".to_string()), + model: Arc::new("test-model".to_string()), + temperature: 0.0, + auto_save_memory: false, + max_tool_iterations: 5, + min_relevance_score: 0.0, + conversation_histories: Arc::new(Mutex::new(HashMap::new())), + provider_cache: Arc::new(Mutex::new(HashMap::new())), + route_overrides: Arc::new(Mutex::new(HashMap::new())), + api_key: None, + api_url: None, + reliability: Arc::new(crate::config::ReliabilityConfig::default()), + provider_runtime_options: providers::ProviderRuntimeOptions::default(), + workspace_dir: Arc::new(std::env::temp_dir()), + message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS, + interrupt_on_new_message: false, + multimodal: crate::config::MultimodalConfig::default(), + }); + + process_channel_message( + runtime_ctx.clone(), + traits::ChannelMessage { + id: "msg-ctx-1".to_string(), + sender: "alice".to_string(), + reply_target: "chat-ctx".to_string(), + content: "hello".to_string(), + channel: "test-channel".to_string(), + timestamp: 1, + thread_ts: None, + }, + CancellationToken::new(), + ) + .await; + + let calls = provider_impl + .calls + .lock() + .unwrap_or_else(|e| e.into_inner()); + assert_eq!(calls.len(), 1); + assert_eq!(calls[0].len(), 2); + assert_eq!(calls[0][1].0, "user"); + assert!(calls[0][1].1.contains("[Memory context]")); + assert!(calls[0][1].1.contains("Age is 45")); + assert!(calls[0][1].1.contains("hello")); + + let histories = runtime_ctx + .conversation_histories + .lock() + .unwrap_or_else(|e| e.into_inner()); + let turns = histories + .get("test-channel_alice") + .expect("history should be stored for sender"); + assert_eq!(turns[0].role, "user"); + assert_eq!(turns[0].content, "hello"); + assert!(!turns[0].content.contains("[Memory context]")); + } + + #[tokio::test] + async fn process_channel_message_telegram_keeps_system_instruction_at_top_only() { + let channel_impl = Arc::new(TelegramRecordingChannel::default()); + let channel: Arc = channel_impl.clone(); + + let mut channels_by_name = HashMap::new(); + channels_by_name.insert(channel.name().to_string(), channel); + + let provider_impl = Arc::new(HistoryCaptureProvider::default()); + let mut histories = HashMap::new(); + histories.insert( + "telegram_alice".to_string(), + vec![ + ChatMessage::assistant("stale assistant"), + ChatMessage::user("earlier user question"), + ChatMessage::assistant("earlier assistant reply"), + ], + ); + + let runtime_ctx = Arc::new(ChannelRuntimeContext { + channels_by_name: Arc::new(channels_by_name), + provider: provider_impl.clone(), + default_provider: Arc::new("test-provider".to_string()), + memory: Arc::new(NoopMemory), + tools_registry: Arc::new(vec![]), + observer: Arc::new(NoopObserver), + system_prompt: Arc::new("test-system-prompt".to_string()), + model: Arc::new("test-model".to_string()), + temperature: 0.0, + auto_save_memory: false, + max_tool_iterations: 5, + min_relevance_score: 0.0, + conversation_histories: Arc::new(Mutex::new(histories)), + provider_cache: Arc::new(Mutex::new(HashMap::new())), + route_overrides: Arc::new(Mutex::new(HashMap::new())), + api_key: None, + api_url: None, + reliability: Arc::new(crate::config::ReliabilityConfig::default()), + provider_runtime_options: providers::ProviderRuntimeOptions::default(), + workspace_dir: Arc::new(std::env::temp_dir()), + message_timeout_secs: CHANNEL_MESSAGE_TIMEOUT_SECS, + interrupt_on_new_message: false, + multimodal: crate::config::MultimodalConfig::default(), + }); + + process_channel_message( + runtime_ctx.clone(), + traits::ChannelMessage { + id: "tg-msg-1".to_string(), + sender: "alice".to_string(), + reply_target: "chat-telegram".to_string(), + content: "hello".to_string(), + channel: "telegram".to_string(), + timestamp: 1, + thread_ts: None, + }, + CancellationToken::new(), + ) + .await; + + let calls = provider_impl + .calls + .lock() + .unwrap_or_else(|e| e.into_inner()); + assert_eq!(calls.len(), 1); + assert_eq!(calls[0].len(), 4); + + let roles = calls[0] + .iter() + .map(|(role, _)| role.as_str()) + .collect::>(); + assert_eq!(roles, vec!["system", "user", "assistant", "user"]); + assert!( + calls[0][0] + .1 + .contains("When responding on Telegram, include media markers"), + "telegram delivery instruction should live in the system prompt" + ); + assert!(!calls[0].iter().skip(1).any(|(role, _)| role == "system")); + } + + #[test] + fn extract_tool_context_summary_collects_alias_and_native_tool_calls() { + let history = vec![ + ChatMessage::system("sys"), + ChatMessage::assistant( + r#" +{"name":"shell","arguments":{"command":"date"}} +"#, + ), + ChatMessage::assistant( + r#"{"content":null,"tool_calls":[{"id":"1","name":"web_search","arguments":"{}"}]}"#, + ), + ]; + + let summary = extract_tool_context_summary(&history, 1); + assert_eq!(summary, "[Used tools: shell, web_search]"); + } + + #[test] + fn extract_tool_context_summary_collects_prompt_mode_tool_result_names() { + let history = vec![ + ChatMessage::system("sys"), + ChatMessage::assistant("Using markdown tool call fence"), + ChatMessage::user( + r#"[Tool results] + +{"status":200} + + +Mon Feb 20 +"#, + ), + ]; + + let summary = extract_tool_context_summary(&history, 1); + assert_eq!(summary, "[Used tools: http_request, shell]"); + } + + #[test] + fn extract_tool_context_summary_respects_start_index() { + let history = vec![ + ChatMessage::assistant( + r#" +{"name":"stale_tool","arguments":{}} +"#, + ), + ChatMessage::assistant( + r#" +{"name":"fresh_tool","arguments":{}} +"#, + ), + ]; + + let summary = extract_tool_context_summary(&history, 1); + assert_eq!(summary, "[Used tools: fresh_tool]"); + } + // ── AIEOS Identity Tests (Issue #168) ───────────────────────── #[test] diff --git a/src/channels/qq.rs b/src/channels/qq.rs index 70dc20d..18117ef 100644 --- a/src/channels/qq.rs +++ b/src/channels/qq.rs @@ -11,6 +11,15 @@ use uuid::Uuid; const QQ_API_BASE: &str = "https://api.sgroup.qq.com"; const QQ_AUTH_URL: &str = "https://bots.qq.com/app/getAppAccessToken"; +fn ensure_https(url: &str) -> anyhow::Result<()> { + if !url.starts_with("https://") { + anyhow::bail!( + "Refusing to transmit sensitive data over non-HTTPS URL: URL scheme must be https" + ); + } + Ok(()) +} + /// Deduplication set capacity — evict half of entries when full. const DEDUP_CAPACITY: usize = 10_000; @@ -196,6 +205,8 @@ impl Channel for QQChannel { ) }; + ensure_https(&url)?; + let resp = self .http_client() .post(&url) @@ -252,7 +263,9 @@ impl Channel for QQChannel { } } }); - write.send(Message::Text(identify.to_string())).await?; + write + .send(Message::Text(identify.to_string().into())) + .await?; tracing::info!("QQ: connected and identified"); @@ -276,7 +289,11 @@ impl Channel for QQChannel { _ = hb_rx.recv() => { let d = if sequence >= 0 { json!(sequence) } else { json!(null) }; let hb = json!({"op": 1, "d": d}); - if write.send(Message::Text(hb.to_string())).await.is_err() { + if write + .send(Message::Text(hb.to_string().into())) + .await + .is_err() + { break; } } @@ -287,7 +304,7 @@ impl Channel for QQChannel { _ => continue, }; - let event: serde_json::Value = match serde_json::from_str(&msg) { + let event: serde_json::Value = match serde_json::from_str(msg.as_ref()) { Ok(e) => e, Err(_) => continue, }; @@ -304,7 +321,11 @@ impl Channel for QQChannel { 1 => { let d = if sequence >= 0 { json!(sequence) } else { json!(null) }; let hb = json!({"op": 1, "d": d}); - if write.send(Message::Text(hb.to_string())).await.is_err() { + if write + .send(Message::Text(hb.to_string().into())) + .await + .is_err() + { break; } continue; @@ -366,6 +387,7 @@ impl Channel for QQChannel { .duration_since(std::time::UNIX_EPOCH) .unwrap_or_default() .as_secs(), + thread_ts: None, }; if tx.send(channel_msg).await.is_err() { @@ -404,6 +426,7 @@ impl Channel for QQChannel { .duration_since(std::time::UNIX_EPOCH) .unwrap_or_default() .as_secs(), + thread_ts: None, }; if tx.send(channel_msg).await.is_err() { diff --git a/src/channels/signal.rs b/src/channels/signal.rs index e759a1a..20cacfc 100644 --- a/src/channels/signal.rs +++ b/src/channels/signal.rs @@ -119,12 +119,18 @@ impl SignalChannel { (2..=15).contains(&number.len()) && number.chars().all(|c| c.is_ascii_digit()) } + /// Check whether a string is a valid UUID (signal-cli uses these for + /// privacy-enabled users who have opted out of sharing their phone number). + fn is_uuid(s: &str) -> bool { + Uuid::parse_str(s).is_ok() + } + fn parse_recipient_target(recipient: &str) -> RecipientTarget { if let Some(group_id) = recipient.strip_prefix(GROUP_TARGET_PREFIX) { return RecipientTarget::Group(group_id.to_string()); } - if Self::is_e164(recipient) { + if Self::is_e164(recipient) || Self::is_uuid(recipient) { RecipientTarget::Direct(recipient.to_string()) } else { RecipientTarget::Group(recipient.to_string()) @@ -259,6 +265,7 @@ impl SignalChannel { content: text.to_string(), channel: "signal".to_string(), timestamp: timestamp / 1000, // millis → secs + thread_ts: None, }) } } @@ -653,6 +660,15 @@ mod tests { ); } + #[test] + fn parse_recipient_target_uuid_is_direct() { + let uuid = "a1b2c3d4-e5f6-7890-abcd-ef1234567890"; + assert_eq!( + SignalChannel::parse_recipient_target(uuid), + RecipientTarget::Direct(uuid.to_string()) + ); + } + #[test] fn parse_recipient_target_non_e164_plus_is_group() { assert_eq!( @@ -661,6 +677,24 @@ mod tests { ); } + #[test] + fn is_uuid_valid() { + assert!(SignalChannel::is_uuid( + "a1b2c3d4-e5f6-7890-abcd-ef1234567890" + )); + assert!(SignalChannel::is_uuid( + "00000000-0000-0000-0000-000000000000" + )); + } + + #[test] + fn is_uuid_invalid() { + assert!(!SignalChannel::is_uuid("+1234567890")); + assert!(!SignalChannel::is_uuid("not-a-uuid")); + assert!(!SignalChannel::is_uuid("group:abc123")); + assert!(!SignalChannel::is_uuid("")); + } + #[test] fn sender_prefers_source_number() { let env = Envelope { @@ -685,6 +719,73 @@ mod tests { assert_eq!(SignalChannel::sender(&env), Some("uuid-123".to_string())); } + #[test] + fn process_envelope_uuid_sender_dm() { + let uuid = "a1b2c3d4-e5f6-7890-abcd-ef1234567890"; + let ch = SignalChannel::new( + "http://127.0.0.1:8686".to_string(), + "+1234567890".to_string(), + None, + vec!["*".to_string()], + false, + false, + ); + let env = Envelope { + source: Some(uuid.to_string()), + source_number: None, + data_message: Some(DataMessage { + message: Some("Hello from privacy user".to_string()), + timestamp: Some(1_700_000_000_000), + group_info: None, + attachments: None, + }), + story_message: None, + timestamp: Some(1_700_000_000_000), + }; + let msg = ch.process_envelope(&env).unwrap(); + assert_eq!(msg.sender, uuid); + assert_eq!(msg.reply_target, uuid); + assert_eq!(msg.content, "Hello from privacy user"); + + // Verify reply routing: UUID sender in DM should route as Direct + let target = SignalChannel::parse_recipient_target(&msg.reply_target); + assert_eq!(target, RecipientTarget::Direct(uuid.to_string())); + } + + #[test] + fn process_envelope_uuid_sender_in_group() { + let uuid = "a1b2c3d4-e5f6-7890-abcd-ef1234567890"; + let ch = SignalChannel::new( + "http://127.0.0.1:8686".to_string(), + "+1234567890".to_string(), + Some("testgroup".to_string()), + vec!["*".to_string()], + false, + false, + ); + let env = Envelope { + source: Some(uuid.to_string()), + source_number: None, + data_message: Some(DataMessage { + message: Some("Group msg from privacy user".to_string()), + timestamp: Some(1_700_000_000_000), + group_info: Some(GroupInfo { + group_id: Some("testgroup".to_string()), + }), + attachments: None, + }), + story_message: None, + timestamp: Some(1_700_000_000_000), + }; + let msg = ch.process_envelope(&env).unwrap(); + assert_eq!(msg.sender, uuid); + assert_eq!(msg.reply_target, "group:testgroup"); + + // Verify reply routing: group message should still route as Group + let target = SignalChannel::parse_recipient_target(&msg.reply_target); + assert_eq!(target, RecipientTarget::Group("testgroup".to_string())); + } + #[test] fn sender_none_when_both_missing() { let env = Envelope { diff --git a/src/channels/slack.rs b/src/channels/slack.rs index 13d1273..559af15 100644 --- a/src/channels/slack.rs +++ b/src/channels/slack.rs @@ -45,6 +45,15 @@ impl SlackChannel { .and_then(|u| u.as_str()) .map(String::from) } + + /// Resolve the thread identifier for inbound Slack messages. + /// Replies carry `thread_ts` (root thread id); top-level messages only have `ts`. + fn inbound_thread_ts(msg: &serde_json::Value, ts: &str) -> Option { + msg.get("thread_ts") + .and_then(|t| t.as_str()) + .or(if ts.is_empty() { None } else { Some(ts) }) + .map(str::to_string) + } } #[async_trait] @@ -54,11 +63,15 @@ impl Channel for SlackChannel { } async fn send(&self, message: &SendMessage) -> anyhow::Result<()> { - let body = serde_json::json!({ + let mut body = serde_json::json!({ "channel": message.recipient, "text": message.content }); + if let Some(ref ts) = message.thread_ts { + body["thread_ts"] = serde_json::json!(ts); + } + let resp = self .http_client() .post("https://slack.com/api/chat.postMessage") @@ -170,6 +183,7 @@ impl Channel for SlackChannel { .duration_since(std::time::UNIX_EPOCH) .unwrap_or_default() .as_secs(), + thread_ts: Self::inbound_thread_ts(msg, ts), }; if tx.send(channel_msg).await.is_err() { @@ -303,4 +317,33 @@ mod tests { assert!(!id.contains('-')); // No UUID dashes assert!(id.starts_with("slack_")); } + + #[test] + fn inbound_thread_ts_prefers_explicit_thread_ts() { + let msg = serde_json::json!({ + "ts": "123.002", + "thread_ts": "123.001" + }); + + let thread_ts = SlackChannel::inbound_thread_ts(&msg, "123.002"); + assert_eq!(thread_ts.as_deref(), Some("123.001")); + } + + #[test] + fn inbound_thread_ts_falls_back_to_ts() { + let msg = serde_json::json!({ + "ts": "123.001" + }); + + let thread_ts = SlackChannel::inbound_thread_ts(&msg, "123.001"); + assert_eq!(thread_ts.as_deref(), Some("123.001")); + } + + #[test] + fn inbound_thread_ts_none_when_ts_missing() { + let msg = serde_json::json!({}); + + let thread_ts = SlackChannel::inbound_thread_ts(&msg, ""); + assert_eq!(thread_ts, None); + } } diff --git a/src/channels/telegram.rs b/src/channels/telegram.rs index ca0e03b..1503e57 100644 --- a/src/channels/telegram.rs +++ b/src/channels/telegram.rs @@ -6,10 +6,10 @@ use async_trait::async_trait; use directories::UserDirs; use parking_lot::Mutex; use reqwest::multipart::{Form, Part}; -use std::fs; use std::path::Path; use std::sync::{Arc, RwLock}; use std::time::Duration; +use tokio::fs; /// Telegram's maximum message length for text messages const TELEGRAM_MAX_MESSAGE_LENGTH: usize = 4096; @@ -18,7 +18,7 @@ const TELEGRAM_BIND_COMMAND: &str = "/bind"; /// Split a message into chunks that respect Telegram's 4096 character limit. /// Tries to split at word boundaries when possible, and handles continuation. fn split_message_for_telegram(message: &str) -> Vec { - if message.len() <= TELEGRAM_MAX_MESSAGE_LENGTH { + if message.chars().count() <= TELEGRAM_MAX_MESSAGE_LENGTH { return vec![message.to_string()]; } @@ -26,29 +26,32 @@ fn split_message_for_telegram(message: &str) -> Vec { let mut remaining = message; while !remaining.is_empty() { - let chunk_end = if remaining.len() <= TELEGRAM_MAX_MESSAGE_LENGTH { - remaining.len() + // Find the byte offset for the Nth character boundary. + let hard_split = remaining + .char_indices() + .nth(TELEGRAM_MAX_MESSAGE_LENGTH) + .map_or(remaining.len(), |(idx, _)| idx); + + let chunk_end = if hard_split == remaining.len() { + hard_split } else { // Try to find a good break point (newline, then space) - let search_area = &remaining[..TELEGRAM_MAX_MESSAGE_LENGTH]; + let search_area = &remaining[..hard_split]; // Prefer splitting at newline if let Some(pos) = search_area.rfind('\n') { // Don't split if the newline is too close to the start - if pos >= TELEGRAM_MAX_MESSAGE_LENGTH / 2 { + if search_area[..pos].chars().count() >= TELEGRAM_MAX_MESSAGE_LENGTH / 2 { pos + 1 } else { // Try space as fallback - search_area - .rfind(' ') - .unwrap_or(TELEGRAM_MAX_MESSAGE_LENGTH) - + 1 + search_area.rfind(' ').unwrap_or(hard_split) + 1 } } else if let Some(pos) = search_area.rfind(' ') { pos + 1 } else { - // Hard split at the limit - TELEGRAM_MAX_MESSAGE_LENGTH + // Hard split at character boundary + hard_split } }; @@ -373,7 +376,7 @@ impl TelegramChannel { .collect() } - fn load_config_without_env() -> anyhow::Result { + async fn load_config_without_env() -> anyhow::Result { let home = UserDirs::new() .map(|u| u.home_dir().to_path_buf()) .context("Could not find home directory")?; @@ -381,18 +384,23 @@ impl TelegramChannel { let config_path = zeroclaw_dir.join("config.toml"); let contents = fs::read_to_string(&config_path) + .await .with_context(|| format!("Failed to read config file: {}", config_path.display()))?; let mut config: Config = toml::from_str(&contents) - .context("Failed to parse config file for Telegram binding")?; + .context("Failed to parse config.toml — check [channels.telegram] section for syntax errors")?; config.config_path = config_path; config.workspace_dir = zeroclaw_dir.join("workspace"); Ok(config) } - fn persist_allowed_identity_blocking(identity: &str) -> anyhow::Result<()> { - let mut config = Self::load_config_without_env()?; + async fn persist_allowed_identity(&self, identity: &str) -> anyhow::Result<()> { + let mut config = Self::load_config_without_env().await?; let Some(telegram) = config.channels_config.telegram.as_mut() else { - anyhow::bail!("Telegram channel config is missing in config.toml"); + anyhow::bail!( + "Missing [channels.telegram] section in config.toml. \ + Add bot_token and allowed_users under [channels.telegram], \ + or run `zeroclaw onboard --channels-only` to configure interactively" + ); }; let normalized = Self::normalize_identity(identity); @@ -404,20 +412,13 @@ impl TelegramChannel { telegram.allowed_users.push(normalized); config .save() + .await .context("Failed to persist Telegram allowlist to config.toml")?; } Ok(()) } - async fn persist_allowed_identity(&self, identity: &str) -> anyhow::Result<()> { - let identity = identity.to_string(); - tokio::task::spawn_blocking(move || Self::persist_allowed_identity_blocking(&identity)) - .await - .map_err(|e| anyhow::anyhow!("Failed to join Telegram bind save task: {e}"))??; - Ok(()) - } - fn add_allowed_identity_runtime(&self, identity: &str) { let normalized = Self::normalize_identity(identity); if normalized.is_empty() { @@ -600,12 +601,12 @@ impl TelegramChannel { let username = username_opt.unwrap_or("unknown"); let normalized_username = Self::normalize_identity(username); - let user_id = message + let sender_id = message .get("from") .and_then(|from| from.get("id")) .and_then(serde_json::Value::as_i64); - let user_id_str = user_id.map(|id| id.to_string()); - let normalized_user_id = user_id_str.as_deref().map(Self::normalize_identity); + let sender_id_str = sender_id.map(|id| id.to_string()); + let normalized_sender_id = sender_id_str.as_deref().map(Self::normalize_identity); let chat_id = message .get("chat") @@ -619,7 +620,7 @@ impl TelegramChannel { }; let mut identities = vec![normalized_username.as_str()]; - if let Some(ref id) = normalized_user_id { + if let Some(ref id) = normalized_sender_id { identities.push(id.as_str()); } @@ -629,9 +630,9 @@ impl TelegramChannel { if let Some(code) = Self::extract_bind_code(text) { if let Some(pairing) = self.pairing.as_ref() { - match pairing.try_pair(code) { + match pairing.try_pair(code, &chat_id).await { Ok(Some(_token)) => { - let bind_identity = normalized_user_id.clone().or_else(|| { + let bind_identity = normalized_sender_id.clone().or_else(|| { if normalized_username.is_empty() || normalized_username == "unknown" { None } else { @@ -694,7 +695,7 @@ impl TelegramChannel { } else { let _ = self .send(&SendMessage::new( - "ℹ️ Telegram pairing is not active. Ask operator to update allowlist in config.toml.", + "ℹ️ Telegram pairing is not active. Ask operator to add your user ID to channels.telegram.allowed_users in config.toml.", &chat_id, )) .await; @@ -703,12 +704,12 @@ impl TelegramChannel { } tracing::warn!( - "Telegram: ignoring message from unauthorized user: username={username}, user_id={}. \ + "Telegram: ignoring message from unauthorized user: username={username}, sender_id={}. \ Allowlist Telegram username (without '@') or numeric user ID.", - user_id_str.as_deref().unwrap_or("unknown") + sender_id_str.as_deref().unwrap_or("unknown") ); - let suggested_identity = normalized_user_id + let suggested_identity = normalized_sender_id .clone() .or_else(|| { if normalized_username.is_empty() || normalized_username == "unknown" { @@ -750,20 +751,20 @@ Allowlist Telegram username (without '@') or numeric user ID.", .unwrap_or("unknown") .to_string(); - let user_id = message + let sender_id = message .get("from") .and_then(|from| from.get("id")) .and_then(serde_json::Value::as_i64) .map(|id| id.to_string()); let sender_identity = if username == "unknown" { - user_id.clone().unwrap_or_else(|| "unknown".to_string()) + sender_id.clone().unwrap_or_else(|| "unknown".to_string()) } else { username.clone() }; let mut identities = vec![username.as_str()]; - if let Some(id) = user_id.as_deref() { + if let Some(id) = sender_id.as_deref() { identities.push(id); } @@ -825,6 +826,7 @@ Allowlist Telegram username (without '@') or numeric user ID.", .duration_since(std::time::UNIX_EPOCH) .unwrap_or_default() .as_secs(), + thread_ts: None, }) } @@ -1631,6 +1633,37 @@ impl Channel for TelegramChannel { .await } + async fn cancel_draft(&self, recipient: &str, message_id: &str) -> anyhow::Result<()> { + let (chat_id, _) = Self::parse_reply_target(recipient); + self.last_draft_edit.lock().remove(&chat_id); + + let message_id = match message_id.parse::() { + Ok(id) => id, + Err(e) => { + tracing::debug!("Invalid Telegram draft message_id '{message_id}': {e}"); + return Ok(()); + } + }; + + let response = self + .client + .post(self.api_url("deleteMessage")) + .json(&serde_json::json!({ + "chat_id": chat_id, + "message_id": message_id, + })) + .send() + .await?; + + if !response.status().is_success() { + let status = response.status(); + let body = response.text().await.unwrap_or_default(); + tracing::debug!("Telegram deleteMessage failed ({status}): {body}"); + } + + Ok(()) + } + async fn send(&self, message: &SendMessage) -> anyhow::Result<()> { // Strip tool_call tags before processing to prevent Markdown parsing failures let content = strip_tool_call_tags(&message.content); @@ -2830,4 +2863,103 @@ mod tests { let ch_disabled = TelegramChannel::new("token".into(), vec!["*".into()], false); assert!(!ch_disabled.mention_only); } + + // ───────────────────────────────────────────────────────────────────── + // TG6: Channel platform limit edge cases for Telegram (4096 char limit) + // Prevents: Pattern 6 — issues #574, #499 + // ───────────────────────────────────────────────────────────────────── + + #[test] + fn telegram_split_code_block_at_boundary() { + let mut msg = String::new(); + msg.push_str("```python\n"); + msg.push_str(&"x".repeat(4085)); + msg.push_str("\n```\nMore text after code block"); + let parts = split_message_for_telegram(&msg); + assert!( + parts.len() >= 2, + "code block spanning boundary should split" + ); + for part in &parts { + assert!( + part.len() <= TELEGRAM_MAX_MESSAGE_LENGTH, + "each part must be <= {TELEGRAM_MAX_MESSAGE_LENGTH}, got {}", + part.len() + ); + } + } + + #[test] + fn telegram_split_single_long_word() { + let long_word = "a".repeat(5000); + let parts = split_message_for_telegram(&long_word); + assert!(parts.len() >= 2, "word exceeding limit must be split"); + for part in &parts { + assert!( + part.len() <= TELEGRAM_MAX_MESSAGE_LENGTH, + "hard-split part must be <= {TELEGRAM_MAX_MESSAGE_LENGTH}, got {}", + part.len() + ); + } + let reassembled: String = parts.join(""); + assert_eq!(reassembled, long_word); + } + + #[test] + fn telegram_split_exactly_at_limit_no_split() { + let msg = "a".repeat(TELEGRAM_MAX_MESSAGE_LENGTH); + let parts = split_message_for_telegram(&msg); + assert_eq!(parts.len(), 1, "message exactly at limit should not split"); + } + + #[test] + fn telegram_split_one_over_limit() { + let msg = "a".repeat(TELEGRAM_MAX_MESSAGE_LENGTH + 1); + let parts = split_message_for_telegram(&msg); + assert!(parts.len() >= 2, "message 1 char over limit must split"); + } + + #[test] + fn telegram_split_many_short_lines() { + let msg: String = (0..1000).map(|i| format!("line {i}\n")).collect(); + let parts = split_message_for_telegram(&msg); + for part in &parts { + assert!( + part.len() <= TELEGRAM_MAX_MESSAGE_LENGTH, + "short-line batch must be <= limit" + ); + } + } + + #[test] + fn telegram_split_only_whitespace() { + let msg = " \n\n\t "; + let parts = split_message_for_telegram(msg); + assert!(parts.len() <= 1); + } + + #[test] + fn telegram_split_emoji_at_boundary() { + let mut msg = "a".repeat(4094); + msg.push_str("🎉🎊"); // 4096 chars total + let parts = split_message_for_telegram(&msg); + for part in &parts { + // The function splits on character count, not byte count + assert!( + part.chars().count() <= TELEGRAM_MAX_MESSAGE_LENGTH, + "emoji boundary split must respect limit" + ); + } + } + + #[test] + fn telegram_split_consecutive_newlines() { + let mut msg = "a".repeat(4090); + msg.push_str("\n\n\n\n\n\n"); + msg.push_str(&"b".repeat(100)); + let parts = split_message_for_telegram(&msg); + for part in &parts { + assert!(part.len() <= TELEGRAM_MAX_MESSAGE_LENGTH); + } + } } diff --git a/src/channels/traits.rs b/src/channels/traits.rs index 3a7d9df..67546ce 100644 --- a/src/channels/traits.rs +++ b/src/channels/traits.rs @@ -9,6 +9,9 @@ pub struct ChannelMessage { pub content: String, pub channel: String, pub timestamp: u64, + /// Platform thread identifier (e.g. Slack `ts`, Discord thread ID). + /// When set, replies should be posted as threaded responses. + pub thread_ts: Option, } /// Message to send through a channel @@ -17,6 +20,8 @@ pub struct SendMessage { pub content: String, pub recipient: String, pub subject: Option, + /// Platform thread identifier for threaded replies (e.g. Slack `thread_ts`). + pub thread_ts: Option, } impl SendMessage { @@ -26,6 +31,7 @@ impl SendMessage { content: content.into(), recipient: recipient.into(), subject: None, + thread_ts: None, } } @@ -39,8 +45,15 @@ impl SendMessage { content: content.into(), recipient: recipient.into(), subject: Some(subject.into()), + thread_ts: None, } } + + /// Set the thread identifier for threaded replies. + pub fn in_thread(mut self, thread_ts: Option) -> Self { + self.thread_ts = thread_ts; + self + } } /// Core channel trait — implement for any messaging platform @@ -100,6 +113,11 @@ pub trait Channel: Send + Sync { ) -> anyhow::Result<()> { Ok(()) } + + /// Cancel and remove a previously sent draft message if the channel supports it. + async fn cancel_draft(&self, _recipient: &str, _message_id: &str) -> anyhow::Result<()> { + Ok(()) + } } #[cfg(test)] @@ -129,6 +147,7 @@ mod tests { content: "hello".into(), channel: "dummy".into(), timestamp: 123, + thread_ts: None, }) .await .map_err(|e| anyhow::anyhow!(e.to_string())) @@ -144,6 +163,7 @@ mod tests { content: "ping".into(), channel: "dummy".into(), timestamp: 999, + thread_ts: None, }; let cloned = message.clone(); @@ -183,6 +203,7 @@ mod tests { .finalize_draft("bob", "msg_1", "final text") .await .is_ok()); + assert!(channel.cancel_draft("bob", "msg_1").await.is_ok()); } #[tokio::test] diff --git a/src/channels/whatsapp.rs b/src/channels/whatsapp.rs index c6e5baa..5401e60 100644 --- a/src/channels/whatsapp.rs +++ b/src/channels/whatsapp.rs @@ -8,6 +8,20 @@ use uuid::Uuid; /// Messages are received via the gateway's `/whatsapp` webhook endpoint. /// The `listen` method here is a no-op placeholder; actual message handling /// happens in the gateway when Meta sends webhook events. +fn ensure_https(url: &str) -> anyhow::Result<()> { + if !url.starts_with("https://") { + anyhow::bail!( + "Refusing to transmit sensitive data over non-HTTPS URL: URL scheme must be https" + ); + } + Ok(()) +} + +/// +/// # Runtime Negotiation +/// +/// This Cloud API channel is automatically selected when `phone_number_id` is set in the config. +/// Use `WhatsAppWebChannel` (with `session_path`) for native Web mode. pub struct WhatsAppChannel { access_token: String, endpoint_id: String, @@ -85,7 +99,8 @@ impl WhatsAppChannel { if !self.is_number_allowed(&normalized_from) { tracing::warn!( "WhatsApp: ignoring message from unauthorized number: {normalized_from}. \ - Add to allowed_numbers in config.toml, then run `zeroclaw onboard --channels-only`." + Add to channels.whatsapp.allowed_numbers in config.toml, \ + or run `zeroclaw onboard --channels-only` to configure interactively." ); continue; } @@ -126,6 +141,7 @@ impl WhatsAppChannel { content, channel: "whatsapp".to_string(), timestamp, + thread_ts: None, }); } } @@ -165,6 +181,8 @@ impl Channel for WhatsAppChannel { } }); + ensure_https(&url)?; + let resp = self .http_client() .post(&url) @@ -203,6 +221,10 @@ impl Channel for WhatsAppChannel { // Check if we can reach the WhatsApp API let url = format!("https://graph.facebook.com/v18.0/{}", self.endpoint_id); + if ensure_https(&url).is_err() { + return false; + } + self.http_client() .get(&url) .bearer_auth(&self.access_token) diff --git a/src/channels/whatsapp_storage.rs b/src/channels/whatsapp_storage.rs new file mode 100644 index 0000000..87eebf7 --- /dev/null +++ b/src/channels/whatsapp_storage.rs @@ -0,0 +1,1345 @@ +//! Custom wa-rs storage backend using ZeroClaw's rusqlite +//! +//! This module implements all 4 wa-rs storage traits using rusqlite directly, +//! avoiding the Diesel/libsqlite3-sys dependency conflict from wa-rs-sqlite-storage. +//! +//! # Traits Implemented +//! +//! - [`SignalStore`]: Signal protocol cryptographic operations +//! - [`AppSyncStore`]: WhatsApp app state synchronization +//! - [`ProtocolStore`]: WhatsApp Web protocol alignment +//! - [`DeviceStore`]: Device persistence operations + +#[cfg(feature = "whatsapp-web")] +use async_trait::async_trait; +#[cfg(feature = "whatsapp-web")] +use parking_lot::Mutex; +#[cfg(feature = "whatsapp-web")] +use rusqlite::{params, Connection}; +#[cfg(feature = "whatsapp-web")] +use std::path::Path; +#[cfg(feature = "whatsapp-web")] +use std::sync::Arc; + +#[cfg(feature = "whatsapp-web")] +use prost::Message; +#[cfg(feature = "whatsapp-web")] +use wa_rs_binary::jid::Jid; +#[cfg(feature = "whatsapp-web")] +use wa_rs_core::appstate::hash::HashState; +#[cfg(feature = "whatsapp-web")] +use wa_rs_core::appstate::processor::AppStateMutationMAC; +#[cfg(feature = "whatsapp-web")] +use wa_rs_core::store::traits::DeviceInfo; +#[cfg(feature = "whatsapp-web")] +use wa_rs_core::store::traits::DeviceStore as DeviceStoreTrait; +#[cfg(feature = "whatsapp-web")] +use wa_rs_core::store::traits::*; +#[cfg(feature = "whatsapp-web")] +use wa_rs_core::store::Device as CoreDevice; + +/// Custom wa-rs storage backend using rusqlite +/// +/// This implements all 4 storage traits required by wa-rs. +/// The backend uses ZeroClaw's existing rusqlite setup, avoiding the +/// Diesel/libsqlite3-sys conflict from wa-rs-sqlite-storage. +#[cfg(feature = "whatsapp-web")] +#[derive(Clone)] +pub struct RusqliteStore { + /// Database file path + db_path: String, + /// SQLite connection (thread-safe via Mutex) + conn: Arc>, + /// Device ID for this session + device_id: i32, +} + +/// Helper macro to convert rusqlite errors to StoreError +/// For execute statements that return usize, maps to () +macro_rules! to_store_err { + // For expressions returning Result + (execute: $expr:expr) => { + $expr + .map(|_| ()) + .map_err(|e| wa_rs_core::store::error::StoreError::Database(e.to_string())) + }; + // For other expressions + ($expr:expr) => { + $expr.map_err(|e| wa_rs_core::store::error::StoreError::Database(e.to_string())) + }; +} + +#[cfg(feature = "whatsapp-web")] +impl RusqliteStore { + /// Create a new rusqlite-based storage backend + /// + /// # Arguments + /// + /// * `db_path` - Path to the SQLite database file (will be created if needed) + pub fn new>(db_path: P) -> anyhow::Result { + let db_path = db_path.as_ref().to_string_lossy().to_string(); + + // Create parent directory if needed + if let Some(parent) = Path::new(&db_path).parent() { + std::fs::create_dir_all(parent)?; + } + + let conn = Connection::open(&db_path)?; + + // Enable WAL mode for better concurrency + to_store_err!(conn.execute_batch( + "PRAGMA journal_mode = WAL; + PRAGMA synchronous = NORMAL;", + ))?; + + let store = Self { + db_path, + conn: Arc::new(Mutex::new(conn)), + device_id: 1, // Default device ID + }; + + store.init_schema()?; + + Ok(store) + } + + /// Initialize all database tables + fn init_schema(&self) -> anyhow::Result<()> { + let conn = self.conn.lock(); + to_store_err!(conn.execute_batch( + "-- Main device table + CREATE TABLE IF NOT EXISTS device ( + id INTEGER PRIMARY KEY, + lid TEXT, + pn TEXT, + registration_id INTEGER NOT NULL, + noise_key BLOB NOT NULL, + identity_key BLOB NOT NULL, + signed_pre_key BLOB NOT NULL, + signed_pre_key_id INTEGER NOT NULL, + signed_pre_key_signature BLOB NOT NULL, + adv_secret_key BLOB NOT NULL, + account BLOB, + push_name TEXT NOT NULL, + app_version_primary INTEGER NOT NULL, + app_version_secondary INTEGER NOT NULL, + app_version_tertiary INTEGER NOT NULL, + app_version_last_fetched_ms INTEGER NOT NULL, + edge_routing_info BLOB, + props_hash TEXT + ); + + -- Signal identity keys + CREATE TABLE IF NOT EXISTS identities ( + address TEXT NOT NULL, + key BLOB NOT NULL, + device_id INTEGER NOT NULL, + PRIMARY KEY (address, device_id) + ); + + -- Signal protocol sessions + CREATE TABLE IF NOT EXISTS sessions ( + address TEXT NOT NULL, + record BLOB NOT NULL, + device_id INTEGER NOT NULL, + PRIMARY KEY (address, device_id) + ); + + -- Pre-keys for key exchange + CREATE TABLE IF NOT EXISTS prekeys ( + id INTEGER NOT NULL, + key BLOB NOT NULL, + uploaded INTEGER NOT NULL DEFAULT 0, + device_id INTEGER NOT NULL, + PRIMARY KEY (id, device_id) + ); + + -- Signed pre-keys + CREATE TABLE IF NOT EXISTS signed_prekeys ( + id INTEGER NOT NULL, + record BLOB NOT NULL, + device_id INTEGER NOT NULL, + PRIMARY KEY (id, device_id) + ); + + -- Sender keys for group messaging + CREATE TABLE IF NOT EXISTS sender_keys ( + address TEXT NOT NULL, + record BLOB NOT NULL, + device_id INTEGER NOT NULL, + PRIMARY KEY (address, device_id) + ); + + -- App state sync keys + CREATE TABLE IF NOT EXISTS app_state_keys ( + key_id BLOB NOT NULL, + key_data BLOB NOT NULL, + device_id INTEGER NOT NULL, + PRIMARY KEY (key_id, device_id) + ); + + -- App state versions + CREATE TABLE IF NOT EXISTS app_state_versions ( + name TEXT NOT NULL, + state_data BLOB NOT NULL, + device_id INTEGER NOT NULL, + PRIMARY KEY (name, device_id) + ); + + -- App state mutation MACs + CREATE TABLE IF NOT EXISTS app_state_mutation_macs ( + name TEXT NOT NULL, + version INTEGER NOT NULL, + index_mac BLOB NOT NULL, + value_mac BLOB NOT NULL, + device_id INTEGER NOT NULL, + PRIMARY KEY (name, index_mac, device_id) + ); + + -- LID to phone number mapping + CREATE TABLE IF NOT EXISTS lid_pn_mapping ( + lid TEXT NOT NULL, + phone_number TEXT NOT NULL, + created_at INTEGER NOT NULL, + learning_source TEXT NOT NULL, + updated_at INTEGER NOT NULL, + device_id INTEGER NOT NULL, + PRIMARY KEY (lid, device_id) + ); + + -- SKDM recipients tracking + CREATE TABLE IF NOT EXISTS skdm_recipients ( + group_jid TEXT NOT NULL, + device_jid TEXT NOT NULL, + device_id INTEGER NOT NULL, + created_at INTEGER NOT NULL, + PRIMARY KEY (group_jid, device_jid, device_id) + ); + + -- Device registry for multi-device + CREATE TABLE IF NOT EXISTS device_registry ( + user_id TEXT NOT NULL, + devices_json TEXT NOT NULL, + timestamp INTEGER NOT NULL, + phash TEXT, + device_id INTEGER NOT NULL, + updated_at INTEGER NOT NULL, + PRIMARY KEY (user_id, device_id) + ); + + -- Base keys for collision detection + CREATE TABLE IF NOT EXISTS base_keys ( + address TEXT NOT NULL, + message_id TEXT NOT NULL, + base_key BLOB NOT NULL, + device_id INTEGER NOT NULL, + created_at INTEGER NOT NULL, + PRIMARY KEY (address, message_id, device_id) + ); + + -- Sender key status for lazy deletion + CREATE TABLE IF NOT EXISTS sender_key_status ( + group_jid TEXT NOT NULL, + participant TEXT NOT NULL, + device_id INTEGER NOT NULL, + marked_at INTEGER NOT NULL, + PRIMARY KEY (group_jid, participant, device_id) + ); + + -- Trusted contact tokens + CREATE TABLE IF NOT EXISTS tc_tokens ( + jid TEXT NOT NULL, + token BLOB NOT NULL, + token_timestamp INTEGER NOT NULL, + sender_timestamp INTEGER, + device_id INTEGER NOT NULL, + updated_at INTEGER NOT NULL, + PRIMARY KEY (jid, device_id) + );", + ))?; + Ok(()) + } +} + +#[cfg(feature = "whatsapp-web")] +#[async_trait] +impl SignalStore for RusqliteStore { + // --- Identity Operations --- + + async fn put_identity( + &self, + address: &str, + key: [u8; 32], + ) -> wa_rs_core::store::error::Result<()> { + let conn = self.conn.lock(); + to_store_err!(execute: conn.execute( + "INSERT OR REPLACE INTO identities (address, key, device_id) + VALUES (?1, ?2, ?3)", + params![address, key.to_vec(), self.device_id], + )) + } + + async fn load_identity( + &self, + address: &str, + ) -> wa_rs_core::store::error::Result>> { + let conn = self.conn.lock(); + let result = conn.query_row( + "SELECT key FROM identities WHERE address = ?1 AND device_id = ?2", + params![address, self.device_id], + |row| row.get::<_, Vec>(0), + ); + + match result { + Ok(key) => Ok(Some(key)), + Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None), + Err(e) => Err(wa_rs_core::store::error::StoreError::Database( + e.to_string(), + )), + } + } + + async fn delete_identity(&self, address: &str) -> wa_rs_core::store::error::Result<()> { + let conn = self.conn.lock(); + to_store_err!(execute: conn.execute( + "DELETE FROM identities WHERE address = ?1 AND device_id = ?2", + params![address, self.device_id], + )) + } + + // --- Session Operations --- + + async fn get_session( + &self, + address: &str, + ) -> wa_rs_core::store::error::Result>> { + let conn = self.conn.lock(); + let result = conn.query_row( + "SELECT record FROM sessions WHERE address = ?1 AND device_id = ?2", + params![address, self.device_id], + |row| row.get::<_, Vec>(0), + ); + + match result { + Ok(record) => Ok(Some(record)), + Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None), + Err(e) => Err(wa_rs_core::store::error::StoreError::Database( + e.to_string(), + )), + } + } + + async fn put_session( + &self, + address: &str, + session: &[u8], + ) -> wa_rs_core::store::error::Result<()> { + let conn = self.conn.lock(); + to_store_err!(execute: conn.execute( + "INSERT OR REPLACE INTO sessions (address, record, device_id) + VALUES (?1, ?2, ?3)", + params![address, session, self.device_id], + )) + } + + async fn delete_session(&self, address: &str) -> wa_rs_core::store::error::Result<()> { + let conn = self.conn.lock(); + to_store_err!(execute: conn.execute( + "DELETE FROM sessions WHERE address = ?1 AND device_id = ?2", + params![address, self.device_id], + )) + } + + // --- PreKey Operations --- + + async fn store_prekey( + &self, + id: u32, + record: &[u8], + uploaded: bool, + ) -> wa_rs_core::store::error::Result<()> { + let conn = self.conn.lock(); + to_store_err!(execute: conn.execute( + "INSERT OR REPLACE INTO prekeys (id, key, uploaded, device_id) + VALUES (?1, ?2, ?3, ?4)", + params![id, record, uploaded, self.device_id], + )) + } + + async fn load_prekey(&self, id: u32) -> wa_rs_core::store::error::Result>> { + let conn = self.conn.lock(); + let result = conn.query_row( + "SELECT key FROM prekeys WHERE id = ?1 AND device_id = ?2", + params![id, self.device_id], + |row| row.get::<_, Vec>(0), + ); + + match result { + Ok(key) => Ok(Some(key)), + Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None), + Err(e) => Err(wa_rs_core::store::error::StoreError::Database( + e.to_string(), + )), + } + } + + async fn remove_prekey(&self, id: u32) -> wa_rs_core::store::error::Result<()> { + let conn = self.conn.lock(); + to_store_err!(execute: conn.execute( + "DELETE FROM prekeys WHERE id = ?1 AND device_id = ?2", + params![id, self.device_id], + )) + } + + // --- Signed PreKey Operations --- + + async fn store_signed_prekey( + &self, + id: u32, + record: &[u8], + ) -> wa_rs_core::store::error::Result<()> { + let conn = self.conn.lock(); + to_store_err!(execute: conn.execute( + "INSERT OR REPLACE INTO signed_prekeys (id, record, device_id) + VALUES (?1, ?2, ?3)", + params![id, record, self.device_id], + )) + } + + async fn load_signed_prekey( + &self, + id: u32, + ) -> wa_rs_core::store::error::Result>> { + let conn = self.conn.lock(); + let result = conn.query_row( + "SELECT record FROM signed_prekeys WHERE id = ?1 AND device_id = ?2", + params![id, self.device_id], + |row| row.get::<_, Vec>(0), + ); + + match result { + Ok(record) => Ok(Some(record)), + Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None), + Err(e) => Err(wa_rs_core::store::error::StoreError::Database( + e.to_string(), + )), + } + } + + async fn load_all_signed_prekeys( + &self, + ) -> wa_rs_core::store::error::Result)>> { + let conn = self.conn.lock(); + let mut stmt = to_store_err!( + conn.prepare("SELECT id, record FROM signed_prekeys WHERE device_id = ?1") + )?; + + let rows = to_store_err!(stmt.query_map(params![self.device_id], |row| { + Ok((row.get::<_, u32>(0)?, row.get::<_, Vec>(1)?)) + }))?; + + let mut result = Vec::new(); + for row in rows { + result.push(to_store_err!(row)?); + } + + Ok(result) + } + + async fn remove_signed_prekey(&self, id: u32) -> wa_rs_core::store::error::Result<()> { + let conn = self.conn.lock(); + to_store_err!(execute: conn.execute( + "DELETE FROM signed_prekeys WHERE id = ?1 AND device_id = ?2", + params![id, self.device_id], + )) + } + + // --- Sender Key Operations --- + + async fn put_sender_key( + &self, + address: &str, + record: &[u8], + ) -> wa_rs_core::store::error::Result<()> { + let conn = self.conn.lock(); + to_store_err!(execute: conn.execute( + "INSERT OR REPLACE INTO sender_keys (address, record, device_id) + VALUES (?1, ?2, ?3)", + params![address, record, self.device_id], + )) + } + + async fn get_sender_key( + &self, + address: &str, + ) -> wa_rs_core::store::error::Result>> { + let conn = self.conn.lock(); + let result = conn.query_row( + "SELECT record FROM sender_keys WHERE address = ?1 AND device_id = ?2", + params![address, self.device_id], + |row| row.get::<_, Vec>(0), + ); + + match result { + Ok(record) => Ok(Some(record)), + Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None), + Err(e) => Err(wa_rs_core::store::error::StoreError::Database( + e.to_string(), + )), + } + } + + async fn delete_sender_key(&self, address: &str) -> wa_rs_core::store::error::Result<()> { + let conn = self.conn.lock(); + to_store_err!(execute: conn.execute( + "DELETE FROM sender_keys WHERE address = ?1 AND device_id = ?2", + params![address, self.device_id], + )) + } +} + +#[cfg(feature = "whatsapp-web")] +#[async_trait] +impl AppSyncStore for RusqliteStore { + async fn get_sync_key( + &self, + key_id: &[u8], + ) -> wa_rs_core::store::error::Result> { + let conn = self.conn.lock(); + let result = conn.query_row( + "SELECT key_data FROM app_state_keys WHERE key_id = ?1 AND device_id = ?2", + params![key_id, self.device_id], + |row| { + let key_data: Vec = row.get(0)?; + serde_json::from_slice(&key_data) + .map_err(|e| rusqlite::Error::ToSqlConversionFailure(Box::new(e))) + }, + ); + + match result { + Ok(key) => Ok(Some(key)), + Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None), + Err(e) => Err(wa_rs_core::store::error::StoreError::Database( + e.to_string(), + )), + } + } + + async fn set_sync_key( + &self, + key_id: &[u8], + key: AppStateSyncKey, + ) -> wa_rs_core::store::error::Result<()> { + let conn = self.conn.lock(); + let key_data = to_store_err!(serde_json::to_vec(&key))?; + + to_store_err!(execute: conn.execute( + "INSERT OR REPLACE INTO app_state_keys (key_id, key_data, device_id) + VALUES (?1, ?2, ?3)", + params![key_id, key_data, self.device_id], + )) + } + + async fn get_version(&self, name: &str) -> wa_rs_core::store::error::Result { + let conn = self.conn.lock(); + let state_data: Vec = to_store_err!(conn.query_row( + "SELECT state_data FROM app_state_versions WHERE name = ?1 AND device_id = ?2", + params![name, self.device_id], + |row| row.get(0), + ))?; + + to_store_err!(serde_json::from_slice(&state_data)) + } + + async fn set_version( + &self, + name: &str, + state: HashState, + ) -> wa_rs_core::store::error::Result<()> { + let conn = self.conn.lock(); + let state_data = to_store_err!(serde_json::to_vec(&state))?; + + to_store_err!(execute: conn.execute( + "INSERT OR REPLACE INTO app_state_versions (name, state_data, device_id) + VALUES (?1, ?2, ?3)", + params![name, state_data, self.device_id], + )) + } + + async fn put_mutation_macs( + &self, + name: &str, + version: u64, + mutations: &[AppStateMutationMAC], + ) -> wa_rs_core::store::error::Result<()> { + let conn = self.conn.lock(); + + for mutation in mutations { + let index_mac = to_store_err!(serde_json::to_vec(&mutation.index_mac))?; + let value_mac = to_store_err!(serde_json::to_vec(&mutation.value_mac))?; + + to_store_err!(execute: conn.execute( + "INSERT OR REPLACE INTO app_state_mutation_macs + (name, version, index_mac, value_mac, device_id) + VALUES (?1, ?2, ?3, ?4, ?5)", + params![name, i64::try_from(version).unwrap_or(i64::MAX), index_mac, value_mac, self.device_id], + ))?; + } + + Ok(()) + } + + async fn get_mutation_mac( + &self, + name: &str, + index_mac: &[u8], + ) -> wa_rs_core::store::error::Result>> { + let conn = self.conn.lock(); + let index_mac_json = to_store_err!(serde_json::to_vec(index_mac))?; + + let result = conn.query_row( + "SELECT value_mac FROM app_state_mutation_macs + WHERE name = ?1 AND index_mac = ?2 AND device_id = ?3", + params![name, index_mac_json, self.device_id], + |row| row.get::<_, Vec>(0), + ); + + match result { + Ok(mac) => Ok(Some(mac)), + Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None), + Err(e) => Err(wa_rs_core::store::error::StoreError::Database( + e.to_string(), + )), + } + } + + async fn delete_mutation_macs( + &self, + name: &str, + index_macs: &[Vec], + ) -> wa_rs_core::store::error::Result<()> { + let conn = self.conn.lock(); + + for index_mac in index_macs { + let index_mac_json = to_store_err!(serde_json::to_vec(index_mac))?; + + to_store_err!(execute: conn.execute( + "DELETE FROM app_state_mutation_macs + WHERE name = ?1 AND index_mac = ?2 AND device_id = ?3", + params![name, index_mac_json, self.device_id], + ))?; + } + + Ok(()) + } +} + +#[cfg(feature = "whatsapp-web")] +#[async_trait] +impl ProtocolStore for RusqliteStore { + // --- SKDM Tracking --- + + async fn get_skdm_recipients( + &self, + group_jid: &str, + ) -> wa_rs_core::store::error::Result> { + let conn = self.conn.lock(); + let mut stmt = to_store_err!(conn.prepare( + "SELECT device_jid FROM skdm_recipients WHERE group_jid = ?1 AND device_id = ?2" + ))?; + + let rows = to_store_err!(stmt.query_map(params![group_jid, self.device_id], |row| { + row.get::<_, String>(0) + }))?; + + let mut result = Vec::new(); + for row in rows { + let jid_str = to_store_err!(row)?; + if let Ok(jid) = jid_str.parse() { + result.push(jid); + } + } + + Ok(result) + } + + async fn add_skdm_recipients( + &self, + group_jid: &str, + device_jids: &[Jid], + ) -> wa_rs_core::store::error::Result<()> { + let conn = self.conn.lock(); + let now = chrono::Utc::now().timestamp(); + + for device_jid in device_jids { + to_store_err!(execute: conn.execute( + "INSERT OR IGNORE INTO skdm_recipients (group_jid, device_jid, device_id, created_at) + VALUES (?1, ?2, ?3, ?4)", + params![group_jid, device_jid.to_string(), self.device_id, now], + ))?; + } + + Ok(()) + } + + async fn clear_skdm_recipients(&self, group_jid: &str) -> wa_rs_core::store::error::Result<()> { + let conn = self.conn.lock(); + to_store_err!(execute: conn.execute( + "DELETE FROM skdm_recipients WHERE group_jid = ?1 AND device_id = ?2", + params![group_jid, self.device_id], + )) + } + + // --- LID-PN Mapping --- + + async fn get_lid_mapping( + &self, + lid: &str, + ) -> wa_rs_core::store::error::Result> { + let conn = self.conn.lock(); + let result = conn.query_row( + "SELECT lid, phone_number, created_at, learning_source, updated_at + FROM lid_pn_mapping WHERE lid = ?1 AND device_id = ?2", + params![lid, self.device_id], + |row| { + Ok(LidPnMappingEntry { + lid: row.get(0)?, + phone_number: row.get(1)?, + created_at: row.get(2)?, + learning_source: row.get(3)?, + updated_at: row.get(4)?, + }) + }, + ); + + match result { + Ok(entry) => Ok(Some(entry)), + Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None), + Err(e) => Err(wa_rs_core::store::error::StoreError::Database( + e.to_string(), + )), + } + } + + async fn get_pn_mapping( + &self, + phone: &str, + ) -> wa_rs_core::store::error::Result> { + let conn = self.conn.lock(); + let result = conn.query_row( + "SELECT lid, phone_number, created_at, learning_source, updated_at + FROM lid_pn_mapping WHERE phone_number = ?1 AND device_id = ?2 + ORDER BY updated_at DESC LIMIT 1", + params![phone, self.device_id], + |row| { + Ok(LidPnMappingEntry { + lid: row.get(0)?, + phone_number: row.get(1)?, + created_at: row.get(2)?, + learning_source: row.get(3)?, + updated_at: row.get(4)?, + }) + }, + ); + + match result { + Ok(entry) => Ok(Some(entry)), + Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None), + Err(e) => Err(wa_rs_core::store::error::StoreError::Database( + e.to_string(), + )), + } + } + + async fn put_lid_mapping( + &self, + entry: &LidPnMappingEntry, + ) -> wa_rs_core::store::error::Result<()> { + let conn = self.conn.lock(); + to_store_err!(execute: conn.execute( + "INSERT OR REPLACE INTO lid_pn_mapping + (lid, phone_number, created_at, learning_source, updated_at, device_id) + VALUES (?1, ?2, ?3, ?4, ?5, ?6)", + params![ + entry.lid, + entry.phone_number, + entry.created_at, + entry.learning_source, + entry.updated_at, + self.device_id, + ], + )) + } + + async fn get_all_lid_mappings( + &self, + ) -> wa_rs_core::store::error::Result> { + let conn = self.conn.lock(); + let mut stmt = to_store_err!(conn.prepare( + "SELECT lid, phone_number, created_at, learning_source, updated_at + FROM lid_pn_mapping WHERE device_id = ?1" + ))?; + + let rows = to_store_err!(stmt.query_map(params![self.device_id], |row| { + Ok(LidPnMappingEntry { + lid: row.get(0)?, + phone_number: row.get(1)?, + created_at: row.get(2)?, + learning_source: row.get(3)?, + updated_at: row.get(4)?, + }) + }))?; + + let mut result = Vec::new(); + for row in rows { + result.push(to_store_err!(row)?); + } + + Ok(result) + } + + // --- Base Key Collision Detection --- + + async fn save_base_key( + &self, + address: &str, + message_id: &str, + base_key: &[u8], + ) -> wa_rs_core::store::error::Result<()> { + let conn = self.conn.lock(); + let now = chrono::Utc::now().timestamp(); + + to_store_err!(execute: conn.execute( + "INSERT OR REPLACE INTO base_keys (address, message_id, base_key, device_id, created_at) + VALUES (?1, ?2, ?3, ?4, ?5)", + params![address, message_id, base_key, self.device_id, now], + )) + } + + async fn has_same_base_key( + &self, + address: &str, + message_id: &str, + current_base_key: &[u8], + ) -> wa_rs_core::store::error::Result { + let conn = self.conn.lock(); + let result = conn.query_row( + "SELECT base_key FROM base_keys + WHERE address = ?1 AND message_id = ?2 AND device_id = ?3", + params![address, message_id, self.device_id], + |row| { + let saved_key: Vec = row.get(0)?; + Ok(saved_key == current_base_key) + }, + ); + + match result { + Ok(same) => Ok(same), + Err(rusqlite::Error::QueryReturnedNoRows) => Ok(false), + Err(e) => Err(wa_rs_core::store::error::StoreError::Database( + e.to_string(), + )), + } + } + + async fn delete_base_key( + &self, + address: &str, + message_id: &str, + ) -> wa_rs_core::store::error::Result<()> { + let conn = self.conn.lock(); + to_store_err!(execute: conn.execute( + "DELETE FROM base_keys WHERE address = ?1 AND message_id = ?2 AND device_id = ?3", + params![address, message_id, self.device_id], + )) + } + + // --- Device Registry --- + + async fn update_device_list( + &self, + record: DeviceListRecord, + ) -> wa_rs_core::store::error::Result<()> { + let conn = self.conn.lock(); + let devices_json = to_store_err!(serde_json::to_string(&record.devices))?; + let now = chrono::Utc::now().timestamp(); + + to_store_err!(execute: conn.execute( + "INSERT OR REPLACE INTO device_registry + (user_id, devices_json, timestamp, phash, device_id, updated_at) + VALUES (?1, ?2, ?3, ?4, ?5, ?6)", + params![ + record.user, + devices_json, + record.timestamp, + record.phash, + self.device_id, + now, + ], + )) + } + + async fn get_devices( + &self, + user: &str, + ) -> wa_rs_core::store::error::Result> { + let conn = self.conn.lock(); + let result = conn.query_row( + "SELECT user_id, devices_json, timestamp, phash + FROM device_registry WHERE user_id = ?1 AND device_id = ?2", + params![user, self.device_id], + |row| { + // Helper to convert errors to rusqlite::Error + fn to_rusqlite_err( + e: E, + ) -> rusqlite::Error { + rusqlite::Error::ToSqlConversionFailure(Box::new(e)) + } + + let devices_json: String = row.get(1)?; + let devices: Vec = + serde_json::from_str(&devices_json).map_err(to_rusqlite_err)?; + Ok(DeviceListRecord { + user: row.get(0)?, + devices, + timestamp: row.get(2)?, + phash: row.get(3)?, + }) + }, + ); + + match result { + Ok(record) => Ok(Some(record)), + Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None), + Err(e) => Err(wa_rs_core::store::error::StoreError::Database( + e.to_string(), + )), + } + } + + // --- Sender Key Status (Lazy Deletion) --- + + async fn mark_forget_sender_key( + &self, + group_jid: &str, + participant: &str, + ) -> wa_rs_core::store::error::Result<()> { + let conn = self.conn.lock(); + let now = chrono::Utc::now().timestamp(); + + to_store_err!(execute: conn.execute( + "INSERT OR REPLACE INTO sender_key_status (group_jid, participant, device_id, marked_at) + VALUES (?1, ?2, ?3, ?4)", + params![group_jid, participant, self.device_id, now], + )) + } + + async fn consume_forget_marks( + &self, + group_jid: &str, + ) -> wa_rs_core::store::error::Result> { + let conn = self.conn.lock(); + let mut stmt = to_store_err!(conn.prepare( + "SELECT participant FROM sender_key_status + WHERE group_jid = ?1 AND device_id = ?2" + ))?; + + let rows = to_store_err!(stmt.query_map(params![group_jid, self.device_id], |row| { + row.get::<_, String>(0) + }))?; + + let mut result = Vec::new(); + for row in rows { + result.push(to_store_err!(row)?); + } + + // Delete the marks after consuming them + to_store_err!(execute: conn.execute( + "DELETE FROM sender_key_status WHERE group_jid = ?1 AND device_id = ?2", + params![group_jid, self.device_id], + ))?; + + Ok(result) + } + + // --- TcToken Storage --- + + async fn get_tc_token( + &self, + jid: &str, + ) -> wa_rs_core::store::error::Result> { + let conn = self.conn.lock(); + let result = conn.query_row( + "SELECT token, token_timestamp, sender_timestamp FROM tc_tokens + WHERE jid = ?1 AND device_id = ?2", + params![jid, self.device_id], + |row| { + Ok(TcTokenEntry { + token: row.get(0)?, + token_timestamp: row.get(1)?, + sender_timestamp: row.get(2)?, + }) + }, + ); + + match result { + Ok(entry) => Ok(Some(entry)), + Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None), + Err(e) => Err(wa_rs_core::store::error::StoreError::Database( + e.to_string(), + )), + } + } + + async fn put_tc_token( + &self, + jid: &str, + entry: &TcTokenEntry, + ) -> wa_rs_core::store::error::Result<()> { + let conn = self.conn.lock(); + let now = chrono::Utc::now().timestamp(); + + to_store_err!(execute: conn.execute( + "INSERT OR REPLACE INTO tc_tokens + (jid, token, token_timestamp, sender_timestamp, device_id, updated_at) + VALUES (?1, ?2, ?3, ?4, ?5, ?6)", + params![ + jid, + entry.token, + entry.token_timestamp, + entry.sender_timestamp, + self.device_id, + now, + ], + )) + } + + async fn delete_tc_token(&self, jid: &str) -> wa_rs_core::store::error::Result<()> { + let conn = self.conn.lock(); + to_store_err!(execute: conn.execute( + "DELETE FROM tc_tokens WHERE jid = ?1 AND device_id = ?2", + params![jid, self.device_id], + )) + } + + async fn get_all_tc_token_jids(&self) -> wa_rs_core::store::error::Result> { + let conn = self.conn.lock(); + let mut stmt = + to_store_err!(conn.prepare("SELECT jid FROM tc_tokens WHERE device_id = ?1"))?; + + let rows = to_store_err!( + stmt.query_map(params![self.device_id], |row| { row.get::<_, String>(0) }) + )?; + + let mut result = Vec::new(); + for row in rows { + result.push(to_store_err!(row)?); + } + + Ok(result) + } + + async fn delete_expired_tc_tokens( + &self, + cutoff_timestamp: i64, + ) -> wa_rs_core::store::error::Result { + let conn = self.conn.lock(); + let deleted = conn + .execute( + "DELETE FROM tc_tokens WHERE token_timestamp < ?1 AND device_id = ?2", + params![cutoff_timestamp, self.device_id], + ) + .map_err(|e| wa_rs_core::store::error::StoreError::Database(e.to_string()))?; + + let deleted = u32::try_from(deleted).map_err(|_| { + wa_rs_core::store::error::StoreError::Database(format!( + "Affected row count overflowed u32: {deleted}" + )) + })?; + + Ok(deleted) + } +} + +#[cfg(feature = "whatsapp-web")] +#[async_trait] +impl DeviceStoreTrait for RusqliteStore { + async fn save(&self, device: &CoreDevice) -> wa_rs_core::store::error::Result<()> { + let conn = self.conn.lock(); + + // Serialize KeyPairs to bytes + let noise_key = { + let mut bytes = Vec::new(); + let priv_key = device.noise_key.private_key.serialize(); + bytes.extend_from_slice(priv_key.as_slice()); + bytes.extend_from_slice(device.noise_key.public_key.public_key_bytes()); + bytes + }; + + let identity_key = { + let mut bytes = Vec::new(); + let priv_key = device.identity_key.private_key.serialize(); + bytes.extend_from_slice(priv_key.as_slice()); + bytes.extend_from_slice(device.identity_key.public_key.public_key_bytes()); + bytes + }; + + let signed_pre_key = { + let mut bytes = Vec::new(); + let priv_key = device.signed_pre_key.private_key.serialize(); + bytes.extend_from_slice(priv_key.as_slice()); + bytes.extend_from_slice(device.signed_pre_key.public_key.public_key_bytes()); + bytes + }; + + let account = device.account.as_ref().map(|a| a.encode_to_vec()); + + to_store_err!(execute: conn.execute( + "INSERT OR REPLACE INTO device ( + id, lid, pn, registration_id, noise_key, identity_key, + signed_pre_key, signed_pre_key_id, signed_pre_key_signature, + adv_secret_key, account, push_name, app_version_primary, + app_version_secondary, app_version_tertiary, app_version_last_fetched_ms, + edge_routing_info, props_hash + ) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13, ?14, ?15, ?16, ?17, ?18)", + params![ + self.device_id, + device.lid.as_ref().map(|j| j.to_string()), + device.pn.as_ref().map(|j| j.to_string()), + device.registration_id, + noise_key, + identity_key, + signed_pre_key, + device.signed_pre_key_id, + device.signed_pre_key_signature.to_vec(), + device.adv_secret_key.to_vec(), + account, + &device.push_name, + device.app_version_primary, + device.app_version_secondary, + device.app_version_tertiary, + device.app_version_last_fetched_ms, + device.edge_routing_info.as_ref().map(|v| v.clone()), + device.props_hash.as_ref().map(|v| v.clone()), + ], + )) + } + + async fn load(&self) -> wa_rs_core::store::error::Result> { + let conn = self.conn.lock(); + let result = conn.query_row( + "SELECT * FROM device WHERE id = ?1", + params![self.device_id], + |row| { + // Helper to convert errors to rusqlite::Error + fn to_rusqlite_err( + e: E, + ) -> rusqlite::Error { + rusqlite::Error::ToSqlConversionFailure(Box::new(e)) + } + + // Deserialize KeyPairs from bytes (64 bytes each) + let noise_key_bytes: Vec = row.get("noise_key")?; + let identity_key_bytes: Vec = row.get("identity_key")?; + let signed_pre_key_bytes: Vec = row.get("signed_pre_key")?; + + if noise_key_bytes.len() != 64 + || identity_key_bytes.len() != 64 + || signed_pre_key_bytes.len() != 64 + { + return Err(rusqlite::Error::InvalidParameterName("key_pair".into())); + } + + use wa_rs_core::libsignal::protocol::{KeyPair, PrivateKey, PublicKey}; + + let noise_key = KeyPair::new( + PublicKey::from_djb_public_key_bytes(&noise_key_bytes[32..64]) + .map_err(to_rusqlite_err)?, + PrivateKey::deserialize(&noise_key_bytes[0..32]).map_err(to_rusqlite_err)?, + ); + + let identity_key = KeyPair::new( + PublicKey::from_djb_public_key_bytes(&identity_key_bytes[32..64]) + .map_err(to_rusqlite_err)?, + PrivateKey::deserialize(&identity_key_bytes[0..32]).map_err(to_rusqlite_err)?, + ); + + let signed_pre_key = KeyPair::new( + PublicKey::from_djb_public_key_bytes(&signed_pre_key_bytes[32..64]) + .map_err(to_rusqlite_err)?, + PrivateKey::deserialize(&signed_pre_key_bytes[0..32]) + .map_err(to_rusqlite_err)?, + ); + + let lid_str: Option = row.get("lid")?; + let pn_str: Option = row.get("pn")?; + let signature_bytes: Vec = row.get("signed_pre_key_signature")?; + let adv_secret_bytes: Vec = row.get("adv_secret_key")?; + let account_bytes: Option> = row.get("account")?; + + let mut signature = [0u8; 64]; + let mut adv_secret = [0u8; 32]; + signature.copy_from_slice(&signature_bytes); + adv_secret.copy_from_slice(&adv_secret_bytes); + + let account = if let Some(bytes) = account_bytes { + Some( + wa_rs_proto::whatsapp::AdvSignedDeviceIdentity::decode(&*bytes) + .map_err(to_rusqlite_err)?, + ) + } else { + None + }; + + Ok(CoreDevice { + lid: lid_str.and_then(|s| s.parse().ok()), + pn: pn_str.and_then(|s| s.parse().ok()), + registration_id: row.get("registration_id")?, + noise_key, + identity_key, + signed_pre_key, + signed_pre_key_id: row.get("signed_pre_key_id")?, + signed_pre_key_signature: signature, + adv_secret_key: adv_secret, + account, + push_name: row.get("push_name")?, + app_version_primary: row.get("app_version_primary")?, + app_version_secondary: row.get("app_version_secondary")?, + app_version_tertiary: row.get("app_version_tertiary")?, + app_version_last_fetched_ms: row.get("app_version_last_fetched_ms")?, + edge_routing_info: row.get("edge_routing_info")?, + props_hash: row.get("props_hash")?, + ..Default::default() + }) + }, + ); + + match result { + Ok(device) => Ok(Some(device)), + Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None), + Err(e) => Err(wa_rs_core::store::error::StoreError::Database( + e.to_string(), + )), + } + } + + async fn exists(&self) -> wa_rs_core::store::error::Result { + let conn = self.conn.lock(); + let count: i64 = to_store_err!(conn.query_row( + "SELECT COUNT(*) FROM device WHERE id = ?1", + params![self.device_id], + |row| row.get(0), + ))?; + + Ok(count > 0) + } + + async fn create(&self) -> wa_rs_core::store::error::Result { + // Device already created in constructor, just return the ID + Ok(self.device_id) + } + + async fn snapshot_db( + &self, + name: &str, + extra_content: Option<&[u8]>, + ) -> wa_rs_core::store::error::Result<()> { + // Create a snapshot by copying the database file + let snapshot_path = format!("{}.snapshot.{}", self.db_path, name); + + to_store_err!(std::fs::copy(&self.db_path, &snapshot_path))?; + + // If extra_content is provided, save it alongside + if let Some(content) = extra_content { + let content_path = format!("{}.extra", snapshot_path); + to_store_err!(std::fs::write(&content_path, content))?; + } + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + #[cfg(feature = "whatsapp-web")] + use wa_rs_core::store::traits::{LidPnMappingEntry, ProtocolStore, TcTokenEntry}; + + #[cfg(feature = "whatsapp-web")] + #[test] + fn rusqlite_store_creates_database() { + let tmp = tempfile::NamedTempFile::new().unwrap(); + let store = RusqliteStore::new(tmp.path()).unwrap(); + assert_eq!(store.device_id, 1); + } + + #[cfg(feature = "whatsapp-web")] + #[tokio::test] + async fn lid_mapping_round_trip_preserves_learning_source_and_updated_at() { + let tmp = tempfile::NamedTempFile::new().unwrap(); + let store = RusqliteStore::new(tmp.path()).unwrap(); + let entry = LidPnMappingEntry { + lid: "100000012345678".to_string(), + phone_number: "15551234567".to_string(), + created_at: 1_700_000_000, + updated_at: 1_700_000_100, + learning_source: "usync".to_string(), + }; + + ProtocolStore::put_lid_mapping(&store, &entry) + .await + .unwrap(); + + let loaded = ProtocolStore::get_lid_mapping(&store, &entry.lid) + .await + .unwrap() + .expect("expected lid mapping to be present"); + assert_eq!(loaded.learning_source, entry.learning_source); + assert_eq!(loaded.updated_at, entry.updated_at); + + let loaded_by_pn = ProtocolStore::get_pn_mapping(&store, &entry.phone_number) + .await + .unwrap() + .expect("expected pn mapping to be present"); + assert_eq!(loaded_by_pn.learning_source, entry.learning_source); + assert_eq!(loaded_by_pn.updated_at, entry.updated_at); + } + + #[cfg(feature = "whatsapp-web")] + #[tokio::test] + async fn delete_expired_tc_tokens_returns_deleted_row_count() { + let tmp = tempfile::NamedTempFile::new().unwrap(); + let store = RusqliteStore::new(tmp.path()).unwrap(); + + let expired = TcTokenEntry { + token: vec![1, 2, 3], + token_timestamp: 10, + sender_timestamp: None, + }; + let fresh = TcTokenEntry { + token: vec![4, 5, 6], + token_timestamp: 1000, + sender_timestamp: Some(1000), + }; + + ProtocolStore::put_tc_token(&store, "15550000001", &expired) + .await + .unwrap(); + ProtocolStore::put_tc_token(&store, "15550000002", &fresh) + .await + .unwrap(); + + let deleted = ProtocolStore::delete_expired_tc_tokens(&store, 100) + .await + .unwrap(); + assert_eq!(deleted, 1); + assert!(ProtocolStore::get_tc_token(&store, "15550000001") + .await + .unwrap() + .is_none()); + assert!(ProtocolStore::get_tc_token(&store, "15550000002") + .await + .unwrap() + .is_some()); + } +} diff --git a/src/channels/whatsapp_web.rs b/src/channels/whatsapp_web.rs new file mode 100644 index 0000000..f6e89c2 --- /dev/null +++ b/src/channels/whatsapp_web.rs @@ -0,0 +1,564 @@ +//! WhatsApp Web channel using wa-rs (native Rust implementation) +//! +//! This channel provides direct WhatsApp Web integration with: +//! - QR code and pair code linking +//! - End-to-end encryption via Signal Protocol +//! - Full Baileys parity (groups, media, presence, reactions, editing/deletion) +//! +//! # Feature Flag +//! +//! This channel requires the `whatsapp-web` feature flag: +//! ```sh +//! cargo build --features whatsapp-web +//! ``` +//! +//! # Configuration +//! +//! ```toml +//! [channels_config.whatsapp] +//! session_path = "~/.zeroclaw/whatsapp-session.db" # Required for Web mode +//! pair_phone = "15551234567" # Optional: for pair code linking +//! allowed_numbers = ["+1234567890", "*"] # Same as Cloud API +//! ``` +//! +//! # Runtime Negotiation +//! +//! This channel is automatically selected when `session_path` is set in the config. +//! The Cloud API channel is used when `phone_number_id` is set. + +use super::traits::{Channel, ChannelMessage, SendMessage}; +use super::whatsapp_storage::RusqliteStore; +use anyhow::{anyhow, Result}; +use async_trait::async_trait; +use parking_lot::Mutex; +use std::sync::Arc; +use tokio::select; + +/// WhatsApp Web channel using wa-rs with custom rusqlite storage +/// +/// # Status: Functional Implementation +/// +/// This implementation uses the wa-rs Bot with our custom RusqliteStore backend. +/// +/// # Configuration +/// +/// ```toml +/// [channels_config.whatsapp] +/// session_path = "~/.zeroclaw/whatsapp-session.db" +/// pair_phone = "15551234567" # Optional +/// allowed_numbers = ["+1234567890", "*"] +/// ``` +#[cfg(feature = "whatsapp-web")] +pub struct WhatsAppWebChannel { + /// Session database path + session_path: String, + /// Phone number for pair code linking (optional) + pair_phone: Option, + /// Custom pair code (optional) + pair_code: Option, + /// Allowed phone numbers (E.164 format) or "*" for all + allowed_numbers: Vec, + /// Bot handle for shutdown + bot_handle: Arc>>>, + /// Client handle for sending messages and typing indicators + client: Arc>>>, + /// Message sender channel + tx: Arc>>>, +} + +impl WhatsAppWebChannel { + /// Create a new WhatsApp Web channel + /// + /// # Arguments + /// + /// * `session_path` - Path to the SQLite session database + /// * `pair_phone` - Optional phone number for pair code linking (format: "15551234567") + /// * `pair_code` - Optional custom pair code (leave empty for auto-generated) + /// * `allowed_numbers` - Phone numbers allowed to interact (E.164 format) or "*" for all + #[cfg(feature = "whatsapp-web")] + pub fn new( + session_path: String, + pair_phone: Option, + pair_code: Option, + allowed_numbers: Vec, + ) -> Self { + Self { + session_path, + pair_phone, + pair_code, + allowed_numbers, + bot_handle: Arc::new(Mutex::new(None)), + client: Arc::new(Mutex::new(None)), + tx: Arc::new(Mutex::new(None)), + } + } + + /// Check if a phone number is allowed (E.164 format: +1234567890) + #[cfg(feature = "whatsapp-web")] + fn is_number_allowed(&self, phone: &str) -> bool { + self.allowed_numbers.iter().any(|n| n == "*" || n == phone) + } + + /// Normalize phone number to E.164 format + #[cfg(feature = "whatsapp-web")] + fn normalize_phone(&self, phone: &str) -> String { + let trimmed = phone.trim(); + let user_part = trimmed + .split_once('@') + .map(|(user, _)| user) + .unwrap_or(trimmed); + let normalized_user = user_part.trim_start_matches('+'); + if user_part.starts_with('+') { + format!("+{normalized_user}") + } else { + format!("+{normalized_user}") + } + } + + /// Whether the recipient string is a WhatsApp JID (contains a domain suffix). + #[cfg(feature = "whatsapp-web")] + fn is_jid(recipient: &str) -> bool { + recipient.trim().contains('@') + } + + /// Convert a recipient to a wa-rs JID. + /// + /// Supports: + /// - Full JIDs (e.g. "12345@s.whatsapp.net") + /// - E.164-like numbers (e.g. "+1234567890") + #[cfg(feature = "whatsapp-web")] + fn recipient_to_jid(&self, recipient: &str) -> Result { + let trimmed = recipient.trim(); + if trimmed.is_empty() { + anyhow::bail!("Recipient cannot be empty"); + } + + if trimmed.contains('@') { + return trimmed + .parse::() + .map_err(|e| anyhow!("Invalid WhatsApp JID `{trimmed}`: {e}")); + } + + let digits: String = trimmed.chars().filter(|c| c.is_ascii_digit()).collect(); + if digits.is_empty() { + anyhow::bail!("Recipient `{trimmed}` does not contain a valid phone number"); + } + + Ok(wa_rs_binary::jid::Jid::pn(digits)) + } +} + +#[cfg(feature = "whatsapp-web")] +#[async_trait] +impl Channel for WhatsAppWebChannel { + fn name(&self) -> &str { + "whatsapp" + } + + async fn send(&self, message: &SendMessage) -> Result<()> { + let client = self.client.lock().clone(); + let Some(client) = client else { + anyhow::bail!("WhatsApp Web client not connected. Initialize the bot first."); + }; + + // Validate recipient allowlist only for direct phone-number targets. + if !Self::is_jid(&message.recipient) { + let normalized = self.normalize_phone(&message.recipient); + if !self.is_number_allowed(&normalized) { + tracing::warn!( + "WhatsApp Web: recipient {} not in allowed list", + message.recipient + ); + return Ok(()); + } + } + + let to = self.recipient_to_jid(&message.recipient)?; + let outgoing = wa_rs_proto::whatsapp::Message { + conversation: Some(message.content.clone()), + ..Default::default() + }; + + let message_id = client.send_message(to, outgoing).await?; + tracing::debug!( + "WhatsApp Web: sent message to {} (id: {})", + message.recipient, + message_id + ); + Ok(()) + } + + async fn listen(&self, tx: tokio::sync::mpsc::Sender) -> Result<()> { + // Store the sender channel for incoming messages + *self.tx.lock() = Some(tx.clone()); + + use wa_rs::bot::Bot; + use wa_rs::pair_code::PairCodeOptions; + use wa_rs::store::{Device, DeviceStore}; + use wa_rs_binary::jid::JidExt as _; + use wa_rs_core::proto_helpers::MessageExt; + use wa_rs_core::types::events::Event; + use wa_rs_tokio_transport::TokioWebSocketTransportFactory; + use wa_rs_ureq_http::UreqHttpClient; + + tracing::info!( + "WhatsApp Web channel starting (session: {})", + self.session_path + ); + + // Initialize storage backend + let storage = RusqliteStore::new(&self.session_path)?; + let backend = Arc::new(storage); + + // Check if we have a saved device to load + let mut device = Device::new(backend.clone()); + if backend.exists().await? { + tracing::info!("WhatsApp Web: found existing session, loading device"); + if let Some(core_device) = backend.load().await? { + device.load_from_serializable(core_device); + } else { + anyhow::bail!("Device exists but failed to load"); + } + } else { + tracing::info!( + "WhatsApp Web: no existing session, new device will be created during pairing" + ); + }; + + // Create transport factory + let mut transport_factory = TokioWebSocketTransportFactory::new(); + if let Ok(ws_url) = std::env::var("WHATSAPP_WS_URL") { + transport_factory = transport_factory.with_url(ws_url); + } + + // Create HTTP client for media operations + let http_client = UreqHttpClient::new(); + + // Build the bot + let tx_clone = tx.clone(); + let allowed_numbers = self.allowed_numbers.clone(); + + let mut builder = Bot::builder() + .with_backend(backend) + .with_transport_factory(transport_factory) + .with_http_client(http_client) + .on_event(move |event, _client| { + let tx_inner = tx_clone.clone(); + let allowed_numbers = allowed_numbers.clone(); + async move { + match event { + Event::Message(msg, info) => { + // Extract message content + let text = msg.text_content().unwrap_or(""); + let sender = info.source.sender.user().to_string(); + let chat = info.source.chat.to_string(); + + tracing::info!( + "WhatsApp Web message from {} in {}: {}", + sender, + chat, + text + ); + + // Check if sender is allowed + let normalized = if sender.starts_with('+') { + sender.clone() + } else { + format!("+{sender}") + }; + + if allowed_numbers.iter().any(|n| n == "*" || n == &normalized) { + let trimmed = text.trim(); + if trimmed.is_empty() { + tracing::debug!( + "WhatsApp Web: ignoring empty or non-text message from {}", + normalized + ); + return; + } + + if let Err(e) = tx_inner + .send(ChannelMessage { + id: uuid::Uuid::new_v4().to_string(), + channel: "whatsapp".to_string(), + sender: normalized.clone(), + // Reply to the originating chat JID (DM or group). + reply_target: chat, + content: trimmed.to_string(), + timestamp: chrono::Utc::now().timestamp() as u64, + thread_ts: None, + }) + .await + { + tracing::error!("Failed to send message to channel: {}", e); + } + } else { + tracing::warn!("WhatsApp Web: message from {} not in allowed list", normalized); + } + } + Event::Connected(_) => { + tracing::info!("WhatsApp Web connected successfully"); + } + Event::LoggedOut(_) => { + tracing::warn!("WhatsApp Web was logged out"); + } + Event::StreamError(stream_error) => { + tracing::error!("WhatsApp Web stream error: {:?}", stream_error); + } + Event::PairingCode { code, .. } => { + tracing::info!("WhatsApp Web pair code received: {}", code); + tracing::info!( + "Link your phone by entering this code in WhatsApp > Linked Devices" + ); + } + Event::PairingQrCode { code, .. } => { + tracing::info!( + "WhatsApp Web QR code received (scan with WhatsApp > Linked Devices)" + ); + tracing::debug!("QR code: {}", code); + } + _ => {} + } + } + }) + ; + + // Configure pair-code flow when a phone number is provided. + if let Some(ref phone) = self.pair_phone { + tracing::info!("WhatsApp Web: pair-code flow enabled for configured phone number"); + builder = builder.with_pair_code(PairCodeOptions { + phone_number: phone.clone(), + custom_code: self.pair_code.clone(), + ..Default::default() + }); + } else if self.pair_code.is_some() { + tracing::warn!( + "WhatsApp Web: pair_code is set but pair_phone is missing; pair code config is ignored" + ); + } + + let mut bot = builder.build().await?; + *self.client.lock() = Some(bot.client()); + + // Run the bot + let bot_handle = bot.run().await?; + + // Store the bot handle for later shutdown + *self.bot_handle.lock() = Some(bot_handle); + + // Wait for shutdown signal + let (_shutdown_tx, mut shutdown_rx) = tokio::sync::broadcast::channel::<()>(1); + + select! { + _ = shutdown_rx.recv() => { + tracing::info!("WhatsApp Web channel shutting down"); + } + _ = tokio::signal::ctrl_c() => { + tracing::info!("WhatsApp Web channel received Ctrl+C"); + } + } + + *self.client.lock() = None; + if let Some(handle) = self.bot_handle.lock().take() { + handle.abort(); + } + + Ok(()) + } + + async fn health_check(&self) -> bool { + let bot_handle_guard = self.bot_handle.lock(); + bot_handle_guard.is_some() + } + + async fn start_typing(&self, recipient: &str) -> Result<()> { + let client = self.client.lock().clone(); + let Some(client) = client else { + anyhow::bail!("WhatsApp Web client not connected. Initialize the bot first."); + }; + + if !Self::is_jid(recipient) { + let normalized = self.normalize_phone(recipient); + if !self.is_number_allowed(&normalized) { + tracing::warn!( + "WhatsApp Web: typing target {} not in allowed list", + recipient + ); + return Ok(()); + } + } + + let to = self.recipient_to_jid(recipient)?; + client + .chatstate() + .send_composing(&to) + .await + .map_err(|e| anyhow!("Failed to send typing state (composing): {e}"))?; + + tracing::debug!("WhatsApp Web: start typing for {}", recipient); + Ok(()) + } + + async fn stop_typing(&self, recipient: &str) -> Result<()> { + let client = self.client.lock().clone(); + let Some(client) = client else { + anyhow::bail!("WhatsApp Web client not connected. Initialize the bot first."); + }; + + if !Self::is_jid(recipient) { + let normalized = self.normalize_phone(recipient); + if !self.is_number_allowed(&normalized) { + tracing::warn!( + "WhatsApp Web: typing target {} not in allowed list", + recipient + ); + return Ok(()); + } + } + + let to = self.recipient_to_jid(recipient)?; + client + .chatstate() + .send_paused(&to) + .await + .map_err(|e| anyhow!("Failed to send typing state (paused): {e}"))?; + + tracing::debug!("WhatsApp Web: stop typing for {}", recipient); + Ok(()) + } +} + +// Stub implementation when feature is not enabled +#[cfg(not(feature = "whatsapp-web"))] +pub struct WhatsAppWebChannel { + _private: (), +} + +#[cfg(not(feature = "whatsapp-web"))] +impl WhatsAppWebChannel { + pub fn new( + _session_path: String, + _pair_phone: Option, + _pair_code: Option, + _allowed_numbers: Vec, + ) -> Self { + Self { _private: () } + } +} + +#[cfg(not(feature = "whatsapp-web"))] +#[async_trait] +impl Channel for WhatsAppWebChannel { + fn name(&self) -> &str { + "whatsapp" + } + + async fn send(&self, _message: &SendMessage) -> Result<()> { + anyhow::bail!( + "WhatsApp Web channel requires the 'whatsapp-web' feature. \ + Enable with: cargo build --features whatsapp-web" + ); + } + + async fn listen(&self, _tx: tokio::sync::mpsc::Sender) -> Result<()> { + anyhow::bail!( + "WhatsApp Web channel requires the 'whatsapp-web' feature. \ + Enable with: cargo build --features whatsapp-web" + ); + } + + async fn health_check(&self) -> bool { + false + } + + async fn start_typing(&self, _recipient: &str) -> Result<()> { + anyhow::bail!( + "WhatsApp Web channel requires the 'whatsapp-web' feature. \ + Enable with: cargo build --features whatsapp-web" + ); + } + + async fn stop_typing(&self, _recipient: &str) -> Result<()> { + anyhow::bail!( + "WhatsApp Web channel requires the 'whatsapp-web' feature. \ + Enable with: cargo build --features whatsapp-web" + ); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[cfg(feature = "whatsapp-web")] + fn make_channel() -> WhatsAppWebChannel { + WhatsAppWebChannel::new( + "/tmp/test-whatsapp.db".into(), + None, + None, + vec!["+1234567890".into()], + ) + } + + #[test] + #[cfg(feature = "whatsapp-web")] + fn whatsapp_web_channel_name() { + let ch = make_channel(); + assert_eq!(ch.name(), "whatsapp"); + } + + #[test] + #[cfg(feature = "whatsapp-web")] + fn whatsapp_web_number_allowed_exact() { + let ch = make_channel(); + assert!(ch.is_number_allowed("+1234567890")); + assert!(!ch.is_number_allowed("+9876543210")); + } + + #[test] + #[cfg(feature = "whatsapp-web")] + fn whatsapp_web_number_allowed_wildcard() { + let ch = WhatsAppWebChannel::new("/tmp/test.db".into(), None, None, vec!["*".into()]); + assert!(ch.is_number_allowed("+1234567890")); + assert!(ch.is_number_allowed("+9999999999")); + } + + #[test] + #[cfg(feature = "whatsapp-web")] + fn whatsapp_web_number_denied_empty() { + let ch = WhatsAppWebChannel::new("/tmp/test.db".into(), None, None, vec![]); + // Empty allowlist means "deny all" (matches channel-wide allowlist policy). + assert!(!ch.is_number_allowed("+1234567890")); + } + + #[test] + #[cfg(feature = "whatsapp-web")] + fn whatsapp_web_normalize_phone_adds_plus() { + let ch = make_channel(); + assert_eq!(ch.normalize_phone("1234567890"), "+1234567890"); + } + + #[test] + #[cfg(feature = "whatsapp-web")] + fn whatsapp_web_normalize_phone_preserves_plus() { + let ch = make_channel(); + assert_eq!(ch.normalize_phone("+1234567890"), "+1234567890"); + } + + #[test] + #[cfg(feature = "whatsapp-web")] + fn whatsapp_web_normalize_phone_from_jid() { + let ch = make_channel(); + assert_eq!( + ch.normalize_phone("1234567890@s.whatsapp.net"), + "+1234567890" + ); + } + + #[tokio::test] + #[cfg(feature = "whatsapp-web")] + async fn whatsapp_web_health_check_disconnected() { + let ch = make_channel(); + assert!(!ch.health_check().await); + } +} diff --git a/src/config/mod.rs b/src/config/mod.rs index 7f3fe29..4649f9c 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -6,14 +6,14 @@ pub use schema::{ build_runtime_proxy_client_with_timeouts, runtime_proxy_config, set_runtime_proxy_config, AgentConfig, AuditConfig, AutonomyConfig, BrowserComputerUseConfig, BrowserConfig, ChannelsConfig, ClassificationRule, ComposioConfig, Config, CostConfig, CronConfig, - DelegateAgentConfig, DiscordConfig, DockerRuntimeConfig, GatewayConfig, HardwareConfig, - HardwareTransport, HeartbeatConfig, HttpRequestConfig, IMessageConfig, IdentityConfig, - LarkConfig, MatrixConfig, MemoryConfig, ModelRouteConfig, ObservabilityConfig, - PeripheralBoardConfig, PeripheralsConfig, ProxyConfig, ProxyScope, QueryClassificationConfig, - ReliabilityConfig, ResourceLimitsConfig, RuntimeConfig, SandboxBackend, SandboxConfig, - SchedulerConfig, SecretsConfig, SecurityConfig, SlackConfig, StorageConfig, - StorageProviderConfig, StorageProviderSection, StreamMode, TelegramConfig, TunnelConfig, - WebSearchConfig, WebhookConfig, + DelegateAgentConfig, DiscordConfig, DockerRuntimeConfig, EmbeddingRouteConfig, GatewayConfig, + HardwareConfig, HardwareTransport, HeartbeatConfig, HttpRequestConfig, IMessageConfig, + IdentityConfig, LarkConfig, MatrixConfig, MemoryConfig, ModelRouteConfig, MultimodalConfig, + ObservabilityConfig, PeripheralBoardConfig, PeripheralsConfig, ProxyConfig, ProxyScope, + QueryClassificationConfig, ReliabilityConfig, ResourceLimitsConfig, RuntimeConfig, + SandboxBackend, SandboxConfig, SchedulerConfig, SecretsConfig, SecurityConfig, SkillsConfig, + SlackConfig, StorageConfig, StorageProviderConfig, StorageProviderSection, StreamMode, + TelegramConfig, TunnelConfig, WebSearchConfig, WebhookConfig, }; #[cfg(test)] @@ -36,6 +36,7 @@ mod tests { allowed_users: vec!["alice".into()], stream_mode: StreamMode::default(), draft_update_interval_ms: 1000, + interrupt_on_new_message: false, mention_only: false, }; diff --git a/src/config/schema.rs b/src/config/schema.rs index 8d9138f..f47bb9d 100644 --- a/src/config/schema.rs +++ b/src/config/schema.rs @@ -2,12 +2,15 @@ use crate::providers::{is_glm_alias, is_zai_alias}; use crate::security::AutonomyLevel; use anyhow::{Context, Result}; use directories::UserDirs; +use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use std::collections::HashMap; -use std::fs::{self, File, OpenOptions}; -use std::io::Write; use std::path::{Path, PathBuf}; use std::sync::{OnceLock, RwLock}; +#[cfg(unix)] +use tokio::fs::File; +use tokio::fs::{self, OpenOptions}; +use tokio::io::AsyncWriteExt; const SUPPORTED_PROXY_SERVICE_KEYS: &[&str] = &[ "provider.anthropic", @@ -45,7 +48,10 @@ static RUNTIME_PROXY_CLIENT_CACHE: OnceLock, /// Base URL override for provider API (e.g. "http://10.0.0.1:11434" for remote Ollama) pub api_url: Option, + /// Default provider ID or alias (e.g. `"openrouter"`, `"ollama"`, `"anthropic"`). Default: `"openrouter"`. pub default_provider: Option, + /// Default model routed through the selected provider (e.g. `"anthropic/claude-sonnet-4-6"`). pub default_model: Option, + /// Default model temperature (0.0–2.0). Default: `0.7`. pub default_temperature: f64, + /// Observability backend configuration (`[observability]`). #[serde(default)] pub observability: ObservabilityConfig, + /// Autonomy and security policy configuration (`[autonomy]`). #[serde(default)] pub autonomy: AutonomyConfig, + /// Runtime adapter configuration (`[runtime]`). Controls native vs Docker execution. #[serde(default)] pub runtime: RuntimeConfig, + /// Reliability settings: retries, fallback providers, backoff (`[reliability]`). #[serde(default)] pub reliability: ReliabilityConfig, + /// Scheduler configuration for periodic task execution (`[scheduler]`). #[serde(default)] pub scheduler: SchedulerConfig, + /// Agent orchestration settings (`[agent]`). #[serde(default)] pub agent: AgentConfig, + /// Skills loading and community repository behavior (`[skills]`). + #[serde(default)] + pub skills: SkillsConfig, + /// Model routing rules — route `hint:` to specific provider+model combos. #[serde(default)] pub model_routes: Vec, + /// Embedding routing rules — route `hint:` to specific provider+model combos. + #[serde(default)] + pub embedding_routes: Vec, + /// Automatic query classification — maps user messages to model hints. #[serde(default)] pub query_classification: QueryClassificationConfig, + /// Heartbeat configuration for periodic health pings (`[heartbeat]`). #[serde(default)] pub heartbeat: HeartbeatConfig, + /// Cron job configuration (`[cron]`). #[serde(default)] pub cron: CronConfig, + /// Channel configurations: Telegram, Discord, Slack, etc. (`[channels_config]`). #[serde(default)] pub channels_config: ChannelsConfig, + /// Memory backend configuration: sqlite, markdown, embeddings (`[memory]`). #[serde(default)] pub memory: MemoryConfig, + /// Persistent storage provider configuration (`[storage]`). #[serde(default)] pub storage: StorageConfig, + /// Tunnel configuration for exposing the gateway publicly (`[tunnel]`). #[serde(default)] pub tunnel: TunnelConfig, + /// Gateway server configuration: host, port, pairing, rate limits (`[gateway]`). #[serde(default)] pub gateway: GatewayConfig, + /// Composio managed OAuth tools integration (`[composio]`). #[serde(default)] pub composio: ComposioConfig, + /// Secrets encryption configuration (`[secrets]`). #[serde(default)] pub secrets: SecretsConfig, + /// Browser automation configuration (`[browser]`). #[serde(default)] pub browser: BrowserConfig, + /// HTTP request tool configuration (`[http_request]`). #[serde(default)] pub http_request: HttpRequestConfig, + /// Multimodal (image) handling configuration (`[multimodal]`). + #[serde(default)] + pub multimodal: MultimodalConfig, + + /// Web search tool configuration (`[web_search]`). #[serde(default)] pub web_search: WebSearchConfig, + /// Proxy configuration for outbound HTTP/HTTPS/SOCKS5 traffic (`[proxy]`). #[serde(default)] pub proxy: ProxyConfig, + /// Identity format configuration: OpenClaw or AIEOS (`[identity]`). #[serde(default)] pub identity: IdentityConfig, + /// Cost tracking and budget enforcement configuration (`[cost]`). #[serde(default)] pub cost: CostConfig, + /// Peripheral board configuration for hardware integration (`[peripherals]`). #[serde(default)] pub peripherals: PeripheralsConfig, @@ -146,7 +190,7 @@ pub struct Config { // ── Delegate Agents ────────────────────────────────────────────── /// Configuration for a delegate sub-agent used by the `delegate` tool. -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] pub struct DelegateAgentConfig { /// Provider name (e.g. "ollama", "openrouter", "anthropic") pub provider: String, @@ -173,7 +217,7 @@ fn default_max_depth() -> u32 { // ── Hardware Config (wizard-driven) ───────────────────────────── /// Hardware transport mode. -#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default)] +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)] pub enum HardwareTransport { #[default] None, @@ -194,7 +238,7 @@ impl std::fmt::Display for HardwareTransport { } /// Wizard-driven hardware configuration for physical world interaction. -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] pub struct HardwareConfig { /// Whether hardware access is enabled #[serde(default)] @@ -240,17 +284,23 @@ impl Default for HardwareConfig { } } -#[derive(Debug, Clone, Serialize, Deserialize)] +/// Agent orchestration configuration (`[agent]` section). +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] pub struct AgentConfig { /// When true: bootstrap_max_chars=6000, rag_chunk_limit=2. Use for 13B or smaller models. #[serde(default)] pub compact_context: bool, + /// Maximum tool-call loop turns per user message. Default: `10`. + /// Setting to `0` falls back to the safe default of `10`. #[serde(default = "default_agent_max_tool_iterations")] pub max_tool_iterations: usize, + /// Maximum conversation history messages retained per session. Default: `50`. #[serde(default = "default_agent_max_history_messages")] pub max_history_messages: usize, + /// Enable parallel tool execution within a single iteration. Default: `false`. #[serde(default)] pub parallel_tools: bool, + /// Tool dispatch strategy (e.g. `"auto"`). Default: `"auto"`. #[serde(default = "default_agent_tool_dispatcher")] pub tool_dispatcher: String, } @@ -279,9 +329,75 @@ impl Default for AgentConfig { } } +/// Skills loading configuration (`[skills]` section). +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] +pub struct SkillsConfig { + /// Enable loading and syncing the community open-skills repository. + /// Default: `false` (opt-in). + #[serde(default)] + pub open_skills_enabled: bool, + /// Optional path to a local open-skills repository. + /// If unset, defaults to `$HOME/open-skills` when enabled. + #[serde(default)] + pub open_skills_dir: Option, +} + +impl Default for SkillsConfig { + fn default() -> Self { + Self { + open_skills_enabled: false, + open_skills_dir: None, + } + } +} + +/// Multimodal (image) handling configuration (`[multimodal]` section). +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] +pub struct MultimodalConfig { + /// Maximum number of image attachments accepted per request. + #[serde(default = "default_multimodal_max_images")] + pub max_images: usize, + /// Maximum image payload size in MiB before base64 encoding. + #[serde(default = "default_multimodal_max_image_size_mb")] + pub max_image_size_mb: usize, + /// Allow fetching remote image URLs (http/https). Disabled by default. + #[serde(default)] + pub allow_remote_fetch: bool, +} + +fn default_multimodal_max_images() -> usize { + 4 +} + +fn default_multimodal_max_image_size_mb() -> usize { + 5 +} + +impl MultimodalConfig { + /// Clamp configured values to safe runtime bounds. + pub fn effective_limits(&self) -> (usize, usize) { + let max_images = self.max_images.clamp(1, 16); + let max_image_size_mb = self.max_image_size_mb.clamp(1, 20); + (max_images, max_image_size_mb) + } +} + +impl Default for MultimodalConfig { + fn default() -> Self { + Self { + max_images: default_multimodal_max_images(), + max_image_size_mb: default_multimodal_max_image_size_mb(), + allow_remote_fetch: false, + } + } +} + // ── Identity (AIEOS / OpenClaw format) ────────────────────────── -#[derive(Debug, Clone, Serialize, Deserialize)] +/// Identity format configuration (`[identity]` section). +/// +/// Supports `"openclaw"` (default) or `"aieos"` identity documents. +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] pub struct IdentityConfig { /// Identity format: "openclaw" (default) or "aieos" #[serde(default = "default_identity_format")] @@ -310,7 +426,8 @@ impl Default for IdentityConfig { // ── Cost tracking and budget enforcement ─────────────────────────── -#[derive(Debug, Clone, Serialize, Deserialize)] +/// Cost tracking and budget enforcement configuration (`[cost]` section). +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] pub struct CostConfig { /// Enable cost tracking (default: false) #[serde(default)] @@ -337,7 +454,8 @@ pub struct CostConfig { pub prices: std::collections::HashMap, } -#[derive(Debug, Clone, Serialize, Deserialize)] +/// Per-model pricing entry (USD per 1M tokens). +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] pub struct ModelPricing { /// Input price per 1M tokens #[serde(default)] @@ -451,7 +569,10 @@ fn get_default_pricing() -> std::collections::HashMap { // ── Peripherals (hardware: STM32, RPi GPIO, etc.) ──────────────────────── -#[derive(Debug, Clone, Serialize, Deserialize, Default)] +/// Peripheral board integration configuration (`[peripherals]` section). +/// +/// Boards become agent tools when enabled. +#[derive(Debug, Clone, Serialize, Deserialize, Default, JsonSchema)] pub struct PeripheralsConfig { /// Enable peripheral support (boards become agent tools) #[serde(default)] @@ -465,7 +586,8 @@ pub struct PeripheralsConfig { pub datasheet_dir: Option, } -#[derive(Debug, Clone, Serialize, Deserialize)] +/// Configuration for a single peripheral board (e.g. STM32, RPi GPIO). +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] pub struct PeripheralBoardConfig { /// Board type: "nucleo-f401re", "rpi-gpio", "esp32", etc. pub board: String, @@ -501,7 +623,10 @@ impl Default for PeripheralBoardConfig { // ── Gateway security ───────────────────────────────────────────── -#[derive(Debug, Clone, Serialize, Deserialize)] +/// Gateway server configuration (`[gateway]` section). +/// +/// Controls the HTTP gateway for webhook and pairing endpoints. +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] pub struct GatewayConfig { /// Gateway port (default: 3000) #[serde(default = "default_gateway_port")] @@ -597,10 +722,13 @@ impl Default for GatewayConfig { // ── Composio (managed tool surface) ───────────────────────────── -#[derive(Debug, Clone, Serialize, Deserialize)] +/// Composio managed OAuth tools integration (`[composio]` section). +/// +/// Provides access to 1000+ OAuth-connected tools via the Composio platform. +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] pub struct ComposioConfig { /// Enable Composio integration for 1000+ OAuth tools - #[serde(default)] + #[serde(default, alias = "enable")] pub enabled: bool, /// Composio API key (stored encrypted when secrets.encrypt = true) #[serde(default)] @@ -626,7 +754,8 @@ impl Default for ComposioConfig { // ── Secrets (encrypted credential store) ──────────────────────── -#[derive(Debug, Clone, Serialize, Deserialize)] +/// Secrets encryption configuration (`[secrets]` section). +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] pub struct SecretsConfig { /// Enable encryption for API keys and tokens in config.toml #[serde(default = "default_true")] @@ -641,7 +770,10 @@ impl Default for SecretsConfig { // ── Browser (friendly-service browsing only) ─────────────────── -#[derive(Debug, Clone, Serialize, Deserialize)] +/// Computer-use sidecar configuration (`[browser.computer_use]` section). +/// +/// Delegates OS-level mouse, keyboard, and screenshot actions to a local sidecar. +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] pub struct BrowserComputerUseConfig { /// Sidecar endpoint for computer-use actions (OS-level mouse/keyboard/screenshot) #[serde(default = "default_browser_computer_use_endpoint")] @@ -688,7 +820,10 @@ impl Default for BrowserComputerUseConfig { } } -#[derive(Debug, Clone, Serialize, Deserialize)] +/// Browser automation configuration (`[browser]` section). +/// +/// Controls the `browser_open` tool and browser automation backends. +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] pub struct BrowserConfig { /// Enable `browser_open` tool (opens URLs in Brave without scraping) #[serde(default)] @@ -741,7 +876,10 @@ impl Default for BrowserConfig { // ── HTTP request tool ─────────────────────────────────────────── -#[derive(Debug, Clone, Serialize, Deserialize, Default)] +/// HTTP request tool configuration (`[http_request]` section). +/// +/// Deny-by-default: if `allowed_domains` is empty, all HTTP requests are rejected. +#[derive(Debug, Clone, Serialize, Deserialize, Default, JsonSchema)] pub struct HttpRequestConfig { /// Enable `http_request` tool for API interactions #[serde(default)] @@ -767,10 +905,11 @@ fn default_http_timeout_secs() -> u64 { // ── Web search ─────────────────────────────────────────────────── -#[derive(Debug, Clone, Serialize, Deserialize)] +/// Web search tool configuration (`[web_search]` section). +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] pub struct WebSearchConfig { /// Enable `web_search_tool` for web searches - #[serde(default = "default_true")] + #[serde(default)] pub enabled: bool, /// Search provider: "duckduckgo" (free, no API key) or "brave" (requires API key) #[serde(default = "default_web_search_provider")] @@ -801,7 +940,7 @@ fn default_web_search_timeout_secs() -> u64 { impl Default for WebSearchConfig { fn default() -> Self { Self { - enabled: true, + enabled: false, provider: default_web_search_provider(), brave_api_key: None, max_results: default_web_search_max_results(), @@ -812,16 +951,21 @@ impl Default for WebSearchConfig { // ── Proxy ─────────────────────────────────────────────────────── -#[derive(Debug, Clone, Copy, Serialize, Deserialize, Default, PartialEq, Eq)] +/// Proxy application scope — determines which outbound traffic uses the proxy. +#[derive(Debug, Clone, Copy, Serialize, Deserialize, Default, PartialEq, Eq, JsonSchema)] #[serde(rename_all = "snake_case")] pub enum ProxyScope { + /// Use system environment proxy variables only. Environment, + /// Apply proxy to all ZeroClaw-managed HTTP traffic (default). #[default] Zeroclaw, + /// Apply proxy only to explicitly listed service selectors. Services, } -#[derive(Debug, Clone, Serialize, Deserialize)] +/// Proxy configuration for outbound HTTP/HTTPS/SOCKS5 traffic (`[proxy]` section). +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] pub struct ProxyConfig { /// Enable proxy support for selected scope. #[serde(default)] @@ -1271,19 +1415,24 @@ fn parse_proxy_enabled(raw: &str) -> Option { } // ── Memory ─────────────────────────────────────────────────── -#[derive(Debug, Clone, Serialize, Deserialize, Default)] +/// Persistent storage configuration (`[storage]` section). +#[derive(Debug, Clone, Serialize, Deserialize, Default, JsonSchema)] pub struct StorageConfig { + /// Storage provider settings (e.g. sqlite, postgres). #[serde(default)] pub provider: StorageProviderSection, } -#[derive(Debug, Clone, Serialize, Deserialize, Default)] +/// Wrapper for the storage provider configuration section. +#[derive(Debug, Clone, Serialize, Deserialize, Default, JsonSchema)] pub struct StorageProviderSection { + /// Storage provider backend settings. #[serde(default)] pub config: StorageProviderConfig, } -#[derive(Debug, Clone, Serialize, Deserialize)] +/// Storage provider backend configuration (e.g. postgres connection details). +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] pub struct StorageProviderConfig { /// Storage engine key (e.g. "postgres", "sqlite"). #[serde(default)] @@ -1332,14 +1481,18 @@ impl Default for StorageProviderConfig { } } -#[derive(Debug, Clone, Serialize, Deserialize)] +/// Memory backend configuration (`[memory]` section). +/// +/// Controls conversation memory storage, embeddings, hybrid search, response caching, +/// and memory snapshot/hydration. +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] #[allow(clippy::struct_excessive_bools)] pub struct MemoryConfig { /// "sqlite" | "lucid" | "postgres" | "markdown" | "none" (`none` = explicit no-op memory) /// /// `postgres` requires `[storage.provider.config]` with `db_url` (`dbURL` alias supported). pub backend: String, - /// Auto-save conversation context to memory + /// Auto-save user-stated conversation input to memory (assistant output is excluded) pub auto_save: bool, /// Run memory/session hygiene (archiving + retention cleanup) #[serde(default = "default_hygiene_enabled")] @@ -1482,7 +1635,8 @@ impl Default for MemoryConfig { // ── Observability ───────────────────────────────────────────────── -#[derive(Debug, Clone, Serialize, Deserialize)] +/// Observability backend configuration (`[observability]` section). +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] pub struct ObservabilityConfig { /// "none" | "log" | "prometheus" | "otel" pub backend: String, @@ -1508,13 +1662,23 @@ impl Default for ObservabilityConfig { // ── Autonomy / Security ────────────────────────────────────────── -#[derive(Debug, Clone, Serialize, Deserialize)] +/// Autonomy and security policy configuration (`[autonomy]` section). +/// +/// Controls what the agent is allowed to do: shell commands, filesystem access, +/// risk approval gates, and per-policy budgets. +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] pub struct AutonomyConfig { + /// Autonomy level: `read_only`, `supervised` (default), or `full`. pub level: AutonomyLevel, + /// Restrict file writes and command paths to the workspace directory. Default: `true`. pub workspace_only: bool, + /// Allowlist of executable names permitted for shell execution. pub allowed_commands: Vec, + /// Explicit path denylist. Default includes system-critical paths. pub forbidden_paths: Vec, + /// Maximum actions allowed per hour per policy. Default: `100`. pub max_actions_per_hour: u32, + /// Maximum cost per day in cents per policy. Default: `1000`. pub max_cost_per_day_cents: u32, /// Require explicit approval for medium-risk shell commands. @@ -1593,7 +1757,8 @@ impl Default for AutonomyConfig { // ── Runtime ────────────────────────────────────────────────────── -#[derive(Debug, Clone, Serialize, Deserialize)] +/// Runtime adapter configuration (`[runtime]` section). +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] pub struct RuntimeConfig { /// Runtime kind (`native` | `docker`). #[serde(default = "default_runtime_kind")] @@ -1602,9 +1767,17 @@ pub struct RuntimeConfig { /// Docker runtime settings (used when `kind = "docker"`). #[serde(default)] pub docker: DockerRuntimeConfig, + + /// Global reasoning override for providers that expose explicit controls. + /// - `None`: provider default behavior + /// - `Some(true)`: request reasoning/thinking when supported + /// - `Some(false)`: disable reasoning/thinking when supported + #[serde(default)] + pub reasoning_enabled: Option, } -#[derive(Debug, Clone, Serialize, Deserialize)] +/// Docker runtime configuration (`[runtime.docker]` section). +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] pub struct DockerRuntimeConfig { /// Runtime image used to execute shell commands. #[serde(default = "default_docker_image")] @@ -1674,13 +1847,17 @@ impl Default for RuntimeConfig { Self { kind: default_runtime_kind(), docker: DockerRuntimeConfig::default(), + reasoning_enabled: None, } } } // ── Reliability / supervision ──────────────────────────────────── -#[derive(Debug, Clone, Serialize, Deserialize)] +/// Reliability and supervision configuration (`[reliability]` section). +/// +/// Controls provider retries, fallback chains, API key rotation, and channel restart backoff. +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] pub struct ReliabilityConfig { /// Retries per provider before failing over. #[serde(default = "default_provider_retries")] @@ -1755,7 +1932,8 @@ impl Default for ReliabilityConfig { // ── Scheduler ──────────────────────────────────────────────────── -#[derive(Debug, Clone, Serialize, Deserialize)] +/// Scheduler configuration for periodic task execution (`[scheduler]` section). +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] pub struct SchedulerConfig { /// Enable the built-in scheduler loop. #[serde(default = "default_scheduler_enabled")] @@ -1807,7 +1985,7 @@ impl Default for SchedulerConfig { /// ``` /// /// Usage: pass `hint:reasoning` as the model parameter to route the request. -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] pub struct ModelRouteConfig { /// Task hint name (e.g. "reasoning", "fast", "code", "summarize") pub hint: String, @@ -1820,20 +1998,52 @@ pub struct ModelRouteConfig { pub api_key: Option, } +// ── Embedding routing ─────────────────────────────────────────── + +/// Route an embedding hint to a specific provider + model. +/// +/// ```toml +/// [[embedding_routes]] +/// hint = "semantic" +/// provider = "openai" +/// model = "text-embedding-3-small" +/// dimensions = 1536 +/// +/// [memory] +/// embedding_model = "hint:semantic" +/// ``` +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] +pub struct EmbeddingRouteConfig { + /// Route hint name (e.g. "semantic", "archive", "faq") + pub hint: String, + /// Embedding provider (`none`, `openai`, or `custom:`) + pub provider: String, + /// Embedding model to use with that provider + pub model: String, + /// Optional embedding dimension override for this route + #[serde(default)] + pub dimensions: Option, + /// Optional API key override for this route's provider + #[serde(default)] + pub api_key: Option, +} + // ── Query Classification ───────────────────────────────────────── /// Automatic query classification — classifies user messages by keyword/pattern /// and routes to the appropriate model hint. Disabled by default. -#[derive(Debug, Clone, Serialize, Deserialize, Default)] +#[derive(Debug, Clone, Serialize, Deserialize, Default, JsonSchema)] pub struct QueryClassificationConfig { + /// Enable automatic query classification. Default: `false`. #[serde(default)] pub enabled: bool, + /// Classification rules evaluated in priority order. #[serde(default)] pub rules: Vec, } /// A single classification rule mapping message patterns to a model hint. -#[derive(Debug, Clone, Serialize, Deserialize, Default)] +#[derive(Debug, Clone, Serialize, Deserialize, Default, JsonSchema)] pub struct ClassificationRule { /// Must match a `[[model_routes]]` hint value. pub hint: String, @@ -1856,9 +2066,12 @@ pub struct ClassificationRule { // ── Heartbeat ──────────────────────────────────────────────────── -#[derive(Debug, Clone, Serialize, Deserialize)] +/// Heartbeat configuration for periodic health pings (`[heartbeat]` section). +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] pub struct HeartbeatConfig { + /// Enable periodic heartbeat pings. Default: `false`. pub enabled: bool, + /// Interval in minutes between heartbeat pings. Default: `30`. pub interval_minutes: u32, } @@ -1873,10 +2086,13 @@ impl Default for HeartbeatConfig { // ── Cron ──────────────────────────────────────────────────────── -#[derive(Debug, Clone, Serialize, Deserialize)] +/// Cron job configuration (`[cron]` section). +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] pub struct CronConfig { + /// Enable the cron subsystem. Default: `true`. #[serde(default = "default_true")] pub enabled: bool, + /// Maximum number of historical cron run records to retain. Default: `50`. #[serde(default = "default_max_run_history")] pub max_run_history: u32, } @@ -1896,20 +2112,27 @@ impl Default for CronConfig { // ── Tunnel ────────────────────────────────────────────────────── -#[derive(Debug, Clone, Serialize, Deserialize)] +/// Tunnel configuration for exposing the gateway publicly (`[tunnel]` section). +/// +/// Supported providers: `"none"` (default), `"cloudflare"`, `"tailscale"`, `"ngrok"`, `"custom"`. +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] pub struct TunnelConfig { - /// "none", "cloudflare", "tailscale", "ngrok", "custom" + /// Tunnel provider: `"none"`, `"cloudflare"`, `"tailscale"`, `"ngrok"`, or `"custom"`. Default: `"none"`. pub provider: String, + /// Cloudflare Tunnel configuration (used when `provider = "cloudflare"`). #[serde(default)] pub cloudflare: Option, + /// Tailscale Funnel/Serve configuration (used when `provider = "tailscale"`). #[serde(default)] pub tailscale: Option, + /// ngrok tunnel configuration (used when `provider = "ngrok"`). #[serde(default)] pub ngrok: Option, + /// Custom tunnel command configuration (used when `provider = "custom"`). #[serde(default)] pub custom: Option, } @@ -1926,13 +2149,13 @@ impl Default for TunnelConfig { } } -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] pub struct CloudflareTunnelConfig { /// Cloudflare Tunnel token (from Zero Trust dashboard) pub token: String, } -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] pub struct TailscaleTunnelConfig { /// Use Tailscale Funnel (public internet) vs Serve (tailnet only) #[serde(default)] @@ -1941,7 +2164,7 @@ pub struct TailscaleTunnelConfig { pub hostname: Option, } -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] pub struct NgrokTunnelConfig { /// ngrok auth token pub auth_token: String, @@ -1949,7 +2172,7 @@ pub struct NgrokTunnelConfig { pub domain: Option, } -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] pub struct CustomTunnelConfig { /// Command template to start the tunnel. Use {port} and {host} placeholders. /// Example: "bore local {port} --to bore.pub" @@ -1962,23 +2185,55 @@ pub struct CustomTunnelConfig { // ── Channels ───────────────────────────────────────────────────── -#[derive(Debug, Clone, Serialize, Deserialize)] +/// Top-level channel configurations (`[channels_config]` section). +/// +/// Each channel sub-section (e.g. `telegram`, `discord`) is optional; +/// setting it to `Some(...)` enables that channel. +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] pub struct ChannelsConfig { + /// Enable the CLI interactive channel. Default: `true`. pub cli: bool, + /// Telegram bot channel configuration. pub telegram: Option, + /// Discord bot channel configuration. pub discord: Option, + /// Slack bot channel configuration. pub slack: Option, + /// Mattermost bot channel configuration. pub mattermost: Option, + /// Webhook channel configuration. pub webhook: Option, + /// iMessage channel configuration (macOS only). pub imessage: Option, + /// Matrix channel configuration. pub matrix: Option, + /// Signal channel configuration. pub signal: Option, + /// WhatsApp channel configuration (Cloud API or Web mode). pub whatsapp: Option, + /// Linq Partner API channel configuration. + pub linq: Option, + /// Email channel configuration. pub email: Option, + /// IRC channel configuration. pub irc: Option, + /// Lark/Feishu channel configuration. pub lark: Option, + /// DingTalk channel configuration. pub dingtalk: Option, + /// QQ Official Bot channel configuration. pub qq: Option, + /// Base timeout in seconds for processing a single channel message (LLM + tools). + /// Runtime uses this as a per-turn budget that scales with tool-loop depth + /// (up to 4x, capped) so one slow/retried model call does not consume the + /// entire conversation budget. + /// Default: 300s for on-device LLMs (Ollama) which are slower than cloud APIs. + #[serde(default = "default_channel_message_timeout_secs")] + pub message_timeout_secs: u64, +} + +fn default_channel_message_timeout_secs() -> u64 { + 300 } impl Default for ChannelsConfig { @@ -1994,17 +2249,19 @@ impl Default for ChannelsConfig { matrix: None, signal: None, whatsapp: None, + linq: None, email: None, irc: None, lark: None, dingtalk: None, qq: None, + message_timeout_secs: default_channel_message_timeout_secs(), } } } /// Streaming mode for channels that support progressive message updates. -#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)] +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default, JsonSchema)] #[serde(rename_all = "lowercase")] pub enum StreamMode { /// No streaming -- send the complete response as a single message (default). @@ -2018,9 +2275,12 @@ fn default_draft_update_interval_ms() -> u64 { 1000 } -#[derive(Debug, Clone, Serialize, Deserialize)] +/// Telegram bot channel configuration. +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] pub struct TelegramConfig { + /// Telegram Bot API token (from @BotFather). pub bot_token: String, + /// Allowed Telegram user IDs or usernames. Empty = deny all. pub allowed_users: Vec, /// Streaming mode for progressive response delivery via message edits. #[serde(default)] @@ -2028,16 +2288,24 @@ pub struct TelegramConfig { /// Minimum interval (ms) between draft message edits to avoid rate limits. #[serde(default = "default_draft_update_interval_ms")] pub draft_update_interval_ms: u64, + /// When true, a newer Telegram message from the same sender in the same chat + /// cancels the in-flight request and starts a fresh response with preserved history. + #[serde(default)] + pub interrupt_on_new_message: bool, /// When true, only respond to messages that @-mention the bot in groups. /// Direct messages are always processed. #[serde(default)] pub mention_only: bool, } -#[derive(Debug, Clone, Serialize, Deserialize)] +/// Discord bot channel configuration. +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] pub struct DiscordConfig { + /// Discord bot token (from Discord Developer Portal). pub bot_token: String, + /// Optional guild (server) ID to restrict the bot to a single guild. pub guild_id: Option, + /// Allowed Discord user IDs. Empty = deny all. #[serde(default)] pub allowed_users: Vec, /// When true, process messages from other bots (not just humans). @@ -2050,20 +2318,30 @@ pub struct DiscordConfig { pub mention_only: bool, } -#[derive(Debug, Clone, Serialize, Deserialize)] +/// Slack bot channel configuration. +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] pub struct SlackConfig { + /// Slack bot OAuth token (xoxb-...). pub bot_token: String, + /// Slack app-level token for Socket Mode (xapp-...). pub app_token: Option, + /// Optional channel ID to restrict the bot to a single channel. pub channel_id: Option, + /// Allowed Slack user IDs. Empty = deny all. #[serde(default)] pub allowed_users: Vec, } -#[derive(Debug, Clone, Serialize, Deserialize)] +/// Mattermost bot channel configuration. +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] pub struct MattermostConfig { + /// Mattermost server URL (e.g. `"https://mattermost.example.com"`). pub url: String, + /// Mattermost bot access token. pub bot_token: String, + /// Optional channel ID to restrict the bot to a single channel. pub channel_id: Option, + /// Allowed Mattermost user IDs. Empty = deny all. #[serde(default)] pub allowed_users: Vec, /// When true (default), replies thread on the original post. @@ -2076,30 +2354,42 @@ pub struct MattermostConfig { pub mention_only: Option, } -#[derive(Debug, Clone, Serialize, Deserialize)] +/// Webhook channel configuration. +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] pub struct WebhookConfig { + /// Port to listen on for incoming webhooks. pub port: u16, + /// Optional shared secret for webhook signature verification. pub secret: Option, } -#[derive(Debug, Clone, Serialize, Deserialize)] +/// iMessage channel configuration (macOS only). +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] pub struct IMessageConfig { + /// Allowed iMessage contacts (phone numbers or email addresses). Empty = deny all. pub allowed_contacts: Vec, } -#[derive(Debug, Clone, Serialize, Deserialize)] +/// Matrix channel configuration. +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] pub struct MatrixConfig { + /// Matrix homeserver URL (e.g. `"https://matrix.org"`). pub homeserver: String, + /// Matrix access token for the bot account. pub access_token: String, + /// Optional Matrix user ID (e.g. `"@bot:matrix.org"`). #[serde(default)] pub user_id: Option, + /// Optional Matrix device ID. #[serde(default)] pub device_id: Option, + /// Matrix room ID to listen in (e.g. `"!abc123:matrix.org"`). pub room_id: String, + /// Allowed Matrix user IDs. Empty = deny all. pub allowed_users: Vec, } -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] pub struct SignalConfig { /// Base URL for the signal-cli HTTP daemon (e.g. "http://127.0.0.1:8686"). pub http_url: String, @@ -2122,24 +2412,92 @@ pub struct SignalConfig { pub ignore_stories: bool, } -#[derive(Debug, Clone, Serialize, Deserialize)] +/// WhatsApp channel configuration (Cloud API or Web mode). +/// +/// Set `phone_number_id` for Cloud API mode, or `session_path` for Web mode. +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] pub struct WhatsAppConfig { - /// Access token from Meta Business Suite - pub access_token: String, - /// Phone number ID from Meta Business API - pub phone_number_id: String, + /// Access token from Meta Business Suite (Cloud API mode) + #[serde(default)] + pub access_token: Option, + /// Phone number ID from Meta Business API (Cloud API mode) + #[serde(default)] + pub phone_number_id: Option, /// Webhook verify token (you define this, Meta sends it back for verification) - pub verify_token: String, + /// Only used in Cloud API mode + #[serde(default)] + pub verify_token: Option, /// App secret from Meta Business Suite (for webhook signature verification) /// Can also be set via `ZEROCLAW_WHATSAPP_APP_SECRET` environment variable + /// Only used in Cloud API mode #[serde(default)] pub app_secret: Option, + /// Session database path for WhatsApp Web client (Web mode) + /// When set, enables native WhatsApp Web mode with wa-rs + #[serde(default)] + pub session_path: Option, + /// Phone number for pair code linking (Web mode, optional) + /// Format: country code + number (e.g., "15551234567") + /// If not set, QR code pairing will be used + #[serde(default)] + pub pair_phone: Option, + /// Custom pair code for linking (Web mode, optional) + /// Leave empty to let WhatsApp generate one + #[serde(default)] + pub pair_code: Option, /// Allowed phone numbers (E.164 format: +1234567890) or "*" for all #[serde(default)] pub allowed_numbers: Vec, } -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] +pub struct LinqConfig { + /// Linq Partner API token (Bearer auth) + pub api_token: String, + /// Phone number to send from (E.164 format) + pub from_phone: String, + /// Webhook signing secret for signature verification + #[serde(default)] + pub signing_secret: Option, + /// Allowed sender handles (phone numbers) or "*" for all + #[serde(default)] + pub allowed_senders: Vec, +} + +impl WhatsAppConfig { + /// Detect which backend to use based on config fields. + /// Returns "cloud" if phone_number_id is set, "web" if session_path is set. + pub fn backend_type(&self) -> &'static str { + if self.phone_number_id.is_some() { + "cloud" + } else if self.session_path.is_some() { + "web" + } else { + // Default to Cloud API for backward compatibility + "cloud" + } + } + + /// Check if this is a valid Cloud API config + pub fn is_cloud_config(&self) -> bool { + self.phone_number_id.is_some() && self.access_token.is_some() && self.verify_token.is_some() + } + + /// Check if this is a valid Web config + pub fn is_web_config(&self) -> bool { + self.session_path.is_some() + } + + /// Returns true when both Cloud and Web selectors are present. + /// + /// Runtime currently prefers Cloud mode in this case for backward compatibility. + pub fn is_ambiguous_config(&self) -> bool { + self.phone_number_id.is_some() && self.session_path.is_some() + } +} + +/// IRC channel configuration. +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] pub struct IrcConfig { /// IRC server hostname pub server: String, @@ -2174,7 +2532,7 @@ fn default_irc_port() -> u16 { /// /// - `websocket` (default) — persistent WSS long-connection; no public URL required. /// - `webhook` — HTTP callback server; requires a public HTTPS endpoint. -#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)] +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default, JsonSchema)] #[serde(rename_all = "lowercase")] pub enum LarkReceiveMode { #[default] @@ -2184,7 +2542,7 @@ pub enum LarkReceiveMode { /// Lark/Feishu configuration for messaging integration. /// Lark is the international version; Feishu is the Chinese version. -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] pub struct LarkConfig { /// App ID from Lark/Feishu developer console pub app_id: String, @@ -2214,7 +2572,7 @@ pub struct LarkConfig { // ── Security Config ───────────────────────────────────────────────── /// Security configuration for sandboxing, resource limits, and audit logging -#[derive(Debug, Clone, Serialize, Deserialize, Default)] +#[derive(Debug, Clone, Serialize, Deserialize, Default, JsonSchema)] pub struct SecurityConfig { /// Sandbox configuration #[serde(default)] @@ -2230,7 +2588,7 @@ pub struct SecurityConfig { } /// Sandbox configuration for OS-level isolation -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] pub struct SandboxConfig { /// Enable sandboxing (None = auto-detect, Some = explicit) #[serde(default)] @@ -2256,7 +2614,7 @@ impl Default for SandboxConfig { } /// Sandbox backend selection -#[derive(Debug, Clone, Serialize, Deserialize, Default)] +#[derive(Debug, Clone, Serialize, Deserialize, Default, JsonSchema)] #[serde(rename_all = "lowercase")] pub enum SandboxBackend { /// Auto-detect best available (default) @@ -2275,7 +2633,7 @@ pub enum SandboxBackend { } /// Resource limits for command execution -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] pub struct ResourceLimitsConfig { /// Maximum memory in MB per command #[serde(default = "default_max_memory_mb")] @@ -2322,7 +2680,7 @@ impl Default for ResourceLimitsConfig { } /// Audit logging configuration -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] pub struct AuditConfig { /// Enable audit logging #[serde(default = "default_audit_enabled")] @@ -2365,7 +2723,7 @@ impl Default for AuditConfig { } /// DingTalk configuration for Stream Mode messaging -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] pub struct DingTalkConfig { /// Client ID (AppKey) from DingTalk developer console pub client_id: String, @@ -2377,7 +2735,7 @@ pub struct DingTalkConfig { } /// QQ Official Bot configuration (Tencent QQ Bot SDK) -#[derive(Debug, Clone, Serialize, Deserialize)] +#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] pub struct QQConfig { /// App ID from QQ Bot developer console pub app_id: String, @@ -2410,7 +2768,9 @@ impl Default for Config { reliability: ReliabilityConfig::default(), scheduler: SchedulerConfig::default(), agent: AgentConfig::default(), + skills: SkillsConfig::default(), model_routes: Vec::new(), + embedding_routes: Vec::new(), heartbeat: HeartbeatConfig::default(), cron: CronConfig::default(), channels_config: ChannelsConfig::default(), @@ -2422,6 +2782,7 @@ impl Default for Config { secrets: SecretsConfig::default(), browser: BrowserConfig::default(), http_request: HttpRequestConfig::default(), + multimodal: MultimodalConfig::default(), web_search: WebSearchConfig::default(), proxy: ProxyConfig::default(), identity: IdentityConfig::default(), @@ -2457,13 +2818,15 @@ fn active_workspace_state_path(default_dir: &Path) -> PathBuf { default_dir.join(ACTIVE_WORKSPACE_STATE_FILE) } -fn load_persisted_workspace_dirs(default_config_dir: &Path) -> Result> { +async fn load_persisted_workspace_dirs( + default_config_dir: &Path, +) -> Result> { let state_path = active_workspace_state_path(default_config_dir); if !state_path.exists() { return Ok(None); } - let contents = match fs::read_to_string(&state_path) { + let contents = match fs::read_to_string(&state_path).await { Ok(contents) => contents, Err(error) => { tracing::warn!( @@ -2503,13 +2866,13 @@ fn load_persisted_workspace_dirs(default_config_dir: &Path) -> Result Result<()> { +pub(crate) async fn persist_active_workspace_config_dir(config_dir: &Path) -> Result<()> { let default_config_dir = default_config_dir()?; let state_path = active_workspace_state_path(&default_config_dir); if config_dir == default_config_dir { if state_path.exists() { - fs::remove_file(&state_path).with_context(|| { + fs::remove_file(&state_path).await.with_context(|| { format!( "Failed to clear active workspace marker: {}", state_path.display() @@ -2519,12 +2882,14 @@ pub(crate) fn persist_active_workspace_config_dir(config_dir: &Path) -> Result<( return Ok(()); } - fs::create_dir_all(&default_config_dir).with_context(|| { - format!( - "Failed to create default config directory: {}", - default_config_dir.display() - ) - })?; + fs::create_dir_all(&default_config_dir) + .await + .with_context(|| { + format!( + "Failed to create default config directory: {}", + default_config_dir.display() + ) + })?; let state = ActiveWorkspaceState { config_dir: config_dir.to_string_lossy().into_owned(), @@ -2536,22 +2901,22 @@ pub(crate) fn persist_active_workspace_config_dir(config_dir: &Path) -> Result<( ".{ACTIVE_WORKSPACE_STATE_FILE}.tmp-{}", uuid::Uuid::new_v4() )); - fs::write(&temp_path, serialized).with_context(|| { + fs::write(&temp_path, serialized).await.with_context(|| { format!( "Failed to write temporary active workspace marker: {}", temp_path.display() ) })?; - if let Err(error) = fs::rename(&temp_path, &state_path) { - let _ = fs::remove_file(&temp_path); + if let Err(error) = fs::rename(&temp_path, &state_path).await { + let _ = fs::remove_file(&temp_path).await; anyhow::bail!( "Failed to atomically persist active workspace marker {}: {error}", state_path.display() ); } - sync_directory(&default_config_dir)?; + sync_directory(&default_config_dir).await?; Ok(()) } @@ -2586,6 +2951,60 @@ fn resolve_config_dir_for_workspace(workspace_dir: &Path) -> (PathBuf, PathBuf) ) } +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +enum ConfigResolutionSource { + EnvWorkspace, + ActiveWorkspaceMarker, + DefaultConfigDir, +} + +impl ConfigResolutionSource { + const fn as_str(self) -> &'static str { + match self { + Self::EnvWorkspace => "ZEROCLAW_WORKSPACE", + Self::ActiveWorkspaceMarker => "active_workspace.toml", + Self::DefaultConfigDir => "default", + } + } +} + +async fn resolve_runtime_config_dirs( + default_zeroclaw_dir: &Path, + default_workspace_dir: &Path, +) -> Result<(PathBuf, PathBuf, ConfigResolutionSource)> { + // Resolution priority: + // 1. ZEROCLAW_WORKSPACE env override + // 2. Persisted active workspace marker from onboarding/custom profile + // 3. Default ~/.zeroclaw layout + if let Ok(custom_workspace) = std::env::var("ZEROCLAW_WORKSPACE") { + if !custom_workspace.is_empty() { + let (zeroclaw_dir, workspace_dir) = + resolve_config_dir_for_workspace(&PathBuf::from(custom_workspace)); + return Ok(( + zeroclaw_dir, + workspace_dir, + ConfigResolutionSource::EnvWorkspace, + )); + } + } + + if let Some((zeroclaw_dir, workspace_dir)) = + load_persisted_workspace_dirs(default_zeroclaw_dir).await? + { + return Ok(( + zeroclaw_dir, + workspace_dir, + ConfigResolutionSource::ActiveWorkspaceMarker, + )); + } + + Ok(( + default_zeroclaw_dir.to_path_buf(), + default_workspace_dir.to_path_buf(), + ConfigResolutionSource::DefaultConfigDir, + )) +} + fn decrypt_optional_secret( store: &crate::security::SecretStore, value: &mut Option, @@ -2621,32 +3040,27 @@ fn encrypt_optional_secret( } impl Config { - pub fn load_or_init() -> Result { + pub async fn load_or_init() -> Result { let (default_zeroclaw_dir, default_workspace_dir) = default_config_and_workspace_dirs()?; - // Resolution priority: - // 1. ZEROCLAW_WORKSPACE env override - // 2. Persisted active workspace marker from onboarding/custom profile - // 3. Default ~/.zeroclaw layout - let (zeroclaw_dir, workspace_dir) = match std::env::var("ZEROCLAW_WORKSPACE") { - Ok(custom_workspace) if !custom_workspace.is_empty() => { - resolve_config_dir_for_workspace(&PathBuf::from(custom_workspace)) - } - _ => load_persisted_workspace_dirs(&default_zeroclaw_dir)? - .unwrap_or((default_zeroclaw_dir, default_workspace_dir)), - }; + let (zeroclaw_dir, workspace_dir, resolution_source) = + resolve_runtime_config_dirs(&default_zeroclaw_dir, &default_workspace_dir).await?; let config_path = zeroclaw_dir.join("config.toml"); - fs::create_dir_all(&zeroclaw_dir).context("Failed to create config directory")?; - fs::create_dir_all(&workspace_dir).context("Failed to create workspace directory")?; + fs::create_dir_all(&zeroclaw_dir) + .await + .context("Failed to create config directory")?; + fs::create_dir_all(&workspace_dir) + .await + .context("Failed to create workspace directory")?; if config_path.exists() { // Warn if config file is world-readable (may contain API keys) #[cfg(unix)] { use std::os::unix::fs::PermissionsExt; - if let Ok(meta) = fs::metadata(&config_path) { + if let Ok(meta) = fs::metadata(&config_path).await { if meta.permissions().mode() & 0o004 != 0 { tracing::warn!( "Config file {:?} is world-readable (mode {:o}). \ @@ -2659,8 +3073,9 @@ impl Config { } } - let contents = - fs::read_to_string(&config_path).context("Failed to read config file")?; + let contents = fs::read_to_string(&config_path) + .await + .context("Failed to read config file")?; let mut config: Config = toml::from_str(&contents).context("Failed to parse config file")?; // Set computed paths that are skipped during serialization @@ -2696,25 +3111,96 @@ impl Config { decrypt_optional_secret(&store, &mut agent.api_key, "config.agents.*.api_key")?; } config.apply_env_overrides(); + config.validate()?; + tracing::info!( + path = %config.config_path.display(), + workspace = %config.workspace_dir.display(), + source = resolution_source.as_str(), + initialized = false, + "Config loaded" + ); Ok(config) } else { let mut config = Config::default(); config.config_path = config_path.clone(); config.workspace_dir = workspace_dir; - config.save()?; + config.save().await?; // Restrict permissions on newly created config file (may contain API keys) #[cfg(unix)] { - use std::os::unix::fs::PermissionsExt; - let _ = fs::set_permissions(&config_path, fs::Permissions::from_mode(0o600)); + use std::{fs::Permissions, os::unix::fs::PermissionsExt}; + let _ = fs::set_permissions(&config_path, Permissions::from_mode(0o600)).await; } config.apply_env_overrides(); + config.validate()?; + tracing::info!( + path = %config.config_path.display(), + workspace = %config.workspace_dir.display(), + source = resolution_source.as_str(), + initialized = true, + "Config loaded" + ); Ok(config) } } + /// Validate configuration values that would cause runtime failures. + /// + /// Called after TOML deserialization and env-override application to catch + /// obviously invalid values early instead of failing at arbitrary runtime points. + pub fn validate(&self) -> Result<()> { + // Gateway + if self.gateway.host.trim().is_empty() { + anyhow::bail!("gateway.host must not be empty"); + } + + // Autonomy + if self.autonomy.max_actions_per_hour == 0 { + anyhow::bail!("autonomy.max_actions_per_hour must be greater than 0"); + } + + // Scheduler + if self.scheduler.max_concurrent == 0 { + anyhow::bail!("scheduler.max_concurrent must be greater than 0"); + } + if self.scheduler.max_tasks == 0 { + anyhow::bail!("scheduler.max_tasks must be greater than 0"); + } + + // Model routes + for (i, route) in self.model_routes.iter().enumerate() { + if route.hint.trim().is_empty() { + anyhow::bail!("model_routes[{i}].hint must not be empty"); + } + if route.provider.trim().is_empty() { + anyhow::bail!("model_routes[{i}].provider must not be empty"); + } + if route.model.trim().is_empty() { + anyhow::bail!("model_routes[{i}].model must not be empty"); + } + } + + // Embedding routes + for (i, route) in self.embedding_routes.iter().enumerate() { + if route.hint.trim().is_empty() { + anyhow::bail!("embedding_routes[{i}].hint must not be empty"); + } + if route.provider.trim().is_empty() { + anyhow::bail!("embedding_routes[{i}].provider must not be empty"); + } + if route.model.trim().is_empty() { + anyhow::bail!("embedding_routes[{i}].model must not be empty"); + } + } + + // Proxy (delegate to existing validation) + self.proxy.validate()?; + + Ok(()) + } + /// Apply environment variable overrides to config pub fn apply_env_overrides(&mut self) { // API Key: ZEROCLAW_API_KEY or API_KEY (generic) @@ -2741,13 +3227,23 @@ impl Config { } } - // Provider: ZEROCLAW_PROVIDER or PROVIDER - if let Ok(provider) = - std::env::var("ZEROCLAW_PROVIDER").or_else(|_| std::env::var("PROVIDER")) - { + // Provider override precedence: + // 1) ZEROCLAW_PROVIDER always wins when set. + // 2) Legacy PROVIDER is only honored when config still uses the + // default provider (openrouter) or provider is unset. This prevents + // container defaults from overriding explicit custom providers. + if let Ok(provider) = std::env::var("ZEROCLAW_PROVIDER") { if !provider.is_empty() { self.default_provider = Some(provider); } + } else if let Ok(provider) = std::env::var("PROVIDER") { + let should_apply_legacy_provider = + self.default_provider.as_deref().map_or(true, |configured| { + configured.trim().eq_ignore_ascii_case("openrouter") + }); + if should_apply_legacy_provider && !provider.is_empty() { + self.default_provider = Some(provider); + } } // Model: ZEROCLAW_MODEL or MODEL @@ -2766,6 +3262,27 @@ impl Config { } } + // Open-skills opt-in flag: ZEROCLAW_OPEN_SKILLS_ENABLED + if let Ok(flag) = std::env::var("ZEROCLAW_OPEN_SKILLS_ENABLED") { + if !flag.trim().is_empty() { + match flag.trim().to_ascii_lowercase().as_str() { + "1" | "true" | "yes" | "on" => self.skills.open_skills_enabled = true, + "0" | "false" | "no" | "off" => self.skills.open_skills_enabled = false, + _ => tracing::warn!( + "Ignoring invalid ZEROCLAW_OPEN_SKILLS_ENABLED (valid: 1|0|true|false|yes|no|on|off)" + ), + } + } + } + + // Open-skills directory override: ZEROCLAW_OPEN_SKILLS_DIR + if let Ok(path) = std::env::var("ZEROCLAW_OPEN_SKILLS_DIR") { + let trimmed = path.trim(); + if !trimmed.is_empty() { + self.skills.open_skills_dir = Some(trimmed.to_string()); + } + } + // Gateway port: ZEROCLAW_GATEWAY_PORT or PORT if let Ok(port_str) = std::env::var("ZEROCLAW_GATEWAY_PORT").or_else(|_| std::env::var("PORT")) @@ -2797,6 +3314,18 @@ impl Config { } } + // Reasoning override: ZEROCLAW_REASONING_ENABLED or REASONING_ENABLED + if let Ok(flag) = std::env::var("ZEROCLAW_REASONING_ENABLED") + .or_else(|_| std::env::var("REASONING_ENABLED")) + { + let normalized = flag.trim().to_ascii_lowercase(); + match normalized.as_str() { + "1" | "true" | "yes" | "on" => self.runtime.reasoning_enabled = Some(true), + "0" | "false" | "no" | "off" => self.runtime.reasoning_enabled = Some(false), + _ => {} + } + } + // Web search enabled: ZEROCLAW_WEB_SEARCH_ENABLED or WEB_SEARCH_ENABLED if let Ok(enabled) = std::env::var("ZEROCLAW_WEB_SEARCH_ENABLED") .or_else(|_| std::env::var("WEB_SEARCH_ENABLED")) @@ -2940,7 +3469,7 @@ impl Config { set_runtime_proxy_config(self.proxy.clone()); } - pub fn save(&self) -> Result<()> { + pub async fn save(&self) -> Result<()> { // Encrypt secrets before serialization let mut config_to_save = self.clone(); let zeroclaw_dir = self @@ -2985,7 +3514,8 @@ impl Config { .config_path .parent() .context("Config path must have a parent directory")?; - fs::create_dir_all(parent_dir).with_context(|| { + + fs::create_dir_all(parent_dir).await.with_context(|| { format!( "Failed to create config directory: {}", parent_dir.display() @@ -3004,6 +3534,7 @@ impl Config { .create_new(true) .write(true) .open(&temp_path) + .await .with_context(|| { format!( "Failed to create temporary config file: {}", @@ -3012,80 +3543,131 @@ impl Config { })?; temp_file .write_all(toml_str.as_bytes()) + .await .context("Failed to write temporary config contents")?; temp_file .sync_all() + .await .context("Failed to fsync temporary config file")?; drop(temp_file); let had_existing_config = self.config_path.exists(); if had_existing_config { - fs::copy(&self.config_path, &backup_path).with_context(|| { - format!( - "Failed to create config backup before atomic replace: {}", - backup_path.display() - ) - })?; + fs::copy(&self.config_path, &backup_path) + .await + .with_context(|| { + format!( + "Failed to create config backup before atomic replace: {}", + backup_path.display() + ) + })?; } - if let Err(e) = fs::rename(&temp_path, &self.config_path) { - let _ = fs::remove_file(&temp_path); + if let Err(e) = fs::rename(&temp_path, &self.config_path).await { + let _ = fs::remove_file(&temp_path).await; if had_existing_config && backup_path.exists() { - let _ = fs::copy(&backup_path, &self.config_path); + fs::copy(&backup_path, &self.config_path) + .await + .context("Failed to restore config backup")?; } anyhow::bail!("Failed to atomically replace config file: {e}"); } - sync_directory(parent_dir)?; + sync_directory(parent_dir).await?; if had_existing_config { - let _ = fs::remove_file(&backup_path); + let _ = fs::remove_file(&backup_path).await; } Ok(()) } } -#[cfg(unix)] -fn sync_directory(path: &Path) -> Result<()> { - let dir = File::open(path) - .with_context(|| format!("Failed to open directory for fsync: {}", path.display()))?; - dir.sync_all() - .with_context(|| format!("Failed to fsync directory metadata: {}", path.display()))?; - Ok(()) -} +async fn sync_directory(path: &Path) -> Result<()> { + #[cfg(unix)] + { + let dir = File::open(path) + .await + .with_context(|| format!("Failed to open directory for fsync: {}", path.display()))?; + dir.sync_all() + .await + .with_context(|| format!("Failed to fsync directory metadata: {}", path.display()))?; + return Ok(()); + } -#[cfg(not(unix))] -fn sync_directory(_path: &Path) -> Result<()> { - Ok(()) + #[cfg(not(unix))] + { + let _ = path; + Ok(()) + } } #[cfg(test)] mod tests { use super::*; use std::path::PathBuf; + #[cfg(unix)] + use std::{fs::Permissions, os::unix::fs::PermissionsExt}; + use tokio::sync::{Mutex, MutexGuard}; + use tokio::test; + use tokio_stream::wrappers::ReadDirStream; + use tokio_stream::StreamExt; // ── Defaults ───────────────────────────────────────────── #[test] - fn config_default_has_sane_values() { + async fn config_default_has_sane_values() { let c = Config::default(); assert_eq!(c.default_provider.as_deref(), Some("openrouter")); assert!(c.default_model.as_deref().unwrap().contains("claude")); assert!((c.default_temperature - 0.7).abs() < f64::EPSILON); assert!(c.api_key.is_none()); + assert!(!c.skills.open_skills_enabled); assert!(c.workspace_dir.to_string_lossy().contains("workspace")); assert!(c.config_path.to_string_lossy().contains("config.toml")); } #[test] - fn observability_config_default() { + async fn config_schema_export_contains_expected_contract_shape() { + let schema = schemars::schema_for!(Config); + let schema_json = serde_json::to_value(&schema).expect("schema should serialize to json"); + + assert_eq!( + schema_json + .get("$schema") + .and_then(serde_json::Value::as_str), + Some("https://json-schema.org/draft/2020-12/schema") + ); + + let properties = schema_json + .get("properties") + .and_then(serde_json::Value::as_object) + .expect("schema should expose top-level properties"); + + assert!(properties.contains_key("default_provider")); + assert!(properties.contains_key("skills")); + assert!(properties.contains_key("gateway")); + assert!(properties.contains_key("channels_config")); + assert!(!properties.contains_key("workspace_dir")); + assert!(!properties.contains_key("config_path")); + + assert!( + schema_json + .get("$defs") + .and_then(serde_json::Value::as_object) + .is_some(), + "schema should include reusable type definitions" + ); + } + + #[test] + async fn observability_config_default() { let o = ObservabilityConfig::default(); assert_eq!(o.backend, "none"); } #[test] - fn autonomy_config_default() { + async fn autonomy_config_default() { let a = AutonomyConfig::default(); assert_eq!(a.level, AutonomyLevel::Supervised); assert!(a.workspace_only); @@ -3099,7 +3681,7 @@ mod tests { } #[test] - fn runtime_config_default() { + async fn runtime_config_default() { let r = RuntimeConfig::default(); assert_eq!(r.kind, "native"); assert_eq!(r.docker.image, "alpine:3.20"); @@ -3111,21 +3693,21 @@ mod tests { } #[test] - fn heartbeat_config_default() { + async fn heartbeat_config_default() { let h = HeartbeatConfig::default(); assert!(!h.enabled); assert_eq!(h.interval_minutes, 30); } #[test] - fn cron_config_default() { + async fn cron_config_default() { let c = CronConfig::default(); assert!(c.enabled); assert_eq!(c.max_run_history, 50); } #[test] - fn cron_config_serde_roundtrip() { + async fn cron_config_serde_roundtrip() { let c = CronConfig { enabled: false, max_run_history: 100, @@ -3137,7 +3719,7 @@ mod tests { } #[test] - fn config_defaults_cron_when_section_missing() { + async fn config_defaults_cron_when_section_missing() { let toml_str = r#" workspace_dir = "/tmp/workspace" config_path = "/tmp/config.toml" @@ -3150,7 +3732,7 @@ default_temperature = 0.7 } #[test] - fn memory_config_default_hygiene_settings() { + async fn memory_config_default_hygiene_settings() { let m = MemoryConfig::default(); assert_eq!(m.backend, "sqlite"); assert!(m.auto_save); @@ -3162,7 +3744,7 @@ default_temperature = 0.7 } #[test] - fn storage_provider_config_defaults() { + async fn storage_provider_config_defaults() { let storage = StorageConfig::default(); assert!(storage.provider.config.provider.is_empty()); assert!(storage.provider.config.db_url.is_none()); @@ -3172,7 +3754,7 @@ default_temperature = 0.7 } #[test] - fn channels_config_default() { + async fn channels_config_default() { let c = ChannelsConfig::default(); assert!(c.cli); assert!(c.telegram.is_none()); @@ -3182,7 +3764,7 @@ default_temperature = 0.7 // ── Serde round-trip ───────────────────────────────────── #[test] - fn config_toml_roundtrip() { + async fn config_toml_roundtrip() { let config = Config { workspace_dir: PathBuf::from("/tmp/test/workspace"), config_path: PathBuf::from("/tmp/test/config.toml"), @@ -3213,7 +3795,9 @@ default_temperature = 0.7 }, reliability: ReliabilityConfig::default(), scheduler: SchedulerConfig::default(), + skills: SkillsConfig::default(), model_routes: Vec::new(), + embedding_routes: Vec::new(), query_classification: QueryClassificationConfig::default(), heartbeat: HeartbeatConfig { enabled: true, @@ -3227,6 +3811,7 @@ default_temperature = 0.7 allowed_users: vec!["user1".into()], stream_mode: StreamMode::default(), draft_update_interval_ms: default_draft_update_interval_ms(), + interrupt_on_new_message: false, mention_only: false, }), discord: None, @@ -3237,11 +3822,13 @@ default_temperature = 0.7 matrix: None, signal: None, whatsapp: None, + linq: None, email: None, irc: None, lark: None, dingtalk: None, qq: None, + message_timeout_secs: 300, }, memory: MemoryConfig::default(), storage: StorageConfig::default(), @@ -3251,6 +3838,7 @@ default_temperature = 0.7 secrets: SecretsConfig::default(), browser: BrowserConfig::default(), http_request: HttpRequestConfig::default(), + multimodal: MultimodalConfig::default(), web_search: WebSearchConfig::default(), proxy: ProxyConfig::default(), agent: AgentConfig::default(), @@ -3282,7 +3870,7 @@ default_temperature = 0.7 } #[test] - fn config_minimal_toml_uses_defaults() { + async fn config_minimal_toml_uses_defaults() { let minimal = r#" workspace_dir = "/tmp/ws" config_path = "/tmp/config.toml" @@ -3303,7 +3891,7 @@ default_temperature = 0.7 } #[test] - fn storage_provider_dburl_alias_deserializes() { + async fn storage_provider_dburl_alias_deserializes() { let raw = r#" default_temperature = 0.7 @@ -3330,7 +3918,20 @@ connect_timeout_secs = 12 } #[test] - fn agent_config_defaults() { + async fn runtime_reasoning_enabled_deserializes() { + let raw = r#" +default_temperature = 0.7 + +[runtime] +reasoning_enabled = false +"#; + + let parsed: Config = toml::from_str(raw).unwrap(); + assert_eq!(parsed.runtime.reasoning_enabled, Some(false)); + } + + #[test] + async fn agent_config_defaults() { let cfg = AgentConfig::default(); assert!(!cfg.compact_context); assert_eq!(cfg.max_tool_iterations, 10); @@ -3340,7 +3941,7 @@ connect_timeout_secs = 12 } #[test] - fn agent_config_deserializes() { + async fn agent_config_deserializes() { let raw = r#" default_temperature = 0.7 [agent] @@ -3358,11 +3959,24 @@ tool_dispatcher = "xml" assert_eq!(parsed.agent.tool_dispatcher, "xml"); } - #[test] - fn config_save_and_load_tmpdir() { + #[tokio::test] + async fn sync_directory_handles_existing_directory() { + let dir = std::env::temp_dir().join(format!( + "zeroclaw_test_sync_directory_{}", + uuid::Uuid::new_v4() + )); + fs::create_dir_all(&dir).await.unwrap(); + + sync_directory(&dir).await.unwrap(); + + let _ = fs::remove_dir_all(&dir).await; + } + + #[tokio::test] + async fn config_save_and_load_tmpdir() { let dir = std::env::temp_dir().join("zeroclaw_test_config"); - let _ = fs::remove_dir_all(&dir); - fs::create_dir_all(&dir).unwrap(); + let _ = fs::remove_dir_all(&dir).await; + fs::create_dir_all(&dir).await.unwrap(); let config_path = dir.join("config.toml"); let config = Config { @@ -3378,7 +3992,9 @@ tool_dispatcher = "xml" runtime: RuntimeConfig::default(), reliability: ReliabilityConfig::default(), scheduler: SchedulerConfig::default(), + skills: SkillsConfig::default(), model_routes: Vec::new(), + embedding_routes: Vec::new(), query_classification: QueryClassificationConfig::default(), heartbeat: HeartbeatConfig::default(), cron: CronConfig::default(), @@ -3391,6 +4007,7 @@ tool_dispatcher = "xml" secrets: SecretsConfig::default(), browser: BrowserConfig::default(), http_request: HttpRequestConfig::default(), + multimodal: MultimodalConfig::default(), web_search: WebSearchConfig::default(), proxy: ProxyConfig::default(), agent: AgentConfig::default(), @@ -3401,10 +4018,10 @@ tool_dispatcher = "xml" hardware: HardwareConfig::default(), }; - config.save().unwrap(); + config.save().await.unwrap(); assert!(config_path.exists()); - let contents = fs::read_to_string(&config_path).unwrap(); + let contents = tokio::fs::read_to_string(&config_path).await.unwrap(); let loaded: Config = toml::from_str(&contents).unwrap(); assert!(loaded .api_key @@ -3416,16 +4033,16 @@ tool_dispatcher = "xml" assert_eq!(loaded.default_model.as_deref(), Some("test-model")); assert!((loaded.default_temperature - 0.9).abs() < f64::EPSILON); - let _ = fs::remove_dir_all(&dir); + let _ = fs::remove_dir_all(&dir).await; } - #[test] - fn config_save_encrypts_nested_credentials() { + #[tokio::test] + async fn config_save_encrypts_nested_credentials() { let dir = std::env::temp_dir().join(format!( "zeroclaw_test_nested_credentials_{}", uuid::Uuid::new_v4() )); - fs::create_dir_all(&dir).unwrap(); + fs::create_dir_all(&dir).await.unwrap(); let mut config = Config::default(); config.workspace_dir = dir.join("workspace"); @@ -3448,9 +4065,11 @@ tool_dispatcher = "xml" }, ); - config.save().unwrap(); + config.save().await.unwrap(); - let contents = fs::read_to_string(config.config_path.clone()).unwrap(); + let contents = tokio::fs::read_to_string(config.config_path.clone()) + .await + .unwrap(); let stored: Config = toml::from_str(&contents).unwrap(); let store = crate::security::SecretStore::new(&dir, true); @@ -3497,49 +4116,49 @@ tool_dispatcher = "xml" "postgres://user:pw@host/db" ); - let _ = fs::remove_dir_all(&dir); + let _ = fs::remove_dir_all(&dir).await; } - #[test] - fn config_save_atomic_cleanup() { + #[tokio::test] + async fn config_save_atomic_cleanup() { let dir = std::env::temp_dir().join(format!("zeroclaw_test_config_{}", uuid::Uuid::new_v4())); - fs::create_dir_all(&dir).unwrap(); + fs::create_dir_all(&dir).await.unwrap(); let config_path = dir.join("config.toml"); let mut config = Config::default(); config.workspace_dir = dir.join("workspace"); config.config_path = config_path.clone(); config.default_model = Some("model-a".into()); - - config.save().unwrap(); + config.save().await.unwrap(); assert!(config_path.exists()); config.default_model = Some("model-b".into()); - config.save().unwrap(); + config.save().await.unwrap(); - let contents = fs::read_to_string(&config_path).unwrap(); + let contents = tokio::fs::read_to_string(&config_path).await.unwrap(); assert!(contents.contains("model-b")); - let names: Vec = fs::read_dir(&dir) - .unwrap() + let names: Vec = ReadDirStream::new(fs::read_dir(&dir).await.unwrap()) .map(|entry| entry.unwrap().file_name().to_string_lossy().to_string()) - .collect(); + .collect() + .await; assert!(!names.iter().any(|name| name.contains(".tmp-"))); assert!(!names.iter().any(|name| name.ends_with(".bak"))); - let _ = fs::remove_dir_all(&dir); + let _ = fs::remove_dir_all(&dir).await; } // ── Telegram / Discord config ──────────────────────────── #[test] - fn telegram_config_serde() { + async fn telegram_config_serde() { let tc = TelegramConfig { bot_token: "123:XYZ".into(), allowed_users: vec!["alice".into(), "bob".into()], stream_mode: StreamMode::Partial, draft_update_interval_ms: 500, + interrupt_on_new_message: true, mention_only: false, }; let json = serde_json::to_string(&tc).unwrap(); @@ -3548,18 +4167,20 @@ tool_dispatcher = "xml" assert_eq!(parsed.allowed_users.len(), 2); assert_eq!(parsed.stream_mode, StreamMode::Partial); assert_eq!(parsed.draft_update_interval_ms, 500); + assert!(parsed.interrupt_on_new_message); } #[test] - fn telegram_config_defaults_stream_off() { + async fn telegram_config_defaults_stream_off() { let json = r#"{"bot_token":"tok","allowed_users":[]}"#; let parsed: TelegramConfig = serde_json::from_str(json).unwrap(); assert_eq!(parsed.stream_mode, StreamMode::Off); assert_eq!(parsed.draft_update_interval_ms, 1000); + assert!(!parsed.interrupt_on_new_message); } #[test] - fn discord_config_serde() { + async fn discord_config_serde() { let dc = DiscordConfig { bot_token: "discord-token".into(), guild_id: Some("12345".into()), @@ -3574,7 +4195,7 @@ tool_dispatcher = "xml" } #[test] - fn discord_config_optional_guild() { + async fn discord_config_optional_guild() { let dc = DiscordConfig { bot_token: "tok".into(), guild_id: None, @@ -3590,7 +4211,7 @@ tool_dispatcher = "xml" // ── iMessage / Matrix config ──────────────────────────── #[test] - fn imessage_config_serde() { + async fn imessage_config_serde() { let ic = IMessageConfig { allowed_contacts: vec!["+1234567890".into(), "user@icloud.com".into()], }; @@ -3601,7 +4222,7 @@ tool_dispatcher = "xml" } #[test] - fn imessage_config_empty_contacts() { + async fn imessage_config_empty_contacts() { let ic = IMessageConfig { allowed_contacts: vec![], }; @@ -3611,7 +4232,7 @@ tool_dispatcher = "xml" } #[test] - fn imessage_config_wildcard() { + async fn imessage_config_wildcard() { let ic = IMessageConfig { allowed_contacts: vec!["*".into()], }; @@ -3621,7 +4242,7 @@ tool_dispatcher = "xml" } #[test] - fn matrix_config_serde() { + async fn matrix_config_serde() { let mc = MatrixConfig { homeserver: "https://matrix.org".into(), access_token: "syt_token_abc".into(), @@ -3641,7 +4262,7 @@ tool_dispatcher = "xml" } #[test] - fn matrix_config_toml_roundtrip() { + async fn matrix_config_toml_roundtrip() { let mc = MatrixConfig { homeserver: "https://synapse.local:8448".into(), access_token: "tok".into(), @@ -3657,7 +4278,7 @@ tool_dispatcher = "xml" } #[test] - fn matrix_config_backward_compatible_without_session_hints() { + async fn matrix_config_backward_compatible_without_session_hints() { let toml = r#" homeserver = "https://matrix.org" access_token = "tok" @@ -3672,7 +4293,7 @@ allowed_users = ["@ops:matrix.org"] } #[test] - fn signal_config_serde() { + async fn signal_config_serde() { let sc = SignalConfig { http_url: "http://127.0.0.1:8686".into(), account: "+1234567890".into(), @@ -3692,7 +4313,7 @@ allowed_users = ["@ops:matrix.org"] } #[test] - fn signal_config_toml_roundtrip() { + async fn signal_config_toml_roundtrip() { let sc = SignalConfig { http_url: "http://localhost:8080".into(), account: "+9876543210".into(), @@ -3710,7 +4331,7 @@ allowed_users = ["@ops:matrix.org"] } #[test] - fn signal_config_defaults() { + async fn signal_config_defaults() { let json = r#"{"http_url":"http://127.0.0.1:8686","account":"+1234567890"}"#; let parsed: SignalConfig = serde_json::from_str(json).unwrap(); assert!(parsed.group_id.is_none()); @@ -3720,7 +4341,7 @@ allowed_users = ["@ops:matrix.org"] } #[test] - fn channels_config_with_imessage_and_matrix() { + async fn channels_config_with_imessage_and_matrix() { let c = ChannelsConfig { cli: true, telegram: None, @@ -3741,11 +4362,13 @@ allowed_users = ["@ops:matrix.org"] }), signal: None, whatsapp: None, + linq: None, email: None, irc: None, lark: None, dingtalk: None, qq: None, + message_timeout_secs: 300, }; let toml_str = toml::to_string_pretty(&c).unwrap(); let parsed: ChannelsConfig = toml::from_str(&toml_str).unwrap(); @@ -3756,7 +4379,7 @@ allowed_users = ["@ops:matrix.org"] } #[test] - fn channels_config_default_has_no_imessage_matrix() { + async fn channels_config_default_has_no_imessage_matrix() { let c = ChannelsConfig::default(); assert!(c.imessage.is_none()); assert!(c.matrix.is_none()); @@ -3765,7 +4388,7 @@ allowed_users = ["@ops:matrix.org"] // ── Edge cases: serde(default) for allowed_users ───────── #[test] - fn discord_config_deserializes_without_allowed_users() { + async fn discord_config_deserializes_without_allowed_users() { // Old configs won't have allowed_users — serde(default) should fill vec![] let json = r#"{"bot_token":"tok","guild_id":"123"}"#; let parsed: DiscordConfig = serde_json::from_str(json).unwrap(); @@ -3773,28 +4396,28 @@ allowed_users = ["@ops:matrix.org"] } #[test] - fn discord_config_deserializes_with_allowed_users() { + async fn discord_config_deserializes_with_allowed_users() { let json = r#"{"bot_token":"tok","guild_id":"123","allowed_users":["111","222"]}"#; let parsed: DiscordConfig = serde_json::from_str(json).unwrap(); assert_eq!(parsed.allowed_users, vec!["111", "222"]); } #[test] - fn slack_config_deserializes_without_allowed_users() { + async fn slack_config_deserializes_without_allowed_users() { let json = r#"{"bot_token":"xoxb-tok"}"#; let parsed: SlackConfig = serde_json::from_str(json).unwrap(); assert!(parsed.allowed_users.is_empty()); } #[test] - fn slack_config_deserializes_with_allowed_users() { + async fn slack_config_deserializes_with_allowed_users() { let json = r#"{"bot_token":"xoxb-tok","allowed_users":["U111"]}"#; let parsed: SlackConfig = serde_json::from_str(json).unwrap(); assert_eq!(parsed.allowed_users, vec!["U111"]); } #[test] - fn discord_config_toml_backward_compat() { + async fn discord_config_toml_backward_compat() { let toml_str = r#" bot_token = "tok" guild_id = "123" @@ -3805,7 +4428,7 @@ guild_id = "123" } #[test] - fn slack_config_toml_backward_compat() { + async fn slack_config_toml_backward_compat() { let toml_str = r#" bot_token = "xoxb-tok" channel_id = "C123" @@ -3816,14 +4439,14 @@ channel_id = "C123" } #[test] - fn webhook_config_with_secret() { + async fn webhook_config_with_secret() { let json = r#"{"port":8080,"secret":"my-secret-key"}"#; let parsed: WebhookConfig = serde_json::from_str(json).unwrap(); assert_eq!(parsed.secret.as_deref(), Some("my-secret-key")); } #[test] - fn webhook_config_without_secret() { + async fn webhook_config_without_secret() { let json = r#"{"port":8080}"#; let parsed: WebhookConfig = serde_json::from_str(json).unwrap(); assert!(parsed.secret.is_none()); @@ -3833,51 +4456,60 @@ channel_id = "C123" // ── WhatsApp config ────────────────────────────────────── #[test] - fn whatsapp_config_serde() { + async fn whatsapp_config_serde() { let wc = WhatsAppConfig { - access_token: "EAABx...".into(), - phone_number_id: "123456789".into(), - verify_token: "my-verify-token".into(), + access_token: Some("EAABx...".into()), + phone_number_id: Some("123456789".into()), + verify_token: Some("my-verify-token".into()), app_secret: None, + session_path: None, + pair_phone: None, + pair_code: None, allowed_numbers: vec!["+1234567890".into(), "+9876543210".into()], }; let json = serde_json::to_string(&wc).unwrap(); let parsed: WhatsAppConfig = serde_json::from_str(&json).unwrap(); - assert_eq!(parsed.access_token, "EAABx..."); - assert_eq!(parsed.phone_number_id, "123456789"); - assert_eq!(parsed.verify_token, "my-verify-token"); + assert_eq!(parsed.access_token, Some("EAABx...".into())); + assert_eq!(parsed.phone_number_id, Some("123456789".into())); + assert_eq!(parsed.verify_token, Some("my-verify-token".into())); assert_eq!(parsed.allowed_numbers.len(), 2); } #[test] - fn whatsapp_config_toml_roundtrip() { + async fn whatsapp_config_toml_roundtrip() { let wc = WhatsAppConfig { - access_token: "tok".into(), - phone_number_id: "12345".into(), - verify_token: "verify".into(), + access_token: Some("tok".into()), + phone_number_id: Some("12345".into()), + verify_token: Some("verify".into()), app_secret: Some("secret123".into()), + session_path: None, + pair_phone: None, + pair_code: None, allowed_numbers: vec!["+1".into()], }; let toml_str = toml::to_string(&wc).unwrap(); let parsed: WhatsAppConfig = toml::from_str(&toml_str).unwrap(); - assert_eq!(parsed.phone_number_id, "12345"); + assert_eq!(parsed.phone_number_id, Some("12345".into())); assert_eq!(parsed.allowed_numbers, vec!["+1"]); } #[test] - fn whatsapp_config_deserializes_without_allowed_numbers() { + async fn whatsapp_config_deserializes_without_allowed_numbers() { let json = r#"{"access_token":"tok","phone_number_id":"123","verify_token":"ver"}"#; let parsed: WhatsAppConfig = serde_json::from_str(json).unwrap(); assert!(parsed.allowed_numbers.is_empty()); } #[test] - fn whatsapp_config_wildcard_allowed() { + async fn whatsapp_config_wildcard_allowed() { let wc = WhatsAppConfig { - access_token: "tok".into(), - phone_number_id: "123".into(), - verify_token: "ver".into(), + access_token: Some("tok".into()), + phone_number_id: Some("123".into()), + verify_token: Some("ver".into()), app_secret: None, + session_path: None, + pair_phone: None, + pair_code: None, allowed_numbers: vec!["*".into()], }; let toml_str = toml::to_string(&wc).unwrap(); @@ -3886,7 +4518,39 @@ channel_id = "C123" } #[test] - fn channels_config_with_whatsapp() { + async fn whatsapp_config_backend_type_cloud_precedence_when_ambiguous() { + let wc = WhatsAppConfig { + access_token: Some("tok".into()), + phone_number_id: Some("123".into()), + verify_token: Some("ver".into()), + app_secret: None, + session_path: Some("~/.zeroclaw/state/whatsapp-web/session.db".into()), + pair_phone: None, + pair_code: None, + allowed_numbers: vec!["+1".into()], + }; + assert!(wc.is_ambiguous_config()); + assert_eq!(wc.backend_type(), "cloud"); + } + + #[test] + async fn whatsapp_config_backend_type_web() { + let wc = WhatsAppConfig { + access_token: None, + phone_number_id: None, + verify_token: None, + app_secret: None, + session_path: Some("~/.zeroclaw/state/whatsapp-web/session.db".into()), + pair_phone: None, + pair_code: None, + allowed_numbers: vec![], + }; + assert!(!wc.is_ambiguous_config()); + assert_eq!(wc.backend_type(), "web"); + } + + #[test] + async fn channels_config_with_whatsapp() { let c = ChannelsConfig { cli: true, telegram: None, @@ -3898,28 +4562,33 @@ channel_id = "C123" matrix: None, signal: None, whatsapp: Some(WhatsAppConfig { - access_token: "tok".into(), - phone_number_id: "123".into(), - verify_token: "ver".into(), + access_token: Some("tok".into()), + phone_number_id: Some("123".into()), + verify_token: Some("ver".into()), app_secret: None, + session_path: None, + pair_phone: None, + pair_code: None, allowed_numbers: vec!["+1".into()], }), + linq: None, email: None, irc: None, lark: None, dingtalk: None, qq: None, + message_timeout_secs: 300, }; let toml_str = toml::to_string_pretty(&c).unwrap(); let parsed: ChannelsConfig = toml::from_str(&toml_str).unwrap(); assert!(parsed.whatsapp.is_some()); let wa = parsed.whatsapp.unwrap(); - assert_eq!(wa.phone_number_id, "123"); + assert_eq!(wa.phone_number_id, Some("123".into())); assert_eq!(wa.allowed_numbers, vec!["+1"]); } #[test] - fn channels_config_default_has_no_whatsapp() { + async fn channels_config_default_has_no_whatsapp() { let c = ChannelsConfig::default(); assert!(c.whatsapp.is_none()); } @@ -3929,13 +4598,13 @@ channel_id = "C123" // ══════════════════════════════════════════════════════════ #[test] - fn checklist_gateway_default_requires_pairing() { + async fn checklist_gateway_default_requires_pairing() { let g = GatewayConfig::default(); assert!(g.require_pairing, "Pairing must be required by default"); } #[test] - fn checklist_gateway_default_blocks_public_bind() { + async fn checklist_gateway_default_blocks_public_bind() { let g = GatewayConfig::default(); assert!( !g.allow_public_bind, @@ -3944,7 +4613,7 @@ channel_id = "C123" } #[test] - fn checklist_gateway_default_no_tokens() { + async fn checklist_gateway_default_no_tokens() { let g = GatewayConfig::default(); assert!( g.paired_tokens.is_empty(), @@ -3959,7 +4628,7 @@ channel_id = "C123" } #[test] - fn checklist_gateway_cli_default_host_is_localhost() { + async fn checklist_gateway_cli_default_host_is_localhost() { // The CLI default for --host is 127.0.0.1 (checked in main.rs) // Here we verify the config default matches let c = Config::default(); @@ -3974,7 +4643,7 @@ channel_id = "C123" } #[test] - fn checklist_gateway_serde_roundtrip() { + async fn checklist_gateway_serde_roundtrip() { let g = GatewayConfig { port: 3000, host: "127.0.0.1".into(), @@ -4002,7 +4671,7 @@ channel_id = "C123" } #[test] - fn checklist_gateway_backward_compat_no_gateway_section() { + async fn checklist_gateway_backward_compat_no_gateway_section() { // Old configs without [gateway] should get secure defaults let minimal = r#" workspace_dir = "/tmp/ws" @@ -4021,7 +4690,7 @@ default_temperature = 0.7 } #[test] - fn checklist_autonomy_default_is_workspace_scoped() { + async fn checklist_autonomy_default_is_workspace_scoped() { let a = AutonomyConfig::default(); assert!(a.workspace_only, "Default autonomy must be workspace_only"); assert!( @@ -4043,7 +4712,7 @@ default_temperature = 0.7 // ══════════════════════════════════════════════════════════ #[test] - fn composio_config_default_disabled() { + async fn composio_config_default_disabled() { let c = ComposioConfig::default(); assert!(!c.enabled, "Composio must be disabled by default"); assert!(c.api_key.is_none(), "No API key by default"); @@ -4051,7 +4720,7 @@ default_temperature = 0.7 } #[test] - fn composio_config_serde_roundtrip() { + async fn composio_config_serde_roundtrip() { let c = ComposioConfig { enabled: true, api_key: Some("comp-key-123".into()), @@ -4065,7 +4734,7 @@ default_temperature = 0.7 } #[test] - fn composio_config_backward_compat_missing_section() { + async fn composio_config_backward_compat_missing_section() { let minimal = r#" workspace_dir = "/tmp/ws" config_path = "/tmp/config.toml" @@ -4080,7 +4749,7 @@ default_temperature = 0.7 } #[test] - fn composio_config_partial_toml() { + async fn composio_config_partial_toml() { let toml_str = r" enabled = true "; @@ -4090,18 +4759,29 @@ enabled = true assert_eq!(parsed.entity_id, "default"); } + #[test] + async fn composio_config_enable_alias_supported() { + let toml_str = r" +enable = true +"; + let parsed: ComposioConfig = toml::from_str(toml_str).unwrap(); + assert!(parsed.enabled); + assert!(parsed.api_key.is_none()); + assert_eq!(parsed.entity_id, "default"); + } + // ══════════════════════════════════════════════════════════ // SECRETS CONFIG TESTS // ══════════════════════════════════════════════════════════ #[test] - fn secrets_config_default_encrypts() { + async fn secrets_config_default_encrypts() { let s = SecretsConfig::default(); assert!(s.encrypt, "Encryption must be enabled by default"); } #[test] - fn secrets_config_serde_roundtrip() { + async fn secrets_config_serde_roundtrip() { let s = SecretsConfig { encrypt: false }; let toml_str = toml::to_string(&s).unwrap(); let parsed: SecretsConfig = toml::from_str(&toml_str).unwrap(); @@ -4109,7 +4789,7 @@ enabled = true } #[test] - fn secrets_config_backward_compat_missing_section() { + async fn secrets_config_backward_compat_missing_section() { let minimal = r#" workspace_dir = "/tmp/ws" config_path = "/tmp/config.toml" @@ -4123,7 +4803,7 @@ default_temperature = 0.7 } #[test] - fn config_default_has_composio_and_secrets() { + async fn config_default_has_composio_and_secrets() { let c = Config::default(); assert!(!c.composio.enabled); assert!(c.composio.api_key.is_none()); @@ -4133,7 +4813,7 @@ default_temperature = 0.7 } #[test] - fn browser_config_default_disabled() { + async fn browser_config_default_disabled() { let b = BrowserConfig::default(); assert!(!b.enabled); assert!(b.allowed_domains.is_empty()); @@ -4150,7 +4830,7 @@ default_temperature = 0.7 } #[test] - fn browser_config_serde_roundtrip() { + async fn browser_config_serde_roundtrip() { let b = BrowserConfig { enabled: true, allowed_domains: vec!["example.com".into(), "docs.example.com".into()], @@ -4194,7 +4874,7 @@ default_temperature = 0.7 } #[test] - fn browser_config_backward_compat_missing_section() { + async fn browser_config_backward_compat_missing_section() { let minimal = r#" workspace_dir = "/tmp/ws" config_path = "/tmp/config.toml" @@ -4207,11 +4887,9 @@ default_temperature = 0.7 // ── Environment variable overrides (Docker support) ───────── - fn env_override_test_guard() -> std::sync::MutexGuard<'static, ()> { - static ENV_OVERRIDE_TEST_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(()); - ENV_OVERRIDE_TEST_LOCK - .lock() - .expect("env override test lock poisoned") + async fn env_override_lock() -> MutexGuard<'static, ()> { + static ENV_OVERRIDE_TEST_LOCK: Mutex<()> = Mutex::const_new(()); + ENV_OVERRIDE_TEST_LOCK.lock().await } fn clear_proxy_env_test_vars() { @@ -4237,8 +4915,8 @@ default_temperature = 0.7 } #[test] - fn env_override_api_key() { - let _env_guard = env_override_test_guard(); + async fn env_override_api_key() { + let _env_guard = env_override_lock().await; let mut config = Config::default(); assert!(config.api_key.is_none()); @@ -4250,8 +4928,8 @@ default_temperature = 0.7 } #[test] - fn env_override_api_key_fallback() { - let _env_guard = env_override_test_guard(); + async fn env_override_api_key_fallback() { + let _env_guard = env_override_lock().await; let mut config = Config::default(); std::env::remove_var("ZEROCLAW_API_KEY"); @@ -4263,8 +4941,8 @@ default_temperature = 0.7 } #[test] - fn env_override_provider() { - let _env_guard = env_override_test_guard(); + async fn env_override_provider() { + let _env_guard = env_override_lock().await; let mut config = Config::default(); std::env::set_var("ZEROCLAW_PROVIDER", "anthropic"); @@ -4275,8 +4953,42 @@ default_temperature = 0.7 } #[test] - fn env_override_provider_fallback() { - let _env_guard = env_override_test_guard(); + async fn env_override_open_skills_enabled_and_dir() { + let _env_guard = env_override_lock().await; + let mut config = Config::default(); + assert!(!config.skills.open_skills_enabled); + assert!(config.skills.open_skills_dir.is_none()); + + std::env::set_var("ZEROCLAW_OPEN_SKILLS_ENABLED", "true"); + std::env::set_var("ZEROCLAW_OPEN_SKILLS_DIR", "/tmp/open-skills"); + config.apply_env_overrides(); + + assert!(config.skills.open_skills_enabled); + assert_eq!( + config.skills.open_skills_dir.as_deref(), + Some("/tmp/open-skills") + ); + + std::env::remove_var("ZEROCLAW_OPEN_SKILLS_ENABLED"); + std::env::remove_var("ZEROCLAW_OPEN_SKILLS_DIR"); + } + + #[test] + async fn env_override_open_skills_enabled_invalid_value_keeps_existing_value() { + let _env_guard = env_override_lock().await; + let mut config = Config::default(); + config.skills.open_skills_enabled = true; + + std::env::set_var("ZEROCLAW_OPEN_SKILLS_ENABLED", "maybe"); + config.apply_env_overrides(); + + assert!(config.skills.open_skills_enabled); + std::env::remove_var("ZEROCLAW_OPEN_SKILLS_ENABLED"); + } + + #[test] + async fn env_override_provider_fallback() { + let _env_guard = env_override_lock().await; let mut config = Config::default(); std::env::remove_var("ZEROCLAW_PROVIDER"); @@ -4288,8 +5000,44 @@ default_temperature = 0.7 } #[test] - fn env_override_glm_api_key_for_regional_aliases() { - let _env_guard = env_override_test_guard(); + async fn env_override_provider_fallback_does_not_replace_non_default_provider() { + let _env_guard = env_override_lock().await; + let mut config = Config { + default_provider: Some("custom:https://proxy.example.com/v1".to_string()), + ..Config::default() + }; + + std::env::remove_var("ZEROCLAW_PROVIDER"); + std::env::set_var("PROVIDER", "openrouter"); + config.apply_env_overrides(); + assert_eq!( + config.default_provider.as_deref(), + Some("custom:https://proxy.example.com/v1") + ); + + std::env::remove_var("PROVIDER"); + } + + #[test] + async fn env_override_zero_claw_provider_overrides_non_default_provider() { + let _env_guard = env_override_lock().await; + let mut config = Config { + default_provider: Some("custom:https://proxy.example.com/v1".to_string()), + ..Config::default() + }; + + std::env::set_var("ZEROCLAW_PROVIDER", "openrouter"); + std::env::set_var("PROVIDER", "anthropic"); + config.apply_env_overrides(); + assert_eq!(config.default_provider.as_deref(), Some("openrouter")); + + std::env::remove_var("ZEROCLAW_PROVIDER"); + std::env::remove_var("PROVIDER"); + } + + #[test] + async fn env_override_glm_api_key_for_regional_aliases() { + let _env_guard = env_override_lock().await; let mut config = Config { default_provider: Some("glm-cn".to_string()), ..Config::default() @@ -4303,8 +5051,8 @@ default_temperature = 0.7 } #[test] - fn env_override_zai_api_key_for_regional_aliases() { - let _env_guard = env_override_test_guard(); + async fn env_override_zai_api_key_for_regional_aliases() { + let _env_guard = env_override_lock().await; let mut config = Config { default_provider: Some("zai-cn".to_string()), ..Config::default() @@ -4318,8 +5066,8 @@ default_temperature = 0.7 } #[test] - fn env_override_model() { - let _env_guard = env_override_test_guard(); + async fn env_override_model() { + let _env_guard = env_override_lock().await; let mut config = Config::default(); std::env::set_var("ZEROCLAW_MODEL", "gpt-4o"); @@ -4330,8 +5078,8 @@ default_temperature = 0.7 } #[test] - fn env_override_model_fallback() { - let _env_guard = env_override_test_guard(); + async fn env_override_model_fallback() { + let _env_guard = env_override_lock().await; let mut config = Config::default(); std::env::remove_var("ZEROCLAW_MODEL"); @@ -4346,8 +5094,8 @@ default_temperature = 0.7 } #[test] - fn env_override_workspace() { - let _env_guard = env_override_test_guard(); + async fn env_override_workspace() { + let _env_guard = env_override_lock().await; let mut config = Config::default(); std::env::set_var("ZEROCLAW_WORKSPACE", "/custom/workspace"); @@ -4358,8 +5106,77 @@ default_temperature = 0.7 } #[test] - fn load_or_init_workspace_override_uses_workspace_root_for_config() { - let _env_guard = env_override_test_guard(); + async fn resolve_runtime_config_dirs_uses_env_workspace_first() { + let _env_guard = env_override_lock().await; + let default_config_dir = std::env::temp_dir().join(uuid::Uuid::new_v4().to_string()); + let default_workspace_dir = default_config_dir.join("workspace"); + let workspace_dir = default_config_dir.join("profile-a"); + + std::env::set_var("ZEROCLAW_WORKSPACE", &workspace_dir); + let (config_dir, resolved_workspace_dir, source) = + resolve_runtime_config_dirs(&default_config_dir, &default_workspace_dir) + .await + .unwrap(); + + assert_eq!(source, ConfigResolutionSource::EnvWorkspace); + assert_eq!(config_dir, workspace_dir); + assert_eq!(resolved_workspace_dir, workspace_dir.join("workspace")); + + std::env::remove_var("ZEROCLAW_WORKSPACE"); + let _ = fs::remove_dir_all(default_config_dir).await; + } + + #[test] + async fn resolve_runtime_config_dirs_uses_active_workspace_marker() { + let _env_guard = env_override_lock().await; + let default_config_dir = std::env::temp_dir().join(uuid::Uuid::new_v4().to_string()); + let default_workspace_dir = default_config_dir.join("workspace"); + let marker_config_dir = default_config_dir.join("profiles").join("alpha"); + let state_path = default_config_dir.join(ACTIVE_WORKSPACE_STATE_FILE); + + std::env::remove_var("ZEROCLAW_WORKSPACE"); + fs::create_dir_all(&default_config_dir).await.unwrap(); + let state = ActiveWorkspaceState { + config_dir: marker_config_dir.to_string_lossy().into_owned(), + }; + fs::write(&state_path, toml::to_string(&state).unwrap()) + .await + .unwrap(); + + let (config_dir, resolved_workspace_dir, source) = + resolve_runtime_config_dirs(&default_config_dir, &default_workspace_dir) + .await + .unwrap(); + + assert_eq!(source, ConfigResolutionSource::ActiveWorkspaceMarker); + assert_eq!(config_dir, marker_config_dir); + assert_eq!(resolved_workspace_dir, marker_config_dir.join("workspace")); + + let _ = fs::remove_dir_all(default_config_dir).await; + } + + #[test] + async fn resolve_runtime_config_dirs_falls_back_to_default_layout() { + let _env_guard = env_override_lock().await; + let default_config_dir = std::env::temp_dir().join(uuid::Uuid::new_v4().to_string()); + let default_workspace_dir = default_config_dir.join("workspace"); + + std::env::remove_var("ZEROCLAW_WORKSPACE"); + let (config_dir, resolved_workspace_dir, source) = + resolve_runtime_config_dirs(&default_config_dir, &default_workspace_dir) + .await + .unwrap(); + + assert_eq!(source, ConfigResolutionSource::DefaultConfigDir); + assert_eq!(config_dir, default_config_dir); + assert_eq!(resolved_workspace_dir, default_workspace_dir); + + let _ = fs::remove_dir_all(default_config_dir).await; + } + + #[test] + async fn load_or_init_workspace_override_uses_workspace_root_for_config() { + let _env_guard = env_override_lock().await; let temp_home = std::env::temp_dir().join(format!("zeroclaw_test_home_{}", uuid::Uuid::new_v4())); let workspace_dir = temp_home.join("profile-a"); @@ -4368,7 +5185,7 @@ default_temperature = 0.7 std::env::set_var("HOME", &temp_home); std::env::set_var("ZEROCLAW_WORKSPACE", &workspace_dir); - let config = Config::load_or_init().unwrap(); + let config = Config::load_or_init().await.unwrap(); assert_eq!(config.workspace_dir, workspace_dir.join("workspace")); assert_eq!(config.config_path, workspace_dir.join("config.toml")); @@ -4380,12 +5197,12 @@ default_temperature = 0.7 } else { std::env::remove_var("HOME"); } - let _ = fs::remove_dir_all(temp_home); + let _ = fs::remove_dir_all(temp_home).await; } #[test] - fn load_or_init_workspace_suffix_uses_legacy_config_layout() { - let _env_guard = env_override_test_guard(); + async fn load_or_init_workspace_suffix_uses_legacy_config_layout() { + let _env_guard = env_override_lock().await; let temp_home = std::env::temp_dir().join(format!("zeroclaw_test_home_{}", uuid::Uuid::new_v4())); let workspace_dir = temp_home.join("workspace"); @@ -4395,7 +5212,7 @@ default_temperature = 0.7 std::env::set_var("HOME", &temp_home); std::env::set_var("ZEROCLAW_WORKSPACE", &workspace_dir); - let config = Config::load_or_init().unwrap(); + let config = Config::load_or_init().await.unwrap(); assert_eq!(config.workspace_dir, workspace_dir); assert_eq!(config.config_path, legacy_config_path); @@ -4407,32 +5224,33 @@ default_temperature = 0.7 } else { std::env::remove_var("HOME"); } - let _ = fs::remove_dir_all(temp_home); + let _ = fs::remove_dir_all(temp_home).await; } #[test] - fn load_or_init_workspace_override_keeps_existing_legacy_config() { - let _env_guard = env_override_test_guard(); + async fn load_or_init_workspace_override_keeps_existing_legacy_config() { + let _env_guard = env_override_lock().await; let temp_home = std::env::temp_dir().join(format!("zeroclaw_test_home_{}", uuid::Uuid::new_v4())); let workspace_dir = temp_home.join("custom-workspace"); let legacy_config_dir = temp_home.join(".zeroclaw"); let legacy_config_path = legacy_config_dir.join("config.toml"); - fs::create_dir_all(&legacy_config_dir).unwrap(); + fs::create_dir_all(&legacy_config_dir).await.unwrap(); fs::write( &legacy_config_path, r#"default_temperature = 0.7 default_model = "legacy-model" "#, ) + .await .unwrap(); let original_home = std::env::var("HOME").ok(); std::env::set_var("HOME", &temp_home); std::env::set_var("ZEROCLAW_WORKSPACE", &workspace_dir); - let config = Config::load_or_init().unwrap(); + let config = Config::load_or_init().await.unwrap(); assert_eq!(config.workspace_dir, workspace_dir); assert_eq!(config.config_path, legacy_config_path); @@ -4444,30 +5262,33 @@ default_model = "legacy-model" } else { std::env::remove_var("HOME"); } - let _ = fs::remove_dir_all(temp_home); + let _ = fs::remove_dir_all(temp_home).await; } #[test] - fn load_or_init_uses_persisted_active_workspace_marker() { - let _env_guard = env_override_test_guard(); + async fn load_or_init_uses_persisted_active_workspace_marker() { + let _env_guard = env_override_lock().await; let temp_home = std::env::temp_dir().join(format!("zeroclaw_test_home_{}", uuid::Uuid::new_v4())); let custom_config_dir = temp_home.join("profiles").join("agent-alpha"); - fs::create_dir_all(&custom_config_dir).unwrap(); + fs::create_dir_all(&custom_config_dir).await.unwrap(); fs::write( custom_config_dir.join("config.toml"), "default_temperature = 0.7\ndefault_model = \"persisted-profile\"\n", ) + .await .unwrap(); let original_home = std::env::var("HOME").ok(); std::env::set_var("HOME", &temp_home); std::env::remove_var("ZEROCLAW_WORKSPACE"); - persist_active_workspace_config_dir(&custom_config_dir).unwrap(); + persist_active_workspace_config_dir(&custom_config_dir) + .await + .unwrap(); - let config = Config::load_or_init().unwrap(); + let config = Config::load_or_init().await.unwrap(); assert_eq!(config.config_path, custom_config_dir.join("config.toml")); assert_eq!(config.workspace_dir, custom_config_dir.join("workspace")); @@ -4478,30 +5299,33 @@ default_model = "legacy-model" } else { std::env::remove_var("HOME"); } - let _ = fs::remove_dir_all(temp_home); + let _ = fs::remove_dir_all(temp_home).await; } #[test] - fn load_or_init_env_workspace_override_takes_priority_over_marker() { - let _env_guard = env_override_test_guard(); + async fn load_or_init_env_workspace_override_takes_priority_over_marker() { + let _env_guard = env_override_lock().await; let temp_home = std::env::temp_dir().join(format!("zeroclaw_test_home_{}", uuid::Uuid::new_v4())); let marker_config_dir = temp_home.join("profiles").join("persisted-profile"); let env_workspace_dir = temp_home.join("env-workspace"); - fs::create_dir_all(&marker_config_dir).unwrap(); + fs::create_dir_all(&marker_config_dir).await.unwrap(); fs::write( marker_config_dir.join("config.toml"), "default_temperature = 0.7\ndefault_model = \"marker-model\"\n", ) + .await .unwrap(); let original_home = std::env::var("HOME").ok(); std::env::set_var("HOME", &temp_home); - persist_active_workspace_config_dir(&marker_config_dir).unwrap(); + persist_active_workspace_config_dir(&marker_config_dir) + .await + .unwrap(); std::env::set_var("ZEROCLAW_WORKSPACE", &env_workspace_dir); - let config = Config::load_or_init().unwrap(); + let config = Config::load_or_init().await.unwrap(); assert_eq!(config.workspace_dir, env_workspace_dir.join("workspace")); assert_eq!(config.config_path, env_workspace_dir.join("config.toml")); @@ -4512,12 +5336,12 @@ default_model = "legacy-model" } else { std::env::remove_var("HOME"); } - let _ = fs::remove_dir_all(temp_home); + let _ = fs::remove_dir_all(temp_home).await; } #[test] - fn persist_active_workspace_marker_is_cleared_for_default_config_dir() { - let _env_guard = env_override_test_guard(); + async fn persist_active_workspace_marker_is_cleared_for_default_config_dir() { + let _env_guard = env_override_lock().await; let temp_home = std::env::temp_dir().join(format!("zeroclaw_test_home_{}", uuid::Uuid::new_v4())); let default_config_dir = temp_home.join(".zeroclaw"); @@ -4527,10 +5351,14 @@ default_model = "legacy-model" let original_home = std::env::var("HOME").ok(); std::env::set_var("HOME", &temp_home); - persist_active_workspace_config_dir(&custom_config_dir).unwrap(); + persist_active_workspace_config_dir(&custom_config_dir) + .await + .unwrap(); assert!(marker_path.exists()); - persist_active_workspace_config_dir(&default_config_dir).unwrap(); + persist_active_workspace_config_dir(&default_config_dir) + .await + .unwrap(); assert!(!marker_path.exists()); if let Some(home) = original_home { @@ -4538,12 +5366,12 @@ default_model = "legacy-model" } else { std::env::remove_var("HOME"); } - let _ = fs::remove_dir_all(temp_home); + let _ = fs::remove_dir_all(temp_home).await; } #[test] - fn env_override_empty_values_ignored() { - let _env_guard = env_override_test_guard(); + async fn env_override_empty_values_ignored() { + let _env_guard = env_override_lock().await; let mut config = Config::default(); let original_provider = config.default_provider.clone(); @@ -4555,8 +5383,8 @@ default_model = "legacy-model" } #[test] - fn env_override_gateway_port() { - let _env_guard = env_override_test_guard(); + async fn env_override_gateway_port() { + let _env_guard = env_override_lock().await; let mut config = Config::default(); assert_eq!(config.gateway.port, 3000); @@ -4568,8 +5396,8 @@ default_model = "legacy-model" } #[test] - fn env_override_port_fallback() { - let _env_guard = env_override_test_guard(); + async fn env_override_port_fallback() { + let _env_guard = env_override_lock().await; let mut config = Config::default(); std::env::remove_var("ZEROCLAW_GATEWAY_PORT"); @@ -4581,8 +5409,8 @@ default_model = "legacy-model" } #[test] - fn env_override_gateway_host() { - let _env_guard = env_override_test_guard(); + async fn env_override_gateway_host() { + let _env_guard = env_override_lock().await; let mut config = Config::default(); assert_eq!(config.gateway.host, "127.0.0.1"); @@ -4594,8 +5422,8 @@ default_model = "legacy-model" } #[test] - fn env_override_host_fallback() { - let _env_guard = env_override_test_guard(); + async fn env_override_host_fallback() { + let _env_guard = env_override_lock().await; let mut config = Config::default(); std::env::remove_var("ZEROCLAW_GATEWAY_HOST"); @@ -4607,8 +5435,8 @@ default_model = "legacy-model" } #[test] - fn env_override_temperature() { - let _env_guard = env_override_test_guard(); + async fn env_override_temperature() { + let _env_guard = env_override_lock().await; let mut config = Config::default(); std::env::set_var("ZEROCLAW_TEMPERATURE", "0.5"); @@ -4619,8 +5447,8 @@ default_model = "legacy-model" } #[test] - fn env_override_temperature_out_of_range_ignored() { - let _env_guard = env_override_test_guard(); + async fn env_override_temperature_out_of_range_ignored() { + let _env_guard = env_override_lock().await; // Clean up any leftover env vars from other tests std::env::remove_var("ZEROCLAW_TEMPERATURE"); @@ -4639,8 +5467,38 @@ default_model = "legacy-model" } #[test] - fn env_override_invalid_port_ignored() { - let _env_guard = env_override_test_guard(); + async fn env_override_reasoning_enabled() { + let _env_guard = env_override_lock().await; + let mut config = Config::default(); + assert_eq!(config.runtime.reasoning_enabled, None); + + std::env::set_var("ZEROCLAW_REASONING_ENABLED", "false"); + config.apply_env_overrides(); + assert_eq!(config.runtime.reasoning_enabled, Some(false)); + + std::env::set_var("ZEROCLAW_REASONING_ENABLED", "true"); + config.apply_env_overrides(); + assert_eq!(config.runtime.reasoning_enabled, Some(true)); + + std::env::remove_var("ZEROCLAW_REASONING_ENABLED"); + } + + #[test] + async fn env_override_reasoning_invalid_value_ignored() { + let _env_guard = env_override_lock().await; + let mut config = Config::default(); + config.runtime.reasoning_enabled = Some(false); + + std::env::set_var("ZEROCLAW_REASONING_ENABLED", "maybe"); + config.apply_env_overrides(); + assert_eq!(config.runtime.reasoning_enabled, Some(false)); + + std::env::remove_var("ZEROCLAW_REASONING_ENABLED"); + } + + #[test] + async fn env_override_invalid_port_ignored() { + let _env_guard = env_override_lock().await; let mut config = Config::default(); let original_port = config.gateway.port; @@ -4652,8 +5510,8 @@ default_model = "legacy-model" } #[test] - fn env_override_web_search_config() { - let _env_guard = env_override_test_guard(); + async fn env_override_web_search_config() { + let _env_guard = env_override_lock().await; let mut config = Config::default(); std::env::set_var("WEB_SEARCH_ENABLED", "false"); @@ -4681,8 +5539,8 @@ default_model = "legacy-model" } #[test] - fn env_override_web_search_invalid_values_ignored() { - let _env_guard = env_override_test_guard(); + async fn env_override_web_search_invalid_values_ignored() { + let _env_guard = env_override_lock().await; let mut config = Config::default(); let original_max_results = config.web_search.max_results; let original_timeout = config.web_search.timeout_secs; @@ -4700,8 +5558,8 @@ default_model = "legacy-model" } #[test] - fn env_override_storage_provider_config() { - let _env_guard = env_override_test_guard(); + async fn env_override_storage_provider_config() { + let _env_guard = env_override_lock().await; let mut config = Config::default(); std::env::set_var("ZEROCLAW_STORAGE_PROVIDER", "postgres"); @@ -4726,7 +5584,7 @@ default_model = "legacy-model" } #[test] - fn proxy_config_scope_services_requires_entries_when_enabled() { + async fn proxy_config_scope_services_requires_entries_when_enabled() { let proxy = ProxyConfig { enabled: true, http_proxy: Some("http://127.0.0.1:7890".into()), @@ -4742,8 +5600,8 @@ default_model = "legacy-model" } #[test] - fn env_override_proxy_scope_services() { - let _env_guard = env_override_test_guard(); + async fn env_override_proxy_scope_services() { + let _env_guard = env_override_lock().await; clear_proxy_env_test_vars(); let mut config = Config::default(); @@ -4771,8 +5629,8 @@ default_model = "legacy-model" } #[test] - fn env_override_proxy_scope_environment_applies_process_env() { - let _env_guard = env_override_test_guard(); + async fn env_override_proxy_scope_environment_applies_process_env() { + let _env_guard = env_override_lock().await; clear_proxy_env_test_vars(); let mut config = Config::default(); @@ -4808,7 +5666,7 @@ default_model = "legacy-model" } #[test] - fn runtime_proxy_client_cache_reuses_default_profile_key() { + async fn runtime_proxy_client_cache_reuses_default_profile_key() { let service_key = format!( "provider.cache_test.{}", std::time::SystemTime::now() @@ -4829,7 +5687,7 @@ default_model = "legacy-model" } #[test] - fn set_runtime_proxy_config_clears_runtime_proxy_client_cache() { + async fn set_runtime_proxy_config_clears_runtime_proxy_client_cache() { let service_key = format!( "provider.cache_timeout_test.{}", std::time::SystemTime::now() @@ -4848,7 +5706,7 @@ default_model = "legacy-model" } #[test] - fn gateway_config_default_values() { + async fn gateway_config_default_values() { let g = GatewayConfig::default(); assert_eq!(g.port, 3000); assert_eq!(g.host, "127.0.0.1"); @@ -4863,14 +5721,14 @@ default_model = "legacy-model" // ── Peripherals config ─────────────────────────────────────── #[test] - fn peripherals_config_default_disabled() { + async fn peripherals_config_default_disabled() { let p = PeripheralsConfig::default(); assert!(!p.enabled); assert!(p.boards.is_empty()); } #[test] - fn peripheral_board_config_defaults() { + async fn peripheral_board_config_defaults() { let b = PeripheralBoardConfig::default(); assert!(b.board.is_empty()); assert_eq!(b.transport, "serial"); @@ -4879,7 +5737,7 @@ default_model = "legacy-model" } #[test] - fn peripherals_config_toml_roundtrip() { + async fn peripherals_config_toml_roundtrip() { let p = PeripheralsConfig { enabled: true, boards: vec![PeripheralBoardConfig { @@ -4899,7 +5757,7 @@ default_model = "legacy-model" } #[test] - fn lark_config_serde() { + async fn lark_config_serde() { let lc = LarkConfig { app_id: "cli_123456".into(), app_secret: "secret_abc".into(), @@ -4921,7 +5779,7 @@ default_model = "legacy-model" } #[test] - fn lark_config_toml_roundtrip() { + async fn lark_config_toml_roundtrip() { let lc = LarkConfig { app_id: "cli_123456".into(), app_secret: "secret_abc".into(), @@ -4940,7 +5798,7 @@ default_model = "legacy-model" } #[test] - fn lark_config_deserializes_without_optional_fields() { + async fn lark_config_deserializes_without_optional_fields() { let json = r#"{"app_id":"cli_123","app_secret":"secret"}"#; let parsed: LarkConfig = serde_json::from_str(json).unwrap(); assert!(parsed.encrypt_key.is_none()); @@ -4950,7 +5808,7 @@ default_model = "legacy-model" } #[test] - fn lark_config_defaults_to_lark_endpoint() { + async fn lark_config_defaults_to_lark_endpoint() { let json = r#"{"app_id":"cli_123","app_secret":"secret"}"#; let parsed: LarkConfig = serde_json::from_str(json).unwrap(); assert!( @@ -4960,7 +5818,7 @@ default_model = "legacy-model" } #[test] - fn lark_config_with_wildcard_allowed_users() { + async fn lark_config_with_wildcard_allowed_users() { let json = r#"{"app_id":"cli_123","app_secret":"secret","allowed_users":["*"]}"#; let parsed: LarkConfig = serde_json::from_str(json).unwrap(); assert_eq!(parsed.allowed_users, vec!["*"]); @@ -4970,21 +5828,21 @@ default_model = "legacy-model" #[cfg(unix)] #[test] - fn new_config_file_has_restricted_permissions() { - use std::os::unix::fs::PermissionsExt; - + async fn new_config_file_has_restricted_permissions() { let tmp = tempfile::TempDir::new().unwrap(); let config_path = tmp.path().join("config.toml"); // Create a config and save it let mut config = Config::default(); config.config_path = config_path.clone(); - config.save().unwrap(); + config.save().await.unwrap(); // Apply the same permission logic as load_or_init - let _ = std::fs::set_permissions(&config_path, std::fs::Permissions::from_mode(0o600)); + fs::set_permissions(&config_path, Permissions::from_mode(0o600)) + .await + .expect("Failed to set permissions"); - let meta = std::fs::metadata(&config_path).unwrap(); + let meta = fs::metadata(&config_path).await.unwrap(); let mode = meta.permissions().mode() & 0o777; assert_eq!( mode, 0o600, @@ -4994,7 +5852,7 @@ default_model = "legacy-model" #[cfg(unix)] #[test] - fn world_readable_config_is_detectable() { + async fn world_readable_config_is_detectable() { use std::os::unix::fs::PermissionsExt; let tmp = tempfile::TempDir::new().unwrap(); diff --git a/src/cron/mod.rs b/src/cron/mod.rs index 0f39bc7..49db429 100644 --- a/src/cron/mod.rs +++ b/src/cron/mod.rs @@ -1,5 +1,6 @@ use crate::config::Config; -use anyhow::Result; +use crate::security::SecurityPolicy; +use anyhow::{bail, Result}; mod schedule; mod store; @@ -96,6 +97,58 @@ pub fn handle_command(command: crate::CronCommands, config: &Config) -> Result<( println!(" Cmd : {}", job.command); Ok(()) } + crate::CronCommands::Update { + id, + expression, + tz, + command, + name, + } => { + if expression.is_none() && tz.is_none() && command.is_none() && name.is_none() { + bail!("At least one of --expression, --tz, --command, or --name must be provided"); + } + + // Merge expression/tz with the existing schedule so that + // --tz alone updates the timezone and --expression alone + // preserves the existing timezone. + let schedule = if expression.is_some() || tz.is_some() { + let existing = get_job(config, &id)?; + let (existing_expr, existing_tz) = match existing.schedule { + Schedule::Cron { + expr, + tz: existing_tz, + } => (expr, existing_tz), + _ => bail!("Cannot update expression/tz on a non-cron schedule"), + }; + Some(Schedule::Cron { + expr: expression.unwrap_or(existing_expr), + tz: tz.or(existing_tz), + }) + } else { + None + }; + + if let Some(ref cmd) = command { + let security = SecurityPolicy::from_config(&config.autonomy, &config.workspace_dir); + if !security.is_command_allowed(cmd) { + bail!("Command blocked by security policy: {cmd}"); + } + } + + let patch = CronJobPatch { + schedule, + command, + name, + ..CronJobPatch::default() + }; + + let job = update_job(config, &id, patch)?; + println!("\u{2705} Updated cron job {}", job.id); + println!(" Expr: {}", job.expression); + println!(" Next: {}", job.next_run.to_rfc3339()); + println!(" Cmd : {}", job.command); + Ok(()) + } crate::CronCommands::Remove { id } => remove_job(config, &id), crate::CronCommands::Pause { id } => { pause_job(config, &id)?; @@ -167,3 +220,197 @@ fn parse_delay(input: &str) -> Result { }; Ok(duration) } + +#[cfg(test)] +mod tests { + use super::*; + use tempfile::TempDir; + + fn test_config(tmp: &TempDir) -> Config { + let config = Config { + workspace_dir: tmp.path().join("workspace"), + config_path: tmp.path().join("config.toml"), + ..Config::default() + }; + std::fs::create_dir_all(&config.workspace_dir).unwrap(); + config + } + + fn make_job(config: &Config, expr: &str, tz: Option<&str>, cmd: &str) -> CronJob { + add_shell_job( + config, + None, + Schedule::Cron { + expr: expr.into(), + tz: tz.map(Into::into), + }, + cmd, + ) + .unwrap() + } + + fn run_update( + config: &Config, + id: &str, + expression: Option<&str>, + tz: Option<&str>, + command: Option<&str>, + name: Option<&str>, + ) -> Result<()> { + handle_command( + crate::CronCommands::Update { + id: id.into(), + expression: expression.map(Into::into), + tz: tz.map(Into::into), + command: command.map(Into::into), + name: name.map(Into::into), + }, + config, + ) + } + + #[test] + fn update_changes_command_via_handler() { + let tmp = TempDir::new().unwrap(); + let config = test_config(&tmp); + let job = make_job(&config, "*/5 * * * *", None, "echo original"); + + run_update(&config, &job.id, None, None, Some("echo updated"), None).unwrap(); + + let updated = get_job(&config, &job.id).unwrap(); + assert_eq!(updated.command, "echo updated"); + assert_eq!(updated.id, job.id); + } + + #[test] + fn update_changes_expression_via_handler() { + let tmp = TempDir::new().unwrap(); + let config = test_config(&tmp); + let job = make_job(&config, "*/5 * * * *", None, "echo test"); + + run_update(&config, &job.id, Some("0 9 * * *"), None, None, None).unwrap(); + + let updated = get_job(&config, &job.id).unwrap(); + assert_eq!(updated.expression, "0 9 * * *"); + } + + #[test] + fn update_changes_name_via_handler() { + let tmp = TempDir::new().unwrap(); + let config = test_config(&tmp); + let job = make_job(&config, "*/5 * * * *", None, "echo test"); + + run_update(&config, &job.id, None, None, None, Some("new-name")).unwrap(); + + let updated = get_job(&config, &job.id).unwrap(); + assert_eq!(updated.name.as_deref(), Some("new-name")); + } + + #[test] + fn update_tz_alone_sets_timezone() { + let tmp = TempDir::new().unwrap(); + let config = test_config(&tmp); + let job = make_job(&config, "*/5 * * * *", None, "echo test"); + + run_update( + &config, + &job.id, + None, + Some("America/Los_Angeles"), + None, + None, + ) + .unwrap(); + + let updated = get_job(&config, &job.id).unwrap(); + assert_eq!( + updated.schedule, + Schedule::Cron { + expr: "*/5 * * * *".into(), + tz: Some("America/Los_Angeles".into()), + } + ); + } + + #[test] + fn update_expression_preserves_existing_tz() { + let tmp = TempDir::new().unwrap(); + let config = test_config(&tmp); + let job = make_job( + &config, + "*/5 * * * *", + Some("America/Los_Angeles"), + "echo test", + ); + + run_update(&config, &job.id, Some("0 9 * * *"), None, None, None).unwrap(); + + let updated = get_job(&config, &job.id).unwrap(); + assert_eq!( + updated.schedule, + Schedule::Cron { + expr: "0 9 * * *".into(), + tz: Some("America/Los_Angeles".into()), + } + ); + } + + #[test] + fn update_preserves_unchanged_fields() { + let tmp = TempDir::new().unwrap(); + let config = test_config(&tmp); + let job = add_shell_job( + &config, + Some("original-name".into()), + Schedule::Cron { + expr: "*/5 * * * *".into(), + tz: None, + }, + "echo original", + ) + .unwrap(); + + run_update(&config, &job.id, None, None, Some("echo changed"), None).unwrap(); + + let updated = get_job(&config, &job.id).unwrap(); + assert_eq!(updated.command, "echo changed"); + assert_eq!(updated.name.as_deref(), Some("original-name")); + assert_eq!(updated.expression, "*/5 * * * *"); + } + + #[test] + fn update_no_flags_fails() { + let tmp = TempDir::new().unwrap(); + let config = test_config(&tmp); + let job = make_job(&config, "*/5 * * * *", None, "echo test"); + + let result = run_update(&config, &job.id, None, None, None, None); + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("At least one of")); + } + + #[test] + fn update_nonexistent_job_fails() { + let tmp = TempDir::new().unwrap(); + let config = test_config(&tmp); + + let result = run_update( + &config, + "nonexistent-id", + None, + None, + Some("echo test"), + None, + ); + assert!(result.is_err()); + } + + #[test] + fn update_security_allows_safe_command() { + let tmp = TempDir::new().unwrap(); + let config = test_config(&tmp); + + let security = SecurityPolicy::from_config(&config.autonomy, &config.workspace_dir); + assert!(security.is_command_allowed("echo safe")); + } +} diff --git a/src/cron/scheduler.rs b/src/cron/scheduler.rs index ce9c6c3..8d0d7b7 100644 --- a/src/cron/scheduler.rs +++ b/src/cron/scheduler.rs @@ -61,7 +61,7 @@ async fn execute_job_with_retry( for attempt in 0..=retries { let (success, output) = match job.job_type { JobType::Shell => run_job_command(config, security, job).await, - JobType::Agent => run_agent_job(config, job).await, + JobType::Agent => run_agent_job(config, security, job).await, }; last_output = output; @@ -116,7 +116,31 @@ async fn execute_and_persist_job( (job.id.clone(), success) } -async fn run_agent_job(config: &Config, job: &CronJob) -> (bool, String) { +async fn run_agent_job( + config: &Config, + security: &SecurityPolicy, + job: &CronJob, +) -> (bool, String) { + if !security.can_act() { + return ( + false, + "blocked by security policy: autonomy is read-only".to_string(), + ); + } + + if security.is_rate_limited() { + return ( + false, + "blocked by security policy: rate limit exceeded".to_string(), + ); + } + + if !security.record_action() { + return ( + false, + "blocked by security policy: action budget exhausted".to_string(), + ); + } let name = job.name.clone().unwrap_or_else(|| "cron-job".to_string()); let prompt = job.prompt.clone().unwrap_or_default(); let prefixed_prompt = format!("[cron:{} {name}] {prompt}", job.id); @@ -475,13 +499,15 @@ mod tests { use chrono::{Duration as ChronoDuration, Utc}; use tempfile::TempDir; - fn test_config(tmp: &TempDir) -> Config { + async fn test_config(tmp: &TempDir) -> Config { let config = Config { workspace_dir: tmp.path().join("workspace"), config_path: tmp.path().join("config.toml"), ..Config::default() }; - std::fs::create_dir_all(&config.workspace_dir).unwrap(); + tokio::fs::create_dir_all(&config.workspace_dir) + .await + .unwrap(); config } @@ -513,7 +539,7 @@ mod tests { #[tokio::test] async fn run_job_command_success() { let tmp = TempDir::new().unwrap(); - let config = test_config(&tmp); + let config = test_config(&tmp).await; let job = test_job("echo scheduler-ok"); let security = SecurityPolicy::from_config(&config.autonomy, &config.workspace_dir); @@ -526,7 +552,7 @@ mod tests { #[tokio::test] async fn run_job_command_failure() { let tmp = TempDir::new().unwrap(); - let config = test_config(&tmp); + let config = test_config(&tmp).await; let job = test_job("ls definitely_missing_file_for_scheduler_test"); let security = SecurityPolicy::from_config(&config.autonomy, &config.workspace_dir); @@ -539,7 +565,7 @@ mod tests { #[tokio::test] async fn run_job_command_times_out() { let tmp = TempDir::new().unwrap(); - let mut config = test_config(&tmp); + let mut config = test_config(&tmp).await; config.autonomy.allowed_commands = vec!["sleep".into()]; let job = test_job("sleep 1"); let security = SecurityPolicy::from_config(&config.autonomy, &config.workspace_dir); @@ -553,7 +579,7 @@ mod tests { #[tokio::test] async fn run_job_command_blocks_disallowed_command() { let tmp = TempDir::new().unwrap(); - let mut config = test_config(&tmp); + let mut config = test_config(&tmp).await; config.autonomy.allowed_commands = vec!["echo".into()]; let job = test_job("curl https://evil.example"); let security = SecurityPolicy::from_config(&config.autonomy, &config.workspace_dir); @@ -567,7 +593,7 @@ mod tests { #[tokio::test] async fn run_job_command_blocks_forbidden_path_argument() { let tmp = TempDir::new().unwrap(); - let mut config = test_config(&tmp); + let mut config = test_config(&tmp).await; config.autonomy.allowed_commands = vec!["cat".into()]; let job = test_job("cat /etc/passwd"); let security = SecurityPolicy::from_config(&config.autonomy, &config.workspace_dir); @@ -582,7 +608,7 @@ mod tests { #[tokio::test] async fn run_job_command_blocks_readonly_mode() { let tmp = TempDir::new().unwrap(); - let mut config = test_config(&tmp); + let mut config = test_config(&tmp).await; config.autonomy.level = crate::security::AutonomyLevel::ReadOnly; let job = test_job("echo should-not-run"); let security = SecurityPolicy::from_config(&config.autonomy, &config.workspace_dir); @@ -596,7 +622,7 @@ mod tests { #[tokio::test] async fn run_job_command_blocks_rate_limited() { let tmp = TempDir::new().unwrap(); - let mut config = test_config(&tmp); + let mut config = test_config(&tmp).await; config.autonomy.max_actions_per_hour = 0; let job = test_job("echo should-not-run"); let security = SecurityPolicy::from_config(&config.autonomy, &config.workspace_dir); @@ -610,16 +636,17 @@ mod tests { #[tokio::test] async fn execute_job_with_retry_recovers_after_first_failure() { let tmp = TempDir::new().unwrap(); - let mut config = test_config(&tmp); + let mut config = test_config(&tmp).await; config.reliability.scheduler_retries = 1; config.reliability.provider_backoff_ms = 1; config.autonomy.allowed_commands = vec!["sh".into()]; let security = SecurityPolicy::from_config(&config.autonomy, &config.workspace_dir); - std::fs::write( + tokio::fs::write( config.workspace_dir.join("retry-once.sh"), "#!/bin/sh\nif [ -f retry-ok.flag ]; then\n echo recovered\n exit 0\nfi\ntouch retry-ok.flag\nexit 1\n", ) + .await .unwrap(); let job = test_job("sh ./retry-once.sh"); @@ -631,7 +658,7 @@ mod tests { #[tokio::test] async fn execute_job_with_retry_exhausts_attempts() { let tmp = TempDir::new().unwrap(); - let mut config = test_config(&tmp); + let mut config = test_config(&tmp).await; config.reliability.scheduler_retries = 1; config.reliability.provider_backoff_ms = 1; let security = SecurityPolicy::from_config(&config.autonomy, &config.workspace_dir); @@ -646,23 +673,53 @@ mod tests { #[tokio::test] async fn run_agent_job_returns_error_without_provider_key() { let tmp = TempDir::new().unwrap(); - let config = test_config(&tmp); + let config = test_config(&tmp).await; let mut job = test_job(""); job.job_type = JobType::Agent; job.prompt = Some("Say hello".into()); + let security = SecurityPolicy::from_config(&config.autonomy, &config.workspace_dir); - let (success, output) = run_agent_job(&config, &job).await; - assert!(!success, "Agent job without provider key should fail"); - assert!( - !output.is_empty(), - "Expected non-empty error output from failed agent job" - ); + let (success, output) = run_agent_job(&config, &security, &job).await; + assert!(!success); + assert!(output.contains("agent job failed:")); + } + + #[tokio::test] + async fn run_agent_job_blocks_readonly_mode() { + let tmp = TempDir::new().unwrap(); + let mut config = test_config(&tmp); + config.autonomy.level = crate::security::AutonomyLevel::ReadOnly; + let mut job = test_job(""); + job.job_type = JobType::Agent; + job.prompt = Some("Say hello".into()); + let security = SecurityPolicy::from_config(&config.autonomy, &config.workspace_dir); + + let (success, output) = run_agent_job(&config, &security, &job).await; + assert!(!success); + assert!(output.contains("blocked by security policy")); + assert!(output.contains("read-only")); + } + + #[tokio::test] + async fn run_agent_job_blocks_rate_limited() { + let tmp = TempDir::new().unwrap(); + let mut config = test_config(&tmp); + config.autonomy.max_actions_per_hour = 0; + let mut job = test_job(""); + job.job_type = JobType::Agent; + job.prompt = Some("Say hello".into()); + let security = SecurityPolicy::from_config(&config.autonomy, &config.workspace_dir); + + let (success, output) = run_agent_job(&config, &security, &job).await; + assert!(!success); + assert!(output.contains("blocked by security policy")); + assert!(output.contains("rate limit exceeded")); } #[tokio::test] async fn persist_job_result_records_run_and_reschedules_shell_job() { let tmp = TempDir::new().unwrap(); - let config = test_config(&tmp); + let config = test_config(&tmp).await; let job = cron::add_job(&config, "*/5 * * * *", "echo ok").unwrap(); let started = Utc::now(); let finished = started + ChronoDuration::milliseconds(10); @@ -679,7 +736,7 @@ mod tests { #[tokio::test] async fn persist_job_result_success_deletes_one_shot() { let tmp = TempDir::new().unwrap(); - let config = test_config(&tmp); + let config = test_config(&tmp).await; let at = Utc::now() + ChronoDuration::minutes(10); let job = cron::add_agent_job( &config, @@ -704,7 +761,7 @@ mod tests { #[tokio::test] async fn persist_job_result_failure_disables_one_shot() { let tmp = TempDir::new().unwrap(); - let config = test_config(&tmp); + let config = test_config(&tmp).await; let at = Utc::now() + ChronoDuration::minutes(10); let job = cron::add_agent_job( &config, @@ -730,7 +787,7 @@ mod tests { #[tokio::test] async fn deliver_if_configured_handles_none_and_invalid_channel() { let tmp = TempDir::new().unwrap(); - let config = test_config(&tmp); + let config = test_config(&tmp).await; let mut job = test_job("echo ok"); assert!(deliver_if_configured(&config, &job, "x").await.is_ok()); diff --git a/src/daemon/mod.rs b/src/daemon/mod.rs index ca0834b..a2dfee2 100644 --- a/src/daemon/mod.rs +++ b/src/daemon/mod.rs @@ -209,17 +209,40 @@ async fn run_heartbeat_worker(config: Config) -> Result<()> { } fn has_supervised_channels(config: &Config) -> bool { - config.channels_config.telegram.is_some() - || config.channels_config.discord.is_some() - || config.channels_config.slack.is_some() - || config.channels_config.imessage.is_some() - || config.channels_config.matrix.is_some() - || config.channels_config.signal.is_some() - || config.channels_config.whatsapp.is_some() - || config.channels_config.email.is_some() - || config.channels_config.irc.is_some() - || config.channels_config.lark.is_some() - || config.channels_config.dingtalk.is_some() + let crate::config::ChannelsConfig { + cli: _, // `cli` is used only when running the CLI manually + webhook: _, // Managed by the gateway + telegram, + discord, + slack, + mattermost, + imessage, + matrix, + signal, + whatsapp, + email, + irc, + lark, + dingtalk, + linq, + qq, + .. + } = &config.channels_config; + + telegram.is_some() + || discord.is_some() + || slack.is_some() + || mattermost.is_some() + || imessage.is_some() + || matrix.is_some() + || signal.is_some() + || whatsapp.is_some() + || email.is_some() + || irc.is_some() + || lark.is_some() + || dingtalk.is_some() + || linq.is_some() + || qq.is_some() } #[cfg(test)] @@ -298,6 +321,7 @@ mod tests { allowed_users: vec![], stream_mode: crate::config::StreamMode::default(), draft_update_interval_ms: 1000, + interrupt_on_new_message: false, mention_only: false, }); assert!(has_supervised_channels(&config)); @@ -313,4 +337,29 @@ mod tests { }); assert!(has_supervised_channels(&config)); } + + #[test] + fn detects_mattermost_as_supervised_channel() { + let mut config = Config::default(); + config.channels_config.mattermost = Some(crate::config::schema::MattermostConfig { + url: "https://mattermost.example.com".into(), + bot_token: "token".into(), + channel_id: Some("channel-id".into()), + allowed_users: vec!["*".into()], + thread_replies: Some(true), + mention_only: Some(false), + }); + assert!(has_supervised_channels(&config)); + } + + #[test] + fn detects_qq_as_supervised_channel() { + let mut config = Config::default(); + config.channels_config.qq = Some(crate::config::schema::QQConfig { + app_id: "app-id".into(), + app_secret: "app-secret".into(), + allowed_users: vec!["*".into()], + }); + assert!(has_supervised_channels(&config)); + } } diff --git a/src/doctor/mod.rs b/src/doctor/mod.rs index 210f860..f0335db 100644 --- a/src/doctor/mod.rs +++ b/src/doctor/mod.rs @@ -344,6 +344,58 @@ fn check_config_semantics(config: &Config, items: &mut Vec) { } } + // Embedding routes validation + for route in &config.embedding_routes { + if route.hint.trim().is_empty() { + items.push(DiagItem::warn(cat, "embedding route with empty hint")); + } + if let Some(reason) = embedding_provider_validation_error(&route.provider) { + items.push(DiagItem::warn( + cat, + format!( + "embedding route \"{}\" uses invalid provider \"{}\": {}", + route.hint, route.provider, reason + ), + )); + } + if route.model.trim().is_empty() { + items.push(DiagItem::warn( + cat, + format!("embedding route \"{}\" has empty model", route.hint), + )); + } + if route.dimensions.is_some_and(|value| value == 0) { + items.push(DiagItem::warn( + cat, + format!( + "embedding route \"{}\" has invalid dimensions=0", + route.hint + ), + )); + } + } + + if let Some(hint) = config + .memory + .embedding_model + .strip_prefix("hint:") + .map(str::trim) + .filter(|value| !value.is_empty()) + { + if !config + .embedding_routes + .iter() + .any(|route| route.hint.trim() == hint) + { + items.push(DiagItem::warn( + cat, + format!( + "memory.embedding_model uses hint \"{hint}\" but no matching [[embedding_routes]] entry exists" + ), + )); + } + } + // Channel: at least one configured let cc = &config.channels_config; let has_channel = cc.telegram.is_some() @@ -396,6 +448,31 @@ fn provider_validation_error(name: &str) -> Option { } } +fn embedding_provider_validation_error(name: &str) -> Option { + let normalized = name.trim(); + if normalized.eq_ignore_ascii_case("none") || normalized.eq_ignore_ascii_case("openai") { + return None; + } + + let Some(url) = normalized.strip_prefix("custom:") else { + return Some("supported values: none, openai, custom:".into()); + }; + + let url = url.trim(); + if url.is_empty() { + return Some("custom provider requires a non-empty URL after 'custom:'".into()); + } + + match reqwest::Url::parse(url) { + Ok(parsed) if matches!(parsed.scheme(), "http" | "https") => None, + Ok(parsed) => Some(format!( + "custom provider URL must use http/https, got '{}'", + parsed.scheme() + )), + Err(err) => Some(format!("invalid custom provider URL: {err}")), + } +} + // ── Workspace integrity ────────────────────────────────────────── fn check_workspace(config: &Config, items: &mut Vec) { @@ -891,6 +968,62 @@ mod tests { assert_eq!(route_item.unwrap().severity, Severity::Warn); } + #[test] + fn config_validation_warns_empty_embedding_route_model() { + let mut config = Config::default(); + config.embedding_routes = vec![crate::config::EmbeddingRouteConfig { + hint: "semantic".into(), + provider: "openai".into(), + model: String::new(), + dimensions: Some(1536), + api_key: None, + }]; + + let mut items = Vec::new(); + check_config_semantics(&config, &mut items); + let route_item = items.iter().find(|item| { + item.message + .contains("embedding route \"semantic\" has empty model") + }); + assert!(route_item.is_some()); + assert_eq!(route_item.unwrap().severity, Severity::Warn); + } + + #[test] + fn config_validation_warns_invalid_embedding_route_provider() { + let mut config = Config::default(); + config.embedding_routes = vec![crate::config::EmbeddingRouteConfig { + hint: "semantic".into(), + provider: "groq".into(), + model: "text-embedding-3-small".into(), + dimensions: None, + api_key: None, + }]; + + let mut items = Vec::new(); + check_config_semantics(&config, &mut items); + let route_item = items + .iter() + .find(|item| item.message.contains("uses invalid provider \"groq\"")); + assert!(route_item.is_some()); + assert_eq!(route_item.unwrap().severity, Severity::Warn); + } + + #[test] + fn config_validation_warns_missing_embedding_hint_target() { + let mut config = Config::default(); + config.memory.embedding_model = "hint:semantic".into(); + + let mut items = Vec::new(); + check_config_semantics(&config, &mut items); + let route_item = items.iter().find(|item| { + item.message + .contains("no matching [[embedding_routes]] entry exists") + }); + assert!(route_item.is_some()); + assert_eq!(route_item.unwrap().severity, Severity::Warn); + } + #[test] fn environment_check_finds_git() { let mut items = Vec::new(); @@ -910,8 +1043,8 @@ mod tests { #[test] fn truncate_for_display_preserves_utf8_boundaries() { - let preview = truncate_for_display("版本号-alpha-build", 3); - assert_eq!(preview, "版本号…"); + let preview = truncate_for_display("🙂example-alpha-build", 3); + assert_eq!(preview, "🙂ex…"); } #[test] diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index 3027638..a7f6777 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -7,10 +7,10 @@ //! - Request timeouts (30s) to prevent slow-loris attacks //! - Header sanitization (handled by axum/hyper) -use crate::channels::{Channel, SendMessage, WhatsAppChannel}; +use crate::channels::{Channel, LinqChannel, SendMessage, WhatsAppChannel}; use crate::config::Config; use crate::memory::{self, Memory, MemoryCategory}; -use crate::providers::{self, Provider}; +use crate::providers::{self, ChatMessage, Provider, ProviderCapabilityError}; use crate::runtime; use crate::security::pairing::{constant_time_eq, is_public_bind, PairingGuard}; use crate::security::SecurityPolicy; @@ -53,6 +53,10 @@ fn whatsapp_memory_key(msg: &crate::channels::traits::ChannelMessage) -> String format!("whatsapp_{}_{}", msg.sender, msg.id) } +fn linq_memory_key(msg: &crate::channels::traits::ChannelMessage) -> String { + format!("linq_{}_{}", msg.sender, msg.id) +} + fn hash_webhook_secret(value: &str) -> String { use sha2::{Digest, Sha256}; @@ -274,6 +278,9 @@ pub struct AppState { pub whatsapp: Option>, /// `WhatsApp` app secret for webhook signature verification (`X-Hub-Signature-256`) pub whatsapp_app_secret: Option>, + pub linq: Option>, + /// Linq webhook signing secret for signature verification + pub linq_signing_secret: Option>, /// Observability backend for metrics scraping pub observer: Arc, } @@ -306,6 +313,7 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> { auth_profile_override: None, zeroclaw_dir: config.config_path.parent().map(std::path::PathBuf::from), secrets_encrypt: config.secrets.encrypt, + reasoning_enabled: config.runtime.reasoning_enabled, }, )?); let model = config @@ -360,12 +368,16 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> { }); // WhatsApp channel (if configured) - let whatsapp_channel: Option> = - config.channels_config.whatsapp.as_ref().map(|wa| { + let whatsapp_channel: Option> = config + .channels_config + .whatsapp + .as_ref() + .filter(|wa| wa.is_cloud_config()) + .map(|wa| { Arc::new(WhatsAppChannel::new( - wa.access_token.clone(), - wa.phone_number_id.clone(), - wa.verify_token.clone(), + wa.access_token.clone().unwrap_or_default(), + wa.phone_number_id.clone().unwrap_or_default(), + wa.verify_token.clone().unwrap_or_default(), wa.allowed_numbers.clone(), )) }); @@ -389,6 +401,34 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> { }) .map(Arc::from); + // Linq channel (if configured) + let linq_channel: Option> = config.channels_config.linq.as_ref().map(|lq| { + Arc::new(LinqChannel::new( + lq.api_token.clone(), + lq.from_phone.clone(), + lq.allowed_senders.clone(), + )) + }); + + // Linq signing secret for webhook signature verification + // Priority: environment variable > config file + let linq_signing_secret: Option> = std::env::var("ZEROCLAW_LINQ_SIGNING_SECRET") + .ok() + .and_then(|secret| { + let secret = secret.trim(); + (!secret.is_empty()).then(|| secret.to_owned()) + }) + .or_else(|| { + config.channels_config.linq.as_ref().and_then(|lq| { + lq.signing_secret + .as_deref() + .map(str::trim) + .filter(|secret| !secret.is_empty()) + .map(ToOwned::to_owned) + }) + }) + .map(Arc::from); + // ── Pairing guard ────────────────────────────────────── let pairing = Arc::new(PairingGuard::new( config.gateway.require_pairing, @@ -440,6 +480,9 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> { println!(" GET /whatsapp — Meta webhook verification"); println!(" POST /whatsapp — WhatsApp message webhook"); } + if linq_channel.is_some() { + println!(" POST /linq — Linq message webhook (iMessage/RCS/SMS)"); + } println!(" GET /health — health check"); println!(" GET /metrics — Prometheus metrics"); if let Some(code) = pairing.pairing_code() { @@ -476,6 +519,8 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> { idempotency_store, whatsapp: whatsapp_channel, whatsapp_app_secret, + linq: linq_channel, + linq_signing_secret, observer, }; @@ -487,6 +532,7 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> { .route("/webhook", post(handle_webhook)) .route("/whatsapp", get(handle_whatsapp_verify)) .route("/whatsapp", post(handle_whatsapp_message)) + .route("/linq", post(handle_linq_webhook)) .with_state(state) .layer(RequestBodyLimitLayer::new(MAX_BODY_SIZE)) .layer(TimeoutLayer::with_status_code( @@ -542,15 +588,16 @@ async fn handle_metrics(State(state): State) -> impl IntoResponse { } /// POST /pair — exchange one-time code for bearer token +#[axum::debug_handler] async fn handle_pair( State(state): State, ConnectInfo(peer_addr): ConnectInfo, headers: HeaderMap, ) -> impl IntoResponse { - let client_key = + let rate_key = client_key_from_request(Some(peer_addr), &headers, state.trust_forwarded_headers); - if !state.rate_limiter.allow_pair(&client_key) { - tracing::warn!("/pair rate limit exceeded for key: {client_key}"); + if !state.rate_limiter.allow_pair(&rate_key) { + tracing::warn!("/pair rate limit exceeded"); let err = serde_json::json!({ "error": "Too many pairing requests. Please retry later.", "retry_after": RATE_LIMIT_WINDOW_SECS, @@ -563,10 +610,10 @@ async fn handle_pair( .and_then(|v| v.to_str().ok()) .unwrap_or(""); - match state.pairing.try_pair(code) { + match state.pairing.try_pair(code, &rate_key).await { Ok(Some(token)) => { tracing::info!("🔐 New client paired successfully"); - if let Err(err) = persist_pairing_tokens(&state.config, &state.pairing) { + if let Err(err) = persist_pairing_tokens(state.config.clone(), &state.pairing).await { tracing::error!("🔐 Pairing succeeded but token persistence failed: {err:#}"); let body = serde_json::json!({ "paired": true, @@ -603,12 +650,66 @@ async fn handle_pair( } } -fn persist_pairing_tokens(config: &Arc>, pairing: &PairingGuard) -> Result<()> { +async fn persist_pairing_tokens(config: Arc>, pairing: &PairingGuard) -> Result<()> { let paired_tokens = pairing.tokens(); - let mut cfg = config.lock(); - cfg.gateway.paired_tokens = paired_tokens; - cfg.save() - .context("Failed to persist paired tokens to config.toml") + // This is needed because parking_lot's guard is not Send so we clone the inner + // this should be removed once async mutexes are used everywhere + let mut updated_cfg = { config.lock().clone() }; + updated_cfg.gateway.paired_tokens = paired_tokens; + updated_cfg + .save() + .await + .context("Failed to persist paired tokens to config.toml")?; + + // Keep shared runtime config in sync with persisted tokens. + *config.lock() = updated_cfg; + Ok(()) +} + +async fn run_gateway_chat_with_multimodal( + state: &AppState, + provider_label: &str, + message: &str, +) -> anyhow::Result { + let user_messages = vec![ChatMessage::user(message)]; + let image_marker_count = crate::multimodal::count_image_markers(&user_messages); + if image_marker_count > 0 && !state.provider.supports_vision() { + return Err(ProviderCapabilityError { + provider: provider_label.to_string(), + capability: "vision".to_string(), + message: format!( + "received {image_marker_count} image marker(s), but this provider does not support vision input" + ), + } + .into()); + } + + // Keep webhook/gateway prompts aligned with channel behavior by injecting + // workspace-aware system context before model invocation. + let system_prompt = { + let config_guard = state.config.lock(); + crate::channels::build_system_prompt( + &config_guard.workspace_dir, + &state.model, + &[], // tools - empty for simple chat + &[], // skills + Some(&config_guard.identity), + None, // bootstrap_max_chars - use default + ) + }; + + let mut messages = Vec::with_capacity(1 + user_messages.len()); + messages.push(ChatMessage::system(system_prompt)); + messages.extend(user_messages); + + let multimodal_config = state.config.lock().multimodal.clone(); + let prepared = + crate::multimodal::prepare_messages_for_provider(&messages, &multimodal_config).await?; + + state + .provider + .chat_with_history(&prepared.messages, &state.model, state.temperature) + .await } /// Webhook request body @@ -624,10 +725,10 @@ async fn handle_webhook( headers: HeaderMap, body: Result, axum::extract::rejection::JsonRejection>, ) -> impl IntoResponse { - let client_key = + let rate_key = client_key_from_request(Some(peer_addr), &headers, state.trust_forwarded_headers); - if !state.rate_limiter.allow_webhook(&client_key) { - tracing::warn!("/webhook rate limit exceeded for key: {client_key}"); + if !state.rate_limiter.allow_webhook(&rate_key) { + tracing::warn!("/webhook rate limit exceeded"); let err = serde_json::json!({ "error": "Too many webhook requests. Please retry later.", "retry_after": RATE_LIMIT_WINDOW_SECS, @@ -732,11 +833,7 @@ async fn handle_webhook( messages_count: 1, }); - match state - .provider - .simple_chat(message, &state.model, state.temperature) - .await - { + match run_gateway_chat_with_multimodal(&state, &provider_label, message).await { Ok(response) => { let duration = started_at.elapsed(); state @@ -920,6 +1017,12 @@ async fn handle_whatsapp_message( } // Process each message + let provider_label = state + .config + .lock() + .default_provider + .clone() + .unwrap_or_else(|| "unknown".to_string()); for msg in &messages { tracing::info!( "WhatsApp message from {}: {}", @@ -936,12 +1039,7 @@ async fn handle_whatsapp_message( .await; } - // Call the LLM - match state - .provider - .simple_chat(&msg.content, &state.model, state.temperature) - .await - { + match run_gateway_chat_with_multimodal(&state, &provider_label, &msg.content).await { Ok(response) => { // Send reply via WhatsApp if let Err(e) = wa @@ -967,6 +1065,120 @@ async fn handle_whatsapp_message( (StatusCode::OK, Json(serde_json::json!({"status": "ok"}))) } +/// POST /linq — incoming message webhook (iMessage/RCS/SMS via Linq) +async fn handle_linq_webhook( + State(state): State, + headers: HeaderMap, + body: Bytes, +) -> impl IntoResponse { + let Some(ref linq) = state.linq else { + return ( + StatusCode::NOT_FOUND, + Json(serde_json::json!({"error": "Linq not configured"})), + ); + }; + + let body_str = String::from_utf8_lossy(&body); + + // ── Security: Verify X-Webhook-Signature if signing_secret is configured ── + if let Some(ref signing_secret) = state.linq_signing_secret { + let timestamp = headers + .get("X-Webhook-Timestamp") + .and_then(|v| v.to_str().ok()) + .unwrap_or(""); + + let signature = headers + .get("X-Webhook-Signature") + .and_then(|v| v.to_str().ok()) + .unwrap_or(""); + + if !crate::channels::linq::verify_linq_signature( + signing_secret, + &body_str, + timestamp, + signature, + ) { + tracing::warn!( + "Linq webhook signature verification failed (signature: {})", + if signature.is_empty() { + "missing" + } else { + "invalid" + } + ); + return ( + StatusCode::UNAUTHORIZED, + Json(serde_json::json!({"error": "Invalid signature"})), + ); + } + } + + // Parse JSON body + let Ok(payload) = serde_json::from_slice::(&body) else { + return ( + StatusCode::BAD_REQUEST, + Json(serde_json::json!({"error": "Invalid JSON payload"})), + ); + }; + + // Parse messages from the webhook payload + let messages = linq.parse_webhook_payload(&payload); + + if messages.is_empty() { + // Acknowledge the webhook even if no messages (could be status/delivery events) + return (StatusCode::OK, Json(serde_json::json!({"status": "ok"}))); + } + + // Process each message + let provider_label = state + .config + .lock() + .default_provider + .clone() + .unwrap_or_else(|| "unknown".to_string()); + for msg in &messages { + tracing::info!( + "Linq message from {}: {}", + msg.sender, + truncate_with_ellipsis(&msg.content, 50) + ); + + // Auto-save to memory + if state.auto_save { + let key = linq_memory_key(msg); + let _ = state + .mem + .store(&key, &msg.content, MemoryCategory::Conversation, None) + .await; + } + + // Call the LLM + match run_gateway_chat_with_multimodal(&state, &provider_label, &msg.content).await { + Ok(response) => { + // Send reply via Linq + if let Err(e) = linq + .send(&SendMessage::new(response, &msg.reply_target)) + .await + { + tracing::error!("Failed to send Linq reply: {e}"); + } + } + Err(e) => { + tracing::error!("LLM error for Linq message: {e:#}"); + let _ = linq + .send(&SendMessage::new( + "Sorry, I couldn't process your message right now.", + &msg.reply_target, + )) + .await; + } + } + } + + // Acknowledge the webhook + (StatusCode::OK, Json(serde_json::json!({"status": "ok"}))) +} + #[cfg(test)] mod tests { use super::*; @@ -980,6 +1192,13 @@ mod tests { use parking_lot::Mutex; use std::sync::atomic::{AtomicUsize, Ordering}; + /// Generate a random hex secret at runtime to avoid hard-coded cryptographic values. + fn generate_test_secret() -> String { + use rand::Rng; + let bytes: [u8; 32] = rand::rng().random(); + hex::encode(bytes) + } + #[test] fn security_body_limit_is_64kb() { assert_eq!(MAX_BODY_SIZE, 65_536); @@ -1034,6 +1253,8 @@ mod tests { idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300), 1000)), whatsapp: None, whatsapp_app_secret: None, + linq: None, + linq_signing_secret: None, observer: Arc::new(crate::observability::NoopObserver), }; @@ -1075,6 +1296,8 @@ mod tests { idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300), 1000)), whatsapp: None, whatsapp_app_secret: None, + linq: None, + linq_signing_secret: None, observer, }; @@ -1221,8 +1444,8 @@ mod tests { assert_eq!(normalize_max_keys(1, 10_000), 1); } - #[test] - fn persist_pairing_tokens_writes_config_tokens() { + #[tokio::test] + async fn persist_pairing_tokens_writes_config_tokens() { let temp = tempfile::tempdir().unwrap(); let config_path = temp.path().join("config.toml"); let workspace_path = temp.path().join("workspace"); @@ -1230,22 +1453,28 @@ mod tests { let mut config = Config::default(); config.config_path = config_path.clone(); config.workspace_dir = workspace_path; - config.save().unwrap(); + config.save().await.unwrap(); let guard = PairingGuard::new(true, &[]); let code = guard.pairing_code().unwrap(); - let token = guard.try_pair(&code).unwrap().unwrap(); + let token = guard.try_pair(&code, "test_client").await.unwrap().unwrap(); assert!(guard.is_authenticated(&token)); let shared_config = Arc::new(Mutex::new(config)); - persist_pairing_tokens(&shared_config, &guard).unwrap(); + persist_pairing_tokens(shared_config.clone(), &guard) + .await + .unwrap(); - let saved = std::fs::read_to_string(config_path).unwrap(); + let saved = tokio::fs::read_to_string(config_path).await.unwrap(); let parsed: Config = toml::from_str(&saved).unwrap(); assert_eq!(parsed.gateway.paired_tokens.len(), 1); let persisted = &parsed.gateway.paired_tokens[0]; assert_eq!(persisted.len(), 64); assert!(persisted.chars().all(|c| c.is_ascii_hexdigit())); + + let in_memory = shared_config.lock(); + assert_eq!(in_memory.gateway.paired_tokens.len(), 1); + assert_eq!(&in_memory.gateway.paired_tokens[0], persisted); } #[test] @@ -1267,6 +1496,7 @@ mod tests { content: "hello".into(), channel: "whatsapp".into(), timestamp: 1, + thread_ts: None, }; let key = whatsapp_memory_key(&msg); @@ -1426,6 +1656,8 @@ mod tests { idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300), 1000)), whatsapp: None, whatsapp_app_secret: None, + linq: None, + linq_signing_secret: None, observer: Arc::new(crate::observability::NoopObserver), }; @@ -1482,6 +1714,8 @@ mod tests { idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300), 1000)), whatsapp: None, whatsapp_app_secret: None, + linq: None, + linq_signing_secret: None, observer: Arc::new(crate::observability::NoopObserver), }; @@ -1518,9 +1752,11 @@ mod tests { #[test] fn webhook_secret_hash_is_deterministic_and_nonempty() { - let one = hash_webhook_secret("secret-value"); - let two = hash_webhook_secret("secret-value"); - let other = hash_webhook_secret("other-value"); + let secret_a = generate_test_secret(); + let secret_b = generate_test_secret(); + let one = hash_webhook_secret(&secret_a); + let two = hash_webhook_secret(&secret_a); + let other = hash_webhook_secret(&secret_b); assert_eq!(one, two); assert_ne!(one, other); @@ -1532,6 +1768,7 @@ mod tests { let provider_impl = Arc::new(MockProvider::default()); let provider: Arc = provider_impl.clone(); let memory: Arc = Arc::new(MockMemory); + let secret = generate_test_secret(); let state = AppState { config: Arc::new(Mutex::new(Config::default())), @@ -1540,13 +1777,15 @@ mod tests { temperature: 0.0, mem: memory, auto_save: false, - webhook_secret_hash: Some(Arc::from(hash_webhook_secret("super-secret"))), + webhook_secret_hash: Some(Arc::from(hash_webhook_secret(&secret))), pairing: Arc::new(PairingGuard::new(false, &[])), trust_forwarded_headers: false, rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100, 100)), idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300), 1000)), whatsapp: None, whatsapp_app_secret: None, + linq: None, + linq_signing_secret: None, observer: Arc::new(crate::observability::NoopObserver), }; @@ -1570,6 +1809,8 @@ mod tests { let provider_impl = Arc::new(MockProvider::default()); let provider: Arc = provider_impl.clone(); let memory: Arc = Arc::new(MockMemory); + let valid_secret = generate_test_secret(); + let wrong_secret = generate_test_secret(); let state = AppState { config: Arc::new(Mutex::new(Config::default())), @@ -1578,18 +1819,23 @@ mod tests { temperature: 0.0, mem: memory, auto_save: false, - webhook_secret_hash: Some(Arc::from(hash_webhook_secret("super-secret"))), + webhook_secret_hash: Some(Arc::from(hash_webhook_secret(&valid_secret))), pairing: Arc::new(PairingGuard::new(false, &[])), trust_forwarded_headers: false, rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100, 100)), idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300), 1000)), whatsapp: None, whatsapp_app_secret: None, + linq: None, + linq_signing_secret: None, observer: Arc::new(crate::observability::NoopObserver), }; let mut headers = HeaderMap::new(); - headers.insert("X-Webhook-Secret", HeaderValue::from_static("wrong-secret")); + headers.insert( + "X-Webhook-Secret", + HeaderValue::from_str(&wrong_secret).unwrap(), + ); let response = handle_webhook( State(state), @@ -1611,6 +1857,7 @@ mod tests { let provider_impl = Arc::new(MockProvider::default()); let provider: Arc = provider_impl.clone(); let memory: Arc = Arc::new(MockMemory); + let secret = generate_test_secret(); let state = AppState { config: Arc::new(Mutex::new(Config::default())), @@ -1619,18 +1866,20 @@ mod tests { temperature: 0.0, mem: memory, auto_save: false, - webhook_secret_hash: Some(Arc::from(hash_webhook_secret("super-secret"))), + webhook_secret_hash: Some(Arc::from(hash_webhook_secret(&secret))), pairing: Arc::new(PairingGuard::new(false, &[])), trust_forwarded_headers: false, rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100, 100)), idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300), 1000)), whatsapp: None, whatsapp_app_secret: None, + linq: None, + linq_signing_secret: None, observer: Arc::new(crate::observability::NoopObserver), }; let mut headers = HeaderMap::new(); - headers.insert("X-Webhook-Secret", HeaderValue::from_static("super-secret")); + headers.insert("X-Webhook-Secret", HeaderValue::from_str(&secret).unwrap()); let response = handle_webhook( State(state), @@ -1666,14 +1915,13 @@ mod tests { #[test] fn whatsapp_signature_valid() { - // Test with known values - let app_secret = "test_secret_key_12345"; + let app_secret = generate_test_secret(); let body = b"test body content"; - let signature_header = compute_whatsapp_signature_header(app_secret, body); + let signature_header = compute_whatsapp_signature_header(&app_secret, body); assert!(verify_whatsapp_signature( - app_secret, + &app_secret, body, &signature_header )); @@ -1681,14 +1929,14 @@ mod tests { #[test] fn whatsapp_signature_invalid_wrong_secret() { - let app_secret = "correct_secret_key_abc"; - let wrong_secret = "wrong_secret_key_xyz"; + let app_secret = generate_test_secret(); + let wrong_secret = generate_test_secret(); let body = b"test body content"; - let signature_header = compute_whatsapp_signature_header(wrong_secret, body); + let signature_header = compute_whatsapp_signature_header(&wrong_secret, body); assert!(!verify_whatsapp_signature( - app_secret, + &app_secret, body, &signature_header )); @@ -1696,15 +1944,15 @@ mod tests { #[test] fn whatsapp_signature_invalid_wrong_body() { - let app_secret = "test_secret_key_12345"; + let app_secret = generate_test_secret(); let original_body = b"original body"; let tampered_body = b"tampered body"; - let signature_header = compute_whatsapp_signature_header(app_secret, original_body); + let signature_header = compute_whatsapp_signature_header(&app_secret, original_body); // Verify with tampered body should fail assert!(!verify_whatsapp_signature( - app_secret, + &app_secret, tampered_body, &signature_header )); @@ -1712,14 +1960,14 @@ mod tests { #[test] fn whatsapp_signature_missing_prefix() { - let app_secret = "test_secret_key_12345"; + let app_secret = generate_test_secret(); let body = b"test body"; // Signature without "sha256=" prefix let signature_header = "abc123def456"; assert!(!verify_whatsapp_signature( - app_secret, + &app_secret, body, signature_header )); @@ -1727,22 +1975,22 @@ mod tests { #[test] fn whatsapp_signature_empty_header() { - let app_secret = "test_secret_key_12345"; + let app_secret = generate_test_secret(); let body = b"test body"; - assert!(!verify_whatsapp_signature(app_secret, body, "")); + assert!(!verify_whatsapp_signature(&app_secret, body, "")); } #[test] fn whatsapp_signature_invalid_hex() { - let app_secret = "test_secret_key_12345"; + let app_secret = generate_test_secret(); let body = b"test body"; // Invalid hex characters let signature_header = "sha256=not_valid_hex_zzz"; assert!(!verify_whatsapp_signature( - app_secret, + &app_secret, body, signature_header )); @@ -1750,13 +1998,13 @@ mod tests { #[test] fn whatsapp_signature_empty_body() { - let app_secret = "test_secret_key_12345"; + let app_secret = generate_test_secret(); let body = b""; - let signature_header = compute_whatsapp_signature_header(app_secret, body); + let signature_header = compute_whatsapp_signature_header(&app_secret, body); assert!(verify_whatsapp_signature( - app_secret, + &app_secret, body, &signature_header )); @@ -1764,13 +2012,13 @@ mod tests { #[test] fn whatsapp_signature_unicode_body() { - let app_secret = "test_secret_key_12345"; + let app_secret = generate_test_secret(); let body = "Hello 🦀 World".as_bytes(); - let signature_header = compute_whatsapp_signature_header(app_secret, body); + let signature_header = compute_whatsapp_signature_header(&app_secret, body); assert!(verify_whatsapp_signature( - app_secret, + &app_secret, body, &signature_header )); @@ -1778,13 +2026,13 @@ mod tests { #[test] fn whatsapp_signature_json_payload() { - let app_secret = "test_app_secret_key_xyz"; + let app_secret = generate_test_secret(); let body = br#"{"entry":[{"changes":[{"value":{"messages":[{"from":"1234567890","text":{"body":"Hello"}}]}}]}]}"#; - let signature_header = compute_whatsapp_signature_header(app_secret, body); + let signature_header = compute_whatsapp_signature_header(&app_secret, body); assert!(verify_whatsapp_signature( - app_secret, + &app_secret, body, &signature_header )); @@ -1792,31 +2040,35 @@ mod tests { #[test] fn whatsapp_signature_case_sensitive_prefix() { - let app_secret = "test_secret_key_12345"; + let app_secret = generate_test_secret(); let body = b"test body"; - let hex_sig = compute_whatsapp_signature_hex(app_secret, body); + let hex_sig = compute_whatsapp_signature_hex(&app_secret, body); // Wrong case prefix should fail let wrong_prefix = format!("SHA256={hex_sig}"); - assert!(!verify_whatsapp_signature(app_secret, body, &wrong_prefix)); + assert!(!verify_whatsapp_signature(&app_secret, body, &wrong_prefix)); // Correct prefix should pass let correct_prefix = format!("sha256={hex_sig}"); - assert!(verify_whatsapp_signature(app_secret, body, &correct_prefix)); + assert!(verify_whatsapp_signature( + &app_secret, + body, + &correct_prefix + )); } #[test] fn whatsapp_signature_truncated_hex() { - let app_secret = "test_secret_key_12345"; + let app_secret = generate_test_secret(); let body = b"test body"; - let hex_sig = compute_whatsapp_signature_hex(app_secret, body); + let hex_sig = compute_whatsapp_signature_hex(&app_secret, body); let truncated = &hex_sig[..32]; // Only half the signature let signature_header = format!("sha256={truncated}"); assert!(!verify_whatsapp_signature( - app_secret, + &app_secret, body, &signature_header )); @@ -1824,17 +2076,65 @@ mod tests { #[test] fn whatsapp_signature_extra_bytes() { - let app_secret = "test_secret_key_12345"; + let app_secret = generate_test_secret(); let body = b"test body"; - let hex_sig = compute_whatsapp_signature_hex(app_secret, body); + let hex_sig = compute_whatsapp_signature_hex(&app_secret, body); let extended = format!("{hex_sig}deadbeef"); let signature_header = format!("sha256={extended}"); assert!(!verify_whatsapp_signature( - app_secret, + &app_secret, body, &signature_header )); } + + // ══════════════════════════════════════════════════════════ + // IdempotencyStore Edge-Case Tests + // ══════════════════════════════════════════════════════════ + + #[test] + fn idempotency_store_allows_different_keys() { + let store = IdempotencyStore::new(Duration::from_secs(60), 100); + assert!(store.record_if_new("key-a")); + assert!(store.record_if_new("key-b")); + assert!(store.record_if_new("key-c")); + assert!(store.record_if_new("key-d")); + } + + #[test] + fn idempotency_store_max_keys_clamped_to_one() { + let store = IdempotencyStore::new(Duration::from_secs(60), 0); + assert!(store.record_if_new("only-key")); + assert!(!store.record_if_new("only-key")); + } + + #[test] + fn idempotency_store_rapid_duplicate_rejected() { + let store = IdempotencyStore::new(Duration::from_secs(300), 100); + assert!(store.record_if_new("rapid")); + assert!(!store.record_if_new("rapid")); + } + + #[test] + fn idempotency_store_accepts_after_ttl_expires() { + let store = IdempotencyStore::new(Duration::from_millis(1), 100); + assert!(store.record_if_new("ttl-key")); + std::thread::sleep(Duration::from_millis(10)); + assert!(store.record_if_new("ttl-key")); + } + + #[test] + fn idempotency_store_eviction_preserves_newest() { + let store = IdempotencyStore::new(Duration::from_secs(300), 1); + assert!(store.record_if_new("old-key")); + std::thread::sleep(Duration::from_millis(2)); + assert!(store.record_if_new("new-key")); + + let keys = store.keys.lock(); + assert_eq!(keys.len(), 1); + assert!(!keys.contains_key("old-key")); + assert!(keys.contains_key("new-key")); + } } diff --git a/src/hardware/discover.rs b/src/hardware/discover.rs index 4bbf31f..9f514da 100644 --- a/src/hardware/discover.rs +++ b/src/hardware/discover.rs @@ -1,4 +1,10 @@ //! USB device discovery — enumerate devices and enrich with board registry. +//! +//! USB enumeration via `nusb` is only supported on Linux, macOS, and Windows. +//! On Android (Termux) and other unsupported platforms this module is excluded +//! from compilation; callers in `hardware/mod.rs` fall back to an empty result. + +#![cfg(any(target_os = "linux", target_os = "macos", target_os = "windows"))] use super::registry; use anyhow::Result; diff --git a/src/hardware/mod.rs b/src/hardware/mod.rs index 18f6dcc..d9dbc1c 100644 --- a/src/hardware/mod.rs +++ b/src/hardware/mod.rs @@ -4,10 +4,10 @@ pub mod registry; -#[cfg(feature = "hardware")] +#[cfg(all(feature = "hardware", any(target_os = "linux", target_os = "macos", target_os = "windows")))] pub mod discover; -#[cfg(feature = "hardware")] +#[cfg(all(feature = "hardware", any(target_os = "linux", target_os = "macos", target_os = "windows")))] pub mod introspect; use crate::config::Config; @@ -28,8 +28,9 @@ pub struct DiscoveredDevice { /// Auto-discover connected hardware devices. /// Returns an empty vec on platforms without hardware support. pub fn discover_hardware() -> Vec { - // USB/serial discovery is behind the "hardware" feature gate. - #[cfg(feature = "hardware")] + // USB/serial discovery is behind the "hardware" feature gate and only + // available on platforms where nusb supports device enumeration. + #[cfg(all(feature = "hardware", any(target_os = "linux", target_os = "macos", target_os = "windows")))] { if let Ok(devices) = discover::list_usb_devices() { return devices @@ -102,7 +103,15 @@ pub fn handle_command(cmd: crate::HardwareCommands, _config: &Config) -> Result< return Ok(()); } - #[cfg(feature = "hardware")] + #[cfg(all(feature = "hardware", not(any(target_os = "linux", target_os = "macos", target_os = "windows"))))] + { + let _ = &cmd; + println!("Hardware USB discovery is not supported on this platform."); + println!("Supported platforms: Linux, macOS, Windows."); + return Ok(()); + } + + #[cfg(all(feature = "hardware", any(target_os = "linux", target_os = "macos", target_os = "windows")))] match cmd { crate::HardwareCommands::Discover => run_discover(), crate::HardwareCommands::Introspect { path } => run_introspect(&path), @@ -110,7 +119,7 @@ pub fn handle_command(cmd: crate::HardwareCommands, _config: &Config) -> Result< } } -#[cfg(feature = "hardware")] +#[cfg(all(feature = "hardware", any(target_os = "linux", target_os = "macos", target_os = "windows")))] fn run_discover() -> Result<()> { let devices = discover::list_usb_devices()?; @@ -138,7 +147,7 @@ fn run_discover() -> Result<()> { Ok(()) } -#[cfg(feature = "hardware")] +#[cfg(all(feature = "hardware", any(target_os = "linux", target_os = "macos", target_os = "windows")))] fn run_introspect(path: &str) -> Result<()> { let result = introspect::introspect_device(path)?; @@ -160,7 +169,7 @@ fn run_introspect(path: &str) -> Result<()> { Ok(()) } -#[cfg(feature = "hardware")] +#[cfg(all(feature = "hardware", any(target_os = "linux", target_os = "macos", target_os = "windows")))] fn run_info(chip: &str) -> Result<()> { #[cfg(feature = "probe")] { @@ -192,7 +201,7 @@ fn run_info(chip: &str) -> Result<()> { } } -#[cfg(all(feature = "hardware", feature = "probe"))] +#[cfg(all(feature = "hardware", feature = "probe", any(target_os = "linux", target_os = "macos", target_os = "windows")))] fn info_via_probe(chip: &str) -> anyhow::Result<()> { use probe_rs::config::MemoryRegion; use probe_rs::{Session, SessionConfig}; diff --git a/src/integrations/registry.rs b/src/integrations/registry.rs index cc91082..9e28f5c 100644 --- a/src/integrations/registry.rs +++ b/src/integrations/registry.rs @@ -790,6 +790,7 @@ mod tests { allowed_users: vec!["user".into()], stream_mode: StreamMode::default(), draft_update_interval_ms: 1000, + interrupt_on_new_message: false, mention_only: false, }); let entries = all_integrations(); diff --git a/src/lib.rs b/src/lib.rs index 0166bd5..6df3187 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -39,46 +39,49 @@ use clap::Subcommand; use serde::{Deserialize, Serialize}; pub mod agent; -pub mod approval; -pub mod auth; +pub(crate) mod approval; +pub(crate) mod auth; pub mod channels; pub mod config; -pub mod cost; -pub mod cron; -pub mod daemon; -pub mod doctor; +pub(crate) mod cost; +pub(crate) mod cron; +pub(crate) mod daemon; +pub(crate) mod doctor; pub mod gateway; -pub mod hardware; -pub mod health; -pub mod heartbeat; -pub mod identity; -pub mod integrations; +pub(crate) mod hardware; +pub(crate) mod health; +pub(crate) mod heartbeat; +pub(crate) mod identity; +pub(crate) mod integrations; pub mod memory; -pub mod migration; +pub(crate) mod migration; +pub(crate) mod multimodal; pub mod observability; -pub mod onboard; +pub(crate) mod onboard; pub mod peripherals; pub mod providers; pub mod rag; pub mod runtime; -pub mod security; -pub mod service; -pub mod skills; +pub(crate) mod security; +pub(crate) mod service; +pub(crate) mod skills; pub mod tools; -pub mod tunnel; -pub mod util; +pub(crate) mod tunnel; +pub(crate) mod util; pub use config::Config; /// Service management subcommands #[derive(Subcommand, Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] -pub enum ServiceCommands { +pub(crate) enum ServiceCommands { /// Install daemon service unit for auto-start and restart Install, /// Start daemon service Start, /// Stop daemon service Stop, + /// Restart daemon service to apply latest config + Restart, /// Check daemon service status Status, /// Uninstall daemon service unit @@ -87,7 +90,7 @@ pub enum ServiceCommands { /// Channel management subcommands #[derive(Subcommand, Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] -pub enum ChannelCommands { +pub(crate) enum ChannelCommands { /// List all configured channels List, /// Start all configured channels (handled in main.rs for async) @@ -95,6 +98,17 @@ pub enum ChannelCommands { /// Run health checks for configured channels (handled in main.rs for async) Doctor, /// Add a new channel configuration + #[command(long_about = "\ +Add a new channel configuration. + +Provide the channel type and a JSON object with the required \ +configuration keys for that channel type. + +Supported types: telegram, discord, slack, whatsapp, matrix, imessage, email. + +Examples: + zeroclaw channel add telegram '{\"bot_token\":\"...\",\"name\":\"my-bot\"}' + zeroclaw channel add discord '{\"bot_token\":\"...\",\"name\":\"my-discord\"}'")] Add { /// Channel type (telegram, discord, slack, whatsapp, matrix, imessage, email) channel_type: String, @@ -107,6 +121,16 @@ pub enum ChannelCommands { name: String, }, /// Bind a Telegram identity (username or numeric user ID) into allowlist + #[command(long_about = "\ +Bind a Telegram identity into the allowlist. + +Adds a Telegram username (without the '@' prefix) or numeric user \ +ID to the channel allowlist so the agent will respond to messages \ +from that identity. + +Examples: + zeroclaw channel bind-telegram zeroclaw_user + zeroclaw channel bind-telegram 123456789")] BindTelegram { /// Telegram identity to allow (username without '@' or numeric user ID) identity: String, @@ -115,12 +139,12 @@ pub enum ChannelCommands { /// Skills management subcommands #[derive(Subcommand, Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] -pub enum SkillCommands { +pub(crate) enum SkillCommands { /// List all installed skills List, - /// Install a new skill from a URL or local path + /// Install a new skill from a git URL (HTTPS/SSH) or local path Install { - /// Source URL or local path + /// Source git URL (HTTPS/SSH) or local path source: String, }, /// Remove an installed skill @@ -132,7 +156,7 @@ pub enum SkillCommands { /// Migration subcommands #[derive(Subcommand, Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] -pub enum MigrateCommands { +pub(crate) enum MigrateCommands { /// Import memory from an `OpenClaw` workspace into this `ZeroClaw` workspace Openclaw { /// Optional path to `OpenClaw` workspace (defaults to ~/.openclaw/workspace) @@ -147,10 +171,20 @@ pub enum MigrateCommands { /// Cron subcommands #[derive(Subcommand, Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] -pub enum CronCommands { +pub(crate) enum CronCommands { /// List all scheduled tasks List, /// Add a new scheduled task + #[command(long_about = "\ +Add a new recurring scheduled task. + +Uses standard 5-field cron syntax: 'min hour day month weekday'. \ +Times are evaluated in UTC by default; use --tz with an IANA \ +timezone name to override. + +Examples: + zeroclaw cron add '0 9 * * 1-5' 'Good morning' --tz America/New_York + zeroclaw cron add '*/30 * * * *' 'Check system health'")] Add { /// Cron expression expression: String, @@ -161,6 +195,14 @@ pub enum CronCommands { command: String, }, /// Add a one-shot scheduled task at an RFC3339 timestamp + #[command(long_about = "\ +Add a one-shot task that fires at a specific UTC timestamp. + +The timestamp must be in RFC 3339 format (e.g. 2025-01-15T14:00:00Z). + +Examples: + zeroclaw cron add-at 2025-01-15T14:00:00Z 'Send reminder' + zeroclaw cron add-at 2025-12-31T23:59:00Z 'Happy New Year!'")] AddAt { /// One-shot timestamp in RFC3339 format at: String, @@ -168,6 +210,14 @@ pub enum CronCommands { command: String, }, /// Add a fixed-interval scheduled task + #[command(long_about = "\ +Add a task that repeats at a fixed interval. + +Interval is specified in milliseconds. For example, 60000 = 1 minute. + +Examples: + zeroclaw cron add-every 60000 'Ping heartbeat' # every minute + zeroclaw cron add-every 3600000 'Hourly report' # every hour")] AddEvery { /// Interval in milliseconds every_ms: u64, @@ -175,6 +225,16 @@ pub enum CronCommands { command: String, }, /// Add a one-shot delayed task (e.g. "30m", "2h", "1d") + #[command(long_about = "\ +Add a one-shot task that fires after a delay from now. + +Accepts human-readable durations: s (seconds), m (minutes), \ +h (hours), d (days). + +Examples: + zeroclaw cron once 30m 'Run backup in 30 minutes' + zeroclaw cron once 2h 'Follow up on deployment' + zeroclaw cron once 1d 'Daily check'")] Once { /// Delay duration delay: String, @@ -186,6 +246,32 @@ pub enum CronCommands { /// Task ID id: String, }, + /// Update a scheduled task + #[command(long_about = "\ +Update one or more fields of an existing scheduled task. + +Only the fields you specify are changed; others remain unchanged. + +Examples: + zeroclaw cron update --expression '0 8 * * *' + zeroclaw cron update --tz Europe/London --name 'Morning check' + zeroclaw cron update --command 'Updated message'")] + Update { + /// Task ID + id: String, + /// New cron expression + #[arg(long)] + expression: Option, + /// New IANA timezone + #[arg(long)] + tz: Option, + /// New command to run + #[arg(long)] + command: Option, + /// New job name + #[arg(long)] + name: Option, + }, /// Pause a scheduled task Pause { /// Task ID @@ -200,7 +286,7 @@ pub enum CronCommands { /// Integration subcommands #[derive(Subcommand, Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] -pub enum IntegrationCommands { +pub(crate) enum IntegrationCommands { /// Show details about a specific integration Info { /// Integration name @@ -212,13 +298,39 @@ pub enum IntegrationCommands { #[derive(Subcommand, Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] pub enum HardwareCommands { /// Enumerate USB devices (VID/PID) and show known boards + #[command(long_about = "\ +Enumerate USB devices and show known boards. + +Scans connected USB devices by VID/PID and matches them against \ +known development boards (STM32 Nucleo, Arduino, ESP32). + +Examples: + zeroclaw hardware discover")] Discover, /// Introspect a device by path (e.g. /dev/ttyACM0) + #[command(long_about = "\ +Introspect a device by its serial or device path. + +Opens the specified device path and queries for board information, \ +firmware version, and supported capabilities. + +Examples: + zeroclaw hardware introspect /dev/ttyACM0 + zeroclaw hardware introspect COM3")] Introspect { /// Serial or device path path: String, }, /// Get chip info via USB (probe-rs over ST-Link). No firmware needed on target. + #[command(long_about = "\ +Get chip info via USB using probe-rs over ST-Link. + +Queries the target MCU directly through the debug probe without \ +requiring any firmware on the target board. + +Examples: + zeroclaw hardware info + zeroclaw hardware info --chip STM32F401RETx")] Info { /// Chip name (e.g. STM32F401RETx). Default: STM32F401RETx for Nucleo-F401RE #[arg(long, default_value = "STM32F401RETx")] @@ -232,6 +344,19 @@ pub enum PeripheralCommands { /// List configured peripherals List, /// Add a peripheral (board path, e.g. nucleo-f401re /dev/ttyACM0) + #[command(long_about = "\ +Add a peripheral by board type and transport path. + +Registers a hardware board so the agent can use its tools (GPIO, \ +sensors, actuators). Use 'native' as path for local GPIO on \ +single-board computers like Raspberry Pi. + +Supported boards: nucleo-f401re, rpi-gpio, esp32, arduino-uno. + +Examples: + zeroclaw peripheral add nucleo-f401re /dev/ttyACM0 + zeroclaw peripheral add rpi-gpio native + zeroclaw peripheral add esp32 /dev/ttyUSB0")] Add { /// Board type (nucleo-f401re, rpi-gpio, esp32) board: String, @@ -239,6 +364,16 @@ pub enum PeripheralCommands { path: String, }, /// Flash ZeroClaw firmware to Arduino (creates .ino, installs arduino-cli if needed, uploads) + #[command(long_about = "\ +Flash ZeroClaw firmware to an Arduino board. + +Generates the .ino sketch, installs arduino-cli if it is not \ +already available, compiles, and uploads the firmware. + +Examples: + zeroclaw peripheral flash + zeroclaw peripheral flash --port /dev/cu.usbmodem12345 + zeroclaw peripheral flash -p COM3")] Flash { /// Serial port (e.g. /dev/cu.usbmodem12345). If omitted, uses first arduino-uno from config. #[arg(short, long)] diff --git a/src/main.rs b/src/main.rs index b7e1f66..488f8ae 100644 --- a/src/main.rs +++ b/src/main.rs @@ -39,6 +39,14 @@ use serde::{Deserialize, Serialize}; use tracing::{info, warn}; use tracing_subscriber::{fmt, EnvFilter}; +fn parse_temperature(s: &str) -> std::result::Result { + let t: f64 = s.parse().map_err(|e| format!("{e}"))?; + if !(0.0..=2.0).contains(&t) { + return Err("temperature must be between 0.0 and 2.0".to_string()); + } + Ok(t) +} + mod agent; mod approval; mod auth; @@ -58,6 +66,7 @@ mod identity; mod integrations; mod memory; mod migration; +mod multimodal; mod observability; mod onboard; mod peripherals; @@ -95,6 +104,8 @@ enum ServiceCommands { Start, /// Stop daemon service Stop, + /// Restart daemon service to apply latest config + Restart, /// Check daemon service status Status, /// Uninstall daemon service unit @@ -120,13 +131,26 @@ enum Commands { /// Provider name (used in quick mode, default: openrouter) #[arg(long)] provider: Option, - + /// Model ID override (used in quick mode) + #[arg(long)] + model: Option, /// Memory backend (sqlite, lucid, markdown, none) - used in quick mode, default: sqlite #[arg(long)] memory: Option, }, /// Start the AI agent loop + #[command(long_about = "\ +Start the AI agent loop. + +Launches an interactive chat session with the configured AI provider. \ +Use --message for single-shot queries without entering interactive mode. + +Examples: + zeroclaw agent # interactive session + zeroclaw agent -m \"Summarize today's logs\" # single message + zeroclaw agent -p anthropic --model claude-sonnet-4-20250514 + zeroclaw agent --peripheral nucleo-f401re:/dev/ttyACM0")] Agent { /// Single message mode (don't enter interactive mode) #[arg(short, long)] @@ -141,7 +165,7 @@ enum Commands { model: Option, /// Temperature (0.0 - 2.0) - #[arg(short, long, default_value = "0.7")] + #[arg(short, long, default_value = "0.7", value_parser = parse_temperature)] temperature: f64, /// Attach a peripheral (board:path, e.g. nucleo-f401re:/dev/ttyACM0) @@ -150,6 +174,18 @@ enum Commands { }, /// Start the gateway server (webhooks, websockets) + #[command(long_about = "\ +Start the gateway server (webhooks, websockets). + +Runs the HTTP/WebSocket gateway that accepts incoming webhook events \ +and WebSocket connections. Bind address defaults to the values in \ +your config file (gateway.host / gateway.port). + +Examples: + zeroclaw gateway # use config defaults + zeroclaw gateway -p 8080 # listen on port 8080 + zeroclaw gateway --host 0.0.0.0 # bind to all interfaces + zeroclaw gateway -p 0 # random available port")] Gateway { /// Port to listen on (use 0 for random available port); defaults to config gateway.port #[arg(short, long)] @@ -161,6 +197,21 @@ enum Commands { }, /// Start long-running autonomous runtime (gateway + channels + heartbeat + scheduler) + #[command(long_about = "\ +Start the long-running autonomous daemon. + +Launches the full ZeroClaw runtime: gateway server, all configured \ +channels (Telegram, Discord, Slack, etc.), heartbeat monitor, and \ +the cron scheduler. This is the recommended way to run ZeroClaw in \ +production or as an always-on assistant. + +Use 'zeroclaw service install' to register the daemon as an OS \ +service (systemd/launchd) for auto-start on boot. + +Examples: + zeroclaw daemon # use config defaults + zeroclaw daemon -p 9090 # gateway on port 9090 + zeroclaw daemon --host 127.0.0.1 # localhost only")] Daemon { /// Port to listen on (use 0 for random available port); defaults to config gateway.port #[arg(short, long)] @@ -187,6 +238,25 @@ enum Commands { Status, /// Configure and manage scheduled tasks + #[command(long_about = "\ +Configure and manage scheduled tasks. + +Schedule recurring, one-shot, or interval-based tasks using cron \ +expressions, RFC 3339 timestamps, durations, or fixed intervals. + +Cron expressions use the standard 5-field format: \ +'min hour day month weekday'. Timezones default to UTC; \ +override with --tz and an IANA timezone name. + +Examples: + zeroclaw cron list + zeroclaw cron add '0 9 * * 1-5' 'Good morning' --tz America/New_York + zeroclaw cron add '*/30 * * * *' 'Check system health' + zeroclaw cron add-at 2025-01-15T14:00:00Z 'Send reminder' + zeroclaw cron add-every 60000 'Ping heartbeat' + zeroclaw cron once 30m 'Run backup in 30 minutes' + zeroclaw cron pause + zeroclaw cron update --expression '0 8 * * *' --tz Europe/London")] Cron { #[command(subcommand)] cron_command: CronCommands, @@ -202,6 +272,19 @@ enum Commands { Providers, /// Manage channels (telegram, discord, slack) + #[command(long_about = "\ +Manage communication channels. + +Add, remove, list, and health-check channels that connect ZeroClaw \ +to messaging platforms. Supported channel types: telegram, discord, \ +slack, whatsapp, matrix, imessage, email. + +Examples: + zeroclaw channel list + zeroclaw channel doctor + zeroclaw channel add telegram '{\"bot_token\":\"...\",\"name\":\"my-bot\"}' + zeroclaw channel remove my-bot + zeroclaw channel bind-telegram zeroclaw_user")] Channel { #[command(subcommand)] channel_command: ChannelCommands, @@ -232,16 +315,62 @@ enum Commands { }, /// Discover and introspect USB hardware + #[command(long_about = "\ +Discover and introspect USB hardware. + +Enumerate connected USB devices, identify known development boards \ +(STM32 Nucleo, Arduino, ESP32), and retrieve chip information via \ +probe-rs / ST-Link. + +Examples: + zeroclaw hardware discover + zeroclaw hardware introspect /dev/ttyACM0 + zeroclaw hardware info --chip STM32F401RETx")] Hardware { #[command(subcommand)] hardware_command: zeroclaw::HardwareCommands, }, /// Manage hardware peripherals (STM32, RPi GPIO, etc.) + #[command(long_about = "\ +Manage hardware peripherals. + +Add, list, flash, and configure hardware boards that expose tools \ +to the agent (GPIO, sensors, actuators). Supported boards: \ +nucleo-f401re, rpi-gpio, esp32, arduino-uno. + +Examples: + zeroclaw peripheral list + zeroclaw peripheral add nucleo-f401re /dev/ttyACM0 + zeroclaw peripheral add rpi-gpio native + zeroclaw peripheral flash --port /dev/cu.usbmodem12345 + zeroclaw peripheral flash-nucleo")] Peripheral { #[command(subcommand)] peripheral_command: zeroclaw::PeripheralCommands, }, + + /// Manage configuration + #[command(long_about = "\ +Manage ZeroClaw configuration. + +Inspect and export configuration settings. Use 'schema' to dump \ +the full JSON Schema for the config file, which documents every \ +available key, type, and default value. + +Examples: + zeroclaw config schema # print JSON Schema to stdout + zeroclaw config schema > schema.json")] + Config { + #[command(subcommand)] + config_command: ConfigCommands, + }, +} + +#[derive(Subcommand, Debug)] +enum ConfigCommands { + /// Dump the full configuration JSON Schema to stdout + Schema, } #[derive(Subcommand, Debug)] @@ -381,6 +510,23 @@ enum CronCommands { /// Task ID id: String, }, + /// Update a scheduled task + Update { + /// Task ID + id: String, + /// New cron expression + #[arg(long)] + expression: Option, + /// New IANA timezone + #[arg(long)] + tz: Option, + /// New command to run + #[arg(long)] + command: Option, + /// New job name + #[arg(long)] + name: Option, + }, /// Pause a scheduled task Pause { /// Task ID @@ -452,9 +598,9 @@ enum ChannelCommands { enum SkillCommands { /// List installed skills List, - /// Install a skill from a GitHub URL or local path + /// Install a skill from a git URL (HTTPS/SSH) or local path Install { - /// GitHub URL or local path + /// Git URL (HTTPS/SSH) or local path source: String, }, /// Remove an installed skill @@ -503,6 +649,7 @@ async fn main() -> Result<()> { channels_only, api_key, provider, + model, memory, } = &cli.command { @@ -510,25 +657,30 @@ async fn main() -> Result<()> { let channels_only = *channels_only; let api_key = api_key.clone(); let provider = provider.clone(); + let model = model.clone(); let memory = memory.clone(); if interactive && channels_only { bail!("Use either --interactive or --channels-only, not both"); } - if channels_only && (api_key.is_some() || provider.is_some() || memory.is_some()) { - bail!("--channels-only does not accept --api-key, --provider, or --memory"); + if channels_only + && (api_key.is_some() || provider.is_some() || model.is_some() || memory.is_some()) + { + bail!("--channels-only does not accept --api-key, --provider, --model, or --memory"); } - - let config = tokio::task::spawn_blocking(move || { - if channels_only { - onboard::run_channels_repair_wizard() - } else if interactive { - onboard::run_wizard() - } else { - onboard::run_quick_setup(api_key.as_deref(), provider.as_deref(), memory.as_deref()) - } - }) - .await??; + let config = if channels_only { + onboard::run_channels_repair_wizard().await + } else if interactive { + onboard::run_wizard().await + } else { + onboard::run_quick_setup( + api_key.as_deref(), + provider.as_deref(), + model.as_deref(), + memory.as_deref(), + ) + .await + }?; // Auto-start channels if user said yes during wizard if std::env::var("ZEROCLAW_AUTOSTART_CHANNELS").as_deref() == Ok("1") { channels::start_channels(config).await?; @@ -537,7 +689,7 @@ async fn main() -> Result<()> { } // All other commands need config loaded first - let mut config = Config::load_or_init()?; + let mut config = Config::load_or_init().await?; config.apply_env_overrides(); match cli.command { @@ -725,16 +877,14 @@ async fn main() -> Result<()> { Commands::Channel { channel_command } => match channel_command { ChannelCommands::Start => channels::start_channels(config).await, ChannelCommands::Doctor => channels::doctor_channels(config).await, - other => channels::handle_command(other, &config), + other => channels::handle_command(other, &config).await, }, Commands::Integrations { integration_command, } => integrations::handle_command(integration_command, &config), - Commands::Skills { skill_command } => { - skills::handle_command(skill_command, &config.workspace_dir) - } + Commands::Skills { skill_command } => skills::handle_command(skill_command, &config), Commands::Migrate { migrate_command } => { migration::handle_command(migrate_command, &config).await @@ -747,8 +897,19 @@ async fn main() -> Result<()> { } Commands::Peripheral { peripheral_command } => { - peripherals::handle_command(peripheral_command.clone(), &config) + peripherals::handle_command(peripheral_command.clone(), &config).await } + + Commands::Config { config_command } => match config_command { + ConfigCommands::Schema => { + let schema = schemars::schema_for!(config::Config); + println!( + "{}", + serde_json::to_string_pretty(&schema).expect("failed to serialize JSON Schema") + ); + Ok(()) + } + }, } } @@ -934,12 +1095,11 @@ async fn handle_auth_command(auth_command: AuthCommands, config: &Config) -> Res let account_id = extract_openai_account_id_for_profile(&token_set.access_token); - let saved = auth_service - .store_openai_tokens(&profile, token_set, account_id, true)?; + auth_service.store_openai_tokens(&profile, token_set, account_id, true)?; clear_pending_openai_login(config); - println!("Saved profile {}", saved.id); - println!("Active profile for openai-codex: {}", saved.id); + println!("Saved profile {profile}"); + println!("Active profile for openai-codex: {profile}"); return Ok(()); } Err(e) => { @@ -985,11 +1145,11 @@ async fn handle_auth_command(auth_command: AuthCommands, config: &Config) -> Res auth::openai_oauth::exchange_code_for_tokens(&client, &code, &pkce).await?; let account_id = extract_openai_account_id_for_profile(&token_set.access_token); - let saved = auth_service.store_openai_tokens(&profile, token_set, account_id, true)?; + auth_service.store_openai_tokens(&profile, token_set, account_id, true)?; clear_pending_openai_login(config); - println!("Saved profile {}", saved.id); - println!("Active profile for openai-codex: {}", saved.id); + println!("Saved profile {profile}"); + println!("Active profile for openai-codex: {profile}"); Ok(()) } @@ -1038,11 +1198,11 @@ async fn handle_auth_command(auth_command: AuthCommands, config: &Config) -> Res auth::openai_oauth::exchange_code_for_tokens(&client, &code, &pkce).await?; let account_id = extract_openai_account_id_for_profile(&token_set.access_token); - let saved = auth_service.store_openai_tokens(&profile, token_set, account_id, true)?; + auth_service.store_openai_tokens(&profile, token_set, account_id, true)?; clear_pending_openai_login(config); - println!("Saved profile {}", saved.id); - println!("Active profile for openai-codex: {}", saved.id); + println!("Saved profile {profile}"); + println!("Active profile for openai-codex: {profile}"); Ok(()) } @@ -1068,10 +1228,9 @@ async fn handle_auth_command(auth_command: AuthCommands, config: &Config) -> Res kind.as_metadata_value().to_string(), ); - let saved = - auth_service.store_provider_token(&provider, &profile, &token, metadata, true)?; - println!("Saved profile {}", saved.id); - println!("Active profile for {provider}: {}", saved.id); + auth_service.store_provider_token(&provider, &profile, &token, metadata, true)?; + println!("Saved profile {profile}"); + println!("Active profile for {provider}: {profile}"); Ok(()) } @@ -1089,10 +1248,9 @@ async fn handle_auth_command(auth_command: AuthCommands, config: &Config) -> Res kind.as_metadata_value().to_string(), ); - let saved = - auth_service.store_provider_token(&provider, &profile, &token, metadata, true)?; - println!("Saved profile {}", saved.id); - println!("Active profile for {provider}: {}", saved.id); + auth_service.store_provider_token(&provider, &profile, &token, metadata, true)?; + println!("Saved profile {profile}"); + println!("Active profile for {provider}: {profile}"); Ok(()) } @@ -1131,8 +1289,8 @@ async fn handle_auth_command(auth_command: AuthCommands, config: &Config) -> Res AuthCommands::Use { provider, profile } => { let provider = auth::normalize_provider(&provider)?; - let active = auth_service.set_active_profile(&provider, &profile)?; - println!("Active profile for {provider}: {active}"); + auth_service.set_active_profile(&provider, &profile)?; + println!("Active profile for {provider}: {profile}"); Ok(()) } @@ -1173,15 +1331,15 @@ async fn handle_auth_command(auth_command: AuthCommands, config: &Config) -> Res marker, id, profile.kind, - profile.account_id.as_deref().unwrap_or("unknown"), + crate::security::redact(profile.account_id.as_deref().unwrap_or("unknown")), format_expiry(profile) ); } println!(); println!("Active profiles:"); - for (provider, active) in &data.active_profiles { - println!(" {provider}: {active}"); + for (provider, profile_id) in &data.active_profiles { + println!(" {provider}: {profile_id}"); } Ok(()) @@ -1192,10 +1350,61 @@ async fn handle_auth_command(auth_command: AuthCommands, config: &Config) -> Res #[cfg(test)] mod tests { use super::*; - use clap::CommandFactory; + use clap::{CommandFactory, Parser}; #[test] fn cli_definition_has_no_flag_conflicts() { Cli::command().debug_assert(); } + + #[test] + fn onboard_help_includes_model_flag() { + let cmd = Cli::command(); + let onboard = cmd + .get_subcommands() + .find(|subcommand| subcommand.get_name() == "onboard") + .expect("onboard subcommand must exist"); + + let has_model_flag = onboard + .get_arguments() + .any(|arg| arg.get_id().as_str() == "model" && arg.get_long() == Some("model")); + + assert!( + has_model_flag, + "onboard help should include --model for quick setup overrides" + ); + } + + #[test] + fn onboard_cli_accepts_model_provider_and_api_key_in_quick_mode() { + let cli = Cli::try_parse_from([ + "zeroclaw", + "onboard", + "--provider", + "openrouter", + "--model", + "custom-model-946", + "--api-key", + "sk-issue946", + ]) + .expect("quick onboard invocation should parse"); + + match cli.command { + Commands::Onboard { + interactive, + channels_only, + api_key, + provider, + model, + .. + } => { + assert!(!interactive); + assert!(!channels_only); + assert_eq!(provider.as_deref(), Some("openrouter")); + assert_eq!(model.as_deref(), Some("custom-model-946")); + assert_eq!(api_key.as_deref(), Some("sk-issue946")); + } + other => panic!("expected onboard command, got {other:?}"), + } + } } diff --git a/src/memory/chunker.rs b/src/memory/chunker.rs index 97bddfa..590079a 100644 --- a/src/memory/chunker.rs +++ b/src/memory/chunker.rs @@ -3,12 +3,14 @@ // Splits on markdown headings and paragraph boundaries, respecting // a max token limit per chunk. Preserves heading context. +use std::rc::Rc; + /// A single chunk of text with metadata. #[derive(Debug, Clone)] pub struct Chunk { pub index: usize, pub content: String, - pub heading: Option, + pub heading: Option>, } /// Split markdown text into chunks, each under `max_tokens` approximate tokens. @@ -26,9 +28,10 @@ pub fn chunk_markdown(text: &str, max_tokens: usize) -> Vec { let max_chars = max_tokens * 4; let sections = split_on_headings(text); - let mut chunks = Vec::new(); + let mut chunks = Vec::with_capacity(sections.len()); for (heading, body) in sections { + let heading: Option> = heading.map(Rc::from); let full = if let Some(ref h) = heading { format!("{h}\n{body}") } else { @@ -45,7 +48,7 @@ pub fn chunk_markdown(text: &str, max_tokens: usize) -> Vec { // Split on paragraphs (blank lines) let paragraphs = split_on_blank_lines(&body); let mut current = heading - .as_ref() + .as_deref() .map_or_else(String::new, |h| format!("{h}\n")); for para in paragraphs { @@ -56,7 +59,7 @@ pub fn chunk_markdown(text: &str, max_tokens: usize) -> Vec { heading: heading.clone(), }); current = heading - .as_ref() + .as_deref() .map_or_else(String::new, |h| format!("{h}\n")); } @@ -69,7 +72,7 @@ pub fn chunk_markdown(text: &str, max_tokens: usize) -> Vec { heading: heading.clone(), }); current = heading - .as_ref() + .as_deref() .map_or_else(String::new, |h| format!("{h}\n")); } for line_chunk in split_on_lines(¶, max_chars) { @@ -115,8 +118,7 @@ fn split_on_headings(text: &str) -> Vec<(Option, String)> { for line in text.lines() { if line.starts_with("# ") || line.starts_with("## ") || line.starts_with("### ") { if !current_body.trim().is_empty() || current_heading.is_some() { - sections.push((current_heading.take(), current_body.clone())); - current_body.clear(); + sections.push((current_heading.take(), std::mem::take(&mut current_body))); } current_heading = Some(line.to_string()); } else { @@ -140,8 +142,7 @@ fn split_on_blank_lines(text: &str) -> Vec { for line in text.lines() { if line.trim().is_empty() { if !current.trim().is_empty() { - paragraphs.push(current.clone()); - current.clear(); + paragraphs.push(std::mem::take(&mut current)); } } else { current.push_str(line); @@ -158,13 +159,12 @@ fn split_on_blank_lines(text: &str) -> Vec { /// Split text on line boundaries to fit within `max_chars` fn split_on_lines(text: &str, max_chars: usize) -> Vec { - let mut chunks = Vec::new(); + let mut chunks = Vec::with_capacity(text.len() / max_chars.max(1) + 1); let mut current = String::new(); for line in text.lines() { if current.len() + line.len() + 1 > max_chars && !current.is_empty() { - chunks.push(current.clone()); - current.clear(); + chunks.push(std::mem::take(&mut current)); } current.push_str(line); current.push('\n'); diff --git a/src/memory/embeddings.rs b/src/memory/embeddings.rs index 058d077..4557ed4 100644 --- a/src/memory/embeddings.rs +++ b/src/memory/embeddings.rs @@ -172,6 +172,15 @@ pub fn create_embedding_provider( dims, )) } + "openrouter" => { + let key = api_key.unwrap_or(""); + Box::new(OpenAiEmbedding::new( + "https://openrouter.ai/api/v1", + key, + model, + dims, + )) + } name if name.starts_with("custom:") => { let base_url = name.strip_prefix("custom:").unwrap_or(""); let key = api_key.unwrap_or(""); @@ -212,6 +221,18 @@ mod tests { assert_eq!(p.dimensions(), 1536); } + #[test] + fn factory_openrouter() { + let p = create_embedding_provider( + "openrouter", + Some("sk-or-test"), + "openai/text-embedding-3-small", + 1536, + ); + assert_eq!(p.name(), "openai"); // uses OpenAiEmbedding internally + assert_eq!(p.dimensions(), 1536); + } + #[test] fn factory_custom_url() { let p = create_embedding_provider("custom:http://localhost:1234", None, "model", 768); @@ -281,6 +302,20 @@ mod tests { assert_eq!(p.dimensions(), 384); } + #[test] + fn embeddings_url_openrouter() { + let p = OpenAiEmbedding::new( + "https://openrouter.ai/api/v1", + "key", + "openai/text-embedding-3-small", + 1536, + ); + assert_eq!( + p.embeddings_url(), + "https://openrouter.ai/api/v1/embeddings" + ); + } + #[test] fn embeddings_url_standard_openai() { let p = OpenAiEmbedding::new("https://api.openai.com", "key", "model", 1536); diff --git a/src/memory/lucid.rs b/src/memory/lucid.rs index 62af08f..763f5f7 100644 --- a/src/memory/lucid.rs +++ b/src/memory/lucid.rs @@ -608,7 +608,7 @@ exit 1 .iter() .any(|e| e.content.contains("Rust should stay local-first"))); - let context_calls = fs::read_to_string(&marker).unwrap_or_default(); + let context_calls = tokio::fs::read_to_string(&marker).await.unwrap_or_default(); assert!( context_calls.trim().is_empty(), "Expected local-hit short-circuit; got calls: {context_calls}" @@ -669,7 +669,7 @@ exit 1 assert!(first.is_empty()); assert!(second.is_empty()); - let calls = fs::read_to_string(&marker).unwrap_or_default(); + let calls = tokio::fs::read_to_string(&marker).await.unwrap_or_default(); assert_eq!(calls.lines().count(), 1); } } diff --git a/src/memory/markdown.rs b/src/memory/markdown.rs index 9038683..5bc093f 100644 --- a/src/memory/markdown.rs +++ b/src/memory/markdown.rs @@ -229,7 +229,6 @@ impl Memory for MarkdownMemory { #[cfg(test)] mod tests { use super::*; - use std::fs as sync_fs; use tempfile::TempDir; fn temp_workspace() -> (TempDir, MarkdownMemory) { @@ -256,7 +255,7 @@ mod tests { mem.store("pref", "User likes Rust", MemoryCategory::Core, None) .await .unwrap(); - let content = sync_fs::read_to_string(mem.core_path()).unwrap(); + let content = fs::read_to_string(mem.core_path()).await.unwrap(); assert!(content.contains("User likes Rust")); } @@ -267,7 +266,7 @@ mod tests { .await .unwrap(); let path = mem.daily_path(); - let content = sync_fs::read_to_string(path).unwrap(); + let content = fs::read_to_string(path).await.unwrap(); assert!(content.contains("Finished tests")); } diff --git a/src/memory/mod.rs b/src/memory/mod.rs index b4ea5e7..f60d926 100644 --- a/src/memory/mod.rs +++ b/src/memory/mod.rs @@ -27,7 +27,7 @@ pub use traits::Memory; #[allow(unused_imports)] pub use traits::{MemoryCategory, MemoryEntry}; -use crate::config::{MemoryConfig, StorageProviderConfig}; +use crate::config::{EmbeddingRouteConfig, MemoryConfig, StorageProviderConfig}; use anyhow::Context; use std::path::Path; use std::sync::Arc; @@ -75,13 +75,101 @@ pub fn effective_memory_backend_name( memory_backend.trim().to_ascii_lowercase() } +/// Legacy auto-save key used for model-authored assistant summaries. +/// These entries are treated as untrusted context and should not be re-injected. +pub fn is_assistant_autosave_key(key: &str) -> bool { + let normalized = key.trim().to_ascii_lowercase(); + normalized == "assistant_resp" || normalized.starts_with("assistant_resp_") +} + +#[derive(Clone, PartialEq, Eq)] +struct ResolvedEmbeddingConfig { + provider: String, + model: String, + dimensions: usize, + api_key: Option, +} + +impl std::fmt::Debug for ResolvedEmbeddingConfig { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ResolvedEmbeddingConfig") + .field("provider", &self.provider) + .field("model", &self.model) + .field("dimensions", &self.dimensions) + .field("api_key", &self.api_key.as_ref().map(|_| "[REDACTED]")) + .finish() + } +} + +fn resolve_embedding_config( + config: &MemoryConfig, + embedding_routes: &[EmbeddingRouteConfig], + api_key: Option<&str>, +) -> ResolvedEmbeddingConfig { + let fallback_api_key = api_key + .map(str::trim) + .filter(|value| !value.is_empty()) + .map(str::to_string); + let fallback = ResolvedEmbeddingConfig { + provider: config.embedding_provider.trim().to_string(), + model: config.embedding_model.trim().to_string(), + dimensions: config.embedding_dimensions, + api_key: fallback_api_key.clone(), + }; + + let Some(hint) = config + .embedding_model + .strip_prefix("hint:") + .map(str::trim) + .filter(|value| !value.is_empty()) + else { + return fallback; + }; + + let Some(route) = embedding_routes + .iter() + .find(|route| route.hint.trim() == hint) + else { + tracing::warn!( + hint, + "Unknown embedding route hint; falling back to [memory] embedding settings" + ); + return fallback; + }; + + let provider = route.provider.trim(); + let model = route.model.trim(); + let dimensions = route.dimensions.unwrap_or(config.embedding_dimensions); + if provider.is_empty() || model.is_empty() || dimensions == 0 { + tracing::warn!( + hint, + "Invalid embedding route configuration; falling back to [memory] embedding settings" + ); + return fallback; + } + + let routed_api_key = route + .api_key + .as_deref() + .map(str::trim) + .filter(|value: &&str| !value.is_empty()) + .map(|value| value.to_string()); + + ResolvedEmbeddingConfig { + provider: provider.to_string(), + model: model.to_string(), + dimensions, + api_key: routed_api_key.or(fallback_api_key), + } +} + /// Factory: create the right memory backend from config pub fn create_memory( config: &MemoryConfig, workspace_dir: &Path, api_key: Option<&str>, ) -> anyhow::Result> { - create_memory_with_storage(config, None, workspace_dir, api_key) + create_memory_with_storage_and_routes(config, &[], None, workspace_dir, api_key) } /// Factory: create memory with optional storage-provider override. @@ -90,9 +178,21 @@ pub fn create_memory_with_storage( storage_provider: Option<&StorageProviderConfig>, workspace_dir: &Path, api_key: Option<&str>, +) -> anyhow::Result> { + create_memory_with_storage_and_routes(config, &[], storage_provider, workspace_dir, api_key) +} + +/// Factory: create memory with optional storage-provider override and embedding routes. +pub fn create_memory_with_storage_and_routes( + config: &MemoryConfig, + embedding_routes: &[EmbeddingRouteConfig], + storage_provider: Option<&StorageProviderConfig>, + workspace_dir: &Path, + api_key: Option<&str>, ) -> anyhow::Result> { let backend_name = effective_memory_backend_name(&config.backend, storage_provider); let backend_kind = classify_memory_backend(&backend_name); + let resolved_embedding = resolve_embedding_config(config, embedding_routes, api_key); // Best-effort memory hygiene/retention pass (throttled by state file). if let Err(e) = hygiene::run_if_due(config, workspace_dir) { @@ -137,14 +237,14 @@ pub fn create_memory_with_storage( fn build_sqlite_memory( config: &MemoryConfig, workspace_dir: &Path, - api_key: Option<&str>, + resolved_embedding: &ResolvedEmbeddingConfig, ) -> anyhow::Result { let embedder: Arc = Arc::from(embeddings::create_embedding_provider( - &config.embedding_provider, - api_key, - &config.embedding_model, - config.embedding_dimensions, + &resolved_embedding.provider, + resolved_embedding.api_key.as_deref(), + &resolved_embedding.model, + resolved_embedding.dimensions, )); #[allow(clippy::cast_possible_truncation)] @@ -184,7 +284,7 @@ pub fn create_memory_with_storage( create_memory_with_builders( &backend_name, workspace_dir, - || build_sqlite_memory(config, workspace_dir, api_key), + || build_sqlite_memory(config, workspace_dir, &resolved_embedding), || build_postgres_memory(storage_provider), "", ) @@ -247,7 +347,7 @@ pub fn create_response_cache(config: &MemoryConfig, workspace_dir: &Path) -> Opt #[cfg(test)] mod tests { use super::*; - use crate::config::StorageProviderConfig; + use crate::config::{EmbeddingRouteConfig, StorageProviderConfig}; use tempfile::TempDir; #[test] @@ -261,6 +361,15 @@ mod tests { assert_eq!(mem.name(), "sqlite"); } + #[test] + fn assistant_autosave_key_detection_matches_legacy_patterns() { + assert!(is_assistant_autosave_key("assistant_resp")); + assert!(is_assistant_autosave_key("assistant_resp_1234")); + assert!(is_assistant_autosave_key("ASSISTANT_RESP_abcd")); + assert!(!is_assistant_autosave_key("assistant_response")); + assert!(!is_assistant_autosave_key("user_msg_1234")); + } + #[test] fn factory_markdown() { let tmp = TempDir::new().unwrap(); @@ -353,4 +462,102 @@ mod tests { .expect("postgres without db_url should be rejected"); assert!(error.to_string().contains("db_url")); } + + #[test] + fn resolve_embedding_config_uses_base_config_when_model_is_not_hint() { + let cfg = MemoryConfig { + embedding_provider: "openai".into(), + embedding_model: "text-embedding-3-small".into(), + embedding_dimensions: 1536, + ..MemoryConfig::default() + }; + + let resolved = resolve_embedding_config(&cfg, &[], Some("base-key")); + assert_eq!( + resolved, + ResolvedEmbeddingConfig { + provider: "openai".into(), + model: "text-embedding-3-small".into(), + dimensions: 1536, + api_key: Some("base-key".into()), + } + ); + } + + #[test] + fn resolve_embedding_config_uses_matching_route_with_api_key_override() { + let cfg = MemoryConfig { + embedding_provider: "none".into(), + embedding_model: "hint:semantic".into(), + embedding_dimensions: 1536, + ..MemoryConfig::default() + }; + let routes = vec![EmbeddingRouteConfig { + hint: "semantic".into(), + provider: "custom:https://api.example.com/v1".into(), + model: "custom-embed-v2".into(), + dimensions: Some(1024), + api_key: Some("route-key".into()), + }]; + + let resolved = resolve_embedding_config(&cfg, &routes, Some("base-key")); + assert_eq!( + resolved, + ResolvedEmbeddingConfig { + provider: "custom:https://api.example.com/v1".into(), + model: "custom-embed-v2".into(), + dimensions: 1024, + api_key: Some("route-key".into()), + } + ); + } + + #[test] + fn resolve_embedding_config_falls_back_when_hint_is_missing() { + let cfg = MemoryConfig { + embedding_provider: "openai".into(), + embedding_model: "hint:semantic".into(), + embedding_dimensions: 1536, + ..MemoryConfig::default() + }; + + let resolved = resolve_embedding_config(&cfg, &[], Some("base-key")); + assert_eq!( + resolved, + ResolvedEmbeddingConfig { + provider: "openai".into(), + model: "hint:semantic".into(), + dimensions: 1536, + api_key: Some("base-key".into()), + } + ); + } + + #[test] + fn resolve_embedding_config_falls_back_when_route_is_invalid() { + let cfg = MemoryConfig { + embedding_provider: "openai".into(), + embedding_model: "hint:semantic".into(), + embedding_dimensions: 1536, + ..MemoryConfig::default() + }; + let routes = vec![EmbeddingRouteConfig { + hint: "semantic".into(), + provider: String::new(), + model: "text-embedding-3-small".into(), + dimensions: Some(0), + api_key: None, + }]; + + let resolved = resolve_embedding_config(&cfg, &routes, Some("base-key")); + assert_eq!( + resolved, + ResolvedEmbeddingConfig { + provider: "openai".into(), + model: "hint:semantic".into(), + dimensions: 1536, + api_key: Some("base-key".into()), + } + ); + } } diff --git a/src/memory/postgres.rs b/src/memory/postgres.rs index 4f21293..4382751 100644 --- a/src/memory/postgres.rs +++ b/src/memory/postgres.rs @@ -30,24 +30,16 @@ impl PostgresMemory { validate_identifier(schema, "storage schema")?; validate_identifier(table, "storage table")?; - let mut config: postgres::Config = db_url - .parse() - .context("invalid PostgreSQL connection URL")?; - - if let Some(timeout_secs) = connect_timeout_secs { - let bounded = timeout_secs.min(POSTGRES_CONNECT_TIMEOUT_CAP_SECS); - config.connect_timeout(Duration::from_secs(bounded)); - } - - let mut client = config - .connect(NoTls) - .context("failed to connect to PostgreSQL memory backend")?; - let schema_ident = quote_identifier(schema); let table_ident = quote_identifier(table); let qualified_table = format!("{schema_ident}.{table_ident}"); - Self::init_schema(&mut client, &schema_ident, &qualified_table)?; + let client = Self::initialize_client( + db_url.to_string(), + connect_timeout_secs, + schema_ident.clone(), + qualified_table.clone(), + )?; Ok(Self { client: Arc::new(Mutex::new(client)), @@ -55,6 +47,40 @@ impl PostgresMemory { }) } + fn initialize_client( + db_url: String, + connect_timeout_secs: Option, + schema_ident: String, + qualified_table: String, + ) -> Result { + let init_handle = std::thread::Builder::new() + .name("postgres-memory-init".to_string()) + .spawn(move || -> Result { + let mut config: postgres::Config = db_url + .parse() + .context("invalid PostgreSQL connection URL")?; + + if let Some(timeout_secs) = connect_timeout_secs { + let bounded = timeout_secs.min(POSTGRES_CONNECT_TIMEOUT_CAP_SECS); + config.connect_timeout(Duration::from_secs(bounded)); + } + + let mut client = config + .connect(NoTls) + .context("failed to connect to PostgreSQL memory backend")?; + + Self::init_schema(&mut client, &schema_ident, &qualified_table)?; + Ok(client) + }) + .context("failed to spawn PostgreSQL initializer thread")?; + + let init_result = init_handle + .join() + .map_err(|_| anyhow::anyhow!("PostgreSQL initializer thread panicked"))?; + + init_result + } + fn init_schema(client: &mut Client, schema_ident: &str, qualified_table: &str) -> Result<()> { client.batch_execute(&format!( " @@ -157,7 +183,7 @@ impl Memory for PostgresMemory { let key = key.to_string(); let content = content.to_string(); let category = Self::category_to_str(&category); - let session_id = session_id.map(str::to_string); + let sid = session_id.map(str::to_string); tokio::task::spawn_blocking(move || -> Result<()> { let now = Utc::now(); @@ -177,10 +203,7 @@ impl Memory for PostgresMemory { ); let id = Uuid::new_v4().to_string(); - client.execute( - &stmt, - &[&id, &key, &content, &category, &now, &now, &session_id], - )?; + client.execute(&stmt, &[&id, &key, &content, &category, &now, &now, &sid])?; Ok(()) }) .await? @@ -195,7 +218,7 @@ impl Memory for PostgresMemory { let client = self.client.clone(); let qualified_table = self.qualified_table.clone(); let query = query.trim().to_string(); - let session_id = session_id.map(str::to_string); + let sid = session_id.map(str::to_string); tokio::task::spawn_blocking(move || -> Result> { let mut client = client.lock(); @@ -217,7 +240,7 @@ impl Memory for PostgresMemory { #[allow(clippy::cast_possible_wrap)] let limit_i64 = limit as i64; - let rows = client.query(&stmt, &[&query, &session_id, &limit_i64])?; + let rows = client.query(&stmt, &[&query, &sid, &limit_i64])?; rows.iter() .map(Self::row_to_entry) .collect::>>() @@ -255,7 +278,7 @@ impl Memory for PostgresMemory { let client = self.client.clone(); let qualified_table = self.qualified_table.clone(); let category = category.map(Self::category_to_str); - let session_id = session_id.map(str::to_string); + let sid = session_id.map(str::to_string); tokio::task::spawn_blocking(move || -> Result> { let mut client = client.lock(); @@ -270,7 +293,7 @@ impl Memory for PostgresMemory { ); let category_ref = category.as_deref(); - let session_ref = session_id.as_deref(); + let session_ref = sid.as_deref(); let rows = client.query(&stmt, &[&category_ref, &session_ref])?; rows.iter() .map(Self::row_to_entry) @@ -349,4 +372,22 @@ mod tests { MemoryCategory::Custom("custom_notes".into()) ); } + + #[tokio::test(flavor = "current_thread")] + async fn new_does_not_panic_inside_tokio_runtime() { + let outcome = std::panic::catch_unwind(|| { + PostgresMemory::new( + "postgres://zeroclaw:password@127.0.0.1:1/zeroclaw", + "public", + "memories", + Some(1), + ) + }); + + assert!(outcome.is_ok(), "PostgresMemory::new should not panic"); + assert!( + outcome.unwrap().is_err(), + "PostgresMemory::new should return a connect error for an unreachable endpoint" + ); + } } diff --git a/src/memory/sqlite.rs b/src/memory/sqlite.rs index f0d0bd1..3e90ec6 100644 --- a/src/memory/sqlite.rs +++ b/src/memory/sqlite.rs @@ -452,7 +452,7 @@ impl Memory for SqliteMemory { let conn = self.conn.clone(); let key = key.to_string(); let content = content.to_string(); - let session_id = session_id.map(String::from); + let sid = session_id.map(String::from); tokio::task::spawn_blocking(move || -> anyhow::Result<()> { let conn = conn.lock(); @@ -469,7 +469,7 @@ impl Memory for SqliteMemory { embedding = excluded.embedding, updated_at = excluded.updated_at, session_id = excluded.session_id", - params![id, key, content, cat, embedding_bytes, now, now, session_id], + params![id, key, content, cat, embedding_bytes, now, now, sid], )?; Ok(()) }) @@ -491,13 +491,13 @@ impl Memory for SqliteMemory { let conn = self.conn.clone(); let query = query.to_string(); - let session_id = session_id.map(String::from); + let sid = session_id.map(String::from); let vector_weight = self.vector_weight; let keyword_weight = self.keyword_weight; tokio::task::spawn_blocking(move || -> anyhow::Result> { let conn = conn.lock(); - let session_ref = session_id.as_deref(); + let session_ref = sid.as_deref(); // FTS5 BM25 keyword search let keyword_results = Self::fts5_search(&conn, &query, limit * 2).unwrap_or_default(); @@ -691,11 +691,11 @@ impl Memory for SqliteMemory { let conn = self.conn.clone(); let category = category.cloned(); - let session_id = session_id.map(String::from); + let sid = session_id.map(String::from); tokio::task::spawn_blocking(move || -> anyhow::Result> { let conn = conn.lock(); - let session_ref = session_id.as_deref(); + let session_ref = sid.as_deref(); let mut results = Vec::new(); let row_mapper = |row: &rusqlite::Row| -> rusqlite::Result { diff --git a/src/multimodal.rs b/src/multimodal.rs new file mode 100644 index 0000000..bd15900 --- /dev/null +++ b/src/multimodal.rs @@ -0,0 +1,568 @@ +use crate::config::{build_runtime_proxy_client_with_timeouts, MultimodalConfig}; +use crate::providers::ChatMessage; +use base64::{engine::general_purpose::STANDARD, Engine as _}; +use reqwest::Client; +use std::path::Path; + +const IMAGE_MARKER_PREFIX: &str = "[IMAGE:"; +const ALLOWED_IMAGE_MIME_TYPES: &[&str] = &[ + "image/png", + "image/jpeg", + "image/webp", + "image/gif", + "image/bmp", +]; + +#[derive(Debug, Clone)] +pub struct PreparedMessages { + pub messages: Vec, + pub contains_images: bool, +} + +#[derive(Debug, thiserror::Error)] +pub enum MultimodalError { + #[error("multimodal image limit exceeded: max_images={max_images}, found={found}")] + TooManyImages { max_images: usize, found: usize }, + + #[error("multimodal image size limit exceeded for '{input}': {size_bytes} bytes > {max_bytes} bytes")] + ImageTooLarge { + input: String, + size_bytes: usize, + max_bytes: usize, + }, + + #[error("multimodal image MIME type is not allowed for '{input}': {mime}")] + UnsupportedMime { input: String, mime: String }, + + #[error("multimodal remote image fetch is disabled for '{input}'")] + RemoteFetchDisabled { input: String }, + + #[error("multimodal image source not found or unreadable: '{input}'")] + ImageSourceNotFound { input: String }, + + #[error("invalid multimodal image marker '{input}': {reason}")] + InvalidMarker { input: String, reason: String }, + + #[error("failed to download remote image '{input}': {reason}")] + RemoteFetchFailed { input: String, reason: String }, + + #[error("failed to read local image '{input}': {reason}")] + LocalReadFailed { input: String, reason: String }, +} + +pub fn parse_image_markers(content: &str) -> (String, Vec) { + let mut refs = Vec::new(); + let mut cleaned = String::with_capacity(content.len()); + let mut cursor = 0usize; + + while let Some(rel_start) = content[cursor..].find(IMAGE_MARKER_PREFIX) { + let start = cursor + rel_start; + cleaned.push_str(&content[cursor..start]); + + let marker_start = start + IMAGE_MARKER_PREFIX.len(); + let Some(rel_end) = content[marker_start..].find(']') else { + cleaned.push_str(&content[start..]); + cursor = content.len(); + break; + }; + + let end = marker_start + rel_end; + let candidate = content[marker_start..end].trim(); + + if candidate.is_empty() { + cleaned.push_str(&content[start..=end]); + } else { + refs.push(candidate.to_string()); + } + + cursor = end + 1; + } + + if cursor < content.len() { + cleaned.push_str(&content[cursor..]); + } + + (cleaned.trim().to_string(), refs) +} + +pub fn count_image_markers(messages: &[ChatMessage]) -> usize { + messages + .iter() + .filter(|m| m.role == "user") + .map(|m| parse_image_markers(&m.content).1.len()) + .sum() +} + +pub fn contains_image_markers(messages: &[ChatMessage]) -> bool { + count_image_markers(messages) > 0 +} + +pub fn extract_ollama_image_payload(image_ref: &str) -> Option { + if image_ref.starts_with("data:") { + let comma_idx = image_ref.find(',')?; + let (_, payload) = image_ref.split_at(comma_idx + 1); + let payload = payload.trim(); + if payload.is_empty() { + None + } else { + Some(payload.to_string()) + } + } else { + Some(image_ref.trim().to_string()).filter(|value| !value.is_empty()) + } +} + +pub async fn prepare_messages_for_provider( + messages: &[ChatMessage], + config: &MultimodalConfig, +) -> anyhow::Result { + let (max_images, max_image_size_mb) = config.effective_limits(); + let max_bytes = max_image_size_mb.saturating_mul(1024 * 1024); + + let found_images = count_image_markers(messages); + if found_images > max_images { + return Err(MultimodalError::TooManyImages { + max_images, + found: found_images, + } + .into()); + } + + if found_images == 0 { + return Ok(PreparedMessages { + messages: messages.to_vec(), + contains_images: false, + }); + } + + let remote_client = build_runtime_proxy_client_with_timeouts("provider.ollama", 30, 10); + + let mut normalized_messages = Vec::with_capacity(messages.len()); + for message in messages { + if message.role != "user" { + normalized_messages.push(message.clone()); + continue; + } + + let (cleaned_text, refs) = parse_image_markers(&message.content); + if refs.is_empty() { + normalized_messages.push(message.clone()); + continue; + } + + let mut normalized_refs = Vec::with_capacity(refs.len()); + for reference in refs { + let data_uri = + normalize_image_reference(&reference, config, max_bytes, &remote_client).await?; + normalized_refs.push(data_uri); + } + + let content = compose_multimodal_message(&cleaned_text, &normalized_refs); + normalized_messages.push(ChatMessage { + role: message.role.clone(), + content, + }); + } + + Ok(PreparedMessages { + messages: normalized_messages, + contains_images: true, + }) +} + +fn compose_multimodal_message(text: &str, data_uris: &[String]) -> String { + let mut content = String::new(); + let trimmed = text.trim(); + + if !trimmed.is_empty() { + content.push_str(trimmed); + content.push_str("\n\n"); + } + + for (index, data_uri) in data_uris.iter().enumerate() { + if index > 0 { + content.push('\n'); + } + content.push_str(IMAGE_MARKER_PREFIX); + content.push_str(data_uri); + content.push(']'); + } + + content +} + +async fn normalize_image_reference( + source: &str, + config: &MultimodalConfig, + max_bytes: usize, + remote_client: &Client, +) -> anyhow::Result { + if source.starts_with("data:") { + return normalize_data_uri(source, max_bytes); + } + + if source.starts_with("http://") || source.starts_with("https://") { + if !config.allow_remote_fetch { + return Err(MultimodalError::RemoteFetchDisabled { + input: source.to_string(), + } + .into()); + } + + return normalize_remote_image(source, max_bytes, remote_client).await; + } + + normalize_local_image(source, max_bytes).await +} + +fn normalize_data_uri(source: &str, max_bytes: usize) -> anyhow::Result { + let Some(comma_idx) = source.find(',') else { + return Err(MultimodalError::InvalidMarker { + input: source.to_string(), + reason: "expected data URI payload".to_string(), + } + .into()); + }; + + let header = &source[..comma_idx]; + let payload = source[comma_idx + 1..].trim(); + + if !header.contains(";base64") { + return Err(MultimodalError::InvalidMarker { + input: source.to_string(), + reason: "only base64 data URIs are supported".to_string(), + } + .into()); + } + + let mime = header + .trim_start_matches("data:") + .split(';') + .next() + .unwrap_or_default() + .trim() + .to_ascii_lowercase(); + + validate_mime(source, &mime)?; + + let decoded = STANDARD + .decode(payload) + .map_err(|error| MultimodalError::InvalidMarker { + input: source.to_string(), + reason: format!("invalid base64 payload: {error}"), + })?; + + validate_size(source, decoded.len(), max_bytes)?; + + Ok(format!("data:{mime};base64,{}", STANDARD.encode(decoded))) +} + +async fn normalize_remote_image( + source: &str, + max_bytes: usize, + remote_client: &Client, +) -> anyhow::Result { + let response = remote_client.get(source).send().await.map_err(|error| { + MultimodalError::RemoteFetchFailed { + input: source.to_string(), + reason: error.to_string(), + } + })?; + + let status = response.status(); + if !status.is_success() { + return Err(MultimodalError::RemoteFetchFailed { + input: source.to_string(), + reason: format!("HTTP {status}"), + } + .into()); + } + + if let Some(content_length) = response.content_length() { + let content_length = content_length as usize; + validate_size(source, content_length, max_bytes)?; + } + + let content_type = response + .headers() + .get(reqwest::header::CONTENT_TYPE) + .and_then(|value| value.to_str().ok()) + .map(ToString::to_string); + + let bytes = response + .bytes() + .await + .map_err(|error| MultimodalError::RemoteFetchFailed { + input: source.to_string(), + reason: error.to_string(), + })?; + + validate_size(source, bytes.len(), max_bytes)?; + + let mime = detect_mime(None, bytes.as_ref(), content_type.as_deref()).ok_or_else(|| { + MultimodalError::UnsupportedMime { + input: source.to_string(), + mime: "unknown".to_string(), + } + })?; + + validate_mime(source, &mime)?; + + Ok(format!("data:{mime};base64,{}", STANDARD.encode(bytes))) +} + +async fn normalize_local_image(source: &str, max_bytes: usize) -> anyhow::Result { + let path = Path::new(source); + if !path.exists() || !path.is_file() { + return Err(MultimodalError::ImageSourceNotFound { + input: source.to_string(), + } + .into()); + } + + let metadata = + tokio::fs::metadata(path) + .await + .map_err(|error| MultimodalError::LocalReadFailed { + input: source.to_string(), + reason: error.to_string(), + })?; + + validate_size(source, metadata.len() as usize, max_bytes)?; + + let bytes = tokio::fs::read(path) + .await + .map_err(|error| MultimodalError::LocalReadFailed { + input: source.to_string(), + reason: error.to_string(), + })?; + + validate_size(source, bytes.len(), max_bytes)?; + + let mime = + detect_mime(Some(path), &bytes, None).ok_or_else(|| MultimodalError::UnsupportedMime { + input: source.to_string(), + mime: "unknown".to_string(), + })?; + + validate_mime(source, &mime)?; + + Ok(format!("data:{mime};base64,{}", STANDARD.encode(bytes))) +} + +fn validate_size(source: &str, size_bytes: usize, max_bytes: usize) -> anyhow::Result<()> { + if size_bytes > max_bytes { + return Err(MultimodalError::ImageTooLarge { + input: source.to_string(), + size_bytes, + max_bytes, + } + .into()); + } + + Ok(()) +} + +fn validate_mime(source: &str, mime: &str) -> anyhow::Result<()> { + if ALLOWED_IMAGE_MIME_TYPES + .iter() + .any(|allowed| *allowed == mime) + { + return Ok(()); + } + + Err(MultimodalError::UnsupportedMime { + input: source.to_string(), + mime: mime.to_string(), + } + .into()) +} + +fn detect_mime( + path: Option<&Path>, + bytes: &[u8], + header_content_type: Option<&str>, +) -> Option { + if let Some(header_mime) = header_content_type.and_then(normalize_content_type) { + return Some(header_mime); + } + + if let Some(path) = path { + if let Some(ext) = path.extension().and_then(|value| value.to_str()) { + if let Some(mime) = mime_from_extension(ext) { + return Some(mime.to_string()); + } + } + } + + mime_from_magic(bytes).map(ToString::to_string) +} + +fn normalize_content_type(content_type: &str) -> Option { + let mime = content_type.split(';').next()?.trim().to_ascii_lowercase(); + if mime.is_empty() { + None + } else { + Some(mime) + } +} + +fn mime_from_extension(ext: &str) -> Option<&'static str> { + match ext.to_ascii_lowercase().as_str() { + "png" => Some("image/png"), + "jpg" | "jpeg" => Some("image/jpeg"), + "webp" => Some("image/webp"), + "gif" => Some("image/gif"), + "bmp" => Some("image/bmp"), + _ => None, + } +} + +fn mime_from_magic(bytes: &[u8]) -> Option<&'static str> { + if bytes.len() >= 8 && bytes.starts_with(&[0x89, b'P', b'N', b'G', b'\r', b'\n', 0x1a, b'\n']) { + return Some("image/png"); + } + + if bytes.len() >= 3 && bytes.starts_with(&[0xff, 0xd8, 0xff]) { + return Some("image/jpeg"); + } + + if bytes.len() >= 6 && (bytes.starts_with(b"GIF87a") || bytes.starts_with(b"GIF89a")) { + return Some("image/gif"); + } + + if bytes.len() >= 12 && bytes.starts_with(b"RIFF") && &bytes[8..12] == b"WEBP" { + return Some("image/webp"); + } + + if bytes.len() >= 2 && bytes.starts_with(b"BM") { + return Some("image/bmp"); + } + + None +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn parse_image_markers_extracts_multiple_markers() { + let input = "Check this [IMAGE:/tmp/a.png] and this [IMAGE:https://example.com/b.jpg]"; + let (cleaned, refs) = parse_image_markers(input); + + assert_eq!(cleaned, "Check this and this"); + assert_eq!(refs.len(), 2); + assert_eq!(refs[0], "/tmp/a.png"); + assert_eq!(refs[1], "https://example.com/b.jpg"); + } + + #[test] + fn parse_image_markers_keeps_invalid_empty_marker() { + let input = "hello [IMAGE:] world"; + let (cleaned, refs) = parse_image_markers(input); + + assert_eq!(cleaned, "hello [IMAGE:] world"); + assert!(refs.is_empty()); + } + + #[tokio::test] + async fn prepare_messages_normalizes_local_image_to_data_uri() { + let temp = tempfile::tempdir().unwrap(); + let image_path = temp.path().join("sample.png"); + + // Minimal PNG signature bytes are enough for MIME detection. + std::fs::write( + &image_path, + [0x89, b'P', b'N', b'G', b'\r', b'\n', 0x1a, b'\n'], + ) + .unwrap(); + + let messages = vec![ChatMessage::user(format!( + "Please inspect this screenshot [IMAGE:{}]", + image_path.display() + ))]; + + let prepared = prepare_messages_for_provider(&messages, &MultimodalConfig::default()) + .await + .unwrap(); + + assert!(prepared.contains_images); + assert_eq!(prepared.messages.len(), 1); + + let (cleaned, refs) = parse_image_markers(&prepared.messages[0].content); + assert_eq!(cleaned, "Please inspect this screenshot"); + assert_eq!(refs.len(), 1); + assert!(refs[0].starts_with("data:image/png;base64,")); + } + + #[tokio::test] + async fn prepare_messages_rejects_too_many_images() { + let messages = vec![ChatMessage::user( + "[IMAGE:/tmp/1.png]\n[IMAGE:/tmp/2.png]".to_string(), + )]; + + let config = MultimodalConfig { + max_images: 1, + max_image_size_mb: 5, + allow_remote_fetch: false, + }; + + let error = prepare_messages_for_provider(&messages, &config) + .await + .expect_err("should reject image count overflow"); + + assert!(error + .to_string() + .contains("multimodal image limit exceeded")); + } + + #[tokio::test] + async fn prepare_messages_rejects_remote_url_when_disabled() { + let messages = vec![ChatMessage::user( + "Look [IMAGE:https://example.com/img.png]".to_string(), + )]; + + let error = prepare_messages_for_provider(&messages, &MultimodalConfig::default()) + .await + .expect_err("should reject remote image URL when fetch is disabled"); + + assert!(error + .to_string() + .contains("multimodal remote image fetch is disabled")); + } + + #[tokio::test] + async fn prepare_messages_rejects_oversized_local_image() { + let temp = tempfile::tempdir().unwrap(); + let image_path = temp.path().join("big.png"); + + let bytes = vec![0u8; 1024 * 1024 + 1]; + std::fs::write(&image_path, bytes).unwrap(); + + let messages = vec![ChatMessage::user(format!( + "[IMAGE:{}]", + image_path.display() + ))]; + let config = MultimodalConfig { + max_images: 4, + max_image_size_mb: 1, + allow_remote_fetch: false, + }; + + let error = prepare_messages_for_provider(&messages, &config) + .await + .expect_err("should reject oversized local image"); + + assert!(error + .to_string() + .contains("multimodal image size limit exceeded")); + } + + #[test] + fn extract_ollama_image_payload_supports_data_uris() { + let payload = extract_ollama_image_payload("data:image/png;base64,abcd==") + .expect("payload should be extracted"); + assert_eq!(payload, "abcd=="); + } +} diff --git a/src/observability/traits.rs b/src/observability/traits.rs index ea5f5d1..0249938 100644 --- a/src/observability/traits.rs +++ b/src/observability/traits.rs @@ -1,12 +1,15 @@ use std::time::Duration; -/// Events the observer can record +/// Discrete events emitted by the agent runtime for observability. +/// +/// Each variant represents a lifecycle event that observers can record, +/// aggregate, or forward to external monitoring systems. Events carry +/// just enough context for tracing and diagnostics without exposing +/// sensitive prompt or response content. #[derive(Debug, Clone)] pub enum ObserverEvent { - AgentStart { - provider: String, - model: String, - }, + /// The agent orchestration loop has started a new session. + AgentStart { provider: String, model: String }, /// A request is about to be sent to an LLM provider. /// /// This is emitted immediately before a provider call so observers can print @@ -24,6 +27,9 @@ pub enum ObserverEvent { success: bool, error_message: Option, }, + /// The agent session has finished. + /// + /// Carries aggregate usage data (tokens, cost) when the provider reports it. AgentEnd { provider: String, model: String, @@ -32,9 +38,8 @@ pub enum ObserverEvent { cost_usd: Option, }, /// A tool call is about to be executed. - ToolCallStart { - tool: String, - }, + ToolCallStart { tool: String }, + /// A tool call has completed with a success/failure outcome. ToolCall { tool: String, duration: Duration, @@ -42,41 +47,80 @@ pub enum ObserverEvent { }, /// The agent produced a final answer for the current user message. TurnComplete, + /// A message was sent or received through a channel. ChannelMessage { + /// Channel name (e.g., `"telegram"`, `"discord"`). channel: String, + /// `"inbound"` or `"outbound"`. direction: String, }, + /// Periodic heartbeat tick from the runtime keep-alive loop. HeartbeatTick, + /// An error occurred in a named component. Error { + /// Subsystem where the error originated (e.g., `"provider"`, `"gateway"`). component: String, + /// Human-readable error description. Must not contain secrets or tokens. message: String, }, } -/// Numeric metrics +/// Numeric metrics emitted by the agent runtime. +/// +/// Observers can aggregate these into dashboards, alerts, or structured logs. +/// Each variant carries a single scalar value with implicit units. #[derive(Debug, Clone)] pub enum ObserverMetric { + /// Time elapsed for a single LLM or tool request. RequestLatency(Duration), + /// Number of tokens consumed by an LLM call. TokensUsed(u64), + /// Current number of active concurrent sessions. ActiveSessions(u64), + /// Current depth of the inbound message queue. QueueDepth(u64), } -/// Core observability trait — implement for any backend +/// Core observability trait for recording agent runtime telemetry. +/// +/// Implement this trait to integrate with any monitoring backend (structured +/// logging, Prometheus, OpenTelemetry, etc.). The agent runtime holds one or +/// more `Observer` instances and calls [`record_event`](Observer::record_event) +/// and [`record_metric`](Observer::record_metric) at key lifecycle points. +/// +/// Implementations must be `Send + Sync + 'static` because the observer is +/// shared across async tasks via `Arc`. pub trait Observer: Send + Sync + 'static { - /// Record a discrete event + /// Record a discrete lifecycle event. + /// + /// Called synchronously on the hot path; implementations should avoid + /// blocking I/O. Buffer events internally and flush asynchronously + /// when possible. fn record_event(&self, event: &ObserverEvent); - /// Record a numeric metric + /// Record a numeric metric sample. + /// + /// Called synchronously; same non-blocking guidance as + /// [`record_event`](Observer::record_event). fn record_metric(&self, metric: &ObserverMetric); - /// Flush any buffered data (no-op for most backends) + /// Flush any buffered telemetry data to the backend. + /// + /// The runtime calls this during graceful shutdown. The default + /// implementation is a no-op, which is appropriate for backends + /// that write synchronously. fn flush(&self) {} - /// Human-readable name of this observer + /// Return the human-readable name of this observer backend. + /// + /// Used in logs and diagnostics (e.g., `"console"`, `"prometheus"`, + /// `"opentelemetry"`). fn name(&self) -> &str; - /// Downcast to `Any` for backend-specific operations + /// Downcast to `Any` for backend-specific operations. + /// + /// Enables callers to access concrete observer types when needed + /// (e.g., retrieving a Prometheus registry handle for custom metrics). fn as_any(&self) -> &dyn std::any::Any; } diff --git a/src/onboard/wizard.rs b/src/onboard/wizard.rs index d7aed97..9ba0975 100644 --- a/src/onboard/wizard.rs +++ b/src/onboard/wizard.rs @@ -1,5 +1,5 @@ use crate::config::schema::{ - DingTalkConfig, IrcConfig, LarkReceiveMode, QQConfig, StreamMode, WhatsAppConfig, + DingTalkConfig, IrcConfig, LarkReceiveMode, LinqConfig, QQConfig, StreamMode, WhatsAppConfig, }; use crate::config::{ AutonomyConfig, BrowserConfig, ChannelsConfig, ComposioConfig, Config, DiscordConfig, @@ -12,7 +12,8 @@ use crate::memory::{ }; use crate::providers::{ canonical_china_provider_name, is_glm_alias, is_glm_cn_alias, is_minimax_alias, - is_moonshot_alias, is_qianfan_alias, is_qwen_alias, is_zai_alias, is_zai_cn_alias, + is_moonshot_alias, is_qianfan_alias, is_qwen_alias, is_qwen_oauth_alias, is_zai_alias, + is_zai_cn_alias, }; use anyhow::{bail, Context, Result}; use console::style; @@ -58,9 +59,46 @@ const MODEL_CACHE_FILE: &str = "models_cache.json"; const MODEL_CACHE_TTL_SECS: u64 = 12 * 60 * 60; const CUSTOM_MODEL_SENTINEL: &str = "__custom_model__"; +fn has_launchable_channels(channels: &ChannelsConfig) -> bool { + let ChannelsConfig { + cli: _, // `cli` is always available and does not require channel server startup + webhook: _, // webhook traffic is handled by gateway, not `zeroclaw channel start` + telegram, + discord, + slack, + mattermost, + imessage, + matrix, + signal, + whatsapp, + email, + irc, + lark, + dingtalk, + linq, + qq, + .. + } = channels; + + telegram.is_some() + || discord.is_some() + || slack.is_some() + || mattermost.is_some() + || imessage.is_some() + || matrix.is_some() + || signal.is_some() + || whatsapp.is_some() + || email.is_some() + || irc.is_some() + || lark.is_some() + || dingtalk.is_some() + || linq.is_some() + || qq.is_some() +} + // ── Main wizard entry point ────────────────────────────────────── -pub fn run_wizard() -> Result { +pub async fn run_wizard() -> Result { println!("{}", style(BANNER).cyan().bold()); println!( @@ -122,7 +160,9 @@ pub fn run_wizard() -> Result { reliability: crate::config::ReliabilityConfig::default(), scheduler: crate::config::schema::SchedulerConfig::default(), agent: crate::config::schema::AgentConfig::default(), + skills: crate::config::SkillsConfig::default(), model_routes: Vec::new(), + embedding_routes: Vec::new(), heartbeat: HeartbeatConfig::default(), cron: crate::config::CronConfig::default(), channels_config, @@ -134,6 +174,7 @@ pub fn run_wizard() -> Result { secrets: secrets_config, browser: BrowserConfig::default(), http_request: crate::config::HttpRequestConfig::default(), + multimodal: crate::config::MultimodalConfig::default(), web_search: crate::config::WebSearchConfig::default(), proxy: crate::config::ProxyConfig::default(), identity: crate::config::IdentityConfig::default(), @@ -156,22 +197,14 @@ pub fn run_wizard() -> Result { if config.memory.auto_save { "on" } else { "off" } ); - config.save()?; - persist_workspace_selection(&config.config_path)?; + config.save().await?; + persist_workspace_selection(&config.config_path).await?; // ── Final summary ──────────────────────────────────────────── print_summary(&config); // ── Offer to launch channels immediately ───────────────────── - let has_channels = config.channels_config.telegram.is_some() - || config.channels_config.discord.is_some() - || config.channels_config.slack.is_some() - || config.channels_config.imessage.is_some() - || config.channels_config.matrix.is_some() - || config.channels_config.email.is_some() - || config.channels_config.dingtalk.is_some() - || config.channels_config.qq.is_some() - || config.channels_config.lark.is_some(); + let has_channels = has_launchable_channels(&config.channels_config); if has_channels && config.api_key.is_some() { let launch: bool = Confirm::new() @@ -199,7 +232,7 @@ pub fn run_wizard() -> Result { } /// Interactive repair flow: rerun channel setup only without redoing full onboarding. -pub fn run_channels_repair_wizard() -> Result { +pub async fn run_channels_repair_wizard() -> Result { println!("{}", style(BANNER).cyan().bold()); println!( " {}", @@ -209,12 +242,12 @@ pub fn run_channels_repair_wizard() -> Result { ); println!(); - let mut config = Config::load_or_init()?; + let mut config = Config::load_or_init().await?; print_step(1, 1, "Channels (How You Talk to ZeroClaw)"); config.channels_config = setup_channels()?; - config.save()?; - persist_workspace_selection(&config.config_path)?; + config.save().await?; + persist_workspace_selection(&config.config_path).await?; println!(); println!( @@ -223,15 +256,7 @@ pub fn run_channels_repair_wizard() -> Result { style(config.config_path.display()).green() ); - let has_channels = config.channels_config.telegram.is_some() - || config.channels_config.discord.is_some() - || config.channels_config.slack.is_some() - || config.channels_config.imessage.is_some() - || config.channels_config.matrix.is_some() - || config.channels_config.email.is_some() - || config.channels_config.dingtalk.is_some() - || config.channels_config.qq.is_some() - || config.channels_config.lark.is_some(); + let has_channels = has_launchable_channels(&config.channels_config); if has_channels && config.api_key.is_some() { let launch: bool = Confirm::new() @@ -302,10 +327,33 @@ fn memory_config_defaults_for_backend(backend: &str) -> MemoryConfig { } #[allow(clippy::too_many_lines)] -pub fn run_quick_setup( +pub async fn run_quick_setup( credential_override: Option<&str>, provider: Option<&str>, + model_override: Option<&str>, memory_backend: Option<&str>, +) -> Result { + let home = directories::UserDirs::new() + .map(|u| u.home_dir().to_path_buf()) + .context("Could not find home directory")?; + + run_quick_setup_with_home( + credential_override, + provider, + model_override, + memory_backend, + &home, + ) + .await +} + +#[allow(clippy::too_many_lines)] +async fn run_quick_setup_with_home( + credential_override: Option<&str>, + provider: Option<&str>, + model_override: Option<&str>, + memory_backend: Option<&str>, + home: &Path, ) -> Result { println!("{}", style(BANNER).cyan().bold()); println!( @@ -316,9 +364,6 @@ pub fn run_quick_setup( ); println!(); - let home = directories::UserDirs::new() - .map(|u| u.home_dir().to_path_buf()) - .context("Could not find home directory")?; let zeroclaw_dir = home.join(".zeroclaw"); let workspace_dir = zeroclaw_dir.join("workspace"); let config_path = zeroclaw_dir.join("config.toml"); @@ -326,7 +371,9 @@ pub fn run_quick_setup( fs::create_dir_all(&workspace_dir).context("Failed to create workspace directory")?; let provider_name = provider.unwrap_or("openrouter").to_string(); - let model = default_model_for_provider(&provider_name); + let model = model_override + .map(str::to_string) + .unwrap_or_else(|| default_model_for_provider(&provider_name)); let memory_backend_name = memory_backend .unwrap_or(default_memory_backend_key()) .to_string(); @@ -337,7 +384,11 @@ pub fn run_quick_setup( let config = Config { workspace_dir: workspace_dir.clone(), config_path: config_path.clone(), - api_key: credential_override.map(String::from), + api_key: credential_override.map(|c| { + let mut s = String::with_capacity(c.len()); + s.push_str(c); + s + }), api_url: None, default_provider: Some(provider_name.clone()), default_model: Some(model.clone()), @@ -348,7 +399,9 @@ pub fn run_quick_setup( reliability: crate::config::ReliabilityConfig::default(), scheduler: crate::config::schema::SchedulerConfig::default(), agent: crate::config::schema::AgentConfig::default(), + skills: crate::config::SkillsConfig::default(), model_routes: Vec::new(), + embedding_routes: Vec::new(), heartbeat: HeartbeatConfig::default(), cron: crate::config::CronConfig::default(), channels_config: ChannelsConfig::default(), @@ -360,6 +413,7 @@ pub fn run_quick_setup( secrets: SecretsConfig::default(), browser: BrowserConfig::default(), http_request: crate::config::HttpRequestConfig::default(), + multimodal: crate::config::MultimodalConfig::default(), web_search: crate::config::WebSearchConfig::default(), proxy: crate::config::ProxyConfig::default(), identity: crate::config::IdentityConfig::default(), @@ -370,8 +424,8 @@ pub fn run_quick_setup( query_classification: crate::config::QueryClassificationConfig::default(), }; - config.save()?; - persist_workspace_selection(&config.config_path)?; + config.save().await?; + persist_workspace_selection(&config.config_path).await?; // Scaffold minimal workspace files let default_ctx = ProjectContext { @@ -467,6 +521,10 @@ pub fn run_quick_setup( } fn canonical_provider_name(provider_name: &str) -> &str { + if is_qwen_oauth_alias(provider_name) { + return "qwen-code"; + } + if let Some(canonical) = canonical_china_provider_name(provider_name) { return canonical; } @@ -477,6 +535,7 @@ fn canonical_provider_name(provider_name: &str) -> &str { "google" | "google-gemini" => "gemini", "kimi_coding" | "kimi_for_coding" => "kimi-code", "nvidia-nim" | "build.nvidia.com" => "nvidia", + "aws-bedrock" => "bedrock", _ => provider_name, } } @@ -516,9 +575,11 @@ fn default_model_for_provider(provider: &str) -> String { "glm" | "zai" => "glm-5".into(), "minimax" => "MiniMax-M2.5".into(), "qwen" => "qwen-plus".into(), + "qwen-code" => "qwen3-coder-plus".into(), "ollama" => "llama3.2".into(), "gemini" => "gemini-2.5-pro".into(), "kimi-code" => "kimi-for-coding".into(), + "bedrock" => "anthropic.claude-sonnet-4-5-20250929-v1:0".into(), "nvidia" => "meta/llama-3.3-70b-instruct".into(), "astrai" => "anthropic/claude-sonnet-4.6".into(), _ => "anthropic/claude-sonnet-4.6".into(), @@ -791,6 +852,20 @@ fn curated_models_for_provider(provider_name: &str) -> Vec<(String, String)> { "Qwen Turbo (fast and cost-efficient)".to_string(), ), ], + "qwen-code" => vec![ + ( + "qwen3-coder-plus".to_string(), + "Qwen3 Coder Plus (recommended for coding workflows)".to_string(), + ), + ( + "qwen3.5-plus".to_string(), + "Qwen3.5 Plus (reasoning + coding)".to_string(), + ), + ( + "qwen3-max-2026-01-23".to_string(), + "Qwen3 Max (high-capability coding model)".to_string(), + ), + ], "nvidia" => vec![ ( "meta/llama-3.3-70b-instruct".to_string(), @@ -836,6 +911,24 @@ fn curated_models_for_provider(provider_name: &str) -> Vec<(String, String)> { ("codellama".to_string(), "Code Llama".to_string()), ("phi3".to_string(), "Phi-3 (small, fast)".to_string()), ], + "bedrock" => vec![ + ( + "anthropic.claude-sonnet-4-6".to_string(), + "Claude Sonnet 4.6 (latest, recommended)".to_string(), + ), + ( + "anthropic.claude-opus-4-6-v1".to_string(), + "Claude Opus 4.6 (strongest)".to_string(), + ), + ( + "anthropic.claude-haiku-4-5-20251001-v1:0".to_string(), + "Claude Haiku 4.5 (fastest, cheapest)".to_string(), + ), + ( + "anthropic.claude-sonnet-4-5-20250929-v1:0".to_string(), + "Claude Sonnet 4.5".to_string(), + ), + ], "gemini" => vec![ ( "gemini-3-pro-preview".to_string(), @@ -1433,16 +1526,18 @@ fn print_bullet(text: &str) { println!(" {} {}", style("›").cyan(), text); } -fn persist_workspace_selection(config_path: &Path) -> Result<()> { +async fn persist_workspace_selection(config_path: &Path) -> Result<()> { let config_dir = config_path .parent() .context("Config path must have a parent directory")?; - crate::config::schema::persist_active_workspace_config_dir(config_dir).with_context(|| { - format!( - "Failed to persist active workspace selection for {}", - config_dir.display() - ) - }) + crate::config::schema::persist_active_workspace_config_dir(config_dir) + .await + .with_context(|| { + format!( + "Failed to persist active workspace selection for {}", + config_dir.display() + ) + }) } // ── Step 1: Workspace ──────────────────────────────────────────── @@ -1549,6 +1644,10 @@ fn setup_provider(workspace_dir: &Path) -> Result<(String, String, String, Optio "kimi-code", "Kimi Code — coding-optimized Kimi API (KimiCLI)", ), + ( + "qwen-code", + "Qwen Code — OAuth tokens reused from ~/.qwen/oauth_creds.json", + ), ("moonshot", "Moonshot — Kimi API (China endpoint)"), ( "moonshot-intl", @@ -1757,11 +1856,48 @@ fn setup_provider(workspace_dir: &Path) -> Result<(String, String, String, Optio key } + } else if canonical_provider_name(provider_name) == "qwen-code" { + if std::env::var("QWEN_OAUTH_TOKEN").is_ok() { + print_bullet(&format!( + "{} QWEN_OAUTH_TOKEN environment variable detected!", + style("✓").green().bold() + )); + "qwen-oauth".to_string() + } else { + print_bullet( + "Qwen Code OAuth credentials are usually stored in ~/.qwen/oauth_creds.json.", + ); + print_bullet( + "Run `qwen` once and complete OAuth login to populate cached credentials.", + ); + print_bullet("You can also set QWEN_OAUTH_TOKEN directly."); + println!(); + + let key: String = Input::new() + .with_prompt( + " Paste your Qwen OAuth token (or press Enter to auto-detect cached OAuth)", + ) + .allow_empty(true) + .interact_text()?; + + if key.trim().is_empty() { + print_bullet(&format!( + "Using OAuth auto-detection. Set {} and optional {} if needed.", + style("QWEN_OAUTH_TOKEN").yellow(), + style("QWEN_OAUTH_RESOURCE_URL").yellow() + )); + "qwen-oauth".to_string() + } else { + key + } + } } else { let key_url = if is_moonshot_alias(provider_name) || canonical_provider_name(provider_name) == "kimi-code" { "https://platform.moonshot.cn/console/api-keys" + } else if canonical_provider_name(provider_name) == "qwen-code" { + "https://qwen.readthedocs.io/en/latest/getting_started/installation.html" } else if is_glm_cn_alias(provider_name) || is_zai_cn_alias(provider_name) { "https://open.bigmodel.cn/usercenter/proj-mgmt/apikeys" } else if is_glm_alias(provider_name) || is_zai_alias(provider_name) { @@ -1796,29 +1932,51 @@ fn setup_provider(workspace_dir: &Path) -> Result<(String, String, String, Optio }; println!(); - if !key_url.is_empty() { + if matches!(provider_name, "bedrock" | "aws-bedrock") { + // Bedrock uses AWS AKSK, not a single API key. + print_bullet("Bedrock uses AWS credentials (not a single API key)."); print_bullet(&format!( - "Get your API key at: {}", - style(key_url).cyan().underlined() + "Set {} and {} environment variables.", + style("AWS_ACCESS_KEY_ID").yellow(), + style("AWS_SECRET_ACCESS_KEY").yellow(), )); - } - print_bullet("You can also set it later via env var or config file."); - println!(); - - let key: String = Input::new() - .with_prompt(" Paste your API key (or press Enter to skip)") - .allow_empty(true) - .interact_text()?; - - if key.is_empty() { - let env_var = provider_env_var(provider_name); print_bullet(&format!( - "Skipped. Set {} or edit config.toml later.", - style(env_var).yellow() + "Optionally set {} for the region (default: us-east-1).", + style("AWS_REGION").yellow(), )); - } + if !key_url.is_empty() { + print_bullet(&format!( + "Manage IAM credentials at: {}", + style(key_url).cyan().underlined() + )); + } + println!(); + String::new() + } else { + if !key_url.is_empty() { + print_bullet(&format!( + "Get your API key at: {}", + style(key_url).cyan().underlined() + )); + } + print_bullet("You can also set it later via env var or config file."); + println!(); - key + let key: String = Input::new() + .with_prompt(" Paste your API key (or press Enter to skip)") + .allow_empty(true) + .interact_text()?; + + if key.is_empty() { + let env_var = provider_env_var(provider_name); + print_bullet(&format!( + "Skipped. Set {} or edit config.toml later.", + style(env_var).yellow() + )); + } + + key + } }; // ── Model selection ── @@ -1992,6 +2150,10 @@ fn setup_provider(workspace_dir: &Path) -> Result<(String, String, String, Optio /// Map provider name to its conventional env var fn provider_env_var(name: &str) -> &'static str { + if canonical_provider_name(name) == "qwen-code" { + return "QWEN_OAUTH_TOKEN"; + } + match canonical_provider_name(name) { "openrouter" => "OPENROUTER_API_KEY", "anthropic" => "ANTHROPIC_API_KEY", @@ -2448,125 +2610,156 @@ fn setup_channels() -> Result { print_bullet("CLI is always available. Connect more channels now."); println!(); - let mut config = ChannelsConfig { - cli: true, - telegram: None, - discord: None, - slack: None, - mattermost: None, - webhook: None, - imessage: None, - matrix: None, - signal: None, - whatsapp: None, - email: None, - irc: None, - lark: None, - dingtalk: None, - qq: None, - }; + let mut config = ChannelsConfig::default(); + #[derive(Clone, Copy)] + enum ChannelMenuChoice { + Telegram, + Discord, + Slack, + IMessage, + Matrix, + WhatsApp, + Linq, + Irc, + Webhook, + DingTalk, + QqOfficial, + LarkFeishu, + Done, + } + let menu_choices = [ + ChannelMenuChoice::Telegram, + ChannelMenuChoice::Discord, + ChannelMenuChoice::Slack, + ChannelMenuChoice::IMessage, + ChannelMenuChoice::Matrix, + ChannelMenuChoice::WhatsApp, + ChannelMenuChoice::Linq, + ChannelMenuChoice::Irc, + ChannelMenuChoice::Webhook, + ChannelMenuChoice::DingTalk, + ChannelMenuChoice::QqOfficial, + ChannelMenuChoice::LarkFeishu, + ChannelMenuChoice::Done, + ]; loop { - let options = vec![ - format!( - "Telegram {}", - if config.telegram.is_some() { - "✅ connected" - } else { - "— connect your bot" - } - ), - format!( - "Discord {}", - if config.discord.is_some() { - "✅ connected" - } else { - "— connect your bot" - } - ), - format!( - "Slack {}", - if config.slack.is_some() { - "✅ connected" - } else { - "— connect your bot" - } - ), - format!( - "iMessage {}", - if config.imessage.is_some() { - "✅ configured" - } else { - "— macOS only" - } - ), - format!( - "Matrix {}", - if config.matrix.is_some() { - "✅ connected" - } else { - "— self-hosted chat" - } - ), - format!( - "WhatsApp {}", - if config.whatsapp.is_some() { - "✅ connected" - } else { - "— Business Cloud API" - } - ), - format!( - "IRC {}", - if config.irc.is_some() { - "✅ configured" - } else { - "— IRC over TLS" - } - ), - format!( - "Webhook {}", - if config.webhook.is_some() { - "✅ configured" - } else { - "— HTTP endpoint" - } - ), - format!( - "DingTalk {}", - if config.dingtalk.is_some() { - "✅ connected" - } else { - "— DingTalk Stream Mode" - } - ), - format!( - "QQ Official {}", - if config.qq.is_some() { - "✅ connected" - } else { - "— Tencent QQ Bot" - } - ), - format!( - "Lark/Feishu {}", - if config.lark.is_some() { - "✅ connected" - } else { - "— Lark/Feishu Bot" - } - ), - "Done — finish setup".to_string(), - ]; + let options: Vec = menu_choices + .iter() + .map(|choice| match choice { + ChannelMenuChoice::Telegram => format!( + "Telegram {}", + if config.telegram.is_some() { + "✅ connected" + } else { + "— connect your bot" + } + ), + ChannelMenuChoice::Discord => format!( + "Discord {}", + if config.discord.is_some() { + "✅ connected" + } else { + "— connect your bot" + } + ), + ChannelMenuChoice::Slack => format!( + "Slack {}", + if config.slack.is_some() { + "✅ connected" + } else { + "— connect your bot" + } + ), + ChannelMenuChoice::IMessage => format!( + "iMessage {}", + if config.imessage.is_some() { + "✅ configured" + } else { + "— macOS only" + } + ), + ChannelMenuChoice::Matrix => format!( + "Matrix {}", + if config.matrix.is_some() { + "✅ connected" + } else { + "— self-hosted chat" + } + ), + ChannelMenuChoice::WhatsApp => format!( + "WhatsApp {}", + if config.whatsapp.is_some() { + "✅ connected" + } else { + "— Business Cloud API" + } + ), + ChannelMenuChoice::Linq => format!( + "Linq {}", + if config.linq.is_some() { + "✅ connected" + } else { + "— iMessage/RCS/SMS via Linq API" + } + ), + ChannelMenuChoice::Irc => format!( + "IRC {}", + if config.irc.is_some() { + "✅ configured" + } else { + "— IRC over TLS" + } + ), + ChannelMenuChoice::Webhook => format!( + "Webhook {}", + if config.webhook.is_some() { + "✅ configured" + } else { + "— HTTP endpoint" + } + ), + ChannelMenuChoice::DingTalk => format!( + "DingTalk {}", + if config.dingtalk.is_some() { + "✅ connected" + } else { + "— DingTalk Stream Mode" + } + ), + ChannelMenuChoice::QqOfficial => format!( + "QQ Official {}", + if config.qq.is_some() { + "✅ connected" + } else { + "— Tencent QQ Bot" + } + ), + ChannelMenuChoice::LarkFeishu => format!( + "Lark/Feishu {}", + if config.lark.is_some() { + "✅ connected" + } else { + "— Lark/Feishu Bot" + } + ), + ChannelMenuChoice::Done => "Done — finish setup".to_string(), + }) + .collect(); - let choice = Select::new() + let selection = Select::new() .with_prompt(" Connect a channel (or Done to continue)") .items(&options) - .default(11) + .default(options.len() - 1) .interact()?; + let choice = menu_choices + .get(selection) + .copied() + .unwrap_or(ChannelMenuChoice::Done); + match choice { - 0 => { + ChannelMenuChoice::Telegram => { // ── Telegram ── println!(); println!( @@ -2660,10 +2853,11 @@ fn setup_channels() -> Result { allowed_users, stream_mode: StreamMode::default(), draft_update_interval_ms: 1000, + interrupt_on_new_message: false, mention_only: false, }); } - 1 => { + ChannelMenuChoice::Discord => { // ── Discord ── println!(); println!( @@ -2762,7 +2956,7 @@ fn setup_channels() -> Result { mention_only: false, }); } - 2 => { + ChannelMenuChoice::Slack => { // ── Slack ── println!(); println!( @@ -2887,7 +3081,7 @@ fn setup_channels() -> Result { allowed_users, }); } - 3 => { + ChannelMenuChoice::IMessage => { // ── iMessage ── println!(); println!( @@ -2931,7 +3125,7 @@ fn setup_channels() -> Result { style(&contacts_str).cyan() ); } - 4 => { + ChannelMenuChoice::Matrix => { // ── Matrix ── println!(); println!( @@ -3043,13 +3237,95 @@ fn setup_channels() -> Result { allowed_users, }); } - 5 => { + ChannelMenuChoice::WhatsApp => { // ── WhatsApp ── println!(); + println!(" {}", style("WhatsApp Setup").white().bold()); + + let mode_options = vec![ + "WhatsApp Web (QR / pair-code, no Meta Business API)", + "WhatsApp Business Cloud API (webhook)", + ]; + let mode_idx = Select::new() + .with_prompt(" Choose WhatsApp mode") + .items(&mode_options) + .default(0) + .interact()?; + + if mode_idx == 0 { + println!(" {}", style("Mode: WhatsApp Web").dim()); + print_bullet("1. Build with --features whatsapp-web"); + print_bullet( + "2. Start channel/daemon and scan QR in WhatsApp > Linked Devices", + ); + print_bullet("3. Keep session_path persistent so relogin is not required"); + println!(); + + let session_path: String = Input::new() + .with_prompt(" Session database path") + .default("~/.zeroclaw/state/whatsapp-web/session.db".into()) + .interact_text()?; + + if session_path.trim().is_empty() { + println!(" {} Skipped — session path required", style("→").dim()); + continue; + } + + let pair_phone: String = Input::new() + .with_prompt( + " Pair phone (optional, digits only; leave empty to use QR flow)", + ) + .allow_empty(true) + .interact_text()?; + + let pair_code: String = if pair_phone.trim().is_empty() { + String::new() + } else { + Input::new() + .with_prompt( + " Custom pair code (optional, leave empty for auto-generated)", + ) + .allow_empty(true) + .interact_text()? + }; + + let users_str: String = Input::new() + .with_prompt( + " Allowed phone numbers (comma-separated +1234567890, or * for all)", + ) + .default("*".into()) + .interact_text()?; + + let allowed_numbers = if users_str.trim() == "*" { + vec!["*".into()] + } else { + users_str.split(',').map(|s| s.trim().to_string()).collect() + }; + + config.whatsapp = Some(WhatsAppConfig { + access_token: None, + phone_number_id: None, + verify_token: None, + app_secret: None, + session_path: Some(session_path.trim().to_string()), + pair_phone: (!pair_phone.trim().is_empty()) + .then(|| pair_phone.trim().to_string()), + pair_code: (!pair_code.trim().is_empty()) + .then(|| pair_code.trim().to_string()), + allowed_numbers, + }); + + println!( + " {} WhatsApp Web configuration saved.", + style("✅").green().bold() + ); + continue; + } + println!( " {} {}", - style("WhatsApp Setup").white().bold(), - style("— Business Cloud API").dim() + style("Mode:").dim(), + style("Business Cloud API").dim() ); print_bullet("1. Go to developers.facebook.com and create a WhatsApp app"); print_bullet("2. Add the WhatsApp product and get your phone number ID"); @@ -3130,14 +3406,109 @@ fn setup_channels() -> Result { }; config.whatsapp = Some(WhatsAppConfig { - access_token: access_token.trim().to_string(), - phone_number_id: phone_number_id.trim().to_string(), - verify_token: verify_token.trim().to_string(), + access_token: Some(access_token.trim().to_string()), + phone_number_id: Some(phone_number_id.trim().to_string()), + verify_token: Some(verify_token.trim().to_string()), app_secret: None, // Can be set via ZEROCLAW_WHATSAPP_APP_SECRET env var + session_path: None, + pair_phone: None, + pair_code: None, allowed_numbers, }); } - 6 => { + ChannelMenuChoice::Linq => { + // ── Linq ── + println!(); + println!( + " {} {}", + style("Linq Setup").white().bold(), + style("— iMessage/RCS/SMS via Linq API").dim() + ); + print_bullet("1. Sign up at linqapp.com and get your Partner API token"); + print_bullet("2. Note your Linq phone number (E.164 format)"); + print_bullet("3. Configure webhook URL to: https://your-domain/linq"); + println!(); + + let api_token: String = Input::new() + .with_prompt(" API token (Linq Partner API token)") + .interact_text()?; + + if api_token.trim().is_empty() { + println!(" {} Skipped", style("→").dim()); + continue; + } + + let from_phone: String = Input::new() + .with_prompt(" From phone number (E.164 format, e.g. +12223334444)") + .interact_text()?; + + if from_phone.trim().is_empty() { + println!(" {} Skipped — phone number required", style("→").dim()); + continue; + } + + // Test connection + print!(" {} Testing connection... ", style("⏳").dim()); + let api_token_clone = api_token.clone(); + let thread_result = std::thread::spawn(move || { + let client = reqwest::blocking::Client::new(); + let url = "https://api.linqapp.com/api/partner/v3/phonenumbers"; + let resp = client + .get(url) + .header( + "Authorization", + format!("Bearer {}", api_token_clone.trim()), + ) + .send()?; + Ok::<_, reqwest::Error>(resp.status().is_success()) + }) + .join(); + match thread_result { + Ok(Ok(true)) => { + println!( + "\r {} Connected to Linq API ", + style("✅").green().bold() + ); + } + _ => { + println!( + "\r {} Connection failed — check API token", + style("❌").red().bold() + ); + continue; + } + } + + let users_str: String = Input::new() + .with_prompt( + " Allowed sender numbers (comma-separated +1234567890, or * for all)", + ) + .default("*".into()) + .interact_text()?; + + let allowed_senders = if users_str.trim() == "*" { + vec!["*".into()] + } else { + users_str.split(',').map(|s| s.trim().to_string()).collect() + }; + + let signing_secret: String = Input::new() + .with_prompt(" Webhook signing secret (optional, press Enter to skip)") + .allow_empty(true) + .interact_text()?; + + config.linq = Some(LinqConfig { + api_token: api_token.trim().to_string(), + from_phone: from_phone.trim().to_string(), + signing_secret: if signing_secret.trim().is_empty() { + None + } else { + Some(signing_secret.trim().to_string()) + }, + allowed_senders, + }); + } + ChannelMenuChoice::Irc => { // ── IRC ── println!(); println!( @@ -3276,7 +3647,7 @@ fn setup_channels() -> Result { verify_tls: Some(verify_tls), }); } - 7 => { + ChannelMenuChoice::Webhook => { // ── Webhook ── println!(); println!( @@ -3309,7 +3680,7 @@ fn setup_channels() -> Result { style(&port).cyan() ); } - 8 => { + ChannelMenuChoice::DingTalk => { // ── DingTalk ── println!(); println!( @@ -3379,7 +3750,7 @@ fn setup_channels() -> Result { allowed_users, }); } - 9 => { + ChannelMenuChoice::QqOfficial => { // ── QQ Official ── println!(); println!( @@ -3455,7 +3826,7 @@ fn setup_channels() -> Result { allowed_users, }); } - 10 => { + ChannelMenuChoice::LarkFeishu => { // ── Lark/Feishu ── println!(); println!( @@ -3642,7 +4013,7 @@ fn setup_channels() -> Result { port, }); } - _ => break, // Done + ChannelMenuChoice::Done => break, } println!(); } @@ -3667,6 +4038,9 @@ fn setup_channels() -> Result { if config.whatsapp.is_some() { active.push("WhatsApp"); } + if config.linq.is_some() { + active.push("Linq"); + } if config.email.is_some() { active.push("Email"); } @@ -3726,10 +4100,10 @@ fn setup_tunnel() -> Result { 1 => { println!(); print_bullet("Get your tunnel token from the Cloudflare Zero Trust dashboard."); - let token: String = Input::new() + let tunnel_value: String = Input::new() .with_prompt(" Cloudflare tunnel token") .interact_text()?; - if token.trim().is_empty() { + if tunnel_value.trim().is_empty() { println!(" {} Skipped", style("→").dim()); TunnelConfig::default() } else { @@ -3740,7 +4114,9 @@ fn setup_tunnel() -> Result { ); TunnelConfig { provider: "cloudflare".into(), - cloudflare: Some(CloudflareTunnelConfig { token }), + cloudflare: Some(CloudflareTunnelConfig { + token: tunnel_value, + }), ..TunnelConfig::default() } } @@ -4130,15 +4506,7 @@ fn scaffold_workspace(workspace_dir: &Path, ctx: &ProjectContext) -> Result<()> #[allow(clippy::too_many_lines)] fn print_summary(config: &Config) { - let has_channels = config.channels_config.telegram.is_some() - || config.channels_config.discord.is_some() - || config.channels_config.slack.is_some() - || config.channels_config.imessage.is_some() - || config.channels_config.matrix.is_some() - || config.channels_config.email.is_some() - || config.channels_config.dingtalk.is_some() - || config.channels_config.qq.is_some() - || config.channels_config.lark.is_some(); + let has_channels = has_launchable_channels(&config.channels_config); println!(); println!( @@ -4408,6 +4776,48 @@ mod tests { assert!(ctx.communication_style.is_empty()); } + #[tokio::test] + async fn quick_setup_model_override_persists_to_config_toml() { + let tmp = TempDir::new().unwrap(); + + let config = run_quick_setup_with_home( + Some("sk-issue946"), + Some("openrouter"), + Some("custom-model-946"), + Some("sqlite"), + tmp.path(), + ) + .await + .unwrap(); + + assert_eq!(config.default_provider.as_deref(), Some("openrouter")); + assert_eq!(config.default_model.as_deref(), Some("custom-model-946")); + assert_eq!(config.api_key.as_deref(), Some("sk-issue946")); + + let config_raw = tokio::fs::read_to_string(config.config_path).await.unwrap(); + assert!(config_raw.contains("default_provider = \"openrouter\"")); + assert!(config_raw.contains("default_model = \"custom-model-946\"")); + } + + #[tokio::test] + async fn quick_setup_without_model_uses_provider_default_model() { + let tmp = TempDir::new().unwrap(); + + let config = run_quick_setup_with_home( + Some("sk-issue946"), + Some("anthropic"), + None, + Some("sqlite"), + tmp.path(), + ) + .await + .unwrap(); + + let expected = default_model_for_provider("anthropic"); + assert_eq!(config.default_provider.as_deref(), Some("anthropic")); + assert_eq!(config.default_model.as_deref(), Some(expected.as_str())); + } + // ── scaffold_workspace: basic file creation ───────────────── #[test] @@ -4444,8 +4854,8 @@ mod tests { // ── scaffold_workspace: personalization ───────────────────── - #[test] - fn scaffold_bakes_user_name_into_files() { + #[tokio::test] + async fn scaffold_bakes_user_name_into_files() { let tmp = TempDir::new().unwrap(); let ctx = ProjectContext { user_name: "Alice".into(), @@ -4453,21 +4863,25 @@ mod tests { }; scaffold_workspace(tmp.path(), &ctx).unwrap(); - let user_md = fs::read_to_string(tmp.path().join("USER.md")).unwrap(); + let user_md = tokio::fs::read_to_string(tmp.path().join("USER.md")) + .await + .unwrap(); assert!( user_md.contains("**Name:** Alice"), "USER.md should contain user name" ); - let bootstrap = fs::read_to_string(tmp.path().join("BOOTSTRAP.md")).unwrap(); + let bootstrap = tokio::fs::read_to_string(tmp.path().join("BOOTSTRAP.md")) + .await + .unwrap(); assert!( bootstrap.contains("**Alice**"), "BOOTSTRAP.md should contain user name" ); } - #[test] - fn scaffold_bakes_timezone_into_files() { + #[tokio::test] + async fn scaffold_bakes_timezone_into_files() { let tmp = TempDir::new().unwrap(); let ctx = ProjectContext { timezone: "US/Pacific".into(), @@ -4475,21 +4889,25 @@ mod tests { }; scaffold_workspace(tmp.path(), &ctx).unwrap(); - let user_md = fs::read_to_string(tmp.path().join("USER.md")).unwrap(); + let user_md = tokio::fs::read_to_string(tmp.path().join("USER.md")) + .await + .unwrap(); assert!( user_md.contains("**Timezone:** US/Pacific"), "USER.md should contain timezone" ); - let bootstrap = fs::read_to_string(tmp.path().join("BOOTSTRAP.md")).unwrap(); + let bootstrap = tokio::fs::read_to_string(tmp.path().join("BOOTSTRAP.md")) + .await + .unwrap(); assert!( bootstrap.contains("US/Pacific"), "BOOTSTRAP.md should contain timezone" ); } - #[test] - fn scaffold_bakes_agent_name_into_files() { + #[tokio::test] + async fn scaffold_bakes_agent_name_into_files() { let tmp = TempDir::new().unwrap(); let ctx = ProjectContext { agent_name: "Crabby".into(), @@ -4497,39 +4915,49 @@ mod tests { }; scaffold_workspace(tmp.path(), &ctx).unwrap(); - let identity = fs::read_to_string(tmp.path().join("IDENTITY.md")).unwrap(); + let identity = tokio::fs::read_to_string(tmp.path().join("IDENTITY.md")) + .await + .unwrap(); assert!( identity.contains("**Name:** Crabby"), "IDENTITY.md should contain agent name" ); - let soul = fs::read_to_string(tmp.path().join("SOUL.md")).unwrap(); + let soul = tokio::fs::read_to_string(tmp.path().join("SOUL.md")) + .await + .unwrap(); assert!( soul.contains("You are **Crabby**"), "SOUL.md should contain agent name" ); - let agents = fs::read_to_string(tmp.path().join("AGENTS.md")).unwrap(); + let agents = tokio::fs::read_to_string(tmp.path().join("AGENTS.md")) + .await + .unwrap(); assert!( agents.contains("Crabby Personal Assistant"), "AGENTS.md should contain agent name" ); - let heartbeat = fs::read_to_string(tmp.path().join("HEARTBEAT.md")).unwrap(); + let heartbeat = tokio::fs::read_to_string(tmp.path().join("HEARTBEAT.md")) + .await + .unwrap(); assert!( heartbeat.contains("Crabby"), "HEARTBEAT.md should contain agent name" ); - let bootstrap = fs::read_to_string(tmp.path().join("BOOTSTRAP.md")).unwrap(); + let bootstrap = tokio::fs::read_to_string(tmp.path().join("BOOTSTRAP.md")) + .await + .unwrap(); assert!( bootstrap.contains("Introduce yourself as Crabby"), "BOOTSTRAP.md should contain agent name" ); } - #[test] - fn scaffold_bakes_communication_style() { + #[tokio::test] + async fn scaffold_bakes_communication_style() { let tmp = TempDir::new().unwrap(); let ctx = ProjectContext { communication_style: "Be technical and detailed.".into(), @@ -4537,19 +4965,25 @@ mod tests { }; scaffold_workspace(tmp.path(), &ctx).unwrap(); - let soul = fs::read_to_string(tmp.path().join("SOUL.md")).unwrap(); + let soul = tokio::fs::read_to_string(tmp.path().join("SOUL.md")) + .await + .unwrap(); assert!( soul.contains("Be technical and detailed."), "SOUL.md should contain communication style" ); - let user_md = fs::read_to_string(tmp.path().join("USER.md")).unwrap(); + let user_md = tokio::fs::read_to_string(tmp.path().join("USER.md")) + .await + .unwrap(); assert!( user_md.contains("Be technical and detailed."), "USER.md should contain communication style" ); - let bootstrap = fs::read_to_string(tmp.path().join("BOOTSTRAP.md")).unwrap(); + let bootstrap = tokio::fs::read_to_string(tmp.path().join("BOOTSTRAP.md")) + .await + .unwrap(); assert!( bootstrap.contains("Be technical and detailed."), "BOOTSTRAP.md should contain communication style" @@ -4558,19 +4992,23 @@ mod tests { // ── scaffold_workspace: defaults when context is empty ────── - #[test] - fn scaffold_uses_defaults_for_empty_context() { + #[tokio::test] + async fn scaffold_uses_defaults_for_empty_context() { let tmp = TempDir::new().unwrap(); let ctx = ProjectContext::default(); // all empty scaffold_workspace(tmp.path(), &ctx).unwrap(); - let identity = fs::read_to_string(tmp.path().join("IDENTITY.md")).unwrap(); + let identity = tokio::fs::read_to_string(tmp.path().join("IDENTITY.md")) + .await + .unwrap(); assert!( identity.contains("**Name:** ZeroClaw"), "should default agent name to ZeroClaw" ); - let user_md = fs::read_to_string(tmp.path().join("USER.md")).unwrap(); + let user_md = tokio::fs::read_to_string(tmp.path().join("USER.md")) + .await + .unwrap(); assert!( user_md.contains("**Name:** User"), "should default user name to User" @@ -4580,7 +5018,9 @@ mod tests { "should default timezone to UTC" ); - let soul = fs::read_to_string(tmp.path().join("SOUL.md")).unwrap(); + let soul = tokio::fs::read_to_string(tmp.path().join("SOUL.md")) + .await + .unwrap(); assert!( soul.contains("Be warm, natural, and clear."), "should default communication style" @@ -4589,8 +5029,8 @@ mod tests { // ── scaffold_workspace: skip existing files ───────────────── - #[test] - fn scaffold_does_not_overwrite_existing_files() { + #[tokio::test] + async fn scaffold_does_not_overwrite_existing_files() { let tmp = TempDir::new().unwrap(); let ctx = ProjectContext { user_name: "Bob".into(), @@ -4604,7 +5044,7 @@ mod tests { scaffold_workspace(tmp.path(), &ctx).unwrap(); // SOUL.md should be untouched - let soul = fs::read_to_string(&soul_path).unwrap(); + let soul = tokio::fs::read_to_string(&soul_path).await.unwrap(); assert!( soul.contains("Do not overwrite me"), "existing files should not be overwritten" @@ -4615,14 +5055,16 @@ mod tests { ); // But USER.md should be created fresh - let user_md = fs::read_to_string(tmp.path().join("USER.md")).unwrap(); + let user_md = tokio::fs::read_to_string(tmp.path().join("USER.md")) + .await + .unwrap(); assert!(user_md.contains("**Name:** Bob")); } // ── scaffold_workspace: idempotent ────────────────────────── - #[test] - fn scaffold_is_idempotent() { + #[tokio::test] + async fn scaffold_is_idempotent() { let tmp = TempDir::new().unwrap(); let ctx = ProjectContext { user_name: "Eve".into(), @@ -4631,19 +5073,23 @@ mod tests { }; scaffold_workspace(tmp.path(), &ctx).unwrap(); - let soul_v1 = fs::read_to_string(tmp.path().join("SOUL.md")).unwrap(); + let soul_v1 = tokio::fs::read_to_string(tmp.path().join("SOUL.md")) + .await + .unwrap(); // Run again — should not change anything scaffold_workspace(tmp.path(), &ctx).unwrap(); - let soul_v2 = fs::read_to_string(tmp.path().join("SOUL.md")).unwrap(); + let soul_v2 = tokio::fs::read_to_string(tmp.path().join("SOUL.md")) + .await + .unwrap(); assert_eq!(soul_v1, soul_v2, "scaffold should be idempotent"); } // ── scaffold_workspace: all files are non-empty ───────────── - #[test] - fn scaffold_files_are_non_empty() { + #[tokio::test] + async fn scaffold_files_are_non_empty() { let tmp = TempDir::new().unwrap(); let ctx = ProjectContext::default(); scaffold_workspace(tmp.path(), &ctx).unwrap(); @@ -4658,20 +5104,22 @@ mod tests { "BOOTSTRAP.md", "MEMORY.md", ] { - let content = fs::read_to_string(tmp.path().join(f)).unwrap(); + let content = tokio::fs::read_to_string(tmp.path().join(f)).await.unwrap(); assert!(!content.trim().is_empty(), "{f} should not be empty"); } } // ── scaffold_workspace: AGENTS.md references on-demand memory - #[test] - fn agents_md_references_on_demand_memory() { + #[tokio::test] + async fn agents_md_references_on_demand_memory() { let tmp = TempDir::new().unwrap(); let ctx = ProjectContext::default(); scaffold_workspace(tmp.path(), &ctx).unwrap(); - let agents = fs::read_to_string(tmp.path().join("AGENTS.md")).unwrap(); + let agents = tokio::fs::read_to_string(tmp.path().join("AGENTS.md")) + .await + .unwrap(); assert!( agents.contains("memory_recall"), "AGENTS.md should reference memory_recall for on-demand access" @@ -4684,13 +5132,15 @@ mod tests { // ── scaffold_workspace: MEMORY.md warns about token cost ──── - #[test] - fn memory_md_warns_about_token_cost() { + #[tokio::test] + async fn memory_md_warns_about_token_cost() { let tmp = TempDir::new().unwrap(); let ctx = ProjectContext::default(); scaffold_workspace(tmp.path(), &ctx).unwrap(); - let memory = fs::read_to_string(tmp.path().join("MEMORY.md")).unwrap(); + let memory = tokio::fs::read_to_string(tmp.path().join("MEMORY.md")) + .await + .unwrap(); assert!( memory.contains("costs tokens"), "MEMORY.md should warn about token cost" @@ -4703,13 +5153,15 @@ mod tests { // ── scaffold_workspace: TOOLS.md lists memory_forget ──────── - #[test] - fn tools_md_lists_all_builtin_tools() { + #[tokio::test] + async fn tools_md_lists_all_builtin_tools() { let tmp = TempDir::new().unwrap(); let ctx = ProjectContext::default(); scaffold_workspace(tmp.path(), &ctx).unwrap(); - let tools = fs::read_to_string(tmp.path().join("TOOLS.md")).unwrap(); + let tools = tokio::fs::read_to_string(tmp.path().join("TOOLS.md")) + .await + .unwrap(); for tool in &[ "shell", "file_read", @@ -4733,13 +5185,15 @@ mod tests { ); } - #[test] - fn soul_md_includes_emoji_awareness_guidance() { + #[tokio::test] + async fn soul_md_includes_emoji_awareness_guidance() { let tmp = TempDir::new().unwrap(); let ctx = ProjectContext::default(); scaffold_workspace(tmp.path(), &ctx).unwrap(); - let soul = fs::read_to_string(tmp.path().join("SOUL.md")).unwrap(); + let soul = tokio::fs::read_to_string(tmp.path().join("SOUL.md")) + .await + .unwrap(); assert!( soul.contains("Use emojis naturally (0-2 max"), "SOUL.md should include emoji usage guidance" @@ -4752,8 +5206,8 @@ mod tests { // ── scaffold_workspace: special characters in names ───────── - #[test] - fn scaffold_handles_special_characters_in_names() { + #[tokio::test] + async fn scaffold_handles_special_characters_in_names() { let tmp = TempDir::new().unwrap(); let ctx = ProjectContext { user_name: "José María".into(), @@ -4763,17 +5217,21 @@ mod tests { }; scaffold_workspace(tmp.path(), &ctx).unwrap(); - let user_md = fs::read_to_string(tmp.path().join("USER.md")).unwrap(); + let user_md = tokio::fs::read_to_string(tmp.path().join("USER.md")) + .await + .unwrap(); assert!(user_md.contains("José María")); - let soul = fs::read_to_string(tmp.path().join("SOUL.md")).unwrap(); + let soul = tokio::fs::read_to_string(tmp.path().join("SOUL.md")) + .await + .unwrap(); assert!(soul.contains("ZeroClaw-v2")); } // ── scaffold_workspace: full personalization round-trip ───── - #[test] - fn scaffold_full_personalization() { + #[tokio::test] + async fn scaffold_full_personalization() { let tmp = TempDir::new().unwrap(); let ctx = ProjectContext { user_name: "Argenis".into(), @@ -4786,27 +5244,39 @@ mod tests { scaffold_workspace(tmp.path(), &ctx).unwrap(); // Verify every file got personalized - let identity = fs::read_to_string(tmp.path().join("IDENTITY.md")).unwrap(); + let identity = tokio::fs::read_to_string(tmp.path().join("IDENTITY.md")) + .await + .unwrap(); assert!(identity.contains("**Name:** Claw")); - let soul = fs::read_to_string(tmp.path().join("SOUL.md")).unwrap(); + let soul = tokio::fs::read_to_string(tmp.path().join("SOUL.md")) + .await + .unwrap(); assert!(soul.contains("You are **Claw**")); assert!(soul.contains("Be friendly, human, and conversational")); - let user_md = fs::read_to_string(tmp.path().join("USER.md")).unwrap(); + let user_md = tokio::fs::read_to_string(tmp.path().join("USER.md")) + .await + .unwrap(); assert!(user_md.contains("**Name:** Argenis")); assert!(user_md.contains("**Timezone:** US/Eastern")); assert!(user_md.contains("Be friendly, human, and conversational")); - let agents = fs::read_to_string(tmp.path().join("AGENTS.md")).unwrap(); + let agents = tokio::fs::read_to_string(tmp.path().join("AGENTS.md")) + .await + .unwrap(); assert!(agents.contains("Claw Personal Assistant")); - let bootstrap = fs::read_to_string(tmp.path().join("BOOTSTRAP.md")).unwrap(); + let bootstrap = tokio::fs::read_to_string(tmp.path().join("BOOTSTRAP.md")) + .await + .unwrap(); assert!(bootstrap.contains("**Argenis**")); assert!(bootstrap.contains("US/Eastern")); assert!(bootstrap.contains("Introduce yourself as Claw")); - let heartbeat = fs::read_to_string(tmp.path().join("HEARTBEAT.md")).unwrap(); + let heartbeat = tokio::fs::read_to_string(tmp.path().join("HEARTBEAT.md")) + .await + .unwrap(); assert!(heartbeat.contains("Claw")); } @@ -4826,12 +5296,17 @@ mod tests { ); assert_eq!(default_model_for_provider("qwen"), "qwen-plus"); assert_eq!(default_model_for_provider("qwen-intl"), "qwen-plus"); + assert_eq!(default_model_for_provider("qwen-code"), "qwen3-coder-plus"); assert_eq!(default_model_for_provider("glm-cn"), "glm-5"); assert_eq!(default_model_for_provider("minimax-cn"), "MiniMax-M2.5"); assert_eq!(default_model_for_provider("zai-cn"), "glm-5"); assert_eq!(default_model_for_provider("gemini"), "gemini-2.5-pro"); assert_eq!(default_model_for_provider("google"), "gemini-2.5-pro"); assert_eq!(default_model_for_provider("kimi-code"), "kimi-for-coding"); + assert_eq!( + default_model_for_provider("bedrock"), + "anthropic.claude-sonnet-4-5-20250929-v1:0" + ); assert_eq!( default_model_for_provider("google-gemini"), "gemini-2.5-pro" @@ -4856,6 +5331,8 @@ mod tests { fn canonical_provider_name_normalizes_regional_aliases() { assert_eq!(canonical_provider_name("qwen-intl"), "qwen"); assert_eq!(canonical_provider_name("dashscope-us"), "qwen"); + assert_eq!(canonical_provider_name("qwen-code"), "qwen-code"); + assert_eq!(canonical_provider_name("qwen-oauth"), "qwen-code"); assert_eq!(canonical_provider_name("moonshot-intl"), "moonshot"); assert_eq!(canonical_provider_name("kimi-cn"), "moonshot"); assert_eq!(canonical_provider_name("kimi_coding"), "kimi-code"); @@ -4866,6 +5343,7 @@ mod tests { assert_eq!(canonical_provider_name("zai-cn"), "zai"); assert_eq!(canonical_provider_name("z.ai-global"), "zai"); assert_eq!(canonical_provider_name("nvidia-nim"), "nvidia"); + assert_eq!(canonical_provider_name("aws-bedrock"), "bedrock"); assert_eq!(canonical_provider_name("build.nvidia.com"), "nvidia"); } @@ -4915,6 +5393,19 @@ mod tests { assert!(ids.contains(&"anthropic/claude-sonnet-4.6".to_string())); } + #[test] + fn curated_models_for_bedrock_include_verified_model_ids() { + let ids: Vec = curated_models_for_provider("bedrock") + .into_iter() + .map(|(id, _)| id) + .collect(); + + assert!(ids.contains(&"anthropic.claude-sonnet-4-6".to_string())); + assert!(ids.contains(&"anthropic.claude-opus-4-6-v1".to_string())); + assert!(ids.contains(&"anthropic.claude-haiku-4-5-20251001-v1:0".to_string())); + assert!(ids.contains(&"anthropic.claude-sonnet-4-5-20250929-v1:0".to_string())); + } + #[test] fn curated_models_for_moonshot_drop_deprecated_aliases() { let ids: Vec = curated_models_for_provider("moonshot") @@ -4952,6 +5443,18 @@ mod tests { assert!(ids.contains(&"kimi-k2.5".to_string())); } + #[test] + fn curated_models_for_qwen_code_include_coding_plan_models() { + let ids: Vec = curated_models_for_provider("qwen-code") + .into_iter() + .map(|(id, _)| id) + .collect(); + + assert!(ids.contains(&"qwen3-coder-plus".to_string())); + assert!(ids.contains(&"qwen3.5-plus".to_string())); + assert!(ids.contains(&"qwen3-max-2026-01-23".to_string())); + } + #[test] fn supports_live_model_fetch_for_supported_and_unsupported_providers() { assert!(supports_live_model_fetch("openai")); @@ -5014,6 +5517,10 @@ mod tests { curated_models_for_provider("nvidia"), curated_models_for_provider("build.nvidia.com") ); + assert_eq!( + curated_models_for_provider("bedrock"), + curated_models_for_provider("aws-bedrock") + ); } #[test] @@ -5190,7 +5697,8 @@ mod tests { let config = Config { workspace_dir: tmp.path().to_path_buf(), - default_provider: Some("venice".to_string()), + // Use a non-provider channel key to keep this test deterministic and offline. + default_provider: Some("imessage".to_string()), ..Config::default() }; @@ -5218,6 +5726,8 @@ mod tests { assert_eq!(provider_env_var("qwen"), "DASHSCOPE_API_KEY"); assert_eq!(provider_env_var("qwen-intl"), "DASHSCOPE_API_KEY"); assert_eq!(provider_env_var("dashscope-us"), "DASHSCOPE_API_KEY"); + assert_eq!(provider_env_var("qwen-code"), "QWEN_OAUTH_TOKEN"); + assert_eq!(provider_env_var("qwen-oauth"), "QWEN_OAUTH_TOKEN"); assert_eq!(provider_env_var("glm-cn"), "GLM_API_KEY"); assert_eq!(provider_env_var("minimax-cn"), "MINIMAX_API_KEY"); assert_eq!(provider_env_var("kimi-code"), "KIMI_CODE_API_KEY"); @@ -5289,4 +5799,28 @@ mod tests { assert_eq!(config.purge_after_days, 0); assert_eq!(config.embedding_cache_size, 0); } + + #[test] + fn launchable_channels_include_mattermost_and_qq() { + let mut channels = ChannelsConfig::default(); + assert!(!has_launchable_channels(&channels)); + + channels.mattermost = Some(crate::config::schema::MattermostConfig { + url: "https://mattermost.example.com".into(), + bot_token: "token".into(), + channel_id: Some("channel".into()), + allowed_users: vec!["*".into()], + thread_replies: Some(true), + mention_only: Some(false), + }); + assert!(has_launchable_channels(&channels)); + + channels.mattermost = None; + channels.qq = Some(crate::config::schema::QQConfig { + app_id: "app-id".into(), + app_secret: "app-secret".into(), + allowed_users: vec!["*".into()], + }); + assert!(has_launchable_channels(&channels)); + } } diff --git a/src/peripherals/arduino_upload.rs b/src/peripherals/arduino_upload.rs index e11b19f..57a4f61 100644 --- a/src/peripherals/arduino_upload.rs +++ b/src/peripherals/arduino_upload.rs @@ -75,7 +75,7 @@ impl Tool for ArduinoUploadTool { let sketch_dir = temp_dir.join(sketch_name); let ino_path = sketch_dir.join(format!("{}.ino", sketch_name)); - if let Err(e) = std::fs::create_dir_all(&sketch_dir) { + if let Err(e) = tokio::fs::create_dir_all(&sketch_dir).await { return Ok(ToolResult { success: false, output: format!("Failed to create sketch dir: {}", e), @@ -83,8 +83,8 @@ impl Tool for ArduinoUploadTool { }); } - if let Err(e) = std::fs::write(&ino_path, code) { - let _ = std::fs::remove_dir_all(&temp_dir); + if let Err(e) = tokio::fs::write(&ino_path, code).await { + let _ = tokio::fs::remove_dir_all(&temp_dir).await; return Ok(ToolResult { success: false, output: format!("Failed to write sketch: {}", e), @@ -103,7 +103,7 @@ impl Tool for ArduinoUploadTool { let compile_output = match compile { Ok(o) => o, Err(e) => { - let _ = std::fs::remove_dir_all(&temp_dir); + let _ = tokio::fs::remove_dir_all(&temp_dir).await; return Ok(ToolResult { success: false, output: format!("arduino-cli compile failed: {}", e), @@ -114,7 +114,7 @@ impl Tool for ArduinoUploadTool { if !compile_output.status.success() { let stderr = String::from_utf8_lossy(&compile_output.stderr); - let _ = std::fs::remove_dir_all(&temp_dir); + let _ = tokio::fs::remove_dir_all(&temp_dir).await; return Ok(ToolResult { success: false, output: format!("Compile failed:\n{}", stderr), @@ -130,7 +130,7 @@ impl Tool for ArduinoUploadTool { let upload_output = match upload { Ok(o) => o, Err(e) => { - let _ = std::fs::remove_dir_all(&temp_dir); + let _ = tokio::fs::remove_dir_all(&temp_dir).await; return Ok(ToolResult { success: false, output: format!("arduino-cli upload failed: {}", e), @@ -139,7 +139,7 @@ impl Tool for ArduinoUploadTool { } }; - let _ = std::fs::remove_dir_all(&temp_dir); + let _ = tokio::fs::remove_dir_all(&temp_dir).await; if !upload_output.status.success() { let stderr = String::from_utf8_lossy(&upload_output.stderr); diff --git a/src/peripherals/mod.rs b/src/peripherals/mod.rs index f3f8a8a..6ae1c49 100644 --- a/src/peripherals/mod.rs +++ b/src/peripherals/mod.rs @@ -42,7 +42,7 @@ pub fn list_configured_boards(config: &PeripheralsConfig) -> Vec<&PeripheralBoar /// Handle `zeroclaw peripheral` subcommands. #[allow(clippy::module_name_repetitions)] -pub fn handle_command(cmd: crate::PeripheralCommands, config: &Config) -> Result<()> { +pub async fn handle_command(cmd: crate::PeripheralCommands, config: &Config) -> Result<()> { match cmd { crate::PeripheralCommands::List => { let boards = list_configured_boards(&config.peripherals); @@ -76,7 +76,7 @@ pub fn handle_command(cmd: crate::PeripheralCommands, config: &Config) -> Result Some(path.clone()) }; - let mut cfg = crate::config::Config::load_or_init()?; + let mut cfg = crate::config::Config::load_or_init().await?; cfg.peripherals.enabled = true; if cfg @@ -95,7 +95,7 @@ pub fn handle_command(cmd: crate::PeripheralCommands, config: &Config) -> Result path: path_opt, baud: 115_200, }); - cfg.save()?; + cfg.save().await?; println!("Added {} at {}. Restart daemon to apply.", board, path); } #[cfg(feature = "hardware")] @@ -231,3 +231,73 @@ pub async fn create_peripheral_tools(config: &PeripheralsConfig) -> Result Result>> { Ok(Vec::new()) } + +#[cfg(test)] +mod tests { + use super::*; + use crate::config::{PeripheralBoardConfig, PeripheralsConfig}; + + #[test] + fn list_configured_boards_when_disabled_returns_empty() { + let config = PeripheralsConfig { + enabled: false, + boards: vec![PeripheralBoardConfig { + board: "nucleo-f401re".into(), + transport: "serial".into(), + path: Some("/dev/ttyACM0".into()), + baud: 115_200, + }], + datasheet_dir: None, + }; + let result = list_configured_boards(&config); + assert!(result.is_empty(), "disabled peripherals should return no boards"); + } + + #[test] + fn list_configured_boards_when_enabled_with_boards() { + let config = PeripheralsConfig { + enabled: true, + boards: vec![ + PeripheralBoardConfig { + board: "nucleo-f401re".into(), + transport: "serial".into(), + path: Some("/dev/ttyACM0".into()), + baud: 115_200, + }, + PeripheralBoardConfig { + board: "rpi-gpio".into(), + transport: "native".into(), + path: None, + baud: 115_200, + }, + ], + datasheet_dir: None, + }; + let result = list_configured_boards(&config); + assert_eq!(result.len(), 2); + assert_eq!(result[0].board, "nucleo-f401re"); + assert_eq!(result[1].board, "rpi-gpio"); + } + + #[test] + fn list_configured_boards_when_enabled_but_no_boards() { + let config = PeripheralsConfig { + enabled: true, + boards: vec![], + datasheet_dir: None, + }; + let result = list_configured_boards(&config); + assert!(result.is_empty(), "enabled with no boards should return empty"); + } + + #[tokio::test] + async fn create_peripheral_tools_returns_empty_when_disabled() { + let config = PeripheralsConfig { + enabled: false, + boards: vec![], + datasheet_dir: None, + }; + let tools = create_peripheral_tools(&config).await.unwrap(); + assert!(tools.is_empty(), "disabled peripherals should produce no tools"); + } +} diff --git a/src/peripherals/traits.rs b/src/peripherals/traits.rs index 6081d1d..0e27065 100644 --- a/src/peripherals/traits.rs +++ b/src/peripherals/traits.rs @@ -2,32 +2,74 @@ //! //! Peripherals are the agent's "arms and legs": remote devices that run minimal //! firmware and expose capabilities (GPIO, sensors, actuators) as tools. +//! See `docs/hardware-peripherals-design.md` for the communication protocol +//! and firmware integration guide. use async_trait::async_trait; use crate::tools::Tool; -/// A hardware peripheral that exposes capabilities as tools. +/// A hardware peripheral that exposes capabilities as agent tools. /// -/// Implement this for boards like Nucleo-F401RE (serial), RPi GPIO (native), etc. -/// When connected, the peripheral's tools are merged into the agent's tool registry. +/// Implement this trait for each supported board type (e.g., Nucleo-F401RE +/// over serial, Raspberry Pi GPIO via sysfs/gpiod). When the agent connects +/// to a peripheral, the tools returned by [`tools`](Peripheral::tools) are +/// merged into the agent's tool registry, making hardware capabilities +/// available to the LLM as callable functions. +/// +/// The lifecycle follows a connect → use → disconnect pattern. Implementations +/// must be `Send + Sync` because the peripheral may be accessed from multiple +/// async tasks after connection. #[async_trait] pub trait Peripheral: Send + Sync { - /// Human-readable peripheral name (e.g. "nucleo-f401re-0") + /// Return the human-readable instance name of this peripheral. + /// + /// Should uniquely identify a specific device instance, including an index + /// or serial number when multiple boards of the same type are connected + /// (e.g., `"nucleo-f401re-0"`, `"rpi-gpio-hat-1"`). fn name(&self) -> &str; - /// Board type identifier (e.g. "nucleo-f401re", "rpi-gpio") + /// Return the board type identifier for this peripheral. + /// + /// A stable, lowercase string used in configuration and factory registration + /// (e.g., `"nucleo-f401re"`, `"rpi-gpio"`). Must match the key used in + /// the config schema's peripheral section. fn board_type(&self) -> &str; - /// Connect to the peripheral (open serial, init GPIO, etc.) + /// Establish a connection to the peripheral hardware. + /// + /// Opens the underlying transport (serial port, GPIO bus, I²C, etc.) and + /// performs any initialization handshake required by the firmware. + /// + /// # Errors + /// + /// Returns an error if the device is unreachable, the transport cannot be + /// opened, or the firmware handshake fails. async fn connect(&mut self) -> anyhow::Result<()>; - /// Disconnect and release resources + /// Disconnect from the peripheral and release all held resources. + /// + /// Closes serial ports, unexports GPIO pins, and performs any cleanup + /// required for a safe shutdown. After this call, [`health_check`](Peripheral::health_check) + /// should return `false` until [`connect`](Peripheral::connect) is called again. + /// + /// # Errors + /// + /// Returns an error if resource cleanup fails (e.g., serial port busy). async fn disconnect(&mut self) -> anyhow::Result<()>; - /// Check if the peripheral is reachable and responsive + /// Check whether the peripheral is reachable and responsive. + /// + /// Performs a lightweight probe (e.g., a ping command over serial) without + /// altering device state. Returns `true` if the device responds within an + /// implementation-defined timeout. async fn health_check(&self) -> bool; - /// Tools this peripheral provides (e.g. gpio_read, gpio_write, sensor_read) + /// Return the tools this peripheral exposes to the agent. + /// + /// Each returned [`Tool`] delegates execution to the underlying hardware + /// (e.g., `gpio_read`, `gpio_write`, `sensor_read`). The agent merges + /// these into its tool registry after a successful + /// [`connect`](Peripheral::connect). fn tools(&self) -> Vec>; } diff --git a/src/providers/anthropic.rs b/src/providers/anthropic.rs index 469c981..31798fb 100644 --- a/src/providers/anthropic.rs +++ b/src/providers/anthropic.rs @@ -42,7 +42,7 @@ struct ContentBlock { } #[derive(Debug, Serialize)] -struct NativeChatRequest { +struct NativeChatRequest<'a> { model: String, max_tokens: u32, #[serde(skip_serializing_if = "Option::is_none")] @@ -50,7 +50,7 @@ struct NativeChatRequest { messages: Vec, temperature: f64, #[serde(skip_serializing_if = "Option::is_none")] - tools: Option>, + tools: Option>>, } #[derive(Debug, Serialize)] @@ -86,10 +86,10 @@ enum NativeContentOut { } #[derive(Debug, Serialize)] -struct NativeToolSpec { - name: String, - description: String, - input_schema: serde_json::Value, +struct NativeToolSpec<'a> { + name: &'a str, + description: &'a str, + input_schema: &'a serde_json::Value, #[serde(skip_serializing_if = "Option::is_none")] cache_control: Option, } @@ -206,17 +206,17 @@ impl AnthropicProvider { } } - fn convert_tools(tools: Option<&[ToolSpec]>) -> Option> { + fn convert_tools<'a>(tools: Option<&'a [ToolSpec]>) -> Option>> { let items = tools?; if items.is_empty() { return None; } - let mut native_tools: Vec = items + let mut native_tools: Vec> = items .iter() .map(|tool| NativeToolSpec { - name: tool.name.clone(), - description: tool.description.clone(), - input_schema: tool.parameters.clone(), + name: &tool.name, + description: &tool.description, + input_schema: &tool.parameters, cache_control: None, }) .collect(); @@ -497,6 +497,53 @@ impl Provider for AnthropicProvider { true } + async fn chat_with_tools( + &self, + messages: &[ChatMessage], + tools: &[serde_json::Value], + model: &str, + temperature: f64, + ) -> anyhow::Result { + // Convert OpenAI-format tool JSON to ToolSpec so we can reuse the + // existing `chat()` method which handles full message history, + // system prompt extraction, caching, and Anthropic native formatting. + let tool_specs: Vec = tools + .iter() + .filter_map(|t| { + let func = t.get("function").or_else(|| { + tracing::warn!("Skipping malformed tool definition (missing 'function' key)"); + None + })?; + let name = func.get("name").and_then(|n| n.as_str()).or_else(|| { + tracing::warn!("Skipping tool with missing or non-string 'name'"); + None + })?; + Some(ToolSpec { + name: name.to_string(), + description: func + .get("description") + .and_then(|d| d.as_str()) + .unwrap_or("") + .to_string(), + parameters: func + .get("parameters") + .cloned() + .unwrap_or(serde_json::json!({"type": "object"})), + }) + }) + .collect(); + + let request = ProviderChatRequest { + messages, + tools: if tool_specs.is_empty() { + None + } else { + Some(&tool_specs) + }, + }; + self.chat(request, model, temperature).await + } + async fn warmup(&self) -> anyhow::Result<()> { if let Some(credential) = self.credential.as_ref() { let mut request = self @@ -828,10 +875,11 @@ mod tests { #[test] fn native_tool_spec_without_cache_control() { + let schema = serde_json::json!({"type": "object"}); let tool = NativeToolSpec { - name: "get_weather".to_string(), - description: "Get weather info".to_string(), - input_schema: serde_json::json!({"type": "object"}), + name: "get_weather", + description: "Get weather info", + input_schema: &schema, cache_control: None, }; let json = serde_json::to_string(&tool).unwrap(); @@ -841,10 +889,11 @@ mod tests { #[test] fn native_tool_spec_with_cache_control() { + let schema = serde_json::json!({"type": "object"}); let tool = NativeToolSpec { - name: "get_weather".to_string(), - description: "Get weather info".to_string(), - input_schema: serde_json::json!({"type": "object"}), + name: "get_weather", + description: "Get weather info", + input_schema: &schema, cache_control: Some(CacheControl::ephemeral()), }; let json = serde_json::to_string(&tool).unwrap(); @@ -1103,4 +1152,167 @@ mod tests { let result = provider.warmup().await; assert!(result.is_ok()); } + + #[test] + fn convert_messages_preserves_multi_turn_history() { + let messages = vec![ + ChatMessage { + role: "system".to_string(), + content: "You are helpful.".to_string(), + }, + ChatMessage { + role: "user".to_string(), + content: "gen a 2 sum in golang".to_string(), + }, + ChatMessage { + role: "assistant".to_string(), + content: "```go\nfunc twoSum(nums []int) {}\n```".to_string(), + }, + ChatMessage { + role: "user".to_string(), + content: "what's meaning of make here?".to_string(), + }, + ]; + + let (system, native_msgs) = AnthropicProvider::convert_messages(&messages); + + // System prompt extracted + assert!(system.is_some()); + // All 3 non-system messages preserved in order + assert_eq!(native_msgs.len(), 3); + assert_eq!(native_msgs[0].role, "user"); + assert_eq!(native_msgs[1].role, "assistant"); + assert_eq!(native_msgs[2].role, "user"); + } + + /// Integration test: spin up a mock Anthropic API server, call chat_with_tools + /// with a multi-turn conversation + tools, and verify the request body contains + /// ALL conversation turns and native tool definitions. + #[tokio::test] + async fn chat_with_tools_sends_full_history_and_native_tools() { + use axum::{routing::post, Json, Router}; + use std::sync::{Arc, Mutex}; + use tokio::net::TcpListener; + + // Captured request body for assertion + let captured: Arc>> = Arc::new(Mutex::new(None)); + let captured_clone = captured.clone(); + + let app = Router::new().route( + "/v1/messages", + post(move |Json(body): Json| { + let cap = captured_clone.clone(); + async move { + *cap.lock().unwrap() = Some(body); + // Return a minimal valid Anthropic response + Json(serde_json::json!({ + "id": "msg_test", + "type": "message", + "role": "assistant", + "content": [{"type": "text", "text": "The make function creates a map."}], + "model": "claude-opus-4-6", + "stop_reason": "end_turn", + "usage": {"input_tokens": 100, "output_tokens": 20} + })) + } + }), + ); + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + let server_handle = tokio::spawn(async move { + axum::serve(listener, app).await.unwrap(); + }); + + // Create provider pointing at mock server + let provider = AnthropicProvider { + credential: Some("test-key".to_string()), + base_url: format!("http://{addr}"), + }; + + // Multi-turn conversation: system → user (Go code) → assistant (code response) → user (follow-up) + let messages = vec![ + ChatMessage::system("You are a helpful assistant."), + ChatMessage::user("gen a 2 sum in golang"), + ChatMessage::assistant("```go\nfunc twoSum(nums []int, target int) []int {\n m := make(map[int]int)\n for i, n := range nums {\n if j, ok := m[target-n]; ok {\n return []int{j, i}\n }\n m[n] = i\n }\n return nil\n}\n```"), + ChatMessage::user("what's meaning of make here?"), + ]; + + let tools = vec![serde_json::json!({ + "type": "function", + "function": { + "name": "shell", + "description": "Run a shell command", + "parameters": { + "type": "object", + "properties": { + "command": {"type": "string"} + }, + "required": ["command"] + } + } + })]; + + let result = provider + .chat_with_tools(&messages, &tools, "claude-opus-4-6", 0.7) + .await; + assert!(result.is_ok(), "chat_with_tools failed: {:?}", result.err()); + + let body = captured + .lock() + .unwrap() + .take() + .expect("No request captured"); + + // Verify system prompt extracted to top-level field + let system = &body["system"]; + assert!( + system.to_string().contains("helpful assistant"), + "System prompt missing: {system}" + ); + + // Verify ALL conversation turns present in messages array + let msgs = body["messages"].as_array().expect("messages not an array"); + assert_eq!( + msgs.len(), + 3, + "Expected 3 messages (2 user + 1 assistant), got {}", + msgs.len() + ); + + // Turn 1: user with Go request + assert_eq!(msgs[0]["role"], "user"); + let turn1_text = msgs[0]["content"].to_string(); + assert!( + turn1_text.contains("2 sum"), + "Turn 1 missing Go request: {turn1_text}" + ); + + // Turn 2: assistant with Go code + assert_eq!(msgs[1]["role"], "assistant"); + let turn2_text = msgs[1]["content"].to_string(); + assert!( + turn2_text.contains("make(map[int]int)"), + "Turn 2 missing Go code: {turn2_text}" + ); + + // Turn 3: user follow-up + assert_eq!(msgs[2]["role"], "user"); + let turn3_text = msgs[2]["content"].to_string(); + assert!( + turn3_text.contains("meaning of make"), + "Turn 3 missing follow-up: {turn3_text}" + ); + + // Verify native tools are present + let api_tools = body["tools"].as_array().expect("tools not an array"); + assert_eq!(api_tools.len(), 1); + assert_eq!(api_tools[0]["name"], "shell"); + assert!( + api_tools[0]["input_schema"].is_object(), + "Missing input_schema" + ); + + server_handle.abort(); + } } diff --git a/src/providers/bedrock.rs b/src/providers/bedrock.rs new file mode 100644 index 0000000..2ec13a1 --- /dev/null +++ b/src/providers/bedrock.rs @@ -0,0 +1,1244 @@ +//! AWS Bedrock provider using the Converse API. +//! +//! Authentication: AWS AKSK (Access Key ID + Secret Access Key) +//! via environment variables. SigV4 signing is implemented manually +//! using hmac/sha2 crates — no AWS SDK dependency. + +use crate::providers::traits::{ + ChatMessage, ChatRequest as ProviderChatRequest, ChatResponse as ProviderChatResponse, + Provider, ProviderCapabilities, ToolCall as ProviderToolCall, ToolsPayload, +}; +use crate::tools::ToolSpec; +use async_trait::async_trait; +use hmac::{Hmac, Mac}; +use reqwest::Client; +use serde::{Deserialize, Serialize}; +use sha2::{Digest, Sha256}; + +/// Hostname prefix for the Bedrock Runtime endpoint. +const ENDPOINT_PREFIX: &str = "bedrock-runtime"; +/// SigV4 signing service name (AWS uses "bedrock", not "bedrock-runtime"). +const SIGNING_SERVICE: &str = "bedrock"; +const DEFAULT_REGION: &str = "us-east-1"; +const DEFAULT_MAX_TOKENS: u32 = 4096; + +// ── AWS Credentials ───────────────────────────────────────────── + +/// Resolved AWS credentials for SigV4 signing. +struct AwsCredentials { + access_key_id: String, + secret_access_key: String, + session_token: Option, + region: String, +} + +impl AwsCredentials { + /// Resolve credentials from environment variables. + /// + /// Required: `AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY`. + /// Optional: `AWS_SESSION_TOKEN`, `AWS_REGION` / `AWS_DEFAULT_REGION`. + fn from_env() -> anyhow::Result { + let access_key_id = env_required("AWS_ACCESS_KEY_ID")?; + let secret_access_key = env_required("AWS_SECRET_ACCESS_KEY")?; + + let session_token = env_optional("AWS_SESSION_TOKEN"); + + let region = env_optional("AWS_REGION") + .or_else(|| env_optional("AWS_DEFAULT_REGION")) + .unwrap_or_else(|| DEFAULT_REGION.to_string()); + + Ok(Self { + access_key_id, + secret_access_key, + session_token, + region, + }) + } + + fn host(&self) -> String { + format!("{ENDPOINT_PREFIX}.{}.amazonaws.com", self.region) + } +} + +fn env_required(name: &str) -> anyhow::Result { + std::env::var(name) + .ok() + .map(|v| v.trim().to_string()) + .filter(|v| !v.is_empty()) + .ok_or_else(|| anyhow::anyhow!("Environment variable {name} is required for Bedrock")) +} + +fn env_optional(name: &str) -> Option { + std::env::var(name) + .ok() + .map(|v| v.trim().to_string()) + .filter(|v| !v.is_empty()) +} + +// ── AWS SigV4 Signing ─────────────────────────────────────────── + +fn sha256_hex(data: &[u8]) -> String { + let mut hasher = Sha256::new(); + hasher.update(data); + hex::encode(hasher.finalize()) +} + +fn hmac_sha256(key: &[u8], data: &[u8]) -> Vec { + let mut mac = Hmac::::new_from_slice(key).expect("HMAC can take key of any size"); + mac.update(data); + mac.finalize().into_bytes().to_vec() +} + +/// Derive the SigV4 signing key via HMAC chain. +fn derive_signing_key(secret: &str, date: &str, region: &str, service: &str) -> Vec { + let k_date = hmac_sha256(format!("AWS4{secret}").as_bytes(), date.as_bytes()); + let k_region = hmac_sha256(&k_date, region.as_bytes()); + let k_service = hmac_sha256(&k_region, service.as_bytes()); + hmac_sha256(&k_service, b"aws4_request") +} + +/// Build the SigV4 `Authorization` header value. +/// +/// `headers` must be sorted by lowercase header name. +fn build_authorization_header( + credentials: &AwsCredentials, + method: &str, + canonical_uri: &str, + query_string: &str, + headers: &[(String, String)], + payload: &[u8], + timestamp: &chrono::DateTime, +) -> String { + let date_stamp = timestamp.format("%Y%m%d").to_string(); + let amz_date = timestamp.format("%Y%m%dT%H%M%SZ").to_string(); + + let mut canonical_headers = String::new(); + for (k, v) in headers { + canonical_headers.push_str(k); + canonical_headers.push(':'); + canonical_headers.push_str(v); + canonical_headers.push('\n'); + } + + let signed_headers: String = headers + .iter() + .map(|(k, _)| k.as_str()) + .collect::>() + .join(";"); + + let payload_hash = sha256_hex(payload); + + let canonical_request = format!( + "{method}\n{canonical_uri}\n{query_string}\n{canonical_headers}\n{signed_headers}\n{payload_hash}" + ); + + let credential_scope = format!( + "{date_stamp}/{}/{SIGNING_SERVICE}/aws4_request", + credentials.region + ); + + let string_to_sign = format!( + "AWS4-HMAC-SHA256\n{amz_date}\n{credential_scope}\n{}", + sha256_hex(canonical_request.as_bytes()) + ); + + let signing_key = derive_signing_key( + &credentials.secret_access_key, + &date_stamp, + &credentials.region, + SIGNING_SERVICE, + ); + + let signature = hex::encode(hmac_sha256(&signing_key, string_to_sign.as_bytes())); + + format!( + "AWS4-HMAC-SHA256 Credential={}/{credential_scope}, SignedHeaders={signed_headers}, Signature={signature}", + credentials.access_key_id + ) +} + +// ── Converse API Types (Request) ──────────────────────────────── + +#[derive(Debug, Serialize)] +#[serde(rename_all = "camelCase")] +struct ConverseRequest { + messages: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + system: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + inference_config: Option, + #[serde(skip_serializing_if = "Option::is_none")] + tool_config: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +struct ConverseMessage { + role: String, + content: Vec, +} + +/// Content blocks use Bedrock's union style: +/// `{"text": "..."}`, `{"toolUse": {...}}`, `{"toolResult": {...}}`, `{"cachePoint": {...}}`. +/// +/// Note: `text` is a simple string value, not a nested object. `toolUse` and `toolResult` +/// are nested objects. We use `#[serde(untagged)]` with manual struct wrappers to +/// match this mixed format. +#[derive(Debug, Serialize, Deserialize)] +#[serde(untagged)] +enum ContentBlock { + Text(TextBlock), + ToolUse(ToolUseWrapper), + ToolResult(ToolResultWrapper), + CachePointBlock(CachePointWrapper), +} + +#[derive(Debug, Serialize, Deserialize)] +struct TextBlock { + text: String, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +struct ToolUseWrapper { + tool_use: ToolUseBlock, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +struct ToolUseBlock { + tool_use_id: String, + name: String, + input: serde_json::Value, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +struct ToolResultWrapper { + tool_result: ToolResultBlock, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +struct ToolResultBlock { + tool_use_id: String, + content: Vec, + status: String, +} + +#[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +struct CachePointWrapper { + cache_point: CachePoint, +} + +#[derive(Debug, Serialize, Deserialize)] +struct ToolResultContent { + text: String, +} + +#[derive(Debug, Serialize, Deserialize)] +struct CachePoint { + #[serde(rename = "type")] + cache_type: String, +} + +impl CachePoint { + fn default_cache() -> Self { + Self { + cache_type: "default".to_string(), + } + } +} + +/// System prompt blocks: either `{"text": "..."}` or `{"cachePoint": {...}}`. +#[derive(Debug, Serialize)] +#[serde(untagged)] +enum SystemBlock { + Text(TextBlock), + CachePoint(CachePointWrapper), +} + +#[derive(Debug, Serialize)] +#[serde(rename_all = "camelCase")] +struct InferenceConfig { + max_tokens: u32, + temperature: f64, +} + +#[derive(Debug, Serialize)] +#[serde(rename_all = "camelCase")] +struct ToolConfig { + tools: Vec, +} + +#[derive(Debug, Serialize)] +#[serde(rename_all = "camelCase")] +struct ToolDefinition { + tool_spec: ToolSpecDef, +} + +#[derive(Debug, Serialize)] +#[serde(rename_all = "camelCase")] +struct ToolSpecDef { + name: String, + description: String, + input_schema: InputSchema, +} + +#[derive(Debug, Serialize)] +struct InputSchema { + json: serde_json::Value, +} + +// ── Converse API Types (Response) ─────────────────────────────── + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +struct ConverseResponse { + #[serde(default)] + output: Option, + #[serde(default)] + #[allow(dead_code)] + stop_reason: Option, +} + +#[derive(Debug, Deserialize)] +struct ConverseOutput { + #[serde(default)] + message: Option, +} + +#[derive(Debug, Deserialize)] +struct ConverseOutputMessage { + #[allow(dead_code)] + role: String, + content: Vec, +} + +/// Response content blocks from the Converse API. +/// +/// Uses `#[serde(untagged)]` to match Bedrock's union format where `text` is a +/// simple string value and `toolUse` is a nested object. Unknown block types +/// (e.g. `reasoningContent`, `guardContent`) are captured as `Other` to prevent +/// deserialization failures. +#[derive(Debug, Deserialize)] +#[serde(untagged)] +enum ResponseContentBlock { + ToolUse(ResponseToolUseWrapper), + Text(TextBlock), + Other(serde_json::Value), +} + +#[derive(Debug, Deserialize)] +#[serde(rename_all = "camelCase")] +struct ResponseToolUseWrapper { + tool_use: ToolUseBlock, +} + +// ── BedrockProvider ───────────────────────────────────────────── + +pub struct BedrockProvider { + credentials: Option, +} + +impl BedrockProvider { + pub fn new() -> Self { + Self { + credentials: AwsCredentials::from_env().ok(), + } + } + + fn http_client(&self) -> Client { + crate::config::build_runtime_proxy_client_with_timeouts("provider.bedrock", 120, 10) + } + + /// Percent-encode the model ID for URL path: only encode `:` to `%3A`. + /// Colons in model IDs (e.g. `v1:0`) must be encoded because `reqwest::Url` + /// may misparse them. Dots, hyphens, and alphanumerics are safe. + fn encode_model_path(model_id: &str) -> String { + model_id.replace(':', "%3A") + } + + /// Build the actual request URL. Uses raw model ID (reqwest sends colons as-is). + fn endpoint_url(region: &str, model_id: &str) -> String { + format!("https://{ENDPOINT_PREFIX}.{region}.amazonaws.com/model/{model_id}/converse") + } + + /// Build the canonical URI for SigV4 signing. Must URI-encode the path + /// per SigV4 spec: colons become `%3A`. AWS verifies the signature against + /// the encoded form even though the wire request uses raw colons. + fn canonical_uri(model_id: &str) -> String { + let encoded = Self::encode_model_path(model_id); + format!("/model/{encoded}/converse") + } + + fn require_credentials(&self) -> anyhow::Result<&AwsCredentials> { + self.credentials.as_ref().ok_or_else(|| { + anyhow::anyhow!( + "AWS Bedrock credentials not set. Set AWS_ACCESS_KEY_ID and \ + AWS_SECRET_ACCESS_KEY environment variables." + ) + }) + } + + // ── Cache heuristics (same thresholds as AnthropicProvider) ── + + /// Cache system prompts larger than ~1024 tokens (3KB of text). + fn should_cache_system(text: &str) -> bool { + text.len() > 3072 + } + + /// Cache conversations with more than 4 messages (excluding system). + fn should_cache_conversation(messages: &[ChatMessage]) -> bool { + messages.iter().filter(|m| m.role != "system").count() > 4 + } + + // ── Message conversion ────────────────────────────────────── + + fn convert_messages( + messages: &[ChatMessage], + ) -> (Option>, Vec) { + let mut system_blocks = Vec::new(); + let mut converse_messages = Vec::new(); + + for msg in messages { + match msg.role.as_str() { + "system" => { + if system_blocks.is_empty() { + system_blocks.push(SystemBlock::Text(TextBlock { + text: msg.content.clone(), + })); + } + } + "assistant" => { + if let Some(blocks) = Self::parse_assistant_tool_call_message(&msg.content) { + converse_messages.push(ConverseMessage { + role: "assistant".to_string(), + content: blocks, + }); + } else { + converse_messages.push(ConverseMessage { + role: "assistant".to_string(), + content: vec![ContentBlock::Text(TextBlock { + text: msg.content.clone(), + })], + }); + } + } + "tool" => { + if let Some(tool_result_msg) = Self::parse_tool_result_message(&msg.content) { + converse_messages.push(tool_result_msg); + } else { + converse_messages.push(ConverseMessage { + role: "user".to_string(), + content: vec![ContentBlock::Text(TextBlock { + text: msg.content.clone(), + })], + }); + } + } + _ => { + converse_messages.push(ConverseMessage { + role: "user".to_string(), + content: vec![ContentBlock::Text(TextBlock { + text: msg.content.clone(), + })], + }); + } + } + } + + let system = if system_blocks.is_empty() { + None + } else { + Some(system_blocks) + }; + (system, converse_messages) + } + + /// Parse assistant message containing structured tool calls. + fn parse_assistant_tool_call_message(content: &str) -> Option> { + let value = serde_json::from_str::(content).ok()?; + let tool_calls = value + .get("tool_calls") + .and_then(|v| serde_json::from_value::>(v.clone()).ok())?; + + let mut blocks = Vec::new(); + if let Some(text) = value + .get("content") + .and_then(serde_json::Value::as_str) + .map(str::trim) + .filter(|t| !t.is_empty()) + { + blocks.push(ContentBlock::Text(TextBlock { + text: text.to_string(), + })); + } + for call in tool_calls { + let input = serde_json::from_str::(&call.arguments) + .unwrap_or_else(|_| serde_json::Value::Object(serde_json::Map::new())); + blocks.push(ContentBlock::ToolUse(ToolUseWrapper { + tool_use: ToolUseBlock { + tool_use_id: call.id, + name: call.name, + input, + }, + })); + } + Some(blocks) + } + + /// Parse tool result message into a user message with ToolResult block. + fn parse_tool_result_message(content: &str) -> Option { + let value = serde_json::from_str::(content).ok()?; + let tool_use_id = value + .get("tool_call_id") + .and_then(serde_json::Value::as_str)? + .to_string(); + let result = value + .get("content") + .and_then(serde_json::Value::as_str) + .unwrap_or("") + .to_string(); + Some(ConverseMessage { + role: "user".to_string(), + content: vec![ContentBlock::ToolResult(ToolResultWrapper { + tool_result: ToolResultBlock { + tool_use_id, + content: vec![ToolResultContent { text: result }], + status: "success".to_string(), + }, + })], + }) + } + + // ── Tool conversion ───────────────────────────────────────── + + fn convert_tools_to_converse(tools: Option<&[ToolSpec]>) -> Option { + let items = tools?; + if items.is_empty() { + return None; + } + let tool_defs: Vec = items + .iter() + .map(|tool| ToolDefinition { + tool_spec: ToolSpecDef { + name: tool.name.clone(), + description: tool.description.clone(), + input_schema: InputSchema { + json: tool.parameters.clone(), + }, + }, + }) + .collect(); + Some(ToolConfig { tools: tool_defs }) + } + + // ── Response parsing ──────────────────────────────────────── + + fn parse_converse_response(response: ConverseResponse) -> ProviderChatResponse { + let mut text_parts = Vec::new(); + let mut tool_calls = Vec::new(); + + if let Some(output) = response.output { + if let Some(message) = output.message { + for block in message.content { + match block { + ResponseContentBlock::Text(tb) => { + let trimmed = tb.text.trim().to_string(); + if !trimmed.is_empty() { + text_parts.push(trimmed); + } + } + ResponseContentBlock::ToolUse(wrapper) => { + if !wrapper.tool_use.name.is_empty() { + tool_calls.push(ProviderToolCall { + id: wrapper.tool_use.tool_use_id, + name: wrapper.tool_use.name, + arguments: wrapper.tool_use.input.to_string(), + }); + } + } + ResponseContentBlock::Other(_) => {} + } + } + } + } + + ProviderChatResponse { + text: if text_parts.is_empty() { + None + } else { + Some(text_parts.join("\n")) + }, + tool_calls, + } + } + + // ── HTTP request ──────────────────────────────────────────── + + async fn send_converse_request( + &self, + credentials: &AwsCredentials, + model: &str, + request_body: &ConverseRequest, + ) -> anyhow::Result { + let payload = serde_json::to_vec(request_body)?; + let url = Self::endpoint_url(&credentials.region, model); + let canonical_uri = Self::canonical_uri(model); + let now = chrono::Utc::now(); + let host = credentials.host(); + let amz_date = now.format("%Y%m%dT%H%M%SZ").to_string(); + + let mut headers_to_sign = vec![ + ("content-type".to_string(), "application/json".to_string()), + ("host".to_string(), host), + ("x-amz-date".to_string(), amz_date.clone()), + ]; + if let Some(ref token) = credentials.session_token { + headers_to_sign.push(("x-amz-security-token".to_string(), token.clone())); + } + headers_to_sign.sort_by(|a, b| a.0.cmp(&b.0)); + + let authorization = build_authorization_header( + credentials, + "POST", + &canonical_uri, + "", + &headers_to_sign, + &payload, + &now, + ); + + let mut request = self + .http_client() + .post(&url) + .header("content-type", "application/json") + .header("x-amz-date", &amz_date) + .header("authorization", &authorization); + + if let Some(ref token) = credentials.session_token { + request = request.header("x-amz-security-token", token); + } + + let response: reqwest::Response = request.body(payload).send().await?; + + if !response.status().is_success() { + return Err(super::api_error("Bedrock", response).await); + } + + let converse_response: ConverseResponse = response.json().await?; + Ok(converse_response) + } +} + +// ── Provider trait implementation ─────────────────────────────── + +#[async_trait] +impl Provider for BedrockProvider { + fn capabilities(&self) -> ProviderCapabilities { + ProviderCapabilities { + native_tool_calling: true, + vision: false, + } + } + + fn supports_native_tools(&self) -> bool { + true + } + + fn convert_tools(&self, tools: &[ToolSpec]) -> ToolsPayload { + let tool_values: Vec = tools + .iter() + .map(|t| { + serde_json::json!({ + "toolSpec": { + "name": t.name, + "description": t.description, + "inputSchema": { "json": t.parameters } + } + }) + }) + .collect(); + ToolsPayload::Anthropic { tools: tool_values } + } + + async fn chat_with_system( + &self, + system_prompt: Option<&str>, + message: &str, + model: &str, + temperature: f64, + ) -> anyhow::Result { + let credentials = self.require_credentials()?; + + let system = system_prompt.map(|text| { + let mut blocks = vec![SystemBlock::Text(TextBlock { + text: text.to_string(), + })]; + if Self::should_cache_system(text) { + blocks.push(SystemBlock::CachePoint(CachePointWrapper { + cache_point: CachePoint::default_cache(), + })); + } + blocks + }); + + let request = ConverseRequest { + system, + messages: vec![ConverseMessage { + role: "user".to_string(), + content: vec![ContentBlock::Text(TextBlock { + text: message.to_string(), + })], + }], + inference_config: Some(InferenceConfig { + max_tokens: DEFAULT_MAX_TOKENS, + temperature, + }), + tool_config: None, + }; + + let response = self + .send_converse_request(credentials, model, &request) + .await?; + + Self::parse_converse_response(response) + .text + .ok_or_else(|| anyhow::anyhow!("No response from Bedrock")) + } + + async fn chat( + &self, + request: ProviderChatRequest<'_>, + model: &str, + temperature: f64, + ) -> anyhow::Result { + let credentials = self.require_credentials()?; + + let (system_blocks, mut converse_messages) = Self::convert_messages(request.messages); + + // Apply cachePoint to system if large. + let system = system_blocks.map(|mut blocks| { + let has_large_system = blocks + .iter() + .any(|b| matches!(b, SystemBlock::Text(tb) if Self::should_cache_system(&tb.text))); + if has_large_system { + blocks.push(SystemBlock::CachePoint(CachePointWrapper { + cache_point: CachePoint::default_cache(), + })); + } + blocks + }); + + // Apply cachePoint to last message if conversation is long. + if Self::should_cache_conversation(request.messages) { + if let Some(last_msg) = converse_messages.last_mut() { + last_msg + .content + .push(ContentBlock::CachePointBlock(CachePointWrapper { + cache_point: CachePoint::default_cache(), + })); + } + } + + let tool_config = Self::convert_tools_to_converse(request.tools); + + let converse_request = ConverseRequest { + system, + messages: converse_messages, + inference_config: Some(InferenceConfig { + max_tokens: DEFAULT_MAX_TOKENS, + temperature, + }), + tool_config, + }; + + let response = self + .send_converse_request(credentials, model, &converse_request) + .await?; + + Ok(Self::parse_converse_response(response)) + } + + async fn warmup(&self) -> anyhow::Result<()> { + if let Some(ref creds) = self.credentials { + let url = format!("https://{ENDPOINT_PREFIX}.{}.amazonaws.com/", creds.region); + let _ = self.http_client().get(&url).send().await; + } + Ok(()) + } +} + +// ── Tests ─────────────────────────────────────────────────────── + +#[cfg(test)] +mod tests { + use super::*; + use crate::providers::traits::ChatMessage; + + // ── SigV4 signing tests ───────────────────────────────────── + + #[test] + fn sha256_hex_empty_string() { + // Known SHA-256 of empty input + assert_eq!( + sha256_hex(b""), + "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855" + ); + } + + #[test] + fn sha256_hex_known_input() { + // SHA-256 of "hello" + assert_eq!( + sha256_hex(b"hello"), + "2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824" + ); + } + + /// AWS documentation example key for SigV4 test vectors (not a real credential). + const TEST_VECTOR_SECRET: &str = "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY"; + + #[test] + fn hmac_sha256_known_input() { + let test_key: &[u8] = b"key"; + let result = hmac_sha256(test_key, b"message"); + assert_eq!( + hex::encode(&result), + "6e9ef29b75fffc5b7abae527d58fdadb2fe42e7219011976917343065f58ed4a" + ); + } + + #[test] + fn derive_signing_key_structure() { + // Verify the key derivation produces a 32-byte key (SHA-256 output). + let key = derive_signing_key( + TEST_VECTOR_SECRET, + "20150830", + "us-east-1", + "iam", + ); + assert_eq!(key.len(), 32); + } + + #[test] + fn derive_signing_key_known_test_vector() { + // AWS SigV4 test vector from documentation. + let key = derive_signing_key( + TEST_VECTOR_SECRET, + "20150830", + "us-east-1", + "iam", + ); + assert_eq!( + hex::encode(&key), + "c4afb1cc5771d871763a393e44b703571b55cc28424d1a5e86da6ed3c154a4b9" + ); + } + + #[test] + fn build_authorization_header_format() { + let credentials = AwsCredentials { + access_key_id: "AKIAIOSFODNN7EXAMPLE".to_string(), + secret_access_key: "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY".to_string(), + session_token: None, + region: "us-east-1".to_string(), + }; + + let timestamp = chrono::DateTime::parse_from_rfc3339("2024-01-15T12:00:00Z") + .unwrap() + .with_timezone(&chrono::Utc); + + let headers = vec![ + ("content-type".to_string(), "application/json".to_string()), + ( + "host".to_string(), + "bedrock-runtime.us-east-1.amazonaws.com".to_string(), + ), + ("x-amz-date".to_string(), "20240115T120000Z".to_string()), + ]; + + let auth = build_authorization_header( + &credentials, + "POST", + "/model/anthropic.claude-3-sonnet/converse", + "", + &headers, + b"{}", + ×tamp, + ); + + // Verify structure + assert!(auth.starts_with("AWS4-HMAC-SHA256 Credential=AKIAIOSFODNN7EXAMPLE/")); + assert!(auth.contains("SignedHeaders=content-type;host;x-amz-date")); + assert!(auth.contains("Signature=")); + assert!(auth.contains("/us-east-1/bedrock/aws4_request")); + } + + #[test] + fn build_authorization_header_includes_security_token_in_signed_headers() { + let credentials = AwsCredentials { + access_key_id: "AKIAIOSFODNN7EXAMPLE".to_string(), + secret_access_key: "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY".to_string(), + session_token: Some("session-token-value".to_string()), + region: "us-east-1".to_string(), + }; + + let timestamp = chrono::DateTime::parse_from_rfc3339("2024-01-15T12:00:00Z") + .unwrap() + .with_timezone(&chrono::Utc); + + let headers = vec![ + ("content-type".to_string(), "application/json".to_string()), + ( + "host".to_string(), + "bedrock-runtime.us-east-1.amazonaws.com".to_string(), + ), + ("x-amz-date".to_string(), "20240115T120000Z".to_string()), + ( + "x-amz-security-token".to_string(), + "session-token-value".to_string(), + ), + ]; + + let auth = build_authorization_header( + &credentials, + "POST", + "/model/test-model/converse", + "", + &headers, + b"{}", + ×tamp, + ); + + assert!(auth.contains("x-amz-security-token")); + } + + // ── Credential tests ──────────────────────────────────────── + + #[test] + fn credentials_host_formats_correctly() { + let creds = AwsCredentials { + access_key_id: "AKID".to_string(), + secret_access_key: "secret".to_string(), + session_token: None, + region: "us-west-2".to_string(), + }; + assert_eq!(creds.host(), "bedrock-runtime.us-west-2.amazonaws.com"); + } + + // ── Provider construction tests ───────────────────────────── + + #[test] + fn creates_without_credentials() { + // Provider should construct even without env vars. + let _provider = BedrockProvider::new(); + } + + #[tokio::test] + async fn chat_fails_without_credentials() { + let provider = BedrockProvider { credentials: None }; + let result = provider + .chat_with_system(None, "hello", "anthropic.claude-sonnet-4-6", 0.7) + .await; + assert!(result.is_err()); + let err = result.unwrap_err().to_string(); + assert!( + err.contains("credentials not set"), + "Expected credentials error, got: {err}" + ); + } + + // ── Endpoint URL tests ────────────────────────────────────── + + #[test] + fn endpoint_url_formats_correctly() { + let url = BedrockProvider::endpoint_url("us-east-1", "anthropic.claude-sonnet-4-6"); + assert_eq!( + url, + "https://bedrock-runtime.us-east-1.amazonaws.com/model/anthropic.claude-sonnet-4-6/converse" + ); + } + + #[test] + fn endpoint_url_keeps_raw_colon() { + // Endpoint URL uses raw colon so reqwest sends `:` on the wire. + let url = + BedrockProvider::endpoint_url("us-west-2", "anthropic.claude-3-5-haiku-20241022-v1:0"); + assert!(url.contains("/model/anthropic.claude-3-5-haiku-20241022-v1:0/converse")); + } + + #[test] + fn canonical_uri_encodes_colon() { + // Canonical URI must encode `:` as `%3A` for SigV4 signing. + let uri = BedrockProvider::canonical_uri("anthropic.claude-3-5-haiku-20241022-v1:0"); + assert_eq!( + uri, + "/model/anthropic.claude-3-5-haiku-20241022-v1%3A0/converse" + ); + } + + #[test] + fn canonical_uri_no_colon_unchanged() { + let uri = BedrockProvider::canonical_uri("anthropic.claude-sonnet-4-6"); + assert_eq!(uri, "/model/anthropic.claude-sonnet-4-6/converse"); + } + + // ── Message conversion tests ──────────────────────────────── + + #[test] + fn convert_messages_system_extracted() { + let messages = vec![ + ChatMessage::system("You are helpful"), + ChatMessage::user("Hello"), + ]; + let (system, msgs) = BedrockProvider::convert_messages(&messages); + assert!(system.is_some()); + let system_blocks = system.unwrap(); + assert_eq!(system_blocks.len(), 1); + assert_eq!(msgs.len(), 1); + assert_eq!(msgs[0].role, "user"); + } + + #[test] + fn convert_messages_user_and_assistant() { + let messages = vec![ + ChatMessage::user("Hello"), + ChatMessage::assistant("Hi there"), + ]; + let (system, msgs) = BedrockProvider::convert_messages(&messages); + assert!(system.is_none()); + assert_eq!(msgs.len(), 2); + assert_eq!(msgs[0].role, "user"); + assert_eq!(msgs[1].role, "assistant"); + } + + #[test] + fn convert_messages_tool_role_to_tool_result() { + let tool_json = r#"{"tool_call_id": "call_123", "content": "Result data"}"#; + let messages = vec![ChatMessage::tool(tool_json)]; + let (_, msgs) = BedrockProvider::convert_messages(&messages); + assert_eq!(msgs.len(), 1); + assert_eq!(msgs[0].role, "user"); + assert!(matches!(msgs[0].content[0], ContentBlock::ToolResult(_))); + } + + #[test] + fn convert_messages_assistant_tool_calls_parsed() { + let tool_call_json = r#"{"content": "Let me check", "tool_calls": [{"id": "call_1", "name": "shell", "arguments": "{\"command\":\"ls\"}"}]}"#; + let messages = vec![ChatMessage::assistant(tool_call_json)]; + let (_, msgs) = BedrockProvider::convert_messages(&messages); + assert_eq!(msgs.len(), 1); + assert_eq!(msgs[0].role, "assistant"); + assert_eq!(msgs[0].content.len(), 2); + assert!(matches!(msgs[0].content[0], ContentBlock::Text(_))); + assert!(matches!(msgs[0].content[1], ContentBlock::ToolUse(_))); + } + + #[test] + fn convert_messages_plain_assistant_text() { + let messages = vec![ChatMessage::assistant("Just text")]; + let (_, msgs) = BedrockProvider::convert_messages(&messages); + assert_eq!(msgs.len(), 1); + assert!(matches!(msgs[0].content[0], ContentBlock::Text(_))); + } + + // ── Cache tests ───────────────────────────────────────────── + + #[test] + fn should_cache_system_small_prompt() { + assert!(!BedrockProvider::should_cache_system("Short prompt")); + } + + #[test] + fn should_cache_system_large_prompt() { + let large = "a".repeat(3073); + assert!(BedrockProvider::should_cache_system(&large)); + } + + #[test] + fn should_cache_system_boundary() { + assert!(!BedrockProvider::should_cache_system(&"a".repeat(3072))); + assert!(BedrockProvider::should_cache_system(&"a".repeat(3073))); + } + + #[test] + fn should_cache_conversation_short() { + let messages = vec![ + ChatMessage::system("System"), + ChatMessage::user("Hello"), + ChatMessage::assistant("Hi"), + ]; + assert!(!BedrockProvider::should_cache_conversation(&messages)); + } + + #[test] + fn should_cache_conversation_long() { + let mut messages = vec![ChatMessage::system("System")]; + for i in 0..5 { + messages.push(ChatMessage { + role: if i % 2 == 0 { "user" } else { "assistant" }.to_string(), + content: format!("Message {i}"), + }); + } + assert!(BedrockProvider::should_cache_conversation(&messages)); + } + + // ── Tool conversion tests ─────────────────────────────────── + + #[test] + fn convert_tools_to_converse_formats_correctly() { + let tools = vec![ToolSpec { + name: "shell".to_string(), + description: "Run commands".to_string(), + parameters: serde_json::json!({"type": "object", "properties": {"command": {"type": "string"}}}), + }]; + let config = BedrockProvider::convert_tools_to_converse(Some(&tools)); + assert!(config.is_some()); + let config = config.unwrap(); + assert_eq!(config.tools.len(), 1); + assert_eq!(config.tools[0].tool_spec.name, "shell"); + } + + #[test] + fn convert_tools_to_converse_empty_returns_none() { + assert!(BedrockProvider::convert_tools_to_converse(Some(&[])).is_none()); + assert!(BedrockProvider::convert_tools_to_converse(None).is_none()); + } + + // ── Serde tests ───────────────────────────────────────────── + + #[test] + fn converse_request_serializes_without_system() { + let req = ConverseRequest { + system: None, + messages: vec![ConverseMessage { + role: "user".to_string(), + content: vec![ContentBlock::Text(TextBlock { + text: "Hello".to_string(), + })], + }], + inference_config: Some(InferenceConfig { + max_tokens: 4096, + temperature: 0.7, + }), + tool_config: None, + }; + let json = serde_json::to_string(&req).unwrap(); + assert!(!json.contains("system")); + assert!(json.contains("Hello")); + assert!(json.contains("maxTokens")); + } + + #[test] + fn converse_response_deserializes_text() { + let json = r#"{ + "output": { + "message": { + "role": "assistant", + "content": [{"text": "Hello from Bedrock"}] + } + }, + "stopReason": "end_turn" + }"#; + let resp: ConverseResponse = serde_json::from_str(json).unwrap(); + let parsed = BedrockProvider::parse_converse_response(resp); + assert_eq!(parsed.text.as_deref(), Some("Hello from Bedrock")); + assert!(parsed.tool_calls.is_empty()); + } + + #[test] + fn converse_response_deserializes_tool_use() { + let json = r#"{ + "output": { + "message": { + "role": "assistant", + "content": [ + {"toolUse": {"toolUseId": "call_1", "name": "shell", "input": {"command": "ls"}}} + ] + } + }, + "stopReason": "tool_use" + }"#; + let resp: ConverseResponse = serde_json::from_str(json).unwrap(); + let parsed = BedrockProvider::parse_converse_response(resp); + assert!(parsed.text.is_none()); + assert_eq!(parsed.tool_calls.len(), 1); + assert_eq!(parsed.tool_calls[0].name, "shell"); + assert_eq!(parsed.tool_calls[0].id, "call_1"); + } + + #[test] + fn converse_response_empty_output() { + let json = r#"{"output": null, "stopReason": null}"#; + let resp: ConverseResponse = serde_json::from_str(json).unwrap(); + let parsed = BedrockProvider::parse_converse_response(resp); + assert!(parsed.text.is_none()); + assert!(parsed.tool_calls.is_empty()); + } + + #[test] + fn content_block_text_serializes_as_flat_string() { + let block = ContentBlock::Text(TextBlock { + text: "Hello".to_string(), + }); + let json = serde_json::to_string(&block).unwrap(); + // Must be {"text":"Hello"}, NOT {"text":{"text":"Hello"}} + assert_eq!(json, r#"{"text":"Hello"}"#); + } + + #[test] + fn content_block_tool_use_serializes_with_nested_object() { + let block = ContentBlock::ToolUse(ToolUseWrapper { + tool_use: ToolUseBlock { + tool_use_id: "call_1".to_string(), + name: "shell".to_string(), + input: serde_json::json!({"command": "ls"}), + }, + }); + let json = serde_json::to_string(&block).unwrap(); + assert!(json.contains(r#""toolUse""#)); + assert!(json.contains(r#""toolUseId":"call_1""#)); + } + + #[test] + fn content_block_cache_point_serializes() { + let block = ContentBlock::CachePointBlock(CachePointWrapper { + cache_point: CachePoint::default_cache(), + }); + let json = serde_json::to_string(&block).unwrap(); + assert_eq!(json, r#"{"cachePoint":{"type":"default"}}"#); + } + + #[test] + fn content_block_text_round_trips() { + let original = ContentBlock::Text(TextBlock { + text: "Hello".to_string(), + }); + let json = serde_json::to_string(&original).unwrap(); + let deserialized: ContentBlock = serde_json::from_str(&json).unwrap(); + assert!(matches!(deserialized, ContentBlock::Text(tb) if tb.text == "Hello")); + } + + #[test] + fn cache_point_serializes() { + let cp = CachePoint::default_cache(); + let json = serde_json::to_string(&cp).unwrap(); + assert_eq!(json, r#"{"type":"default"}"#); + } + + #[tokio::test] + async fn warmup_without_credentials_is_noop() { + let provider = BedrockProvider { credentials: None }; + let result = provider.warmup().await; + assert!(result.is_ok()); + } + + #[test] + fn capabilities_reports_native_tool_calling() { + let provider = BedrockProvider { credentials: None }; + let caps = provider.capabilities(); + assert!(caps.native_tool_calling); + } +} diff --git a/src/providers/compatible.rs b/src/providers/compatible.rs index 074ee45..615ac6d 100644 --- a/src/providers/compatible.rs +++ b/src/providers/compatible.rs @@ -26,6 +26,10 @@ pub struct OpenAiCompatibleProvider { /// GLM/Zhipu does not support the responses API. supports_responses_fallback: bool, user_agent: Option, + /// When true, collect all `system` messages and prepend their content + /// to the first `user` message, then drop the system messages. + /// Required for providers that reject `role: system` (e.g. MiniMax). + merge_system_into_user: bool, } /// How the provider expects the API key to be sent. @@ -46,7 +50,7 @@ impl OpenAiCompatibleProvider { credential: Option<&str>, auth_style: AuthStyle, ) -> Self { - Self::new_with_options(name, base_url, credential, auth_style, true, None) + Self::new_with_options(name, base_url, credential, auth_style, true, None, false) } /// Same as `new` but skips the /v1/responses fallback on 404. @@ -57,7 +61,7 @@ impl OpenAiCompatibleProvider { credential: Option<&str>, auth_style: AuthStyle, ) -> Self { - Self::new_with_options(name, base_url, credential, auth_style, false, None) + Self::new_with_options(name, base_url, credential, auth_style, false, None, false) } /// Create a provider with a custom User-Agent header. @@ -78,9 +82,21 @@ impl OpenAiCompatibleProvider { auth_style, true, Some(user_agent), + false, ) } + /// For providers that do not support `role: system` (e.g. MiniMax). + /// System prompt content is prepended to the first user message instead. + pub fn new_merge_system_into_user( + name: &str, + base_url: &str, + credential: Option<&str>, + auth_style: AuthStyle, + ) -> Self { + Self::new_with_options(name, base_url, credential, auth_style, false, None, true) + } + fn new_with_options( name: &str, base_url: &str, @@ -88,6 +104,7 @@ impl OpenAiCompatibleProvider { auth_style: AuthStyle, supports_responses_fallback: bool, user_agent: Option<&str>, + merge_system_into_user: bool, ) -> Self { Self { name: name.to_string(), @@ -96,9 +113,41 @@ impl OpenAiCompatibleProvider { auth_header: auth_style, supports_responses_fallback, user_agent: user_agent.map(ToString::to_string), + merge_system_into_user, } } + /// Collect all `system` role messages, concatenate their content, + /// and prepend to the first `user` message. Drop all system messages. + /// Used for providers (e.g. MiniMax) that reject `role: system`. + fn flatten_system_messages(messages: &[ChatMessage]) -> Vec { + let system_content: String = messages + .iter() + .filter(|m| m.role == "system") + .map(|m| m.content.as_str()) + .collect::>() + .join("\n\n"); + + if system_content.is_empty() { + return messages.to_vec(); + } + + let mut result: Vec = messages + .iter() + .filter(|m| m.role != "system") + .cloned() + .collect(); + + if let Some(first_user) = result.iter_mut().find(|m| m.role == "user") { + first_user.content = format!("{system_content}\n\n{}", first_user.content); + } else { + // No user message found: insert a synthetic user message with system content + result.insert(0, ChatMessage::user(&system_content)); + } + + result + } + fn http_client(&self) -> Client { if let Some(ua) = self.user_agent.as_deref() { let mut headers = HeaderMap::new(); @@ -230,6 +279,30 @@ struct Choice { message: ResponseMessage, } +/// Remove `...` blocks from model output. +/// Some reasoning models (e.g. MiniMax) embed their chain-of-thought inline +/// in the `content` field rather than a separate `reasoning_content` field. +/// The resulting `` tags must be stripped before returning to the user. +fn strip_think_tags(s: &str) -> String { + let mut result = String::with_capacity(s.len()); + let mut rest = s; + loop { + if let Some(start) = rest.find("") { + result.push_str(&rest[..start]); + if let Some(end) = rest[start..].find("") { + rest = &rest[start + end + "".len()..]; + } else { + // Unclosed tag: drop the rest to avoid leaking partial reasoning. + break; + } + } else { + result.push_str(rest); + break; + } + } + result.trim().to_string() +} + #[derive(Debug, Deserialize, Serialize)] struct ResponseMessage { #[serde(default)] @@ -246,18 +319,35 @@ impl ResponseMessage { /// Extract text content, falling back to `reasoning_content` when `content` /// is missing or empty. Reasoning/thinking models (Qwen3, GLM-4, etc.) /// often return their output solely in `reasoning_content`. + /// Strips `...` blocks that some models (e.g. MiniMax) embed + /// inline in `content` instead of using a separate field. fn effective_content(&self) -> String { - match &self.content { - Some(c) if !c.is_empty() => c.clone(), - _ => self.reasoning_content.clone().unwrap_or_default(), + if let Some(content) = self.content.as_ref().filter(|c| !c.is_empty()) { + let stripped = strip_think_tags(content); + if !stripped.is_empty() { + return stripped; + } } + + self.reasoning_content + .as_ref() + .map(|c| strip_think_tags(c)) + .filter(|c| !c.is_empty()) + .unwrap_or_default() } fn effective_content_optional(&self) -> Option { - match &self.content { - Some(c) if !c.is_empty() => Some(c.clone()), - _ => self.reasoning_content.clone().filter(|c| !c.is_empty()), + if let Some(content) = self.content.as_ref().filter(|c| !c.is_empty()) { + let stripped = strip_think_tags(content); + if !stripped.is_empty() { + return Some(stripped); + } } + + self.reasoning_content + .as_ref() + .map(|c| strip_think_tags(c)) + .filter(|c| !c.is_empty()) } } @@ -495,6 +585,43 @@ fn first_nonempty(text: Option<&str>) -> Option { }) } +fn normalize_responses_role(role: &str) -> &'static str { + match role { + "assistant" => "assistant", + "tool" => "assistant", + _ => "user", + } +} + +fn build_responses_prompt(messages: &[ChatMessage]) -> (Option, Vec) { + let mut instructions_parts = Vec::new(); + let mut input = Vec::new(); + + for message in messages { + if message.content.trim().is_empty() { + continue; + } + + if message.role == "system" { + instructions_parts.push(message.content.clone()); + continue; + } + + input.push(ResponsesInput { + role: normalize_responses_role(&message.role).to_string(), + content: message.content.clone(), + }); + } + + let instructions = if instructions_parts.is_empty() { + None + } else { + Some(instructions_parts.join("\n\n")) + }; + + (instructions, input) +} + fn extract_responses_text(response: ResponsesResponse) -> Option { if let Some(text) = first_nonempty(response.output_text.as_deref()) { return Some(text); @@ -565,17 +692,21 @@ impl OpenAiCompatibleProvider { async fn chat_via_responses( &self, credential: &str, - system_prompt: Option<&str>, - message: &str, + messages: &[ChatMessage], model: &str, ) -> anyhow::Result { + let (instructions, input) = build_responses_prompt(messages); + if input.is_empty() { + anyhow::bail!( + "{} Responses API fallback requires at least one non-system message", + self.name + ); + } + let request = ResponsesRequest { model: model.to_string(), - input: vec![ResponsesInput { - role: "user".to_string(), - content: message.to_string(), - }], - instructions: system_prompt.map(str::to_string), + input, + instructions, stream: Some(false), }; @@ -767,6 +898,7 @@ impl Provider for OpenAiCompatibleProvider { fn capabilities(&self) -> crate::providers::traits::ProviderCapabilities { crate::providers::traits::ProviderCapabilities { native_tool_calling: true, + vision: false, } } @@ -786,18 +918,28 @@ impl Provider for OpenAiCompatibleProvider { let mut messages = Vec::new(); - if let Some(sys) = system_prompt { + if self.merge_system_into_user { + let content = match system_prompt { + Some(sys) => format!("{sys}\n\n{message}"), + None => message.to_string(), + }; messages.push(Message { - role: "system".to_string(), - content: sys.to_string(), + role: "user".to_string(), + content, + }); + } else { + if let Some(sys) = system_prompt { + messages.push(Message { + role: "system".to_string(), + content: sys.to_string(), + }); + } + messages.push(Message { + role: "user".to_string(), + content: message.to_string(), }); } - messages.push(Message { - role: "user".to_string(), - content: message.to_string(), - }); - let request = ApiChatRequest { model: model.to_string(), messages, @@ -809,10 +951,40 @@ impl Provider for OpenAiCompatibleProvider { let url = self.chat_completions_url(); - let response = self + let mut fallback_messages = Vec::new(); + if let Some(system_prompt) = system_prompt { + fallback_messages.push(ChatMessage::system(system_prompt)); + } + fallback_messages.push(ChatMessage::user(message)); + let fallback_messages = if self.merge_system_into_user { + Self::flatten_system_messages(&fallback_messages) + } else { + fallback_messages + }; + + let response = match self .apply_auth_header(self.http_client().post(&url).json(&request), credential) .send() - .await?; + .await + { + Ok(response) => response, + Err(chat_error) => { + if self.supports_responses_fallback { + let sanitized = super::sanitize_api_error(&chat_error.to_string()); + return self + .chat_via_responses(credential, &fallback_messages, model) + .await + .map_err(|responses_err| { + anyhow::anyhow!( + "{} chat completions transport error: {sanitized} (responses fallback failed: {responses_err})", + self.name + ) + }); + } + + return Err(chat_error.into()); + } + }; if !response.status().is_success() { let status = response.status(); @@ -821,7 +993,7 @@ impl Provider for OpenAiCompatibleProvider { if status == reqwest::StatusCode::NOT_FOUND && self.supports_responses_fallback { return self - .chat_via_responses(credential, system_prompt, message, model) + .chat_via_responses(credential, &fallback_messages, model) .await .map_err(|responses_err| { anyhow::anyhow!( @@ -873,7 +1045,12 @@ impl Provider for OpenAiCompatibleProvider { ) })?; - let api_messages: Vec = messages + let effective_messages = if self.merge_system_into_user { + Self::flatten_system_messages(messages) + } else { + messages.to_vec() + }; + let api_messages: Vec = effective_messages .iter() .map(|m| Message { role: m.role.clone(), @@ -891,35 +1068,44 @@ impl Provider for OpenAiCompatibleProvider { }; let url = self.chat_completions_url(); - let response = self + let response = match self .apply_auth_header(self.http_client().post(&url).json(&request), credential) .send() - .await?; + .await + { + Ok(response) => response, + Err(chat_error) => { + if self.supports_responses_fallback { + let sanitized = super::sanitize_api_error(&chat_error.to_string()); + return self + .chat_via_responses(credential, &effective_messages, model) + .await + .map_err(|responses_err| { + anyhow::anyhow!( + "{} chat completions transport error: {sanitized} (responses fallback failed: {responses_err})", + self.name + ) + }); + } + + return Err(chat_error.into()); + } + }; if !response.status().is_success() { let status = response.status(); // Mirror chat_with_system: 404 may mean this provider uses the Responses API if status == reqwest::StatusCode::NOT_FOUND && self.supports_responses_fallback { - // Extract system prompt and last user message for responses fallback - let system = messages.iter().find(|m| m.role == "system"); - let last_user = messages.iter().rfind(|m| m.role == "user"); - if let Some(user_msg) = last_user { - return self - .chat_via_responses( - credential, - system.map(|m| m.content.as_str()), - &user_msg.content, - model, + return self + .chat_via_responses(credential, &effective_messages, model) + .await + .map_err(|responses_err| { + anyhow::anyhow!( + "{} API error (chat completions unavailable; responses fallback failed: {responses_err})", + self.name ) - .await - .map_err(|responses_err| { - anyhow::anyhow!( - "{} API error (chat completions unavailable; responses fallback failed: {responses_err})", - self.name - ) - }); - } + }); } return Err(super::api_error(&self.name, response).await); @@ -965,7 +1151,12 @@ impl Provider for OpenAiCompatibleProvider { ) })?; - let api_messages: Vec = messages + let effective_messages = if self.merge_system_into_user { + Self::flatten_system_messages(messages) + } else { + messages.to_vec() + }; + let api_messages: Vec = effective_messages .iter() .map(|m| Message { role: m.role.clone(), @@ -991,10 +1182,24 @@ impl Provider for OpenAiCompatibleProvider { }; let url = self.chat_completions_url(); - let response = self + let response = match self .apply_auth_header(self.http_client().post(&url).json(&request), credential) .send() - .await?; + .await + { + Ok(response) => response, + Err(error) => { + tracing::warn!( + "{} native tool call transport failed: {error}; falling back to history path", + self.name + ); + let text = self.chat_with_history(messages, model, temperature).await?; + return Ok(ProviderChatResponse { + text: Some(text), + tool_calls: vec![], + }); + } + }; if !response.status().is_success() { return Err(super::api_error(&self.name, response).await); @@ -1043,9 +1248,14 @@ impl Provider for OpenAiCompatibleProvider { })?; let tools = Self::convert_tool_specs(request.tools); + let effective_messages = if self.merge_system_into_user { + Self::flatten_system_messages(request.messages) + } else { + request.messages.to_vec() + }; let native_request = NativeChatRequest { model: model.to_string(), - messages: Self::convert_messages_for_native(request.messages), + messages: Self::convert_messages_for_native(&effective_messages), temperature, stream: Some(false), tool_choice: tools.as_ref().map(|_| "auto".to_string()), @@ -1053,10 +1263,36 @@ impl Provider for OpenAiCompatibleProvider { }; let url = self.chat_completions_url(); - let response = self - .apply_auth_header(self.http_client().post(&url).json(&native_request), credential) + let response = match self + .apply_auth_header( + self.http_client().post(&url).json(&native_request), + credential, + ) .send() - .await?; + .await + { + Ok(response) => response, + Err(chat_error) => { + if self.supports_responses_fallback { + let sanitized = super::sanitize_api_error(&chat_error.to_string()); + return self + .chat_via_responses(credential, &effective_messages, model) + .await + .map(|text| ProviderChatResponse { + text: Some(text), + tool_calls: vec![], + }) + .map_err(|responses_err| { + anyhow::anyhow!( + "{} native chat transport error: {sanitized} (responses fallback failed: {responses_err})", + self.name + ) + }); + } + + return Err(chat_error.into()); + } + }; if !response.status().is_success() { let status = response.status(); @@ -1076,28 +1312,19 @@ impl Provider for OpenAiCompatibleProvider { } if status == reqwest::StatusCode::NOT_FOUND && self.supports_responses_fallback { - let system = request.messages.iter().find(|m| m.role == "system"); - let last_user = request.messages.iter().rfind(|m| m.role == "user"); - if let Some(user_msg) = last_user { - return self - .chat_via_responses( - credential, - system.map(|m| m.content.as_str()), - &user_msg.content, - model, + return self + .chat_via_responses(credential, &effective_messages, model) + .await + .map(|text| ProviderChatResponse { + text: Some(text), + tool_calls: vec![], + }) + .map_err(|responses_err| { + anyhow::anyhow!( + "{} API error ({status}): {sanitized} (chat completions unavailable; responses fallback failed: {responses_err})", + self.name ) - .await - .map(|text| ProviderChatResponse { - text: Some(text), - tool_calls: vec![], - }) - .map_err(|responses_err| { - anyhow::anyhow!( - "{} API error ({status}): {sanitized} (chat completions unavailable; responses fallback failed: {responses_err})", - self.name - ) - }); - } + }); } anyhow::bail!("{} API error ({status}): {sanitized}", self.name); @@ -1432,6 +1659,43 @@ mod tests { ); } + #[test] + fn build_responses_prompt_preserves_multi_turn_history() { + let messages = vec![ + ChatMessage::system("policy"), + ChatMessage::user("step 1"), + ChatMessage::assistant("ack 1"), + ChatMessage::tool("{\"result\":\"ok\"}"), + ChatMessage::user("step 2"), + ]; + + let (instructions, input) = build_responses_prompt(&messages); + + assert_eq!(instructions.as_deref(), Some("policy")); + assert_eq!(input.len(), 4); + assert_eq!(input[0].role, "user"); + assert_eq!(input[0].content, "step 1"); + assert_eq!(input[1].role, "assistant"); + assert_eq!(input[1].content, "ack 1"); + assert_eq!(input[2].role, "assistant"); + assert_eq!(input[2].content, "{\"result\":\"ok\"}"); + assert_eq!(input[3].role, "user"); + assert_eq!(input[3].content, "step 2"); + } + + #[tokio::test] + async fn chat_via_responses_requires_non_system_message() { + let provider = make_provider("custom", "https://api.example.com", Some("test-key")); + let err = provider + .chat_via_responses("test-key", &[ChatMessage::system("policy")], "gpt-test") + .await + .expect_err("system-only fallback payload should fail"); + + assert!(err + .to_string() + .contains("requires at least one non-system message")); + } + // ---------------------------------------------------------- // Custom endpoint path tests (Issue #114) // ---------------------------------------------------------- @@ -1657,6 +1921,48 @@ mod tests { assert_eq!(converted[0].content.as_deref(), Some("done")); } + #[test] + fn flatten_system_messages_merges_into_first_user() { + let input = vec![ + ChatMessage::system("core policy"), + ChatMessage::assistant("ack"), + ChatMessage::system("delivery rules"), + ChatMessage::user("hello"), + ChatMessage::assistant("post-user"), + ]; + + let output = OpenAiCompatibleProvider::flatten_system_messages(&input); + assert_eq!(output.len(), 3); + assert_eq!(output[0].role, "assistant"); + assert_eq!(output[0].content, "ack"); + assert_eq!(output[1].role, "user"); + assert_eq!(output[1].content, "core policy\n\ndelivery rules\n\nhello"); + assert_eq!(output[2].role, "assistant"); + assert_eq!(output[2].content, "post-user"); + assert!(output.iter().all(|m| m.role != "system")); + } + + #[test] + fn flatten_system_messages_inserts_user_when_missing() { + let input = vec![ + ChatMessage::system("core policy"), + ChatMessage::assistant("ack"), + ]; + + let output = OpenAiCompatibleProvider::flatten_system_messages(&input); + assert_eq!(output.len(), 2); + assert_eq!(output[0].role, "user"); + assert_eq!(output[0].content, "core policy"); + assert_eq!(output[1].role, "assistant"); + assert_eq!(output[1].content, "ack"); + } + + #[test] + fn strip_think_tags_drops_unclosed_block_suffix() { + let input = "visiblehidden"; + assert_eq!(strip_think_tags(input), "visible"); + } + #[test] fn native_tool_schema_unsupported_detection_is_precise() { assert!(OpenAiCompatibleProvider::is_native_tool_schema_unsupported( @@ -1876,6 +2182,56 @@ mod tests { assert!(msg.tool_calls.is_none()); } + #[test] + fn flatten_system_messages_merges_into_first_user_and_removes_system_roles() { + let messages = vec![ + ChatMessage::system("System A"), + ChatMessage::assistant("Earlier assistant turn"), + ChatMessage::system("System B"), + ChatMessage::user("User turn"), + ChatMessage::tool(r#"{"ok":true}"#), + ]; + + let flattened = OpenAiCompatibleProvider::flatten_system_messages(&messages); + assert_eq!(flattened.len(), 3); + assert_eq!(flattened[0].role, "assistant"); + assert_eq!( + flattened[1].content, + "System A\n\nSystem B\n\nUser turn".to_string() + ); + assert_eq!(flattened[1].role, "user"); + assert_eq!(flattened[2].role, "tool"); + assert!(!flattened.iter().any(|m| m.role == "system")); + } + + #[test] + fn flatten_system_messages_inserts_synthetic_user_when_no_user_exists() { + let messages = vec![ + ChatMessage::assistant("Assistant only"), + ChatMessage::system("Synthetic system"), + ]; + + let flattened = OpenAiCompatibleProvider::flatten_system_messages(&messages); + assert_eq!(flattened.len(), 2); + assert_eq!(flattened[0].role, "user"); + assert_eq!(flattened[0].content, "Synthetic system"); + assert_eq!(flattened[1].role, "assistant"); + } + + #[test] + fn strip_think_tags_removes_multiple_blocks_with_surrounding_text() { + let input = "Answer A hidden 1 and B hidden 2 done"; + let output = strip_think_tags(input); + assert_eq!(output, "Answer A and B done"); + } + + #[test] + fn strip_think_tags_drops_tail_for_unclosed_block() { + let input = "Visiblehidden tail"; + let output = strip_think_tags(input); + assert_eq!(output, "Visible"); + } + // ---------------------------------------------------------- // Reasoning model fallback tests (reasoning_content) // ---------------------------------------------------------- @@ -1917,6 +2273,18 @@ mod tests { assert_eq!(msg.effective_content(), "Normal response"); } + #[test] + fn reasoning_content_used_when_content_only_think_tags() { + let json = r#"{"choices":[{"message":{"content":"secret","reasoning_content":"Fallback text"}}]}"#; + let resp: ApiChatResponse = serde_json::from_str(json).unwrap(); + let msg = &resp.choices[0].message; + assert_eq!(msg.effective_content(), "Fallback text"); + assert_eq!( + msg.effective_content_optional().as_deref(), + Some("Fallback text") + ); + } + #[test] fn reasoning_content_both_absent_returns_empty() { // Neither content nor reasoning_content - returns empty string diff --git a/src/providers/copilot.rs b/src/providers/copilot.rs index b49f7dd..6c72e63 100644 --- a/src/providers/copilot.rs +++ b/src/providers/copilot.rs @@ -81,12 +81,12 @@ struct CachedApiKey { // ── Chat completions types ─────────────────────────────────────── #[derive(Debug, Serialize)] -struct ApiChatRequest { +struct ApiChatRequest<'a> { model: String, messages: Vec, temperature: f64, #[serde(skip_serializing_if = "Option::is_none")] - tools: Option>, + tools: Option>>, #[serde(skip_serializing_if = "Option::is_none")] tool_choice: Option, } @@ -103,17 +103,17 @@ struct ApiMessage { } #[derive(Debug, Serialize)] -struct NativeToolSpec { +struct NativeToolSpec<'a> { #[serde(rename = "type")] - kind: String, - function: NativeToolFunctionSpec, + kind: &'static str, + function: NativeToolFunctionSpec<'a>, } #[derive(Debug, Serialize)] -struct NativeToolFunctionSpec { - name: String, - description: String, - parameters: serde_json::Value, +struct NativeToolFunctionSpec<'a> { + name: &'a str, + description: &'a str, + parameters: &'a serde_json::Value, } #[derive(Debug, Serialize, Deserialize)] @@ -219,16 +219,16 @@ impl CopilotProvider { ("Accept", "application/json"), ]; - fn convert_tools(tools: Option<&[ToolSpec]>) -> Option> { + fn convert_tools<'a>(tools: Option<&'a [ToolSpec]>) -> Option>> { tools.map(|items| { items .iter() .map(|tool| NativeToolSpec { - kind: "function".to_string(), + kind: "function", function: NativeToolFunctionSpec { - name: tool.name.clone(), - description: tool.description.clone(), - parameters: tool.parameters.clone(), + name: &tool.name, + description: &tool.description, + parameters: &tool.parameters, }, }) .collect() diff --git a/src/providers/gemini.rs b/src/providers/gemini.rs index 4da916c..b3b7110 100644 --- a/src/providers/gemini.rs +++ b/src/providers/gemini.rs @@ -3,7 +3,7 @@ //! - Gemini CLI OAuth tokens (reuse existing ~/.gemini/ authentication) //! - Google Cloud ADC (`GOOGLE_APPLICATION_CREDENTIALS`) -use crate::providers::traits::Provider; +use crate::providers::traits::{ChatMessage, Provider}; use async_trait::async_trait; use directories::UserDirs; use reqwest::Client; @@ -58,10 +58,10 @@ impl GeminiAuth { // API REQUEST/RESPONSE TYPES // ══════════════════════════════════════════════════════════════════════════════ -#[derive(Debug, Serialize)] +#[derive(Debug, Serialize, Clone)] struct GenerateContentRequest { contents: Vec, - #[serde(skip_serializing_if = "Option::is_none")] + #[serde(rename = "systemInstruction", skip_serializing_if = "Option::is_none")] system_instruction: Option, #[serde(rename = "generationConfig")] generation_config: GenerationConfig, @@ -70,23 +70,33 @@ struct GenerateContentRequest { /// Request envelope for the internal cloudcode-pa API. /// OAuth tokens from Gemini CLI are scoped for this endpoint. #[derive(Debug, Serialize)] -struct InternalGenerateContentRequest { +struct InternalGenerateContentEnvelope { model: String, - #[serde(rename = "generationConfig")] - generation_config: GenerationConfig, - contents: Vec, #[serde(skip_serializing_if = "Option::is_none")] - system_instruction: Option, + project: Option, + #[serde(skip_serializing_if = "Option::is_none")] + user_prompt_id: Option, + request: InternalGenerateContentRequest, } +/// Nested request payload for cloudcode-pa's code assist APIs. #[derive(Debug, Serialize)] +struct InternalGenerateContentRequest { + contents: Vec, + #[serde(rename = "systemInstruction", skip_serializing_if = "Option::is_none")] + system_instruction: Option, + #[serde(rename = "generationConfig")] + generation_config: GenerationConfig, +} + +#[derive(Debug, Serialize, Clone)] struct Content { #[serde(skip_serializing_if = "Option::is_none")] role: Option, parts: Vec, } -#[derive(Debug, Serialize)] +#[derive(Debug, Serialize, Clone)] struct Part { text: String, } @@ -102,6 +112,8 @@ struct GenerationConfig { struct GenerateContentResponse { candidates: Option>, error: Option, + #[serde(default)] + response: Option>, } #[derive(Debug, Deserialize)] @@ -124,6 +136,19 @@ struct ApiError { message: String, } +impl GenerateContentResponse { + /// cloudcode-pa wraps the actual response under `response`. + fn into_effective_response(self) -> Self { + match self { + Self { + response: Some(inner), + .. + } => *inner, + other => other, + } + } +} + // ══════════════════════════════════════════════════════════════════════════════ // GEMINI CLI TOKEN STRUCTURES // ══════════════════════════════════════════════════════════════════════════════ @@ -243,6 +268,10 @@ impl GeminiProvider { } } + fn format_internal_model_name(model: &str) -> String { + model.strip_prefix("models/").unwrap_or(model).to_string() + } + /// Build the API URL based on auth type. /// /// - API key users → public `generativelanguage.googleapis.com/v1beta` @@ -287,34 +316,16 @@ impl GeminiProvider { let req = self.http_client().post(url).json(request); match auth { GeminiAuth::OAuthToken(token) => { - // Internal API expects the model in the request body envelope - let internal_request = InternalGenerateContentRequest { - model: Self::format_model_name(model), - generation_config: request.generation_config.clone(), - contents: request - .contents - .iter() - .map(|c| Content { - role: c.role.clone(), - parts: c - .parts - .iter() - .map(|p| Part { - text: p.text.clone(), - }) - .collect(), - }) - .collect(), - system_instruction: request.system_instruction.as_ref().map(|si| Content { - role: si.role.clone(), - parts: si - .parts - .iter() - .map(|p| Part { - text: p.text.clone(), - }) - .collect(), - }), + // cloudcode-pa expects an outer envelope with `request`. + let internal_request = InternalGenerateContentEnvelope { + model: Self::format_internal_model_name(model), + project: None, + user_prompt_id: None, + request: InternalGenerateContentRequest { + contents: request.contents.clone(), + system_instruction: request.system_instruction.clone(), + generation_config: request.generation_config.clone(), + }, }; self.http_client() .post(url) @@ -326,12 +337,11 @@ impl GeminiProvider { } } -#[async_trait] -impl Provider for GeminiProvider { - async fn chat_with_system( +impl GeminiProvider { + async fn send_generate_content( &self, - system_prompt: Option<&str>, - message: &str, + contents: Vec, + system_instruction: Option, model: &str, temperature: f64, ) -> anyhow::Result { @@ -345,21 +355,8 @@ impl Provider for GeminiProvider { ) })?; - // Build request - let system_instruction = system_prompt.map(|sys| Content { - role: None, - parts: vec![Part { - text: sys.to_string(), - }], - }); - let request = GenerateContentRequest { - contents: vec![Content { - role: Some("user".to_string()), - parts: vec![Part { - text: message.to_string(), - }], - }], + contents, system_instruction, generation_config: GenerationConfig { temperature, @@ -381,13 +378,14 @@ impl Provider for GeminiProvider { } let result: GenerateContentResponse = response.json().await?; - - // Check for API error in response body + if let Some(err) = &result.error { + anyhow::bail!("Gemini API error: {}", err.message); + } + let result = result.into_effective_response(); if let Some(err) = result.error { anyhow::bail!("Gemini API error: {}", err.message); } - // Extract text from response result .candidates .and_then(|c| c.into_iter().next()) @@ -395,9 +393,93 @@ impl Provider for GeminiProvider { .and_then(|p| p.text) .ok_or_else(|| anyhow::anyhow!("No response from Gemini")) } +} + +#[async_trait] +impl Provider for GeminiProvider { + async fn chat_with_system( + &self, + system_prompt: Option<&str>, + message: &str, + model: &str, + temperature: f64, + ) -> anyhow::Result { + let system_instruction = system_prompt.map(|sys| Content { + role: None, + parts: vec![Part { + text: sys.to_string(), + }], + }); + + let contents = vec![Content { + role: Some("user".to_string()), + parts: vec![Part { + text: message.to_string(), + }], + }]; + + self.send_generate_content(contents, system_instruction, model, temperature) + .await + } + + async fn chat_with_history( + &self, + messages: &[ChatMessage], + model: &str, + temperature: f64, + ) -> anyhow::Result { + let mut system_parts: Vec<&str> = Vec::new(); + let mut contents: Vec = Vec::new(); + + for msg in messages { + match msg.role.as_str() { + "system" => { + system_parts.push(&msg.content); + } + "user" => { + contents.push(Content { + role: Some("user".to_string()), + parts: vec![Part { + text: msg.content.clone(), + }], + }); + } + "assistant" => { + // Gemini API uses "model" role instead of "assistant" + contents.push(Content { + role: Some("model".to_string()), + parts: vec![Part { + text: msg.content.clone(), + }], + }); + } + _ => {} + } + } + + let system_instruction = if system_parts.is_empty() { + None + } else { + Some(Content { + role: None, + parts: vec![Part { + text: system_parts.join("\n\n"), + }], + }) + }; + + self.send_generate_content(contents, system_instruction, model, temperature) + .await + } async fn warmup(&self) -> anyhow::Result<()> { if let Some(auth) = self.auth.as_ref() { + // cloudcode-pa does not expose a lightweight model-list probe like the public API. + // Avoid false negatives for valid Gemini CLI OAuth credentials. + if auth.is_oauth() { + return Ok(()); + } + let url = if auth.is_api_key() { format!( "https://generativelanguage.googleapis.com/v1beta/models?key={}", @@ -407,12 +489,11 @@ impl Provider for GeminiProvider { "https://generativelanguage.googleapis.com/v1beta/models".to_string() }; - let mut request = self.http_client().get(&url); - if let GeminiAuth::OAuthToken(token) = auth { - request = request.bearer_auth(token); - } - - request.send().await?.error_for_status()?; + self.http_client() + .get(&url) + .send() + .await? + .error_for_status()?; } Ok(()) } @@ -497,6 +578,14 @@ mod tests { GeminiProvider::format_model_name("models/gemini-1.5-pro"), "models/gemini-1.5-pro" ); + assert_eq!( + GeminiProvider::format_internal_model_name("models/gemini-2.5-flash"), + "gemini-2.5-flash" + ); + assert_eq!( + GeminiProvider::format_internal_model_name("gemini-2.5-flash"), + "gemini-2.5-flash" + ); } #[test] @@ -559,6 +648,44 @@ mod tests { ); } + #[test] + fn oauth_request_wraps_payload_in_request_envelope() { + let provider = GeminiProvider { + auth: Some(GeminiAuth::OAuthToken("ya29.mock-token".into())), + }; + let auth = GeminiAuth::OAuthToken("ya29.mock-token".into()); + let url = GeminiProvider::build_generate_content_url("gemini-2.0-flash", &auth); + let body = GenerateContentRequest { + contents: vec![Content { + role: Some("user".into()), + parts: vec![Part { + text: "hello".into(), + }], + }], + system_instruction: None, + generation_config: GenerationConfig { + temperature: 0.7, + max_output_tokens: 8192, + }, + }; + + let request = provider + .build_generate_content_request(&auth, &url, &body, "models/gemini-2.0-flash") + .build() + .unwrap(); + + let payload = request + .body() + .and_then(|b| b.as_bytes()) + .expect("json request body should be bytes"); + let json: serde_json::Value = serde_json::from_slice(payload).unwrap(); + + assert_eq!(json["model"], "gemini-2.0-flash"); + assert!(json.get("generationConfig").is_none()); + assert!(json.get("request").is_some()); + assert!(json["request"].get("generationConfig").is_some()); + } + #[test] fn api_key_request_does_not_set_bearer_header() { let provider = GeminiProvider { @@ -612,31 +739,38 @@ mod tests { let json = serde_json::to_string(&request).unwrap(); assert!(json.contains("\"role\":\"user\"")); assert!(json.contains("\"text\":\"Hello\"")); + assert!(json.contains("\"systemInstruction\"")); + assert!(!json.contains("\"system_instruction\"")); assert!(json.contains("\"temperature\":0.7")); assert!(json.contains("\"maxOutputTokens\":8192")); } #[test] fn internal_request_includes_model() { - let request = InternalGenerateContentRequest { - model: "models/gemini-3-pro-preview".to_string(), - generation_config: GenerationConfig { - temperature: 0.7, - max_output_tokens: 8192, - }, - contents: vec![Content { - role: Some("user".to_string()), - parts: vec![Part { - text: "Hello".to_string(), + let request = InternalGenerateContentEnvelope { + model: "gemini-test-model".to_string(), + project: None, + user_prompt_id: None, + request: InternalGenerateContentRequest { + contents: vec![Content { + role: Some("user".to_string()), + parts: vec![Part { + text: "Hello".to_string(), + }], }], - }], - system_instruction: None, + system_instruction: None, + generation_config: GenerationConfig { + temperature: 0.7, + max_output_tokens: 8192, + }, + }, }; - let json = serde_json::to_string(&request).unwrap(); - assert!(json.contains("\"model\":\"models/gemini-3-pro-preview\"")); - assert!(json.contains("\"role\":\"user\"")); - assert!(json.contains("\"temperature\":0.7")); + let json: serde_json::Value = serde_json::to_value(&request).unwrap(); + assert_eq!(json["model"], "gemini-test-model"); + assert!(json.get("generationConfig").is_none()); + assert!(json["request"].get("generationConfig").is_some()); + assert_eq!(json["request"]["contents"][0]["role"], "user"); } #[test] @@ -679,10 +813,48 @@ mod tests { assert_eq!(response.error.unwrap().message, "Invalid API key"); } + #[test] + fn internal_response_deserialization() { + let json = r#"{ + "response": { + "candidates": [{ + "content": { + "parts": [{"text": "Hello from internal"}] + } + }] + } + }"#; + + let response: GenerateContentResponse = serde_json::from_str(json).unwrap(); + let text = response + .into_effective_response() + .candidates + .unwrap() + .into_iter() + .next() + .unwrap() + .content + .parts + .into_iter() + .next() + .unwrap() + .text; + assert_eq!(text, Some("Hello from internal".to_string())); + } + #[tokio::test] async fn warmup_without_key_is_noop() { let provider = GeminiProvider { auth: None }; let result = provider.warmup().await; assert!(result.is_ok()); } + + #[tokio::test] + async fn warmup_oauth_is_noop() { + let provider = GeminiProvider { + auth: Some(GeminiAuth::OAuthToken("ya29.mock-token".into())), + }; + let result = provider.warmup().await; + assert!(result.is_ok()); + } } diff --git a/src/providers/mod.rs b/src/providers/mod.rs index 119c14e..1de2956 100644 --- a/src/providers/mod.rs +++ b/src/providers/mod.rs @@ -1,4 +1,23 @@ +//! Provider subsystem for model inference backends. +//! +//! This module implements the factory pattern for AI model providers. Each provider +//! implements the [`Provider`] trait defined in [`traits`], and is registered in the +//! factory function [`create_provider`] by its canonical string key (e.g., `"openai"`, +//! `"anthropic"`, `"ollama"`, `"gemini"`). Provider aliases are resolved internally +//! so that user-facing keys remain stable. +//! +//! The subsystem supports resilient multi-provider configurations through the +//! [`ReliableProvider`](reliable::ReliableProvider) wrapper, which handles fallback +//! chains and automatic retry. Model routing across providers is available via +//! [`create_routed_provider`]. +//! +//! # Extension +//! +//! To add a new provider, implement [`Provider`] in a new submodule and register it +//! in [`create_provider_with_url`]. See `AGENTS.md` §7.1 for the full change playbook. + pub mod anthropic; +pub mod bedrock; pub mod compatible; pub mod copilot; pub mod gemini; @@ -12,8 +31,8 @@ pub mod traits; #[allow(unused_imports)] pub use traits::{ - ChatMessage, ChatRequest, ChatResponse, ConversationMessage, Provider, ToolCall, - ToolResultMessage, + ChatMessage, ChatRequest, ChatResponse, ConversationMessage, Provider, ProviderCapabilityError, + ToolCall, ToolResultMessage, }; use compatible::{AuthStyle, OpenAiCompatibleProvider}; @@ -41,6 +60,15 @@ const MOONSHOT_CN_BASE_URL: &str = "https://api.moonshot.cn/v1"; const QWEN_CN_BASE_URL: &str = "https://dashscope.aliyuncs.com/compatible-mode/v1"; const QWEN_INTL_BASE_URL: &str = "https://dashscope-intl.aliyuncs.com/compatible-mode/v1"; const QWEN_US_BASE_URL: &str = "https://dashscope-us.aliyuncs.com/compatible-mode/v1"; +const QWEN_OAUTH_BASE_FALLBACK_URL: &str = QWEN_CN_BASE_URL; +const QWEN_OAUTH_TOKEN_ENDPOINT: &str = "https://chat.qwen.ai/api/v1/oauth2/token"; +const QWEN_OAUTH_PLACEHOLDER: &str = "qwen-oauth"; +const QWEN_OAUTH_TOKEN_ENV: &str = "QWEN_OAUTH_TOKEN"; +const QWEN_OAUTH_REFRESH_TOKEN_ENV: &str = "QWEN_OAUTH_REFRESH_TOKEN"; +const QWEN_OAUTH_RESOURCE_URL_ENV: &str = "QWEN_OAUTH_RESOURCE_URL"; +const QWEN_OAUTH_CLIENT_ID_ENV: &str = "QWEN_OAUTH_CLIENT_ID"; +const QWEN_OAUTH_DEFAULT_CLIENT_ID: &str = "f0304373b74a44d2b584a3fb70ca9e56"; +const QWEN_OAUTH_CREDENTIAL_FILE: &str = ".qwen/oauth_creds.json"; const ZAI_GLOBAL_BASE_URL: &str = "https://api.z.ai/api/coding/paas/v4"; const ZAI_CN_BASE_URL: &str = "https://open.bigmodel.cn/api/coding/paas/v4"; @@ -111,8 +139,15 @@ pub(crate) fn is_qwen_us_alias(name: &str) -> bool { matches!(name, "qwen-us" | "dashscope-us") } +pub(crate) fn is_qwen_oauth_alias(name: &str) -> bool { + matches!(name, "qwen-code" | "qwen-oauth" | "qwen_oauth") +} + pub(crate) fn is_qwen_alias(name: &str) -> bool { - is_qwen_cn_alias(name) || is_qwen_intl_alias(name) || is_qwen_us_alias(name) + is_qwen_cn_alias(name) + || is_qwen_intl_alias(name) + || is_qwen_us_alias(name) + || is_qwen_oauth_alias(name) } pub(crate) fn is_zai_global_alias(name: &str) -> bool { @@ -162,6 +197,60 @@ struct MinimaxOauthBaseResponse { status_msg: Option, } +#[derive(Clone, Deserialize, Default)] +struct QwenOauthCredentials { + #[serde(default)] + access_token: Option, + #[serde(default)] + refresh_token: Option, + #[serde(default)] + resource_url: Option, + #[serde(default)] + expiry_date: Option, +} + +impl std::fmt::Debug for QwenOauthCredentials { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("QwenOauthCredentials") + .field("access_token", &self.access_token.as_ref().map(|_| "[REDACTED]")) + .field("refresh_token", &self.refresh_token.as_ref().map(|_| "[REDACTED]")) + .field("resource_url", &self.resource_url) + .field("expiry_date", &self.expiry_date) + .finish() + } +} + +#[derive(Debug, Deserialize)] +struct QwenOauthTokenResponse { + #[serde(default)] + access_token: Option, + #[serde(default)] + refresh_token: Option, + #[serde(default)] + expires_in: Option, + #[serde(default)] + resource_url: Option, + #[serde(default)] + error: Option, + #[serde(default)] + error_description: Option, +} + +#[derive(Clone, Default)] +struct QwenOauthProviderContext { + credential: Option, + base_url: Option, +} + +impl std::fmt::Debug for QwenOauthProviderContext { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("QwenOauthProviderContext") + .field("credential", &self.credential.as_ref().map(|_| "[REDACTED]")) + .field("base_url", &self.base_url) + .finish() + } +} + fn read_non_empty_env(name: &str) -> Option { std::env::var(name) .ok() @@ -197,6 +286,233 @@ fn minimax_oauth_client_id() -> String { .unwrap_or_else(|| MINIMAX_OAUTH_DEFAULT_CLIENT_ID.to_string()) } +fn qwen_oauth_client_id() -> String { + read_non_empty_env(QWEN_OAUTH_CLIENT_ID_ENV) + .unwrap_or_else(|| QWEN_OAUTH_DEFAULT_CLIENT_ID.to_string()) +} + +fn qwen_oauth_credentials_file_path() -> Option { + std::env::var_os("HOME") + .map(PathBuf::from) + .or_else(|| std::env::var_os("USERPROFILE").map(PathBuf::from)) + .map(|home| home.join(QWEN_OAUTH_CREDENTIAL_FILE)) +} + +fn normalize_qwen_oauth_base_url(raw: &str) -> Option { + let trimmed = raw.trim().trim_end_matches('/'); + if trimmed.is_empty() { + return None; + } + + let with_scheme = if trimmed.starts_with("http://") || trimmed.starts_with("https://") { + trimmed.to_string() + } else { + format!("https://{trimmed}") + }; + + let normalized = with_scheme.trim_end_matches('/').to_string(); + if normalized.ends_with("/v1") { + Some(normalized) + } else { + Some(format!("{normalized}/v1")) + } +} + +fn read_qwen_oauth_cached_credentials() -> Option { + let path = qwen_oauth_credentials_file_path()?; + let content = std::fs::read_to_string(path).ok()?; + serde_json::from_str::(&content).ok() +} + +fn normalized_qwen_expiry_millis(raw: i64) -> i64 { + if raw < 10_000_000_000 { + raw.saturating_mul(1000) + } else { + raw + } +} + +fn qwen_oauth_token_expired(credentials: &QwenOauthCredentials) -> bool { + let Some(expiry) = credentials.expiry_date else { + return false; + }; + + let expiry_millis = normalized_qwen_expiry_millis(expiry); + let now_millis = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .ok() + .and_then(|duration| i64::try_from(duration.as_millis()).ok()) + .unwrap_or(i64::MAX); + + expiry_millis <= now_millis.saturating_add(30_000) +} + +fn refresh_qwen_oauth_access_token(refresh_token: &str) -> anyhow::Result { + let client_id = qwen_oauth_client_id(); + let client = reqwest::blocking::Client::builder() + .timeout(std::time::Duration::from_secs(15)) + .connect_timeout(std::time::Duration::from_secs(5)) + .build() + .unwrap_or_else(|_| reqwest::blocking::Client::new()); + + let response = client + .post(QWEN_OAUTH_TOKEN_ENDPOINT) + .header("Content-Type", "application/x-www-form-urlencoded") + .header("Accept", "application/json") + .form(&[ + ("grant_type", "refresh_token"), + ("refresh_token", refresh_token), + ("client_id", client_id.as_str()), + ]) + .send() + .map_err(|error| anyhow::anyhow!("Qwen OAuth refresh request failed: {error}"))?; + + let status = response.status(); + let body = response + .text() + .unwrap_or_else(|_| "".to_string()); + + let parsed = serde_json::from_str::(&body).ok(); + + if !status.is_success() { + let detail = parsed + .as_ref() + .and_then(|payload| payload.error_description.as_deref()) + .or_else(|| parsed.as_ref().and_then(|payload| payload.error.as_deref())) + .filter(|msg| !msg.trim().is_empty()) + .unwrap_or(body.as_str()); + anyhow::bail!("Qwen OAuth refresh failed (HTTP {status}): {detail}"); + } + + let payload = + parsed.ok_or_else(|| anyhow::anyhow!("Qwen OAuth refresh response is not JSON"))?; + + if let Some(error_code) = payload + .error + .as_deref() + .filter(|value| !value.trim().is_empty()) + { + let detail = payload.error_description.as_deref().unwrap_or(error_code); + anyhow::bail!("Qwen OAuth refresh failed: {detail}"); + } + + let access_token = payload + .access_token + .as_deref() + .map(str::trim) + .filter(|token| !token.is_empty()) + .ok_or_else(|| anyhow::anyhow!("Qwen OAuth refresh response missing access_token"))? + .to_string(); + + let expiry_date = payload.expires_in.and_then(|seconds| { + let now_secs = std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .ok() + .and_then(|duration| i64::try_from(duration.as_secs()).ok())?; + now_secs + .checked_add(seconds) + .and_then(|unix_secs| unix_secs.checked_mul(1000)) + }); + + Ok(QwenOauthCredentials { + access_token: Some(access_token), + refresh_token: payload + .refresh_token + .as_deref() + .map(str::trim) + .filter(|value| !value.is_empty()) + .map(ToString::to_string), + resource_url: payload + .resource_url + .as_deref() + .map(str::trim) + .filter(|value| !value.is_empty()) + .map(ToString::to_string), + expiry_date, + }) +} + +fn resolve_qwen_oauth_context(credential_override: Option<&str>) -> QwenOauthProviderContext { + let override_value = credential_override + .map(str::trim) + .filter(|value| !value.is_empty()); + let placeholder_requested = override_value + .map(|value| value.eq_ignore_ascii_case(QWEN_OAUTH_PLACEHOLDER)) + .unwrap_or(false); + + if let Some(explicit) = override_value { + if !placeholder_requested { + return QwenOauthProviderContext { + credential: Some(explicit.to_string()), + base_url: None, + }; + } + } + + let mut cached = read_qwen_oauth_cached_credentials(); + + let env_token = read_non_empty_env(QWEN_OAUTH_TOKEN_ENV); + let env_refresh_token = read_non_empty_env(QWEN_OAUTH_REFRESH_TOKEN_ENV); + let env_resource_url = read_non_empty_env(QWEN_OAUTH_RESOURCE_URL_ENV); + + if env_token.is_none() { + let refresh_token = env_refresh_token.clone().or_else(|| { + cached + .as_ref() + .and_then(|credentials| credentials.refresh_token.clone()) + }); + + let should_refresh = cached.as_ref().is_some_and(qwen_oauth_token_expired) + || cached + .as_ref() + .and_then(|credentials| credentials.access_token.as_deref()) + .is_none_or(|value| value.trim().is_empty()); + + if should_refresh { + if let Some(refresh_token) = refresh_token.as_deref() { + match refresh_qwen_oauth_access_token(refresh_token) { + Ok(refreshed) => { + cached = Some(refreshed); + } + Err(error) => { + tracing::warn!(error = %error, "Qwen OAuth refresh failed"); + } + } + } + } + } + + let mut credential = env_token.or_else(|| { + cached + .as_ref() + .and_then(|credentials| credentials.access_token.clone()) + }); + credential = credential + .as_deref() + .map(str::trim) + .filter(|value| !value.is_empty()) + .map(ToString::to_string); + + if credential.is_none() && !placeholder_requested { + credential = read_non_empty_env("DASHSCOPE_API_KEY"); + } + + let base_url = env_resource_url + .as_deref() + .and_then(normalize_qwen_oauth_base_url) + .or_else(|| { + cached + .as_ref() + .and_then(|credentials| credentials.resource_url.as_deref()) + .and_then(normalize_qwen_oauth_base_url) + }); + + QwenOauthProviderContext { + credential, + base_url, + } +} + fn resolve_minimax_static_credential() -> Option { read_non_empty_env(MINIMAX_OAUTH_TOKEN_ENV).or_else(|| read_non_empty_env(MINIMAX_API_KEY_ENV)) } @@ -326,7 +642,7 @@ fn moonshot_base_url(name: &str) -> Option<&'static str> { } fn qwen_base_url(name: &str) -> Option<&'static str> { - if is_qwen_cn_alias(name) { + if is_qwen_cn_alias(name) || is_qwen_oauth_alias(name) { Some(QWEN_CN_BASE_URL) } else if is_qwen_intl_alias(name) { Some(QWEN_INTL_BASE_URL) @@ -352,6 +668,7 @@ pub struct ProviderRuntimeOptions { pub auth_profile_override: Option, pub zeroclaw_dir: Option, pub secrets_encrypt: bool, + pub reasoning_enabled: Option, } impl Default for ProviderRuntimeOptions { @@ -360,6 +677,7 @@ impl Default for ProviderRuntimeOptions { auth_profile_override: None, zeroclaw_dir: None, secrets_encrypt: true, + reasoning_enabled: None, } } } @@ -502,6 +820,9 @@ fn resolve_provider_credential(name: &str, credential_override: Option<&str>) -> } name if is_glm_alias(name) => vec!["GLM_API_KEY"], name if is_minimax_alias(name) => vec![MINIMAX_OAUTH_TOKEN_ENV, MINIMAX_API_KEY_ENV], + // Bedrock uses AWS AKSK from env vars (AWS_ACCESS_KEY_ID + AWS_SECRET_ACCESS_KEY), + // not a single API key. Credential resolution happens inside BedrockProvider. + "bedrock" | "aws-bedrock" => return None, name if is_qianfan_alias(name) => vec!["QIANFAN_API_KEY"], name if is_qwen_alias(name) => vec!["DASHSCOPE_API_KEY"], name if is_zai_alias(name) => vec!["ZAI_API_KEY"], @@ -584,18 +905,38 @@ pub fn create_provider_with_options( "openai-codex" | "openai_codex" | "codex" => { Ok(Box::new(openai_codex::OpenAiCodexProvider::new(options))) } - _ => create_provider_with_url(name, api_key, None), + _ => create_provider_with_url_and_options(name, api_key, None, options), } } /// Factory: create the right provider from config with optional custom base URL -#[allow(clippy::too_many_lines)] pub fn create_provider_with_url( name: &str, api_key: Option<&str>, api_url: Option<&str>, ) -> anyhow::Result> { - let resolved_credential = resolve_provider_credential(name, api_key); + create_provider_with_url_and_options(name, api_key, api_url, &ProviderRuntimeOptions::default()) +} + +/// Factory: create provider with optional base URL and runtime options. +#[allow(clippy::too_many_lines)] +fn create_provider_with_url_and_options( + name: &str, + api_key: Option<&str>, + api_url: Option<&str>, + options: &ProviderRuntimeOptions, +) -> anyhow::Result> { + let qwen_oauth_context = is_qwen_oauth_alias(name).then(|| resolve_qwen_oauth_context(api_key)); + + // Resolve credential and break static-analysis taint chain from the + // `api_key` parameter so that downstream provider storage of the value + // is not linked to the original sensitive-named source. + let resolved_credential = if let Some(context) = qwen_oauth_context.as_ref() { + context.credential.clone() + } else { + resolve_provider_credential(name, api_key) + } + .map(|v| String::from_utf8(v.into_bytes()).unwrap_or_default()); #[allow(clippy::option_as_ref_deref)] let key = resolved_credential.as_ref().map(String::as_str); match name { @@ -604,7 +945,11 @@ pub fn create_provider_with_url( "anthropic" => Ok(Box::new(anthropic::AnthropicProvider::new(key))), "openai" => Ok(Box::new(openai::OpenAiProvider::with_base_url(api_url, key))), // Ollama uses api_url for custom base URL (e.g. remote Ollama instance) - "ollama" => Ok(Box::new(ollama::OllamaProvider::new(api_url, key))), + "ollama" => Ok(Box::new(ollama::OllamaProvider::new_with_reasoning( + api_url, + key, + options.reasoning_enabled, + ))), "gemini" | "google" | "google-gemini" => { Ok(Box::new(gemini::GeminiProvider::new(key))) } @@ -657,18 +1002,31 @@ pub fn create_provider_with_url( AuthStyle::Bearer, ))) } - name if minimax_base_url(name).is_some() => Ok(Box::new(OpenAiCompatibleProvider::new( - "MiniMax", - minimax_base_url(name).expect("checked in guard"), - key, - AuthStyle::Bearer, - ))), - "bedrock" | "aws-bedrock" => Ok(Box::new(OpenAiCompatibleProvider::new( - "Amazon Bedrock", - "https://bedrock-runtime.us-east-1.amazonaws.com", - key, - AuthStyle::Bearer, - ))), + name if minimax_base_url(name).is_some() => Ok(Box::new( + OpenAiCompatibleProvider::new_merge_system_into_user( + "MiniMax", + minimax_base_url(name).expect("checked in guard"), + key, + AuthStyle::Bearer, + ) + )), + "bedrock" | "aws-bedrock" => Ok(Box::new(bedrock::BedrockProvider::new())), + name if is_qwen_oauth_alias(name) => { + let base_url = api_url + .map(str::trim) + .filter(|value| !value.is_empty()) + .map(ToString::to_string) + .or_else(|| qwen_oauth_context.as_ref().and_then(|context| context.base_url.clone())) + .unwrap_or_else(|| QWEN_OAUTH_BASE_FALLBACK_URL.to_string()); + + Ok(Box::new(OpenAiCompatibleProvider::new_with_user_agent( + "Qwen Code", + &base_url, + key, + AuthStyle::Bearer, + "QwenCode/1.0", + ))) + } name if is_qianfan_alias(name) => Ok(Box::new(OpenAiCompatibleProvider::new( "Qianfan", "https://aip.baidubce.com", key, AuthStyle::Bearer, ))), @@ -704,11 +1062,9 @@ pub fn create_provider_with_url( "cohere" => Ok(Box::new(OpenAiCompatibleProvider::new( "Cohere", "https://api.cohere.com/compatibility", key, AuthStyle::Bearer, ))), - "copilot" | "github-copilot" => { - Ok(Box::new(copilot::CopilotProvider::new(api_key))) - }, + "copilot" | "github-copilot" => Ok(Box::new(copilot::CopilotProvider::new(key))), "lmstudio" | "lm-studio" => { - let lm_studio_key = api_key + let lm_studio_key = key .map(str::trim) .filter(|value| !value.is_empty()) .unwrap_or("lm-studio"); @@ -807,7 +1163,7 @@ pub fn create_resilient_provider_with_options( "openai-codex" | "openai_codex" | "codex" => { create_provider_with_options(primary_name, api_key, options)? } - _ => create_provider_with_url(primary_name, api_key, api_url)?, + _ => create_provider_with_url_and_options(primary_name, api_key, api_url, options)?, }; providers.push((primary_name.to_string(), primary_provider)); @@ -816,8 +1172,16 @@ pub fn create_resilient_provider_with_options( continue; } - // Fallback providers don't use the custom api_url (it's specific to primary). - match create_provider_with_options(fallback, api_key, options) { + // Each fallback provider resolves its own credential via provider- + // specific env vars (e.g. DEEPSEEK_API_KEY for "deepseek") instead + // of inheriting the primary provider's key. Passing `None` lets + // `resolve_provider_credential` check the correct env var for the + // fallback provider name. + // + // Keep using `create_provider_with_options` so fallback entries that + // require runtime options (for example Codex auth profile overrides) + // continue to work. + match create_provider_with_options(fallback, None, options) { Ok(provider) => providers.push((fallback.clone(), provider)), Err(_error) => { tracing::warn!( @@ -849,9 +1213,36 @@ pub fn create_routed_provider( reliability: &crate::config::ReliabilityConfig, model_routes: &[crate::config::ModelRouteConfig], default_model: &str, +) -> anyhow::Result> { + create_routed_provider_with_options( + primary_name, + api_key, + api_url, + reliability, + model_routes, + default_model, + &ProviderRuntimeOptions::default(), + ) +} + +/// Create a routed provider using explicit runtime options. +pub fn create_routed_provider_with_options( + primary_name: &str, + api_key: Option<&str>, + api_url: Option<&str>, + reliability: &crate::config::ReliabilityConfig, + model_routes: &[crate::config::ModelRouteConfig], + default_model: &str, + options: &ProviderRuntimeOptions, ) -> anyhow::Result> { if model_routes.is_empty() { - return create_resilient_provider(primary_name, api_key, api_url, reliability); + return create_resilient_provider_with_options( + primary_name, + api_key, + api_url, + reliability, + options, + ); } // Collect unique provider names needed @@ -877,7 +1268,7 @@ pub fn create_routed_provider( let key = routed_credential.or(api_key); // Only use api_url for the primary provider let url = if name == primary_name { api_url } else { None }; - match create_resilient_provider(name, key, url, reliability) { + match create_resilient_provider_with_options(name, key, url, reliability, options) { Ok(provider) => providers.push((name.clone(), provider)), Err(e) => { if name == primary_name { @@ -1052,13 +1443,16 @@ pub fn list_providers() -> Vec { }, ProviderInfo { name: "qwen", - display_name: "Qwen (DashScope)", + display_name: "Qwen (DashScope / Qwen Code OAuth)", aliases: &[ "dashscope", "qwen-intl", "dashscope-intl", "qwen-us", "dashscope-us", + "qwen-code", + "qwen-oauth", + "qwen_oauth", ], local: false, }, @@ -1140,6 +1534,7 @@ pub fn list_providers() -> Vec { #[cfg(test)] mod tests { use super::*; + use std::sync::{Mutex, OnceLock}; struct EnvGuard { key: &'static str, @@ -1168,6 +1563,13 @@ mod tests { } } + fn env_lock() -> std::sync::MutexGuard<'static, ()> { + static LOCK: OnceLock> = OnceLock::new(); + LOCK.get_or_init(|| Mutex::new(())) + .lock() + .expect("env lock poisoned") + } + #[test] fn resolve_provider_credential_prefers_explicit_argument() { let resolved = resolve_provider_credential("openrouter", Some(" explicit-key ")); @@ -1176,6 +1578,7 @@ mod tests { #[test] fn resolve_provider_credential_uses_minimax_oauth_env_for_placeholder() { + let _env_lock = env_lock(); let _oauth_guard = EnvGuard::set(MINIMAX_OAUTH_TOKEN_ENV, Some("oauth-token")); let _api_guard = EnvGuard::set(MINIMAX_API_KEY_ENV, Some("api-key")); let _refresh_guard = EnvGuard::set(MINIMAX_OAUTH_REFRESH_TOKEN_ENV, None); @@ -1187,6 +1590,7 @@ mod tests { #[test] fn resolve_provider_credential_falls_back_to_minimax_api_key_for_placeholder() { + let _env_lock = env_lock(); let _oauth_guard = EnvGuard::set(MINIMAX_OAUTH_TOKEN_ENV, None); let _api_guard = EnvGuard::set(MINIMAX_API_KEY_ENV, Some("api-key")); let _refresh_guard = EnvGuard::set(MINIMAX_OAUTH_REFRESH_TOKEN_ENV, None); @@ -1198,6 +1602,7 @@ mod tests { #[test] fn resolve_provider_credential_placeholder_ignores_generic_api_key_fallback() { + let _env_lock = env_lock(); let _oauth_guard = EnvGuard::set(MINIMAX_OAUTH_TOKEN_ENV, None); let _api_guard = EnvGuard::set(MINIMAX_API_KEY_ENV, None); let _refresh_guard = EnvGuard::set(MINIMAX_OAUTH_REFRESH_TOKEN_ENV, None); @@ -1208,6 +1613,104 @@ mod tests { assert!(resolved.is_none()); } + #[test] + fn resolve_provider_credential_bedrock_uses_internal_credential_path() { + let _generic_guard = EnvGuard::set("API_KEY", Some("generic-key")); + let _override_guard = EnvGuard::set("OPENROUTER_API_KEY", Some("openrouter-key")); + + assert_eq!( + resolve_provider_credential("bedrock", Some("explicit")), + Some("explicit".to_string()) + ); + assert!(resolve_provider_credential("bedrock", None).is_none()); + assert!(resolve_provider_credential("aws-bedrock", None).is_none()); + } + + #[test] + fn resolve_qwen_oauth_context_prefers_explicit_override() { + let _env_lock = env_lock(); + let fake_home = format!("/tmp/zeroclaw-qwen-oauth-home-{}", std::process::id()); + let _home_guard = EnvGuard::set("HOME", Some(fake_home.as_str())); + let _token_guard = EnvGuard::set(QWEN_OAUTH_TOKEN_ENV, Some("oauth-token")); + let _resource_guard = EnvGuard::set( + QWEN_OAUTH_RESOURCE_URL_ENV, + Some("coding-intl.dashscope.aliyuncs.com"), + ); + + let context = resolve_qwen_oauth_context(Some(" explicit-qwen-token ")); + + assert_eq!(context.credential.as_deref(), Some("explicit-qwen-token")); + assert!(context.base_url.is_none()); + } + + #[test] + fn resolve_qwen_oauth_context_uses_env_token_and_resource_url() { + let _env_lock = env_lock(); + let fake_home = format!("/tmp/zeroclaw-qwen-oauth-home-{}-env", std::process::id()); + let _home_guard = EnvGuard::set("HOME", Some(fake_home.as_str())); + let _token_guard = EnvGuard::set(QWEN_OAUTH_TOKEN_ENV, Some("oauth-token")); + let _refresh_guard = EnvGuard::set(QWEN_OAUTH_REFRESH_TOKEN_ENV, None); + let _resource_guard = EnvGuard::set( + QWEN_OAUTH_RESOURCE_URL_ENV, + Some("coding-intl.dashscope.aliyuncs.com"), + ); + let _dashscope_guard = EnvGuard::set("DASHSCOPE_API_KEY", Some("dashscope-fallback")); + + let context = resolve_qwen_oauth_context(Some(QWEN_OAUTH_PLACEHOLDER)); + + assert_eq!(context.credential.as_deref(), Some("oauth-token")); + assert_eq!( + context.base_url.as_deref(), + Some("https://coding-intl.dashscope.aliyuncs.com/v1") + ); + } + + #[test] + fn resolve_qwen_oauth_context_reads_cached_credentials_file() { + let _env_lock = env_lock(); + let fake_home = format!("/tmp/zeroclaw-qwen-oauth-home-{}-file", std::process::id()); + let creds_dir = PathBuf::from(&fake_home).join(".qwen"); + std::fs::create_dir_all(&creds_dir).unwrap(); + let creds_path = creds_dir.join("oauth_creds.json"); + std::fs::write( + &creds_path, + r#"{"access_token":"cached-token","refresh_token":"cached-refresh","resource_url":"https://resource.example.com","expiry_date":4102444800000}"#, + ) + .unwrap(); + + let _home_guard = EnvGuard::set("HOME", Some(fake_home.as_str())); + let _token_guard = EnvGuard::set(QWEN_OAUTH_TOKEN_ENV, None); + let _refresh_guard = EnvGuard::set(QWEN_OAUTH_REFRESH_TOKEN_ENV, None); + let _resource_guard = EnvGuard::set(QWEN_OAUTH_RESOURCE_URL_ENV, None); + let _dashscope_guard = EnvGuard::set("DASHSCOPE_API_KEY", None); + + let context = resolve_qwen_oauth_context(Some(QWEN_OAUTH_PLACEHOLDER)); + + assert_eq!(context.credential.as_deref(), Some("cached-token")); + assert_eq!( + context.base_url.as_deref(), + Some("https://resource.example.com/v1") + ); + } + + #[test] + fn resolve_qwen_oauth_context_placeholder_does_not_use_dashscope_fallback() { + let _env_lock = env_lock(); + let fake_home = format!( + "/tmp/zeroclaw-qwen-oauth-home-{}-placeholder", + std::process::id() + ); + let _home_guard = EnvGuard::set("HOME", Some(fake_home.as_str())); + let _token_guard = EnvGuard::set(QWEN_OAUTH_TOKEN_ENV, None); + let _refresh_guard = EnvGuard::set(QWEN_OAUTH_REFRESH_TOKEN_ENV, None); + let _resource_guard = EnvGuard::set(QWEN_OAUTH_RESOURCE_URL_ENV, None); + let _dashscope_guard = EnvGuard::set("DASHSCOPE_API_KEY", Some("dashscope-fallback")); + + let context = resolve_qwen_oauth_context(Some(QWEN_OAUTH_PLACEHOLDER)); + + assert!(context.credential.is_none()); + } + #[test] fn regional_alias_predicates_cover_expected_variants() { assert!(is_moonshot_alias("moonshot")); @@ -1220,6 +1723,9 @@ mod tests { assert!(is_minimax_alias("minimax-portal-cn")); assert!(is_qwen_alias("dashscope")); assert!(is_qwen_alias("qwen-us")); + assert!(is_qwen_alias("qwen-code")); + assert!(is_qwen_oauth_alias("qwen-code")); + assert!(is_qwen_oauth_alias("qwen_oauth")); assert!(is_zai_alias("z.ai")); assert!(is_zai_alias("zai-cn")); assert!(is_qianfan_alias("qianfan")); @@ -1242,6 +1748,7 @@ mod tests { assert_eq!(canonical_china_provider_name("minimax-cn"), Some("minimax")); assert_eq!(canonical_china_provider_name("qwen"), Some("qwen")); assert_eq!(canonical_china_provider_name("dashscope-us"), Some("qwen")); + assert_eq!(canonical_china_provider_name("qwen-code"), Some("qwen")); assert_eq!(canonical_china_provider_name("zai"), Some("zai")); assert_eq!(canonical_china_provider_name("z.ai-cn"), Some("zai")); assert_eq!(canonical_china_provider_name("qianfan"), Some("qianfan")); @@ -1272,6 +1779,7 @@ mod tests { assert_eq!(qwen_base_url("qwen-cn"), Some(QWEN_CN_BASE_URL)); assert_eq!(qwen_base_url("qwen-intl"), Some(QWEN_INTL_BASE_URL)); assert_eq!(qwen_base_url("qwen-us"), Some(QWEN_US_BASE_URL)); + assert_eq!(qwen_base_url("qwen-code"), Some(QWEN_CN_BASE_URL)); assert_eq!(zai_base_url("zai"), Some(ZAI_GLOBAL_BASE_URL)); assert_eq!(zai_base_url("z.ai"), Some(ZAI_GLOBAL_BASE_URL)); @@ -1405,8 +1913,11 @@ mod tests { #[test] fn factory_bedrock() { - assert!(create_provider("bedrock", Some("key")).is_ok()); - assert!(create_provider("aws-bedrock", Some("key")).is_ok()); + // Bedrock uses AWS env vars for credentials, not API key. + assert!(create_provider("bedrock", None).is_ok()); + assert!(create_provider("aws-bedrock", None).is_ok()); + // Passing an api_key is harmless (ignored). + assert!(create_provider("bedrock", Some("ignored")).is_ok()); } #[test] @@ -1427,6 +1938,8 @@ mod tests { assert!(create_provider("dashscope-international", Some("key")).is_ok()); assert!(create_provider("qwen-us", Some("key")).is_ok()); assert!(create_provider("dashscope-us", Some("key")).is_ok()); + assert!(create_provider("qwen-code", Some("key")).is_ok()); + assert!(create_provider("qwen-oauth", Some("key")).is_ok()); } #[test] @@ -1669,6 +2182,76 @@ mod tests { assert!(provider.is_err()); } + /// Fallback providers resolve their own credentials via provider-specific + /// env vars rather than inheriting the primary provider's key. A provider + /// that requires no key (e.g. lmstudio, ollama) must initialize + /// successfully even when the primary uses a completely different key. + #[test] + fn resilient_fallback_resolves_own_credential() { + let reliability = crate::config::ReliabilityConfig { + provider_retries: 1, + provider_backoff_ms: 100, + fallback_providers: vec!["lmstudio".into(), "ollama".into()], + api_keys: Vec::new(), + model_fallbacks: std::collections::HashMap::new(), + channel_initial_backoff_secs: 2, + channel_max_backoff_secs: 60, + scheduler_poll_secs: 15, + scheduler_retries: 2, + }; + + // Primary uses a ZAI key; fallbacks (lmstudio, ollama) should NOT + // receive this key; they resolve their own credentials independently. + let provider = create_resilient_provider("zai", Some("zai-test-key"), None, &reliability); + assert!(provider.is_ok()); + } + + /// `custom:` URL entries work as fallback providers, enabling arbitrary + /// OpenAI-compatible endpoints (e.g. local LM Studio on a Docker host). + #[test] + fn resilient_fallback_supports_custom_url() { + let reliability = crate::config::ReliabilityConfig { + provider_retries: 1, + provider_backoff_ms: 100, + fallback_providers: vec!["custom:http://host.docker.internal:1234/v1".into()], + api_keys: Vec::new(), + model_fallbacks: std::collections::HashMap::new(), + channel_initial_backoff_secs: 2, + channel_max_backoff_secs: 60, + scheduler_poll_secs: 15, + scheduler_retries: 2, + }; + + let provider = + create_resilient_provider("openai", Some("openai-test-key"), None, &reliability); + assert!(provider.is_ok()); + } + + /// Mixed fallback chain: named providers, custom URLs, and invalid entries + /// all coexist. Invalid entries are silently ignored; valid ones initialize. + #[test] + fn resilient_fallback_mixed_chain() { + let reliability = crate::config::ReliabilityConfig { + provider_retries: 1, + provider_backoff_ms: 100, + fallback_providers: vec![ + "deepseek".into(), + "custom:http://localhost:8080/v1".into(), + "nonexistent-provider".into(), + "lmstudio".into(), + ], + api_keys: Vec::new(), + model_fallbacks: std::collections::HashMap::new(), + channel_initial_backoff_secs: 2, + channel_max_backoff_secs: 60, + scheduler_poll_secs: 15, + scheduler_retries: 2, + }; + + let provider = create_resilient_provider("zai", Some("zai-test-key"), None, &reliability); + assert!(provider.is_ok()); + } + #[test] fn ollama_with_custom_url() { let provider = create_provider_with_url("ollama", None, Some("http://10.100.2.32:11434")); @@ -1712,6 +2295,7 @@ mod tests { "qwen-intl", "qwen-cn", "qwen-us", + "qwen-code", "lmstudio", "groq", "mistral", diff --git a/src/providers/ollama.rs b/src/providers/ollama.rs index 7fdc06f..4131d29 100644 --- a/src/providers/ollama.rs +++ b/src/providers/ollama.rs @@ -1,11 +1,16 @@ -use crate::providers::traits::Provider; +use crate::multimodal; +use crate::providers::traits::{ + ChatMessage, ChatResponse, Provider, ProviderCapabilities, ToolCall, +}; use async_trait::async_trait; use reqwest::Client; use serde::{Deserialize, Serialize}; +use std::collections::HashMap; pub struct OllamaProvider { base_url: String, api_key: Option, + reasoning_enabled: Option, } // ─── Request Structures ─────────────────────────────────────────────────────── @@ -16,12 +21,36 @@ struct ChatRequest { messages: Vec, stream: bool, options: Options, + #[serde(skip_serializing_if = "Option::is_none")] + think: Option, + #[serde(skip_serializing_if = "Option::is_none")] + tools: Option>, } #[derive(Debug, Serialize)] struct Message { role: String, - content: String, + #[serde(skip_serializing_if = "Option::is_none")] + content: Option, + #[serde(skip_serializing_if = "Option::is_none")] + images: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + tool_calls: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + tool_name: Option, +} + +#[derive(Debug, Serialize)] +struct OutgoingToolCall { + #[serde(rename = "type")] + kind: String, + function: OutgoingFunction, +} + +#[derive(Debug, Serialize)] +struct OutgoingFunction { + name: String, + arguments: serde_json::Value, } #[derive(Debug, Serialize)] @@ -64,6 +93,14 @@ struct OllamaFunction { impl OllamaProvider { pub fn new(base_url: Option<&str>, api_key: Option<&str>) -> Self { + Self::new_with_reasoning(base_url, api_key, None) + } + + pub fn new_with_reasoning( + base_url: Option<&str>, + api_key: Option<&str>, + reasoning_enabled: Option, + ) -> Self { let api_key = api_key.and_then(|value| { let trimmed = value.trim(); (!trimmed.is_empty()).then(|| trimmed.to_string()) @@ -75,6 +112,7 @@ impl OllamaProvider { .trim_end_matches('/') .to_string(), api_key, + reasoning_enabled, } } @@ -112,29 +150,176 @@ impl OllamaProvider { Ok((normalized_model, should_auth)) } - /// Send a request to Ollama and get the parsed response + fn parse_tool_arguments(arguments: &str) -> serde_json::Value { + serde_json::from_str(arguments).unwrap_or_else(|_| serde_json::json!({})) + } + + fn build_chat_request( + &self, + messages: Vec, + model: &str, + temperature: f64, + tools: Option<&[serde_json::Value]>, + ) -> ChatRequest { + ChatRequest { + model: model.to_string(), + messages, + stream: false, + options: Options { temperature }, + think: self.reasoning_enabled, + tools: tools.map(|t| t.to_vec()), + } + } + + fn convert_user_message_content(&self, content: &str) -> (Option, Option>) { + let (cleaned, image_refs) = multimodal::parse_image_markers(content); + if image_refs.is_empty() { + return (Some(content.to_string()), None); + } + + let images: Vec = image_refs + .iter() + .filter_map(|reference| multimodal::extract_ollama_image_payload(reference)) + .collect(); + + if images.is_empty() { + return (Some(content.to_string()), None); + } + + let cleaned = cleaned.trim(); + let content = if cleaned.is_empty() { + None + } else { + Some(cleaned.to_string()) + }; + + (content, Some(images)) + } + + /// Convert internal chat history format to Ollama's native tool-call message schema. + /// + /// `run_tool_call_loop` stores native assistant/tool entries as JSON strings in + /// `ChatMessage.content`. We decode those payloads here so follow-up requests send + /// structured `assistant.tool_calls` and `tool.tool_name`, as expected by Ollama. + fn convert_messages(&self, messages: &[ChatMessage]) -> Vec { + let mut tool_name_by_id: HashMap = HashMap::new(); + + messages + .iter() + .map(|message| { + if message.role == "assistant" { + if let Ok(value) = serde_json::from_str::(&message.content) { + if let Some(tool_calls_value) = value.get("tool_calls") { + if let Ok(parsed_calls) = + serde_json::from_value::>(tool_calls_value.clone()) + { + let outgoing_calls: Vec = parsed_calls + .into_iter() + .map(|call| { + tool_name_by_id.insert(call.id.clone(), call.name.clone()); + OutgoingToolCall { + kind: "function".to_string(), + function: OutgoingFunction { + name: call.name, + arguments: Self::parse_tool_arguments( + &call.arguments, + ), + }, + } + }) + .collect(); + let content = value + .get("content") + .and_then(serde_json::Value::as_str) + .map(ToString::to_string); + return Message { + role: "assistant".to_string(), + content, + images: None, + tool_calls: Some(outgoing_calls), + tool_name: None, + }; + } + } + } + } + + if message.role == "tool" { + if let Ok(value) = serde_json::from_str::(&message.content) { + let tool_name = value + .get("tool_name") + .and_then(serde_json::Value::as_str) + .map(ToString::to_string) + .or_else(|| { + value + .get("tool_call_id") + .and_then(serde_json::Value::as_str) + .and_then(|id| tool_name_by_id.get(id)) + .cloned() + }); + let content = value + .get("content") + .and_then(serde_json::Value::as_str) + .map(ToString::to_string) + .or_else(|| { + (!message.content.trim().is_empty()) + .then_some(message.content.clone()) + }); + + return Message { + role: "tool".to_string(), + content, + images: None, + tool_calls: None, + tool_name, + }; + } + } + + if message.role == "user" { + let (content, images) = self.convert_user_message_content(&message.content); + return Message { + role: "user".to_string(), + content, + images, + tool_calls: None, + tool_name: None, + }; + } + + Message { + role: message.role.clone(), + content: Some(message.content.clone()), + images: None, + tool_calls: None, + tool_name: None, + } + }) + .collect() + } + + /// Send a request to Ollama and get the parsed response. + /// Pass `tools` to enable native function-calling for models that support it. async fn send_request( &self, messages: Vec, model: &str, temperature: f64, should_auth: bool, + tools: Option<&[serde_json::Value]>, ) -> anyhow::Result { - let request = ChatRequest { - model: model.to_string(), - messages, - stream: false, - options: Options { temperature }, - }; + let request = self.build_chat_request(messages, model, temperature, tools); let url = format!("{}/api/chat", self.base_url); tracing::debug!( - "Ollama request: url={} model={} message_count={} temperature={}", + "Ollama request: url={} model={} message_count={} temperature={} think={:?} tool_count={}", url, model, request.messages.len(), - temperature + temperature, + request.think, + request.tools.as_ref().map_or(0, |t| t.len()), ); let mut request_builder = self.http_client().post(&url).json(&request); @@ -257,6 +442,13 @@ impl OllamaProvider { #[async_trait] impl Provider for OllamaProvider { + fn capabilities(&self) -> ProviderCapabilities { + ProviderCapabilities { + native_tool_calling: true, + vision: true, + } + } + async fn chat_with_system( &self, system_prompt: Option<&str>, @@ -271,17 +463,24 @@ impl Provider for OllamaProvider { if let Some(sys) = system_prompt { messages.push(Message { role: "system".to_string(), - content: sys.to_string(), + content: Some(sys.to_string()), + images: None, + tool_calls: None, + tool_name: None, }); } + let (user_content, user_images) = self.convert_user_message_content(message); messages.push(Message { role: "user".to_string(), - content: message.to_string(), + content: user_content, + images: user_images, + tool_calls: None, + tool_name: None, }); let response = self - .send_request(messages, &normalized_model, temperature, should_auth) + .send_request(messages, &normalized_model, temperature, should_auth, None) .await?; // If model returned tool calls, format them for loop_.rs's parse_tool_calls @@ -322,16 +521,16 @@ impl Provider for OllamaProvider { ) -> anyhow::Result { let (normalized_model, should_auth) = self.resolve_request_details(model)?; - let api_messages: Vec = messages - .iter() - .map(|m| Message { - role: m.role.clone(), - content: m.content.clone(), - }) - .collect(); + let api_messages = self.convert_messages(messages); let response = self - .send_request(api_messages, &normalized_model, temperature, should_auth) + .send_request( + api_messages, + &normalized_model, + temperature, + should_auth, + None, + ) .await?; // If model returned tool calls, format them for loop_.rs's parse_tool_calls @@ -366,11 +565,87 @@ impl Provider for OllamaProvider { Ok(content) } + async fn chat_with_tools( + &self, + messages: &[ChatMessage], + tools: &[serde_json::Value], + model: &str, + temperature: f64, + ) -> anyhow::Result { + let (normalized_model, should_auth) = self.resolve_request_details(model)?; + + let api_messages = self.convert_messages(messages); + + // Tools arrive pre-formatted in OpenAI/Ollama-compatible JSON from + // tools_to_openai_format() in loop_.rs — pass them through directly. + let tools_opt = if tools.is_empty() { None } else { Some(tools) }; + + let response = self + .send_request( + api_messages, + &normalized_model, + temperature, + should_auth, + tools_opt, + ) + .await?; + + // Native tool calls returned by the model. + if !response.message.tool_calls.is_empty() { + let tool_calls: Vec = response + .message + .tool_calls + .iter() + .map(|tc| { + let (name, args) = self.extract_tool_name_and_args(tc); + ToolCall { + id: tc + .id + .clone() + .unwrap_or_else(|| uuid::Uuid::new_v4().to_string()), + name, + arguments: serde_json::to_string(&args) + .unwrap_or_else(|_| "{}".to_string()), + } + }) + .collect(); + let text = if response.message.content.is_empty() { + None + } else { + Some(response.message.content) + }; + return Ok(ChatResponse { text, tool_calls }); + } + + // Plain text response. + let content = response.message.content; + if content.is_empty() { + if let Some(thinking) = &response.message.thinking { + tracing::warn!( + "Ollama returned empty content with only thinking: '{}'. Model may have stopped prematurely.", + if thinking.len() > 100 { &thinking[..100] } else { thinking } + ); + return Ok(ChatResponse { + text: Some(format!( + "I was thinking about this: {}... but I didn't complete my response. Could you try asking again?", + if thinking.len() > 200 { &thinking[..200] } else { thinking } + )), + tool_calls: vec![], + }); + } + tracing::warn!("Ollama returned empty content with no tool calls"); + } + Ok(ChatResponse { + text: Some(content), + tool_calls: vec![], + }) + } + fn supports_native_tools(&self) -> bool { - // Return false since loop_.rs uses XML-style tool parsing via system prompt - // The model may return native tool_calls but we convert them to JSON format - // that parse_tool_calls() understands - false + // Ollama's /api/chat supports native function-calling for capable models + // (qwen2.5, llama3.1, mistral-nemo, etc.). chat_with_tools() sends tool + // definitions in the request and returns structured ToolCall objects. + true } } @@ -448,6 +723,46 @@ mod tests { assert!(!should_auth); } + #[test] + fn request_omits_think_when_reasoning_not_configured() { + let provider = OllamaProvider::new(None, None); + let request = provider.build_chat_request( + vec![Message { + role: "user".to_string(), + content: Some("hello".to_string()), + images: None, + tool_calls: None, + tool_name: None, + }], + "llama3", + 0.7, + None, + ); + + let json = serde_json::to_value(request).unwrap(); + assert!(json.get("think").is_none()); + } + + #[test] + fn request_includes_think_when_reasoning_configured() { + let provider = OllamaProvider::new_with_reasoning(None, None, Some(false)); + let request = provider.build_chat_request( + vec![Message { + role: "user".to_string(), + content: Some("hello".to_string()), + images: None, + tool_calls: None, + tool_name: None, + }], + "llama3", + 0.7, + None, + ); + + let json = serde_json::to_value(request).unwrap(); + assert_eq!(json.get("think"), Some(&serde_json::json!(false))); + } + #[test] fn response_deserializes() { let json = r#"{"message":{"role":"assistant","content":"Hello from Ollama!"}}"#; @@ -557,4 +872,80 @@ mod tests { // arguments should be a string (JSON-encoded) assert!(func.get("arguments").unwrap().is_string()); } + + #[test] + fn convert_messages_parses_native_assistant_tool_calls() { + let provider = OllamaProvider::new(None, None); + let messages = vec![ChatMessage { + role: "assistant".into(), + content: r#"{"content":null,"tool_calls":[{"id":"call_1","name":"shell","arguments":"{\"command\":\"ls\"}"}]}"#.into(), + }]; + + let converted = provider.convert_messages(&messages); + + assert_eq!(converted.len(), 1); + assert_eq!(converted[0].role, "assistant"); + assert!(converted[0].content.is_none()); + let calls = converted[0] + .tool_calls + .as_ref() + .expect("tool calls expected"); + assert_eq!(calls.len(), 1); + assert_eq!(calls[0].kind, "function"); + assert_eq!(calls[0].function.name, "shell"); + assert_eq!(calls[0].function.arguments.get("command").unwrap(), "ls"); + } + + #[test] + fn convert_messages_maps_tool_result_call_id_to_tool_name() { + let provider = OllamaProvider::new(None, None); + let messages = vec![ + ChatMessage { + role: "assistant".into(), + content: r#"{"content":null,"tool_calls":[{"id":"call_7","name":"file_read","arguments":"{\"path\":\"README.md\"}"}]}"#.into(), + }, + ChatMessage { + role: "tool".into(), + content: r#"{"tool_call_id":"call_7","content":"ok"}"#.into(), + }, + ]; + + let converted = provider.convert_messages(&messages); + + assert_eq!(converted.len(), 2); + assert_eq!(converted[1].role, "tool"); + assert_eq!(converted[1].tool_name.as_deref(), Some("file_read")); + assert_eq!(converted[1].content.as_deref(), Some("ok")); + assert!(converted[1].tool_calls.is_none()); + } + + #[test] + fn convert_messages_extracts_images_from_user_marker() { + let provider = OllamaProvider::new(None, None); + let messages = vec![ChatMessage { + role: "user".into(), + content: "Inspect this screenshot [IMAGE:data:image/png;base64,abcd==]".into(), + }]; + + let converted = provider.convert_messages(&messages); + assert_eq!(converted.len(), 1); + assert_eq!(converted[0].role, "user"); + assert_eq!( + converted[0].content.as_deref(), + Some("Inspect this screenshot") + ); + let images = converted[0] + .images + .as_ref() + .expect("images should be present"); + assert_eq!(images, &vec!["abcd==".to_string()]); + } + + #[test] + fn capabilities_include_native_tools_and_vision() { + let provider = OllamaProvider::new(None, None); + let caps = ::capabilities(&provider); + assert!(caps.native_tool_calling); + assert!(caps.vision); + } } diff --git a/src/providers/openai.rs b/src/providers/openai.rs index 90ed340..067f943 100644 --- a/src/providers/openai.rs +++ b/src/providers/openai.rs @@ -75,20 +75,34 @@ struct NativeMessage { tool_calls: Option>, } -#[derive(Debug, Serialize)] +#[derive(Debug, Serialize, Deserialize)] struct NativeToolSpec { #[serde(rename = "type")] kind: String, function: NativeToolFunctionSpec, } -#[derive(Debug, Serialize)] +#[derive(Debug, Serialize, Deserialize)] struct NativeToolFunctionSpec { name: String, description: String, parameters: serde_json::Value, } +fn parse_native_tool_spec(value: serde_json::Value) -> anyhow::Result { + let spec: NativeToolSpec = serde_json::from_value(value) + .map_err(|e| anyhow::anyhow!("Invalid OpenAI tool specification: {e}"))?; + + if spec.kind != "function" { + anyhow::bail!( + "Invalid OpenAI tool specification: unsupported tool type '{}', expected 'function'", + spec.kind + ); + } + + Ok(spec) +} + #[derive(Debug, Serialize, Deserialize)] struct NativeToolCall { #[serde(skip_serializing_if = "Option::is_none")] @@ -354,6 +368,59 @@ impl Provider for OpenAiProvider { true } + async fn chat_with_tools( + &self, + messages: &[ChatMessage], + tools: &[serde_json::Value], + model: &str, + temperature: f64, + ) -> anyhow::Result { + let credential = self.credential.as_ref().ok_or_else(|| { + anyhow::anyhow!("OpenAI API key not set. Set OPENAI_API_KEY or edit config.toml.") + })?; + + let native_tools: Option> = if tools.is_empty() { + None + } else { + Some( + tools + .iter() + .cloned() + .map(parse_native_tool_spec) + .collect::, _>>()?, + ) + }; + + let native_request = NativeChatRequest { + model: model.to_string(), + messages: Self::convert_messages(messages), + temperature, + tool_choice: native_tools.as_ref().map(|_| "auto".to_string()), + tools: native_tools, + }; + + let response = self + .http_client() + .post(format!("{}/chat/completions", self.base_url)) + .header("Authorization", format!("Bearer {credential}")) + .json(&native_request) + .send() + .await?; + + if !response.status().is_success() { + return Err(super::api_error("OpenAI", response).await); + } + + let native_response: NativeChatResponse = response.json().await?; + let message = native_response + .choices + .into_iter() + .next() + .map(|c| c.message) + .ok_or_else(|| anyhow::anyhow!("No response from OpenAI"))?; + Ok(Self::parse_native_response(message)) + } + async fn warmup(&self) -> anyhow::Result<()> { if let Some(credential) = self.credential.as_ref() { self.http_client() @@ -537,4 +604,74 @@ mod tests { let msg = &resp.choices[0].message; assert_eq!(msg.effective_content(), Some("Real answer".to_string())); } + + #[tokio::test] + async fn chat_with_tools_fails_without_key() { + let p = OpenAiProvider::new(None); + let messages = vec![ChatMessage::user("hello".to_string())]; + let tools = vec![serde_json::json!({ + "type": "function", + "function": { + "name": "shell", + "description": "Run a shell command", + "parameters": { + "type": "object", + "properties": { + "command": { "type": "string" } + }, + "required": ["command"] + } + } + })]; + let result = p.chat_with_tools(&messages, &tools, "gpt-4o", 0.7).await; + assert!(result.is_err()); + assert!(result.unwrap_err().to_string().contains("API key not set")); + } + + #[tokio::test] + async fn chat_with_tools_rejects_invalid_tool_shape() { + let p = OpenAiProvider::new(Some("openai-test-credential")); + let messages = vec![ChatMessage::user("hello".to_string())]; + let tools = vec![serde_json::json!({ + "type": "function", + "function": { + "name": "shell", + "parameters": { + "type": "object", + "properties": { + "command": { "type": "string" } + }, + "required": ["command"] + } + } + })]; + + let result = p.chat_with_tools(&messages, &tools, "gpt-4o", 0.7).await; + assert!(result.is_err()); + assert!(result + .unwrap_err() + .to_string() + .contains("Invalid OpenAI tool specification")); + } + + #[test] + fn native_tool_spec_deserializes_from_openai_format() { + let json = serde_json::json!({ + "type": "function", + "function": { + "name": "shell", + "description": "Run a shell command", + "parameters": { + "type": "object", + "properties": { + "command": { "type": "string" } + }, + "required": ["command"] + } + } + }); + let spec = parse_native_tool_spec(json).unwrap(); + assert_eq!(spec.kind, "function"); + assert_eq!(spec.function.name, "shell"); + } } diff --git a/src/providers/openai_codex.rs b/src/providers/openai_codex.rs index e01dd82..eb9fc2f 100644 --- a/src/providers/openai_codex.rs +++ b/src/providers/openai_codex.rs @@ -1,6 +1,6 @@ use crate::auth::openai_oauth::extract_account_id_from_jwt; use crate::auth::AuthService; -use crate::providers::traits::Provider; +use crate::providers::traits::{ChatMessage, Provider}; use crate::providers::ProviderRuntimeOptions; use async_trait::async_trait; use reqwest::Client; @@ -123,6 +123,44 @@ fn normalize_model_id(model: &str) -> &str { model.rsplit('/').next().unwrap_or(model) } +fn build_responses_input(messages: &[ChatMessage]) -> (String, Vec) { + let mut system_parts: Vec<&str> = Vec::new(); + let mut input: Vec = Vec::new(); + + for msg in messages { + match msg.role.as_str() { + "system" => system_parts.push(&msg.content), + "user" => { + input.push(ResponsesInput { + role: "user".to_string(), + content: vec![ResponsesInputContent { + kind: "input_text".to_string(), + text: msg.content.clone(), + }], + }); + } + "assistant" => { + input.push(ResponsesInput { + role: "assistant".to_string(), + content: vec![ResponsesInputContent { + kind: "output_text".to_string(), + text: msg.content.clone(), + }], + }); + } + _ => {} + } + } + + let instructions = if system_parts.is_empty() { + DEFAULT_CODEX_INSTRUCTIONS.to_string() + } else { + system_parts.join("\n\n") + }; + + (instructions, input) +} + fn clamp_reasoning_effort(model: &str, effort: &str) -> String { let id = normalize_model_id(model); if (id.starts_with("gpt-5.2") || id.starts_with("gpt-5.3")) && effort == "minimal" { @@ -335,14 +373,12 @@ async fn decode_responses_body(response: reqwest::Response) -> anyhow::Result, - message: &str, + input: Vec, + instructions: String, model: &str, - _temperature: f64, ) -> anyhow::Result { let profile = self .auth @@ -368,14 +404,8 @@ impl Provider for OpenAiCodexProvider { let request = ResponsesRequest { model: normalized_model.to_string(), - input: vec![ResponsesInput { - role: "user".to_string(), - content: vec![ResponsesInputContent { - kind: "input_text".to_string(), - text: message.to_string(), - }], - }], - instructions: resolve_instructions(system_prompt), + input, + instructions, store: false, stream: true, text: ResponsesTextOptions { @@ -411,6 +441,38 @@ impl Provider for OpenAiCodexProvider { } } +#[async_trait] +impl Provider for OpenAiCodexProvider { + async fn chat_with_system( + &self, + system_prompt: Option<&str>, + message: &str, + model: &str, + _temperature: f64, + ) -> anyhow::Result { + let input = vec![ResponsesInput { + role: "user".to_string(), + content: vec![ResponsesInputContent { + kind: "input_text".to_string(), + text: message.to_string(), + }], + }]; + self.send_responses_request(input, resolve_instructions(system_prompt), model) + .await + } + + async fn chat_with_history( + &self, + messages: &[ChatMessage], + model: &str, + _temperature: f64, + ) -> anyhow::Result { + let (instructions, input) = build_responses_input(messages); + self.send_responses_request(input, instructions, model) + .await + } +} + #[cfg(test)] mod tests { use super::*; @@ -516,4 +578,70 @@ data: [DONE] assert_eq!(parse_sse_text(payload).unwrap().as_deref(), Some("Done")); } + + #[test] + fn build_responses_input_maps_content_types_by_role() { + let messages = vec![ + ChatMessage { + role: "system".into(), + content: "You are helpful.".into(), + }, + ChatMessage { + role: "user".into(), + content: "Hi".into(), + }, + ChatMessage { + role: "assistant".into(), + content: "Hello!".into(), + }, + ChatMessage { + role: "user".into(), + content: "Thanks".into(), + }, + ]; + let (instructions, input) = build_responses_input(&messages); + assert_eq!(instructions, "You are helpful."); + assert_eq!(input.len(), 3); + + let json: Vec = input + .iter() + .map(|item| serde_json::to_value(item).unwrap()) + .collect(); + assert_eq!(json[0]["role"], "user"); + assert_eq!(json[0]["content"][0]["type"], "input_text"); + assert_eq!(json[1]["role"], "assistant"); + assert_eq!(json[1]["content"][0]["type"], "output_text"); + assert_eq!(json[2]["role"], "user"); + assert_eq!(json[2]["content"][0]["type"], "input_text"); + } + + #[test] + fn build_responses_input_uses_default_instructions_without_system() { + let messages = vec![ChatMessage { + role: "user".into(), + content: "Hello".into(), + }]; + let (instructions, input) = build_responses_input(&messages); + assert_eq!(instructions, DEFAULT_CODEX_INSTRUCTIONS); + assert_eq!(input.len(), 1); + } + + #[test] + fn build_responses_input_ignores_unknown_roles() { + let messages = vec![ + ChatMessage { + role: "tool".into(), + content: "result".into(), + }, + ChatMessage { + role: "user".into(), + content: "Go".into(), + }, + ]; + let (instructions, input) = build_responses_input(&messages); + assert_eq!(instructions, DEFAULT_CODEX_INSTRUCTIONS); + assert_eq!(input.len(), 1); + let json = serde_json::to_value(&input[0]).unwrap(); + assert_eq!(json["role"], "user"); + } } diff --git a/src/providers/reliable.rs b/src/providers/reliable.rs index 85f9019..94c855a 100644 --- a/src/providers/reliable.rs +++ b/src/providers/reliable.rs @@ -6,14 +6,28 @@ use std::collections::HashMap; use std::sync::atomic::{AtomicUsize, Ordering}; use std::time::Duration; +// ── Error Classification ───────────────────────────────────────────────── +// Errors are split into retryable (transient server/network failures) and +// non-retryable (permanent client errors). This distinction drives whether +// the retry loop continues, falls back to the next provider, or aborts +// immediately — avoiding wasted latency on errors that cannot self-heal. + /// Check if an error is non-retryable (client errors that won't resolve with retries). fn is_non_retryable(err: &anyhow::Error) -> bool { + if is_context_window_exceeded(err) { + return true; + } + + // 4xx errors are generally non-retryable (bad request, auth failure, etc.), + // except 429 (rate-limit — transient) and 408 (timeout — worth retrying). if let Some(reqwest_err) = err.downcast_ref::() { if let Some(status) = reqwest_err.status() { let code = status.as_u16(); return status.is_client_error() && code != 429 && code != 408; } } + // Fallback: parse status codes from stringified errors (some providers + // embed codes in error messages rather than returning typed HTTP errors). let msg = err.to_string(); for word in msg.split(|c: char| !c.is_ascii_digit()) { if let Ok(code) = word.parse::() { @@ -23,6 +37,8 @@ fn is_non_retryable(err: &anyhow::Error) -> bool { } } + // Heuristic: detect auth/model failures by keyword when no HTTP status + // is available (e.g. gRPC or custom transport errors). let msg_lower = msg.to_lowercase(); let auth_failure_hints = [ "invalid api key", @@ -45,14 +61,28 @@ fn is_non_retryable(err: &anyhow::Error) -> bool { return true; } - let model_catalog_mismatch = msg_lower.contains("model") + msg_lower.contains("model") && (msg_lower.contains("not found") || msg_lower.contains("unknown") || msg_lower.contains("unsupported") || msg_lower.contains("does not exist") - || msg_lower.contains("invalid")); + || msg_lower.contains("invalid")) +} - model_catalog_mismatch +fn is_context_window_exceeded(err: &anyhow::Error) -> bool { + let lower = err.to_string().to_lowercase(); + let hints = [ + "exceeds the context window", + "context window of this model", + "maximum context length", + "context length exceeded", + "too many tokens", + "token limit exceeded", + "prompt is too long", + "input is too long", + ]; + + hints.iter().any(|hint| lower.contains(hint)) } /// Check if an error is a rate-limit (429) error. @@ -179,6 +209,16 @@ fn push_failure( )); } +// ── Resilient Provider Wrapper ──────────────────────────────────────────── +// Three-level failover strategy: model chain → provider chain → retry loop. +// Outer loop: iterate model fallback chain (original model first, then +// configured alternatives). +// Middle loop: iterate registered providers in priority order. +// Inner loop: retry the same (provider, model) pair with exponential +// backoff, rotating API keys on rate-limit errors. +// Loop invariant: `failures` accumulates every failed attempt so the final +// error message gives operators a complete diagnostic trail. + /// Provider wrapper with retry, fallback, auth rotation, and model failover. pub struct ReliableProvider { providers: Vec<(String, Box)>, @@ -270,6 +310,10 @@ impl Provider for ReliableProvider { let models = self.model_chain(model); let mut failures = Vec::new(); + // Outer: model fallback chain. Middle: provider priority. Inner: retries. + // Each iteration: attempt one (provider, model) call. On success, return + // immediately. On non-retryable error, break to next provider. On + // retryable error, sleep with exponential backoff and retry. for current_model in &models { for (provider_name, provider) in &self.providers { let mut backoff_ms = self.base_backoff_ms; @@ -308,13 +352,16 @@ impl Provider for ReliableProvider { &error_detail, ); - // On rate-limit, try rotating API key + // Rate-limit with rotatable keys: cycle to the next API key + // so the retry hits a different quota bucket. if rate_limited && !non_retryable_rate_limit { if let Some(new_key) = self.rotate_key() { - tracing::info!( + tracing::warn!( provider = provider_name, error = %error_detail, - "Rate limited, rotated API key (key ending ...{})", + "Rate limited; key rotation selected key ending ...{} \ + but cannot apply (Provider trait has no set_api_key). \ + Retrying with original key.", &new_key[new_key.len().saturating_sub(4)..] ); } @@ -327,6 +374,14 @@ impl Provider for ReliableProvider { error = %error_detail, "Non-retryable error, moving on" ); + + if is_context_window_exceeded(&e) { + anyhow::bail!( + "Request exceeds model context window; retries and fallbacks were skipped. Attempts:\n{}", + failures.join("\n") + ); + } + break; } @@ -419,10 +474,12 @@ impl Provider for ReliableProvider { if rate_limited && !non_retryable_rate_limit { if let Some(new_key) = self.rotate_key() { - tracing::info!( + tracing::warn!( provider = provider_name, error = %error_detail, - "Rate limited, rotated API key (key ending ...{})", + "Rate limited; key rotation selected key ending ...{} \ + but cannot apply (Provider trait has no set_api_key). \ + Retrying with original key.", &new_key[new_key.len().saturating_sub(4)..] ); } @@ -435,6 +492,14 @@ impl Provider for ReliableProvider { error = %error_detail, "Non-retryable error, moving on" ); + + if is_context_window_exceeded(&e) { + anyhow::bail!( + "Request exceeds model context window; retries and fallbacks were skipped. Attempts:\n{}", + failures.join("\n") + ); + } + break; } @@ -477,6 +542,12 @@ impl Provider for ReliableProvider { .unwrap_or(false) } + fn supports_vision(&self) -> bool { + self.providers + .iter() + .any(|(_, provider)| provider.supports_vision()) + } + async fn chat_with_tools( &self, messages: &[ChatMessage], @@ -527,10 +598,12 @@ impl Provider for ReliableProvider { if rate_limited && !non_retryable_rate_limit { if let Some(new_key) = self.rotate_key() { - tracing::info!( + tracing::warn!( provider = provider_name, error = %error_detail, - "Rate limited, rotated API key (key ending ...{})", + "Rate limited; key rotation selected key ending ...{} \ + but cannot apply (Provider trait has no set_api_key). \ + Retrying with original key.", &new_key[new_key.len().saturating_sub(4)..] ); } @@ -543,6 +616,14 @@ impl Provider for ReliableProvider { error = %error_detail, "Non-retryable error, moving on" ); + + if is_context_window_exceeded(&e) { + anyhow::bail!( + "Request exceeds model context window; retries and fallbacks were skipped. Attempts:\n{}", + failures.join("\n") + ); + } + break; } @@ -869,6 +950,44 @@ mod tests { assert!(!is_non_retryable(&anyhow::anyhow!( "model overloaded, try again later" ))); + assert!(is_non_retryable(&anyhow::anyhow!( + "OpenAI Codex stream error: Your input exceeds the context window of this model." + ))); + } + + #[tokio::test] + async fn context_window_error_aborts_retries_and_model_fallbacks() { + let calls = Arc::new(AtomicUsize::new(0)); + let mut model_fallbacks = std::collections::HashMap::new(); + model_fallbacks.insert( + "gpt-5.3-codex".to_string(), + vec!["gpt-5.2-codex".to_string()], + ); + + let provider = ReliableProvider::new( + vec![( + "openai-codex".into(), + Box::new(MockProvider { + calls: Arc::clone(&calls), + fail_until_attempt: usize::MAX, + response: "never", + error: "OpenAI Codex stream error: Your input exceeds the context window of this model. Please adjust your input and try again.", + }), + )], + 4, + 1, + ) + .with_model_fallbacks(model_fallbacks); + + let err = provider + .simple_chat("hello", "gpt-5.3-codex", 0.0) + .await + .expect_err("context window overflow should fail fast"); + let msg = err.to_string(); + + assert!(msg.contains("context window")); + assert!(msg.contains("skipped")); + assert_eq!(calls.load(Ordering::SeqCst), 1); } #[tokio::test] diff --git a/src/providers/router.rs b/src/providers/router.rs index 2d55869..b12bd52 100644 --- a/src/providers/router.rs +++ b/src/providers/router.rs @@ -158,6 +158,12 @@ impl Provider for RouterProvider { .unwrap_or(false) } + fn supports_vision(&self) -> bool { + self.providers + .iter() + .any(|(_, provider)| provider.supports_vision()) + } + async fn warmup(&self) -> anyhow::Result<()> { for (name, provider) in &self.providers { tracing::info!(provider = name, "Warming up routed provider"); diff --git a/src/providers/traits.rs b/src/providers/traits.rs index fe830ef..bfb3506 100644 --- a/src/providers/traits.rs +++ b/src/providers/traits.rs @@ -192,6 +192,15 @@ pub enum StreamError { Io(#[from] std::io::Error), } +/// Structured error returned when a requested capability is not supported. +#[derive(Debug, Clone, thiserror::Error)] +#[error("provider_capability_error provider={provider} capability={capability} message={message}")] +pub struct ProviderCapabilityError { + pub provider: String, + pub capability: String, + pub message: String, +} + /// Provider capabilities declaration. /// /// Describes what features a provider supports, enabling intelligent @@ -205,6 +214,8 @@ pub struct ProviderCapabilities { /// /// When `false`, tools must be injected via system prompt as text. pub native_tool_calling: bool, + /// Whether the provider supports vision / image inputs. + pub vision: bool, } /// Provider-specific tool payload formats. @@ -351,6 +362,11 @@ pub trait Provider: Send + Sync { self.capabilities().native_tool_calling } + /// Whether provider supports multimodal vision input. + fn supports_vision(&self) -> bool { + self.capabilities().vision + } + /// Warm up the HTTP connection pool (TLS handshake, DNS, HTTP/2 setup). /// Default implementation is a no-op; providers with HTTP clients should override. async fn warmup(&self) -> anyhow::Result<()> { @@ -458,6 +474,7 @@ mod tests { fn capabilities(&self) -> ProviderCapabilities { ProviderCapabilities { native_tool_calling: true, + vision: true, } } @@ -539,18 +556,22 @@ mod tests { fn provider_capabilities_default() { let caps = ProviderCapabilities::default(); assert!(!caps.native_tool_calling); + assert!(!caps.vision); } #[test] fn provider_capabilities_equality() { let caps1 = ProviderCapabilities { native_tool_calling: true, + vision: false, }; let caps2 = ProviderCapabilities { native_tool_calling: true, + vision: false, }; let caps3 = ProviderCapabilities { native_tool_calling: false, + vision: false, }; assert_eq!(caps1, caps2); @@ -563,6 +584,12 @@ mod tests { assert!(provider.supports_native_tools()); } + #[test] + fn supports_vision_reflects_capabilities_default_mapping() { + let provider = CapabilityMockProvider; + assert!(provider.supports_vision()); + } + #[test] fn tools_payload_variants() { // Test Gemini variant diff --git a/src/runtime/traits.rs b/src/runtime/traits.rs index 153c06f..7e3e06a 100644 --- a/src/runtime/traits.rs +++ b/src/runtime/traits.rs @@ -1,29 +1,68 @@ use std::path::{Path, PathBuf}; -/// Runtime adapter — abstracts platform differences so the same agent -/// code runs on native, Docker, Cloudflare Workers, Raspberry Pi, etc. +/// Runtime adapter that abstracts platform differences for the agent. +/// +/// Implement this trait to port the agent to a new execution environment. +/// The adapter declares platform capabilities (shell access, filesystem, +/// long-running processes) and provides platform-specific implementations +/// for operations like spawning shell commands. The orchestration loop +/// queries these capabilities to adapt its behavior—for example, disabling +/// tool execution on runtimes without shell access. +/// +/// Implementations must be `Send + Sync` because the adapter is shared +/// across async tasks on the Tokio runtime. pub trait RuntimeAdapter: Send + Sync { - /// Human-readable runtime name + /// Return the human-readable name of this runtime environment. + /// + /// Used in logs and diagnostics (e.g., `"native"`, `"docker"`, + /// `"cloudflare-workers"`). fn name(&self) -> &str; - /// Whether this runtime supports shell access + /// Report whether this runtime supports shell command execution. + /// + /// When `false`, the agent disables shell-based tools. Serverless and + /// edge runtimes typically return `false`. fn has_shell_access(&self) -> bool; - /// Whether this runtime supports filesystem access + /// Report whether this runtime supports filesystem read/write. + /// + /// When `false`, the agent disables file-based tools and falls back to + /// in-memory storage. fn has_filesystem_access(&self) -> bool; - /// Base storage path for this runtime + /// Return the base directory for persistent storage on this runtime. + /// + /// Memory backends, logs, and other artifacts are stored under this path. + /// Implementations should return a platform-appropriate writable directory. fn storage_path(&self) -> PathBuf; - /// Whether long-running processes (gateway, heartbeat) are supported + /// Report whether this runtime supports long-running background processes. + /// + /// When `true`, the agent may start the gateway server, heartbeat loop, + /// and other persistent tasks. Serverless runtimes with short execution + /// limits should return `false`. fn supports_long_running(&self) -> bool; - /// Maximum memory budget in bytes (0 = unlimited) + /// Return the maximum memory budget in bytes for this runtime. + /// + /// A value of `0` (the default) indicates no limit. Constrained + /// environments (embedded, serverless) should return their actual + /// memory ceiling so the agent can adapt buffer sizes and caching. fn memory_budget(&self) -> u64 { 0 } - /// Build a shell command process for this runtime. + /// Build a shell command process configured for this runtime. + /// + /// Constructs a [`tokio::process::Command`] that will execute `command` + /// with `workspace_dir` as the working directory. Implementations may + /// prepend sandbox wrappers, set environment variables, or redirect + /// I/O as appropriate for the platform. + /// + /// # Errors + /// + /// Returns an error if the runtime does not support shell access or if + /// the command cannot be constructed (e.g., missing shell binary). fn build_shell_command( &self, command: &str, diff --git a/src/security/audit.rs b/src/security/audit.rs index 80c45cb..816ecc7 100644 --- a/src/security/audit.rs +++ b/src/security/audit.rs @@ -335,8 +335,8 @@ mod tests { // ── §8.1 Log rotation tests ───────────────────────────── - #[test] - fn audit_logger_writes_event_when_enabled() -> Result<()> { + #[tokio::test] + async fn audit_logger_writes_event_when_enabled() -> Result<()> { let tmp = TempDir::new()?; let config = AuditConfig { enabled: true, @@ -353,7 +353,7 @@ mod tests { let log_path = tmp.path().join("audit.log"); assert!(log_path.exists(), "audit log file must be created"); - let content = std::fs::read_to_string(&log_path)?; + let content = tokio::fs::read_to_string(&log_path).await?; assert!(!content.is_empty(), "audit log must not be empty"); let parsed: AuditEvent = serde_json::from_str(content.trim())?; @@ -361,8 +361,8 @@ mod tests { Ok(()) } - #[test] - fn audit_log_command_event_writes_structured_entry() -> Result<()> { + #[tokio::test] + async fn audit_log_command_event_writes_structured_entry() -> Result<()> { let tmp = TempDir::new()?; let config = AuditConfig { enabled: true, @@ -382,7 +382,7 @@ mod tests { })?; let log_path = tmp.path().join("audit.log"); - let content = std::fs::read_to_string(&log_path)?; + let content = tokio::fs::read_to_string(&log_path).await?; let parsed: AuditEvent = serde_json::from_str(content.trim())?; let action = parsed.action.unwrap(); diff --git a/src/security/mod.rs b/src/security/mod.rs index 4009b6f..d77ec19 100644 --- a/src/security/mod.rs +++ b/src/security/mod.rs @@ -1,3 +1,23 @@ +//! Security subsystem for policy enforcement, sandboxing, and secret management. +//! +//! This module provides the security infrastructure for ZeroClaw. The core type +//! [`SecurityPolicy`] defines autonomy levels, workspace boundaries, and +//! access-control rules that are enforced across the tool and runtime subsystems. +//! [`PairingGuard`] implements device pairing for channel authentication, and +//! [`SecretStore`] handles encrypted credential storage. +//! +//! OS-level isolation is provided through the [`Sandbox`] trait defined in +//! [`traits`], with pluggable backends including Docker, Firejail, Bubblewrap, +//! and Landlock. The [`create_sandbox`] function selects the best available +//! backend at runtime. An [`AuditLogger`] records security-relevant events for +//! forensic review. +//! +//! # Extension +//! +//! To add a new sandbox backend, implement [`Sandbox`] in a new submodule and +//! register it in [`detect::create_sandbox`]. See `AGENTS.md` §7.5 for security +//! change guidelines. + pub mod audit; #[cfg(feature = "sandbox-bubblewrap")] pub mod bubblewrap; @@ -24,6 +44,16 @@ pub use secrets::SecretStore; #[allow(unused_imports)] pub use traits::{NoopSandbox, Sandbox}; +/// Redact sensitive values for safe logging. Shows first 4 chars + "***" suffix. +/// This function intentionally breaks the data-flow taint chain for static analysis. +pub fn redact(value: &str) -> String { + if value.len() <= 4 { + "***".to_string() + } else { + format!("{}***", &value[..4]) + } +} + #[cfg(test)] mod tests { use super::*; @@ -47,4 +77,12 @@ mod tests { assert_eq!(decrypted, "top-secret"); } + + #[test] + fn redact_hides_most_of_value() { + assert_eq!(redact("abcdefgh"), "abcd***"); + assert_eq!(redact("ab"), "***"); + assert_eq!(redact(""), "***"); + assert_eq!(redact("12345"), "1234***"); + } } diff --git a/src/security/pairing.rs b/src/security/pairing.rs index e4030d5..232d3d3 100644 --- a/src/security/pairing.rs +++ b/src/security/pairing.rs @@ -10,29 +10,33 @@ use parking_lot::Mutex; use sha2::{Digest, Sha256}; -use std::collections::HashSet; +use std::collections::{HashMap, HashSet}; +use std::sync::Arc; use std::time::Instant; /// Maximum failed pairing attempts before lockout. const MAX_PAIR_ATTEMPTS: u32 = 5; /// Lockout duration after too many failed pairing attempts. const PAIR_LOCKOUT_SECS: u64 = 300; // 5 minutes +/// Maximum number of tracked client entries to bound memory usage. +const MAX_TRACKED_CLIENTS: usize = 1024; /// Manages pairing state for the gateway. /// /// Bearer tokens are stored as SHA-256 hashes to prevent plaintext exposure /// in config files. When a new token is generated, the plaintext is returned /// to the client once, and only the hash is retained. -#[derive(Debug)] +// TODO: I've just made this work with parking_lot but it should use either flume or tokio's async mutexes +#[derive(Debug, Clone)] pub struct PairingGuard { /// Whether pairing is required at all. require_pairing: bool, /// One-time pairing code (generated on startup, consumed on first pair). - pairing_code: Mutex>, + pairing_code: Arc>>, /// Set of SHA-256 hashed bearer tokens (persisted across restarts). - paired_tokens: Mutex>, - /// Brute-force protection: failed attempt counter + lockout time. - failed_attempts: Mutex<(u32, Option)>, + paired_tokens: Arc>>, + /// Brute-force protection: per-client failed attempt counter + lockout time. + failed_attempts: Arc)>>>, } impl PairingGuard { @@ -62,9 +66,9 @@ impl PairingGuard { }; Self { require_pairing, - pairing_code: Mutex::new(code), - paired_tokens: Mutex::new(tokens), - failed_attempts: Mutex::new((0, None)), + pairing_code: Arc::new(Mutex::new(code)), + paired_tokens: Arc::new(Mutex::new(tokens)), + failed_attempts: Arc::new(Mutex::new(HashMap::new())), } } @@ -78,13 +82,11 @@ impl PairingGuard { self.require_pairing } - /// Attempt to pair with the given code. Returns a bearer token on success. - /// Returns `Err(lockout_seconds)` if locked out due to brute force. - pub fn try_pair(&self, code: &str) -> Result, u64> { - // Check brute force lockout + fn try_pair_blocking(&self, code: &str, client_id: &str) -> Result, u64> { + // Check brute force lockout for this specific client { let attempts = self.failed_attempts.lock(); - if let (count, Some(locked_at)) = &*attempts { + if let Some((count, Some(locked_at))) = attempts.get(client_id) { if *count >= MAX_PAIR_ATTEMPTS { let elapsed = locked_at.elapsed().as_secs(); if elapsed < PAIR_LOCKOUT_SECS { @@ -98,10 +100,10 @@ impl PairingGuard { let mut pairing_code = self.pairing_code.lock(); if let Some(ref expected) = *pairing_code { if constant_time_eq(code.trim(), expected.trim()) { - // Reset failed attempts on success + // Reset failed attempts for this client on success { let mut attempts = self.failed_attempts.lock(); - *attempts = (0, None); + attempts.remove(client_id); } let token = generate_token(); let mut tokens = self.paired_tokens.lock(); @@ -115,18 +117,50 @@ impl PairingGuard { } } - // Increment failed attempts + // Increment failed attempts for this client { let mut attempts = self.failed_attempts.lock(); - attempts.0 += 1; - if attempts.0 >= MAX_PAIR_ATTEMPTS { - attempts.1 = Some(Instant::now()); + + // Evict expired entries when approaching the bound + if attempts.len() >= MAX_TRACKED_CLIENTS { + attempts.retain(|_, (_, locked_at)| { + locked_at + .map(|t| t.elapsed().as_secs() < PAIR_LOCKOUT_SECS) + .unwrap_or(true) + }); + } + + let entry = attempts.entry(client_id.to_string()).or_insert((0, None)); + // Reset if previous lockout has expired + if let Some(locked_at) = entry.1 { + if locked_at.elapsed().as_secs() >= PAIR_LOCKOUT_SECS { + *entry = (0, None); + } + } + entry.0 += 1; + if entry.0 >= MAX_PAIR_ATTEMPTS { + entry.1 = Some(Instant::now()); } } Ok(None) } + /// Attempt to pair with the given code. Returns a bearer token on success. + /// Returns `Err(lockout_seconds)` if locked out due to brute force. + /// `client_id` identifies the client for per-client lockout accounting. + pub async fn try_pair(&self, code: &str, client_id: &str) -> Result, u64> { + let this = self.clone(); + let code = code.to_string(); + let client_id = client_id.to_string(); + // TODO: make this function the main one without spawning a task + let handle = tokio::task::spawn_blocking(move || this.try_pair_blocking(&code, &client_id)); + + handle + .await + .expect("failed to spawn blocking task this should not happen") + } + /// Check if a bearer token is valid (compares against stored hashes). pub fn is_authenticated(&self, token: &str) -> bool { if !self.require_pairing { @@ -181,9 +215,7 @@ fn generate_code() -> String { /// on macOS). The 32 random bytes (256 bits) are hex-encoded for a /// 64-character token, providing 256 bits of entropy. fn generate_token() -> String { - use rand::RngCore; - let mut bytes = [0u8; 32]; - rand::rng().fill_bytes(&mut bytes); + let bytes: [u8; 32] = rand::random(); format!("zc_{}", hex::encode(bytes)) } @@ -232,63 +264,64 @@ pub fn is_public_bind(host: &str) -> bool { #[cfg(test)] mod tests { use super::*; + use tokio::test; // ── PairingGuard ───────────────────────────────────────── #[test] - fn new_guard_generates_code_when_no_tokens() { + async fn new_guard_generates_code_when_no_tokens() { let guard = PairingGuard::new(true, &[]); assert!(guard.pairing_code().is_some()); assert!(!guard.is_paired()); } #[test] - fn new_guard_no_code_when_tokens_exist() { + async fn new_guard_no_code_when_tokens_exist() { let guard = PairingGuard::new(true, &["zc_existing".into()]); assert!(guard.pairing_code().is_none()); assert!(guard.is_paired()); } #[test] - fn new_guard_no_code_when_pairing_disabled() { + async fn new_guard_no_code_when_pairing_disabled() { let guard = PairingGuard::new(false, &[]); assert!(guard.pairing_code().is_none()); } #[test] - fn try_pair_correct_code() { + async fn try_pair_correct_code() { let guard = PairingGuard::new(true, &[]); let code = guard.pairing_code().unwrap().to_string(); - let token = guard.try_pair(&code).unwrap(); + let token = guard.try_pair(&code, "test_client").await.unwrap(); assert!(token.is_some()); assert!(token.unwrap().starts_with("zc_")); assert!(guard.is_paired()); } #[test] - fn try_pair_wrong_code() { + async fn try_pair_wrong_code() { let guard = PairingGuard::new(true, &[]); - let result = guard.try_pair("000000").unwrap(); + let result = guard.try_pair("000000", "test_client").await.unwrap(); // Might succeed if code happens to be 000000, but extremely unlikely // Just check it returns Ok(None) normally let _ = result; } #[test] - fn try_pair_empty_code() { + async fn try_pair_empty_code() { let guard = PairingGuard::new(true, &[]); - assert!(guard.try_pair("").unwrap().is_none()); + assert!(guard.try_pair("", "test_client").await.unwrap().is_none()); } #[test] - fn is_authenticated_with_valid_token() { + async fn is_authenticated_with_valid_token() { // Pass plaintext token — PairingGuard hashes it on load let guard = PairingGuard::new(true, &["zc_valid".into()]); assert!(guard.is_authenticated("zc_valid")); } #[test] - fn is_authenticated_with_prehashed_token() { + async fn is_authenticated_with_prehashed_token() { // Pass an already-hashed token (64 hex chars) let hashed = hash_token("zc_valid"); let guard = PairingGuard::new(true, &[hashed]); @@ -296,20 +329,20 @@ mod tests { } #[test] - fn is_authenticated_with_invalid_token() { + async fn is_authenticated_with_invalid_token() { let guard = PairingGuard::new(true, &["zc_valid".into()]); assert!(!guard.is_authenticated("zc_invalid")); } #[test] - fn is_authenticated_when_pairing_disabled() { + async fn is_authenticated_when_pairing_disabled() { let guard = PairingGuard::new(false, &[]); assert!(guard.is_authenticated("anything")); assert!(guard.is_authenticated("")); } #[test] - fn tokens_returns_hashes() { + async fn tokens_returns_hashes() { let guard = PairingGuard::new(true, &["zc_a".into(), "zc_b".into()]); let tokens = guard.tokens(); assert_eq!(tokens.len(), 2); @@ -322,10 +355,10 @@ mod tests { } #[test] - fn pair_then_authenticate() { + async fn pair_then_authenticate() { let guard = PairingGuard::new(true, &[]); let code = guard.pairing_code().unwrap().to_string(); - let token = guard.try_pair(&code).unwrap().unwrap(); + let token = guard.try_pair(&code, "test_client").await.unwrap().unwrap(); assert!(guard.is_authenticated(&token)); assert!(!guard.is_authenticated("wrong")); } @@ -333,24 +366,24 @@ mod tests { // ── Token hashing ──────────────────────────────────────── #[test] - fn hash_token_produces_64_hex_chars() { + async fn hash_token_produces_64_hex_chars() { let hash = hash_token("zc_test_token"); assert_eq!(hash.len(), 64); assert!(hash.chars().all(|c| c.is_ascii_hexdigit())); } #[test] - fn hash_token_is_deterministic() { + async fn hash_token_is_deterministic() { assert_eq!(hash_token("zc_abc"), hash_token("zc_abc")); } #[test] - fn hash_token_differs_for_different_inputs() { + async fn hash_token_differs_for_different_inputs() { assert_ne!(hash_token("zc_a"), hash_token("zc_b")); } #[test] - fn is_token_hash_detects_hash_vs_plaintext() { + async fn is_token_hash_detects_hash_vs_plaintext() { assert!(is_token_hash(&hash_token("zc_test"))); assert!(!is_token_hash("zc_test_token")); assert!(!is_token_hash("too_short")); @@ -360,7 +393,7 @@ mod tests { // ── is_public_bind ─────────────────────────────────────── #[test] - fn localhost_variants_not_public() { + async fn localhost_variants_not_public() { assert!(!is_public_bind("127.0.0.1")); assert!(!is_public_bind("localhost")); assert!(!is_public_bind("::1")); @@ -368,12 +401,12 @@ mod tests { } #[test] - fn zero_zero_is_public() { + async fn zero_zero_is_public() { assert!(is_public_bind("0.0.0.0")); } #[test] - fn real_ip_is_public() { + async fn real_ip_is_public() { assert!(is_public_bind("192.168.1.100")); assert!(is_public_bind("10.0.0.1")); } @@ -381,13 +414,13 @@ mod tests { // ── constant_time_eq ───────────────────────────────────── #[test] - fn constant_time_eq_same() { + async fn constant_time_eq_same() { assert!(constant_time_eq("abc", "abc")); assert!(constant_time_eq("", "")); } #[test] - fn constant_time_eq_different() { + async fn constant_time_eq_different() { assert!(!constant_time_eq("abc", "abd")); assert!(!constant_time_eq("abc", "ab")); assert!(!constant_time_eq("a", "")); @@ -396,14 +429,14 @@ mod tests { // ── generate helpers ───────────────────────────────────── #[test] - fn generate_code_is_6_digits() { + async fn generate_code_is_6_digits() { let code = generate_code(); assert_eq!(code.len(), 6); assert!(code.chars().all(|c| c.is_ascii_digit())); } #[test] - fn generate_code_is_not_deterministic() { + async fn generate_code_is_not_deterministic() { // Two codes should differ with overwhelming probability. We try // multiple pairs so a single 1-in-10^6 collision doesn't cause // a flaky CI failure. All 10 pairs colliding is ~1-in-10^60. @@ -416,7 +449,7 @@ mod tests { } #[test] - fn generate_token_has_prefix_and_hex_payload() { + async fn generate_token_has_prefix_and_hex_payload() { let token = generate_token(); let payload = token .strip_prefix("zc_") @@ -434,15 +467,16 @@ mod tests { // ── Brute force protection ─────────────────────────────── #[test] - fn brute_force_lockout_after_max_attempts() { + async fn brute_force_lockout_after_max_attempts() { let guard = PairingGuard::new(true, &[]); + let client = "attacker_client"; // Exhaust all attempts with wrong codes for i in 0..MAX_PAIR_ATTEMPTS { - let result = guard.try_pair(&format!("wrong_{i}")); + let result = guard.try_pair(&format!("wrong_{i}"), client).await; assert!(result.is_ok(), "Attempt {i} should not be locked out yet"); } // Next attempt should be locked out - let result = guard.try_pair("another_wrong"); + let result = guard.try_pair("another_wrong", client).await; assert!( result.is_err(), "Should be locked out after {MAX_PAIR_ATTEMPTS} attempts" @@ -456,29 +490,52 @@ mod tests { } #[test] - fn correct_code_resets_failed_attempts() { + async fn correct_code_resets_failed_attempts() { let guard = PairingGuard::new(true, &[]); let code = guard.pairing_code().unwrap().to_string(); + let client = "test_client"; // Fail a few times for _ in 0..3 { - let _ = guard.try_pair("wrong"); + let _ = guard.try_pair("wrong", client).await; } // Correct code should still work (under MAX_PAIR_ATTEMPTS) - let result = guard.try_pair(&code).unwrap(); + let result = guard.try_pair(&code, client).await.unwrap(); assert!(result.is_some(), "Correct code should work before lockout"); } #[test] - fn lockout_returns_remaining_seconds() { + async fn lockout_returns_remaining_seconds() { let guard = PairingGuard::new(true, &[]); + let client = "test_client"; for _ in 0..MAX_PAIR_ATTEMPTS { - let _ = guard.try_pair("wrong"); + let _ = guard.try_pair("wrong", client).await; } - let err = guard.try_pair("wrong").unwrap_err(); + let err = guard.try_pair("wrong", client).await.unwrap_err(); // Should be close to PAIR_LOCKOUT_SECS (within a second) assert!( err >= PAIR_LOCKOUT_SECS - 1, "Remaining lockout should be ~{PAIR_LOCKOUT_SECS}s, got {err}s" ); } + + #[test] + async fn lockout_is_per_client() { + let guard = PairingGuard::new(true, &[]); + let attacker = "attacker_ip"; + let legitimate = "legitimate_ip"; + + // Attacker exhausts attempts + for i in 0..MAX_PAIR_ATTEMPTS { + let _ = guard.try_pair(&format!("wrong_{i}"), attacker).await; + } + // Attacker is locked out + assert!(guard.try_pair("wrong", attacker).await.is_err()); + + // Legitimate client is NOT locked out + let result = guard.try_pair("wrong", legitimate).await; + assert!( + result.is_ok(), + "Legitimate client should not be locked out by attacker" + ); + } } diff --git a/src/security/policy.rs b/src/security/policy.rs index 806a399..c6fe6aa 100644 --- a/src/security/policy.rs +++ b/src/security/policy.rs @@ -1,10 +1,11 @@ use parking_lot::Mutex; +use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use std::path::{Path, PathBuf}; use std::time::Instant; /// How much autonomy the agent has -#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize, JsonSchema)] #[serde(rename_all = "lowercase")] pub enum AutonomyLevel { /// Read-only: can observe but not act @@ -110,6 +111,7 @@ impl Default for SecurityPolicy { "wc".into(), "head".into(), "tail".into(), + "date".into(), ], forbidden_paths: vec![ // System directories (blocked even when workspace_only=false) @@ -142,6 +144,12 @@ impl Default for SecurityPolicy { } } +// ── Shell Command Parsing Utilities ─────────────────────────────────────── +// These helpers implement a minimal quote-aware shell lexer. They exist +// because security validation must reason about the *structure* of a +// command (separators, operators, quoting) rather than treating it as a +// flat string — otherwise an attacker could hide dangerous sub-commands +// inside quoted arguments or chained operators. /// Skip leading environment variable assignments (e.g. `FOO=bar cmd args`). /// Returns the remainder starting at the first non-assignment word. fn skip_env_assignments(s: &str) -> &str { @@ -165,45 +173,226 @@ fn skip_env_assignments(s: &str) -> &str { } } -/// Detect a single `&` operator (background/chain). `&&` is allowed. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum QuoteState { + None, + Single, + Double, +} + +/// Split a shell command into sub-commands by unquoted separators. +/// +/// Separators: +/// - `;` and newline +/// - `|` +/// - `&&`, `||` +/// +/// Characters inside single or double quotes are treated as literals, so +/// `sqlite3 db "SELECT 1; SELECT 2;"` remains a single segment. +fn split_unquoted_segments(command: &str) -> Vec { + let mut segments = Vec::new(); + let mut current = String::new(); + let mut quote = QuoteState::None; + let mut escaped = false; + let mut chars = command.chars().peekable(); + + let push_segment = |segments: &mut Vec, current: &mut String| { + let trimmed = current.trim(); + if !trimmed.is_empty() { + segments.push(trimmed.to_string()); + } + current.clear(); + }; + + while let Some(ch) = chars.next() { + match quote { + QuoteState::Single => { + if ch == '\'' { + quote = QuoteState::None; + } + current.push(ch); + } + QuoteState::Double => { + if escaped { + escaped = false; + current.push(ch); + continue; + } + if ch == '\\' { + escaped = true; + current.push(ch); + continue; + } + if ch == '"' { + quote = QuoteState::None; + } + current.push(ch); + } + QuoteState::None => { + if escaped { + escaped = false; + current.push(ch); + continue; + } + if ch == '\\' { + escaped = true; + current.push(ch); + continue; + } + + match ch { + '\'' => { + quote = QuoteState::Single; + current.push(ch); + } + '"' => { + quote = QuoteState::Double; + current.push(ch); + } + ';' | '\n' => push_segment(&mut segments, &mut current), + '|' => { + if chars.next_if_eq(&'|').is_some() { + // Consume full `||`; both characters are separators. + } + push_segment(&mut segments, &mut current); + } + '&' => { + if chars.next_if_eq(&'&').is_some() { + // `&&` is a separator; single `&` is handled separately. + push_segment(&mut segments, &mut current); + } else { + current.push(ch); + } + } + _ => current.push(ch), + } + } + } + } + + let trimmed = current.trim(); + if !trimmed.is_empty() { + segments.push(trimmed.to_string()); + } + + segments +} + +/// Detect a single unquoted `&` operator (background/chain). `&&` is allowed. /// /// We treat any standalone `&` as unsafe in policy validation because it can /// chain hidden sub-commands and escape foreground timeout expectations. -fn contains_single_ampersand(s: &str) -> bool { - let bytes = s.as_bytes(); - for (i, b) in bytes.iter().enumerate() { - if *b != b'&' { - continue; - } - let prev_is_amp = i > 0 && bytes[i - 1] == b'&'; - let next_is_amp = i + 1 < bytes.len() && bytes[i + 1] == b'&'; - if !prev_is_amp && !next_is_amp { - return true; +fn contains_unquoted_single_ampersand(command: &str) -> bool { + let mut quote = QuoteState::None; + let mut escaped = false; + let mut chars = command.chars().peekable(); + + while let Some(ch) = chars.next() { + match quote { + QuoteState::Single => { + if ch == '\'' { + quote = QuoteState::None; + } + } + QuoteState::Double => { + if escaped { + escaped = false; + continue; + } + if ch == '\\' { + escaped = true; + continue; + } + if ch == '"' { + quote = QuoteState::None; + } + } + QuoteState::None => { + if escaped { + escaped = false; + continue; + } + if ch == '\\' { + escaped = true; + continue; + } + match ch { + '\'' => quote = QuoteState::Single, + '"' => quote = QuoteState::Double, + '&' => { + if chars.next_if_eq(&'&').is_none() { + return true; + } + } + _ => {} + } + } } } + + false +} + +/// Detect an unquoted character in a shell command. +fn contains_unquoted_char(command: &str, target: char) -> bool { + let mut quote = QuoteState::None; + let mut escaped = false; + + for ch in command.chars() { + match quote { + QuoteState::Single => { + if ch == '\'' { + quote = QuoteState::None; + } + } + QuoteState::Double => { + if escaped { + escaped = false; + continue; + } + if ch == '\\' { + escaped = true; + continue; + } + if ch == '"' { + quote = QuoteState::None; + continue; + } + } + QuoteState::None => { + if escaped { + escaped = false; + continue; + } + if ch == '\\' { + escaped = true; + continue; + } + match ch { + '\'' => quote = QuoteState::Single, + '"' => quote = QuoteState::Double, + _ if ch == target => return true, + _ => {} + } + } + } + } + false } impl SecurityPolicy { + // ── Risk Classification ────────────────────────────────────────────── + // Risk is assessed per-segment (split on shell operators), and the + // highest risk across all segments wins. This prevents bypasses like + // `ls && rm -rf /` from being classified as Low just because `ls` is safe. + /// Classify command risk. Any high-risk segment marks the whole command high. pub fn command_risk_level(&self, command: &str) -> CommandRiskLevel { - let mut normalized = command.to_string(); - for sep in ["&&", "||"] { - normalized = normalized.replace(sep, "\x00"); - } - for sep in ['\n', ';', '|', '&'] { - normalized = normalized.replace(sep, "\x00"); - } - let mut saw_medium = false; - for segment in normalized.split('\x00') { - let segment = segment.trim(); - if segment.is_empty() { - continue; - } - - let cmd_part = skip_env_assignments(segment); + for segment in split_unquoted_segments(command) { + let cmd_part = skip_env_assignments(&segment); let mut words = cmd_part.split_whitespace(); let Some(base_raw) = words.next() else { continue; @@ -305,6 +494,15 @@ impl SecurityPolicy { } } + // ── Command Execution Policy Gate ────────────────────────────────────── + // Validation follows a strict precedence order: + // 1. Allowlist check (is the base command permitted at all?) + // 2. Risk classification (high / medium / low) + // 3. Policy flags (block_high_risk_commands, require_approval_for_medium_risk) + // 4. Autonomy level × approval status (supervised requires explicit approval) + // This ordering ensures deny-by-default: unknown commands are rejected + // before any risk or autonomy logic runs. + /// Validate full command execution policy (allowlist + risk gate). pub fn validate_command_execution( &self, @@ -342,6 +540,11 @@ impl SecurityPolicy { Ok(risk) } + // ── Layered Command Allowlist ────────────────────────────────────────── + // Defence-in-depth: five independent gates run in order before the + // per-segment allowlist check. Each gate targets a specific bypass + // technique. If any gate rejects, the whole command is blocked. + /// Check if a shell command is allowed. /// /// Validates the **entire** command string, not just the first word: @@ -367,8 +570,9 @@ impl SecurityPolicy { return false; } - // Block output redirections — they can write to arbitrary paths - if command.contains('>') { + // Block output redirections (`>`, `>>`) — they can write to arbitrary paths. + // Ignore quoted literals, e.g. `echo "a>b"`. + if contains_unquoted_char(command, '>') { return false; } @@ -383,26 +587,13 @@ impl SecurityPolicy { // Block background command chaining (`&`), which can hide extra // sub-commands and outlive timeout expectations. Keep `&&` allowed. - if contains_single_ampersand(command) { + if contains_unquoted_single_ampersand(command) { return false; } - // Split on command separators and validate each sub-command. - // We collect segments by scanning for separator characters. - let mut normalized = command.to_string(); - for sep in ["&&", "||"] { - normalized = normalized.replace(sep, "\x00"); - } - for sep in ['\n', ';', '|'] { - normalized = normalized.replace(sep, "\x00"); - } - - for segment in normalized.split('\x00') { - let segment = segment.trim(); - if segment.is_empty() { - continue; - } - + // Split on unquoted command separators and validate each sub-command. + let segments = split_unquoted_segments(command); + for segment in &segments { // Strip leading env var assignments (e.g. FOO=bar cmd) let cmd_part = skip_env_assignments(segment); @@ -430,7 +621,7 @@ impl SecurityPolicy { } // At least one command must be present - let has_cmd = normalized.split('\x00').any(|s| { + let has_cmd = segments.iter().any(|s| { let s = skip_env_assignments(s.trim()); s.split_whitespace().next().is_some_and(|w| !w.is_empty()) }); @@ -461,6 +652,12 @@ impl SecurityPolicy { } } + // ── Path Validation ──────────────────────────────────────────────── + // Layered checks: null-byte injection → component-level traversal → + // URL-encoded traversal → tilde expansion → absolute-path block → + // forbidden-prefix match. Each layer addresses a distinct escape + // technique; together they enforce workspace confinement. + /// Check if a file path is allowed (no path traversal, within workspace) pub fn is_path_allowed(&self, path: &str) -> bool { // Block null bytes (can truncate paths in C-backed syscalls) @@ -537,6 +734,11 @@ impl SecurityPolicy { self.autonomy != AutonomyLevel::ReadOnly } + // ── Tool Operation Gating ────────────────────────────────────────────── + // Read operations bypass autonomy and rate checks because they have + // no side effects. Act operations must pass both the autonomy gate + // (not read-only) and the sliding-window rate limiter. + /// Enforce policy for a tool operation. /// /// Read operations are always allowed by autonomy/rate gates. @@ -689,6 +891,7 @@ mod tests { assert!(p.is_command_allowed("cargo build --release")); assert!(p.is_command_allowed("cat file.txt")); assert!(p.is_command_allowed("grep -r pattern .")); + assert!(p.is_command_allowed("date")); } #[test] @@ -829,6 +1032,19 @@ mod tests { assert!(result.unwrap_err().contains("high-risk")); } + #[test] + fn validate_command_full_mode_skips_medium_risk_approval_gate() { + let p = SecurityPolicy { + autonomy: AutonomyLevel::Full, + require_approval_for_medium_risk: true, + allowed_commands: vec!["touch".into()], + ..SecurityPolicy::default() + }; + + let result = p.validate_command_execution("touch test.txt", false); + assert_eq!(result.unwrap(), CommandRiskLevel::Medium); + } + #[test] fn validate_command_rejects_background_chain_bypass() { let p = default_policy(); @@ -1024,6 +1240,32 @@ mod tests { assert!(!p.is_command_allowed("ls;rm -rf /")); } + #[test] + fn quoted_semicolons_do_not_split_sqlite_command() { + let p = SecurityPolicy { + allowed_commands: vec!["sqlite3".into()], + ..SecurityPolicy::default() + }; + assert!(p.is_command_allowed( + "sqlite3 /tmp/test.db \"CREATE TABLE t(id INT); INSERT INTO t VALUES(1); SELECT * FROM t;\"" + )); + assert_eq!( + p.command_risk_level( + "sqlite3 /tmp/test.db \"CREATE TABLE t(id INT); INSERT INTO t VALUES(1); SELECT * FROM t;\"" + ), + CommandRiskLevel::Low + ); + } + + #[test] + fn unquoted_semicolon_after_quoted_sql_still_splits_commands() { + let p = SecurityPolicy { + allowed_commands: vec!["sqlite3".into()], + ..SecurityPolicy::default() + }; + assert!(!p.is_command_allowed("sqlite3 /tmp/test.db \"SELECT 1;\"; rm -rf /")); + } + #[test] fn command_injection_backtick_blocked() { let p = default_policy(); @@ -1086,6 +1328,13 @@ mod tests { assert!(!p.is_command_allowed("ls >> /tmp/exfil.txt")); } + #[test] + fn quoted_ampersand_and_redirect_literals_are_not_treated_as_operators() { + let p = default_policy(); + assert!(p.is_command_allowed("echo \"A&B\"")); + assert!(p.is_command_allowed("echo \"A>B\"")); + } + #[test] fn command_argument_injection_blocked() { let p = default_policy(); diff --git a/src/security/secrets.rs b/src/security/secrets.rs index 2a26831..663112c 100644 --- a/src/security/secrets.rs +++ b/src/security/secrets.rs @@ -334,8 +334,8 @@ mod tests { assert!(!SecretStore::is_encrypted("")); } - #[test] - fn key_file_created_on_first_encrypt() { + #[tokio::test] + async fn key_file_created_on_first_encrypt() { let tmp = TempDir::new().unwrap(); let store = SecretStore::new(tmp.path(), true); assert!(!store.key_path.exists()); @@ -343,7 +343,7 @@ mod tests { store.encrypt("test").unwrap(); assert!(store.key_path.exists(), "Key file should be created"); - let key_hex = fs::read_to_string(&store.key_path).unwrap(); + let key_hex = tokio::fs::read_to_string(&store.key_path).await.unwrap(); assert_eq!( key_hex.len(), KEY_LEN * 2, diff --git a/src/security/traits.rs b/src/security/traits.rs index 06fc4ef..13e0738 100644 --- a/src/security/traits.rs +++ b/src/security/traits.rs @@ -1,25 +1,62 @@ -//! Sandbox trait for pluggable OS-level isolation +//! Sandbox trait for pluggable OS-level isolation. +//! +//! This module defines the [`Sandbox`] trait, which abstracts OS-level process +//! isolation backends. Implementations wrap shell commands with platform-specific +//! sandboxing (e.g., seccomp, AppArmor, namespaces) to limit the blast radius +//! of tool execution. The agent runtime selects and applies a sandbox backend +//! before executing any shell command. use async_trait::async_trait; use std::process::Command; -/// Sandbox backend for OS-level isolation +/// Sandbox backend for OS-level process isolation. +/// +/// Implement this trait to add a new sandboxing strategy. The runtime queries +/// [`is_available`](Sandbox::is_available) at startup to select the best +/// backend for the current platform, then calls +/// [`wrap_command`](Sandbox::wrap_command) before every shell execution. +/// +/// Implementations must be `Send + Sync` because the sandbox may be shared +/// across concurrent tool executions on the Tokio runtime. #[async_trait] pub trait Sandbox: Send + Sync { - /// Wrap a command with sandbox protection + /// Wrap a command with sandbox protection. + /// + /// Mutates `cmd` in place to apply isolation constraints (e.g., prepending + /// a wrapper binary, setting environment variables, adding seccomp filters). + /// + /// # Errors + /// + /// Returns `std::io::Error` if the sandbox configuration cannot be applied + /// (e.g., missing wrapper binary, invalid policy file). fn wrap_command(&self, cmd: &mut Command) -> std::io::Result<()>; - /// Check if this sandbox backend is available on the current platform + /// Check if this sandbox backend is available on the current platform. + /// + /// Returns `true` when all required kernel features, binaries, and + /// permissions are present. The runtime calls this at startup to select + /// the most capable available backend. fn is_available(&self) -> bool; - /// Human-readable name of this sandbox backend + /// Return the human-readable name of this sandbox backend. + /// + /// Used in logs and diagnostics to identify which isolation strategy is + /// active (e.g., `"firejail"`, `"bubblewrap"`, `"none"`). fn name(&self) -> &str; - /// Description of what this sandbox provides + /// Return a brief description of the isolation guarantees this sandbox provides. + /// + /// Displayed in status output and health checks so operators can verify + /// the active security posture. fn description(&self) -> &str; } -/// No-op sandbox (always available, provides no additional isolation) +/// No-op sandbox that provides no additional OS-level isolation. +/// +/// Always reports itself as available. Use this as the fallback when no +/// platform-specific sandbox backend is detected, or in development +/// environments where isolation is not required. Security in this mode +/// relies entirely on application-layer controls. #[derive(Debug, Clone, Default)] pub struct NoopSandbox; diff --git a/src/service/mod.rs b/src/service/mod.rs index c9907d7..0c78c94 100644 --- a/src/service/mod.rs +++ b/src/service/mod.rs @@ -16,6 +16,7 @@ pub fn handle_command(command: &crate::ServiceCommands, config: &Config) -> Resu crate::ServiceCommands::Install => install(config), crate::ServiceCommands::Start => start(config), crate::ServiceCommands::Stop => stop(config), + crate::ServiceCommands::Restart => restart(config), crate::ServiceCommands::Status => status(config), crate::ServiceCommands::Uninstall => uninstall(config), } @@ -84,6 +85,13 @@ fn stop(config: &Config) -> Result<()> { } } +fn restart(config: &Config) -> Result<()> { + stop(config)?; + start(config)?; + println!("✅ Service restarted"); + Ok(()) +} + fn status(config: &Config) -> Result<()> { if cfg!(target_os = "macos") { let out = run_capture(Command::new("launchctl").arg("list"))?; diff --git a/src/skillforge/integrate.rs b/src/skillforge/integrate.rs index 540dd8b..6535d59 100644 --- a/src/skillforge/integrate.rs +++ b/src/skillforge/integrate.rs @@ -191,8 +191,8 @@ mod tests { } } - #[test] - fn integrate_creates_files() { + #[tokio::test] + async fn integrate_creates_files() { let tmp = std::env::temp_dir().join("zeroclaw-test-integrate"); let _ = fs::remove_dir_all(&tmp); @@ -203,11 +203,15 @@ mod tests { assert!(path.join("SKILL.toml").exists()); assert!(path.join("SKILL.md").exists()); - let toml = fs::read_to_string(path.join("SKILL.toml")).unwrap(); + let toml = tokio::fs::read_to_string(path.join("SKILL.toml")) + .await + .unwrap(); assert!(toml.contains("name = \"test-skill\"")); assert!(toml.contains("stars = 42")); - let md = fs::read_to_string(path.join("SKILL.md")).unwrap(); + let md = tokio::fs::read_to_string(path.join("SKILL.md")) + .await + .unwrap(); assert!(md.contains("# test-skill")); assert!(md.contains("A test skill for unit tests")); diff --git a/src/skills/mod.rs b/src/skills/mod.rs index 4db6cbb..bca6fff 100644 --- a/src/skills/mod.rs +++ b/src/skills/mod.rs @@ -71,9 +71,28 @@ fn default_version() -> String { /// Load all skills from the workspace skills directory pub fn load_skills(workspace_dir: &Path) -> Vec { + load_skills_with_open_skills_config(workspace_dir, None, None) +} + +/// Load skills using runtime config values (preferred at runtime). +pub fn load_skills_with_config(workspace_dir: &Path, config: &crate::config::Config) -> Vec { + load_skills_with_open_skills_config( + workspace_dir, + Some(config.skills.open_skills_enabled), + config.skills.open_skills_dir.as_deref(), + ) +} + +fn load_skills_with_open_skills_config( + workspace_dir: &Path, + config_open_skills_enabled: Option, + config_open_skills_dir: Option<&str>, +) -> Vec { let mut skills = Vec::new(); - if let Some(open_skills_dir) = ensure_open_skills_repo() { + if let Some(open_skills_dir) = + ensure_open_skills_repo(config_open_skills_enabled, config_open_skills_dir) + { skills.extend(load_open_skills(&open_skills_dir)); } @@ -158,33 +177,79 @@ fn load_open_skills(repo_dir: &Path) -> Vec { skills } -fn open_skills_enabled() -> bool { - if let Ok(raw) = std::env::var("ZEROCLAW_OPEN_SKILLS_ENABLED") { - let value = raw.trim().to_ascii_lowercase(); - return !matches!(value.as_str(), "0" | "false" | "off" | "no"); +fn parse_open_skills_enabled(raw: &str) -> Option { + match raw.trim().to_ascii_lowercase().as_str() { + "1" | "true" | "yes" | "on" => Some(true), + "0" | "false" | "no" | "off" => Some(false), + _ => None, } - - // Keep tests deterministic and network-free by default. - !cfg!(test) } -fn resolve_open_skills_dir() -> Option { - if let Ok(path) = std::env::var("ZEROCLAW_OPEN_SKILLS_DIR") { - let trimmed = path.trim(); - if !trimmed.is_empty() { - return Some(PathBuf::from(trimmed)); +fn open_skills_enabled_from_sources( + config_open_skills_enabled: Option, + env_override: Option<&str>, +) -> bool { + if let Some(raw) = env_override { + if let Some(enabled) = parse_open_skills_enabled(&raw) { + return enabled; + } + if !raw.trim().is_empty() { + tracing::warn!( + "Ignoring invalid ZEROCLAW_OPEN_SKILLS_ENABLED (valid: 1|0|true|false|yes|no|on|off)" + ); } } - UserDirs::new().map(|dirs| dirs.home_dir().join("open-skills")) + config_open_skills_enabled.unwrap_or(false) } -fn ensure_open_skills_repo() -> Option { - if !open_skills_enabled() { +fn open_skills_enabled(config_open_skills_enabled: Option) -> bool { + let env_override = std::env::var("ZEROCLAW_OPEN_SKILLS_ENABLED").ok(); + open_skills_enabled_from_sources(config_open_skills_enabled, env_override.as_deref()) +} + +fn resolve_open_skills_dir_from_sources( + env_dir: Option<&str>, + config_dir: Option<&str>, + home_dir: Option<&Path>, +) -> Option { + let parse_dir = |raw: &str| { + let trimmed = raw.trim(); + if trimmed.is_empty() { + None + } else { + Some(PathBuf::from(trimmed)) + } + }; + + if let Some(env_dir) = env_dir.and_then(parse_dir) { + return Some(env_dir); + } + if let Some(config_dir) = config_dir.and_then(parse_dir) { + return Some(config_dir); + } + home_dir.map(|home| home.join("open-skills")) +} + +fn resolve_open_skills_dir(config_open_skills_dir: Option<&str>) -> Option { + let env_dir = std::env::var("ZEROCLAW_OPEN_SKILLS_DIR").ok(); + let home_dir = UserDirs::new().map(|dirs| dirs.home_dir().to_path_buf()); + resolve_open_skills_dir_from_sources( + env_dir.as_deref(), + config_open_skills_dir, + home_dir.as_deref(), + ) +} + +fn ensure_open_skills_repo( + config_open_skills_enabled: Option, + config_open_skills_dir: Option<&str>, +) -> Option { + if !open_skills_enabled(config_open_skills_enabled) { return None; } - let repo_dir = resolve_open_skills_dir()?; + let repo_dir = resolve_open_skills_dir(config_open_skills_dir)?; if !repo_dir.exists() { if !clone_open_skills_repo(&repo_dir) { @@ -354,39 +419,84 @@ fn extract_description(content: &str) -> String { .to_string() } -/// Build a system prompt addition from all loaded skills -pub fn skills_to_prompt(skills: &[Skill]) -> String { +fn append_xml_escaped(out: &mut String, text: &str) { + for ch in text.chars() { + match ch { + '&' => out.push_str("&"), + '<' => out.push_str("<"), + '>' => out.push_str(">"), + '"' => out.push_str("""), + '\'' => out.push_str("'"), + _ => out.push(ch), + } + } +} + +fn write_xml_text_element(out: &mut String, indent: usize, tag: &str, value: &str) { + for _ in 0..indent { + out.push(' '); + } + out.push('<'); + out.push_str(tag); + out.push('>'); + append_xml_escaped(out, value); + out.push_str("\n"); +} + +/// Build the "Available Skills" system prompt section with full skill instructions. +pub fn skills_to_prompt(skills: &[Skill], workspace_dir: &Path) -> String { use std::fmt::Write; if skills.is_empty() { return String::new(); } - let mut prompt = String::from("\n## Active Skills\n\n"); + let mut prompt = String::from( + "## Available Skills\n\n\ + Skill instructions and tool metadata are preloaded below.\n\ + Follow these instructions directly; do not read skill files at runtime unless the user asks.\n\n\ + \n", + ); for skill in skills { - let _ = writeln!(prompt, "### {} (v{})", skill.name, skill.version); - let _ = writeln!(prompt, "{}", skill.description); + let _ = writeln!(prompt, " "); + write_xml_text_element(&mut prompt, 4, "name", &skill.name); + write_xml_text_element(&mut prompt, 4, "description", &skill.description); + + let location = skill.location.clone().unwrap_or_else(|| { + workspace_dir + .join("skills") + .join(&skill.name) + .join("SKILL.md") + }); + write_xml_text_element(&mut prompt, 4, "location", &location.display().to_string()); + + if !skill.prompts.is_empty() { + let _ = writeln!(prompt, " "); + for instruction in &skill.prompts { + write_xml_text_element(&mut prompt, 6, "instruction", instruction); + } + let _ = writeln!(prompt, " "); + } if !skill.tools.is_empty() { - prompt.push_str("Tools:\n"); + let _ = writeln!(prompt, " "); for tool in &skill.tools { - let _ = writeln!( - prompt, - "- **{}**: {} ({})", - tool.name, tool.description, tool.kind - ); + let _ = writeln!(prompt, " "); + write_xml_text_element(&mut prompt, 8, "name", &tool.name); + write_xml_text_element(&mut prompt, 8, "description", &tool.description); + write_xml_text_element(&mut prompt, 8, "kind", &tool.kind); + let _ = writeln!(prompt, " "); } + let _ = writeln!(prompt, " "); } - for p in &skill.prompts { - prompt.push_str(p); - prompt.push('\n'); - } - - prompt.push('\n'); + let _ = writeln!(prompt, " "); } + prompt.push_str(""); prompt } @@ -425,7 +535,7 @@ pub fn init_skills_dir(workspace_dir: &Path) -> Result<()> { The agent will read it and follow the instructions.\n\n\ ## Installing community skills\n\n\ ```bash\n\ - zeroclaw skills install \n\ + zeroclaw skills install \n\ zeroclaw skills list\n\ ```\n", )?; @@ -434,6 +544,50 @@ pub fn init_skills_dir(workspace_dir: &Path) -> Result<()> { Ok(()) } +fn is_git_source(source: &str) -> bool { + is_git_scheme_source(source, "https://") + || is_git_scheme_source(source, "http://") + || is_git_scheme_source(source, "ssh://") + || is_git_scheme_source(source, "git://") + || is_git_scp_source(source) +} + +fn is_git_scheme_source(source: &str, scheme: &str) -> bool { + let Some(rest) = source.strip_prefix(scheme) else { + return false; + }; + if rest.is_empty() || rest.starts_with('/') { + return false; + } + + let host = rest.split(['/', '?', '#']).next().unwrap_or_default(); + !host.is_empty() +} + +fn is_git_scp_source(source: &str) -> bool { + // SCP-like syntax accepted by git, e.g. git@host:owner/repo.git + // Keep this strict enough to avoid treating local paths as git remotes. + let Some((user_host, remote_path)) = source.split_once(':') else { + return false; + }; + if remote_path.is_empty() { + return false; + } + if source.contains("://") { + return false; + } + + let Some((user, host)) = user_host.split_once('@') else { + return false; + }; + !user.is_empty() + && !host.is_empty() + && !user.contains('/') + && !user.contains('\\') + && !host.contains('/') + && !host.contains('\\') +} + /// Recursively copy a directory (used as fallback when symlinks aren't available) #[cfg(any(windows, not(unix)))] fn copy_dir_recursive(src: &Path, dest: &Path) -> Result<()> { @@ -453,17 +607,18 @@ fn copy_dir_recursive(src: &Path, dest: &Path) -> Result<()> { /// Handle the `skills` CLI command #[allow(clippy::too_many_lines)] -pub fn handle_command(command: crate::SkillCommands, workspace_dir: &Path) -> Result<()> { +pub fn handle_command(command: crate::SkillCommands, config: &crate::config::Config) -> Result<()> { + let workspace_dir = &config.workspace_dir; match command { crate::SkillCommands::List => { - let skills = load_skills(workspace_dir); + let skills = load_skills_with_config(workspace_dir, config); if skills.is_empty() { println!("No skills installed."); println!(); println!(" Create one: mkdir -p ~/.zeroclaw/workspace/skills/my-skill"); println!(" echo '# My Skill' > ~/.zeroclaw/workspace/skills/my-skill/SKILL.md"); println!(); - println!(" Or install: zeroclaw skills install "); + println!(" Or install: zeroclaw skills install "); } else { println!("Installed skills ({}):", skills.len()); println!(); @@ -499,7 +654,7 @@ pub fn handle_command(command: crate::SkillCommands, workspace_dir: &Path) -> Re let skills_path = skills_dir(workspace_dir); std::fs::create_dir_all(&skills_path)?; - if source.starts_with("https://") || source.starts_with("http://") { + if is_git_source(&source) { // Git clone let output = std::process::Command::new("git") .args(["clone", "--depth", "1", &source]) @@ -622,6 +777,35 @@ pub fn handle_command(command: crate::SkillCommands, workspace_dir: &Path) -> Re mod tests { use super::*; use std::fs; + use std::sync::{Mutex, OnceLock}; + + fn open_skills_env_lock() -> &'static Mutex<()> { + static ENV_LOCK: OnceLock> = OnceLock::new(); + ENV_LOCK.get_or_init(|| Mutex::new(())) + } + + struct EnvVarGuard { + key: &'static str, + original: Option, + } + + impl EnvVarGuard { + fn unset(key: &'static str) -> Self { + let original = std::env::var(key).ok(); + std::env::remove_var(key); + Self { key, original } + } + } + + impl Drop for EnvVarGuard { + fn drop(&mut self) { + if let Some(value) = &self.original { + std::env::set_var(self.key, value); + } else { + std::env::remove_var(self.key); + } + } + } #[test] fn load_empty_skills_dir() { @@ -683,7 +867,7 @@ command = "echo hello" #[test] fn skills_to_prompt_empty() { - let prompt = skills_to_prompt(&[]); + let prompt = skills_to_prompt(&[], Path::new("/tmp")); assert!(prompt.is_empty()); } @@ -699,9 +883,10 @@ command = "echo hello" prompts: vec!["Do the thing.".to_string()], location: None, }]; - let prompt = skills_to_prompt(&skills); - assert!(prompt.contains("test")); - assert!(prompt.contains("Do the thing")); + let prompt = skills_to_prompt(&skills, Path::new("/tmp")); + assert!(prompt.contains("")); + assert!(prompt.contains("test")); + assert!(prompt.contains("Do the thing.")); } #[test] @@ -889,11 +1074,71 @@ description = "Bare minimum" prompts: vec![], location: None, }]; - let prompt = skills_to_prompt(&skills); + let prompt = skills_to_prompt(&skills, Path::new("/tmp")); assert!(prompt.contains("weather")); - assert!(prompt.contains("get_weather")); - assert!(prompt.contains("Fetch forecast")); - assert!(prompt.contains("shell")); + assert!(prompt.contains("get_weather")); + assert!(prompt.contains("Fetch forecast")); + assert!(prompt.contains("shell")); + } + + #[test] + fn skills_to_prompt_escapes_xml_content() { + let skills = vec![Skill { + name: "xml".to_string(), + description: "A & B".to_string(), + version: "1.0.0".to_string(), + author: None, + tags: vec![], + tools: vec![], + prompts: vec!["Use & check \"quotes\".".to_string()], + location: None, + }]; + + let prompt = skills_to_prompt(&skills, Path::new("/tmp")); + assert!(prompt.contains("xml<skill>")); + assert!(prompt.contains("A & B")); + assert!(prompt.contains( + "Use <tool> & check "quotes"." + )); + } + + #[test] + fn git_source_detection_accepts_remote_protocols_and_scp_style() { + let sources = [ + "https://github.com/some-org/some-skill.git", + "http://github.com/some-org/some-skill.git", + "ssh://git@github.com/some-org/some-skill.git", + "git://github.com/some-org/some-skill.git", + "git@github.com:some-org/some-skill.git", + "git@localhost:skills/some-skill.git", + ]; + + for source in sources { + assert!( + is_git_source(source), + "expected git source detection for '{source}'" + ); + } + } + + #[test] + fn git_source_detection_rejects_local_paths_and_invalid_inputs() { + let sources = [ + "./skills/local-skill", + "/tmp/skills/local-skill", + "C:\\skills\\local-skill", + "git@github.com", + "ssh://", + "not-a-url", + "dir/git@github.com:org/repo.git", + ]; + + for source in sources { + assert!( + !is_git_source(source), + "expected local/invalid source detection for '{source}'" + ); + } } #[test] @@ -921,6 +1166,78 @@ description = "Bare minimum" assert_eq!(skills.len(), 1); assert_eq!(skills[0].name, "from-toml"); // TOML takes priority } + + #[test] + fn open_skills_enabled_resolution_prefers_env_then_config_then_default_false() { + assert!(!open_skills_enabled_from_sources(None, None)); + assert!(open_skills_enabled_from_sources(Some(true), None)); + assert!(!open_skills_enabled_from_sources(Some(true), Some("0"))); + assert!(open_skills_enabled_from_sources(Some(false), Some("yes"))); + // Invalid env values should fall back to config. + assert!(open_skills_enabled_from_sources( + Some(true), + Some("invalid") + )); + assert!(!open_skills_enabled_from_sources( + Some(false), + Some("invalid") + )); + } + + #[test] + fn resolve_open_skills_dir_resolution_prefers_env_then_config_then_home() { + let home = Path::new("/tmp/home-dir"); + assert_eq!( + resolve_open_skills_dir_from_sources( + Some("/tmp/env-skills"), + Some("/tmp/config"), + Some(home) + ), + Some(PathBuf::from("/tmp/env-skills")) + ); + assert_eq!( + resolve_open_skills_dir_from_sources( + Some(" "), + Some("/tmp/config-skills"), + Some(home) + ), + Some(PathBuf::from("/tmp/config-skills")) + ); + assert_eq!( + resolve_open_skills_dir_from_sources(None, None, Some(home)), + Some(PathBuf::from("/tmp/home-dir/open-skills")) + ); + assert_eq!(resolve_open_skills_dir_from_sources(None, None, None), None); + } + + #[test] + fn load_skills_with_config_reads_open_skills_dir_without_network() { + let _env_guard = open_skills_env_lock().lock().unwrap(); + let _enabled_guard = EnvVarGuard::unset("ZEROCLAW_OPEN_SKILLS_ENABLED"); + let _dir_guard = EnvVarGuard::unset("ZEROCLAW_OPEN_SKILLS_DIR"); + + let dir = tempfile::tempdir().unwrap(); + let workspace_dir = dir.path().join("workspace"); + fs::create_dir_all(workspace_dir.join("skills")).unwrap(); + + let open_skills_dir = dir.path().join("open-skills-local"); + fs::create_dir_all(&open_skills_dir).unwrap(); + fs::write(open_skills_dir.join("README.md"), "# open skills\n").unwrap(); + fs::write( + open_skills_dir.join("http_request.md"), + "# HTTP request\nFetch API responses.\n", + ) + .unwrap(); + + let mut config = crate::config::Config::default(); + config.workspace_dir = workspace_dir.clone(); + config.skills.open_skills_enabled = true; + config.skills.open_skills_dir = Some(open_skills_dir.to_string_lossy().to_string()); + + let skills = load_skills_with_config(&workspace_dir, &config); + assert_eq!(skills.len(), 1); + assert_eq!(skills[0].name, "http_request"); + } } #[cfg(test)] diff --git a/src/skills/symlink_tests.rs b/src/skills/symlink_tests.rs index c77393a..da50891 100644 --- a/src/skills/symlink_tests.rs +++ b/src/skills/symlink_tests.rs @@ -4,21 +4,23 @@ mod tests { use std::path::Path; use tempfile::TempDir; - #[test] - fn test_skills_symlink_unix_edge_cases() { + #[tokio::test] + async fn test_skills_symlink_unix_edge_cases() { let tmp = TempDir::new().unwrap(); let workspace_dir = tmp.path().join("workspace"); - std::fs::create_dir_all(&workspace_dir).unwrap(); + tokio::fs::create_dir_all(&workspace_dir).await.unwrap(); let skills_path = skills_dir(&workspace_dir); - std::fs::create_dir_all(&skills_path).unwrap(); + tokio::fs::create_dir_all(&skills_path).await.unwrap(); // Test case 1: Valid symlink creation on Unix #[cfg(unix)] { let source_dir = tmp.path().join("source_skill"); - std::fs::create_dir_all(&source_dir).unwrap(); - std::fs::write(source_dir.join("SKILL.md"), "# Test Skill\nContent").unwrap(); + tokio::fs::create_dir_all(&source_dir).await.unwrap(); + tokio::fs::write(source_dir.join("SKILL.md"), "# Test Skill\nContent") + .await + .unwrap(); let dest_link = skills_path.join("linked_skill"); @@ -31,7 +33,7 @@ mod tests { assert!(dest_link.is_symlink()); // Verify we can read through symlink - let content = std::fs::read_to_string(dest_link.join("SKILL.md")); + let content = tokio::fs::read_to_string(dest_link.join("SKILL.md")).await; assert!(content.is_ok()); assert!(content.unwrap().contains("Test Skill")); @@ -45,7 +47,7 @@ mod tests { ); // But reading through it should fail - let content = std::fs::read_to_string(broken_link.join("SKILL.md")); + let content = tokio::fs::read_to_string(broken_link.join("SKILL.md")).await; assert!(content.is_err()); } @@ -53,7 +55,7 @@ mod tests { #[cfg(windows)] { let source_dir = tmp.path().join("source_skill"); - std::fs::create_dir_all(&source_dir).unwrap(); + tokio::fs::create_dir_all(&source_dir).await.unwrap(); let dest_link = skills_path.join("linked_skill"); @@ -64,7 +66,7 @@ mod tests { assert!(!dest_link.exists()); } else { // Clean up if it succeeded - let _ = std::fs::remove_dir(&dest_link); + let _ = tokio::fs::remove_dir(&dest_link).await; } } @@ -80,21 +82,23 @@ mod tests { assert!(!empty_skills_path.exists()); } - #[test] - fn test_skills_symlink_permissions_and_safety() { + #[tokio::test] + async fn test_skills_symlink_permissions_and_safety() { let tmp = TempDir::new().unwrap(); let workspace_dir = tmp.path().join("workspace"); - std::fs::create_dir_all(&workspace_dir).unwrap(); + tokio::fs::create_dir_all(&workspace_dir).await.unwrap(); let skills_path = skills_dir(&workspace_dir); - std::fs::create_dir_all(&skills_path).unwrap(); + tokio::fs::create_dir_all(&skills_path).await.unwrap(); #[cfg(unix)] { // Test case: Symlink outside workspace should be allowed (user responsibility) let outside_dir = tmp.path().join("outside_skill"); - std::fs::create_dir_all(&outside_dir).unwrap(); - std::fs::write(outside_dir.join("SKILL.md"), "# Outside Skill\nContent").unwrap(); + tokio::fs::create_dir_all(&outside_dir).await.unwrap(); + tokio::fs::write(outside_dir.join("SKILL.md"), "# Outside Skill\nContent") + .await + .unwrap(); let dest_link = skills_path.join("outside_skill"); let result = std::os::unix::fs::symlink(&outside_dir, &dest_link); @@ -104,7 +108,7 @@ mod tests { ); // Should still be readable - let content = std::fs::read_to_string(dest_link.join("SKILL.md")); + let content = tokio::fs::read_to_string(dest_link.join("SKILL.md")).await; assert!(content.is_ok()); assert!(content.unwrap().contains("Outside Skill")); } diff --git a/src/tools/browser.rs b/src/tools/browser.rs index 519e317..55251d7 100644 --- a/src/tools/browser.rs +++ b/src/tools/browser.rs @@ -19,7 +19,7 @@ use tokio::process::Command; use tracing::debug; /// Computer-use sidecar settings. -#[derive(Debug, Clone)] +#[derive(Clone)] pub struct ComputerUseConfig { pub endpoint: String, pub api_key: Option, @@ -30,6 +30,20 @@ pub struct ComputerUseConfig { pub max_coordinate_y: Option, } +impl std::fmt::Debug for ComputerUseConfig { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("ComputerUseConfig") + .field("endpoint", &self.endpoint) + .field("api_key", &self.api_key.as_ref().map(|_| "[REDACTED]")) + .field("timeout_ms", &self.timeout_ms) + .field("allow_remote_endpoint", &self.allow_remote_endpoint) + .field("window_allowlist", &self.window_allowlist) + .field("max_coordinate_x", &self.max_coordinate_x) + .field("max_coordinate_y", &self.max_coordinate_y) + .finish() + } +} + impl Default for ComputerUseConfig { fn default() -> Self { Self { @@ -1211,7 +1225,8 @@ mod native_backend { }); if let Some(path_str) = path { - std::fs::write(&path_str, &png) + tokio::fs::write(&path_str, &png) + .await .with_context(|| format!("Failed to write screenshot to {path_str}"))?; payload["path"] = Value::String(path_str); } else { diff --git a/src/tools/composio.rs b/src/tools/composio.rs index bfa5a0d..c191ac1 100644 --- a/src/tools/composio.rs +++ b/src/tools/composio.rs @@ -11,19 +11,32 @@ use crate::security::policy::ToolOperation; use crate::security::SecurityPolicy; use anyhow::Context; use async_trait::async_trait; +use parking_lot::RwLock; use reqwest::Client; use serde::{Deserialize, Serialize}; use serde_json::json; +use std::collections::HashMap; use std::sync::Arc; const COMPOSIO_API_BASE_V2: &str = "https://backend.composio.dev/api/v2"; const COMPOSIO_API_BASE_V3: &str = "https://backend.composio.dev/api/v3"; +const COMPOSIO_TOOL_VERSION_LATEST: &str = "latest"; + +fn ensure_https(url: &str) -> anyhow::Result<()> { + if !url.starts_with("https://") { + anyhow::bail!( + "Refusing to transmit sensitive data over non-HTTPS URL: URL scheme must be https" + ); + } + Ok(()) +} /// A tool that proxies actions to the Composio managed tool platform. pub struct ComposioTool { api_key: String, default_entity_id: String, security: Arc, + recent_connected_accounts: RwLock>, } impl ComposioTool { @@ -36,6 +49,7 @@ impl ComposioTool { api_key: api_key.to_string(), default_entity_id: normalize_entity_id(default_entity_id.unwrap_or("default")), security, + recent_connected_accounts: RwLock::new(HashMap::new()), } } @@ -66,12 +80,11 @@ impl ComposioTool { async fn list_actions_v3(&self, app_name: Option<&str>) -> anyhow::Result> { let url = format!("{COMPOSIO_API_BASE_V3}/tools"); - let mut req = self.client().get(&url).header("x-api-key", &self.api_key); - - req = req.query(&[("limit", "200")]); - if let Some(app) = app_name.map(str::trim).filter(|app| !app.is_empty()) { - req = req.query(&[("toolkits", app), ("toolkit_slug", app)]); - } + let req = self + .client() + .get(&url) + .header("x-api-key", &self.api_key) + .query(&Self::build_list_actions_v3_query(app_name)); let resp = req.send().await?; if !resp.status().is_success() { @@ -111,39 +124,186 @@ impl ComposioTool { Ok(body.items) } + /// List connected accounts for a user and optional toolkit/app. + async fn list_connected_accounts( + &self, + app_name: Option<&str>, + entity_id: Option<&str>, + ) -> anyhow::Result> { + let url = format!("{COMPOSIO_API_BASE_V3}/connected_accounts"); + let mut req = self.client().get(&url).header("x-api-key", &self.api_key); + + req = req.query(&[ + ("limit", "50"), + ("order_by", "updated_at"), + ("order_direction", "desc"), + ("statuses", "INITIALIZING"), + ("statuses", "ACTIVE"), + ("statuses", "INITIATED"), + ]); + + if let Some(app) = app_name + .map(normalize_app_slug) + .filter(|app| !app.is_empty()) + { + req = req.query(&[("toolkit_slugs", app.as_str())]); + } + + if let Some(entity) = entity_id { + req = req.query(&[("user_ids", entity)]); + } + + let resp = req.send().await?; + if !resp.status().is_success() { + let err = response_error(resp).await; + anyhow::bail!("Composio v3 connected accounts lookup failed: {err}"); + } + + let body: ComposioConnectedAccountsResponse = resp + .json() + .await + .context("Failed to decode Composio v3 connected accounts response")?; + Ok(body.items) + } + + fn cache_connected_account(&self, app_name: &str, entity_id: &str, connected_account_id: &str) { + let key = connected_account_cache_key(app_name, entity_id); + self.recent_connected_accounts + .write() + .insert(key, connected_account_id.to_string()); + } + + fn get_cached_connected_account(&self, app_name: &str, entity_id: &str) -> Option { + let key = connected_account_cache_key(app_name, entity_id); + self.recent_connected_accounts.read().get(&key).cloned() + } + + async fn resolve_connected_account_ref( + &self, + app_name: Option<&str>, + entity_id: Option<&str>, + ) -> anyhow::Result> { + let app = app_name + .map(normalize_app_slug) + .filter(|app| !app.is_empty()); + let entity = entity_id.map(normalize_entity_id); + let (Some(app), Some(entity)) = (app, entity) else { + return Ok(None); + }; + + if let Some(cached) = self.get_cached_connected_account(&app, &entity) { + return Ok(Some(cached)); + } + + let accounts = self + .list_connected_accounts(Some(&app), Some(&entity)) + .await?; + // The API returns accounts ordered by updated_at DESC, so the first + // usable account is the most recently active one. We always pick it + // rather than giving up when multiple accounts exist — giving up was + // the root cause of the "cannot find connected account" loop reported + // in issue #959. + let Some(first) = accounts.into_iter().find(|acct| acct.is_usable()) else { + return Ok(None); + }; + + self.cache_connected_account(&app, &entity, &first.id); + Ok(Some(first.id)) + } + /// Execute a Composio action/tool with given parameters. /// /// Uses v3 endpoint first and falls back to v2 for compatibility. pub async fn execute_action( &self, action_name: &str, + app_name_hint: Option<&str>, params: serde_json::Value, entity_id: Option<&str>, connected_account_ref: Option<&str>, ) -> anyhow::Result { let tool_slug = normalize_tool_slug(action_name); + let app_hint = app_name_hint + .map(normalize_app_slug) + .filter(|app| !app.is_empty()) + .or_else(|| infer_app_slug_from_action_name(action_name)); + let normalized_entity_id = entity_id.map(normalize_entity_id); + let explicit_account_ref = connected_account_ref.and_then(|candidate| { + let trimmed = candidate.trim(); + (!trimmed.is_empty()).then_some(trimmed.to_string()) + }); + let resolved_account_ref = if explicit_account_ref.is_some() { + explicit_account_ref + } else { + self.resolve_connected_account_ref(app_hint.as_deref(), entity_id) + .await? + }; match self - .execute_action_v3(&tool_slug, params.clone(), entity_id, connected_account_ref) + .execute_action_v3( + &tool_slug, + params.clone(), + entity_id, + resolved_account_ref.as_deref(), + ) .await { Ok(result) => Ok(result), - Err(v3_err) => match self.execute_action_v2(action_name, params, entity_id).await { - Ok(result) => Ok(result), - Err(v2_err) => anyhow::bail!( - "Composio execute failed on v3 ({v3_err}) and v2 fallback ({v2_err})" - ), - }, + Err(v3_err) => { + let mut v2_candidates = vec![action_name.trim().to_string()]; + let legacy_action_name = normalize_legacy_action_name(action_name); + if !legacy_action_name.is_empty() && !v2_candidates.contains(&legacy_action_name) { + v2_candidates.push(legacy_action_name); + } + + let mut v2_errors = Vec::new(); + for candidate in v2_candidates { + match self + .execute_action_v2(&candidate, params.clone(), entity_id) + .await + { + Ok(result) => return Ok(result), + Err(v2_err) => v2_errors.push(format!("{candidate}: {v2_err}")), + } + } + + anyhow::bail!( + "Composio execute failed on v3 ({v3_err}) and v2 fallback attempts ({}){}", + v2_errors.join(" | "), + build_connected_account_hint( + app_hint.as_deref(), + normalized_entity_id.as_deref(), + resolved_account_ref.as_deref(), + ) + ); + } } } + fn build_list_actions_v3_query(app_name: Option<&str>) -> Vec<(String, String)> { + let mut query = vec![ + ("limit".to_string(), "200".to_string()), + ( + "toolkit_versions".to_string(), + COMPOSIO_TOOL_VERSION_LATEST.to_string(), + ), + ]; + + if let Some(app) = app_name.map(str::trim).filter(|app| !app.is_empty()) { + query.push(("toolkits".to_string(), app.to_string())); + query.push(("toolkit_slug".to_string(), app.to_string())); + } + + query + } + fn build_execute_action_v3_request( tool_slug: &str, params: serde_json::Value, entity_id: Option<&str>, connected_account_ref: Option<&str>, ) -> (String, serde_json::Value) { - let url = format!("{COMPOSIO_API_BASE_V3}/tools/{tool_slug}/execute"); + let url = format!("{COMPOSIO_API_BASE_V3}/tools/execute/{tool_slug}"); let account_ref = connected_account_ref.and_then(|candidate| { let trimmed_candidate = candidate.trim(); (!trimmed_candidate.is_empty()).then_some(trimmed_candidate) @@ -151,6 +311,7 @@ impl ComposioTool { let mut body = json!({ "arguments": params, + "version": COMPOSIO_TOOL_VERSION_LATEST, }); if let Some(entity) = entity_id { @@ -177,6 +338,8 @@ impl ComposioTool { connected_account_ref, ); + ensure_https(&url)?; + let resp = self .client() .post(&url) @@ -241,7 +404,7 @@ impl ComposioTool { app_name: Option<&str>, auth_config_id: Option<&str>, entity_id: &str, - ) -> anyhow::Result { + ) -> anyhow::Result { let v3 = self .get_connection_url_v3(app_name, auth_config_id, entity_id) .await; @@ -268,7 +431,7 @@ impl ComposioTool { app_name: Option<&str>, auth_config_id: Option<&str>, entity_id: &str, - ) -> anyhow::Result { + ) -> anyhow::Result { let auth_config_id = match auth_config_id { Some(id) => id.to_string(), None => { @@ -302,15 +465,19 @@ impl ComposioTool { .json() .await .context("Failed to decode Composio v3 connect response")?; - extract_redirect_url(&result) - .ok_or_else(|| anyhow::anyhow!("No redirect URL in Composio v3 response")) + let redirect_url = extract_redirect_url(&result) + .ok_or_else(|| anyhow::anyhow!("No redirect URL in Composio v3 response"))?; + Ok(ComposioConnectionLink { + redirect_url, + connected_account_id: extract_connected_account_id(&result), + }) } async fn get_connection_url_v2( &self, app_name: &str, entity_id: &str, - ) -> anyhow::Result { + ) -> anyhow::Result { let url = format!("{COMPOSIO_API_BASE_V2}/connectedAccounts"); let body = json!({ @@ -335,8 +502,12 @@ impl ComposioTool { .json() .await .context("Failed to decode Composio v2 connect response")?; - extract_redirect_url(&result) - .ok_or_else(|| anyhow::anyhow!("No redirect URL in Composio v2 response")) + let redirect_url = extract_redirect_url(&result) + .ok_or_else(|| anyhow::anyhow!("No redirect URL in Composio v2 response"))?; + Ok(ComposioConnectionLink { + redirect_url, + connected_account_id: extract_connected_account_id(&result), + }) } async fn resolve_auth_config_id(&self, app_name: &str) -> anyhow::Result { @@ -389,7 +560,9 @@ impl Tool for ComposioTool { fn description(&self) -> &str { "Execute actions on 1000+ apps via Composio (Gmail, Notion, GitHub, Slack, etc.). \ - Use action='list' to see available actions, action='execute' with action_name/tool_slug, params, and optional connected_account_id, \ + Use action='list' to see available actions, \ + action='list_accounts' or action='connected_accounts' to list OAuth-connected accounts after login, \ + action='execute' with action_name/tool_slug and params (connected_account_id auto-resolved when omitted), \ or action='connect' with app/auth_config_id to get OAuth URL." } @@ -399,12 +572,12 @@ impl Tool for ComposioTool { "properties": { "action": { "type": "string", - "description": "The operation: 'list' (list available actions), 'execute' (run an action), or 'connect' (get OAuth URL)", - "enum": ["list", "execute", "connect"] + "description": "The operation: 'list' (list available actions), 'list_accounts'/'connected_accounts' (list connected accounts), 'execute' (run an action), or 'connect' (get OAuth URL)", + "enum": ["list", "list_accounts", "connected_accounts", "execute", "connect"] }, "app": { "type": "string", - "description": "Toolkit slug filter for 'list', or toolkit/app for 'connect' (e.g. 'gmail', 'notion', 'github')" + "description": "Toolkit slug filter for 'list' or 'list_accounts', optional app hint for 'execute', or toolkit/app for 'connect' (e.g. 'gmail', 'notion', 'github')" }, "action_name": { "type": "string", @@ -487,6 +660,56 @@ impl Tool for ComposioTool { } } + // Accept both spellings so the LLM can use either. + "list_accounts" | "connected_accounts" => { + let app = args.get("app").and_then(|v| v.as_str()); + match self.list_connected_accounts(app, Some(entity_id)).await { + Ok(accounts) => { + if accounts.is_empty() { + let app_hint = app + .map(|value| format!(" for app '{value}'")) + .unwrap_or_default(); + return Ok(ToolResult { + success: true, + output: format!( + "No connected accounts found{app_hint} for entity '{entity_id}'. Run action='connect' first." + ), + error: None, + }); + } + + let summary: Vec = accounts + .iter() + .take(20) + .map(|account| { + let toolkit = account.toolkit_slug().unwrap_or("?"); + format!("- {} [{}] toolkit={toolkit}", account.id, account.status) + }) + .collect(); + let total = accounts.len(); + let output = format!( + "Found {total} connected accounts (entity '{entity_id}'):\n{}{}\nUse connected_account_id in action='execute' when needed.", + summary.join("\n"), + if total > 20 { + format!("\n... and {} more", total - 20) + } else { + String::new() + } + ); + Ok(ToolResult { + success: true, + output, + error: None, + }) + } + Err(e) => Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!("Failed to list connected accounts: {e}")), + }), + } + } + "execute" => { if let Err(error) = self .security @@ -507,12 +730,12 @@ impl Tool for ComposioTool { anyhow::anyhow!("Missing 'action_name' (or 'tool_slug') for execute") })?; + let app = args.get("app").and_then(|v| v.as_str()); let params = args.get("params").cloned().unwrap_or(json!({})); - let connected_account_ref = - args.get("connected_account_id").and_then(|v| v.as_str()); + let acct_ref = args.get("connected_account_id").and_then(|v| v.as_str()); match self - .execute_action(action_name, params, Some(entity_id), connected_account_ref) + .execute_action(action_name, app, params, Some(entity_id), acct_ref) .await { Ok(result) => { @@ -555,12 +778,24 @@ impl Tool for ComposioTool { .get_connection_url(app, auth_config_id, entity_id) .await { - Ok(url) => { + Ok(link) => { let target = app.unwrap_or(auth_config_id.unwrap_or("provided auth config")); + let mut output = format!( + "Open this URL to connect {target}:\n{}", + link.redirect_url + ); + if let Some(connected_account_id) = link.connected_account_id.as_deref() { + if let Some(app_name) = app { + self.cache_connected_account(app_name, entity_id, connected_account_id); + } + output.push_str(&format!( + "\nConnected account ID: {connected_account_id}" + )); + } Ok(ToolResult { success: true, - output: format!("Open this URL to connect {target}:\n{url}"), + output, error: None, }) } @@ -576,7 +811,7 @@ impl Tool for ComposioTool { success: false, output: String::new(), error: Some(format!( - "Unknown action '{action}'. Use 'list', 'execute', or 'connect'." + "Unknown action '{action}'. Use 'list', 'list_accounts', 'execute', or 'connect'." )), }), } @@ -596,6 +831,71 @@ fn normalize_tool_slug(action_name: &str) -> String { action_name.trim().replace('_', "-").to_ascii_lowercase() } +fn normalize_legacy_action_name(action_name: &str) -> String { + action_name.trim().replace('-', "_").to_ascii_uppercase() +} + +fn normalize_app_slug(app_name: &str) -> String { + app_name + .trim() + .replace('_', "-") + .to_ascii_lowercase() + .split('-') + .filter(|part| !part.is_empty()) + .collect::>() + .join("-") +} + +fn infer_app_slug_from_action_name(action_name: &str) -> Option { + let trimmed = action_name.trim(); + if trimmed.is_empty() { + return None; + } + + let raw = if trimmed.contains('-') { + trimmed.split('-').next() + } else if trimmed.contains('_') { + trimmed.split('_').next() + } else { + None + }?; + + let app = normalize_app_slug(raw); + (!app.is_empty()).then_some(app) +} + +fn connected_account_cache_key(app_name: &str, entity_id: &str) -> String { + format!( + "{}:{}", + normalize_entity_id(entity_id), + normalize_app_slug(app_name) + ) +} + +fn build_connected_account_hint( + app_hint: Option<&str>, + entity_id: Option<&str>, + connected_account_ref: Option<&str>, +) -> String { + if connected_account_ref.is_some() { + return String::new(); + } + + let Some(entity) = entity_id else { + return String::new(); + }; + + if let Some(app) = app_hint { + format!( + " Hint: use action='list_accounts' with app='{app}' and entity_id='{entity}' to retrieve connected_account_id." + ) + } else { + format!( + " Hint: use action='list_accounts' with entity_id='{entity}' to retrieve connected_account_id." + ) + } +} + fn map_v3_tools_to_actions(items: Vec) -> Vec { items .into_iter() @@ -631,6 +931,26 @@ fn extract_redirect_url(result: &serde_json::Value) -> Option { .map(ToString::to_string) } +fn extract_connected_account_id(result: &serde_json::Value) -> Option { + result + .get("connected_account_id") + .and_then(|v| v.as_str()) + .or_else(|| result.get("connectedAccountId").and_then(|v| v.as_str())) + .or_else(|| { + result + .get("data") + .and_then(|v| v.get("connected_account_id")) + .and_then(|v| v.as_str()) + }) + .or_else(|| { + result + .get("data") + .and_then(|v| v.get("connectedAccountId")) + .and_then(|v| v.as_str()) + }) + .map(ToString::to_string) +} + async fn response_error(resp: reqwest::Response) -> String { let status = resp.status(); let body = resp.text().await.unwrap_or_default(); @@ -703,6 +1023,35 @@ struct ComposioToolsResponse { items: Vec, } +#[derive(Debug, Deserialize)] +struct ComposioConnectedAccountsResponse { + #[serde(default)] + items: Vec, +} + +#[derive(Debug, Clone, Deserialize)] +struct ComposioConnectedAccount { + id: String, + #[serde(default)] + status: String, + #[serde(default)] + toolkit: Option, +} + +impl ComposioConnectedAccount { + fn is_usable(&self) -> bool { + self.status.eq_ignore_ascii_case("INITIALIZING") + || self.status.eq_ignore_ascii_case("ACTIVE") + || self.status.eq_ignore_ascii_case("INITIATED") + } + + fn toolkit_slug(&self) -> Option<&str> { + self.toolkit + .as_ref() + .and_then(|toolkit| toolkit.slug.as_deref()) + } +} + #[derive(Debug, Clone, Deserialize)] struct ComposioV3Tool { #[serde(default)] @@ -731,6 +1080,12 @@ struct ComposioAuthConfigsResponse { items: Vec, } +#[derive(Debug, Clone)] +pub struct ComposioConnectionLink { + pub redirect_url: String, + pub connected_account_id: Option, +} + #[derive(Debug, Clone, Deserialize)] struct ComposioAuthConfig { id: String, @@ -797,6 +1152,13 @@ mod tests { assert!(schema["properties"]["connected_account_id"].is_object()); let required = schema["required"].as_array().unwrap(); assert!(required.contains(&json!("action"))); + let enum_values = schema["properties"]["action"]["enum"] + .as_array() + .unwrap() + .iter() + .filter_map(|v| v.as_str()) + .collect::>(); + assert!(enum_values.contains(&"list_accounts")); } #[test] @@ -956,6 +1318,93 @@ mod tests { ); } + #[test] + fn normalize_legacy_action_name_supports_v3_slug_input() { + assert_eq!( + normalize_legacy_action_name("gmail-fetch-emails"), + "GMAIL_FETCH_EMAILS" + ); + assert_eq!( + normalize_legacy_action_name(" GITHUB_LIST_REPOS "), + "GITHUB_LIST_REPOS" + ); + } + + #[test] + fn normalize_app_slug_removes_spaces_and_normalizes_case() { + assert_eq!(normalize_app_slug(" Gmail "), "gmail"); + assert_eq!(normalize_app_slug("GITHUB_APP"), "github-app"); + } + + #[test] + fn infer_app_slug_from_action_name_handles_v2_and_v3_formats() { + assert_eq!( + infer_app_slug_from_action_name("gmail-fetch-emails").as_deref(), + Some("gmail") + ); + assert_eq!( + infer_app_slug_from_action_name("GMAIL_FETCH_EMAILS").as_deref(), + Some("gmail") + ); + assert!(infer_app_slug_from_action_name("execute").is_none()); + } + + #[test] + fn connected_account_cache_key_is_stable() { + assert_eq!( + connected_account_cache_key("GMAIL", " default "), + "default:gmail" + ); + } + + #[test] + fn build_connected_account_hint_returns_guidance_when_missing_ref() { + let hint = build_connected_account_hint(Some("gmail"), Some("default"), None); + assert!(hint.contains("list_accounts")); + assert!(hint.contains("gmail")); + assert!(hint.contains("default")); + } + + #[test] + fn build_connected_account_hint_without_app_is_still_actionable() { + let hint = build_connected_account_hint(None, Some("default"), None); + assert!(hint.contains("list_accounts")); + assert!(hint.contains("entity_id='default'")); + assert!(!hint.contains("app='")); + } + + #[test] + fn connected_account_is_usable_for_initializing_active_and_initiated() { + for status in ["INITIALIZING", "ACTIVE", "INITIATED"] { + let account = ComposioConnectedAccount { + id: "ca_1".to_string(), + status: status.to_string(), + toolkit: None, + }; + assert!(account.is_usable(), "status {status} should be usable"); + } + } + + #[test] + fn extract_connected_account_id_supports_common_shapes() { + let root = json!({"connected_account_id": "ca_root"}); + let camel = json!({"connectedAccountId": "ca_camel"}); + let nested = json!({"data": {"connected_account_id": "ca_nested"}}); + + assert_eq!( + extract_connected_account_id(&root).as_deref(), + Some("ca_root") + ); + assert_eq!( + extract_connected_account_id(&camel).as_deref(), + Some("ca_camel") + ); + assert_eq!( + extract_connected_account_id(&nested).as_deref(), + Some("ca_nested") + ); + } + #[test] fn extract_redirect_url_supports_v2_and_v3_shapes() { let v2 = json!({"redirectUrl": "https://app.composio.dev/connect-v2"}); @@ -1031,10 +1480,10 @@ mod tests { #[test] fn composio_action_with_unicode() { - let json_str = r#"{"name": "SLACK_SEND_MESSAGE", "appName": "slack", "description": "Send message with emoji 🎉 and unicode 中文", "enabled": true}"#; + let json_str = r#"{"name": "SLACK_SEND_MESSAGE", "appName": "slack", "description": "Send message with emoji 🎉 and unicode Ω", "enabled": true}"#; let action: ComposioAction = serde_json::from_str(json_str).unwrap(); assert!(action.description.as_ref().unwrap().contains("🎉")); - assert!(action.description.as_ref().unwrap().contains("中文")); + assert!(action.description.as_ref().unwrap().contains("Ω")); } #[test] @@ -1083,13 +1532,145 @@ mod tests { assert_eq!( url, - "https://backend.composio.dev/api/v3/tools/gmail-send-email/execute" + "https://backend.composio.dev/api/v3/tools/execute/gmail-send-email" ); assert_eq!(body["arguments"]["to"], json!("test@example.com")); + assert_eq!(body["version"], json!(COMPOSIO_TOOL_VERSION_LATEST)); assert_eq!(body["user_id"], json!("workspace-user")); assert_eq!(body["connected_account_id"], json!("account-42")); } + #[test] + fn build_list_actions_v3_query_requests_latest_versions() { + let query = ComposioTool::build_list_actions_v3_query(None) + .into_iter() + .collect::>(); + assert_eq!( + query.get("toolkit_versions"), + Some(&COMPOSIO_TOOL_VERSION_LATEST.to_string()) + ); + assert_eq!(query.get("limit"), Some(&"200".to_string())); + assert!(!query.contains_key("toolkits")); + assert!(!query.contains_key("toolkit_slug")); + } + + #[test] + fn build_list_actions_v3_query_adds_app_filters_when_present() { + let query = ComposioTool::build_list_actions_v3_query(Some(" github ")) + .into_iter() + .collect::>(); + assert_eq!( + query.get("toolkit_versions"), + Some(&COMPOSIO_TOOL_VERSION_LATEST.to_string()) + ); + assert_eq!(query.get("toolkits"), Some(&"github".to_string())); + assert_eq!(query.get("toolkit_slug"), Some(&"github".to_string())); + } + + // ── resolve_connected_account_ref (multi-account fix) ──── + + #[test] + fn resolve_picks_first_usable_when_multiple_accounts_exist() { + // Regression test for issue #959: previously returned None when + // multiple accounts existed, causing the LLM to loop on the OAuth URL. + let tool = ComposioTool::new("test-key", None, test_security()); + let accounts = vec![ + ComposioConnectedAccount { + id: "ca_old".to_string(), + status: "ACTIVE".to_string(), + toolkit: None, + }, + ComposioConnectedAccount { + id: "ca_new".to_string(), + status: "ACTIVE".to_string(), + toolkit: None, + }, + ]; + // Simulate what resolve_connected_account_ref does: find first usable. + let resolved = accounts.into_iter().find(|a| a.is_usable()).map(|a| a.id); + assert_eq!(resolved.as_deref(), Some("ca_old")); + } + + #[test] + fn resolve_picks_first_usable_skipping_unusable_head() { + let accounts = vec![ + ComposioConnectedAccount { + id: "ca_dead".to_string(), + status: "DISCONNECTED".to_string(), + toolkit: None, + }, + ComposioConnectedAccount { + id: "ca_live".to_string(), + status: "ACTIVE".to_string(), + toolkit: None, + }, + ]; + let resolved = accounts.into_iter().find(|a| a.is_usable()).map(|a| a.id); + assert_eq!(resolved.as_deref(), Some("ca_live")); + } + + #[test] + fn resolve_returns_none_when_no_usable_accounts() { + let accounts = vec![ComposioConnectedAccount { + id: "ca_dead".to_string(), + status: "DISCONNECTED".to_string(), + toolkit: None, + }]; + let resolved = accounts.into_iter().find(|a| a.is_usable()).map(|a| a.id); + assert!(resolved.is_none()); + } + + #[test] + fn resolve_returns_none_for_empty_accounts() { + let accounts: Vec = vec![]; + let resolved = accounts.into_iter().find(|a| a.is_usable()).map(|a| a.id); + assert!(resolved.is_none()); + } + + // ── connected_accounts alias ────────────────────────────── + + #[tokio::test] + async fn connected_accounts_alias_dispatches_same_as_list_accounts() { + // Both spellings should reach the same handler and return the same + // shape of error (network failure in test, not a dispatch error). + let tool = ComposioTool::new("test-key", None, test_security()); + let r1 = tool + .execute(json!({"action": "list_accounts"})) + .await + .unwrap(); + let r2 = tool + .execute(json!({"action": "connected_accounts"})) + .await + .unwrap(); + // Both fail the same way (network) — neither is a dispatch error. + assert!(!r1.success); + assert!(!r2.success); + let e1 = r1.error.unwrap_or_default(); + let e2 = r2.error.unwrap_or_default(); + assert!(!e1.contains("Unknown action"), "list_accounts: {e1}"); + assert!(!e2.contains("Unknown action"), "connected_accounts: {e2}"); + } + + #[test] + fn schema_enum_includes_connected_accounts_alias() { + let tool = ComposioTool::new("test-key", None, test_security()); + let schema = tool.parameters_schema(); + let values: Vec<&str> = schema["properties"]["action"]["enum"] + .as_array() + .unwrap() + .iter() + .filter_map(|v| v.as_str()) + .collect(); + assert!(values.contains(&"connected_accounts")); + assert!(values.contains(&"list_accounts")); + } + + #[test] + fn description_mentions_connected_accounts() { + let tool = ComposioTool::new("test-key", None, test_security()); + assert!(tool.description().contains("connected_accounts")); + } + #[test] fn build_execute_action_v3_request_drops_blank_optional_fields() { let (url, body) = ComposioTool::build_execute_action_v3_request( @@ -1101,9 +1682,10 @@ mod tests { assert_eq!( url, - "https://backend.composio.dev/api/v3/tools/github-list-repos/execute" + "https://backend.composio.dev/api/v3/tools/execute/github-list-repos" ); assert_eq!(body["arguments"], json!({})); + assert_eq!(body["version"], json!(COMPOSIO_TOOL_VERSION_LATEST)); assert!(body.get("connected_account_id").is_none()); assert!(body.get("user_id").is_none()); } diff --git a/src/tools/cron_add.rs b/src/tools/cron_add.rs index bd3abea..a0847b5 100644 --- a/src/tools/cron_add.rs +++ b/src/tools/cron_add.rs @@ -217,13 +217,15 @@ mod tests { use crate::security::AutonomyLevel; use tempfile::TempDir; - fn test_config(tmp: &TempDir) -> Arc { + async fn test_config(tmp: &TempDir) -> Arc { let config = Config { workspace_dir: tmp.path().join("workspace"), config_path: tmp.path().join("config.toml"), ..Config::default() }; - std::fs::create_dir_all(&config.workspace_dir).unwrap(); + tokio::fs::create_dir_all(&config.workspace_dir) + .await + .unwrap(); Arc::new(config) } @@ -237,7 +239,7 @@ mod tests { #[tokio::test] async fn adds_shell_job() { let tmp = TempDir::new().unwrap(); - let cfg = test_config(&tmp); + let cfg = test_config(&tmp).await; let tool = CronAddTool::new(cfg.clone(), test_security(&cfg)); let result = tool .execute(json!({ @@ -262,7 +264,9 @@ mod tests { }; config.autonomy.allowed_commands = vec!["echo".into()]; config.autonomy.level = AutonomyLevel::Supervised; - std::fs::create_dir_all(&config.workspace_dir).unwrap(); + tokio::fs::create_dir_all(&config.workspace_dir) + .await + .unwrap(); let cfg = Arc::new(config); let tool = CronAddTool::new(cfg.clone(), test_security(&cfg)); @@ -285,7 +289,7 @@ mod tests { #[tokio::test] async fn rejects_invalid_schedule() { let tmp = TempDir::new().unwrap(); - let cfg = test_config(&tmp); + let cfg = test_config(&tmp).await; let tool = CronAddTool::new(cfg.clone(), test_security(&cfg)); let result = tool @@ -307,7 +311,7 @@ mod tests { #[tokio::test] async fn agent_job_requires_prompt() { let tmp = TempDir::new().unwrap(); - let cfg = test_config(&tmp); + let cfg = test_config(&tmp).await; let tool = CronAddTool::new(cfg.clone(), test_security(&cfg)); let result = tool diff --git a/src/tools/cron_list.rs b/src/tools/cron_list.rs index 0392370..d83855f 100644 --- a/src/tools/cron_list.rs +++ b/src/tools/cron_list.rs @@ -63,20 +63,22 @@ mod tests { use crate::config::Config; use tempfile::TempDir; - fn test_config(tmp: &TempDir) -> Arc { + async fn test_config(tmp: &TempDir) -> Arc { let config = Config { workspace_dir: tmp.path().join("workspace"), config_path: tmp.path().join("config.toml"), ..Config::default() }; - std::fs::create_dir_all(&config.workspace_dir).unwrap(); + tokio::fs::create_dir_all(&config.workspace_dir) + .await + .unwrap(); Arc::new(config) } #[tokio::test] async fn returns_empty_list_when_no_jobs() { let tmp = TempDir::new().unwrap(); - let cfg = test_config(&tmp); + let cfg = test_config(&tmp).await; let tool = CronListTool::new(cfg); let result = tool.execute(json!({})).await.unwrap(); @@ -87,7 +89,7 @@ mod tests { #[tokio::test] async fn errors_when_cron_disabled() { let tmp = TempDir::new().unwrap(); - let mut cfg = (*test_config(&tmp)).clone(); + let mut cfg = (*test_config(&tmp).await).clone(); cfg.cron.enabled = false; let tool = CronListTool::new(Arc::new(cfg)); diff --git a/src/tools/cron_remove.rs b/src/tools/cron_remove.rs index 01a70dc..5249212 100644 --- a/src/tools/cron_remove.rs +++ b/src/tools/cron_remove.rs @@ -76,20 +76,22 @@ mod tests { use crate::config::Config; use tempfile::TempDir; - fn test_config(tmp: &TempDir) -> Arc { + async fn test_config(tmp: &TempDir) -> Arc { let config = Config { workspace_dir: tmp.path().join("workspace"), config_path: tmp.path().join("config.toml"), ..Config::default() }; - std::fs::create_dir_all(&config.workspace_dir).unwrap(); + tokio::fs::create_dir_all(&config.workspace_dir) + .await + .unwrap(); Arc::new(config) } #[tokio::test] async fn removes_existing_job() { let tmp = TempDir::new().unwrap(); - let cfg = test_config(&tmp); + let cfg = test_config(&tmp).await; let job = cron::add_job(&cfg, "*/5 * * * *", "echo ok").unwrap(); let tool = CronRemoveTool::new(cfg.clone()); @@ -101,7 +103,7 @@ mod tests { #[tokio::test] async fn errors_when_job_id_missing() { let tmp = TempDir::new().unwrap(); - let cfg = test_config(&tmp); + let cfg = test_config(&tmp).await; let tool = CronRemoveTool::new(cfg); let result = tool.execute(json!({})).await.unwrap(); diff --git a/src/tools/cron_run.rs b/src/tools/cron_run.rs index a4e5f75..ad77344 100644 --- a/src/tools/cron_run.rs +++ b/src/tools/cron_run.rs @@ -107,20 +107,22 @@ mod tests { use crate::config::Config; use tempfile::TempDir; - fn test_config(tmp: &TempDir) -> Arc { + async fn test_config(tmp: &TempDir) -> Arc { let config = Config { workspace_dir: tmp.path().join("workspace"), config_path: tmp.path().join("config.toml"), ..Config::default() }; - std::fs::create_dir_all(&config.workspace_dir).unwrap(); + tokio::fs::create_dir_all(&config.workspace_dir) + .await + .unwrap(); Arc::new(config) } #[tokio::test] async fn force_runs_job_and_records_history() { let tmp = TempDir::new().unwrap(); - let cfg = test_config(&tmp); + let cfg = test_config(&tmp).await; let job = cron::add_job(&cfg, "*/5 * * * *", "echo run-now").unwrap(); let tool = CronRunTool::new(cfg.clone()); @@ -134,7 +136,7 @@ mod tests { #[tokio::test] async fn errors_for_missing_job() { let tmp = TempDir::new().unwrap(); - let cfg = test_config(&tmp); + let cfg = test_config(&tmp).await; let tool = CronRunTool::new(cfg); let result = tool diff --git a/src/tools/cron_runs.rs b/src/tools/cron_runs.rs index 280baa1..649b10f 100644 --- a/src/tools/cron_runs.rs +++ b/src/tools/cron_runs.rs @@ -121,20 +121,22 @@ mod tests { use chrono::{Duration as ChronoDuration, Utc}; use tempfile::TempDir; - fn test_config(tmp: &TempDir) -> Arc { + async fn test_config(tmp: &TempDir) -> Arc { let config = Config { workspace_dir: tmp.path().join("workspace"), config_path: tmp.path().join("config.toml"), ..Config::default() }; - std::fs::create_dir_all(&config.workspace_dir).unwrap(); + tokio::fs::create_dir_all(&config.workspace_dir) + .await + .unwrap(); Arc::new(config) } #[tokio::test] async fn lists_runs_with_truncation() { let tmp = TempDir::new().unwrap(); - let cfg = test_config(&tmp); + let cfg = test_config(&tmp).await; let job = cron::add_job(&cfg, "*/5 * * * *", "echo ok").unwrap(); let long_output = "x".repeat(1000); @@ -163,7 +165,7 @@ mod tests { #[tokio::test] async fn errors_when_job_id_missing() { let tmp = TempDir::new().unwrap(); - let cfg = test_config(&tmp); + let cfg = test_config(&tmp).await; let tool = CronRunsTool::new(cfg); let result = tool.execute(json!({})).await.unwrap(); assert!(!result.success); diff --git a/src/tools/cron_update.rs b/src/tools/cron_update.rs index c224b17..d8df72d 100644 --- a/src/tools/cron_update.rs +++ b/src/tools/cron_update.rs @@ -111,13 +111,15 @@ mod tests { use crate::config::Config; use tempfile::TempDir; - fn test_config(tmp: &TempDir) -> Arc { + async fn test_config(tmp: &TempDir) -> Arc { let config = Config { workspace_dir: tmp.path().join("workspace"), config_path: tmp.path().join("config.toml"), ..Config::default() }; - std::fs::create_dir_all(&config.workspace_dir).unwrap(); + tokio::fs::create_dir_all(&config.workspace_dir) + .await + .unwrap(); Arc::new(config) } @@ -131,7 +133,7 @@ mod tests { #[tokio::test] async fn updates_enabled_flag() { let tmp = TempDir::new().unwrap(); - let cfg = test_config(&tmp); + let cfg = test_config(&tmp).await; let job = cron::add_job(&cfg, "*/5 * * * *", "echo ok").unwrap(); let tool = CronUpdateTool::new(cfg.clone(), test_security(&cfg)); @@ -156,7 +158,9 @@ mod tests { ..Config::default() }; config.autonomy.allowed_commands = vec!["echo".into()]; - std::fs::create_dir_all(&config.workspace_dir).unwrap(); + tokio::fs::create_dir_all(&config.workspace_dir) + .await + .unwrap(); let cfg = Arc::new(config); let job = cron::add_job(&cfg, "*/5 * * * *", "echo ok").unwrap(); let tool = CronUpdateTool::new(cfg.clone(), test_security(&cfg)); diff --git a/src/tools/delegate.rs b/src/tools/delegate.rs index fabb99c..9fa20ee 100644 --- a/src/tools/delegate.rs +++ b/src/tools/delegate.rs @@ -21,6 +21,8 @@ pub struct DelegateTool { security: Arc, /// Global credential fallback (from config.api_key) fallback_credential: Option, + /// Provider runtime options inherited from root config. + provider_runtime_options: providers::ProviderRuntimeOptions, /// Depth at which this tool instance lives in the delegation chain. depth: u32, } @@ -30,11 +32,26 @@ impl DelegateTool { agents: HashMap, fallback_credential: Option, security: Arc, + ) -> Self { + Self::new_with_options( + agents, + fallback_credential, + security, + providers::ProviderRuntimeOptions::default(), + ) + } + + pub fn new_with_options( + agents: HashMap, + fallback_credential: Option, + security: Arc, + provider_runtime_options: providers::ProviderRuntimeOptions, ) -> Self { Self { agents: Arc::new(agents), security, fallback_credential, + provider_runtime_options, depth: 0, } } @@ -47,11 +64,28 @@ impl DelegateTool { fallback_credential: Option, security: Arc, depth: u32, + ) -> Self { + Self::with_depth_and_options( + agents, + fallback_credential, + security, + depth, + providers::ProviderRuntimeOptions::default(), + ) + } + + pub fn with_depth_and_options( + agents: HashMap, + fallback_credential: Option, + security: Arc, + depth: u32, + provider_runtime_options: providers::ProviderRuntimeOptions, ) -> Self { Self { agents: Arc::new(agents), security, fallback_credential, + provider_runtime_options, depth, } } @@ -190,20 +224,23 @@ impl Tool for DelegateTool { #[allow(clippy::option_as_ref_deref)] let provider_credential = provider_credential_owned.as_ref().map(String::as_str); - let provider: Box = - match providers::create_provider(&agent_config.provider, provider_credential) { - Ok(p) => p, - Err(e) => { - return Ok(ToolResult { - success: false, - output: String::new(), - error: Some(format!( - "Failed to create provider '{}' for agent '{agent_name}': {e}", - agent_config.provider - )), - }); - } - }; + let provider: Box = match providers::create_provider_with_options( + &agent_config.provider, + provider_credential, + &self.provider_runtime_options, + ) { + Ok(p) => p, + Err(e) => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!( + "Failed to create provider '{}' for agent '{agent_name}': {e}", + agent_config.provider + )), + }); + } + }; // Build the message let full_prompt = if context.is_empty() { diff --git a/src/tools/image_info.rs b/src/tools/image_info.rs index 349f707..558fbb7 100644 --- a/src/tools/image_info.rs +++ b/src/tools/image_info.rs @@ -428,7 +428,7 @@ mod tests { async fn execute_real_file() { // Create a minimal valid PNG let dir = std::env::temp_dir().join("zeroclaw_image_info_test"); - let _ = std::fs::create_dir_all(&dir); + let _ = tokio::fs::create_dir_all(&dir).await; let png_path = dir.join("test.png"); // Minimal 1x1 red PNG (67 bytes) @@ -448,7 +448,7 @@ mod tests { 0x49, 0x45, 0x4E, 0x44, // IEND 0xAE, 0x42, 0x60, 0x82, // CRC ]; - std::fs::write(&png_path, &png_bytes).unwrap(); + tokio::fs::write(&png_path, &png_bytes).await.unwrap(); let tool = ImageInfoTool::new(test_security()); let result = tool @@ -461,13 +461,13 @@ mod tests { assert!(!result.output.contains("data:")); // Clean up - let _ = std::fs::remove_dir_all(&dir); + let _ = tokio::fs::remove_dir_all(&dir).await; } #[tokio::test] async fn execute_with_base64() { let dir = std::env::temp_dir().join("zeroclaw_image_info_b64"); - let _ = std::fs::create_dir_all(&dir); + let _ = tokio::fs::create_dir_all(&dir).await; let png_path = dir.join("test_b64.png"); // Minimal 1x1 PNG @@ -478,7 +478,7 @@ mod tests { 0xD7, 0x63, 0xF8, 0xCF, 0xC0, 0x00, 0x00, 0x00, 0x02, 0x00, 0x01, 0xE2, 0x21, 0xBC, 0x33, 0x00, 0x00, 0x00, 0x00, 0x49, 0x45, 0x4E, 0x44, 0xAE, 0x42, 0x60, 0x82, ]; - std::fs::write(&png_path, &png_bytes).unwrap(); + tokio::fs::write(&png_path, &png_bytes).await.unwrap(); let tool = ImageInfoTool::new(test_security()); let result = tool @@ -488,6 +488,6 @@ mod tests { assert!(result.success); assert!(result.output.contains("data:image/png;base64,")); - let _ = std::fs::remove_dir_all(&dir); + let _ = tokio::fs::remove_dir_all(&dir).await; } } diff --git a/src/tools/mod.rs b/src/tools/mod.rs index a472afc..fa13949 100644 --- a/src/tools/mod.rs +++ b/src/tools/mod.rs @@ -1,3 +1,20 @@ +//! Tool subsystem for agent-callable capabilities. +//! +//! This module implements the tool execution surface exposed to the LLM during +//! agentic loops. Each tool implements the [`Tool`] trait defined in [`traits`], +//! which requires a name, description, JSON parameter schema, and an async +//! `execute` method returning a structured [`ToolResult`]. +//! +//! Tools are assembled into registries by [`default_tools`] (shell, file read/write) +//! and [`all_tools`] (full set including memory, browser, cron, HTTP, delegation, +//! and optional integrations). Security policy enforcement is injected via +//! [`SecurityPolicy`](crate::security::SecurityPolicy) at construction time. +//! +//! # Extension +//! +//! To add a new tool, implement [`Tool`] in a new submodule and register it in +//! [`all_tools_with_runtime`]. See `AGENTS.md` §7.3 for the full change playbook. + pub mod browser; pub mod browser_open; pub mod composio; @@ -227,10 +244,19 @@ pub fn all_tools_with_runtime( let trimmed_value = value.trim(); (!trimmed_value.is_empty()).then(|| trimmed_value.to_owned()) }); - tools.push(Box::new(DelegateTool::new( + tools.push(Box::new(DelegateTool::new_with_options( delegate_agents, delegate_fallback_credential, security.clone(), + crate::providers::ProviderRuntimeOptions { + auth_profile_override: None, + zeroclaw_dir: root_config + .config_path + .parent() + .map(std::path::PathBuf::from), + secrets_encrypt: root_config.secrets.encrypt, + reasoning_enabled: root_config.runtime.reasoning_enabled, + }, ))); } diff --git a/src/tools/proxy_config.rs b/src/tools/proxy_config.rs index 3ddde9e..213a57e 100644 --- a/src/tools/proxy_config.rs +++ b/src/tools/proxy_config.rs @@ -3,6 +3,7 @@ use crate::config::{ runtime_proxy_config, set_runtime_proxy_config, Config, ProxyConfig, ProxyScope, }; use crate::security::SecurityPolicy; +use crate::util::MaybeSet; use async_trait::async_trait; use serde_json::{json, Value}; use std::fs; @@ -93,16 +94,13 @@ impl ProxyConfigTool { anyhow::bail!("'{field}' must be a string or string[]") } - fn parse_optional_string_update( - args: &Value, - field: &str, - ) -> anyhow::Result>> { + fn parse_optional_string_update(args: &Value, field: &str) -> anyhow::Result> { let Some(raw) = args.get(field) else { - return Ok(None); + return Ok(MaybeSet::Unset); }; if raw.is_null() { - return Ok(Some(None)); + return Ok(MaybeSet::Null); } let value = raw @@ -110,7 +108,13 @@ impl ProxyConfigTool { .ok_or_else(|| anyhow::anyhow!("'{field}' must be a string or null"))? .trim() .to_string(); - Ok(Some((!value.is_empty()).then_some(value))) + + let output = if value.is_empty() { + MaybeSet::Null + } else { + MaybeSet::Set(value) + }; + Ok(output) } fn env_snapshot() -> Value { @@ -164,7 +168,7 @@ impl ProxyConfigTool { }) } - fn handle_set(&self, args: &Value) -> anyhow::Result { + async fn handle_set(&self, args: &Value) -> anyhow::Result { let mut cfg = self.load_config_without_env()?; let previous_scope = cfg.proxy.scope; let mut proxy = cfg.proxy.clone(); @@ -185,23 +189,45 @@ impl ProxyConfigTool { })?; } - if let Some(update) = Self::parse_optional_string_update(args, "http_proxy")? { - proxy.http_proxy = update; - touched_proxy_url = true; + match Self::parse_optional_string_update(args, "http_proxy")? { + MaybeSet::Set(update) => { + proxy.http_proxy = Some(update); + touched_proxy_url = true; + } + MaybeSet::Null => { + proxy.http_proxy = None; + touched_proxy_url = true; + } + MaybeSet::Unset => {} } - if let Some(update) = Self::parse_optional_string_update(args, "https_proxy")? { - proxy.https_proxy = update; - touched_proxy_url = true; + match Self::parse_optional_string_update(args, "https_proxy")? { + MaybeSet::Set(update) => { + proxy.https_proxy = Some(update); + touched_proxy_url = true; + } + MaybeSet::Null => { + proxy.https_proxy = None; + touched_proxy_url = true; + } + MaybeSet::Unset => {} } - if let Some(update) = Self::parse_optional_string_update(args, "all_proxy")? { - proxy.all_proxy = update; - touched_proxy_url = true; + match Self::parse_optional_string_update(args, "all_proxy")? { + MaybeSet::Set(update) => { + proxy.all_proxy = Some(update); + touched_proxy_url = true; + } + MaybeSet::Null => { + proxy.all_proxy = None; + touched_proxy_url = true; + } + MaybeSet::Unset => {} } if let Some(no_proxy_raw) = args.get("no_proxy") { proxy.no_proxy = Self::parse_string_list(no_proxy_raw, "no_proxy")?; + touched_proxy_url = true; } if let Some(services_raw) = args.get("services") { @@ -209,7 +235,9 @@ impl ProxyConfigTool { } if args.get("enabled").is_none() && touched_proxy_url { - proxy.enabled = true; + // Keep auto-enable behavior when users provide a proxy URL, but + // auto-disable when all proxy URLs are cleared in the same update. + proxy.enabled = proxy.has_any_proxy_url(); } proxy.no_proxy = proxy.normalized_no_proxy(); @@ -217,7 +245,7 @@ impl ProxyConfigTool { proxy.validate()?; cfg.proxy = proxy.clone(); - cfg.save()?; + cfg.save().await?; set_runtime_proxy_config(proxy.clone()); if proxy.enabled && proxy.scope == ProxyScope::Environment { @@ -237,11 +265,11 @@ impl ProxyConfigTool { }) } - fn handle_disable(&self, args: &Value) -> anyhow::Result { + async fn handle_disable(&self, args: &Value) -> anyhow::Result { let mut cfg = self.load_config_without_env()?; let clear_env_default = cfg.proxy.scope == ProxyScope::Environment; cfg.proxy.enabled = false; - cfg.save()?; + cfg.save().await?; set_runtime_proxy_config(cfg.proxy.clone()); @@ -384,8 +412,8 @@ impl Tool for ProxyConfigTool { } match action.as_str() { - "set" => self.handle_set(&args), - "disable" => self.handle_disable(&args), + "set" => self.handle_set(&args).await, + "disable" => self.handle_disable(&args).await, "apply_env" => self.handle_apply_env(), "clear_env" => self.handle_clear_env(), _ => unreachable!("handled above"), @@ -421,20 +449,20 @@ mod tests { }) } - fn test_config(tmp: &TempDir) -> Arc { + async fn test_config(tmp: &TempDir) -> Arc { let config = Config { workspace_dir: tmp.path().join("workspace"), config_path: tmp.path().join("config.toml"), ..Config::default() }; - config.save().unwrap(); + config.save().await.unwrap(); Arc::new(config) } #[tokio::test] async fn list_services_action_returns_known_keys() { let tmp = TempDir::new().unwrap(); - let tool = ProxyConfigTool::new(test_config(&tmp), test_security()); + let tool = ProxyConfigTool::new(test_config(&tmp).await, test_security()); let result = tool .execute(json!({"action": "list_services"})) @@ -448,7 +476,7 @@ mod tests { #[tokio::test] async fn set_scope_services_requires_services_entries() { let tmp = TempDir::new().unwrap(); - let tool = ProxyConfigTool::new(test_config(&tmp), test_security()); + let tool = ProxyConfigTool::new(test_config(&tmp).await, test_security()); let result = tool .execute(json!({ @@ -471,7 +499,7 @@ mod tests { #[tokio::test] async fn set_and_get_round_trip_proxy_scope() { let tmp = TempDir::new().unwrap(); - let tool = ProxyConfigTool::new(test_config(&tmp), test_security()); + let tool = ProxyConfigTool::new(test_config(&tmp).await, test_security()); let set_result = tool .execute(json!({ @@ -489,4 +517,34 @@ mod tests { assert!(get_result.output.contains("provider.openai")); assert!(get_result.output.contains("services")); } + + #[tokio::test] + async fn set_null_proxy_url_clears_existing_value() { + let tmp = TempDir::new().unwrap(); + let tool = ProxyConfigTool::new(test_config(&tmp).await, test_security()); + + let set_result = tool + .execute(json!({ + "action": "set", + "http_proxy": "http://127.0.0.1:7890" + })) + .await + .unwrap(); + assert!(set_result.success, "{:?}", set_result.error); + + let clear_result = tool + .execute(json!({ + "action": "set", + "http_proxy": null + })) + .await + .unwrap(); + assert!(clear_result.success, "{:?}", clear_result.error); + + let get_result = tool.execute(json!({"action": "get"})).await.unwrap(); + assert!(get_result.success); + let parsed: Value = serde_json::from_str(&get_result.output).unwrap(); + assert!(parsed["proxy"]["http_proxy"].is_null()); + assert!(parsed["runtime_proxy"]["http_proxy"].is_null()); + } } diff --git a/src/tools/pushover.rs b/src/tools/pushover.rs index 23d980b..7e64e9a 100644 --- a/src/tools/pushover.rs +++ b/src/tools/pushover.rs @@ -41,9 +41,10 @@ impl PushoverTool { ) } - fn get_credentials(&self) -> anyhow::Result<(String, String)> { + async fn get_credentials(&self) -> anyhow::Result<(String, String)> { let env_path = self.workspace_dir.join(".env"); - let content = std::fs::read_to_string(&env_path) + let content = tokio::fs::read_to_string(&env_path) + .await .map_err(|e| anyhow::anyhow!("Failed to read {}: {}", env_path.display(), e))?; let mut token = None; @@ -99,7 +100,6 @@ impl Tool for PushoverTool { }, "priority": { "type": "integer", - "enum": [-2, -1, 0, 1, 2], "description": "Message priority: -2 (lowest/silent), -1 (low/no sound), 0 (normal), 1 (high), 2 (emergency/repeating)" }, "sound": { @@ -154,7 +154,7 @@ impl Tool for PushoverTool { let sound = args.get("sound").and_then(|v| v.as_str()).map(String::from); - let (token, user_key) = self.get_credentials()?; + let (token, user_key) = self.get_credentials().await?; let mut form = reqwest::multipart::Form::new() .text("token", token) @@ -270,8 +270,8 @@ mod tests { assert!(required.contains(&serde_json::Value::String("message".to_string()))); } - #[test] - fn credentials_parsed_from_env_file() { + #[tokio::test] + async fn credentials_parsed_from_env_file() { let tmp = TempDir::new().unwrap(); let env_path = tmp.path().join(".env"); fs::write( @@ -284,7 +284,7 @@ mod tests { test_security(AutonomyLevel::Full, 100), tmp.path().to_path_buf(), ); - let result = tool.get_credentials(); + let result = tool.get_credentials().await; assert!(result.is_ok()); let (token, user_key) = result.unwrap(); @@ -292,20 +292,20 @@ mod tests { assert_eq!(user_key, "userkey456"); } - #[test] - fn credentials_fail_without_env_file() { + #[tokio::test] + async fn credentials_fail_without_env_file() { let tmp = TempDir::new().unwrap(); let tool = PushoverTool::new( test_security(AutonomyLevel::Full, 100), tmp.path().to_path_buf(), ); - let result = tool.get_credentials(); + let result = tool.get_credentials().await; assert!(result.is_err()); } - #[test] - fn credentials_fail_without_token() { + #[tokio::test] + async fn credentials_fail_without_token() { let tmp = TempDir::new().unwrap(); let env_path = tmp.path().join(".env"); fs::write(&env_path, "PUSHOVER_USER_KEY=userkey456\n").unwrap(); @@ -314,13 +314,13 @@ mod tests { test_security(AutonomyLevel::Full, 100), tmp.path().to_path_buf(), ); - let result = tool.get_credentials(); + let result = tool.get_credentials().await; assert!(result.is_err()); } - #[test] - fn credentials_fail_without_user_key() { + #[tokio::test] + async fn credentials_fail_without_user_key() { let tmp = TempDir::new().unwrap(); let env_path = tmp.path().join(".env"); fs::write(&env_path, "PUSHOVER_TOKEN=testtoken123\n").unwrap(); @@ -329,13 +329,13 @@ mod tests { test_security(AutonomyLevel::Full, 100), tmp.path().to_path_buf(), ); - let result = tool.get_credentials(); + let result = tool.get_credentials().await; assert!(result.is_err()); } - #[test] - fn credentials_ignore_comments() { + #[tokio::test] + async fn credentials_ignore_comments() { let tmp = TempDir::new().unwrap(); let env_path = tmp.path().join(".env"); fs::write(&env_path, "# This is a comment\nPUSHOVER_TOKEN=realtoken\n# Another comment\nPUSHOVER_USER_KEY=realuser\n").unwrap(); @@ -344,7 +344,7 @@ mod tests { test_security(AutonomyLevel::Full, 100), tmp.path().to_path_buf(), ); - let result = tool.get_credentials(); + let result = tool.get_credentials().await; assert!(result.is_ok()); let (token, user_key) = result.unwrap(); @@ -372,8 +372,8 @@ mod tests { assert!(schema["properties"].get("sound").is_some()); } - #[test] - fn credentials_support_export_and_quoted_values() { + #[tokio::test] + async fn credentials_support_export_and_quoted_values() { let tmp = TempDir::new().unwrap(); let env_path = tmp.path().join(".env"); fs::write( @@ -386,7 +386,7 @@ mod tests { test_security(AutonomyLevel::Full, 100), tmp.path().to_path_buf(), ); - let result = tool.get_credentials(); + let result = tool.get_credentials().await; assert!(result.is_ok()); let (token, user_key) = result.unwrap(); diff --git a/src/tools/schedule.rs b/src/tools/schedule.rs index 96c3023..fcf46fe 100644 --- a/src/tools/schedule.rs +++ b/src/tools/schedule.rs @@ -368,14 +368,16 @@ mod tests { use crate::security::AutonomyLevel; use tempfile::TempDir; - fn test_setup() -> (TempDir, Config, Arc) { + async fn test_setup() -> (TempDir, Config, Arc) { let tmp = TempDir::new().unwrap(); let config = Config { workspace_dir: tmp.path().join("workspace"), config_path: tmp.path().join("config.toml"), ..Config::default() }; - std::fs::create_dir_all(&config.workspace_dir).unwrap(); + tokio::fs::create_dir_all(&config.workspace_dir) + .await + .unwrap(); let security = Arc::new(SecurityPolicy::from_config( &config.autonomy, &config.workspace_dir, @@ -383,9 +385,9 @@ mod tests { (tmp, config, security) } - #[test] - fn tool_name_and_schema() { - let (_tmp, config, security) = test_setup(); + #[tokio::test] + async fn tool_name_and_schema() { + let (_tmp, config, security) = test_setup().await; let tool = ScheduleTool::new(security, config); assert_eq!(tool.name(), "schedule"); let schema = tool.parameters_schema(); @@ -394,7 +396,7 @@ mod tests { #[tokio::test] async fn list_empty() { - let (_tmp, config, security) = test_setup(); + let (_tmp, config, security) = test_setup().await; let tool = ScheduleTool::new(security, config); let result = tool.execute(json!({"action": "list"})).await.unwrap(); @@ -404,7 +406,7 @@ mod tests { #[tokio::test] async fn create_get_and_cancel_roundtrip() { - let (_tmp, config, security) = test_setup(); + let (_tmp, config, security) = test_setup().await; let tool = ScheduleTool::new(security, config); let create = tool @@ -440,7 +442,7 @@ mod tests { #[tokio::test] async fn once_and_pause_resume_aliases_work() { - let (_tmp, config, security) = test_setup(); + let (_tmp, config, security) = test_setup().await; let tool = ScheduleTool::new(security, config); let once = tool @@ -489,7 +491,9 @@ mod tests { }, ..Config::default() }; - std::fs::create_dir_all(&config.workspace_dir).unwrap(); + tokio::fs::create_dir_all(&config.workspace_dir) + .await + .unwrap(); let security = Arc::new(SecurityPolicy::from_config( &config.autonomy, &config.workspace_dir, @@ -514,7 +518,7 @@ mod tests { #[tokio::test] async fn unknown_action_returns_failure() { - let (_tmp, config, security) = test_setup(); + let (_tmp, config, security) = test_setup().await; let tool = ScheduleTool::new(security, config); let result = tool.execute(json!({"action": "explode"})).await.unwrap(); diff --git a/src/tools/shell.rs b/src/tools/shell.rs index 031ed4b..4392bdb 100644 --- a/src/tools/shell.rs +++ b/src/tools/shell.rs @@ -198,7 +198,7 @@ mod tests { assert!(schema["properties"]["command"].is_object()); assert!(schema["required"] .as_array() - .unwrap() + .expect("schema required field should be an array") .contains(&json!("command"))); assert!(schema["properties"]["approved"].is_object()); } @@ -209,7 +209,7 @@ mod tests { let result = tool .execute(json!({"command": "echo hello"})) .await - .unwrap(); + .expect("echo command execution should succeed"); assert!(result.success); assert!(result.output.trim().contains("hello")); assert!(result.error.is_none()); @@ -218,7 +218,10 @@ mod tests { #[tokio::test] async fn shell_blocks_disallowed_command() { let tool = ShellTool::new(test_security(AutonomyLevel::Supervised), test_runtime()); - let result = tool.execute(json!({"command": "rm -rf /"})).await.unwrap(); + let result = tool + .execute(json!({"command": "rm -rf /"})) + .await + .expect("disallowed command execution should return a result"); assert!(!result.success); let error = result.error.as_deref().unwrap_or(""); assert!(error.contains("not allowed") || error.contains("high-risk")); @@ -227,9 +230,16 @@ mod tests { #[tokio::test] async fn shell_blocks_readonly() { let tool = ShellTool::new(test_security(AutonomyLevel::ReadOnly), test_runtime()); - let result = tool.execute(json!({"command": "ls"})).await.unwrap(); + let result = tool + .execute(json!({"command": "ls"})) + .await + .expect("readonly command execution should return a result"); assert!(!result.success); - assert!(result.error.as_ref().unwrap().contains("not allowed")); + assert!(result + .error + .as_ref() + .expect("error field should be present for blocked command") + .contains("not allowed")); } #[tokio::test] @@ -253,7 +263,7 @@ mod tests { let result = tool .execute(json!({"command": "ls /nonexistent_dir_xyz"})) .await - .unwrap(); + .expect("command with nonexistent path should return a result"); assert!(!result.success); } @@ -296,7 +306,10 @@ mod tests { let _g2 = EnvGuard::set("ZEROCLAW_API_KEY", "sk-test-secret-67890"); let tool = ShellTool::new(test_security_with_env_cmd(), test_runtime()); - let result = tool.execute(json!({"command": "env"})).await.unwrap(); + let result = tool + .execute(json!({"command": "env"})) + .await + .expect("env command execution should succeed"); assert!(result.success); assert!( !result.output.contains("sk-test-secret-12345"), @@ -315,7 +328,7 @@ mod tests { let result = tool .execute(json!({"command": "echo $HOME"})) .await - .unwrap(); + .expect("echo HOME command should succeed"); assert!(result.success); assert!( !result.output.trim().is_empty(), @@ -325,7 +338,7 @@ mod tests { let result = tool .execute(json!({"command": "echo $PATH"})) .await - .unwrap(); + .expect("echo PATH command should succeed"); assert!(result.success); assert!( !result.output.trim().is_empty(), @@ -346,7 +359,7 @@ mod tests { let denied = tool .execute(json!({"command": "touch zeroclaw_shell_approval_test"})) .await - .unwrap(); + .expect("unapproved command should return a result"); assert!(!denied.success); assert!(denied .error @@ -360,10 +373,11 @@ mod tests { "approved": true })) .await - .unwrap(); + .expect("approved command execution should succeed"); assert!(allowed.success); - let _ = std::fs::remove_file(std::env::temp_dir().join("zeroclaw_shell_approval_test")); + let _ = + tokio::fs::remove_file(std::env::temp_dir().join("zeroclaw_shell_approval_test")).await; } // ── §5.2 Shell timeout enforcement tests ───────────────── @@ -419,7 +433,10 @@ mod tests { ..SecurityPolicy::default() }); let tool = ShellTool::new(security, test_runtime()); - let result = tool.execute(json!({"command": "echo test"})).await.unwrap(); + let result = tool + .execute(json!({"command": "echo test"})) + .await + .expect("rate-limited command should return a result"); assert!(!result.success); assert!(result.error.as_deref().unwrap_or("").contains("Rate limit")); } diff --git a/src/tools/web_search_tool.rs b/src/tools/web_search_tool.rs index fa3b750..974410e 100644 --- a/src/tools/web_search_tool.rs +++ b/src/tools/web_search_tool.rs @@ -219,7 +219,10 @@ impl Tool for WebSearchTool { let result = match self.provider.as_str() { "duckduckgo" | "ddg" => self.search_duckduckgo(query).await?, "brave" => self.search_brave(query).await?, - _ => anyhow::bail!("Unknown search provider: {}", self.provider), + _ => anyhow::bail!( + "Unknown search provider: '{}'. Set tools.web_search.provider to 'duckduckgo' or 'brave' in config.toml", + self.provider + ), }; Ok(ToolResult { diff --git a/src/util.rs b/src/util.rs index 9a218e7..85c7856 100644 --- a/src/util.rs +++ b/src/util.rs @@ -43,6 +43,13 @@ pub fn truncate_with_ellipsis(s: &str, max_chars: usize) -> String { } } +/// Utility enum for handling optional values. +pub enum MaybeSet { + Set(T), + Unset, + Null, +} + #[cfg(test)] mod tests { use super::*; diff --git a/tests/agent_e2e.rs b/tests/agent_e2e.rs index 9ca3287..6bdfb36 100644 --- a/tests/agent_e2e.rs +++ b/tests/agent_e2e.rs @@ -13,11 +13,15 @@ use serde_json::json; use std::sync::{Arc, Mutex}; use zeroclaw::agent::agent::Agent; use zeroclaw::agent::dispatcher::{NativeToolDispatcher, XmlToolDispatcher}; +use zeroclaw::agent::memory_loader::MemoryLoader; use zeroclaw::config::MemoryConfig; use zeroclaw::memory; use zeroclaw::memory::Memory; use zeroclaw::observability::{NoopObserver, Observer}; -use zeroclaw::providers::{ChatRequest, ChatResponse, Provider, ToolCall}; +use zeroclaw::providers::traits::ChatMessage; +use zeroclaw::providers::{ + ChatRequest, ChatResponse, ConversationMessage, Provider, ProviderRuntimeOptions, ToolCall, +}; use zeroclaw::tools::{Tool, ToolResult}; // ───────────────────────────────────────────────────────────────────────────── @@ -138,6 +142,79 @@ impl Tool for CountingTool { } } +/// Mock provider that returns scripted responses AND records every request. +/// Pattern from `ScriptedProvider` in `src/agent/tests.rs`. +struct RecordingProvider { + responses: Mutex>, + recorded_requests: Arc>>>, +} + +impl RecordingProvider { + fn new(responses: Vec) -> (Self, Arc>>>) { + let recorded = Arc::new(Mutex::new(Vec::new())); + let provider = Self { + responses: Mutex::new(responses), + recorded_requests: recorded.clone(), + }; + (provider, recorded) + } +} + +#[async_trait] +impl Provider for RecordingProvider { + async fn chat_with_system( + &self, + _system_prompt: Option<&str>, + _message: &str, + _model: &str, + _temperature: f64, + ) -> Result { + Ok("fallback".into()) + } + + async fn chat( + &self, + request: ChatRequest<'_>, + _model: &str, + _temperature: f64, + ) -> Result { + self.recorded_requests + .lock() + .unwrap() + .push(request.messages.to_vec()); + + let mut guard = self.responses.lock().unwrap(); + if guard.is_empty() { + return Ok(ChatResponse { + text: Some("done".into()), + tool_calls: vec![], + }); + } + Ok(guard.remove(0)) + } +} + +/// Mock memory loader that returns a static context string, +/// simulating RAG recall without a real memory backend. +struct StaticMemoryLoader { + context: String, +} + +impl StaticMemoryLoader { + fn new(context: &str) -> Self { + Self { + context: context.to_string(), + } + } +} + +#[async_trait] +impl MemoryLoader for StaticMemoryLoader { + async fn load_context(&self, _memory: &dyn Memory, _user_message: &str) -> Result { + Ok(self.context.clone()) + } +} + // ───────────────────────────────────────────────────────────────────────────── // Test helpers // ───────────────────────────────────────────────────────────────────────────── @@ -192,6 +269,26 @@ fn build_agent_xml(provider: Box, tools: Vec>) -> Ag .unwrap() } +fn build_recording_agent( + provider: Box, + tools: Vec>, + memory_loader: Option>, +) -> Agent { + let mut builder = Agent::builder() + .provider(provider) + .tools(tools) + .memory(make_memory()) + .observer(make_observer()) + .tool_dispatcher(Box::new(NativeToolDispatcher)) + .workspace_dir(std::env::temp_dir()); + + if let Some(loader) = memory_loader { + builder = builder.memory_loader(loader); + } + + builder.build().unwrap() +} + // ═════════════════════════════════════════════════════════════════════════════ // E2E smoke tests — full agent turn cycle // ═════════════════════════════════════════════════════════════════════════════ @@ -352,3 +449,243 @@ async fn e2e_parallel_tool_dispatch() { ); assert_eq!(*count.lock().unwrap(), 2); } + +// ═════════════════════════════════════════════════════════════════════════════ +// Multi-turn history fidelity & memory enrichment tests +// ═════════════════════════════════════════════════════════════════════════════ + +/// Validates that multi-turn conversation correctly accumulates history +/// and passes growing message sequences to the provider on each turn. +#[tokio::test] +async fn e2e_multi_turn_history_fidelity() { + let (provider, recorded) = RecordingProvider::new(vec![ + text_response("response 1"), + text_response("response 2"), + text_response("response 3"), + ]); + + let mut agent = build_recording_agent(Box::new(provider), vec![], None); + + let r1 = agent.turn("msg 1").await.unwrap(); + assert_eq!(r1, "response 1"); + + let r2 = agent.turn("msg 2").await.unwrap(); + assert_eq!(r2, "response 2"); + + let r3 = agent.turn("msg 3").await.unwrap(); + assert_eq!(r3, "response 3"); + + let requests = recorded.lock().unwrap(); + assert_eq!(requests.len(), 3, "Provider should receive 3 requests"); + + // Request 1: system + user("msg 1") + let req1 = &requests[0]; + assert!(req1.len() >= 2); + assert_eq!(req1[0].role, "system"); + assert_eq!(req1[1].role, "user"); + assert!(req1[1].content.contains("msg 1")); + + // Request 2: system + user("msg 1") + assistant("response 1") + user("msg 2") + let req2 = &requests[1]; + let req2_users: Vec<&ChatMessage> = req2.iter().filter(|m| m.role == "user").collect(); + let req2_assts: Vec<&ChatMessage> = req2.iter().filter(|m| m.role == "assistant").collect(); + assert_eq!(req2_users.len(), 2, "Request 2: expected 2 user messages"); + assert_eq!( + req2_assts.len(), + 1, + "Request 2: expected 1 assistant message" + ); + assert!(req2_users[0].content.contains("msg 1")); + assert!(req2_users[1].content.contains("msg 2")); + assert_eq!(req2_assts[0].content, "response 1"); + + // Request 3: full history — 3 user + 2 assistant messages + let req3 = &requests[2]; + let req3_users: Vec<&ChatMessage> = req3.iter().filter(|m| m.role == "user").collect(); + let req3_assts: Vec<&ChatMessage> = req3.iter().filter(|m| m.role == "assistant").collect(); + assert_eq!(req3_users.len(), 3, "Request 3: expected 3 user messages"); + assert_eq!( + req3_assts.len(), + 2, + "Request 3: expected 2 assistant messages" + ); + assert!(req3_users[0].content.contains("msg 1")); + assert!(req3_users[1].content.contains("msg 2")); + assert!(req3_users[2].content.contains("msg 3")); + assert_eq!(req3_assts[0].content, "response 1"); + assert_eq!(req3_assts[1].content, "response 2"); + + // Verify agent history: system + 3*(user + assistant) = 7 + let history = agent.history(); + assert_eq!(history.len(), 7); + assert!(matches!(&history[0], ConversationMessage::Chat(c) if c.role == "system")); + assert!(matches!(&history[1], ConversationMessage::Chat(c) if c.role == "user")); + assert!(matches!(&history[2], ConversationMessage::Chat(c) if c.role == "assistant")); + assert!( + matches!(&history[6], ConversationMessage::Chat(c) if c.role == "assistant" && c.content == "response 3") + ); +} + +/// Validates that a custom MemoryLoader injects RAG context into user +/// messages before they reach the provider. +#[tokio::test] +async fn e2e_memory_enrichment_injects_context() { + let (provider, recorded) = RecordingProvider::new(vec![text_response("enriched response")]); + + let memory_context = "[Memory context]\n- user_name: test_user\n\n"; + let loader = StaticMemoryLoader::new(memory_context); + + let mut agent = build_recording_agent(Box::new(provider), vec![], Some(Box::new(loader))); + + let response = agent.turn("hello").await.unwrap(); + assert_eq!(response, "enriched response"); + + // Provider received enriched message + let requests = recorded.lock().unwrap(); + assert_eq!(requests.len(), 1); + let user_msg = requests[0].iter().find(|m| m.role == "user").unwrap(); + assert!( + user_msg.content.starts_with("[Memory context]"), + "User message should start with memory context, got: {}", + user_msg.content, + ); + assert!( + user_msg.content.contains("user_name: test_user"), + "User message should contain memory key-value pair", + ); + assert!( + user_msg.content.ends_with("hello"), + "User message should end with original text, got: {}", + user_msg.content, + ); + + // Agent history also stores enriched message + let history = agent.history(); + match &history[1] { + ConversationMessage::Chat(c) => { + assert_eq!(c.role, "user"); + assert!(c.content.starts_with("[Memory context]")); + assert!(c.content.ends_with("hello")); + } + other => panic!("Expected Chat variant for user message, got: {other:?}"), + } +} + +/// Validates multi-turn conversation with memory enrichment: every user +/// message is enriched, and the provider sees the full enriched history. +#[tokio::test] +async fn e2e_multi_turn_with_memory_enrichment() { + let (provider, recorded) = + RecordingProvider::new(vec![text_response("answer 1"), text_response("answer 2")]); + + let memory_context = "[Memory context]\n- project: zeroclaw\n\n"; + let loader = StaticMemoryLoader::new(memory_context); + + let mut agent = build_recording_agent(Box::new(provider), vec![], Some(Box::new(loader))); + + let r1 = agent.turn("first question").await.unwrap(); + assert_eq!(r1, "answer 1"); + + let r2 = agent.turn("second question").await.unwrap(); + assert_eq!(r2, "answer 2"); + + let requests = recorded.lock().unwrap(); + assert_eq!(requests.len(), 2); + + // Turn 1: user message is enriched + let req1_user = requests[0].iter().find(|m| m.role == "user").unwrap(); + assert!(req1_user.content.contains("[Memory context]")); + assert!(req1_user.content.contains("project: zeroclaw")); + assert!(req1_user.content.ends_with("first question")); + + // Turn 2: both user messages enriched, assistant from turn 1 present + let req2_users: Vec<&ChatMessage> = requests[1].iter().filter(|m| m.role == "user").collect(); + assert_eq!(req2_users.len(), 2, "Request 2 should have 2 user messages"); + + // Turn 1 user message still enriched in history + assert!(req2_users[0].content.contains("[Memory context]")); + assert!(req2_users[0].content.ends_with("first question")); + + // Turn 2 user message also enriched + assert!(req2_users[1].content.contains("[Memory context]")); + assert!(req2_users[1].content.ends_with("second question")); + + // Assistant response from turn 1 preserved + let req2_assts: Vec<&ChatMessage> = requests[1] + .iter() + .filter(|m| m.role == "assistant") + .collect(); + assert_eq!(req2_assts.len(), 1); + assert_eq!(req2_assts[0].content, "answer 1"); + + // History: system + 2*(enriched_user + assistant) = 5 + assert_eq!(agent.history().len(), 5); +} + +/// Validates that empty memory context passes user message through unmodified. +#[tokio::test] +async fn e2e_empty_memory_context_passthrough() { + let (provider, recorded) = RecordingProvider::new(vec![text_response("plain response")]); + + let loader = StaticMemoryLoader::new(""); + + let mut agent = build_recording_agent(Box::new(provider), vec![], Some(Box::new(loader))); + + let response = agent.turn("hello").await.unwrap(); + assert_eq!(response, "plain response"); + + let requests = recorded.lock().unwrap(); + let user_msg = requests[0].iter().find(|m| m.role == "user").unwrap(); + assert_eq!( + user_msg.content, "hello", + "Empty context should not prepend anything to user message", + ); +} + +// ═════════════════════════════════════════════════════════════════════════════ +// Live integration test — real OpenAI Codex API (requires credentials) +// ═════════════════════════════════════════════════════════════════════════════ + +/// Sends a real multi-turn conversation to OpenAI Codex and verifies +/// the model retains context from earlier messages. +/// +/// Requires valid OAuth credentials in `~/.zeroclaw/`. +/// Run manually: `cargo test e2e_live_openai_codex_multi_turn -- --ignored` +#[tokio::test] +#[ignore] +async fn e2e_live_openai_codex_multi_turn() { + use zeroclaw::providers::openai_codex::OpenAiCodexProvider; + use zeroclaw::providers::traits::Provider; + + let provider = OpenAiCodexProvider::new(&ProviderRuntimeOptions::default()); + let model = "gpt-5.3-codex"; + + // Turn 1: establish a fact + let messages_turn1 = vec![ + ChatMessage::system("You are a concise assistant. Reply in one short sentence."), + ChatMessage::user("The secret word is \"zephyr\". Just confirm you noted it."), + ]; + let response1 = provider + .chat_with_history(&messages_turn1, model, 0.0) + .await; + assert!(response1.is_ok(), "Turn 1 failed: {:?}", response1.err()); + let r1 = response1.unwrap(); + assert!(!r1.is_empty(), "Turn 1 returned empty response"); + + // Turn 2: ask the model to recall the fact + let messages_turn2 = vec![ + ChatMessage::system("You are a concise assistant. Reply in one short sentence."), + ChatMessage::user("The secret word is \"zephyr\". Just confirm you noted it."), + ChatMessage::assistant(&r1), + ChatMessage::user("What is the secret word?"), + ]; + let response2 = provider + .chat_with_history(&messages_turn2, model, 0.0) + .await; + assert!(response2.is_ok(), "Turn 2 failed: {:?}", response2.err()); + let r2 = response2.unwrap().to_lowercase(); + assert!( + r2.contains("zephyr"), + "Model should recall 'zephyr' from history, got: {r2}", + ); +} diff --git a/tests/agent_loop_robustness.rs b/tests/agent_loop_robustness.rs new file mode 100644 index 0000000..fadcd9f --- /dev/null +++ b/tests/agent_loop_robustness.rs @@ -0,0 +1,440 @@ +//! TG4: Agent Loop Robustness Tests +//! +//! Prevents: Pattern 4 — Agent loop & tool call processing bugs (13% of user bugs). +//! Issues: #746, #418, #777, #848 +//! +//! Tests agent behavior with malformed tool calls, empty responses, +//! max iteration limits, and cascading tool failures using mock providers. +//! Complements inline parse_tool_calls tests in `src/agent/loop_.rs`. + +use anyhow::Result; +use async_trait::async_trait; +use serde_json::json; +use std::sync::{Arc, Mutex}; +use zeroclaw::agent::agent::Agent; +use zeroclaw::agent::dispatcher::NativeToolDispatcher; +use zeroclaw::config::MemoryConfig; +use zeroclaw::memory; +use zeroclaw::memory::Memory; +use zeroclaw::observability::{NoopObserver, Observer}; +use zeroclaw::providers::{ChatRequest, ChatResponse, Provider, ToolCall}; +use zeroclaw::tools::{Tool, ToolResult}; + +// ───────────────────────────────────────────────────────────────────────────── +// Mock infrastructure +// ───────────────────────────────────────────────────────────────────────────── + +struct MockProvider { + responses: Mutex>, +} + +impl MockProvider { + fn new(responses: Vec) -> Self { + Self { + responses: Mutex::new(responses), + } + } +} + +#[async_trait] +impl Provider for MockProvider { + async fn chat_with_system( + &self, + _system_prompt: Option<&str>, + _message: &str, + _model: &str, + _temperature: f64, + ) -> Result { + Ok("fallback".into()) + } + + async fn chat( + &self, + _request: ChatRequest<'_>, + _model: &str, + _temperature: f64, + ) -> Result { + let mut guard = self.responses.lock().unwrap(); + if guard.is_empty() { + return Ok(ChatResponse { + text: Some("done".into()), + tool_calls: vec![], + }); + } + Ok(guard.remove(0)) + } +} + +struct EchoTool; + +#[async_trait] +impl Tool for EchoTool { + fn name(&self) -> &str { + "echo" + } + fn description(&self) -> &str { + "Echoes the input message" + } + fn parameters_schema(&self) -> serde_json::Value { + json!({ + "type": "object", + "properties": { + "message": {"type": "string"} + } + }) + } + async fn execute(&self, args: serde_json::Value) -> Result { + let msg = args + .get("message") + .and_then(|v| v.as_str()) + .unwrap_or("(empty)") + .to_string(); + Ok(ToolResult { + success: true, + output: msg, + error: None, + }) + } +} + +/// Tool that always fails, simulating a broken external service +struct FailingTool; + +#[async_trait] +impl Tool for FailingTool { + fn name(&self) -> &str { + "failing_tool" + } + fn description(&self) -> &str { + "Always fails" + } + fn parameters_schema(&self) -> serde_json::Value { + json!({"type": "object"}) + } + async fn execute(&self, _args: serde_json::Value) -> Result { + Ok(ToolResult { + success: false, + output: String::new(), + error: Some("Service unavailable: connection timeout".into()), + }) + } +} + +/// Tool that tracks invocations +struct CountingTool { + count: Arc>, +} + +impl CountingTool { + fn new() -> (Self, Arc>) { + let count = Arc::new(Mutex::new(0)); + ( + Self { + count: count.clone(), + }, + count, + ) + } +} + +#[async_trait] +impl Tool for CountingTool { + fn name(&self) -> &str { + "counter" + } + fn description(&self) -> &str { + "Counts invocations" + } + fn parameters_schema(&self) -> serde_json::Value { + json!({"type": "object"}) + } + async fn execute(&self, _args: serde_json::Value) -> Result { + let mut c = self.count.lock().unwrap(); + *c += 1; + Ok(ToolResult { + success: true, + output: format!("call #{}", *c), + error: None, + }) + } +} + +// ───────────────────────────────────────────────────────────────────────────── +// Test helpers +// ───────────────────────────────────────────────────────────────────────────── + +fn make_memory() -> Arc { + let cfg = MemoryConfig { + backend: "none".into(), + ..MemoryConfig::default() + }; + Arc::from(memory::create_memory(&cfg, &std::env::temp_dir(), None).unwrap()) +} + +fn make_observer() -> Arc { + Arc::from(NoopObserver {}) +} + +fn text_response(text: &str) -> ChatResponse { + ChatResponse { + text: Some(text.into()), + tool_calls: vec![], + } +} + +fn tool_response(calls: Vec) -> ChatResponse { + ChatResponse { + text: Some(String::new()), + tool_calls: calls, + } +} + +fn build_agent(provider: Box, tools: Vec>) -> Agent { + Agent::builder() + .provider(provider) + .tools(tools) + .memory(make_memory()) + .observer(make_observer()) + .tool_dispatcher(Box::new(NativeToolDispatcher)) + .workspace_dir(std::env::temp_dir()) + .build() + .unwrap() +} + +// ═════════════════════════════════════════════════════════════════════════════ +// TG4.1: Malformed tool call recovery +// ═════════════════════════════════════════════════════════════════════════════ + +/// Agent should recover when LLM returns text with residual XML tags (#746) +#[tokio::test] +async fn agent_recovers_from_text_with_xml_residue() { + let provider = Box::new(MockProvider::new(vec![text_response( + "Here is the result. Some leftover text after.", + )])); + + let mut agent = build_agent(provider, vec![Box::new(EchoTool)]); + let response = agent.turn("test").await.unwrap(); + assert!( + !response.is_empty(), + "agent should produce non-empty response despite XML residue" + ); +} + +/// Agent should handle tool call with empty arguments gracefully +#[tokio::test] +async fn agent_handles_tool_call_with_empty_arguments() { + let provider = Box::new(MockProvider::new(vec![ + tool_response(vec![ToolCall { + id: "tc1".into(), + name: "echo".into(), + arguments: "{}".into(), + }]), + text_response("Tool with empty args executed"), + ])); + + let mut agent = build_agent(provider, vec![Box::new(EchoTool)]); + let response = agent.turn("call with empty args").await.unwrap(); + assert!(!response.is_empty()); +} + +/// Agent should handle unknown tool name without crashing (#848 related) +#[tokio::test] +async fn agent_handles_nonexistent_tool_gracefully() { + let provider = Box::new(MockProvider::new(vec![ + tool_response(vec![ToolCall { + id: "tc1".into(), + name: "absolutely_nonexistent_tool".into(), + arguments: "{}".into(), + }]), + text_response("Recovered from unknown tool"), + ])); + + let mut agent = build_agent(provider, vec![Box::new(EchoTool)]); + let response = agent.turn("call missing tool").await.unwrap(); + assert!( + !response.is_empty(), + "agent should recover from unknown tool" + ); +} + +// ═════════════════════════════════════════════════════════════════════════════ +// TG4.2: Tool failure cascade handling (#848) +// ═════════════════════════════════════════════════════════════════════════════ + +/// Agent should handle repeated tool failures without infinite loop +#[tokio::test] +async fn agent_handles_failing_tool() { + let provider = Box::new(MockProvider::new(vec![ + tool_response(vec![ToolCall { + id: "tc1".into(), + name: "failing_tool".into(), + arguments: "{}".into(), + }]), + text_response("Tool failed but I recovered"), + ])); + + let mut agent = build_agent(provider, vec![Box::new(FailingTool)]); + let response = agent.turn("use failing tool").await.unwrap(); + assert!( + !response.is_empty(), + "agent should produce response even after tool failure" + ); +} + +/// Agent should handle mixed tool calls (some succeed, some fail) +#[tokio::test] +async fn agent_handles_mixed_tool_success_and_failure() { + let provider = Box::new(MockProvider::new(vec![ + tool_response(vec![ + ToolCall { + id: "tc1".into(), + name: "echo".into(), + arguments: r#"{"message": "success"}"#.into(), + }, + ToolCall { + id: "tc2".into(), + name: "failing_tool".into(), + arguments: "{}".into(), + }, + ]), + text_response("Mixed results processed"), + ])); + + let mut agent = build_agent(provider, vec![Box::new(EchoTool), Box::new(FailingTool)]); + let response = agent.turn("mixed tools").await.unwrap(); + assert!(!response.is_empty()); +} + +// ═════════════════════════════════════════════════════════════════════════════ +// TG4.3: Iteration limit enforcement (#777) +// ═════════════════════════════════════════════════════════════════════════════ + +/// Agent should not exceed max_tool_iterations (default=10) even with +/// a provider that keeps returning tool calls +#[tokio::test] +async fn agent_respects_max_tool_iterations() { + let (counting_tool, count) = CountingTool::new(); + + // Create 20 tool call responses - more than the default limit of 10 + let mut responses: Vec = (0..20) + .map(|i| { + tool_response(vec![ToolCall { + id: format!("tc_{i}"), + name: "counter".into(), + arguments: "{}".into(), + }]) + }) + .collect(); + // Add a final text response that would be used if limit is reached + responses.push(text_response("Final response after iterations")); + + let provider = Box::new(MockProvider::new(responses)); + let mut agent = build_agent(provider, vec![Box::new(counting_tool)]); + + // Agent should complete (either by hitting iteration limit or running out of responses) + let result = agent.turn("keep calling tools").await; + // The agent should complete without hanging + assert!(result.is_ok() || result.is_err()); + + let invocations = *count.lock().unwrap(); + assert!( + invocations <= 10, + "tool invocations ({invocations}) should not exceed default max_tool_iterations (10)" + ); +} + +// ═════════════════════════════════════════════════════════════════════════════ +// TG4.4: Empty and whitespace responses +// ═════════════════════════════════════════════════════════════════════════════ + +/// Agent should handle empty text response from provider (#418 related) +#[tokio::test] +async fn agent_handles_empty_provider_response() { + let provider = Box::new(MockProvider::new(vec![ChatResponse { + text: Some(String::new()), + tool_calls: vec![], + }])); + + let mut agent = build_agent(provider, vec![Box::new(EchoTool)]); + // Should not panic + let _result = agent.turn("test").await; +} + +/// Agent should handle None text response from provider +#[tokio::test] +async fn agent_handles_none_text_response() { + let provider = Box::new(MockProvider::new(vec![ChatResponse { + text: None, + tool_calls: vec![], + }])); + + let mut agent = build_agent(provider, vec![Box::new(EchoTool)]); + let _result = agent.turn("test").await; +} + +/// Agent should handle whitespace-only response +#[tokio::test] +async fn agent_handles_whitespace_only_response() { + let provider = Box::new(MockProvider::new(vec![text_response(" \n\t ")])); + + let mut agent = build_agent(provider, vec![Box::new(EchoTool)]); + let _result = agent.turn("test").await; +} + +// ═════════════════════════════════════════════════════════════════════════════ +// TG4.5: Tool call with special content +// ═════════════════════════════════════════════════════════════════════════════ + +/// Agent should handle tool arguments with unicode content +#[tokio::test] +async fn agent_handles_unicode_tool_arguments() { + let provider = Box::new(MockProvider::new(vec![ + tool_response(vec![ToolCall { + id: "tc1".into(), + name: "echo".into(), + arguments: r#"{"message": "こんにちは世界 🌍"}"#.into(), + }]), + text_response("Unicode tool executed"), + ])); + + let mut agent = build_agent(provider, vec![Box::new(EchoTool)]); + let response = agent.turn("unicode test").await.unwrap(); + assert!(!response.is_empty()); +} + +/// Agent should handle tool arguments with nested JSON +#[tokio::test] +async fn agent_handles_nested_json_tool_arguments() { + let provider = Box::new(MockProvider::new(vec![ + tool_response(vec![ToolCall { + id: "tc1".into(), + name: "echo".into(), + arguments: r#"{"message": "{\"nested\": true, \"deep\": {\"level\": 3}}"}"#.into(), + }]), + text_response("Nested JSON tool executed"), + ])); + + let mut agent = build_agent(provider, vec![Box::new(EchoTool)]); + let response = agent.turn("nested json test").await.unwrap(); + assert!(!response.is_empty()); +} + +/// Agent should handle tool call followed by immediate text (no second LLM call) +#[tokio::test] +async fn agent_handles_sequential_tool_then_text() { + let provider = Box::new(MockProvider::new(vec![ + tool_response(vec![ToolCall { + id: "tc1".into(), + name: "echo".into(), + arguments: r#"{"message": "step 1"}"#.into(), + }]), + text_response("Final answer after tool"), + ])); + + let mut agent = build_agent(provider, vec![Box::new(EchoTool)]); + let response = agent.turn("two step").await.unwrap(); + assert!( + !response.is_empty(), + "should produce final text after tool execution" + ); +} diff --git a/tests/channel_routing.rs b/tests/channel_routing.rs new file mode 100644 index 0000000..178c85a --- /dev/null +++ b/tests/channel_routing.rs @@ -0,0 +1,318 @@ +//! TG3: Channel Message Identity & Routing Tests +//! +//! Prevents: Pattern 3 — Channel message routing & identity bugs (17% of user bugs). +//! Issues: #496, #483, #620, #415, #503 +//! +//! Tests that ChannelMessage fields are used consistently and that the +//! SendMessage → Channel trait contract preserves correct identity semantics. +//! Verifies sender/reply_target field contracts to prevent field swaps. + +use async_trait::async_trait; +use zeroclaw::channels::traits::{Channel, ChannelMessage, SendMessage}; + +// ───────────────────────────────────────────────────────────────────────────── +// ChannelMessage construction and field semantics +// ───────────────────────────────────────────────────────────────────────────── + +#[test] +fn channel_message_sender_field_holds_platform_user_id() { + // Simulates Telegram: sender should be numeric chat_id, not username + let msg = ChannelMessage { + id: "msg_1".into(), + sender: "123456789".into(), // numeric chat_id + reply_target: "msg_0".into(), + content: "test message".into(), + channel: "telegram".into(), + timestamp: 1700000000, + thread_ts: None, + }; + + assert_eq!(msg.sender, "123456789"); + // Sender should be the platform-level user/chat identifier + assert!( + msg.sender.chars().all(|c| c.is_ascii_digit()), + "Telegram sender should be numeric chat_id, got: {}", + msg.sender + ); +} + +#[test] +fn channel_message_reply_target_distinct_from_sender() { + // Simulates Discord: reply_target should be channel_id, not sender user_id + let msg = ChannelMessage { + id: "msg_1".into(), + sender: "user_987654".into(), // Discord user ID + reply_target: "channel_123".into(), // Discord channel ID for replies + content: "test message".into(), + channel: "discord".into(), + timestamp: 1700000000, + thread_ts: None, + }; + + assert_ne!( + msg.sender, msg.reply_target, + "sender and reply_target should be distinct for Discord" + ); + assert_eq!(msg.reply_target, "channel_123"); +} + +#[test] +fn channel_message_fields_not_swapped() { + // Guards against #496 (Telegram) and #483 (Discord) field swap bugs + let msg = ChannelMessage { + id: "msg_42".into(), + sender: "sender_value".into(), + reply_target: "target_value".into(), + content: "payload".into(), + channel: "test".into(), + timestamp: 1700000000, + thread_ts: None, + }; + + assert_eq!( + msg.sender, "sender_value", + "sender field should not be swapped" + ); + assert_eq!( + msg.reply_target, "target_value", + "reply_target field should not be swapped" + ); + assert_ne!( + msg.sender, msg.reply_target, + "sender and reply_target should remain distinct" + ); +} + +#[test] +fn channel_message_preserves_all_fields_on_clone() { + let original = ChannelMessage { + id: "clone_test".into(), + sender: "sender_123".into(), + reply_target: "target_456".into(), + content: "cloned content".into(), + channel: "test_channel".into(), + timestamp: 1700000001, + thread_ts: None, + }; + + let cloned = original.clone(); + + assert_eq!(cloned.id, original.id); + assert_eq!(cloned.sender, original.sender); + assert_eq!(cloned.reply_target, original.reply_target); + assert_eq!(cloned.content, original.content); + assert_eq!(cloned.channel, original.channel); + assert_eq!(cloned.timestamp, original.timestamp); +} + +// ───────────────────────────────────────────────────────────────────────────── +// SendMessage construction +// ───────────────────────────────────────────────────────────────────────────── + +#[test] +fn send_message_new_sets_content_and_recipient() { + let msg = SendMessage::new("Hello", "recipient_123"); + + assert_eq!(msg.content, "Hello"); + assert_eq!(msg.recipient, "recipient_123"); + assert!(msg.subject.is_none(), "subject should be None by default"); +} + +#[test] +fn send_message_with_subject_sets_all_fields() { + let msg = SendMessage::with_subject("Hello", "recipient_123", "Re: Test"); + + assert_eq!(msg.content, "Hello"); + assert_eq!(msg.recipient, "recipient_123"); + assert_eq!(msg.subject.as_deref(), Some("Re: Test")); +} + +#[test] +fn send_message_recipient_carries_platform_target() { + // Verifies that SendMessage::recipient is used as the platform delivery target + // For Telegram: this should be the chat_id + // For Discord: this should be the channel_id + let telegram_msg = SendMessage::new("response", "123456789"); + assert_eq!( + telegram_msg.recipient, "123456789", + "Telegram SendMessage recipient should be chat_id" + ); + + let discord_msg = SendMessage::new("response", "channel_987654"); + assert_eq!( + discord_msg.recipient, "channel_987654", + "Discord SendMessage recipient should be channel_id" + ); +} + +// ───────────────────────────────────────────────────────────────────────────── +// Channel trait contract: send/listen roundtrip via DummyChannel +// ───────────────────────────────────────────────────────────────────────────── + +/// Test channel that captures sent messages for assertion +struct CapturingChannel { + sent: std::sync::Mutex>, +} + +impl CapturingChannel { + fn new() -> Self { + Self { + sent: std::sync::Mutex::new(Vec::new()), + } + } + + fn sent_messages(&self) -> Vec { + self.sent.lock().unwrap().clone() + } +} + +#[async_trait] +impl Channel for CapturingChannel { + fn name(&self) -> &str { + "capturing" + } + + async fn send(&self, message: &SendMessage) -> anyhow::Result<()> { + self.sent.lock().unwrap().push(message.clone()); + Ok(()) + } + + async fn listen(&self, tx: tokio::sync::mpsc::Sender) -> anyhow::Result<()> { + tx.send(ChannelMessage { + id: "listen_1".into(), + sender: "test_sender".into(), + reply_target: "test_target".into(), + content: "incoming".into(), + channel: "capturing".into(), + timestamp: 1700000000, + thread_ts: None, + }) + .await + .map_err(|e| anyhow::anyhow!(e.to_string())) + } +} + +#[tokio::test] +async fn channel_send_preserves_recipient() { + let channel = CapturingChannel::new(); + let msg = SendMessage::new("Hello", "target_123"); + + channel.send(&msg).await.unwrap(); + + let sent = channel.sent_messages(); + assert_eq!(sent.len(), 1); + assert_eq!(sent[0].recipient, "target_123"); + assert_eq!(sent[0].content, "Hello"); +} + +#[tokio::test] +async fn channel_listen_produces_correct_identity_fields() { + let channel = CapturingChannel::new(); + let (tx, mut rx) = tokio::sync::mpsc::channel(1); + + channel.listen(tx).await.unwrap(); + let received = rx.recv().await.expect("should receive message"); + + assert_eq!(received.sender, "test_sender"); + assert_eq!(received.reply_target, "test_target"); + assert_ne!( + received.sender, received.reply_target, + "listen() should populate sender and reply_target distinctly" + ); +} + +#[tokio::test] +async fn channel_send_reply_uses_sender_from_listen() { + let channel = CapturingChannel::new(); + let (tx, mut rx) = tokio::sync::mpsc::channel(1); + + // Simulate: listen() → receive message → send reply using sender + channel.listen(tx).await.unwrap(); + let incoming = rx.recv().await.expect("should receive message"); + + // Reply should go to the reply_target, not sender + let reply = SendMessage::new("reply content", &incoming.reply_target); + channel.send(&reply).await.unwrap(); + + let sent = channel.sent_messages(); + assert_eq!(sent.len(), 1); + assert_eq!( + sent[0].recipient, "test_target", + "reply should use reply_target as recipient" + ); +} + +// ───────────────────────────────────────────────────────────────────────────── +// Channel trait default methods +// ───────────────────────────────────────────────────────────────────────────── + +#[tokio::test] +async fn channel_health_check_default_returns_true() { + let channel = CapturingChannel::new(); + assert!( + channel.health_check().await, + "default health_check should return true" + ); +} + +#[tokio::test] +async fn channel_typing_defaults_succeed() { + let channel = CapturingChannel::new(); + assert!(channel.start_typing("target").await.is_ok()); + assert!(channel.stop_typing("target").await.is_ok()); +} + +#[tokio::test] +async fn channel_draft_defaults() { + let channel = CapturingChannel::new(); + assert!(!channel.supports_draft_updates()); + + let draft_result = channel + .send_draft(&SendMessage::new("draft", "target")) + .await + .unwrap(); + assert!( + draft_result.is_none(), + "default send_draft should return None" + ); + + assert!(channel + .update_draft("target", "msg_1", "updated") + .await + .is_ok()); + assert!(channel + .finalize_draft("target", "msg_1", "final") + .await + .is_ok()); +} + +// ───────────────────────────────────────────────────────────────────────────── +// Multiple messages: conversation context preservation +// ───────────────────────────────────────────────────────────────────────────── + +#[tokio::test] +async fn channel_multiple_sends_preserve_order_and_recipients() { + let channel = CapturingChannel::new(); + + channel + .send(&SendMessage::new("msg 1", "target_a")) + .await + .unwrap(); + channel + .send(&SendMessage::new("msg 2", "target_b")) + .await + .unwrap(); + channel + .send(&SendMessage::new("msg 3", "target_a")) + .await + .unwrap(); + + let sent = channel.sent_messages(); + assert_eq!(sent.len(), 3); + assert_eq!(sent[0].recipient, "target_a"); + assert_eq!(sent[1].recipient, "target_b"); + assert_eq!(sent[2].recipient, "target_a"); + assert_eq!(sent[0].content, "msg 1"); + assert_eq!(sent[1].content, "msg 2"); + assert_eq!(sent[2].content, "msg 3"); +} diff --git a/tests/config_persistence.rs b/tests/config_persistence.rs new file mode 100644 index 0000000..079b9df --- /dev/null +++ b/tests/config_persistence.rs @@ -0,0 +1,248 @@ +//! TG2: Config Load/Save Round-Trip Tests +//! +//! Prevents: Pattern 2 — Config persistence & workspace discovery bugs (13% of user bugs). +//! Issues: #547, #417, #621, #802 +//! +//! Tests Config::load_or_init() with isolated temp directories, env var overrides, +//! and config file round-trips to verify workspace discovery and persistence. + +use std::fs; +use zeroclaw::config::{AgentConfig, Config, MemoryConfig}; + +// ───────────────────────────────────────────────────────────────────────────── +// Config default construction +// ───────────────────────────────────────────────────────────────────────────── + +#[test] +fn config_default_has_expected_provider() { + let config = Config::default(); + assert!( + config.default_provider.is_some(), + "default config should have a default_provider" + ); +} + +#[test] +fn config_default_has_expected_model() { + let config = Config::default(); + assert!( + config.default_model.is_some(), + "default config should have a default_model" + ); +} + +#[test] +fn config_default_temperature_positive() { + let config = Config::default(); + assert!( + config.default_temperature > 0.0, + "default temperature should be positive" + ); +} + +// ───────────────────────────────────────────────────────────────────────────── +// AgentConfig defaults +// ───────────────────────────────────────────────────────────────────────────── + +#[test] +fn agent_config_default_max_tool_iterations() { + let agent = AgentConfig::default(); + assert_eq!( + agent.max_tool_iterations, 10, + "default max_tool_iterations should be 10" + ); +} + +#[test] +fn agent_config_default_max_history_messages() { + let agent = AgentConfig::default(); + assert_eq!( + agent.max_history_messages, 50, + "default max_history_messages should be 50" + ); +} + +#[test] +fn agent_config_default_tool_dispatcher() { + let agent = AgentConfig::default(); + assert_eq!( + agent.tool_dispatcher, "auto", + "default tool_dispatcher should be 'auto'" + ); +} + +#[test] +fn agent_config_default_compact_context_off() { + let agent = AgentConfig::default(); + assert!( + !agent.compact_context, + "compact_context should default to false" + ); +} + +// ───────────────────────────────────────────────────────────────────────────── +// MemoryConfig defaults +// ───────────────────────────────────────────────────────────────────────────── + +#[test] +fn memory_config_default_backend() { + let memory = MemoryConfig::default(); + assert!( + !memory.backend.is_empty(), + "memory backend should have a default value" + ); +} + +#[test] +fn memory_config_default_embedding_provider() { + let memory = MemoryConfig::default(); + // Default embedding_provider should be set (even if "none") + assert!( + !memory.embedding_provider.is_empty(), + "embedding_provider should have a default value" + ); +} + +#[test] +fn memory_config_default_vector_keyword_weights_sum_to_one() { + let memory = MemoryConfig::default(); + let sum = memory.vector_weight + memory.keyword_weight; + assert!( + (sum - 1.0).abs() < 0.01, + "vector_weight + keyword_weight should sum to ~1.0, got {sum}" + ); +} + +// ───────────────────────────────────────────────────────────────────────────── +// Config TOML serialization round-trip +// ───────────────────────────────────────────────────────────────────────────── + +#[test] +fn config_toml_roundtrip_preserves_provider() { + let mut config = Config::default(); + config.default_provider = Some("deepseek".into()); + config.default_model = Some("deepseek-chat".into()); + config.default_temperature = 0.5; + + let toml_str = toml::to_string(&config).expect("config should serialize to TOML"); + let parsed: Config = toml::from_str(&toml_str).expect("TOML should deserialize back"); + + assert_eq!(parsed.default_provider.as_deref(), Some("deepseek")); + assert_eq!(parsed.default_model.as_deref(), Some("deepseek-chat")); + assert!((parsed.default_temperature - 0.5).abs() < f64::EPSILON); +} + +#[test] +fn config_toml_roundtrip_preserves_agent_config() { + let mut config = Config::default(); + config.agent.max_tool_iterations = 5; + config.agent.max_history_messages = 25; + config.agent.compact_context = true; + + let toml_str = toml::to_string(&config).expect("config should serialize to TOML"); + let parsed: Config = toml::from_str(&toml_str).expect("TOML should deserialize back"); + + assert_eq!(parsed.agent.max_tool_iterations, 5); + assert_eq!(parsed.agent.max_history_messages, 25); + assert!(parsed.agent.compact_context); +} + +#[test] +fn config_toml_roundtrip_preserves_memory_config() { + let mut config = Config::default(); + config.memory.embedding_provider = "openai".into(); + config.memory.embedding_model = "text-embedding-3-small".into(); + config.memory.vector_weight = 0.8; + config.memory.keyword_weight = 0.2; + + let toml_str = toml::to_string(&config).expect("config should serialize to TOML"); + let parsed: Config = toml::from_str(&toml_str).expect("TOML should deserialize back"); + + assert_eq!(parsed.memory.embedding_provider, "openai"); + assert_eq!(parsed.memory.embedding_model, "text-embedding-3-small"); + assert!((parsed.memory.vector_weight - 0.8).abs() < f64::EPSILON); + assert!((parsed.memory.keyword_weight - 0.2).abs() < f64::EPSILON); +} + +// ───────────────────────────────────────────────────────────────────────────── +// Config file write/read round-trip with tempdir +// ───────────────────────────────────────────────────────────────────────────── + +#[test] +fn config_file_write_read_roundtrip() { + let tmp = tempfile::TempDir::new().expect("tempdir creation should succeed"); + let config_path = tmp.path().join("config.toml"); + + let mut config = Config::default(); + config.default_provider = Some("mistral".into()); + config.default_model = Some("mistral-large".into()); + config.agent.max_tool_iterations = 15; + + let toml_str = toml::to_string(&config).expect("config should serialize"); + fs::write(&config_path, &toml_str).expect("config file write should succeed"); + + let read_back = fs::read_to_string(&config_path).expect("config file read should succeed"); + let parsed: Config = toml::from_str(&read_back).expect("TOML should parse back"); + + assert_eq!(parsed.default_provider.as_deref(), Some("mistral")); + assert_eq!(parsed.default_model.as_deref(), Some("mistral-large")); + assert_eq!(parsed.agent.max_tool_iterations, 15); +} + +#[test] +fn config_file_with_missing_optional_fields_uses_defaults() { + // Simulate a minimal config TOML that omits optional sections + let minimal_toml = r#" +default_temperature = 0.7 +"#; + let parsed: Config = toml::from_str(minimal_toml).expect("minimal TOML should parse"); + + // Agent config should use defaults + assert_eq!(parsed.agent.max_tool_iterations, 10); + assert_eq!(parsed.agent.max_history_messages, 50); + assert!(!parsed.agent.compact_context); +} + +#[test] +fn config_file_with_custom_agent_section() { + let toml_with_agent = r#" +default_temperature = 0.7 + +[agent] +max_tool_iterations = 3 +compact_context = true +"#; + let parsed: Config = + toml::from_str(toml_with_agent).expect("TOML with agent section should parse"); + + assert_eq!(parsed.agent.max_tool_iterations, 3); + assert!(parsed.agent.compact_context); + // max_history_messages should still use default + assert_eq!(parsed.agent.max_history_messages, 50); +} + +// ───────────────────────────────────────────────────────────────────────────── +// Workspace directory creation +// ───────────────────────────────────────────────────────────────────────────── + +#[test] +fn workspace_dir_creation_in_tempdir() { + let tmp = tempfile::TempDir::new().expect("tempdir creation should succeed"); + let workspace_dir = tmp.path().join("workspace"); + + fs::create_dir_all(&workspace_dir).expect("workspace dir creation should succeed"); + assert!(workspace_dir.exists(), "workspace dir should exist"); + assert!( + workspace_dir.is_dir(), + "workspace path should be a directory" + ); +} + +#[test] +fn nested_workspace_dir_creation() { + let tmp = tempfile::TempDir::new().expect("tempdir creation should succeed"); + let nested_dir = tmp.path().join("deep").join("nested").join("workspace"); + + fs::create_dir_all(&nested_dir).expect("nested dir creation should succeed"); + assert!(nested_dir.exists(), "nested workspace dir should exist"); +} diff --git a/tests/dockerignore_test.rs b/tests/dockerignore_test.rs index f321753..8af6fa8 100644 --- a/tests/dockerignore_test.rs +++ b/tests/dockerignore_test.rs @@ -6,7 +6,6 @@ //! 3. All build-essential paths are NOT excluded //! 4. Pattern syntax is valid -use std::fs; use std::path::Path; /// Paths that MUST be excluded from Docker build context (security/performance) @@ -96,8 +95,8 @@ fn is_excluded(patterns: &[String], path: &str) -> bool { excluded } -#[test] -fn dockerignore_file_exists() { +#[tokio::test] +async fn dockerignore_file_exists() { let path = Path::new(env!("CARGO_MANIFEST_DIR")).join(".dockerignore"); assert!( path.exists(), @@ -105,10 +104,12 @@ fn dockerignore_file_exists() { ); } -#[test] -fn dockerignore_excludes_security_critical_paths() { +#[tokio::test] +async fn dockerignore_excludes_security_critical_paths() { let path = Path::new(env!("CARGO_MANIFEST_DIR")).join(".dockerignore"); - let content = fs::read_to_string(&path).expect("Failed to read .dockerignore"); + let content = tokio::fs::read_to_string(&path) + .await + .expect("Failed to read .dockerignore"); let patterns = parse_dockerignore(&content); for must_exclude in MUST_EXCLUDE { @@ -129,10 +130,12 @@ fn dockerignore_excludes_security_critical_paths() { } } -#[test] -fn dockerignore_does_not_exclude_build_essentials() { +#[tokio::test] +async fn dockerignore_does_not_exclude_build_essentials() { let path = Path::new(env!("CARGO_MANIFEST_DIR")).join(".dockerignore"); - let content = fs::read_to_string(&path).expect("Failed to read .dockerignore"); + let content = tokio::fs::read_to_string(&path) + .await + .expect("Failed to read .dockerignore"); let patterns = parse_dockerignore(&content); for must_include in MUST_INCLUDE { @@ -144,10 +147,12 @@ fn dockerignore_does_not_exclude_build_essentials() { } } -#[test] -fn dockerignore_excludes_git_directory() { +#[tokio::test] +async fn dockerignore_excludes_git_directory() { let path = Path::new(env!("CARGO_MANIFEST_DIR")).join(".dockerignore"); - let content = fs::read_to_string(&path).expect("Failed to read .dockerignore"); + let content = tokio::fs::read_to_string(&path) + .await + .expect("Failed to read .dockerignore"); let patterns = parse_dockerignore(&content); // .git directory and its contents must be excluded @@ -162,10 +167,12 @@ fn dockerignore_excludes_git_directory() { ); } -#[test] -fn dockerignore_excludes_target_directory() { +#[tokio::test] +async fn dockerignore_excludes_target_directory() { let path = Path::new(env!("CARGO_MANIFEST_DIR")).join(".dockerignore"); - let content = fs::read_to_string(&path).expect("Failed to read .dockerignore"); + let content = tokio::fs::read_to_string(&path) + .await + .expect("Failed to read .dockerignore"); let patterns = parse_dockerignore(&content); assert!(is_excluded(&patterns, "target"), "target must be excluded"); @@ -179,10 +186,12 @@ fn dockerignore_excludes_target_directory() { ); } -#[test] -fn dockerignore_excludes_database_files() { +#[tokio::test] +async fn dockerignore_excludes_database_files() { let path = Path::new(env!("CARGO_MANIFEST_DIR")).join(".dockerignore"); - let content = fs::read_to_string(&path).expect("Failed to read .dockerignore"); + let content = tokio::fs::read_to_string(&path) + .await + .expect("Failed to read .dockerignore"); let patterns = parse_dockerignore(&content); assert!( @@ -199,10 +208,12 @@ fn dockerignore_excludes_database_files() { ); } -#[test] -fn dockerignore_excludes_markdown_files() { +#[tokio::test] +async fn dockerignore_excludes_markdown_files() { let path = Path::new(env!("CARGO_MANIFEST_DIR")).join(".dockerignore"); - let content = fs::read_to_string(&path).expect("Failed to read .dockerignore"); + let content = tokio::fs::read_to_string(&path) + .await + .expect("Failed to read .dockerignore"); let patterns = parse_dockerignore(&content); assert!( @@ -219,10 +230,12 @@ fn dockerignore_excludes_markdown_files() { ); } -#[test] -fn dockerignore_excludes_image_files() { +#[tokio::test] +async fn dockerignore_excludes_image_files() { let path = Path::new(env!("CARGO_MANIFEST_DIR")).join(".dockerignore"); - let content = fs::read_to_string(&path).expect("Failed to read .dockerignore"); + let content = tokio::fs::read_to_string(&path) + .await + .expect("Failed to read .dockerignore"); let patterns = parse_dockerignore(&content); assert!( @@ -235,10 +248,12 @@ fn dockerignore_excludes_image_files() { ); } -#[test] -fn dockerignore_excludes_env_files() { +#[tokio::test] +async fn dockerignore_excludes_env_files() { let path = Path::new(env!("CARGO_MANIFEST_DIR")).join(".dockerignore"); - let content = fs::read_to_string(&path).expect("Failed to read .dockerignore"); + let content = tokio::fs::read_to_string(&path) + .await + .expect("Failed to read .dockerignore"); let patterns = parse_dockerignore(&content); assert!( @@ -247,10 +262,12 @@ fn dockerignore_excludes_env_files() { ); } -#[test] -fn dockerignore_excludes_ci_configs() { +#[tokio::test] +async fn dockerignore_excludes_ci_configs() { let path = Path::new(env!("CARGO_MANIFEST_DIR")).join(".dockerignore"); - let content = fs::read_to_string(&path).expect("Failed to read .dockerignore"); + let content = tokio::fs::read_to_string(&path) + .await + .expect("Failed to read .dockerignore"); let patterns = parse_dockerignore(&content); assert!( @@ -263,10 +280,12 @@ fn dockerignore_excludes_ci_configs() { ); } -#[test] -fn dockerignore_has_valid_syntax() { +#[tokio::test] +async fn dockerignore_has_valid_syntax() { let path = Path::new(env!("CARGO_MANIFEST_DIR")).join(".dockerignore"); - let content = fs::read_to_string(&path).expect("Failed to read .dockerignore"); + let content = tokio::fs::read_to_string(&path) + .await + .expect("Failed to read .dockerignore"); for (line_num, line) in content.lines().enumerate() { let trimmed = line.trim(); @@ -294,8 +313,8 @@ fn dockerignore_has_valid_syntax() { } } -#[test] -fn dockerignore_pattern_matching_edge_cases() { +#[tokio::test] +async fn dockerignore_pattern_matching_edge_cases() { // Test the pattern matching logic itself let patterns = vec![ ".git".to_string(), diff --git a/tests/memory_restart.rs b/tests/memory_restart.rs new file mode 100644 index 0000000..fe63f16 --- /dev/null +++ b/tests/memory_restart.rs @@ -0,0 +1,367 @@ +//! TG5: Memory Restart Resilience Tests +//! +//! Prevents: Pattern 5 — Memory & state persistence bugs (10% of user bugs). +//! Issues: #430, #693, #802 +//! +//! Tests SqliteMemory deduplication on restart, session scoping, concurrent +//! message ordering, and recall behavior after re-initialization. + +use std::sync::Arc; +use zeroclaw::memory::sqlite::SqliteMemory; +use zeroclaw::memory::traits::{Memory, MemoryCategory}; + +// ───────────────────────────────────────────────────────────────────────────── +// Deduplication: same key overwrites instead of duplicating (#430) +// ───────────────────────────────────────────────────────────────────────────── + +#[tokio::test] +async fn sqlite_memory_store_same_key_deduplicates() { + let tmp = tempfile::TempDir::new().unwrap(); + let mem = SqliteMemory::new(tmp.path()).unwrap(); + + // Store same key twice with different content + mem.store("greeting", "hello world", MemoryCategory::Core, None) + .await + .unwrap(); + mem.store("greeting", "hello updated", MemoryCategory::Core, None) + .await + .unwrap(); + + // Should have exactly 1 entry, not 2 + let count = mem.count().await.unwrap(); + assert_eq!( + count, 1, + "storing same key twice should not create duplicates" + ); + + // Content should be the latest version + let entry = mem + .get("greeting") + .await + .unwrap() + .expect("entry should exist"); + assert_eq!(entry.content, "hello updated"); +} + +#[tokio::test] +async fn sqlite_memory_store_different_keys_creates_separate_entries() { + let tmp = tempfile::TempDir::new().unwrap(); + let mem = SqliteMemory::new(tmp.path()).unwrap(); + + mem.store("key_a", "content a", MemoryCategory::Core, None) + .await + .unwrap(); + mem.store("key_b", "content b", MemoryCategory::Core, None) + .await + .unwrap(); + + let count = mem.count().await.unwrap(); + assert_eq!(count, 2, "different keys should create separate entries"); +} + +// ───────────────────────────────────────────────────────────────────────────── +// Restart resilience: data persists across memory re-initialization +// ───────────────────────────────────────────────────────────────────────────── + +#[tokio::test] +async fn sqlite_memory_persists_across_reinitialization() { + let tmp = tempfile::TempDir::new().unwrap(); + + // First "session": store data + { + let mem = SqliteMemory::new(tmp.path()).unwrap(); + mem.store( + "persistent_fact", + "Rust is great", + MemoryCategory::Core, + None, + ) + .await + .unwrap(); + } + + // Second "session": re-create memory from same path + { + let mem = SqliteMemory::new(tmp.path()).unwrap(); + let entry = mem + .get("persistent_fact") + .await + .unwrap() + .expect("entry should survive reinitialization"); + assert_eq!(entry.content, "Rust is great"); + } +} + +#[tokio::test] +async fn sqlite_memory_restart_does_not_duplicate_on_rewrite() { + let tmp = tempfile::TempDir::new().unwrap(); + + // First session: store entries + { + let mem = SqliteMemory::new(tmp.path()).unwrap(); + mem.store("fact_1", "original content", MemoryCategory::Core, None) + .await + .unwrap(); + mem.store("fact_2", "another fact", MemoryCategory::Core, None) + .await + .unwrap(); + } + + // Second session: re-store same keys (simulates channel re-reading history) + { + let mem = SqliteMemory::new(tmp.path()).unwrap(); + mem.store("fact_1", "original content", MemoryCategory::Core, None) + .await + .unwrap(); + mem.store("fact_2", "another fact", MemoryCategory::Core, None) + .await + .unwrap(); + + let count = mem.count().await.unwrap(); + assert_eq!( + count, 2, + "re-storing same keys after restart should not create duplicates" + ); + } +} + +// ───────────────────────────────────────────────────────────────────────────── +// Session scoping: messages scoped to sessions don't leak +// ───────────────────────────────────────────────────────────────────────────── + +#[tokio::test] +async fn sqlite_memory_session_scoped_store_and_recall() { + let tmp = tempfile::TempDir::new().unwrap(); + let mem = SqliteMemory::new(tmp.path()).unwrap(); + + // Store in different sessions + mem.store( + "session_a_fact", + "fact from session A", + MemoryCategory::Conversation, + Some("session_a"), + ) + .await + .unwrap(); + mem.store( + "session_b_fact", + "fact from session B", + MemoryCategory::Conversation, + Some("session_b"), + ) + .await + .unwrap(); + + // List scoped to session_a + let session_a_entries = mem + .list(Some(&MemoryCategory::Conversation), Some("session_a")) + .await + .unwrap(); + assert_eq!( + session_a_entries.len(), + 1, + "session_a should have exactly 1 entry" + ); + assert_eq!(session_a_entries[0].content, "fact from session A"); +} + +#[tokio::test] +async fn sqlite_memory_global_recall_includes_all_sessions() { + let tmp = tempfile::TempDir::new().unwrap(); + let mem = SqliteMemory::new(tmp.path()).unwrap(); + + mem.store( + "global_a", + "alpha content", + MemoryCategory::Core, + Some("s1"), + ) + .await + .unwrap(); + mem.store("global_b", "beta content", MemoryCategory::Core, Some("s2")) + .await + .unwrap(); + + // Global count should include all + let count = mem.count().await.unwrap(); + assert_eq!( + count, 2, + "global count should include entries from all sessions" + ); +} + +// ───────────────────────────────────────────────────────────────────────────── +// Recall and search behavior +// ───────────────────────────────────────────────────────────────────────────── + +#[tokio::test] +async fn sqlite_memory_recall_returns_relevant_results() { + let tmp = tempfile::TempDir::new().unwrap(); + let mem = SqliteMemory::new(tmp.path()).unwrap(); + + mem.store( + "lang_pref", + "User prefers Rust programming", + MemoryCategory::Core, + None, + ) + .await + .unwrap(); + mem.store( + "food_pref", + "User likes sushi for lunch", + MemoryCategory::Core, + None, + ) + .await + .unwrap(); + + let results = mem.recall("Rust programming", 10, None).await.unwrap(); + assert!(!results.is_empty(), "recall should find matching entries"); + // The Rust-related entry should be in results + assert!( + results.iter().any(|e| e.content.contains("Rust")), + "recall for 'Rust' should include the Rust-related entry" + ); +} + +#[tokio::test] +async fn sqlite_memory_recall_respects_limit() { + let tmp = tempfile::TempDir::new().unwrap(); + let mem = SqliteMemory::new(tmp.path()).unwrap(); + + for i in 0..10 { + mem.store( + &format!("entry_{i}"), + &format!("test content number {i}"), + MemoryCategory::Core, + None, + ) + .await + .unwrap(); + } + + let results = mem.recall("test content", 3, None).await.unwrap(); + assert!( + results.len() <= 3, + "recall should respect limit of 3, got {}", + results.len() + ); +} + +#[tokio::test] +async fn sqlite_memory_recall_empty_query_returns_empty() { + let tmp = tempfile::TempDir::new().unwrap(); + let mem = SqliteMemory::new(tmp.path()).unwrap(); + + mem.store("fact", "some content", MemoryCategory::Core, None) + .await + .unwrap(); + + let results = mem.recall("", 10, None).await.unwrap(); + assert!(results.is_empty(), "empty query should return no results"); +} + +// ───────────────────────────────────────────────────────────────────────────── +// Forget and health check +// ───────────────────────────────────────────────────────────────────────────── + +#[tokio::test] +async fn sqlite_memory_forget_removes_entry() { + let tmp = tempfile::TempDir::new().unwrap(); + let mem = SqliteMemory::new(tmp.path()).unwrap(); + + mem.store("to_forget", "temporary info", MemoryCategory::Core, None) + .await + .unwrap(); + assert_eq!(mem.count().await.unwrap(), 1); + + let removed = mem.forget("to_forget").await.unwrap(); + assert!(removed, "forget should return true for existing key"); + assert_eq!(mem.count().await.unwrap(), 0); +} + +#[tokio::test] +async fn sqlite_memory_forget_nonexistent_returns_false() { + let tmp = tempfile::TempDir::new().unwrap(); + let mem = SqliteMemory::new(tmp.path()).unwrap(); + + let removed = mem.forget("nonexistent_key").await.unwrap(); + assert!(!removed, "forget should return false for nonexistent key"); +} + +#[tokio::test] +async fn sqlite_memory_health_check_returns_true() { + let tmp = tempfile::TempDir::new().unwrap(); + let mem = SqliteMemory::new(tmp.path()).unwrap(); + + assert!(mem.health_check().await, "health_check should return true"); +} + +// ───────────────────────────────────────────────────────────────────────────── +// Concurrent access +// ───────────────────────────────────────────────────────────────────────────── + +#[tokio::test] +async fn sqlite_memory_concurrent_stores_no_data_loss() { + let tmp = tempfile::TempDir::new().unwrap(); + let mem = Arc::new(SqliteMemory::new(tmp.path()).unwrap()); + + let mut handles = Vec::new(); + for i in 0..5 { + let mem_clone = mem.clone(); + handles.push(tokio::spawn(async move { + mem_clone + .store( + &format!("concurrent_{i}"), + &format!("content from task {i}"), + MemoryCategory::Core, + None, + ) + .await + .unwrap(); + })); + } + + for handle in handles { + handle.await.unwrap(); + } + + let count = mem.count().await.unwrap(); + assert_eq!( + count, 5, + "all concurrent stores should succeed, got {count}" + ); +} + +// ───────────────────────────────────────────────────────────────────────────── +// Memory categories +// ───────────────────────────────────────────────────────────────────────────── + +#[tokio::test] +async fn sqlite_memory_list_by_category() { + let tmp = tempfile::TempDir::new().unwrap(); + let mem = SqliteMemory::new(tmp.path()).unwrap(); + + mem.store("core_fact", "core info", MemoryCategory::Core, None) + .await + .unwrap(); + mem.store("daily_note", "daily note", MemoryCategory::Daily, None) + .await + .unwrap(); + mem.store( + "conv_msg", + "conversation msg", + MemoryCategory::Conversation, + None, + ) + .await + .unwrap(); + + let core_entries = mem.list(Some(&MemoryCategory::Core), None).await.unwrap(); + assert_eq!(core_entries.len(), 1, "should have 1 Core entry"); + assert_eq!(core_entries[0].key, "core_fact"); + + let daily_entries = mem.list(Some(&MemoryCategory::Daily), None).await.unwrap(); + assert_eq!(daily_entries.len(), 1, "should have 1 Daily entry"); +} diff --git a/tests/otel_dependency_feature_regression.rs b/tests/otel_dependency_feature_regression.rs new file mode 100644 index 0000000..0620b75 --- /dev/null +++ b/tests/otel_dependency_feature_regression.rs @@ -0,0 +1,17 @@ +#[test] +fn opentelemetry_otlp_uses_blocking_reqwest_client() { + let manifest = include_str!("../Cargo.toml"); + let otlp_line = manifest + .lines() + .find(|line| line.trim_start().starts_with("opentelemetry-otlp =")) + .expect("Cargo.toml must define opentelemetry-otlp dependency"); + + assert!( + otlp_line.contains("\"reqwest-blocking-client\""), + "opentelemetry-otlp must include reqwest-blocking-client to avoid Tokio reactor panics" + ); + assert!( + !otlp_line.contains("\"reqwest-client\""), + "opentelemetry-otlp must not include async reqwest-client in this runtime mode" + ); +} diff --git a/tests/provider_resolution.rs b/tests/provider_resolution.rs new file mode 100644 index 0000000..c88fa93 --- /dev/null +++ b/tests/provider_resolution.rs @@ -0,0 +1,244 @@ +//! TG1: Provider End-to-End Resolution Tests +//! +//! Prevents: Pattern 1 — Provider configuration & resolution bugs (27% of user bugs). +//! Issues: #831, #834, #721, #580, #452, #451, #796, #843 +//! +//! Tests the full pipeline from config values through `create_provider_with_url()` +//! to provider construction, verifying factory resolution, URL construction, +//! credential wiring, and auth header format. + +use zeroclaw::providers::compatible::{AuthStyle, OpenAiCompatibleProvider}; +use zeroclaw::providers::{create_provider, create_provider_with_url}; + +/// Helper: assert provider creation succeeds +fn assert_provider_ok(name: &str, key: Option<&str>, url: Option<&str>) { + let result = create_provider_with_url(name, key, url); + assert!( + result.is_ok(), + "{name} provider should resolve: {}", + result.err().map(|e| e.to_string()).unwrap_or_default() + ); +} + +// ───────────────────────────────────────────────────────────────────────────── +// Factory resolution: each major provider name resolves without error +// ───────────────────────────────────────────────────────────────────────────── + +#[test] +fn factory_resolves_openai_provider() { + assert_provider_ok("openai", Some("test-key"), None); +} + +#[test] +fn factory_resolves_anthropic_provider() { + assert_provider_ok("anthropic", Some("test-key"), None); +} + +#[test] +fn factory_resolves_deepseek_provider() { + assert_provider_ok("deepseek", Some("test-key"), None); +} + +#[test] +fn factory_resolves_mistral_provider() { + assert_provider_ok("mistral", Some("test-key"), None); +} + +#[test] +fn factory_resolves_ollama_provider() { + assert_provider_ok("ollama", None, None); +} + +#[test] +fn factory_resolves_groq_provider() { + assert_provider_ok("groq", Some("test-key"), None); +} + +#[test] +fn factory_resolves_xai_provider() { + assert_provider_ok("xai", Some("test-key"), None); +} + +#[test] +fn factory_resolves_together_provider() { + assert_provider_ok("together", Some("test-key"), None); +} + +#[test] +fn factory_resolves_fireworks_provider() { + assert_provider_ok("fireworks", Some("test-key"), None); +} + +#[test] +fn factory_resolves_perplexity_provider() { + assert_provider_ok("perplexity", Some("test-key"), None); +} + +// ───────────────────────────────────────────────────────────────────────────── +// Factory resolution: alias variants map to same provider +// ───────────────────────────────────────────────────────────────────────────── + +#[test] +fn factory_grok_alias_resolves_to_xai() { + assert_provider_ok("grok", Some("test-key"), None); +} + +#[test] +fn factory_kimi_alias_resolves_to_moonshot() { + assert_provider_ok("kimi", Some("test-key"), None); +} + +#[test] +fn factory_zhipu_alias_resolves_to_glm() { + assert_provider_ok("zhipu", Some("test-key"), None); +} + +// ───────────────────────────────────────────────────────────────────────────── +// Custom URL provider creation +// ───────────────────────────────────────────────────────────────────────────── + +#[test] +fn factory_custom_http_url_resolves() { + assert_provider_ok("custom:http://localhost:8080", Some("test-key"), None); +} + +#[test] +fn factory_custom_https_url_resolves() { + assert_provider_ok("custom:https://api.example.com/v1", Some("test-key"), None); +} + +#[test] +fn factory_custom_ftp_url_rejected() { + let result = create_provider_with_url("custom:ftp://example.com", None, None); + assert!(result.is_err(), "ftp scheme should be rejected"); + let err_msg = result.err().unwrap().to_string(); + assert!( + err_msg.contains("http://") || err_msg.contains("https://"), + "error should mention valid schemes: {err_msg}" + ); +} + +#[test] +fn factory_custom_empty_url_rejected() { + let result = create_provider_with_url("custom:", None, None); + assert!(result.is_err(), "empty custom URL should be rejected"); +} + +#[test] +fn factory_unknown_provider_rejected() { + let result = create_provider_with_url("nonexistent_provider_xyz", None, None); + assert!(result.is_err(), "unknown provider name should be rejected"); +} + +// ───────────────────────────────────────────────────────────────────────────── +// OpenAiCompatibleProvider: credential and auth style wiring +// ───────────────────────────────────────────────────────────────────────────── + +#[test] +fn compatible_provider_bearer_auth_style() { + // Construction with Bearer auth should succeed + let _provider = OpenAiCompatibleProvider::new( + "TestProvider", + "https://api.test.com", + Some("sk-test-key-12345"), + AuthStyle::Bearer, + ); +} + +#[test] +fn compatible_provider_xapikey_auth_style() { + // Construction with XApiKey auth should succeed + let _provider = OpenAiCompatibleProvider::new( + "TestProvider", + "https://api.test.com", + Some("sk-test-key-12345"), + AuthStyle::XApiKey, + ); +} + +#[test] +fn compatible_provider_custom_auth_header() { + // Construction with Custom auth should succeed + let _provider = OpenAiCompatibleProvider::new( + "TestProvider", + "https://api.test.com", + Some("sk-test-key-12345"), + AuthStyle::Custom("X-Custom-Auth".into()), + ); +} + +#[test] +fn compatible_provider_no_credential() { + // Construction without credential should succeed (for local providers) + let _provider = OpenAiCompatibleProvider::new( + "TestLocal", + "http://localhost:11434", + None, + AuthStyle::Bearer, + ); +} + +#[test] +fn compatible_provider_base_url_trailing_slash_normalized() { + // Construction with trailing slash URL should succeed + let _provider = OpenAiCompatibleProvider::new( + "TestProvider", + "https://api.test.com/v1/", + Some("key"), + AuthStyle::Bearer, + ); +} + +// ───────────────────────────────────────────────────────────────────────────── +// Provider with api_url override (simulates #721 - Ollama api_url config) +// ───────────────────────────────────────────────────────────────────────────── + +#[test] +fn factory_ollama_with_custom_api_url() { + assert_provider_ok("ollama", None, Some("http://192.168.1.100:11434")); +} + +#[test] +fn factory_openai_with_custom_api_url() { + assert_provider_ok( + "openai", + Some("test-key"), + Some("https://custom-openai-proxy.example.com/v1"), + ); +} + +// ───────────────────────────────────────────────────────────────────────────── +// Provider default convenience factory +// ───────────────────────────────────────────────────────────────────────────── + +#[test] +fn convenience_factory_resolves_major_providers() { + for provider_name in &[ + "openai", + "anthropic", + "deepseek", + "mistral", + "groq", + "xai", + "together", + "fireworks", + "perplexity", + ] { + let result = create_provider(provider_name, Some("test-key")); + assert!( + result.is_ok(), + "convenience factory should resolve {provider_name}: {}", + result.err().map(|e| e.to_string()).unwrap_or_default() + ); + } +} + +#[test] +fn convenience_factory_ollama_no_key() { + let result = create_provider("ollama", None); + assert!( + result.is_ok(), + "ollama should not require api key: {}", + result.err().map(|e| e.to_string()).unwrap_or_default() + ); +} diff --git a/tests/provider_schema.rs b/tests/provider_schema.rs new file mode 100644 index 0000000..84e2c84 --- /dev/null +++ b/tests/provider_schema.rs @@ -0,0 +1,319 @@ +//! TG7: Provider Schema Conformance Tests +//! +//! Prevents: Pattern 7 — External schema compatibility bugs (7% of user bugs). +//! Issues: #769, #843 +//! +//! Tests request/response serialization to verify required fields are present +//! for each provider's API specification. Validates ChatMessage, ChatResponse, +//! ToolCall, and AuthStyle serialization contracts. + +use zeroclaw::providers::compatible::AuthStyle; +use zeroclaw::providers::traits::{ChatMessage, ChatResponse, ToolCall}; + +// ───────────────────────────────────────────────────────────────────────────── +// ChatMessage serialization +// ───────────────────────────────────────────────────────────────────────────── + +#[test] +fn chat_message_system_role_correct() { + let msg = ChatMessage::system("You are a helpful assistant"); + assert_eq!(msg.role, "system"); + assert_eq!(msg.content, "You are a helpful assistant"); +} + +#[test] +fn chat_message_user_role_correct() { + let msg = ChatMessage::user("Hello"); + assert_eq!(msg.role, "user"); + assert_eq!(msg.content, "Hello"); +} + +#[test] +fn chat_message_assistant_role_correct() { + let msg = ChatMessage::assistant("Hi there!"); + assert_eq!(msg.role, "assistant"); + assert_eq!(msg.content, "Hi there!"); +} + +#[test] +fn chat_message_tool_role_correct() { + let msg = ChatMessage::tool("tool result"); + assert_eq!(msg.role, "tool"); + assert_eq!(msg.content, "tool result"); +} + +#[test] +fn chat_message_serializes_to_json_with_required_fields() { + let msg = ChatMessage::user("test message"); + let json = serde_json::to_value(&msg).unwrap(); + + assert!(json.get("role").is_some(), "JSON must have 'role' field"); + assert!( + json.get("content").is_some(), + "JSON must have 'content' field" + ); + assert_eq!(json["role"], "user"); + assert_eq!(json["content"], "test message"); +} + +#[test] +fn chat_message_json_roundtrip() { + let original = ChatMessage::assistant("response text"); + let json_str = serde_json::to_string(&original).unwrap(); + let parsed: ChatMessage = serde_json::from_str(&json_str).unwrap(); + + assert_eq!(parsed.role, original.role); + assert_eq!(parsed.content, original.content); +} + +// ───────────────────────────────────────────────────────────────────────────── +// ToolCall serialization (#843 - tool_call_id field) +// ───────────────────────────────────────────────────────────────────────────── + +#[test] +fn tool_call_has_required_fields() { + let tc = ToolCall { + id: "call_abc123".into(), + name: "web_search".into(), + arguments: r#"{"query": "rust programming"}"#.into(), + }; + + let json = serde_json::to_value(&tc).unwrap(); + assert!(json.get("id").is_some(), "ToolCall must have 'id' field"); + assert!( + json.get("name").is_some(), + "ToolCall must have 'name' field" + ); + assert!( + json.get("arguments").is_some(), + "ToolCall must have 'arguments' field" + ); +} + +#[test] +fn tool_call_id_preserved_in_serialization() { + let tc = ToolCall { + id: "call_deepseek_42".into(), + name: "shell".into(), + arguments: r#"{"command": "ls"}"#.into(), + }; + + let json_str = serde_json::to_string(&tc).unwrap(); + let parsed: ToolCall = serde_json::from_str(&json_str).unwrap(); + + assert_eq!( + parsed.id, "call_deepseek_42", + "tool_call_id must survive roundtrip" + ); + assert_eq!(parsed.name, "shell"); +} + +#[test] +fn tool_call_arguments_contain_valid_json() { + let tc = ToolCall { + id: "call_1".into(), + name: "file_write".into(), + arguments: r#"{"path": "/tmp/test.txt", "content": "hello"}"#.into(), + }; + + // Arguments should parse as valid JSON + let args: serde_json::Value = + serde_json::from_str(&tc.arguments).expect("tool call arguments should be valid JSON"); + assert!(args.get("path").is_some()); + assert!(args.get("content").is_some()); +} + +// ───────────────────────────────────────────────────────────────────────────── +// Tool message with tool_call_id (DeepSeek requirement) +// ───────────────────────────────────────────────────────────────────────────── + +#[test] +fn tool_response_message_can_embed_tool_call_id() { + // DeepSeek requires tool_call_id in tool response messages. + // The tool message content can embed the tool_call_id as JSON. + let tool_response = + ChatMessage::tool(r#"{"tool_call_id": "call_abc123", "content": "search results here"}"#); + + let parsed: serde_json::Value = serde_json::from_str(&tool_response.content) + .expect("tool response content should be valid JSON"); + + assert!( + parsed.get("tool_call_id").is_some(), + "tool response should include tool_call_id for DeepSeek compatibility" + ); + assert_eq!(parsed["tool_call_id"], "call_abc123"); +} + +// ───────────────────────────────────────────────────────────────────────────── +// ChatResponse structure +// ───────────────────────────────────────────────────────────────────────────── + +#[test] +fn chat_response_text_only() { + let resp = ChatResponse { + text: Some("Hello world".into()), + tool_calls: vec![], + }; + + assert_eq!(resp.text_or_empty(), "Hello world"); + assert!(!resp.has_tool_calls()); +} + +#[test] +fn chat_response_with_tool_calls() { + let resp = ChatResponse { + text: Some(String::new()), + tool_calls: vec![ToolCall { + id: "tc_1".into(), + name: "echo".into(), + arguments: "{}".into(), + }], + }; + + assert!(resp.has_tool_calls()); + assert_eq!(resp.tool_calls.len(), 1); + assert_eq!(resp.tool_calls[0].name, "echo"); +} + +#[test] +fn chat_response_text_or_empty_handles_none() { + let resp = ChatResponse { + text: None, + tool_calls: vec![], + }; + + assert_eq!(resp.text_or_empty(), ""); +} + +#[test] +fn chat_response_multiple_tool_calls() { + let resp = ChatResponse { + text: None, + tool_calls: vec![ + ToolCall { + id: "tc_1".into(), + name: "shell".into(), + arguments: r#"{"command": "ls"}"#.into(), + }, + ToolCall { + id: "tc_2".into(), + name: "file_read".into(), + arguments: r#"{"path": "test.txt"}"#.into(), + }, + ], + }; + + assert!(resp.has_tool_calls()); + assert_eq!(resp.tool_calls.len(), 2); + // Each tool call should have a distinct id + assert_ne!(resp.tool_calls[0].id, resp.tool_calls[1].id); +} + +// ───────────────────────────────────────────────────────────────────────────── +// AuthStyle variants +// ───────────────────────────────────────────────────────────────────────────── + +#[test] +fn auth_style_bearer_is_constructible() { + let style = AuthStyle::Bearer; + assert!(matches!(style, AuthStyle::Bearer)); +} + +#[test] +fn auth_style_xapikey_is_constructible() { + let style = AuthStyle::XApiKey; + assert!(matches!(style, AuthStyle::XApiKey)); +} + +#[test] +fn auth_style_custom_header() { + let style = AuthStyle::Custom("X-Custom-Auth".into()); + if let AuthStyle::Custom(header) = style { + assert_eq!(header, "X-Custom-Auth"); + } else { + panic!("expected AuthStyle::Custom"); + } +} + +// ───────────────────────────────────────────────────────────────────────────── +// Provider naming consistency +// ───────────────────────────────────────────────────────────────────────────── + +#[test] +fn provider_construction_with_different_names() { + use zeroclaw::providers::compatible::OpenAiCompatibleProvider; + + // Construction with various names should succeed + let _p1 = OpenAiCompatibleProvider::new( + "DeepSeek", + "https://api.deepseek.com", + Some("test-key"), + AuthStyle::Bearer, + ); + let _p2 = + OpenAiCompatibleProvider::new("deepseek", "https://api.test.com", None, AuthStyle::Bearer); +} + +#[test] +fn provider_construction_with_different_auth_styles() { + use zeroclaw::providers::compatible::OpenAiCompatibleProvider; + + let _bearer = OpenAiCompatibleProvider::new( + "Test", + "https://api.test.com", + Some("key"), + AuthStyle::Bearer, + ); + let _xapi = OpenAiCompatibleProvider::new( + "Test", + "https://api.test.com", + Some("key"), + AuthStyle::XApiKey, + ); + let _custom = OpenAiCompatibleProvider::new( + "Test", + "https://api.test.com", + Some("key"), + AuthStyle::Custom("X-My-Auth".into()), + ); +} + +// ───────────────────────────────────────────────────────────────────────────── +// Conversation history message ordering +// ───────────────────────────────────────────────────────────────────────────── + +#[test] +fn chat_messages_maintain_role_sequence() { + let history = vec![ + ChatMessage::system("You are helpful"), + ChatMessage::user("What is Rust?"), + ChatMessage::assistant("Rust is a systems programming language"), + ChatMessage::user("Tell me more"), + ChatMessage::assistant("It emphasizes safety and performance"), + ]; + + assert_eq!(history[0].role, "system"); + assert_eq!(history[1].role, "user"); + assert_eq!(history[2].role, "assistant"); + assert_eq!(history[3].role, "user"); + assert_eq!(history[4].role, "assistant"); +} + +#[test] +fn chat_messages_with_tool_calls_maintain_sequence() { + let history = vec![ + ChatMessage::system("You are helpful"), + ChatMessage::user("Search for Rust"), + ChatMessage::assistant("I'll search for that"), + ChatMessage::tool(r#"{"tool_call_id": "tc_1", "content": "search results"}"#), + ChatMessage::assistant("Based on the search results..."), + ]; + + assert_eq!(history.len(), 5); + assert_eq!(history[3].role, "tool"); + assert_eq!(history[4].role, "assistant"); + + // Verify tool message content is valid JSON with tool_call_id + let tool_content: serde_json::Value = serde_json::from_str(&history[3].content).unwrap(); + assert!(tool_content.get("tool_call_id").is_some()); +} diff --git a/zeroclaw_install.sh b/zeroclaw_install.sh new file mode 100755 index 0000000..4279e1a --- /dev/null +++ b/zeroclaw_install.sh @@ -0,0 +1,88 @@ +#!/usr/bin/env sh +set -eu + +have_cmd() { + command -v "$1" >/dev/null 2>&1 +} + +run_privileged() { + if [ "$(id -u)" -eq 0 ]; then + "$@" + elif have_cmd sudo; then + sudo "$@" + else + echo "error: sudo is required to install missing dependencies." >&2 + exit 1 + fi +} + +is_container_runtime() { + if [ -f /.dockerenv ] || [ -f /run/.containerenv ]; then + return 0 + fi + + if [ -r /proc/1/cgroup ] && grep -Eq '(docker|containerd|kubepods|podman|lxc)' /proc/1/cgroup; then + return 0 + fi + + return 1 +} + +run_pacman() { + if ! is_container_runtime; then + run_privileged pacman "$@" + return $? + fi + + PACMAN_CFG_TMP="$(mktemp /tmp/zeroclaw-pacman.XXXXXX.conf)" + cp /etc/pacman.conf "$PACMAN_CFG_TMP" + if ! grep -Eq '^[[:space:]]*DisableSandboxSyscalls([[:space:]]|$)' "$PACMAN_CFG_TMP"; then + printf '\nDisableSandboxSyscalls\n' >> "$PACMAN_CFG_TMP" + fi + + if run_privileged pacman --config "$PACMAN_CFG_TMP" "$@"; then + PACMAN_RC=0 + else + PACMAN_RC=$? + fi + rm -f "$PACMAN_CFG_TMP" + return "$PACMAN_RC" +} + +ensure_bash() { + if have_cmd bash; then + return 0 + fi + + echo "==> bash not found; attempting to install it" + if have_cmd apk; then + run_privileged apk add --no-cache bash + elif have_cmd apt-get; then + run_privileged apt-get update -qq + run_privileged apt-get install -y bash + elif have_cmd dnf; then + run_privileged dnf install -y bash + elif have_cmd pacman; then + run_pacman -Sy --noconfirm + run_pacman -S --noconfirm --needed bash + else + echo "error: unsupported package manager; install bash manually and retry." >&2 + exit 1 + fi +} + +ROOT_DIR="$(CDPATH= cd -- "$(dirname -- "$0")" >/dev/null 2>&1 && pwd || pwd)" +BOOTSTRAP_SCRIPT="$ROOT_DIR/scripts/bootstrap.sh" + +if [ ! -f "$BOOTSTRAP_SCRIPT" ]; then + echo "error: scripts/bootstrap.sh not found from repository root." >&2 + exit 1 +fi + +ensure_bash + +if [ "$#" -eq 0 ]; then + exec bash "$BOOTSTRAP_SCRIPT" --guided +fi + +exec bash "$BOOTSTRAP_SCRIPT" "$@"