diff --git a/.env.example b/.env.example
index 17686d3..7a2c253 100644
--- a/.env.example
+++ b/.env.example
@@ -1,25 +1,69 @@
# ZeroClaw Environment Variables
-# Copy this file to .env and fill in your values.
-# NEVER commit .env — it is listed in .gitignore.
+# Copy this file to `.env` and fill in your local values.
+# Never commit `.env` or any real secrets.
-# ── Required ──────────────────────────────────────────────────
-# Your LLM provider API key
-# ZEROCLAW_API_KEY=sk-your-key-here
+# ── Core Runtime ──────────────────────────────────────────────
+# Provider key resolution at runtime:
+# 1) explicit key passed from config/CLI
+# 2) provider-specific env var (OPENROUTER_API_KEY, OPENAI_API_KEY, ...)
+# 3) generic fallback env vars below
+
+# Generic fallback API key (used when provider-specific key is absent)
API_KEY=your-api-key-here
+# ZEROCLAW_API_KEY=your-api-key-here
-# ── Provider & Model ─────────────────────────────────────────
-# LLM provider: openrouter, openai, anthropic, ollama, glm
+# Default provider/model (can be overridden by CLI flags)
PROVIDER=openrouter
+# ZEROCLAW_PROVIDER=openrouter
# ZEROCLAW_MODEL=anthropic/claude-sonnet-4-20250514
# ZEROCLAW_TEMPERATURE=0.7
+# Workspace directory override
+# ZEROCLAW_WORKSPACE=/path/to/workspace
+
+# ── Provider-Specific API Keys ────────────────────────────────
+# OpenRouter
+# OPENROUTER_API_KEY=sk-or-v1-...
+
+# Anthropic
+# ANTHROPIC_OAUTH_TOKEN=...
+# ANTHROPIC_API_KEY=sk-ant-...
+
+# OpenAI / Gemini
+# OPENAI_API_KEY=sk-...
+# GEMINI_API_KEY=...
+# GOOGLE_API_KEY=...
+
+# Other supported providers
+# VENICE_API_KEY=...
+# GROQ_API_KEY=...
+# MISTRAL_API_KEY=...
+# DEEPSEEK_API_KEY=...
+# XAI_API_KEY=...
+# TOGETHER_API_KEY=...
+# FIREWORKS_API_KEY=...
+# PERPLEXITY_API_KEY=...
+# COHERE_API_KEY=...
+# MOONSHOT_API_KEY=...
+# GLM_API_KEY=...
+# MINIMAX_API_KEY=...
+# QIANFAN_API_KEY=...
+# DASHSCOPE_API_KEY=...
+# ZAI_API_KEY=...
+# SYNTHETIC_API_KEY=...
+# OPENCODE_API_KEY=...
+# VERCEL_API_KEY=...
+# CLOUDFLARE_API_KEY=...
+
# ── Gateway ──────────────────────────────────────────────────
# ZEROCLAW_GATEWAY_PORT=3000
# ZEROCLAW_GATEWAY_HOST=127.0.0.1
# ZEROCLAW_ALLOW_PUBLIC_BIND=false
-# ── Workspace ────────────────────────────────────────────────
-# ZEROCLAW_WORKSPACE=/path/to/workspace
+# ── Optional Integrations ────────────────────────────────────
+# Pushover notifications (`pushover` tool)
+# PUSHOVER_TOKEN=your-pushover-app-token
+# PUSHOVER_USER_KEY=your-pushover-user-key
# ── Docker Compose ───────────────────────────────────────────
# Host port mapping (used by docker-compose.yml)
diff --git a/.githooks/pre-commit b/.githooks/pre-commit
new file mode 100755
index 0000000..d162ba3
--- /dev/null
+++ b/.githooks/pre-commit
@@ -0,0 +1,8 @@
+#!/usr/bin/env bash
+set -euo pipefail
+
+if command -v gitleaks >/dev/null 2>&1; then
+ gitleaks protect --staged --redact
+else
+ echo "warning: gitleaks not found; skipping staged secret scan" >&2
+fi
diff --git a/.github/pull_request_template.md b/.github/pull_request_template.md
index 550bd95..7c9e601 100644
--- a/.github/pull_request_template.md
+++ b/.github/pull_request_template.md
@@ -12,7 +12,11 @@ Describe this PR in 2-5 bullets:
- Risk label (`risk: low|medium|high`):
- Size label (`size: XS|S|M|L|XL`, auto-managed/read-only):
- Scope labels (`core|agent|channel|config|cron|daemon|doctor|gateway|health|heartbeat|integration|memory|observability|onboard|provider|runtime|security|service|skillforge|skills|tool|tunnel|docs|dependencies|ci|tests|scripts|dev`, comma-separated):
+<<<<<<< chore/labeler-spacing-trusted-tier
+- Module labels (`: `, for example `channel: telegram`, `provider: kimi`, `tool: shell`):
+=======
- Module labels (`:`, for example `channel:telegram`, `provider:kimi`, `tool:shell`):
+>>>>>>> main
- Contributor tier label (`trusted contributor|experienced contributor|principal contributor|distinguished contributor`, auto-managed/read-only; author merged PRs >=5/10/20/50):
- If any auto-label is incorrect, note requested correction:
diff --git a/.github/workflows/auto-response.yml b/.github/workflows/auto-response.yml
index 4398085..753bb52 100644
--- a/.github/workflows/auto-response.yml
+++ b/.github/workflows/auto-response.yml
@@ -18,6 +18,7 @@ jobs:
runs-on: blacksmith-2vcpu-ubuntu-2404
permissions:
issues: write
+ pull-requests: write
steps:
- name: Apply contributor tier label for issue author
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8
diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml
index 63ea2ad..67005c6 100644
--- a/.github/workflows/docker.yml
+++ b/.github/workflows/docker.yml
@@ -35,7 +35,7 @@ jobs:
uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4
- name: Setup Blacksmith Builder
- uses: useblacksmith/setup-docker-builder@v1
+ uses: useblacksmith/setup-docker-builder@ef12d5b165b596e3aa44ea8198d8fde563eab402 # v1
- name: Extract metadata (tags, labels)
id: meta
@@ -46,7 +46,7 @@ jobs:
type=ref,event=pr
- name: Build smoke image
- uses: useblacksmith/build-push-action@v2
+ uses: useblacksmith/build-push-action@30c71162f16ea2c27c3e21523255d209b8b538c1 # v2
with:
context: .
push: false
@@ -71,7 +71,7 @@ jobs:
uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4
- name: Setup Blacksmith Builder
- uses: useblacksmith/setup-docker-builder@v1
+ uses: useblacksmith/setup-docker-builder@ef12d5b165b596e3aa44ea8198d8fde563eab402 # v1
- name: Log in to Container Registry
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # v3
@@ -102,7 +102,7 @@ jobs:
echo "tags=${TAGS}" >> "$GITHUB_OUTPUT"
- name: Build and push Docker image
- uses: useblacksmith/build-push-action@v2
+ uses: useblacksmith/build-push-action@30c71162f16ea2c27c3e21523255d209b8b538c1 # v2
with:
context: .
push: true
diff --git a/.github/workflows/labeler.yml b/.github/workflows/labeler.yml
index d629a1f..10d8bfb 100644
--- a/.github/workflows/labeler.yml
+++ b/.github/workflows/labeler.yml
@@ -325,13 +325,18 @@ jobs:
return pattern.test(text);
}
+ function formatModuleLabel(prefix, segment) {
+ return `${prefix}: ${segment}`;
+ }
+
function parseModuleLabel(label) {
- const separatorIndex = label.indexOf(":");
- if (separatorIndex <= 0 || separatorIndex >= label.length - 1) return null;
- return {
- prefix: label.slice(0, separatorIndex),
- segment: label.slice(separatorIndex + 1),
- };
+ if (typeof label !== "string") return null;
+ const match = label.match(/^([^:]+):\s*(.+)$/);
+ if (!match) return null;
+ const prefix = match[1].trim().toLowerCase();
+ const segment = (match[2] || "").trim().toLowerCase();
+ if (!prefix || !segment) return null;
+ return { prefix, segment };
}
function sortByPriority(labels, priorityIndex) {
@@ -389,7 +394,7 @@ jobs:
for (const [prefix, segments] of segmentsByPrefix) {
const hasSpecificSegment = [...segments].some((segment) => segment !== "core");
if (hasSpecificSegment) {
- refined.delete(`${prefix}:core`);
+ refined.delete(formatModuleLabel(prefix, "core"));
}
}
@@ -418,7 +423,7 @@ jobs:
if (uniqueSegments.length === 0) continue;
if (uniqueSegments.length === 1) {
- compactedModuleLabels.add(`${prefix}:${uniqueSegments[0]}`);
+ compactedModuleLabels.add(formatModuleLabel(prefix, uniqueSegments[0]));
} else {
forcePathPrefixes.add(prefix);
}
@@ -609,7 +614,7 @@ jobs:
segment = normalizeLabelSegment(segment);
if (!segment) continue;
- detectedModuleLabels.add(`${rule.prefix}:${segment}`);
+ detectedModuleLabels.add(formatModuleLabel(rule.prefix, segment));
}
}
@@ -635,7 +640,7 @@ jobs:
for (const keyword of providerKeywordHints) {
if (containsKeyword(searchableText, keyword)) {
- detectedModuleLabels.add(`provider:${keyword}`);
+ detectedModuleLabels.add(formatModuleLabel("provider", keyword));
}
}
}
@@ -661,7 +666,7 @@ jobs:
for (const keyword of channelKeywordHints) {
if (containsKeyword(searchableText, keyword)) {
- detectedModuleLabels.add(`channel:${keyword}`);
+ detectedModuleLabels.add(formatModuleLabel("channel", keyword));
}
}
}
diff --git a/.gitignore b/.gitignore
index 49980c2..9440b79 100644
--- a/.gitignore
+++ b/.gitignore
@@ -4,6 +4,26 @@ firmware/*/target
*.db-journal
.DS_Store
.wt-pr37/
-.env
__pycache__/
*.pyc
+docker-compose.override.yml
+
+# Environment files (may contain secrets)
+.env
+
+# Python virtual environments
+
+.venv/
+venv/
+
+# ESP32 build cache (esp-idf-sys managed)
+
+.embuild/
+.env.local
+.env.*.local
+
+# Secret keys and credentials
+.secret_key
+*.key
+*.pem
+credentials.json
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index a25ad4e..d98a2ce 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -79,6 +79,94 @@ git push --no-verify
> **Note:** CI runs the same checks, so skipped hooks will be caught on the PR.
+## Local Secret Management (Required)
+
+ZeroClaw supports layered secret management for local development and CI hygiene.
+
+### Secret Storage Options
+
+1. **Environment variables** (recommended for local development)
+ - Copy `.env.example` to `.env` and fill in values
+ - `.env` files are Git-ignored and should stay local
+ - Best for temporary/local API keys
+
+2. **Config file** (`~/.zeroclaw/config.toml`)
+ - Persistent setup for long-term use
+ - When `secrets.encrypt = true` (default), secret values are encrypted before save
+ - Secret key is stored at `~/.zeroclaw/.secret_key` with restricted permissions
+ - Use `zeroclaw onboard` for guided setup
+
+### Runtime Resolution Rules
+
+API key resolution follows this order:
+
+1. Explicit key passed from config/CLI
+2. Provider-specific env vars (`OPENROUTER_API_KEY`, `OPENAI_API_KEY`, `ANTHROPIC_API_KEY`, ...)
+3. Generic env vars (`ZEROCLAW_API_KEY`, `API_KEY`)
+
+Provider/model config overrides:
+
+- `ZEROCLAW_PROVIDER` / `PROVIDER`
+- `ZEROCLAW_MODEL`
+
+See `.env.example` for practical examples and currently supported provider key env vars.
+
+### Pre-Commit Secret Hygiene (Mandatory)
+
+Before every commit, verify:
+
+- [ ] No `.env` files are staged (`.env.example` only)
+- [ ] No raw API keys/tokens in code, tests, fixtures, examples, logs, or commit messages
+- [ ] No credentials in debug output or error payloads
+- [ ] `git diff --cached` has no accidental secret-like strings
+
+Quick local audit:
+
+```bash
+# Search staged diff for common secret markers
+git diff --cached | grep -iE '(api[_-]?key|secret|token|password|bearer|sk-)'
+
+# Confirm no .env file is staged
+git status --short | grep -E '\.env$'
+```
+
+### Optional Local Secret Scanning
+
+For extra guardrails, install one of:
+
+- **gitleaks**: [GitHub - gitleaks/gitleaks](https://github.com/gitleaks/gitleaks)
+- **truffleHog**: [GitHub - trufflesecurity/trufflehog](https://github.com/trufflesecurity/trufflehog)
+- **git-secrets**: [GitHub - awslabs/git-secrets](https://github.com/awslabs/git-secrets)
+
+This repo includes `.githooks/pre-commit` to run `gitleaks protect --staged --redact` when gitleaks is installed.
+
+Enable hooks with:
+
+```bash
+git config core.hooksPath .githooks
+```
+
+If gitleaks is not installed, the pre-commit hook prints a warning and continues.
+
+### What Must Never Be Committed
+
+- `.env` files (use `.env.example` only)
+- API keys, tokens, passwords, or credentials (plain or encrypted)
+- OAuth tokens or session identifiers
+- Webhook signing secrets
+- `~/.zeroclaw/.secret_key` or similar key files
+- Personal identifiers or real user data in tests/fixtures
+
+### If a Secret Is Committed Accidentally
+
+1. Revoke/rotate the credential immediately
+2. Do not rely only on `git revert` (history still contains the secret)
+3. Purge history with `git filter-repo` or BFG
+4. Force-push cleaned history (coordinate with maintainers)
+5. Ensure the leaked value is removed from PR/issue/discussion/comment history
+
+Reference: [GitHub guide: removing sensitive data from a repository](https://docs.github.com/en/authentication/keeping-your-account-and-data-secure/removing-sensitive-data-from-a-repository)
+
## Collaboration Tracks (Risk-Based)
To keep review throughput high without lowering quality, every PR should map to one track:
diff --git a/Cargo.lock b/Cargo.lock
index d940f9f..e19c5c9 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -209,6 +209,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8b52af3cb4058c895d37317bb27508dccc8e5f2d39454016b297bf4a400597b8"
dependencies = [
"axum-core",
+ "base64",
"bytes",
"form_urlencoded",
"futures-util",
@@ -227,8 +228,10 @@ dependencies = [
"serde_json",
"serde_path_to_error",
"serde_urlencoded",
+ "sha1",
"sync_wrapper",
"tokio",
+ "tokio-tungstenite 0.28.0",
"tower",
"tower-layer",
"tower-service",
@@ -2057,6 +2060,15 @@ dependencies = [
"hashify",
]
+[[package]]
+name = "matchers"
+version = "0.2.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "d1525a2a28c7f4fa0fc98bb91ae755d1e2d1505079e05539e35bc876b5d65ae9"
+dependencies = [
+ "regex-automata",
+]
+
[[package]]
name = "matchit"
version = "0.8.4"
@@ -3747,10 +3759,22 @@ dependencies = [
"rustls-pki-types",
"tokio",
"tokio-rustls",
- "tungstenite",
+ "tungstenite 0.24.0",
"webpki-roots 0.26.11",
]
+[[package]]
+name = "tokio-tungstenite"
+version = "0.28.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "d25a406cddcc431a75d3d9afc6a7c0f7428d4891dd973e4d54c56b46127bf857"
+dependencies = [
+ "futures-util",
+ "log",
+ "tokio",
+ "tungstenite 0.28.0",
+]
+
[[package]]
name = "tokio-util"
version = "0.7.18"
@@ -3940,9 +3964,13 @@ version = "0.3.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2f30143827ddab0d256fd843b7a66d164e9f271cfa0dde49142c5ca0ca291f1e"
dependencies = [
+ "matchers",
"nu-ansi-term",
+ "once_cell",
+ "regex-automata",
"sharded-slab",
"thread_local",
+ "tracing",
"tracing-core",
]
@@ -3978,6 +4006,23 @@ dependencies = [
"utf-8",
]
+[[package]]
+name = "tungstenite"
+version = "0.28.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "8628dcc84e5a09eb3d8423d6cb682965dea9133204e8fb3efee74c2a0c259442"
+dependencies = [
+ "bytes",
+ "data-encoding",
+ "http 1.4.0",
+ "httparse",
+ "log",
+ "rand 0.9.2",
+ "sha1",
+ "thiserror 2.0.18",
+ "utf-8",
+]
+
[[package]]
name = "twox-hash"
version = "2.1.2"
@@ -4880,7 +4925,9 @@ dependencies = [
"pdf-extract",
"probe-rs",
"prometheus",
+ "prost",
"rand 0.8.5",
+ "regex",
"reqwest",
"rppal",
"rusqlite",
@@ -4896,7 +4943,7 @@ dependencies = [
"tokio-rustls",
"tokio-serial",
"tokio-test",
- "tokio-tungstenite",
+ "tokio-tungstenite 0.24.0",
"toml",
"tower",
"tower-http",
diff --git a/Cargo.toml b/Cargo.toml
index 79dcdfe..15d4665 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -1,3 +1,7 @@
+[workspace]
+members = ["."]
+resolver = "2"
+
[package]
name = "zeroclaw"
version = "0.1.0"
@@ -31,7 +35,7 @@ shellexpand = "3.1"
# Logging - minimal
tracing = { version = "0.1", default-features = false }
-tracing-subscriber = { version = "0.3", default-features = false, features = ["fmt", "ansi"] }
+tracing-subscriber = { version = "0.3", default-features = false, features = ["fmt", "ansi", "env-filter"] }
# Observability - Prometheus metrics
prometheus = { version = "0.14", default-features = false }
@@ -63,12 +67,12 @@ rand = "0.8"
# Fast mutexes that don't poison on panic
parking_lot = "0.12"
-# Landlock (Linux sandbox) - optional dependency
-landlock = { version = "0.4", optional = true }
-
# Async traits
async-trait = "0.1"
+# Protobuf encode/decode (Feishu WS long-connection frame codec)
+prost = { version = "0.14", default-features = false }
+
# Memory / persistence
rusqlite = { version = "0.38", features = ["bundled"] }
chrono = { version = "0.4", default-features = false, features = ["clock", "std", "serde"] }
@@ -86,6 +90,7 @@ glob = "0.3"
tokio-tungstenite = { version = "0.24", features = ["rustls-tls-webpki-roots"] }
futures-util = { version = "0.3", default-features = false, features = ["sink"] }
futures = "0.3"
+regex = "1.10"
hostname = "0.4.2"
lettre = { version = "0.11.19", default-features = false, features = ["builder", "smtp-transport", "rustls-tls"] }
mail-parser = "0.11.2"
@@ -95,7 +100,7 @@ tokio-rustls = "0.26.4"
webpki-roots = "1.0.6"
# HTTP server (gateway) — replaces raw TCP for proper HTTP/1.1 compliance
-axum = { version = "0.8", default-features = false, features = ["http1", "json", "tokio", "query"] }
+axum = { version = "0.8", default-features = false, features = ["http1", "json", "tokio", "query", "ws"] }
tower = { version = "0.5", default-features = false }
tower-http = { version = "0.6", default-features = false, features = ["limit", "timeout"] }
http-body-util = "0.1"
@@ -117,19 +122,28 @@ probe-rs = { version = "0.30", optional = true }
# PDF extraction for datasheet RAG (optional, enable with --features rag-pdf)
pdf-extract = { version = "0.10", optional = true }
-# Raspberry Pi GPIO (Linux/RPi only) — target-specific to avoid compile failure on macOS
+# Raspberry Pi GPIO / Landlock (Linux only) — target-specific to avoid compile failure on macOS
[target.'cfg(target_os = "linux")'.dependencies]
rppal = { version = "0.14", optional = true }
+landlock = { version = "0.4", optional = true }
[features]
default = ["hardware"]
hardware = ["nusb", "tokio-serial"]
peripheral-rpi = ["rppal"]
+# Browser backend feature alias used by cfg(feature = "browser-native")
+browser-native = ["dep:fantoccini"]
+# Backward-compatible alias for older invocations
+fantoccini = ["browser-native"]
+# Sandbox feature aliases used by cfg(feature = "sandbox-*")
+sandbox-landlock = ["dep:landlock"]
+sandbox-bubblewrap = []
+# Backward-compatible alias for older invocations
+landlock = ["sandbox-landlock"]
# probe = probe-rs for Nucleo memory read (adds ~50 deps; optional)
probe = ["dep:probe-rs"]
# rag-pdf = PDF ingestion for datasheet RAG
rag-pdf = ["dep:pdf-extract"]
-
[profile.release]
opt-level = "z" # Optimize for size
lto = "thin" # Lower memory use during release builds
diff --git a/LICENSE b/LICENSE
index 9d0e27e..349c342 100644
--- a/LICENSE
+++ b/LICENSE
@@ -1,197 +1,28 @@
- Apache License
- Version 2.0, January 2004
- http://www.apache.org/licenses/
+MIT License
- TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+Copyright (c) 2025 ZeroClaw Labs
- 1. Definitions.
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
- "License" shall mean the terms and conditions for use, reproduction,
- and distribution as defined by Sections 1 through 9 of this document.
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
- "Licensor" shall mean the copyright owner or entity authorized by
- the copyright owner that is granting the License.
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
- "Legal Entity" shall mean the union of the acting entity and all
- other entities that control, are controlled by, or are under common
- control with that entity. For the purposes of this definition,
- "control" means (i) the power, direct or indirect, to cause the
- direction or management of such entity, whether by contract or
- otherwise, or (ii) ownership of fifty percent (50%) or more of the
- outstanding shares, or (iii) beneficial ownership of such entity.
+================================================================================
- "You" (or "Your") shall mean an individual or Legal Entity
- exercising permissions granted by this License.
+This product includes software developed by ZeroClaw Labs and contributors:
+https://github.com/zeroclaw-labs/zeroclaw/graphs/contributors
- "Source" form shall mean the preferred form for making modifications,
- including but not limited to software source code, documentation
- source, and configuration files.
-
- "Object" form shall mean any form resulting from mechanical
- transformation or translation of a Source form, including but
- not limited to compiled object code, generated documentation,
- and conversions to other media types.
-
- "Work" shall mean the work of authorship, whether in Source or
- Object form, made available under the License, as indicated by a
- copyright notice that is included in or attached to the work
- (an example is provided in the Appendix below).
-
- "Derivative Works" shall mean any work, whether in Source or Object
- form, that is based on (or derived from) the Work and for which the
- editorial revisions, annotations, elaborations, or other modifications
- represent, as a whole, an original work of authorship. For the purposes
- of this License, Derivative Works shall not include works that remain
- separable from, or merely link (or bind by name) to the interfaces of,
- the Work and Derivative Works thereof.
-
- "Contribution" shall mean any work of authorship, including
- the original version of the Work and any modifications or additions
- to that Work or Derivative Works thereof, that is intentionally
- submitted to the Licensor for inclusion in the Work by the copyright owner
- or by an individual or Legal Entity authorized to submit on behalf of
- the copyright owner. For the purposes of this definition, "submitted"
- means any form of electronic, verbal, or written communication sent
- to the Licensor or its representatives, including but not limited to
- communication on electronic mailing lists, source code control systems,
- and issue tracking systems that are managed by, or on behalf of, the
- Licensor for the purpose of discussing and improving the Work, but
- excluding communication that is conspicuously marked or otherwise
- designated in writing by the copyright owner as "Not a Contribution."
-
- "Contributor" shall mean Licensor and any individual or Legal Entity
- on behalf of whom a Contribution has been received by Licensor and
- subsequently incorporated within the Work.
-
- 2. Grant of Copyright License. Subject to the terms and conditions of
- this License, each Contributor hereby grants to You a perpetual,
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
- copyright license to reproduce, prepare Derivative Works of,
- publicly display, publicly perform, sublicense, and distribute the
- Work and such Derivative Works in Source or Object form.
-
- 3. Grant of Patent License. Subject to the terms and conditions of
- this License, each Contributor hereby grants to You a perpetual,
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
- (except as stated in this section) patent license to make, have made,
- use, offer to sell, sell, import, and otherwise transfer the Work,
- where such license applies only to those patent claims licensable
- by such Contributor that are necessarily infringed by their
- Contribution(s) alone or by combination of their Contribution(s)
- with the Work to which such Contribution(s) was submitted. If You
- institute patent litigation against any entity (including a
- cross-claim or counterclaim in a lawsuit) alleging that the Work
- or a Contribution incorporated within the Work constitutes direct
- or contributory patent infringement, then any patent licenses
- granted to You under this License for that Work shall terminate
- as of the date such litigation is filed.
-
- 4. Redistribution. You may reproduce and distribute copies of the
- Work or Derivative Works thereof in any medium, with or without
- modifications, and in Source or Object form, provided that You
- meet the following conditions:
-
- (a) You must give any other recipients of the Work or
- Derivative Works a copy of this License; and
-
- (b) You must cause any modified files to carry prominent notices
- stating that You changed the files; and
-
- (c) You must retain, in the Source form of any Derivative Works
- that You distribute, all copyright, patent, trademark, and
- attribution notices from the Source form of the Work,
- excluding those notices that do not pertain to any part of
- the Derivative Works; and
-
- (d) If the Work includes a "NOTICE" text file as part of its
- distribution, then any Derivative Works that You distribute must
- include a readable copy of the attribution notices contained
- within such NOTICE file, excluding those notices that do not
- pertain to any part of the Derivative Works, in at least one
- of the following places: within a NOTICE text file distributed
- as part of the Derivative Works; within the Source form or
- documentation, if provided along with the Derivative Works; or,
- within a display generated by the Derivative Works, if and
- wherever such third-party notices normally appear. The contents
- of the NOTICE file are for informational purposes only and
- do not modify the License. You may add Your own attribution
- notices within Derivative Works that You distribute, alongside
- or as an addendum to the NOTICE text from the Work, provided
- that such additional attribution notices cannot be construed
- as modifying the License.
-
- You may add Your own copyright statement to Your modifications and
- may provide additional or different license terms and conditions
- for use, reproduction, or distribution of Your modifications, or
- for any such Derivative Works as a whole, provided Your use,
- reproduction, and distribution of the Work otherwise complies with
- the conditions stated in this License.
-
- 5. Submission of Contributions. Unless You explicitly state otherwise,
- any Contribution intentionally submitted for inclusion in the Work
- by You to the Licensor shall be under the terms and conditions of
- this License, without any additional terms or conditions.
- Notwithstanding the above, nothing herein shall supersede or modify
- the terms of any separate license agreement you may have executed
- with Licensor regarding such Contributions.
-
- 6. Trademarks. This License does not grant permission to use the trade
- names, trademarks, service marks, or product names of the Licensor,
- except as required for reasonable and customary use in describing the
- origin of the Work and reproducing the content of the NOTICE file.
-
- 7. Disclaimer of Warranty. Unless required by applicable law or
- agreed to in writing, Licensor provides the Work (and each
- Contributor provides its Contributions) on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
- implied, including, without limitation, any warranties or conditions
- of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
- PARTICULAR PURPOSE. You are solely responsible for determining the
- appropriateness of using or redistributing the Work and assume any
- risks associated with Your exercise of permissions under this License.
-
- 8. Limitation of Liability. In no event and under no legal theory,
- whether in tort (including negligence), contract, or otherwise,
- unless required by applicable law (such as deliberate and grossly
- negligent acts) or agreed to in writing, shall any Contributor be
- liable to You for damages, including any direct, indirect, special,
- incidental, or consequential damages of any character arising as a
- result of this License or out of the use or inability to use the
- Work (including but not limited to damages for loss of goodwill,
- work stoppage, computer failure or malfunction, or any and all
- other commercial damages or losses), even if such Contributor
- has been advised of the possibility of such damages.
-
- 9. Accepting Warranty or Additional Liability. While redistributing
- the Work or Derivative Works thereof, You may choose to offer,
- and charge a fee for, acceptance of support, warranty, indemnity,
- or other liability obligations and/or rights consistent with this
- License. However, in accepting such obligations, You may act only
- on Your own behalf and on Your sole responsibility, not on behalf
- of any other Contributor, and only if You agree to indemnify,
- defend, and hold each Contributor harmless for any liability
- incurred by, or claims asserted against, such Contributor by reason
- of your accepting any such warranty or additional liability.
-
- END OF TERMS AND CONDITIONS
-
- Copyright 2025-2026 Argenis Delarosa
-
- Licensed under the Apache License, Version 2.0 (the "License");
- you may not use this file except in compliance with the License.
- You may obtain a copy of the License at
-
- http://www.apache.org/licenses/LICENSE-2.0
-
- Unless required by applicable law or agreed to in writing, software
- distributed under the License is distributed on an "AS IS" BASIS,
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- See the License for the specific language governing permissions and
- limitations under the License.
-
- ===============================================================================
-
- This product includes software developed by ZeroClaw Labs and contributors:
- https://github.com/zeroclaw-labs/zeroclaw/graphs/contributors
-
- See NOTICE file for full contributor attribution.
+See NOTICE file for full contributor attribution.
diff --git a/README.md b/README.md
index 9031482..ec87d47 100644
--- a/README.md
+++ b/README.md
@@ -10,14 +10,14 @@
-
+
Fast, small, and fully autonomous AI assistant infrastructure — deploy anywhere, swap anything.
```
-~3.4MB binary · <10ms startup · 1,017 tests · 22+ providers · 8 traits · Pluggable everything
+~3.4MB binary · <10ms startup · 1,017 tests · 23+ providers · 8 traits · Pluggable everything
```
### ✨ Features
@@ -132,6 +132,9 @@ cd zeroclaw
cargo build --release --locked
cargo install --path . --force --locked
+# Ensure ~/.cargo/bin is in your PATH
+export PATH="$HOME/.cargo/bin:$PATH"
+
# Quick setup (no prompts)
zeroclaw onboard --api-key sk-... --provider openrouter
@@ -187,7 +190,7 @@ Every subsystem is a **trait** — swap implementations with a config change, ze
| Subsystem | Trait | Ships with | Extend |
|-----------|-------|------------|--------|
-| **AI Models** | `Provider` | 22+ providers (OpenRouter, Anthropic, OpenAI, Ollama, Venice, Groq, Mistral, xAI, DeepSeek, Together, Fireworks, Perplexity, Cohere, Bedrock, etc.) | `custom:https://your-api.com` — any OpenAI-compatible API |
+| **AI Models** | `Provider` | 23+ providers (OpenRouter, Anthropic, OpenAI, Ollama, Venice, Groq, Mistral, xAI, DeepSeek, Together, Fireworks, Perplexity, Cohere, Bedrock, Astrai, etc.) | `custom:https://your-api.com` — any OpenAI-compatible API |
| **Channels** | `Channel` | CLI, Telegram, Discord, Slack, iMessage, Matrix, WhatsApp, Webhook | Any messaging API |
| **Memory** | `Memory` | SQLite with hybrid search (FTS5 + vector cosine similarity), Lucid bridge (CLI sync + SQLite fallback), Markdown | Any persistence backend |
| **Tools** | `Tool` | shell, file_read, file_write, memory_store, memory_recall, memory_forget, browser_open (Brave + allowlist), browser (agent-browser / rust-native), composio (optional) | Any capability |
@@ -287,6 +290,21 @@ rerun channel setup only:
zeroclaw onboard --channels-only
```
+### Telegram media replies
+
+Telegram routing now replies to the source **chat ID** from incoming updates (instead of usernames),
+which avoids `Bad Request: chat not found` failures.
+
+For non-text replies, ZeroClaw can send Telegram attachments when the assistant includes markers:
+
+- `[IMAGE:]`
+- `[DOCUMENT:]`
+- `[VIDEO:]`
+- `[AUDIO:]`
+- `[VOICE:]`
+
+Paths can be local files (for example `/tmp/screenshot.png`) or HTTPS URLs.
+
### WhatsApp Business Cloud API Setup
WhatsApp uses Meta's Cloud API with webhooks (push-based, not polling):
@@ -610,7 +628,7 @@ We're building in the open because the best ideas come from everywhere. If you'r
## License
-Apache 2.0 — see [LICENSE](LICENSE) and [NOTICE](NOTICE) for contributor attribution
+MIT — see [LICENSE](LICENSE) and [NOTICE](NOTICE) for contributor attribution
## Contributing
@@ -624,7 +642,6 @@ See [CONTRIBUTING.md](CONTRIBUTING.md). Implement a trait, submit a PR:
- New `Tunnel` → `src/tunnel/`
- New `Skill` → `~/.zeroclaw/workspace/skills//`
-
---
**ZeroClaw** — Zero overhead. Zero compromise. Deploy anywhere. Swap anything. 🦀
diff --git a/dev/sandbox/Dockerfile b/dev/sandbox/Dockerfile
index 59ddf05..6b81a7a 100644
--- a/dev/sandbox/Dockerfile
+++ b/dev/sandbox/Dockerfile
@@ -1,4 +1,4 @@
-FROM ubuntu:22.04
+FROM ubuntu:22.04@sha256:c7eb020043d8fc2ae0793fb35a37bff1cf33f156d4d4b12ccc7f3ef8706c38b1
# Prevent interactive prompts during package installation
ENV DEBIAN_FRONTEND=noninteractive
diff --git a/docs/ci-map.md b/docs/ci-map.md
index 108a9d0..6a2260d 100644
--- a/docs/ci-map.md
+++ b/docs/ci-map.md
@@ -27,7 +27,7 @@ Merge-blocking checks should stay small and deterministic. Optional checks are u
### Optional Repository Automation
- `.github/workflows/labeler.yml` (`PR Labeler`)
- - Purpose: scope/path labels + size/risk labels + fine-grained module labels (`:`)
+ - Purpose: scope/path labels + size/risk labels + fine-grained module labels (`: `)
- Additional behavior: label descriptions are auto-managed as hover tooltips to explain each auto-judgment rule
- Additional behavior: provider-related keywords in provider/config/onboard/integration changes are promoted to `provider:*` labels (for example `provider:kimi`, `provider:deepseek`)
- Additional behavior: hierarchical de-duplication keeps only the most specific scope labels (for example `tool:composio` suppresses `tool:core` and `tool`)
diff --git a/docs/pr-workflow.md b/docs/pr-workflow.md
index 3c62711..2c154ef 100644
--- a/docs/pr-workflow.md
+++ b/docs/pr-workflow.md
@@ -244,7 +244,7 @@ Label discipline:
- Path labels identify subsystem ownership quickly.
- Size labels drive batching strategy.
- Risk labels drive review depth (`risk: low/medium/high`).
-- Module labels (`:`) improve reviewer routing for integration-specific changes and future newly-added modules.
+- Module labels (`: `) improve reviewer routing for integration-specific changes and future newly-added modules.
- `risk: manual` allows maintainers to preserve a human risk judgment when automation lacks context.
- `no-stale` is reserved for accepted-but-blocked work.
diff --git a/docs/reviewer-playbook.md b/docs/reviewer-playbook.md
index bc42509..6f72fea 100644
--- a/docs/reviewer-playbook.md
+++ b/docs/reviewer-playbook.md
@@ -14,7 +14,7 @@ Use it to reduce review latency without reducing quality.
For every new PR, do a fast intake pass:
1. Confirm template completeness (`summary`, `validation`, `security`, `rollback`).
-2. Confirm labels (`size:*`, `risk:*`, scope labels such as `provider`/`channel`/`security`, module-scoped labels such as `channel:*`/`provider:*`/`tool:*`, and contributor tier labels when applicable) are present and plausible.
+2. Confirm labels (`size:*`, `risk:*`, scope labels such as `provider`/`channel`/`security`, module-scoped labels such as `channel: *`/`provider: *`/`tool: *`, and contributor tier labels when applicable) are present and plausible.
3. Confirm CI signal status (`CI Required Gate`).
4. Confirm scope is one concern (reject mixed mega-PRs unless justified).
5. Confirm privacy/data-hygiene and neutral test wording requirements are satisfied.
diff --git a/examples/custom_channel.rs b/examples/custom_channel.rs
index dd3fdf8..790762d 100644
--- a/examples/custom_channel.rs
+++ b/examples/custom_channel.rs
@@ -12,6 +12,8 @@ use tokio::sync::mpsc;
pub struct ChannelMessage {
pub id: String,
pub sender: String,
+ /// Channel-specific reply address (e.g. Telegram chat_id, Discord channel_id).
+ pub reply_to: String,
pub content: String,
pub channel: String,
pub timestamp: u64,
@@ -90,9 +92,12 @@ impl Channel for TelegramChannel {
continue;
}
+ let chat_id = msg["chat"]["id"].to_string();
+
let channel_msg = ChannelMessage {
id: msg["message_id"].to_string(),
sender,
+ reply_to: chat_id,
content: msg["text"].as_str().unwrap_or("").to_string(),
channel: "telegram".into(),
timestamp: msg["date"].as_u64().unwrap_or(0),
diff --git a/firmware/zeroclaw-esp32/.cargo/config.toml b/firmware/zeroclaw-esp32/.cargo/config.toml
index 8746ad1..56dd71b 100644
--- a/firmware/zeroclaw-esp32/.cargo/config.toml
+++ b/firmware/zeroclaw-esp32/.cargo/config.toml
@@ -2,4 +2,10 @@
target = "riscv32imc-esp-espidf"
[target.riscv32imc-esp-espidf]
+linker = "ldproxy"
runner = "espflash flash --monitor"
+# ESP-IDF 5.x uses 64-bit time_t
+rustflags = ["-C", "default-linker-libraries", "--cfg", "espidf_time64"]
+
+[unstable]
+build-std = ["std", "panic_abort"]
diff --git a/firmware/zeroclaw-esp32/Cargo.lock b/firmware/zeroclaw-esp32/Cargo.lock
index 2580883..69e989b 100644
--- a/firmware/zeroclaw-esp32/Cargo.lock
+++ b/firmware/zeroclaw-esp32/Cargo.lock
@@ -58,24 +58,22 @@ checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8"
[[package]]
name = "bindgen"
-version = "0.63.0"
+version = "0.71.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "36d860121800b2a9a94f9b5604b332d5cffb234ce17609ea479d723dbc9d3885"
+checksum = "5f58bf3d7db68cfbac37cfc485a8d711e87e064c3d0fe0435b92f7a407f9d6b3"
dependencies = [
- "bitflags 1.3.2",
+ "bitflags 2.11.0",
"cexpr",
"clang-sys",
- "lazy_static",
- "lazycell",
+ "itertools",
"log",
- "peeking_take_while",
+ "prettyplease",
"proc-macro2",
"quote",
"regex",
"rustc-hash",
"shlex",
- "syn 1.0.109",
- "which",
+ "syn 2.0.116",
]
[[package]]
@@ -374,14 +372,15 @@ checksum = "dc2d050bdc5c21e0862a89256ed8029ae6c290a93aecefc73084b3002cdebb01"
[[package]]
name = "embassy-sync"
-version = "0.5.0"
+version = "0.7.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "dd938f25c0798db4280fcd8026bf4c2f48789aebf8f77b6e5cf8a7693ba114ec"
+checksum = "73974a3edbd0bd286759b3d483540f0ebef705919a5f56f4fc7709066f71689b"
dependencies = [
"cfg-if",
"critical-section",
"embedded-io-async",
- "futures-util",
+ "futures-core",
+ "futures-sink",
"heapless",
]
@@ -446,16 +445,15 @@ dependencies = [
[[package]]
name = "embedded-svc"
-version = "0.27.1"
+version = "0.28.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "ac6f87e7654f28018340aa55f933803017aefabaa5417820a3b2f808033c7bbc"
+checksum = "a7770e30ab55cfbf954c00019522490d6ce26a3334bede05a732ba61010e98e0"
dependencies = [
"defmt 0.3.100",
"embedded-io",
"embedded-io-async",
"enumset",
"heapless",
- "no-std-net",
"num_enum",
"serde",
"strum 0.25.0",
@@ -463,9 +461,9 @@ dependencies = [
[[package]]
name = "embuild"
-version = "0.31.4"
+version = "0.33.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "4caa4f198bb9152a55c0103efb83fa4edfcbb8625f4c9e94ae8ec8e23827c563"
+checksum = "e188ad2bbe82afa841ea4a29880651e53ab86815db036b2cb9f8de3ac32dad75"
dependencies = [
"anyhow",
"bindgen",
@@ -475,6 +473,7 @@ dependencies = [
"globwalk",
"home",
"log",
+ "regex",
"remove_dir_all",
"serde",
"serde_json",
@@ -533,9 +532,8 @@ dependencies = [
[[package]]
name = "esp-idf-hal"
-version = "0.43.1"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "f7adf3fb19a9ca016cbea1ab8a7b852ac69df8fcde4923c23d3b155efbc42a74"
+version = "0.45.2"
+source = "git+https://github.com/esp-rs/esp-idf-hal#bc48639bd626c72afc1e25e5d497b5c639161d30"
dependencies = [
"atomic-waker",
"embassy-sync",
@@ -552,14 +550,12 @@ dependencies = [
"heapless",
"log",
"nb 1.1.0",
- "num_enum",
]
[[package]]
name = "esp-idf-svc"
-version = "0.48.1"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "2180642ca122a7fec1ec417a9b1a77aa66aaa067fdf1daae683dd8caba84f26b"
+version = "0.51.0"
+source = "git+https://github.com/esp-rs/esp-idf-svc#dee202f146c7681e54eabbf118a216fc0195d203"
dependencies = [
"embassy-futures",
"embedded-hal-async",
@@ -567,6 +563,7 @@ dependencies = [
"embuild",
"enumset",
"esp-idf-hal",
+ "futures-io",
"heapless",
"log",
"num_enum",
@@ -575,14 +572,13 @@ dependencies = [
[[package]]
name = "esp-idf-sys"
-version = "0.34.1"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "2e148f97c04ed3e9181a08bcdc9560a515aad939b0ba7f50a0022e294665e0af"
+version = "0.36.1"
+source = "git+https://github.com/esp-rs/esp-idf-sys#64667a38fb8004e1fc3b032488af6857ca3cd849"
dependencies = [
"anyhow",
- "bindgen",
"build-time",
"cargo_metadata",
+ "cmake",
"const_format",
"embuild",
"envy",
@@ -649,21 +645,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7e3450815272ef58cec6d564423f6e755e25379b217b0bc688e295ba24df6b1d"
[[package]]
-name = "futures-task"
+name = "futures-io"
version = "0.3.32"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "037711b3d59c33004d3856fbdc83b99d4ff37a24768fa1be9ce3538a1cde4393"
+checksum = "cecba35d7ad927e23624b22ad55235f2239cfa44fd10428eecbeba6d6a717718"
[[package]]
-name = "futures-util"
+name = "futures-sink"
version = "0.3.32"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "389ca41296e6190b48053de0321d02a77f32f8a5d2461dd38762c0593805c6d6"
-dependencies = [
- "futures-core",
- "futures-task",
- "pin-project-lite",
-]
+checksum = "c39754e157331b013978ec91992bde1ac089843443c49cbc7f46150b0fad0893"
[[package]]
name = "getrandom"
@@ -827,6 +818,15 @@ dependencies = [
"serde_core",
]
+[[package]]
+name = "itertools"
+version = "0.13.0"
+source = "registry+https://github.com/rust-lang/crates.io-index"
+checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186"
+dependencies = [
+ "either",
+]
+
[[package]]
name = "itoa"
version = "1.0.17"
@@ -843,18 +843,6 @@ dependencies = [
"wasm-bindgen",
]
-[[package]]
-name = "lazy_static"
-version = "1.5.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe"
-
-[[package]]
-name = "lazycell"
-version = "1.3.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55"
-
[[package]]
name = "leb128fmt"
version = "0.1.0"
@@ -945,12 +933,6 @@ dependencies = [
"libc",
]
-[[package]]
-name = "no-std-net"
-version = "0.5.0"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "1bcece43b12349917e096cddfa66107277f123e6c96a5aea78711dc601a47152"
-
[[package]]
name = "nom"
version = "7.1.3"
@@ -1007,18 +989,6 @@ version = "1.21.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d"
-[[package]]
-name = "peeking_take_while"
-version = "0.1.2"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "19b17cddbe7ec3f8bc800887bab5e717348c95ea2ca0b1bf0837fb964dc67099"
-
-[[package]]
-name = "pin-project-lite"
-version = "0.2.16"
-source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "3b3cff922bd51709b605d9ead9aa71031d81447142d828eb4a6eba76fe619f9b"
-
[[package]]
name = "prettyplease"
version = "0.2.37"
@@ -1138,9 +1108,9 @@ dependencies = [
[[package]]
name = "rustc-hash"
-version = "1.1.0"
+version = "2.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
-checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2"
+checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d"
[[package]]
name = "rustix"
diff --git a/firmware/zeroclaw-esp32/Cargo.toml b/firmware/zeroclaw-esp32/Cargo.toml
index 70d2611..2ec056f 100644
--- a/firmware/zeroclaw-esp32/Cargo.toml
+++ b/firmware/zeroclaw-esp32/Cargo.toml
@@ -14,15 +14,21 @@ edition = "2021"
license = "MIT"
description = "ZeroClaw ESP32 peripheral firmware — GPIO over JSON serial"
+[patch.crates-io]
+# Use latest esp-rs crates to fix u8/i8 char pointer compatibility with ESP-IDF 5.x
+esp-idf-sys = { git = "https://github.com/esp-rs/esp-idf-sys" }
+esp-idf-hal = { git = "https://github.com/esp-rs/esp-idf-hal" }
+esp-idf-svc = { git = "https://github.com/esp-rs/esp-idf-svc" }
+
[dependencies]
-esp-idf-svc = "0.48"
+esp-idf-svc = { git = "https://github.com/esp-rs/esp-idf-svc" }
log = "0.4"
anyhow = "1.0"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
[build-dependencies]
-embuild = "0.31"
+embuild = { version = "0.33", features = ["espidf"] }
[profile.release]
opt-level = "s"
diff --git a/firmware/zeroclaw-esp32/README.md b/firmware/zeroclaw-esp32/README.md
index 804aaca..f4b2c08 100644
--- a/firmware/zeroclaw-esp32/README.md
+++ b/firmware/zeroclaw-esp32/README.md
@@ -2,8 +2,11 @@
Peripheral firmware for ESP32 — speaks the same JSON-over-serial protocol as the STM32 firmware. Flash this to your ESP32, then configure ZeroClaw on the host to connect via serial.
+**New to this?** See [SETUP.md](SETUP.md) for step-by-step commands and troubleshooting.
+
## Protocol
+
- **Request** (host → ESP32): `{"id":"1","cmd":"gpio_write","args":{"pin":13,"value":1}}\n`
- **Response** (ESP32 → host): `{"id":"1","ok":true,"result":"done"}\n`
@@ -11,19 +14,44 @@ Commands: `gpio_read`, `gpio_write`.
## Prerequisites
-1. **ESP toolchain** (espup):
+1. **RISC-V ESP-IDF** (ESP32-C2/C3): Uses nightly Rust with `build-std`.
+
+ **Python**: ESP-IDF requires Python 3.10–3.13 (not 3.14). If you have Python 3.14:
+ ```sh
+ brew install python@3.12
+ ```
+
+ **virtualenv** (needed by ESP-IDF tools; PEP 668 workaround on macOS):
+ ```sh
+ /opt/homebrew/opt/python@3.12/bin/python3.12 -m pip install virtualenv --break-system-packages
+ ```
+
+ **Rust tools**:
+ ```sh
+ cargo install espflash ldproxy
+ ```
+
+ The project's `rust-toolchain.toml` pins nightly + rust-src. `esp-idf-sys` downloads ESP-IDF automatically on first build. Use Python 3.12 for the build:
+ ```sh
+ export PATH="/opt/homebrew/opt/python@3.12/libexec/bin:$PATH"
+ ```
+
+2. **Xtensa targets** (ESP32, ESP32-S2, ESP32-S3): Use espup instead:
```sh
cargo install espup espflash
espup install
- source ~/export-esp.sh # or ~/export-esp.fish for Fish
+ source ~/export-esp.sh
```
-
-2. **Target**: ESP32-C3 (RISC-V) by default. Edit `.cargo/config.toml` for other targets (e.g. `xtensa-esp32-espidf` for original ESP32).
+ Then edit `.cargo/config.toml` to change the target (e.g. `xtensa-esp32-espidf`).
## Build & Flash
```sh
cd firmware/zeroclaw-esp32
+# Use Python 3.12 (required if you have 3.14)
+export PATH="/opt/homebrew/opt/python@3.12/libexec/bin:$PATH"
+# Optional: pin MCU (esp32c3 or esp32c2)
+export MCU=esp32c3
cargo build --release
espflash flash target/riscv32imc-esp-espidf/release/zeroclaw-esp32 --monitor
```
diff --git a/firmware/zeroclaw-esp32/SETUP.md b/firmware/zeroclaw-esp32/SETUP.md
new file mode 100644
index 0000000..0624f4d
--- /dev/null
+++ b/firmware/zeroclaw-esp32/SETUP.md
@@ -0,0 +1,156 @@
+# ESP32 Firmware Setup Guide
+
+Step-by-step setup for building the ZeroClaw ESP32 firmware. Follow this if you run into issues.
+
+## Quick Start (copy-paste)
+
+```sh
+# 1. Install Python 3.12 (ESP-IDF needs 3.10–3.13, not 3.14)
+brew install python@3.12
+
+# 2. Install virtualenv (PEP 668 workaround on macOS)
+/opt/homebrew/opt/python@3.12/bin/python3.12 -m pip install virtualenv --break-system-packages
+
+# 3. Install Rust tools
+cargo install espflash ldproxy
+
+# 4. Build
+cd firmware/zeroclaw-esp32
+export PATH="/opt/homebrew/opt/python@3.12/libexec/bin:$PATH"
+cargo build --release
+
+# 5. Flash (connect ESP32 via USB)
+espflash flash target/riscv32imc-esp-espidf/release/zeroclaw-esp32 --monitor
+```
+
+---
+
+## Detailed Steps
+
+### 1. Python
+
+ESP-IDF requires Python 3.10–3.13. **Python 3.14 is not supported.**
+
+```sh
+brew install python@3.12
+```
+
+### 2. virtualenv
+
+ESP-IDF tools need `virtualenv`. On macOS with Homebrew Python, PEP 668 blocks `pip install`; use:
+
+```sh
+/opt/homebrew/opt/python@3.12/bin/python3.12 -m pip install virtualenv --break-system-packages
+```
+
+### 3. Rust Tools
+
+```sh
+cargo install espflash ldproxy
+```
+
+- **espflash**: flash and monitor
+- **ldproxy**: linker for ESP-IDF builds
+
+### 4. Use Python 3.12 for Builds
+
+Before every build (or add to `~/.zshrc`):
+
+```sh
+export PATH="/opt/homebrew/opt/python@3.12/libexec/bin:$PATH"
+```
+
+### 5. Build
+
+```sh
+cd firmware/zeroclaw-esp32
+cargo build --release
+```
+
+First build downloads and compiles ESP-IDF (~5–15 min).
+
+### 6. Flash
+
+```sh
+espflash flash target/riscv32imc-esp-espidf/release/zeroclaw-esp32 --monitor
+```
+
+---
+
+## Troubleshooting
+
+### "No space left on device"
+
+Free disk space. Common targets:
+
+```sh
+# Cargo cache (often 5–20 GB)
+rm -rf ~/.cargo/registry/cache ~/.cargo/registry/src
+
+# Unused Rust toolchains
+rustup toolchain list
+rustup toolchain uninstall
+
+# iOS Simulator runtimes (~35 GB)
+xcrun simctl delete unavailable
+
+# Temp files
+rm -rf /var/folders/*/T/cargo-install*
+```
+
+### "can't find crate for `core`" / "riscv32imc-esp-espidf target may not be installed"
+
+This project uses **nightly Rust with build-std**, not espup. Ensure:
+
+- `rust-toolchain.toml` exists (pins nightly + rust-src)
+- You are **not** sourcing `~/export-esp.sh` (that's for Xtensa targets)
+- Run `cargo build` from `firmware/zeroclaw-esp32`
+
+### "externally-managed-environment" / "No module named 'virtualenv'"
+
+Install virtualenv with the PEP 668 workaround:
+
+```sh
+/opt/homebrew/opt/python@3.12/bin/python3.12 -m pip install virtualenv --break-system-packages
+```
+
+### "expected `i64`, found `i32`" (time_t mismatch)
+
+Already fixed in `.cargo/config.toml` with `espidf_time64` for ESP-IDF 5.x. If you use ESP-IDF 4.4, switch to `espidf_time32`.
+
+### "expected `*const u8`, found `*const i8`" (esp-idf-svc)
+
+Already fixed via `[patch.crates-io]` in `Cargo.toml` using esp-rs crates from git. Do not remove the patch.
+
+### 10,000+ files in `git status`
+
+The `.embuild/` directory (ESP-IDF cache) has ~100k+ files. It is in `.gitignore`. If you see them, ensure `.gitignore` contains:
+
+```
+.embuild/
+```
+
+---
+
+## Optional: Auto-load Python 3.12
+
+Add to `~/.zshrc`:
+
+```sh
+# ESP32 firmware build
+export PATH="/opt/homebrew/opt/python@3.12/libexec/bin:$PATH"
+```
+
+---
+
+## Xtensa Targets (ESP32, ESP32-S2, ESP32-S3)
+
+For non–RISC-V chips, use espup instead:
+
+```sh
+cargo install espup espflash
+espup install
+source ~/export-esp.sh
+```
+
+Then edit `.cargo/config.toml` to use `xtensa-esp32-espidf` (or the correct target).
diff --git a/firmware/zeroclaw-esp32/rust-toolchain.toml b/firmware/zeroclaw-esp32/rust-toolchain.toml
new file mode 100644
index 0000000..f70d225
--- /dev/null
+++ b/firmware/zeroclaw-esp32/rust-toolchain.toml
@@ -0,0 +1,3 @@
+[toolchain]
+channel = "nightly"
+components = ["rust-src"]
diff --git a/firmware/zeroclaw-esp32/src/main.rs b/firmware/zeroclaw-esp32/src/main.rs
index b1a487c..a85b67d 100644
--- a/firmware/zeroclaw-esp32/src/main.rs
+++ b/firmware/zeroclaw-esp32/src/main.rs
@@ -6,8 +6,9 @@
//! Protocol: same as STM32 — see docs/hardware-peripherals-design.md
use esp_idf_svc::hal::gpio::PinDriver;
-use esp_idf_svc::hal::prelude::*;
-use esp_idf_svc::hal::uart::*;
+use esp_idf_svc::hal::peripherals::Peripherals;
+use esp_idf_svc::hal::uart::{UartConfig, UartDriver};
+use esp_idf_svc::hal::units::Hertz;
use log::info;
use serde::{Deserialize, Serialize};
@@ -36,9 +37,13 @@ fn main() -> anyhow::Result<()> {
let peripherals = Peripherals::take()?;
let pins = peripherals.pins;
+ // Create GPIO output drivers first (they take ownership of pins)
+ let mut gpio2 = PinDriver::output(pins.gpio2)?;
+ let mut gpio13 = PinDriver::output(pins.gpio13)?;
+
// UART0: TX=21, RX=20 (ESP32) — ESP32-C3 may use different pins; adjust for your board
let config = UartConfig::new().baudrate(Hertz(115_200));
- let mut uart = UartDriver::new(
+ let uart = UartDriver::new(
peripherals.uart0,
pins.gpio21,
pins.gpio20,
@@ -60,7 +65,8 @@ fn main() -> anyhow::Result<()> {
if b == b'\n' {
if !line.is_empty() {
if let Ok(line_str) = std::str::from_utf8(&line) {
- if let Ok(resp) = handle_request(line_str, &peripherals) {
+ if let Ok(resp) = handle_request(line_str, &mut gpio2, &mut gpio13)
+ {
let out = serde_json::to_string(&resp).unwrap_or_default();
let _ = uart.write(format!("{}\n", out).as_bytes());
}
@@ -80,10 +86,15 @@ fn main() -> anyhow::Result<()> {
}
}
-fn handle_request(
+fn handle_request(
line: &str,
- peripherals: &esp_idf_svc::hal::peripherals::Peripherals,
-) -> anyhow::Result {
+ gpio2: &mut PinDriver<'_, G2>,
+ gpio13: &mut PinDriver<'_, G13>,
+) -> anyhow::Result
+where
+ G2: esp_idf_svc::hal::gpio::OutputMode,
+ G13: esp_idf_svc::hal::gpio::OutputMode,
+{
let req: Request = serde_json::from_str(line.trim())?;
let id = req.id.clone();
@@ -98,13 +109,13 @@ fn handle_request(
}
"gpio_read" => {
let pin_num = req.args.get("pin").and_then(|v| v.as_u64()).unwrap_or(0) as i32;
- let value = gpio_read(peripherals, pin_num)?;
+ let value = gpio_read(pin_num)?;
Ok(value.to_string())
}
"gpio_write" => {
let pin_num = req.args.get("pin").and_then(|v| v.as_u64()).unwrap_or(0) as i32;
let value = req.args.get("value").and_then(|v| v.as_u64()).unwrap_or(0);
- gpio_write(peripherals, pin_num, value)?;
+ gpio_write(gpio2, gpio13, pin_num, value)?;
Ok("done".into())
}
_ => Err(anyhow::anyhow!("Unknown command: {}", req.cmd)),
@@ -126,28 +137,26 @@ fn handle_request(
}
}
-fn gpio_read(_peripherals: &esp_idf_svc::hal::peripherals::Peripherals, _pin: i32) -> anyhow::Result {
+fn gpio_read(_pin: i32) -> anyhow::Result {
// TODO: implement input pin read — requires storing InputPin drivers per pin
Ok(0)
}
-fn gpio_write(
- peripherals: &esp_idf_svc::hal::peripherals::Peripherals,
+fn gpio_write(
+ gpio2: &mut PinDriver<'_, G2>,
+ gpio13: &mut PinDriver<'_, G13>,
pin: i32,
value: u64,
-) -> anyhow::Result<()> {
- let pins = peripherals.pins;
- let level = value != 0;
+) -> anyhow::Result<()>
+where
+ G2: esp_idf_svc::hal::gpio::OutputMode,
+ G13: esp_idf_svc::hal::gpio::OutputMode,
+{
+ let level = esp_idf_svc::hal::gpio::Level::from(value != 0);
match pin {
- 2 => {
- let mut out = PinDriver::output(pins.gpio2)?;
- out.set_level(esp_idf_svc::hal::gpio::Level::from(level))?;
- }
- 13 => {
- let mut out = PinDriver::output(pins.gpio13)?;
- out.set_level(esp_idf_svc::hal::gpio::Level::from(level))?;
- }
+ 2 => gpio2.set_level(level)?,
+ 13 => gpio13.set_level(level)?,
_ => anyhow::bail!("Pin {} not configured (add to gpio_write)", pin),
}
Ok(())
diff --git a/scripts/recompute_contributor_tiers.sh b/scripts/recompute_contributor_tiers.sh
new file mode 100755
index 0000000..6e3e528
--- /dev/null
+++ b/scripts/recompute_contributor_tiers.sh
@@ -0,0 +1,324 @@
+#!/usr/bin/env bash
+
+set -euo pipefail
+
+SCRIPT_NAME="$(basename "$0")"
+
+usage() {
+ cat < Target repository (default: current gh repo)
+ --kind
+ Target objects (default: both)
+ --state
+ State filter for listing objects (default: all)
+ --limit Limit processed objects after fetch (default: 0 = no limit)
+ --apply Apply label updates (default is dry-run)
+ --dry-run Preview only (default)
+ -h, --help Show this help
+
+Examples:
+ ./$SCRIPT_NAME --repo zeroclaw-labs/zeroclaw --limit 50
+ ./$SCRIPT_NAME --repo zeroclaw-labs/zeroclaw --kind prs --state open --apply
+USAGE
+}
+
+die() {
+ echo "[$SCRIPT_NAME] ERROR: $*" >&2
+ exit 1
+}
+
+require_cmd() {
+ if ! command -v "$1" >/dev/null 2>&1; then
+ die "Required command not found: $1"
+ fi
+}
+
+urlencode() {
+ jq -nr --arg value "$1" '$value|@uri'
+}
+
+select_contributor_tier() {
+ local merged_count="$1"
+ if (( merged_count >= 50 )); then
+ echo "distinguished contributor"
+ elif (( merged_count >= 20 )); then
+ echo "principal contributor"
+ elif (( merged_count >= 10 )); then
+ echo "experienced contributor"
+ elif (( merged_count >= 5 )); then
+ echo "trusted contributor"
+ else
+ echo ""
+ fi
+}
+
+DRY_RUN=1
+KIND="both"
+STATE="all"
+LIMIT=0
+REPO=""
+
+while (($# > 0)); do
+ case "$1" in
+ --repo)
+ [[ $# -ge 2 ]] || die "Missing value for --repo"
+ REPO="$2"
+ shift 2
+ ;;
+ --kind)
+ [[ $# -ge 2 ]] || die "Missing value for --kind"
+ KIND="$2"
+ shift 2
+ ;;
+ --state)
+ [[ $# -ge 2 ]] || die "Missing value for --state"
+ STATE="$2"
+ shift 2
+ ;;
+ --limit)
+ [[ $# -ge 2 ]] || die "Missing value for --limit"
+ LIMIT="$2"
+ shift 2
+ ;;
+ --apply)
+ DRY_RUN=0
+ shift
+ ;;
+ --dry-run)
+ DRY_RUN=1
+ shift
+ ;;
+ -h|--help)
+ usage
+ exit 0
+ ;;
+ *)
+ die "Unknown option: $1"
+ ;;
+ esac
+done
+
+case "$KIND" in
+ both|prs|issues) ;;
+ *) die "--kind must be one of: both, prs, issues" ;;
+esac
+
+case "$STATE" in
+ all|open|closed) ;;
+ *) die "--state must be one of: all, open, closed" ;;
+esac
+
+if ! [[ "$LIMIT" =~ ^[0-9]+$ ]]; then
+ die "--limit must be a non-negative integer"
+fi
+
+require_cmd gh
+require_cmd jq
+
+if ! gh auth status >/dev/null 2>&1; then
+ die "gh CLI is not authenticated. Run: gh auth login"
+fi
+
+if [[ -z "$REPO" ]]; then
+ REPO="$(gh repo view --json nameWithOwner --jq '.nameWithOwner' 2>/dev/null || true)"
+ [[ -n "$REPO" ]] || die "Unable to infer repo. Pass --repo ."
+fi
+
+echo "[$SCRIPT_NAME] Repo: $REPO"
+echo "[$SCRIPT_NAME] Mode: $([[ "$DRY_RUN" -eq 1 ]] && echo "dry-run" || echo "apply")"
+echo "[$SCRIPT_NAME] Kind: $KIND | State: $STATE | Limit: $LIMIT"
+
+TIERS_JSON='["trusted contributor","experienced contributor","principal contributor","distinguished contributor"]'
+
+TMP_FILES=()
+cleanup() {
+ if ((${#TMP_FILES[@]} > 0)); then
+ rm -f "${TMP_FILES[@]}"
+ fi
+}
+trap cleanup EXIT
+
+new_tmp_file() {
+ local tmp
+ tmp="$(mktemp)"
+ TMP_FILES+=("$tmp")
+ echo "$tmp"
+}
+
+targets_file="$(new_tmp_file)"
+
+if [[ "$KIND" == "both" || "$KIND" == "prs" ]]; then
+ gh api --paginate "repos/$REPO/pulls?state=$STATE&per_page=100" \
+ --jq '.[] | {
+ kind: "pr",
+ number: .number,
+ author: (.user.login // ""),
+ author_type: (.user.type // ""),
+ labels: [(.labels[]?.name // empty)]
+ }' >> "$targets_file"
+fi
+
+if [[ "$KIND" == "both" || "$KIND" == "issues" ]]; then
+ gh api --paginate "repos/$REPO/issues?state=$STATE&per_page=100" \
+ --jq '.[] | select(.pull_request | not) | {
+ kind: "issue",
+ number: .number,
+ author: (.user.login // ""),
+ author_type: (.user.type // ""),
+ labels: [(.labels[]?.name // empty)]
+ }' >> "$targets_file"
+fi
+
+if [[ "$LIMIT" -gt 0 ]]; then
+ limited_file="$(new_tmp_file)"
+ head -n "$LIMIT" "$targets_file" > "$limited_file"
+ mv "$limited_file" "$targets_file"
+fi
+
+target_count="$(wc -l < "$targets_file" | tr -d ' ')"
+if [[ "$target_count" -eq 0 ]]; then
+ echo "[$SCRIPT_NAME] No targets found."
+ exit 0
+fi
+
+echo "[$SCRIPT_NAME] Targets fetched: $target_count"
+
+# Ensure tier labels exist (trusted contributor might be new).
+label_color=""
+for probe_label in "experienced contributor" "principal contributor" "distinguished contributor" "trusted contributor"; do
+ encoded_label="$(urlencode "$probe_label")"
+ if color_candidate="$(gh api "repos/$REPO/labels/$encoded_label" --jq '.color' 2>/dev/null || true)"; then
+ if [[ -n "$color_candidate" ]]; then
+ label_color="$(echo "$color_candidate" | tr '[:lower:]' '[:upper:]')"
+ break
+ fi
+ fi
+done
+[[ -n "$label_color" ]] || label_color="C5D7A2"
+
+while IFS= read -r tier_label; do
+ [[ -n "$tier_label" ]] || continue
+ encoded_label="$(urlencode "$tier_label")"
+ if gh api "repos/$REPO/labels/$encoded_label" >/dev/null 2>&1; then
+ continue
+ fi
+
+ if [[ "$DRY_RUN" -eq 1 ]]; then
+ echo "[dry-run] Would create missing label: $tier_label (color=$label_color)"
+ else
+ gh api -X POST "repos/$REPO/labels" \
+ -f name="$tier_label" \
+ -f color="$label_color" >/dev/null
+ echo "[apply] Created missing label: $tier_label"
+ fi
+done < <(jq -r '.[]' <<<"$TIERS_JSON")
+
+# Build merged PR count cache by unique human authors.
+authors_file="$(new_tmp_file)"
+jq -r 'select(.author != "" and .author_type != "Bot") | .author' "$targets_file" | sort -u > "$authors_file"
+author_count="$(wc -l < "$authors_file" | tr -d ' ')"
+echo "[$SCRIPT_NAME] Unique human authors: $author_count"
+
+author_counts_file="$(new_tmp_file)"
+while IFS= read -r author; do
+ [[ -n "$author" ]] || continue
+ query="repo:$REPO is:pr is:merged author:$author"
+ merged_count="$(gh api search/issues -f q="$query" -F per_page=1 --jq '.total_count' 2>/dev/null || true)"
+ if ! [[ "$merged_count" =~ ^[0-9]+$ ]]; then
+ merged_count=0
+ fi
+ printf '%s\t%s\n' "$author" "$merged_count" >> "$author_counts_file"
+done < "$authors_file"
+
+updated=0
+unchanged=0
+skipped=0
+failed=0
+
+while IFS= read -r target_json; do
+ [[ -n "$target_json" ]] || continue
+
+ number="$(jq -r '.number' <<<"$target_json")"
+ kind="$(jq -r '.kind' <<<"$target_json")"
+ author="$(jq -r '.author' <<<"$target_json")"
+ author_type="$(jq -r '.author_type' <<<"$target_json")"
+ current_labels_json="$(jq -c '.labels // []' <<<"$target_json")"
+
+ if [[ -z "$author" || "$author_type" == "Bot" ]]; then
+ skipped=$((skipped + 1))
+ continue
+ fi
+
+ merged_count="$(awk -F '\t' -v key="$author" '$1 == key { print $2; exit }' "$author_counts_file")"
+ if ! [[ "$merged_count" =~ ^[0-9]+$ ]]; then
+ merged_count=0
+ fi
+ desired_tier="$(select_contributor_tier "$merged_count")"
+
+ if ! current_tier="$(jq -r --argjson tiers "$TIERS_JSON" '[.[] | select(. as $label | ($tiers | index($label)) != null)][0] // ""' <<<"$current_labels_json" 2>/dev/null)"; then
+ echo "[warn] Skipping ${kind} #${number}: cannot parse current labels JSON" >&2
+ failed=$((failed + 1))
+ continue
+ fi
+
+ if ! next_labels_json="$(jq -c --arg desired "$desired_tier" --argjson tiers "$TIERS_JSON" '
+ (. // [])
+ | map(select(. as $label | ($tiers | index($label)) == null))
+ | if $desired != "" then . + [$desired] else . end
+ | unique
+ ' <<<"$current_labels_json" 2>/dev/null)"; then
+ echo "[warn] Skipping ${kind} #${number}: cannot compute next labels" >&2
+ failed=$((failed + 1))
+ continue
+ fi
+
+ if ! normalized_current="$(jq -c 'unique | sort' <<<"$current_labels_json" 2>/dev/null)"; then
+ echo "[warn] Skipping ${kind} #${number}: cannot normalize current labels" >&2
+ failed=$((failed + 1))
+ continue
+ fi
+
+ if ! normalized_next="$(jq -c 'unique | sort' <<<"$next_labels_json" 2>/dev/null)"; then
+ echo "[warn] Skipping ${kind} #${number}: cannot normalize next labels" >&2
+ failed=$((failed + 1))
+ continue
+ fi
+
+ if [[ "$normalized_current" == "$normalized_next" ]]; then
+ unchanged=$((unchanged + 1))
+ continue
+ fi
+
+ if [[ "$DRY_RUN" -eq 1 ]]; then
+ echo "[dry-run] ${kind} #${number} @${author} merged=${merged_count} tier: '${current_tier:-none}' -> '${desired_tier:-none}'"
+ updated=$((updated + 1))
+ continue
+ fi
+
+ payload="$(jq -cn --argjson labels "$next_labels_json" '{labels: $labels}')"
+ if gh api -X PUT "repos/$REPO/issues/$number/labels" --input - <<<"$payload" >/dev/null; then
+ echo "[apply] Updated ${kind} #${number} @${author} tier: '${current_tier:-none}' -> '${desired_tier:-none}'"
+ updated=$((updated + 1))
+ else
+ echo "[apply] FAILED ${kind} #${number}" >&2
+ failed=$((failed + 1))
+ fi
+done < "$targets_file"
+
+echo ""
+echo "[$SCRIPT_NAME] Summary"
+echo " Targets: $target_count"
+echo " Updated: $updated"
+echo " Unchanged: $unchanged"
+echo " Skipped: $skipped"
+echo " Failed: $failed"
+
+if [[ "$failed" -gt 0 ]]; then
+ exit 1
+fi
diff --git a/src/agent/agent.rs b/src/agent/agent.rs
index ca18e79..3e5693e 100644
--- a/src/agent/agent.rs
+++ b/src/agent/agent.rs
@@ -251,6 +251,7 @@ impl Agent {
let provider: Box = providers::create_routed_provider(
provider_name,
config.api_key.as_deref(),
+ config.api_url.as_deref(),
&config.reliability,
&config.model_routes,
&model_name,
@@ -388,7 +389,7 @@ impl Agent {
if self.auto_save {
let _ = self
.memory
- .store("user_msg", user_message, MemoryCategory::Conversation)
+ .store("user_msg", user_message, MemoryCategory::Conversation, None)
.await;
}
@@ -447,7 +448,7 @@ impl Agent {
let summary = truncate_with_ellipsis(&final_text, 100);
let _ = self
.memory
- .store("assistant_resp", &summary, MemoryCategory::Daily)
+ .store("assistant_resp", &summary, MemoryCategory::Daily, None)
.await;
}
@@ -557,6 +558,7 @@ pub async fn run(
agent.observer.record_event(&ObserverEvent::AgentEnd {
duration: start.elapsed(),
tokens_used: None,
+ cost_usd: None,
});
Ok(())
diff --git a/src/agent/loop_.rs b/src/agent/loop_.rs
index 47d02a6..81882d6 100644
--- a/src/agent/loop_.rs
+++ b/src/agent/loop_.rs
@@ -7,14 +7,70 @@ use crate::security::SecurityPolicy;
use crate::tools::{self, Tool};
use crate::util::truncate_with_ellipsis;
use anyhow::Result;
+use regex::{Regex, RegexSet};
use std::fmt::Write;
use std::io::Write as _;
-use std::sync::Arc;
+use std::sync::{Arc, LazyLock};
use std::time::Instant;
use uuid::Uuid;
+
/// Maximum agentic tool-use iterations per user message to prevent runaway loops.
const MAX_TOOL_ITERATIONS: usize = 10;
+static SENSITIVE_KEY_PATTERNS: LazyLock = LazyLock::new(|| {
+ RegexSet::new([
+ r"(?i)token",
+ r"(?i)api[_-]?key",
+ r"(?i)password",
+ r"(?i)secret",
+ r"(?i)user[_-]?key",
+ r"(?i)bearer",
+ r"(?i)credential",
+ ])
+ .unwrap()
+});
+
+static SENSITIVE_KV_REGEX: LazyLock = LazyLock::new(|| {
+ Regex::new(r#"(?i)(token|api[_-]?key|password|secret|user[_-]?key|bearer|credential)["']?\s*[:=]\s*(?:"([^"]{8,})"|'([^']{8,})'|([a-zA-Z0-9_\-\.]{8,}))"#).unwrap()
+});
+
+/// Scrub credentials from tool output to prevent accidental exfiltration.
+/// Replaces known credential patterns with a redacted placeholder while preserving
+/// a small prefix for context.
+fn scrub_credentials(input: &str) -> String {
+ SENSITIVE_KV_REGEX
+ .replace_all(input, |caps: ®ex::Captures| {
+ let full_match = &caps[0];
+ let key = &caps[1];
+ let val = caps
+ .get(2)
+ .or(caps.get(3))
+ .or(caps.get(4))
+ .map(|m| m.as_str())
+ .unwrap_or("");
+
+ // Preserve first 4 chars for context, then redact
+ let prefix = if val.len() > 4 { &val[..4] } else { "" };
+
+ if full_match.contains(':') {
+ if full_match.contains('"') {
+ format!("\"{}\": \"{}*[REDACTED]\"", key, prefix)
+ } else {
+ format!("{}: {}*[REDACTED]", key, prefix)
+ }
+ } else if full_match.contains('=') {
+ if full_match.contains('"') {
+ format!("{}=\"{}*[REDACTED]\"", key, prefix)
+ } else {
+ format!("{}={}*[REDACTED]", key, prefix)
+ }
+ } else {
+ format!("{}: {}*[REDACTED]", key, prefix)
+ }
+ })
+ .to_string()
+}
+
/// Trigger auto-compaction when non-system message count exceeds this threshold.
const MAX_HISTORY_MESSAGES: usize = 50;
@@ -145,7 +201,7 @@ async fn build_context(mem: &dyn Memory, user_msg: &str) -> String {
let mut context = String::new();
// Pull relevant memories for this message
- if let Ok(entries) = mem.recall(user_msg, 5).await {
+ if let Ok(entries) = mem.recall(user_msg, 5, None).await {
if !entries.is_empty() {
context.push_str("[Memory context]\n");
for entry in &entries {
@@ -436,6 +492,7 @@ struct ParsedToolCall {
/// Execute a single turn of the agent loop: send messages, parse tool calls,
/// execute tools, and loop until the LLM produces a final text response.
/// When `silent` is true, suppresses stdout (for channel use).
+#[allow(clippy::too_many_arguments)]
pub(crate) async fn agent_turn(
provider: &dyn Provider,
history: &mut Vec,
@@ -461,6 +518,7 @@ pub(crate) async fn agent_turn(
/// Execute a single turn of the agent loop: send messages, parse tool calls,
/// execute tools, and loop until the LLM produces a final text response.
+#[allow(clippy::too_many_arguments)]
pub(crate) async fn run_tool_call_loop(
provider: &dyn Provider,
history: &mut Vec,
@@ -606,7 +664,7 @@ pub(crate) async fn run_tool_call_loop(
success: r.success,
});
if r.success {
- r.output
+ scrub_credentials(&r.output)
} else {
format!("Error: {}", r.error.unwrap_or_else(|| r.output))
}
@@ -749,6 +807,7 @@ pub async fn run(
let provider: Box = providers::create_routed_provider(
provider_name,
config.api_key.as_deref(),
+ config.api_url.as_deref(),
&config.reliability,
&config.model_routes,
model_name,
@@ -912,7 +971,7 @@ pub async fn run(
if config.memory.auto_save {
let user_key = autosave_memory_key("user_msg");
let _ = mem
- .store(&user_key, &msg, MemoryCategory::Conversation)
+ .store(&user_key, &msg, MemoryCategory::Conversation, None)
.await;
}
@@ -955,7 +1014,7 @@ pub async fn run(
let summary = truncate_with_ellipsis(&response, 100);
let response_key = autosave_memory_key("assistant_resp");
let _ = mem
- .store(&response_key, &summary, MemoryCategory::Daily)
+ .store(&response_key, &summary, MemoryCategory::Daily, None)
.await;
}
} else {
@@ -978,7 +1037,7 @@ pub async fn run(
if config.memory.auto_save {
let user_key = autosave_memory_key("user_msg");
let _ = mem
- .store(&user_key, &msg.content, MemoryCategory::Conversation)
+ .store(&user_key, &msg.content, MemoryCategory::Conversation, None)
.await;
}
@@ -1036,7 +1095,7 @@ pub async fn run(
let summary = truncate_with_ellipsis(&response, 100);
let response_key = autosave_memory_key("assistant_resp");
let _ = mem
- .store(&response_key, &summary, MemoryCategory::Daily)
+ .store(&response_key, &summary, MemoryCategory::Daily, None)
.await;
}
}
@@ -1048,6 +1107,7 @@ pub async fn run(
observer.record_event(&ObserverEvent::AgentEnd {
duration,
tokens_used: None,
+ cost_usd: None,
});
Ok(final_output)
@@ -1104,6 +1164,7 @@ pub async fn process_message(config: Config, message: &str) -> Result {
let provider: Box = providers::create_routed_provider(
provider_name,
config.api_key.as_deref(),
+ config.api_url.as_deref(),
&config.reliability,
&config.model_routes,
&model_name,
@@ -1217,6 +1278,25 @@ pub async fn process_message(config: Config, message: &str) -> Result {
#[cfg(test)]
mod tests {
use super::*;
+
+ #[test]
+ fn test_scrub_credentials() {
+ let input = "API_KEY=sk-1234567890abcdef; token: 1234567890; password=\"secret123456\"";
+ let scrubbed = scrub_credentials(input);
+ assert!(scrubbed.contains("API_KEY=sk-1*[REDACTED]"));
+ assert!(scrubbed.contains("token: 1234*[REDACTED]"));
+ assert!(scrubbed.contains("password=\"secr*[REDACTED]\""));
+ assert!(!scrubbed.contains("abcdef"));
+ assert!(!scrubbed.contains("secret123456"));
+ }
+
+ #[test]
+ fn test_scrub_credentials_json() {
+ let input = r#"{"api_key": "sk-1234567890", "other": "public"}"#;
+ let scrubbed = scrub_credentials(input);
+ assert!(scrubbed.contains("\"api_key\": \"sk-1*[REDACTED]\""));
+ assert!(scrubbed.contains("public"));
+ }
use crate::memory::{Memory, MemoryCategory, SqliteMemory};
use tempfile::TempDir;
@@ -1496,16 +1576,16 @@ I will now call the tool with this payload:
let key1 = autosave_memory_key("user_msg");
let key2 = autosave_memory_key("user_msg");
- mem.store(&key1, "I'm Paul", MemoryCategory::Conversation)
+ mem.store(&key1, "I'm Paul", MemoryCategory::Conversation, None)
.await
.unwrap();
- mem.store(&key2, "I'm 45", MemoryCategory::Conversation)
+ mem.store(&key2, "I'm 45", MemoryCategory::Conversation, None)
.await
.unwrap();
assert_eq!(mem.count().await.unwrap(), 2);
- let recalled = mem.recall("45", 5).await.unwrap();
+ let recalled = mem.recall("45", 5, None).await.unwrap();
assert!(recalled.iter().any(|entry| entry.content.contains("45")));
}
diff --git a/src/agent/memory_loader.rs b/src/agent/memory_loader.rs
index f5733ec..0cc530f 100644
--- a/src/agent/memory_loader.rs
+++ b/src/agent/memory_loader.rs
@@ -33,7 +33,7 @@ impl MemoryLoader for DefaultMemoryLoader {
memory: &dyn Memory,
user_message: &str,
) -> anyhow::Result {
- let entries = memory.recall(user_message, self.limit).await?;
+ let entries = memory.recall(user_message, self.limit, None).await?;
if entries.is_empty() {
return Ok(String::new());
}
@@ -61,11 +61,17 @@ mod tests {
_key: &str,
_content: &str,
_category: MemoryCategory,
+ _session_id: Option<&str>,
) -> anyhow::Result<()> {
Ok(())
}
- async fn recall(&self, _query: &str, limit: usize) -> anyhow::Result> {
+ async fn recall(
+ &self,
+ _query: &str,
+ limit: usize,
+ _session_id: Option<&str>,
+ ) -> anyhow::Result> {
if limit == 0 {
return Ok(vec![]);
}
@@ -87,6 +93,7 @@ mod tests {
async fn list(
&self,
_category: Option<&MemoryCategory>,
+ _session_id: Option<&str>,
) -> anyhow::Result> {
Ok(vec![])
}
diff --git a/src/channels/cli.rs b/src/channels/cli.rs
index 8b414fd..46ee474 100644
--- a/src/channels/cli.rs
+++ b/src/channels/cli.rs
@@ -40,6 +40,7 @@ impl Channel for CliChannel {
let msg = ChannelMessage {
id: Uuid::new_v4().to_string(),
sender: "user".to_string(),
+ reply_target: "user".to_string(),
content: line,
channel: "cli".to_string(),
timestamp: std::time::SystemTime::now()
@@ -90,12 +91,14 @@ mod tests {
let msg = ChannelMessage {
id: "test-id".into(),
sender: "user".into(),
+ reply_target: "user".into(),
content: "hello".into(),
channel: "cli".into(),
timestamp: 1_234_567_890,
};
assert_eq!(msg.id, "test-id");
assert_eq!(msg.sender, "user");
+ assert_eq!(msg.reply_target, "user");
assert_eq!(msg.content, "hello");
assert_eq!(msg.channel, "cli");
assert_eq!(msg.timestamp, 1_234_567_890);
@@ -106,6 +109,7 @@ mod tests {
let msg = ChannelMessage {
id: "id".into(),
sender: "s".into(),
+ reply_target: "s".into(),
content: "c".into(),
channel: "ch".into(),
timestamp: 0,
diff --git a/src/channels/dingtalk.rs b/src/channels/dingtalk.rs
index f55135a..7473bb3 100644
--- a/src/channels/dingtalk.rs
+++ b/src/channels/dingtalk.rs
@@ -7,7 +7,7 @@ use tokio::sync::RwLock;
use tokio_tungstenite::tungstenite::Message;
use uuid::Uuid;
-/// DingTalk (钉钉) channel — connects via Stream Mode WebSocket for real-time messages.
+/// DingTalk channel — connects via Stream Mode WebSocket for real-time messages.
/// Replies are sent through per-message session webhook URLs.
pub struct DingTalkChannel {
client_id: String,
@@ -64,6 +64,18 @@ impl DingTalkChannel {
let gw: GatewayResponse = resp.json().await?;
Ok(gw)
}
+
+ fn resolve_reply_target(
+ sender_id: &str,
+ conversation_type: &str,
+ conversation_id: Option<&str>,
+ ) -> String {
+ if conversation_type == "1" {
+ sender_id.to_string()
+ } else {
+ conversation_id.unwrap_or(sender_id).to_string()
+ }
+ }
}
#[async_trait]
@@ -193,14 +205,11 @@ impl Channel for DingTalkChannel {
.unwrap_or("1");
// Private chat uses sender ID, group chat uses conversation ID
- let chat_id = if conversation_type == "1" {
- sender_id.to_string()
- } else {
- data.get("conversationId")
- .and_then(|c| c.as_str())
- .unwrap_or(sender_id)
- .to_string()
- };
+ let chat_id = Self::resolve_reply_target(
+ sender_id,
+ conversation_type,
+ data.get("conversationId").and_then(|c| c.as_str()),
+ );
// Store session webhook for later replies
if let Some(webhook) = data.get("sessionWebhook").and_then(|w| w.as_str()) {
@@ -229,6 +238,7 @@ impl Channel for DingTalkChannel {
let channel_msg = ChannelMessage {
id: Uuid::new_v4().to_string(),
sender: sender_id.to_string(),
+ reply_target: chat_id,
content: content.to_string(),
channel: "dingtalk".to_string(),
timestamp: std::time::SystemTime::now()
@@ -305,4 +315,22 @@ client_secret = "secret"
let config: crate::config::schema::DingTalkConfig = toml::from_str(toml_str).unwrap();
assert!(config.allowed_users.is_empty());
}
+
+ #[test]
+ fn test_resolve_reply_target_private_chat_uses_sender_id() {
+ let target = DingTalkChannel::resolve_reply_target("staff_1", "1", Some("conv_1"));
+ assert_eq!(target, "staff_1");
+ }
+
+ #[test]
+ fn test_resolve_reply_target_group_chat_uses_conversation_id() {
+ let target = DingTalkChannel::resolve_reply_target("staff_1", "2", Some("conv_1"));
+ assert_eq!(target, "conv_1");
+ }
+
+ #[test]
+ fn test_resolve_reply_target_group_chat_falls_back_to_sender_id() {
+ let target = DingTalkChannel::resolve_reply_target("staff_1", "2", None);
+ assert_eq!(target, "staff_1");
+ }
}
diff --git a/src/channels/discord.rs b/src/channels/discord.rs
index 4e99f43..7eb7502 100644
--- a/src/channels/discord.rs
+++ b/src/channels/discord.rs
@@ -11,6 +11,7 @@ pub struct DiscordChannel {
guild_id: Option,
allowed_users: Vec,
listen_to_bots: bool,
+ mention_only: bool,
client: reqwest::Client,
typing_handle: std::sync::Mutex>>,
}
@@ -21,12 +22,14 @@ impl DiscordChannel {
guild_id: Option,
allowed_users: Vec,
listen_to_bots: bool,
+ mention_only: bool,
) -> Self {
Self {
bot_token,
guild_id,
allowed_users,
listen_to_bots,
+ mention_only,
client: reqwest::Client::new(),
typing_handle: std::sync::Mutex::new(None),
}
@@ -343,6 +346,22 @@ impl Channel for DiscordChannel {
continue;
}
+ // Skip messages that don't @-mention the bot (when mention_only is enabled)
+ if self.mention_only {
+ let mention_tag = format!("<@{bot_user_id}>");
+ if !content.contains(&mention_tag) {
+ continue;
+ }
+ }
+
+ // Strip the bot mention from content so the agent sees clean text
+ let clean_content = if self.mention_only {
+ let mention_tag = format!("<@{bot_user_id}>");
+ content.replace(&mention_tag, "").trim().to_string()
+ } else {
+ content.to_string()
+ };
+
let message_id = d.get("id").and_then(|i| i.as_str()).unwrap_or("");
let channel_id = d.get("channel_id").and_then(|c| c.as_str()).unwrap_or("").to_string();
@@ -353,6 +372,11 @@ impl Channel for DiscordChannel {
format!("discord_{message_id}")
},
sender: author_id.to_string(),
+ reply_target: if channel_id.is_empty() {
+ author_id.to_string()
+ } else {
+ channel_id
+ },
content: content.to_string(),
channel: channel_id,
timestamp: std::time::SystemTime::now()
@@ -423,7 +447,7 @@ mod tests {
#[test]
fn discord_channel_name() {
- let ch = DiscordChannel::new("fake".into(), None, vec![], false);
+ let ch = DiscordChannel::new("fake".into(), None, vec![], false, false);
assert_eq!(ch.name(), "discord");
}
@@ -444,21 +468,27 @@ mod tests {
#[test]
fn empty_allowlist_denies_everyone() {
- let ch = DiscordChannel::new("fake".into(), None, vec![], false);
+ let ch = DiscordChannel::new("fake".into(), None, vec![], false, false);
assert!(!ch.is_user_allowed("12345"));
assert!(!ch.is_user_allowed("anyone"));
}
#[test]
fn wildcard_allows_everyone() {
- let ch = DiscordChannel::new("fake".into(), None, vec!["*".into()], false);
+ let ch = DiscordChannel::new("fake".into(), None, vec!["*".into()], false, false);
assert!(ch.is_user_allowed("12345"));
assert!(ch.is_user_allowed("anyone"));
}
#[test]
fn specific_allowlist_filters() {
- let ch = DiscordChannel::new("fake".into(), None, vec!["111".into(), "222".into()], false);
+ let ch = DiscordChannel::new(
+ "fake".into(),
+ None,
+ vec!["111".into(), "222".into()],
+ false,
+ false,
+ );
assert!(ch.is_user_allowed("111"));
assert!(ch.is_user_allowed("222"));
assert!(!ch.is_user_allowed("333"));
@@ -467,7 +497,7 @@ mod tests {
#[test]
fn allowlist_is_exact_match_not_substring() {
- let ch = DiscordChannel::new("fake".into(), None, vec!["111".into()], false);
+ let ch = DiscordChannel::new("fake".into(), None, vec!["111".into()], false, false);
assert!(!ch.is_user_allowed("1111"));
assert!(!ch.is_user_allowed("11"));
assert!(!ch.is_user_allowed("0111"));
@@ -475,20 +505,26 @@ mod tests {
#[test]
fn allowlist_empty_string_user_id() {
- let ch = DiscordChannel::new("fake".into(), None, vec!["111".into()], false);
+ let ch = DiscordChannel::new("fake".into(), None, vec!["111".into()], false, false);
assert!(!ch.is_user_allowed(""));
}
#[test]
fn allowlist_with_wildcard_and_specific() {
- let ch = DiscordChannel::new("fake".into(), None, vec!["111".into(), "*".into()], false);
+ let ch = DiscordChannel::new(
+ "fake".into(),
+ None,
+ vec!["111".into(), "*".into()],
+ false,
+ false,
+ );
assert!(ch.is_user_allowed("111"));
assert!(ch.is_user_allowed("anyone_else"));
}
#[test]
fn allowlist_case_sensitive() {
- let ch = DiscordChannel::new("fake".into(), None, vec!["ABC".into()], false);
+ let ch = DiscordChannel::new("fake".into(), None, vec!["ABC".into()], false, false);
assert!(ch.is_user_allowed("ABC"));
assert!(!ch.is_user_allowed("abc"));
assert!(!ch.is_user_allowed("Abc"));
@@ -663,14 +699,14 @@ mod tests {
#[test]
fn typing_handle_starts_as_none() {
- let ch = DiscordChannel::new("fake".into(), None, vec![], false);
+ let ch = DiscordChannel::new("fake".into(), None, vec![], false, false);
let guard = ch.typing_handle.lock().unwrap();
assert!(guard.is_none());
}
#[tokio::test]
async fn start_typing_sets_handle() {
- let ch = DiscordChannel::new("fake".into(), None, vec![], false);
+ let ch = DiscordChannel::new("fake".into(), None, vec![], false, false);
let _ = ch.start_typing("123456").await;
let guard = ch.typing_handle.lock().unwrap();
assert!(guard.is_some());
@@ -678,7 +714,7 @@ mod tests {
#[tokio::test]
async fn stop_typing_clears_handle() {
- let ch = DiscordChannel::new("fake".into(), None, vec![], false);
+ let ch = DiscordChannel::new("fake".into(), None, vec![], false, false);
let _ = ch.start_typing("123456").await;
let _ = ch.stop_typing("123456").await;
let guard = ch.typing_handle.lock().unwrap();
@@ -687,14 +723,14 @@ mod tests {
#[tokio::test]
async fn stop_typing_is_idempotent() {
- let ch = DiscordChannel::new("fake".into(), None, vec![], false);
+ let ch = DiscordChannel::new("fake".into(), None, vec![], false, false);
assert!(ch.stop_typing("123456").await.is_ok());
assert!(ch.stop_typing("123456").await.is_ok());
}
#[tokio::test]
async fn start_typing_replaces_existing_task() {
- let ch = DiscordChannel::new("fake".into(), None, vec![], false);
+ let ch = DiscordChannel::new("fake".into(), None, vec![], false, false);
let _ = ch.start_typing("111").await;
let _ = ch.start_typing("222").await;
let guard = ch.typing_handle.lock().unwrap();
diff --git a/src/channels/email_channel.rs b/src/channels/email_channel.rs
index f1ea016..da3490d 100644
--- a/src/channels/email_channel.rs
+++ b/src/channels/email_channel.rs
@@ -10,6 +10,7 @@
use anyhow::{anyhow, Result};
use async_trait::async_trait;
+use lettre::message::SinglePart;
use lettre::transport::smtp::authentication::Credentials;
use lettre::{Message, SmtpTransport, Transport};
use mail_parser::{MessageParser, MimeHeaders};
@@ -39,7 +40,7 @@ pub struct EmailConfig {
pub imap_folder: String,
/// SMTP server hostname
pub smtp_host: String,
- /// SMTP server port (default: 587 for STARTTLS)
+ /// SMTP server port (default: 465 for TLS)
#[serde(default = "default_smtp_port")]
pub smtp_port: u16,
/// Use TLS for SMTP (default: true)
@@ -63,7 +64,7 @@ fn default_imap_port() -> u16 {
993
}
fn default_smtp_port() -> u16 {
- 587
+ 465
}
fn default_imap_folder() -> String {
"INBOX".into()
@@ -389,7 +390,7 @@ impl Channel for EmailChannel {
.from(self.config.from_address.parse()?)
.to(recipient.parse()?)
.subject(subject)
- .body(body.to_string())?;
+ .singlepart(SinglePart::plain(body.to_string()))?;
let transport = self.create_smtp_transport()?;
transport.send(&email)?;
@@ -427,6 +428,7 @@ impl Channel for EmailChannel {
} // MutexGuard dropped before await
let msg = ChannelMessage {
id,
+ reply_target: sender.clone(),
sender,
content,
channel: "email".to_string(),
@@ -464,6 +466,18 @@ impl Channel for EmailChannel {
mod tests {
use super::*;
+ #[test]
+ fn default_smtp_port_uses_tls_port() {
+ assert_eq!(default_smtp_port(), 465);
+ }
+
+ #[test]
+ fn email_config_default_uses_tls_smtp_defaults() {
+ let config = EmailConfig::default();
+ assert_eq!(config.smtp_port, 465);
+ assert!(config.smtp_tls);
+ }
+
#[test]
fn build_imap_tls_config_succeeds() {
let tls_config =
@@ -504,7 +518,7 @@ mod tests {
assert_eq!(config.imap_port, 993);
assert_eq!(config.imap_folder, "INBOX");
assert_eq!(config.smtp_host, "");
- assert_eq!(config.smtp_port, 587);
+ assert_eq!(config.smtp_port, 465);
assert!(config.smtp_tls);
assert_eq!(config.username, "");
assert_eq!(config.password, "");
@@ -765,8 +779,8 @@ mod tests {
}
#[test]
- fn default_smtp_port_returns_587() {
- assert_eq!(default_smtp_port(), 587);
+ fn default_smtp_port_returns_465() {
+ assert_eq!(default_smtp_port(), 465);
}
#[test]
@@ -822,7 +836,7 @@ mod tests {
let config: EmailConfig = serde_json::from_str(json).unwrap();
assert_eq!(config.imap_port, 993); // default
- assert_eq!(config.smtp_port, 587); // default
+ assert_eq!(config.smtp_port, 465); // default
assert!(config.smtp_tls); // default
assert_eq!(config.poll_interval_secs, 60); // default
}
diff --git a/src/channels/imessage.rs b/src/channels/imessage.rs
index f001c56..36bf72f 100644
--- a/src/channels/imessage.rs
+++ b/src/channels/imessage.rs
@@ -172,6 +172,7 @@ end tell"#
let msg = ChannelMessage {
id: rowid.to_string(),
sender: sender.clone(),
+ reply_target: sender.clone(),
content: text,
channel: "imessage".to_string(),
timestamp: std::time::SystemTime::now()
diff --git a/src/channels/irc.rs b/src/channels/irc.rs
index d63ad41..61a48cc 100644
--- a/src/channels/irc.rs
+++ b/src/channels/irc.rs
@@ -220,32 +220,34 @@ fn split_message(message: &str, max_bytes: usize) -> Vec {
chunks
}
+/// Configuration for constructing an `IrcChannel`.
+pub struct IrcChannelConfig {
+ pub server: String,
+ pub port: u16,
+ pub nickname: String,
+ pub username: Option,
+ pub channels: Vec,
+ pub allowed_users: Vec,
+ pub server_password: Option,
+ pub nickserv_password: Option,
+ pub sasl_password: Option,
+ pub verify_tls: bool,
+}
+
impl IrcChannel {
- #[allow(clippy::too_many_arguments)]
- pub fn new(
- server: String,
- port: u16,
- nickname: String,
- username: Option,
- channels: Vec,
- allowed_users: Vec,
- server_password: Option,
- nickserv_password: Option,
- sasl_password: Option,
- verify_tls: bool,
- ) -> Self {
- let username = username.unwrap_or_else(|| nickname.clone());
+ pub fn new(cfg: IrcChannelConfig) -> Self {
+ let username = cfg.username.unwrap_or_else(|| cfg.nickname.clone());
Self {
- server,
- port,
- nickname,
+ server: cfg.server,
+ port: cfg.port,
+ nickname: cfg.nickname,
username,
- channels,
- allowed_users,
- server_password,
- nickserv_password,
- sasl_password,
- verify_tls,
+ channels: cfg.channels,
+ allowed_users: cfg.allowed_users,
+ server_password: cfg.server_password,
+ nickserv_password: cfg.nickserv_password,
+ sasl_password: cfg.sasl_password,
+ verify_tls: cfg.verify_tls,
writer: Arc::new(Mutex::new(None)),
}
}
@@ -563,7 +565,8 @@ impl Channel for IrcChannel {
let seq = MSG_SEQ.fetch_add(1, Ordering::Relaxed);
let channel_msg = ChannelMessage {
id: format!("irc_{}_{seq}", chrono::Utc::now().timestamp_millis()),
- sender: reply_to,
+ sender: sender_nick.to_string(),
+ reply_target: reply_to,
content,
channel: "irc".to_string(),
timestamp: std::time::SystemTime::now()
@@ -807,18 +810,18 @@ mod tests {
#[test]
fn specific_user_allowed() {
- let ch = IrcChannel::new(
- "irc.test".into(),
- 6697,
- "bot".into(),
- None,
- vec![],
- vec!["alice".into(), "bob".into()],
- None,
- None,
- None,
- true,
- );
+ let ch = IrcChannel::new(IrcChannelConfig {
+ server: "irc.test".into(),
+ port: 6697,
+ nickname: "bot".into(),
+ username: None,
+ channels: vec![],
+ allowed_users: vec!["alice".into(), "bob".into()],
+ server_password: None,
+ nickserv_password: None,
+ sasl_password: None,
+ verify_tls: true,
+ });
assert!(ch.is_user_allowed("alice"));
assert!(ch.is_user_allowed("bob"));
assert!(!ch.is_user_allowed("eve"));
@@ -826,18 +829,18 @@ mod tests {
#[test]
fn allowlist_case_insensitive() {
- let ch = IrcChannel::new(
- "irc.test".into(),
- 6697,
- "bot".into(),
- None,
- vec![],
- vec!["Alice".into()],
- None,
- None,
- None,
- true,
- );
+ let ch = IrcChannel::new(IrcChannelConfig {
+ server: "irc.test".into(),
+ port: 6697,
+ nickname: "bot".into(),
+ username: None,
+ channels: vec![],
+ allowed_users: vec!["Alice".into()],
+ server_password: None,
+ nickserv_password: None,
+ sasl_password: None,
+ verify_tls: true,
+ });
assert!(ch.is_user_allowed("alice"));
assert!(ch.is_user_allowed("ALICE"));
assert!(ch.is_user_allowed("Alice"));
@@ -845,18 +848,18 @@ mod tests {
#[test]
fn empty_allowlist_denies_all() {
- let ch = IrcChannel::new(
- "irc.test".into(),
- 6697,
- "bot".into(),
- None,
- vec![],
- vec![],
- None,
- None,
- None,
- true,
- );
+ let ch = IrcChannel::new(IrcChannelConfig {
+ server: "irc.test".into(),
+ port: 6697,
+ nickname: "bot".into(),
+ username: None,
+ channels: vec![],
+ allowed_users: vec![],
+ server_password: None,
+ nickserv_password: None,
+ sasl_password: None,
+ verify_tls: true,
+ });
assert!(!ch.is_user_allowed("anyone"));
}
@@ -864,35 +867,35 @@ mod tests {
#[test]
fn new_defaults_username_to_nickname() {
- let ch = IrcChannel::new(
- "irc.test".into(),
- 6697,
- "mybot".into(),
- None,
- vec![],
- vec![],
- None,
- None,
- None,
- true,
- );
+ let ch = IrcChannel::new(IrcChannelConfig {
+ server: "irc.test".into(),
+ port: 6697,
+ nickname: "mybot".into(),
+ username: None,
+ channels: vec![],
+ allowed_users: vec![],
+ server_password: None,
+ nickserv_password: None,
+ sasl_password: None,
+ verify_tls: true,
+ });
assert_eq!(ch.username, "mybot");
}
#[test]
fn new_uses_explicit_username() {
- let ch = IrcChannel::new(
- "irc.test".into(),
- 6697,
- "mybot".into(),
- Some("customuser".into()),
- vec![],
- vec![],
- None,
- None,
- None,
- true,
- );
+ let ch = IrcChannel::new(IrcChannelConfig {
+ server: "irc.test".into(),
+ port: 6697,
+ nickname: "mybot".into(),
+ username: Some("customuser".into()),
+ channels: vec![],
+ allowed_users: vec![],
+ server_password: None,
+ nickserv_password: None,
+ sasl_password: None,
+ verify_tls: true,
+ });
assert_eq!(ch.username, "customuser");
assert_eq!(ch.nickname, "mybot");
}
@@ -905,18 +908,18 @@ mod tests {
#[test]
fn new_stores_all_fields() {
- let ch = IrcChannel::new(
- "irc.example.com".into(),
- 6697,
- "zcbot".into(),
- Some("zeroclaw".into()),
- vec!["#test".into()],
- vec!["alice".into()],
- Some("serverpass".into()),
- Some("nspass".into()),
- Some("saslpass".into()),
- false,
- );
+ let ch = IrcChannel::new(IrcChannelConfig {
+ server: "irc.example.com".into(),
+ port: 6697,
+ nickname: "zcbot".into(),
+ username: Some("zeroclaw".into()),
+ channels: vec!["#test".into()],
+ allowed_users: vec!["alice".into()],
+ server_password: Some("serverpass".into()),
+ nickserv_password: Some("nspass".into()),
+ sasl_password: Some("saslpass".into()),
+ verify_tls: false,
+ });
assert_eq!(ch.server, "irc.example.com");
assert_eq!(ch.port, 6697);
assert_eq!(ch.nickname, "zcbot");
@@ -995,17 +998,17 @@ nickname = "bot"
// ── Helpers ─────────────────────────────────────────────
fn make_channel() -> IrcChannel {
- IrcChannel::new(
- "irc.example.com".into(),
- 6697,
- "zcbot".into(),
- None,
- vec!["#zeroclaw".into()],
- vec!["*".into()],
- None,
- None,
- None,
- true,
- )
+ IrcChannel::new(IrcChannelConfig {
+ server: "irc.example.com".into(),
+ port: 6697,
+ nickname: "zcbot".into(),
+ username: None,
+ channels: vec!["#zeroclaw".into()],
+ allowed_users: vec!["*".into()],
+ server_password: None,
+ nickserv_password: None,
+ sasl_password: None,
+ verify_tls: true,
+ })
}
}
diff --git a/src/channels/lark.rs b/src/channels/lark.rs
index 4e9e679..5f929f8 100644
--- a/src/channels/lark.rs
+++ b/src/channels/lark.rs
@@ -1,21 +1,152 @@
use super::traits::{Channel, ChannelMessage};
use async_trait::async_trait;
+use futures_util::{SinkExt, StreamExt};
+use prost::Message as ProstMessage;
+use std::collections::HashMap;
use std::sync::Arc;
+use std::time::{Duration, Instant};
use tokio::sync::RwLock;
+use tokio_tungstenite::tungstenite::Message as WsMsg;
use uuid::Uuid;
const FEISHU_BASE_URL: &str = "https://open.feishu.cn/open-apis";
+const FEISHU_WS_BASE_URL: &str = "https://open.feishu.cn";
+const LARK_BASE_URL: &str = "https://open.larksuite.com/open-apis";
+const LARK_WS_BASE_URL: &str = "https://open.larksuite.com";
-/// Lark/Feishu channel — receives events via HTTP callback, sends via Open API
+// ─────────────────────────────────────────────────────────────────────────────
+// Feishu WebSocket long-connection: pbbp2.proto frame codec
+// ─────────────────────────────────────────────────────────────────────────────
+
+#[derive(Clone, PartialEq, prost::Message)]
+struct PbHeader {
+ #[prost(string, tag = "1")]
+ pub key: String,
+ #[prost(string, tag = "2")]
+ pub value: String,
+}
+
+/// Feishu WS frame (pbbp2.proto).
+/// method=0 → CONTROL (ping/pong) method=1 → DATA (events)
+#[derive(Clone, PartialEq, prost::Message)]
+struct PbFrame {
+ #[prost(uint64, tag = "1")]
+ pub seq_id: u64,
+ #[prost(uint64, tag = "2")]
+ pub log_id: u64,
+ #[prost(int32, tag = "3")]
+ pub service: i32,
+ #[prost(int32, tag = "4")]
+ pub method: i32,
+ #[prost(message, repeated, tag = "5")]
+ pub headers: Vec,
+ #[prost(bytes = "vec", optional, tag = "8")]
+ pub payload: Option>,
+}
+
+impl PbFrame {
+ fn header_value<'a>(&'a self, key: &str) -> &'a str {
+ self.headers
+ .iter()
+ .find(|h| h.key == key)
+ .map(|h| h.value.as_str())
+ .unwrap_or("")
+ }
+}
+
+/// Server-sent client config (parsed from pong payload)
+#[derive(Debug, serde::Deserialize, Default, Clone)]
+struct WsClientConfig {
+ #[serde(rename = "PingInterval")]
+ ping_interval: Option,
+}
+
+/// POST /callback/ws/endpoint response
+#[derive(Debug, serde::Deserialize)]
+struct WsEndpointResp {
+ code: i32,
+ #[serde(default)]
+ msg: Option,
+ #[serde(default)]
+ data: Option,
+}
+
+#[derive(Debug, serde::Deserialize)]
+struct WsEndpoint {
+ #[serde(rename = "URL")]
+ url: String,
+ #[serde(rename = "ClientConfig")]
+ client_config: Option,
+}
+
+/// LarkEvent envelope (method=1 / type=event payload)
+#[derive(Debug, serde::Deserialize)]
+struct LarkEvent {
+ header: LarkEventHeader,
+ event: serde_json::Value,
+}
+
+#[derive(Debug, serde::Deserialize)]
+struct LarkEventHeader {
+ event_type: String,
+ #[allow(dead_code)]
+ event_id: String,
+}
+
+#[derive(Debug, serde::Deserialize)]
+struct MsgReceivePayload {
+ sender: LarkSender,
+ message: LarkMessage,
+}
+
+#[derive(Debug, serde::Deserialize)]
+struct LarkSender {
+ sender_id: LarkSenderId,
+ #[serde(default)]
+ sender_type: String,
+}
+
+#[derive(Debug, serde::Deserialize, Default)]
+struct LarkSenderId {
+ open_id: Option,
+}
+
+#[derive(Debug, serde::Deserialize)]
+struct LarkMessage {
+ message_id: String,
+ chat_id: String,
+ chat_type: String,
+ message_type: String,
+ #[serde(default)]
+ content: String,
+ #[serde(default)]
+ mentions: Vec,
+}
+
+/// Heartbeat timeout for WS connection — must be larger than ping_interval (default 120 s).
+/// If no binary frame (pong or event) is received within this window, reconnect.
+const WS_HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(300);
+
+/// Lark/Feishu channel.
+///
+/// Supports two receive modes (configured via `receive_mode` in config):
+/// - **`websocket`** (default): persistent WSS long-connection; no public URL needed.
+/// - **`webhook`**: HTTP callback server; requires a public HTTPS endpoint.
pub struct LarkChannel {
app_id: String,
app_secret: String,
verification_token: String,
- port: u16,
+ port: Option,
allowed_users: Vec,
+ /// When true, use Feishu (CN) endpoints; when false, use Lark (international).
+ use_feishu: bool,
+ /// How to receive events: WebSocket long-connection or HTTP webhook.
+ receive_mode: crate::config::schema::LarkReceiveMode,
client: reqwest::Client,
/// Cached tenant access token
tenant_token: Arc>>,
+ /// Dedup set: WS message_ids seen in last ~30 min to prevent double-dispatch
+ ws_seen_ids: Arc>>,
}
impl LarkChannel {
@@ -23,7 +154,7 @@ impl LarkChannel {
app_id: String,
app_secret: String,
verification_token: String,
- port: u16,
+ port: Option,
allowed_users: Vec,
) -> Self {
Self {
@@ -32,11 +163,310 @@ impl LarkChannel {
verification_token,
port,
allowed_users,
+ use_feishu: true,
+ receive_mode: crate::config::schema::LarkReceiveMode::default(),
client: reqwest::Client::new(),
tenant_token: Arc::new(RwLock::new(None)),
+ ws_seen_ids: Arc::new(RwLock::new(HashMap::new())),
}
}
+ /// Build from `LarkConfig` (preserves `use_feishu` and `receive_mode`).
+ pub fn from_config(config: &crate::config::schema::LarkConfig) -> Self {
+ let mut ch = Self::new(
+ config.app_id.clone(),
+ config.app_secret.clone(),
+ config.verification_token.clone().unwrap_or_default(),
+ config.port,
+ config.allowed_users.clone(),
+ );
+ ch.use_feishu = config.use_feishu;
+ ch.receive_mode = config.receive_mode.clone();
+ ch
+ }
+
+ fn api_base(&self) -> &'static str {
+ if self.use_feishu {
+ FEISHU_BASE_URL
+ } else {
+ LARK_BASE_URL
+ }
+ }
+
+ fn ws_base(&self) -> &'static str {
+ if self.use_feishu {
+ FEISHU_WS_BASE_URL
+ } else {
+ LARK_WS_BASE_URL
+ }
+ }
+
+ fn tenant_access_token_url(&self) -> String {
+ format!("{}/auth/v3/tenant_access_token/internal", self.api_base())
+ }
+
+ fn send_message_url(&self) -> String {
+ format!("{}/im/v1/messages?receive_id_type=chat_id", self.api_base())
+ }
+
+ /// POST /callback/ws/endpoint → (wss_url, client_config)
+ async fn get_ws_endpoint(&self) -> anyhow::Result<(String, WsClientConfig)> {
+ let resp = self
+ .client
+ .post(format!("{}/callback/ws/endpoint", self.ws_base()))
+ .header("locale", if self.use_feishu { "zh" } else { "en" })
+ .json(&serde_json::json!({
+ "AppID": self.app_id,
+ "AppSecret": self.app_secret,
+ }))
+ .send()
+ .await?
+ .json::()
+ .await?;
+ if resp.code != 0 {
+ anyhow::bail!(
+ "Lark WS endpoint failed: code={} msg={}",
+ resp.code,
+ resp.msg.as_deref().unwrap_or("(none)")
+ );
+ }
+ let ep = resp
+ .data
+ .ok_or_else(|| anyhow::anyhow!("Lark WS endpoint: empty data"))?;
+ Ok((ep.url, ep.client_config.unwrap_or_default()))
+ }
+
+ /// WS long-connection event loop. Returns Ok(()) when the connection closes
+ /// (the caller reconnects).
+ #[allow(clippy::too_many_lines)]
+ async fn listen_ws(&self, tx: tokio::sync::mpsc::Sender) -> anyhow::Result<()> {
+ let (wss_url, client_config) = self.get_ws_endpoint().await?;
+ let service_id = wss_url
+ .split('?')
+ .nth(1)
+ .and_then(|qs| {
+ qs.split('&')
+ .find(|kv| kv.starts_with("service_id="))
+ .and_then(|kv| kv.split('=').nth(1))
+ .and_then(|v| v.parse::().ok())
+ })
+ .unwrap_or(0);
+ tracing::info!("Lark: connecting to {wss_url}");
+
+ let (ws_stream, _) = tokio_tungstenite::connect_async(&wss_url).await?;
+ let (mut write, mut read) = ws_stream.split();
+ tracing::info!("Lark: WS connected (service_id={service_id})");
+
+ let mut ping_secs = client_config.ping_interval.unwrap_or(120).max(10);
+ let mut hb_interval = tokio::time::interval(Duration::from_secs(ping_secs));
+ let mut timeout_check = tokio::time::interval(Duration::from_secs(10));
+ hb_interval.tick().await; // consume immediate tick
+
+ let mut seq: u64 = 0;
+ let mut last_recv = Instant::now();
+
+ // Send initial ping immediately (like the official SDK) so the server
+ // starts responding with pongs and we can calibrate the ping_interval.
+ seq = seq.wrapping_add(1);
+ let initial_ping = PbFrame {
+ seq_id: seq,
+ log_id: 0,
+ service: service_id,
+ method: 0,
+ headers: vec![PbHeader {
+ key: "type".into(),
+ value: "ping".into(),
+ }],
+ payload: None,
+ };
+ if write
+ .send(WsMsg::Binary(initial_ping.encode_to_vec()))
+ .await
+ .is_err()
+ {
+ anyhow::bail!("Lark: initial ping failed");
+ }
+ // message_id → (fragment_slots, created_at) for multi-part reassembly
+ type FragEntry = (Vec>>, Instant);
+ let mut frag_cache: HashMap = HashMap::new();
+
+ loop {
+ tokio::select! {
+ biased;
+
+ _ = hb_interval.tick() => {
+ seq = seq.wrapping_add(1);
+ let ping = PbFrame {
+ seq_id: seq, log_id: 0, service: service_id, method: 0,
+ headers: vec![PbHeader { key: "type".into(), value: "ping".into() }],
+ payload: None,
+ };
+ if write.send(WsMsg::Binary(ping.encode_to_vec())).await.is_err() {
+ tracing::warn!("Lark: ping failed, reconnecting");
+ break;
+ }
+ // GC stale fragments > 5 min
+ let cutoff = Instant::now().checked_sub(Duration::from_secs(300)).unwrap_or(Instant::now());
+ frag_cache.retain(|_, (_, ts)| *ts > cutoff);
+ }
+
+ _ = timeout_check.tick() => {
+ if last_recv.elapsed() > WS_HEARTBEAT_TIMEOUT {
+ tracing::warn!("Lark: heartbeat timeout, reconnecting");
+ break;
+ }
+ }
+
+ msg = read.next() => {
+ let raw = match msg {
+ Some(Ok(WsMsg::Binary(b))) => { last_recv = Instant::now(); b }
+ Some(Ok(WsMsg::Ping(d))) => { let _ = write.send(WsMsg::Pong(d)).await; continue; }
+ Some(Ok(WsMsg::Close(_))) | None => { tracing::info!("Lark: WS closed — reconnecting"); break; }
+ Some(Err(e)) => { tracing::error!("Lark: WS read error: {e}"); break; }
+ _ => continue,
+ };
+
+ let frame = match PbFrame::decode(&raw[..]) {
+ Ok(f) => f,
+ Err(e) => { tracing::error!("Lark: proto decode: {e}"); continue; }
+ };
+
+ // CONTROL frame
+ if frame.method == 0 {
+ if frame.header_value("type") == "pong" {
+ if let Some(p) = &frame.payload {
+ if let Ok(cfg) = serde_json::from_slice::(p) {
+ if let Some(secs) = cfg.ping_interval {
+ let secs = secs.max(10);
+ if secs != ping_secs {
+ ping_secs = secs;
+ hb_interval = tokio::time::interval(Duration::from_secs(ping_secs));
+ tracing::info!("Lark: ping_interval → {ping_secs}s");
+ }
+ }
+ }
+ }
+ }
+ continue;
+ }
+
+ // DATA frame
+ let msg_type = frame.header_value("type").to_string();
+ let msg_id = frame.header_value("message_id").to_string();
+ let sum = frame.header_value("sum").parse::().unwrap_or(1);
+ let seq_num = frame.header_value("seq").parse::().unwrap_or(0);
+
+ // ACK immediately (Feishu requires within 3 s)
+ {
+ let mut ack = frame.clone();
+ ack.payload = Some(br#"{"code":200,"headers":{},"data":[]}"#.to_vec());
+ ack.headers.push(PbHeader { key: "biz_rt".into(), value: "0".into() });
+ let _ = write.send(WsMsg::Binary(ack.encode_to_vec())).await;
+ }
+
+ // Fragment reassembly
+ let sum = if sum == 0 { 1 } else { sum };
+ let payload: Vec = if sum == 1 || msg_id.is_empty() || seq_num >= sum {
+ frame.payload.clone().unwrap_or_default()
+ } else {
+ let entry = frag_cache.entry(msg_id.clone())
+ .or_insert_with(|| (vec![None; sum], Instant::now()));
+ if entry.0.len() != sum { *entry = (vec![None; sum], Instant::now()); }
+ entry.0[seq_num] = frame.payload.clone();
+ if entry.0.iter().all(|s| s.is_some()) {
+ let full: Vec = entry.0.iter()
+ .flat_map(|s| s.as_deref().unwrap_or(&[]))
+ .copied().collect();
+ frag_cache.remove(&msg_id);
+ full
+ } else { continue; }
+ };
+
+ if msg_type != "event" { continue; }
+
+ let event: LarkEvent = match serde_json::from_slice(&payload) {
+ Ok(e) => e,
+ Err(e) => { tracing::error!("Lark: event JSON: {e}"); continue; }
+ };
+ if event.header.event_type != "im.message.receive_v1" { continue; }
+
+ let recv: MsgReceivePayload = match serde_json::from_value(event.event) {
+ Ok(r) => r,
+ Err(e) => { tracing::error!("Lark: payload parse: {e}"); continue; }
+ };
+
+ if recv.sender.sender_type == "app" || recv.sender.sender_type == "bot" { continue; }
+
+ let sender_open_id = recv.sender.sender_id.open_id.as_deref().unwrap_or("");
+ if !self.is_user_allowed(sender_open_id) {
+ tracing::warn!("Lark WS: ignoring {sender_open_id} (not in allowed_users)");
+ continue;
+ }
+
+ let lark_msg = &recv.message;
+
+ // Dedup
+ {
+ let now = Instant::now();
+ let mut seen = self.ws_seen_ids.write().await;
+ // GC
+ seen.retain(|_, t| now.duration_since(*t) < Duration::from_secs(30 * 60));
+ if seen.contains_key(&lark_msg.message_id) {
+ tracing::debug!("Lark WS: dup {}", lark_msg.message_id);
+ continue;
+ }
+ seen.insert(lark_msg.message_id.clone(), now);
+ }
+
+ // Decode content by type (mirrors clawdbot-feishu parsing)
+ let text = match lark_msg.message_type.as_str() {
+ "text" => {
+ let v: serde_json::Value = match serde_json::from_str(&lark_msg.content) {
+ Ok(v) => v,
+ Err(_) => continue,
+ };
+ match v.get("text").and_then(|t| t.as_str()).filter(|s| !s.is_empty()) {
+ Some(t) => t.to_string(),
+ None => continue,
+ }
+ }
+ "post" => match parse_post_content(&lark_msg.content) {
+ Some(t) => t,
+ None => continue,
+ },
+ _ => { tracing::debug!("Lark WS: skipping unsupported type '{}'", lark_msg.message_type); continue; }
+ };
+
+ // Strip @_user_N placeholders
+ let text = strip_at_placeholders(&text);
+ let text = text.trim().to_string();
+ if text.is_empty() { continue; }
+
+ // Group-chat: only respond when explicitly @-mentioned
+ if lark_msg.chat_type == "group" && !should_respond_in_group(&lark_msg.mentions) {
+ continue;
+ }
+
+ let channel_msg = ChannelMessage {
+ id: Uuid::new_v4().to_string(),
+ sender: lark_msg.chat_id.clone(),
+ reply_target: lark_msg.chat_id.clone(),
+ content: text,
+ channel: "lark".to_string(),
+ timestamp: std::time::SystemTime::now()
+ .duration_since(std::time::UNIX_EPOCH)
+ .unwrap_or_default()
+ .as_secs(),
+ };
+
+ tracing::debug!("Lark WS: message in {}", lark_msg.chat_id);
+ if tx.send(channel_msg).await.is_err() { break; }
+ }
+ }
+ }
+ Ok(())
+ }
+
/// Check if a user open_id is allowed
fn is_user_allowed(&self, open_id: &str) -> bool {
self.allowed_users.iter().any(|u| u == "*" || u == open_id)
@@ -52,7 +482,7 @@ impl LarkChannel {
}
}
- let url = format!("{FEISHU_BASE_URL}/auth/v3/tenant_access_token/internal");
+ let url = self.tenant_access_token_url();
let body = serde_json::json!({
"app_id": self.app_id,
"app_secret": self.app_secret,
@@ -127,31 +557,41 @@ impl LarkChannel {
return messages;
}
- // Extract message content (text only)
+ // Extract message content (text and post supported)
let msg_type = event
.pointer("/message/message_type")
.and_then(|t| t.as_str())
.unwrap_or("");
- if msg_type != "text" {
- tracing::debug!("Lark: skipping non-text message type: {msg_type}");
- return messages;
- }
-
let content_str = event
.pointer("/message/content")
.and_then(|c| c.as_str())
.unwrap_or("");
- // content is a JSON string like "{\"text\":\"hello\"}"
- let text = serde_json::from_str::(content_str)
- .ok()
- .and_then(|v| v.get("text").and_then(|t| t.as_str()).map(String::from))
- .unwrap_or_default();
-
- if text.is_empty() {
- return messages;
- }
+ let text: String = match msg_type {
+ "text" => {
+ let extracted = serde_json::from_str::(content_str)
+ .ok()
+ .and_then(|v| {
+ v.get("text")
+ .and_then(|t| t.as_str())
+ .filter(|s| !s.is_empty())
+ .map(String::from)
+ });
+ match extracted {
+ Some(t) => t,
+ None => return messages,
+ }
+ }
+ "post" => match parse_post_content(content_str) {
+ Some(t) => t,
+ None => return messages,
+ },
+ _ => {
+ tracing::debug!("Lark: skipping unsupported message type: {msg_type}");
+ return messages;
+ }
+ };
let timestamp = event
.pointer("/message/create_time")
@@ -174,6 +614,7 @@ impl LarkChannel {
messages.push(ChannelMessage {
id: Uuid::new_v4().to_string(),
sender: chat_id.to_string(),
+ reply_target: chat_id.to_string(),
content: text,
channel: "lark".to_string(),
timestamp,
@@ -191,7 +632,7 @@ impl Channel for LarkChannel {
async fn send(&self, message: &str, recipient: &str) -> anyhow::Result<()> {
let token = self.get_tenant_access_token().await?;
- let url = format!("{FEISHU_BASE_URL}/im/v1/messages?receive_id_type=chat_id");
+ let url = self.send_message_url();
let content = serde_json::json!({ "text": message }).to_string();
let body = serde_json::json!({
@@ -238,6 +679,25 @@ impl Channel for LarkChannel {
}
async fn listen(&self, tx: tokio::sync::mpsc::Sender) -> anyhow::Result<()> {
+ use crate::config::schema::LarkReceiveMode;
+ match self.receive_mode {
+ LarkReceiveMode::Websocket => self.listen_ws(tx).await,
+ LarkReceiveMode::Webhook => self.listen_http(tx).await,
+ }
+ }
+
+ async fn health_check(&self) -> bool {
+ self.get_tenant_access_token().await.is_ok()
+ }
+}
+
+impl LarkChannel {
+ /// HTTP callback server (legacy — requires a public endpoint).
+ /// Use `listen()` (WS long-connection) for new deployments.
+ pub async fn listen_http(
+ &self,
+ tx: tokio::sync::mpsc::Sender,
+ ) -> anyhow::Result<()> {
use axum::{extract::State, routing::post, Json, Router};
#[derive(Clone)]
@@ -282,13 +742,17 @@ impl Channel for LarkChannel {
(StatusCode::OK, "ok").into_response()
}
+ let port = self.port.ok_or_else(|| {
+ anyhow::anyhow!("Lark webhook mode requires `port` to be set in [channels_config.lark]")
+ })?;
+
let state = AppState {
verification_token: self.verification_token.clone(),
channel: Arc::new(LarkChannel::new(
self.app_id.clone(),
self.app_secret.clone(),
self.verification_token.clone(),
- self.port,
+ None,
self.allowed_users.clone(),
)),
tx,
@@ -298,7 +762,7 @@ impl Channel for LarkChannel {
.route("/lark", post(handle_event))
.with_state(state);
- let addr = std::net::SocketAddr::from(([0, 0, 0, 0], self.port));
+ let addr = std::net::SocketAddr::from(([0, 0, 0, 0], port));
tracing::info!("Lark event callback server listening on {addr}");
let listener = tokio::net::TcpListener::bind(addr).await?;
@@ -306,10 +770,110 @@ impl Channel for LarkChannel {
Ok(())
}
+}
- async fn health_check(&self) -> bool {
- self.get_tenant_access_token().await.is_ok()
+// ─────────────────────────────────────────────────────────────────────────────
+// WS helper functions
+// ─────────────────────────────────────────────────────────────────────────────
+
+/// Flatten a Feishu `post` rich-text message to plain text.
+///
+/// Returns `None` when the content cannot be parsed or yields no usable text,
+/// so callers can simply `continue` rather than forwarding a meaningless
+/// placeholder string to the agent.
+fn parse_post_content(content: &str) -> Option {
+ let parsed = serde_json::from_str::(content).ok()?;
+ let locale = parsed
+ .get("zh_cn")
+ .or_else(|| parsed.get("en_us"))
+ .or_else(|| {
+ parsed
+ .as_object()
+ .and_then(|m| m.values().find(|v| v.is_object()))
+ })?;
+
+ let mut text = String::new();
+
+ if let Some(title) = locale
+ .get("title")
+ .and_then(|t| t.as_str())
+ .filter(|s| !s.is_empty())
+ {
+ text.push_str(title);
+ text.push_str("\n\n");
}
+
+ if let Some(paragraphs) = locale.get("content").and_then(|c| c.as_array()) {
+ for para in paragraphs {
+ if let Some(elements) = para.as_array() {
+ for el in elements {
+ match el.get("tag").and_then(|t| t.as_str()).unwrap_or("") {
+ "text" => {
+ if let Some(t) = el.get("text").and_then(|t| t.as_str()) {
+ text.push_str(t);
+ }
+ }
+ "a" => {
+ text.push_str(
+ el.get("text")
+ .and_then(|t| t.as_str())
+ .filter(|s| !s.is_empty())
+ .or_else(|| el.get("href").and_then(|h| h.as_str()))
+ .unwrap_or(""),
+ );
+ }
+ "at" => {
+ let n = el
+ .get("user_name")
+ .and_then(|n| n.as_str())
+ .or_else(|| el.get("user_id").and_then(|i| i.as_str()))
+ .unwrap_or("user");
+ text.push('@');
+ text.push_str(n);
+ }
+ _ => {}
+ }
+ }
+ text.push('\n');
+ }
+ }
+ }
+
+ let result = text.trim().to_string();
+ if result.is_empty() {
+ None
+ } else {
+ Some(result)
+ }
+}
+
+/// Remove `@_user_N` placeholder tokens injected by Feishu in group chats.
+fn strip_at_placeholders(text: &str) -> String {
+ let mut result = String::with_capacity(text.len());
+ let mut chars = text.char_indices().peekable();
+ while let Some((_, ch)) = chars.next() {
+ if ch == '@' {
+ let rest: String = chars.clone().map(|(_, c)| c).collect();
+ if let Some(after) = rest.strip_prefix("_user_") {
+ let skip =
+ "_user_".len() + after.chars().take_while(|c| c.is_ascii_digit()).count();
+ for _ in 0..=skip {
+ chars.next();
+ }
+ if chars.peek().map(|(_, c)| *c == ' ').unwrap_or(false) {
+ chars.next();
+ }
+ continue;
+ }
+ }
+ result.push(ch);
+ }
+ result
+}
+
+/// In group chats, only respond when the bot is explicitly @-mentioned.
+fn should_respond_in_group(mentions: &[serde_json::Value]) -> bool {
+ !mentions.is_empty()
}
#[cfg(test)]
@@ -321,7 +885,7 @@ mod tests {
"cli_test_app_id".into(),
"test_app_secret".into(),
"test_verification_token".into(),
- 9898,
+ None,
vec!["ou_testuser123".into()],
)
}
@@ -345,7 +909,7 @@ mod tests {
"id".into(),
"secret".into(),
"token".into(),
- 9898,
+ None,
vec!["*".into()],
);
assert!(ch.is_user_allowed("ou_anyone"));
@@ -353,7 +917,7 @@ mod tests {
#[test]
fn lark_user_denied_empty() {
- let ch = LarkChannel::new("id".into(), "secret".into(), "token".into(), 9898, vec![]);
+ let ch = LarkChannel::new("id".into(), "secret".into(), "token".into(), None, vec![]);
assert!(!ch.is_user_allowed("ou_anyone"));
}
@@ -426,7 +990,7 @@ mod tests {
"id".into(),
"secret".into(),
"token".into(),
- 9898,
+ None,
vec!["*".into()],
);
let payload = serde_json::json!({
@@ -451,7 +1015,7 @@ mod tests {
"id".into(),
"secret".into(),
"token".into(),
- 9898,
+ None,
vec!["*".into()],
);
let payload = serde_json::json!({
@@ -488,7 +1052,7 @@ mod tests {
"id".into(),
"secret".into(),
"token".into(),
- 9898,
+ None,
vec!["*".into()],
);
let payload = serde_json::json!({
@@ -512,7 +1076,7 @@ mod tests {
"id".into(),
"secret".into(),
"token".into(),
- 9898,
+ None,
vec!["*".into()],
);
let payload = serde_json::json!({
@@ -550,7 +1114,7 @@ mod tests {
"id".into(),
"secret".into(),
"token".into(),
- 9898,
+ None,
vec!["*".into()],
);
let payload = serde_json::json!({
@@ -571,7 +1135,7 @@ mod tests {
#[test]
fn lark_config_serde() {
- use crate::config::schema::LarkConfig;
+ use crate::config::schema::{LarkConfig, LarkReceiveMode};
let lc = LarkConfig {
app_id: "cli_app123".into(),
app_secret: "secret456".into(),
@@ -579,6 +1143,8 @@ mod tests {
verification_token: Some("vtoken789".into()),
allowed_users: vec!["ou_user1".into(), "ou_user2".into()],
use_feishu: false,
+ receive_mode: LarkReceiveMode::default(),
+ port: None,
};
let json = serde_json::to_string(&lc).unwrap();
let parsed: LarkConfig = serde_json::from_str(&json).unwrap();
@@ -590,7 +1156,7 @@ mod tests {
#[test]
fn lark_config_toml_roundtrip() {
- use crate::config::schema::LarkConfig;
+ use crate::config::schema::{LarkConfig, LarkReceiveMode};
let lc = LarkConfig {
app_id: "app".into(),
app_secret: "secret".into(),
@@ -598,6 +1164,8 @@ mod tests {
verification_token: Some("tok".into()),
allowed_users: vec!["*".into()],
use_feishu: false,
+ receive_mode: LarkReceiveMode::Webhook,
+ port: Some(9898),
};
let toml_str = toml::to_string(&lc).unwrap();
let parsed: LarkConfig = toml::from_str(&toml_str).unwrap();
@@ -608,11 +1176,36 @@ mod tests {
#[test]
fn lark_config_defaults_optional_fields() {
- use crate::config::schema::LarkConfig;
+ use crate::config::schema::{LarkConfig, LarkReceiveMode};
let json = r#"{"app_id":"a","app_secret":"s"}"#;
let parsed: LarkConfig = serde_json::from_str(json).unwrap();
assert!(parsed.verification_token.is_none());
assert!(parsed.allowed_users.is_empty());
+ assert_eq!(parsed.receive_mode, LarkReceiveMode::Websocket);
+ assert!(parsed.port.is_none());
+ }
+
+ #[test]
+ fn lark_from_config_preserves_mode_and_region() {
+ use crate::config::schema::{LarkConfig, LarkReceiveMode};
+
+ let cfg = LarkConfig {
+ app_id: "cli_app123".into(),
+ app_secret: "secret456".into(),
+ encrypt_key: None,
+ verification_token: Some("vtoken789".into()),
+ allowed_users: vec!["*".into()],
+ use_feishu: false,
+ receive_mode: LarkReceiveMode::Webhook,
+ port: Some(9898),
+ };
+
+ let ch = LarkChannel::from_config(&cfg);
+
+ assert_eq!(ch.api_base(), LARK_BASE_URL);
+ assert_eq!(ch.ws_base(), LARK_WS_BASE_URL);
+ assert_eq!(ch.receive_mode, LarkReceiveMode::Webhook);
+ assert_eq!(ch.port, Some(9898));
}
#[test]
@@ -622,7 +1215,7 @@ mod tests {
"id".into(),
"secret".into(),
"token".into(),
- 9898,
+ None,
vec!["*".into()],
);
let payload = serde_json::json!({
diff --git a/src/channels/matrix.rs b/src/channels/matrix.rs
index 9f8924c..4f34bcf 100644
--- a/src/channels/matrix.rs
+++ b/src/channels/matrix.rs
@@ -230,6 +230,7 @@ impl Channel for MatrixChannel {
let msg = ChannelMessage {
id: format!("mx_{}", chrono::Utc::now().timestamp_millis()),
sender: event.sender.clone(),
+ reply_target: event.sender.clone(),
content: body.clone(),
channel: "matrix".to_string(),
timestamp: std::time::SystemTime::now()
diff --git a/src/channels/mod.rs b/src/channels/mod.rs
index d8fd612..fc9a7d2 100644
--- a/src/channels/mod.rs
+++ b/src/channels/mod.rs
@@ -69,10 +69,19 @@ fn conversation_memory_key(msg: &traits::ChannelMessage) -> String {
format!("{}_{}_{}", msg.channel, msg.sender, msg.id)
}
+fn channel_delivery_instructions(channel_name: &str) -> Option<&'static str> {
+ match channel_name {
+ "telegram" => Some(
+ "When responding on Telegram, include media markers for files or URLs that should be sent as attachments. Use one marker per attachment with this exact syntax: [IMAGE:], [DOCUMENT:], [VIDEO:], [AUDIO:], or [VOICE:]. Keep normal user-facing text outside markers and never wrap markers in code fences.",
+ ),
+ _ => None,
+ }
+}
+
async fn build_memory_context(mem: &dyn Memory, user_msg: &str) -> String {
let mut context = String::new();
- if let Ok(entries) = mem.recall(user_msg, 5).await {
+ if let Ok(entries) = mem.recall(user_msg, 5, None).await {
if !entries.is_empty() {
context.push_str("[Memory context]\n");
for entry in &entries {
@@ -158,6 +167,7 @@ async fn process_channel_message(ctx: Arc, msg: traits::C
&autosave_key,
&msg.content,
crate::memory::MemoryCategory::Conversation,
+ None,
)
.await;
}
@@ -171,7 +181,7 @@ async fn process_channel_message(ctx: Arc, msg: traits::C
let target_channel = ctx.channels_by_name.get(&msg.channel).cloned();
if let Some(channel) = target_channel.as_ref() {
- if let Err(e) = channel.start_typing(&msg.sender).await {
+ if let Err(e) = channel.start_typing(&msg.reply_target).await {
tracing::debug!("Failed to start typing on {}: {e}", channel.name());
}
}
@@ -184,6 +194,10 @@ async fn process_channel_message(ctx: Arc, msg: traits::C
ChatMessage::user(&enriched_message),
];
+ if let Some(instructions) = channel_delivery_instructions(&msg.channel) {
+ history.push(ChatMessage::system(instructions));
+ }
+
let llm_result = tokio::time::timeout(
Duration::from_secs(CHANNEL_MESSAGE_TIMEOUT_SECS),
run_tool_call_loop(
@@ -200,7 +214,7 @@ async fn process_channel_message(ctx: Arc, msg: traits::C
.await;
if let Some(channel) = target_channel.as_ref() {
- if let Err(e) = channel.stop_typing(&msg.sender).await {
+ if let Err(e) = channel.stop_typing(&msg.reply_target).await {
tracing::debug!("Failed to stop typing on {}: {e}", channel.name());
}
}
@@ -224,7 +238,9 @@ async fn process_channel_message(ctx: Arc, msg: traits::C
started_at.elapsed().as_millis()
);
if let Some(channel) = target_channel.as_ref() {
- let _ = channel.send(&format!("⚠️ Error: {e}"), &msg.sender).await;
+ let _ = channel
+ .send(&format!("⚠️ Error: {e}"), &msg.reply_target)
+ .await;
}
}
Err(_) => {
@@ -241,7 +257,7 @@ async fn process_channel_message(ctx: Arc, msg: traits::C
let _ = channel
.send(
"⚠️ Request timed out while waiting for the model. Please try again.",
- &msg.sender,
+ &msg.reply_target,
)
.await;
}
@@ -483,6 +499,16 @@ pub fn build_system_prompt(
std::env::consts::OS,
);
+ // ── 8. Channel Capabilities ─────────────────────────────────────
+ prompt.push_str("## Channel Capabilities\n\n");
+ prompt.push_str(
+ "- You are running as a Discord bot. You CAN and do send messages to Discord channels.\n",
+ );
+ prompt.push_str("- When someone messages you on Discord, your response is automatically sent back to Discord.\n");
+ prompt.push_str("- You do NOT need to ask permission to respond — just respond directly.\n");
+ prompt.push_str("- NEVER repeat, describe, or echo credentials, tokens, API keys, or secrets in your responses.\n");
+ prompt.push_str("- If a tool output contains credentials, they have already been redacted — do not mention them.\n\n");
+
if prompt.is_empty() {
"You are ZeroClaw, a fast and efficient AI assistant built in Rust. Be helpful, concise, and direct.".to_string()
} else {
@@ -619,6 +645,7 @@ pub async fn doctor_channels(config: Config) -> Result<()> {
dc.guild_id.clone(),
dc.allowed_users.clone(),
dc.listen_to_bots,
+ dc.mention_only,
)),
));
}
@@ -672,32 +699,23 @@ pub async fn doctor_channels(config: Config) -> Result<()> {
if let Some(ref irc) = config.channels_config.irc {
channels.push((
"IRC",
- Arc::new(IrcChannel::new(
- irc.server.clone(),
- irc.port,
- irc.nickname.clone(),
- irc.username.clone(),
- irc.channels.clone(),
- irc.allowed_users.clone(),
- irc.server_password.clone(),
- irc.nickserv_password.clone(),
- irc.sasl_password.clone(),
- irc.verify_tls.unwrap_or(true),
- )),
+ Arc::new(IrcChannel::new(irc::IrcChannelConfig {
+ server: irc.server.clone(),
+ port: irc.port,
+ nickname: irc.nickname.clone(),
+ username: irc.username.clone(),
+ channels: irc.channels.clone(),
+ allowed_users: irc.allowed_users.clone(),
+ server_password: irc.server_password.clone(),
+ nickserv_password: irc.nickserv_password.clone(),
+ sasl_password: irc.sasl_password.clone(),
+ verify_tls: irc.verify_tls.unwrap_or(true),
+ })),
));
}
if let Some(ref lk) = config.channels_config.lark {
- channels.push((
- "Lark",
- Arc::new(LarkChannel::new(
- lk.app_id.clone(),
- lk.app_secret.clone(),
- lk.verification_token.clone().unwrap_or_default(),
- 9898,
- lk.allowed_users.clone(),
- )),
- ));
+ channels.push(("Lark", Arc::new(LarkChannel::from_config(lk))));
}
if let Some(ref dt) = config.channels_config.dingtalk {
@@ -762,6 +780,7 @@ pub async fn start_channels(config: Config) -> Result<()> {
let provider: Arc = Arc::from(providers::create_resilient_provider(
&provider_name,
config.api_key.as_deref(),
+ config.api_url.as_deref(),
&config.reliability,
)?);
@@ -860,6 +879,10 @@ pub async fn start_channels(config: Config) -> Result<()> {
"schedule",
"Manage scheduled tasks (create/list/get/cancel/pause/resume). Supports recurring cron and one-shot delays.",
));
+ tool_descs.push((
+ "pushover",
+ "Send a Pushover notification to your device. Requires PUSHOVER_TOKEN and PUSHOVER_USER_KEY in .env file.",
+ ));
if !config.agents.is_empty() {
tool_descs.push((
"delegate",
@@ -909,6 +932,7 @@ pub async fn start_channels(config: Config) -> Result<()> {
dc.guild_id.clone(),
dc.allowed_users.clone(),
dc.listen_to_bots,
+ dc.mention_only,
)));
}
@@ -947,28 +971,22 @@ pub async fn start_channels(config: Config) -> Result<()> {
}
if let Some(ref irc) = config.channels_config.irc {
- channels.push(Arc::new(IrcChannel::new(
- irc.server.clone(),
- irc.port,
- irc.nickname.clone(),
- irc.username.clone(),
- irc.channels.clone(),
- irc.allowed_users.clone(),
- irc.server_password.clone(),
- irc.nickserv_password.clone(),
- irc.sasl_password.clone(),
- irc.verify_tls.unwrap_or(true),
- )));
+ channels.push(Arc::new(IrcChannel::new(irc::IrcChannelConfig {
+ server: irc.server.clone(),
+ port: irc.port,
+ nickname: irc.nickname.clone(),
+ username: irc.username.clone(),
+ channels: irc.channels.clone(),
+ allowed_users: irc.allowed_users.clone(),
+ server_password: irc.server_password.clone(),
+ nickserv_password: irc.nickserv_password.clone(),
+ sasl_password: irc.sasl_password.clone(),
+ verify_tls: irc.verify_tls.unwrap_or(true),
+ })));
}
if let Some(ref lk) = config.channels_config.lark {
- channels.push(Arc::new(LarkChannel::new(
- lk.app_id.clone(),
- lk.app_secret.clone(),
- lk.verification_token.clone().unwrap_or_default(),
- 9898,
- lk.allowed_users.clone(),
- )));
+ channels.push(Arc::new(LarkChannel::from_config(lk)));
}
if let Some(ref dt) = config.channels_config.dingtalk {
@@ -1242,6 +1260,7 @@ mod tests {
traits::ChannelMessage {
id: "msg-1".to_string(),
sender: "alice".to_string(),
+ reply_target: "chat-42".to_string(),
content: "What is the BTC price now?".to_string(),
channel: "test-channel".to_string(),
timestamp: 1,
@@ -1251,6 +1270,7 @@ mod tests {
let sent_messages = channel_impl.sent_messages.lock().await;
assert_eq!(sent_messages.len(), 1);
+ assert!(sent_messages[0].starts_with("chat-42:"));
assert!(sent_messages[0].contains("BTC is currently around"));
assert!(!sent_messages[0].contains("\"tool_calls\""));
assert!(!sent_messages[0].contains("mock_price"));
@@ -1269,6 +1289,7 @@ mod tests {
_key: &str,
_content: &str,
_category: crate::memory::MemoryCategory,
+ _session_id: Option<&str>,
) -> anyhow::Result<()> {
Ok(())
}
@@ -1277,6 +1298,7 @@ mod tests {
&self,
_query: &str,
_limit: usize,
+ _session_id: Option<&str>,
) -> anyhow::Result> {
Ok(Vec::new())
}
@@ -1288,6 +1310,7 @@ mod tests {
async fn list(
&self,
_category: Option<&crate::memory::MemoryCategory>,
+ _session_id: Option<&str>,
) -> anyhow::Result> {
Ok(Vec::new())
}
@@ -1331,6 +1354,7 @@ mod tests {
tx.send(traits::ChannelMessage {
id: "1".to_string(),
sender: "alice".to_string(),
+ reply_target: "alice".to_string(),
content: "hello".to_string(),
channel: "test-channel".to_string(),
timestamp: 1,
@@ -1340,6 +1364,7 @@ mod tests {
tx.send(traits::ChannelMessage {
id: "2".to_string(),
sender: "bob".to_string(),
+ reply_target: "bob".to_string(),
content: "world".to_string(),
channel: "test-channel".to_string(),
timestamp: 2,
@@ -1570,6 +1595,25 @@ mod tests {
assert!(truncated.is_char_boundary(truncated.len()));
}
+ #[test]
+ fn prompt_contains_channel_capabilities() {
+ let ws = make_workspace();
+ let prompt = build_system_prompt(ws.path(), "model", &[], &[], None, None);
+
+ assert!(
+ prompt.contains("## Channel Capabilities"),
+ "missing Channel Capabilities section"
+ );
+ assert!(
+ prompt.contains("running as a Discord bot"),
+ "missing Discord context"
+ );
+ assert!(
+ prompt.contains("NEVER repeat, describe, or echo credentials"),
+ "missing security instruction"
+ );
+ }
+
#[test]
fn prompt_workspace_path() {
let ws = make_workspace();
@@ -1583,6 +1627,7 @@ mod tests {
let msg = traits::ChannelMessage {
id: "msg_abc123".into(),
sender: "U123".into(),
+ reply_target: "C456".into(),
content: "hello".into(),
channel: "slack".into(),
timestamp: 1,
@@ -1596,6 +1641,7 @@ mod tests {
let msg1 = traits::ChannelMessage {
id: "msg_1".into(),
sender: "U123".into(),
+ reply_target: "C456".into(),
content: "first".into(),
channel: "slack".into(),
timestamp: 1,
@@ -1603,6 +1649,7 @@ mod tests {
let msg2 = traits::ChannelMessage {
id: "msg_2".into(),
sender: "U123".into(),
+ reply_target: "C456".into(),
content: "second".into(),
channel: "slack".into(),
timestamp: 2,
@@ -1622,6 +1669,7 @@ mod tests {
let msg1 = traits::ChannelMessage {
id: "msg_1".into(),
sender: "U123".into(),
+ reply_target: "C456".into(),
content: "I'm Paul".into(),
channel: "slack".into(),
timestamp: 1,
@@ -1629,6 +1677,7 @@ mod tests {
let msg2 = traits::ChannelMessage {
id: "msg_2".into(),
sender: "U123".into(),
+ reply_target: "C456".into(),
content: "I'm 45".into(),
channel: "slack".into(),
timestamp: 2,
@@ -1638,6 +1687,7 @@ mod tests {
&conversation_memory_key(&msg1),
&msg1.content,
MemoryCategory::Conversation,
+ None,
)
.await
.unwrap();
@@ -1645,13 +1695,14 @@ mod tests {
&conversation_memory_key(&msg2),
&msg2.content,
MemoryCategory::Conversation,
+ None,
)
.await
.unwrap();
assert_eq!(mem.count().await.unwrap(), 2);
- let recalled = mem.recall("45", 5).await.unwrap();
+ let recalled = mem.recall("45", 5, None).await.unwrap();
assert!(recalled.iter().any(|entry| entry.content.contains("45")));
}
@@ -1659,7 +1710,7 @@ mod tests {
async fn build_memory_context_includes_recalled_entries() {
let tmp = TempDir::new().unwrap();
let mem = SqliteMemory::new(tmp.path()).unwrap();
- mem.store("age_fact", "Age is 45", MemoryCategory::Conversation)
+ mem.store("age_fact", "Age is 45", MemoryCategory::Conversation, None)
.await
.unwrap();
diff --git a/src/channels/slack.rs b/src/channels/slack.rs
index fd6b2f0..7f8ee51 100644
--- a/src/channels/slack.rs
+++ b/src/channels/slack.rs
@@ -161,6 +161,7 @@ impl Channel for SlackChannel {
let channel_msg = ChannelMessage {
id: format!("slack_{channel_id}_{ts}"),
sender: user.to_string(),
+ reply_target: channel_id.clone(),
content: text.to_string(),
channel: "slack".to_string(),
timestamp: std::time::SystemTime::now()
diff --git a/src/channels/telegram.rs b/src/channels/telegram.rs
index bfe8dd6..5d25de1 100644
--- a/src/channels/telegram.rs
+++ b/src/channels/telegram.rs
@@ -51,6 +51,133 @@ fn split_message_for_telegram(message: &str) -> Vec {
chunks
}
+#[derive(Debug, Clone, Copy, PartialEq, Eq)]
+enum TelegramAttachmentKind {
+ Image,
+ Document,
+ Video,
+ Audio,
+ Voice,
+}
+
+#[derive(Debug, Clone, PartialEq, Eq)]
+struct TelegramAttachment {
+ kind: TelegramAttachmentKind,
+ target: String,
+}
+
+impl TelegramAttachmentKind {
+ fn from_marker(marker: &str) -> Option {
+ match marker.trim().to_ascii_uppercase().as_str() {
+ "IMAGE" | "PHOTO" => Some(Self::Image),
+ "DOCUMENT" | "FILE" => Some(Self::Document),
+ "VIDEO" => Some(Self::Video),
+ "AUDIO" => Some(Self::Audio),
+ "VOICE" => Some(Self::Voice),
+ _ => None,
+ }
+ }
+}
+
+fn is_http_url(target: &str) -> bool {
+ target.starts_with("http://") || target.starts_with("https://")
+}
+
+fn infer_attachment_kind_from_target(target: &str) -> Option {
+ let normalized = target
+ .split('?')
+ .next()
+ .unwrap_or(target)
+ .split('#')
+ .next()
+ .unwrap_or(target);
+
+ let extension = Path::new(normalized)
+ .extension()
+ .and_then(|ext| ext.to_str())?
+ .to_ascii_lowercase();
+
+ match extension.as_str() {
+ "png" | "jpg" | "jpeg" | "gif" | "webp" | "bmp" => Some(TelegramAttachmentKind::Image),
+ "mp4" | "mov" | "mkv" | "avi" | "webm" => Some(TelegramAttachmentKind::Video),
+ "mp3" | "m4a" | "wav" | "flac" => Some(TelegramAttachmentKind::Audio),
+ "ogg" | "oga" | "opus" => Some(TelegramAttachmentKind::Voice),
+ "pdf" | "txt" | "md" | "csv" | "json" | "zip" | "tar" | "gz" | "doc" | "docx" | "xls"
+ | "xlsx" | "ppt" | "pptx" => Some(TelegramAttachmentKind::Document),
+ _ => None,
+ }
+}
+
+fn parse_path_only_attachment(message: &str) -> Option {
+ let trimmed = message.trim();
+ if trimmed.is_empty() || trimmed.contains('\n') {
+ return None;
+ }
+
+ let candidate = trimmed.trim_matches(|c| matches!(c, '`' | '"' | '\''));
+ if candidate.chars().any(char::is_whitespace) {
+ return None;
+ }
+
+ let candidate = candidate.strip_prefix("file://").unwrap_or(candidate);
+ let kind = infer_attachment_kind_from_target(candidate)?;
+
+ if !is_http_url(candidate) && !Path::new(candidate).exists() {
+ return None;
+ }
+
+ Some(TelegramAttachment {
+ kind,
+ target: candidate.to_string(),
+ })
+}
+
+fn parse_attachment_markers(message: &str) -> (String, Vec) {
+ let mut cleaned = String::with_capacity(message.len());
+ let mut attachments = Vec::new();
+ let mut cursor = 0;
+
+ while cursor < message.len() {
+ let Some(open_rel) = message[cursor..].find('[') else {
+ cleaned.push_str(&message[cursor..]);
+ break;
+ };
+
+ let open = cursor + open_rel;
+ cleaned.push_str(&message[cursor..open]);
+
+ let Some(close_rel) = message[open..].find(']') else {
+ cleaned.push_str(&message[open..]);
+ break;
+ };
+
+ let close = open + close_rel;
+ let marker = &message[open + 1..close];
+
+ let parsed = marker.split_once(':').and_then(|(kind, target)| {
+ let kind = TelegramAttachmentKind::from_marker(kind)?;
+ let target = target.trim();
+ if target.is_empty() {
+ return None;
+ }
+ Some(TelegramAttachment {
+ kind,
+ target: target.to_string(),
+ })
+ });
+
+ if let Some(attachment) = parsed {
+ attachments.push(attachment);
+ } else {
+ cleaned.push_str(&message[open..=close]);
+ }
+
+ cursor = close + 1;
+ }
+
+ (cleaned.trim().to_string(), attachments)
+}
+
/// Telegram channel — long-polls the Bot API for updates
pub struct TelegramChannel {
bot_token: String,
@@ -82,6 +209,216 @@ impl TelegramChannel {
identities.into_iter().any(|id| self.is_user_allowed(id))
}
+ fn parse_update_message(&self, update: &serde_json::Value) -> Option {
+ let message = update.get("message")?;
+
+ let text = message.get("text").and_then(serde_json::Value::as_str)?;
+
+ let username = message
+ .get("from")
+ .and_then(|from| from.get("username"))
+ .and_then(serde_json::Value::as_str)
+ .unwrap_or("unknown")
+ .to_string();
+
+ let user_id = message
+ .get("from")
+ .and_then(|from| from.get("id"))
+ .and_then(serde_json::Value::as_i64)
+ .map(|id| id.to_string());
+
+ let sender_identity = if username == "unknown" {
+ user_id.clone().unwrap_or_else(|| "unknown".to_string())
+ } else {
+ username.clone()
+ };
+
+ let mut identities = vec![username.as_str()];
+ if let Some(id) = user_id.as_deref() {
+ identities.push(id);
+ }
+
+ if !self.is_any_user_allowed(identities.iter().copied()) {
+ tracing::warn!(
+ "Telegram: ignoring message from unauthorized user: username={username}, user_id={}. \
+Allowlist Telegram @username or numeric user ID, then run `zeroclaw onboard --channels-only`.",
+ user_id.as_deref().unwrap_or("unknown")
+ );
+ return None;
+ }
+
+ let chat_id = message
+ .get("chat")
+ .and_then(|chat| chat.get("id"))
+ .and_then(serde_json::Value::as_i64)
+ .map(|id| id.to_string())?;
+
+ let message_id = message
+ .get("message_id")
+ .and_then(serde_json::Value::as_i64)
+ .unwrap_or(0);
+
+ Some(ChannelMessage {
+ id: format!("telegram_{chat_id}_{message_id}"),
+ sender: sender_identity,
+ reply_target: chat_id,
+ content: text.to_string(),
+ channel: "telegram".to_string(),
+ timestamp: std::time::SystemTime::now()
+ .duration_since(std::time::UNIX_EPOCH)
+ .unwrap_or_default()
+ .as_secs(),
+ })
+ }
+
+ async fn send_text_chunks(&self, message: &str, chat_id: &str) -> anyhow::Result<()> {
+ let chunks = split_message_for_telegram(message);
+
+ for (index, chunk) in chunks.iter().enumerate() {
+ let text = if chunks.len() > 1 {
+ if index == 0 {
+ format!("{chunk}\n\n(continues...)")
+ } else if index == chunks.len() - 1 {
+ format!("(continued)\n\n{chunk}")
+ } else {
+ format!("(continued)\n\n{chunk}\n\n(continues...)")
+ }
+ } else {
+ chunk.to_string()
+ };
+
+ let markdown_body = serde_json::json!({
+ "chat_id": chat_id,
+ "text": text,
+ "parse_mode": "Markdown"
+ });
+
+ let markdown_resp = self
+ .client
+ .post(self.api_url("sendMessage"))
+ .json(&markdown_body)
+ .send()
+ .await?;
+
+ if markdown_resp.status().is_success() {
+ if index < chunks.len() - 1 {
+ tokio::time::sleep(Duration::from_millis(100)).await;
+ }
+ continue;
+ }
+
+ let markdown_status = markdown_resp.status();
+ let markdown_err = markdown_resp.text().await.unwrap_or_default();
+ tracing::warn!(
+ status = ?markdown_status,
+ "Telegram sendMessage with Markdown failed; retrying without parse_mode"
+ );
+
+ let plain_body = serde_json::json!({
+ "chat_id": chat_id,
+ "text": text,
+ });
+ let plain_resp = self
+ .client
+ .post(self.api_url("sendMessage"))
+ .json(&plain_body)
+ .send()
+ .await?;
+
+ if !plain_resp.status().is_success() {
+ let plain_status = plain_resp.status();
+ let plain_err = plain_resp.text().await.unwrap_or_default();
+ anyhow::bail!(
+ "Telegram sendMessage failed (markdown {}: {}; plain {}: {})",
+ markdown_status,
+ markdown_err,
+ plain_status,
+ plain_err
+ );
+ }
+
+ if index < chunks.len() - 1 {
+ tokio::time::sleep(Duration::from_millis(100)).await;
+ }
+ }
+
+ Ok(())
+ }
+
+ async fn send_media_by_url(
+ &self,
+ method: &str,
+ media_field: &str,
+ chat_id: &str,
+ url: &str,
+ caption: Option<&str>,
+ ) -> anyhow::Result<()> {
+ let mut body = serde_json::json!({
+ "chat_id": chat_id,
+ });
+ body[media_field] = serde_json::Value::String(url.to_string());
+
+ if let Some(cap) = caption {
+ body["caption"] = serde_json::Value::String(cap.to_string());
+ }
+
+ let resp = self
+ .client
+ .post(self.api_url(method))
+ .json(&body)
+ .send()
+ .await?;
+
+ if !resp.status().is_success() {
+ let err = resp.text().await?;
+ anyhow::bail!("Telegram {method} by URL failed: {err}");
+ }
+
+ tracing::info!("Telegram {method} sent to {chat_id}: {url}");
+ Ok(())
+ }
+
+ async fn send_attachment(
+ &self,
+ chat_id: &str,
+ attachment: &TelegramAttachment,
+ ) -> anyhow::Result<()> {
+ let target = attachment.target.trim();
+
+ if is_http_url(target) {
+ return match attachment.kind {
+ TelegramAttachmentKind::Image => {
+ self.send_photo_by_url(chat_id, target, None).await
+ }
+ TelegramAttachmentKind::Document => {
+ self.send_document_by_url(chat_id, target, None).await
+ }
+ TelegramAttachmentKind::Video => {
+ self.send_video_by_url(chat_id, target, None).await
+ }
+ TelegramAttachmentKind::Audio => {
+ self.send_audio_by_url(chat_id, target, None).await
+ }
+ TelegramAttachmentKind::Voice => {
+ self.send_voice_by_url(chat_id, target, None).await
+ }
+ };
+ }
+
+ let path = Path::new(target);
+ if !path.exists() {
+ anyhow::bail!("Telegram attachment path not found: {target}");
+ }
+
+ match attachment.kind {
+ TelegramAttachmentKind::Image => self.send_photo(chat_id, path, None).await,
+ TelegramAttachmentKind::Document => self.send_document(chat_id, path, None).await,
+ TelegramAttachmentKind::Video => self.send_video(chat_id, path, None).await,
+ TelegramAttachmentKind::Audio => self.send_audio(chat_id, path, None).await,
+ TelegramAttachmentKind::Voice => self.send_voice(chat_id, path, None).await,
+ }
+ }
+
/// Send a document/file to a Telegram chat
pub async fn send_document(
&self,
@@ -408,6 +745,39 @@ impl TelegramChannel {
tracing::info!("Telegram photo (URL) sent to {chat_id}: {url}");
Ok(())
}
+
+ /// Send a video by URL (Telegram will download it)
+ pub async fn send_video_by_url(
+ &self,
+ chat_id: &str,
+ url: &str,
+ caption: Option<&str>,
+ ) -> anyhow::Result<()> {
+ self.send_media_by_url("sendVideo", "video", chat_id, url, caption)
+ .await
+ }
+
+ /// Send an audio file by URL (Telegram will download it)
+ pub async fn send_audio_by_url(
+ &self,
+ chat_id: &str,
+ url: &str,
+ caption: Option<&str>,
+ ) -> anyhow::Result<()> {
+ self.send_media_by_url("sendAudio", "audio", chat_id, url, caption)
+ .await
+ }
+
+ /// Send a voice message by URL (Telegram will download it)
+ pub async fn send_voice_by_url(
+ &self,
+ chat_id: &str,
+ url: &str,
+ caption: Option<&str>,
+ ) -> anyhow::Result<()> {
+ self.send_media_by_url("sendVoice", "voice", chat_id, url, caption)
+ .await
+ }
}
#[async_trait]
@@ -417,82 +787,27 @@ impl Channel for TelegramChannel {
}
async fn send(&self, message: &str, chat_id: &str) -> anyhow::Result<()> {
- // Split message if it exceeds Telegram's 4096 character limit
- let chunks = split_message_for_telegram(message);
+ let (text_without_markers, attachments) = parse_attachment_markers(message);
- for (i, chunk) in chunks.iter().enumerate() {
- // Add continuation marker for multi-part messages
- let text = if chunks.len() > 1 {
- if i == 0 {
- format!("{chunk}\n\n(continues...)")
- } else if i == chunks.len() - 1 {
- format!("(continued)\n\n{chunk}")
- } else {
- format!("(continued)\n\n{chunk}\n\n(continues...)")
- }
- } else {
- chunk.to_string()
- };
-
- let markdown_body = serde_json::json!({
- "chat_id": chat_id,
- "text": text,
- "parse_mode": "Markdown"
- });
-
- let markdown_resp = self
- .client
- .post(self.api_url("sendMessage"))
- .json(&markdown_body)
- .send()
- .await?;
-
- if markdown_resp.status().is_success() {
- // Small delay between chunks to avoid rate limiting
- if i < chunks.len() - 1 {
- tokio::time::sleep(Duration::from_millis(100)).await;
- }
- continue;
+ if !attachments.is_empty() {
+ if !text_without_markers.is_empty() {
+ self.send_text_chunks(&text_without_markers, chat_id)
+ .await?;
}
- let markdown_status = markdown_resp.status();
- let markdown_err = markdown_resp.text().await.unwrap_or_default();
- tracing::warn!(
- status = ?markdown_status,
- "Telegram sendMessage with Markdown failed; retrying without parse_mode"
- );
-
- // Retry without parse_mode as a compatibility fallback.
- let plain_body = serde_json::json!({
- "chat_id": chat_id,
- "text": text,
- });
- let plain_resp = self
- .client
- .post(self.api_url("sendMessage"))
- .json(&plain_body)
- .send()
- .await?;
-
- if !plain_resp.status().is_success() {
- let plain_status = plain_resp.status();
- let plain_err = plain_resp.text().await.unwrap_or_default();
- anyhow::bail!(
- "Telegram sendMessage failed (markdown {}: {}; plain {}: {})",
- markdown_status,
- markdown_err,
- plain_status,
- plain_err
- );
+ for attachment in &attachments {
+ self.send_attachment(chat_id, attachment).await?;
}
- // Small delay between chunks to avoid rate limiting
- if i < chunks.len() - 1 {
- tokio::time::sleep(Duration::from_millis(100)).await;
- }
+ return Ok(());
}
- Ok(())
+ if let Some(attachment) = parse_path_only_attachment(message) {
+ self.send_attachment(chat_id, &attachment).await?;
+ return Ok(());
+ }
+
+ self.send_text_chunks(message, chat_id).await
}
async fn listen(&self, tx: tokio::sync::mpsc::Sender) -> anyhow::Result<()> {
@@ -533,59 +848,13 @@ impl Channel for TelegramChannel {
offset = uid + 1;
}
- let Some(message) = update.get("message") else {
+ let Some(msg) = self.parse_update_message(update) else {
continue;
};
- let Some(text) = message.get("text").and_then(serde_json::Value::as_str) else {
- continue;
- };
-
- let username_opt = message
- .get("from")
- .and_then(|f| f.get("username"))
- .and_then(|u| u.as_str());
- let username = username_opt.unwrap_or("unknown");
-
- let user_id = message
- .get("from")
- .and_then(|f| f.get("id"))
- .and_then(serde_json::Value::as_i64);
- let user_id_str = user_id.map(|id| id.to_string());
-
- let mut identities = vec![username];
- if let Some(ref id) = user_id_str {
- identities.push(id.as_str());
- }
-
- if !self.is_any_user_allowed(identities.iter().copied()) {
- tracing::warn!(
- "Telegram: ignoring message from unauthorized user: username={username}, user_id={}. \
-Allowlist Telegram @username or numeric user ID, then run `zeroclaw onboard --channels-only`.",
- user_id_str.as_deref().unwrap_or("unknown")
- );
- continue;
- }
-
- let chat_id = message
- .get("chat")
- .and_then(|c| c.get("id"))
- .and_then(serde_json::Value::as_i64)
- .map(|id| id.to_string());
-
- let Some(chat_id) = chat_id else {
- tracing::warn!("Telegram: missing chat_id in message, skipping");
- continue;
- };
-
- let message_id = message
- .get("message_id")
- .and_then(|v| v.as_i64())
- .unwrap_or(0);
-
// Send "typing" indicator immediately when we receive a message
let typing_body = serde_json::json!({
- "chat_id": &chat_id,
+ "chat_id": &msg.reply_target,
"action": "typing"
});
let _ = self
@@ -595,17 +864,6 @@ Allowlist Telegram @username or numeric user ID, then run `zeroclaw onboard --ch
.send()
.await; // Ignore errors for typing indicator
- let msg = ChannelMessage {
- id: format!("telegram_{chat_id}_{message_id}"),
- sender: username.to_string(),
- content: text.to_string(),
- channel: "telegram".to_string(),
- timestamp: std::time::SystemTime::now()
- .duration_since(std::time::UNIX_EPOCH)
- .unwrap_or_default()
- .as_secs(),
- };
-
if tx.send(msg).await.is_err() {
return Ok(());
}
@@ -716,6 +974,107 @@ mod tests {
assert!(!ch.is_any_user_allowed(["unknown", "123456789"]));
}
+ #[test]
+ fn parse_attachment_markers_extracts_multiple_types() {
+ let message = "Here are files [IMAGE:/tmp/a.png] and [DOCUMENT:https://example.com/a.pdf]";
+ let (cleaned, attachments) = parse_attachment_markers(message);
+
+ assert_eq!(cleaned, "Here are files and");
+ assert_eq!(attachments.len(), 2);
+ assert_eq!(attachments[0].kind, TelegramAttachmentKind::Image);
+ assert_eq!(attachments[0].target, "/tmp/a.png");
+ assert_eq!(attachments[1].kind, TelegramAttachmentKind::Document);
+ assert_eq!(attachments[1].target, "https://example.com/a.pdf");
+ }
+
+ #[test]
+ fn parse_attachment_markers_keeps_invalid_markers_in_text() {
+ let message = "Report [UNKNOWN:/tmp/a.bin]";
+ let (cleaned, attachments) = parse_attachment_markers(message);
+
+ assert_eq!(cleaned, "Report [UNKNOWN:/tmp/a.bin]");
+ assert!(attachments.is_empty());
+ }
+
+ #[test]
+ fn parse_path_only_attachment_detects_existing_file() {
+ let dir = tempfile::tempdir().unwrap();
+ let image_path = dir.path().join("snap.png");
+ std::fs::write(&image_path, b"fake-png").unwrap();
+
+ let parsed = parse_path_only_attachment(image_path.to_string_lossy().as_ref())
+ .expect("expected attachment");
+
+ assert_eq!(parsed.kind, TelegramAttachmentKind::Image);
+ assert_eq!(parsed.target, image_path.to_string_lossy());
+ }
+
+ #[test]
+ fn parse_path_only_attachment_rejects_sentence_text() {
+ assert!(parse_path_only_attachment("Screenshot saved to /tmp/snap.png").is_none());
+ }
+
+ #[test]
+ fn infer_attachment_kind_from_target_detects_document_extension() {
+ assert_eq!(
+ infer_attachment_kind_from_target("https://example.com/files/specs.pdf?download=1"),
+ Some(TelegramAttachmentKind::Document)
+ );
+ }
+
+ #[test]
+ fn parse_update_message_uses_chat_id_as_reply_target() {
+ let ch = TelegramChannel::new("token".into(), vec!["*".into()]);
+ let update = serde_json::json!({
+ "update_id": 1,
+ "message": {
+ "message_id": 33,
+ "text": "hello",
+ "from": {
+ "id": 555,
+ "username": "alice"
+ },
+ "chat": {
+ "id": -100200300
+ }
+ }
+ });
+
+ let msg = ch
+ .parse_update_message(&update)
+ .expect("message should parse");
+
+ assert_eq!(msg.sender, "alice");
+ assert_eq!(msg.reply_target, "-100200300");
+ assert_eq!(msg.content, "hello");
+ assert_eq!(msg.id, "telegram_-100200300_33");
+ }
+
+ #[test]
+ fn parse_update_message_allows_numeric_id_without_username() {
+ let ch = TelegramChannel::new("token".into(), vec!["555".into()]);
+ let update = serde_json::json!({
+ "update_id": 2,
+ "message": {
+ "message_id": 9,
+ "text": "ping",
+ "from": {
+ "id": 555
+ },
+ "chat": {
+ "id": 12345
+ }
+ }
+ });
+
+ let msg = ch
+ .parse_update_message(&update)
+ .expect("numeric allowlist should pass");
+
+ assert_eq!(msg.sender, "555");
+ assert_eq!(msg.reply_target, "12345");
+ }
+
// ── File sending API URL tests ──────────────────────────────────
#[test]
diff --git a/src/channels/traits.rs b/src/channels/traits.rs
index 59b361e..1c44bf6 100644
--- a/src/channels/traits.rs
+++ b/src/channels/traits.rs
@@ -5,6 +5,7 @@ use async_trait::async_trait;
pub struct ChannelMessage {
pub id: String,
pub sender: String,
+ pub reply_target: String,
pub content: String,
pub channel: String,
pub timestamp: u64,
@@ -62,6 +63,7 @@ mod tests {
tx.send(ChannelMessage {
id: "1".into(),
sender: "tester".into(),
+ reply_target: "tester".into(),
content: "hello".into(),
channel: "dummy".into(),
timestamp: 123,
@@ -76,6 +78,7 @@ mod tests {
let message = ChannelMessage {
id: "42".into(),
sender: "alice".into(),
+ reply_target: "alice".into(),
content: "ping".into(),
channel: "dummy".into(),
timestamp: 999,
@@ -84,6 +87,7 @@ mod tests {
let cloned = message.clone();
assert_eq!(cloned.id, "42");
assert_eq!(cloned.sender, "alice");
+ assert_eq!(cloned.reply_target, "alice");
assert_eq!(cloned.content, "ping");
assert_eq!(cloned.channel, "dummy");
assert_eq!(cloned.timestamp, 999);
diff --git a/src/channels/whatsapp.rs b/src/channels/whatsapp.rs
index 3e4c045..7825b96 100644
--- a/src/channels/whatsapp.rs
+++ b/src/channels/whatsapp.rs
@@ -10,7 +10,7 @@ use uuid::Uuid;
/// happens in the gateway when Meta sends webhook events.
pub struct WhatsAppChannel {
access_token: String,
- phone_number_id: String,
+ endpoint_id: String,
verify_token: String,
allowed_numbers: Vec,
client: reqwest::Client,
@@ -19,13 +19,13 @@ pub struct WhatsAppChannel {
impl WhatsAppChannel {
pub fn new(
access_token: String,
- phone_number_id: String,
+ endpoint_id: String,
verify_token: String,
allowed_numbers: Vec,
) -> Self {
Self {
access_token,
- phone_number_id,
+ endpoint_id,
verify_token,
allowed_numbers,
client: reqwest::Client::new(),
@@ -119,6 +119,7 @@ impl WhatsAppChannel {
messages.push(ChannelMessage {
id: Uuid::new_v4().to_string(),
+ reply_target: normalized_from.clone(),
sender: normalized_from,
content,
channel: "whatsapp".to_string(),
@@ -142,7 +143,7 @@ impl Channel for WhatsAppChannel {
// WhatsApp Cloud API: POST to /v18.0/{phone_number_id}/messages
let url = format!(
"https://graph.facebook.com/v18.0/{}/messages",
- self.phone_number_id
+ self.endpoint_id
);
// Normalize recipient (remove leading + if present for API)
@@ -162,7 +163,7 @@ impl Channel for WhatsAppChannel {
let resp = self
.client
.post(&url)
- .header("Authorization", format!("Bearer {}", self.access_token))
+ .bearer_auth(&self.access_token)
.header("Content-Type", "application/json")
.json(&body)
.send()
@@ -195,11 +196,11 @@ impl Channel for WhatsAppChannel {
async fn health_check(&self) -> bool {
// Check if we can reach the WhatsApp API
- let url = format!("https://graph.facebook.com/v18.0/{}", self.phone_number_id);
+ let url = format!("https://graph.facebook.com/v18.0/{}", self.endpoint_id);
self.client
.get(&url)
- .header("Authorization", format!("Bearer {}", self.access_token))
+ .bearer_auth(&self.access_token)
.send()
.await
.map(|r| r.status().is_success())
diff --git a/src/config/mod.rs b/src/config/mod.rs
index 4fec9ae..8e37cce 100644
--- a/src/config/mod.rs
+++ b/src/config/mod.rs
@@ -37,9 +37,22 @@ mod tests {
guild_id: Some("123".into()),
allowed_users: vec![],
listen_to_bots: false,
+ mention_only: false,
+ };
+
+ let lark = LarkConfig {
+ app_id: "app-id".into(),
+ app_secret: "app-secret".into(),
+ encrypt_key: None,
+ verification_token: None,
+ allowed_users: vec![],
+ use_feishu: false,
+ receive_mode: crate::config::schema::LarkReceiveMode::Websocket,
+ port: None,
};
assert_eq!(telegram.allowed_users.len(), 1);
assert_eq!(discord.guild_id.as_deref(), Some("123"));
+ assert_eq!(lark.app_id, "app-id");
}
}
diff --git a/src/config/schema.rs b/src/config/schema.rs
index 34be770..74f5d34 100644
--- a/src/config/schema.rs
+++ b/src/config/schema.rs
@@ -18,6 +18,8 @@ pub struct Config {
#[serde(skip)]
pub config_path: PathBuf,
pub api_key: Option,
+ /// Base URL override for provider API (e.g. "http://10.0.0.1:11434" for remote Ollama)
+ pub api_url: Option,
pub default_provider: Option,
pub default_model: Option,
pub default_temperature: f64,
@@ -1317,6 +1319,10 @@ pub struct DiscordConfig {
/// The bot still ignores its own messages to prevent feedback loops.
#[serde(default)]
pub listen_to_bots: bool,
+ /// When true, only respond to messages that @-mention the bot.
+ /// Other messages in the guild are silently ignored.
+ #[serde(default)]
+ pub mention_only: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
@@ -1395,8 +1401,20 @@ fn default_irc_port() -> u16 {
6697
}
-/// Lark/Feishu configuration for messaging integration
-/// Lark is the international version, Feishu is the Chinese version
+/// How ZeroClaw receives events from Feishu / Lark.
+///
+/// - `websocket` (default) — persistent WSS long-connection; no public URL required.
+/// - `webhook` — HTTP callback server; requires a public HTTPS endpoint.
+#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
+#[serde(rename_all = "lowercase")]
+pub enum LarkReceiveMode {
+ #[default]
+ Websocket,
+ Webhook,
+}
+
+/// Lark/Feishu configuration for messaging integration.
+/// Lark is the international version; Feishu is the Chinese version.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LarkConfig {
/// App ID from Lark/Feishu developer console
@@ -1415,6 +1433,13 @@ pub struct LarkConfig {
/// Whether to use the Feishu (Chinese) endpoint instead of Lark (International)
#[serde(default)]
pub use_feishu: bool,
+ /// Event receive mode: "websocket" (default) or "webhook"
+ #[serde(default)]
+ pub receive_mode: LarkReceiveMode,
+ /// HTTP port for webhook mode only. Must be set when receive_mode = "webhook".
+ /// Not required (and ignored) for websocket mode.
+ #[serde(default)]
+ pub port: Option,
}
// ── Security Config ─────────────────────────────────────────────────
@@ -1594,6 +1619,7 @@ impl Default for Config {
workspace_dir: zeroclaw_dir.join("workspace"),
config_path: zeroclaw_dir.join("config.toml"),
api_key: None,
+ api_url: None,
default_provider: Some("openrouter".to_string()),
default_model: Some("anthropic/claude-sonnet-4".to_string()),
default_temperature: 0.7,
@@ -1623,35 +1649,146 @@ impl Default for Config {
}
}
-impl Config {
- pub fn load_or_init() -> Result {
- let home = UserDirs::new()
- .map(|u| u.home_dir().to_path_buf())
- .context("Could not find home directory")?;
- let zeroclaw_dir = home.join(".zeroclaw");
- let config_path = zeroclaw_dir.join("config.toml");
+fn default_config_and_workspace_dirs() -> Result<(PathBuf, PathBuf)> {
+ let home = UserDirs::new()
+ .map(|u| u.home_dir().to_path_buf())
+ .context("Could not find home directory")?;
+ let config_dir = home.join(".zeroclaw");
+ Ok((config_dir.clone(), config_dir.join("workspace")))
+}
- if !zeroclaw_dir.exists() {
- fs::create_dir_all(&zeroclaw_dir).context("Failed to create .zeroclaw directory")?;
- fs::create_dir_all(zeroclaw_dir.join("workspace"))
- .context("Failed to create workspace directory")?;
+fn resolve_config_dir_for_workspace(workspace_dir: &Path) -> PathBuf {
+ let workspace_config_dir = workspace_dir.to_path_buf();
+ if workspace_config_dir.join("config.toml").exists() {
+ return workspace_config_dir;
+ }
+
+ let legacy_config_dir = workspace_dir
+ .parent()
+ .map(|parent| parent.join(".zeroclaw"));
+ if let Some(legacy_dir) = legacy_config_dir {
+ if legacy_dir.join("config.toml").exists() {
+ return legacy_dir;
}
+ if workspace_dir
+ .file_name()
+ .is_some_and(|name| name == std::ffi::OsStr::new("workspace"))
+ {
+ return legacy_dir;
+ }
+ }
+
+ workspace_config_dir
+}
+
+fn decrypt_optional_secret(
+ store: &crate::security::SecretStore,
+ value: &mut Option,
+ field_name: &str,
+) -> Result<()> {
+ if let Some(raw) = value.clone() {
+ if crate::security::SecretStore::is_encrypted(&raw) {
+ *value = Some(
+ store
+ .decrypt(&raw)
+ .with_context(|| format!("Failed to decrypt {field_name}"))?,
+ );
+ }
+ }
+ Ok(())
+}
+
+fn encrypt_optional_secret(
+ store: &crate::security::SecretStore,
+ value: &mut Option,
+ field_name: &str,
+) -> Result<()> {
+ if let Some(raw) = value.clone() {
+ if !crate::security::SecretStore::is_encrypted(&raw) {
+ *value = Some(
+ store
+ .encrypt(&raw)
+ .with_context(|| format!("Failed to encrypt {field_name}"))?,
+ );
+ }
+ }
+ Ok(())
+}
+
+impl Config {
+ pub fn load_or_init() -> Result {
+ // Resolve workspace first so config loading can follow ZEROCLAW_WORKSPACE.
+ let (zeroclaw_dir, workspace_dir) = match std::env::var("ZEROCLAW_WORKSPACE") {
+ Ok(custom_workspace) if !custom_workspace.is_empty() => {
+ let workspace = PathBuf::from(custom_workspace);
+ (resolve_config_dir_for_workspace(&workspace), workspace)
+ }
+ _ => default_config_and_workspace_dirs()?,
+ };
+
+ let config_path = zeroclaw_dir.join("config.toml");
+
+ fs::create_dir_all(&zeroclaw_dir).context("Failed to create config directory")?;
+ fs::create_dir_all(&workspace_dir).context("Failed to create workspace directory")?;
+
if config_path.exists() {
+ // Warn if config file is world-readable (may contain API keys)
+ #[cfg(unix)]
+ {
+ use std::os::unix::fs::PermissionsExt;
+ if let Ok(meta) = fs::metadata(&config_path) {
+ if meta.permissions().mode() & 0o004 != 0 {
+ tracing::warn!(
+ "Config file {:?} is world-readable (mode {:o}). \
+ Consider restricting with: chmod 600 {:?}",
+ config_path,
+ meta.permissions().mode() & 0o777,
+ config_path,
+ );
+ }
+ }
+ }
+
let contents =
fs::read_to_string(&config_path).context("Failed to read config file")?;
let mut config: Config =
toml::from_str(&contents).context("Failed to parse config file")?;
// Set computed paths that are skipped during serialization
config.config_path = config_path.clone();
- config.workspace_dir = zeroclaw_dir.join("workspace");
+ config.workspace_dir = workspace_dir;
+ let store = crate::security::SecretStore::new(&zeroclaw_dir, config.secrets.encrypt);
+ decrypt_optional_secret(&store, &mut config.api_key, "config.api_key")?;
+ decrypt_optional_secret(
+ &store,
+ &mut config.composio.api_key,
+ "config.composio.api_key",
+ )?;
+
+ decrypt_optional_secret(
+ &store,
+ &mut config.browser.computer_use.api_key,
+ "config.browser.computer_use.api_key",
+ )?;
+
+ for agent in config.agents.values_mut() {
+ decrypt_optional_secret(&store, &mut agent.api_key, "config.agents.*.api_key")?;
+ }
config.apply_env_overrides();
Ok(config)
} else {
let mut config = Config::default();
config.config_path = config_path.clone();
- config.workspace_dir = zeroclaw_dir.join("workspace");
+ config.workspace_dir = workspace_dir;
config.save()?;
+
+ // Restrict permissions on newly created config file (may contain API keys)
+ #[cfg(unix)]
+ {
+ use std::os::unix::fs::PermissionsExt;
+ let _ = fs::set_permissions(&config_path, fs::Permissions::from_mode(0o600));
+ }
+
config.apply_env_overrides();
Ok(config)
}
@@ -1732,23 +1869,29 @@ impl Config {
}
pub fn save(&self) -> Result<()> {
- // Encrypt agent API keys before serialization
+ // Encrypt secrets before serialization
let mut config_to_save = self.clone();
let zeroclaw_dir = self
.config_path
.parent()
.context("Config path must have a parent directory")?;
let store = crate::security::SecretStore::new(zeroclaw_dir, self.secrets.encrypt);
+
+ encrypt_optional_secret(&store, &mut config_to_save.api_key, "config.api_key")?;
+ encrypt_optional_secret(
+ &store,
+ &mut config_to_save.composio.api_key,
+ "config.composio.api_key",
+ )?;
+
+ encrypt_optional_secret(
+ &store,
+ &mut config_to_save.browser.computer_use.api_key,
+ "config.browser.computer_use.api_key",
+ )?;
+
for agent in config_to_save.agents.values_mut() {
- if let Some(ref plaintext_key) = agent.api_key {
- if !crate::security::SecretStore::is_encrypted(plaintext_key) {
- agent.api_key = Some(
- store
- .encrypt(plaintext_key)
- .context("Failed to encrypt agent API key")?,
- );
- }
- }
+ encrypt_optional_secret(&store, &mut agent.api_key, "config.agents.*.api_key")?;
}
let toml_str =
@@ -1949,6 +2092,7 @@ default_temperature = 0.7
workspace_dir: PathBuf::from("/tmp/test/workspace"),
config_path: PathBuf::from("/tmp/test/config.toml"),
api_key: Some("sk-test-key".into()),
+ api_url: None,
default_provider: Some("openrouter".into()),
default_model: Some("gpt-4o".into()),
default_temperature: 0.5,
@@ -2091,6 +2235,7 @@ tool_dispatcher = "xml"
workspace_dir: dir.join("workspace"),
config_path: config_path.clone(),
api_key: Some("sk-roundtrip".into()),
+ api_url: None,
default_provider: Some("openrouter".into()),
default_model: Some("test-model".into()),
default_temperature: 0.9,
@@ -2123,13 +2268,82 @@ tool_dispatcher = "xml"
let contents = fs::read_to_string(&config_path).unwrap();
let loaded: Config = toml::from_str(&contents).unwrap();
- assert_eq!(loaded.api_key.as_deref(), Some("sk-roundtrip"));
+ assert!(loaded
+ .api_key
+ .as_deref()
+ .is_some_and(crate::security::SecretStore::is_encrypted));
+ let store = crate::security::SecretStore::new(&dir, true);
+ let decrypted = store.decrypt(loaded.api_key.as_deref().unwrap()).unwrap();
+ assert_eq!(decrypted, "sk-roundtrip");
assert_eq!(loaded.default_model.as_deref(), Some("test-model"));
assert!((loaded.default_temperature - 0.9).abs() < f64::EPSILON);
let _ = fs::remove_dir_all(&dir);
}
+ #[test]
+ fn config_save_encrypts_nested_credentials() {
+ let dir = std::env::temp_dir().join(format!(
+ "zeroclaw_test_nested_credentials_{}",
+ uuid::Uuid::new_v4()
+ ));
+ fs::create_dir_all(&dir).unwrap();
+
+ let mut config = Config::default();
+ config.workspace_dir = dir.join("workspace");
+ config.config_path = dir.join("config.toml");
+ config.api_key = Some("root-credential".into());
+ config.composio.api_key = Some("composio-credential".into());
+ config.browser.computer_use.api_key = Some("browser-credential".into());
+
+ config.agents.insert(
+ "worker".into(),
+ DelegateAgentConfig {
+ provider: "openrouter".into(),
+ model: "model-test".into(),
+ system_prompt: None,
+ api_key: Some("agent-credential".into()),
+ temperature: None,
+ max_depth: 3,
+ },
+ );
+
+ config.save().unwrap();
+
+ let contents = fs::read_to_string(config.config_path.clone()).unwrap();
+ let stored: Config = toml::from_str(&contents).unwrap();
+ let store = crate::security::SecretStore::new(&dir, true);
+
+ let root_encrypted = stored.api_key.as_deref().unwrap();
+ assert!(crate::security::SecretStore::is_encrypted(root_encrypted));
+ assert_eq!(store.decrypt(root_encrypted).unwrap(), "root-credential");
+
+ let composio_encrypted = stored.composio.api_key.as_deref().unwrap();
+ assert!(crate::security::SecretStore::is_encrypted(
+ composio_encrypted
+ ));
+ assert_eq!(
+ store.decrypt(composio_encrypted).unwrap(),
+ "composio-credential"
+ );
+
+ let browser_encrypted = stored.browser.computer_use.api_key.as_deref().unwrap();
+ assert!(crate::security::SecretStore::is_encrypted(
+ browser_encrypted
+ ));
+ assert_eq!(
+ store.decrypt(browser_encrypted).unwrap(),
+ "browser-credential"
+ );
+
+ let worker = stored.agents.get("worker").unwrap();
+ let worker_encrypted = worker.api_key.as_deref().unwrap();
+ assert!(crate::security::SecretStore::is_encrypted(worker_encrypted));
+ assert_eq!(store.decrypt(worker_encrypted).unwrap(), "agent-credential");
+
+ let _ = fs::remove_dir_all(&dir);
+ }
+
#[test]
fn config_save_atomic_cleanup() {
let dir =
@@ -2182,6 +2396,7 @@ tool_dispatcher = "xml"
guild_id: Some("12345".into()),
allowed_users: vec![],
listen_to_bots: false,
+ mention_only: false,
};
let json = serde_json::to_string(&dc).unwrap();
let parsed: DiscordConfig = serde_json::from_str(&json).unwrap();
@@ -2196,6 +2411,7 @@ tool_dispatcher = "xml"
guild_id: None,
allowed_users: vec![],
listen_to_bots: false,
+ mention_only: false,
};
let json = serde_json::to_string(&dc).unwrap();
let parsed: DiscordConfig = serde_json::from_str(&json).unwrap();
@@ -2818,6 +3034,96 @@ default_temperature = 0.7
std::env::remove_var("ZEROCLAW_WORKSPACE");
}
+ #[test]
+ fn load_or_init_workspace_override_uses_workspace_root_for_config() {
+ let _env_guard = env_override_test_guard();
+ let temp_home =
+ std::env::temp_dir().join(format!("zeroclaw_test_home_{}", uuid::Uuid::new_v4()));
+ let workspace_dir = temp_home.join("profile-a");
+
+ let original_home = std::env::var("HOME").ok();
+ std::env::set_var("HOME", &temp_home);
+ std::env::set_var("ZEROCLAW_WORKSPACE", &workspace_dir);
+
+ let config = Config::load_or_init().unwrap();
+
+ assert_eq!(config.workspace_dir, workspace_dir);
+ assert_eq!(config.config_path, workspace_dir.join("config.toml"));
+ assert!(workspace_dir.join("config.toml").exists());
+
+ std::env::remove_var("ZEROCLAW_WORKSPACE");
+ if let Some(home) = original_home {
+ std::env::set_var("HOME", home);
+ } else {
+ std::env::remove_var("HOME");
+ }
+ let _ = fs::remove_dir_all(temp_home);
+ }
+
+ #[test]
+ fn load_or_init_workspace_suffix_uses_legacy_config_layout() {
+ let _env_guard = env_override_test_guard();
+ let temp_home =
+ std::env::temp_dir().join(format!("zeroclaw_test_home_{}", uuid::Uuid::new_v4()));
+ let workspace_dir = temp_home.join("workspace");
+ let legacy_config_path = temp_home.join(".zeroclaw").join("config.toml");
+
+ let original_home = std::env::var("HOME").ok();
+ std::env::set_var("HOME", &temp_home);
+ std::env::set_var("ZEROCLAW_WORKSPACE", &workspace_dir);
+
+ let config = Config::load_or_init().unwrap();
+
+ assert_eq!(config.workspace_dir, workspace_dir);
+ assert_eq!(config.config_path, legacy_config_path);
+ assert!(config.config_path.exists());
+
+ std::env::remove_var("ZEROCLAW_WORKSPACE");
+ if let Some(home) = original_home {
+ std::env::set_var("HOME", home);
+ } else {
+ std::env::remove_var("HOME");
+ }
+ let _ = fs::remove_dir_all(temp_home);
+ }
+
+ #[test]
+ fn load_or_init_workspace_override_keeps_existing_legacy_config() {
+ let _env_guard = env_override_test_guard();
+ let temp_home =
+ std::env::temp_dir().join(format!("zeroclaw_test_home_{}", uuid::Uuid::new_v4()));
+ let workspace_dir = temp_home.join("custom-workspace");
+ let legacy_config_dir = temp_home.join(".zeroclaw");
+ let legacy_config_path = legacy_config_dir.join("config.toml");
+
+ fs::create_dir_all(&legacy_config_dir).unwrap();
+ fs::write(
+ &legacy_config_path,
+ r#"default_temperature = 0.7
+default_model = "legacy-model"
+"#,
+ )
+ .unwrap();
+
+ let original_home = std::env::var("HOME").ok();
+ std::env::set_var("HOME", &temp_home);
+ std::env::set_var("ZEROCLAW_WORKSPACE", &workspace_dir);
+
+ let config = Config::load_or_init().unwrap();
+
+ assert_eq!(config.workspace_dir, workspace_dir);
+ assert_eq!(config.config_path, legacy_config_path);
+ assert_eq!(config.default_model.as_deref(), Some("legacy-model"));
+
+ std::env::remove_var("ZEROCLAW_WORKSPACE");
+ if let Some(home) = original_home {
+ std::env::set_var("HOME", home);
+ } else {
+ std::env::remove_var("HOME");
+ }
+ let _ = fs::remove_dir_all(temp_home);
+ }
+
#[test]
fn env_override_empty_values_ignored() {
let _env_guard = env_override_test_guard();
@@ -2975,4 +3281,118 @@ default_temperature = 0.7
assert_eq!(parsed.boards[0].board, "nucleo-f401re");
assert_eq!(parsed.boards[0].path.as_deref(), Some("/dev/ttyACM0"));
}
+
+ #[test]
+ fn lark_config_serde() {
+ let lc = LarkConfig {
+ app_id: "cli_123456".into(),
+ app_secret: "secret_abc".into(),
+ encrypt_key: Some("encrypt_key".into()),
+ verification_token: Some("verify_token".into()),
+ allowed_users: vec!["user_123".into(), "user_456".into()],
+ use_feishu: true,
+ receive_mode: LarkReceiveMode::Websocket,
+ port: None,
+ };
+ let json = serde_json::to_string(&lc).unwrap();
+ let parsed: LarkConfig = serde_json::from_str(&json).unwrap();
+ assert_eq!(parsed.app_id, "cli_123456");
+ assert_eq!(parsed.app_secret, "secret_abc");
+ assert_eq!(parsed.encrypt_key.as_deref(), Some("encrypt_key"));
+ assert_eq!(parsed.verification_token.as_deref(), Some("verify_token"));
+ assert_eq!(parsed.allowed_users.len(), 2);
+ assert!(parsed.use_feishu);
+ }
+
+ #[test]
+ fn lark_config_toml_roundtrip() {
+ let lc = LarkConfig {
+ app_id: "cli_123456".into(),
+ app_secret: "secret_abc".into(),
+ encrypt_key: Some("encrypt_key".into()),
+ verification_token: Some("verify_token".into()),
+ allowed_users: vec!["*".into()],
+ use_feishu: false,
+ receive_mode: LarkReceiveMode::Webhook,
+ port: Some(9898),
+ };
+ let toml_str = toml::to_string(&lc).unwrap();
+ let parsed: LarkConfig = toml::from_str(&toml_str).unwrap();
+ assert_eq!(parsed.app_id, "cli_123456");
+ assert_eq!(parsed.app_secret, "secret_abc");
+ assert!(!parsed.use_feishu);
+ }
+
+ #[test]
+ fn lark_config_deserializes_without_optional_fields() {
+ let json = r#"{"app_id":"cli_123","app_secret":"secret"}"#;
+ let parsed: LarkConfig = serde_json::from_str(json).unwrap();
+ assert!(parsed.encrypt_key.is_none());
+ assert!(parsed.verification_token.is_none());
+ assert!(parsed.allowed_users.is_empty());
+ assert!(!parsed.use_feishu);
+ }
+
+ #[test]
+ fn lark_config_defaults_to_lark_endpoint() {
+ let json = r#"{"app_id":"cli_123","app_secret":"secret"}"#;
+ let parsed: LarkConfig = serde_json::from_str(json).unwrap();
+ assert!(
+ !parsed.use_feishu,
+ "use_feishu should default to false (Lark)"
+ );
+ }
+
+ #[test]
+ fn lark_config_with_wildcard_allowed_users() {
+ let json = r#"{"app_id":"cli_123","app_secret":"secret","allowed_users":["*"]}"#;
+ let parsed: LarkConfig = serde_json::from_str(json).unwrap();
+ assert_eq!(parsed.allowed_users, vec!["*"]);
+ }
+
+ // ── Config file permission hardening (Unix only) ───────────────
+
+ #[cfg(unix)]
+ #[test]
+ fn new_config_file_has_restricted_permissions() {
+ use std::os::unix::fs::PermissionsExt;
+
+ let tmp = tempfile::TempDir::new().unwrap();
+ let config_path = tmp.path().join("config.toml");
+
+ // Create a config and save it
+ let mut config = Config::default();
+ config.config_path = config_path.clone();
+ config.save().unwrap();
+
+ // Apply the same permission logic as load_or_init
+ let _ = std::fs::set_permissions(&config_path, std::fs::Permissions::from_mode(0o600));
+
+ let meta = std::fs::metadata(&config_path).unwrap();
+ let mode = meta.permissions().mode() & 0o777;
+ assert_eq!(
+ mode, 0o600,
+ "New config file should be owner-only (0600), got {mode:o}"
+ );
+ }
+
+ #[cfg(unix)]
+ #[test]
+ fn world_readable_config_is_detectable() {
+ use std::os::unix::fs::PermissionsExt;
+
+ let tmp = tempfile::TempDir::new().unwrap();
+ let config_path = tmp.path().join("config.toml");
+
+ // Create a config file with intentionally loose permissions
+ std::fs::write(&config_path, "# test config").unwrap();
+ std::fs::set_permissions(&config_path, std::fs::Permissions::from_mode(0o644)).unwrap();
+
+ let meta = std::fs::metadata(&config_path).unwrap();
+ let mode = meta.permissions().mode();
+ assert!(
+ mode & 0o004 != 0,
+ "Test setup: file should be world-readable (mode {mode:o})"
+ );
+ }
}
diff --git a/src/cron/scheduler.rs b/src/cron/scheduler.rs
index df771d6..4562dba 100644
--- a/src/cron/scheduler.rs
+++ b/src/cron/scheduler.rs
@@ -245,6 +245,7 @@ async fn deliver_if_configured(config: &Config, job: &CronJob, output: &str) ->
dc.guild_id.clone(),
dc.allowed_users.clone(),
dc.listen_to_bots,
+ dc.mention_only,
);
channel.send(output, target).await?;
}
diff --git a/src/daemon/mod.rs b/src/daemon/mod.rs
index c2f4487..a223597 100644
--- a/src/daemon/mod.rs
+++ b/src/daemon/mod.rs
@@ -216,6 +216,7 @@ fn has_supervised_channels(config: &Config) -> bool {
|| config.channels_config.matrix.is_some()
|| config.channels_config.whatsapp.is_some()
|| config.channels_config.email.is_some()
+ || config.channels_config.lark.is_some()
}
#[cfg(test)]
diff --git a/src/gateway/mod.rs b/src/gateway/mod.rs
index 719e8e7..b391a88 100644
--- a/src/gateway/mod.rs
+++ b/src/gateway/mod.rs
@@ -49,6 +49,13 @@ fn whatsapp_memory_key(msg: &crate::channels::traits::ChannelMessage) -> String
format!("whatsapp_{}_{}", msg.sender, msg.id)
}
+fn hash_webhook_secret(value: &str) -> String {
+ use sha2::{Digest, Sha256};
+
+ let digest = Sha256::digest(value.as_bytes());
+ hex::encode(digest)
+}
+
/// How often the rate limiter sweeps stale IP entries from its map.
const RATE_LIMITER_SWEEP_INTERVAL_SECS: u64 = 300; // 5 minutes
@@ -178,7 +185,8 @@ pub struct AppState {
pub temperature: f64,
pub mem: Arc,
pub auto_save: bool,
- pub webhook_secret: Option>,
+ /// SHA-256 hash of `X-Webhook-Secret` (hex-encoded), never plaintext.
+ pub webhook_secret_hash: Option>,
pub pairing: Arc,
pub rate_limiter: Arc,
pub idempotency_store: Arc,
@@ -208,6 +216,7 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
let provider: Arc = Arc::from(providers::create_resilient_provider(
config.default_provider.as_deref().unwrap_or("openrouter"),
config.api_key.as_deref(),
+ config.api_url.as_deref(),
&config.reliability,
)?);
let model = config
@@ -251,12 +260,14 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
&config,
));
// Extract webhook secret for authentication
- let webhook_secret: Option> = config
- .channels_config
- .webhook
- .as_ref()
- .and_then(|w| w.secret.as_deref())
- .map(Arc::from);
+ let webhook_secret_hash: Option> =
+ config.channels_config.webhook.as_ref().and_then(|webhook| {
+ webhook.secret.as_ref().and_then(|raw_secret| {
+ let trimmed_secret = raw_secret.trim();
+ (!trimmed_secret.is_empty())
+ .then(|| Arc::::from(hash_webhook_secret(trimmed_secret)))
+ })
+ });
// WhatsApp channel (if configured)
let whatsapp_channel: Option> =
@@ -342,9 +353,6 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
} else {
println!(" ⚠️ Pairing: DISABLED (all requests accepted)");
}
- if webhook_secret.is_some() {
- println!(" 🔒 Webhook secret: ENABLED");
- }
println!(" Press Ctrl+C to stop.\n");
crate::health::mark_component_ok("gateway");
@@ -356,7 +364,7 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
temperature,
mem,
auto_save: config.memory.auto_save,
- webhook_secret,
+ webhook_secret_hash,
pairing,
rate_limiter,
idempotency_store,
@@ -482,12 +490,15 @@ async fn handle_webhook(
}
// ── Webhook secret auth (optional, additional layer) ──
- if let Some(ref secret) = state.webhook_secret {
- let header_val = headers
+ if let Some(ref secret_hash) = state.webhook_secret_hash {
+ let header_hash = headers
.get("X-Webhook-Secret")
- .and_then(|v| v.to_str().ok());
- match header_val {
- Some(val) if constant_time_eq(val, secret.as_ref()) => {}
+ .and_then(|v| v.to_str().ok())
+ .map(str::trim)
+ .filter(|value| !value.is_empty())
+ .map(hash_webhook_secret);
+ match header_hash {
+ Some(val) if constant_time_eq(&val, secret_hash.as_ref()) => {}
_ => {
tracing::warn!("Webhook: rejected request — invalid or missing X-Webhook-Secret");
let err = serde_json::json!({"error": "Unauthorized — invalid or missing X-Webhook-Secret header"});
@@ -532,7 +543,7 @@ async fn handle_webhook(
let key = webhook_memory_key();
let _ = state
.mem
- .store(&key, message, MemoryCategory::Conversation)
+ .store(&key, message, MemoryCategory::Conversation, None)
.await;
}
@@ -685,7 +696,7 @@ async fn handle_whatsapp_message(
let key = whatsapp_memory_key(msg);
let _ = state
.mem
- .store(&key, &msg.content, MemoryCategory::Conversation)
+ .store(&key, &msg.content, MemoryCategory::Conversation, None)
.await;
}
@@ -697,7 +708,7 @@ async fn handle_whatsapp_message(
{
Ok(response) => {
// Send reply via WhatsApp
- if let Err(e) = wa.send(&response, &msg.sender).await {
+ if let Err(e) = wa.send(&response, &msg.reply_target).await {
tracing::error!("Failed to send WhatsApp reply: {e}");
}
}
@@ -706,7 +717,7 @@ async fn handle_whatsapp_message(
let _ = wa
.send(
"Sorry, I couldn't process your message right now.",
- &msg.sender,
+ &msg.reply_target,
)
.await;
}
@@ -798,7 +809,9 @@ mod tests {
.requests
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
- guard.1 = Instant::now() - Duration::from_secs(RATE_LIMITER_SWEEP_INTERVAL_SECS + 1);
+ guard.1 = Instant::now()
+ .checked_sub(Duration::from_secs(RATE_LIMITER_SWEEP_INTERVAL_SECS + 1))
+ .unwrap();
// Clear timestamps for ip-2 and ip-3 to simulate stale entries
guard.0.get_mut("ip-2").unwrap().clear();
guard.0.get_mut("ip-3").unwrap().clear();
@@ -848,6 +861,7 @@ mod tests {
let msg = ChannelMessage {
id: "wamid-123".into(),
sender: "+1234567890".into(),
+ reply_target: "+1234567890".into(),
content: "hello".into(),
channel: "whatsapp".into(),
timestamp: 1,
@@ -871,11 +885,17 @@ mod tests {
_key: &str,
_content: &str,
_category: MemoryCategory,
+ _session_id: Option<&str>,
) -> anyhow::Result<()> {
Ok(())
}
- async fn recall(&self, _query: &str, _limit: usize) -> anyhow::Result> {
+ async fn recall(
+ &self,
+ _query: &str,
+ _limit: usize,
+ _session_id: Option<&str>,
+ ) -> anyhow::Result> {
Ok(Vec::new())
}
@@ -886,6 +906,7 @@ mod tests {
async fn list(
&self,
_category: Option<&MemoryCategory>,
+ _session_id: Option<&str>,
) -> anyhow::Result> {
Ok(Vec::new())
}
@@ -938,6 +959,7 @@ mod tests {
key: &str,
_content: &str,
_category: MemoryCategory,
+ _session_id: Option<&str>,
) -> anyhow::Result<()> {
self.keys
.lock()
@@ -946,7 +968,12 @@ mod tests {
Ok(())
}
- async fn recall(&self, _query: &str, _limit: usize) -> anyhow::Result> {
+ async fn recall(
+ &self,
+ _query: &str,
+ _limit: usize,
+ _session_id: Option<&str>,
+ ) -> anyhow::Result> {
Ok(Vec::new())
}
@@ -957,6 +984,7 @@ mod tests {
async fn list(
&self,
_category: Option<&MemoryCategory>,
+ _session_id: Option<&str>,
) -> anyhow::Result> {
Ok(Vec::new())
}
@@ -991,7 +1019,7 @@ mod tests {
temperature: 0.0,
mem: memory,
auto_save: false,
- webhook_secret: None,
+ webhook_secret_hash: None,
pairing: Arc::new(PairingGuard::new(false, &[])),
rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100)),
idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300))),
@@ -1039,7 +1067,7 @@ mod tests {
temperature: 0.0,
mem: memory,
auto_save: true,
- webhook_secret: None,
+ webhook_secret_hash: None,
pairing: Arc::new(PairingGuard::new(false, &[])),
rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100)),
idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300))),
@@ -1077,6 +1105,125 @@ mod tests {
assert_eq!(provider_impl.calls.load(Ordering::SeqCst), 2);
}
+ #[test]
+ fn webhook_secret_hash_is_deterministic_and_nonempty() {
+ let one = hash_webhook_secret("secret-value");
+ let two = hash_webhook_secret("secret-value");
+ let other = hash_webhook_secret("other-value");
+
+ assert_eq!(one, two);
+ assert_ne!(one, other);
+ assert_eq!(one.len(), 64);
+ }
+
+ #[tokio::test]
+ async fn webhook_secret_hash_rejects_missing_header() {
+ let provider_impl = Arc::new(MockProvider::default());
+ let provider: Arc = provider_impl.clone();
+ let memory: Arc = Arc::new(MockMemory);
+
+ let state = AppState {
+ provider,
+ model: "test-model".into(),
+ temperature: 0.0,
+ mem: memory,
+ auto_save: false,
+ webhook_secret_hash: Some(Arc::from(hash_webhook_secret("super-secret"))),
+ pairing: Arc::new(PairingGuard::new(false, &[])),
+ rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100)),
+ idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300))),
+ whatsapp: None,
+ whatsapp_app_secret: None,
+ };
+
+ let response = handle_webhook(
+ State(state),
+ HeaderMap::new(),
+ Ok(Json(WebhookBody {
+ message: "hello".into(),
+ })),
+ )
+ .await
+ .into_response();
+
+ assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
+ assert_eq!(provider_impl.calls.load(Ordering::SeqCst), 0);
+ }
+
+ #[tokio::test]
+ async fn webhook_secret_hash_rejects_invalid_header() {
+ let provider_impl = Arc::new(MockProvider::default());
+ let provider: Arc = provider_impl.clone();
+ let memory: Arc = Arc::new(MockMemory);
+
+ let state = AppState {
+ provider,
+ model: "test-model".into(),
+ temperature: 0.0,
+ mem: memory,
+ auto_save: false,
+ webhook_secret_hash: Some(Arc::from(hash_webhook_secret("super-secret"))),
+ pairing: Arc::new(PairingGuard::new(false, &[])),
+ rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100)),
+ idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300))),
+ whatsapp: None,
+ whatsapp_app_secret: None,
+ };
+
+ let mut headers = HeaderMap::new();
+ headers.insert("X-Webhook-Secret", HeaderValue::from_static("wrong-secret"));
+
+ let response = handle_webhook(
+ State(state),
+ headers,
+ Ok(Json(WebhookBody {
+ message: "hello".into(),
+ })),
+ )
+ .await
+ .into_response();
+
+ assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
+ assert_eq!(provider_impl.calls.load(Ordering::SeqCst), 0);
+ }
+
+ #[tokio::test]
+ async fn webhook_secret_hash_accepts_valid_header() {
+ let provider_impl = Arc::new(MockProvider::default());
+ let provider: Arc = provider_impl.clone();
+ let memory: Arc = Arc::new(MockMemory);
+
+ let state = AppState {
+ provider,
+ model: "test-model".into(),
+ temperature: 0.0,
+ mem: memory,
+ auto_save: false,
+ webhook_secret_hash: Some(Arc::from(hash_webhook_secret("super-secret"))),
+ pairing: Arc::new(PairingGuard::new(false, &[])),
+ rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100)),
+ idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300))),
+ whatsapp: None,
+ whatsapp_app_secret: None,
+ };
+
+ let mut headers = HeaderMap::new();
+ headers.insert("X-Webhook-Secret", HeaderValue::from_static("super-secret"));
+
+ let response = handle_webhook(
+ State(state),
+ headers,
+ Ok(Json(WebhookBody {
+ message: "hello".into(),
+ })),
+ )
+ .await
+ .into_response();
+
+ assert_eq!(response.status(), StatusCode::OK);
+ assert_eq!(provider_impl.calls.load(Ordering::SeqCst), 1);
+ }
+
// ══════════════════════════════════════════════════════════
// WhatsApp Signature Verification Tests (CWE-345 Prevention)
// ══════════════════════════════════════════════════════════
diff --git a/src/main.rs b/src/main.rs
index dbc76ff..56cd579 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -34,8 +34,8 @@
use anyhow::{bail, Result};
use clap::{Parser, Subcommand};
-use tracing::{info, Level};
-use tracing_subscriber::FmtSubscriber;
+use tracing::info;
+use tracing_subscriber::{fmt, EnvFilter};
mod agent;
mod channels;
@@ -147,24 +147,24 @@ enum Commands {
/// Start the gateway server (webhooks, websockets)
Gateway {
- /// Port to listen on (use 0 for random available port)
- #[arg(short, long, default_value = "8080")]
- port: u16,
+ /// Port to listen on (use 0 for random available port); defaults to config gateway.port
+ #[arg(short, long)]
+ port: Option,
- /// Host to bind to
- #[arg(long, default_value = "127.0.0.1")]
- host: String,
+ /// Host to bind to; defaults to config gateway.host
+ #[arg(long)]
+ host: Option,
},
/// Start long-running autonomous runtime (gateway + channels + heartbeat + scheduler)
Daemon {
- /// Port to listen on (use 0 for random available port)
- #[arg(short, long, default_value = "8080")]
- port: u16,
+ /// Port to listen on (use 0 for random available port); defaults to config gateway.port
+ #[arg(short, long)]
+ port: Option,
- /// Host to bind to
- #[arg(long, default_value = "127.0.0.1")]
- host: String,
+ /// Host to bind to; defaults to config gateway.host
+ #[arg(long)]
+ host: Option,
},
/// Manage OS service lifecycle (launchd/systemd user service)
@@ -367,9 +367,11 @@ async fn main() -> Result<()> {
let cli = Cli::parse();
- // Initialize logging
- let subscriber = FmtSubscriber::builder()
- .with_max_level(Level::INFO)
+ // Initialize logging - respects RUST_LOG env var, defaults to INFO
+ let subscriber = fmt::Subscriber::builder()
+ .with_env_filter(
+ EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")),
+ )
.finish();
tracing::subscriber::set_global_default(subscriber).expect("setting default subscriber failed");
@@ -434,6 +436,8 @@ async fn main() -> Result<()> {
.map(|_| ()),
Commands::Gateway { port, host } => {
+ let port = port.unwrap_or(config.gateway.port);
+ let host = host.unwrap_or_else(|| config.gateway.host.clone());
if port == 0 {
info!("🚀 Starting ZeroClaw Gateway on {host} (random port)");
} else {
@@ -443,6 +447,8 @@ async fn main() -> Result<()> {
}
Commands::Daemon { port, host } => {
+ let port = port.unwrap_or(config.gateway.port);
+ let host = host.unwrap_or_else(|| config.gateway.host.clone());
if port == 0 {
info!("🧠 Starting ZeroClaw Daemon on {host} (random port)");
} else {
diff --git a/src/memory/backend.rs b/src/memory/backend.rs
index 4de636a..8ba7ec3 100644
--- a/src/memory/backend.rs
+++ b/src/memory/backend.rs
@@ -7,6 +7,7 @@ pub enum MemoryBackendKind {
Unknown,
}
+#[allow(clippy::struct_excessive_bools)]
#[derive(Debug, Clone, Copy, Eq, PartialEq)]
pub struct MemoryBackendProfile {
pub key: &'static str,
diff --git a/src/memory/hygiene.rs b/src/memory/hygiene.rs
index cf58e21..01054ce 100644
--- a/src/memory/hygiene.rs
+++ b/src/memory/hygiene.rs
@@ -502,10 +502,10 @@ mod tests {
let workspace = tmp.path();
let mem = SqliteMemory::new(workspace).unwrap();
- mem.store("conv_old", "outdated", MemoryCategory::Conversation)
+ mem.store("conv_old", "outdated", MemoryCategory::Conversation, None)
.await
.unwrap();
- mem.store("core_keep", "durable", MemoryCategory::Core)
+ mem.store("core_keep", "durable", MemoryCategory::Core, None)
.await
.unwrap();
drop(mem);
diff --git a/src/memory/lucid.rs b/src/memory/lucid.rs
index 50cf9de..e1cb43a 100644
--- a/src/memory/lucid.rs
+++ b/src/memory/lucid.rs
@@ -24,7 +24,9 @@ pub struct LucidMemory {
impl LucidMemory {
const DEFAULT_LUCID_CMD: &'static str = "lucid";
const DEFAULT_TOKEN_BUDGET: usize = 200;
- const DEFAULT_RECALL_TIMEOUT_MS: u64 = 120;
+ // Lucid CLI cold start can exceed 120ms on slower machines, which causes
+ // avoidable fallback to local-only memory and premature cooldown.
+ const DEFAULT_RECALL_TIMEOUT_MS: u64 = 500;
const DEFAULT_STORE_TIMEOUT_MS: u64 = 800;
const DEFAULT_LOCAL_HIT_THRESHOLD: usize = 3;
const DEFAULT_FAILURE_COOLDOWN_MS: u64 = 15_000;
@@ -74,6 +76,7 @@ impl LucidMemory {
}
#[cfg(test)]
+ #[allow(clippy::too_many_arguments)]
fn with_options(
workspace_dir: &Path,
local: SqliteMemory,
@@ -307,14 +310,22 @@ impl Memory for LucidMemory {
key: &str,
content: &str,
category: MemoryCategory,
+ session_id: Option<&str>,
) -> anyhow::Result<()> {
- self.local.store(key, content, category.clone()).await?;
+ self.local
+ .store(key, content, category.clone(), session_id)
+ .await?;
self.sync_to_lucid_async(key, content, &category).await;
Ok(())
}
- async fn recall(&self, query: &str, limit: usize) -> anyhow::Result> {
- let local_results = self.local.recall(query, limit).await?;
+ async fn recall(
+ &self,
+ query: &str,
+ limit: usize,
+ session_id: Option<&str>,
+ ) -> anyhow::Result> {
+ let local_results = self.local.recall(query, limit, session_id).await?;
if limit == 0
|| local_results.len() >= limit
|| local_results.len() >= self.local_hit_threshold
@@ -351,8 +362,12 @@ impl Memory for LucidMemory {
self.local.get(key).await
}
- async fn list(&self, category: Option<&MemoryCategory>) -> anyhow::Result> {
- self.local.list(category).await
+ async fn list(
+ &self,
+ category: Option<&MemoryCategory>,
+ session_id: Option<&str>,
+ ) -> anyhow::Result> {
+ self.local.list(category, session_id).await
}
async fn forget(&self, key: &str) -> anyhow::Result {
@@ -396,6 +411,38 @@ EOF
exit 0
fi
+echo "unsupported command" >&2
+exit 1
+"#;
+
+ fs::write(&script_path, script).unwrap();
+ let mut perms = fs::metadata(&script_path).unwrap().permissions();
+ perms.set_mode(0o755);
+ fs::set_permissions(&script_path, perms).unwrap();
+ script_path.display().to_string()
+ }
+
+ fn write_delayed_lucid_script(dir: &Path) -> String {
+ let script_path = dir.join("delayed-lucid.sh");
+ let script = r#"#!/usr/bin/env bash
+set -euo pipefail
+
+if [[ "${1:-}" == "store" ]]; then
+ echo '{"success":true,"id":"mem_1"}'
+ exit 0
+fi
+
+if [[ "${1:-}" == "context" ]]; then
+ # Simulate a cold start that is slower than 120ms but below the 500ms timeout.
+ sleep 0.2
+ cat <<'EOF'
+
+- [decision] Delayed token refresh guidance
+
+EOF
+ exit 0
+fi
+
echo "unsupported command" >&2
exit 1
"#;
@@ -449,7 +496,7 @@ exit 1
cmd,
200,
3,
- Duration::from_millis(120),
+ Duration::from_millis(500),
Duration::from_millis(400),
Duration::from_secs(2),
)
@@ -468,7 +515,7 @@ exit 1
let memory = test_memory(tmp.path(), "nonexistent-lucid-binary".to_string());
memory
- .store("lang", "User prefers Rust", MemoryCategory::Core)
+ .store("lang", "User prefers Rust", MemoryCategory::Core, None)
.await
.unwrap();
@@ -483,6 +530,30 @@ exit 1
let fake_cmd = write_fake_lucid_script(tmp.path());
let memory = test_memory(tmp.path(), fake_cmd);
+ memory
+ .store(
+ "local_note",
+ "Local sqlite auth fallback note",
+ MemoryCategory::Core,
+ None,
+ )
+ .await
+ .unwrap();
+
+ let entries = memory.recall("auth", 5, None).await.unwrap();
+
+ assert!(entries
+ .iter()
+ .any(|e| e.content.contains("Local sqlite auth fallback note")));
+ assert!(entries.iter().any(|e| e.content.contains("token refresh")));
+ }
+
+ #[tokio::test]
+ async fn recall_handles_lucid_cold_start_delay_within_timeout() {
+ let tmp = TempDir::new().unwrap();
+ let delayed_cmd = write_delayed_lucid_script(tmp.path());
+ let memory = test_memory(tmp.path(), delayed_cmd);
+
memory
.store(
"local_note",
@@ -497,7 +568,9 @@ exit 1
assert!(entries
.iter()
.any(|e| e.content.contains("Local sqlite auth fallback note")));
- assert!(entries.iter().any(|e| e.content.contains("token refresh")));
+ assert!(entries
+ .iter()
+ .any(|e| e.content.contains("Delayed token refresh guidance")));
}
#[tokio::test]
@@ -513,17 +586,22 @@ exit 1
probe_cmd,
200,
1,
- Duration::from_millis(120),
+ Duration::from_millis(500),
Duration::from_millis(400),
Duration::from_secs(2),
);
memory
- .store("pref", "Rust should stay local-first", MemoryCategory::Core)
+ .store(
+ "pref",
+ "Rust should stay local-first",
+ MemoryCategory::Core,
+ None,
+ )
.await
.unwrap();
- let entries = memory.recall("rust", 5).await.unwrap();
+ let entries = memory.recall("rust", 5, None).await.unwrap();
assert!(entries
.iter()
.any(|e| e.content.contains("Rust should stay local-first")));
@@ -578,13 +656,13 @@ exit 1
failing_cmd,
200,
99,
- Duration::from_millis(120),
+ Duration::from_millis(500),
Duration::from_millis(400),
Duration::from_secs(5),
);
- let first = memory.recall("auth", 5).await.unwrap();
- let second = memory.recall("auth", 5).await.unwrap();
+ let first = memory.recall("auth", 5, None).await.unwrap();
+ let second = memory.recall("auth", 5, None).await.unwrap();
assert!(first.is_empty());
assert!(second.is_empty());
diff --git a/src/memory/markdown.rs b/src/memory/markdown.rs
index 8dcd667..9038683 100644
--- a/src/memory/markdown.rs
+++ b/src/memory/markdown.rs
@@ -143,6 +143,7 @@ impl Memory for MarkdownMemory {
key: &str,
content: &str,
category: MemoryCategory,
+ _session_id: Option<&str>,
) -> anyhow::Result<()> {
let entry = format!("- **{key}**: {content}");
let path = match category {
@@ -152,7 +153,12 @@ impl Memory for MarkdownMemory {
self.append_to_file(&path, &entry).await
}
- async fn recall(&self, query: &str, limit: usize) -> anyhow::Result> {
+ async fn recall(
+ &self,
+ query: &str,
+ limit: usize,
+ _session_id: Option<&str>,
+ ) -> anyhow::Result> {
let all = self.read_all_entries().await?;
let query_lower = query.to_lowercase();
let keywords: Vec<&str> = query_lower.split_whitespace().collect();
@@ -192,7 +198,11 @@ impl Memory for MarkdownMemory {
.find(|e| e.key == key || e.content.contains(key)))
}
- async fn list(&self, category: Option<&MemoryCategory>) -> anyhow::Result> {
+ async fn list(
+ &self,
+ category: Option<&MemoryCategory>,
+ _session_id: Option<&str>,
+ ) -> anyhow::Result> {
let all = self.read_all_entries().await?;
match category {
Some(cat) => Ok(all.into_iter().filter(|e| &e.category == cat).collect()),
@@ -243,7 +253,7 @@ mod tests {
#[tokio::test]
async fn markdown_store_core() {
let (_tmp, mem) = temp_workspace();
- mem.store("pref", "User likes Rust", MemoryCategory::Core)
+ mem.store("pref", "User likes Rust", MemoryCategory::Core, None)
.await
.unwrap();
let content = sync_fs::read_to_string(mem.core_path()).unwrap();
@@ -253,7 +263,7 @@ mod tests {
#[tokio::test]
async fn markdown_store_daily() {
let (_tmp, mem) = temp_workspace();
- mem.store("note", "Finished tests", MemoryCategory::Daily)
+ mem.store("note", "Finished tests", MemoryCategory::Daily, None)
.await
.unwrap();
let path = mem.daily_path();
@@ -264,17 +274,17 @@ mod tests {
#[tokio::test]
async fn markdown_recall_keyword() {
let (_tmp, mem) = temp_workspace();
- mem.store("a", "Rust is fast", MemoryCategory::Core)
+ mem.store("a", "Rust is fast", MemoryCategory::Core, None)
.await
.unwrap();
- mem.store("b", "Python is slow", MemoryCategory::Core)
+ mem.store("b", "Python is slow", MemoryCategory::Core, None)
.await
.unwrap();
- mem.store("c", "Rust and safety", MemoryCategory::Core)
+ mem.store("c", "Rust and safety", MemoryCategory::Core, None)
.await
.unwrap();
- let results = mem.recall("Rust", 10).await.unwrap();
+ let results = mem.recall("Rust", 10, None).await.unwrap();
assert!(results.len() >= 2);
assert!(results
.iter()
@@ -284,18 +294,20 @@ mod tests {
#[tokio::test]
async fn markdown_recall_no_match() {
let (_tmp, mem) = temp_workspace();
- mem.store("a", "Rust is great", MemoryCategory::Core)
+ mem.store("a", "Rust is great", MemoryCategory::Core, None)
.await
.unwrap();
- let results = mem.recall("javascript", 10).await.unwrap();
+ let results = mem.recall("javascript", 10, None).await.unwrap();
assert!(results.is_empty());
}
#[tokio::test]
async fn markdown_count() {
let (_tmp, mem) = temp_workspace();
- mem.store("a", "first", MemoryCategory::Core).await.unwrap();
- mem.store("b", "second", MemoryCategory::Core)
+ mem.store("a", "first", MemoryCategory::Core, None)
+ .await
+ .unwrap();
+ mem.store("b", "second", MemoryCategory::Core, None)
.await
.unwrap();
let count = mem.count().await.unwrap();
@@ -305,24 +317,24 @@ mod tests {
#[tokio::test]
async fn markdown_list_by_category() {
let (_tmp, mem) = temp_workspace();
- mem.store("a", "core fact", MemoryCategory::Core)
+ mem.store("a", "core fact", MemoryCategory::Core, None)
.await
.unwrap();
- mem.store("b", "daily note", MemoryCategory::Daily)
+ mem.store("b", "daily note", MemoryCategory::Daily, None)
.await
.unwrap();
- let core = mem.list(Some(&MemoryCategory::Core)).await.unwrap();
+ let core = mem.list(Some(&MemoryCategory::Core), None).await.unwrap();
assert!(core.iter().all(|e| e.category == MemoryCategory::Core));
- let daily = mem.list(Some(&MemoryCategory::Daily)).await.unwrap();
+ let daily = mem.list(Some(&MemoryCategory::Daily), None).await.unwrap();
assert!(daily.iter().all(|e| e.category == MemoryCategory::Daily));
}
#[tokio::test]
async fn markdown_forget_is_noop() {
let (_tmp, mem) = temp_workspace();
- mem.store("a", "permanent", MemoryCategory::Core)
+ mem.store("a", "permanent", MemoryCategory::Core, None)
.await
.unwrap();
let removed = mem.forget("a").await.unwrap();
@@ -332,7 +344,7 @@ mod tests {
#[tokio::test]
async fn markdown_empty_recall() {
let (_tmp, mem) = temp_workspace();
- let results = mem.recall("anything", 10).await.unwrap();
+ let results = mem.recall("anything", 10, None).await.unwrap();
assert!(results.is_empty());
}
diff --git a/src/memory/none.rs b/src/memory/none.rs
index 6057ad0..4ccd2f8 100644
--- a/src/memory/none.rs
+++ b/src/memory/none.rs
@@ -25,11 +25,17 @@ impl Memory for NoneMemory {
_key: &str,
_content: &str,
_category: MemoryCategory,
+ _session_id: Option<&str>,
) -> anyhow::Result<()> {
Ok(())
}
- async fn recall(&self, _query: &str, _limit: usize) -> anyhow::Result> {
+ async fn recall(
+ &self,
+ _query: &str,
+ _limit: usize,
+ _session_id: Option<&str>,
+ ) -> anyhow::Result> {
Ok(Vec::new())
}
@@ -37,7 +43,11 @@ impl Memory for NoneMemory {
Ok(None)
}
- async fn list(&self, _category: Option<&MemoryCategory>) -> anyhow::Result> {
+ async fn list(
+ &self,
+ _category: Option<&MemoryCategory>,
+ _session_id: Option<&str>,
+ ) -> anyhow::Result> {
Ok(Vec::new())
}
@@ -62,11 +72,14 @@ mod tests {
async fn none_memory_is_noop() {
let memory = NoneMemory::new();
- memory.store("k", "v", MemoryCategory::Core).await.unwrap();
+ memory
+ .store("k", "v", MemoryCategory::Core, None)
+ .await
+ .unwrap();
assert!(memory.get("k").await.unwrap().is_none());
- assert!(memory.recall("k", 10).await.unwrap().is_empty());
- assert!(memory.list(None).await.unwrap().is_empty());
+ assert!(memory.recall("k", 10, None).await.unwrap().is_empty());
+ assert!(memory.list(None, None).await.unwrap().is_empty());
assert!(!memory.forget("k").await.unwrap());
assert_eq!(memory.count().await.unwrap(), 0);
assert!(memory.health_check().await);
diff --git a/src/memory/response_cache.rs b/src/memory/response_cache.rs
index 6baa5c7..a260aa7 100644
--- a/src/memory/response_cache.rs
+++ b/src/memory/response_cache.rs
@@ -157,7 +157,7 @@ impl ResponseCache {
|row| row.get(0),
)?;
- #[allow(clippy::cast_sign_loss)]
+ #[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)]
Ok((count as usize, hits as u64, tokens_saved as u64))
}
diff --git a/src/memory/sqlite.rs b/src/memory/sqlite.rs
index 160487d..46a98db 100644
--- a/src/memory/sqlite.rs
+++ b/src/memory/sqlite.rs
@@ -124,6 +124,19 @@ impl SqliteMemory {
);
CREATE INDEX IF NOT EXISTS idx_cache_accessed ON embedding_cache(accessed_at);",
)?;
+
+ // Migration: add session_id column if not present (safe to run repeatedly)
+ let has_session_id: bool = conn
+ .prepare("SELECT sql FROM sqlite_master WHERE type='table' AND name='memories'")?
+ .query_row([], |row| row.get::<_, String>(0))?
+ .contains("session_id");
+ if !has_session_id {
+ conn.execute_batch(
+ "ALTER TABLE memories ADD COLUMN session_id TEXT;
+ CREATE INDEX IF NOT EXISTS idx_memories_session ON memories(session_id);",
+ )?;
+ }
+
Ok(())
}
@@ -361,6 +374,7 @@ impl Memory for SqliteMemory {
key: &str,
content: &str,
category: MemoryCategory,
+ session_id: Option<&str>,
) -> anyhow::Result<()> {
// Compute embedding (async, before lock)
let embedding_bytes = self
@@ -377,20 +391,26 @@ impl Memory for SqliteMemory {
let id = Uuid::new_v4().to_string();
conn.execute(
- "INSERT INTO memories (id, key, content, category, embedding, created_at, updated_at)
- VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)
+ "INSERT INTO memories (id, key, content, category, embedding, created_at, updated_at, session_id)
+ VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)
ON CONFLICT(key) DO UPDATE SET
content = excluded.content,
category = excluded.category,
embedding = excluded.embedding,
- updated_at = excluded.updated_at",
- params![id, key, content, cat, embedding_bytes, now, now],
+ updated_at = excluded.updated_at,
+ session_id = excluded.session_id",
+ params![id, key, content, cat, embedding_bytes, now, now, session_id],
)?;
Ok(())
}
- async fn recall(&self, query: &str, limit: usize) -> anyhow::Result> {
+ async fn recall(
+ &self,
+ query: &str,
+ limit: usize,
+ session_id: Option<&str>,
+ ) -> anyhow::Result> {
if query.trim().is_empty() {
return Ok(Vec::new());
}
@@ -439,7 +459,7 @@ impl Memory for SqliteMemory {
let mut results = Vec::new();
for scored in &merged {
let mut stmt = conn.prepare(
- "SELECT id, key, content, category, created_at FROM memories WHERE id = ?1",
+ "SELECT id, key, content, category, created_at, session_id FROM memories WHERE id = ?1",
)?;
if let Ok(entry) = stmt.query_row(params![scored.id], |row| {
Ok(MemoryEntry {
@@ -448,10 +468,16 @@ impl Memory for SqliteMemory {
content: row.get(2)?,
category: Self::str_to_category(&row.get::<_, String>(3)?),
timestamp: row.get(4)?,
- session_id: None,
+ session_id: row.get(5)?,
score: Some(f64::from(scored.final_score)),
})
}) {
+ // Filter by session_id if requested
+ if let Some(sid) = session_id {
+ if entry.session_id.as_deref() != Some(sid) {
+ continue;
+ }
+ }
results.push(entry);
}
}
@@ -470,7 +496,7 @@ impl Memory for SqliteMemory {
.collect();
let where_clause = conditions.join(" OR ");
let sql = format!(
- "SELECT id, key, content, category, created_at FROM memories
+ "SELECT id, key, content, category, created_at, session_id FROM memories
WHERE {where_clause}
ORDER BY updated_at DESC
LIMIT ?{}",
@@ -493,12 +519,18 @@ impl Memory for SqliteMemory {
content: row.get(2)?,
category: Self::str_to_category(&row.get::<_, String>(3)?),
timestamp: row.get(4)?,
- session_id: None,
+ session_id: row.get(5)?,
score: Some(1.0),
})
})?;
for row in rows {
- results.push(row?);
+ let entry = row?;
+ if let Some(sid) = session_id {
+ if entry.session_id.as_deref() != Some(sid) {
+ continue;
+ }
+ }
+ results.push(entry);
}
}
}
@@ -514,7 +546,7 @@ impl Memory for SqliteMemory {
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
let mut stmt = conn.prepare(
- "SELECT id, key, content, category, created_at FROM memories WHERE key = ?1",
+ "SELECT id, key, content, category, created_at, session_id FROM memories WHERE key = ?1",
)?;
let mut rows = stmt.query_map(params![key], |row| {
@@ -524,7 +556,7 @@ impl Memory for SqliteMemory {
content: row.get(2)?,
category: Self::str_to_category(&row.get::<_, String>(3)?),
timestamp: row.get(4)?,
- session_id: None,
+ session_id: row.get(5)?,
score: None,
})
})?;
@@ -535,7 +567,11 @@ impl Memory for SqliteMemory {
}
}
- async fn list(&self, category: Option<&MemoryCategory>) -> anyhow::Result> {
+ async fn list(
+ &self,
+ category: Option<&MemoryCategory>,
+ session_id: Option<&str>,
+ ) -> anyhow::Result> {
let conn = self
.conn
.lock()
@@ -550,7 +586,7 @@ impl Memory for SqliteMemory {
content: row.get(2)?,
category: Self::str_to_category(&row.get::<_, String>(3)?),
timestamp: row.get(4)?,
- session_id: None,
+ session_id: row.get(5)?,
score: None,
})
};
@@ -558,21 +594,33 @@ impl Memory for SqliteMemory {
if let Some(cat) = category {
let cat_str = Self::category_to_str(cat);
let mut stmt = conn.prepare(
- "SELECT id, key, content, category, created_at FROM memories
+ "SELECT id, key, content, category, created_at, session_id FROM memories
WHERE category = ?1 ORDER BY updated_at DESC",
)?;
let rows = stmt.query_map(params![cat_str], row_mapper)?;
for row in rows {
- results.push(row?);
+ let entry = row?;
+ if let Some(sid) = session_id {
+ if entry.session_id.as_deref() != Some(sid) {
+ continue;
+ }
+ }
+ results.push(entry);
}
} else {
let mut stmt = conn.prepare(
- "SELECT id, key, content, category, created_at FROM memories
+ "SELECT id, key, content, category, created_at, session_id FROM memories
ORDER BY updated_at DESC",
)?;
let rows = stmt.query_map([], row_mapper)?;
for row in rows {
- results.push(row?);
+ let entry = row?;
+ if let Some(sid) = session_id {
+ if entry.session_id.as_deref() != Some(sid) {
+ continue;
+ }
+ }
+ results.push(entry);
}
}
@@ -632,7 +680,7 @@ mod tests {
#[tokio::test]
async fn sqlite_store_and_get() {
let (_tmp, mem) = temp_sqlite();
- mem.store("user_lang", "Prefers Rust", MemoryCategory::Core)
+ mem.store("user_lang", "Prefers Rust", MemoryCategory::Core, None)
.await
.unwrap();
@@ -647,10 +695,10 @@ mod tests {
#[tokio::test]
async fn sqlite_store_upsert() {
let (_tmp, mem) = temp_sqlite();
- mem.store("pref", "likes Rust", MemoryCategory::Core)
+ mem.store("pref", "likes Rust", MemoryCategory::Core, None)
.await
.unwrap();
- mem.store("pref", "loves Rust", MemoryCategory::Core)
+ mem.store("pref", "loves Rust", MemoryCategory::Core, None)
.await
.unwrap();
@@ -662,17 +710,22 @@ mod tests {
#[tokio::test]
async fn sqlite_recall_keyword() {
let (_tmp, mem) = temp_sqlite();
- mem.store("a", "Rust is fast and safe", MemoryCategory::Core)
+ mem.store("a", "Rust is fast and safe", MemoryCategory::Core, None)
.await
.unwrap();
- mem.store("b", "Python is interpreted", MemoryCategory::Core)
- .await
- .unwrap();
- mem.store("c", "Rust has zero-cost abstractions", MemoryCategory::Core)
+ mem.store("b", "Python is interpreted", MemoryCategory::Core, None)
.await
.unwrap();
+ mem.store(
+ "c",
+ "Rust has zero-cost abstractions",
+ MemoryCategory::Core,
+ None,
+ )
+ .await
+ .unwrap();
- let results = mem.recall("Rust", 10).await.unwrap();
+ let results = mem.recall("Rust", 10, None).await.unwrap();
assert_eq!(results.len(), 2);
assert!(results
.iter()
@@ -682,14 +735,14 @@ mod tests {
#[tokio::test]
async fn sqlite_recall_multi_keyword() {
let (_tmp, mem) = temp_sqlite();
- mem.store("a", "Rust is fast", MemoryCategory::Core)
+ mem.store("a", "Rust is fast", MemoryCategory::Core, None)
.await
.unwrap();
- mem.store("b", "Rust is safe and fast", MemoryCategory::Core)
+ mem.store("b", "Rust is safe and fast", MemoryCategory::Core, None)
.await
.unwrap();
- let results = mem.recall("fast safe", 10).await.unwrap();
+ let results = mem.recall("fast safe", 10, None).await.unwrap();
assert!(!results.is_empty());
// Entry with both keywords should score higher
assert!(results[0].content.contains("safe") && results[0].content.contains("fast"));
@@ -698,17 +751,17 @@ mod tests {
#[tokio::test]
async fn sqlite_recall_no_match() {
let (_tmp, mem) = temp_sqlite();
- mem.store("a", "Rust rocks", MemoryCategory::Core)
+ mem.store("a", "Rust rocks", MemoryCategory::Core, None)
.await
.unwrap();
- let results = mem.recall("javascript", 10).await.unwrap();
+ let results = mem.recall("javascript", 10, None).await.unwrap();
assert!(results.is_empty());
}
#[tokio::test]
async fn sqlite_forget() {
let (_tmp, mem) = temp_sqlite();
- mem.store("temp", "temporary data", MemoryCategory::Conversation)
+ mem.store("temp", "temporary data", MemoryCategory::Conversation, None)
.await
.unwrap();
assert_eq!(mem.count().await.unwrap(), 1);
@@ -728,29 +781,37 @@ mod tests {
#[tokio::test]
async fn sqlite_list_all() {
let (_tmp, mem) = temp_sqlite();
- mem.store("a", "one", MemoryCategory::Core).await.unwrap();
- mem.store("b", "two", MemoryCategory::Daily).await.unwrap();
- mem.store("c", "three", MemoryCategory::Conversation)
+ mem.store("a", "one", MemoryCategory::Core, None)
+ .await
+ .unwrap();
+ mem.store("b", "two", MemoryCategory::Daily, None)
+ .await
+ .unwrap();
+ mem.store("c", "three", MemoryCategory::Conversation, None)
.await
.unwrap();
- let all = mem.list(None).await.unwrap();
+ let all = mem.list(None, None).await.unwrap();
assert_eq!(all.len(), 3);
}
#[tokio::test]
async fn sqlite_list_by_category() {
let (_tmp, mem) = temp_sqlite();
- mem.store("a", "core1", MemoryCategory::Core).await.unwrap();
- mem.store("b", "core2", MemoryCategory::Core).await.unwrap();
- mem.store("c", "daily1", MemoryCategory::Daily)
+ mem.store("a", "core1", MemoryCategory::Core, None)
+ .await
+ .unwrap();
+ mem.store("b", "core2", MemoryCategory::Core, None)
+ .await
+ .unwrap();
+ mem.store("c", "daily1", MemoryCategory::Daily, None)
.await
.unwrap();
- let core = mem.list(Some(&MemoryCategory::Core)).await.unwrap();
+ let core = mem.list(Some(&MemoryCategory::Core), None).await.unwrap();
assert_eq!(core.len(), 2);
- let daily = mem.list(Some(&MemoryCategory::Daily)).await.unwrap();
+ let daily = mem.list(Some(&MemoryCategory::Daily), None).await.unwrap();
assert_eq!(daily.len(), 1);
}
@@ -772,7 +833,7 @@ mod tests {
{
let mem = SqliteMemory::new(tmp.path()).unwrap();
- mem.store("persist", "I survive restarts", MemoryCategory::Core)
+ mem.store("persist", "I survive restarts", MemoryCategory::Core, None)
.await
.unwrap();
}
@@ -795,7 +856,7 @@ mod tests {
];
for (i, cat) in categories.iter().enumerate() {
- mem.store(&format!("k{i}"), &format!("v{i}"), cat.clone())
+ mem.store(&format!("k{i}"), &format!("v{i}"), cat.clone(), None)
.await
.unwrap();
}
@@ -815,21 +876,28 @@ mod tests {
"a",
"Rust is a systems programming language",
MemoryCategory::Core,
+ None,
+ )
+ .await
+ .unwrap();
+ mem.store(
+ "b",
+ "Python is great for scripting",
+ MemoryCategory::Core,
+ None,
)
.await
.unwrap();
- mem.store("b", "Python is great for scripting", MemoryCategory::Core)
- .await
- .unwrap();
mem.store(
"c",
"Rust and Rust and Rust everywhere",
MemoryCategory::Core,
+ None,
)
.await
.unwrap();
- let results = mem.recall("Rust", 10).await.unwrap();
+ let results = mem.recall("Rust", 10, None).await.unwrap();
assert!(results.len() >= 2);
// All results should contain "Rust"
for r in &results {
@@ -844,17 +912,17 @@ mod tests {
#[tokio::test]
async fn fts5_multi_word_query() {
let (_tmp, mem) = temp_sqlite();
- mem.store("a", "The quick brown fox jumps", MemoryCategory::Core)
+ mem.store("a", "The quick brown fox jumps", MemoryCategory::Core, None)
.await
.unwrap();
- mem.store("b", "A lazy dog sleeps", MemoryCategory::Core)
+ mem.store("b", "A lazy dog sleeps", MemoryCategory::Core, None)
.await
.unwrap();
- mem.store("c", "The quick dog runs fast", MemoryCategory::Core)
+ mem.store("c", "The quick dog runs fast", MemoryCategory::Core, None)
.await
.unwrap();
- let results = mem.recall("quick dog", 10).await.unwrap();
+ let results = mem.recall("quick dog", 10, None).await.unwrap();
assert!(!results.is_empty());
// "The quick dog runs fast" matches both terms
assert!(results[0].content.contains("quick"));
@@ -863,16 +931,20 @@ mod tests {
#[tokio::test]
async fn recall_empty_query_returns_empty() {
let (_tmp, mem) = temp_sqlite();
- mem.store("a", "data", MemoryCategory::Core).await.unwrap();
- let results = mem.recall("", 10).await.unwrap();
+ mem.store("a", "data", MemoryCategory::Core, None)
+ .await
+ .unwrap();
+ let results = mem.recall("", 10, None).await.unwrap();
assert!(results.is_empty());
}
#[tokio::test]
async fn recall_whitespace_query_returns_empty() {
let (_tmp, mem) = temp_sqlite();
- mem.store("a", "data", MemoryCategory::Core).await.unwrap();
- let results = mem.recall(" ", 10).await.unwrap();
+ mem.store("a", "data", MemoryCategory::Core, None)
+ .await
+ .unwrap();
+ let results = mem.recall(" ", 10, None).await.unwrap();
assert!(results.is_empty());
}
@@ -937,9 +1009,14 @@ mod tests {
#[tokio::test]
async fn fts5_syncs_on_insert() {
let (_tmp, mem) = temp_sqlite();
- mem.store("test_key", "unique_searchterm_xyz", MemoryCategory::Core)
- .await
- .unwrap();
+ mem.store(
+ "test_key",
+ "unique_searchterm_xyz",
+ MemoryCategory::Core,
+ None,
+ )
+ .await
+ .unwrap();
let conn = mem.conn.lock();
let count: i64 = conn
@@ -955,9 +1032,14 @@ mod tests {
#[tokio::test]
async fn fts5_syncs_on_delete() {
let (_tmp, mem) = temp_sqlite();
- mem.store("del_key", "deletable_content_abc", MemoryCategory::Core)
- .await
- .unwrap();
+ mem.store(
+ "del_key",
+ "deletable_content_abc",
+ MemoryCategory::Core,
+ None,
+ )
+ .await
+ .unwrap();
mem.forget("del_key").await.unwrap();
let conn = mem.conn.lock();
@@ -974,10 +1056,15 @@ mod tests {
#[tokio::test]
async fn fts5_syncs_on_update() {
let (_tmp, mem) = temp_sqlite();
- mem.store("upd_key", "original_content_111", MemoryCategory::Core)
- .await
- .unwrap();
- mem.store("upd_key", "updated_content_222", MemoryCategory::Core)
+ mem.store(
+ "upd_key",
+ "original_content_111",
+ MemoryCategory::Core,
+ None,
+ )
+ .await
+ .unwrap();
+ mem.store("upd_key", "updated_content_222", MemoryCategory::Core, None)
.await
.unwrap();
@@ -1019,10 +1106,10 @@ mod tests {
#[tokio::test]
async fn reindex_rebuilds_fts() {
let (_tmp, mem) = temp_sqlite();
- mem.store("r1", "reindex test alpha", MemoryCategory::Core)
+ mem.store("r1", "reindex test alpha", MemoryCategory::Core, None)
.await
.unwrap();
- mem.store("r2", "reindex test beta", MemoryCategory::Core)
+ mem.store("r2", "reindex test beta", MemoryCategory::Core, None)
.await
.unwrap();
@@ -1031,7 +1118,7 @@ mod tests {
assert_eq!(count, 0);
// FTS should still work after rebuild
- let results = mem.recall("reindex", 10).await.unwrap();
+ let results = mem.recall("reindex", 10, None).await.unwrap();
assert_eq!(results.len(), 2);
}
@@ -1045,12 +1132,13 @@ mod tests {
&format!("k{i}"),
&format!("common keyword item {i}"),
MemoryCategory::Core,
+ None,
)
.await
.unwrap();
}
- let results = mem.recall("common keyword", 5).await.unwrap();
+ let results = mem.recall("common keyword", 5, None).await.unwrap();
assert!(results.len() <= 5);
}
@@ -1059,11 +1147,11 @@ mod tests {
#[tokio::test]
async fn recall_results_have_scores() {
let (_tmp, mem) = temp_sqlite();
- mem.store("s1", "scored result test", MemoryCategory::Core)
+ mem.store("s1", "scored result test", MemoryCategory::Core, None)
.await
.unwrap();
- let results = mem.recall("scored", 10).await.unwrap();
+ let results = mem.recall("scored", 10, None).await.unwrap();
assert!(!results.is_empty());
for r in &results {
assert!(r.score.is_some(), "Expected score on result: {:?}", r.key);
@@ -1075,11 +1163,11 @@ mod tests {
#[tokio::test]
async fn recall_with_quotes_in_query() {
let (_tmp, mem) = temp_sqlite();
- mem.store("q1", "He said hello world", MemoryCategory::Core)
+ mem.store("q1", "He said hello world", MemoryCategory::Core, None)
.await
.unwrap();
// Quotes in query should not crash FTS5
- let results = mem.recall("\"hello\"", 10).await.unwrap();
+ let results = mem.recall("\"hello\"", 10, None).await.unwrap();
// May or may not match depending on FTS5 escaping, but must not error
assert!(results.len() <= 10);
}
@@ -1087,31 +1175,34 @@ mod tests {
#[tokio::test]
async fn recall_with_asterisk_in_query() {
let (_tmp, mem) = temp_sqlite();
- mem.store("a1", "wildcard test content", MemoryCategory::Core)
+ mem.store("a1", "wildcard test content", MemoryCategory::Core, None)
.await
.unwrap();
- let results = mem.recall("wild*", 10).await.unwrap();
+ let results = mem.recall("wild*", 10, None).await.unwrap();
assert!(results.len() <= 10);
}
#[tokio::test]
async fn recall_with_parentheses_in_query() {
let (_tmp, mem) = temp_sqlite();
- mem.store("p1", "function call test", MemoryCategory::Core)
+ mem.store("p1", "function call test", MemoryCategory::Core, None)
.await
.unwrap();
- let results = mem.recall("function()", 10).await.unwrap();
+ let results = mem.recall("function()", 10, None).await.unwrap();
assert!(results.len() <= 10);
}
#[tokio::test]
async fn recall_with_sql_injection_attempt() {
let (_tmp, mem) = temp_sqlite();
- mem.store("safe", "normal content", MemoryCategory::Core)
+ mem.store("safe", "normal content", MemoryCategory::Core, None)
.await
.unwrap();
// Should not crash or leak data
- let results = mem.recall("'; DROP TABLE memories; --", 10).await.unwrap();
+ let results = mem
+ .recall("'; DROP TABLE memories; --", 10, None)
+ .await
+ .unwrap();
assert!(results.len() <= 10);
// Table should still exist
assert_eq!(mem.count().await.unwrap(), 1);
@@ -1122,7 +1213,9 @@ mod tests {
#[tokio::test]
async fn store_empty_content() {
let (_tmp, mem) = temp_sqlite();
- mem.store("empty", "", MemoryCategory::Core).await.unwrap();
+ mem.store("empty", "", MemoryCategory::Core, None)
+ .await
+ .unwrap();
let entry = mem.get("empty").await.unwrap().unwrap();
assert_eq!(entry.content, "");
}
@@ -1130,7 +1223,7 @@ mod tests {
#[tokio::test]
async fn store_empty_key() {
let (_tmp, mem) = temp_sqlite();
- mem.store("", "content for empty key", MemoryCategory::Core)
+ mem.store("", "content for empty key", MemoryCategory::Core, None)
.await
.unwrap();
let entry = mem.get("").await.unwrap().unwrap();
@@ -1141,7 +1234,7 @@ mod tests {
async fn store_very_long_content() {
let (_tmp, mem) = temp_sqlite();
let long_content = "x".repeat(100_000);
- mem.store("long", &long_content, MemoryCategory::Core)
+ mem.store("long", &long_content, MemoryCategory::Core, None)
.await
.unwrap();
let entry = mem.get("long").await.unwrap().unwrap();
@@ -1151,9 +1244,14 @@ mod tests {
#[tokio::test]
async fn store_unicode_and_emoji() {
let (_tmp, mem) = temp_sqlite();
- mem.store("emoji_key_🦀", "こんにちは 🚀 Ñoño", MemoryCategory::Core)
- .await
- .unwrap();
+ mem.store(
+ "emoji_key_🦀",
+ "こんにちは 🚀 Ñoño",
+ MemoryCategory::Core,
+ None,
+ )
+ .await
+ .unwrap();
let entry = mem.get("emoji_key_🦀").await.unwrap().unwrap();
assert_eq!(entry.content, "こんにちは 🚀 Ñoño");
}
@@ -1162,7 +1260,7 @@ mod tests {
async fn store_content_with_newlines_and_tabs() {
let (_tmp, mem) = temp_sqlite();
let content = "line1\nline2\ttab\rcarriage\n\nnewparagraph";
- mem.store("whitespace", content, MemoryCategory::Core)
+ mem.store("whitespace", content, MemoryCategory::Core, None)
.await
.unwrap();
let entry = mem.get("whitespace").await.unwrap().unwrap();
@@ -1174,11 +1272,11 @@ mod tests {
#[tokio::test]
async fn recall_single_character_query() {
let (_tmp, mem) = temp_sqlite();
- mem.store("a", "x marks the spot", MemoryCategory::Core)
+ mem.store("a", "x marks the spot", MemoryCategory::Core, None)
.await
.unwrap();
// Single char may not match FTS5 but LIKE fallback should work
- let results = mem.recall("x", 10).await.unwrap();
+ let results = mem.recall("x", 10, None).await.unwrap();
// Should not crash; may or may not find results
assert!(results.len() <= 10);
}
@@ -1186,23 +1284,23 @@ mod tests {
#[tokio::test]
async fn recall_limit_zero() {
let (_tmp, mem) = temp_sqlite();
- mem.store("a", "some content", MemoryCategory::Core)
+ mem.store("a", "some content", MemoryCategory::Core, None)
.await
.unwrap();
- let results = mem.recall("some", 0).await.unwrap();
+ let results = mem.recall("some", 0, None).await.unwrap();
assert!(results.is_empty());
}
#[tokio::test]
async fn recall_limit_one() {
let (_tmp, mem) = temp_sqlite();
- mem.store("a", "matching content alpha", MemoryCategory::Core)
+ mem.store("a", "matching content alpha", MemoryCategory::Core, None)
.await
.unwrap();
- mem.store("b", "matching content beta", MemoryCategory::Core)
+ mem.store("b", "matching content beta", MemoryCategory::Core, None)
.await
.unwrap();
- let results = mem.recall("matching content", 1).await.unwrap();
+ let results = mem.recall("matching content", 1, None).await.unwrap();
assert_eq!(results.len(), 1);
}
@@ -1213,21 +1311,22 @@ mod tests {
"rust_preferences",
"User likes systems programming",
MemoryCategory::Core,
+ None,
)
.await
.unwrap();
// "rust" appears in key but not content — LIKE fallback checks key too
- let results = mem.recall("rust", 10).await.unwrap();
+ let results = mem.recall("rust", 10, None).await.unwrap();
assert!(!results.is_empty(), "Should match by key");
}
#[tokio::test]
async fn recall_unicode_query() {
let (_tmp, mem) = temp_sqlite();
- mem.store("jp", "日本語のテスト", MemoryCategory::Core)
+ mem.store("jp", "日本語のテスト", MemoryCategory::Core, None)
.await
.unwrap();
- let results = mem.recall("日本語", 10).await.unwrap();
+ let results = mem.recall("日本語", 10, None).await.unwrap();
assert!(!results.is_empty());
}
@@ -1238,7 +1337,9 @@ mod tests {
let tmp = TempDir::new().unwrap();
{
let mem = SqliteMemory::new(tmp.path()).unwrap();
- mem.store("k1", "v1", MemoryCategory::Core).await.unwrap();
+ mem.store("k1", "v1", MemoryCategory::Core, None)
+ .await
+ .unwrap();
}
// Open again — init_schema runs again on existing DB
let mem2 = SqliteMemory::new(tmp.path()).unwrap();
@@ -1246,7 +1347,9 @@ mod tests {
assert!(entry.is_some());
assert_eq!(entry.unwrap().content, "v1");
// Store more data — should work fine
- mem2.store("k2", "v2", MemoryCategory::Daily).await.unwrap();
+ mem2.store("k2", "v2", MemoryCategory::Daily, None)
+ .await
+ .unwrap();
assert_eq!(mem2.count().await.unwrap(), 2);
}
@@ -1264,11 +1367,16 @@ mod tests {
#[tokio::test]
async fn forget_then_recall_no_ghost_results() {
let (_tmp, mem) = temp_sqlite();
- mem.store("ghost", "phantom memory content", MemoryCategory::Core)
- .await
- .unwrap();
+ mem.store(
+ "ghost",
+ "phantom memory content",
+ MemoryCategory::Core,
+ None,
+ )
+ .await
+ .unwrap();
mem.forget("ghost").await.unwrap();
- let results = mem.recall("phantom memory", 10).await.unwrap();
+ let results = mem.recall("phantom memory", 10, None).await.unwrap();
assert!(
results.is_empty(),
"Deleted memory should not appear in recall"
@@ -1278,11 +1386,11 @@ mod tests {
#[tokio::test]
async fn forget_and_re_store_same_key() {
let (_tmp, mem) = temp_sqlite();
- mem.store("cycle", "version 1", MemoryCategory::Core)
+ mem.store("cycle", "version 1", MemoryCategory::Core, None)
.await
.unwrap();
mem.forget("cycle").await.unwrap();
- mem.store("cycle", "version 2", MemoryCategory::Core)
+ mem.store("cycle", "version 2", MemoryCategory::Core, None)
.await
.unwrap();
let entry = mem.get("cycle").await.unwrap().unwrap();
@@ -1302,14 +1410,14 @@ mod tests {
#[tokio::test]
async fn reindex_twice_is_safe() {
let (_tmp, mem) = temp_sqlite();
- mem.store("r1", "reindex data", MemoryCategory::Core)
+ mem.store("r1", "reindex data", MemoryCategory::Core, None)
.await
.unwrap();
mem.reindex().await.unwrap();
let count = mem.reindex().await.unwrap();
assert_eq!(count, 0); // Noop embedder → nothing to re-embed
// Data should still be intact
- let results = mem.recall("reindex", 10).await.unwrap();
+ let results = mem.recall("reindex", 10, None).await.unwrap();
assert_eq!(results.len(), 1);
}
@@ -1363,18 +1471,28 @@ mod tests {
#[tokio::test]
async fn list_custom_category() {
let (_tmp, mem) = temp_sqlite();
- mem.store("c1", "custom1", MemoryCategory::Custom("project".into()))
- .await
- .unwrap();
- mem.store("c2", "custom2", MemoryCategory::Custom("project".into()))
- .await
- .unwrap();
- mem.store("c3", "other", MemoryCategory::Core)
+ mem.store(
+ "c1",
+ "custom1",
+ MemoryCategory::Custom("project".into()),
+ None,
+ )
+ .await
+ .unwrap();
+ mem.store(
+ "c2",
+ "custom2",
+ MemoryCategory::Custom("project".into()),
+ None,
+ )
+ .await
+ .unwrap();
+ mem.store("c3", "other", MemoryCategory::Core, None)
.await
.unwrap();
let project = mem
- .list(Some(&MemoryCategory::Custom("project".into())))
+ .list(Some(&MemoryCategory::Custom("project".into())), None)
.await
.unwrap();
assert_eq!(project.len(), 2);
@@ -1383,7 +1501,122 @@ mod tests {
#[tokio::test]
async fn list_empty_db() {
let (_tmp, mem) = temp_sqlite();
- let all = mem.list(None).await.unwrap();
+ let all = mem.list(None, None).await.unwrap();
assert!(all.is_empty());
}
+
+ // ── Session isolation ─────────────────────────────────────────
+
+ #[tokio::test]
+ async fn store_and_recall_with_session_id() {
+ let (_tmp, mem) = temp_sqlite();
+ mem.store("k1", "session A fact", MemoryCategory::Core, Some("sess-a"))
+ .await
+ .unwrap();
+ mem.store("k2", "session B fact", MemoryCategory::Core, Some("sess-b"))
+ .await
+ .unwrap();
+ mem.store("k3", "no session fact", MemoryCategory::Core, None)
+ .await
+ .unwrap();
+
+ // Recall with session-a filter returns only session-a entry
+ let results = mem.recall("fact", 10, Some("sess-a")).await.unwrap();
+ assert_eq!(results.len(), 1);
+ assert_eq!(results[0].key, "k1");
+ assert_eq!(results[0].session_id.as_deref(), Some("sess-a"));
+ }
+
+ #[tokio::test]
+ async fn recall_no_session_filter_returns_all() {
+ let (_tmp, mem) = temp_sqlite();
+ mem.store("k1", "alpha fact", MemoryCategory::Core, Some("sess-a"))
+ .await
+ .unwrap();
+ mem.store("k2", "beta fact", MemoryCategory::Core, Some("sess-b"))
+ .await
+ .unwrap();
+ mem.store("k3", "gamma fact", MemoryCategory::Core, None)
+ .await
+ .unwrap();
+
+ // Recall without session filter returns all matching entries
+ let results = mem.recall("fact", 10, None).await.unwrap();
+ assert_eq!(results.len(), 3);
+ }
+
+ #[tokio::test]
+ async fn cross_session_recall_isolation() {
+ let (_tmp, mem) = temp_sqlite();
+ mem.store(
+ "secret",
+ "session A secret data",
+ MemoryCategory::Core,
+ Some("sess-a"),
+ )
+ .await
+ .unwrap();
+
+ // Session B cannot see session A data
+ let results = mem.recall("secret", 10, Some("sess-b")).await.unwrap();
+ assert!(results.is_empty());
+
+ // Session A can see its own data
+ let results = mem.recall("secret", 10, Some("sess-a")).await.unwrap();
+ assert_eq!(results.len(), 1);
+ }
+
+ #[tokio::test]
+ async fn list_with_session_filter() {
+ let (_tmp, mem) = temp_sqlite();
+ mem.store("k1", "a1", MemoryCategory::Core, Some("sess-a"))
+ .await
+ .unwrap();
+ mem.store("k2", "a2", MemoryCategory::Conversation, Some("sess-a"))
+ .await
+ .unwrap();
+ mem.store("k3", "b1", MemoryCategory::Core, Some("sess-b"))
+ .await
+ .unwrap();
+ mem.store("k4", "none1", MemoryCategory::Core, None)
+ .await
+ .unwrap();
+
+ // List with session-a filter
+ let results = mem.list(None, Some("sess-a")).await.unwrap();
+ assert_eq!(results.len(), 2);
+ assert!(results
+ .iter()
+ .all(|e| e.session_id.as_deref() == Some("sess-a")));
+
+ // List with session-a + category filter
+ let results = mem
+ .list(Some(&MemoryCategory::Core), Some("sess-a"))
+ .await
+ .unwrap();
+ assert_eq!(results.len(), 1);
+ assert_eq!(results[0].key, "k1");
+ }
+
+ #[tokio::test]
+ async fn schema_migration_idempotent_on_reopen() {
+ let tmp = TempDir::new().unwrap();
+
+ // First open: creates schema + migration
+ {
+ let mem = SqliteMemory::new(tmp.path()).unwrap();
+ mem.store("k1", "before reopen", MemoryCategory::Core, Some("sess-x"))
+ .await
+ .unwrap();
+ }
+
+ // Second open: migration runs again but is idempotent
+ {
+ let mem = SqliteMemory::new(tmp.path()).unwrap();
+ let results = mem.recall("reopen", 10, Some("sess-x")).await.unwrap();
+ assert_eq!(results.len(), 1);
+ assert_eq!(results[0].key, "k1");
+ assert_eq!(results[0].session_id.as_deref(), Some("sess-x"));
+ }
+ }
}
diff --git a/src/memory/traits.rs b/src/memory/traits.rs
index 72e120e..bf8c021 100644
--- a/src/memory/traits.rs
+++ b/src/memory/traits.rs
@@ -44,18 +44,32 @@ pub trait Memory: Send + Sync {
/// Backend name
fn name(&self) -> &str;
- /// Store a memory entry
- async fn store(&self, key: &str, content: &str, category: MemoryCategory)
- -> anyhow::Result<()>;
+ /// Store a memory entry, optionally scoped to a session
+ async fn store(
+ &self,
+ key: &str,
+ content: &str,
+ category: MemoryCategory,
+ session_id: Option<&str>,
+ ) -> anyhow::Result<()>;
- /// Recall memories matching a query (keyword search)
- async fn recall(&self, query: &str, limit: usize) -> anyhow::Result>;
+ /// Recall memories matching a query (keyword search), optionally scoped to a session
+ async fn recall(
+ &self,
+ query: &str,
+ limit: usize,
+ session_id: Option<&str>,
+ ) -> anyhow::Result>;
/// Get a specific memory by key
async fn get(&self, key: &str) -> anyhow::Result>;
- /// List all memory keys, optionally filtered by category
- async fn list(&self, category: Option<&MemoryCategory>) -> anyhow::Result>;
+ /// List all memory keys, optionally filtered by category and/or session
+ async fn list(
+ &self,
+ category: Option<&MemoryCategory>,
+ session_id: Option<&str>,
+ ) -> anyhow::Result>;
/// Remove a memory by key
async fn forget(&self, key: &str) -> anyhow::Result;
diff --git a/src/migration.rs b/src/migration.rs
index f217030..8a83262 100644
--- a/src/migration.rs
+++ b/src/migration.rs
@@ -95,7 +95,9 @@ async fn migrate_openclaw_memory(
stats.renamed_conflicts += 1;
}
- memory.store(&key, &entry.content, entry.category).await?;
+ memory
+ .store(&key, &entry.content, entry.category, None)
+ .await?;
stats.imported += 1;
}
@@ -488,7 +490,7 @@ mod tests {
// Existing target memory
let target_mem = SqliteMemory::new(target.path()).unwrap();
target_mem
- .store("k", "new value", MemoryCategory::Core)
+ .store("k", "new value", MemoryCategory::Core, None)
.await
.unwrap();
@@ -510,7 +512,7 @@ mod tests {
.await
.unwrap();
- let all = target_mem.list(None).await.unwrap();
+ let all = target_mem.list(None, None).await.unwrap();
assert!(all.iter().any(|e| e.key == "k" && e.content == "new value"));
assert!(all
.iter()
diff --git a/src/observability/log.rs b/src/observability/log.rs
index 9e3d062..b932fe0 100644
--- a/src/observability/log.rs
+++ b/src/observability/log.rs
@@ -48,9 +48,10 @@ impl Observer for LogObserver {
ObserverEvent::AgentEnd {
duration,
tokens_used,
+ cost_usd,
} => {
let ms = u64::try_from(duration.as_millis()).unwrap_or(u64::MAX);
- info!(duration_ms = ms, tokens = ?tokens_used, "agent.end");
+ info!(duration_ms = ms, tokens = ?tokens_used, cost_usd = ?cost_usd, "agent.end");
}
ObserverEvent::ToolCallStart { tool } => {
info!(tool = %tool, "tool.start");
@@ -133,10 +134,12 @@ mod tests {
obs.record_event(&ObserverEvent::AgentEnd {
duration: Duration::from_millis(500),
tokens_used: Some(100),
+ cost_usd: Some(0.0015),
});
obs.record_event(&ObserverEvent::AgentEnd {
duration: Duration::ZERO,
tokens_used: None,
+ cost_usd: None,
});
obs.record_event(&ObserverEvent::ToolCallStart {
tool: "shell".into(),
diff --git a/src/observability/noop.rs b/src/observability/noop.rs
index 1189490..004af21 100644
--- a/src/observability/noop.rs
+++ b/src/observability/noop.rs
@@ -48,10 +48,12 @@ mod tests {
obs.record_event(&ObserverEvent::AgentEnd {
duration: Duration::from_millis(100),
tokens_used: Some(42),
+ cost_usd: Some(0.001),
});
obs.record_event(&ObserverEvent::AgentEnd {
duration: Duration::ZERO,
tokens_used: None,
+ cost_usd: None,
});
obs.record_event(&ObserverEvent::ToolCallStart {
tool: "shell".into(),
diff --git a/src/observability/otel.rs b/src/observability/otel.rs
index 5e0c37e..ae4932d 100644
--- a/src/observability/otel.rs
+++ b/src/observability/otel.rs
@@ -227,6 +227,7 @@ impl Observer for OtelObserver {
ObserverEvent::AgentEnd {
duration,
tokens_used,
+ cost_usd,
} => {
let secs = duration.as_secs_f64();
let start_time = SystemTime::now()
@@ -243,6 +244,9 @@ impl Observer for OtelObserver {
if let Some(t) = tokens_used {
span.set_attribute(KeyValue::new("tokens_used", *t as i64));
}
+ if let Some(c) = cost_usd {
+ span.set_attribute(KeyValue::new("cost_usd", *c));
+ }
span.end();
self.agent_duration.record(secs, &[]);
@@ -394,10 +398,12 @@ mod tests {
obs.record_event(&ObserverEvent::AgentEnd {
duration: Duration::from_millis(500),
tokens_used: Some(100),
+ cost_usd: Some(0.0015),
});
obs.record_event(&ObserverEvent::AgentEnd {
duration: Duration::ZERO,
tokens_used: None,
+ cost_usd: None,
});
obs.record_event(&ObserverEvent::ToolCallStart {
tool: "shell".into(),
diff --git a/src/observability/traits.rs b/src/observability/traits.rs
index ca62caf..d978304 100644
--- a/src/observability/traits.rs
+++ b/src/observability/traits.rs
@@ -27,6 +27,7 @@ pub enum ObserverEvent {
AgentEnd {
duration: Duration,
tokens_used: Option,
+ cost_usd: Option,
},
/// A tool call is about to be executed.
ToolCallStart {
diff --git a/src/onboard/wizard.rs b/src/onboard/wizard.rs
index 20c3baa..0422e45 100644
--- a/src/onboard/wizard.rs
+++ b/src/onboard/wizard.rs
@@ -106,6 +106,7 @@ pub fn run_wizard() -> Result {
} else {
Some(api_key)
},
+ api_url: None,
default_provider: Some(provider),
default_model: Some(model),
default_temperature: 0.7,
@@ -284,7 +285,7 @@ fn memory_config_defaults_for_backend(backend: &str) -> MemoryConfig {
#[allow(clippy::too_many_lines)]
pub fn run_quick_setup(
- api_key: Option<&str>,
+ credential_override: Option<&str>,
provider: Option<&str>,
memory_backend: Option<&str>,
) -> Result {
@@ -318,7 +319,8 @@ pub fn run_quick_setup(
let config = Config {
workspace_dir: workspace_dir.clone(),
config_path: config_path.clone(),
- api_key: api_key.map(String::from),
+ api_key: credential_override.map(String::from),
+ api_url: None,
default_provider: Some(provider_name.clone()),
default_model: Some(model.clone()),
default_temperature: 0.7,
@@ -377,7 +379,7 @@ pub fn run_quick_setup(
println!(
" {} API Key: {}",
style("✓").green().bold(),
- if api_key.is_some() {
+ if credential_override.is_some() {
style("set").green()
} else {
style("not set (use --api-key or edit config.toml)").yellow()
@@ -426,7 +428,7 @@ pub fn run_quick_setup(
);
println!();
println!(" {}", style("Next steps:").white().bold());
- if api_key.is_none() {
+ if credential_override.is_none() {
println!(" 1. Set your API key: export OPENROUTER_API_KEY=\"sk-...\"");
println!(" 2. Or edit: ~/.zeroclaw/config.toml");
println!(" 3. Chat: zeroclaw agent -m \"Hello!\"");
@@ -2269,14 +2271,11 @@ fn setup_memory() -> Result {
let backend = backend_key_from_choice(choice);
let profile = memory_backend_profile(backend);
- let auto_save = if !profile.auto_save_default {
- false
- } else {
- Confirm::new()
+ let auto_save = profile.auto_save_default
+ && Confirm::new()
.with_prompt(" Auto-save conversations to memory?")
.default(true)
- .interact()?
- };
+ .interact()?;
println!(
" {} Memory: {} (auto-save: {})",
@@ -2587,6 +2586,7 @@ fn setup_channels() -> Result {
guild_id: if guild.is_empty() { None } else { Some(guild) },
allowed_users,
listen_to_bots: false,
+ mention_only: false,
});
}
2 => {
@@ -2799,22 +2799,14 @@ fn setup_channels() -> Result {
.header("Authorization", format!("Bearer {access_token_clone}"))
.send()?;
let ok = resp.status().is_success();
- let data: serde_json::Value = resp.json().unwrap_or_default();
- let user_id = data
- .get("user_id")
- .and_then(serde_json::Value::as_str)
- .unwrap_or("unknown")
- .to_string();
- Ok::<_, reqwest::Error>((ok, user_id))
+ Ok::<_, reqwest::Error>(ok)
})
.join();
match thread_result {
- Ok(Ok((true, user_id))) => {
- println!(
- "\r {} Connected as {user_id} ",
- style("✅").green().bold()
- );
- }
+ Ok(Ok(true)) => println!(
+ "\r {} Connection verified ",
+ style("✅").green().bold()
+ ),
_ => {
println!(
"\r {} Connection failed — check homeserver URL and token",
@@ -3779,15 +3771,7 @@ fn print_summary(config: &Config) {
);
// Secrets
- println!(
- " {} Secrets: {}",
- style("🔒").cyan(),
- if config.secrets.encrypt {
- style("encrypted").green().to_string()
- } else {
- style("plaintext").yellow().to_string()
- }
- );
+ println!(" {} Secrets: configured", style("🔒").cyan());
// Gateway
println!(
diff --git a/src/peripherals/arduino_flash.rs b/src/peripherals/arduino_flash.rs
index 8aaf287..7bc53f5 100644
--- a/src/peripherals/arduino_flash.rs
+++ b/src/peripherals/arduino_flash.rs
@@ -38,6 +38,10 @@ pub fn ensure_arduino_cli() -> Result<()> {
anyhow::bail!("brew install arduino-cli failed. Install manually: https://arduino.github.io/arduino-cli/");
}
println!("arduino-cli installed.");
+ if !arduino_cli_available() {
+ anyhow::bail!("arduino-cli still not found after install. Ensure it's in PATH.");
+ }
+ return Ok(());
}
#[cfg(target_os = "linux")]
@@ -54,11 +58,6 @@ pub fn ensure_arduino_cli() -> Result<()> {
println!("arduino-cli not found. Install it: https://arduino.github.io/arduino-cli/");
anyhow::bail!("arduino-cli not installed.");
}
-
- if !arduino_cli_available() {
- anyhow::bail!("arduino-cli still not found after install. Ensure it's in PATH.");
- }
- Ok(())
}
/// Ensure arduino:avr core is installed.
diff --git a/src/peripherals/serial.rs b/src/peripherals/serial.rs
index 05d0bae..2bcec56 100644
--- a/src/peripherals/serial.rs
+++ b/src/peripherals/serial.rs
@@ -112,6 +112,7 @@ pub struct SerialPeripheral {
impl SerialPeripheral {
/// Create and connect to a serial peripheral.
+ #[allow(clippy::unused_async)]
pub async fn connect(config: &PeripheralBoardConfig) -> anyhow::Result {
let path = config
.path
diff --git a/src/providers/anthropic.rs b/src/providers/anthropic.rs
index 4216853..1f45c7e 100644
--- a/src/providers/anthropic.rs
+++ b/src/providers/anthropic.rs
@@ -106,17 +106,17 @@ struct NativeContentIn {
}
impl AnthropicProvider {
- pub fn new(api_key: Option<&str>) -> Self {
- Self::with_base_url(api_key, None)
+ pub fn new(credential: Option<&str>) -> Self {
+ Self::with_base_url(credential, None)
}
- pub fn with_base_url(api_key: Option<&str>, base_url: Option<&str>) -> Self {
+ pub fn with_base_url(credential: Option<&str>, base_url: Option<&str>) -> Self {
let base_url = base_url
.map(|u| u.trim_end_matches('/'))
.unwrap_or("https://api.anthropic.com")
.to_string();
Self {
- credential: api_key
+ credential: credential
.map(str::trim)
.filter(|k| !k.is_empty())
.map(ToString::to_string),
@@ -410,9 +410,9 @@ mod tests {
#[test]
fn creates_with_key() {
- let p = AnthropicProvider::new(Some("sk-ant-test123"));
+ let p = AnthropicProvider::new(Some("anthropic-test-credential"));
assert!(p.credential.is_some());
- assert_eq!(p.credential.as_deref(), Some("sk-ant-test123"));
+ assert_eq!(p.credential.as_deref(), Some("anthropic-test-credential"));
assert_eq!(p.base_url, "https://api.anthropic.com");
}
@@ -431,17 +431,19 @@ mod tests {
#[test]
fn creates_with_whitespace_key() {
- let p = AnthropicProvider::new(Some(" sk-ant-test123 "));
+ let p = AnthropicProvider::new(Some(" anthropic-test-credential "));
assert!(p.credential.is_some());
- assert_eq!(p.credential.as_deref(), Some("sk-ant-test123"));
+ assert_eq!(p.credential.as_deref(), Some("anthropic-test-credential"));
}
#[test]
fn creates_with_custom_base_url() {
- let p =
- AnthropicProvider::with_base_url(Some("sk-ant-test"), Some("https://api.example.com"));
+ let p = AnthropicProvider::with_base_url(
+ Some("anthropic-credential"),
+ Some("https://api.example.com"),
+ );
assert_eq!(p.base_url, "https://api.example.com");
- assert_eq!(p.credential.as_deref(), Some("sk-ant-test"));
+ assert_eq!(p.credential.as_deref(), Some("anthropic-credential"));
}
#[test]
diff --git a/src/providers/compatible.rs b/src/providers/compatible.rs
index ee1c588..d17d309 100644
--- a/src/providers/compatible.rs
+++ b/src/providers/compatible.rs
@@ -17,7 +17,7 @@ use serde::{Deserialize, Serialize};
pub struct OpenAiCompatibleProvider {
pub(crate) name: String,
pub(crate) base_url: String,
- pub(crate) api_key: Option,
+ pub(crate) credential: Option,
pub(crate) auth_header: AuthStyle,
/// When false, do not fall back to /v1/responses on chat completions 404.
/// GLM/Zhipu does not support the responses API.
@@ -37,11 +37,16 @@ pub enum AuthStyle {
}
impl OpenAiCompatibleProvider {
- pub fn new(name: &str, base_url: &str, api_key: Option<&str>, auth_style: AuthStyle) -> Self {
+ pub fn new(
+ name: &str,
+ base_url: &str,
+ credential: Option<&str>,
+ auth_style: AuthStyle,
+ ) -> Self {
Self {
name: name.to_string(),
base_url: base_url.trim_end_matches('/').to_string(),
- api_key: api_key.map(ToString::to_string),
+ credential: credential.map(ToString::to_string),
auth_header: auth_style,
supports_responses_fallback: true,
client: Client::builder()
@@ -57,13 +62,13 @@ impl OpenAiCompatibleProvider {
pub fn new_no_responses_fallback(
name: &str,
base_url: &str,
- api_key: Option<&str>,
+ credential: Option<&str>,
auth_style: AuthStyle,
) -> Self {
Self {
name: name.to_string(),
base_url: base_url.trim_end_matches('/').to_string(),
- api_key: api_key.map(ToString::to_string),
+ credential: credential.map(ToString::to_string),
auth_header: auth_style,
supports_responses_fallback: false,
client: Client::builder()
@@ -405,18 +410,18 @@ impl OpenAiCompatibleProvider {
fn apply_auth_header(
&self,
req: reqwest::RequestBuilder,
- api_key: &str,
+ credential: &str,
) -> reqwest::RequestBuilder {
match &self.auth_header {
- AuthStyle::Bearer => req.header("Authorization", format!("Bearer {api_key}")),
- AuthStyle::XApiKey => req.header("x-api-key", api_key),
- AuthStyle::Custom(header) => req.header(header, api_key),
+ AuthStyle::Bearer => req.header("Authorization", format!("Bearer {credential}")),
+ AuthStyle::XApiKey => req.header("x-api-key", credential),
+ AuthStyle::Custom(header) => req.header(header, credential),
}
}
async fn chat_via_responses(
&self,
- api_key: &str,
+ credential: &str,
system_prompt: Option<&str>,
message: &str,
model: &str,
@@ -434,7 +439,7 @@ impl OpenAiCompatibleProvider {
let url = self.responses_url();
let response = self
- .apply_auth_header(self.client.post(&url).json(&request), api_key)
+ .apply_auth_header(self.client.post(&url).json(&request), credential)
.send()
.await?;
@@ -459,7 +464,7 @@ impl Provider for OpenAiCompatibleProvider {
model: &str,
temperature: f64,
) -> anyhow::Result {
- let api_key = self.api_key.as_ref().ok_or_else(|| {
+ let credential = self.credential.as_ref().ok_or_else(|| {
anyhow::anyhow!(
"{} API key not set. Run `zeroclaw onboard` or set the appropriate env var.",
self.name
@@ -490,7 +495,7 @@ impl Provider for OpenAiCompatibleProvider {
let url = self.chat_completions_url();
let response = self
- .apply_auth_header(self.client.post(&url).json(&request), api_key)
+ .apply_auth_header(self.client.post(&url).json(&request), credential)
.send()
.await?;
@@ -501,7 +506,7 @@ impl Provider for OpenAiCompatibleProvider {
if status == reqwest::StatusCode::NOT_FOUND && self.supports_responses_fallback {
return self
- .chat_via_responses(api_key, system_prompt, message, model)
+ .chat_via_responses(credential, system_prompt, message, model)
.await
.map_err(|responses_err| {
anyhow::anyhow!(
@@ -545,7 +550,7 @@ impl Provider for OpenAiCompatibleProvider {
model: &str,
temperature: f64,
) -> anyhow::Result {
- let api_key = self.api_key.as_ref().ok_or_else(|| {
+ let credential = self.credential.as_ref().ok_or_else(|| {
anyhow::anyhow!(
"{} API key not set. Run `zeroclaw onboard` or set the appropriate env var.",
self.name
@@ -569,7 +574,7 @@ impl Provider for OpenAiCompatibleProvider {
let url = self.chat_completions_url();
let response = self
- .apply_auth_header(self.client.post(&url).json(&request), api_key)
+ .apply_auth_header(self.client.post(&url).json(&request), credential)
.send()
.await?;
@@ -584,7 +589,7 @@ impl Provider for OpenAiCompatibleProvider {
if let Some(user_msg) = last_user {
return self
.chat_via_responses(
- api_key,
+ credential,
system.map(|m| m.content.as_str()),
&user_msg.content,
model,
@@ -791,16 +796,20 @@ mod tests {
#[test]
fn creates_with_key() {
- let p = make_provider("venice", "https://api.venice.ai", Some("vn-key"));
+ let p = make_provider(
+ "venice",
+ "https://api.venice.ai",
+ Some("venice-test-credential"),
+ );
assert_eq!(p.name, "venice");
assert_eq!(p.base_url, "https://api.venice.ai");
- assert_eq!(p.api_key.as_deref(), Some("vn-key"));
+ assert_eq!(p.credential.as_deref(), Some("venice-test-credential"));
}
#[test]
fn creates_without_key() {
let p = make_provider("test", "https://example.com", None);
- assert!(p.api_key.is_none());
+ assert!(p.credential.is_none());
}
#[test]
@@ -894,6 +903,7 @@ mod tests {
make_provider("Groq", "https://api.groq.com/openai", None),
make_provider("Mistral", "https://api.mistral.ai", None),
make_provider("xAI", "https://api.x.ai", None),
+ make_provider("Astrai", "https://as-trai.com/v1", None),
];
for p in providers {
diff --git a/src/providers/copilot.rs b/src/providers/copilot.rs
new file mode 100644
index 0000000..ab8eb3b
--- /dev/null
+++ b/src/providers/copilot.rs
@@ -0,0 +1,705 @@
+//! GitHub Copilot provider with OAuth device-flow authentication.
+//!
+//! Authenticates via GitHub's device code flow (same as VS Code Copilot),
+//! then exchanges the OAuth token for short-lived Copilot API keys.
+//! Tokens are cached to disk and auto-refreshed.
+//!
+//! **Note:** This uses VS Code's OAuth client ID (`Iv1.b507a08c87ecfe98`) and
+//! editor headers. This is the same approach used by LiteLLM, Codex CLI,
+//! and other third-party Copilot integrations. The Copilot token endpoint is
+//! private; there is no public OAuth scope or app registration for it.
+//! GitHub could change or revoke this at any time, which would break all
+//! third-party integrations simultaneously.
+
+use crate::providers::traits::{
+ ChatMessage, ChatRequest as ProviderChatRequest, ChatResponse as ProviderChatResponse,
+ Provider, ToolCall as ProviderToolCall,
+};
+use crate::tools::ToolSpec;
+use async_trait::async_trait;
+use reqwest::Client;
+use serde::{Deserialize, Serialize};
+use std::path::{Path, PathBuf};
+use std::sync::Arc;
+use std::time::Duration;
+use tokio::sync::Mutex;
+use tracing::warn;
+
+/// GitHub OAuth client ID for Copilot (VS Code extension).
+const GITHUB_CLIENT_ID: &str = "Iv1.b507a08c87ecfe98";
+const GITHUB_DEVICE_CODE_URL: &str = "https://github.com/login/device/code";
+const GITHUB_ACCESS_TOKEN_URL: &str = "https://github.com/login/oauth/access_token";
+const GITHUB_API_KEY_URL: &str = "https://api.github.com/copilot_internal/v2/token";
+const DEFAULT_API: &str = "https://api.githubcopilot.com";
+
+// ── Token types ──────────────────────────────────────────────────
+
+#[derive(Debug, Deserialize)]
+struct DeviceCodeResponse {
+ device_code: String,
+ user_code: String,
+ verification_uri: String,
+ #[serde(default = "default_interval")]
+ interval: u64,
+ #[serde(default = "default_expires_in")]
+ expires_in: u64,
+}
+
+fn default_interval() -> u64 {
+ 5
+}
+
+fn default_expires_in() -> u64 {
+ 900
+}
+
+#[derive(Debug, Deserialize)]
+struct AccessTokenResponse {
+ access_token: Option,
+ error: Option,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+struct ApiKeyInfo {
+ token: String,
+ expires_at: i64,
+ #[serde(default)]
+ endpoints: Option,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+struct ApiEndpoints {
+ api: Option,
+}
+
+struct CachedApiKey {
+ token: String,
+ api_endpoint: String,
+ expires_at: i64,
+}
+
+// ── Chat completions types ───────────────────────────────────────
+
+#[derive(Debug, Serialize)]
+struct ApiChatRequest {
+ model: String,
+ messages: Vec,
+ temperature: f64,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ tools: Option>,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ tool_choice: Option,
+}
+
+#[derive(Debug, Serialize)]
+struct ApiMessage {
+ role: String,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ content: Option,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ tool_call_id: Option,
+ #[serde(skip_serializing_if = "Option::is_none")]
+ tool_calls: Option>,
+}
+
+#[derive(Debug, Serialize)]
+struct NativeToolSpec {
+ #[serde(rename = "type")]
+ kind: String,
+ function: NativeToolFunctionSpec,
+}
+
+#[derive(Debug, Serialize)]
+struct NativeToolFunctionSpec {
+ name: String,
+ description: String,
+ parameters: serde_json::Value,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+struct NativeToolCall {
+ #[serde(skip_serializing_if = "Option::is_none")]
+ id: Option,
+ #[serde(rename = "type", skip_serializing_if = "Option::is_none")]
+ kind: Option,
+ function: NativeFunctionCall,
+}
+
+#[derive(Debug, Serialize, Deserialize)]
+struct NativeFunctionCall {
+ name: String,
+ arguments: String,
+}
+
+#[derive(Debug, Deserialize)]
+struct ApiChatResponse {
+ choices: Vec,
+}
+
+#[derive(Debug, Deserialize)]
+struct Choice {
+ message: ResponseMessage,
+}
+
+#[derive(Debug, Deserialize)]
+struct ResponseMessage {
+ #[serde(default)]
+ content: Option,
+ #[serde(default)]
+ tool_calls: Option>,
+}
+
+// ── Provider ─────────────────────────────────────────────────────
+
+/// GitHub Copilot provider with automatic OAuth and token refresh.
+///
+/// On first use, prompts the user to visit github.com/login/device.
+/// Tokens are cached to `~/.config/zeroclaw/copilot/` and refreshed
+/// automatically.
+pub struct CopilotProvider {
+ github_token: Option,
+ /// Mutex ensures only one caller refreshes tokens at a time,
+ /// preventing duplicate device flow prompts or redundant API calls.
+ refresh_lock: Arc>>,
+ http: Client,
+ token_dir: PathBuf,
+}
+
+impl CopilotProvider {
+ pub fn new(github_token: Option<&str>) -> Self {
+ let token_dir = directories::ProjectDirs::from("", "", "zeroclaw")
+ .map(|dir| dir.config_dir().join("copilot"))
+ .unwrap_or_else(|| {
+ // Fall back to a user-specific temp directory to avoid
+ // shared-directory symlink attacks.
+ let user = std::env::var("USER")
+ .or_else(|_| std::env::var("USERNAME"))
+ .unwrap_or_else(|_| "unknown".to_string());
+ std::env::temp_dir().join(format!("zeroclaw-copilot-{user}"))
+ });
+
+ if let Err(err) = std::fs::create_dir_all(&token_dir) {
+ warn!(
+ "Failed to create Copilot token directory {:?}: {err}. Token caching is disabled.",
+ token_dir
+ );
+ } else {
+ #[cfg(unix)]
+ {
+ use std::os::unix::fs::PermissionsExt;
+
+ if let Err(err) =
+ std::fs::set_permissions(&token_dir, std::fs::Permissions::from_mode(0o700))
+ {
+ warn!(
+ "Failed to set Copilot token directory permissions on {:?}: {err}",
+ token_dir
+ );
+ }
+ }
+ }
+
+ Self {
+ github_token: github_token
+ .filter(|token| !token.is_empty())
+ .map(String::from),
+ refresh_lock: Arc::new(Mutex::new(None)),
+ http: Client::builder()
+ .timeout(Duration::from_secs(120))
+ .connect_timeout(Duration::from_secs(10))
+ .build()
+ .unwrap_or_else(|_| Client::new()),
+ token_dir,
+ }
+ }
+
+ /// Required headers for Copilot API requests (editor identification).
+ const COPILOT_HEADERS: [(&str, &str); 4] = [
+ ("Editor-Version", "vscode/1.85.1"),
+ ("Editor-Plugin-Version", "copilot/1.155.0"),
+ ("User-Agent", "GithubCopilot/1.155.0"),
+ ("Accept", "application/json"),
+ ];
+
+ fn convert_tools(tools: Option<&[ToolSpec]>) -> Option> {
+ tools.map(|items| {
+ items
+ .iter()
+ .map(|tool| NativeToolSpec {
+ kind: "function".to_string(),
+ function: NativeToolFunctionSpec {
+ name: tool.name.clone(),
+ description: tool.description.clone(),
+ parameters: tool.parameters.clone(),
+ },
+ })
+ .collect()
+ })
+ }
+
+ fn convert_messages(messages: &[ChatMessage]) -> Vec {
+ messages
+ .iter()
+ .map(|message| {
+ if message.role == "assistant" {
+ if let Ok(value) = serde_json::from_str::(&message.content) {
+ if let Some(tool_calls_value) = value.get("tool_calls") {
+ if let Ok(parsed_calls) =
+ serde_json::from_value::>(tool_calls_value.clone())
+ {
+ let tool_calls = parsed_calls
+ .into_iter()
+ .map(|tool_call| NativeToolCall {
+ id: Some(tool_call.id),
+ kind: Some("function".to_string()),
+ function: NativeFunctionCall {
+ name: tool_call.name,
+ arguments: tool_call.arguments,
+ },
+ })
+ .collect::>();
+
+ let content = value
+ .get("content")
+ .and_then(serde_json::Value::as_str)
+ .map(ToString::to_string);
+
+ return ApiMessage {
+ role: "assistant".to_string(),
+ content,
+ tool_call_id: None,
+ tool_calls: Some(tool_calls),
+ };
+ }
+ }
+ }
+ }
+
+ if message.role == "tool" {
+ if let Ok(value) = serde_json::from_str::(&message.content) {
+ let tool_call_id = value
+ .get("tool_call_id")
+ .and_then(serde_json::Value::as_str)
+ .map(ToString::to_string);
+ let content = value
+ .get("content")
+ .and_then(serde_json::Value::as_str)
+ .map(ToString::to_string);
+
+ return ApiMessage {
+ role: "tool".to_string(),
+ content,
+ tool_call_id,
+ tool_calls: None,
+ };
+ }
+ }
+
+ ApiMessage {
+ role: message.role.clone(),
+ content: Some(message.content.clone()),
+ tool_call_id: None,
+ tool_calls: None,
+ }
+ })
+ .collect()
+ }
+
+ /// Send a chat completions request with required Copilot headers.
+ async fn send_chat_request(
+ &self,
+ messages: Vec,
+ tools: Option<&[ToolSpec]>,
+ model: &str,
+ temperature: f64,
+ ) -> anyhow::Result {
+ let (token, endpoint) = self.get_api_key().await?;
+ let url = format!("{}/chat/completions", endpoint.trim_end_matches('/'));
+
+ let native_tools = Self::convert_tools(tools);
+ let request = ApiChatRequest {
+ model: model.to_string(),
+ messages,
+ temperature,
+ tool_choice: native_tools.as_ref().map(|_| "auto".to_string()),
+ tools: native_tools,
+ };
+
+ let mut req = self
+ .http
+ .post(&url)
+ .header("Authorization", format!("Bearer {token}"))
+ .json(&request);
+
+ for (header, value) in &Self::COPILOT_HEADERS {
+ req = req.header(*header, *value);
+ }
+
+ let response = req.send().await?;
+
+ if !response.status().is_success() {
+ return Err(super::api_error("GitHub Copilot", response).await);
+ }
+
+ let api_response: ApiChatResponse = response.json().await?;
+ let choice = api_response
+ .choices
+ .into_iter()
+ .next()
+ .ok_or_else(|| anyhow::anyhow!("No response from GitHub Copilot"))?;
+
+ let tool_calls = choice
+ .message
+ .tool_calls
+ .unwrap_or_default()
+ .into_iter()
+ .map(|tool_call| ProviderToolCall {
+ id: tool_call
+ .id
+ .unwrap_or_else(|| uuid::Uuid::new_v4().to_string()),
+ name: tool_call.function.name,
+ arguments: tool_call.function.arguments,
+ })
+ .collect();
+
+ Ok(ProviderChatResponse {
+ text: choice.message.content,
+ tool_calls,
+ })
+ }
+
+ /// Get a valid Copilot API key, refreshing or re-authenticating as needed.
+ /// Uses a Mutex to ensure only one caller refreshes at a time.
+ async fn get_api_key(&self) -> anyhow::Result<(String, String)> {
+ let mut cached = self.refresh_lock.lock().await;
+
+ if let Some(cached_key) = cached.as_ref() {
+ if chrono::Utc::now().timestamp() + 120 < cached_key.expires_at {
+ return Ok((cached_key.token.clone(), cached_key.api_endpoint.clone()));
+ }
+ }
+
+ if let Some(info) = self.load_api_key_from_disk().await {
+ if chrono::Utc::now().timestamp() + 120 < info.expires_at {
+ let endpoint = info
+ .endpoints
+ .as_ref()
+ .and_then(|e| e.api.clone())
+ .unwrap_or_else(|| DEFAULT_API.to_string());
+ let token = info.token;
+
+ *cached = Some(CachedApiKey {
+ token: token.clone(),
+ api_endpoint: endpoint.clone(),
+ expires_at: info.expires_at,
+ });
+ return Ok((token, endpoint));
+ }
+ }
+
+ let access_token = self.get_github_access_token().await?;
+ let api_key_info = self.exchange_for_api_key(&access_token).await?;
+ self.save_api_key_to_disk(&api_key_info).await;
+
+ let endpoint = api_key_info
+ .endpoints
+ .as_ref()
+ .and_then(|e| e.api.clone())
+ .unwrap_or_else(|| DEFAULT_API.to_string());
+
+ *cached = Some(CachedApiKey {
+ token: api_key_info.token.clone(),
+ api_endpoint: endpoint.clone(),
+ expires_at: api_key_info.expires_at,
+ });
+
+ Ok((api_key_info.token, endpoint))
+ }
+
+ /// Get a GitHub access token from config, cache, or device flow.
+ async fn get_github_access_token(&self) -> anyhow::Result {
+ if let Some(token) = &self.github_token {
+ return Ok(token.clone());
+ }
+
+ let access_token_path = self.token_dir.join("access-token");
+ if let Ok(cached) = tokio::fs::read_to_string(&access_token_path).await {
+ let token = cached.trim();
+ if !token.is_empty() {
+ return Ok(token.to_string());
+ }
+ }
+
+ let token = self.device_code_login().await?;
+ write_file_secure(&access_token_path, &token).await;
+ Ok(token)
+ }
+
+ /// Run GitHub OAuth device code flow.
+ async fn device_code_login(&self) -> anyhow::Result {
+ let response: DeviceCodeResponse = self
+ .http
+ .post(GITHUB_DEVICE_CODE_URL)
+ .header("Accept", "application/json")
+ .json(&serde_json::json!({
+ "client_id": GITHUB_CLIENT_ID,
+ "scope": "read:user"
+ }))
+ .send()
+ .await?
+ .error_for_status()?
+ .json()
+ .await?;
+
+ let mut poll_interval = Duration::from_secs(response.interval.max(5));
+ let expires_in = response.expires_in.max(1);
+ let expires_at = tokio::time::Instant::now() + Duration::from_secs(expires_in);
+
+ eprintln!(
+ "\nGitHub Copilot authentication is required.\n\
+ Visit: {}\n\
+ Code: {}\n\
+ Waiting for authorization...\n",
+ response.verification_uri, response.user_code
+ );
+
+ while tokio::time::Instant::now() < expires_at {
+ tokio::time::sleep(poll_interval).await;
+
+ let token_response: AccessTokenResponse = self
+ .http
+ .post(GITHUB_ACCESS_TOKEN_URL)
+ .header("Accept", "application/json")
+ .json(&serde_json::json!({
+ "client_id": GITHUB_CLIENT_ID,
+ "device_code": response.device_code,
+ "grant_type": "urn:ietf:params:oauth:grant-type:device_code"
+ }))
+ .send()
+ .await?
+ .json()
+ .await?;
+
+ if let Some(token) = token_response.access_token {
+ eprintln!("Authentication succeeded.\n");
+ return Ok(token);
+ }
+
+ match token_response.error.as_deref() {
+ Some("slow_down") => {
+ poll_interval += Duration::from_secs(5);
+ }
+ Some("authorization_pending") | None => {}
+ Some("expired_token") => {
+ anyhow::bail!("GitHub device authorization expired")
+ }
+ Some(error) => anyhow::bail!("GitHub auth failed: {error}"),
+ }
+ }
+
+ anyhow::bail!("Timed out waiting for GitHub authorization")
+ }
+
+ /// Exchange a GitHub access token for a Copilot API key.
+ async fn exchange_for_api_key(&self, access_token: &str) -> anyhow::Result {
+ let mut request = self.http.get(GITHUB_API_KEY_URL);
+ for (header, value) in &Self::COPILOT_HEADERS {
+ request = request.header(*header, *value);
+ }
+ request = request.header("Authorization", format!("token {access_token}"));
+
+ let response = request.send().await?;
+
+ if !response.status().is_success() {
+ let status = response.status();
+ let body = response.text().await.unwrap_or_default();
+ let sanitized = super::sanitize_api_error(&body);
+
+ if status.as_u16() == 401 || status.as_u16() == 403 {
+ let access_token_path = self.token_dir.join("access-token");
+ tokio::fs::remove_file(&access_token_path).await.ok();
+ }
+
+ anyhow::bail!(
+ "Failed to get Copilot API key ({status}): {sanitized}. \
+ Ensure your GitHub account has an active Copilot subscription."
+ );
+ }
+
+ let info: ApiKeyInfo = response.json().await?;
+ Ok(info)
+ }
+
+ async fn load_api_key_from_disk(&self) -> Option {
+ let path = self.token_dir.join("api-key.json");
+ let data = tokio::fs::read_to_string(&path).await.ok()?;
+ serde_json::from_str(&data).ok()
+ }
+
+ async fn save_api_key_to_disk(&self, info: &ApiKeyInfo) {
+ let path = self.token_dir.join("api-key.json");
+ if let Ok(json) = serde_json::to_string_pretty(info) {
+ write_file_secure(&path, &json).await;
+ }
+ }
+}
+
+/// Write a file with 0600 permissions (owner read/write only).
+/// Uses `spawn_blocking` to avoid blocking the async runtime.
+async fn write_file_secure(path: &Path, content: &str) {
+ let path = path.to_path_buf();
+ let content = content.to_string();
+
+ let result = tokio::task::spawn_blocking(move || {
+ #[cfg(unix)]
+ {
+ use std::io::Write;
+ use std::os::unix::fs::{OpenOptionsExt, PermissionsExt};
+
+ let mut file = std::fs::OpenOptions::new()
+ .write(true)
+ .create(true)
+ .truncate(true)
+ .mode(0o600)
+ .open(&path)?;
+ file.write_all(content.as_bytes())?;
+
+ std::fs::set_permissions(&path, std::fs::Permissions::from_mode(0o600))?;
+ Ok::<(), std::io::Error>(())
+ }
+ #[cfg(not(unix))]
+ {
+ std::fs::write(&path, &content)?;
+ Ok::<(), std::io::Error>(())
+ }
+ })
+ .await;
+
+ match result {
+ Ok(Ok(())) => {}
+ Ok(Err(err)) => warn!("Failed to write secure file: {err}"),
+ Err(err) => warn!("Failed to spawn blocking write: {err}"),
+ }
+}
+
+#[async_trait]
+impl Provider for CopilotProvider {
+ async fn chat_with_system(
+ &self,
+ system_prompt: Option<&str>,
+ message: &str,
+ model: &str,
+ temperature: f64,
+ ) -> anyhow::Result {
+ let mut messages = Vec::new();
+ if let Some(system) = system_prompt {
+ messages.push(ApiMessage {
+ role: "system".to_string(),
+ content: Some(system.to_string()),
+ tool_call_id: None,
+ tool_calls: None,
+ });
+ }
+ messages.push(ApiMessage {
+ role: "user".to_string(),
+ content: Some(message.to_string()),
+ tool_call_id: None,
+ tool_calls: None,
+ });
+
+ let response = self
+ .send_chat_request(messages, None, model, temperature)
+ .await?;
+ Ok(response.text.unwrap_or_default())
+ }
+
+ async fn chat_with_history(
+ &self,
+ messages: &[ChatMessage],
+ model: &str,
+ temperature: f64,
+ ) -> anyhow::Result {
+ let response = self
+ .send_chat_request(Self::convert_messages(messages), None, model, temperature)
+ .await?;
+ Ok(response.text.unwrap_or_default())
+ }
+
+ async fn chat(
+ &self,
+ request: ProviderChatRequest<'_>,
+ model: &str,
+ temperature: f64,
+ ) -> anyhow::Result {
+ self.send_chat_request(
+ Self::convert_messages(request.messages),
+ request.tools,
+ model,
+ temperature,
+ )
+ .await
+ }
+
+ fn supports_native_tools(&self) -> bool {
+ true
+ }
+
+ async fn warmup(&self) -> anyhow::Result<()> {
+ let _ = self.get_api_key().await?;
+ Ok(())
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn new_without_token() {
+ let provider = CopilotProvider::new(None);
+ assert!(provider.github_token.is_none());
+ }
+
+ #[test]
+ fn new_with_token() {
+ let provider = CopilotProvider::new(Some("ghp_test"));
+ assert_eq!(provider.github_token.as_deref(), Some("ghp_test"));
+ }
+
+ #[test]
+ fn empty_token_treated_as_none() {
+ let provider = CopilotProvider::new(Some(""));
+ assert!(provider.github_token.is_none());
+ }
+
+ #[tokio::test]
+ async fn cache_starts_empty() {
+ let provider = CopilotProvider::new(None);
+ let cached = provider.refresh_lock.lock().await;
+ assert!(cached.is_none());
+ }
+
+ #[test]
+ fn copilot_headers_include_required_fields() {
+ let headers = CopilotProvider::COPILOT_HEADERS;
+ assert!(headers
+ .iter()
+ .any(|(header, _)| *header == "Editor-Version"));
+ assert!(headers
+ .iter()
+ .any(|(header, _)| *header == "Editor-Plugin-Version"));
+ assert!(headers.iter().any(|(header, _)| *header == "User-Agent"));
+ }
+
+ #[test]
+ fn default_interval_and_expiry() {
+ assert_eq!(default_interval(), 5);
+ assert_eq!(default_expires_in(), 900);
+ }
+
+ #[test]
+ fn supports_native_tools() {
+ let provider = CopilotProvider::new(None);
+ assert!(provider.supports_native_tools());
+ }
+}
diff --git a/src/providers/mod.rs b/src/providers/mod.rs
index 86517d6..e18e789 100644
--- a/src/providers/mod.rs
+++ b/src/providers/mod.rs
@@ -1,5 +1,6 @@
pub mod anthropic;
pub mod compatible;
+pub mod copilot;
pub mod gemini;
pub mod ollama;
pub mod openai;
@@ -37,9 +38,18 @@ fn token_end(input: &str, from: usize) -> usize {
/// Scrub known secret-like token prefixes from provider error strings.
///
-/// Redacts tokens with prefixes like `sk-`, `xoxb-`, and `xoxp-`.
+/// Redacts tokens with prefixes like `sk-`, `xoxb-`, `xoxp-`, `ghp_`, `gho_`,
+/// `ghu_`, and `github_pat_`.
pub fn scrub_secret_patterns(input: &str) -> String {
- const PREFIXES: [&str; 3] = ["sk-", "xoxb-", "xoxp-"];
+ const PREFIXES: [&str; 7] = [
+ "sk-",
+ "xoxb-",
+ "xoxp-",
+ "ghp_",
+ "gho_",
+ "ghu_",
+ "github_pat_",
+ ];
let mut scrubbed = input.to_string();
@@ -104,9 +114,12 @@ pub async fn api_error(provider: &str, response: reqwest::Response) -> anyhow::E
///
/// For Anthropic, the provider-specific env var is `ANTHROPIC_OAUTH_TOKEN` (for setup-tokens)
/// followed by `ANTHROPIC_API_KEY` (for regular API keys).
-fn resolve_api_key(name: &str, api_key: Option<&str>) -> Option {
- if let Some(key) = api_key.map(str::trim).filter(|k| !k.is_empty()) {
- return Some(key.to_string());
+fn resolve_provider_credential(name: &str, credential_override: Option<&str>) -> Option {
+ if let Some(raw_override) = credential_override {
+ let trimmed_override = raw_override.trim();
+ if !trimmed_override.is_empty() {
+ return Some(trimmed_override.to_owned());
+ }
}
let provider_env_candidates: Vec<&str> = match name {
@@ -135,6 +148,7 @@ fn resolve_api_key(name: &str, api_key: Option<&str>) -> Option {
"opencode" | "opencode-zen" => vec!["OPENCODE_API_KEY"],
"vercel" | "vercel-ai" => vec!["VERCEL_API_KEY"],
"cloudflare" | "cloudflare-ai" => vec!["CLOUDFLARE_API_KEY"],
+ "astrai" => vec!["ASTRAI_API_KEY"],
_ => vec![],
};
@@ -182,19 +196,28 @@ fn parse_custom_provider_url(
}
}
-/// Factory: create the right provider from config
-#[allow(clippy::too_many_lines)]
+/// Factory: create the right provider from config (without custom URL)
pub fn create_provider(name: &str, api_key: Option<&str>) -> anyhow::Result> {
- let resolved_key = resolve_api_key(name, api_key);
- let key = resolved_key.as_deref();
+ create_provider_with_url(name, api_key, None)
+}
+
+/// Factory: create the right provider from config with optional custom base URL
+#[allow(clippy::too_many_lines)]
+pub fn create_provider_with_url(
+ name: &str,
+ api_key: Option<&str>,
+ api_url: Option<&str>,
+) -> anyhow::Result> {
+ let resolved_credential = resolve_provider_credential(name, api_key);
+ #[allow(clippy::option_as_ref_deref)]
+ let key = resolved_credential.as_ref().map(String::as_str);
match name {
// ── Primary providers (custom implementations) ───────
"openrouter" => Ok(Box::new(openrouter::OpenRouterProvider::new(key))),
"anthropic" => Ok(Box::new(anthropic::AnthropicProvider::new(key))),
"openai" => Ok(Box::new(openai::OpenAiProvider::new(key))),
- // Ollama is a local service that doesn't use API keys.
- // The api_key parameter is ignored to avoid it being misinterpreted as a base_url.
- "ollama" => Ok(Box::new(ollama::OllamaProvider::new(None))),
+ // Ollama uses api_url for custom base URL (e.g. remote Ollama instance)
+ "ollama" => Ok(Box::new(ollama::OllamaProvider::new(api_url))),
"gemini" | "google" | "google-gemini" => {
Ok(Box::new(gemini::GeminiProvider::new(key)))
}
@@ -257,7 +280,7 @@ pub fn create_provider(name: &str, api_key: Option<&str>) -> anyhow::Result Ok(Box::new(OpenAiCompatibleProvider::new(
- "Mistral", "https://api.mistral.ai", key, AuthStyle::Bearer,
+ "Mistral", "https://api.mistral.ai/v1", key, AuthStyle::Bearer,
))),
"xai" | "grok" => Ok(Box::new(OpenAiCompatibleProvider::new(
"xAI", "https://api.x.ai", key, AuthStyle::Bearer,
@@ -277,11 +300,33 @@ pub fn create_provider(name: &str, api_key: Option<&str>) -> anyhow::Result Ok(Box::new(OpenAiCompatibleProvider::new(
"Cohere", "https://api.cohere.com/compatibility", key, AuthStyle::Bearer,
))),
- "copilot" | "github-copilot" => Ok(Box::new(OpenAiCompatibleProvider::new(
- "GitHub Copilot", "https://api.githubcopilot.com", key, AuthStyle::Bearer,
- ))),
- "nvidia" | "nvidia-nim" | "build.nvidia.com" => Ok(Box::new(OpenAiCompatibleProvider::new(
- "NVIDIA NIM", "https://integrate.api.nvidia.com/v1", key, AuthStyle::Bearer,
+ "copilot" | "github-copilot" => {
+ Ok(Box::new(copilot::CopilotProvider::new(api_key)))
+ },
+ "lmstudio" | "lm-studio" => {
+ let lm_studio_key = api_key
+ .map(str::trim)
+ .filter(|value| !value.is_empty())
+ .unwrap_or("lm-studio");
+ Ok(Box::new(OpenAiCompatibleProvider::new(
+ "LM Studio",
+ "http://localhost:1234/v1",
+ Some(lm_studio_key),
+ AuthStyle::Bearer,
+ )))
+ }
+ "nvidia" | "nvidia-nim" | "build.nvidia.com" => Ok(Box::new(
+ OpenAiCompatibleProvider::new(
+ "NVIDIA NIM",
+ "https://integrate.api.nvidia.com/v1",
+ key,
+ AuthStyle::Bearer,
+ ),
+ )),
+
+ // ── AI inference routers ─────────────────────────────
+ "astrai" => Ok(Box::new(OpenAiCompatibleProvider::new(
+ "Astrai", "https://as-trai.com/v1", key, AuthStyle::Bearer,
))),
// ── Bring Your Own Provider (custom URL) ───────────
@@ -326,13 +371,14 @@ pub fn create_provider(name: &str, api_key: Option<&str>) -> anyhow::Result,
+ api_url: Option<&str>,
reliability: &crate::config::ReliabilityConfig,
) -> anyhow::Result> {
let mut providers: Vec<(String, Box)> = Vec::new();
providers.push((
primary_name.to_string(),
- create_provider(primary_name, api_key)?,
+ create_provider_with_url(primary_name, api_key, api_url)?,
));
for fallback in &reliability.fallback_providers {
@@ -340,21 +386,13 @@ pub fn create_resilient_provider(
continue;
}
- if api_key.is_some() && fallback != "ollama" {
- tracing::warn!(
- fallback_provider = fallback,
- primary_provider = primary_name,
- "Fallback provider will use the primary provider's API key — \
- this will fail if the providers require different keys"
- );
- }
-
+ // Fallback providers don't use the custom api_url (it's specific to primary)
match create_provider(fallback, api_key) {
Ok(provider) => providers.push((fallback.clone(), provider)),
- Err(e) => {
+ Err(_error) => {
tracing::warn!(
fallback_provider = fallback,
- "Ignoring invalid fallback provider: {e}"
+ "Ignoring invalid fallback provider during initialization"
);
}
}
@@ -377,12 +415,13 @@ pub fn create_resilient_provider(
pub fn create_routed_provider(
primary_name: &str,
api_key: Option<&str>,
+ api_url: Option<&str>,
reliability: &crate::config::ReliabilityConfig,
model_routes: &[crate::config::ModelRouteConfig],
default_model: &str,
) -> anyhow::Result> {
if model_routes.is_empty() {
- return create_resilient_provider(primary_name, api_key, reliability);
+ return create_resilient_provider(primary_name, api_key, api_url, reliability);
}
// Collect unique provider names needed
@@ -396,12 +435,19 @@ pub fn create_routed_provider(
// Create each provider (with its own resilience wrapper)
let mut providers: Vec<(String, Box)> = Vec::new();
for name in &needed {
- let key = model_routes
+ let routed_credential = model_routes
.iter()
.find(|r| &r.provider == name)
- .and_then(|r| r.api_key.as_deref())
- .or(api_key);
- match create_resilient_provider(name, key, reliability) {
+ .and_then(|r| {
+ r.api_key.as_ref().and_then(|raw_key| {
+ let trimmed_key = raw_key.trim();
+ (!trimmed_key.is_empty()).then_some(trimmed_key)
+ })
+ });
+ let key = routed_credential.or(api_key);
+ // Only use api_url for the primary provider
+ let url = if name == primary_name { api_url } else { None };
+ match create_resilient_provider(name, key, url, reliability) {
Ok(provider) => providers.push((name.clone(), provider)),
Err(e) => {
if name == primary_name {
@@ -409,7 +455,7 @@ pub fn create_routed_provider(
}
tracing::warn!(
provider = name.as_str(),
- "Ignoring routed provider that failed to create: {e}"
+ "Ignoring routed provider that failed to initialize"
);
}
}
@@ -441,27 +487,27 @@ mod tests {
use super::*;
#[test]
- fn resolve_api_key_prefers_explicit_argument() {
- let resolved = resolve_api_key("openrouter", Some(" explicit-key "));
- assert_eq!(resolved.as_deref(), Some("explicit-key"));
+ fn resolve_provider_credential_prefers_explicit_argument() {
+ let resolved = resolve_provider_credential("openrouter", Some(" explicit-key "));
+ assert_eq!(resolved, Some("explicit-key".to_string()));
}
// ── Primary providers ────────────────────────────────────
#[test]
fn factory_openrouter() {
- assert!(create_provider("openrouter", Some("sk-test")).is_ok());
+ assert!(create_provider("openrouter", Some("provider-test-credential")).is_ok());
assert!(create_provider("openrouter", None).is_ok());
}
#[test]
fn factory_anthropic() {
- assert!(create_provider("anthropic", Some("sk-test")).is_ok());
+ assert!(create_provider("anthropic", Some("provider-test-credential")).is_ok());
}
#[test]
fn factory_openai() {
- assert!(create_provider("openai", Some("sk-test")).is_ok());
+ assert!(create_provider("openai", Some("provider-test-credential")).is_ok());
}
#[test]
@@ -556,6 +602,13 @@ mod tests {
assert!(create_provider("dashscope-us", Some("key")).is_ok());
}
+ #[test]
+ fn factory_lmstudio() {
+ assert!(create_provider("lmstudio", Some("key")).is_ok());
+ assert!(create_provider("lm-studio", Some("key")).is_ok());
+ assert!(create_provider("lmstudio", None).is_ok());
+ }
+
// ── Extended ecosystem ───────────────────────────────────
#[test]
@@ -614,6 +667,13 @@ mod tests {
assert!(create_provider("build.nvidia.com", Some("nvapi-test")).is_ok());
}
+ // ── AI inference routers ─────────────────────────────────
+
+ #[test]
+ fn factory_astrai() {
+ assert!(create_provider("astrai", Some("sk-astrai-test")).is_ok());
+ }
+
// ── Custom / BYOP provider ─────────────────────────────
#[test]
@@ -761,17 +821,33 @@ mod tests {
scheduler_retries: 2,
};
- let provider = create_resilient_provider("openrouter", Some("sk-test"), &reliability);
+ let provider = create_resilient_provider(
+ "openrouter",
+ Some("provider-test-credential"),
+ None,
+ &reliability,
+ );
assert!(provider.is_ok());
}
#[test]
fn resilient_provider_errors_for_invalid_primary() {
let reliability = crate::config::ReliabilityConfig::default();
- let provider = create_resilient_provider("totally-invalid", Some("sk-test"), &reliability);
+ let provider = create_resilient_provider(
+ "totally-invalid",
+ Some("provider-test-credential"),
+ None,
+ &reliability,
+ );
assert!(provider.is_err());
}
+ #[test]
+ fn ollama_with_custom_url() {
+ let provider = create_provider_with_url("ollama", None, Some("http://10.100.2.32:11434"));
+ assert!(provider.is_ok());
+ }
+
#[test]
fn factory_all_providers_create_successfully() {
let providers = [
@@ -794,6 +870,7 @@ mod tests {
"qwen",
"qwen-intl",
"qwen-us",
+ "lmstudio",
"groq",
"mistral",
"xai",
@@ -888,7 +965,7 @@ mod tests {
#[test]
fn sanitize_preserves_unicode_boundaries() {
- let input = format!("{} sk-abcdef123", "こんにちは".repeat(80));
+ let input = format!("{} sk-abcdef123", "hello🙂".repeat(80));
let result = sanitize_api_error(&input);
assert!(std::str::from_utf8(result.as_bytes()).is_ok());
assert!(!result.contains("sk-abcdef123"));
@@ -900,4 +977,32 @@ mod tests {
let result = sanitize_api_error(input);
assert_eq!(result, input);
}
+
+ #[test]
+ fn scrub_github_personal_access_token() {
+ let input = "auth failed with token ghp_abc123def456";
+ let result = scrub_secret_patterns(input);
+ assert_eq!(result, "auth failed with token [REDACTED]");
+ }
+
+ #[test]
+ fn scrub_github_oauth_token() {
+ let input = "Bearer gho_1234567890abcdef";
+ let result = scrub_secret_patterns(input);
+ assert_eq!(result, "Bearer [REDACTED]");
+ }
+
+ #[test]
+ fn scrub_github_user_token() {
+ let input = "token ghu_sessiontoken123";
+ let result = scrub_secret_patterns(input);
+ assert_eq!(result, "token [REDACTED]");
+ }
+
+ #[test]
+ fn scrub_github_fine_grained_pat() {
+ let input = "failed: github_pat_11AABBC_xyzzy789";
+ let result = scrub_secret_patterns(input);
+ assert_eq!(result, "failed: [REDACTED]");
+ }
}
diff --git a/src/providers/ollama.rs b/src/providers/ollama.rs
index 8ecfb5a..e05f027 100644
--- a/src/providers/ollama.rs
+++ b/src/providers/ollama.rs
@@ -8,6 +8,8 @@ pub struct OllamaProvider {
client: Client,
}
+// ─── Request Structures ───────────────────────────────────────────────────────
+
#[derive(Debug, Serialize)]
struct ChatRequest {
model: String,
@@ -27,6 +29,8 @@ struct Options {
temperature: f64,
}
+// ─── Response Structures ──────────────────────────────────────────────────────
+
#[derive(Debug, Deserialize)]
struct ApiChatResponse {
message: ResponseMessage,
@@ -34,9 +38,30 @@ struct ApiChatResponse {
#[derive(Debug, Deserialize)]
struct ResponseMessage {
+ #[serde(default)]
content: String,
+ #[serde(default)]
+ tool_calls: Vec,
+ /// Some models return a "thinking" field with internal reasoning
+ #[serde(default)]
+ thinking: Option,
}
+#[derive(Debug, Deserialize)]
+struct OllamaToolCall {
+ id: Option,
+ function: OllamaFunction,
+}
+
+#[derive(Debug, Deserialize)]
+struct OllamaFunction {
+ name: String,
+ #[serde(default)]
+ arguments: serde_json::Value,
+}
+
+// ─── Implementation ───────────────────────────────────────────────────────────
+
impl OllamaProvider {
pub fn new(base_url: Option<&str>) -> Self {
Self {
@@ -45,12 +70,145 @@ impl OllamaProvider {
.trim_end_matches('/')
.to_string(),
client: Client::builder()
- .timeout(std::time::Duration::from_secs(300)) // Ollama runs locally, may be slow
+ .timeout(std::time::Duration::from_secs(300))
.connect_timeout(std::time::Duration::from_secs(10))
.build()
.unwrap_or_else(|_| Client::new()),
}
}
+
+ /// Send a request to Ollama and get the parsed response
+ async fn send_request(
+ &self,
+ messages: Vec,
+ model: &str,
+ temperature: f64,
+ ) -> anyhow::Result {
+ let request = ChatRequest {
+ model: model.to_string(),
+ messages,
+ stream: false,
+ options: Options { temperature },
+ };
+
+ let url = format!("{}/api/chat", self.base_url);
+
+ tracing::debug!(
+ "Ollama request: url={} model={} message_count={} temperature={}",
+ url,
+ model,
+ request.messages.len(),
+ temperature
+ );
+
+ let response = self.client.post(&url).json(&request).send().await?;
+ let status = response.status();
+ tracing::debug!("Ollama response status: {}", status);
+
+ let body = response.bytes().await?;
+ tracing::debug!("Ollama response body length: {} bytes", body.len());
+
+ if !status.is_success() {
+ let raw = String::from_utf8_lossy(&body);
+ let sanitized = super::sanitize_api_error(&raw);
+ tracing::error!(
+ "Ollama error response: status={} body_excerpt={}",
+ status,
+ sanitized
+ );
+ anyhow::bail!(
+ "Ollama API error ({}): {}. Is Ollama running? (brew install ollama && ollama serve)",
+ status,
+ sanitized
+ );
+ }
+
+ let chat_response: ApiChatResponse = match serde_json::from_slice(&body) {
+ Ok(r) => r,
+ Err(e) => {
+ let raw = String::from_utf8_lossy(&body);
+ let sanitized = super::sanitize_api_error(&raw);
+ tracing::error!(
+ "Ollama response deserialization failed: {e}. body_excerpt={}",
+ sanitized
+ );
+ anyhow::bail!("Failed to parse Ollama response: {e}");
+ }
+ };
+
+ Ok(chat_response)
+ }
+
+ /// Convert Ollama tool calls to the JSON format expected by parse_tool_calls in loop_.rs
+ ///
+ /// Handles quirky model behavior where tool calls are wrapped:
+ /// - `{"name": "tool_call", "arguments": {"name": "shell", "arguments": {...}}}`
+ /// - `{"name": "tool.shell", "arguments": {...}}`
+ fn format_tool_calls_for_loop(&self, tool_calls: &[OllamaToolCall]) -> String {
+ let formatted_calls: Vec = tool_calls
+ .iter()
+ .map(|tc| {
+ let (tool_name, tool_args) = self.extract_tool_name_and_args(tc);
+
+ // Arguments must be a JSON string for parse_tool_calls compatibility
+ let args_str =
+ serde_json::to_string(&tool_args).unwrap_or_else(|_| "{}".to_string());
+
+ serde_json::json!({
+ "id": tc.id,
+ "type": "function",
+ "function": {
+ "name": tool_name,
+ "arguments": args_str
+ }
+ })
+ })
+ .collect();
+
+ serde_json::json!({
+ "content": "",
+ "tool_calls": formatted_calls
+ })
+ .to_string()
+ }
+
+ /// Extract the actual tool name and arguments from potentially nested structures
+ fn extract_tool_name_and_args(&self, tc: &OllamaToolCall) -> (String, serde_json::Value) {
+ let name = &tc.function.name;
+ let args = &tc.function.arguments;
+
+ // Pattern 1: Nested tool_call wrapper (various malformed versions)
+ // {"name": "tool_call", "arguments": {"name": "shell", "arguments": {"command": "date"}}}
+ // {"name": "tool_call>")
+ || name.starts_with("tool_call<")
+ {
+ if let Some(nested_name) = args.get("name").and_then(|v| v.as_str()) {
+ let nested_args = args
+ .get("arguments")
+ .cloned()
+ .unwrap_or(serde_json::json!({}));
+ tracing::debug!(
+ "Unwrapped nested tool call: {} -> {} with args {:?}",
+ name,
+ nested_name,
+ nested_args
+ );
+ return (nested_name.to_string(), nested_args);
+ }
+ }
+
+ // Pattern 2: Prefixed tool name (tool.shell, tool.file_read, etc.)
+ if let Some(stripped) = name.strip_prefix("tool.") {
+ return (stripped.to_string(), args.clone());
+ }
+
+ // Pattern 3: Normal tool call
+ (name.clone(), args.clone())
+ }
}
#[async_trait]
@@ -76,27 +234,96 @@ impl Provider for OllamaProvider {
content: message.to_string(),
});
- let request = ChatRequest {
- model: model.to_string(),
- messages,
- stream: false,
- options: Options { temperature },
- };
+ let response = self.send_request(messages, model, temperature).await?;
- let url = format!("{}/api/chat", self.base_url);
-
- let response = self.client.post(&url).json(&request).send().await?;
-
- if !response.status().is_success() {
- let err = super::api_error("Ollama", response).await;
- anyhow::bail!("{err}. Is Ollama running? (brew install ollama && ollama serve)");
+ // If model returned tool calls, format them for loop_.rs's parse_tool_calls
+ if !response.message.tool_calls.is_empty() {
+ tracing::debug!(
+ "Ollama returned {} tool call(s), formatting for loop parser",
+ response.message.tool_calls.len()
+ );
+ return Ok(self.format_tool_calls_for_loop(&response.message.tool_calls));
}
- let chat_response: ApiChatResponse = response.json().await?;
- Ok(chat_response.message.content)
+ // Plain text response
+ let content = response.message.content;
+
+ // Handle edge case: model returned only "thinking" with no content or tool calls
+ if content.is_empty() {
+ if let Some(thinking) = &response.message.thinking {
+ tracing::warn!(
+ "Ollama returned empty content with only thinking: '{}'. Model may have stopped prematurely.",
+ if thinking.len() > 100 { &thinking[..100] } else { thinking }
+ );
+ return Ok(format!(
+ "I was thinking about this: {}... but I didn't complete my response. Could you try asking again?",
+ if thinking.len() > 200 { &thinking[..200] } else { thinking }
+ ));
+ }
+ tracing::warn!("Ollama returned empty content with no tool calls");
+ }
+
+ Ok(content)
+ }
+
+ async fn chat_with_history(
+ &self,
+ messages: &[crate::providers::ChatMessage],
+ model: &str,
+ temperature: f64,
+ ) -> anyhow::Result {
+ let api_messages: Vec = messages
+ .iter()
+ .map(|m| Message {
+ role: m.role.clone(),
+ content: m.content.clone(),
+ })
+ .collect();
+
+ let response = self.send_request(api_messages, model, temperature).await?;
+
+ // If model returned tool calls, format them for loop_.rs's parse_tool_calls
+ if !response.message.tool_calls.is_empty() {
+ tracing::debug!(
+ "Ollama returned {} tool call(s), formatting for loop parser",
+ response.message.tool_calls.len()
+ );
+ return Ok(self.format_tool_calls_for_loop(&response.message.tool_calls));
+ }
+
+ // Plain text response
+ let content = response.message.content;
+
+ // Handle edge case: model returned only "thinking" with no content or tool calls
+ // This is a model quirk - it stopped after reasoning without producing output
+ if content.is_empty() {
+ if let Some(thinking) = &response.message.thinking {
+ tracing::warn!(
+ "Ollama returned empty content with only thinking: '{}'. Model may have stopped prematurely.",
+ if thinking.len() > 100 { &thinking[..100] } else { thinking }
+ );
+ // Return a message indicating the model's thought process but no action
+ return Ok(format!(
+ "I was thinking about this: {}... but I didn't complete my response. Could you try asking again?",
+ if thinking.len() > 200 { &thinking[..200] } else { thinking }
+ ));
+ }
+ tracing::warn!("Ollama returned empty content with no tool calls");
+ }
+
+ Ok(content)
+ }
+
+ fn supports_native_tools(&self) -> bool {
+ // Return false since loop_.rs uses XML-style tool parsing via system prompt
+ // The model may return native tool_calls but we convert them to JSON format
+ // that parse_tool_calls() understands
+ false
}
}
+// ─── Tests ────────────────────────────────────────────────────────────────────
+
#[cfg(test)]
mod tests {
use super::*;
@@ -125,46 +352,6 @@ mod tests {
assert_eq!(p.base_url, "");
}
- #[test]
- fn request_serializes_with_system() {
- let req = ChatRequest {
- model: "llama3".to_string(),
- messages: vec![
- Message {
- role: "system".to_string(),
- content: "You are ZeroClaw".to_string(),
- },
- Message {
- role: "user".to_string(),
- content: "hello".to_string(),
- },
- ],
- stream: false,
- options: Options { temperature: 0.7 },
- };
- let json = serde_json::to_string(&req).unwrap();
- assert!(json.contains("\"stream\":false"));
- assert!(json.contains("llama3"));
- assert!(json.contains("system"));
- assert!(json.contains("\"temperature\":0.7"));
- }
-
- #[test]
- fn request_serializes_without_system() {
- let req = ChatRequest {
- model: "mistral".to_string(),
- messages: vec![Message {
- role: "user".to_string(),
- content: "test".to_string(),
- }],
- stream: false,
- options: Options { temperature: 0.0 },
- };
- let json = serde_json::to_string(&req).unwrap();
- assert!(!json.contains("\"role\":\"system\""));
- assert!(json.contains("mistral"));
- }
-
#[test]
fn response_deserializes() {
let json = r#"{"message":{"role":"assistant","content":"Hello from Ollama!"}}"#;
@@ -180,9 +367,98 @@ mod tests {
}
#[test]
- fn response_with_multiline() {
- let json = r#"{"message":{"role":"assistant","content":"line1\nline2\nline3"}}"#;
+ fn response_with_missing_content_defaults_to_empty() {
+ let json = r#"{"message":{"role":"assistant"}}"#;
let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
- assert!(resp.message.content.contains("line1"));
+ assert!(resp.message.content.is_empty());
+ }
+
+ #[test]
+ fn response_with_thinking_field_extracts_content() {
+ let json =
+ r#"{"message":{"role":"assistant","content":"hello","thinking":"internal reasoning"}}"#;
+ let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
+ assert_eq!(resp.message.content, "hello");
+ }
+
+ #[test]
+ fn response_with_tool_calls_parses_correctly() {
+ let json = r#"{"message":{"role":"assistant","content":"","tool_calls":[{"id":"call_123","function":{"name":"shell","arguments":{"command":"date"}}}]}}"#;
+ let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
+ assert!(resp.message.content.is_empty());
+ assert_eq!(resp.message.tool_calls.len(), 1);
+ assert_eq!(resp.message.tool_calls[0].function.name, "shell");
+ }
+
+ #[test]
+ fn extract_tool_name_handles_nested_tool_call() {
+ let provider = OllamaProvider::new(None);
+ let tc = OllamaToolCall {
+ id: Some("call_123".into()),
+ function: OllamaFunction {
+ name: "tool_call".into(),
+ arguments: serde_json::json!({
+ "name": "shell",
+ "arguments": {"command": "date"}
+ }),
+ },
+ };
+ let (name, args) = provider.extract_tool_name_and_args(&tc);
+ assert_eq!(name, "shell");
+ assert_eq!(args.get("command").unwrap(), "date");
+ }
+
+ #[test]
+ fn extract_tool_name_handles_prefixed_name() {
+ let provider = OllamaProvider::new(None);
+ let tc = OllamaToolCall {
+ id: Some("call_123".into()),
+ function: OllamaFunction {
+ name: "tool.shell".into(),
+ arguments: serde_json::json!({"command": "ls"}),
+ },
+ };
+ let (name, args) = provider.extract_tool_name_and_args(&tc);
+ assert_eq!(name, "shell");
+ assert_eq!(args.get("command").unwrap(), "ls");
+ }
+
+ #[test]
+ fn extract_tool_name_handles_normal_call() {
+ let provider = OllamaProvider::new(None);
+ let tc = OllamaToolCall {
+ id: Some("call_123".into()),
+ function: OllamaFunction {
+ name: "file_read".into(),
+ arguments: serde_json::json!({"path": "/tmp/test"}),
+ },
+ };
+ let (name, args) = provider.extract_tool_name_and_args(&tc);
+ assert_eq!(name, "file_read");
+ assert_eq!(args.get("path").unwrap(), "/tmp/test");
+ }
+
+ #[test]
+ fn format_tool_calls_produces_valid_json() {
+ let provider = OllamaProvider::new(None);
+ let tool_calls = vec![OllamaToolCall {
+ id: Some("call_abc".into()),
+ function: OllamaFunction {
+ name: "shell".into(),
+ arguments: serde_json::json!({"command": "date"}),
+ },
+ }];
+
+ let formatted = provider.format_tool_calls_for_loop(&tool_calls);
+ let parsed: serde_json::Value = serde_json::from_str(&formatted).unwrap();
+
+ assert!(parsed.get("tool_calls").is_some());
+ let calls = parsed.get("tool_calls").unwrap().as_array().unwrap();
+ assert_eq!(calls.len(), 1);
+
+ let func = calls[0].get("function").unwrap();
+ assert_eq!(func.get("name").unwrap(), "shell");
+ // arguments should be a string (JSON-encoded)
+ assert!(func.get("arguments").unwrap().is_string());
}
}
diff --git a/src/providers/openai.rs b/src/providers/openai.rs
index ef67678..22b53ca 100644
--- a/src/providers/openai.rs
+++ b/src/providers/openai.rs
@@ -8,7 +8,7 @@ use reqwest::Client;
use serde::{Deserialize, Serialize};
pub struct OpenAiProvider {
- api_key: Option,
+ credential: Option,
client: Client,
}
@@ -110,9 +110,9 @@ struct NativeResponseMessage {
}
impl OpenAiProvider {
- pub fn new(api_key: Option<&str>) -> Self {
+ pub fn new(credential: Option<&str>) -> Self {
Self {
- api_key: api_key.map(ToString::to_string),
+ credential: credential.map(ToString::to_string),
client: Client::builder()
.timeout(std::time::Duration::from_secs(120))
.connect_timeout(std::time::Duration::from_secs(10))
@@ -232,7 +232,7 @@ impl Provider for OpenAiProvider {
model: &str,
temperature: f64,
) -> anyhow::Result {
- let api_key = self.api_key.as_ref().ok_or_else(|| {
+ let credential = self.credential.as_ref().ok_or_else(|| {
anyhow::anyhow!("OpenAI API key not set. Set OPENAI_API_KEY or edit config.toml.")
})?;
@@ -259,7 +259,7 @@ impl Provider for OpenAiProvider {
let response = self
.client
.post("https://api.openai.com/v1/chat/completions")
- .header("Authorization", format!("Bearer {api_key}"))
+ .header("Authorization", format!("Bearer {credential}"))
.json(&request)
.send()
.await?;
@@ -284,7 +284,7 @@ impl Provider for OpenAiProvider {
model: &str,
temperature: f64,
) -> anyhow::Result {
- let api_key = self.api_key.as_ref().ok_or_else(|| {
+ let credential = self.credential.as_ref().ok_or_else(|| {
anyhow::anyhow!("OpenAI API key not set. Set OPENAI_API_KEY or edit config.toml.")
})?;
@@ -300,7 +300,7 @@ impl Provider for OpenAiProvider {
let response = self
.client
.post("https://api.openai.com/v1/chat/completions")
- .header("Authorization", format!("Bearer {api_key}"))
+ .header("Authorization", format!("Bearer {credential}"))
.json(&native_request)
.send()
.await?;
@@ -330,20 +330,20 @@ mod tests {
#[test]
fn creates_with_key() {
- let p = OpenAiProvider::new(Some("sk-proj-abc123"));
- assert_eq!(p.api_key.as_deref(), Some("sk-proj-abc123"));
+ let p = OpenAiProvider::new(Some("openai-test-credential"));
+ assert_eq!(p.credential.as_deref(), Some("openai-test-credential"));
}
#[test]
fn creates_without_key() {
let p = OpenAiProvider::new(None);
- assert!(p.api_key.is_none());
+ assert!(p.credential.is_none());
}
#[test]
fn creates_with_empty_key() {
let p = OpenAiProvider::new(Some(""));
- assert_eq!(p.api_key.as_deref(), Some(""));
+ assert_eq!(p.credential.as_deref(), Some(""));
}
#[tokio::test]
diff --git a/src/providers/openrouter.rs b/src/providers/openrouter.rs
index 2896c07..b27bff4 100644
--- a/src/providers/openrouter.rs
+++ b/src/providers/openrouter.rs
@@ -8,7 +8,7 @@ use reqwest::Client;
use serde::{Deserialize, Serialize};
pub struct OpenRouterProvider {
- api_key: Option,
+ credential: Option,
client: Client,
}
@@ -110,9 +110,9 @@ struct NativeResponseMessage {
}
impl OpenRouterProvider {
- pub fn new(api_key: Option<&str>) -> Self {
+ pub fn new(credential: Option<&str>) -> Self {
Self {
- api_key: api_key.map(ToString::to_string),
+ credential: credential.map(ToString::to_string),
client: Client::builder()
.timeout(std::time::Duration::from_secs(120))
.connect_timeout(std::time::Duration::from_secs(10))
@@ -232,10 +232,10 @@ impl Provider for OpenRouterProvider {
async fn warmup(&self) -> anyhow::Result<()> {
// Hit a lightweight endpoint to establish TLS + HTTP/2 connection pool.
// This prevents the first real chat request from timing out on cold start.
- if let Some(api_key) = self.api_key.as_ref() {
+ if let Some(credential) = self.credential.as_ref() {
self.client
.get("https://openrouter.ai/api/v1/auth/key")
- .header("Authorization", format!("Bearer {api_key}"))
+ .header("Authorization", format!("Bearer {credential}"))
.send()
.await?
.error_for_status()?;
@@ -250,7 +250,7 @@ impl Provider for OpenRouterProvider {
model: &str,
temperature: f64,
) -> anyhow::Result {
- let api_key = self.api_key.as_ref()
+ let credential = self.credential.as_ref()
.ok_or_else(|| anyhow::anyhow!("OpenRouter API key not set. Run `zeroclaw onboard` or set OPENROUTER_API_KEY env var."))?;
let mut messages = Vec::new();
@@ -276,7 +276,7 @@ impl Provider for OpenRouterProvider {
let response = self
.client
.post("https://openrouter.ai/api/v1/chat/completions")
- .header("Authorization", format!("Bearer {api_key}"))
+ .header("Authorization", format!("Bearer {credential}"))
.header(
"HTTP-Referer",
"https://github.com/theonlyhennygod/zeroclaw",
@@ -306,7 +306,7 @@ impl Provider for OpenRouterProvider {
model: &str,
temperature: f64,
) -> anyhow::Result {
- let api_key = self.api_key.as_ref()
+ let credential = self.credential.as_ref()
.ok_or_else(|| anyhow::anyhow!("OpenRouter API key not set. Run `zeroclaw onboard` or set OPENROUTER_API_KEY env var."))?;
let api_messages: Vec = messages
@@ -326,7 +326,7 @@ impl Provider for OpenRouterProvider {
let response = self
.client
.post("https://openrouter.ai/api/v1/chat/completions")
- .header("Authorization", format!("Bearer {api_key}"))
+ .header("Authorization", format!("Bearer {credential}"))
.header(
"HTTP-Referer",
"https://github.com/theonlyhennygod/zeroclaw",
@@ -356,7 +356,7 @@ impl Provider for OpenRouterProvider {
model: &str,
temperature: f64,
) -> anyhow::Result {
- let api_key = self.api_key.as_ref().ok_or_else(|| {
+ let credential = self.credential.as_ref().ok_or_else(|| {
anyhow::anyhow!(
"OpenRouter API key not set. Run `zeroclaw onboard` or set OPENROUTER_API_KEY env var."
)
@@ -374,7 +374,7 @@ impl Provider for OpenRouterProvider {
let response = self
.client
.post("https://openrouter.ai/api/v1/chat/completions")
- .header("Authorization", format!("Bearer {api_key}"))
+ .header("Authorization", format!("Bearer {credential}"))
.header(
"HTTP-Referer",
"https://github.com/theonlyhennygod/zeroclaw",
@@ -409,7 +409,7 @@ impl Provider for OpenRouterProvider {
model: &str,
temperature: f64,
) -> anyhow::Result {
- let api_key = self.api_key.as_ref().ok_or_else(|| {
+ let credential = self.credential.as_ref().ok_or_else(|| {
anyhow::anyhow!(
"OpenRouter API key not set. Run `zeroclaw onboard` or set OPENROUTER_API_KEY env var."
)
@@ -462,7 +462,7 @@ impl Provider for OpenRouterProvider {
let response = self
.client
.post("https://openrouter.ai/api/v1/chat/completions")
- .header("Authorization", format!("Bearer {api_key}"))
+ .header("Authorization", format!("Bearer {credential}"))
.header(
"HTTP-Referer",
"https://github.com/theonlyhennygod/zeroclaw",
@@ -494,14 +494,17 @@ mod tests {
#[test]
fn creates_with_key() {
- let provider = OpenRouterProvider::new(Some("sk-or-123"));
- assert_eq!(provider.api_key.as_deref(), Some("sk-or-123"));
+ let provider = OpenRouterProvider::new(Some("openrouter-test-credential"));
+ assert_eq!(
+ provider.credential.as_deref(),
+ Some("openrouter-test-credential")
+ );
}
#[test]
fn creates_without_key() {
let provider = OpenRouterProvider::new(None);
- assert!(provider.api_key.is_none());
+ assert!(provider.credential.is_none());
}
#[tokio::test]
diff --git a/src/providers/reliable.rs b/src/providers/reliable.rs
index 045f2c3..32cc0ca 100644
--- a/src/providers/reliable.rs
+++ b/src/providers/reliable.rs
@@ -144,8 +144,8 @@ impl Provider for ReliableProvider {
async fn warmup(&self) -> anyhow::Result<()> {
for (name, provider) in &self.providers {
tracing::info!(provider = name, "Warming up provider connection pool");
- if let Err(e) = provider.warmup().await {
- tracing::warn!(provider = name, "Warmup failed (non-fatal): {e}");
+ if provider.warmup().await.is_err() {
+ tracing::warn!(provider = name, "Warmup failed (non-fatal)");
}
}
Ok(())
@@ -186,8 +186,15 @@ impl Provider for ReliableProvider {
let non_retryable = is_non_retryable(&e);
let rate_limited = is_rate_limited(&e);
+ let failure_reason = if rate_limited {
+ "rate_limited"
+ } else if non_retryable {
+ "non_retryable"
+ } else {
+ "retryable"
+ };
failures.push(format!(
- "{provider_name}/{current_model} attempt {}/{}: {e}",
+ "{provider_name}/{current_model} attempt {}/{}: {failure_reason}",
attempt + 1,
self.max_retries + 1
));
@@ -284,8 +291,15 @@ impl Provider for ReliableProvider {
let non_retryable = is_non_retryable(&e);
let rate_limited = is_rate_limited(&e);
+ let failure_reason = if rate_limited {
+ "rate_limited"
+ } else if non_retryable {
+ "non_retryable"
+ } else {
+ "retryable"
+ };
failures.push(format!(
- "{provider_name}/{current_model} attempt {}/{}: {e}",
+ "{provider_name}/{current_model} attempt {}/{}: {failure_reason}",
attempt + 1,
self.max_retries + 1
));
diff --git a/src/providers/traits.rs b/src/providers/traits.rs
index f43d099..380bbc5 100644
--- a/src/providers/traits.rs
+++ b/src/providers/traits.rs
@@ -193,6 +193,13 @@ pub enum StreamError {
#[async_trait]
pub trait Provider: Send + Sync {
+ /// Query provider capabilities.
+ ///
+ /// Default implementation returns minimal capabilities (no native tool calling).
+ /// Providers should override this to declare their actual capabilities.
+ fn capabilities(&self) -> ProviderCapabilities {
+ ProviderCapabilities::default()
+ }
/// Simple one-shot chat (single user message, no explicit system prompt).
///
/// This is the preferred API for non-agentic direct interactions.
@@ -256,7 +263,7 @@ pub trait Provider: Send + Sync {
/// Whether provider supports native tool calls over API.
fn supports_native_tools(&self) -> bool {
- false
+ self.capabilities().native_tool_calling
}
/// Warm up the HTTP connection pool (TLS handshake, DNS, HTTP/2 setup).
@@ -336,6 +343,27 @@ pub trait Provider: Send + Sync {
mod tests {
use super::*;
+ struct CapabilityMockProvider;
+
+ #[async_trait]
+ impl Provider for CapabilityMockProvider {
+ fn capabilities(&self) -> ProviderCapabilities {
+ ProviderCapabilities {
+ native_tool_calling: true,
+ }
+ }
+
+ async fn chat_with_system(
+ &self,
+ _system_prompt: Option<&str>,
+ _message: &str,
+ _model: &str,
+ _temperature: f64,
+ ) -> anyhow::Result {
+ Ok("ok".into())
+ }
+ }
+
#[test]
fn chat_message_constructors() {
let sys = ChatMessage::system("Be helpful");
@@ -398,4 +426,32 @@ mod tests {
let json = serde_json::to_string(&tool_result).unwrap();
assert!(json.contains("\"type\":\"ToolResults\""));
}
+
+ #[test]
+ fn provider_capabilities_default() {
+ let caps = ProviderCapabilities::default();
+ assert!(!caps.native_tool_calling);
+ }
+
+ #[test]
+ fn provider_capabilities_equality() {
+ let caps1 = ProviderCapabilities {
+ native_tool_calling: true,
+ };
+ let caps2 = ProviderCapabilities {
+ native_tool_calling: true,
+ };
+ let caps3 = ProviderCapabilities {
+ native_tool_calling: false,
+ };
+
+ assert_eq!(caps1, caps2);
+ assert_ne!(caps1, caps3);
+ }
+
+ #[test]
+ fn supports_native_tools_reflects_capabilities_default_mapping() {
+ let provider = CapabilityMockProvider;
+ assert!(provider.supports_native_tools());
+ }
}
diff --git a/src/security/bubblewrap.rs b/src/security/bubblewrap.rs
index 5c7106e..fca76e6 100644
--- a/src/security/bubblewrap.rs
+++ b/src/security/bubblewrap.rs
@@ -81,14 +81,17 @@ mod tests {
#[test]
fn bubblewrap_sandbox_name() {
- assert_eq!(BubblewrapSandbox.name(), "bubblewrap");
+ let sandbox = BubblewrapSandbox;
+ assert_eq!(sandbox.name(), "bubblewrap");
}
#[test]
fn bubblewrap_is_available_only_if_installed() {
// Result depends on whether bwrap is installed
- let available = BubblewrapSandbox::is_available();
+ let sandbox = BubblewrapSandbox;
+ let _available = sandbox.is_available();
+
// Either way, the name should still work
- assert_eq!(BubblewrapSandbox.name(), "bubblewrap");
+ assert_eq!(sandbox.name(), "bubblewrap");
}
}
diff --git a/src/security/pairing.rs b/src/security/pairing.rs
index 806431b..2a828e1 100644
--- a/src/security/pairing.rs
+++ b/src/security/pairing.rs
@@ -184,7 +184,7 @@ fn generate_token() -> String {
use rand::RngCore;
let mut bytes = [0u8; 32];
rand::thread_rng().fill_bytes(&mut bytes);
- format!("zc_{}", hex::encode(&bytes))
+ format!("zc_{}", hex::encode(bytes))
}
/// SHA-256 hash a bearer token for storage. Returns lowercase hex.
diff --git a/src/security/policy.rs b/src/security/policy.rs
index 9383f3a..e47947a 100644
--- a/src/security/policy.rs
+++ b/src/security/policy.rs
@@ -343,6 +343,7 @@ impl SecurityPolicy {
/// validates each sub-command against the allowlist
/// - Blocks single `&` background chaining (`&&` remains supported)
/// - Blocks output redirections (`>`, `>>`) that could write outside workspace
+ /// - Blocks dangerous arguments (e.g. `find -exec`, `git config`)
pub fn is_command_allowed(&self, command: &str) -> bool {
if self.autonomy == AutonomyLevel::ReadOnly {
return false;
@@ -350,7 +351,12 @@ impl SecurityPolicy {
// Block subshell/expansion operators — these allow hiding arbitrary
// commands inside an allowed command (e.g. `echo $(rm -rf /)`)
- if command.contains('`') || command.contains("$(") || command.contains("${") {
+ if command.contains('`')
+ || command.contains("$(")
+ || command.contains("${")
+ || command.contains("<(")
+ || command.contains(">(")
+ {
return false;
}
@@ -359,6 +365,15 @@ impl SecurityPolicy {
return false;
}
+ // Block `tee` — it can write to arbitrary files, bypassing the
+ // redirect check above (e.g. `echo secret | tee /etc/crontab`)
+ if command
+ .split_whitespace()
+ .any(|w| w == "tee" || w.ends_with("/tee"))
+ {
+ return false;
+ }
+
// Block background command chaining (`&`), which can hide extra
// sub-commands and outlive timeout expectations. Keep `&&` allowed.
if contains_single_ampersand(command) {
@@ -384,13 +399,9 @@ impl SecurityPolicy {
// Strip leading env var assignments (e.g. FOO=bar cmd)
let cmd_part = skip_env_assignments(segment);
- let base_cmd = cmd_part
- .split_whitespace()
- .next()
- .unwrap_or("")
- .rsplit('/')
- .next()
- .unwrap_or("");
+ let mut words = cmd_part.split_whitespace();
+ let base_raw = words.next().unwrap_or("");
+ let base_cmd = base_raw.rsplit('/').next().unwrap_or("");
if base_cmd.is_empty() {
continue;
@@ -403,6 +414,12 @@ impl SecurityPolicy {
{
return false;
}
+
+ // Validate arguments for the command
+ let args: Vec = words.map(|w| w.to_ascii_lowercase()).collect();
+ if !self.is_args_safe(base_cmd, &args) {
+ return false;
+ }
}
// At least one command must be present
@@ -414,6 +431,29 @@ impl SecurityPolicy {
has_cmd
}
+ /// Check for dangerous arguments that allow sub-command execution.
+ fn is_args_safe(&self, base: &str, args: &[String]) -> bool {
+ let base = base.to_ascii_lowercase();
+ match base.as_str() {
+ "find" => {
+ // find -exec and find -ok allow arbitrary command execution
+ !args.iter().any(|arg| arg == "-exec" || arg == "-ok")
+ }
+ "git" => {
+ // git config, alias, and -c can be used to set dangerous options
+ // (e.g. git config core.editor "rm -rf /")
+ !args.iter().any(|arg| {
+ arg == "config"
+ || arg.starts_with("config.")
+ || arg == "alias"
+ || arg.starts_with("alias.")
+ || arg == "-c"
+ })
+ }
+ _ => true,
+ }
+ }
+
/// Check if a file path is allowed (no path traversal, within workspace)
pub fn is_path_allowed(&self, path: &str) -> bool {
// Block null bytes (can truncate paths in C-backed syscalls)
@@ -982,12 +1022,43 @@ mod tests {
assert!(!p.is_command_allowed("ls >> /tmp/exfil.txt"));
}
+ #[test]
+ fn command_argument_injection_blocked() {
+ let p = default_policy();
+ // find -exec is a common bypass
+ assert!(!p.is_command_allowed("find . -exec rm -rf {} +"));
+ assert!(!p.is_command_allowed("find / -ok cat {} \\;"));
+ // git config/alias can execute commands
+ assert!(!p.is_command_allowed("git config core.editor \"rm -rf /\""));
+ assert!(!p.is_command_allowed("git alias.st status"));
+ assert!(!p.is_command_allowed("git -c core.editor=calc.exe commit"));
+ // Legitimate commands should still work
+ assert!(p.is_command_allowed("find . -name '*.txt'"));
+ assert!(p.is_command_allowed("git status"));
+ assert!(p.is_command_allowed("git add ."));
+ }
+
#[test]
fn command_injection_dollar_brace_blocked() {
let p = default_policy();
assert!(!p.is_command_allowed("echo ${IFS}cat${IFS}/etc/passwd"));
}
+ #[test]
+ fn command_injection_tee_blocked() {
+ let p = default_policy();
+ assert!(!p.is_command_allowed("echo secret | tee /etc/crontab"));
+ assert!(!p.is_command_allowed("ls | /usr/bin/tee outfile"));
+ assert!(!p.is_command_allowed("tee file.txt"));
+ }
+
+ #[test]
+ fn command_injection_process_substitution_blocked() {
+ let p = default_policy();
+ assert!(!p.is_command_allowed("cat <(echo pwned)"));
+ assert!(!p.is_command_allowed("ls >(cat /etc/passwd)"));
+ }
+
#[test]
fn command_env_var_prefix_with_allowed_cmd() {
let p = default_policy();
diff --git a/src/tools/browser.rs b/src/tools/browser.rs
index fe3be26..4e3d59e 100644
--- a/src/tools/browser.rs
+++ b/src/tools/browser.rs
@@ -854,7 +854,6 @@ impl BrowserTool {
}
}
-#[allow(clippy::too_many_lines)]
#[async_trait]
impl Tool for BrowserTool {
fn name(&self) -> &str {
@@ -1031,165 +1030,21 @@ impl Tool for BrowserTool {
return self.execute_computer_use_action(action_str, &args).await;
}
- let action = match action_str {
- "open" => {
- let url = args
- .get("url")
- .and_then(|v| v.as_str())
- .ok_or_else(|| anyhow::anyhow!("Missing 'url' for open action"))?;
- BrowserAction::Open { url: url.into() }
- }
- "snapshot" => BrowserAction::Snapshot {
- interactive_only: args
- .get("interactive_only")
- .and_then(serde_json::Value::as_bool)
- .unwrap_or(true), // Default to interactive for AI
- compact: args
- .get("compact")
- .and_then(serde_json::Value::as_bool)
- .unwrap_or(true),
- depth: args
- .get("depth")
- .and_then(serde_json::Value::as_u64)
- .map(|d| u32::try_from(d).unwrap_or(u32::MAX)),
- },
- "click" => {
- let selector = args
- .get("selector")
- .and_then(|v| v.as_str())
- .ok_or_else(|| anyhow::anyhow!("Missing 'selector' for click"))?;
- BrowserAction::Click {
- selector: selector.into(),
- }
- }
- "fill" => {
- let selector = args
- .get("selector")
- .and_then(|v| v.as_str())
- .ok_or_else(|| anyhow::anyhow!("Missing 'selector' for fill"))?;
- let value = args
- .get("value")
- .and_then(|v| v.as_str())
- .ok_or_else(|| anyhow::anyhow!("Missing 'value' for fill"))?;
- BrowserAction::Fill {
- selector: selector.into(),
- value: value.into(),
- }
- }
- "type" => {
- let selector = args
- .get("selector")
- .and_then(|v| v.as_str())
- .ok_or_else(|| anyhow::anyhow!("Missing 'selector' for type"))?;
- let text = args
- .get("text")
- .and_then(|v| v.as_str())
- .ok_or_else(|| anyhow::anyhow!("Missing 'text' for type"))?;
- BrowserAction::Type {
- selector: selector.into(),
- text: text.into(),
- }
- }
- "get_text" => {
- let selector = args
- .get("selector")
- .and_then(|v| v.as_str())
- .ok_or_else(|| anyhow::anyhow!("Missing 'selector' for get_text"))?;
- BrowserAction::GetText {
- selector: selector.into(),
- }
- }
- "get_title" => BrowserAction::GetTitle,
- "get_url" => BrowserAction::GetUrl,
- "screenshot" => BrowserAction::Screenshot {
- path: args.get("path").and_then(|v| v.as_str()).map(String::from),
- full_page: args
- .get("full_page")
- .and_then(serde_json::Value::as_bool)
- .unwrap_or(false),
- },
- "wait" => BrowserAction::Wait {
- selector: args
- .get("selector")
- .and_then(|v| v.as_str())
- .map(String::from),
- ms: args.get("ms").and_then(serde_json::Value::as_u64),
- text: args.get("text").and_then(|v| v.as_str()).map(String::from),
- },
- "press" => {
- let key = args
- .get("key")
- .and_then(|v| v.as_str())
- .ok_or_else(|| anyhow::anyhow!("Missing 'key' for press"))?;
- BrowserAction::Press { key: key.into() }
- }
- "hover" => {
- let selector = args
- .get("selector")
- .and_then(|v| v.as_str())
- .ok_or_else(|| anyhow::anyhow!("Missing 'selector' for hover"))?;
- BrowserAction::Hover {
- selector: selector.into(),
- }
- }
- "scroll" => {
- let direction = args
- .get("direction")
- .and_then(|v| v.as_str())
- .ok_or_else(|| anyhow::anyhow!("Missing 'direction' for scroll"))?;
- BrowserAction::Scroll {
- direction: direction.into(),
- pixels: args
- .get("pixels")
- .and_then(serde_json::Value::as_u64)
- .map(|p| u32::try_from(p).unwrap_or(u32::MAX)),
- }
- }
- "is_visible" => {
- let selector = args
- .get("selector")
- .and_then(|v| v.as_str())
- .ok_or_else(|| anyhow::anyhow!("Missing 'selector' for is_visible"))?;
- BrowserAction::IsVisible {
- selector: selector.into(),
- }
- }
- "close" => BrowserAction::Close,
- "find" => {
- let by = args
- .get("by")
- .and_then(|v| v.as_str())
- .ok_or_else(|| anyhow::anyhow!("Missing 'by' for find"))?;
- let value = args
- .get("value")
- .and_then(|v| v.as_str())
- .ok_or_else(|| anyhow::anyhow!("Missing 'value' for find"))?;
- let action = args
- .get("find_action")
- .and_then(|v| v.as_str())
- .ok_or_else(|| anyhow::anyhow!("Missing 'find_action' for find"))?;
- BrowserAction::Find {
- by: by.into(),
- value: value.into(),
- action: action.into(),
- fill_value: args
- .get("fill_value")
- .and_then(|v| v.as_str())
- .map(String::from),
- }
- }
- _ => {
+ if is_computer_use_only_action(action_str) {
+ return Ok(ToolResult {
+ success: false,
+ output: String::new(),
+ error: Some(unavailable_action_for_backend_error(action_str, backend)),
+ });
+ }
+
+ let action = match parse_browser_action(action_str, &args) {
+ Ok(a) => a,
+ Err(e) => {
return Ok(ToolResult {
success: false,
output: String::new(),
- error: Some(format!(
- "Action '{action_str}' is unavailable for backend '{}'",
- match backend {
- ResolvedBackend::AgentBrowser => "agent_browser",
- ResolvedBackend::RustNative => "rust_native",
- ResolvedBackend::ComputerUse => "computer_use",
- }
- )),
+ error: Some(e.to_string()),
});
}
};
@@ -1871,6 +1726,161 @@ mod native_backend {
}
}
+// ── Action parsing ──────────────────────────────────────────────
+
+/// Parse a JSON `args` object into a typed `BrowserAction`.
+fn parse_browser_action(action_str: &str, args: &Value) -> anyhow::Result {
+ match action_str {
+ "open" => {
+ let url = args
+ .get("url")
+ .and_then(|v| v.as_str())
+ .ok_or_else(|| anyhow::anyhow!("Missing 'url' for open action"))?;
+ Ok(BrowserAction::Open { url: url.into() })
+ }
+ "snapshot" => Ok(BrowserAction::Snapshot {
+ interactive_only: args
+ .get("interactive_only")
+ .and_then(serde_json::Value::as_bool)
+ .unwrap_or(true),
+ compact: args
+ .get("compact")
+ .and_then(serde_json::Value::as_bool)
+ .unwrap_or(true),
+ depth: args
+ .get("depth")
+ .and_then(serde_json::Value::as_u64)
+ .map(|d| u32::try_from(d).unwrap_or(u32::MAX)),
+ }),
+ "click" => {
+ let selector = args
+ .get("selector")
+ .and_then(|v| v.as_str())
+ .ok_or_else(|| anyhow::anyhow!("Missing 'selector' for click"))?;
+ Ok(BrowserAction::Click {
+ selector: selector.into(),
+ })
+ }
+ "fill" => {
+ let selector = args
+ .get("selector")
+ .and_then(|v| v.as_str())
+ .ok_or_else(|| anyhow::anyhow!("Missing 'selector' for fill"))?;
+ let value = args
+ .get("value")
+ .and_then(|v| v.as_str())
+ .ok_or_else(|| anyhow::anyhow!("Missing 'value' for fill"))?;
+ Ok(BrowserAction::Fill {
+ selector: selector.into(),
+ value: value.into(),
+ })
+ }
+ "type" => {
+ let selector = args
+ .get("selector")
+ .and_then(|v| v.as_str())
+ .ok_or_else(|| anyhow::anyhow!("Missing 'selector' for type"))?;
+ let text = args
+ .get("text")
+ .and_then(|v| v.as_str())
+ .ok_or_else(|| anyhow::anyhow!("Missing 'text' for type"))?;
+ Ok(BrowserAction::Type {
+ selector: selector.into(),
+ text: text.into(),
+ })
+ }
+ "get_text" => {
+ let selector = args
+ .get("selector")
+ .and_then(|v| v.as_str())
+ .ok_or_else(|| anyhow::anyhow!("Missing 'selector' for get_text"))?;
+ Ok(BrowserAction::GetText {
+ selector: selector.into(),
+ })
+ }
+ "get_title" => Ok(BrowserAction::GetTitle),
+ "get_url" => Ok(BrowserAction::GetUrl),
+ "screenshot" => Ok(BrowserAction::Screenshot {
+ path: args.get("path").and_then(|v| v.as_str()).map(String::from),
+ full_page: args
+ .get("full_page")
+ .and_then(serde_json::Value::as_bool)
+ .unwrap_or(false),
+ }),
+ "wait" => Ok(BrowserAction::Wait {
+ selector: args
+ .get("selector")
+ .and_then(|v| v.as_str())
+ .map(String::from),
+ ms: args.get("ms").and_then(serde_json::Value::as_u64),
+ text: args.get("text").and_then(|v| v.as_str()).map(String::from),
+ }),
+ "press" => {
+ let key = args
+ .get("key")
+ .and_then(|v| v.as_str())
+ .ok_or_else(|| anyhow::anyhow!("Missing 'key' for press"))?;
+ Ok(BrowserAction::Press { key: key.into() })
+ }
+ "hover" => {
+ let selector = args
+ .get("selector")
+ .and_then(|v| v.as_str())
+ .ok_or_else(|| anyhow::anyhow!("Missing 'selector' for hover"))?;
+ Ok(BrowserAction::Hover {
+ selector: selector.into(),
+ })
+ }
+ "scroll" => {
+ let direction = args
+ .get("direction")
+ .and_then(|v| v.as_str())
+ .ok_or_else(|| anyhow::anyhow!("Missing 'direction' for scroll"))?;
+ Ok(BrowserAction::Scroll {
+ direction: direction.into(),
+ pixels: args
+ .get("pixels")
+ .and_then(serde_json::Value::as_u64)
+ .map(|p| u32::try_from(p).unwrap_or(u32::MAX)),
+ })
+ }
+ "is_visible" => {
+ let selector = args
+ .get("selector")
+ .and_then(|v| v.as_str())
+ .ok_or_else(|| anyhow::anyhow!("Missing 'selector' for is_visible"))?;
+ Ok(BrowserAction::IsVisible {
+ selector: selector.into(),
+ })
+ }
+ "close" => Ok(BrowserAction::Close),
+ "find" => {
+ let by = args
+ .get("by")
+ .and_then(|v| v.as_str())
+ .ok_or_else(|| anyhow::anyhow!("Missing 'by' for find"))?;
+ let value = args
+ .get("value")
+ .and_then(|v| v.as_str())
+ .ok_or_else(|| anyhow::anyhow!("Missing 'value' for find"))?;
+ let action = args
+ .get("find_action")
+ .and_then(|v| v.as_str())
+ .ok_or_else(|| anyhow::anyhow!("Missing 'find_action' for find"))?;
+ Ok(BrowserAction::Find {
+ by: by.into(),
+ value: value.into(),
+ action: action.into(),
+ fill_value: args
+ .get("fill_value")
+ .and_then(|v| v.as_str())
+ .map(String::from),
+ })
+ }
+ other => anyhow::bail!("Unsupported browser action: {other}"),
+ }
+}
+
// ── Helper functions ─────────────────────────────────────────────
fn is_supported_browser_action(action: &str) -> bool {
@@ -1901,6 +1911,28 @@ fn is_supported_browser_action(action: &str) -> bool {
)
}
+fn is_computer_use_only_action(action: &str) -> bool {
+ matches!(
+ action,
+ "mouse_move" | "mouse_click" | "mouse_drag" | "key_type" | "key_press" | "screen_capture"
+ )
+}
+
+fn backend_name(backend: ResolvedBackend) -> &'static str {
+ match backend {
+ ResolvedBackend::AgentBrowser => "agent_browser",
+ ResolvedBackend::RustNative => "rust_native",
+ ResolvedBackend::ComputerUse => "computer_use",
+ }
+}
+
+fn unavailable_action_for_backend_error(action: &str, backend: ResolvedBackend) -> String {
+ format!(
+ "Action '{action}' is unavailable for backend '{}'",
+ backend_name(backend)
+ )
+}
+
fn normalize_domains(domains: Vec) -> Vec {
domains
.into_iter()
@@ -2342,4 +2374,28 @@ mod tests {
let tool = BrowserTool::new(security, vec![], None);
assert!(tool.validate_url("https://example.com").is_err());
}
+
+ #[test]
+ fn computer_use_only_action_detection_is_correct() {
+ assert!(is_computer_use_only_action("mouse_move"));
+ assert!(is_computer_use_only_action("mouse_click"));
+ assert!(is_computer_use_only_action("mouse_drag"));
+ assert!(is_computer_use_only_action("key_type"));
+ assert!(is_computer_use_only_action("key_press"));
+ assert!(is_computer_use_only_action("screen_capture"));
+ assert!(!is_computer_use_only_action("open"));
+ assert!(!is_computer_use_only_action("snapshot"));
+ }
+
+ #[test]
+ fn unavailable_action_error_preserves_backend_context() {
+ assert_eq!(
+ unavailable_action_for_backend_error("mouse_move", ResolvedBackend::AgentBrowser),
+ "Action 'mouse_move' is unavailable for backend 'agent_browser'"
+ );
+ assert_eq!(
+ unavailable_action_for_backend_error("mouse_move", ResolvedBackend::RustNative),
+ "Action 'mouse_move' is unavailable for backend 'rust_native'"
+ );
+ }
}
diff --git a/src/tools/composio.rs b/src/tools/composio.rs
index 4e608cb..65f128e 100644
--- a/src/tools/composio.rs
+++ b/src/tools/composio.rs
@@ -112,12 +112,12 @@ impl ComposioTool {
action_name: &str,
params: serde_json::Value,
entity_id: Option<&str>,
- connected_account_id: Option<&str>,
+ connected_account_ref: Option<&str>,
) -> anyhow::Result {
let tool_slug = normalize_tool_slug(action_name);
match self
- .execute_action_v3(&tool_slug, params.clone(), entity_id, connected_account_id)
+ .execute_action_v3(&tool_slug, params.clone(), entity_id, connected_account_ref)
.await
{
Ok(result) => Ok(result),
@@ -130,21 +130,17 @@ impl ComposioTool {
}
}
- async fn execute_action_v3(
- &self,
+ fn build_execute_action_v3_request(
tool_slug: &str,
params: serde_json::Value,
entity_id: Option<&str>,
- connected_account_id: Option<&str>,
- ) -> anyhow::Result {
- let url = if let Some(connected_account_id) = connected_account_id
- .map(str::trim)
- .filter(|id| !id.is_empty())
- {
- format!("{COMPOSIO_API_BASE_V3}/tools/{tool_slug}/execute/{connected_account_id}")
- } else {
- format!("{COMPOSIO_API_BASE_V3}/tools/{tool_slug}/execute")
- };
+ connected_account_ref: Option<&str>,
+ ) -> (String, serde_json::Value) {
+ let url = format!("{COMPOSIO_API_BASE_V3}/tools/{tool_slug}/execute");
+ let account_ref = connected_account_ref.and_then(|candidate| {
+ let trimmed_candidate = candidate.trim();
+ (!trimmed_candidate.is_empty()).then_some(trimmed_candidate)
+ });
let mut body = json!({
"arguments": params,
@@ -153,6 +149,26 @@ impl ComposioTool {
if let Some(entity) = entity_id {
body["user_id"] = json!(entity);
}
+ if let Some(account_ref) = account_ref {
+ body["connected_account_id"] = json!(account_ref);
+ }
+
+ (url, body)
+ }
+
+ async fn execute_action_v3(
+ &self,
+ tool_slug: &str,
+ params: serde_json::Value,
+ entity_id: Option<&str>,
+ connected_account_ref: Option<&str>,
+ ) -> anyhow::Result {
+ let (url, body) = Self::build_execute_action_v3_request(
+ tool_slug,
+ params,
+ entity_id,
+ connected_account_ref,
+ );
let resp = self
.client
@@ -474,11 +490,11 @@ impl Tool for ComposioTool {
})?;
let params = args.get("params").cloned().unwrap_or(json!({}));
- let connected_account_id =
+ let connected_account_ref =
args.get("connected_account_id").and_then(|v| v.as_str());
match self
- .execute_action(action_name, params, Some(entity_id), connected_account_id)
+ .execute_action(action_name, params, Some(entity_id), connected_account_ref)
.await
{
Ok(result) => {
@@ -594,9 +610,38 @@ async fn response_error(resp: reqwest::Response) -> String {
}
if let Some(api_error) = extract_api_error_message(&body) {
- format!("HTTP {}: {api_error}", status.as_u16())
+ return format!(
+ "HTTP {}: {}",
+ status.as_u16(),
+ sanitize_error_message(&api_error)
+ );
+ }
+
+ format!("HTTP {}", status.as_u16())
+}
+
+fn sanitize_error_message(message: &str) -> String {
+ let mut sanitized = message.replace('\n', " ");
+ for marker in [
+ "connected_account_id",
+ "connectedAccountId",
+ "entity_id",
+ "entityId",
+ "user_id",
+ "userId",
+ ] {
+ sanitized = sanitized.replace(marker, "[redacted]");
+ }
+
+ let max_chars = 240;
+ if sanitized.chars().count() <= max_chars {
+ sanitized
} else {
- format!("HTTP {}: {body}", status.as_u16())
+ let mut end = max_chars;
+ while end > 0 && !sanitized.is_char_boundary(end) {
+ end -= 1;
+ }
+ format!("{}...", &sanitized[..end])
}
}
@@ -948,4 +993,40 @@ mod tests {
fn composio_api_base_url_is_v3() {
assert_eq!(COMPOSIO_API_BASE_V3, "https://backend.composio.dev/api/v3");
}
+
+ #[test]
+ fn build_execute_action_v3_request_uses_fixed_endpoint_and_body_account_id() {
+ let (url, body) = ComposioTool::build_execute_action_v3_request(
+ "gmail-send-email",
+ json!({"to": "test@example.com"}),
+ Some("workspace-user"),
+ Some("account-42"),
+ );
+
+ assert_eq!(
+ url,
+ "https://backend.composio.dev/api/v3/tools/gmail-send-email/execute"
+ );
+ assert_eq!(body["arguments"]["to"], json!("test@example.com"));
+ assert_eq!(body["user_id"], json!("workspace-user"));
+ assert_eq!(body["connected_account_id"], json!("account-42"));
+ }
+
+ #[test]
+ fn build_execute_action_v3_request_drops_blank_optional_fields() {
+ let (url, body) = ComposioTool::build_execute_action_v3_request(
+ "github-list-repos",
+ json!({}),
+ None,
+ Some(" "),
+ );
+
+ assert_eq!(
+ url,
+ "https://backend.composio.dev/api/v3/tools/github-list-repos/execute"
+ );
+ assert_eq!(body["arguments"], json!({}));
+ assert!(body.get("connected_account_id").is_none());
+ assert!(body.get("user_id").is_none());
+ }
}
diff --git a/src/tools/delegate.rs b/src/tools/delegate.rs
index 7f30b64..3de7872 100644
--- a/src/tools/delegate.rs
+++ b/src/tools/delegate.rs
@@ -16,8 +16,8 @@ const DELEGATE_TIMEOUT_SECS: u64 = 120;
/// summarization) to purpose-built sub-agents.
pub struct DelegateTool {
agents: Arc