diff --git a/.env.example b/.env.example index 17686d3..7a2c253 100644 --- a/.env.example +++ b/.env.example @@ -1,25 +1,69 @@ # ZeroClaw Environment Variables -# Copy this file to .env and fill in your values. -# NEVER commit .env — it is listed in .gitignore. +# Copy this file to `.env` and fill in your local values. +# Never commit `.env` or any real secrets. -# ── Required ────────────────────────────────────────────────── -# Your LLM provider API key -# ZEROCLAW_API_KEY=sk-your-key-here +# ── Core Runtime ────────────────────────────────────────────── +# Provider key resolution at runtime: +# 1) explicit key passed from config/CLI +# 2) provider-specific env var (OPENROUTER_API_KEY, OPENAI_API_KEY, ...) +# 3) generic fallback env vars below + +# Generic fallback API key (used when provider-specific key is absent) API_KEY=your-api-key-here +# ZEROCLAW_API_KEY=your-api-key-here -# ── Provider & Model ───────────────────────────────────────── -# LLM provider: openrouter, openai, anthropic, ollama, glm +# Default provider/model (can be overridden by CLI flags) PROVIDER=openrouter +# ZEROCLAW_PROVIDER=openrouter # ZEROCLAW_MODEL=anthropic/claude-sonnet-4-20250514 # ZEROCLAW_TEMPERATURE=0.7 +# Workspace directory override +# ZEROCLAW_WORKSPACE=/path/to/workspace + +# ── Provider-Specific API Keys ──────────────────────────────── +# OpenRouter +# OPENROUTER_API_KEY=sk-or-v1-... + +# Anthropic +# ANTHROPIC_OAUTH_TOKEN=... +# ANTHROPIC_API_KEY=sk-ant-... + +# OpenAI / Gemini +# OPENAI_API_KEY=sk-... +# GEMINI_API_KEY=... +# GOOGLE_API_KEY=... + +# Other supported providers +# VENICE_API_KEY=... +# GROQ_API_KEY=... +# MISTRAL_API_KEY=... +# DEEPSEEK_API_KEY=... +# XAI_API_KEY=... +# TOGETHER_API_KEY=... +# FIREWORKS_API_KEY=... +# PERPLEXITY_API_KEY=... +# COHERE_API_KEY=... +# MOONSHOT_API_KEY=... +# GLM_API_KEY=... +# MINIMAX_API_KEY=... +# QIANFAN_API_KEY=... +# DASHSCOPE_API_KEY=... +# ZAI_API_KEY=... +# SYNTHETIC_API_KEY=... +# OPENCODE_API_KEY=... +# VERCEL_API_KEY=... +# CLOUDFLARE_API_KEY=... + # ── Gateway ────────────────────────────────────────────────── # ZEROCLAW_GATEWAY_PORT=3000 # ZEROCLAW_GATEWAY_HOST=127.0.0.1 # ZEROCLAW_ALLOW_PUBLIC_BIND=false -# ── Workspace ──────────────────────────────────────────────── -# ZEROCLAW_WORKSPACE=/path/to/workspace +# ── Optional Integrations ──────────────────────────────────── +# Pushover notifications (`pushover` tool) +# PUSHOVER_TOKEN=your-pushover-app-token +# PUSHOVER_USER_KEY=your-pushover-user-key # ── Docker Compose ─────────────────────────────────────────── # Host port mapping (used by docker-compose.yml) diff --git a/.githooks/pre-commit b/.githooks/pre-commit new file mode 100755 index 0000000..d162ba3 --- /dev/null +++ b/.githooks/pre-commit @@ -0,0 +1,8 @@ +#!/usr/bin/env bash +set -euo pipefail + +if command -v gitleaks >/dev/null 2>&1; then + gitleaks protect --staged --redact +else + echo "warning: gitleaks not found; skipping staged secret scan" >&2 +fi diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md index 550bd95..7c9e601 100644 --- a/.github/pull_request_template.md +++ b/.github/pull_request_template.md @@ -12,7 +12,11 @@ Describe this PR in 2-5 bullets: - Risk label (`risk: low|medium|high`): - Size label (`size: XS|S|M|L|XL`, auto-managed/read-only): - Scope labels (`core|agent|channel|config|cron|daemon|doctor|gateway|health|heartbeat|integration|memory|observability|onboard|provider|runtime|security|service|skillforge|skills|tool|tunnel|docs|dependencies|ci|tests|scripts|dev`, comma-separated): +<<<<<<< chore/labeler-spacing-trusted-tier +- Module labels (`: `, for example `channel: telegram`, `provider: kimi`, `tool: shell`): +======= - Module labels (`:`, for example `channel:telegram`, `provider:kimi`, `tool:shell`): +>>>>>>> main - Contributor tier label (`trusted contributor|experienced contributor|principal contributor|distinguished contributor`, auto-managed/read-only; author merged PRs >=5/10/20/50): - If any auto-label is incorrect, note requested correction: diff --git a/.github/workflows/auto-response.yml b/.github/workflows/auto-response.yml index 4398085..753bb52 100644 --- a/.github/workflows/auto-response.yml +++ b/.github/workflows/auto-response.yml @@ -18,6 +18,7 @@ jobs: runs-on: blacksmith-2vcpu-ubuntu-2404 permissions: issues: write + pull-requests: write steps: - name: Apply contributor tier label for issue author uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8 diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml index 63ea2ad..67005c6 100644 --- a/.github/workflows/docker.yml +++ b/.github/workflows/docker.yml @@ -35,7 +35,7 @@ jobs: uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4 - name: Setup Blacksmith Builder - uses: useblacksmith/setup-docker-builder@v1 + uses: useblacksmith/setup-docker-builder@ef12d5b165b596e3aa44ea8198d8fde563eab402 # v1 - name: Extract metadata (tags, labels) id: meta @@ -46,7 +46,7 @@ jobs: type=ref,event=pr - name: Build smoke image - uses: useblacksmith/build-push-action@v2 + uses: useblacksmith/build-push-action@30c71162f16ea2c27c3e21523255d209b8b538c1 # v2 with: context: . push: false @@ -71,7 +71,7 @@ jobs: uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4 - name: Setup Blacksmith Builder - uses: useblacksmith/setup-docker-builder@v1 + uses: useblacksmith/setup-docker-builder@ef12d5b165b596e3aa44ea8198d8fde563eab402 # v1 - name: Log in to Container Registry uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # v3 @@ -102,7 +102,7 @@ jobs: echo "tags=${TAGS}" >> "$GITHUB_OUTPUT" - name: Build and push Docker image - uses: useblacksmith/build-push-action@v2 + uses: useblacksmith/build-push-action@30c71162f16ea2c27c3e21523255d209b8b538c1 # v2 with: context: . push: true diff --git a/.github/workflows/labeler.yml b/.github/workflows/labeler.yml index d629a1f..10d8bfb 100644 --- a/.github/workflows/labeler.yml +++ b/.github/workflows/labeler.yml @@ -325,13 +325,18 @@ jobs: return pattern.test(text); } + function formatModuleLabel(prefix, segment) { + return `${prefix}: ${segment}`; + } + function parseModuleLabel(label) { - const separatorIndex = label.indexOf(":"); - if (separatorIndex <= 0 || separatorIndex >= label.length - 1) return null; - return { - prefix: label.slice(0, separatorIndex), - segment: label.slice(separatorIndex + 1), - }; + if (typeof label !== "string") return null; + const match = label.match(/^([^:]+):\s*(.+)$/); + if (!match) return null; + const prefix = match[1].trim().toLowerCase(); + const segment = (match[2] || "").trim().toLowerCase(); + if (!prefix || !segment) return null; + return { prefix, segment }; } function sortByPriority(labels, priorityIndex) { @@ -389,7 +394,7 @@ jobs: for (const [prefix, segments] of segmentsByPrefix) { const hasSpecificSegment = [...segments].some((segment) => segment !== "core"); if (hasSpecificSegment) { - refined.delete(`${prefix}:core`); + refined.delete(formatModuleLabel(prefix, "core")); } } @@ -418,7 +423,7 @@ jobs: if (uniqueSegments.length === 0) continue; if (uniqueSegments.length === 1) { - compactedModuleLabels.add(`${prefix}:${uniqueSegments[0]}`); + compactedModuleLabels.add(formatModuleLabel(prefix, uniqueSegments[0])); } else { forcePathPrefixes.add(prefix); } @@ -609,7 +614,7 @@ jobs: segment = normalizeLabelSegment(segment); if (!segment) continue; - detectedModuleLabels.add(`${rule.prefix}:${segment}`); + detectedModuleLabels.add(formatModuleLabel(rule.prefix, segment)); } } @@ -635,7 +640,7 @@ jobs: for (const keyword of providerKeywordHints) { if (containsKeyword(searchableText, keyword)) { - detectedModuleLabels.add(`provider:${keyword}`); + detectedModuleLabels.add(formatModuleLabel("provider", keyword)); } } } @@ -661,7 +666,7 @@ jobs: for (const keyword of channelKeywordHints) { if (containsKeyword(searchableText, keyword)) { - detectedModuleLabels.add(`channel:${keyword}`); + detectedModuleLabels.add(formatModuleLabel("channel", keyword)); } } } diff --git a/.gitignore b/.gitignore index 49980c2..9440b79 100644 --- a/.gitignore +++ b/.gitignore @@ -4,6 +4,26 @@ firmware/*/target *.db-journal .DS_Store .wt-pr37/ -.env __pycache__/ *.pyc +docker-compose.override.yml + +# Environment files (may contain secrets) +.env + +# Python virtual environments + +.venv/ +venv/ + +# ESP32 build cache (esp-idf-sys managed) + +.embuild/ +.env.local +.env.*.local + +# Secret keys and credentials +.secret_key +*.key +*.pem +credentials.json diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index a25ad4e..d98a2ce 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -79,6 +79,94 @@ git push --no-verify > **Note:** CI runs the same checks, so skipped hooks will be caught on the PR. +## Local Secret Management (Required) + +ZeroClaw supports layered secret management for local development and CI hygiene. + +### Secret Storage Options + +1. **Environment variables** (recommended for local development) + - Copy `.env.example` to `.env` and fill in values + - `.env` files are Git-ignored and should stay local + - Best for temporary/local API keys + +2. **Config file** (`~/.zeroclaw/config.toml`) + - Persistent setup for long-term use + - When `secrets.encrypt = true` (default), secret values are encrypted before save + - Secret key is stored at `~/.zeroclaw/.secret_key` with restricted permissions + - Use `zeroclaw onboard` for guided setup + +### Runtime Resolution Rules + +API key resolution follows this order: + +1. Explicit key passed from config/CLI +2. Provider-specific env vars (`OPENROUTER_API_KEY`, `OPENAI_API_KEY`, `ANTHROPIC_API_KEY`, ...) +3. Generic env vars (`ZEROCLAW_API_KEY`, `API_KEY`) + +Provider/model config overrides: + +- `ZEROCLAW_PROVIDER` / `PROVIDER` +- `ZEROCLAW_MODEL` + +See `.env.example` for practical examples and currently supported provider key env vars. + +### Pre-Commit Secret Hygiene (Mandatory) + +Before every commit, verify: + +- [ ] No `.env` files are staged (`.env.example` only) +- [ ] No raw API keys/tokens in code, tests, fixtures, examples, logs, or commit messages +- [ ] No credentials in debug output or error payloads +- [ ] `git diff --cached` has no accidental secret-like strings + +Quick local audit: + +```bash +# Search staged diff for common secret markers +git diff --cached | grep -iE '(api[_-]?key|secret|token|password|bearer|sk-)' + +# Confirm no .env file is staged +git status --short | grep -E '\.env$' +``` + +### Optional Local Secret Scanning + +For extra guardrails, install one of: + +- **gitleaks**: [GitHub - gitleaks/gitleaks](https://github.com/gitleaks/gitleaks) +- **truffleHog**: [GitHub - trufflesecurity/trufflehog](https://github.com/trufflesecurity/trufflehog) +- **git-secrets**: [GitHub - awslabs/git-secrets](https://github.com/awslabs/git-secrets) + +This repo includes `.githooks/pre-commit` to run `gitleaks protect --staged --redact` when gitleaks is installed. + +Enable hooks with: + +```bash +git config core.hooksPath .githooks +``` + +If gitleaks is not installed, the pre-commit hook prints a warning and continues. + +### What Must Never Be Committed + +- `.env` files (use `.env.example` only) +- API keys, tokens, passwords, or credentials (plain or encrypted) +- OAuth tokens or session identifiers +- Webhook signing secrets +- `~/.zeroclaw/.secret_key` or similar key files +- Personal identifiers or real user data in tests/fixtures + +### If a Secret Is Committed Accidentally + +1. Revoke/rotate the credential immediately +2. Do not rely only on `git revert` (history still contains the secret) +3. Purge history with `git filter-repo` or BFG +4. Force-push cleaned history (coordinate with maintainers) +5. Ensure the leaked value is removed from PR/issue/discussion/comment history + +Reference: [GitHub guide: removing sensitive data from a repository](https://docs.github.com/en/authentication/keeping-your-account-and-data-secure/removing-sensitive-data-from-a-repository) + ## Collaboration Tracks (Risk-Based) To keep review throughput high without lowering quality, every PR should map to one track: diff --git a/Cargo.lock b/Cargo.lock index d940f9f..e19c5c9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -209,6 +209,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b52af3cb4058c895d37317bb27508dccc8e5f2d39454016b297bf4a400597b8" dependencies = [ "axum-core", + "base64", "bytes", "form_urlencoded", "futures-util", @@ -227,8 +228,10 @@ dependencies = [ "serde_json", "serde_path_to_error", "serde_urlencoded", + "sha1", "sync_wrapper", "tokio", + "tokio-tungstenite 0.28.0", "tower", "tower-layer", "tower-service", @@ -2057,6 +2060,15 @@ dependencies = [ "hashify", ] +[[package]] +name = "matchers" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d1525a2a28c7f4fa0fc98bb91ae755d1e2d1505079e05539e35bc876b5d65ae9" +dependencies = [ + "regex-automata", +] + [[package]] name = "matchit" version = "0.8.4" @@ -3747,10 +3759,22 @@ dependencies = [ "rustls-pki-types", "tokio", "tokio-rustls", - "tungstenite", + "tungstenite 0.24.0", "webpki-roots 0.26.11", ] +[[package]] +name = "tokio-tungstenite" +version = "0.28.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d25a406cddcc431a75d3d9afc6a7c0f7428d4891dd973e4d54c56b46127bf857" +dependencies = [ + "futures-util", + "log", + "tokio", + "tungstenite 0.28.0", +] + [[package]] name = "tokio-util" version = "0.7.18" @@ -3940,9 +3964,13 @@ version = "0.3.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2f30143827ddab0d256fd843b7a66d164e9f271cfa0dde49142c5ca0ca291f1e" dependencies = [ + "matchers", "nu-ansi-term", + "once_cell", + "regex-automata", "sharded-slab", "thread_local", + "tracing", "tracing-core", ] @@ -3978,6 +4006,23 @@ dependencies = [ "utf-8", ] +[[package]] +name = "tungstenite" +version = "0.28.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8628dcc84e5a09eb3d8423d6cb682965dea9133204e8fb3efee74c2a0c259442" +dependencies = [ + "bytes", + "data-encoding", + "http 1.4.0", + "httparse", + "log", + "rand 0.9.2", + "sha1", + "thiserror 2.0.18", + "utf-8", +] + [[package]] name = "twox-hash" version = "2.1.2" @@ -4880,7 +4925,9 @@ dependencies = [ "pdf-extract", "probe-rs", "prometheus", + "prost", "rand 0.8.5", + "regex", "reqwest", "rppal", "rusqlite", @@ -4896,7 +4943,7 @@ dependencies = [ "tokio-rustls", "tokio-serial", "tokio-test", - "tokio-tungstenite", + "tokio-tungstenite 0.24.0", "toml", "tower", "tower-http", diff --git a/Cargo.toml b/Cargo.toml index 79dcdfe..15d4665 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,3 +1,7 @@ +[workspace] +members = ["."] +resolver = "2" + [package] name = "zeroclaw" version = "0.1.0" @@ -31,7 +35,7 @@ shellexpand = "3.1" # Logging - minimal tracing = { version = "0.1", default-features = false } -tracing-subscriber = { version = "0.3", default-features = false, features = ["fmt", "ansi"] } +tracing-subscriber = { version = "0.3", default-features = false, features = ["fmt", "ansi", "env-filter"] } # Observability - Prometheus metrics prometheus = { version = "0.14", default-features = false } @@ -63,12 +67,12 @@ rand = "0.8" # Fast mutexes that don't poison on panic parking_lot = "0.12" -# Landlock (Linux sandbox) - optional dependency -landlock = { version = "0.4", optional = true } - # Async traits async-trait = "0.1" +# Protobuf encode/decode (Feishu WS long-connection frame codec) +prost = { version = "0.14", default-features = false } + # Memory / persistence rusqlite = { version = "0.38", features = ["bundled"] } chrono = { version = "0.4", default-features = false, features = ["clock", "std", "serde"] } @@ -86,6 +90,7 @@ glob = "0.3" tokio-tungstenite = { version = "0.24", features = ["rustls-tls-webpki-roots"] } futures-util = { version = "0.3", default-features = false, features = ["sink"] } futures = "0.3" +regex = "1.10" hostname = "0.4.2" lettre = { version = "0.11.19", default-features = false, features = ["builder", "smtp-transport", "rustls-tls"] } mail-parser = "0.11.2" @@ -95,7 +100,7 @@ tokio-rustls = "0.26.4" webpki-roots = "1.0.6" # HTTP server (gateway) — replaces raw TCP for proper HTTP/1.1 compliance -axum = { version = "0.8", default-features = false, features = ["http1", "json", "tokio", "query"] } +axum = { version = "0.8", default-features = false, features = ["http1", "json", "tokio", "query", "ws"] } tower = { version = "0.5", default-features = false } tower-http = { version = "0.6", default-features = false, features = ["limit", "timeout"] } http-body-util = "0.1" @@ -117,19 +122,28 @@ probe-rs = { version = "0.30", optional = true } # PDF extraction for datasheet RAG (optional, enable with --features rag-pdf) pdf-extract = { version = "0.10", optional = true } -# Raspberry Pi GPIO (Linux/RPi only) — target-specific to avoid compile failure on macOS +# Raspberry Pi GPIO / Landlock (Linux only) — target-specific to avoid compile failure on macOS [target.'cfg(target_os = "linux")'.dependencies] rppal = { version = "0.14", optional = true } +landlock = { version = "0.4", optional = true } [features] default = ["hardware"] hardware = ["nusb", "tokio-serial"] peripheral-rpi = ["rppal"] +# Browser backend feature alias used by cfg(feature = "browser-native") +browser-native = ["dep:fantoccini"] +# Backward-compatible alias for older invocations +fantoccini = ["browser-native"] +# Sandbox feature aliases used by cfg(feature = "sandbox-*") +sandbox-landlock = ["dep:landlock"] +sandbox-bubblewrap = [] +# Backward-compatible alias for older invocations +landlock = ["sandbox-landlock"] # probe = probe-rs for Nucleo memory read (adds ~50 deps; optional) probe = ["dep:probe-rs"] # rag-pdf = PDF ingestion for datasheet RAG rag-pdf = ["dep:pdf-extract"] - [profile.release] opt-level = "z" # Optimize for size lto = "thin" # Lower memory use during release builds diff --git a/LICENSE b/LICENSE index 9d0e27e..349c342 100644 --- a/LICENSE +++ b/LICENSE @@ -1,197 +1,28 @@ - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ +MIT License - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION +Copyright (c) 2025 ZeroClaw Labs - 1. Definitions. +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. +================================================================================ - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. +This product includes software developed by ZeroClaw Labs and contributors: +https://github.com/zeroclaw-labs/zeroclaw/graphs/contributors - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to the Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - Copyright 2025-2026 Argenis Delarosa - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. - - =============================================================================== - - This product includes software developed by ZeroClaw Labs and contributors: - https://github.com/zeroclaw-labs/zeroclaw/graphs/contributors - - See NOTICE file for full contributor attribution. +See NOTICE file for full contributor attribution. diff --git a/README.md b/README.md index 9031482..ec87d47 100644 --- a/README.md +++ b/README.md @@ -10,14 +10,14 @@

- License: Apache 2.0 + License: MIT Contributors

Fast, small, and fully autonomous AI assistant infrastructure — deploy anywhere, swap anything. ``` -~3.4MB binary · <10ms startup · 1,017 tests · 22+ providers · 8 traits · Pluggable everything +~3.4MB binary · <10ms startup · 1,017 tests · 23+ providers · 8 traits · Pluggable everything ``` ### ✨ Features @@ -132,6 +132,9 @@ cd zeroclaw cargo build --release --locked cargo install --path . --force --locked +# Ensure ~/.cargo/bin is in your PATH +export PATH="$HOME/.cargo/bin:$PATH" + # Quick setup (no prompts) zeroclaw onboard --api-key sk-... --provider openrouter @@ -187,7 +190,7 @@ Every subsystem is a **trait** — swap implementations with a config change, ze | Subsystem | Trait | Ships with | Extend | |-----------|-------|------------|--------| -| **AI Models** | `Provider` | 22+ providers (OpenRouter, Anthropic, OpenAI, Ollama, Venice, Groq, Mistral, xAI, DeepSeek, Together, Fireworks, Perplexity, Cohere, Bedrock, etc.) | `custom:https://your-api.com` — any OpenAI-compatible API | +| **AI Models** | `Provider` | 23+ providers (OpenRouter, Anthropic, OpenAI, Ollama, Venice, Groq, Mistral, xAI, DeepSeek, Together, Fireworks, Perplexity, Cohere, Bedrock, Astrai, etc.) | `custom:https://your-api.com` — any OpenAI-compatible API | | **Channels** | `Channel` | CLI, Telegram, Discord, Slack, iMessage, Matrix, WhatsApp, Webhook | Any messaging API | | **Memory** | `Memory` | SQLite with hybrid search (FTS5 + vector cosine similarity), Lucid bridge (CLI sync + SQLite fallback), Markdown | Any persistence backend | | **Tools** | `Tool` | shell, file_read, file_write, memory_store, memory_recall, memory_forget, browser_open (Brave + allowlist), browser (agent-browser / rust-native), composio (optional) | Any capability | @@ -287,6 +290,21 @@ rerun channel setup only: zeroclaw onboard --channels-only ``` +### Telegram media replies + +Telegram routing now replies to the source **chat ID** from incoming updates (instead of usernames), +which avoids `Bad Request: chat not found` failures. + +For non-text replies, ZeroClaw can send Telegram attachments when the assistant includes markers: + +- `[IMAGE:]` +- `[DOCUMENT:]` +- `[VIDEO:]` +- `[AUDIO:]` +- `[VOICE:]` + +Paths can be local files (for example `/tmp/screenshot.png`) or HTTPS URLs. + ### WhatsApp Business Cloud API Setup WhatsApp uses Meta's Cloud API with webhooks (push-based, not polling): @@ -610,7 +628,7 @@ We're building in the open because the best ideas come from everywhere. If you'r ## License -Apache 2.0 — see [LICENSE](LICENSE) and [NOTICE](NOTICE) for contributor attribution +MIT — see [LICENSE](LICENSE) and [NOTICE](NOTICE) for contributor attribution ## Contributing @@ -624,7 +642,6 @@ See [CONTRIBUTING.md](CONTRIBUTING.md). Implement a trait, submit a PR: - New `Tunnel` → `src/tunnel/` - New `Skill` → `~/.zeroclaw/workspace/skills//` - --- **ZeroClaw** — Zero overhead. Zero compromise. Deploy anywhere. Swap anything. 🦀 diff --git a/dev/sandbox/Dockerfile b/dev/sandbox/Dockerfile index 59ddf05..6b81a7a 100644 --- a/dev/sandbox/Dockerfile +++ b/dev/sandbox/Dockerfile @@ -1,4 +1,4 @@ -FROM ubuntu:22.04 +FROM ubuntu:22.04@sha256:c7eb020043d8fc2ae0793fb35a37bff1cf33f156d4d4b12ccc7f3ef8706c38b1 # Prevent interactive prompts during package installation ENV DEBIAN_FRONTEND=noninteractive diff --git a/docs/ci-map.md b/docs/ci-map.md index 108a9d0..6a2260d 100644 --- a/docs/ci-map.md +++ b/docs/ci-map.md @@ -27,7 +27,7 @@ Merge-blocking checks should stay small and deterministic. Optional checks are u ### Optional Repository Automation - `.github/workflows/labeler.yml` (`PR Labeler`) - - Purpose: scope/path labels + size/risk labels + fine-grained module labels (`:`) + - Purpose: scope/path labels + size/risk labels + fine-grained module labels (`: `) - Additional behavior: label descriptions are auto-managed as hover tooltips to explain each auto-judgment rule - Additional behavior: provider-related keywords in provider/config/onboard/integration changes are promoted to `provider:*` labels (for example `provider:kimi`, `provider:deepseek`) - Additional behavior: hierarchical de-duplication keeps only the most specific scope labels (for example `tool:composio` suppresses `tool:core` and `tool`) diff --git a/docs/pr-workflow.md b/docs/pr-workflow.md index 3c62711..2c154ef 100644 --- a/docs/pr-workflow.md +++ b/docs/pr-workflow.md @@ -244,7 +244,7 @@ Label discipline: - Path labels identify subsystem ownership quickly. - Size labels drive batching strategy. - Risk labels drive review depth (`risk: low/medium/high`). -- Module labels (`:`) improve reviewer routing for integration-specific changes and future newly-added modules. +- Module labels (`: `) improve reviewer routing for integration-specific changes and future newly-added modules. - `risk: manual` allows maintainers to preserve a human risk judgment when automation lacks context. - `no-stale` is reserved for accepted-but-blocked work. diff --git a/docs/reviewer-playbook.md b/docs/reviewer-playbook.md index bc42509..6f72fea 100644 --- a/docs/reviewer-playbook.md +++ b/docs/reviewer-playbook.md @@ -14,7 +14,7 @@ Use it to reduce review latency without reducing quality. For every new PR, do a fast intake pass: 1. Confirm template completeness (`summary`, `validation`, `security`, `rollback`). -2. Confirm labels (`size:*`, `risk:*`, scope labels such as `provider`/`channel`/`security`, module-scoped labels such as `channel:*`/`provider:*`/`tool:*`, and contributor tier labels when applicable) are present and plausible. +2. Confirm labels (`size:*`, `risk:*`, scope labels such as `provider`/`channel`/`security`, module-scoped labels such as `channel: *`/`provider: *`/`tool: *`, and contributor tier labels when applicable) are present and plausible. 3. Confirm CI signal status (`CI Required Gate`). 4. Confirm scope is one concern (reject mixed mega-PRs unless justified). 5. Confirm privacy/data-hygiene and neutral test wording requirements are satisfied. diff --git a/examples/custom_channel.rs b/examples/custom_channel.rs index dd3fdf8..790762d 100644 --- a/examples/custom_channel.rs +++ b/examples/custom_channel.rs @@ -12,6 +12,8 @@ use tokio::sync::mpsc; pub struct ChannelMessage { pub id: String, pub sender: String, + /// Channel-specific reply address (e.g. Telegram chat_id, Discord channel_id). + pub reply_to: String, pub content: String, pub channel: String, pub timestamp: u64, @@ -90,9 +92,12 @@ impl Channel for TelegramChannel { continue; } + let chat_id = msg["chat"]["id"].to_string(); + let channel_msg = ChannelMessage { id: msg["message_id"].to_string(), sender, + reply_to: chat_id, content: msg["text"].as_str().unwrap_or("").to_string(), channel: "telegram".into(), timestamp: msg["date"].as_u64().unwrap_or(0), diff --git a/firmware/zeroclaw-esp32/.cargo/config.toml b/firmware/zeroclaw-esp32/.cargo/config.toml index 8746ad1..56dd71b 100644 --- a/firmware/zeroclaw-esp32/.cargo/config.toml +++ b/firmware/zeroclaw-esp32/.cargo/config.toml @@ -2,4 +2,10 @@ target = "riscv32imc-esp-espidf" [target.riscv32imc-esp-espidf] +linker = "ldproxy" runner = "espflash flash --monitor" +# ESP-IDF 5.x uses 64-bit time_t +rustflags = ["-C", "default-linker-libraries", "--cfg", "espidf_time64"] + +[unstable] +build-std = ["std", "panic_abort"] diff --git a/firmware/zeroclaw-esp32/Cargo.lock b/firmware/zeroclaw-esp32/Cargo.lock index 2580883..69e989b 100644 --- a/firmware/zeroclaw-esp32/Cargo.lock +++ b/firmware/zeroclaw-esp32/Cargo.lock @@ -58,24 +58,22 @@ checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" [[package]] name = "bindgen" -version = "0.63.0" +version = "0.71.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "36d860121800b2a9a94f9b5604b332d5cffb234ce17609ea479d723dbc9d3885" +checksum = "5f58bf3d7db68cfbac37cfc485a8d711e87e064c3d0fe0435b92f7a407f9d6b3" dependencies = [ - "bitflags 1.3.2", + "bitflags 2.11.0", "cexpr", "clang-sys", - "lazy_static", - "lazycell", + "itertools", "log", - "peeking_take_while", + "prettyplease", "proc-macro2", "quote", "regex", "rustc-hash", "shlex", - "syn 1.0.109", - "which", + "syn 2.0.116", ] [[package]] @@ -374,14 +372,15 @@ checksum = "dc2d050bdc5c21e0862a89256ed8029ae6c290a93aecefc73084b3002cdebb01" [[package]] name = "embassy-sync" -version = "0.5.0" +version = "0.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dd938f25c0798db4280fcd8026bf4c2f48789aebf8f77b6e5cf8a7693ba114ec" +checksum = "73974a3edbd0bd286759b3d483540f0ebef705919a5f56f4fc7709066f71689b" dependencies = [ "cfg-if", "critical-section", "embedded-io-async", - "futures-util", + "futures-core", + "futures-sink", "heapless", ] @@ -446,16 +445,15 @@ dependencies = [ [[package]] name = "embedded-svc" -version = "0.27.1" +version = "0.28.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ac6f87e7654f28018340aa55f933803017aefabaa5417820a3b2f808033c7bbc" +checksum = "a7770e30ab55cfbf954c00019522490d6ce26a3334bede05a732ba61010e98e0" dependencies = [ "defmt 0.3.100", "embedded-io", "embedded-io-async", "enumset", "heapless", - "no-std-net", "num_enum", "serde", "strum 0.25.0", @@ -463,9 +461,9 @@ dependencies = [ [[package]] name = "embuild" -version = "0.31.4" +version = "0.33.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4caa4f198bb9152a55c0103efb83fa4edfcbb8625f4c9e94ae8ec8e23827c563" +checksum = "e188ad2bbe82afa841ea4a29880651e53ab86815db036b2cb9f8de3ac32dad75" dependencies = [ "anyhow", "bindgen", @@ -475,6 +473,7 @@ dependencies = [ "globwalk", "home", "log", + "regex", "remove_dir_all", "serde", "serde_json", @@ -533,9 +532,8 @@ dependencies = [ [[package]] name = "esp-idf-hal" -version = "0.43.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f7adf3fb19a9ca016cbea1ab8a7b852ac69df8fcde4923c23d3b155efbc42a74" +version = "0.45.2" +source = "git+https://github.com/esp-rs/esp-idf-hal#bc48639bd626c72afc1e25e5d497b5c639161d30" dependencies = [ "atomic-waker", "embassy-sync", @@ -552,14 +550,12 @@ dependencies = [ "heapless", "log", "nb 1.1.0", - "num_enum", ] [[package]] name = "esp-idf-svc" -version = "0.48.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2180642ca122a7fec1ec417a9b1a77aa66aaa067fdf1daae683dd8caba84f26b" +version = "0.51.0" +source = "git+https://github.com/esp-rs/esp-idf-svc#dee202f146c7681e54eabbf118a216fc0195d203" dependencies = [ "embassy-futures", "embedded-hal-async", @@ -567,6 +563,7 @@ dependencies = [ "embuild", "enumset", "esp-idf-hal", + "futures-io", "heapless", "log", "num_enum", @@ -575,14 +572,13 @@ dependencies = [ [[package]] name = "esp-idf-sys" -version = "0.34.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2e148f97c04ed3e9181a08bcdc9560a515aad939b0ba7f50a0022e294665e0af" +version = "0.36.1" +source = "git+https://github.com/esp-rs/esp-idf-sys#64667a38fb8004e1fc3b032488af6857ca3cd849" dependencies = [ "anyhow", - "bindgen", "build-time", "cargo_metadata", + "cmake", "const_format", "embuild", "envy", @@ -649,21 +645,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7e3450815272ef58cec6d564423f6e755e25379b217b0bc688e295ba24df6b1d" [[package]] -name = "futures-task" +name = "futures-io" version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "037711b3d59c33004d3856fbdc83b99d4ff37a24768fa1be9ce3538a1cde4393" +checksum = "cecba35d7ad927e23624b22ad55235f2239cfa44fd10428eecbeba6d6a717718" [[package]] -name = "futures-util" +name = "futures-sink" version = "0.3.32" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "389ca41296e6190b48053de0321d02a77f32f8a5d2461dd38762c0593805c6d6" -dependencies = [ - "futures-core", - "futures-task", - "pin-project-lite", -] +checksum = "c39754e157331b013978ec91992bde1ac089843443c49cbc7f46150b0fad0893" [[package]] name = "getrandom" @@ -827,6 +818,15 @@ dependencies = [ "serde_core", ] +[[package]] +name = "itertools" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186" +dependencies = [ + "either", +] + [[package]] name = "itoa" version = "1.0.17" @@ -843,18 +843,6 @@ dependencies = [ "wasm-bindgen", ] -[[package]] -name = "lazy_static" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" - -[[package]] -name = "lazycell" -version = "1.3.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55" - [[package]] name = "leb128fmt" version = "0.1.0" @@ -945,12 +933,6 @@ dependencies = [ "libc", ] -[[package]] -name = "no-std-net" -version = "0.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1bcece43b12349917e096cddfa66107277f123e6c96a5aea78711dc601a47152" - [[package]] name = "nom" version = "7.1.3" @@ -1007,18 +989,6 @@ version = "1.21.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" -[[package]] -name = "peeking_take_while" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "19b17cddbe7ec3f8bc800887bab5e717348c95ea2ca0b1bf0837fb964dc67099" - -[[package]] -name = "pin-project-lite" -version = "0.2.16" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3b3cff922bd51709b605d9ead9aa71031d81447142d828eb4a6eba76fe619f9b" - [[package]] name = "prettyplease" version = "0.2.37" @@ -1138,9 +1108,9 @@ dependencies = [ [[package]] name = "rustc-hash" -version = "1.1.0" +version = "2.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" +checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d" [[package]] name = "rustix" diff --git a/firmware/zeroclaw-esp32/Cargo.toml b/firmware/zeroclaw-esp32/Cargo.toml index 70d2611..2ec056f 100644 --- a/firmware/zeroclaw-esp32/Cargo.toml +++ b/firmware/zeroclaw-esp32/Cargo.toml @@ -14,15 +14,21 @@ edition = "2021" license = "MIT" description = "ZeroClaw ESP32 peripheral firmware — GPIO over JSON serial" +[patch.crates-io] +# Use latest esp-rs crates to fix u8/i8 char pointer compatibility with ESP-IDF 5.x +esp-idf-sys = { git = "https://github.com/esp-rs/esp-idf-sys" } +esp-idf-hal = { git = "https://github.com/esp-rs/esp-idf-hal" } +esp-idf-svc = { git = "https://github.com/esp-rs/esp-idf-svc" } + [dependencies] -esp-idf-svc = "0.48" +esp-idf-svc = { git = "https://github.com/esp-rs/esp-idf-svc" } log = "0.4" anyhow = "1.0" serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" [build-dependencies] -embuild = "0.31" +embuild = { version = "0.33", features = ["espidf"] } [profile.release] opt-level = "s" diff --git a/firmware/zeroclaw-esp32/README.md b/firmware/zeroclaw-esp32/README.md index 804aaca..f4b2c08 100644 --- a/firmware/zeroclaw-esp32/README.md +++ b/firmware/zeroclaw-esp32/README.md @@ -2,8 +2,11 @@ Peripheral firmware for ESP32 — speaks the same JSON-over-serial protocol as the STM32 firmware. Flash this to your ESP32, then configure ZeroClaw on the host to connect via serial. +**New to this?** See [SETUP.md](SETUP.md) for step-by-step commands and troubleshooting. + ## Protocol + - **Request** (host → ESP32): `{"id":"1","cmd":"gpio_write","args":{"pin":13,"value":1}}\n` - **Response** (ESP32 → host): `{"id":"1","ok":true,"result":"done"}\n` @@ -11,19 +14,44 @@ Commands: `gpio_read`, `gpio_write`. ## Prerequisites -1. **ESP toolchain** (espup): +1. **RISC-V ESP-IDF** (ESP32-C2/C3): Uses nightly Rust with `build-std`. + + **Python**: ESP-IDF requires Python 3.10–3.13 (not 3.14). If you have Python 3.14: + ```sh + brew install python@3.12 + ``` + + **virtualenv** (needed by ESP-IDF tools; PEP 668 workaround on macOS): + ```sh + /opt/homebrew/opt/python@3.12/bin/python3.12 -m pip install virtualenv --break-system-packages + ``` + + **Rust tools**: + ```sh + cargo install espflash ldproxy + ``` + + The project's `rust-toolchain.toml` pins nightly + rust-src. `esp-idf-sys` downloads ESP-IDF automatically on first build. Use Python 3.12 for the build: + ```sh + export PATH="/opt/homebrew/opt/python@3.12/libexec/bin:$PATH" + ``` + +2. **Xtensa targets** (ESP32, ESP32-S2, ESP32-S3): Use espup instead: ```sh cargo install espup espflash espup install - source ~/export-esp.sh # or ~/export-esp.fish for Fish + source ~/export-esp.sh ``` - -2. **Target**: ESP32-C3 (RISC-V) by default. Edit `.cargo/config.toml` for other targets (e.g. `xtensa-esp32-espidf` for original ESP32). + Then edit `.cargo/config.toml` to change the target (e.g. `xtensa-esp32-espidf`). ## Build & Flash ```sh cd firmware/zeroclaw-esp32 +# Use Python 3.12 (required if you have 3.14) +export PATH="/opt/homebrew/opt/python@3.12/libexec/bin:$PATH" +# Optional: pin MCU (esp32c3 or esp32c2) +export MCU=esp32c3 cargo build --release espflash flash target/riscv32imc-esp-espidf/release/zeroclaw-esp32 --monitor ``` diff --git a/firmware/zeroclaw-esp32/SETUP.md b/firmware/zeroclaw-esp32/SETUP.md new file mode 100644 index 0000000..0624f4d --- /dev/null +++ b/firmware/zeroclaw-esp32/SETUP.md @@ -0,0 +1,156 @@ +# ESP32 Firmware Setup Guide + +Step-by-step setup for building the ZeroClaw ESP32 firmware. Follow this if you run into issues. + +## Quick Start (copy-paste) + +```sh +# 1. Install Python 3.12 (ESP-IDF needs 3.10–3.13, not 3.14) +brew install python@3.12 + +# 2. Install virtualenv (PEP 668 workaround on macOS) +/opt/homebrew/opt/python@3.12/bin/python3.12 -m pip install virtualenv --break-system-packages + +# 3. Install Rust tools +cargo install espflash ldproxy + +# 4. Build +cd firmware/zeroclaw-esp32 +export PATH="/opt/homebrew/opt/python@3.12/libexec/bin:$PATH" +cargo build --release + +# 5. Flash (connect ESP32 via USB) +espflash flash target/riscv32imc-esp-espidf/release/zeroclaw-esp32 --monitor +``` + +--- + +## Detailed Steps + +### 1. Python + +ESP-IDF requires Python 3.10–3.13. **Python 3.14 is not supported.** + +```sh +brew install python@3.12 +``` + +### 2. virtualenv + +ESP-IDF tools need `virtualenv`. On macOS with Homebrew Python, PEP 668 blocks `pip install`; use: + +```sh +/opt/homebrew/opt/python@3.12/bin/python3.12 -m pip install virtualenv --break-system-packages +``` + +### 3. Rust Tools + +```sh +cargo install espflash ldproxy +``` + +- **espflash**: flash and monitor +- **ldproxy**: linker for ESP-IDF builds + +### 4. Use Python 3.12 for Builds + +Before every build (or add to `~/.zshrc`): + +```sh +export PATH="/opt/homebrew/opt/python@3.12/libexec/bin:$PATH" +``` + +### 5. Build + +```sh +cd firmware/zeroclaw-esp32 +cargo build --release +``` + +First build downloads and compiles ESP-IDF (~5–15 min). + +### 6. Flash + +```sh +espflash flash target/riscv32imc-esp-espidf/release/zeroclaw-esp32 --monitor +``` + +--- + +## Troubleshooting + +### "No space left on device" + +Free disk space. Common targets: + +```sh +# Cargo cache (often 5–20 GB) +rm -rf ~/.cargo/registry/cache ~/.cargo/registry/src + +# Unused Rust toolchains +rustup toolchain list +rustup toolchain uninstall + +# iOS Simulator runtimes (~35 GB) +xcrun simctl delete unavailable + +# Temp files +rm -rf /var/folders/*/T/cargo-install* +``` + +### "can't find crate for `core`" / "riscv32imc-esp-espidf target may not be installed" + +This project uses **nightly Rust with build-std**, not espup. Ensure: + +- `rust-toolchain.toml` exists (pins nightly + rust-src) +- You are **not** sourcing `~/export-esp.sh` (that's for Xtensa targets) +- Run `cargo build` from `firmware/zeroclaw-esp32` + +### "externally-managed-environment" / "No module named 'virtualenv'" + +Install virtualenv with the PEP 668 workaround: + +```sh +/opt/homebrew/opt/python@3.12/bin/python3.12 -m pip install virtualenv --break-system-packages +``` + +### "expected `i64`, found `i32`" (time_t mismatch) + +Already fixed in `.cargo/config.toml` with `espidf_time64` for ESP-IDF 5.x. If you use ESP-IDF 4.4, switch to `espidf_time32`. + +### "expected `*const u8`, found `*const i8`" (esp-idf-svc) + +Already fixed via `[patch.crates-io]` in `Cargo.toml` using esp-rs crates from git. Do not remove the patch. + +### 10,000+ files in `git status` + +The `.embuild/` directory (ESP-IDF cache) has ~100k+ files. It is in `.gitignore`. If you see them, ensure `.gitignore` contains: + +``` +.embuild/ +``` + +--- + +## Optional: Auto-load Python 3.12 + +Add to `~/.zshrc`: + +```sh +# ESP32 firmware build +export PATH="/opt/homebrew/opt/python@3.12/libexec/bin:$PATH" +``` + +--- + +## Xtensa Targets (ESP32, ESP32-S2, ESP32-S3) + +For non–RISC-V chips, use espup instead: + +```sh +cargo install espup espflash +espup install +source ~/export-esp.sh +``` + +Then edit `.cargo/config.toml` to use `xtensa-esp32-espidf` (or the correct target). diff --git a/firmware/zeroclaw-esp32/rust-toolchain.toml b/firmware/zeroclaw-esp32/rust-toolchain.toml new file mode 100644 index 0000000..f70d225 --- /dev/null +++ b/firmware/zeroclaw-esp32/rust-toolchain.toml @@ -0,0 +1,3 @@ +[toolchain] +channel = "nightly" +components = ["rust-src"] diff --git a/firmware/zeroclaw-esp32/src/main.rs b/firmware/zeroclaw-esp32/src/main.rs index b1a487c..a85b67d 100644 --- a/firmware/zeroclaw-esp32/src/main.rs +++ b/firmware/zeroclaw-esp32/src/main.rs @@ -6,8 +6,9 @@ //! Protocol: same as STM32 — see docs/hardware-peripherals-design.md use esp_idf_svc::hal::gpio::PinDriver; -use esp_idf_svc::hal::prelude::*; -use esp_idf_svc::hal::uart::*; +use esp_idf_svc::hal::peripherals::Peripherals; +use esp_idf_svc::hal::uart::{UartConfig, UartDriver}; +use esp_idf_svc::hal::units::Hertz; use log::info; use serde::{Deserialize, Serialize}; @@ -36,9 +37,13 @@ fn main() -> anyhow::Result<()> { let peripherals = Peripherals::take()?; let pins = peripherals.pins; + // Create GPIO output drivers first (they take ownership of pins) + let mut gpio2 = PinDriver::output(pins.gpio2)?; + let mut gpio13 = PinDriver::output(pins.gpio13)?; + // UART0: TX=21, RX=20 (ESP32) — ESP32-C3 may use different pins; adjust for your board let config = UartConfig::new().baudrate(Hertz(115_200)); - let mut uart = UartDriver::new( + let uart = UartDriver::new( peripherals.uart0, pins.gpio21, pins.gpio20, @@ -60,7 +65,8 @@ fn main() -> anyhow::Result<()> { if b == b'\n' { if !line.is_empty() { if let Ok(line_str) = std::str::from_utf8(&line) { - if let Ok(resp) = handle_request(line_str, &peripherals) { + if let Ok(resp) = handle_request(line_str, &mut gpio2, &mut gpio13) + { let out = serde_json::to_string(&resp).unwrap_or_default(); let _ = uart.write(format!("{}\n", out).as_bytes()); } @@ -80,10 +86,15 @@ fn main() -> anyhow::Result<()> { } } -fn handle_request( +fn handle_request( line: &str, - peripherals: &esp_idf_svc::hal::peripherals::Peripherals, -) -> anyhow::Result { + gpio2: &mut PinDriver<'_, G2>, + gpio13: &mut PinDriver<'_, G13>, +) -> anyhow::Result +where + G2: esp_idf_svc::hal::gpio::OutputMode, + G13: esp_idf_svc::hal::gpio::OutputMode, +{ let req: Request = serde_json::from_str(line.trim())?; let id = req.id.clone(); @@ -98,13 +109,13 @@ fn handle_request( } "gpio_read" => { let pin_num = req.args.get("pin").and_then(|v| v.as_u64()).unwrap_or(0) as i32; - let value = gpio_read(peripherals, pin_num)?; + let value = gpio_read(pin_num)?; Ok(value.to_string()) } "gpio_write" => { let pin_num = req.args.get("pin").and_then(|v| v.as_u64()).unwrap_or(0) as i32; let value = req.args.get("value").and_then(|v| v.as_u64()).unwrap_or(0); - gpio_write(peripherals, pin_num, value)?; + gpio_write(gpio2, gpio13, pin_num, value)?; Ok("done".into()) } _ => Err(anyhow::anyhow!("Unknown command: {}", req.cmd)), @@ -126,28 +137,26 @@ fn handle_request( } } -fn gpio_read(_peripherals: &esp_idf_svc::hal::peripherals::Peripherals, _pin: i32) -> anyhow::Result { +fn gpio_read(_pin: i32) -> anyhow::Result { // TODO: implement input pin read — requires storing InputPin drivers per pin Ok(0) } -fn gpio_write( - peripherals: &esp_idf_svc::hal::peripherals::Peripherals, +fn gpio_write( + gpio2: &mut PinDriver<'_, G2>, + gpio13: &mut PinDriver<'_, G13>, pin: i32, value: u64, -) -> anyhow::Result<()> { - let pins = peripherals.pins; - let level = value != 0; +) -> anyhow::Result<()> +where + G2: esp_idf_svc::hal::gpio::OutputMode, + G13: esp_idf_svc::hal::gpio::OutputMode, +{ + let level = esp_idf_svc::hal::gpio::Level::from(value != 0); match pin { - 2 => { - let mut out = PinDriver::output(pins.gpio2)?; - out.set_level(esp_idf_svc::hal::gpio::Level::from(level))?; - } - 13 => { - let mut out = PinDriver::output(pins.gpio13)?; - out.set_level(esp_idf_svc::hal::gpio::Level::from(level))?; - } + 2 => gpio2.set_level(level)?, + 13 => gpio13.set_level(level)?, _ => anyhow::bail!("Pin {} not configured (add to gpio_write)", pin), } Ok(()) diff --git a/scripts/recompute_contributor_tiers.sh b/scripts/recompute_contributor_tiers.sh new file mode 100755 index 0000000..6e3e528 --- /dev/null +++ b/scripts/recompute_contributor_tiers.sh @@ -0,0 +1,324 @@ +#!/usr/bin/env bash + +set -euo pipefail + +SCRIPT_NAME="$(basename "$0")" + +usage() { + cat < Target repository (default: current gh repo) + --kind + Target objects (default: both) + --state + State filter for listing objects (default: all) + --limit Limit processed objects after fetch (default: 0 = no limit) + --apply Apply label updates (default is dry-run) + --dry-run Preview only (default) + -h, --help Show this help + +Examples: + ./$SCRIPT_NAME --repo zeroclaw-labs/zeroclaw --limit 50 + ./$SCRIPT_NAME --repo zeroclaw-labs/zeroclaw --kind prs --state open --apply +USAGE +} + +die() { + echo "[$SCRIPT_NAME] ERROR: $*" >&2 + exit 1 +} + +require_cmd() { + if ! command -v "$1" >/dev/null 2>&1; then + die "Required command not found: $1" + fi +} + +urlencode() { + jq -nr --arg value "$1" '$value|@uri' +} + +select_contributor_tier() { + local merged_count="$1" + if (( merged_count >= 50 )); then + echo "distinguished contributor" + elif (( merged_count >= 20 )); then + echo "principal contributor" + elif (( merged_count >= 10 )); then + echo "experienced contributor" + elif (( merged_count >= 5 )); then + echo "trusted contributor" + else + echo "" + fi +} + +DRY_RUN=1 +KIND="both" +STATE="all" +LIMIT=0 +REPO="" + +while (($# > 0)); do + case "$1" in + --repo) + [[ $# -ge 2 ]] || die "Missing value for --repo" + REPO="$2" + shift 2 + ;; + --kind) + [[ $# -ge 2 ]] || die "Missing value for --kind" + KIND="$2" + shift 2 + ;; + --state) + [[ $# -ge 2 ]] || die "Missing value for --state" + STATE="$2" + shift 2 + ;; + --limit) + [[ $# -ge 2 ]] || die "Missing value for --limit" + LIMIT="$2" + shift 2 + ;; + --apply) + DRY_RUN=0 + shift + ;; + --dry-run) + DRY_RUN=1 + shift + ;; + -h|--help) + usage + exit 0 + ;; + *) + die "Unknown option: $1" + ;; + esac +done + +case "$KIND" in + both|prs|issues) ;; + *) die "--kind must be one of: both, prs, issues" ;; +esac + +case "$STATE" in + all|open|closed) ;; + *) die "--state must be one of: all, open, closed" ;; +esac + +if ! [[ "$LIMIT" =~ ^[0-9]+$ ]]; then + die "--limit must be a non-negative integer" +fi + +require_cmd gh +require_cmd jq + +if ! gh auth status >/dev/null 2>&1; then + die "gh CLI is not authenticated. Run: gh auth login" +fi + +if [[ -z "$REPO" ]]; then + REPO="$(gh repo view --json nameWithOwner --jq '.nameWithOwner' 2>/dev/null || true)" + [[ -n "$REPO" ]] || die "Unable to infer repo. Pass --repo ." +fi + +echo "[$SCRIPT_NAME] Repo: $REPO" +echo "[$SCRIPT_NAME] Mode: $([[ "$DRY_RUN" -eq 1 ]] && echo "dry-run" || echo "apply")" +echo "[$SCRIPT_NAME] Kind: $KIND | State: $STATE | Limit: $LIMIT" + +TIERS_JSON='["trusted contributor","experienced contributor","principal contributor","distinguished contributor"]' + +TMP_FILES=() +cleanup() { + if ((${#TMP_FILES[@]} > 0)); then + rm -f "${TMP_FILES[@]}" + fi +} +trap cleanup EXIT + +new_tmp_file() { + local tmp + tmp="$(mktemp)" + TMP_FILES+=("$tmp") + echo "$tmp" +} + +targets_file="$(new_tmp_file)" + +if [[ "$KIND" == "both" || "$KIND" == "prs" ]]; then + gh api --paginate "repos/$REPO/pulls?state=$STATE&per_page=100" \ + --jq '.[] | { + kind: "pr", + number: .number, + author: (.user.login // ""), + author_type: (.user.type // ""), + labels: [(.labels[]?.name // empty)] + }' >> "$targets_file" +fi + +if [[ "$KIND" == "both" || "$KIND" == "issues" ]]; then + gh api --paginate "repos/$REPO/issues?state=$STATE&per_page=100" \ + --jq '.[] | select(.pull_request | not) | { + kind: "issue", + number: .number, + author: (.user.login // ""), + author_type: (.user.type // ""), + labels: [(.labels[]?.name // empty)] + }' >> "$targets_file" +fi + +if [[ "$LIMIT" -gt 0 ]]; then + limited_file="$(new_tmp_file)" + head -n "$LIMIT" "$targets_file" > "$limited_file" + mv "$limited_file" "$targets_file" +fi + +target_count="$(wc -l < "$targets_file" | tr -d ' ')" +if [[ "$target_count" -eq 0 ]]; then + echo "[$SCRIPT_NAME] No targets found." + exit 0 +fi + +echo "[$SCRIPT_NAME] Targets fetched: $target_count" + +# Ensure tier labels exist (trusted contributor might be new). +label_color="" +for probe_label in "experienced contributor" "principal contributor" "distinguished contributor" "trusted contributor"; do + encoded_label="$(urlencode "$probe_label")" + if color_candidate="$(gh api "repos/$REPO/labels/$encoded_label" --jq '.color' 2>/dev/null || true)"; then + if [[ -n "$color_candidate" ]]; then + label_color="$(echo "$color_candidate" | tr '[:lower:]' '[:upper:]')" + break + fi + fi +done +[[ -n "$label_color" ]] || label_color="C5D7A2" + +while IFS= read -r tier_label; do + [[ -n "$tier_label" ]] || continue + encoded_label="$(urlencode "$tier_label")" + if gh api "repos/$REPO/labels/$encoded_label" >/dev/null 2>&1; then + continue + fi + + if [[ "$DRY_RUN" -eq 1 ]]; then + echo "[dry-run] Would create missing label: $tier_label (color=$label_color)" + else + gh api -X POST "repos/$REPO/labels" \ + -f name="$tier_label" \ + -f color="$label_color" >/dev/null + echo "[apply] Created missing label: $tier_label" + fi +done < <(jq -r '.[]' <<<"$TIERS_JSON") + +# Build merged PR count cache by unique human authors. +authors_file="$(new_tmp_file)" +jq -r 'select(.author != "" and .author_type != "Bot") | .author' "$targets_file" | sort -u > "$authors_file" +author_count="$(wc -l < "$authors_file" | tr -d ' ')" +echo "[$SCRIPT_NAME] Unique human authors: $author_count" + +author_counts_file="$(new_tmp_file)" +while IFS= read -r author; do + [[ -n "$author" ]] || continue + query="repo:$REPO is:pr is:merged author:$author" + merged_count="$(gh api search/issues -f q="$query" -F per_page=1 --jq '.total_count' 2>/dev/null || true)" + if ! [[ "$merged_count" =~ ^[0-9]+$ ]]; then + merged_count=0 + fi + printf '%s\t%s\n' "$author" "$merged_count" >> "$author_counts_file" +done < "$authors_file" + +updated=0 +unchanged=0 +skipped=0 +failed=0 + +while IFS= read -r target_json; do + [[ -n "$target_json" ]] || continue + + number="$(jq -r '.number' <<<"$target_json")" + kind="$(jq -r '.kind' <<<"$target_json")" + author="$(jq -r '.author' <<<"$target_json")" + author_type="$(jq -r '.author_type' <<<"$target_json")" + current_labels_json="$(jq -c '.labels // []' <<<"$target_json")" + + if [[ -z "$author" || "$author_type" == "Bot" ]]; then + skipped=$((skipped + 1)) + continue + fi + + merged_count="$(awk -F '\t' -v key="$author" '$1 == key { print $2; exit }' "$author_counts_file")" + if ! [[ "$merged_count" =~ ^[0-9]+$ ]]; then + merged_count=0 + fi + desired_tier="$(select_contributor_tier "$merged_count")" + + if ! current_tier="$(jq -r --argjson tiers "$TIERS_JSON" '[.[] | select(. as $label | ($tiers | index($label)) != null)][0] // ""' <<<"$current_labels_json" 2>/dev/null)"; then + echo "[warn] Skipping ${kind} #${number}: cannot parse current labels JSON" >&2 + failed=$((failed + 1)) + continue + fi + + if ! next_labels_json="$(jq -c --arg desired "$desired_tier" --argjson tiers "$TIERS_JSON" ' + (. // []) + | map(select(. as $label | ($tiers | index($label)) == null)) + | if $desired != "" then . + [$desired] else . end + | unique + ' <<<"$current_labels_json" 2>/dev/null)"; then + echo "[warn] Skipping ${kind} #${number}: cannot compute next labels" >&2 + failed=$((failed + 1)) + continue + fi + + if ! normalized_current="$(jq -c 'unique | sort' <<<"$current_labels_json" 2>/dev/null)"; then + echo "[warn] Skipping ${kind} #${number}: cannot normalize current labels" >&2 + failed=$((failed + 1)) + continue + fi + + if ! normalized_next="$(jq -c 'unique | sort' <<<"$next_labels_json" 2>/dev/null)"; then + echo "[warn] Skipping ${kind} #${number}: cannot normalize next labels" >&2 + failed=$((failed + 1)) + continue + fi + + if [[ "$normalized_current" == "$normalized_next" ]]; then + unchanged=$((unchanged + 1)) + continue + fi + + if [[ "$DRY_RUN" -eq 1 ]]; then + echo "[dry-run] ${kind} #${number} @${author} merged=${merged_count} tier: '${current_tier:-none}' -> '${desired_tier:-none}'" + updated=$((updated + 1)) + continue + fi + + payload="$(jq -cn --argjson labels "$next_labels_json" '{labels: $labels}')" + if gh api -X PUT "repos/$REPO/issues/$number/labels" --input - <<<"$payload" >/dev/null; then + echo "[apply] Updated ${kind} #${number} @${author} tier: '${current_tier:-none}' -> '${desired_tier:-none}'" + updated=$((updated + 1)) + else + echo "[apply] FAILED ${kind} #${number}" >&2 + failed=$((failed + 1)) + fi +done < "$targets_file" + +echo "" +echo "[$SCRIPT_NAME] Summary" +echo " Targets: $target_count" +echo " Updated: $updated" +echo " Unchanged: $unchanged" +echo " Skipped: $skipped" +echo " Failed: $failed" + +if [[ "$failed" -gt 0 ]]; then + exit 1 +fi diff --git a/src/agent/agent.rs b/src/agent/agent.rs index ca18e79..3e5693e 100644 --- a/src/agent/agent.rs +++ b/src/agent/agent.rs @@ -251,6 +251,7 @@ impl Agent { let provider: Box = providers::create_routed_provider( provider_name, config.api_key.as_deref(), + config.api_url.as_deref(), &config.reliability, &config.model_routes, &model_name, @@ -388,7 +389,7 @@ impl Agent { if self.auto_save { let _ = self .memory - .store("user_msg", user_message, MemoryCategory::Conversation) + .store("user_msg", user_message, MemoryCategory::Conversation, None) .await; } @@ -447,7 +448,7 @@ impl Agent { let summary = truncate_with_ellipsis(&final_text, 100); let _ = self .memory - .store("assistant_resp", &summary, MemoryCategory::Daily) + .store("assistant_resp", &summary, MemoryCategory::Daily, None) .await; } @@ -557,6 +558,7 @@ pub async fn run( agent.observer.record_event(&ObserverEvent::AgentEnd { duration: start.elapsed(), tokens_used: None, + cost_usd: None, }); Ok(()) diff --git a/src/agent/loop_.rs b/src/agent/loop_.rs index 47d02a6..81882d6 100644 --- a/src/agent/loop_.rs +++ b/src/agent/loop_.rs @@ -7,14 +7,70 @@ use crate::security::SecurityPolicy; use crate::tools::{self, Tool}; use crate::util::truncate_with_ellipsis; use anyhow::Result; +use regex::{Regex, RegexSet}; use std::fmt::Write; use std::io::Write as _; -use std::sync::Arc; +use std::sync::{Arc, LazyLock}; use std::time::Instant; use uuid::Uuid; + /// Maximum agentic tool-use iterations per user message to prevent runaway loops. const MAX_TOOL_ITERATIONS: usize = 10; +static SENSITIVE_KEY_PATTERNS: LazyLock = LazyLock::new(|| { + RegexSet::new([ + r"(?i)token", + r"(?i)api[_-]?key", + r"(?i)password", + r"(?i)secret", + r"(?i)user[_-]?key", + r"(?i)bearer", + r"(?i)credential", + ]) + .unwrap() +}); + +static SENSITIVE_KV_REGEX: LazyLock = LazyLock::new(|| { + Regex::new(r#"(?i)(token|api[_-]?key|password|secret|user[_-]?key|bearer|credential)["']?\s*[:=]\s*(?:"([^"]{8,})"|'([^']{8,})'|([a-zA-Z0-9_\-\.]{8,}))"#).unwrap() +}); + +/// Scrub credentials from tool output to prevent accidental exfiltration. +/// Replaces known credential patterns with a redacted placeholder while preserving +/// a small prefix for context. +fn scrub_credentials(input: &str) -> String { + SENSITIVE_KV_REGEX + .replace_all(input, |caps: ®ex::Captures| { + let full_match = &caps[0]; + let key = &caps[1]; + let val = caps + .get(2) + .or(caps.get(3)) + .or(caps.get(4)) + .map(|m| m.as_str()) + .unwrap_or(""); + + // Preserve first 4 chars for context, then redact + let prefix = if val.len() > 4 { &val[..4] } else { "" }; + + if full_match.contains(':') { + if full_match.contains('"') { + format!("\"{}\": \"{}*[REDACTED]\"", key, prefix) + } else { + format!("{}: {}*[REDACTED]", key, prefix) + } + } else if full_match.contains('=') { + if full_match.contains('"') { + format!("{}=\"{}*[REDACTED]\"", key, prefix) + } else { + format!("{}={}*[REDACTED]", key, prefix) + } + } else { + format!("{}: {}*[REDACTED]", key, prefix) + } + }) + .to_string() +} + /// Trigger auto-compaction when non-system message count exceeds this threshold. const MAX_HISTORY_MESSAGES: usize = 50; @@ -145,7 +201,7 @@ async fn build_context(mem: &dyn Memory, user_msg: &str) -> String { let mut context = String::new(); // Pull relevant memories for this message - if let Ok(entries) = mem.recall(user_msg, 5).await { + if let Ok(entries) = mem.recall(user_msg, 5, None).await { if !entries.is_empty() { context.push_str("[Memory context]\n"); for entry in &entries { @@ -436,6 +492,7 @@ struct ParsedToolCall { /// Execute a single turn of the agent loop: send messages, parse tool calls, /// execute tools, and loop until the LLM produces a final text response. /// When `silent` is true, suppresses stdout (for channel use). +#[allow(clippy::too_many_arguments)] pub(crate) async fn agent_turn( provider: &dyn Provider, history: &mut Vec, @@ -461,6 +518,7 @@ pub(crate) async fn agent_turn( /// Execute a single turn of the agent loop: send messages, parse tool calls, /// execute tools, and loop until the LLM produces a final text response. +#[allow(clippy::too_many_arguments)] pub(crate) async fn run_tool_call_loop( provider: &dyn Provider, history: &mut Vec, @@ -606,7 +664,7 @@ pub(crate) async fn run_tool_call_loop( success: r.success, }); if r.success { - r.output + scrub_credentials(&r.output) } else { format!("Error: {}", r.error.unwrap_or_else(|| r.output)) } @@ -749,6 +807,7 @@ pub async fn run( let provider: Box = providers::create_routed_provider( provider_name, config.api_key.as_deref(), + config.api_url.as_deref(), &config.reliability, &config.model_routes, model_name, @@ -912,7 +971,7 @@ pub async fn run( if config.memory.auto_save { let user_key = autosave_memory_key("user_msg"); let _ = mem - .store(&user_key, &msg, MemoryCategory::Conversation) + .store(&user_key, &msg, MemoryCategory::Conversation, None) .await; } @@ -955,7 +1014,7 @@ pub async fn run( let summary = truncate_with_ellipsis(&response, 100); let response_key = autosave_memory_key("assistant_resp"); let _ = mem - .store(&response_key, &summary, MemoryCategory::Daily) + .store(&response_key, &summary, MemoryCategory::Daily, None) .await; } } else { @@ -978,7 +1037,7 @@ pub async fn run( if config.memory.auto_save { let user_key = autosave_memory_key("user_msg"); let _ = mem - .store(&user_key, &msg.content, MemoryCategory::Conversation) + .store(&user_key, &msg.content, MemoryCategory::Conversation, None) .await; } @@ -1036,7 +1095,7 @@ pub async fn run( let summary = truncate_with_ellipsis(&response, 100); let response_key = autosave_memory_key("assistant_resp"); let _ = mem - .store(&response_key, &summary, MemoryCategory::Daily) + .store(&response_key, &summary, MemoryCategory::Daily, None) .await; } } @@ -1048,6 +1107,7 @@ pub async fn run( observer.record_event(&ObserverEvent::AgentEnd { duration, tokens_used: None, + cost_usd: None, }); Ok(final_output) @@ -1104,6 +1164,7 @@ pub async fn process_message(config: Config, message: &str) -> Result { let provider: Box = providers::create_routed_provider( provider_name, config.api_key.as_deref(), + config.api_url.as_deref(), &config.reliability, &config.model_routes, &model_name, @@ -1217,6 +1278,25 @@ pub async fn process_message(config: Config, message: &str) -> Result { #[cfg(test)] mod tests { use super::*; + + #[test] + fn test_scrub_credentials() { + let input = "API_KEY=sk-1234567890abcdef; token: 1234567890; password=\"secret123456\""; + let scrubbed = scrub_credentials(input); + assert!(scrubbed.contains("API_KEY=sk-1*[REDACTED]")); + assert!(scrubbed.contains("token: 1234*[REDACTED]")); + assert!(scrubbed.contains("password=\"secr*[REDACTED]\"")); + assert!(!scrubbed.contains("abcdef")); + assert!(!scrubbed.contains("secret123456")); + } + + #[test] + fn test_scrub_credentials_json() { + let input = r#"{"api_key": "sk-1234567890", "other": "public"}"#; + let scrubbed = scrub_credentials(input); + assert!(scrubbed.contains("\"api_key\": \"sk-1*[REDACTED]\"")); + assert!(scrubbed.contains("public")); + } use crate::memory::{Memory, MemoryCategory, SqliteMemory}; use tempfile::TempDir; @@ -1496,16 +1576,16 @@ I will now call the tool with this payload: let key1 = autosave_memory_key("user_msg"); let key2 = autosave_memory_key("user_msg"); - mem.store(&key1, "I'm Paul", MemoryCategory::Conversation) + mem.store(&key1, "I'm Paul", MemoryCategory::Conversation, None) .await .unwrap(); - mem.store(&key2, "I'm 45", MemoryCategory::Conversation) + mem.store(&key2, "I'm 45", MemoryCategory::Conversation, None) .await .unwrap(); assert_eq!(mem.count().await.unwrap(), 2); - let recalled = mem.recall("45", 5).await.unwrap(); + let recalled = mem.recall("45", 5, None).await.unwrap(); assert!(recalled.iter().any(|entry| entry.content.contains("45"))); } diff --git a/src/agent/memory_loader.rs b/src/agent/memory_loader.rs index f5733ec..0cc530f 100644 --- a/src/agent/memory_loader.rs +++ b/src/agent/memory_loader.rs @@ -33,7 +33,7 @@ impl MemoryLoader for DefaultMemoryLoader { memory: &dyn Memory, user_message: &str, ) -> anyhow::Result { - let entries = memory.recall(user_message, self.limit).await?; + let entries = memory.recall(user_message, self.limit, None).await?; if entries.is_empty() { return Ok(String::new()); } @@ -61,11 +61,17 @@ mod tests { _key: &str, _content: &str, _category: MemoryCategory, + _session_id: Option<&str>, ) -> anyhow::Result<()> { Ok(()) } - async fn recall(&self, _query: &str, limit: usize) -> anyhow::Result> { + async fn recall( + &self, + _query: &str, + limit: usize, + _session_id: Option<&str>, + ) -> anyhow::Result> { if limit == 0 { return Ok(vec![]); } @@ -87,6 +93,7 @@ mod tests { async fn list( &self, _category: Option<&MemoryCategory>, + _session_id: Option<&str>, ) -> anyhow::Result> { Ok(vec![]) } diff --git a/src/channels/cli.rs b/src/channels/cli.rs index 8b414fd..46ee474 100644 --- a/src/channels/cli.rs +++ b/src/channels/cli.rs @@ -40,6 +40,7 @@ impl Channel for CliChannel { let msg = ChannelMessage { id: Uuid::new_v4().to_string(), sender: "user".to_string(), + reply_target: "user".to_string(), content: line, channel: "cli".to_string(), timestamp: std::time::SystemTime::now() @@ -90,12 +91,14 @@ mod tests { let msg = ChannelMessage { id: "test-id".into(), sender: "user".into(), + reply_target: "user".into(), content: "hello".into(), channel: "cli".into(), timestamp: 1_234_567_890, }; assert_eq!(msg.id, "test-id"); assert_eq!(msg.sender, "user"); + assert_eq!(msg.reply_target, "user"); assert_eq!(msg.content, "hello"); assert_eq!(msg.channel, "cli"); assert_eq!(msg.timestamp, 1_234_567_890); @@ -106,6 +109,7 @@ mod tests { let msg = ChannelMessage { id: "id".into(), sender: "s".into(), + reply_target: "s".into(), content: "c".into(), channel: "ch".into(), timestamp: 0, diff --git a/src/channels/dingtalk.rs b/src/channels/dingtalk.rs index f55135a..7473bb3 100644 --- a/src/channels/dingtalk.rs +++ b/src/channels/dingtalk.rs @@ -7,7 +7,7 @@ use tokio::sync::RwLock; use tokio_tungstenite::tungstenite::Message; use uuid::Uuid; -/// DingTalk (钉钉) channel — connects via Stream Mode WebSocket for real-time messages. +/// DingTalk channel — connects via Stream Mode WebSocket for real-time messages. /// Replies are sent through per-message session webhook URLs. pub struct DingTalkChannel { client_id: String, @@ -64,6 +64,18 @@ impl DingTalkChannel { let gw: GatewayResponse = resp.json().await?; Ok(gw) } + + fn resolve_reply_target( + sender_id: &str, + conversation_type: &str, + conversation_id: Option<&str>, + ) -> String { + if conversation_type == "1" { + sender_id.to_string() + } else { + conversation_id.unwrap_or(sender_id).to_string() + } + } } #[async_trait] @@ -193,14 +205,11 @@ impl Channel for DingTalkChannel { .unwrap_or("1"); // Private chat uses sender ID, group chat uses conversation ID - let chat_id = if conversation_type == "1" { - sender_id.to_string() - } else { - data.get("conversationId") - .and_then(|c| c.as_str()) - .unwrap_or(sender_id) - .to_string() - }; + let chat_id = Self::resolve_reply_target( + sender_id, + conversation_type, + data.get("conversationId").and_then(|c| c.as_str()), + ); // Store session webhook for later replies if let Some(webhook) = data.get("sessionWebhook").and_then(|w| w.as_str()) { @@ -229,6 +238,7 @@ impl Channel for DingTalkChannel { let channel_msg = ChannelMessage { id: Uuid::new_v4().to_string(), sender: sender_id.to_string(), + reply_target: chat_id, content: content.to_string(), channel: "dingtalk".to_string(), timestamp: std::time::SystemTime::now() @@ -305,4 +315,22 @@ client_secret = "secret" let config: crate::config::schema::DingTalkConfig = toml::from_str(toml_str).unwrap(); assert!(config.allowed_users.is_empty()); } + + #[test] + fn test_resolve_reply_target_private_chat_uses_sender_id() { + let target = DingTalkChannel::resolve_reply_target("staff_1", "1", Some("conv_1")); + assert_eq!(target, "staff_1"); + } + + #[test] + fn test_resolve_reply_target_group_chat_uses_conversation_id() { + let target = DingTalkChannel::resolve_reply_target("staff_1", "2", Some("conv_1")); + assert_eq!(target, "conv_1"); + } + + #[test] + fn test_resolve_reply_target_group_chat_falls_back_to_sender_id() { + let target = DingTalkChannel::resolve_reply_target("staff_1", "2", None); + assert_eq!(target, "staff_1"); + } } diff --git a/src/channels/discord.rs b/src/channels/discord.rs index 4e99f43..7eb7502 100644 --- a/src/channels/discord.rs +++ b/src/channels/discord.rs @@ -11,6 +11,7 @@ pub struct DiscordChannel { guild_id: Option, allowed_users: Vec, listen_to_bots: bool, + mention_only: bool, client: reqwest::Client, typing_handle: std::sync::Mutex>>, } @@ -21,12 +22,14 @@ impl DiscordChannel { guild_id: Option, allowed_users: Vec, listen_to_bots: bool, + mention_only: bool, ) -> Self { Self { bot_token, guild_id, allowed_users, listen_to_bots, + mention_only, client: reqwest::Client::new(), typing_handle: std::sync::Mutex::new(None), } @@ -343,6 +346,22 @@ impl Channel for DiscordChannel { continue; } + // Skip messages that don't @-mention the bot (when mention_only is enabled) + if self.mention_only { + let mention_tag = format!("<@{bot_user_id}>"); + if !content.contains(&mention_tag) { + continue; + } + } + + // Strip the bot mention from content so the agent sees clean text + let clean_content = if self.mention_only { + let mention_tag = format!("<@{bot_user_id}>"); + content.replace(&mention_tag, "").trim().to_string() + } else { + content.to_string() + }; + let message_id = d.get("id").and_then(|i| i.as_str()).unwrap_or(""); let channel_id = d.get("channel_id").and_then(|c| c.as_str()).unwrap_or("").to_string(); @@ -353,6 +372,11 @@ impl Channel for DiscordChannel { format!("discord_{message_id}") }, sender: author_id.to_string(), + reply_target: if channel_id.is_empty() { + author_id.to_string() + } else { + channel_id + }, content: content.to_string(), channel: channel_id, timestamp: std::time::SystemTime::now() @@ -423,7 +447,7 @@ mod tests { #[test] fn discord_channel_name() { - let ch = DiscordChannel::new("fake".into(), None, vec![], false); + let ch = DiscordChannel::new("fake".into(), None, vec![], false, false); assert_eq!(ch.name(), "discord"); } @@ -444,21 +468,27 @@ mod tests { #[test] fn empty_allowlist_denies_everyone() { - let ch = DiscordChannel::new("fake".into(), None, vec![], false); + let ch = DiscordChannel::new("fake".into(), None, vec![], false, false); assert!(!ch.is_user_allowed("12345")); assert!(!ch.is_user_allowed("anyone")); } #[test] fn wildcard_allows_everyone() { - let ch = DiscordChannel::new("fake".into(), None, vec!["*".into()], false); + let ch = DiscordChannel::new("fake".into(), None, vec!["*".into()], false, false); assert!(ch.is_user_allowed("12345")); assert!(ch.is_user_allowed("anyone")); } #[test] fn specific_allowlist_filters() { - let ch = DiscordChannel::new("fake".into(), None, vec!["111".into(), "222".into()], false); + let ch = DiscordChannel::new( + "fake".into(), + None, + vec!["111".into(), "222".into()], + false, + false, + ); assert!(ch.is_user_allowed("111")); assert!(ch.is_user_allowed("222")); assert!(!ch.is_user_allowed("333")); @@ -467,7 +497,7 @@ mod tests { #[test] fn allowlist_is_exact_match_not_substring() { - let ch = DiscordChannel::new("fake".into(), None, vec!["111".into()], false); + let ch = DiscordChannel::new("fake".into(), None, vec!["111".into()], false, false); assert!(!ch.is_user_allowed("1111")); assert!(!ch.is_user_allowed("11")); assert!(!ch.is_user_allowed("0111")); @@ -475,20 +505,26 @@ mod tests { #[test] fn allowlist_empty_string_user_id() { - let ch = DiscordChannel::new("fake".into(), None, vec!["111".into()], false); + let ch = DiscordChannel::new("fake".into(), None, vec!["111".into()], false, false); assert!(!ch.is_user_allowed("")); } #[test] fn allowlist_with_wildcard_and_specific() { - let ch = DiscordChannel::new("fake".into(), None, vec!["111".into(), "*".into()], false); + let ch = DiscordChannel::new( + "fake".into(), + None, + vec!["111".into(), "*".into()], + false, + false, + ); assert!(ch.is_user_allowed("111")); assert!(ch.is_user_allowed("anyone_else")); } #[test] fn allowlist_case_sensitive() { - let ch = DiscordChannel::new("fake".into(), None, vec!["ABC".into()], false); + let ch = DiscordChannel::new("fake".into(), None, vec!["ABC".into()], false, false); assert!(ch.is_user_allowed("ABC")); assert!(!ch.is_user_allowed("abc")); assert!(!ch.is_user_allowed("Abc")); @@ -663,14 +699,14 @@ mod tests { #[test] fn typing_handle_starts_as_none() { - let ch = DiscordChannel::new("fake".into(), None, vec![], false); + let ch = DiscordChannel::new("fake".into(), None, vec![], false, false); let guard = ch.typing_handle.lock().unwrap(); assert!(guard.is_none()); } #[tokio::test] async fn start_typing_sets_handle() { - let ch = DiscordChannel::new("fake".into(), None, vec![], false); + let ch = DiscordChannel::new("fake".into(), None, vec![], false, false); let _ = ch.start_typing("123456").await; let guard = ch.typing_handle.lock().unwrap(); assert!(guard.is_some()); @@ -678,7 +714,7 @@ mod tests { #[tokio::test] async fn stop_typing_clears_handle() { - let ch = DiscordChannel::new("fake".into(), None, vec![], false); + let ch = DiscordChannel::new("fake".into(), None, vec![], false, false); let _ = ch.start_typing("123456").await; let _ = ch.stop_typing("123456").await; let guard = ch.typing_handle.lock().unwrap(); @@ -687,14 +723,14 @@ mod tests { #[tokio::test] async fn stop_typing_is_idempotent() { - let ch = DiscordChannel::new("fake".into(), None, vec![], false); + let ch = DiscordChannel::new("fake".into(), None, vec![], false, false); assert!(ch.stop_typing("123456").await.is_ok()); assert!(ch.stop_typing("123456").await.is_ok()); } #[tokio::test] async fn start_typing_replaces_existing_task() { - let ch = DiscordChannel::new("fake".into(), None, vec![], false); + let ch = DiscordChannel::new("fake".into(), None, vec![], false, false); let _ = ch.start_typing("111").await; let _ = ch.start_typing("222").await; let guard = ch.typing_handle.lock().unwrap(); diff --git a/src/channels/email_channel.rs b/src/channels/email_channel.rs index f1ea016..da3490d 100644 --- a/src/channels/email_channel.rs +++ b/src/channels/email_channel.rs @@ -10,6 +10,7 @@ use anyhow::{anyhow, Result}; use async_trait::async_trait; +use lettre::message::SinglePart; use lettre::transport::smtp::authentication::Credentials; use lettre::{Message, SmtpTransport, Transport}; use mail_parser::{MessageParser, MimeHeaders}; @@ -39,7 +40,7 @@ pub struct EmailConfig { pub imap_folder: String, /// SMTP server hostname pub smtp_host: String, - /// SMTP server port (default: 587 for STARTTLS) + /// SMTP server port (default: 465 for TLS) #[serde(default = "default_smtp_port")] pub smtp_port: u16, /// Use TLS for SMTP (default: true) @@ -63,7 +64,7 @@ fn default_imap_port() -> u16 { 993 } fn default_smtp_port() -> u16 { - 587 + 465 } fn default_imap_folder() -> String { "INBOX".into() @@ -389,7 +390,7 @@ impl Channel for EmailChannel { .from(self.config.from_address.parse()?) .to(recipient.parse()?) .subject(subject) - .body(body.to_string())?; + .singlepart(SinglePart::plain(body.to_string()))?; let transport = self.create_smtp_transport()?; transport.send(&email)?; @@ -427,6 +428,7 @@ impl Channel for EmailChannel { } // MutexGuard dropped before await let msg = ChannelMessage { id, + reply_target: sender.clone(), sender, content, channel: "email".to_string(), @@ -464,6 +466,18 @@ impl Channel for EmailChannel { mod tests { use super::*; + #[test] + fn default_smtp_port_uses_tls_port() { + assert_eq!(default_smtp_port(), 465); + } + + #[test] + fn email_config_default_uses_tls_smtp_defaults() { + let config = EmailConfig::default(); + assert_eq!(config.smtp_port, 465); + assert!(config.smtp_tls); + } + #[test] fn build_imap_tls_config_succeeds() { let tls_config = @@ -504,7 +518,7 @@ mod tests { assert_eq!(config.imap_port, 993); assert_eq!(config.imap_folder, "INBOX"); assert_eq!(config.smtp_host, ""); - assert_eq!(config.smtp_port, 587); + assert_eq!(config.smtp_port, 465); assert!(config.smtp_tls); assert_eq!(config.username, ""); assert_eq!(config.password, ""); @@ -765,8 +779,8 @@ mod tests { } #[test] - fn default_smtp_port_returns_587() { - assert_eq!(default_smtp_port(), 587); + fn default_smtp_port_returns_465() { + assert_eq!(default_smtp_port(), 465); } #[test] @@ -822,7 +836,7 @@ mod tests { let config: EmailConfig = serde_json::from_str(json).unwrap(); assert_eq!(config.imap_port, 993); // default - assert_eq!(config.smtp_port, 587); // default + assert_eq!(config.smtp_port, 465); // default assert!(config.smtp_tls); // default assert_eq!(config.poll_interval_secs, 60); // default } diff --git a/src/channels/imessage.rs b/src/channels/imessage.rs index f001c56..36bf72f 100644 --- a/src/channels/imessage.rs +++ b/src/channels/imessage.rs @@ -172,6 +172,7 @@ end tell"# let msg = ChannelMessage { id: rowid.to_string(), sender: sender.clone(), + reply_target: sender.clone(), content: text, channel: "imessage".to_string(), timestamp: std::time::SystemTime::now() diff --git a/src/channels/irc.rs b/src/channels/irc.rs index d63ad41..61a48cc 100644 --- a/src/channels/irc.rs +++ b/src/channels/irc.rs @@ -220,32 +220,34 @@ fn split_message(message: &str, max_bytes: usize) -> Vec { chunks } +/// Configuration for constructing an `IrcChannel`. +pub struct IrcChannelConfig { + pub server: String, + pub port: u16, + pub nickname: String, + pub username: Option, + pub channels: Vec, + pub allowed_users: Vec, + pub server_password: Option, + pub nickserv_password: Option, + pub sasl_password: Option, + pub verify_tls: bool, +} + impl IrcChannel { - #[allow(clippy::too_many_arguments)] - pub fn new( - server: String, - port: u16, - nickname: String, - username: Option, - channels: Vec, - allowed_users: Vec, - server_password: Option, - nickserv_password: Option, - sasl_password: Option, - verify_tls: bool, - ) -> Self { - let username = username.unwrap_or_else(|| nickname.clone()); + pub fn new(cfg: IrcChannelConfig) -> Self { + let username = cfg.username.unwrap_or_else(|| cfg.nickname.clone()); Self { - server, - port, - nickname, + server: cfg.server, + port: cfg.port, + nickname: cfg.nickname, username, - channels, - allowed_users, - server_password, - nickserv_password, - sasl_password, - verify_tls, + channels: cfg.channels, + allowed_users: cfg.allowed_users, + server_password: cfg.server_password, + nickserv_password: cfg.nickserv_password, + sasl_password: cfg.sasl_password, + verify_tls: cfg.verify_tls, writer: Arc::new(Mutex::new(None)), } } @@ -563,7 +565,8 @@ impl Channel for IrcChannel { let seq = MSG_SEQ.fetch_add(1, Ordering::Relaxed); let channel_msg = ChannelMessage { id: format!("irc_{}_{seq}", chrono::Utc::now().timestamp_millis()), - sender: reply_to, + sender: sender_nick.to_string(), + reply_target: reply_to, content, channel: "irc".to_string(), timestamp: std::time::SystemTime::now() @@ -807,18 +810,18 @@ mod tests { #[test] fn specific_user_allowed() { - let ch = IrcChannel::new( - "irc.test".into(), - 6697, - "bot".into(), - None, - vec![], - vec!["alice".into(), "bob".into()], - None, - None, - None, - true, - ); + let ch = IrcChannel::new(IrcChannelConfig { + server: "irc.test".into(), + port: 6697, + nickname: "bot".into(), + username: None, + channels: vec![], + allowed_users: vec!["alice".into(), "bob".into()], + server_password: None, + nickserv_password: None, + sasl_password: None, + verify_tls: true, + }); assert!(ch.is_user_allowed("alice")); assert!(ch.is_user_allowed("bob")); assert!(!ch.is_user_allowed("eve")); @@ -826,18 +829,18 @@ mod tests { #[test] fn allowlist_case_insensitive() { - let ch = IrcChannel::new( - "irc.test".into(), - 6697, - "bot".into(), - None, - vec![], - vec!["Alice".into()], - None, - None, - None, - true, - ); + let ch = IrcChannel::new(IrcChannelConfig { + server: "irc.test".into(), + port: 6697, + nickname: "bot".into(), + username: None, + channels: vec![], + allowed_users: vec!["Alice".into()], + server_password: None, + nickserv_password: None, + sasl_password: None, + verify_tls: true, + }); assert!(ch.is_user_allowed("alice")); assert!(ch.is_user_allowed("ALICE")); assert!(ch.is_user_allowed("Alice")); @@ -845,18 +848,18 @@ mod tests { #[test] fn empty_allowlist_denies_all() { - let ch = IrcChannel::new( - "irc.test".into(), - 6697, - "bot".into(), - None, - vec![], - vec![], - None, - None, - None, - true, - ); + let ch = IrcChannel::new(IrcChannelConfig { + server: "irc.test".into(), + port: 6697, + nickname: "bot".into(), + username: None, + channels: vec![], + allowed_users: vec![], + server_password: None, + nickserv_password: None, + sasl_password: None, + verify_tls: true, + }); assert!(!ch.is_user_allowed("anyone")); } @@ -864,35 +867,35 @@ mod tests { #[test] fn new_defaults_username_to_nickname() { - let ch = IrcChannel::new( - "irc.test".into(), - 6697, - "mybot".into(), - None, - vec![], - vec![], - None, - None, - None, - true, - ); + let ch = IrcChannel::new(IrcChannelConfig { + server: "irc.test".into(), + port: 6697, + nickname: "mybot".into(), + username: None, + channels: vec![], + allowed_users: vec![], + server_password: None, + nickserv_password: None, + sasl_password: None, + verify_tls: true, + }); assert_eq!(ch.username, "mybot"); } #[test] fn new_uses_explicit_username() { - let ch = IrcChannel::new( - "irc.test".into(), - 6697, - "mybot".into(), - Some("customuser".into()), - vec![], - vec![], - None, - None, - None, - true, - ); + let ch = IrcChannel::new(IrcChannelConfig { + server: "irc.test".into(), + port: 6697, + nickname: "mybot".into(), + username: Some("customuser".into()), + channels: vec![], + allowed_users: vec![], + server_password: None, + nickserv_password: None, + sasl_password: None, + verify_tls: true, + }); assert_eq!(ch.username, "customuser"); assert_eq!(ch.nickname, "mybot"); } @@ -905,18 +908,18 @@ mod tests { #[test] fn new_stores_all_fields() { - let ch = IrcChannel::new( - "irc.example.com".into(), - 6697, - "zcbot".into(), - Some("zeroclaw".into()), - vec!["#test".into()], - vec!["alice".into()], - Some("serverpass".into()), - Some("nspass".into()), - Some("saslpass".into()), - false, - ); + let ch = IrcChannel::new(IrcChannelConfig { + server: "irc.example.com".into(), + port: 6697, + nickname: "zcbot".into(), + username: Some("zeroclaw".into()), + channels: vec!["#test".into()], + allowed_users: vec!["alice".into()], + server_password: Some("serverpass".into()), + nickserv_password: Some("nspass".into()), + sasl_password: Some("saslpass".into()), + verify_tls: false, + }); assert_eq!(ch.server, "irc.example.com"); assert_eq!(ch.port, 6697); assert_eq!(ch.nickname, "zcbot"); @@ -995,17 +998,17 @@ nickname = "bot" // ── Helpers ───────────────────────────────────────────── fn make_channel() -> IrcChannel { - IrcChannel::new( - "irc.example.com".into(), - 6697, - "zcbot".into(), - None, - vec!["#zeroclaw".into()], - vec!["*".into()], - None, - None, - None, - true, - ) + IrcChannel::new(IrcChannelConfig { + server: "irc.example.com".into(), + port: 6697, + nickname: "zcbot".into(), + username: None, + channels: vec!["#zeroclaw".into()], + allowed_users: vec!["*".into()], + server_password: None, + nickserv_password: None, + sasl_password: None, + verify_tls: true, + }) } } diff --git a/src/channels/lark.rs b/src/channels/lark.rs index 4e9e679..5f929f8 100644 --- a/src/channels/lark.rs +++ b/src/channels/lark.rs @@ -1,21 +1,152 @@ use super::traits::{Channel, ChannelMessage}; use async_trait::async_trait; +use futures_util::{SinkExt, StreamExt}; +use prost::Message as ProstMessage; +use std::collections::HashMap; use std::sync::Arc; +use std::time::{Duration, Instant}; use tokio::sync::RwLock; +use tokio_tungstenite::tungstenite::Message as WsMsg; use uuid::Uuid; const FEISHU_BASE_URL: &str = "https://open.feishu.cn/open-apis"; +const FEISHU_WS_BASE_URL: &str = "https://open.feishu.cn"; +const LARK_BASE_URL: &str = "https://open.larksuite.com/open-apis"; +const LARK_WS_BASE_URL: &str = "https://open.larksuite.com"; -/// Lark/Feishu channel — receives events via HTTP callback, sends via Open API +// ───────────────────────────────────────────────────────────────────────────── +// Feishu WebSocket long-connection: pbbp2.proto frame codec +// ───────────────────────────────────────────────────────────────────────────── + +#[derive(Clone, PartialEq, prost::Message)] +struct PbHeader { + #[prost(string, tag = "1")] + pub key: String, + #[prost(string, tag = "2")] + pub value: String, +} + +/// Feishu WS frame (pbbp2.proto). +/// method=0 → CONTROL (ping/pong) method=1 → DATA (events) +#[derive(Clone, PartialEq, prost::Message)] +struct PbFrame { + #[prost(uint64, tag = "1")] + pub seq_id: u64, + #[prost(uint64, tag = "2")] + pub log_id: u64, + #[prost(int32, tag = "3")] + pub service: i32, + #[prost(int32, tag = "4")] + pub method: i32, + #[prost(message, repeated, tag = "5")] + pub headers: Vec, + #[prost(bytes = "vec", optional, tag = "8")] + pub payload: Option>, +} + +impl PbFrame { + fn header_value<'a>(&'a self, key: &str) -> &'a str { + self.headers + .iter() + .find(|h| h.key == key) + .map(|h| h.value.as_str()) + .unwrap_or("") + } +} + +/// Server-sent client config (parsed from pong payload) +#[derive(Debug, serde::Deserialize, Default, Clone)] +struct WsClientConfig { + #[serde(rename = "PingInterval")] + ping_interval: Option, +} + +/// POST /callback/ws/endpoint response +#[derive(Debug, serde::Deserialize)] +struct WsEndpointResp { + code: i32, + #[serde(default)] + msg: Option, + #[serde(default)] + data: Option, +} + +#[derive(Debug, serde::Deserialize)] +struct WsEndpoint { + #[serde(rename = "URL")] + url: String, + #[serde(rename = "ClientConfig")] + client_config: Option, +} + +/// LarkEvent envelope (method=1 / type=event payload) +#[derive(Debug, serde::Deserialize)] +struct LarkEvent { + header: LarkEventHeader, + event: serde_json::Value, +} + +#[derive(Debug, serde::Deserialize)] +struct LarkEventHeader { + event_type: String, + #[allow(dead_code)] + event_id: String, +} + +#[derive(Debug, serde::Deserialize)] +struct MsgReceivePayload { + sender: LarkSender, + message: LarkMessage, +} + +#[derive(Debug, serde::Deserialize)] +struct LarkSender { + sender_id: LarkSenderId, + #[serde(default)] + sender_type: String, +} + +#[derive(Debug, serde::Deserialize, Default)] +struct LarkSenderId { + open_id: Option, +} + +#[derive(Debug, serde::Deserialize)] +struct LarkMessage { + message_id: String, + chat_id: String, + chat_type: String, + message_type: String, + #[serde(default)] + content: String, + #[serde(default)] + mentions: Vec, +} + +/// Heartbeat timeout for WS connection — must be larger than ping_interval (default 120 s). +/// If no binary frame (pong or event) is received within this window, reconnect. +const WS_HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(300); + +/// Lark/Feishu channel. +/// +/// Supports two receive modes (configured via `receive_mode` in config): +/// - **`websocket`** (default): persistent WSS long-connection; no public URL needed. +/// - **`webhook`**: HTTP callback server; requires a public HTTPS endpoint. pub struct LarkChannel { app_id: String, app_secret: String, verification_token: String, - port: u16, + port: Option, allowed_users: Vec, + /// When true, use Feishu (CN) endpoints; when false, use Lark (international). + use_feishu: bool, + /// How to receive events: WebSocket long-connection or HTTP webhook. + receive_mode: crate::config::schema::LarkReceiveMode, client: reqwest::Client, /// Cached tenant access token tenant_token: Arc>>, + /// Dedup set: WS message_ids seen in last ~30 min to prevent double-dispatch + ws_seen_ids: Arc>>, } impl LarkChannel { @@ -23,7 +154,7 @@ impl LarkChannel { app_id: String, app_secret: String, verification_token: String, - port: u16, + port: Option, allowed_users: Vec, ) -> Self { Self { @@ -32,11 +163,310 @@ impl LarkChannel { verification_token, port, allowed_users, + use_feishu: true, + receive_mode: crate::config::schema::LarkReceiveMode::default(), client: reqwest::Client::new(), tenant_token: Arc::new(RwLock::new(None)), + ws_seen_ids: Arc::new(RwLock::new(HashMap::new())), } } + /// Build from `LarkConfig` (preserves `use_feishu` and `receive_mode`). + pub fn from_config(config: &crate::config::schema::LarkConfig) -> Self { + let mut ch = Self::new( + config.app_id.clone(), + config.app_secret.clone(), + config.verification_token.clone().unwrap_or_default(), + config.port, + config.allowed_users.clone(), + ); + ch.use_feishu = config.use_feishu; + ch.receive_mode = config.receive_mode.clone(); + ch + } + + fn api_base(&self) -> &'static str { + if self.use_feishu { + FEISHU_BASE_URL + } else { + LARK_BASE_URL + } + } + + fn ws_base(&self) -> &'static str { + if self.use_feishu { + FEISHU_WS_BASE_URL + } else { + LARK_WS_BASE_URL + } + } + + fn tenant_access_token_url(&self) -> String { + format!("{}/auth/v3/tenant_access_token/internal", self.api_base()) + } + + fn send_message_url(&self) -> String { + format!("{}/im/v1/messages?receive_id_type=chat_id", self.api_base()) + } + + /// POST /callback/ws/endpoint → (wss_url, client_config) + async fn get_ws_endpoint(&self) -> anyhow::Result<(String, WsClientConfig)> { + let resp = self + .client + .post(format!("{}/callback/ws/endpoint", self.ws_base())) + .header("locale", if self.use_feishu { "zh" } else { "en" }) + .json(&serde_json::json!({ + "AppID": self.app_id, + "AppSecret": self.app_secret, + })) + .send() + .await? + .json::() + .await?; + if resp.code != 0 { + anyhow::bail!( + "Lark WS endpoint failed: code={} msg={}", + resp.code, + resp.msg.as_deref().unwrap_or("(none)") + ); + } + let ep = resp + .data + .ok_or_else(|| anyhow::anyhow!("Lark WS endpoint: empty data"))?; + Ok((ep.url, ep.client_config.unwrap_or_default())) + } + + /// WS long-connection event loop. Returns Ok(()) when the connection closes + /// (the caller reconnects). + #[allow(clippy::too_many_lines)] + async fn listen_ws(&self, tx: tokio::sync::mpsc::Sender) -> anyhow::Result<()> { + let (wss_url, client_config) = self.get_ws_endpoint().await?; + let service_id = wss_url + .split('?') + .nth(1) + .and_then(|qs| { + qs.split('&') + .find(|kv| kv.starts_with("service_id=")) + .and_then(|kv| kv.split('=').nth(1)) + .and_then(|v| v.parse::().ok()) + }) + .unwrap_or(0); + tracing::info!("Lark: connecting to {wss_url}"); + + let (ws_stream, _) = tokio_tungstenite::connect_async(&wss_url).await?; + let (mut write, mut read) = ws_stream.split(); + tracing::info!("Lark: WS connected (service_id={service_id})"); + + let mut ping_secs = client_config.ping_interval.unwrap_or(120).max(10); + let mut hb_interval = tokio::time::interval(Duration::from_secs(ping_secs)); + let mut timeout_check = tokio::time::interval(Duration::from_secs(10)); + hb_interval.tick().await; // consume immediate tick + + let mut seq: u64 = 0; + let mut last_recv = Instant::now(); + + // Send initial ping immediately (like the official SDK) so the server + // starts responding with pongs and we can calibrate the ping_interval. + seq = seq.wrapping_add(1); + let initial_ping = PbFrame { + seq_id: seq, + log_id: 0, + service: service_id, + method: 0, + headers: vec![PbHeader { + key: "type".into(), + value: "ping".into(), + }], + payload: None, + }; + if write + .send(WsMsg::Binary(initial_ping.encode_to_vec())) + .await + .is_err() + { + anyhow::bail!("Lark: initial ping failed"); + } + // message_id → (fragment_slots, created_at) for multi-part reassembly + type FragEntry = (Vec>>, Instant); + let mut frag_cache: HashMap = HashMap::new(); + + loop { + tokio::select! { + biased; + + _ = hb_interval.tick() => { + seq = seq.wrapping_add(1); + let ping = PbFrame { + seq_id: seq, log_id: 0, service: service_id, method: 0, + headers: vec![PbHeader { key: "type".into(), value: "ping".into() }], + payload: None, + }; + if write.send(WsMsg::Binary(ping.encode_to_vec())).await.is_err() { + tracing::warn!("Lark: ping failed, reconnecting"); + break; + } + // GC stale fragments > 5 min + let cutoff = Instant::now().checked_sub(Duration::from_secs(300)).unwrap_or(Instant::now()); + frag_cache.retain(|_, (_, ts)| *ts > cutoff); + } + + _ = timeout_check.tick() => { + if last_recv.elapsed() > WS_HEARTBEAT_TIMEOUT { + tracing::warn!("Lark: heartbeat timeout, reconnecting"); + break; + } + } + + msg = read.next() => { + let raw = match msg { + Some(Ok(WsMsg::Binary(b))) => { last_recv = Instant::now(); b } + Some(Ok(WsMsg::Ping(d))) => { let _ = write.send(WsMsg::Pong(d)).await; continue; } + Some(Ok(WsMsg::Close(_))) | None => { tracing::info!("Lark: WS closed — reconnecting"); break; } + Some(Err(e)) => { tracing::error!("Lark: WS read error: {e}"); break; } + _ => continue, + }; + + let frame = match PbFrame::decode(&raw[..]) { + Ok(f) => f, + Err(e) => { tracing::error!("Lark: proto decode: {e}"); continue; } + }; + + // CONTROL frame + if frame.method == 0 { + if frame.header_value("type") == "pong" { + if let Some(p) = &frame.payload { + if let Ok(cfg) = serde_json::from_slice::(p) { + if let Some(secs) = cfg.ping_interval { + let secs = secs.max(10); + if secs != ping_secs { + ping_secs = secs; + hb_interval = tokio::time::interval(Duration::from_secs(ping_secs)); + tracing::info!("Lark: ping_interval → {ping_secs}s"); + } + } + } + } + } + continue; + } + + // DATA frame + let msg_type = frame.header_value("type").to_string(); + let msg_id = frame.header_value("message_id").to_string(); + let sum = frame.header_value("sum").parse::().unwrap_or(1); + let seq_num = frame.header_value("seq").parse::().unwrap_or(0); + + // ACK immediately (Feishu requires within 3 s) + { + let mut ack = frame.clone(); + ack.payload = Some(br#"{"code":200,"headers":{},"data":[]}"#.to_vec()); + ack.headers.push(PbHeader { key: "biz_rt".into(), value: "0".into() }); + let _ = write.send(WsMsg::Binary(ack.encode_to_vec())).await; + } + + // Fragment reassembly + let sum = if sum == 0 { 1 } else { sum }; + let payload: Vec = if sum == 1 || msg_id.is_empty() || seq_num >= sum { + frame.payload.clone().unwrap_or_default() + } else { + let entry = frag_cache.entry(msg_id.clone()) + .or_insert_with(|| (vec![None; sum], Instant::now())); + if entry.0.len() != sum { *entry = (vec![None; sum], Instant::now()); } + entry.0[seq_num] = frame.payload.clone(); + if entry.0.iter().all(|s| s.is_some()) { + let full: Vec = entry.0.iter() + .flat_map(|s| s.as_deref().unwrap_or(&[])) + .copied().collect(); + frag_cache.remove(&msg_id); + full + } else { continue; } + }; + + if msg_type != "event" { continue; } + + let event: LarkEvent = match serde_json::from_slice(&payload) { + Ok(e) => e, + Err(e) => { tracing::error!("Lark: event JSON: {e}"); continue; } + }; + if event.header.event_type != "im.message.receive_v1" { continue; } + + let recv: MsgReceivePayload = match serde_json::from_value(event.event) { + Ok(r) => r, + Err(e) => { tracing::error!("Lark: payload parse: {e}"); continue; } + }; + + if recv.sender.sender_type == "app" || recv.sender.sender_type == "bot" { continue; } + + let sender_open_id = recv.sender.sender_id.open_id.as_deref().unwrap_or(""); + if !self.is_user_allowed(sender_open_id) { + tracing::warn!("Lark WS: ignoring {sender_open_id} (not in allowed_users)"); + continue; + } + + let lark_msg = &recv.message; + + // Dedup + { + let now = Instant::now(); + let mut seen = self.ws_seen_ids.write().await; + // GC + seen.retain(|_, t| now.duration_since(*t) < Duration::from_secs(30 * 60)); + if seen.contains_key(&lark_msg.message_id) { + tracing::debug!("Lark WS: dup {}", lark_msg.message_id); + continue; + } + seen.insert(lark_msg.message_id.clone(), now); + } + + // Decode content by type (mirrors clawdbot-feishu parsing) + let text = match lark_msg.message_type.as_str() { + "text" => { + let v: serde_json::Value = match serde_json::from_str(&lark_msg.content) { + Ok(v) => v, + Err(_) => continue, + }; + match v.get("text").and_then(|t| t.as_str()).filter(|s| !s.is_empty()) { + Some(t) => t.to_string(), + None => continue, + } + } + "post" => match parse_post_content(&lark_msg.content) { + Some(t) => t, + None => continue, + }, + _ => { tracing::debug!("Lark WS: skipping unsupported type '{}'", lark_msg.message_type); continue; } + }; + + // Strip @_user_N placeholders + let text = strip_at_placeholders(&text); + let text = text.trim().to_string(); + if text.is_empty() { continue; } + + // Group-chat: only respond when explicitly @-mentioned + if lark_msg.chat_type == "group" && !should_respond_in_group(&lark_msg.mentions) { + continue; + } + + let channel_msg = ChannelMessage { + id: Uuid::new_v4().to_string(), + sender: lark_msg.chat_id.clone(), + reply_target: lark_msg.chat_id.clone(), + content: text, + channel: "lark".to_string(), + timestamp: std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_secs(), + }; + + tracing::debug!("Lark WS: message in {}", lark_msg.chat_id); + if tx.send(channel_msg).await.is_err() { break; } + } + } + } + Ok(()) + } + /// Check if a user open_id is allowed fn is_user_allowed(&self, open_id: &str) -> bool { self.allowed_users.iter().any(|u| u == "*" || u == open_id) @@ -52,7 +482,7 @@ impl LarkChannel { } } - let url = format!("{FEISHU_BASE_URL}/auth/v3/tenant_access_token/internal"); + let url = self.tenant_access_token_url(); let body = serde_json::json!({ "app_id": self.app_id, "app_secret": self.app_secret, @@ -127,31 +557,41 @@ impl LarkChannel { return messages; } - // Extract message content (text only) + // Extract message content (text and post supported) let msg_type = event .pointer("/message/message_type") .and_then(|t| t.as_str()) .unwrap_or(""); - if msg_type != "text" { - tracing::debug!("Lark: skipping non-text message type: {msg_type}"); - return messages; - } - let content_str = event .pointer("/message/content") .and_then(|c| c.as_str()) .unwrap_or(""); - // content is a JSON string like "{\"text\":\"hello\"}" - let text = serde_json::from_str::(content_str) - .ok() - .and_then(|v| v.get("text").and_then(|t| t.as_str()).map(String::from)) - .unwrap_or_default(); - - if text.is_empty() { - return messages; - } + let text: String = match msg_type { + "text" => { + let extracted = serde_json::from_str::(content_str) + .ok() + .and_then(|v| { + v.get("text") + .and_then(|t| t.as_str()) + .filter(|s| !s.is_empty()) + .map(String::from) + }); + match extracted { + Some(t) => t, + None => return messages, + } + } + "post" => match parse_post_content(content_str) { + Some(t) => t, + None => return messages, + }, + _ => { + tracing::debug!("Lark: skipping unsupported message type: {msg_type}"); + return messages; + } + }; let timestamp = event .pointer("/message/create_time") @@ -174,6 +614,7 @@ impl LarkChannel { messages.push(ChannelMessage { id: Uuid::new_v4().to_string(), sender: chat_id.to_string(), + reply_target: chat_id.to_string(), content: text, channel: "lark".to_string(), timestamp, @@ -191,7 +632,7 @@ impl Channel for LarkChannel { async fn send(&self, message: &str, recipient: &str) -> anyhow::Result<()> { let token = self.get_tenant_access_token().await?; - let url = format!("{FEISHU_BASE_URL}/im/v1/messages?receive_id_type=chat_id"); + let url = self.send_message_url(); let content = serde_json::json!({ "text": message }).to_string(); let body = serde_json::json!({ @@ -238,6 +679,25 @@ impl Channel for LarkChannel { } async fn listen(&self, tx: tokio::sync::mpsc::Sender) -> anyhow::Result<()> { + use crate::config::schema::LarkReceiveMode; + match self.receive_mode { + LarkReceiveMode::Websocket => self.listen_ws(tx).await, + LarkReceiveMode::Webhook => self.listen_http(tx).await, + } + } + + async fn health_check(&self) -> bool { + self.get_tenant_access_token().await.is_ok() + } +} + +impl LarkChannel { + /// HTTP callback server (legacy — requires a public endpoint). + /// Use `listen()` (WS long-connection) for new deployments. + pub async fn listen_http( + &self, + tx: tokio::sync::mpsc::Sender, + ) -> anyhow::Result<()> { use axum::{extract::State, routing::post, Json, Router}; #[derive(Clone)] @@ -282,13 +742,17 @@ impl Channel for LarkChannel { (StatusCode::OK, "ok").into_response() } + let port = self.port.ok_or_else(|| { + anyhow::anyhow!("Lark webhook mode requires `port` to be set in [channels_config.lark]") + })?; + let state = AppState { verification_token: self.verification_token.clone(), channel: Arc::new(LarkChannel::new( self.app_id.clone(), self.app_secret.clone(), self.verification_token.clone(), - self.port, + None, self.allowed_users.clone(), )), tx, @@ -298,7 +762,7 @@ impl Channel for LarkChannel { .route("/lark", post(handle_event)) .with_state(state); - let addr = std::net::SocketAddr::from(([0, 0, 0, 0], self.port)); + let addr = std::net::SocketAddr::from(([0, 0, 0, 0], port)); tracing::info!("Lark event callback server listening on {addr}"); let listener = tokio::net::TcpListener::bind(addr).await?; @@ -306,10 +770,110 @@ impl Channel for LarkChannel { Ok(()) } +} - async fn health_check(&self) -> bool { - self.get_tenant_access_token().await.is_ok() +// ───────────────────────────────────────────────────────────────────────────── +// WS helper functions +// ───────────────────────────────────────────────────────────────────────────── + +/// Flatten a Feishu `post` rich-text message to plain text. +/// +/// Returns `None` when the content cannot be parsed or yields no usable text, +/// so callers can simply `continue` rather than forwarding a meaningless +/// placeholder string to the agent. +fn parse_post_content(content: &str) -> Option { + let parsed = serde_json::from_str::(content).ok()?; + let locale = parsed + .get("zh_cn") + .or_else(|| parsed.get("en_us")) + .or_else(|| { + parsed + .as_object() + .and_then(|m| m.values().find(|v| v.is_object())) + })?; + + let mut text = String::new(); + + if let Some(title) = locale + .get("title") + .and_then(|t| t.as_str()) + .filter(|s| !s.is_empty()) + { + text.push_str(title); + text.push_str("\n\n"); } + + if let Some(paragraphs) = locale.get("content").and_then(|c| c.as_array()) { + for para in paragraphs { + if let Some(elements) = para.as_array() { + for el in elements { + match el.get("tag").and_then(|t| t.as_str()).unwrap_or("") { + "text" => { + if let Some(t) = el.get("text").and_then(|t| t.as_str()) { + text.push_str(t); + } + } + "a" => { + text.push_str( + el.get("text") + .and_then(|t| t.as_str()) + .filter(|s| !s.is_empty()) + .or_else(|| el.get("href").and_then(|h| h.as_str())) + .unwrap_or(""), + ); + } + "at" => { + let n = el + .get("user_name") + .and_then(|n| n.as_str()) + .or_else(|| el.get("user_id").and_then(|i| i.as_str())) + .unwrap_or("user"); + text.push('@'); + text.push_str(n); + } + _ => {} + } + } + text.push('\n'); + } + } + } + + let result = text.trim().to_string(); + if result.is_empty() { + None + } else { + Some(result) + } +} + +/// Remove `@_user_N` placeholder tokens injected by Feishu in group chats. +fn strip_at_placeholders(text: &str) -> String { + let mut result = String::with_capacity(text.len()); + let mut chars = text.char_indices().peekable(); + while let Some((_, ch)) = chars.next() { + if ch == '@' { + let rest: String = chars.clone().map(|(_, c)| c).collect(); + if let Some(after) = rest.strip_prefix("_user_") { + let skip = + "_user_".len() + after.chars().take_while(|c| c.is_ascii_digit()).count(); + for _ in 0..=skip { + chars.next(); + } + if chars.peek().map(|(_, c)| *c == ' ').unwrap_or(false) { + chars.next(); + } + continue; + } + } + result.push(ch); + } + result +} + +/// In group chats, only respond when the bot is explicitly @-mentioned. +fn should_respond_in_group(mentions: &[serde_json::Value]) -> bool { + !mentions.is_empty() } #[cfg(test)] @@ -321,7 +885,7 @@ mod tests { "cli_test_app_id".into(), "test_app_secret".into(), "test_verification_token".into(), - 9898, + None, vec!["ou_testuser123".into()], ) } @@ -345,7 +909,7 @@ mod tests { "id".into(), "secret".into(), "token".into(), - 9898, + None, vec!["*".into()], ); assert!(ch.is_user_allowed("ou_anyone")); @@ -353,7 +917,7 @@ mod tests { #[test] fn lark_user_denied_empty() { - let ch = LarkChannel::new("id".into(), "secret".into(), "token".into(), 9898, vec![]); + let ch = LarkChannel::new("id".into(), "secret".into(), "token".into(), None, vec![]); assert!(!ch.is_user_allowed("ou_anyone")); } @@ -426,7 +990,7 @@ mod tests { "id".into(), "secret".into(), "token".into(), - 9898, + None, vec!["*".into()], ); let payload = serde_json::json!({ @@ -451,7 +1015,7 @@ mod tests { "id".into(), "secret".into(), "token".into(), - 9898, + None, vec!["*".into()], ); let payload = serde_json::json!({ @@ -488,7 +1052,7 @@ mod tests { "id".into(), "secret".into(), "token".into(), - 9898, + None, vec!["*".into()], ); let payload = serde_json::json!({ @@ -512,7 +1076,7 @@ mod tests { "id".into(), "secret".into(), "token".into(), - 9898, + None, vec!["*".into()], ); let payload = serde_json::json!({ @@ -550,7 +1114,7 @@ mod tests { "id".into(), "secret".into(), "token".into(), - 9898, + None, vec!["*".into()], ); let payload = serde_json::json!({ @@ -571,7 +1135,7 @@ mod tests { #[test] fn lark_config_serde() { - use crate::config::schema::LarkConfig; + use crate::config::schema::{LarkConfig, LarkReceiveMode}; let lc = LarkConfig { app_id: "cli_app123".into(), app_secret: "secret456".into(), @@ -579,6 +1143,8 @@ mod tests { verification_token: Some("vtoken789".into()), allowed_users: vec!["ou_user1".into(), "ou_user2".into()], use_feishu: false, + receive_mode: LarkReceiveMode::default(), + port: None, }; let json = serde_json::to_string(&lc).unwrap(); let parsed: LarkConfig = serde_json::from_str(&json).unwrap(); @@ -590,7 +1156,7 @@ mod tests { #[test] fn lark_config_toml_roundtrip() { - use crate::config::schema::LarkConfig; + use crate::config::schema::{LarkConfig, LarkReceiveMode}; let lc = LarkConfig { app_id: "app".into(), app_secret: "secret".into(), @@ -598,6 +1164,8 @@ mod tests { verification_token: Some("tok".into()), allowed_users: vec!["*".into()], use_feishu: false, + receive_mode: LarkReceiveMode::Webhook, + port: Some(9898), }; let toml_str = toml::to_string(&lc).unwrap(); let parsed: LarkConfig = toml::from_str(&toml_str).unwrap(); @@ -608,11 +1176,36 @@ mod tests { #[test] fn lark_config_defaults_optional_fields() { - use crate::config::schema::LarkConfig; + use crate::config::schema::{LarkConfig, LarkReceiveMode}; let json = r#"{"app_id":"a","app_secret":"s"}"#; let parsed: LarkConfig = serde_json::from_str(json).unwrap(); assert!(parsed.verification_token.is_none()); assert!(parsed.allowed_users.is_empty()); + assert_eq!(parsed.receive_mode, LarkReceiveMode::Websocket); + assert!(parsed.port.is_none()); + } + + #[test] + fn lark_from_config_preserves_mode_and_region() { + use crate::config::schema::{LarkConfig, LarkReceiveMode}; + + let cfg = LarkConfig { + app_id: "cli_app123".into(), + app_secret: "secret456".into(), + encrypt_key: None, + verification_token: Some("vtoken789".into()), + allowed_users: vec!["*".into()], + use_feishu: false, + receive_mode: LarkReceiveMode::Webhook, + port: Some(9898), + }; + + let ch = LarkChannel::from_config(&cfg); + + assert_eq!(ch.api_base(), LARK_BASE_URL); + assert_eq!(ch.ws_base(), LARK_WS_BASE_URL); + assert_eq!(ch.receive_mode, LarkReceiveMode::Webhook); + assert_eq!(ch.port, Some(9898)); } #[test] @@ -622,7 +1215,7 @@ mod tests { "id".into(), "secret".into(), "token".into(), - 9898, + None, vec!["*".into()], ); let payload = serde_json::json!({ diff --git a/src/channels/matrix.rs b/src/channels/matrix.rs index 9f8924c..4f34bcf 100644 --- a/src/channels/matrix.rs +++ b/src/channels/matrix.rs @@ -230,6 +230,7 @@ impl Channel for MatrixChannel { let msg = ChannelMessage { id: format!("mx_{}", chrono::Utc::now().timestamp_millis()), sender: event.sender.clone(), + reply_target: event.sender.clone(), content: body.clone(), channel: "matrix".to_string(), timestamp: std::time::SystemTime::now() diff --git a/src/channels/mod.rs b/src/channels/mod.rs index d8fd612..fc9a7d2 100644 --- a/src/channels/mod.rs +++ b/src/channels/mod.rs @@ -69,10 +69,19 @@ fn conversation_memory_key(msg: &traits::ChannelMessage) -> String { format!("{}_{}_{}", msg.channel, msg.sender, msg.id) } +fn channel_delivery_instructions(channel_name: &str) -> Option<&'static str> { + match channel_name { + "telegram" => Some( + "When responding on Telegram, include media markers for files or URLs that should be sent as attachments. Use one marker per attachment with this exact syntax: [IMAGE:], [DOCUMENT:], [VIDEO:], [AUDIO:], or [VOICE:]. Keep normal user-facing text outside markers and never wrap markers in code fences.", + ), + _ => None, + } +} + async fn build_memory_context(mem: &dyn Memory, user_msg: &str) -> String { let mut context = String::new(); - if let Ok(entries) = mem.recall(user_msg, 5).await { + if let Ok(entries) = mem.recall(user_msg, 5, None).await { if !entries.is_empty() { context.push_str("[Memory context]\n"); for entry in &entries { @@ -158,6 +167,7 @@ async fn process_channel_message(ctx: Arc, msg: traits::C &autosave_key, &msg.content, crate::memory::MemoryCategory::Conversation, + None, ) .await; } @@ -171,7 +181,7 @@ async fn process_channel_message(ctx: Arc, msg: traits::C let target_channel = ctx.channels_by_name.get(&msg.channel).cloned(); if let Some(channel) = target_channel.as_ref() { - if let Err(e) = channel.start_typing(&msg.sender).await { + if let Err(e) = channel.start_typing(&msg.reply_target).await { tracing::debug!("Failed to start typing on {}: {e}", channel.name()); } } @@ -184,6 +194,10 @@ async fn process_channel_message(ctx: Arc, msg: traits::C ChatMessage::user(&enriched_message), ]; + if let Some(instructions) = channel_delivery_instructions(&msg.channel) { + history.push(ChatMessage::system(instructions)); + } + let llm_result = tokio::time::timeout( Duration::from_secs(CHANNEL_MESSAGE_TIMEOUT_SECS), run_tool_call_loop( @@ -200,7 +214,7 @@ async fn process_channel_message(ctx: Arc, msg: traits::C .await; if let Some(channel) = target_channel.as_ref() { - if let Err(e) = channel.stop_typing(&msg.sender).await { + if let Err(e) = channel.stop_typing(&msg.reply_target).await { tracing::debug!("Failed to stop typing on {}: {e}", channel.name()); } } @@ -224,7 +238,9 @@ async fn process_channel_message(ctx: Arc, msg: traits::C started_at.elapsed().as_millis() ); if let Some(channel) = target_channel.as_ref() { - let _ = channel.send(&format!("⚠️ Error: {e}"), &msg.sender).await; + let _ = channel + .send(&format!("⚠️ Error: {e}"), &msg.reply_target) + .await; } } Err(_) => { @@ -241,7 +257,7 @@ async fn process_channel_message(ctx: Arc, msg: traits::C let _ = channel .send( "⚠️ Request timed out while waiting for the model. Please try again.", - &msg.sender, + &msg.reply_target, ) .await; } @@ -483,6 +499,16 @@ pub fn build_system_prompt( std::env::consts::OS, ); + // ── 8. Channel Capabilities ───────────────────────────────────── + prompt.push_str("## Channel Capabilities\n\n"); + prompt.push_str( + "- You are running as a Discord bot. You CAN and do send messages to Discord channels.\n", + ); + prompt.push_str("- When someone messages you on Discord, your response is automatically sent back to Discord.\n"); + prompt.push_str("- You do NOT need to ask permission to respond — just respond directly.\n"); + prompt.push_str("- NEVER repeat, describe, or echo credentials, tokens, API keys, or secrets in your responses.\n"); + prompt.push_str("- If a tool output contains credentials, they have already been redacted — do not mention them.\n\n"); + if prompt.is_empty() { "You are ZeroClaw, a fast and efficient AI assistant built in Rust. Be helpful, concise, and direct.".to_string() } else { @@ -619,6 +645,7 @@ pub async fn doctor_channels(config: Config) -> Result<()> { dc.guild_id.clone(), dc.allowed_users.clone(), dc.listen_to_bots, + dc.mention_only, )), )); } @@ -672,32 +699,23 @@ pub async fn doctor_channels(config: Config) -> Result<()> { if let Some(ref irc) = config.channels_config.irc { channels.push(( "IRC", - Arc::new(IrcChannel::new( - irc.server.clone(), - irc.port, - irc.nickname.clone(), - irc.username.clone(), - irc.channels.clone(), - irc.allowed_users.clone(), - irc.server_password.clone(), - irc.nickserv_password.clone(), - irc.sasl_password.clone(), - irc.verify_tls.unwrap_or(true), - )), + Arc::new(IrcChannel::new(irc::IrcChannelConfig { + server: irc.server.clone(), + port: irc.port, + nickname: irc.nickname.clone(), + username: irc.username.clone(), + channels: irc.channels.clone(), + allowed_users: irc.allowed_users.clone(), + server_password: irc.server_password.clone(), + nickserv_password: irc.nickserv_password.clone(), + sasl_password: irc.sasl_password.clone(), + verify_tls: irc.verify_tls.unwrap_or(true), + })), )); } if let Some(ref lk) = config.channels_config.lark { - channels.push(( - "Lark", - Arc::new(LarkChannel::new( - lk.app_id.clone(), - lk.app_secret.clone(), - lk.verification_token.clone().unwrap_or_default(), - 9898, - lk.allowed_users.clone(), - )), - )); + channels.push(("Lark", Arc::new(LarkChannel::from_config(lk)))); } if let Some(ref dt) = config.channels_config.dingtalk { @@ -762,6 +780,7 @@ pub async fn start_channels(config: Config) -> Result<()> { let provider: Arc = Arc::from(providers::create_resilient_provider( &provider_name, config.api_key.as_deref(), + config.api_url.as_deref(), &config.reliability, )?); @@ -860,6 +879,10 @@ pub async fn start_channels(config: Config) -> Result<()> { "schedule", "Manage scheduled tasks (create/list/get/cancel/pause/resume). Supports recurring cron and one-shot delays.", )); + tool_descs.push(( + "pushover", + "Send a Pushover notification to your device. Requires PUSHOVER_TOKEN and PUSHOVER_USER_KEY in .env file.", + )); if !config.agents.is_empty() { tool_descs.push(( "delegate", @@ -909,6 +932,7 @@ pub async fn start_channels(config: Config) -> Result<()> { dc.guild_id.clone(), dc.allowed_users.clone(), dc.listen_to_bots, + dc.mention_only, ))); } @@ -947,28 +971,22 @@ pub async fn start_channels(config: Config) -> Result<()> { } if let Some(ref irc) = config.channels_config.irc { - channels.push(Arc::new(IrcChannel::new( - irc.server.clone(), - irc.port, - irc.nickname.clone(), - irc.username.clone(), - irc.channels.clone(), - irc.allowed_users.clone(), - irc.server_password.clone(), - irc.nickserv_password.clone(), - irc.sasl_password.clone(), - irc.verify_tls.unwrap_or(true), - ))); + channels.push(Arc::new(IrcChannel::new(irc::IrcChannelConfig { + server: irc.server.clone(), + port: irc.port, + nickname: irc.nickname.clone(), + username: irc.username.clone(), + channels: irc.channels.clone(), + allowed_users: irc.allowed_users.clone(), + server_password: irc.server_password.clone(), + nickserv_password: irc.nickserv_password.clone(), + sasl_password: irc.sasl_password.clone(), + verify_tls: irc.verify_tls.unwrap_or(true), + }))); } if let Some(ref lk) = config.channels_config.lark { - channels.push(Arc::new(LarkChannel::new( - lk.app_id.clone(), - lk.app_secret.clone(), - lk.verification_token.clone().unwrap_or_default(), - 9898, - lk.allowed_users.clone(), - ))); + channels.push(Arc::new(LarkChannel::from_config(lk))); } if let Some(ref dt) = config.channels_config.dingtalk { @@ -1242,6 +1260,7 @@ mod tests { traits::ChannelMessage { id: "msg-1".to_string(), sender: "alice".to_string(), + reply_target: "chat-42".to_string(), content: "What is the BTC price now?".to_string(), channel: "test-channel".to_string(), timestamp: 1, @@ -1251,6 +1270,7 @@ mod tests { let sent_messages = channel_impl.sent_messages.lock().await; assert_eq!(sent_messages.len(), 1); + assert!(sent_messages[0].starts_with("chat-42:")); assert!(sent_messages[0].contains("BTC is currently around")); assert!(!sent_messages[0].contains("\"tool_calls\"")); assert!(!sent_messages[0].contains("mock_price")); @@ -1269,6 +1289,7 @@ mod tests { _key: &str, _content: &str, _category: crate::memory::MemoryCategory, + _session_id: Option<&str>, ) -> anyhow::Result<()> { Ok(()) } @@ -1277,6 +1298,7 @@ mod tests { &self, _query: &str, _limit: usize, + _session_id: Option<&str>, ) -> anyhow::Result> { Ok(Vec::new()) } @@ -1288,6 +1310,7 @@ mod tests { async fn list( &self, _category: Option<&crate::memory::MemoryCategory>, + _session_id: Option<&str>, ) -> anyhow::Result> { Ok(Vec::new()) } @@ -1331,6 +1354,7 @@ mod tests { tx.send(traits::ChannelMessage { id: "1".to_string(), sender: "alice".to_string(), + reply_target: "alice".to_string(), content: "hello".to_string(), channel: "test-channel".to_string(), timestamp: 1, @@ -1340,6 +1364,7 @@ mod tests { tx.send(traits::ChannelMessage { id: "2".to_string(), sender: "bob".to_string(), + reply_target: "bob".to_string(), content: "world".to_string(), channel: "test-channel".to_string(), timestamp: 2, @@ -1570,6 +1595,25 @@ mod tests { assert!(truncated.is_char_boundary(truncated.len())); } + #[test] + fn prompt_contains_channel_capabilities() { + let ws = make_workspace(); + let prompt = build_system_prompt(ws.path(), "model", &[], &[], None, None); + + assert!( + prompt.contains("## Channel Capabilities"), + "missing Channel Capabilities section" + ); + assert!( + prompt.contains("running as a Discord bot"), + "missing Discord context" + ); + assert!( + prompt.contains("NEVER repeat, describe, or echo credentials"), + "missing security instruction" + ); + } + #[test] fn prompt_workspace_path() { let ws = make_workspace(); @@ -1583,6 +1627,7 @@ mod tests { let msg = traits::ChannelMessage { id: "msg_abc123".into(), sender: "U123".into(), + reply_target: "C456".into(), content: "hello".into(), channel: "slack".into(), timestamp: 1, @@ -1596,6 +1641,7 @@ mod tests { let msg1 = traits::ChannelMessage { id: "msg_1".into(), sender: "U123".into(), + reply_target: "C456".into(), content: "first".into(), channel: "slack".into(), timestamp: 1, @@ -1603,6 +1649,7 @@ mod tests { let msg2 = traits::ChannelMessage { id: "msg_2".into(), sender: "U123".into(), + reply_target: "C456".into(), content: "second".into(), channel: "slack".into(), timestamp: 2, @@ -1622,6 +1669,7 @@ mod tests { let msg1 = traits::ChannelMessage { id: "msg_1".into(), sender: "U123".into(), + reply_target: "C456".into(), content: "I'm Paul".into(), channel: "slack".into(), timestamp: 1, @@ -1629,6 +1677,7 @@ mod tests { let msg2 = traits::ChannelMessage { id: "msg_2".into(), sender: "U123".into(), + reply_target: "C456".into(), content: "I'm 45".into(), channel: "slack".into(), timestamp: 2, @@ -1638,6 +1687,7 @@ mod tests { &conversation_memory_key(&msg1), &msg1.content, MemoryCategory::Conversation, + None, ) .await .unwrap(); @@ -1645,13 +1695,14 @@ mod tests { &conversation_memory_key(&msg2), &msg2.content, MemoryCategory::Conversation, + None, ) .await .unwrap(); assert_eq!(mem.count().await.unwrap(), 2); - let recalled = mem.recall("45", 5).await.unwrap(); + let recalled = mem.recall("45", 5, None).await.unwrap(); assert!(recalled.iter().any(|entry| entry.content.contains("45"))); } @@ -1659,7 +1710,7 @@ mod tests { async fn build_memory_context_includes_recalled_entries() { let tmp = TempDir::new().unwrap(); let mem = SqliteMemory::new(tmp.path()).unwrap(); - mem.store("age_fact", "Age is 45", MemoryCategory::Conversation) + mem.store("age_fact", "Age is 45", MemoryCategory::Conversation, None) .await .unwrap(); diff --git a/src/channels/slack.rs b/src/channels/slack.rs index fd6b2f0..7f8ee51 100644 --- a/src/channels/slack.rs +++ b/src/channels/slack.rs @@ -161,6 +161,7 @@ impl Channel for SlackChannel { let channel_msg = ChannelMessage { id: format!("slack_{channel_id}_{ts}"), sender: user.to_string(), + reply_target: channel_id.clone(), content: text.to_string(), channel: "slack".to_string(), timestamp: std::time::SystemTime::now() diff --git a/src/channels/telegram.rs b/src/channels/telegram.rs index bfe8dd6..5d25de1 100644 --- a/src/channels/telegram.rs +++ b/src/channels/telegram.rs @@ -51,6 +51,133 @@ fn split_message_for_telegram(message: &str) -> Vec { chunks } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum TelegramAttachmentKind { + Image, + Document, + Video, + Audio, + Voice, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +struct TelegramAttachment { + kind: TelegramAttachmentKind, + target: String, +} + +impl TelegramAttachmentKind { + fn from_marker(marker: &str) -> Option { + match marker.trim().to_ascii_uppercase().as_str() { + "IMAGE" | "PHOTO" => Some(Self::Image), + "DOCUMENT" | "FILE" => Some(Self::Document), + "VIDEO" => Some(Self::Video), + "AUDIO" => Some(Self::Audio), + "VOICE" => Some(Self::Voice), + _ => None, + } + } +} + +fn is_http_url(target: &str) -> bool { + target.starts_with("http://") || target.starts_with("https://") +} + +fn infer_attachment_kind_from_target(target: &str) -> Option { + let normalized = target + .split('?') + .next() + .unwrap_or(target) + .split('#') + .next() + .unwrap_or(target); + + let extension = Path::new(normalized) + .extension() + .and_then(|ext| ext.to_str())? + .to_ascii_lowercase(); + + match extension.as_str() { + "png" | "jpg" | "jpeg" | "gif" | "webp" | "bmp" => Some(TelegramAttachmentKind::Image), + "mp4" | "mov" | "mkv" | "avi" | "webm" => Some(TelegramAttachmentKind::Video), + "mp3" | "m4a" | "wav" | "flac" => Some(TelegramAttachmentKind::Audio), + "ogg" | "oga" | "opus" => Some(TelegramAttachmentKind::Voice), + "pdf" | "txt" | "md" | "csv" | "json" | "zip" | "tar" | "gz" | "doc" | "docx" | "xls" + | "xlsx" | "ppt" | "pptx" => Some(TelegramAttachmentKind::Document), + _ => None, + } +} + +fn parse_path_only_attachment(message: &str) -> Option { + let trimmed = message.trim(); + if trimmed.is_empty() || trimmed.contains('\n') { + return None; + } + + let candidate = trimmed.trim_matches(|c| matches!(c, '`' | '"' | '\'')); + if candidate.chars().any(char::is_whitespace) { + return None; + } + + let candidate = candidate.strip_prefix("file://").unwrap_or(candidate); + let kind = infer_attachment_kind_from_target(candidate)?; + + if !is_http_url(candidate) && !Path::new(candidate).exists() { + return None; + } + + Some(TelegramAttachment { + kind, + target: candidate.to_string(), + }) +} + +fn parse_attachment_markers(message: &str) -> (String, Vec) { + let mut cleaned = String::with_capacity(message.len()); + let mut attachments = Vec::new(); + let mut cursor = 0; + + while cursor < message.len() { + let Some(open_rel) = message[cursor..].find('[') else { + cleaned.push_str(&message[cursor..]); + break; + }; + + let open = cursor + open_rel; + cleaned.push_str(&message[cursor..open]); + + let Some(close_rel) = message[open..].find(']') else { + cleaned.push_str(&message[open..]); + break; + }; + + let close = open + close_rel; + let marker = &message[open + 1..close]; + + let parsed = marker.split_once(':').and_then(|(kind, target)| { + let kind = TelegramAttachmentKind::from_marker(kind)?; + let target = target.trim(); + if target.is_empty() { + return None; + } + Some(TelegramAttachment { + kind, + target: target.to_string(), + }) + }); + + if let Some(attachment) = parsed { + attachments.push(attachment); + } else { + cleaned.push_str(&message[open..=close]); + } + + cursor = close + 1; + } + + (cleaned.trim().to_string(), attachments) +} + /// Telegram channel — long-polls the Bot API for updates pub struct TelegramChannel { bot_token: String, @@ -82,6 +209,216 @@ impl TelegramChannel { identities.into_iter().any(|id| self.is_user_allowed(id)) } + fn parse_update_message(&self, update: &serde_json::Value) -> Option { + let message = update.get("message")?; + + let text = message.get("text").and_then(serde_json::Value::as_str)?; + + let username = message + .get("from") + .and_then(|from| from.get("username")) + .and_then(serde_json::Value::as_str) + .unwrap_or("unknown") + .to_string(); + + let user_id = message + .get("from") + .and_then(|from| from.get("id")) + .and_then(serde_json::Value::as_i64) + .map(|id| id.to_string()); + + let sender_identity = if username == "unknown" { + user_id.clone().unwrap_or_else(|| "unknown".to_string()) + } else { + username.clone() + }; + + let mut identities = vec![username.as_str()]; + if let Some(id) = user_id.as_deref() { + identities.push(id); + } + + if !self.is_any_user_allowed(identities.iter().copied()) { + tracing::warn!( + "Telegram: ignoring message from unauthorized user: username={username}, user_id={}. \ +Allowlist Telegram @username or numeric user ID, then run `zeroclaw onboard --channels-only`.", + user_id.as_deref().unwrap_or("unknown") + ); + return None; + } + + let chat_id = message + .get("chat") + .and_then(|chat| chat.get("id")) + .and_then(serde_json::Value::as_i64) + .map(|id| id.to_string())?; + + let message_id = message + .get("message_id") + .and_then(serde_json::Value::as_i64) + .unwrap_or(0); + + Some(ChannelMessage { + id: format!("telegram_{chat_id}_{message_id}"), + sender: sender_identity, + reply_target: chat_id, + content: text.to_string(), + channel: "telegram".to_string(), + timestamp: std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_secs(), + }) + } + + async fn send_text_chunks(&self, message: &str, chat_id: &str) -> anyhow::Result<()> { + let chunks = split_message_for_telegram(message); + + for (index, chunk) in chunks.iter().enumerate() { + let text = if chunks.len() > 1 { + if index == 0 { + format!("{chunk}\n\n(continues...)") + } else if index == chunks.len() - 1 { + format!("(continued)\n\n{chunk}") + } else { + format!("(continued)\n\n{chunk}\n\n(continues...)") + } + } else { + chunk.to_string() + }; + + let markdown_body = serde_json::json!({ + "chat_id": chat_id, + "text": text, + "parse_mode": "Markdown" + }); + + let markdown_resp = self + .client + .post(self.api_url("sendMessage")) + .json(&markdown_body) + .send() + .await?; + + if markdown_resp.status().is_success() { + if index < chunks.len() - 1 { + tokio::time::sleep(Duration::from_millis(100)).await; + } + continue; + } + + let markdown_status = markdown_resp.status(); + let markdown_err = markdown_resp.text().await.unwrap_or_default(); + tracing::warn!( + status = ?markdown_status, + "Telegram sendMessage with Markdown failed; retrying without parse_mode" + ); + + let plain_body = serde_json::json!({ + "chat_id": chat_id, + "text": text, + }); + let plain_resp = self + .client + .post(self.api_url("sendMessage")) + .json(&plain_body) + .send() + .await?; + + if !plain_resp.status().is_success() { + let plain_status = plain_resp.status(); + let plain_err = plain_resp.text().await.unwrap_or_default(); + anyhow::bail!( + "Telegram sendMessage failed (markdown {}: {}; plain {}: {})", + markdown_status, + markdown_err, + plain_status, + plain_err + ); + } + + if index < chunks.len() - 1 { + tokio::time::sleep(Duration::from_millis(100)).await; + } + } + + Ok(()) + } + + async fn send_media_by_url( + &self, + method: &str, + media_field: &str, + chat_id: &str, + url: &str, + caption: Option<&str>, + ) -> anyhow::Result<()> { + let mut body = serde_json::json!({ + "chat_id": chat_id, + }); + body[media_field] = serde_json::Value::String(url.to_string()); + + if let Some(cap) = caption { + body["caption"] = serde_json::Value::String(cap.to_string()); + } + + let resp = self + .client + .post(self.api_url(method)) + .json(&body) + .send() + .await?; + + if !resp.status().is_success() { + let err = resp.text().await?; + anyhow::bail!("Telegram {method} by URL failed: {err}"); + } + + tracing::info!("Telegram {method} sent to {chat_id}: {url}"); + Ok(()) + } + + async fn send_attachment( + &self, + chat_id: &str, + attachment: &TelegramAttachment, + ) -> anyhow::Result<()> { + let target = attachment.target.trim(); + + if is_http_url(target) { + return match attachment.kind { + TelegramAttachmentKind::Image => { + self.send_photo_by_url(chat_id, target, None).await + } + TelegramAttachmentKind::Document => { + self.send_document_by_url(chat_id, target, None).await + } + TelegramAttachmentKind::Video => { + self.send_video_by_url(chat_id, target, None).await + } + TelegramAttachmentKind::Audio => { + self.send_audio_by_url(chat_id, target, None).await + } + TelegramAttachmentKind::Voice => { + self.send_voice_by_url(chat_id, target, None).await + } + }; + } + + let path = Path::new(target); + if !path.exists() { + anyhow::bail!("Telegram attachment path not found: {target}"); + } + + match attachment.kind { + TelegramAttachmentKind::Image => self.send_photo(chat_id, path, None).await, + TelegramAttachmentKind::Document => self.send_document(chat_id, path, None).await, + TelegramAttachmentKind::Video => self.send_video(chat_id, path, None).await, + TelegramAttachmentKind::Audio => self.send_audio(chat_id, path, None).await, + TelegramAttachmentKind::Voice => self.send_voice(chat_id, path, None).await, + } + } + /// Send a document/file to a Telegram chat pub async fn send_document( &self, @@ -408,6 +745,39 @@ impl TelegramChannel { tracing::info!("Telegram photo (URL) sent to {chat_id}: {url}"); Ok(()) } + + /// Send a video by URL (Telegram will download it) + pub async fn send_video_by_url( + &self, + chat_id: &str, + url: &str, + caption: Option<&str>, + ) -> anyhow::Result<()> { + self.send_media_by_url("sendVideo", "video", chat_id, url, caption) + .await + } + + /// Send an audio file by URL (Telegram will download it) + pub async fn send_audio_by_url( + &self, + chat_id: &str, + url: &str, + caption: Option<&str>, + ) -> anyhow::Result<()> { + self.send_media_by_url("sendAudio", "audio", chat_id, url, caption) + .await + } + + /// Send a voice message by URL (Telegram will download it) + pub async fn send_voice_by_url( + &self, + chat_id: &str, + url: &str, + caption: Option<&str>, + ) -> anyhow::Result<()> { + self.send_media_by_url("sendVoice", "voice", chat_id, url, caption) + .await + } } #[async_trait] @@ -417,82 +787,27 @@ impl Channel for TelegramChannel { } async fn send(&self, message: &str, chat_id: &str) -> anyhow::Result<()> { - // Split message if it exceeds Telegram's 4096 character limit - let chunks = split_message_for_telegram(message); + let (text_without_markers, attachments) = parse_attachment_markers(message); - for (i, chunk) in chunks.iter().enumerate() { - // Add continuation marker for multi-part messages - let text = if chunks.len() > 1 { - if i == 0 { - format!("{chunk}\n\n(continues...)") - } else if i == chunks.len() - 1 { - format!("(continued)\n\n{chunk}") - } else { - format!("(continued)\n\n{chunk}\n\n(continues...)") - } - } else { - chunk.to_string() - }; - - let markdown_body = serde_json::json!({ - "chat_id": chat_id, - "text": text, - "parse_mode": "Markdown" - }); - - let markdown_resp = self - .client - .post(self.api_url("sendMessage")) - .json(&markdown_body) - .send() - .await?; - - if markdown_resp.status().is_success() { - // Small delay between chunks to avoid rate limiting - if i < chunks.len() - 1 { - tokio::time::sleep(Duration::from_millis(100)).await; - } - continue; + if !attachments.is_empty() { + if !text_without_markers.is_empty() { + self.send_text_chunks(&text_without_markers, chat_id) + .await?; } - let markdown_status = markdown_resp.status(); - let markdown_err = markdown_resp.text().await.unwrap_or_default(); - tracing::warn!( - status = ?markdown_status, - "Telegram sendMessage with Markdown failed; retrying without parse_mode" - ); - - // Retry without parse_mode as a compatibility fallback. - let plain_body = serde_json::json!({ - "chat_id": chat_id, - "text": text, - }); - let plain_resp = self - .client - .post(self.api_url("sendMessage")) - .json(&plain_body) - .send() - .await?; - - if !plain_resp.status().is_success() { - let plain_status = plain_resp.status(); - let plain_err = plain_resp.text().await.unwrap_or_default(); - anyhow::bail!( - "Telegram sendMessage failed (markdown {}: {}; plain {}: {})", - markdown_status, - markdown_err, - plain_status, - plain_err - ); + for attachment in &attachments { + self.send_attachment(chat_id, attachment).await?; } - // Small delay between chunks to avoid rate limiting - if i < chunks.len() - 1 { - tokio::time::sleep(Duration::from_millis(100)).await; - } + return Ok(()); } - Ok(()) + if let Some(attachment) = parse_path_only_attachment(message) { + self.send_attachment(chat_id, &attachment).await?; + return Ok(()); + } + + self.send_text_chunks(message, chat_id).await } async fn listen(&self, tx: tokio::sync::mpsc::Sender) -> anyhow::Result<()> { @@ -533,59 +848,13 @@ impl Channel for TelegramChannel { offset = uid + 1; } - let Some(message) = update.get("message") else { + let Some(msg) = self.parse_update_message(update) else { continue; }; - let Some(text) = message.get("text").and_then(serde_json::Value::as_str) else { - continue; - }; - - let username_opt = message - .get("from") - .and_then(|f| f.get("username")) - .and_then(|u| u.as_str()); - let username = username_opt.unwrap_or("unknown"); - - let user_id = message - .get("from") - .and_then(|f| f.get("id")) - .and_then(serde_json::Value::as_i64); - let user_id_str = user_id.map(|id| id.to_string()); - - let mut identities = vec![username]; - if let Some(ref id) = user_id_str { - identities.push(id.as_str()); - } - - if !self.is_any_user_allowed(identities.iter().copied()) { - tracing::warn!( - "Telegram: ignoring message from unauthorized user: username={username}, user_id={}. \ -Allowlist Telegram @username or numeric user ID, then run `zeroclaw onboard --channels-only`.", - user_id_str.as_deref().unwrap_or("unknown") - ); - continue; - } - - let chat_id = message - .get("chat") - .and_then(|c| c.get("id")) - .and_then(serde_json::Value::as_i64) - .map(|id| id.to_string()); - - let Some(chat_id) = chat_id else { - tracing::warn!("Telegram: missing chat_id in message, skipping"); - continue; - }; - - let message_id = message - .get("message_id") - .and_then(|v| v.as_i64()) - .unwrap_or(0); - // Send "typing" indicator immediately when we receive a message let typing_body = serde_json::json!({ - "chat_id": &chat_id, + "chat_id": &msg.reply_target, "action": "typing" }); let _ = self @@ -595,17 +864,6 @@ Allowlist Telegram @username or numeric user ID, then run `zeroclaw onboard --ch .send() .await; // Ignore errors for typing indicator - let msg = ChannelMessage { - id: format!("telegram_{chat_id}_{message_id}"), - sender: username.to_string(), - content: text.to_string(), - channel: "telegram".to_string(), - timestamp: std::time::SystemTime::now() - .duration_since(std::time::UNIX_EPOCH) - .unwrap_or_default() - .as_secs(), - }; - if tx.send(msg).await.is_err() { return Ok(()); } @@ -716,6 +974,107 @@ mod tests { assert!(!ch.is_any_user_allowed(["unknown", "123456789"])); } + #[test] + fn parse_attachment_markers_extracts_multiple_types() { + let message = "Here are files [IMAGE:/tmp/a.png] and [DOCUMENT:https://example.com/a.pdf]"; + let (cleaned, attachments) = parse_attachment_markers(message); + + assert_eq!(cleaned, "Here are files and"); + assert_eq!(attachments.len(), 2); + assert_eq!(attachments[0].kind, TelegramAttachmentKind::Image); + assert_eq!(attachments[0].target, "/tmp/a.png"); + assert_eq!(attachments[1].kind, TelegramAttachmentKind::Document); + assert_eq!(attachments[1].target, "https://example.com/a.pdf"); + } + + #[test] + fn parse_attachment_markers_keeps_invalid_markers_in_text() { + let message = "Report [UNKNOWN:/tmp/a.bin]"; + let (cleaned, attachments) = parse_attachment_markers(message); + + assert_eq!(cleaned, "Report [UNKNOWN:/tmp/a.bin]"); + assert!(attachments.is_empty()); + } + + #[test] + fn parse_path_only_attachment_detects_existing_file() { + let dir = tempfile::tempdir().unwrap(); + let image_path = dir.path().join("snap.png"); + std::fs::write(&image_path, b"fake-png").unwrap(); + + let parsed = parse_path_only_attachment(image_path.to_string_lossy().as_ref()) + .expect("expected attachment"); + + assert_eq!(parsed.kind, TelegramAttachmentKind::Image); + assert_eq!(parsed.target, image_path.to_string_lossy()); + } + + #[test] + fn parse_path_only_attachment_rejects_sentence_text() { + assert!(parse_path_only_attachment("Screenshot saved to /tmp/snap.png").is_none()); + } + + #[test] + fn infer_attachment_kind_from_target_detects_document_extension() { + assert_eq!( + infer_attachment_kind_from_target("https://example.com/files/specs.pdf?download=1"), + Some(TelegramAttachmentKind::Document) + ); + } + + #[test] + fn parse_update_message_uses_chat_id_as_reply_target() { + let ch = TelegramChannel::new("token".into(), vec!["*".into()]); + let update = serde_json::json!({ + "update_id": 1, + "message": { + "message_id": 33, + "text": "hello", + "from": { + "id": 555, + "username": "alice" + }, + "chat": { + "id": -100200300 + } + } + }); + + let msg = ch + .parse_update_message(&update) + .expect("message should parse"); + + assert_eq!(msg.sender, "alice"); + assert_eq!(msg.reply_target, "-100200300"); + assert_eq!(msg.content, "hello"); + assert_eq!(msg.id, "telegram_-100200300_33"); + } + + #[test] + fn parse_update_message_allows_numeric_id_without_username() { + let ch = TelegramChannel::new("token".into(), vec!["555".into()]); + let update = serde_json::json!({ + "update_id": 2, + "message": { + "message_id": 9, + "text": "ping", + "from": { + "id": 555 + }, + "chat": { + "id": 12345 + } + } + }); + + let msg = ch + .parse_update_message(&update) + .expect("numeric allowlist should pass"); + + assert_eq!(msg.sender, "555"); + assert_eq!(msg.reply_target, "12345"); + } + // ── File sending API URL tests ────────────────────────────────── #[test] diff --git a/src/channels/traits.rs b/src/channels/traits.rs index 59b361e..1c44bf6 100644 --- a/src/channels/traits.rs +++ b/src/channels/traits.rs @@ -5,6 +5,7 @@ use async_trait::async_trait; pub struct ChannelMessage { pub id: String, pub sender: String, + pub reply_target: String, pub content: String, pub channel: String, pub timestamp: u64, @@ -62,6 +63,7 @@ mod tests { tx.send(ChannelMessage { id: "1".into(), sender: "tester".into(), + reply_target: "tester".into(), content: "hello".into(), channel: "dummy".into(), timestamp: 123, @@ -76,6 +78,7 @@ mod tests { let message = ChannelMessage { id: "42".into(), sender: "alice".into(), + reply_target: "alice".into(), content: "ping".into(), channel: "dummy".into(), timestamp: 999, @@ -84,6 +87,7 @@ mod tests { let cloned = message.clone(); assert_eq!(cloned.id, "42"); assert_eq!(cloned.sender, "alice"); + assert_eq!(cloned.reply_target, "alice"); assert_eq!(cloned.content, "ping"); assert_eq!(cloned.channel, "dummy"); assert_eq!(cloned.timestamp, 999); diff --git a/src/channels/whatsapp.rs b/src/channels/whatsapp.rs index 3e4c045..7825b96 100644 --- a/src/channels/whatsapp.rs +++ b/src/channels/whatsapp.rs @@ -10,7 +10,7 @@ use uuid::Uuid; /// happens in the gateway when Meta sends webhook events. pub struct WhatsAppChannel { access_token: String, - phone_number_id: String, + endpoint_id: String, verify_token: String, allowed_numbers: Vec, client: reqwest::Client, @@ -19,13 +19,13 @@ pub struct WhatsAppChannel { impl WhatsAppChannel { pub fn new( access_token: String, - phone_number_id: String, + endpoint_id: String, verify_token: String, allowed_numbers: Vec, ) -> Self { Self { access_token, - phone_number_id, + endpoint_id, verify_token, allowed_numbers, client: reqwest::Client::new(), @@ -119,6 +119,7 @@ impl WhatsAppChannel { messages.push(ChannelMessage { id: Uuid::new_v4().to_string(), + reply_target: normalized_from.clone(), sender: normalized_from, content, channel: "whatsapp".to_string(), @@ -142,7 +143,7 @@ impl Channel for WhatsAppChannel { // WhatsApp Cloud API: POST to /v18.0/{phone_number_id}/messages let url = format!( "https://graph.facebook.com/v18.0/{}/messages", - self.phone_number_id + self.endpoint_id ); // Normalize recipient (remove leading + if present for API) @@ -162,7 +163,7 @@ impl Channel for WhatsAppChannel { let resp = self .client .post(&url) - .header("Authorization", format!("Bearer {}", self.access_token)) + .bearer_auth(&self.access_token) .header("Content-Type", "application/json") .json(&body) .send() @@ -195,11 +196,11 @@ impl Channel for WhatsAppChannel { async fn health_check(&self) -> bool { // Check if we can reach the WhatsApp API - let url = format!("https://graph.facebook.com/v18.0/{}", self.phone_number_id); + let url = format!("https://graph.facebook.com/v18.0/{}", self.endpoint_id); self.client .get(&url) - .header("Authorization", format!("Bearer {}", self.access_token)) + .bearer_auth(&self.access_token) .send() .await .map(|r| r.status().is_success()) diff --git a/src/config/mod.rs b/src/config/mod.rs index 4fec9ae..8e37cce 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -37,9 +37,22 @@ mod tests { guild_id: Some("123".into()), allowed_users: vec![], listen_to_bots: false, + mention_only: false, + }; + + let lark = LarkConfig { + app_id: "app-id".into(), + app_secret: "app-secret".into(), + encrypt_key: None, + verification_token: None, + allowed_users: vec![], + use_feishu: false, + receive_mode: crate::config::schema::LarkReceiveMode::Websocket, + port: None, }; assert_eq!(telegram.allowed_users.len(), 1); assert_eq!(discord.guild_id.as_deref(), Some("123")); + assert_eq!(lark.app_id, "app-id"); } } diff --git a/src/config/schema.rs b/src/config/schema.rs index 34be770..74f5d34 100644 --- a/src/config/schema.rs +++ b/src/config/schema.rs @@ -18,6 +18,8 @@ pub struct Config { #[serde(skip)] pub config_path: PathBuf, pub api_key: Option, + /// Base URL override for provider API (e.g. "http://10.0.0.1:11434" for remote Ollama) + pub api_url: Option, pub default_provider: Option, pub default_model: Option, pub default_temperature: f64, @@ -1317,6 +1319,10 @@ pub struct DiscordConfig { /// The bot still ignores its own messages to prevent feedback loops. #[serde(default)] pub listen_to_bots: bool, + /// When true, only respond to messages that @-mention the bot. + /// Other messages in the guild are silently ignored. + #[serde(default)] + pub mention_only: bool, } #[derive(Debug, Clone, Serialize, Deserialize)] @@ -1395,8 +1401,20 @@ fn default_irc_port() -> u16 { 6697 } -/// Lark/Feishu configuration for messaging integration -/// Lark is the international version, Feishu is the Chinese version +/// How ZeroClaw receives events from Feishu / Lark. +/// +/// - `websocket` (default) — persistent WSS long-connection; no public URL required. +/// - `webhook` — HTTP callback server; requires a public HTTPS endpoint. +#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)] +#[serde(rename_all = "lowercase")] +pub enum LarkReceiveMode { + #[default] + Websocket, + Webhook, +} + +/// Lark/Feishu configuration for messaging integration. +/// Lark is the international version; Feishu is the Chinese version. #[derive(Debug, Clone, Serialize, Deserialize)] pub struct LarkConfig { /// App ID from Lark/Feishu developer console @@ -1415,6 +1433,13 @@ pub struct LarkConfig { /// Whether to use the Feishu (Chinese) endpoint instead of Lark (International) #[serde(default)] pub use_feishu: bool, + /// Event receive mode: "websocket" (default) or "webhook" + #[serde(default)] + pub receive_mode: LarkReceiveMode, + /// HTTP port for webhook mode only. Must be set when receive_mode = "webhook". + /// Not required (and ignored) for websocket mode. + #[serde(default)] + pub port: Option, } // ── Security Config ───────────────────────────────────────────────── @@ -1594,6 +1619,7 @@ impl Default for Config { workspace_dir: zeroclaw_dir.join("workspace"), config_path: zeroclaw_dir.join("config.toml"), api_key: None, + api_url: None, default_provider: Some("openrouter".to_string()), default_model: Some("anthropic/claude-sonnet-4".to_string()), default_temperature: 0.7, @@ -1623,35 +1649,146 @@ impl Default for Config { } } -impl Config { - pub fn load_or_init() -> Result { - let home = UserDirs::new() - .map(|u| u.home_dir().to_path_buf()) - .context("Could not find home directory")?; - let zeroclaw_dir = home.join(".zeroclaw"); - let config_path = zeroclaw_dir.join("config.toml"); +fn default_config_and_workspace_dirs() -> Result<(PathBuf, PathBuf)> { + let home = UserDirs::new() + .map(|u| u.home_dir().to_path_buf()) + .context("Could not find home directory")?; + let config_dir = home.join(".zeroclaw"); + Ok((config_dir.clone(), config_dir.join("workspace"))) +} - if !zeroclaw_dir.exists() { - fs::create_dir_all(&zeroclaw_dir).context("Failed to create .zeroclaw directory")?; - fs::create_dir_all(zeroclaw_dir.join("workspace")) - .context("Failed to create workspace directory")?; +fn resolve_config_dir_for_workspace(workspace_dir: &Path) -> PathBuf { + let workspace_config_dir = workspace_dir.to_path_buf(); + if workspace_config_dir.join("config.toml").exists() { + return workspace_config_dir; + } + + let legacy_config_dir = workspace_dir + .parent() + .map(|parent| parent.join(".zeroclaw")); + if let Some(legacy_dir) = legacy_config_dir { + if legacy_dir.join("config.toml").exists() { + return legacy_dir; } + if workspace_dir + .file_name() + .is_some_and(|name| name == std::ffi::OsStr::new("workspace")) + { + return legacy_dir; + } + } + + workspace_config_dir +} + +fn decrypt_optional_secret( + store: &crate::security::SecretStore, + value: &mut Option, + field_name: &str, +) -> Result<()> { + if let Some(raw) = value.clone() { + if crate::security::SecretStore::is_encrypted(&raw) { + *value = Some( + store + .decrypt(&raw) + .with_context(|| format!("Failed to decrypt {field_name}"))?, + ); + } + } + Ok(()) +} + +fn encrypt_optional_secret( + store: &crate::security::SecretStore, + value: &mut Option, + field_name: &str, +) -> Result<()> { + if let Some(raw) = value.clone() { + if !crate::security::SecretStore::is_encrypted(&raw) { + *value = Some( + store + .encrypt(&raw) + .with_context(|| format!("Failed to encrypt {field_name}"))?, + ); + } + } + Ok(()) +} + +impl Config { + pub fn load_or_init() -> Result { + // Resolve workspace first so config loading can follow ZEROCLAW_WORKSPACE. + let (zeroclaw_dir, workspace_dir) = match std::env::var("ZEROCLAW_WORKSPACE") { + Ok(custom_workspace) if !custom_workspace.is_empty() => { + let workspace = PathBuf::from(custom_workspace); + (resolve_config_dir_for_workspace(&workspace), workspace) + } + _ => default_config_and_workspace_dirs()?, + }; + + let config_path = zeroclaw_dir.join("config.toml"); + + fs::create_dir_all(&zeroclaw_dir).context("Failed to create config directory")?; + fs::create_dir_all(&workspace_dir).context("Failed to create workspace directory")?; + if config_path.exists() { + // Warn if config file is world-readable (may contain API keys) + #[cfg(unix)] + { + use std::os::unix::fs::PermissionsExt; + if let Ok(meta) = fs::metadata(&config_path) { + if meta.permissions().mode() & 0o004 != 0 { + tracing::warn!( + "Config file {:?} is world-readable (mode {:o}). \ + Consider restricting with: chmod 600 {:?}", + config_path, + meta.permissions().mode() & 0o777, + config_path, + ); + } + } + } + let contents = fs::read_to_string(&config_path).context("Failed to read config file")?; let mut config: Config = toml::from_str(&contents).context("Failed to parse config file")?; // Set computed paths that are skipped during serialization config.config_path = config_path.clone(); - config.workspace_dir = zeroclaw_dir.join("workspace"); + config.workspace_dir = workspace_dir; + let store = crate::security::SecretStore::new(&zeroclaw_dir, config.secrets.encrypt); + decrypt_optional_secret(&store, &mut config.api_key, "config.api_key")?; + decrypt_optional_secret( + &store, + &mut config.composio.api_key, + "config.composio.api_key", + )?; + + decrypt_optional_secret( + &store, + &mut config.browser.computer_use.api_key, + "config.browser.computer_use.api_key", + )?; + + for agent in config.agents.values_mut() { + decrypt_optional_secret(&store, &mut agent.api_key, "config.agents.*.api_key")?; + } config.apply_env_overrides(); Ok(config) } else { let mut config = Config::default(); config.config_path = config_path.clone(); - config.workspace_dir = zeroclaw_dir.join("workspace"); + config.workspace_dir = workspace_dir; config.save()?; + + // Restrict permissions on newly created config file (may contain API keys) + #[cfg(unix)] + { + use std::os::unix::fs::PermissionsExt; + let _ = fs::set_permissions(&config_path, fs::Permissions::from_mode(0o600)); + } + config.apply_env_overrides(); Ok(config) } @@ -1732,23 +1869,29 @@ impl Config { } pub fn save(&self) -> Result<()> { - // Encrypt agent API keys before serialization + // Encrypt secrets before serialization let mut config_to_save = self.clone(); let zeroclaw_dir = self .config_path .parent() .context("Config path must have a parent directory")?; let store = crate::security::SecretStore::new(zeroclaw_dir, self.secrets.encrypt); + + encrypt_optional_secret(&store, &mut config_to_save.api_key, "config.api_key")?; + encrypt_optional_secret( + &store, + &mut config_to_save.composio.api_key, + "config.composio.api_key", + )?; + + encrypt_optional_secret( + &store, + &mut config_to_save.browser.computer_use.api_key, + "config.browser.computer_use.api_key", + )?; + for agent in config_to_save.agents.values_mut() { - if let Some(ref plaintext_key) = agent.api_key { - if !crate::security::SecretStore::is_encrypted(plaintext_key) { - agent.api_key = Some( - store - .encrypt(plaintext_key) - .context("Failed to encrypt agent API key")?, - ); - } - } + encrypt_optional_secret(&store, &mut agent.api_key, "config.agents.*.api_key")?; } let toml_str = @@ -1949,6 +2092,7 @@ default_temperature = 0.7 workspace_dir: PathBuf::from("/tmp/test/workspace"), config_path: PathBuf::from("/tmp/test/config.toml"), api_key: Some("sk-test-key".into()), + api_url: None, default_provider: Some("openrouter".into()), default_model: Some("gpt-4o".into()), default_temperature: 0.5, @@ -2091,6 +2235,7 @@ tool_dispatcher = "xml" workspace_dir: dir.join("workspace"), config_path: config_path.clone(), api_key: Some("sk-roundtrip".into()), + api_url: None, default_provider: Some("openrouter".into()), default_model: Some("test-model".into()), default_temperature: 0.9, @@ -2123,13 +2268,82 @@ tool_dispatcher = "xml" let contents = fs::read_to_string(&config_path).unwrap(); let loaded: Config = toml::from_str(&contents).unwrap(); - assert_eq!(loaded.api_key.as_deref(), Some("sk-roundtrip")); + assert!(loaded + .api_key + .as_deref() + .is_some_and(crate::security::SecretStore::is_encrypted)); + let store = crate::security::SecretStore::new(&dir, true); + let decrypted = store.decrypt(loaded.api_key.as_deref().unwrap()).unwrap(); + assert_eq!(decrypted, "sk-roundtrip"); assert_eq!(loaded.default_model.as_deref(), Some("test-model")); assert!((loaded.default_temperature - 0.9).abs() < f64::EPSILON); let _ = fs::remove_dir_all(&dir); } + #[test] + fn config_save_encrypts_nested_credentials() { + let dir = std::env::temp_dir().join(format!( + "zeroclaw_test_nested_credentials_{}", + uuid::Uuid::new_v4() + )); + fs::create_dir_all(&dir).unwrap(); + + let mut config = Config::default(); + config.workspace_dir = dir.join("workspace"); + config.config_path = dir.join("config.toml"); + config.api_key = Some("root-credential".into()); + config.composio.api_key = Some("composio-credential".into()); + config.browser.computer_use.api_key = Some("browser-credential".into()); + + config.agents.insert( + "worker".into(), + DelegateAgentConfig { + provider: "openrouter".into(), + model: "model-test".into(), + system_prompt: None, + api_key: Some("agent-credential".into()), + temperature: None, + max_depth: 3, + }, + ); + + config.save().unwrap(); + + let contents = fs::read_to_string(config.config_path.clone()).unwrap(); + let stored: Config = toml::from_str(&contents).unwrap(); + let store = crate::security::SecretStore::new(&dir, true); + + let root_encrypted = stored.api_key.as_deref().unwrap(); + assert!(crate::security::SecretStore::is_encrypted(root_encrypted)); + assert_eq!(store.decrypt(root_encrypted).unwrap(), "root-credential"); + + let composio_encrypted = stored.composio.api_key.as_deref().unwrap(); + assert!(crate::security::SecretStore::is_encrypted( + composio_encrypted + )); + assert_eq!( + store.decrypt(composio_encrypted).unwrap(), + "composio-credential" + ); + + let browser_encrypted = stored.browser.computer_use.api_key.as_deref().unwrap(); + assert!(crate::security::SecretStore::is_encrypted( + browser_encrypted + )); + assert_eq!( + store.decrypt(browser_encrypted).unwrap(), + "browser-credential" + ); + + let worker = stored.agents.get("worker").unwrap(); + let worker_encrypted = worker.api_key.as_deref().unwrap(); + assert!(crate::security::SecretStore::is_encrypted(worker_encrypted)); + assert_eq!(store.decrypt(worker_encrypted).unwrap(), "agent-credential"); + + let _ = fs::remove_dir_all(&dir); + } + #[test] fn config_save_atomic_cleanup() { let dir = @@ -2182,6 +2396,7 @@ tool_dispatcher = "xml" guild_id: Some("12345".into()), allowed_users: vec![], listen_to_bots: false, + mention_only: false, }; let json = serde_json::to_string(&dc).unwrap(); let parsed: DiscordConfig = serde_json::from_str(&json).unwrap(); @@ -2196,6 +2411,7 @@ tool_dispatcher = "xml" guild_id: None, allowed_users: vec![], listen_to_bots: false, + mention_only: false, }; let json = serde_json::to_string(&dc).unwrap(); let parsed: DiscordConfig = serde_json::from_str(&json).unwrap(); @@ -2818,6 +3034,96 @@ default_temperature = 0.7 std::env::remove_var("ZEROCLAW_WORKSPACE"); } + #[test] + fn load_or_init_workspace_override_uses_workspace_root_for_config() { + let _env_guard = env_override_test_guard(); + let temp_home = + std::env::temp_dir().join(format!("zeroclaw_test_home_{}", uuid::Uuid::new_v4())); + let workspace_dir = temp_home.join("profile-a"); + + let original_home = std::env::var("HOME").ok(); + std::env::set_var("HOME", &temp_home); + std::env::set_var("ZEROCLAW_WORKSPACE", &workspace_dir); + + let config = Config::load_or_init().unwrap(); + + assert_eq!(config.workspace_dir, workspace_dir); + assert_eq!(config.config_path, workspace_dir.join("config.toml")); + assert!(workspace_dir.join("config.toml").exists()); + + std::env::remove_var("ZEROCLAW_WORKSPACE"); + if let Some(home) = original_home { + std::env::set_var("HOME", home); + } else { + std::env::remove_var("HOME"); + } + let _ = fs::remove_dir_all(temp_home); + } + + #[test] + fn load_or_init_workspace_suffix_uses_legacy_config_layout() { + let _env_guard = env_override_test_guard(); + let temp_home = + std::env::temp_dir().join(format!("zeroclaw_test_home_{}", uuid::Uuid::new_v4())); + let workspace_dir = temp_home.join("workspace"); + let legacy_config_path = temp_home.join(".zeroclaw").join("config.toml"); + + let original_home = std::env::var("HOME").ok(); + std::env::set_var("HOME", &temp_home); + std::env::set_var("ZEROCLAW_WORKSPACE", &workspace_dir); + + let config = Config::load_or_init().unwrap(); + + assert_eq!(config.workspace_dir, workspace_dir); + assert_eq!(config.config_path, legacy_config_path); + assert!(config.config_path.exists()); + + std::env::remove_var("ZEROCLAW_WORKSPACE"); + if let Some(home) = original_home { + std::env::set_var("HOME", home); + } else { + std::env::remove_var("HOME"); + } + let _ = fs::remove_dir_all(temp_home); + } + + #[test] + fn load_or_init_workspace_override_keeps_existing_legacy_config() { + let _env_guard = env_override_test_guard(); + let temp_home = + std::env::temp_dir().join(format!("zeroclaw_test_home_{}", uuid::Uuid::new_v4())); + let workspace_dir = temp_home.join("custom-workspace"); + let legacy_config_dir = temp_home.join(".zeroclaw"); + let legacy_config_path = legacy_config_dir.join("config.toml"); + + fs::create_dir_all(&legacy_config_dir).unwrap(); + fs::write( + &legacy_config_path, + r#"default_temperature = 0.7 +default_model = "legacy-model" +"#, + ) + .unwrap(); + + let original_home = std::env::var("HOME").ok(); + std::env::set_var("HOME", &temp_home); + std::env::set_var("ZEROCLAW_WORKSPACE", &workspace_dir); + + let config = Config::load_or_init().unwrap(); + + assert_eq!(config.workspace_dir, workspace_dir); + assert_eq!(config.config_path, legacy_config_path); + assert_eq!(config.default_model.as_deref(), Some("legacy-model")); + + std::env::remove_var("ZEROCLAW_WORKSPACE"); + if let Some(home) = original_home { + std::env::set_var("HOME", home); + } else { + std::env::remove_var("HOME"); + } + let _ = fs::remove_dir_all(temp_home); + } + #[test] fn env_override_empty_values_ignored() { let _env_guard = env_override_test_guard(); @@ -2975,4 +3281,118 @@ default_temperature = 0.7 assert_eq!(parsed.boards[0].board, "nucleo-f401re"); assert_eq!(parsed.boards[0].path.as_deref(), Some("/dev/ttyACM0")); } + + #[test] + fn lark_config_serde() { + let lc = LarkConfig { + app_id: "cli_123456".into(), + app_secret: "secret_abc".into(), + encrypt_key: Some("encrypt_key".into()), + verification_token: Some("verify_token".into()), + allowed_users: vec!["user_123".into(), "user_456".into()], + use_feishu: true, + receive_mode: LarkReceiveMode::Websocket, + port: None, + }; + let json = serde_json::to_string(&lc).unwrap(); + let parsed: LarkConfig = serde_json::from_str(&json).unwrap(); + assert_eq!(parsed.app_id, "cli_123456"); + assert_eq!(parsed.app_secret, "secret_abc"); + assert_eq!(parsed.encrypt_key.as_deref(), Some("encrypt_key")); + assert_eq!(parsed.verification_token.as_deref(), Some("verify_token")); + assert_eq!(parsed.allowed_users.len(), 2); + assert!(parsed.use_feishu); + } + + #[test] + fn lark_config_toml_roundtrip() { + let lc = LarkConfig { + app_id: "cli_123456".into(), + app_secret: "secret_abc".into(), + encrypt_key: Some("encrypt_key".into()), + verification_token: Some("verify_token".into()), + allowed_users: vec!["*".into()], + use_feishu: false, + receive_mode: LarkReceiveMode::Webhook, + port: Some(9898), + }; + let toml_str = toml::to_string(&lc).unwrap(); + let parsed: LarkConfig = toml::from_str(&toml_str).unwrap(); + assert_eq!(parsed.app_id, "cli_123456"); + assert_eq!(parsed.app_secret, "secret_abc"); + assert!(!parsed.use_feishu); + } + + #[test] + fn lark_config_deserializes_without_optional_fields() { + let json = r#"{"app_id":"cli_123","app_secret":"secret"}"#; + let parsed: LarkConfig = serde_json::from_str(json).unwrap(); + assert!(parsed.encrypt_key.is_none()); + assert!(parsed.verification_token.is_none()); + assert!(parsed.allowed_users.is_empty()); + assert!(!parsed.use_feishu); + } + + #[test] + fn lark_config_defaults_to_lark_endpoint() { + let json = r#"{"app_id":"cli_123","app_secret":"secret"}"#; + let parsed: LarkConfig = serde_json::from_str(json).unwrap(); + assert!( + !parsed.use_feishu, + "use_feishu should default to false (Lark)" + ); + } + + #[test] + fn lark_config_with_wildcard_allowed_users() { + let json = r#"{"app_id":"cli_123","app_secret":"secret","allowed_users":["*"]}"#; + let parsed: LarkConfig = serde_json::from_str(json).unwrap(); + assert_eq!(parsed.allowed_users, vec!["*"]); + } + + // ── Config file permission hardening (Unix only) ─────────────── + + #[cfg(unix)] + #[test] + fn new_config_file_has_restricted_permissions() { + use std::os::unix::fs::PermissionsExt; + + let tmp = tempfile::TempDir::new().unwrap(); + let config_path = tmp.path().join("config.toml"); + + // Create a config and save it + let mut config = Config::default(); + config.config_path = config_path.clone(); + config.save().unwrap(); + + // Apply the same permission logic as load_or_init + let _ = std::fs::set_permissions(&config_path, std::fs::Permissions::from_mode(0o600)); + + let meta = std::fs::metadata(&config_path).unwrap(); + let mode = meta.permissions().mode() & 0o777; + assert_eq!( + mode, 0o600, + "New config file should be owner-only (0600), got {mode:o}" + ); + } + + #[cfg(unix)] + #[test] + fn world_readable_config_is_detectable() { + use std::os::unix::fs::PermissionsExt; + + let tmp = tempfile::TempDir::new().unwrap(); + let config_path = tmp.path().join("config.toml"); + + // Create a config file with intentionally loose permissions + std::fs::write(&config_path, "# test config").unwrap(); + std::fs::set_permissions(&config_path, std::fs::Permissions::from_mode(0o644)).unwrap(); + + let meta = std::fs::metadata(&config_path).unwrap(); + let mode = meta.permissions().mode(); + assert!( + mode & 0o004 != 0, + "Test setup: file should be world-readable (mode {mode:o})" + ); + } } diff --git a/src/cron/scheduler.rs b/src/cron/scheduler.rs index df771d6..4562dba 100644 --- a/src/cron/scheduler.rs +++ b/src/cron/scheduler.rs @@ -245,6 +245,7 @@ async fn deliver_if_configured(config: &Config, job: &CronJob, output: &str) -> dc.guild_id.clone(), dc.allowed_users.clone(), dc.listen_to_bots, + dc.mention_only, ); channel.send(output, target).await?; } diff --git a/src/daemon/mod.rs b/src/daemon/mod.rs index c2f4487..a223597 100644 --- a/src/daemon/mod.rs +++ b/src/daemon/mod.rs @@ -216,6 +216,7 @@ fn has_supervised_channels(config: &Config) -> bool { || config.channels_config.matrix.is_some() || config.channels_config.whatsapp.is_some() || config.channels_config.email.is_some() + || config.channels_config.lark.is_some() } #[cfg(test)] diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs index 719e8e7..b391a88 100644 --- a/src/gateway/mod.rs +++ b/src/gateway/mod.rs @@ -49,6 +49,13 @@ fn whatsapp_memory_key(msg: &crate::channels::traits::ChannelMessage) -> String format!("whatsapp_{}_{}", msg.sender, msg.id) } +fn hash_webhook_secret(value: &str) -> String { + use sha2::{Digest, Sha256}; + + let digest = Sha256::digest(value.as_bytes()); + hex::encode(digest) +} + /// How often the rate limiter sweeps stale IP entries from its map. const RATE_LIMITER_SWEEP_INTERVAL_SECS: u64 = 300; // 5 minutes @@ -178,7 +185,8 @@ pub struct AppState { pub temperature: f64, pub mem: Arc, pub auto_save: bool, - pub webhook_secret: Option>, + /// SHA-256 hash of `X-Webhook-Secret` (hex-encoded), never plaintext. + pub webhook_secret_hash: Option>, pub pairing: Arc, pub rate_limiter: Arc, pub idempotency_store: Arc, @@ -208,6 +216,7 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> { let provider: Arc = Arc::from(providers::create_resilient_provider( config.default_provider.as_deref().unwrap_or("openrouter"), config.api_key.as_deref(), + config.api_url.as_deref(), &config.reliability, )?); let model = config @@ -251,12 +260,14 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> { &config, )); // Extract webhook secret for authentication - let webhook_secret: Option> = config - .channels_config - .webhook - .as_ref() - .and_then(|w| w.secret.as_deref()) - .map(Arc::from); + let webhook_secret_hash: Option> = + config.channels_config.webhook.as_ref().and_then(|webhook| { + webhook.secret.as_ref().and_then(|raw_secret| { + let trimmed_secret = raw_secret.trim(); + (!trimmed_secret.is_empty()) + .then(|| Arc::::from(hash_webhook_secret(trimmed_secret))) + }) + }); // WhatsApp channel (if configured) let whatsapp_channel: Option> = @@ -342,9 +353,6 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> { } else { println!(" ⚠️ Pairing: DISABLED (all requests accepted)"); } - if webhook_secret.is_some() { - println!(" 🔒 Webhook secret: ENABLED"); - } println!(" Press Ctrl+C to stop.\n"); crate::health::mark_component_ok("gateway"); @@ -356,7 +364,7 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> { temperature, mem, auto_save: config.memory.auto_save, - webhook_secret, + webhook_secret_hash, pairing, rate_limiter, idempotency_store, @@ -482,12 +490,15 @@ async fn handle_webhook( } // ── Webhook secret auth (optional, additional layer) ── - if let Some(ref secret) = state.webhook_secret { - let header_val = headers + if let Some(ref secret_hash) = state.webhook_secret_hash { + let header_hash = headers .get("X-Webhook-Secret") - .and_then(|v| v.to_str().ok()); - match header_val { - Some(val) if constant_time_eq(val, secret.as_ref()) => {} + .and_then(|v| v.to_str().ok()) + .map(str::trim) + .filter(|value| !value.is_empty()) + .map(hash_webhook_secret); + match header_hash { + Some(val) if constant_time_eq(&val, secret_hash.as_ref()) => {} _ => { tracing::warn!("Webhook: rejected request — invalid or missing X-Webhook-Secret"); let err = serde_json::json!({"error": "Unauthorized — invalid or missing X-Webhook-Secret header"}); @@ -532,7 +543,7 @@ async fn handle_webhook( let key = webhook_memory_key(); let _ = state .mem - .store(&key, message, MemoryCategory::Conversation) + .store(&key, message, MemoryCategory::Conversation, None) .await; } @@ -685,7 +696,7 @@ async fn handle_whatsapp_message( let key = whatsapp_memory_key(msg); let _ = state .mem - .store(&key, &msg.content, MemoryCategory::Conversation) + .store(&key, &msg.content, MemoryCategory::Conversation, None) .await; } @@ -697,7 +708,7 @@ async fn handle_whatsapp_message( { Ok(response) => { // Send reply via WhatsApp - if let Err(e) = wa.send(&response, &msg.sender).await { + if let Err(e) = wa.send(&response, &msg.reply_target).await { tracing::error!("Failed to send WhatsApp reply: {e}"); } } @@ -706,7 +717,7 @@ async fn handle_whatsapp_message( let _ = wa .send( "Sorry, I couldn't process your message right now.", - &msg.sender, + &msg.reply_target, ) .await; } @@ -798,7 +809,9 @@ mod tests { .requests .lock() .unwrap_or_else(std::sync::PoisonError::into_inner); - guard.1 = Instant::now() - Duration::from_secs(RATE_LIMITER_SWEEP_INTERVAL_SECS + 1); + guard.1 = Instant::now() + .checked_sub(Duration::from_secs(RATE_LIMITER_SWEEP_INTERVAL_SECS + 1)) + .unwrap(); // Clear timestamps for ip-2 and ip-3 to simulate stale entries guard.0.get_mut("ip-2").unwrap().clear(); guard.0.get_mut("ip-3").unwrap().clear(); @@ -848,6 +861,7 @@ mod tests { let msg = ChannelMessage { id: "wamid-123".into(), sender: "+1234567890".into(), + reply_target: "+1234567890".into(), content: "hello".into(), channel: "whatsapp".into(), timestamp: 1, @@ -871,11 +885,17 @@ mod tests { _key: &str, _content: &str, _category: MemoryCategory, + _session_id: Option<&str>, ) -> anyhow::Result<()> { Ok(()) } - async fn recall(&self, _query: &str, _limit: usize) -> anyhow::Result> { + async fn recall( + &self, + _query: &str, + _limit: usize, + _session_id: Option<&str>, + ) -> anyhow::Result> { Ok(Vec::new()) } @@ -886,6 +906,7 @@ mod tests { async fn list( &self, _category: Option<&MemoryCategory>, + _session_id: Option<&str>, ) -> anyhow::Result> { Ok(Vec::new()) } @@ -938,6 +959,7 @@ mod tests { key: &str, _content: &str, _category: MemoryCategory, + _session_id: Option<&str>, ) -> anyhow::Result<()> { self.keys .lock() @@ -946,7 +968,12 @@ mod tests { Ok(()) } - async fn recall(&self, _query: &str, _limit: usize) -> anyhow::Result> { + async fn recall( + &self, + _query: &str, + _limit: usize, + _session_id: Option<&str>, + ) -> anyhow::Result> { Ok(Vec::new()) } @@ -957,6 +984,7 @@ mod tests { async fn list( &self, _category: Option<&MemoryCategory>, + _session_id: Option<&str>, ) -> anyhow::Result> { Ok(Vec::new()) } @@ -991,7 +1019,7 @@ mod tests { temperature: 0.0, mem: memory, auto_save: false, - webhook_secret: None, + webhook_secret_hash: None, pairing: Arc::new(PairingGuard::new(false, &[])), rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100)), idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300))), @@ -1039,7 +1067,7 @@ mod tests { temperature: 0.0, mem: memory, auto_save: true, - webhook_secret: None, + webhook_secret_hash: None, pairing: Arc::new(PairingGuard::new(false, &[])), rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100)), idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300))), @@ -1077,6 +1105,125 @@ mod tests { assert_eq!(provider_impl.calls.load(Ordering::SeqCst), 2); } + #[test] + fn webhook_secret_hash_is_deterministic_and_nonempty() { + let one = hash_webhook_secret("secret-value"); + let two = hash_webhook_secret("secret-value"); + let other = hash_webhook_secret("other-value"); + + assert_eq!(one, two); + assert_ne!(one, other); + assert_eq!(one.len(), 64); + } + + #[tokio::test] + async fn webhook_secret_hash_rejects_missing_header() { + let provider_impl = Arc::new(MockProvider::default()); + let provider: Arc = provider_impl.clone(); + let memory: Arc = Arc::new(MockMemory); + + let state = AppState { + provider, + model: "test-model".into(), + temperature: 0.0, + mem: memory, + auto_save: false, + webhook_secret_hash: Some(Arc::from(hash_webhook_secret("super-secret"))), + pairing: Arc::new(PairingGuard::new(false, &[])), + rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100)), + idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300))), + whatsapp: None, + whatsapp_app_secret: None, + }; + + let response = handle_webhook( + State(state), + HeaderMap::new(), + Ok(Json(WebhookBody { + message: "hello".into(), + })), + ) + .await + .into_response(); + + assert_eq!(response.status(), StatusCode::UNAUTHORIZED); + assert_eq!(provider_impl.calls.load(Ordering::SeqCst), 0); + } + + #[tokio::test] + async fn webhook_secret_hash_rejects_invalid_header() { + let provider_impl = Arc::new(MockProvider::default()); + let provider: Arc = provider_impl.clone(); + let memory: Arc = Arc::new(MockMemory); + + let state = AppState { + provider, + model: "test-model".into(), + temperature: 0.0, + mem: memory, + auto_save: false, + webhook_secret_hash: Some(Arc::from(hash_webhook_secret("super-secret"))), + pairing: Arc::new(PairingGuard::new(false, &[])), + rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100)), + idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300))), + whatsapp: None, + whatsapp_app_secret: None, + }; + + let mut headers = HeaderMap::new(); + headers.insert("X-Webhook-Secret", HeaderValue::from_static("wrong-secret")); + + let response = handle_webhook( + State(state), + headers, + Ok(Json(WebhookBody { + message: "hello".into(), + })), + ) + .await + .into_response(); + + assert_eq!(response.status(), StatusCode::UNAUTHORIZED); + assert_eq!(provider_impl.calls.load(Ordering::SeqCst), 0); + } + + #[tokio::test] + async fn webhook_secret_hash_accepts_valid_header() { + let provider_impl = Arc::new(MockProvider::default()); + let provider: Arc = provider_impl.clone(); + let memory: Arc = Arc::new(MockMemory); + + let state = AppState { + provider, + model: "test-model".into(), + temperature: 0.0, + mem: memory, + auto_save: false, + webhook_secret_hash: Some(Arc::from(hash_webhook_secret("super-secret"))), + pairing: Arc::new(PairingGuard::new(false, &[])), + rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100)), + idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300))), + whatsapp: None, + whatsapp_app_secret: None, + }; + + let mut headers = HeaderMap::new(); + headers.insert("X-Webhook-Secret", HeaderValue::from_static("super-secret")); + + let response = handle_webhook( + State(state), + headers, + Ok(Json(WebhookBody { + message: "hello".into(), + })), + ) + .await + .into_response(); + + assert_eq!(response.status(), StatusCode::OK); + assert_eq!(provider_impl.calls.load(Ordering::SeqCst), 1); + } + // ══════════════════════════════════════════════════════════ // WhatsApp Signature Verification Tests (CWE-345 Prevention) // ══════════════════════════════════════════════════════════ diff --git a/src/main.rs b/src/main.rs index dbc76ff..56cd579 100644 --- a/src/main.rs +++ b/src/main.rs @@ -34,8 +34,8 @@ use anyhow::{bail, Result}; use clap::{Parser, Subcommand}; -use tracing::{info, Level}; -use tracing_subscriber::FmtSubscriber; +use tracing::info; +use tracing_subscriber::{fmt, EnvFilter}; mod agent; mod channels; @@ -147,24 +147,24 @@ enum Commands { /// Start the gateway server (webhooks, websockets) Gateway { - /// Port to listen on (use 0 for random available port) - #[arg(short, long, default_value = "8080")] - port: u16, + /// Port to listen on (use 0 for random available port); defaults to config gateway.port + #[arg(short, long)] + port: Option, - /// Host to bind to - #[arg(long, default_value = "127.0.0.1")] - host: String, + /// Host to bind to; defaults to config gateway.host + #[arg(long)] + host: Option, }, /// Start long-running autonomous runtime (gateway + channels + heartbeat + scheduler) Daemon { - /// Port to listen on (use 0 for random available port) - #[arg(short, long, default_value = "8080")] - port: u16, + /// Port to listen on (use 0 for random available port); defaults to config gateway.port + #[arg(short, long)] + port: Option, - /// Host to bind to - #[arg(long, default_value = "127.0.0.1")] - host: String, + /// Host to bind to; defaults to config gateway.host + #[arg(long)] + host: Option, }, /// Manage OS service lifecycle (launchd/systemd user service) @@ -367,9 +367,11 @@ async fn main() -> Result<()> { let cli = Cli::parse(); - // Initialize logging - let subscriber = FmtSubscriber::builder() - .with_max_level(Level::INFO) + // Initialize logging - respects RUST_LOG env var, defaults to INFO + let subscriber = fmt::Subscriber::builder() + .with_env_filter( + EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")), + ) .finish(); tracing::subscriber::set_global_default(subscriber).expect("setting default subscriber failed"); @@ -434,6 +436,8 @@ async fn main() -> Result<()> { .map(|_| ()), Commands::Gateway { port, host } => { + let port = port.unwrap_or(config.gateway.port); + let host = host.unwrap_or_else(|| config.gateway.host.clone()); if port == 0 { info!("🚀 Starting ZeroClaw Gateway on {host} (random port)"); } else { @@ -443,6 +447,8 @@ async fn main() -> Result<()> { } Commands::Daemon { port, host } => { + let port = port.unwrap_or(config.gateway.port); + let host = host.unwrap_or_else(|| config.gateway.host.clone()); if port == 0 { info!("🧠 Starting ZeroClaw Daemon on {host} (random port)"); } else { diff --git a/src/memory/backend.rs b/src/memory/backend.rs index 4de636a..8ba7ec3 100644 --- a/src/memory/backend.rs +++ b/src/memory/backend.rs @@ -7,6 +7,7 @@ pub enum MemoryBackendKind { Unknown, } +#[allow(clippy::struct_excessive_bools)] #[derive(Debug, Clone, Copy, Eq, PartialEq)] pub struct MemoryBackendProfile { pub key: &'static str, diff --git a/src/memory/hygiene.rs b/src/memory/hygiene.rs index cf58e21..01054ce 100644 --- a/src/memory/hygiene.rs +++ b/src/memory/hygiene.rs @@ -502,10 +502,10 @@ mod tests { let workspace = tmp.path(); let mem = SqliteMemory::new(workspace).unwrap(); - mem.store("conv_old", "outdated", MemoryCategory::Conversation) + mem.store("conv_old", "outdated", MemoryCategory::Conversation, None) .await .unwrap(); - mem.store("core_keep", "durable", MemoryCategory::Core) + mem.store("core_keep", "durable", MemoryCategory::Core, None) .await .unwrap(); drop(mem); diff --git a/src/memory/lucid.rs b/src/memory/lucid.rs index 50cf9de..e1cb43a 100644 --- a/src/memory/lucid.rs +++ b/src/memory/lucid.rs @@ -24,7 +24,9 @@ pub struct LucidMemory { impl LucidMemory { const DEFAULT_LUCID_CMD: &'static str = "lucid"; const DEFAULT_TOKEN_BUDGET: usize = 200; - const DEFAULT_RECALL_TIMEOUT_MS: u64 = 120; + // Lucid CLI cold start can exceed 120ms on slower machines, which causes + // avoidable fallback to local-only memory and premature cooldown. + const DEFAULT_RECALL_TIMEOUT_MS: u64 = 500; const DEFAULT_STORE_TIMEOUT_MS: u64 = 800; const DEFAULT_LOCAL_HIT_THRESHOLD: usize = 3; const DEFAULT_FAILURE_COOLDOWN_MS: u64 = 15_000; @@ -74,6 +76,7 @@ impl LucidMemory { } #[cfg(test)] + #[allow(clippy::too_many_arguments)] fn with_options( workspace_dir: &Path, local: SqliteMemory, @@ -307,14 +310,22 @@ impl Memory for LucidMemory { key: &str, content: &str, category: MemoryCategory, + session_id: Option<&str>, ) -> anyhow::Result<()> { - self.local.store(key, content, category.clone()).await?; + self.local + .store(key, content, category.clone(), session_id) + .await?; self.sync_to_lucid_async(key, content, &category).await; Ok(()) } - async fn recall(&self, query: &str, limit: usize) -> anyhow::Result> { - let local_results = self.local.recall(query, limit).await?; + async fn recall( + &self, + query: &str, + limit: usize, + session_id: Option<&str>, + ) -> anyhow::Result> { + let local_results = self.local.recall(query, limit, session_id).await?; if limit == 0 || local_results.len() >= limit || local_results.len() >= self.local_hit_threshold @@ -351,8 +362,12 @@ impl Memory for LucidMemory { self.local.get(key).await } - async fn list(&self, category: Option<&MemoryCategory>) -> anyhow::Result> { - self.local.list(category).await + async fn list( + &self, + category: Option<&MemoryCategory>, + session_id: Option<&str>, + ) -> anyhow::Result> { + self.local.list(category, session_id).await } async fn forget(&self, key: &str) -> anyhow::Result { @@ -396,6 +411,38 @@ EOF exit 0 fi +echo "unsupported command" >&2 +exit 1 +"#; + + fs::write(&script_path, script).unwrap(); + let mut perms = fs::metadata(&script_path).unwrap().permissions(); + perms.set_mode(0o755); + fs::set_permissions(&script_path, perms).unwrap(); + script_path.display().to_string() + } + + fn write_delayed_lucid_script(dir: &Path) -> String { + let script_path = dir.join("delayed-lucid.sh"); + let script = r#"#!/usr/bin/env bash +set -euo pipefail + +if [[ "${1:-}" == "store" ]]; then + echo '{"success":true,"id":"mem_1"}' + exit 0 +fi + +if [[ "${1:-}" == "context" ]]; then + # Simulate a cold start that is slower than 120ms but below the 500ms timeout. + sleep 0.2 + cat <<'EOF' + +- [decision] Delayed token refresh guidance + +EOF + exit 0 +fi + echo "unsupported command" >&2 exit 1 "#; @@ -449,7 +496,7 @@ exit 1 cmd, 200, 3, - Duration::from_millis(120), + Duration::from_millis(500), Duration::from_millis(400), Duration::from_secs(2), ) @@ -468,7 +515,7 @@ exit 1 let memory = test_memory(tmp.path(), "nonexistent-lucid-binary".to_string()); memory - .store("lang", "User prefers Rust", MemoryCategory::Core) + .store("lang", "User prefers Rust", MemoryCategory::Core, None) .await .unwrap(); @@ -483,6 +530,30 @@ exit 1 let fake_cmd = write_fake_lucid_script(tmp.path()); let memory = test_memory(tmp.path(), fake_cmd); + memory + .store( + "local_note", + "Local sqlite auth fallback note", + MemoryCategory::Core, + None, + ) + .await + .unwrap(); + + let entries = memory.recall("auth", 5, None).await.unwrap(); + + assert!(entries + .iter() + .any(|e| e.content.contains("Local sqlite auth fallback note"))); + assert!(entries.iter().any(|e| e.content.contains("token refresh"))); + } + + #[tokio::test] + async fn recall_handles_lucid_cold_start_delay_within_timeout() { + let tmp = TempDir::new().unwrap(); + let delayed_cmd = write_delayed_lucid_script(tmp.path()); + let memory = test_memory(tmp.path(), delayed_cmd); + memory .store( "local_note", @@ -497,7 +568,9 @@ exit 1 assert!(entries .iter() .any(|e| e.content.contains("Local sqlite auth fallback note"))); - assert!(entries.iter().any(|e| e.content.contains("token refresh"))); + assert!(entries + .iter() + .any(|e| e.content.contains("Delayed token refresh guidance"))); } #[tokio::test] @@ -513,17 +586,22 @@ exit 1 probe_cmd, 200, 1, - Duration::from_millis(120), + Duration::from_millis(500), Duration::from_millis(400), Duration::from_secs(2), ); memory - .store("pref", "Rust should stay local-first", MemoryCategory::Core) + .store( + "pref", + "Rust should stay local-first", + MemoryCategory::Core, + None, + ) .await .unwrap(); - let entries = memory.recall("rust", 5).await.unwrap(); + let entries = memory.recall("rust", 5, None).await.unwrap(); assert!(entries .iter() .any(|e| e.content.contains("Rust should stay local-first"))); @@ -578,13 +656,13 @@ exit 1 failing_cmd, 200, 99, - Duration::from_millis(120), + Duration::from_millis(500), Duration::from_millis(400), Duration::from_secs(5), ); - let first = memory.recall("auth", 5).await.unwrap(); - let second = memory.recall("auth", 5).await.unwrap(); + let first = memory.recall("auth", 5, None).await.unwrap(); + let second = memory.recall("auth", 5, None).await.unwrap(); assert!(first.is_empty()); assert!(second.is_empty()); diff --git a/src/memory/markdown.rs b/src/memory/markdown.rs index 8dcd667..9038683 100644 --- a/src/memory/markdown.rs +++ b/src/memory/markdown.rs @@ -143,6 +143,7 @@ impl Memory for MarkdownMemory { key: &str, content: &str, category: MemoryCategory, + _session_id: Option<&str>, ) -> anyhow::Result<()> { let entry = format!("- **{key}**: {content}"); let path = match category { @@ -152,7 +153,12 @@ impl Memory for MarkdownMemory { self.append_to_file(&path, &entry).await } - async fn recall(&self, query: &str, limit: usize) -> anyhow::Result> { + async fn recall( + &self, + query: &str, + limit: usize, + _session_id: Option<&str>, + ) -> anyhow::Result> { let all = self.read_all_entries().await?; let query_lower = query.to_lowercase(); let keywords: Vec<&str> = query_lower.split_whitespace().collect(); @@ -192,7 +198,11 @@ impl Memory for MarkdownMemory { .find(|e| e.key == key || e.content.contains(key))) } - async fn list(&self, category: Option<&MemoryCategory>) -> anyhow::Result> { + async fn list( + &self, + category: Option<&MemoryCategory>, + _session_id: Option<&str>, + ) -> anyhow::Result> { let all = self.read_all_entries().await?; match category { Some(cat) => Ok(all.into_iter().filter(|e| &e.category == cat).collect()), @@ -243,7 +253,7 @@ mod tests { #[tokio::test] async fn markdown_store_core() { let (_tmp, mem) = temp_workspace(); - mem.store("pref", "User likes Rust", MemoryCategory::Core) + mem.store("pref", "User likes Rust", MemoryCategory::Core, None) .await .unwrap(); let content = sync_fs::read_to_string(mem.core_path()).unwrap(); @@ -253,7 +263,7 @@ mod tests { #[tokio::test] async fn markdown_store_daily() { let (_tmp, mem) = temp_workspace(); - mem.store("note", "Finished tests", MemoryCategory::Daily) + mem.store("note", "Finished tests", MemoryCategory::Daily, None) .await .unwrap(); let path = mem.daily_path(); @@ -264,17 +274,17 @@ mod tests { #[tokio::test] async fn markdown_recall_keyword() { let (_tmp, mem) = temp_workspace(); - mem.store("a", "Rust is fast", MemoryCategory::Core) + mem.store("a", "Rust is fast", MemoryCategory::Core, None) .await .unwrap(); - mem.store("b", "Python is slow", MemoryCategory::Core) + mem.store("b", "Python is slow", MemoryCategory::Core, None) .await .unwrap(); - mem.store("c", "Rust and safety", MemoryCategory::Core) + mem.store("c", "Rust and safety", MemoryCategory::Core, None) .await .unwrap(); - let results = mem.recall("Rust", 10).await.unwrap(); + let results = mem.recall("Rust", 10, None).await.unwrap(); assert!(results.len() >= 2); assert!(results .iter() @@ -284,18 +294,20 @@ mod tests { #[tokio::test] async fn markdown_recall_no_match() { let (_tmp, mem) = temp_workspace(); - mem.store("a", "Rust is great", MemoryCategory::Core) + mem.store("a", "Rust is great", MemoryCategory::Core, None) .await .unwrap(); - let results = mem.recall("javascript", 10).await.unwrap(); + let results = mem.recall("javascript", 10, None).await.unwrap(); assert!(results.is_empty()); } #[tokio::test] async fn markdown_count() { let (_tmp, mem) = temp_workspace(); - mem.store("a", "first", MemoryCategory::Core).await.unwrap(); - mem.store("b", "second", MemoryCategory::Core) + mem.store("a", "first", MemoryCategory::Core, None) + .await + .unwrap(); + mem.store("b", "second", MemoryCategory::Core, None) .await .unwrap(); let count = mem.count().await.unwrap(); @@ -305,24 +317,24 @@ mod tests { #[tokio::test] async fn markdown_list_by_category() { let (_tmp, mem) = temp_workspace(); - mem.store("a", "core fact", MemoryCategory::Core) + mem.store("a", "core fact", MemoryCategory::Core, None) .await .unwrap(); - mem.store("b", "daily note", MemoryCategory::Daily) + mem.store("b", "daily note", MemoryCategory::Daily, None) .await .unwrap(); - let core = mem.list(Some(&MemoryCategory::Core)).await.unwrap(); + let core = mem.list(Some(&MemoryCategory::Core), None).await.unwrap(); assert!(core.iter().all(|e| e.category == MemoryCategory::Core)); - let daily = mem.list(Some(&MemoryCategory::Daily)).await.unwrap(); + let daily = mem.list(Some(&MemoryCategory::Daily), None).await.unwrap(); assert!(daily.iter().all(|e| e.category == MemoryCategory::Daily)); } #[tokio::test] async fn markdown_forget_is_noop() { let (_tmp, mem) = temp_workspace(); - mem.store("a", "permanent", MemoryCategory::Core) + mem.store("a", "permanent", MemoryCategory::Core, None) .await .unwrap(); let removed = mem.forget("a").await.unwrap(); @@ -332,7 +344,7 @@ mod tests { #[tokio::test] async fn markdown_empty_recall() { let (_tmp, mem) = temp_workspace(); - let results = mem.recall("anything", 10).await.unwrap(); + let results = mem.recall("anything", 10, None).await.unwrap(); assert!(results.is_empty()); } diff --git a/src/memory/none.rs b/src/memory/none.rs index 6057ad0..4ccd2f8 100644 --- a/src/memory/none.rs +++ b/src/memory/none.rs @@ -25,11 +25,17 @@ impl Memory for NoneMemory { _key: &str, _content: &str, _category: MemoryCategory, + _session_id: Option<&str>, ) -> anyhow::Result<()> { Ok(()) } - async fn recall(&self, _query: &str, _limit: usize) -> anyhow::Result> { + async fn recall( + &self, + _query: &str, + _limit: usize, + _session_id: Option<&str>, + ) -> anyhow::Result> { Ok(Vec::new()) } @@ -37,7 +43,11 @@ impl Memory for NoneMemory { Ok(None) } - async fn list(&self, _category: Option<&MemoryCategory>) -> anyhow::Result> { + async fn list( + &self, + _category: Option<&MemoryCategory>, + _session_id: Option<&str>, + ) -> anyhow::Result> { Ok(Vec::new()) } @@ -62,11 +72,14 @@ mod tests { async fn none_memory_is_noop() { let memory = NoneMemory::new(); - memory.store("k", "v", MemoryCategory::Core).await.unwrap(); + memory + .store("k", "v", MemoryCategory::Core, None) + .await + .unwrap(); assert!(memory.get("k").await.unwrap().is_none()); - assert!(memory.recall("k", 10).await.unwrap().is_empty()); - assert!(memory.list(None).await.unwrap().is_empty()); + assert!(memory.recall("k", 10, None).await.unwrap().is_empty()); + assert!(memory.list(None, None).await.unwrap().is_empty()); assert!(!memory.forget("k").await.unwrap()); assert_eq!(memory.count().await.unwrap(), 0); assert!(memory.health_check().await); diff --git a/src/memory/response_cache.rs b/src/memory/response_cache.rs index 6baa5c7..a260aa7 100644 --- a/src/memory/response_cache.rs +++ b/src/memory/response_cache.rs @@ -157,7 +157,7 @@ impl ResponseCache { |row| row.get(0), )?; - #[allow(clippy::cast_sign_loss)] + #[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)] Ok((count as usize, hits as u64, tokens_saved as u64)) } diff --git a/src/memory/sqlite.rs b/src/memory/sqlite.rs index 160487d..46a98db 100644 --- a/src/memory/sqlite.rs +++ b/src/memory/sqlite.rs @@ -124,6 +124,19 @@ impl SqliteMemory { ); CREATE INDEX IF NOT EXISTS idx_cache_accessed ON embedding_cache(accessed_at);", )?; + + // Migration: add session_id column if not present (safe to run repeatedly) + let has_session_id: bool = conn + .prepare("SELECT sql FROM sqlite_master WHERE type='table' AND name='memories'")? + .query_row([], |row| row.get::<_, String>(0))? + .contains("session_id"); + if !has_session_id { + conn.execute_batch( + "ALTER TABLE memories ADD COLUMN session_id TEXT; + CREATE INDEX IF NOT EXISTS idx_memories_session ON memories(session_id);", + )?; + } + Ok(()) } @@ -361,6 +374,7 @@ impl Memory for SqliteMemory { key: &str, content: &str, category: MemoryCategory, + session_id: Option<&str>, ) -> anyhow::Result<()> { // Compute embedding (async, before lock) let embedding_bytes = self @@ -377,20 +391,26 @@ impl Memory for SqliteMemory { let id = Uuid::new_v4().to_string(); conn.execute( - "INSERT INTO memories (id, key, content, category, embedding, created_at, updated_at) - VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7) + "INSERT INTO memories (id, key, content, category, embedding, created_at, updated_at, session_id) + VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8) ON CONFLICT(key) DO UPDATE SET content = excluded.content, category = excluded.category, embedding = excluded.embedding, - updated_at = excluded.updated_at", - params![id, key, content, cat, embedding_bytes, now, now], + updated_at = excluded.updated_at, + session_id = excluded.session_id", + params![id, key, content, cat, embedding_bytes, now, now, session_id], )?; Ok(()) } - async fn recall(&self, query: &str, limit: usize) -> anyhow::Result> { + async fn recall( + &self, + query: &str, + limit: usize, + session_id: Option<&str>, + ) -> anyhow::Result> { if query.trim().is_empty() { return Ok(Vec::new()); } @@ -439,7 +459,7 @@ impl Memory for SqliteMemory { let mut results = Vec::new(); for scored in &merged { let mut stmt = conn.prepare( - "SELECT id, key, content, category, created_at FROM memories WHERE id = ?1", + "SELECT id, key, content, category, created_at, session_id FROM memories WHERE id = ?1", )?; if let Ok(entry) = stmt.query_row(params![scored.id], |row| { Ok(MemoryEntry { @@ -448,10 +468,16 @@ impl Memory for SqliteMemory { content: row.get(2)?, category: Self::str_to_category(&row.get::<_, String>(3)?), timestamp: row.get(4)?, - session_id: None, + session_id: row.get(5)?, score: Some(f64::from(scored.final_score)), }) }) { + // Filter by session_id if requested + if let Some(sid) = session_id { + if entry.session_id.as_deref() != Some(sid) { + continue; + } + } results.push(entry); } } @@ -470,7 +496,7 @@ impl Memory for SqliteMemory { .collect(); let where_clause = conditions.join(" OR "); let sql = format!( - "SELECT id, key, content, category, created_at FROM memories + "SELECT id, key, content, category, created_at, session_id FROM memories WHERE {where_clause} ORDER BY updated_at DESC LIMIT ?{}", @@ -493,12 +519,18 @@ impl Memory for SqliteMemory { content: row.get(2)?, category: Self::str_to_category(&row.get::<_, String>(3)?), timestamp: row.get(4)?, - session_id: None, + session_id: row.get(5)?, score: Some(1.0), }) })?; for row in rows { - results.push(row?); + let entry = row?; + if let Some(sid) = session_id { + if entry.session_id.as_deref() != Some(sid) { + continue; + } + } + results.push(entry); } } } @@ -514,7 +546,7 @@ impl Memory for SqliteMemory { .map_err(|e| anyhow::anyhow!("Lock error: {e}"))?; let mut stmt = conn.prepare( - "SELECT id, key, content, category, created_at FROM memories WHERE key = ?1", + "SELECT id, key, content, category, created_at, session_id FROM memories WHERE key = ?1", )?; let mut rows = stmt.query_map(params![key], |row| { @@ -524,7 +556,7 @@ impl Memory for SqliteMemory { content: row.get(2)?, category: Self::str_to_category(&row.get::<_, String>(3)?), timestamp: row.get(4)?, - session_id: None, + session_id: row.get(5)?, score: None, }) })?; @@ -535,7 +567,11 @@ impl Memory for SqliteMemory { } } - async fn list(&self, category: Option<&MemoryCategory>) -> anyhow::Result> { + async fn list( + &self, + category: Option<&MemoryCategory>, + session_id: Option<&str>, + ) -> anyhow::Result> { let conn = self .conn .lock() @@ -550,7 +586,7 @@ impl Memory for SqliteMemory { content: row.get(2)?, category: Self::str_to_category(&row.get::<_, String>(3)?), timestamp: row.get(4)?, - session_id: None, + session_id: row.get(5)?, score: None, }) }; @@ -558,21 +594,33 @@ impl Memory for SqliteMemory { if let Some(cat) = category { let cat_str = Self::category_to_str(cat); let mut stmt = conn.prepare( - "SELECT id, key, content, category, created_at FROM memories + "SELECT id, key, content, category, created_at, session_id FROM memories WHERE category = ?1 ORDER BY updated_at DESC", )?; let rows = stmt.query_map(params![cat_str], row_mapper)?; for row in rows { - results.push(row?); + let entry = row?; + if let Some(sid) = session_id { + if entry.session_id.as_deref() != Some(sid) { + continue; + } + } + results.push(entry); } } else { let mut stmt = conn.prepare( - "SELECT id, key, content, category, created_at FROM memories + "SELECT id, key, content, category, created_at, session_id FROM memories ORDER BY updated_at DESC", )?; let rows = stmt.query_map([], row_mapper)?; for row in rows { - results.push(row?); + let entry = row?; + if let Some(sid) = session_id { + if entry.session_id.as_deref() != Some(sid) { + continue; + } + } + results.push(entry); } } @@ -632,7 +680,7 @@ mod tests { #[tokio::test] async fn sqlite_store_and_get() { let (_tmp, mem) = temp_sqlite(); - mem.store("user_lang", "Prefers Rust", MemoryCategory::Core) + mem.store("user_lang", "Prefers Rust", MemoryCategory::Core, None) .await .unwrap(); @@ -647,10 +695,10 @@ mod tests { #[tokio::test] async fn sqlite_store_upsert() { let (_tmp, mem) = temp_sqlite(); - mem.store("pref", "likes Rust", MemoryCategory::Core) + mem.store("pref", "likes Rust", MemoryCategory::Core, None) .await .unwrap(); - mem.store("pref", "loves Rust", MemoryCategory::Core) + mem.store("pref", "loves Rust", MemoryCategory::Core, None) .await .unwrap(); @@ -662,17 +710,22 @@ mod tests { #[tokio::test] async fn sqlite_recall_keyword() { let (_tmp, mem) = temp_sqlite(); - mem.store("a", "Rust is fast and safe", MemoryCategory::Core) + mem.store("a", "Rust is fast and safe", MemoryCategory::Core, None) .await .unwrap(); - mem.store("b", "Python is interpreted", MemoryCategory::Core) - .await - .unwrap(); - mem.store("c", "Rust has zero-cost abstractions", MemoryCategory::Core) + mem.store("b", "Python is interpreted", MemoryCategory::Core, None) .await .unwrap(); + mem.store( + "c", + "Rust has zero-cost abstractions", + MemoryCategory::Core, + None, + ) + .await + .unwrap(); - let results = mem.recall("Rust", 10).await.unwrap(); + let results = mem.recall("Rust", 10, None).await.unwrap(); assert_eq!(results.len(), 2); assert!(results .iter() @@ -682,14 +735,14 @@ mod tests { #[tokio::test] async fn sqlite_recall_multi_keyword() { let (_tmp, mem) = temp_sqlite(); - mem.store("a", "Rust is fast", MemoryCategory::Core) + mem.store("a", "Rust is fast", MemoryCategory::Core, None) .await .unwrap(); - mem.store("b", "Rust is safe and fast", MemoryCategory::Core) + mem.store("b", "Rust is safe and fast", MemoryCategory::Core, None) .await .unwrap(); - let results = mem.recall("fast safe", 10).await.unwrap(); + let results = mem.recall("fast safe", 10, None).await.unwrap(); assert!(!results.is_empty()); // Entry with both keywords should score higher assert!(results[0].content.contains("safe") && results[0].content.contains("fast")); @@ -698,17 +751,17 @@ mod tests { #[tokio::test] async fn sqlite_recall_no_match() { let (_tmp, mem) = temp_sqlite(); - mem.store("a", "Rust rocks", MemoryCategory::Core) + mem.store("a", "Rust rocks", MemoryCategory::Core, None) .await .unwrap(); - let results = mem.recall("javascript", 10).await.unwrap(); + let results = mem.recall("javascript", 10, None).await.unwrap(); assert!(results.is_empty()); } #[tokio::test] async fn sqlite_forget() { let (_tmp, mem) = temp_sqlite(); - mem.store("temp", "temporary data", MemoryCategory::Conversation) + mem.store("temp", "temporary data", MemoryCategory::Conversation, None) .await .unwrap(); assert_eq!(mem.count().await.unwrap(), 1); @@ -728,29 +781,37 @@ mod tests { #[tokio::test] async fn sqlite_list_all() { let (_tmp, mem) = temp_sqlite(); - mem.store("a", "one", MemoryCategory::Core).await.unwrap(); - mem.store("b", "two", MemoryCategory::Daily).await.unwrap(); - mem.store("c", "three", MemoryCategory::Conversation) + mem.store("a", "one", MemoryCategory::Core, None) + .await + .unwrap(); + mem.store("b", "two", MemoryCategory::Daily, None) + .await + .unwrap(); + mem.store("c", "three", MemoryCategory::Conversation, None) .await .unwrap(); - let all = mem.list(None).await.unwrap(); + let all = mem.list(None, None).await.unwrap(); assert_eq!(all.len(), 3); } #[tokio::test] async fn sqlite_list_by_category() { let (_tmp, mem) = temp_sqlite(); - mem.store("a", "core1", MemoryCategory::Core).await.unwrap(); - mem.store("b", "core2", MemoryCategory::Core).await.unwrap(); - mem.store("c", "daily1", MemoryCategory::Daily) + mem.store("a", "core1", MemoryCategory::Core, None) + .await + .unwrap(); + mem.store("b", "core2", MemoryCategory::Core, None) + .await + .unwrap(); + mem.store("c", "daily1", MemoryCategory::Daily, None) .await .unwrap(); - let core = mem.list(Some(&MemoryCategory::Core)).await.unwrap(); + let core = mem.list(Some(&MemoryCategory::Core), None).await.unwrap(); assert_eq!(core.len(), 2); - let daily = mem.list(Some(&MemoryCategory::Daily)).await.unwrap(); + let daily = mem.list(Some(&MemoryCategory::Daily), None).await.unwrap(); assert_eq!(daily.len(), 1); } @@ -772,7 +833,7 @@ mod tests { { let mem = SqliteMemory::new(tmp.path()).unwrap(); - mem.store("persist", "I survive restarts", MemoryCategory::Core) + mem.store("persist", "I survive restarts", MemoryCategory::Core, None) .await .unwrap(); } @@ -795,7 +856,7 @@ mod tests { ]; for (i, cat) in categories.iter().enumerate() { - mem.store(&format!("k{i}"), &format!("v{i}"), cat.clone()) + mem.store(&format!("k{i}"), &format!("v{i}"), cat.clone(), None) .await .unwrap(); } @@ -815,21 +876,28 @@ mod tests { "a", "Rust is a systems programming language", MemoryCategory::Core, + None, + ) + .await + .unwrap(); + mem.store( + "b", + "Python is great for scripting", + MemoryCategory::Core, + None, ) .await .unwrap(); - mem.store("b", "Python is great for scripting", MemoryCategory::Core) - .await - .unwrap(); mem.store( "c", "Rust and Rust and Rust everywhere", MemoryCategory::Core, + None, ) .await .unwrap(); - let results = mem.recall("Rust", 10).await.unwrap(); + let results = mem.recall("Rust", 10, None).await.unwrap(); assert!(results.len() >= 2); // All results should contain "Rust" for r in &results { @@ -844,17 +912,17 @@ mod tests { #[tokio::test] async fn fts5_multi_word_query() { let (_tmp, mem) = temp_sqlite(); - mem.store("a", "The quick brown fox jumps", MemoryCategory::Core) + mem.store("a", "The quick brown fox jumps", MemoryCategory::Core, None) .await .unwrap(); - mem.store("b", "A lazy dog sleeps", MemoryCategory::Core) + mem.store("b", "A lazy dog sleeps", MemoryCategory::Core, None) .await .unwrap(); - mem.store("c", "The quick dog runs fast", MemoryCategory::Core) + mem.store("c", "The quick dog runs fast", MemoryCategory::Core, None) .await .unwrap(); - let results = mem.recall("quick dog", 10).await.unwrap(); + let results = mem.recall("quick dog", 10, None).await.unwrap(); assert!(!results.is_empty()); // "The quick dog runs fast" matches both terms assert!(results[0].content.contains("quick")); @@ -863,16 +931,20 @@ mod tests { #[tokio::test] async fn recall_empty_query_returns_empty() { let (_tmp, mem) = temp_sqlite(); - mem.store("a", "data", MemoryCategory::Core).await.unwrap(); - let results = mem.recall("", 10).await.unwrap(); + mem.store("a", "data", MemoryCategory::Core, None) + .await + .unwrap(); + let results = mem.recall("", 10, None).await.unwrap(); assert!(results.is_empty()); } #[tokio::test] async fn recall_whitespace_query_returns_empty() { let (_tmp, mem) = temp_sqlite(); - mem.store("a", "data", MemoryCategory::Core).await.unwrap(); - let results = mem.recall(" ", 10).await.unwrap(); + mem.store("a", "data", MemoryCategory::Core, None) + .await + .unwrap(); + let results = mem.recall(" ", 10, None).await.unwrap(); assert!(results.is_empty()); } @@ -937,9 +1009,14 @@ mod tests { #[tokio::test] async fn fts5_syncs_on_insert() { let (_tmp, mem) = temp_sqlite(); - mem.store("test_key", "unique_searchterm_xyz", MemoryCategory::Core) - .await - .unwrap(); + mem.store( + "test_key", + "unique_searchterm_xyz", + MemoryCategory::Core, + None, + ) + .await + .unwrap(); let conn = mem.conn.lock(); let count: i64 = conn @@ -955,9 +1032,14 @@ mod tests { #[tokio::test] async fn fts5_syncs_on_delete() { let (_tmp, mem) = temp_sqlite(); - mem.store("del_key", "deletable_content_abc", MemoryCategory::Core) - .await - .unwrap(); + mem.store( + "del_key", + "deletable_content_abc", + MemoryCategory::Core, + None, + ) + .await + .unwrap(); mem.forget("del_key").await.unwrap(); let conn = mem.conn.lock(); @@ -974,10 +1056,15 @@ mod tests { #[tokio::test] async fn fts5_syncs_on_update() { let (_tmp, mem) = temp_sqlite(); - mem.store("upd_key", "original_content_111", MemoryCategory::Core) - .await - .unwrap(); - mem.store("upd_key", "updated_content_222", MemoryCategory::Core) + mem.store( + "upd_key", + "original_content_111", + MemoryCategory::Core, + None, + ) + .await + .unwrap(); + mem.store("upd_key", "updated_content_222", MemoryCategory::Core, None) .await .unwrap(); @@ -1019,10 +1106,10 @@ mod tests { #[tokio::test] async fn reindex_rebuilds_fts() { let (_tmp, mem) = temp_sqlite(); - mem.store("r1", "reindex test alpha", MemoryCategory::Core) + mem.store("r1", "reindex test alpha", MemoryCategory::Core, None) .await .unwrap(); - mem.store("r2", "reindex test beta", MemoryCategory::Core) + mem.store("r2", "reindex test beta", MemoryCategory::Core, None) .await .unwrap(); @@ -1031,7 +1118,7 @@ mod tests { assert_eq!(count, 0); // FTS should still work after rebuild - let results = mem.recall("reindex", 10).await.unwrap(); + let results = mem.recall("reindex", 10, None).await.unwrap(); assert_eq!(results.len(), 2); } @@ -1045,12 +1132,13 @@ mod tests { &format!("k{i}"), &format!("common keyword item {i}"), MemoryCategory::Core, + None, ) .await .unwrap(); } - let results = mem.recall("common keyword", 5).await.unwrap(); + let results = mem.recall("common keyword", 5, None).await.unwrap(); assert!(results.len() <= 5); } @@ -1059,11 +1147,11 @@ mod tests { #[tokio::test] async fn recall_results_have_scores() { let (_tmp, mem) = temp_sqlite(); - mem.store("s1", "scored result test", MemoryCategory::Core) + mem.store("s1", "scored result test", MemoryCategory::Core, None) .await .unwrap(); - let results = mem.recall("scored", 10).await.unwrap(); + let results = mem.recall("scored", 10, None).await.unwrap(); assert!(!results.is_empty()); for r in &results { assert!(r.score.is_some(), "Expected score on result: {:?}", r.key); @@ -1075,11 +1163,11 @@ mod tests { #[tokio::test] async fn recall_with_quotes_in_query() { let (_tmp, mem) = temp_sqlite(); - mem.store("q1", "He said hello world", MemoryCategory::Core) + mem.store("q1", "He said hello world", MemoryCategory::Core, None) .await .unwrap(); // Quotes in query should not crash FTS5 - let results = mem.recall("\"hello\"", 10).await.unwrap(); + let results = mem.recall("\"hello\"", 10, None).await.unwrap(); // May or may not match depending on FTS5 escaping, but must not error assert!(results.len() <= 10); } @@ -1087,31 +1175,34 @@ mod tests { #[tokio::test] async fn recall_with_asterisk_in_query() { let (_tmp, mem) = temp_sqlite(); - mem.store("a1", "wildcard test content", MemoryCategory::Core) + mem.store("a1", "wildcard test content", MemoryCategory::Core, None) .await .unwrap(); - let results = mem.recall("wild*", 10).await.unwrap(); + let results = mem.recall("wild*", 10, None).await.unwrap(); assert!(results.len() <= 10); } #[tokio::test] async fn recall_with_parentheses_in_query() { let (_tmp, mem) = temp_sqlite(); - mem.store("p1", "function call test", MemoryCategory::Core) + mem.store("p1", "function call test", MemoryCategory::Core, None) .await .unwrap(); - let results = mem.recall("function()", 10).await.unwrap(); + let results = mem.recall("function()", 10, None).await.unwrap(); assert!(results.len() <= 10); } #[tokio::test] async fn recall_with_sql_injection_attempt() { let (_tmp, mem) = temp_sqlite(); - mem.store("safe", "normal content", MemoryCategory::Core) + mem.store("safe", "normal content", MemoryCategory::Core, None) .await .unwrap(); // Should not crash or leak data - let results = mem.recall("'; DROP TABLE memories; --", 10).await.unwrap(); + let results = mem + .recall("'; DROP TABLE memories; --", 10, None) + .await + .unwrap(); assert!(results.len() <= 10); // Table should still exist assert_eq!(mem.count().await.unwrap(), 1); @@ -1122,7 +1213,9 @@ mod tests { #[tokio::test] async fn store_empty_content() { let (_tmp, mem) = temp_sqlite(); - mem.store("empty", "", MemoryCategory::Core).await.unwrap(); + mem.store("empty", "", MemoryCategory::Core, None) + .await + .unwrap(); let entry = mem.get("empty").await.unwrap().unwrap(); assert_eq!(entry.content, ""); } @@ -1130,7 +1223,7 @@ mod tests { #[tokio::test] async fn store_empty_key() { let (_tmp, mem) = temp_sqlite(); - mem.store("", "content for empty key", MemoryCategory::Core) + mem.store("", "content for empty key", MemoryCategory::Core, None) .await .unwrap(); let entry = mem.get("").await.unwrap().unwrap(); @@ -1141,7 +1234,7 @@ mod tests { async fn store_very_long_content() { let (_tmp, mem) = temp_sqlite(); let long_content = "x".repeat(100_000); - mem.store("long", &long_content, MemoryCategory::Core) + mem.store("long", &long_content, MemoryCategory::Core, None) .await .unwrap(); let entry = mem.get("long").await.unwrap().unwrap(); @@ -1151,9 +1244,14 @@ mod tests { #[tokio::test] async fn store_unicode_and_emoji() { let (_tmp, mem) = temp_sqlite(); - mem.store("emoji_key_🦀", "こんにちは 🚀 Ñoño", MemoryCategory::Core) - .await - .unwrap(); + mem.store( + "emoji_key_🦀", + "こんにちは 🚀 Ñoño", + MemoryCategory::Core, + None, + ) + .await + .unwrap(); let entry = mem.get("emoji_key_🦀").await.unwrap().unwrap(); assert_eq!(entry.content, "こんにちは 🚀 Ñoño"); } @@ -1162,7 +1260,7 @@ mod tests { async fn store_content_with_newlines_and_tabs() { let (_tmp, mem) = temp_sqlite(); let content = "line1\nline2\ttab\rcarriage\n\nnewparagraph"; - mem.store("whitespace", content, MemoryCategory::Core) + mem.store("whitespace", content, MemoryCategory::Core, None) .await .unwrap(); let entry = mem.get("whitespace").await.unwrap().unwrap(); @@ -1174,11 +1272,11 @@ mod tests { #[tokio::test] async fn recall_single_character_query() { let (_tmp, mem) = temp_sqlite(); - mem.store("a", "x marks the spot", MemoryCategory::Core) + mem.store("a", "x marks the spot", MemoryCategory::Core, None) .await .unwrap(); // Single char may not match FTS5 but LIKE fallback should work - let results = mem.recall("x", 10).await.unwrap(); + let results = mem.recall("x", 10, None).await.unwrap(); // Should not crash; may or may not find results assert!(results.len() <= 10); } @@ -1186,23 +1284,23 @@ mod tests { #[tokio::test] async fn recall_limit_zero() { let (_tmp, mem) = temp_sqlite(); - mem.store("a", "some content", MemoryCategory::Core) + mem.store("a", "some content", MemoryCategory::Core, None) .await .unwrap(); - let results = mem.recall("some", 0).await.unwrap(); + let results = mem.recall("some", 0, None).await.unwrap(); assert!(results.is_empty()); } #[tokio::test] async fn recall_limit_one() { let (_tmp, mem) = temp_sqlite(); - mem.store("a", "matching content alpha", MemoryCategory::Core) + mem.store("a", "matching content alpha", MemoryCategory::Core, None) .await .unwrap(); - mem.store("b", "matching content beta", MemoryCategory::Core) + mem.store("b", "matching content beta", MemoryCategory::Core, None) .await .unwrap(); - let results = mem.recall("matching content", 1).await.unwrap(); + let results = mem.recall("matching content", 1, None).await.unwrap(); assert_eq!(results.len(), 1); } @@ -1213,21 +1311,22 @@ mod tests { "rust_preferences", "User likes systems programming", MemoryCategory::Core, + None, ) .await .unwrap(); // "rust" appears in key but not content — LIKE fallback checks key too - let results = mem.recall("rust", 10).await.unwrap(); + let results = mem.recall("rust", 10, None).await.unwrap(); assert!(!results.is_empty(), "Should match by key"); } #[tokio::test] async fn recall_unicode_query() { let (_tmp, mem) = temp_sqlite(); - mem.store("jp", "日本語のテスト", MemoryCategory::Core) + mem.store("jp", "日本語のテスト", MemoryCategory::Core, None) .await .unwrap(); - let results = mem.recall("日本語", 10).await.unwrap(); + let results = mem.recall("日本語", 10, None).await.unwrap(); assert!(!results.is_empty()); } @@ -1238,7 +1337,9 @@ mod tests { let tmp = TempDir::new().unwrap(); { let mem = SqliteMemory::new(tmp.path()).unwrap(); - mem.store("k1", "v1", MemoryCategory::Core).await.unwrap(); + mem.store("k1", "v1", MemoryCategory::Core, None) + .await + .unwrap(); } // Open again — init_schema runs again on existing DB let mem2 = SqliteMemory::new(tmp.path()).unwrap(); @@ -1246,7 +1347,9 @@ mod tests { assert!(entry.is_some()); assert_eq!(entry.unwrap().content, "v1"); // Store more data — should work fine - mem2.store("k2", "v2", MemoryCategory::Daily).await.unwrap(); + mem2.store("k2", "v2", MemoryCategory::Daily, None) + .await + .unwrap(); assert_eq!(mem2.count().await.unwrap(), 2); } @@ -1264,11 +1367,16 @@ mod tests { #[tokio::test] async fn forget_then_recall_no_ghost_results() { let (_tmp, mem) = temp_sqlite(); - mem.store("ghost", "phantom memory content", MemoryCategory::Core) - .await - .unwrap(); + mem.store( + "ghost", + "phantom memory content", + MemoryCategory::Core, + None, + ) + .await + .unwrap(); mem.forget("ghost").await.unwrap(); - let results = mem.recall("phantom memory", 10).await.unwrap(); + let results = mem.recall("phantom memory", 10, None).await.unwrap(); assert!( results.is_empty(), "Deleted memory should not appear in recall" @@ -1278,11 +1386,11 @@ mod tests { #[tokio::test] async fn forget_and_re_store_same_key() { let (_tmp, mem) = temp_sqlite(); - mem.store("cycle", "version 1", MemoryCategory::Core) + mem.store("cycle", "version 1", MemoryCategory::Core, None) .await .unwrap(); mem.forget("cycle").await.unwrap(); - mem.store("cycle", "version 2", MemoryCategory::Core) + mem.store("cycle", "version 2", MemoryCategory::Core, None) .await .unwrap(); let entry = mem.get("cycle").await.unwrap().unwrap(); @@ -1302,14 +1410,14 @@ mod tests { #[tokio::test] async fn reindex_twice_is_safe() { let (_tmp, mem) = temp_sqlite(); - mem.store("r1", "reindex data", MemoryCategory::Core) + mem.store("r1", "reindex data", MemoryCategory::Core, None) .await .unwrap(); mem.reindex().await.unwrap(); let count = mem.reindex().await.unwrap(); assert_eq!(count, 0); // Noop embedder → nothing to re-embed // Data should still be intact - let results = mem.recall("reindex", 10).await.unwrap(); + let results = mem.recall("reindex", 10, None).await.unwrap(); assert_eq!(results.len(), 1); } @@ -1363,18 +1471,28 @@ mod tests { #[tokio::test] async fn list_custom_category() { let (_tmp, mem) = temp_sqlite(); - mem.store("c1", "custom1", MemoryCategory::Custom("project".into())) - .await - .unwrap(); - mem.store("c2", "custom2", MemoryCategory::Custom("project".into())) - .await - .unwrap(); - mem.store("c3", "other", MemoryCategory::Core) + mem.store( + "c1", + "custom1", + MemoryCategory::Custom("project".into()), + None, + ) + .await + .unwrap(); + mem.store( + "c2", + "custom2", + MemoryCategory::Custom("project".into()), + None, + ) + .await + .unwrap(); + mem.store("c3", "other", MemoryCategory::Core, None) .await .unwrap(); let project = mem - .list(Some(&MemoryCategory::Custom("project".into()))) + .list(Some(&MemoryCategory::Custom("project".into())), None) .await .unwrap(); assert_eq!(project.len(), 2); @@ -1383,7 +1501,122 @@ mod tests { #[tokio::test] async fn list_empty_db() { let (_tmp, mem) = temp_sqlite(); - let all = mem.list(None).await.unwrap(); + let all = mem.list(None, None).await.unwrap(); assert!(all.is_empty()); } + + // ── Session isolation ───────────────────────────────────────── + + #[tokio::test] + async fn store_and_recall_with_session_id() { + let (_tmp, mem) = temp_sqlite(); + mem.store("k1", "session A fact", MemoryCategory::Core, Some("sess-a")) + .await + .unwrap(); + mem.store("k2", "session B fact", MemoryCategory::Core, Some("sess-b")) + .await + .unwrap(); + mem.store("k3", "no session fact", MemoryCategory::Core, None) + .await + .unwrap(); + + // Recall with session-a filter returns only session-a entry + let results = mem.recall("fact", 10, Some("sess-a")).await.unwrap(); + assert_eq!(results.len(), 1); + assert_eq!(results[0].key, "k1"); + assert_eq!(results[0].session_id.as_deref(), Some("sess-a")); + } + + #[tokio::test] + async fn recall_no_session_filter_returns_all() { + let (_tmp, mem) = temp_sqlite(); + mem.store("k1", "alpha fact", MemoryCategory::Core, Some("sess-a")) + .await + .unwrap(); + mem.store("k2", "beta fact", MemoryCategory::Core, Some("sess-b")) + .await + .unwrap(); + mem.store("k3", "gamma fact", MemoryCategory::Core, None) + .await + .unwrap(); + + // Recall without session filter returns all matching entries + let results = mem.recall("fact", 10, None).await.unwrap(); + assert_eq!(results.len(), 3); + } + + #[tokio::test] + async fn cross_session_recall_isolation() { + let (_tmp, mem) = temp_sqlite(); + mem.store( + "secret", + "session A secret data", + MemoryCategory::Core, + Some("sess-a"), + ) + .await + .unwrap(); + + // Session B cannot see session A data + let results = mem.recall("secret", 10, Some("sess-b")).await.unwrap(); + assert!(results.is_empty()); + + // Session A can see its own data + let results = mem.recall("secret", 10, Some("sess-a")).await.unwrap(); + assert_eq!(results.len(), 1); + } + + #[tokio::test] + async fn list_with_session_filter() { + let (_tmp, mem) = temp_sqlite(); + mem.store("k1", "a1", MemoryCategory::Core, Some("sess-a")) + .await + .unwrap(); + mem.store("k2", "a2", MemoryCategory::Conversation, Some("sess-a")) + .await + .unwrap(); + mem.store("k3", "b1", MemoryCategory::Core, Some("sess-b")) + .await + .unwrap(); + mem.store("k4", "none1", MemoryCategory::Core, None) + .await + .unwrap(); + + // List with session-a filter + let results = mem.list(None, Some("sess-a")).await.unwrap(); + assert_eq!(results.len(), 2); + assert!(results + .iter() + .all(|e| e.session_id.as_deref() == Some("sess-a"))); + + // List with session-a + category filter + let results = mem + .list(Some(&MemoryCategory::Core), Some("sess-a")) + .await + .unwrap(); + assert_eq!(results.len(), 1); + assert_eq!(results[0].key, "k1"); + } + + #[tokio::test] + async fn schema_migration_idempotent_on_reopen() { + let tmp = TempDir::new().unwrap(); + + // First open: creates schema + migration + { + let mem = SqliteMemory::new(tmp.path()).unwrap(); + mem.store("k1", "before reopen", MemoryCategory::Core, Some("sess-x")) + .await + .unwrap(); + } + + // Second open: migration runs again but is idempotent + { + let mem = SqliteMemory::new(tmp.path()).unwrap(); + let results = mem.recall("reopen", 10, Some("sess-x")).await.unwrap(); + assert_eq!(results.len(), 1); + assert_eq!(results[0].key, "k1"); + assert_eq!(results[0].session_id.as_deref(), Some("sess-x")); + } + } } diff --git a/src/memory/traits.rs b/src/memory/traits.rs index 72e120e..bf8c021 100644 --- a/src/memory/traits.rs +++ b/src/memory/traits.rs @@ -44,18 +44,32 @@ pub trait Memory: Send + Sync { /// Backend name fn name(&self) -> &str; - /// Store a memory entry - async fn store(&self, key: &str, content: &str, category: MemoryCategory) - -> anyhow::Result<()>; + /// Store a memory entry, optionally scoped to a session + async fn store( + &self, + key: &str, + content: &str, + category: MemoryCategory, + session_id: Option<&str>, + ) -> anyhow::Result<()>; - /// Recall memories matching a query (keyword search) - async fn recall(&self, query: &str, limit: usize) -> anyhow::Result>; + /// Recall memories matching a query (keyword search), optionally scoped to a session + async fn recall( + &self, + query: &str, + limit: usize, + session_id: Option<&str>, + ) -> anyhow::Result>; /// Get a specific memory by key async fn get(&self, key: &str) -> anyhow::Result>; - /// List all memory keys, optionally filtered by category - async fn list(&self, category: Option<&MemoryCategory>) -> anyhow::Result>; + /// List all memory keys, optionally filtered by category and/or session + async fn list( + &self, + category: Option<&MemoryCategory>, + session_id: Option<&str>, + ) -> anyhow::Result>; /// Remove a memory by key async fn forget(&self, key: &str) -> anyhow::Result; diff --git a/src/migration.rs b/src/migration.rs index f217030..8a83262 100644 --- a/src/migration.rs +++ b/src/migration.rs @@ -95,7 +95,9 @@ async fn migrate_openclaw_memory( stats.renamed_conflicts += 1; } - memory.store(&key, &entry.content, entry.category).await?; + memory + .store(&key, &entry.content, entry.category, None) + .await?; stats.imported += 1; } @@ -488,7 +490,7 @@ mod tests { // Existing target memory let target_mem = SqliteMemory::new(target.path()).unwrap(); target_mem - .store("k", "new value", MemoryCategory::Core) + .store("k", "new value", MemoryCategory::Core, None) .await .unwrap(); @@ -510,7 +512,7 @@ mod tests { .await .unwrap(); - let all = target_mem.list(None).await.unwrap(); + let all = target_mem.list(None, None).await.unwrap(); assert!(all.iter().any(|e| e.key == "k" && e.content == "new value")); assert!(all .iter() diff --git a/src/observability/log.rs b/src/observability/log.rs index 9e3d062..b932fe0 100644 --- a/src/observability/log.rs +++ b/src/observability/log.rs @@ -48,9 +48,10 @@ impl Observer for LogObserver { ObserverEvent::AgentEnd { duration, tokens_used, + cost_usd, } => { let ms = u64::try_from(duration.as_millis()).unwrap_or(u64::MAX); - info!(duration_ms = ms, tokens = ?tokens_used, "agent.end"); + info!(duration_ms = ms, tokens = ?tokens_used, cost_usd = ?cost_usd, "agent.end"); } ObserverEvent::ToolCallStart { tool } => { info!(tool = %tool, "tool.start"); @@ -133,10 +134,12 @@ mod tests { obs.record_event(&ObserverEvent::AgentEnd { duration: Duration::from_millis(500), tokens_used: Some(100), + cost_usd: Some(0.0015), }); obs.record_event(&ObserverEvent::AgentEnd { duration: Duration::ZERO, tokens_used: None, + cost_usd: None, }); obs.record_event(&ObserverEvent::ToolCallStart { tool: "shell".into(), diff --git a/src/observability/noop.rs b/src/observability/noop.rs index 1189490..004af21 100644 --- a/src/observability/noop.rs +++ b/src/observability/noop.rs @@ -48,10 +48,12 @@ mod tests { obs.record_event(&ObserverEvent::AgentEnd { duration: Duration::from_millis(100), tokens_used: Some(42), + cost_usd: Some(0.001), }); obs.record_event(&ObserverEvent::AgentEnd { duration: Duration::ZERO, tokens_used: None, + cost_usd: None, }); obs.record_event(&ObserverEvent::ToolCallStart { tool: "shell".into(), diff --git a/src/observability/otel.rs b/src/observability/otel.rs index 5e0c37e..ae4932d 100644 --- a/src/observability/otel.rs +++ b/src/observability/otel.rs @@ -227,6 +227,7 @@ impl Observer for OtelObserver { ObserverEvent::AgentEnd { duration, tokens_used, + cost_usd, } => { let secs = duration.as_secs_f64(); let start_time = SystemTime::now() @@ -243,6 +244,9 @@ impl Observer for OtelObserver { if let Some(t) = tokens_used { span.set_attribute(KeyValue::new("tokens_used", *t as i64)); } + if let Some(c) = cost_usd { + span.set_attribute(KeyValue::new("cost_usd", *c)); + } span.end(); self.agent_duration.record(secs, &[]); @@ -394,10 +398,12 @@ mod tests { obs.record_event(&ObserverEvent::AgentEnd { duration: Duration::from_millis(500), tokens_used: Some(100), + cost_usd: Some(0.0015), }); obs.record_event(&ObserverEvent::AgentEnd { duration: Duration::ZERO, tokens_used: None, + cost_usd: None, }); obs.record_event(&ObserverEvent::ToolCallStart { tool: "shell".into(), diff --git a/src/observability/traits.rs b/src/observability/traits.rs index ca62caf..d978304 100644 --- a/src/observability/traits.rs +++ b/src/observability/traits.rs @@ -27,6 +27,7 @@ pub enum ObserverEvent { AgentEnd { duration: Duration, tokens_used: Option, + cost_usd: Option, }, /// A tool call is about to be executed. ToolCallStart { diff --git a/src/onboard/wizard.rs b/src/onboard/wizard.rs index 20c3baa..0422e45 100644 --- a/src/onboard/wizard.rs +++ b/src/onboard/wizard.rs @@ -106,6 +106,7 @@ pub fn run_wizard() -> Result { } else { Some(api_key) }, + api_url: None, default_provider: Some(provider), default_model: Some(model), default_temperature: 0.7, @@ -284,7 +285,7 @@ fn memory_config_defaults_for_backend(backend: &str) -> MemoryConfig { #[allow(clippy::too_many_lines)] pub fn run_quick_setup( - api_key: Option<&str>, + credential_override: Option<&str>, provider: Option<&str>, memory_backend: Option<&str>, ) -> Result { @@ -318,7 +319,8 @@ pub fn run_quick_setup( let config = Config { workspace_dir: workspace_dir.clone(), config_path: config_path.clone(), - api_key: api_key.map(String::from), + api_key: credential_override.map(String::from), + api_url: None, default_provider: Some(provider_name.clone()), default_model: Some(model.clone()), default_temperature: 0.7, @@ -377,7 +379,7 @@ pub fn run_quick_setup( println!( " {} API Key: {}", style("✓").green().bold(), - if api_key.is_some() { + if credential_override.is_some() { style("set").green() } else { style("not set (use --api-key or edit config.toml)").yellow() @@ -426,7 +428,7 @@ pub fn run_quick_setup( ); println!(); println!(" {}", style("Next steps:").white().bold()); - if api_key.is_none() { + if credential_override.is_none() { println!(" 1. Set your API key: export OPENROUTER_API_KEY=\"sk-...\""); println!(" 2. Or edit: ~/.zeroclaw/config.toml"); println!(" 3. Chat: zeroclaw agent -m \"Hello!\""); @@ -2269,14 +2271,11 @@ fn setup_memory() -> Result { let backend = backend_key_from_choice(choice); let profile = memory_backend_profile(backend); - let auto_save = if !profile.auto_save_default { - false - } else { - Confirm::new() + let auto_save = profile.auto_save_default + && Confirm::new() .with_prompt(" Auto-save conversations to memory?") .default(true) - .interact()? - }; + .interact()?; println!( " {} Memory: {} (auto-save: {})", @@ -2587,6 +2586,7 @@ fn setup_channels() -> Result { guild_id: if guild.is_empty() { None } else { Some(guild) }, allowed_users, listen_to_bots: false, + mention_only: false, }); } 2 => { @@ -2799,22 +2799,14 @@ fn setup_channels() -> Result { .header("Authorization", format!("Bearer {access_token_clone}")) .send()?; let ok = resp.status().is_success(); - let data: serde_json::Value = resp.json().unwrap_or_default(); - let user_id = data - .get("user_id") - .and_then(serde_json::Value::as_str) - .unwrap_or("unknown") - .to_string(); - Ok::<_, reqwest::Error>((ok, user_id)) + Ok::<_, reqwest::Error>(ok) }) .join(); match thread_result { - Ok(Ok((true, user_id))) => { - println!( - "\r {} Connected as {user_id} ", - style("✅").green().bold() - ); - } + Ok(Ok(true)) => println!( + "\r {} Connection verified ", + style("✅").green().bold() + ), _ => { println!( "\r {} Connection failed — check homeserver URL and token", @@ -3779,15 +3771,7 @@ fn print_summary(config: &Config) { ); // Secrets - println!( - " {} Secrets: {}", - style("🔒").cyan(), - if config.secrets.encrypt { - style("encrypted").green().to_string() - } else { - style("plaintext").yellow().to_string() - } - ); + println!(" {} Secrets: configured", style("🔒").cyan()); // Gateway println!( diff --git a/src/peripherals/arduino_flash.rs b/src/peripherals/arduino_flash.rs index 8aaf287..7bc53f5 100644 --- a/src/peripherals/arduino_flash.rs +++ b/src/peripherals/arduino_flash.rs @@ -38,6 +38,10 @@ pub fn ensure_arduino_cli() -> Result<()> { anyhow::bail!("brew install arduino-cli failed. Install manually: https://arduino.github.io/arduino-cli/"); } println!("arduino-cli installed."); + if !arduino_cli_available() { + anyhow::bail!("arduino-cli still not found after install. Ensure it's in PATH."); + } + return Ok(()); } #[cfg(target_os = "linux")] @@ -54,11 +58,6 @@ pub fn ensure_arduino_cli() -> Result<()> { println!("arduino-cli not found. Install it: https://arduino.github.io/arduino-cli/"); anyhow::bail!("arduino-cli not installed."); } - - if !arduino_cli_available() { - anyhow::bail!("arduino-cli still not found after install. Ensure it's in PATH."); - } - Ok(()) } /// Ensure arduino:avr core is installed. diff --git a/src/peripherals/serial.rs b/src/peripherals/serial.rs index 05d0bae..2bcec56 100644 --- a/src/peripherals/serial.rs +++ b/src/peripherals/serial.rs @@ -112,6 +112,7 @@ pub struct SerialPeripheral { impl SerialPeripheral { /// Create and connect to a serial peripheral. + #[allow(clippy::unused_async)] pub async fn connect(config: &PeripheralBoardConfig) -> anyhow::Result { let path = config .path diff --git a/src/providers/anthropic.rs b/src/providers/anthropic.rs index 4216853..1f45c7e 100644 --- a/src/providers/anthropic.rs +++ b/src/providers/anthropic.rs @@ -106,17 +106,17 @@ struct NativeContentIn { } impl AnthropicProvider { - pub fn new(api_key: Option<&str>) -> Self { - Self::with_base_url(api_key, None) + pub fn new(credential: Option<&str>) -> Self { + Self::with_base_url(credential, None) } - pub fn with_base_url(api_key: Option<&str>, base_url: Option<&str>) -> Self { + pub fn with_base_url(credential: Option<&str>, base_url: Option<&str>) -> Self { let base_url = base_url .map(|u| u.trim_end_matches('/')) .unwrap_or("https://api.anthropic.com") .to_string(); Self { - credential: api_key + credential: credential .map(str::trim) .filter(|k| !k.is_empty()) .map(ToString::to_string), @@ -410,9 +410,9 @@ mod tests { #[test] fn creates_with_key() { - let p = AnthropicProvider::new(Some("sk-ant-test123")); + let p = AnthropicProvider::new(Some("anthropic-test-credential")); assert!(p.credential.is_some()); - assert_eq!(p.credential.as_deref(), Some("sk-ant-test123")); + assert_eq!(p.credential.as_deref(), Some("anthropic-test-credential")); assert_eq!(p.base_url, "https://api.anthropic.com"); } @@ -431,17 +431,19 @@ mod tests { #[test] fn creates_with_whitespace_key() { - let p = AnthropicProvider::new(Some(" sk-ant-test123 ")); + let p = AnthropicProvider::new(Some(" anthropic-test-credential ")); assert!(p.credential.is_some()); - assert_eq!(p.credential.as_deref(), Some("sk-ant-test123")); + assert_eq!(p.credential.as_deref(), Some("anthropic-test-credential")); } #[test] fn creates_with_custom_base_url() { - let p = - AnthropicProvider::with_base_url(Some("sk-ant-test"), Some("https://api.example.com")); + let p = AnthropicProvider::with_base_url( + Some("anthropic-credential"), + Some("https://api.example.com"), + ); assert_eq!(p.base_url, "https://api.example.com"); - assert_eq!(p.credential.as_deref(), Some("sk-ant-test")); + assert_eq!(p.credential.as_deref(), Some("anthropic-credential")); } #[test] diff --git a/src/providers/compatible.rs b/src/providers/compatible.rs index ee1c588..d17d309 100644 --- a/src/providers/compatible.rs +++ b/src/providers/compatible.rs @@ -17,7 +17,7 @@ use serde::{Deserialize, Serialize}; pub struct OpenAiCompatibleProvider { pub(crate) name: String, pub(crate) base_url: String, - pub(crate) api_key: Option, + pub(crate) credential: Option, pub(crate) auth_header: AuthStyle, /// When false, do not fall back to /v1/responses on chat completions 404. /// GLM/Zhipu does not support the responses API. @@ -37,11 +37,16 @@ pub enum AuthStyle { } impl OpenAiCompatibleProvider { - pub fn new(name: &str, base_url: &str, api_key: Option<&str>, auth_style: AuthStyle) -> Self { + pub fn new( + name: &str, + base_url: &str, + credential: Option<&str>, + auth_style: AuthStyle, + ) -> Self { Self { name: name.to_string(), base_url: base_url.trim_end_matches('/').to_string(), - api_key: api_key.map(ToString::to_string), + credential: credential.map(ToString::to_string), auth_header: auth_style, supports_responses_fallback: true, client: Client::builder() @@ -57,13 +62,13 @@ impl OpenAiCompatibleProvider { pub fn new_no_responses_fallback( name: &str, base_url: &str, - api_key: Option<&str>, + credential: Option<&str>, auth_style: AuthStyle, ) -> Self { Self { name: name.to_string(), base_url: base_url.trim_end_matches('/').to_string(), - api_key: api_key.map(ToString::to_string), + credential: credential.map(ToString::to_string), auth_header: auth_style, supports_responses_fallback: false, client: Client::builder() @@ -405,18 +410,18 @@ impl OpenAiCompatibleProvider { fn apply_auth_header( &self, req: reqwest::RequestBuilder, - api_key: &str, + credential: &str, ) -> reqwest::RequestBuilder { match &self.auth_header { - AuthStyle::Bearer => req.header("Authorization", format!("Bearer {api_key}")), - AuthStyle::XApiKey => req.header("x-api-key", api_key), - AuthStyle::Custom(header) => req.header(header, api_key), + AuthStyle::Bearer => req.header("Authorization", format!("Bearer {credential}")), + AuthStyle::XApiKey => req.header("x-api-key", credential), + AuthStyle::Custom(header) => req.header(header, credential), } } async fn chat_via_responses( &self, - api_key: &str, + credential: &str, system_prompt: Option<&str>, message: &str, model: &str, @@ -434,7 +439,7 @@ impl OpenAiCompatibleProvider { let url = self.responses_url(); let response = self - .apply_auth_header(self.client.post(&url).json(&request), api_key) + .apply_auth_header(self.client.post(&url).json(&request), credential) .send() .await?; @@ -459,7 +464,7 @@ impl Provider for OpenAiCompatibleProvider { model: &str, temperature: f64, ) -> anyhow::Result { - let api_key = self.api_key.as_ref().ok_or_else(|| { + let credential = self.credential.as_ref().ok_or_else(|| { anyhow::anyhow!( "{} API key not set. Run `zeroclaw onboard` or set the appropriate env var.", self.name @@ -490,7 +495,7 @@ impl Provider for OpenAiCompatibleProvider { let url = self.chat_completions_url(); let response = self - .apply_auth_header(self.client.post(&url).json(&request), api_key) + .apply_auth_header(self.client.post(&url).json(&request), credential) .send() .await?; @@ -501,7 +506,7 @@ impl Provider for OpenAiCompatibleProvider { if status == reqwest::StatusCode::NOT_FOUND && self.supports_responses_fallback { return self - .chat_via_responses(api_key, system_prompt, message, model) + .chat_via_responses(credential, system_prompt, message, model) .await .map_err(|responses_err| { anyhow::anyhow!( @@ -545,7 +550,7 @@ impl Provider for OpenAiCompatibleProvider { model: &str, temperature: f64, ) -> anyhow::Result { - let api_key = self.api_key.as_ref().ok_or_else(|| { + let credential = self.credential.as_ref().ok_or_else(|| { anyhow::anyhow!( "{} API key not set. Run `zeroclaw onboard` or set the appropriate env var.", self.name @@ -569,7 +574,7 @@ impl Provider for OpenAiCompatibleProvider { let url = self.chat_completions_url(); let response = self - .apply_auth_header(self.client.post(&url).json(&request), api_key) + .apply_auth_header(self.client.post(&url).json(&request), credential) .send() .await?; @@ -584,7 +589,7 @@ impl Provider for OpenAiCompatibleProvider { if let Some(user_msg) = last_user { return self .chat_via_responses( - api_key, + credential, system.map(|m| m.content.as_str()), &user_msg.content, model, @@ -791,16 +796,20 @@ mod tests { #[test] fn creates_with_key() { - let p = make_provider("venice", "https://api.venice.ai", Some("vn-key")); + let p = make_provider( + "venice", + "https://api.venice.ai", + Some("venice-test-credential"), + ); assert_eq!(p.name, "venice"); assert_eq!(p.base_url, "https://api.venice.ai"); - assert_eq!(p.api_key.as_deref(), Some("vn-key")); + assert_eq!(p.credential.as_deref(), Some("venice-test-credential")); } #[test] fn creates_without_key() { let p = make_provider("test", "https://example.com", None); - assert!(p.api_key.is_none()); + assert!(p.credential.is_none()); } #[test] @@ -894,6 +903,7 @@ mod tests { make_provider("Groq", "https://api.groq.com/openai", None), make_provider("Mistral", "https://api.mistral.ai", None), make_provider("xAI", "https://api.x.ai", None), + make_provider("Astrai", "https://as-trai.com/v1", None), ]; for p in providers { diff --git a/src/providers/copilot.rs b/src/providers/copilot.rs new file mode 100644 index 0000000..ab8eb3b --- /dev/null +++ b/src/providers/copilot.rs @@ -0,0 +1,705 @@ +//! GitHub Copilot provider with OAuth device-flow authentication. +//! +//! Authenticates via GitHub's device code flow (same as VS Code Copilot), +//! then exchanges the OAuth token for short-lived Copilot API keys. +//! Tokens are cached to disk and auto-refreshed. +//! +//! **Note:** This uses VS Code's OAuth client ID (`Iv1.b507a08c87ecfe98`) and +//! editor headers. This is the same approach used by LiteLLM, Codex CLI, +//! and other third-party Copilot integrations. The Copilot token endpoint is +//! private; there is no public OAuth scope or app registration for it. +//! GitHub could change or revoke this at any time, which would break all +//! third-party integrations simultaneously. + +use crate::providers::traits::{ + ChatMessage, ChatRequest as ProviderChatRequest, ChatResponse as ProviderChatResponse, + Provider, ToolCall as ProviderToolCall, +}; +use crate::tools::ToolSpec; +use async_trait::async_trait; +use reqwest::Client; +use serde::{Deserialize, Serialize}; +use std::path::{Path, PathBuf}; +use std::sync::Arc; +use std::time::Duration; +use tokio::sync::Mutex; +use tracing::warn; + +/// GitHub OAuth client ID for Copilot (VS Code extension). +const GITHUB_CLIENT_ID: &str = "Iv1.b507a08c87ecfe98"; +const GITHUB_DEVICE_CODE_URL: &str = "https://github.com/login/device/code"; +const GITHUB_ACCESS_TOKEN_URL: &str = "https://github.com/login/oauth/access_token"; +const GITHUB_API_KEY_URL: &str = "https://api.github.com/copilot_internal/v2/token"; +const DEFAULT_API: &str = "https://api.githubcopilot.com"; + +// ── Token types ────────────────────────────────────────────────── + +#[derive(Debug, Deserialize)] +struct DeviceCodeResponse { + device_code: String, + user_code: String, + verification_uri: String, + #[serde(default = "default_interval")] + interval: u64, + #[serde(default = "default_expires_in")] + expires_in: u64, +} + +fn default_interval() -> u64 { + 5 +} + +fn default_expires_in() -> u64 { + 900 +} + +#[derive(Debug, Deserialize)] +struct AccessTokenResponse { + access_token: Option, + error: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +struct ApiKeyInfo { + token: String, + expires_at: i64, + #[serde(default)] + endpoints: Option, +} + +#[derive(Debug, Serialize, Deserialize)] +struct ApiEndpoints { + api: Option, +} + +struct CachedApiKey { + token: String, + api_endpoint: String, + expires_at: i64, +} + +// ── Chat completions types ─────────────────────────────────────── + +#[derive(Debug, Serialize)] +struct ApiChatRequest { + model: String, + messages: Vec, + temperature: f64, + #[serde(skip_serializing_if = "Option::is_none")] + tools: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + tool_choice: Option, +} + +#[derive(Debug, Serialize)] +struct ApiMessage { + role: String, + #[serde(skip_serializing_if = "Option::is_none")] + content: Option, + #[serde(skip_serializing_if = "Option::is_none")] + tool_call_id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + tool_calls: Option>, +} + +#[derive(Debug, Serialize)] +struct NativeToolSpec { + #[serde(rename = "type")] + kind: String, + function: NativeToolFunctionSpec, +} + +#[derive(Debug, Serialize)] +struct NativeToolFunctionSpec { + name: String, + description: String, + parameters: serde_json::Value, +} + +#[derive(Debug, Serialize, Deserialize)] +struct NativeToolCall { + #[serde(skip_serializing_if = "Option::is_none")] + id: Option, + #[serde(rename = "type", skip_serializing_if = "Option::is_none")] + kind: Option, + function: NativeFunctionCall, +} + +#[derive(Debug, Serialize, Deserialize)] +struct NativeFunctionCall { + name: String, + arguments: String, +} + +#[derive(Debug, Deserialize)] +struct ApiChatResponse { + choices: Vec, +} + +#[derive(Debug, Deserialize)] +struct Choice { + message: ResponseMessage, +} + +#[derive(Debug, Deserialize)] +struct ResponseMessage { + #[serde(default)] + content: Option, + #[serde(default)] + tool_calls: Option>, +} + +// ── Provider ───────────────────────────────────────────────────── + +/// GitHub Copilot provider with automatic OAuth and token refresh. +/// +/// On first use, prompts the user to visit github.com/login/device. +/// Tokens are cached to `~/.config/zeroclaw/copilot/` and refreshed +/// automatically. +pub struct CopilotProvider { + github_token: Option, + /// Mutex ensures only one caller refreshes tokens at a time, + /// preventing duplicate device flow prompts or redundant API calls. + refresh_lock: Arc>>, + http: Client, + token_dir: PathBuf, +} + +impl CopilotProvider { + pub fn new(github_token: Option<&str>) -> Self { + let token_dir = directories::ProjectDirs::from("", "", "zeroclaw") + .map(|dir| dir.config_dir().join("copilot")) + .unwrap_or_else(|| { + // Fall back to a user-specific temp directory to avoid + // shared-directory symlink attacks. + let user = std::env::var("USER") + .or_else(|_| std::env::var("USERNAME")) + .unwrap_or_else(|_| "unknown".to_string()); + std::env::temp_dir().join(format!("zeroclaw-copilot-{user}")) + }); + + if let Err(err) = std::fs::create_dir_all(&token_dir) { + warn!( + "Failed to create Copilot token directory {:?}: {err}. Token caching is disabled.", + token_dir + ); + } else { + #[cfg(unix)] + { + use std::os::unix::fs::PermissionsExt; + + if let Err(err) = + std::fs::set_permissions(&token_dir, std::fs::Permissions::from_mode(0o700)) + { + warn!( + "Failed to set Copilot token directory permissions on {:?}: {err}", + token_dir + ); + } + } + } + + Self { + github_token: github_token + .filter(|token| !token.is_empty()) + .map(String::from), + refresh_lock: Arc::new(Mutex::new(None)), + http: Client::builder() + .timeout(Duration::from_secs(120)) + .connect_timeout(Duration::from_secs(10)) + .build() + .unwrap_or_else(|_| Client::new()), + token_dir, + } + } + + /// Required headers for Copilot API requests (editor identification). + const COPILOT_HEADERS: [(&str, &str); 4] = [ + ("Editor-Version", "vscode/1.85.1"), + ("Editor-Plugin-Version", "copilot/1.155.0"), + ("User-Agent", "GithubCopilot/1.155.0"), + ("Accept", "application/json"), + ]; + + fn convert_tools(tools: Option<&[ToolSpec]>) -> Option> { + tools.map(|items| { + items + .iter() + .map(|tool| NativeToolSpec { + kind: "function".to_string(), + function: NativeToolFunctionSpec { + name: tool.name.clone(), + description: tool.description.clone(), + parameters: tool.parameters.clone(), + }, + }) + .collect() + }) + } + + fn convert_messages(messages: &[ChatMessage]) -> Vec { + messages + .iter() + .map(|message| { + if message.role == "assistant" { + if let Ok(value) = serde_json::from_str::(&message.content) { + if let Some(tool_calls_value) = value.get("tool_calls") { + if let Ok(parsed_calls) = + serde_json::from_value::>(tool_calls_value.clone()) + { + let tool_calls = parsed_calls + .into_iter() + .map(|tool_call| NativeToolCall { + id: Some(tool_call.id), + kind: Some("function".to_string()), + function: NativeFunctionCall { + name: tool_call.name, + arguments: tool_call.arguments, + }, + }) + .collect::>(); + + let content = value + .get("content") + .and_then(serde_json::Value::as_str) + .map(ToString::to_string); + + return ApiMessage { + role: "assistant".to_string(), + content, + tool_call_id: None, + tool_calls: Some(tool_calls), + }; + } + } + } + } + + if message.role == "tool" { + if let Ok(value) = serde_json::from_str::(&message.content) { + let tool_call_id = value + .get("tool_call_id") + .and_then(serde_json::Value::as_str) + .map(ToString::to_string); + let content = value + .get("content") + .and_then(serde_json::Value::as_str) + .map(ToString::to_string); + + return ApiMessage { + role: "tool".to_string(), + content, + tool_call_id, + tool_calls: None, + }; + } + } + + ApiMessage { + role: message.role.clone(), + content: Some(message.content.clone()), + tool_call_id: None, + tool_calls: None, + } + }) + .collect() + } + + /// Send a chat completions request with required Copilot headers. + async fn send_chat_request( + &self, + messages: Vec, + tools: Option<&[ToolSpec]>, + model: &str, + temperature: f64, + ) -> anyhow::Result { + let (token, endpoint) = self.get_api_key().await?; + let url = format!("{}/chat/completions", endpoint.trim_end_matches('/')); + + let native_tools = Self::convert_tools(tools); + let request = ApiChatRequest { + model: model.to_string(), + messages, + temperature, + tool_choice: native_tools.as_ref().map(|_| "auto".to_string()), + tools: native_tools, + }; + + let mut req = self + .http + .post(&url) + .header("Authorization", format!("Bearer {token}")) + .json(&request); + + for (header, value) in &Self::COPILOT_HEADERS { + req = req.header(*header, *value); + } + + let response = req.send().await?; + + if !response.status().is_success() { + return Err(super::api_error("GitHub Copilot", response).await); + } + + let api_response: ApiChatResponse = response.json().await?; + let choice = api_response + .choices + .into_iter() + .next() + .ok_or_else(|| anyhow::anyhow!("No response from GitHub Copilot"))?; + + let tool_calls = choice + .message + .tool_calls + .unwrap_or_default() + .into_iter() + .map(|tool_call| ProviderToolCall { + id: tool_call + .id + .unwrap_or_else(|| uuid::Uuid::new_v4().to_string()), + name: tool_call.function.name, + arguments: tool_call.function.arguments, + }) + .collect(); + + Ok(ProviderChatResponse { + text: choice.message.content, + tool_calls, + }) + } + + /// Get a valid Copilot API key, refreshing or re-authenticating as needed. + /// Uses a Mutex to ensure only one caller refreshes at a time. + async fn get_api_key(&self) -> anyhow::Result<(String, String)> { + let mut cached = self.refresh_lock.lock().await; + + if let Some(cached_key) = cached.as_ref() { + if chrono::Utc::now().timestamp() + 120 < cached_key.expires_at { + return Ok((cached_key.token.clone(), cached_key.api_endpoint.clone())); + } + } + + if let Some(info) = self.load_api_key_from_disk().await { + if chrono::Utc::now().timestamp() + 120 < info.expires_at { + let endpoint = info + .endpoints + .as_ref() + .and_then(|e| e.api.clone()) + .unwrap_or_else(|| DEFAULT_API.to_string()); + let token = info.token; + + *cached = Some(CachedApiKey { + token: token.clone(), + api_endpoint: endpoint.clone(), + expires_at: info.expires_at, + }); + return Ok((token, endpoint)); + } + } + + let access_token = self.get_github_access_token().await?; + let api_key_info = self.exchange_for_api_key(&access_token).await?; + self.save_api_key_to_disk(&api_key_info).await; + + let endpoint = api_key_info + .endpoints + .as_ref() + .and_then(|e| e.api.clone()) + .unwrap_or_else(|| DEFAULT_API.to_string()); + + *cached = Some(CachedApiKey { + token: api_key_info.token.clone(), + api_endpoint: endpoint.clone(), + expires_at: api_key_info.expires_at, + }); + + Ok((api_key_info.token, endpoint)) + } + + /// Get a GitHub access token from config, cache, or device flow. + async fn get_github_access_token(&self) -> anyhow::Result { + if let Some(token) = &self.github_token { + return Ok(token.clone()); + } + + let access_token_path = self.token_dir.join("access-token"); + if let Ok(cached) = tokio::fs::read_to_string(&access_token_path).await { + let token = cached.trim(); + if !token.is_empty() { + return Ok(token.to_string()); + } + } + + let token = self.device_code_login().await?; + write_file_secure(&access_token_path, &token).await; + Ok(token) + } + + /// Run GitHub OAuth device code flow. + async fn device_code_login(&self) -> anyhow::Result { + let response: DeviceCodeResponse = self + .http + .post(GITHUB_DEVICE_CODE_URL) + .header("Accept", "application/json") + .json(&serde_json::json!({ + "client_id": GITHUB_CLIENT_ID, + "scope": "read:user" + })) + .send() + .await? + .error_for_status()? + .json() + .await?; + + let mut poll_interval = Duration::from_secs(response.interval.max(5)); + let expires_in = response.expires_in.max(1); + let expires_at = tokio::time::Instant::now() + Duration::from_secs(expires_in); + + eprintln!( + "\nGitHub Copilot authentication is required.\n\ + Visit: {}\n\ + Code: {}\n\ + Waiting for authorization...\n", + response.verification_uri, response.user_code + ); + + while tokio::time::Instant::now() < expires_at { + tokio::time::sleep(poll_interval).await; + + let token_response: AccessTokenResponse = self + .http + .post(GITHUB_ACCESS_TOKEN_URL) + .header("Accept", "application/json") + .json(&serde_json::json!({ + "client_id": GITHUB_CLIENT_ID, + "device_code": response.device_code, + "grant_type": "urn:ietf:params:oauth:grant-type:device_code" + })) + .send() + .await? + .json() + .await?; + + if let Some(token) = token_response.access_token { + eprintln!("Authentication succeeded.\n"); + return Ok(token); + } + + match token_response.error.as_deref() { + Some("slow_down") => { + poll_interval += Duration::from_secs(5); + } + Some("authorization_pending") | None => {} + Some("expired_token") => { + anyhow::bail!("GitHub device authorization expired") + } + Some(error) => anyhow::bail!("GitHub auth failed: {error}"), + } + } + + anyhow::bail!("Timed out waiting for GitHub authorization") + } + + /// Exchange a GitHub access token for a Copilot API key. + async fn exchange_for_api_key(&self, access_token: &str) -> anyhow::Result { + let mut request = self.http.get(GITHUB_API_KEY_URL); + for (header, value) in &Self::COPILOT_HEADERS { + request = request.header(*header, *value); + } + request = request.header("Authorization", format!("token {access_token}")); + + let response = request.send().await?; + + if !response.status().is_success() { + let status = response.status(); + let body = response.text().await.unwrap_or_default(); + let sanitized = super::sanitize_api_error(&body); + + if status.as_u16() == 401 || status.as_u16() == 403 { + let access_token_path = self.token_dir.join("access-token"); + tokio::fs::remove_file(&access_token_path).await.ok(); + } + + anyhow::bail!( + "Failed to get Copilot API key ({status}): {sanitized}. \ + Ensure your GitHub account has an active Copilot subscription." + ); + } + + let info: ApiKeyInfo = response.json().await?; + Ok(info) + } + + async fn load_api_key_from_disk(&self) -> Option { + let path = self.token_dir.join("api-key.json"); + let data = tokio::fs::read_to_string(&path).await.ok()?; + serde_json::from_str(&data).ok() + } + + async fn save_api_key_to_disk(&self, info: &ApiKeyInfo) { + let path = self.token_dir.join("api-key.json"); + if let Ok(json) = serde_json::to_string_pretty(info) { + write_file_secure(&path, &json).await; + } + } +} + +/// Write a file with 0600 permissions (owner read/write only). +/// Uses `spawn_blocking` to avoid blocking the async runtime. +async fn write_file_secure(path: &Path, content: &str) { + let path = path.to_path_buf(); + let content = content.to_string(); + + let result = tokio::task::spawn_blocking(move || { + #[cfg(unix)] + { + use std::io::Write; + use std::os::unix::fs::{OpenOptionsExt, PermissionsExt}; + + let mut file = std::fs::OpenOptions::new() + .write(true) + .create(true) + .truncate(true) + .mode(0o600) + .open(&path)?; + file.write_all(content.as_bytes())?; + + std::fs::set_permissions(&path, std::fs::Permissions::from_mode(0o600))?; + Ok::<(), std::io::Error>(()) + } + #[cfg(not(unix))] + { + std::fs::write(&path, &content)?; + Ok::<(), std::io::Error>(()) + } + }) + .await; + + match result { + Ok(Ok(())) => {} + Ok(Err(err)) => warn!("Failed to write secure file: {err}"), + Err(err) => warn!("Failed to spawn blocking write: {err}"), + } +} + +#[async_trait] +impl Provider for CopilotProvider { + async fn chat_with_system( + &self, + system_prompt: Option<&str>, + message: &str, + model: &str, + temperature: f64, + ) -> anyhow::Result { + let mut messages = Vec::new(); + if let Some(system) = system_prompt { + messages.push(ApiMessage { + role: "system".to_string(), + content: Some(system.to_string()), + tool_call_id: None, + tool_calls: None, + }); + } + messages.push(ApiMessage { + role: "user".to_string(), + content: Some(message.to_string()), + tool_call_id: None, + tool_calls: None, + }); + + let response = self + .send_chat_request(messages, None, model, temperature) + .await?; + Ok(response.text.unwrap_or_default()) + } + + async fn chat_with_history( + &self, + messages: &[ChatMessage], + model: &str, + temperature: f64, + ) -> anyhow::Result { + let response = self + .send_chat_request(Self::convert_messages(messages), None, model, temperature) + .await?; + Ok(response.text.unwrap_or_default()) + } + + async fn chat( + &self, + request: ProviderChatRequest<'_>, + model: &str, + temperature: f64, + ) -> anyhow::Result { + self.send_chat_request( + Self::convert_messages(request.messages), + request.tools, + model, + temperature, + ) + .await + } + + fn supports_native_tools(&self) -> bool { + true + } + + async fn warmup(&self) -> anyhow::Result<()> { + let _ = self.get_api_key().await?; + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn new_without_token() { + let provider = CopilotProvider::new(None); + assert!(provider.github_token.is_none()); + } + + #[test] + fn new_with_token() { + let provider = CopilotProvider::new(Some("ghp_test")); + assert_eq!(provider.github_token.as_deref(), Some("ghp_test")); + } + + #[test] + fn empty_token_treated_as_none() { + let provider = CopilotProvider::new(Some("")); + assert!(provider.github_token.is_none()); + } + + #[tokio::test] + async fn cache_starts_empty() { + let provider = CopilotProvider::new(None); + let cached = provider.refresh_lock.lock().await; + assert!(cached.is_none()); + } + + #[test] + fn copilot_headers_include_required_fields() { + let headers = CopilotProvider::COPILOT_HEADERS; + assert!(headers + .iter() + .any(|(header, _)| *header == "Editor-Version")); + assert!(headers + .iter() + .any(|(header, _)| *header == "Editor-Plugin-Version")); + assert!(headers.iter().any(|(header, _)| *header == "User-Agent")); + } + + #[test] + fn default_interval_and_expiry() { + assert_eq!(default_interval(), 5); + assert_eq!(default_expires_in(), 900); + } + + #[test] + fn supports_native_tools() { + let provider = CopilotProvider::new(None); + assert!(provider.supports_native_tools()); + } +} diff --git a/src/providers/mod.rs b/src/providers/mod.rs index 86517d6..e18e789 100644 --- a/src/providers/mod.rs +++ b/src/providers/mod.rs @@ -1,5 +1,6 @@ pub mod anthropic; pub mod compatible; +pub mod copilot; pub mod gemini; pub mod ollama; pub mod openai; @@ -37,9 +38,18 @@ fn token_end(input: &str, from: usize) -> usize { /// Scrub known secret-like token prefixes from provider error strings. /// -/// Redacts tokens with prefixes like `sk-`, `xoxb-`, and `xoxp-`. +/// Redacts tokens with prefixes like `sk-`, `xoxb-`, `xoxp-`, `ghp_`, `gho_`, +/// `ghu_`, and `github_pat_`. pub fn scrub_secret_patterns(input: &str) -> String { - const PREFIXES: [&str; 3] = ["sk-", "xoxb-", "xoxp-"]; + const PREFIXES: [&str; 7] = [ + "sk-", + "xoxb-", + "xoxp-", + "ghp_", + "gho_", + "ghu_", + "github_pat_", + ]; let mut scrubbed = input.to_string(); @@ -104,9 +114,12 @@ pub async fn api_error(provider: &str, response: reqwest::Response) -> anyhow::E /// /// For Anthropic, the provider-specific env var is `ANTHROPIC_OAUTH_TOKEN` (for setup-tokens) /// followed by `ANTHROPIC_API_KEY` (for regular API keys). -fn resolve_api_key(name: &str, api_key: Option<&str>) -> Option { - if let Some(key) = api_key.map(str::trim).filter(|k| !k.is_empty()) { - return Some(key.to_string()); +fn resolve_provider_credential(name: &str, credential_override: Option<&str>) -> Option { + if let Some(raw_override) = credential_override { + let trimmed_override = raw_override.trim(); + if !trimmed_override.is_empty() { + return Some(trimmed_override.to_owned()); + } } let provider_env_candidates: Vec<&str> = match name { @@ -135,6 +148,7 @@ fn resolve_api_key(name: &str, api_key: Option<&str>) -> Option { "opencode" | "opencode-zen" => vec!["OPENCODE_API_KEY"], "vercel" | "vercel-ai" => vec!["VERCEL_API_KEY"], "cloudflare" | "cloudflare-ai" => vec!["CLOUDFLARE_API_KEY"], + "astrai" => vec!["ASTRAI_API_KEY"], _ => vec![], }; @@ -182,19 +196,28 @@ fn parse_custom_provider_url( } } -/// Factory: create the right provider from config -#[allow(clippy::too_many_lines)] +/// Factory: create the right provider from config (without custom URL) pub fn create_provider(name: &str, api_key: Option<&str>) -> anyhow::Result> { - let resolved_key = resolve_api_key(name, api_key); - let key = resolved_key.as_deref(); + create_provider_with_url(name, api_key, None) +} + +/// Factory: create the right provider from config with optional custom base URL +#[allow(clippy::too_many_lines)] +pub fn create_provider_with_url( + name: &str, + api_key: Option<&str>, + api_url: Option<&str>, +) -> anyhow::Result> { + let resolved_credential = resolve_provider_credential(name, api_key); + #[allow(clippy::option_as_ref_deref)] + let key = resolved_credential.as_ref().map(String::as_str); match name { // ── Primary providers (custom implementations) ─────── "openrouter" => Ok(Box::new(openrouter::OpenRouterProvider::new(key))), "anthropic" => Ok(Box::new(anthropic::AnthropicProvider::new(key))), "openai" => Ok(Box::new(openai::OpenAiProvider::new(key))), - // Ollama is a local service that doesn't use API keys. - // The api_key parameter is ignored to avoid it being misinterpreted as a base_url. - "ollama" => Ok(Box::new(ollama::OllamaProvider::new(None))), + // Ollama uses api_url for custom base URL (e.g. remote Ollama instance) + "ollama" => Ok(Box::new(ollama::OllamaProvider::new(api_url))), "gemini" | "google" | "google-gemini" => { Ok(Box::new(gemini::GeminiProvider::new(key))) } @@ -257,7 +280,7 @@ pub fn create_provider(name: &str, api_key: Option<&str>) -> anyhow::Result Ok(Box::new(OpenAiCompatibleProvider::new( - "Mistral", "https://api.mistral.ai", key, AuthStyle::Bearer, + "Mistral", "https://api.mistral.ai/v1", key, AuthStyle::Bearer, ))), "xai" | "grok" => Ok(Box::new(OpenAiCompatibleProvider::new( "xAI", "https://api.x.ai", key, AuthStyle::Bearer, @@ -277,11 +300,33 @@ pub fn create_provider(name: &str, api_key: Option<&str>) -> anyhow::Result Ok(Box::new(OpenAiCompatibleProvider::new( "Cohere", "https://api.cohere.com/compatibility", key, AuthStyle::Bearer, ))), - "copilot" | "github-copilot" => Ok(Box::new(OpenAiCompatibleProvider::new( - "GitHub Copilot", "https://api.githubcopilot.com", key, AuthStyle::Bearer, - ))), - "nvidia" | "nvidia-nim" | "build.nvidia.com" => Ok(Box::new(OpenAiCompatibleProvider::new( - "NVIDIA NIM", "https://integrate.api.nvidia.com/v1", key, AuthStyle::Bearer, + "copilot" | "github-copilot" => { + Ok(Box::new(copilot::CopilotProvider::new(api_key))) + }, + "lmstudio" | "lm-studio" => { + let lm_studio_key = api_key + .map(str::trim) + .filter(|value| !value.is_empty()) + .unwrap_or("lm-studio"); + Ok(Box::new(OpenAiCompatibleProvider::new( + "LM Studio", + "http://localhost:1234/v1", + Some(lm_studio_key), + AuthStyle::Bearer, + ))) + } + "nvidia" | "nvidia-nim" | "build.nvidia.com" => Ok(Box::new( + OpenAiCompatibleProvider::new( + "NVIDIA NIM", + "https://integrate.api.nvidia.com/v1", + key, + AuthStyle::Bearer, + ), + )), + + // ── AI inference routers ───────────────────────────── + "astrai" => Ok(Box::new(OpenAiCompatibleProvider::new( + "Astrai", "https://as-trai.com/v1", key, AuthStyle::Bearer, ))), // ── Bring Your Own Provider (custom URL) ─────────── @@ -326,13 +371,14 @@ pub fn create_provider(name: &str, api_key: Option<&str>) -> anyhow::Result, + api_url: Option<&str>, reliability: &crate::config::ReliabilityConfig, ) -> anyhow::Result> { let mut providers: Vec<(String, Box)> = Vec::new(); providers.push(( primary_name.to_string(), - create_provider(primary_name, api_key)?, + create_provider_with_url(primary_name, api_key, api_url)?, )); for fallback in &reliability.fallback_providers { @@ -340,21 +386,13 @@ pub fn create_resilient_provider( continue; } - if api_key.is_some() && fallback != "ollama" { - tracing::warn!( - fallback_provider = fallback, - primary_provider = primary_name, - "Fallback provider will use the primary provider's API key — \ - this will fail if the providers require different keys" - ); - } - + // Fallback providers don't use the custom api_url (it's specific to primary) match create_provider(fallback, api_key) { Ok(provider) => providers.push((fallback.clone(), provider)), - Err(e) => { + Err(_error) => { tracing::warn!( fallback_provider = fallback, - "Ignoring invalid fallback provider: {e}" + "Ignoring invalid fallback provider during initialization" ); } } @@ -377,12 +415,13 @@ pub fn create_resilient_provider( pub fn create_routed_provider( primary_name: &str, api_key: Option<&str>, + api_url: Option<&str>, reliability: &crate::config::ReliabilityConfig, model_routes: &[crate::config::ModelRouteConfig], default_model: &str, ) -> anyhow::Result> { if model_routes.is_empty() { - return create_resilient_provider(primary_name, api_key, reliability); + return create_resilient_provider(primary_name, api_key, api_url, reliability); } // Collect unique provider names needed @@ -396,12 +435,19 @@ pub fn create_routed_provider( // Create each provider (with its own resilience wrapper) let mut providers: Vec<(String, Box)> = Vec::new(); for name in &needed { - let key = model_routes + let routed_credential = model_routes .iter() .find(|r| &r.provider == name) - .and_then(|r| r.api_key.as_deref()) - .or(api_key); - match create_resilient_provider(name, key, reliability) { + .and_then(|r| { + r.api_key.as_ref().and_then(|raw_key| { + let trimmed_key = raw_key.trim(); + (!trimmed_key.is_empty()).then_some(trimmed_key) + }) + }); + let key = routed_credential.or(api_key); + // Only use api_url for the primary provider + let url = if name == primary_name { api_url } else { None }; + match create_resilient_provider(name, key, url, reliability) { Ok(provider) => providers.push((name.clone(), provider)), Err(e) => { if name == primary_name { @@ -409,7 +455,7 @@ pub fn create_routed_provider( } tracing::warn!( provider = name.as_str(), - "Ignoring routed provider that failed to create: {e}" + "Ignoring routed provider that failed to initialize" ); } } @@ -441,27 +487,27 @@ mod tests { use super::*; #[test] - fn resolve_api_key_prefers_explicit_argument() { - let resolved = resolve_api_key("openrouter", Some(" explicit-key ")); - assert_eq!(resolved.as_deref(), Some("explicit-key")); + fn resolve_provider_credential_prefers_explicit_argument() { + let resolved = resolve_provider_credential("openrouter", Some(" explicit-key ")); + assert_eq!(resolved, Some("explicit-key".to_string())); } // ── Primary providers ──────────────────────────────────── #[test] fn factory_openrouter() { - assert!(create_provider("openrouter", Some("sk-test")).is_ok()); + assert!(create_provider("openrouter", Some("provider-test-credential")).is_ok()); assert!(create_provider("openrouter", None).is_ok()); } #[test] fn factory_anthropic() { - assert!(create_provider("anthropic", Some("sk-test")).is_ok()); + assert!(create_provider("anthropic", Some("provider-test-credential")).is_ok()); } #[test] fn factory_openai() { - assert!(create_provider("openai", Some("sk-test")).is_ok()); + assert!(create_provider("openai", Some("provider-test-credential")).is_ok()); } #[test] @@ -556,6 +602,13 @@ mod tests { assert!(create_provider("dashscope-us", Some("key")).is_ok()); } + #[test] + fn factory_lmstudio() { + assert!(create_provider("lmstudio", Some("key")).is_ok()); + assert!(create_provider("lm-studio", Some("key")).is_ok()); + assert!(create_provider("lmstudio", None).is_ok()); + } + // ── Extended ecosystem ─────────────────────────────────── #[test] @@ -614,6 +667,13 @@ mod tests { assert!(create_provider("build.nvidia.com", Some("nvapi-test")).is_ok()); } + // ── AI inference routers ───────────────────────────────── + + #[test] + fn factory_astrai() { + assert!(create_provider("astrai", Some("sk-astrai-test")).is_ok()); + } + // ── Custom / BYOP provider ───────────────────────────── #[test] @@ -761,17 +821,33 @@ mod tests { scheduler_retries: 2, }; - let provider = create_resilient_provider("openrouter", Some("sk-test"), &reliability); + let provider = create_resilient_provider( + "openrouter", + Some("provider-test-credential"), + None, + &reliability, + ); assert!(provider.is_ok()); } #[test] fn resilient_provider_errors_for_invalid_primary() { let reliability = crate::config::ReliabilityConfig::default(); - let provider = create_resilient_provider("totally-invalid", Some("sk-test"), &reliability); + let provider = create_resilient_provider( + "totally-invalid", + Some("provider-test-credential"), + None, + &reliability, + ); assert!(provider.is_err()); } + #[test] + fn ollama_with_custom_url() { + let provider = create_provider_with_url("ollama", None, Some("http://10.100.2.32:11434")); + assert!(provider.is_ok()); + } + #[test] fn factory_all_providers_create_successfully() { let providers = [ @@ -794,6 +870,7 @@ mod tests { "qwen", "qwen-intl", "qwen-us", + "lmstudio", "groq", "mistral", "xai", @@ -888,7 +965,7 @@ mod tests { #[test] fn sanitize_preserves_unicode_boundaries() { - let input = format!("{} sk-abcdef123", "こんにちは".repeat(80)); + let input = format!("{} sk-abcdef123", "hello🙂".repeat(80)); let result = sanitize_api_error(&input); assert!(std::str::from_utf8(result.as_bytes()).is_ok()); assert!(!result.contains("sk-abcdef123")); @@ -900,4 +977,32 @@ mod tests { let result = sanitize_api_error(input); assert_eq!(result, input); } + + #[test] + fn scrub_github_personal_access_token() { + let input = "auth failed with token ghp_abc123def456"; + let result = scrub_secret_patterns(input); + assert_eq!(result, "auth failed with token [REDACTED]"); + } + + #[test] + fn scrub_github_oauth_token() { + let input = "Bearer gho_1234567890abcdef"; + let result = scrub_secret_patterns(input); + assert_eq!(result, "Bearer [REDACTED]"); + } + + #[test] + fn scrub_github_user_token() { + let input = "token ghu_sessiontoken123"; + let result = scrub_secret_patterns(input); + assert_eq!(result, "token [REDACTED]"); + } + + #[test] + fn scrub_github_fine_grained_pat() { + let input = "failed: github_pat_11AABBC_xyzzy789"; + let result = scrub_secret_patterns(input); + assert_eq!(result, "failed: [REDACTED]"); + } } diff --git a/src/providers/ollama.rs b/src/providers/ollama.rs index 8ecfb5a..e05f027 100644 --- a/src/providers/ollama.rs +++ b/src/providers/ollama.rs @@ -8,6 +8,8 @@ pub struct OllamaProvider { client: Client, } +// ─── Request Structures ─────────────────────────────────────────────────────── + #[derive(Debug, Serialize)] struct ChatRequest { model: String, @@ -27,6 +29,8 @@ struct Options { temperature: f64, } +// ─── Response Structures ────────────────────────────────────────────────────── + #[derive(Debug, Deserialize)] struct ApiChatResponse { message: ResponseMessage, @@ -34,9 +38,30 @@ struct ApiChatResponse { #[derive(Debug, Deserialize)] struct ResponseMessage { + #[serde(default)] content: String, + #[serde(default)] + tool_calls: Vec, + /// Some models return a "thinking" field with internal reasoning + #[serde(default)] + thinking: Option, } +#[derive(Debug, Deserialize)] +struct OllamaToolCall { + id: Option, + function: OllamaFunction, +} + +#[derive(Debug, Deserialize)] +struct OllamaFunction { + name: String, + #[serde(default)] + arguments: serde_json::Value, +} + +// ─── Implementation ─────────────────────────────────────────────────────────── + impl OllamaProvider { pub fn new(base_url: Option<&str>) -> Self { Self { @@ -45,12 +70,145 @@ impl OllamaProvider { .trim_end_matches('/') .to_string(), client: Client::builder() - .timeout(std::time::Duration::from_secs(300)) // Ollama runs locally, may be slow + .timeout(std::time::Duration::from_secs(300)) .connect_timeout(std::time::Duration::from_secs(10)) .build() .unwrap_or_else(|_| Client::new()), } } + + /// Send a request to Ollama and get the parsed response + async fn send_request( + &self, + messages: Vec, + model: &str, + temperature: f64, + ) -> anyhow::Result { + let request = ChatRequest { + model: model.to_string(), + messages, + stream: false, + options: Options { temperature }, + }; + + let url = format!("{}/api/chat", self.base_url); + + tracing::debug!( + "Ollama request: url={} model={} message_count={} temperature={}", + url, + model, + request.messages.len(), + temperature + ); + + let response = self.client.post(&url).json(&request).send().await?; + let status = response.status(); + tracing::debug!("Ollama response status: {}", status); + + let body = response.bytes().await?; + tracing::debug!("Ollama response body length: {} bytes", body.len()); + + if !status.is_success() { + let raw = String::from_utf8_lossy(&body); + let sanitized = super::sanitize_api_error(&raw); + tracing::error!( + "Ollama error response: status={} body_excerpt={}", + status, + sanitized + ); + anyhow::bail!( + "Ollama API error ({}): {}. Is Ollama running? (brew install ollama && ollama serve)", + status, + sanitized + ); + } + + let chat_response: ApiChatResponse = match serde_json::from_slice(&body) { + Ok(r) => r, + Err(e) => { + let raw = String::from_utf8_lossy(&body); + let sanitized = super::sanitize_api_error(&raw); + tracing::error!( + "Ollama response deserialization failed: {e}. body_excerpt={}", + sanitized + ); + anyhow::bail!("Failed to parse Ollama response: {e}"); + } + }; + + Ok(chat_response) + } + + /// Convert Ollama tool calls to the JSON format expected by parse_tool_calls in loop_.rs + /// + /// Handles quirky model behavior where tool calls are wrapped: + /// - `{"name": "tool_call", "arguments": {"name": "shell", "arguments": {...}}}` + /// - `{"name": "tool.shell", "arguments": {...}}` + fn format_tool_calls_for_loop(&self, tool_calls: &[OllamaToolCall]) -> String { + let formatted_calls: Vec = tool_calls + .iter() + .map(|tc| { + let (tool_name, tool_args) = self.extract_tool_name_and_args(tc); + + // Arguments must be a JSON string for parse_tool_calls compatibility + let args_str = + serde_json::to_string(&tool_args).unwrap_or_else(|_| "{}".to_string()); + + serde_json::json!({ + "id": tc.id, + "type": "function", + "function": { + "name": tool_name, + "arguments": args_str + } + }) + }) + .collect(); + + serde_json::json!({ + "content": "", + "tool_calls": formatted_calls + }) + .to_string() + } + + /// Extract the actual tool name and arguments from potentially nested structures + fn extract_tool_name_and_args(&self, tc: &OllamaToolCall) -> (String, serde_json::Value) { + let name = &tc.function.name; + let args = &tc.function.arguments; + + // Pattern 1: Nested tool_call wrapper (various malformed versions) + // {"name": "tool_call", "arguments": {"name": "shell", "arguments": {"command": "date"}}} + // {"name": "tool_call>") + || name.starts_with("tool_call<") + { + if let Some(nested_name) = args.get("name").and_then(|v| v.as_str()) { + let nested_args = args + .get("arguments") + .cloned() + .unwrap_or(serde_json::json!({})); + tracing::debug!( + "Unwrapped nested tool call: {} -> {} with args {:?}", + name, + nested_name, + nested_args + ); + return (nested_name.to_string(), nested_args); + } + } + + // Pattern 2: Prefixed tool name (tool.shell, tool.file_read, etc.) + if let Some(stripped) = name.strip_prefix("tool.") { + return (stripped.to_string(), args.clone()); + } + + // Pattern 3: Normal tool call + (name.clone(), args.clone()) + } } #[async_trait] @@ -76,27 +234,96 @@ impl Provider for OllamaProvider { content: message.to_string(), }); - let request = ChatRequest { - model: model.to_string(), - messages, - stream: false, - options: Options { temperature }, - }; + let response = self.send_request(messages, model, temperature).await?; - let url = format!("{}/api/chat", self.base_url); - - let response = self.client.post(&url).json(&request).send().await?; - - if !response.status().is_success() { - let err = super::api_error("Ollama", response).await; - anyhow::bail!("{err}. Is Ollama running? (brew install ollama && ollama serve)"); + // If model returned tool calls, format them for loop_.rs's parse_tool_calls + if !response.message.tool_calls.is_empty() { + tracing::debug!( + "Ollama returned {} tool call(s), formatting for loop parser", + response.message.tool_calls.len() + ); + return Ok(self.format_tool_calls_for_loop(&response.message.tool_calls)); } - let chat_response: ApiChatResponse = response.json().await?; - Ok(chat_response.message.content) + // Plain text response + let content = response.message.content; + + // Handle edge case: model returned only "thinking" with no content or tool calls + if content.is_empty() { + if let Some(thinking) = &response.message.thinking { + tracing::warn!( + "Ollama returned empty content with only thinking: '{}'. Model may have stopped prematurely.", + if thinking.len() > 100 { &thinking[..100] } else { thinking } + ); + return Ok(format!( + "I was thinking about this: {}... but I didn't complete my response. Could you try asking again?", + if thinking.len() > 200 { &thinking[..200] } else { thinking } + )); + } + tracing::warn!("Ollama returned empty content with no tool calls"); + } + + Ok(content) + } + + async fn chat_with_history( + &self, + messages: &[crate::providers::ChatMessage], + model: &str, + temperature: f64, + ) -> anyhow::Result { + let api_messages: Vec = messages + .iter() + .map(|m| Message { + role: m.role.clone(), + content: m.content.clone(), + }) + .collect(); + + let response = self.send_request(api_messages, model, temperature).await?; + + // If model returned tool calls, format them for loop_.rs's parse_tool_calls + if !response.message.tool_calls.is_empty() { + tracing::debug!( + "Ollama returned {} tool call(s), formatting for loop parser", + response.message.tool_calls.len() + ); + return Ok(self.format_tool_calls_for_loop(&response.message.tool_calls)); + } + + // Plain text response + let content = response.message.content; + + // Handle edge case: model returned only "thinking" with no content or tool calls + // This is a model quirk - it stopped after reasoning without producing output + if content.is_empty() { + if let Some(thinking) = &response.message.thinking { + tracing::warn!( + "Ollama returned empty content with only thinking: '{}'. Model may have stopped prematurely.", + if thinking.len() > 100 { &thinking[..100] } else { thinking } + ); + // Return a message indicating the model's thought process but no action + return Ok(format!( + "I was thinking about this: {}... but I didn't complete my response. Could you try asking again?", + if thinking.len() > 200 { &thinking[..200] } else { thinking } + )); + } + tracing::warn!("Ollama returned empty content with no tool calls"); + } + + Ok(content) + } + + fn supports_native_tools(&self) -> bool { + // Return false since loop_.rs uses XML-style tool parsing via system prompt + // The model may return native tool_calls but we convert them to JSON format + // that parse_tool_calls() understands + false } } +// ─── Tests ──────────────────────────────────────────────────────────────────── + #[cfg(test)] mod tests { use super::*; @@ -125,46 +352,6 @@ mod tests { assert_eq!(p.base_url, ""); } - #[test] - fn request_serializes_with_system() { - let req = ChatRequest { - model: "llama3".to_string(), - messages: vec![ - Message { - role: "system".to_string(), - content: "You are ZeroClaw".to_string(), - }, - Message { - role: "user".to_string(), - content: "hello".to_string(), - }, - ], - stream: false, - options: Options { temperature: 0.7 }, - }; - let json = serde_json::to_string(&req).unwrap(); - assert!(json.contains("\"stream\":false")); - assert!(json.contains("llama3")); - assert!(json.contains("system")); - assert!(json.contains("\"temperature\":0.7")); - } - - #[test] - fn request_serializes_without_system() { - let req = ChatRequest { - model: "mistral".to_string(), - messages: vec![Message { - role: "user".to_string(), - content: "test".to_string(), - }], - stream: false, - options: Options { temperature: 0.0 }, - }; - let json = serde_json::to_string(&req).unwrap(); - assert!(!json.contains("\"role\":\"system\"")); - assert!(json.contains("mistral")); - } - #[test] fn response_deserializes() { let json = r#"{"message":{"role":"assistant","content":"Hello from Ollama!"}}"#; @@ -180,9 +367,98 @@ mod tests { } #[test] - fn response_with_multiline() { - let json = r#"{"message":{"role":"assistant","content":"line1\nline2\nline3"}}"#; + fn response_with_missing_content_defaults_to_empty() { + let json = r#"{"message":{"role":"assistant"}}"#; let resp: ApiChatResponse = serde_json::from_str(json).unwrap(); - assert!(resp.message.content.contains("line1")); + assert!(resp.message.content.is_empty()); + } + + #[test] + fn response_with_thinking_field_extracts_content() { + let json = + r#"{"message":{"role":"assistant","content":"hello","thinking":"internal reasoning"}}"#; + let resp: ApiChatResponse = serde_json::from_str(json).unwrap(); + assert_eq!(resp.message.content, "hello"); + } + + #[test] + fn response_with_tool_calls_parses_correctly() { + let json = r#"{"message":{"role":"assistant","content":"","tool_calls":[{"id":"call_123","function":{"name":"shell","arguments":{"command":"date"}}}]}}"#; + let resp: ApiChatResponse = serde_json::from_str(json).unwrap(); + assert!(resp.message.content.is_empty()); + assert_eq!(resp.message.tool_calls.len(), 1); + assert_eq!(resp.message.tool_calls[0].function.name, "shell"); + } + + #[test] + fn extract_tool_name_handles_nested_tool_call() { + let provider = OllamaProvider::new(None); + let tc = OllamaToolCall { + id: Some("call_123".into()), + function: OllamaFunction { + name: "tool_call".into(), + arguments: serde_json::json!({ + "name": "shell", + "arguments": {"command": "date"} + }), + }, + }; + let (name, args) = provider.extract_tool_name_and_args(&tc); + assert_eq!(name, "shell"); + assert_eq!(args.get("command").unwrap(), "date"); + } + + #[test] + fn extract_tool_name_handles_prefixed_name() { + let provider = OllamaProvider::new(None); + let tc = OllamaToolCall { + id: Some("call_123".into()), + function: OllamaFunction { + name: "tool.shell".into(), + arguments: serde_json::json!({"command": "ls"}), + }, + }; + let (name, args) = provider.extract_tool_name_and_args(&tc); + assert_eq!(name, "shell"); + assert_eq!(args.get("command").unwrap(), "ls"); + } + + #[test] + fn extract_tool_name_handles_normal_call() { + let provider = OllamaProvider::new(None); + let tc = OllamaToolCall { + id: Some("call_123".into()), + function: OllamaFunction { + name: "file_read".into(), + arguments: serde_json::json!({"path": "/tmp/test"}), + }, + }; + let (name, args) = provider.extract_tool_name_and_args(&tc); + assert_eq!(name, "file_read"); + assert_eq!(args.get("path").unwrap(), "/tmp/test"); + } + + #[test] + fn format_tool_calls_produces_valid_json() { + let provider = OllamaProvider::new(None); + let tool_calls = vec![OllamaToolCall { + id: Some("call_abc".into()), + function: OllamaFunction { + name: "shell".into(), + arguments: serde_json::json!({"command": "date"}), + }, + }]; + + let formatted = provider.format_tool_calls_for_loop(&tool_calls); + let parsed: serde_json::Value = serde_json::from_str(&formatted).unwrap(); + + assert!(parsed.get("tool_calls").is_some()); + let calls = parsed.get("tool_calls").unwrap().as_array().unwrap(); + assert_eq!(calls.len(), 1); + + let func = calls[0].get("function").unwrap(); + assert_eq!(func.get("name").unwrap(), "shell"); + // arguments should be a string (JSON-encoded) + assert!(func.get("arguments").unwrap().is_string()); } } diff --git a/src/providers/openai.rs b/src/providers/openai.rs index ef67678..22b53ca 100644 --- a/src/providers/openai.rs +++ b/src/providers/openai.rs @@ -8,7 +8,7 @@ use reqwest::Client; use serde::{Deserialize, Serialize}; pub struct OpenAiProvider { - api_key: Option, + credential: Option, client: Client, } @@ -110,9 +110,9 @@ struct NativeResponseMessage { } impl OpenAiProvider { - pub fn new(api_key: Option<&str>) -> Self { + pub fn new(credential: Option<&str>) -> Self { Self { - api_key: api_key.map(ToString::to_string), + credential: credential.map(ToString::to_string), client: Client::builder() .timeout(std::time::Duration::from_secs(120)) .connect_timeout(std::time::Duration::from_secs(10)) @@ -232,7 +232,7 @@ impl Provider for OpenAiProvider { model: &str, temperature: f64, ) -> anyhow::Result { - let api_key = self.api_key.as_ref().ok_or_else(|| { + let credential = self.credential.as_ref().ok_or_else(|| { anyhow::anyhow!("OpenAI API key not set. Set OPENAI_API_KEY or edit config.toml.") })?; @@ -259,7 +259,7 @@ impl Provider for OpenAiProvider { let response = self .client .post("https://api.openai.com/v1/chat/completions") - .header("Authorization", format!("Bearer {api_key}")) + .header("Authorization", format!("Bearer {credential}")) .json(&request) .send() .await?; @@ -284,7 +284,7 @@ impl Provider for OpenAiProvider { model: &str, temperature: f64, ) -> anyhow::Result { - let api_key = self.api_key.as_ref().ok_or_else(|| { + let credential = self.credential.as_ref().ok_or_else(|| { anyhow::anyhow!("OpenAI API key not set. Set OPENAI_API_KEY or edit config.toml.") })?; @@ -300,7 +300,7 @@ impl Provider for OpenAiProvider { let response = self .client .post("https://api.openai.com/v1/chat/completions") - .header("Authorization", format!("Bearer {api_key}")) + .header("Authorization", format!("Bearer {credential}")) .json(&native_request) .send() .await?; @@ -330,20 +330,20 @@ mod tests { #[test] fn creates_with_key() { - let p = OpenAiProvider::new(Some("sk-proj-abc123")); - assert_eq!(p.api_key.as_deref(), Some("sk-proj-abc123")); + let p = OpenAiProvider::new(Some("openai-test-credential")); + assert_eq!(p.credential.as_deref(), Some("openai-test-credential")); } #[test] fn creates_without_key() { let p = OpenAiProvider::new(None); - assert!(p.api_key.is_none()); + assert!(p.credential.is_none()); } #[test] fn creates_with_empty_key() { let p = OpenAiProvider::new(Some("")); - assert_eq!(p.api_key.as_deref(), Some("")); + assert_eq!(p.credential.as_deref(), Some("")); } #[tokio::test] diff --git a/src/providers/openrouter.rs b/src/providers/openrouter.rs index 2896c07..b27bff4 100644 --- a/src/providers/openrouter.rs +++ b/src/providers/openrouter.rs @@ -8,7 +8,7 @@ use reqwest::Client; use serde::{Deserialize, Serialize}; pub struct OpenRouterProvider { - api_key: Option, + credential: Option, client: Client, } @@ -110,9 +110,9 @@ struct NativeResponseMessage { } impl OpenRouterProvider { - pub fn new(api_key: Option<&str>) -> Self { + pub fn new(credential: Option<&str>) -> Self { Self { - api_key: api_key.map(ToString::to_string), + credential: credential.map(ToString::to_string), client: Client::builder() .timeout(std::time::Duration::from_secs(120)) .connect_timeout(std::time::Duration::from_secs(10)) @@ -232,10 +232,10 @@ impl Provider for OpenRouterProvider { async fn warmup(&self) -> anyhow::Result<()> { // Hit a lightweight endpoint to establish TLS + HTTP/2 connection pool. // This prevents the first real chat request from timing out on cold start. - if let Some(api_key) = self.api_key.as_ref() { + if let Some(credential) = self.credential.as_ref() { self.client .get("https://openrouter.ai/api/v1/auth/key") - .header("Authorization", format!("Bearer {api_key}")) + .header("Authorization", format!("Bearer {credential}")) .send() .await? .error_for_status()?; @@ -250,7 +250,7 @@ impl Provider for OpenRouterProvider { model: &str, temperature: f64, ) -> anyhow::Result { - let api_key = self.api_key.as_ref() + let credential = self.credential.as_ref() .ok_or_else(|| anyhow::anyhow!("OpenRouter API key not set. Run `zeroclaw onboard` or set OPENROUTER_API_KEY env var."))?; let mut messages = Vec::new(); @@ -276,7 +276,7 @@ impl Provider for OpenRouterProvider { let response = self .client .post("https://openrouter.ai/api/v1/chat/completions") - .header("Authorization", format!("Bearer {api_key}")) + .header("Authorization", format!("Bearer {credential}")) .header( "HTTP-Referer", "https://github.com/theonlyhennygod/zeroclaw", @@ -306,7 +306,7 @@ impl Provider for OpenRouterProvider { model: &str, temperature: f64, ) -> anyhow::Result { - let api_key = self.api_key.as_ref() + let credential = self.credential.as_ref() .ok_or_else(|| anyhow::anyhow!("OpenRouter API key not set. Run `zeroclaw onboard` or set OPENROUTER_API_KEY env var."))?; let api_messages: Vec = messages @@ -326,7 +326,7 @@ impl Provider for OpenRouterProvider { let response = self .client .post("https://openrouter.ai/api/v1/chat/completions") - .header("Authorization", format!("Bearer {api_key}")) + .header("Authorization", format!("Bearer {credential}")) .header( "HTTP-Referer", "https://github.com/theonlyhennygod/zeroclaw", @@ -356,7 +356,7 @@ impl Provider for OpenRouterProvider { model: &str, temperature: f64, ) -> anyhow::Result { - let api_key = self.api_key.as_ref().ok_or_else(|| { + let credential = self.credential.as_ref().ok_or_else(|| { anyhow::anyhow!( "OpenRouter API key not set. Run `zeroclaw onboard` or set OPENROUTER_API_KEY env var." ) @@ -374,7 +374,7 @@ impl Provider for OpenRouterProvider { let response = self .client .post("https://openrouter.ai/api/v1/chat/completions") - .header("Authorization", format!("Bearer {api_key}")) + .header("Authorization", format!("Bearer {credential}")) .header( "HTTP-Referer", "https://github.com/theonlyhennygod/zeroclaw", @@ -409,7 +409,7 @@ impl Provider for OpenRouterProvider { model: &str, temperature: f64, ) -> anyhow::Result { - let api_key = self.api_key.as_ref().ok_or_else(|| { + let credential = self.credential.as_ref().ok_or_else(|| { anyhow::anyhow!( "OpenRouter API key not set. Run `zeroclaw onboard` or set OPENROUTER_API_KEY env var." ) @@ -462,7 +462,7 @@ impl Provider for OpenRouterProvider { let response = self .client .post("https://openrouter.ai/api/v1/chat/completions") - .header("Authorization", format!("Bearer {api_key}")) + .header("Authorization", format!("Bearer {credential}")) .header( "HTTP-Referer", "https://github.com/theonlyhennygod/zeroclaw", @@ -494,14 +494,17 @@ mod tests { #[test] fn creates_with_key() { - let provider = OpenRouterProvider::new(Some("sk-or-123")); - assert_eq!(provider.api_key.as_deref(), Some("sk-or-123")); + let provider = OpenRouterProvider::new(Some("openrouter-test-credential")); + assert_eq!( + provider.credential.as_deref(), + Some("openrouter-test-credential") + ); } #[test] fn creates_without_key() { let provider = OpenRouterProvider::new(None); - assert!(provider.api_key.is_none()); + assert!(provider.credential.is_none()); } #[tokio::test] diff --git a/src/providers/reliable.rs b/src/providers/reliable.rs index 045f2c3..32cc0ca 100644 --- a/src/providers/reliable.rs +++ b/src/providers/reliable.rs @@ -144,8 +144,8 @@ impl Provider for ReliableProvider { async fn warmup(&self) -> anyhow::Result<()> { for (name, provider) in &self.providers { tracing::info!(provider = name, "Warming up provider connection pool"); - if let Err(e) = provider.warmup().await { - tracing::warn!(provider = name, "Warmup failed (non-fatal): {e}"); + if provider.warmup().await.is_err() { + tracing::warn!(provider = name, "Warmup failed (non-fatal)"); } } Ok(()) @@ -186,8 +186,15 @@ impl Provider for ReliableProvider { let non_retryable = is_non_retryable(&e); let rate_limited = is_rate_limited(&e); + let failure_reason = if rate_limited { + "rate_limited" + } else if non_retryable { + "non_retryable" + } else { + "retryable" + }; failures.push(format!( - "{provider_name}/{current_model} attempt {}/{}: {e}", + "{provider_name}/{current_model} attempt {}/{}: {failure_reason}", attempt + 1, self.max_retries + 1 )); @@ -284,8 +291,15 @@ impl Provider for ReliableProvider { let non_retryable = is_non_retryable(&e); let rate_limited = is_rate_limited(&e); + let failure_reason = if rate_limited { + "rate_limited" + } else if non_retryable { + "non_retryable" + } else { + "retryable" + }; failures.push(format!( - "{provider_name}/{current_model} attempt {}/{}: {e}", + "{provider_name}/{current_model} attempt {}/{}: {failure_reason}", attempt + 1, self.max_retries + 1 )); diff --git a/src/providers/traits.rs b/src/providers/traits.rs index f43d099..380bbc5 100644 --- a/src/providers/traits.rs +++ b/src/providers/traits.rs @@ -193,6 +193,13 @@ pub enum StreamError { #[async_trait] pub trait Provider: Send + Sync { + /// Query provider capabilities. + /// + /// Default implementation returns minimal capabilities (no native tool calling). + /// Providers should override this to declare their actual capabilities. + fn capabilities(&self) -> ProviderCapabilities { + ProviderCapabilities::default() + } /// Simple one-shot chat (single user message, no explicit system prompt). /// /// This is the preferred API for non-agentic direct interactions. @@ -256,7 +263,7 @@ pub trait Provider: Send + Sync { /// Whether provider supports native tool calls over API. fn supports_native_tools(&self) -> bool { - false + self.capabilities().native_tool_calling } /// Warm up the HTTP connection pool (TLS handshake, DNS, HTTP/2 setup). @@ -336,6 +343,27 @@ pub trait Provider: Send + Sync { mod tests { use super::*; + struct CapabilityMockProvider; + + #[async_trait] + impl Provider for CapabilityMockProvider { + fn capabilities(&self) -> ProviderCapabilities { + ProviderCapabilities { + native_tool_calling: true, + } + } + + async fn chat_with_system( + &self, + _system_prompt: Option<&str>, + _message: &str, + _model: &str, + _temperature: f64, + ) -> anyhow::Result { + Ok("ok".into()) + } + } + #[test] fn chat_message_constructors() { let sys = ChatMessage::system("Be helpful"); @@ -398,4 +426,32 @@ mod tests { let json = serde_json::to_string(&tool_result).unwrap(); assert!(json.contains("\"type\":\"ToolResults\"")); } + + #[test] + fn provider_capabilities_default() { + let caps = ProviderCapabilities::default(); + assert!(!caps.native_tool_calling); + } + + #[test] + fn provider_capabilities_equality() { + let caps1 = ProviderCapabilities { + native_tool_calling: true, + }; + let caps2 = ProviderCapabilities { + native_tool_calling: true, + }; + let caps3 = ProviderCapabilities { + native_tool_calling: false, + }; + + assert_eq!(caps1, caps2); + assert_ne!(caps1, caps3); + } + + #[test] + fn supports_native_tools_reflects_capabilities_default_mapping() { + let provider = CapabilityMockProvider; + assert!(provider.supports_native_tools()); + } } diff --git a/src/security/bubblewrap.rs b/src/security/bubblewrap.rs index 5c7106e..fca76e6 100644 --- a/src/security/bubblewrap.rs +++ b/src/security/bubblewrap.rs @@ -81,14 +81,17 @@ mod tests { #[test] fn bubblewrap_sandbox_name() { - assert_eq!(BubblewrapSandbox.name(), "bubblewrap"); + let sandbox = BubblewrapSandbox; + assert_eq!(sandbox.name(), "bubblewrap"); } #[test] fn bubblewrap_is_available_only_if_installed() { // Result depends on whether bwrap is installed - let available = BubblewrapSandbox::is_available(); + let sandbox = BubblewrapSandbox; + let _available = sandbox.is_available(); + // Either way, the name should still work - assert_eq!(BubblewrapSandbox.name(), "bubblewrap"); + assert_eq!(sandbox.name(), "bubblewrap"); } } diff --git a/src/security/pairing.rs b/src/security/pairing.rs index 806431b..2a828e1 100644 --- a/src/security/pairing.rs +++ b/src/security/pairing.rs @@ -184,7 +184,7 @@ fn generate_token() -> String { use rand::RngCore; let mut bytes = [0u8; 32]; rand::thread_rng().fill_bytes(&mut bytes); - format!("zc_{}", hex::encode(&bytes)) + format!("zc_{}", hex::encode(bytes)) } /// SHA-256 hash a bearer token for storage. Returns lowercase hex. diff --git a/src/security/policy.rs b/src/security/policy.rs index 9383f3a..e47947a 100644 --- a/src/security/policy.rs +++ b/src/security/policy.rs @@ -343,6 +343,7 @@ impl SecurityPolicy { /// validates each sub-command against the allowlist /// - Blocks single `&` background chaining (`&&` remains supported) /// - Blocks output redirections (`>`, `>>`) that could write outside workspace + /// - Blocks dangerous arguments (e.g. `find -exec`, `git config`) pub fn is_command_allowed(&self, command: &str) -> bool { if self.autonomy == AutonomyLevel::ReadOnly { return false; @@ -350,7 +351,12 @@ impl SecurityPolicy { // Block subshell/expansion operators — these allow hiding arbitrary // commands inside an allowed command (e.g. `echo $(rm -rf /)`) - if command.contains('`') || command.contains("$(") || command.contains("${") { + if command.contains('`') + || command.contains("$(") + || command.contains("${") + || command.contains("<(") + || command.contains(">(") + { return false; } @@ -359,6 +365,15 @@ impl SecurityPolicy { return false; } + // Block `tee` — it can write to arbitrary files, bypassing the + // redirect check above (e.g. `echo secret | tee /etc/crontab`) + if command + .split_whitespace() + .any(|w| w == "tee" || w.ends_with("/tee")) + { + return false; + } + // Block background command chaining (`&`), which can hide extra // sub-commands and outlive timeout expectations. Keep `&&` allowed. if contains_single_ampersand(command) { @@ -384,13 +399,9 @@ impl SecurityPolicy { // Strip leading env var assignments (e.g. FOO=bar cmd) let cmd_part = skip_env_assignments(segment); - let base_cmd = cmd_part - .split_whitespace() - .next() - .unwrap_or("") - .rsplit('/') - .next() - .unwrap_or(""); + let mut words = cmd_part.split_whitespace(); + let base_raw = words.next().unwrap_or(""); + let base_cmd = base_raw.rsplit('/').next().unwrap_or(""); if base_cmd.is_empty() { continue; @@ -403,6 +414,12 @@ impl SecurityPolicy { { return false; } + + // Validate arguments for the command + let args: Vec = words.map(|w| w.to_ascii_lowercase()).collect(); + if !self.is_args_safe(base_cmd, &args) { + return false; + } } // At least one command must be present @@ -414,6 +431,29 @@ impl SecurityPolicy { has_cmd } + /// Check for dangerous arguments that allow sub-command execution. + fn is_args_safe(&self, base: &str, args: &[String]) -> bool { + let base = base.to_ascii_lowercase(); + match base.as_str() { + "find" => { + // find -exec and find -ok allow arbitrary command execution + !args.iter().any(|arg| arg == "-exec" || arg == "-ok") + } + "git" => { + // git config, alias, and -c can be used to set dangerous options + // (e.g. git config core.editor "rm -rf /") + !args.iter().any(|arg| { + arg == "config" + || arg.starts_with("config.") + || arg == "alias" + || arg.starts_with("alias.") + || arg == "-c" + }) + } + _ => true, + } + } + /// Check if a file path is allowed (no path traversal, within workspace) pub fn is_path_allowed(&self, path: &str) -> bool { // Block null bytes (can truncate paths in C-backed syscalls) @@ -982,12 +1022,43 @@ mod tests { assert!(!p.is_command_allowed("ls >> /tmp/exfil.txt")); } + #[test] + fn command_argument_injection_blocked() { + let p = default_policy(); + // find -exec is a common bypass + assert!(!p.is_command_allowed("find . -exec rm -rf {} +")); + assert!(!p.is_command_allowed("find / -ok cat {} \\;")); + // git config/alias can execute commands + assert!(!p.is_command_allowed("git config core.editor \"rm -rf /\"")); + assert!(!p.is_command_allowed("git alias.st status")); + assert!(!p.is_command_allowed("git -c core.editor=calc.exe commit")); + // Legitimate commands should still work + assert!(p.is_command_allowed("find . -name '*.txt'")); + assert!(p.is_command_allowed("git status")); + assert!(p.is_command_allowed("git add .")); + } + #[test] fn command_injection_dollar_brace_blocked() { let p = default_policy(); assert!(!p.is_command_allowed("echo ${IFS}cat${IFS}/etc/passwd")); } + #[test] + fn command_injection_tee_blocked() { + let p = default_policy(); + assert!(!p.is_command_allowed("echo secret | tee /etc/crontab")); + assert!(!p.is_command_allowed("ls | /usr/bin/tee outfile")); + assert!(!p.is_command_allowed("tee file.txt")); + } + + #[test] + fn command_injection_process_substitution_blocked() { + let p = default_policy(); + assert!(!p.is_command_allowed("cat <(echo pwned)")); + assert!(!p.is_command_allowed("ls >(cat /etc/passwd)")); + } + #[test] fn command_env_var_prefix_with_allowed_cmd() { let p = default_policy(); diff --git a/src/tools/browser.rs b/src/tools/browser.rs index fe3be26..4e3d59e 100644 --- a/src/tools/browser.rs +++ b/src/tools/browser.rs @@ -854,7 +854,6 @@ impl BrowserTool { } } -#[allow(clippy::too_many_lines)] #[async_trait] impl Tool for BrowserTool { fn name(&self) -> &str { @@ -1031,165 +1030,21 @@ impl Tool for BrowserTool { return self.execute_computer_use_action(action_str, &args).await; } - let action = match action_str { - "open" => { - let url = args - .get("url") - .and_then(|v| v.as_str()) - .ok_or_else(|| anyhow::anyhow!("Missing 'url' for open action"))?; - BrowserAction::Open { url: url.into() } - } - "snapshot" => BrowserAction::Snapshot { - interactive_only: args - .get("interactive_only") - .and_then(serde_json::Value::as_bool) - .unwrap_or(true), // Default to interactive for AI - compact: args - .get("compact") - .and_then(serde_json::Value::as_bool) - .unwrap_or(true), - depth: args - .get("depth") - .and_then(serde_json::Value::as_u64) - .map(|d| u32::try_from(d).unwrap_or(u32::MAX)), - }, - "click" => { - let selector = args - .get("selector") - .and_then(|v| v.as_str()) - .ok_or_else(|| anyhow::anyhow!("Missing 'selector' for click"))?; - BrowserAction::Click { - selector: selector.into(), - } - } - "fill" => { - let selector = args - .get("selector") - .and_then(|v| v.as_str()) - .ok_or_else(|| anyhow::anyhow!("Missing 'selector' for fill"))?; - let value = args - .get("value") - .and_then(|v| v.as_str()) - .ok_or_else(|| anyhow::anyhow!("Missing 'value' for fill"))?; - BrowserAction::Fill { - selector: selector.into(), - value: value.into(), - } - } - "type" => { - let selector = args - .get("selector") - .and_then(|v| v.as_str()) - .ok_or_else(|| anyhow::anyhow!("Missing 'selector' for type"))?; - let text = args - .get("text") - .and_then(|v| v.as_str()) - .ok_or_else(|| anyhow::anyhow!("Missing 'text' for type"))?; - BrowserAction::Type { - selector: selector.into(), - text: text.into(), - } - } - "get_text" => { - let selector = args - .get("selector") - .and_then(|v| v.as_str()) - .ok_or_else(|| anyhow::anyhow!("Missing 'selector' for get_text"))?; - BrowserAction::GetText { - selector: selector.into(), - } - } - "get_title" => BrowserAction::GetTitle, - "get_url" => BrowserAction::GetUrl, - "screenshot" => BrowserAction::Screenshot { - path: args.get("path").and_then(|v| v.as_str()).map(String::from), - full_page: args - .get("full_page") - .and_then(serde_json::Value::as_bool) - .unwrap_or(false), - }, - "wait" => BrowserAction::Wait { - selector: args - .get("selector") - .and_then(|v| v.as_str()) - .map(String::from), - ms: args.get("ms").and_then(serde_json::Value::as_u64), - text: args.get("text").and_then(|v| v.as_str()).map(String::from), - }, - "press" => { - let key = args - .get("key") - .and_then(|v| v.as_str()) - .ok_or_else(|| anyhow::anyhow!("Missing 'key' for press"))?; - BrowserAction::Press { key: key.into() } - } - "hover" => { - let selector = args - .get("selector") - .and_then(|v| v.as_str()) - .ok_or_else(|| anyhow::anyhow!("Missing 'selector' for hover"))?; - BrowserAction::Hover { - selector: selector.into(), - } - } - "scroll" => { - let direction = args - .get("direction") - .and_then(|v| v.as_str()) - .ok_or_else(|| anyhow::anyhow!("Missing 'direction' for scroll"))?; - BrowserAction::Scroll { - direction: direction.into(), - pixels: args - .get("pixels") - .and_then(serde_json::Value::as_u64) - .map(|p| u32::try_from(p).unwrap_or(u32::MAX)), - } - } - "is_visible" => { - let selector = args - .get("selector") - .and_then(|v| v.as_str()) - .ok_or_else(|| anyhow::anyhow!("Missing 'selector' for is_visible"))?; - BrowserAction::IsVisible { - selector: selector.into(), - } - } - "close" => BrowserAction::Close, - "find" => { - let by = args - .get("by") - .and_then(|v| v.as_str()) - .ok_or_else(|| anyhow::anyhow!("Missing 'by' for find"))?; - let value = args - .get("value") - .and_then(|v| v.as_str()) - .ok_or_else(|| anyhow::anyhow!("Missing 'value' for find"))?; - let action = args - .get("find_action") - .and_then(|v| v.as_str()) - .ok_or_else(|| anyhow::anyhow!("Missing 'find_action' for find"))?; - BrowserAction::Find { - by: by.into(), - value: value.into(), - action: action.into(), - fill_value: args - .get("fill_value") - .and_then(|v| v.as_str()) - .map(String::from), - } - } - _ => { + if is_computer_use_only_action(action_str) { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(unavailable_action_for_backend_error(action_str, backend)), + }); + } + + let action = match parse_browser_action(action_str, &args) { + Ok(a) => a, + Err(e) => { return Ok(ToolResult { success: false, output: String::new(), - error: Some(format!( - "Action '{action_str}' is unavailable for backend '{}'", - match backend { - ResolvedBackend::AgentBrowser => "agent_browser", - ResolvedBackend::RustNative => "rust_native", - ResolvedBackend::ComputerUse => "computer_use", - } - )), + error: Some(e.to_string()), }); } }; @@ -1871,6 +1726,161 @@ mod native_backend { } } +// ── Action parsing ────────────────────────────────────────────── + +/// Parse a JSON `args` object into a typed `BrowserAction`. +fn parse_browser_action(action_str: &str, args: &Value) -> anyhow::Result { + match action_str { + "open" => { + let url = args + .get("url") + .and_then(|v| v.as_str()) + .ok_or_else(|| anyhow::anyhow!("Missing 'url' for open action"))?; + Ok(BrowserAction::Open { url: url.into() }) + } + "snapshot" => Ok(BrowserAction::Snapshot { + interactive_only: args + .get("interactive_only") + .and_then(serde_json::Value::as_bool) + .unwrap_or(true), + compact: args + .get("compact") + .and_then(serde_json::Value::as_bool) + .unwrap_or(true), + depth: args + .get("depth") + .and_then(serde_json::Value::as_u64) + .map(|d| u32::try_from(d).unwrap_or(u32::MAX)), + }), + "click" => { + let selector = args + .get("selector") + .and_then(|v| v.as_str()) + .ok_or_else(|| anyhow::anyhow!("Missing 'selector' for click"))?; + Ok(BrowserAction::Click { + selector: selector.into(), + }) + } + "fill" => { + let selector = args + .get("selector") + .and_then(|v| v.as_str()) + .ok_or_else(|| anyhow::anyhow!("Missing 'selector' for fill"))?; + let value = args + .get("value") + .and_then(|v| v.as_str()) + .ok_or_else(|| anyhow::anyhow!("Missing 'value' for fill"))?; + Ok(BrowserAction::Fill { + selector: selector.into(), + value: value.into(), + }) + } + "type" => { + let selector = args + .get("selector") + .and_then(|v| v.as_str()) + .ok_or_else(|| anyhow::anyhow!("Missing 'selector' for type"))?; + let text = args + .get("text") + .and_then(|v| v.as_str()) + .ok_or_else(|| anyhow::anyhow!("Missing 'text' for type"))?; + Ok(BrowserAction::Type { + selector: selector.into(), + text: text.into(), + }) + } + "get_text" => { + let selector = args + .get("selector") + .and_then(|v| v.as_str()) + .ok_or_else(|| anyhow::anyhow!("Missing 'selector' for get_text"))?; + Ok(BrowserAction::GetText { + selector: selector.into(), + }) + } + "get_title" => Ok(BrowserAction::GetTitle), + "get_url" => Ok(BrowserAction::GetUrl), + "screenshot" => Ok(BrowserAction::Screenshot { + path: args.get("path").and_then(|v| v.as_str()).map(String::from), + full_page: args + .get("full_page") + .and_then(serde_json::Value::as_bool) + .unwrap_or(false), + }), + "wait" => Ok(BrowserAction::Wait { + selector: args + .get("selector") + .and_then(|v| v.as_str()) + .map(String::from), + ms: args.get("ms").and_then(serde_json::Value::as_u64), + text: args.get("text").and_then(|v| v.as_str()).map(String::from), + }), + "press" => { + let key = args + .get("key") + .and_then(|v| v.as_str()) + .ok_or_else(|| anyhow::anyhow!("Missing 'key' for press"))?; + Ok(BrowserAction::Press { key: key.into() }) + } + "hover" => { + let selector = args + .get("selector") + .and_then(|v| v.as_str()) + .ok_or_else(|| anyhow::anyhow!("Missing 'selector' for hover"))?; + Ok(BrowserAction::Hover { + selector: selector.into(), + }) + } + "scroll" => { + let direction = args + .get("direction") + .and_then(|v| v.as_str()) + .ok_or_else(|| anyhow::anyhow!("Missing 'direction' for scroll"))?; + Ok(BrowserAction::Scroll { + direction: direction.into(), + pixels: args + .get("pixels") + .and_then(serde_json::Value::as_u64) + .map(|p| u32::try_from(p).unwrap_or(u32::MAX)), + }) + } + "is_visible" => { + let selector = args + .get("selector") + .and_then(|v| v.as_str()) + .ok_or_else(|| anyhow::anyhow!("Missing 'selector' for is_visible"))?; + Ok(BrowserAction::IsVisible { + selector: selector.into(), + }) + } + "close" => Ok(BrowserAction::Close), + "find" => { + let by = args + .get("by") + .and_then(|v| v.as_str()) + .ok_or_else(|| anyhow::anyhow!("Missing 'by' for find"))?; + let value = args + .get("value") + .and_then(|v| v.as_str()) + .ok_or_else(|| anyhow::anyhow!("Missing 'value' for find"))?; + let action = args + .get("find_action") + .and_then(|v| v.as_str()) + .ok_or_else(|| anyhow::anyhow!("Missing 'find_action' for find"))?; + Ok(BrowserAction::Find { + by: by.into(), + value: value.into(), + action: action.into(), + fill_value: args + .get("fill_value") + .and_then(|v| v.as_str()) + .map(String::from), + }) + } + other => anyhow::bail!("Unsupported browser action: {other}"), + } +} + // ── Helper functions ───────────────────────────────────────────── fn is_supported_browser_action(action: &str) -> bool { @@ -1901,6 +1911,28 @@ fn is_supported_browser_action(action: &str) -> bool { ) } +fn is_computer_use_only_action(action: &str) -> bool { + matches!( + action, + "mouse_move" | "mouse_click" | "mouse_drag" | "key_type" | "key_press" | "screen_capture" + ) +} + +fn backend_name(backend: ResolvedBackend) -> &'static str { + match backend { + ResolvedBackend::AgentBrowser => "agent_browser", + ResolvedBackend::RustNative => "rust_native", + ResolvedBackend::ComputerUse => "computer_use", + } +} + +fn unavailable_action_for_backend_error(action: &str, backend: ResolvedBackend) -> String { + format!( + "Action '{action}' is unavailable for backend '{}'", + backend_name(backend) + ) +} + fn normalize_domains(domains: Vec) -> Vec { domains .into_iter() @@ -2342,4 +2374,28 @@ mod tests { let tool = BrowserTool::new(security, vec![], None); assert!(tool.validate_url("https://example.com").is_err()); } + + #[test] + fn computer_use_only_action_detection_is_correct() { + assert!(is_computer_use_only_action("mouse_move")); + assert!(is_computer_use_only_action("mouse_click")); + assert!(is_computer_use_only_action("mouse_drag")); + assert!(is_computer_use_only_action("key_type")); + assert!(is_computer_use_only_action("key_press")); + assert!(is_computer_use_only_action("screen_capture")); + assert!(!is_computer_use_only_action("open")); + assert!(!is_computer_use_only_action("snapshot")); + } + + #[test] + fn unavailable_action_error_preserves_backend_context() { + assert_eq!( + unavailable_action_for_backend_error("mouse_move", ResolvedBackend::AgentBrowser), + "Action 'mouse_move' is unavailable for backend 'agent_browser'" + ); + assert_eq!( + unavailable_action_for_backend_error("mouse_move", ResolvedBackend::RustNative), + "Action 'mouse_move' is unavailable for backend 'rust_native'" + ); + } } diff --git a/src/tools/composio.rs b/src/tools/composio.rs index 4e608cb..65f128e 100644 --- a/src/tools/composio.rs +++ b/src/tools/composio.rs @@ -112,12 +112,12 @@ impl ComposioTool { action_name: &str, params: serde_json::Value, entity_id: Option<&str>, - connected_account_id: Option<&str>, + connected_account_ref: Option<&str>, ) -> anyhow::Result { let tool_slug = normalize_tool_slug(action_name); match self - .execute_action_v3(&tool_slug, params.clone(), entity_id, connected_account_id) + .execute_action_v3(&tool_slug, params.clone(), entity_id, connected_account_ref) .await { Ok(result) => Ok(result), @@ -130,21 +130,17 @@ impl ComposioTool { } } - async fn execute_action_v3( - &self, + fn build_execute_action_v3_request( tool_slug: &str, params: serde_json::Value, entity_id: Option<&str>, - connected_account_id: Option<&str>, - ) -> anyhow::Result { - let url = if let Some(connected_account_id) = connected_account_id - .map(str::trim) - .filter(|id| !id.is_empty()) - { - format!("{COMPOSIO_API_BASE_V3}/tools/{tool_slug}/execute/{connected_account_id}") - } else { - format!("{COMPOSIO_API_BASE_V3}/tools/{tool_slug}/execute") - }; + connected_account_ref: Option<&str>, + ) -> (String, serde_json::Value) { + let url = format!("{COMPOSIO_API_BASE_V3}/tools/{tool_slug}/execute"); + let account_ref = connected_account_ref.and_then(|candidate| { + let trimmed_candidate = candidate.trim(); + (!trimmed_candidate.is_empty()).then_some(trimmed_candidate) + }); let mut body = json!({ "arguments": params, @@ -153,6 +149,26 @@ impl ComposioTool { if let Some(entity) = entity_id { body["user_id"] = json!(entity); } + if let Some(account_ref) = account_ref { + body["connected_account_id"] = json!(account_ref); + } + + (url, body) + } + + async fn execute_action_v3( + &self, + tool_slug: &str, + params: serde_json::Value, + entity_id: Option<&str>, + connected_account_ref: Option<&str>, + ) -> anyhow::Result { + let (url, body) = Self::build_execute_action_v3_request( + tool_slug, + params, + entity_id, + connected_account_ref, + ); let resp = self .client @@ -474,11 +490,11 @@ impl Tool for ComposioTool { })?; let params = args.get("params").cloned().unwrap_or(json!({})); - let connected_account_id = + let connected_account_ref = args.get("connected_account_id").and_then(|v| v.as_str()); match self - .execute_action(action_name, params, Some(entity_id), connected_account_id) + .execute_action(action_name, params, Some(entity_id), connected_account_ref) .await { Ok(result) => { @@ -594,9 +610,38 @@ async fn response_error(resp: reqwest::Response) -> String { } if let Some(api_error) = extract_api_error_message(&body) { - format!("HTTP {}: {api_error}", status.as_u16()) + return format!( + "HTTP {}: {}", + status.as_u16(), + sanitize_error_message(&api_error) + ); + } + + format!("HTTP {}", status.as_u16()) +} + +fn sanitize_error_message(message: &str) -> String { + let mut sanitized = message.replace('\n', " "); + for marker in [ + "connected_account_id", + "connectedAccountId", + "entity_id", + "entityId", + "user_id", + "userId", + ] { + sanitized = sanitized.replace(marker, "[redacted]"); + } + + let max_chars = 240; + if sanitized.chars().count() <= max_chars { + sanitized } else { - format!("HTTP {}: {body}", status.as_u16()) + let mut end = max_chars; + while end > 0 && !sanitized.is_char_boundary(end) { + end -= 1; + } + format!("{}...", &sanitized[..end]) } } @@ -948,4 +993,40 @@ mod tests { fn composio_api_base_url_is_v3() { assert_eq!(COMPOSIO_API_BASE_V3, "https://backend.composio.dev/api/v3"); } + + #[test] + fn build_execute_action_v3_request_uses_fixed_endpoint_and_body_account_id() { + let (url, body) = ComposioTool::build_execute_action_v3_request( + "gmail-send-email", + json!({"to": "test@example.com"}), + Some("workspace-user"), + Some("account-42"), + ); + + assert_eq!( + url, + "https://backend.composio.dev/api/v3/tools/gmail-send-email/execute" + ); + assert_eq!(body["arguments"]["to"], json!("test@example.com")); + assert_eq!(body["user_id"], json!("workspace-user")); + assert_eq!(body["connected_account_id"], json!("account-42")); + } + + #[test] + fn build_execute_action_v3_request_drops_blank_optional_fields() { + let (url, body) = ComposioTool::build_execute_action_v3_request( + "github-list-repos", + json!({}), + None, + Some(" "), + ); + + assert_eq!( + url, + "https://backend.composio.dev/api/v3/tools/github-list-repos/execute" + ); + assert_eq!(body["arguments"], json!({})); + assert!(body.get("connected_account_id").is_none()); + assert!(body.get("user_id").is_none()); + } } diff --git a/src/tools/delegate.rs b/src/tools/delegate.rs index 7f30b64..3de7872 100644 --- a/src/tools/delegate.rs +++ b/src/tools/delegate.rs @@ -16,8 +16,8 @@ const DELEGATE_TIMEOUT_SECS: u64 = 120; /// summarization) to purpose-built sub-agents. pub struct DelegateTool { agents: Arc>, - /// Global API key fallback (from config.api_key) - fallback_api_key: Option, + /// Global credential fallback (from config.api_key) + fallback_credential: Option, /// Depth at which this tool instance lives in the delegation chain. depth: u32, } @@ -25,11 +25,11 @@ pub struct DelegateTool { impl DelegateTool { pub fn new( agents: HashMap, - fallback_api_key: Option, + fallback_credential: Option, ) -> Self { Self { agents: Arc::new(agents), - fallback_api_key, + fallback_credential, depth: 0, } } @@ -39,12 +39,12 @@ impl DelegateTool { /// their DelegateTool via this method with `depth: parent.depth + 1`. pub fn with_depth( agents: HashMap, - fallback_api_key: Option, + fallback_credential: Option, depth: u32, ) -> Self { Self { agents: Arc::new(agents), - fallback_api_key, + fallback_credential, depth, } } @@ -165,13 +165,15 @@ impl Tool for DelegateTool { } // Create provider for this agent - let api_key = agent_config + let provider_credential_owned = agent_config .api_key - .as_deref() - .or(self.fallback_api_key.as_deref()); + .clone() + .or_else(|| self.fallback_credential.clone()); + #[allow(clippy::option_as_ref_deref)] + let provider_credential = provider_credential_owned.as_ref().map(String::as_str); let provider: Box = - match providers::create_provider(&agent_config.provider, api_key) { + match providers::create_provider(&agent_config.provider, provider_credential) { Ok(p) => p, Err(e) => { return Ok(ToolResult { @@ -268,7 +270,7 @@ mod tests { provider: "openrouter".to_string(), model: "anthropic/claude-sonnet-4-20250514".to_string(), system_prompt: None, - api_key: Some("sk-test".to_string()), + api_key: Some("delegate-test-credential".to_string()), temperature: None, max_depth: 2, }, diff --git a/src/tools/git_operations.rs b/src/tools/git_operations.rs index 9fcb453..21440ba 100644 --- a/src/tools/git_operations.rs +++ b/src/tools/git_operations.rs @@ -28,13 +28,22 @@ impl GitOperationsTool { if arg_lower.starts_with("--exec=") || arg_lower.starts_with("--upload-pack=") || arg_lower.starts_with("--receive-pack=") + || arg_lower.starts_with("--pager=") + || arg_lower.starts_with("--editor=") + || arg_lower == "--no-verify" || arg_lower.contains("$(") || arg_lower.contains('`') || arg.contains('|') || arg.contains(';') + || arg.contains('>') { anyhow::bail!("Blocked potentially dangerous git argument: {arg}"); } + // Block `-c` config injection (exact match or `-c=...` prefix). + // This must not false-positive on `--cached` or `-cached`. + if arg_lower == "-c" || arg_lower.starts_with("-c=") { + anyhow::bail!("Blocked potentially dangerous git argument: {arg}"); + } result.push(arg.to_string()); } Ok(result) @@ -129,6 +138,9 @@ impl GitOperationsTool { .and_then(|v| v.as_bool()) .unwrap_or(false); + // Validate files argument against injection patterns + self.sanitize_git_args(files)?; + let mut git_args = vec!["diff", "--unified=3"]; if cached { git_args.push("--cached"); @@ -267,6 +279,14 @@ impl GitOperationsTool { }) } + fn truncate_commit_message(message: &str) -> String { + if message.chars().count() > 2000 { + format!("{}...", message.chars().take(1997).collect::()) + } else { + message.to_string() + } + } + async fn git_commit(&self, args: serde_json::Value) -> anyhow::Result { let message = args .get("message") @@ -286,11 +306,7 @@ impl GitOperationsTool { } // Limit message length - let message = if sanitized.len() > 2000 { - format!("{}...", &sanitized[..1997]) - } else { - sanitized - }; + let message = Self::truncate_commit_message(&sanitized); let output = self.run_git_command(&["commit", "-m", &message]).await; @@ -314,6 +330,9 @@ impl GitOperationsTool { .and_then(|v| v.as_str()) .ok_or_else(|| anyhow::anyhow!("Missing 'paths' parameter"))?; + // Validate paths against injection patterns + self.sanitize_git_args(paths)?; + let output = self.run_git_command(&["add", "--", paths]).await; match output { @@ -574,6 +593,52 @@ mod tests { assert!(tool.sanitize_git_args("arg; rm file").is_err()); } + #[test] + fn sanitize_git_blocks_pager_editor_injection() { + let tmp = TempDir::new().unwrap(); + let tool = test_tool(tmp.path()); + + assert!(tool.sanitize_git_args("--pager=less").is_err()); + assert!(tool.sanitize_git_args("--editor=vim").is_err()); + } + + #[test] + fn sanitize_git_blocks_config_injection() { + let tmp = TempDir::new().unwrap(); + let tool = test_tool(tmp.path()); + + // Exact `-c` flag (config injection) + assert!(tool.sanitize_git_args("-c core.sshCommand=evil").is_err()); + assert!(tool.sanitize_git_args("-c=core.pager=less").is_err()); + } + + #[test] + fn sanitize_git_blocks_no_verify() { + let tmp = TempDir::new().unwrap(); + let tool = test_tool(tmp.path()); + + assert!(tool.sanitize_git_args("--no-verify").is_err()); + } + + #[test] + fn sanitize_git_blocks_redirect_in_args() { + let tmp = TempDir::new().unwrap(); + let tool = test_tool(tmp.path()); + + assert!(tool.sanitize_git_args("file.txt > /tmp/out").is_err()); + } + + #[test] + fn sanitize_git_cached_not_blocked() { + let tmp = TempDir::new().unwrap(); + let tool = test_tool(tmp.path()); + + // --cached must NOT be blocked by the `-c` check + assert!(tool.sanitize_git_args("--cached").is_ok()); + // Other safe flags starting with -c prefix + assert!(tool.sanitize_git_args("-cached").is_ok()); + } + #[test] fn sanitize_git_allows_safe() { let tmp = TempDir::new().unwrap(); @@ -583,6 +648,8 @@ mod tests { assert!(tool.sanitize_git_args("main").is_ok()); assert!(tool.sanitize_git_args("feature/test-branch").is_ok()); assert!(tool.sanitize_git_args("--cached").is_ok()); + assert!(tool.sanitize_git_args("src/main.rs").is_ok()); + assert!(tool.sanitize_git_args(".").is_ok()); } #[test] @@ -691,4 +758,12 @@ mod tests { .unwrap_or("") .contains("Unknown operation")); } + + #[test] + fn truncates_multibyte_commit_message_without_panicking() { + let long = "🦀".repeat(2500); + let truncated = GitOperationsTool::truncate_commit_message(&long); + + assert_eq!(truncated.chars().count(), 2000); + } } diff --git a/src/tools/hardware_board_info.rs b/src/tools/hardware_board_info.rs index f7af262..73b30fc 100644 --- a/src/tools/hardware_board_info.rs +++ b/src/tools/hardware_board_info.rs @@ -124,10 +124,11 @@ impl Tool for HardwareBoardInfoTool { }); } Err(e) => { - output.push_str(&format!( - "probe-rs attach failed: {}. Using static info.\n\n", - e - )); + use std::fmt::Write; + let _ = write!( + output, + "probe-rs attach failed: {e}. Using static info.\n\n" + ); } } } @@ -135,13 +136,15 @@ impl Tool for HardwareBoardInfoTool { if let Some(info) = self.static_info_for_board(board) { output.push_str(&info); if let Some(mem) = memory_map_static(board) { - output.push_str(&format!("\n\n**Memory map:**\n{}", mem)); + use std::fmt::Write; + let _ = write!(output, "\n\n**Memory map:**\n{mem}"); } } else { - output.push_str(&format!( - "Board '{}' configured. No static info available.", - board - )); + use std::fmt::Write; + let _ = write!( + output, + "Board '{board}' configured. No static info available." + ); } Ok(ToolResult { diff --git a/src/tools/hardware_memory_map.rs b/src/tools/hardware_memory_map.rs index bdb4f96..41fd07b 100644 --- a/src/tools/hardware_memory_map.rs +++ b/src/tools/hardware_memory_map.rs @@ -122,14 +122,16 @@ impl Tool for HardwareMemoryMapTool { if !probe_ok { if let Some(map) = self.static_map_for_board(board) { - output.push_str(&format!("**{}** (from datasheet):\n{}", board, map)); + use std::fmt::Write; + let _ = write!(output, "**{board}** (from datasheet):\n{map}"); } else { + use std::fmt::Write; let known: Vec<&str> = MEMORY_MAPS.iter().map(|(b, _)| *b).collect(); - output.push_str(&format!( - "No memory map for board '{}'. Known boards: {}", - board, + let _ = write!( + output, + "No memory map for board '{board}'. Known boards: {}", known.join(", ") - )); + ); } } diff --git a/src/tools/hardware_memory_read.rs b/src/tools/hardware_memory_read.rs index 4cc42d5..3232c78 100644 --- a/src/tools/hardware_memory_read.rs +++ b/src/tools/hardware_memory_read.rs @@ -94,14 +94,16 @@ impl Tool for HardwareMemoryReadTool { .get("address") .and_then(|v| v.as_str()) .unwrap_or("0x20000000"); - let address = parse_hex_address(address_str).unwrap_or(NUCLEO_RAM_BASE); + let _address = parse_hex_address(address_str).unwrap_or(NUCLEO_RAM_BASE); - let length = args.get("length").and_then(|v| v.as_u64()).unwrap_or(128) as usize; - let length = length.min(256).max(1); + let requested_length = args.get("length").and_then(|v| v.as_u64()).unwrap_or(128); + let _length = usize::try_from(requested_length) + .unwrap_or(256) + .clamp(1, 256); #[cfg(feature = "probe")] { - match probe_read_memory(chip.unwrap(), address, length) { + match probe_read_memory(chip.unwrap(), _address, _length) { Ok(output) => { return Ok(ToolResult { success: true, diff --git a/src/tools/http_request.rs b/src/tools/http_request.rs index 0701f95..1d00253 100644 --- a/src/tools/http_request.rs +++ b/src/tools/http_request.rs @@ -749,4 +749,54 @@ mod tests { let _ = HttpRequestTool::redact_headers_for_display(&headers); assert_eq!(headers[0].1, "Bearer real-token"); } + + // ── SSRF: alternate IP notation bypass defense-in-depth ───────── + // + // Rust's IpAddr::parse() rejects non-standard notations (octal, hex, + // decimal integer, zero-padded). These tests document that property + // so regressions are caught if the parsing strategy ever changes. + + #[test] + fn ssrf_octal_loopback_not_parsed_as_ip() { + // 0177.0.0.1 is octal for 127.0.0.1 in some languages, but + // Rust's IpAddr rejects it — it falls through as a hostname. + assert!(!is_private_or_local_host("0177.0.0.1")); + } + + #[test] + fn ssrf_hex_loopback_not_parsed_as_ip() { + // 0x7f000001 is hex for 127.0.0.1 in some languages. + assert!(!is_private_or_local_host("0x7f000001")); + } + + #[test] + fn ssrf_decimal_loopback_not_parsed_as_ip() { + // 2130706433 is decimal for 127.0.0.1 in some languages. + assert!(!is_private_or_local_host("2130706433")); + } + + #[test] + fn ssrf_zero_padded_loopback_not_parsed_as_ip() { + // 127.000.000.001 uses zero-padded octets. + assert!(!is_private_or_local_host("127.000.000.001")); + } + + #[test] + fn ssrf_alternate_notations_rejected_by_validate_url() { + // Even if is_private_or_local_host doesn't flag these, they + // fail the allowlist because they're treated as hostnames. + let tool = test_tool(vec!["example.com"]); + for notation in [ + "http://0177.0.0.1", + "http://0x7f000001", + "http://2130706433", + "http://127.000.000.001", + ] { + let err = tool.validate_url(notation).unwrap_err().to_string(); + assert!( + err.contains("allowed_domains"), + "Expected allowlist rejection for {notation}, got: {err}" + ); + } + } } diff --git a/src/tools/memory_forget.rs b/src/tools/memory_forget.rs index 16b2b8a..a53885e 100644 --- a/src/tools/memory_forget.rs +++ b/src/tools/memory_forget.rs @@ -87,7 +87,7 @@ mod tests { #[tokio::test] async fn forget_existing() { let (_tmp, mem) = test_mem(); - mem.store("temp", "temporary", MemoryCategory::Conversation) + mem.store("temp", "temporary", MemoryCategory::Conversation, None) .await .unwrap(); diff --git a/src/tools/memory_recall.rs b/src/tools/memory_recall.rs index ff1385a..fada306 100644 --- a/src/tools/memory_recall.rs +++ b/src/tools/memory_recall.rs @@ -55,7 +55,7 @@ impl Tool for MemoryRecallTool { .and_then(serde_json::Value::as_u64) .map_or(5, |v| v as usize); - match self.memory.recall(query, limit).await { + match self.memory.recall(query, limit, None).await { Ok(entries) if entries.is_empty() => Ok(ToolResult { success: true, output: "No memories found matching that query.".into(), @@ -112,10 +112,10 @@ mod tests { #[tokio::test] async fn recall_finds_match() { let (_tmp, mem) = seeded_mem(); - mem.store("lang", "User prefers Rust", MemoryCategory::Core) + mem.store("lang", "User prefers Rust", MemoryCategory::Core, None) .await .unwrap(); - mem.store("tz", "Timezone is EST", MemoryCategory::Core) + mem.store("tz", "Timezone is EST", MemoryCategory::Core, None) .await .unwrap(); @@ -134,6 +134,7 @@ mod tests { &format!("k{i}"), &format!("Rust fact {i}"), MemoryCategory::Core, + None, ) .await .unwrap(); diff --git a/src/tools/memory_store.rs b/src/tools/memory_store.rs index b90222c..d2aad40 100644 --- a/src/tools/memory_store.rs +++ b/src/tools/memory_store.rs @@ -64,7 +64,7 @@ impl Tool for MemoryStoreTool { _ => MemoryCategory::Core, }; - match self.memory.store(key, content, category).await { + match self.memory.store(key, content, category, None).await { Ok(()) => Ok(ToolResult { success: true, output: format!("Stored memory: {key}"), diff --git a/src/tools/mod.rs b/src/tools/mod.rs index 07f29d8..b541736 100644 --- a/src/tools/mod.rs +++ b/src/tools/mod.rs @@ -19,7 +19,9 @@ pub mod image_info; pub mod memory_forget; pub mod memory_recall; pub mod memory_store; +pub mod pushover; pub mod schedule; +pub mod schema; pub mod screenshot; pub mod shell; pub mod traits; @@ -45,7 +47,9 @@ pub use image_info::ImageInfoTool; pub use memory_forget::MemoryForgetTool; pub use memory_recall::MemoryRecallTool; pub use memory_store::MemoryStoreTool; +pub use pushover::PushoverTool; pub use schedule::ScheduleTool; +pub use schema::{CleaningStrategy, SchemaCleanr}; pub use screenshot::ScreenshotTool; pub use shell::ShellTool; pub use traits::Tool; @@ -141,6 +145,10 @@ pub fn all_tools_with_runtime( security.clone(), workspace_dir.to_path_buf(), )), + Box::new(PushoverTool::new( + security.clone(), + workspace_dir.to_path_buf(), + )), ]; if browser_config.enabled { @@ -195,9 +203,13 @@ pub fn all_tools_with_runtime( .iter() .map(|(name, cfg)| (name.clone(), cfg.clone())) .collect(); + let delegate_fallback_credential = fallback_api_key.and_then(|value| { + let trimmed_value = value.trim(); + (!trimmed_value.is_empty()).then(|| trimmed_value.to_owned()) + }); tools.push(Box::new(DelegateTool::new( delegate_agents, - fallback_api_key.map(String::from), + delegate_fallback_credential, ))); } @@ -261,6 +273,7 @@ mod tests { let names: Vec<&str> = tools.iter().map(|t| t.name()).collect(); assert!(!names.contains(&"browser_open")); assert!(names.contains(&"schedule")); + assert!(names.contains(&"pushover")); } #[test] @@ -298,6 +311,7 @@ mod tests { ); let names: Vec<&str> = tools.iter().map(|t| t.name()).collect(); assert!(names.contains(&"browser_open")); + assert!(names.contains(&"pushover")); } #[test] @@ -432,7 +446,7 @@ mod tests { &http, tmp.path(), &agents, - Some("sk-test"), + Some("delegate-test-credential"), &cfg, ); let names: Vec<&str> = tools.iter().map(|t| t.name()).collect(); diff --git a/src/tools/pushover.rs b/src/tools/pushover.rs new file mode 100644 index 0000000..ad1d385 --- /dev/null +++ b/src/tools/pushover.rs @@ -0,0 +1,442 @@ +use super::traits::{Tool, ToolResult}; +use crate::security::SecurityPolicy; +use async_trait::async_trait; +use reqwest::Client; +use serde_json::json; +use std::path::PathBuf; +use std::sync::Arc; +use std::time::Duration; + +const PUSHOVER_API_URL: &str = "https://api.pushover.net/1/messages.json"; +const PUSHOVER_REQUEST_TIMEOUT_SECS: u64 = 15; + +pub struct PushoverTool { + client: Client, + security: Arc, + workspace_dir: PathBuf, +} + +impl PushoverTool { + pub fn new(security: Arc, workspace_dir: PathBuf) -> Self { + let client = Client::builder() + .timeout(Duration::from_secs(PUSHOVER_REQUEST_TIMEOUT_SECS)) + .build() + .unwrap_or_else(|_| Client::new()); + + Self { + client, + security, + workspace_dir, + } + } + + fn parse_env_value(raw: &str) -> String { + let raw = raw.trim(); + + let unquoted = if raw.len() >= 2 + && ((raw.starts_with('"') && raw.ends_with('"')) + || (raw.starts_with('\'') && raw.ends_with('\''))) + { + &raw[1..raw.len() - 1] + } else { + raw + }; + + // Keep support for inline comments in unquoted values: + // KEY=value # comment + unquoted.split_once(" #").map_or_else( + || unquoted.trim().to_string(), + |(value, _)| value.trim().to_string(), + ) + } + + fn get_credentials(&self) -> anyhow::Result<(String, String)> { + let env_path = self.workspace_dir.join(".env"); + let content = std::fs::read_to_string(&env_path) + .map_err(|e| anyhow::anyhow!("Failed to read {}: {}", env_path.display(), e))?; + + let mut token = None; + let mut user_key = None; + + for line in content.lines() { + let line = line.trim(); + if line.starts_with('#') || line.is_empty() { + continue; + } + let line = line.strip_prefix("export ").map(str::trim).unwrap_or(line); + if let Some((key, value)) = line.split_once('=') { + let key = key.trim(); + let value = Self::parse_env_value(value); + + if key.eq_ignore_ascii_case("PUSHOVER_TOKEN") { + token = Some(value); + } else if key.eq_ignore_ascii_case("PUSHOVER_USER_KEY") { + user_key = Some(value); + } + } + } + + let token = token.ok_or_else(|| anyhow::anyhow!("PUSHOVER_TOKEN not found in .env"))?; + let user_key = + user_key.ok_or_else(|| anyhow::anyhow!("PUSHOVER_USER_KEY not found in .env"))?; + + Ok((token, user_key)) + } +} + +#[async_trait] +impl Tool for PushoverTool { + fn name(&self) -> &str { + "pushover" + } + + fn description(&self) -> &str { + "Send a Pushover notification to your device. Requires PUSHOVER_TOKEN and PUSHOVER_USER_KEY in .env file." + } + + fn parameters_schema(&self) -> serde_json::Value { + json!({ + "type": "object", + "properties": { + "message": { + "type": "string", + "description": "The notification message to send" + }, + "title": { + "type": "string", + "description": "Optional notification title" + }, + "priority": { + "type": "integer", + "enum": [-2, -1, 0, 1, 2], + "description": "Message priority: -2 (lowest/silent), -1 (low/no sound), 0 (normal), 1 (high), 2 (emergency/repeating)" + }, + "sound": { + "type": "string", + "description": "Notification sound override (e.g., 'pushover', 'bike', 'bugle', 'cashregister', etc.)" + } + }, + "required": ["message"] + }) + } + + async fn execute(&self, args: serde_json::Value) -> anyhow::Result { + if !self.security.can_act() { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("Action blocked: autonomy is read-only".into()), + }); + } + + if !self.security.record_action() { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some("Action blocked: rate limit exceeded".into()), + }); + } + + let message = args + .get("message") + .and_then(|v| v.as_str()) + .map(str::trim) + .filter(|v| !v.is_empty()) + .ok_or_else(|| anyhow::anyhow!("Missing 'message' parameter"))? + .to_string(); + + let title = args.get("title").and_then(|v| v.as_str()).map(String::from); + + let priority = match args.get("priority").and_then(|v| v.as_i64()) { + Some(value) if (-2..=2).contains(&value) => Some(value), + Some(value) => { + return Ok(ToolResult { + success: false, + output: String::new(), + error: Some(format!( + "Invalid 'priority': {value}. Expected integer in range -2..=2" + )), + }) + } + None => None, + }; + + let sound = args.get("sound").and_then(|v| v.as_str()).map(String::from); + + let (token, user_key) = self.get_credentials()?; + + let mut form = reqwest::multipart::Form::new() + .text("token", token) + .text("user", user_key) + .text("message", message); + + if let Some(title) = title { + form = form.text("title", title); + } + + if let Some(priority) = priority { + form = form.text("priority", priority.to_string()); + } + + if let Some(sound) = sound { + form = form.text("sound", sound); + } + + let response = self + .client + .post(PUSHOVER_API_URL) + .multipart(form) + .send() + .await?; + + let status = response.status(); + let body = response.text().await.unwrap_or_default(); + + if !status.is_success() { + return Ok(ToolResult { + success: false, + output: body, + error: Some(format!("Pushover API returned status {}", status)), + }); + } + + let api_status = serde_json::from_str::(&body) + .ok() + .and_then(|json| json.get("status").and_then(|value| value.as_i64())); + + if api_status == Some(1) { + Ok(ToolResult { + success: true, + output: format!( + "Pushover notification sent successfully. Response: {}", + body + ), + error: None, + }) + } else { + Ok(ToolResult { + success: false, + output: body, + error: Some("Pushover API returned an application-level error".into()), + }) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::security::AutonomyLevel; + use std::fs; + use tempfile::TempDir; + + fn test_security(level: AutonomyLevel, max_actions_per_hour: u32) -> Arc { + Arc::new(SecurityPolicy { + autonomy: level, + max_actions_per_hour, + workspace_dir: std::env::temp_dir(), + ..SecurityPolicy::default() + }) + } + + #[test] + fn pushover_tool_name() { + let tool = PushoverTool::new( + test_security(AutonomyLevel::Full, 100), + PathBuf::from("/tmp"), + ); + assert_eq!(tool.name(), "pushover"); + } + + #[test] + fn pushover_tool_description() { + let tool = PushoverTool::new( + test_security(AutonomyLevel::Full, 100), + PathBuf::from("/tmp"), + ); + assert!(!tool.description().is_empty()); + } + + #[test] + fn pushover_tool_has_parameters_schema() { + let tool = PushoverTool::new( + test_security(AutonomyLevel::Full, 100), + PathBuf::from("/tmp"), + ); + let schema = tool.parameters_schema(); + assert_eq!(schema["type"], "object"); + assert!(schema["properties"].get("message").is_some()); + } + + #[test] + fn pushover_tool_requires_message() { + let tool = PushoverTool::new( + test_security(AutonomyLevel::Full, 100), + PathBuf::from("/tmp"), + ); + let schema = tool.parameters_schema(); + let required = schema["required"].as_array().unwrap(); + assert!(required.contains(&serde_json::Value::String("message".to_string()))); + } + + #[test] + fn credentials_parsed_from_env_file() { + let tmp = TempDir::new().unwrap(); + let env_path = tmp.path().join(".env"); + fs::write( + &env_path, + "PUSHOVER_TOKEN=testtoken123\nPUSHOVER_USER_KEY=userkey456\n", + ) + .unwrap(); + + let tool = PushoverTool::new( + test_security(AutonomyLevel::Full, 100), + tmp.path().to_path_buf(), + ); + let result = tool.get_credentials(); + + assert!(result.is_ok()); + let (token, user_key) = result.unwrap(); + assert_eq!(token, "testtoken123"); + assert_eq!(user_key, "userkey456"); + } + + #[test] + fn credentials_fail_without_env_file() { + let tmp = TempDir::new().unwrap(); + let tool = PushoverTool::new( + test_security(AutonomyLevel::Full, 100), + tmp.path().to_path_buf(), + ); + let result = tool.get_credentials(); + + assert!(result.is_err()); + } + + #[test] + fn credentials_fail_without_token() { + let tmp = TempDir::new().unwrap(); + let env_path = tmp.path().join(".env"); + fs::write(&env_path, "PUSHOVER_USER_KEY=userkey456\n").unwrap(); + + let tool = PushoverTool::new( + test_security(AutonomyLevel::Full, 100), + tmp.path().to_path_buf(), + ); + let result = tool.get_credentials(); + + assert!(result.is_err()); + } + + #[test] + fn credentials_fail_without_user_key() { + let tmp = TempDir::new().unwrap(); + let env_path = tmp.path().join(".env"); + fs::write(&env_path, "PUSHOVER_TOKEN=testtoken123\n").unwrap(); + + let tool = PushoverTool::new( + test_security(AutonomyLevel::Full, 100), + tmp.path().to_path_buf(), + ); + let result = tool.get_credentials(); + + assert!(result.is_err()); + } + + #[test] + fn credentials_ignore_comments() { + let tmp = TempDir::new().unwrap(); + let env_path = tmp.path().join(".env"); + fs::write(&env_path, "# This is a comment\nPUSHOVER_TOKEN=realtoken\n# Another comment\nPUSHOVER_USER_KEY=realuser\n").unwrap(); + + let tool = PushoverTool::new( + test_security(AutonomyLevel::Full, 100), + tmp.path().to_path_buf(), + ); + let result = tool.get_credentials(); + + assert!(result.is_ok()); + let (token, user_key) = result.unwrap(); + assert_eq!(token, "realtoken"); + assert_eq!(user_key, "realuser"); + } + + #[test] + fn pushover_tool_supports_priority() { + let tool = PushoverTool::new( + test_security(AutonomyLevel::Full, 100), + PathBuf::from("/tmp"), + ); + let schema = tool.parameters_schema(); + assert!(schema["properties"].get("priority").is_some()); + } + + #[test] + fn pushover_tool_supports_sound() { + let tool = PushoverTool::new( + test_security(AutonomyLevel::Full, 100), + PathBuf::from("/tmp"), + ); + let schema = tool.parameters_schema(); + assert!(schema["properties"].get("sound").is_some()); + } + + #[test] + fn credentials_support_export_and_quoted_values() { + let tmp = TempDir::new().unwrap(); + let env_path = tmp.path().join(".env"); + fs::write( + &env_path, + "export PUSHOVER_TOKEN=\"quotedtoken\"\nPUSHOVER_USER_KEY='quoteduser'\n", + ) + .unwrap(); + + let tool = PushoverTool::new( + test_security(AutonomyLevel::Full, 100), + tmp.path().to_path_buf(), + ); + let result = tool.get_credentials(); + + assert!(result.is_ok()); + let (token, user_key) = result.unwrap(); + assert_eq!(token, "quotedtoken"); + assert_eq!(user_key, "quoteduser"); + } + + #[tokio::test] + async fn execute_blocks_readonly_mode() { + let tool = PushoverTool::new( + test_security(AutonomyLevel::ReadOnly, 100), + PathBuf::from("/tmp"), + ); + + let result = tool.execute(json!({"message": "hello"})).await.unwrap(); + assert!(!result.success); + assert!(result.error.unwrap().contains("read-only")); + } + + #[tokio::test] + async fn execute_blocks_rate_limit() { + let tool = PushoverTool::new(test_security(AutonomyLevel::Full, 0), PathBuf::from("/tmp")); + + let result = tool.execute(json!({"message": "hello"})).await.unwrap(); + assert!(!result.success); + assert!(result.error.unwrap().contains("rate limit")); + } + + #[tokio::test] + async fn execute_rejects_priority_out_of_range() { + let tool = PushoverTool::new( + test_security(AutonomyLevel::Full, 100), + PathBuf::from("/tmp"), + ); + + let result = tool + .execute(json!({"message": "hello", "priority": 5})) + .await + .unwrap(); + + assert!(!result.success); + assert!(result.error.unwrap().contains("-2..=2")); + } +} diff --git a/src/tools/schema.rs b/src/tools/schema.rs new file mode 100644 index 0000000..b9a22f4 --- /dev/null +++ b/src/tools/schema.rs @@ -0,0 +1,838 @@ +//! JSON Schema cleaning and validation for LLM tool-calling compatibility. +//! +//! Different providers support different subsets of JSON Schema. This module +//! normalizes tool schemas to improve cross-provider compatibility while +//! preserving semantic intent. +//! +//! ## What this module does +//! +//! 1. Removes unsupported keywords per provider strategy +//! 2. Resolves local `$ref` entries from `$defs` and `definitions` +//! 3. Flattens literal `anyOf` / `oneOf` unions into `enum` +//! 4. Strips nullable variants from unions and `type` arrays +//! 5. Converts `const` to single-value `enum` +//! 6. Detects circular references and stops recursion safely +//! +//! # Example +//! +//! ```rust +//! use serde_json::json; +//! use zeroclaw::tools::schema::SchemaCleanr; +//! +//! let dirty_schema = json!({ +//! "type": "object", +//! "properties": { +//! "name": { +//! "type": "string", +//! "minLength": 1, // Gemini rejects this +//! "pattern": "^[a-z]+$" // Gemini rejects this +//! }, +//! "age": { +//! "$ref": "#/$defs/Age" // Needs resolution +//! } +//! }, +//! "$defs": { +//! "Age": { +//! "type": "integer", +//! "minimum": 0 // Gemini rejects this +//! } +//! } +//! }); +//! +//! let cleaned = SchemaCleanr::clean_for_gemini(dirty_schema); +//! +//! // Result: +//! // { +//! // "type": "object", +//! // "properties": { +//! // "name": { "type": "string" }, +//! // "age": { "type": "integer" } +//! // } +//! // } +//! ``` +//! +use serde_json::{json, Map, Value}; +use std::collections::{HashMap, HashSet}; + +/// Keywords that Gemini rejects for tool schemas. +pub const GEMINI_UNSUPPORTED_KEYWORDS: &[&str] = &[ + // Schema composition + "$ref", + "$schema", + "$id", + "$defs", + "definitions", + // Property constraints + "additionalProperties", + "patternProperties", + // String constraints + "minLength", + "maxLength", + "pattern", + "format", + // Number constraints + "minimum", + "maximum", + "multipleOf", + // Array constraints + "minItems", + "maxItems", + "uniqueItems", + // Object constraints + "minProperties", + "maxProperties", + // Non-standard + "examples", // OpenAPI keyword, not JSON Schema +]; + +/// Keywords that should be preserved during cleaning (metadata). +const SCHEMA_META_KEYS: &[&str] = &["description", "title", "default"]; + +/// Schema cleaning strategies for different LLM providers. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum CleaningStrategy { + /// Gemini (Google AI / Vertex AI) - Most restrictive + Gemini, + /// Anthropic Claude - Moderately permissive + Anthropic, + /// OpenAI GPT - Most permissive + OpenAI, + /// Conservative: Remove only universally unsupported keywords + Conservative, +} + +impl CleaningStrategy { + /// Get the list of unsupported keywords for this strategy. + pub fn unsupported_keywords(&self) -> &'static [&'static str] { + match self { + Self::Gemini => GEMINI_UNSUPPORTED_KEYWORDS, + Self::Anthropic => &["$ref", "$defs", "definitions"], // Anthropic doesn't resolve refs + Self::OpenAI => &[], // OpenAI is most permissive + Self::Conservative => &["$ref", "$defs", "definitions", "additionalProperties"], + } + } +} + +/// JSON Schema cleaner optimized for LLM tool calling. +pub struct SchemaCleanr; + +impl SchemaCleanr { + /// Clean schema for Gemini compatibility (strictest). + /// + /// This is the most aggressive cleaning strategy, removing all keywords + /// that Gemini's API rejects. + pub fn clean_for_gemini(schema: Value) -> Value { + Self::clean(schema, CleaningStrategy::Gemini) + } + + /// Clean schema for Anthropic compatibility. + pub fn clean_for_anthropic(schema: Value) -> Value { + Self::clean(schema, CleaningStrategy::Anthropic) + } + + /// Clean schema for OpenAI compatibility (most permissive). + pub fn clean_for_openai(schema: Value) -> Value { + Self::clean(schema, CleaningStrategy::OpenAI) + } + + /// Clean schema with specified strategy. + pub fn clean(schema: Value, strategy: CleaningStrategy) -> Value { + // Extract $defs for reference resolution + let defs = if let Some(obj) = schema.as_object() { + Self::extract_defs(obj) + } else { + HashMap::new() + }; + + Self::clean_with_defs(schema, &defs, strategy, &mut HashSet::new()) + } + + /// Validate that a schema is suitable for LLM tool calling. + /// + /// Returns an error if the schema is invalid or missing required fields. + pub fn validate(schema: &Value) -> anyhow::Result<()> { + let obj = schema + .as_object() + .ok_or_else(|| anyhow::anyhow!("Schema must be an object"))?; + + // Must have 'type' field + if !obj.contains_key("type") { + anyhow::bail!("Schema missing required 'type' field"); + } + + // If type is 'object', should have 'properties' + if let Some(Value::String(t)) = obj.get("type") { + if t == "object" && !obj.contains_key("properties") { + tracing::warn!("Object schema without 'properties' field may cause issues"); + } + } + + Ok(()) + } + + // -------------------------------------------------------------------- + // Internal implementation + // -------------------------------------------------------------------- + + /// Extract $defs and definitions into a flat map for reference resolution. + fn extract_defs(obj: &Map) -> HashMap { + let mut defs = HashMap::new(); + + // Extract from $defs (JSON Schema 2019-09+) + if let Some(Value::Object(defs_obj)) = obj.get("$defs") { + for (key, value) in defs_obj { + defs.insert(key.clone(), value.clone()); + } + } + + // Extract from definitions (JSON Schema draft-07) + if let Some(Value::Object(defs_obj)) = obj.get("definitions") { + for (key, value) in defs_obj { + defs.insert(key.clone(), value.clone()); + } + } + + defs + } + + /// Recursively clean a schema value. + fn clean_with_defs( + schema: Value, + defs: &HashMap, + strategy: CleaningStrategy, + ref_stack: &mut HashSet, + ) -> Value { + match schema { + Value::Object(obj) => Self::clean_object(obj, defs, strategy, ref_stack), + Value::Array(arr) => Value::Array( + arr.into_iter() + .map(|v| Self::clean_with_defs(v, defs, strategy, ref_stack)) + .collect(), + ), + other => other, + } + } + + /// Clean an object schema. + fn clean_object( + obj: Map, + defs: &HashMap, + strategy: CleaningStrategy, + ref_stack: &mut HashSet, + ) -> Value { + // Handle $ref resolution + if let Some(Value::String(ref_value)) = obj.get("$ref") { + return Self::resolve_ref(ref_value, &obj, defs, strategy, ref_stack); + } + + // Handle anyOf/oneOf simplification + if obj.contains_key("anyOf") || obj.contains_key("oneOf") { + if let Some(simplified) = Self::try_simplify_union(&obj, defs, strategy, ref_stack) { + return simplified; + } + } + + // Build cleaned object + let mut cleaned = Map::new(); + let unsupported: HashSet<&str> = strategy.unsupported_keywords().iter().copied().collect(); + let has_union = obj.contains_key("anyOf") || obj.contains_key("oneOf"); + + for (key, value) in obj { + // Skip unsupported keywords + if unsupported.contains(key.as_str()) { + continue; + } + + // Special handling for specific keys + match key.as_str() { + // Convert const to enum + "const" => { + cleaned.insert("enum".to_string(), json!([value])); + } + // Skip type if we have anyOf/oneOf (they define the type) + "type" if has_union => { + // Skip + } + // Handle type arrays (remove null) + "type" if matches!(value, Value::Array(_)) => { + let cleaned_value = Self::clean_type_array(value); + cleaned.insert(key, cleaned_value); + } + // Recursively clean nested schemas + "properties" => { + let cleaned_value = Self::clean_properties(value, defs, strategy, ref_stack); + cleaned.insert(key, cleaned_value); + } + "items" => { + let cleaned_value = Self::clean_with_defs(value, defs, strategy, ref_stack); + cleaned.insert(key, cleaned_value); + } + "anyOf" | "oneOf" | "allOf" => { + let cleaned_value = Self::clean_union(value, defs, strategy, ref_stack); + cleaned.insert(key, cleaned_value); + } + // Keep all other keys, cleaning nested objects/arrays recursively. + _ => { + let cleaned_value = match value { + Value::Object(_) | Value::Array(_) => { + Self::clean_with_defs(value, defs, strategy, ref_stack) + } + other => other, + }; + cleaned.insert(key, cleaned_value); + } + } + } + + Value::Object(cleaned) + } + + /// Resolve a $ref to its definition. + fn resolve_ref( + ref_value: &str, + obj: &Map, + defs: &HashMap, + strategy: CleaningStrategy, + ref_stack: &mut HashSet, + ) -> Value { + // Prevent circular references + if ref_stack.contains(ref_value) { + tracing::warn!("Circular $ref detected: {}", ref_value); + return Self::preserve_meta(obj, Value::Object(Map::new())); + } + + // Try to resolve local ref (#/$defs/Name or #/definitions/Name) + if let Some(def_name) = Self::parse_local_ref(ref_value) { + if let Some(definition) = defs.get(def_name.as_str()) { + ref_stack.insert(ref_value.to_string()); + let cleaned = Self::clean_with_defs(definition.clone(), defs, strategy, ref_stack); + ref_stack.remove(ref_value); + return Self::preserve_meta(obj, cleaned); + } + } + + // Can't resolve: return empty object with metadata + tracing::warn!("Cannot resolve $ref: {}", ref_value); + Self::preserve_meta(obj, Value::Object(Map::new())) + } + + /// Parse a local JSON Pointer ref (#/$defs/Name). + fn parse_local_ref(ref_value: &str) -> Option { + ref_value + .strip_prefix("#/$defs/") + .or_else(|| ref_value.strip_prefix("#/definitions/")) + .map(Self::decode_json_pointer) + } + + /// Decode JSON Pointer escaping (`~0` = `~`, `~1` = `/`). + fn decode_json_pointer(segment: &str) -> String { + if !segment.contains('~') { + return segment.to_string(); + } + + let mut decoded = String::with_capacity(segment.len()); + let mut chars = segment.chars().peekable(); + + while let Some(ch) = chars.next() { + if ch == '~' { + match chars.peek().copied() { + Some('0') => { + chars.next(); + decoded.push('~'); + } + Some('1') => { + chars.next(); + decoded.push('/'); + } + _ => decoded.push('~'), + } + } else { + decoded.push(ch); + } + } + + decoded + } + + /// Try to simplify anyOf/oneOf to a simpler form. + fn try_simplify_union( + obj: &Map, + defs: &HashMap, + strategy: CleaningStrategy, + ref_stack: &mut HashSet, + ) -> Option { + let union_key = if obj.contains_key("anyOf") { + "anyOf" + } else if obj.contains_key("oneOf") { + "oneOf" + } else { + return None; + }; + + let variants = obj.get(union_key)?.as_array()?; + + // Clean all variants first + let cleaned_variants: Vec = variants + .iter() + .map(|v| Self::clean_with_defs(v.clone(), defs, strategy, ref_stack)) + .collect(); + + // Strip null variants + let non_null: Vec = cleaned_variants + .into_iter() + .filter(|v| !Self::is_null_schema(v)) + .collect(); + + // If only one variant remains after stripping nulls, return it + if non_null.len() == 1 { + return Some(Self::preserve_meta(obj, non_null[0].clone())); + } + + // Try to flatten to enum if all variants are literals + if let Some(enum_value) = Self::try_flatten_literal_union(&non_null) { + return Some(Self::preserve_meta(obj, enum_value)); + } + + None + } + + /// Check if a schema represents null type. + fn is_null_schema(value: &Value) -> bool { + if let Some(obj) = value.as_object() { + // { const: null } + if let Some(Value::Null) = obj.get("const") { + return true; + } + // { enum: [null] } + if let Some(Value::Array(arr)) = obj.get("enum") { + if arr.len() == 1 && matches!(arr[0], Value::Null) { + return true; + } + } + // { type: "null" } + if let Some(Value::String(t)) = obj.get("type") { + if t == "null" { + return true; + } + } + } + false + } + + /// Try to flatten anyOf/oneOf with only literal values to enum. + /// + /// Example: `anyOf: [{const: "a"}, {const: "b"}]` -> `{type: "string", enum: ["a", "b"]}` + fn try_flatten_literal_union(variants: &[Value]) -> Option { + if variants.is_empty() { + return None; + } + + let mut all_values = Vec::new(); + let mut common_type: Option = None; + + for variant in variants { + let obj = variant.as_object()?; + + // Extract literal value from const or single-item enum + let literal_value = if let Some(const_val) = obj.get("const") { + const_val.clone() + } else if let Some(Value::Array(arr)) = obj.get("enum") { + if arr.len() == 1 { + arr[0].clone() + } else { + return None; + } + } else { + return None; + }; + + // Check type consistency + let variant_type = obj.get("type")?.as_str()?; + match &common_type { + None => common_type = Some(variant_type.to_string()), + Some(t) if t != variant_type => return None, + _ => {} + } + + all_values.push(literal_value); + } + + common_type.map(|t| { + json!({ + "type": t, + "enum": all_values + }) + }) + } + + /// Clean type array, removing null. + fn clean_type_array(value: Value) -> Value { + if let Value::Array(types) = value { + let non_null: Vec = types + .into_iter() + .filter(|v| v.as_str() != Some("null")) + .collect(); + + match non_null.len() { + 0 => Value::String("null".to_string()), + 1 => non_null + .into_iter() + .next() + .unwrap_or(Value::String("null".to_string())), + _ => Value::Array(non_null), + } + } else { + value + } + } + + /// Clean properties object. + fn clean_properties( + value: Value, + defs: &HashMap, + strategy: CleaningStrategy, + ref_stack: &mut HashSet, + ) -> Value { + if let Value::Object(props) = value { + let cleaned: Map = props + .into_iter() + .map(|(k, v)| (k, Self::clean_with_defs(v, defs, strategy, ref_stack))) + .collect(); + Value::Object(cleaned) + } else { + value + } + } + + /// Clean union (anyOf/oneOf/allOf). + fn clean_union( + value: Value, + defs: &HashMap, + strategy: CleaningStrategy, + ref_stack: &mut HashSet, + ) -> Value { + if let Value::Array(variants) = value { + let cleaned: Vec = variants + .into_iter() + .map(|v| Self::clean_with_defs(v, defs, strategy, ref_stack)) + .collect(); + Value::Array(cleaned) + } else { + value + } + } + + /// Preserve metadata (description, title, default) from source to target. + fn preserve_meta(source: &Map, mut target: Value) -> Value { + if let Value::Object(target_obj) = &mut target { + for &key in SCHEMA_META_KEYS { + if let Some(value) = source.get(key) { + target_obj.insert(key.to_string(), value.clone()); + } + } + } + target + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_remove_unsupported_keywords() { + let schema = json!({ + "type": "string", + "minLength": 1, + "maxLength": 100, + "pattern": "^[a-z]+$", + "description": "A lowercase string" + }); + + let cleaned = SchemaCleanr::clean_for_gemini(schema); + + assert_eq!(cleaned["type"], "string"); + assert_eq!(cleaned["description"], "A lowercase string"); + assert!(cleaned.get("minLength").is_none()); + assert!(cleaned.get("maxLength").is_none()); + assert!(cleaned.get("pattern").is_none()); + } + + #[test] + fn test_resolve_ref() { + let schema = json!({ + "type": "object", + "properties": { + "age": { + "$ref": "#/$defs/Age" + } + }, + "$defs": { + "Age": { + "type": "integer", + "minimum": 0 + } + } + }); + + let cleaned = SchemaCleanr::clean_for_gemini(schema); + + assert_eq!(cleaned["properties"]["age"]["type"], "integer"); + assert!(cleaned["properties"]["age"].get("minimum").is_none()); // Stripped by Gemini strategy + assert!(cleaned.get("$defs").is_none()); + } + + #[test] + fn test_flatten_literal_union() { + let schema = json!({ + "anyOf": [ + { "const": "admin", "type": "string" }, + { "const": "user", "type": "string" }, + { "const": "guest", "type": "string" } + ] + }); + + let cleaned = SchemaCleanr::clean_for_gemini(schema); + + assert_eq!(cleaned["type"], "string"); + assert!(cleaned["enum"].is_array()); + let enum_values = cleaned["enum"].as_array().unwrap(); + assert_eq!(enum_values.len(), 3); + assert!(enum_values.contains(&json!("admin"))); + assert!(enum_values.contains(&json!("user"))); + assert!(enum_values.contains(&json!("guest"))); + } + + #[test] + fn test_strip_null_from_union() { + let schema = json!({ + "oneOf": [ + { "type": "string" }, + { "type": "null" } + ] + }); + + let cleaned = SchemaCleanr::clean_for_gemini(schema); + + // Should simplify to just { type: "string" } + assert_eq!(cleaned["type"], "string"); + assert!(cleaned.get("oneOf").is_none()); + } + + #[test] + fn test_const_to_enum() { + let schema = json!({ + "const": "fixed_value", + "description": "A constant" + }); + + let cleaned = SchemaCleanr::clean_for_gemini(schema); + + assert_eq!(cleaned["enum"], json!(["fixed_value"])); + assert_eq!(cleaned["description"], "A constant"); + assert!(cleaned.get("const").is_none()); + } + + #[test] + fn test_preserve_metadata() { + let schema = json!({ + "$ref": "#/$defs/Name", + "description": "User's name", + "title": "Name Field", + "default": "Anonymous", + "$defs": { + "Name": { + "type": "string" + } + } + }); + + let cleaned = SchemaCleanr::clean_for_gemini(schema); + + assert_eq!(cleaned["type"], "string"); + assert_eq!(cleaned["description"], "User's name"); + assert_eq!(cleaned["title"], "Name Field"); + assert_eq!(cleaned["default"], "Anonymous"); + } + + #[test] + fn test_circular_ref_prevention() { + let schema = json!({ + "type": "object", + "properties": { + "parent": { + "$ref": "#/$defs/Node" + } + }, + "$defs": { + "Node": { + "type": "object", + "properties": { + "child": { + "$ref": "#/$defs/Node" + } + } + } + } + }); + + // Should not panic on circular reference + let cleaned = SchemaCleanr::clean_for_gemini(schema); + + assert_eq!(cleaned["properties"]["parent"]["type"], "object"); + // Circular reference should be broken + } + + #[test] + fn test_validate_schema() { + let valid = json!({ + "type": "object", + "properties": { + "name": { "type": "string" } + } + }); + + assert!(SchemaCleanr::validate(&valid).is_ok()); + + let invalid = json!({ + "properties": { + "name": { "type": "string" } + } + }); + + assert!(SchemaCleanr::validate(&invalid).is_err()); + } + + #[test] + fn test_strategy_differences() { + let schema = json!({ + "type": "string", + "minLength": 1, + "description": "A string field" + }); + + // Gemini: Most restrictive (removes minLength) + let gemini = SchemaCleanr::clean_for_gemini(schema.clone()); + assert!(gemini.get("minLength").is_none()); + assert_eq!(gemini["type"], "string"); + assert_eq!(gemini["description"], "A string field"); + + // OpenAI: Most permissive (keeps minLength) + let openai = SchemaCleanr::clean_for_openai(schema.clone()); + assert_eq!(openai["minLength"], 1); // OpenAI allows validation keywords + assert_eq!(openai["type"], "string"); + } + + #[test] + fn test_nested_properties() { + let schema = json!({ + "type": "object", + "properties": { + "user": { + "type": "object", + "properties": { + "name": { + "type": "string", + "minLength": 1 + } + }, + "additionalProperties": false + } + } + }); + + let cleaned = SchemaCleanr::clean_for_gemini(schema); + + assert!(cleaned["properties"]["user"]["properties"]["name"] + .get("minLength") + .is_none()); + assert!(cleaned["properties"]["user"] + .get("additionalProperties") + .is_none()); + } + + #[test] + fn test_type_array_null_removal() { + let schema = json!({ + "type": ["string", "null"] + }); + + let cleaned = SchemaCleanr::clean_for_gemini(schema); + + // Should simplify to just "string" + assert_eq!(cleaned["type"], "string"); + } + + #[test] + fn test_type_array_only_null_preserved() { + let schema = json!({ + "type": ["null"] + }); + + let cleaned = SchemaCleanr::clean_for_gemini(schema); + + assert_eq!(cleaned["type"], "null"); + } + + #[test] + fn test_ref_with_json_pointer_escape() { + let schema = json!({ + "$ref": "#/$defs/Foo~1Bar", + "$defs": { + "Foo/Bar": { + "type": "string" + } + } + }); + + let cleaned = SchemaCleanr::clean_for_gemini(schema); + + assert_eq!(cleaned["type"], "string"); + } + + #[test] + fn test_skip_type_when_non_simplifiable_union_exists() { + let schema = json!({ + "type": "object", + "oneOf": [ + { + "type": "object", + "properties": { + "a": { "type": "string" } + } + }, + { + "type": "object", + "properties": { + "b": { "type": "number" } + } + } + ] + }); + + let cleaned = SchemaCleanr::clean_for_gemini(schema); + + assert!(cleaned.get("type").is_none()); + assert!(cleaned.get("oneOf").is_some()); + } + + #[test] + fn test_clean_nested_unknown_schema_keyword() { + let schema = json!({ + "not": { + "$ref": "#/$defs/Age" + }, + "$defs": { + "Age": { + "type": "integer", + "minimum": 0 + } + } + }); + + let cleaned = SchemaCleanr::clean_for_gemini(schema); + + assert_eq!(cleaned["not"]["type"], "integer"); + assert!(cleaned["not"].get("minimum").is_none()); + } +} diff --git a/tests/memory_comparison.rs b/tests/memory_comparison.rs index 8e0f4d6..2523829 100644 --- a/tests/memory_comparison.rs +++ b/tests/memory_comparison.rs @@ -36,6 +36,7 @@ async fn compare_store_speed() { &format!("key_{i}"), &format!("Memory entry number {i} about Rust programming"), MemoryCategory::Core, + None, ) .await .unwrap(); @@ -49,6 +50,7 @@ async fn compare_store_speed() { &format!("key_{i}"), &format!("Memory entry number {i} about Rust programming"), MemoryCategory::Core, + None, ) .await .unwrap(); @@ -127,8 +129,8 @@ async fn compare_recall_quality() { ]; for (key, content, cat) in &entries { - sq.store(key, content, cat.clone()).await.unwrap(); - md.store(key, content, cat.clone()).await.unwrap(); + sq.store(key, content, cat.clone(), None).await.unwrap(); + md.store(key, content, cat.clone(), None).await.unwrap(); } // Test queries and compare results @@ -145,8 +147,8 @@ async fn compare_recall_quality() { println!("RECALL QUALITY (10 entries seeded):\n"); for (query, desc) in &queries { - let sq_results = sq.recall(query, 10).await.unwrap(); - let md_results = md.recall(query, 10).await.unwrap(); + let sq_results = sq.recall(query, 10, None).await.unwrap(); + let md_results = md.recall(query, 10, None).await.unwrap(); println!(" Query: \"{query}\" — {desc}"); println!(" SQLite: {} results", sq_results.len()); @@ -190,21 +192,21 @@ async fn compare_recall_speed() { } else { format!("TypeScript powers modern web apps, entry {i}") }; - sq.store(&format!("e{i}"), &content, MemoryCategory::Core) + sq.store(&format!("e{i}"), &content, MemoryCategory::Core, None) .await .unwrap(); - md.store(&format!("e{i}"), &content, MemoryCategory::Daily) + md.store(&format!("e{i}"), &content, MemoryCategory::Daily, None) .await .unwrap(); } // Benchmark recall let start = Instant::now(); - let sq_results = sq.recall("Rust systems", 10).await.unwrap(); + let sq_results = sq.recall("Rust systems", 10, None).await.unwrap(); let sq_dur = start.elapsed(); let start = Instant::now(); - let md_results = md.recall("Rust systems", 10).await.unwrap(); + let md_results = md.recall("Rust systems", 10, None).await.unwrap(); let md_dur = start.elapsed(); println!("\n============================================================"); @@ -227,15 +229,25 @@ async fn compare_persistence() { // Store in both, then drop and re-open { let sq = sqlite_backend(tmp_sq.path()); - sq.store("persist_test", "I should survive", MemoryCategory::Core) - .await - .unwrap(); + sq.store( + "persist_test", + "I should survive", + MemoryCategory::Core, + None, + ) + .await + .unwrap(); } { let md = markdown_backend(tmp_md.path()); - md.store("persist_test", "I should survive", MemoryCategory::Core) - .await - .unwrap(); + md.store( + "persist_test", + "I should survive", + MemoryCategory::Core, + None, + ) + .await + .unwrap(); } // Re-open @@ -282,17 +294,17 @@ async fn compare_upsert() { let md = markdown_backend(tmp_md.path()); // Store twice with same key, different content - sq.store("pref", "likes Rust", MemoryCategory::Core) + sq.store("pref", "likes Rust", MemoryCategory::Core, None) .await .unwrap(); - sq.store("pref", "loves Rust", MemoryCategory::Core) + sq.store("pref", "loves Rust", MemoryCategory::Core, None) .await .unwrap(); - md.store("pref", "likes Rust", MemoryCategory::Core) + md.store("pref", "likes Rust", MemoryCategory::Core, None) .await .unwrap(); - md.store("pref", "loves Rust", MemoryCategory::Core) + md.store("pref", "loves Rust", MemoryCategory::Core, None) .await .unwrap(); @@ -300,7 +312,7 @@ async fn compare_upsert() { let md_count = md.count().await.unwrap(); let sq_entry = sq.get("pref").await.unwrap(); - let md_results = md.recall("loves Rust", 5).await.unwrap(); + let md_results = md.recall("loves Rust", 5, None).await.unwrap(); println!("\n============================================================"); println!("UPSERT (store same key twice):"); @@ -328,10 +340,10 @@ async fn compare_forget() { let sq = sqlite_backend(tmp_sq.path()); let md = markdown_backend(tmp_md.path()); - sq.store("secret", "API key: sk-1234", MemoryCategory::Core) + sq.store("secret", "API key: sk-1234", MemoryCategory::Core, None) .await .unwrap(); - md.store("secret", "API key: sk-1234", MemoryCategory::Core) + md.store("secret", "API key: sk-1234", MemoryCategory::Core, None) .await .unwrap(); @@ -372,37 +384,40 @@ async fn compare_category_filter() { let md = markdown_backend(tmp_md.path()); // Mix of categories - sq.store("a", "core fact 1", MemoryCategory::Core) + sq.store("a", "core fact 1", MemoryCategory::Core, None) .await .unwrap(); - sq.store("b", "core fact 2", MemoryCategory::Core) + sq.store("b", "core fact 2", MemoryCategory::Core, None) .await .unwrap(); - sq.store("c", "daily note", MemoryCategory::Daily) + sq.store("c", "daily note", MemoryCategory::Daily, None) .await .unwrap(); - sq.store("d", "convo msg", MemoryCategory::Conversation) + sq.store("d", "convo msg", MemoryCategory::Conversation, None) .await .unwrap(); - md.store("a", "core fact 1", MemoryCategory::Core) + md.store("a", "core fact 1", MemoryCategory::Core, None) .await .unwrap(); - md.store("b", "core fact 2", MemoryCategory::Core) + md.store("b", "core fact 2", MemoryCategory::Core, None) .await .unwrap(); - md.store("c", "daily note", MemoryCategory::Daily) + md.store("c", "daily note", MemoryCategory::Daily, None) .await .unwrap(); - let sq_core = sq.list(Some(&MemoryCategory::Core)).await.unwrap(); - let sq_daily = sq.list(Some(&MemoryCategory::Daily)).await.unwrap(); - let sq_conv = sq.list(Some(&MemoryCategory::Conversation)).await.unwrap(); - let sq_all = sq.list(None).await.unwrap(); + let sq_core = sq.list(Some(&MemoryCategory::Core), None).await.unwrap(); + let sq_daily = sq.list(Some(&MemoryCategory::Daily), None).await.unwrap(); + let sq_conv = sq + .list(Some(&MemoryCategory::Conversation), None) + .await + .unwrap(); + let sq_all = sq.list(None, None).await.unwrap(); - let md_core = md.list(Some(&MemoryCategory::Core)).await.unwrap(); - let md_daily = md.list(Some(&MemoryCategory::Daily)).await.unwrap(); - let md_all = md.list(None).await.unwrap(); + let md_core = md.list(Some(&MemoryCategory::Core), None).await.unwrap(); + let md_daily = md.list(Some(&MemoryCategory::Daily), None).await.unwrap(); + let md_all = md.list(None, None).await.unwrap(); println!("\n============================================================"); println!("CATEGORY FILTERING:");