Merge branch 'main' into pr-484-clean

This commit is contained in:
Will Sarg 2026-02-17 08:54:24 -05:00 committed by GitHub
commit ee05d62ce4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
90 changed files with 6937 additions and 1403 deletions

View file

@ -1,25 +1,69 @@
# ZeroClaw Environment Variables # ZeroClaw Environment Variables
# Copy this file to .env and fill in your values. # Copy this file to `.env` and fill in your local values.
# NEVER commit .env — it is listed in .gitignore. # Never commit `.env` or any real secrets.
# ── Required ────────────────────────────────────────────────── # ── Core Runtime ──────────────────────────────────────────────
# Your LLM provider API key # Provider key resolution at runtime:
# ZEROCLAW_API_KEY=sk-your-key-here # 1) explicit key passed from config/CLI
# 2) provider-specific env var (OPENROUTER_API_KEY, OPENAI_API_KEY, ...)
# 3) generic fallback env vars below
# Generic fallback API key (used when provider-specific key is absent)
API_KEY=your-api-key-here API_KEY=your-api-key-here
# ZEROCLAW_API_KEY=your-api-key-here
# ── Provider & Model ───────────────────────────────────────── # Default provider/model (can be overridden by CLI flags)
# LLM provider: openrouter, openai, anthropic, ollama, glm
PROVIDER=openrouter PROVIDER=openrouter
# ZEROCLAW_PROVIDER=openrouter
# ZEROCLAW_MODEL=anthropic/claude-sonnet-4-20250514 # ZEROCLAW_MODEL=anthropic/claude-sonnet-4-20250514
# ZEROCLAW_TEMPERATURE=0.7 # ZEROCLAW_TEMPERATURE=0.7
# Workspace directory override
# ZEROCLAW_WORKSPACE=/path/to/workspace
# ── Provider-Specific API Keys ────────────────────────────────
# OpenRouter
# OPENROUTER_API_KEY=sk-or-v1-...
# Anthropic
# ANTHROPIC_OAUTH_TOKEN=...
# ANTHROPIC_API_KEY=sk-ant-...
# OpenAI / Gemini
# OPENAI_API_KEY=sk-...
# GEMINI_API_KEY=...
# GOOGLE_API_KEY=...
# Other supported providers
# VENICE_API_KEY=...
# GROQ_API_KEY=...
# MISTRAL_API_KEY=...
# DEEPSEEK_API_KEY=...
# XAI_API_KEY=...
# TOGETHER_API_KEY=...
# FIREWORKS_API_KEY=...
# PERPLEXITY_API_KEY=...
# COHERE_API_KEY=...
# MOONSHOT_API_KEY=...
# GLM_API_KEY=...
# MINIMAX_API_KEY=...
# QIANFAN_API_KEY=...
# DASHSCOPE_API_KEY=...
# ZAI_API_KEY=...
# SYNTHETIC_API_KEY=...
# OPENCODE_API_KEY=...
# VERCEL_API_KEY=...
# CLOUDFLARE_API_KEY=...
# ── Gateway ────────────────────────────────────────────────── # ── Gateway ──────────────────────────────────────────────────
# ZEROCLAW_GATEWAY_PORT=3000 # ZEROCLAW_GATEWAY_PORT=3000
# ZEROCLAW_GATEWAY_HOST=127.0.0.1 # ZEROCLAW_GATEWAY_HOST=127.0.0.1
# ZEROCLAW_ALLOW_PUBLIC_BIND=false # ZEROCLAW_ALLOW_PUBLIC_BIND=false
# ── Workspace ──────────────────────────────────────────────── # ── Optional Integrations ────────────────────────────────────
# ZEROCLAW_WORKSPACE=/path/to/workspace # Pushover notifications (`pushover` tool)
# PUSHOVER_TOKEN=your-pushover-app-token
# PUSHOVER_USER_KEY=your-pushover-user-key
# ── Docker Compose ─────────────────────────────────────────── # ── Docker Compose ───────────────────────────────────────────
# Host port mapping (used by docker-compose.yml) # Host port mapping (used by docker-compose.yml)

8
.githooks/pre-commit Executable file
View file

@ -0,0 +1,8 @@
#!/usr/bin/env bash
set -euo pipefail
if command -v gitleaks >/dev/null 2>&1; then
gitleaks protect --staged --redact
else
echo "warning: gitleaks not found; skipping staged secret scan" >&2
fi

View file

@ -12,7 +12,11 @@ Describe this PR in 2-5 bullets:
- Risk label (`risk: low|medium|high`): - Risk label (`risk: low|medium|high`):
- Size label (`size: XS|S|M|L|XL`, auto-managed/read-only): - Size label (`size: XS|S|M|L|XL`, auto-managed/read-only):
- Scope labels (`core|agent|channel|config|cron|daemon|doctor|gateway|health|heartbeat|integration|memory|observability|onboard|provider|runtime|security|service|skillforge|skills|tool|tunnel|docs|dependencies|ci|tests|scripts|dev`, comma-separated): - Scope labels (`core|agent|channel|config|cron|daemon|doctor|gateway|health|heartbeat|integration|memory|observability|onboard|provider|runtime|security|service|skillforge|skills|tool|tunnel|docs|dependencies|ci|tests|scripts|dev`, comma-separated):
<<<<<<< chore/labeler-spacing-trusted-tier
- Module labels (`<module>: <component>`, for example `channel: telegram`, `provider: kimi`, `tool: shell`):
=======
- Module labels (`<module>:<component>`, for example `channel:telegram`, `provider:kimi`, `tool:shell`): - Module labels (`<module>:<component>`, for example `channel:telegram`, `provider:kimi`, `tool:shell`):
>>>>>>> main
- Contributor tier label (`trusted contributor|experienced contributor|principal contributor|distinguished contributor`, auto-managed/read-only; author merged PRs >=5/10/20/50): - Contributor tier label (`trusted contributor|experienced contributor|principal contributor|distinguished contributor`, auto-managed/read-only; author merged PRs >=5/10/20/50):
- If any auto-label is incorrect, note requested correction: - If any auto-label is incorrect, note requested correction:

View file

@ -18,6 +18,7 @@ jobs:
runs-on: blacksmith-2vcpu-ubuntu-2404 runs-on: blacksmith-2vcpu-ubuntu-2404
permissions: permissions:
issues: write issues: write
pull-requests: write
steps: steps:
- name: Apply contributor tier label for issue author - name: Apply contributor tier label for issue author
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8 uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8

View file

@ -35,7 +35,7 @@ jobs:
uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4 uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4
- name: Setup Blacksmith Builder - name: Setup Blacksmith Builder
uses: useblacksmith/setup-docker-builder@v1 uses: useblacksmith/setup-docker-builder@ef12d5b165b596e3aa44ea8198d8fde563eab402 # v1
- name: Extract metadata (tags, labels) - name: Extract metadata (tags, labels)
id: meta id: meta
@ -46,7 +46,7 @@ jobs:
type=ref,event=pr type=ref,event=pr
- name: Build smoke image - name: Build smoke image
uses: useblacksmith/build-push-action@v2 uses: useblacksmith/build-push-action@30c71162f16ea2c27c3e21523255d209b8b538c1 # v2
with: with:
context: . context: .
push: false push: false
@ -71,7 +71,7 @@ jobs:
uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4 uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4
- name: Setup Blacksmith Builder - name: Setup Blacksmith Builder
uses: useblacksmith/setup-docker-builder@v1 uses: useblacksmith/setup-docker-builder@ef12d5b165b596e3aa44ea8198d8fde563eab402 # v1
- name: Log in to Container Registry - name: Log in to Container Registry
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # v3 uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # v3
@ -102,7 +102,7 @@ jobs:
echo "tags=${TAGS}" >> "$GITHUB_OUTPUT" echo "tags=${TAGS}" >> "$GITHUB_OUTPUT"
- name: Build and push Docker image - name: Build and push Docker image
uses: useblacksmith/build-push-action@v2 uses: useblacksmith/build-push-action@30c71162f16ea2c27c3e21523255d209b8b538c1 # v2
with: with:
context: . context: .
push: true push: true

View file

@ -325,13 +325,18 @@ jobs:
return pattern.test(text); return pattern.test(text);
} }
function formatModuleLabel(prefix, segment) {
return `${prefix}: ${segment}`;
}
function parseModuleLabel(label) { function parseModuleLabel(label) {
const separatorIndex = label.indexOf(":"); if (typeof label !== "string") return null;
if (separatorIndex <= 0 || separatorIndex >= label.length - 1) return null; const match = label.match(/^([^:]+):\s*(.+)$/);
return { if (!match) return null;
prefix: label.slice(0, separatorIndex), const prefix = match[1].trim().toLowerCase();
segment: label.slice(separatorIndex + 1), const segment = (match[2] || "").trim().toLowerCase();
}; if (!prefix || !segment) return null;
return { prefix, segment };
} }
function sortByPriority(labels, priorityIndex) { function sortByPriority(labels, priorityIndex) {
@ -389,7 +394,7 @@ jobs:
for (const [prefix, segments] of segmentsByPrefix) { for (const [prefix, segments] of segmentsByPrefix) {
const hasSpecificSegment = [...segments].some((segment) => segment !== "core"); const hasSpecificSegment = [...segments].some((segment) => segment !== "core");
if (hasSpecificSegment) { if (hasSpecificSegment) {
refined.delete(`${prefix}:core`); refined.delete(formatModuleLabel(prefix, "core"));
} }
} }
@ -418,7 +423,7 @@ jobs:
if (uniqueSegments.length === 0) continue; if (uniqueSegments.length === 0) continue;
if (uniqueSegments.length === 1) { if (uniqueSegments.length === 1) {
compactedModuleLabels.add(`${prefix}:${uniqueSegments[0]}`); compactedModuleLabels.add(formatModuleLabel(prefix, uniqueSegments[0]));
} else { } else {
forcePathPrefixes.add(prefix); forcePathPrefixes.add(prefix);
} }
@ -609,7 +614,7 @@ jobs:
segment = normalizeLabelSegment(segment); segment = normalizeLabelSegment(segment);
if (!segment) continue; if (!segment) continue;
detectedModuleLabels.add(`${rule.prefix}:${segment}`); detectedModuleLabels.add(formatModuleLabel(rule.prefix, segment));
} }
} }
@ -635,7 +640,7 @@ jobs:
for (const keyword of providerKeywordHints) { for (const keyword of providerKeywordHints) {
if (containsKeyword(searchableText, keyword)) { if (containsKeyword(searchableText, keyword)) {
detectedModuleLabels.add(`provider:${keyword}`); detectedModuleLabels.add(formatModuleLabel("provider", keyword));
} }
} }
} }
@ -661,7 +666,7 @@ jobs:
for (const keyword of channelKeywordHints) { for (const keyword of channelKeywordHints) {
if (containsKeyword(searchableText, keyword)) { if (containsKeyword(searchableText, keyword)) {
detectedModuleLabels.add(`channel:${keyword}`); detectedModuleLabels.add(formatModuleLabel("channel", keyword));
} }
} }
} }

22
.gitignore vendored
View file

@ -4,6 +4,26 @@ firmware/*/target
*.db-journal *.db-journal
.DS_Store .DS_Store
.wt-pr37/ .wt-pr37/
.env
__pycache__/ __pycache__/
*.pyc *.pyc
docker-compose.override.yml
# Environment files (may contain secrets)
.env
# Python virtual environments
.venv/
venv/
# ESP32 build cache (esp-idf-sys managed)
.embuild/
.env.local
.env.*.local
# Secret keys and credentials
.secret_key
*.key
*.pem
credentials.json

View file

@ -79,6 +79,94 @@ git push --no-verify
> **Note:** CI runs the same checks, so skipped hooks will be caught on the PR. > **Note:** CI runs the same checks, so skipped hooks will be caught on the PR.
## Local Secret Management (Required)
ZeroClaw supports layered secret management for local development and CI hygiene.
### Secret Storage Options
1. **Environment variables** (recommended for local development)
- Copy `.env.example` to `.env` and fill in values
- `.env` files are Git-ignored and should stay local
- Best for temporary/local API keys
2. **Config file** (`~/.zeroclaw/config.toml`)
- Persistent setup for long-term use
- When `secrets.encrypt = true` (default), secret values are encrypted before save
- Secret key is stored at `~/.zeroclaw/.secret_key` with restricted permissions
- Use `zeroclaw onboard` for guided setup
### Runtime Resolution Rules
API key resolution follows this order:
1. Explicit key passed from config/CLI
2. Provider-specific env vars (`OPENROUTER_API_KEY`, `OPENAI_API_KEY`, `ANTHROPIC_API_KEY`, ...)
3. Generic env vars (`ZEROCLAW_API_KEY`, `API_KEY`)
Provider/model config overrides:
- `ZEROCLAW_PROVIDER` / `PROVIDER`
- `ZEROCLAW_MODEL`
See `.env.example` for practical examples and currently supported provider key env vars.
### Pre-Commit Secret Hygiene (Mandatory)
Before every commit, verify:
- [ ] No `.env` files are staged (`.env.example` only)
- [ ] No raw API keys/tokens in code, tests, fixtures, examples, logs, or commit messages
- [ ] No credentials in debug output or error payloads
- [ ] `git diff --cached` has no accidental secret-like strings
Quick local audit:
```bash
# Search staged diff for common secret markers
git diff --cached | grep -iE '(api[_-]?key|secret|token|password|bearer|sk-)'
# Confirm no .env file is staged
git status --short | grep -E '\.env$'
```
### Optional Local Secret Scanning
For extra guardrails, install one of:
- **gitleaks**: [GitHub - gitleaks/gitleaks](https://github.com/gitleaks/gitleaks)
- **truffleHog**: [GitHub - trufflesecurity/trufflehog](https://github.com/trufflesecurity/trufflehog)
- **git-secrets**: [GitHub - awslabs/git-secrets](https://github.com/awslabs/git-secrets)
This repo includes `.githooks/pre-commit` to run `gitleaks protect --staged --redact` when gitleaks is installed.
Enable hooks with:
```bash
git config core.hooksPath .githooks
```
If gitleaks is not installed, the pre-commit hook prints a warning and continues.
### What Must Never Be Committed
- `.env` files (use `.env.example` only)
- API keys, tokens, passwords, or credentials (plain or encrypted)
- OAuth tokens or session identifiers
- Webhook signing secrets
- `~/.zeroclaw/.secret_key` or similar key files
- Personal identifiers or real user data in tests/fixtures
### If a Secret Is Committed Accidentally
1. Revoke/rotate the credential immediately
2. Do not rely only on `git revert` (history still contains the secret)
3. Purge history with `git filter-repo` or BFG
4. Force-push cleaned history (coordinate with maintainers)
5. Ensure the leaked value is removed from PR/issue/discussion/comment history
Reference: [GitHub guide: removing sensitive data from a repository](https://docs.github.com/en/authentication/keeping-your-account-and-data-secure/removing-sensitive-data-from-a-repository)
## Collaboration Tracks (Risk-Based) ## Collaboration Tracks (Risk-Based)
To keep review throughput high without lowering quality, every PR should map to one track: To keep review throughput high without lowering quality, every PR should map to one track:

51
Cargo.lock generated
View file

@ -209,6 +209,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8b52af3cb4058c895d37317bb27508dccc8e5f2d39454016b297bf4a400597b8" checksum = "8b52af3cb4058c895d37317bb27508dccc8e5f2d39454016b297bf4a400597b8"
dependencies = [ dependencies = [
"axum-core", "axum-core",
"base64",
"bytes", "bytes",
"form_urlencoded", "form_urlencoded",
"futures-util", "futures-util",
@ -227,8 +228,10 @@ dependencies = [
"serde_json", "serde_json",
"serde_path_to_error", "serde_path_to_error",
"serde_urlencoded", "serde_urlencoded",
"sha1",
"sync_wrapper", "sync_wrapper",
"tokio", "tokio",
"tokio-tungstenite 0.28.0",
"tower", "tower",
"tower-layer", "tower-layer",
"tower-service", "tower-service",
@ -2057,6 +2060,15 @@ dependencies = [
"hashify", "hashify",
] ]
[[package]]
name = "matchers"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d1525a2a28c7f4fa0fc98bb91ae755d1e2d1505079e05539e35bc876b5d65ae9"
dependencies = [
"regex-automata",
]
[[package]] [[package]]
name = "matchit" name = "matchit"
version = "0.8.4" version = "0.8.4"
@ -3747,10 +3759,22 @@ dependencies = [
"rustls-pki-types", "rustls-pki-types",
"tokio", "tokio",
"tokio-rustls", "tokio-rustls",
"tungstenite", "tungstenite 0.24.0",
"webpki-roots 0.26.11", "webpki-roots 0.26.11",
] ]
[[package]]
name = "tokio-tungstenite"
version = "0.28.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d25a406cddcc431a75d3d9afc6a7c0f7428d4891dd973e4d54c56b46127bf857"
dependencies = [
"futures-util",
"log",
"tokio",
"tungstenite 0.28.0",
]
[[package]] [[package]]
name = "tokio-util" name = "tokio-util"
version = "0.7.18" version = "0.7.18"
@ -3940,9 +3964,13 @@ version = "0.3.22"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2f30143827ddab0d256fd843b7a66d164e9f271cfa0dde49142c5ca0ca291f1e" checksum = "2f30143827ddab0d256fd843b7a66d164e9f271cfa0dde49142c5ca0ca291f1e"
dependencies = [ dependencies = [
"matchers",
"nu-ansi-term", "nu-ansi-term",
"once_cell",
"regex-automata",
"sharded-slab", "sharded-slab",
"thread_local", "thread_local",
"tracing",
"tracing-core", "tracing-core",
] ]
@ -3978,6 +4006,23 @@ dependencies = [
"utf-8", "utf-8",
] ]
[[package]]
name = "tungstenite"
version = "0.28.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8628dcc84e5a09eb3d8423d6cb682965dea9133204e8fb3efee74c2a0c259442"
dependencies = [
"bytes",
"data-encoding",
"http 1.4.0",
"httparse",
"log",
"rand 0.9.2",
"sha1",
"thiserror 2.0.18",
"utf-8",
]
[[package]] [[package]]
name = "twox-hash" name = "twox-hash"
version = "2.1.2" version = "2.1.2"
@ -4880,7 +4925,9 @@ dependencies = [
"pdf-extract", "pdf-extract",
"probe-rs", "probe-rs",
"prometheus", "prometheus",
"prost",
"rand 0.8.5", "rand 0.8.5",
"regex",
"reqwest", "reqwest",
"rppal", "rppal",
"rusqlite", "rusqlite",
@ -4896,7 +4943,7 @@ dependencies = [
"tokio-rustls", "tokio-rustls",
"tokio-serial", "tokio-serial",
"tokio-test", "tokio-test",
"tokio-tungstenite", "tokio-tungstenite 0.24.0",
"toml", "toml",
"tower", "tower",
"tower-http", "tower-http",

View file

@ -1,3 +1,7 @@
[workspace]
members = ["."]
resolver = "2"
[package] [package]
name = "zeroclaw" name = "zeroclaw"
version = "0.1.0" version = "0.1.0"
@ -31,7 +35,7 @@ shellexpand = "3.1"
# Logging - minimal # Logging - minimal
tracing = { version = "0.1", default-features = false } tracing = { version = "0.1", default-features = false }
tracing-subscriber = { version = "0.3", default-features = false, features = ["fmt", "ansi"] } tracing-subscriber = { version = "0.3", default-features = false, features = ["fmt", "ansi", "env-filter"] }
# Observability - Prometheus metrics # Observability - Prometheus metrics
prometheus = { version = "0.14", default-features = false } prometheus = { version = "0.14", default-features = false }
@ -63,12 +67,12 @@ rand = "0.8"
# Fast mutexes that don't poison on panic # Fast mutexes that don't poison on panic
parking_lot = "0.12" parking_lot = "0.12"
# Landlock (Linux sandbox) - optional dependency
landlock = { version = "0.4", optional = true }
# Async traits # Async traits
async-trait = "0.1" async-trait = "0.1"
# Protobuf encode/decode (Feishu WS long-connection frame codec)
prost = { version = "0.14", default-features = false }
# Memory / persistence # Memory / persistence
rusqlite = { version = "0.38", features = ["bundled"] } rusqlite = { version = "0.38", features = ["bundled"] }
chrono = { version = "0.4", default-features = false, features = ["clock", "std", "serde"] } chrono = { version = "0.4", default-features = false, features = ["clock", "std", "serde"] }
@ -86,6 +90,7 @@ glob = "0.3"
tokio-tungstenite = { version = "0.24", features = ["rustls-tls-webpki-roots"] } tokio-tungstenite = { version = "0.24", features = ["rustls-tls-webpki-roots"] }
futures-util = { version = "0.3", default-features = false, features = ["sink"] } futures-util = { version = "0.3", default-features = false, features = ["sink"] }
futures = "0.3" futures = "0.3"
regex = "1.10"
hostname = "0.4.2" hostname = "0.4.2"
lettre = { version = "0.11.19", default-features = false, features = ["builder", "smtp-transport", "rustls-tls"] } lettre = { version = "0.11.19", default-features = false, features = ["builder", "smtp-transport", "rustls-tls"] }
mail-parser = "0.11.2" mail-parser = "0.11.2"
@ -95,7 +100,7 @@ tokio-rustls = "0.26.4"
webpki-roots = "1.0.6" webpki-roots = "1.0.6"
# HTTP server (gateway) — replaces raw TCP for proper HTTP/1.1 compliance # HTTP server (gateway) — replaces raw TCP for proper HTTP/1.1 compliance
axum = { version = "0.8", default-features = false, features = ["http1", "json", "tokio", "query"] } axum = { version = "0.8", default-features = false, features = ["http1", "json", "tokio", "query", "ws"] }
tower = { version = "0.5", default-features = false } tower = { version = "0.5", default-features = false }
tower-http = { version = "0.6", default-features = false, features = ["limit", "timeout"] } tower-http = { version = "0.6", default-features = false, features = ["limit", "timeout"] }
http-body-util = "0.1" http-body-util = "0.1"
@ -117,19 +122,28 @@ probe-rs = { version = "0.30", optional = true }
# PDF extraction for datasheet RAG (optional, enable with --features rag-pdf) # PDF extraction for datasheet RAG (optional, enable with --features rag-pdf)
pdf-extract = { version = "0.10", optional = true } pdf-extract = { version = "0.10", optional = true }
# Raspberry Pi GPIO (Linux/RPi only) — target-specific to avoid compile failure on macOS # Raspberry Pi GPIO / Landlock (Linux only) — target-specific to avoid compile failure on macOS
[target.'cfg(target_os = "linux")'.dependencies] [target.'cfg(target_os = "linux")'.dependencies]
rppal = { version = "0.14", optional = true } rppal = { version = "0.14", optional = true }
landlock = { version = "0.4", optional = true }
[features] [features]
default = ["hardware"] default = ["hardware"]
hardware = ["nusb", "tokio-serial"] hardware = ["nusb", "tokio-serial"]
peripheral-rpi = ["rppal"] peripheral-rpi = ["rppal"]
# Browser backend feature alias used by cfg(feature = "browser-native")
browser-native = ["dep:fantoccini"]
# Backward-compatible alias for older invocations
fantoccini = ["browser-native"]
# Sandbox feature aliases used by cfg(feature = "sandbox-*")
sandbox-landlock = ["dep:landlock"]
sandbox-bubblewrap = []
# Backward-compatible alias for older invocations
landlock = ["sandbox-landlock"]
# probe = probe-rs for Nucleo memory read (adds ~50 deps; optional) # probe = probe-rs for Nucleo memory read (adds ~50 deps; optional)
probe = ["dep:probe-rs"] probe = ["dep:probe-rs"]
# rag-pdf = PDF ingestion for datasheet RAG # rag-pdf = PDF ingestion for datasheet RAG
rag-pdf = ["dep:pdf-extract"] rag-pdf = ["dep:pdf-extract"]
[profile.release] [profile.release]
opt-level = "z" # Optimize for size opt-level = "z" # Optimize for size
lto = "thin" # Lower memory use during release builds lto = "thin" # Lower memory use during release builds

211
LICENSE
View file

@ -1,197 +1,28 @@
Apache License MIT License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION Copyright (c) 2025 ZeroClaw Labs
1. Definitions. Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
"License" shall mean the terms and conditions for use, reproduction, The above copyright notice and this permission notice shall be included in all
and distribution as defined by Sections 1 through 9 of this document. copies or substantial portions of the Software.
"Licensor" shall mean the copyright owner or entity authorized by THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
the copyright owner that is granting the License. IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"Legal Entity" shall mean the union of the acting entity and all ================================================================================
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity This product includes software developed by ZeroClaw Labs and contributors:
exercising permissions granted by this License. https://github.com/zeroclaw-labs/zeroclaw/graphs/contributors
"Source" form shall mean the preferred form for making modifications, See NOTICE file for full contributor attribution.
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to the Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
Copyright 2025-2026 Argenis Delarosa
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
===============================================================================
This product includes software developed by ZeroClaw Labs and contributors:
https://github.com/zeroclaw-labs/zeroclaw/graphs/contributors
See NOTICE file for full contributor attribution.

View file

@ -10,14 +10,14 @@
</p> </p>
<p align="center"> <p align="center">
<a href="LICENSE"><img src="https://img.shields.io/badge/license-Apache%202.0-blue.svg" alt="License: Apache 2.0" /></a> <a href="LICENSE"><img src="https://img.shields.io/badge/license-MIT-blue.svg" alt="License: MIT" /></a>
<a href="NOTICE"><img src="https://img.shields.io/badge/contributors-27+-green.svg" alt="Contributors" /></a> <a href="NOTICE"><img src="https://img.shields.io/badge/contributors-27+-green.svg" alt="Contributors" /></a>
</p> </p>
Fast, small, and fully autonomous AI assistant infrastructure — deploy anywhere, swap anything. Fast, small, and fully autonomous AI assistant infrastructure — deploy anywhere, swap anything.
``` ```
~3.4MB binary · <10ms startup · 1,017 tests · 22+ providers · 8 traits · Pluggable everything ~3.4MB binary · <10ms startup · 1,017 tests · 23+ providers · 8 traits · Pluggable everything
``` ```
### ✨ Features ### ✨ Features
@ -132,6 +132,9 @@ cd zeroclaw
cargo build --release --locked cargo build --release --locked
cargo install --path . --force --locked cargo install --path . --force --locked
# Ensure ~/.cargo/bin is in your PATH
export PATH="$HOME/.cargo/bin:$PATH"
# Quick setup (no prompts) # Quick setup (no prompts)
zeroclaw onboard --api-key sk-... --provider openrouter zeroclaw onboard --api-key sk-... --provider openrouter
@ -187,7 +190,7 @@ Every subsystem is a **trait** — swap implementations with a config change, ze
| Subsystem | Trait | Ships with | Extend | | Subsystem | Trait | Ships with | Extend |
|-----------|-------|------------|--------| |-----------|-------|------------|--------|
| **AI Models** | `Provider` | 22+ providers (OpenRouter, Anthropic, OpenAI, Ollama, Venice, Groq, Mistral, xAI, DeepSeek, Together, Fireworks, Perplexity, Cohere, Bedrock, etc.) | `custom:https://your-api.com` — any OpenAI-compatible API | | **AI Models** | `Provider` | 23+ providers (OpenRouter, Anthropic, OpenAI, Ollama, Venice, Groq, Mistral, xAI, DeepSeek, Together, Fireworks, Perplexity, Cohere, Bedrock, Astrai, etc.) | `custom:https://your-api.com` — any OpenAI-compatible API |
| **Channels** | `Channel` | CLI, Telegram, Discord, Slack, iMessage, Matrix, WhatsApp, Webhook | Any messaging API | | **Channels** | `Channel` | CLI, Telegram, Discord, Slack, iMessage, Matrix, WhatsApp, Webhook | Any messaging API |
| **Memory** | `Memory` | SQLite with hybrid search (FTS5 + vector cosine similarity), Lucid bridge (CLI sync + SQLite fallback), Markdown | Any persistence backend | | **Memory** | `Memory` | SQLite with hybrid search (FTS5 + vector cosine similarity), Lucid bridge (CLI sync + SQLite fallback), Markdown | Any persistence backend |
| **Tools** | `Tool` | shell, file_read, file_write, memory_store, memory_recall, memory_forget, browser_open (Brave + allowlist), browser (agent-browser / rust-native), composio (optional) | Any capability | | **Tools** | `Tool` | shell, file_read, file_write, memory_store, memory_recall, memory_forget, browser_open (Brave + allowlist), browser (agent-browser / rust-native), composio (optional) | Any capability |
@ -287,6 +290,21 @@ rerun channel setup only:
zeroclaw onboard --channels-only zeroclaw onboard --channels-only
``` ```
### Telegram media replies
Telegram routing now replies to the source **chat ID** from incoming updates (instead of usernames),
which avoids `Bad Request: chat not found` failures.
For non-text replies, ZeroClaw can send Telegram attachments when the assistant includes markers:
- `[IMAGE:<path-or-url>]`
- `[DOCUMENT:<path-or-url>]`
- `[VIDEO:<path-or-url>]`
- `[AUDIO:<path-or-url>]`
- `[VOICE:<path-or-url>]`
Paths can be local files (for example `/tmp/screenshot.png`) or HTTPS URLs.
### WhatsApp Business Cloud API Setup ### WhatsApp Business Cloud API Setup
WhatsApp uses Meta's Cloud API with webhooks (push-based, not polling): WhatsApp uses Meta's Cloud API with webhooks (push-based, not polling):
@ -610,7 +628,7 @@ We're building in the open because the best ideas come from everywhere. If you'r
## License ## License
Apache 2.0 — see [LICENSE](LICENSE) and [NOTICE](NOTICE) for contributor attribution MIT — see [LICENSE](LICENSE) and [NOTICE](NOTICE) for contributor attribution
## Contributing ## Contributing
@ -624,7 +642,6 @@ See [CONTRIBUTING.md](CONTRIBUTING.md). Implement a trait, submit a PR:
- New `Tunnel``src/tunnel/` - New `Tunnel``src/tunnel/`
- New `Skill``~/.zeroclaw/workspace/skills/<name>/` - New `Skill``~/.zeroclaw/workspace/skills/<name>/`
--- ---
**ZeroClaw** — Zero overhead. Zero compromise. Deploy anywhere. Swap anything. 🦀 **ZeroClaw** — Zero overhead. Zero compromise. Deploy anywhere. Swap anything. 🦀

View file

@ -1,4 +1,4 @@
FROM ubuntu:22.04 FROM ubuntu:22.04@sha256:c7eb020043d8fc2ae0793fb35a37bff1cf33f156d4d4b12ccc7f3ef8706c38b1
# Prevent interactive prompts during package installation # Prevent interactive prompts during package installation
ENV DEBIAN_FRONTEND=noninteractive ENV DEBIAN_FRONTEND=noninteractive

View file

@ -27,7 +27,7 @@ Merge-blocking checks should stay small and deterministic. Optional checks are u
### Optional Repository Automation ### Optional Repository Automation
- `.github/workflows/labeler.yml` (`PR Labeler`) - `.github/workflows/labeler.yml` (`PR Labeler`)
- Purpose: scope/path labels + size/risk labels + fine-grained module labels (`<module>:<component>`) - Purpose: scope/path labels + size/risk labels + fine-grained module labels (`<module>: <component>`)
- Additional behavior: label descriptions are auto-managed as hover tooltips to explain each auto-judgment rule - Additional behavior: label descriptions are auto-managed as hover tooltips to explain each auto-judgment rule
- Additional behavior: provider-related keywords in provider/config/onboard/integration changes are promoted to `provider:*` labels (for example `provider:kimi`, `provider:deepseek`) - Additional behavior: provider-related keywords in provider/config/onboard/integration changes are promoted to `provider:*` labels (for example `provider:kimi`, `provider:deepseek`)
- Additional behavior: hierarchical de-duplication keeps only the most specific scope labels (for example `tool:composio` suppresses `tool:core` and `tool`) - Additional behavior: hierarchical de-duplication keeps only the most specific scope labels (for example `tool:composio` suppresses `tool:core` and `tool`)

View file

@ -244,7 +244,7 @@ Label discipline:
- Path labels identify subsystem ownership quickly. - Path labels identify subsystem ownership quickly.
- Size labels drive batching strategy. - Size labels drive batching strategy.
- Risk labels drive review depth (`risk: low/medium/high`). - Risk labels drive review depth (`risk: low/medium/high`).
- Module labels (`<module>:<component>`) improve reviewer routing for integration-specific changes and future newly-added modules. - Module labels (`<module>: <component>`) improve reviewer routing for integration-specific changes and future newly-added modules.
- `risk: manual` allows maintainers to preserve a human risk judgment when automation lacks context. - `risk: manual` allows maintainers to preserve a human risk judgment when automation lacks context.
- `no-stale` is reserved for accepted-but-blocked work. - `no-stale` is reserved for accepted-but-blocked work.

View file

@ -14,7 +14,7 @@ Use it to reduce review latency without reducing quality.
For every new PR, do a fast intake pass: For every new PR, do a fast intake pass:
1. Confirm template completeness (`summary`, `validation`, `security`, `rollback`). 1. Confirm template completeness (`summary`, `validation`, `security`, `rollback`).
2. Confirm labels (`size:*`, `risk:*`, scope labels such as `provider`/`channel`/`security`, module-scoped labels such as `channel:*`/`provider:*`/`tool:*`, and contributor tier labels when applicable) are present and plausible. 2. Confirm labels (`size:*`, `risk:*`, scope labels such as `provider`/`channel`/`security`, module-scoped labels such as `channel: *`/`provider: *`/`tool: *`, and contributor tier labels when applicable) are present and plausible.
3. Confirm CI signal status (`CI Required Gate`). 3. Confirm CI signal status (`CI Required Gate`).
4. Confirm scope is one concern (reject mixed mega-PRs unless justified). 4. Confirm scope is one concern (reject mixed mega-PRs unless justified).
5. Confirm privacy/data-hygiene and neutral test wording requirements are satisfied. 5. Confirm privacy/data-hygiene and neutral test wording requirements are satisfied.

View file

@ -12,6 +12,8 @@ use tokio::sync::mpsc;
pub struct ChannelMessage { pub struct ChannelMessage {
pub id: String, pub id: String,
pub sender: String, pub sender: String,
/// Channel-specific reply address (e.g. Telegram chat_id, Discord channel_id).
pub reply_to: String,
pub content: String, pub content: String,
pub channel: String, pub channel: String,
pub timestamp: u64, pub timestamp: u64,
@ -90,9 +92,12 @@ impl Channel for TelegramChannel {
continue; continue;
} }
let chat_id = msg["chat"]["id"].to_string();
let channel_msg = ChannelMessage { let channel_msg = ChannelMessage {
id: msg["message_id"].to_string(), id: msg["message_id"].to_string(),
sender, sender,
reply_to: chat_id,
content: msg["text"].as_str().unwrap_or("").to_string(), content: msg["text"].as_str().unwrap_or("").to_string(),
channel: "telegram".into(), channel: "telegram".into(),
timestamp: msg["date"].as_u64().unwrap_or(0), timestamp: msg["date"].as_u64().unwrap_or(0),

View file

@ -2,4 +2,10 @@
target = "riscv32imc-esp-espidf" target = "riscv32imc-esp-espidf"
[target.riscv32imc-esp-espidf] [target.riscv32imc-esp-espidf]
linker = "ldproxy"
runner = "espflash flash --monitor" runner = "espflash flash --monitor"
# ESP-IDF 5.x uses 64-bit time_t
rustflags = ["-C", "default-linker-libraries", "--cfg", "espidf_time64"]
[unstable]
build-std = ["std", "panic_abort"]

View file

@ -58,24 +58,22 @@ checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8"
[[package]] [[package]]
name = "bindgen" name = "bindgen"
version = "0.63.0" version = "0.71.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "36d860121800b2a9a94f9b5604b332d5cffb234ce17609ea479d723dbc9d3885" checksum = "5f58bf3d7db68cfbac37cfc485a8d711e87e064c3d0fe0435b92f7a407f9d6b3"
dependencies = [ dependencies = [
"bitflags 1.3.2", "bitflags 2.11.0",
"cexpr", "cexpr",
"clang-sys", "clang-sys",
"lazy_static", "itertools",
"lazycell",
"log", "log",
"peeking_take_while", "prettyplease",
"proc-macro2", "proc-macro2",
"quote", "quote",
"regex", "regex",
"rustc-hash", "rustc-hash",
"shlex", "shlex",
"syn 1.0.109", "syn 2.0.116",
"which",
] ]
[[package]] [[package]]
@ -374,14 +372,15 @@ checksum = "dc2d050bdc5c21e0862a89256ed8029ae6c290a93aecefc73084b3002cdebb01"
[[package]] [[package]]
name = "embassy-sync" name = "embassy-sync"
version = "0.5.0" version = "0.7.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dd938f25c0798db4280fcd8026bf4c2f48789aebf8f77b6e5cf8a7693ba114ec" checksum = "73974a3edbd0bd286759b3d483540f0ebef705919a5f56f4fc7709066f71689b"
dependencies = [ dependencies = [
"cfg-if", "cfg-if",
"critical-section", "critical-section",
"embedded-io-async", "embedded-io-async",
"futures-util", "futures-core",
"futures-sink",
"heapless", "heapless",
] ]
@ -446,16 +445,15 @@ dependencies = [
[[package]] [[package]]
name = "embedded-svc" name = "embedded-svc"
version = "0.27.1" version = "0.28.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ac6f87e7654f28018340aa55f933803017aefabaa5417820a3b2f808033c7bbc" checksum = "a7770e30ab55cfbf954c00019522490d6ce26a3334bede05a732ba61010e98e0"
dependencies = [ dependencies = [
"defmt 0.3.100", "defmt 0.3.100",
"embedded-io", "embedded-io",
"embedded-io-async", "embedded-io-async",
"enumset", "enumset",
"heapless", "heapless",
"no-std-net",
"num_enum", "num_enum",
"serde", "serde",
"strum 0.25.0", "strum 0.25.0",
@ -463,9 +461,9 @@ dependencies = [
[[package]] [[package]]
name = "embuild" name = "embuild"
version = "0.31.4" version = "0.33.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4caa4f198bb9152a55c0103efb83fa4edfcbb8625f4c9e94ae8ec8e23827c563" checksum = "e188ad2bbe82afa841ea4a29880651e53ab86815db036b2cb9f8de3ac32dad75"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"bindgen", "bindgen",
@ -475,6 +473,7 @@ dependencies = [
"globwalk", "globwalk",
"home", "home",
"log", "log",
"regex",
"remove_dir_all", "remove_dir_all",
"serde", "serde",
"serde_json", "serde_json",
@ -533,9 +532,8 @@ dependencies = [
[[package]] [[package]]
name = "esp-idf-hal" name = "esp-idf-hal"
version = "0.43.1" version = "0.45.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "git+https://github.com/esp-rs/esp-idf-hal#bc48639bd626c72afc1e25e5d497b5c639161d30"
checksum = "f7adf3fb19a9ca016cbea1ab8a7b852ac69df8fcde4923c23d3b155efbc42a74"
dependencies = [ dependencies = [
"atomic-waker", "atomic-waker",
"embassy-sync", "embassy-sync",
@ -552,14 +550,12 @@ dependencies = [
"heapless", "heapless",
"log", "log",
"nb 1.1.0", "nb 1.1.0",
"num_enum",
] ]
[[package]] [[package]]
name = "esp-idf-svc" name = "esp-idf-svc"
version = "0.48.1" version = "0.51.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "git+https://github.com/esp-rs/esp-idf-svc#dee202f146c7681e54eabbf118a216fc0195d203"
checksum = "2180642ca122a7fec1ec417a9b1a77aa66aaa067fdf1daae683dd8caba84f26b"
dependencies = [ dependencies = [
"embassy-futures", "embassy-futures",
"embedded-hal-async", "embedded-hal-async",
@ -567,6 +563,7 @@ dependencies = [
"embuild", "embuild",
"enumset", "enumset",
"esp-idf-hal", "esp-idf-hal",
"futures-io",
"heapless", "heapless",
"log", "log",
"num_enum", "num_enum",
@ -575,14 +572,13 @@ dependencies = [
[[package]] [[package]]
name = "esp-idf-sys" name = "esp-idf-sys"
version = "0.34.1" version = "0.36.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "git+https://github.com/esp-rs/esp-idf-sys#64667a38fb8004e1fc3b032488af6857ca3cd849"
checksum = "2e148f97c04ed3e9181a08bcdc9560a515aad939b0ba7f50a0022e294665e0af"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"bindgen",
"build-time", "build-time",
"cargo_metadata", "cargo_metadata",
"cmake",
"const_format", "const_format",
"embuild", "embuild",
"envy", "envy",
@ -649,21 +645,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7e3450815272ef58cec6d564423f6e755e25379b217b0bc688e295ba24df6b1d" checksum = "7e3450815272ef58cec6d564423f6e755e25379b217b0bc688e295ba24df6b1d"
[[package]] [[package]]
name = "futures-task" name = "futures-io"
version = "0.3.32" version = "0.3.32"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "037711b3d59c33004d3856fbdc83b99d4ff37a24768fa1be9ce3538a1cde4393" checksum = "cecba35d7ad927e23624b22ad55235f2239cfa44fd10428eecbeba6d6a717718"
[[package]] [[package]]
name = "futures-util" name = "futures-sink"
version = "0.3.32" version = "0.3.32"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "389ca41296e6190b48053de0321d02a77f32f8a5d2461dd38762c0593805c6d6" checksum = "c39754e157331b013978ec91992bde1ac089843443c49cbc7f46150b0fad0893"
dependencies = [
"futures-core",
"futures-task",
"pin-project-lite",
]
[[package]] [[package]]
name = "getrandom" name = "getrandom"
@ -827,6 +818,15 @@ dependencies = [
"serde_core", "serde_core",
] ]
[[package]]
name = "itertools"
version = "0.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186"
dependencies = [
"either",
]
[[package]] [[package]]
name = "itoa" name = "itoa"
version = "1.0.17" version = "1.0.17"
@ -843,18 +843,6 @@ dependencies = [
"wasm-bindgen", "wasm-bindgen",
] ]
[[package]]
name = "lazy_static"
version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe"
[[package]]
name = "lazycell"
version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55"
[[package]] [[package]]
name = "leb128fmt" name = "leb128fmt"
version = "0.1.0" version = "0.1.0"
@ -945,12 +933,6 @@ dependencies = [
"libc", "libc",
] ]
[[package]]
name = "no-std-net"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1bcece43b12349917e096cddfa66107277f123e6c96a5aea78711dc601a47152"
[[package]] [[package]]
name = "nom" name = "nom"
version = "7.1.3" version = "7.1.3"
@ -1007,18 +989,6 @@ version = "1.21.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d"
[[package]]
name = "peeking_take_while"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "19b17cddbe7ec3f8bc800887bab5e717348c95ea2ca0b1bf0837fb964dc67099"
[[package]]
name = "pin-project-lite"
version = "0.2.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3b3cff922bd51709b605d9ead9aa71031d81447142d828eb4a6eba76fe619f9b"
[[package]] [[package]]
name = "prettyplease" name = "prettyplease"
version = "0.2.37" version = "0.2.37"
@ -1138,9 +1108,9 @@ dependencies = [
[[package]] [[package]]
name = "rustc-hash" name = "rustc-hash"
version = "1.1.0" version = "2.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d"
[[package]] [[package]]
name = "rustix" name = "rustix"

View file

@ -14,15 +14,21 @@ edition = "2021"
license = "MIT" license = "MIT"
description = "ZeroClaw ESP32 peripheral firmware — GPIO over JSON serial" description = "ZeroClaw ESP32 peripheral firmware — GPIO over JSON serial"
[patch.crates-io]
# Use latest esp-rs crates to fix u8/i8 char pointer compatibility with ESP-IDF 5.x
esp-idf-sys = { git = "https://github.com/esp-rs/esp-idf-sys" }
esp-idf-hal = { git = "https://github.com/esp-rs/esp-idf-hal" }
esp-idf-svc = { git = "https://github.com/esp-rs/esp-idf-svc" }
[dependencies] [dependencies]
esp-idf-svc = "0.48" esp-idf-svc = { git = "https://github.com/esp-rs/esp-idf-svc" }
log = "0.4" log = "0.4"
anyhow = "1.0" anyhow = "1.0"
serde = { version = "1.0", features = ["derive"] } serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0" serde_json = "1.0"
[build-dependencies] [build-dependencies]
embuild = "0.31" embuild = { version = "0.33", features = ["espidf"] }
[profile.release] [profile.release]
opt-level = "s" opt-level = "s"

View file

@ -2,8 +2,11 @@
Peripheral firmware for ESP32 — speaks the same JSON-over-serial protocol as the STM32 firmware. Flash this to your ESP32, then configure ZeroClaw on the host to connect via serial. Peripheral firmware for ESP32 — speaks the same JSON-over-serial protocol as the STM32 firmware. Flash this to your ESP32, then configure ZeroClaw on the host to connect via serial.
**New to this?** See [SETUP.md](SETUP.md) for step-by-step commands and troubleshooting.
## Protocol ## Protocol
- **Request** (host → ESP32): `{"id":"1","cmd":"gpio_write","args":{"pin":13,"value":1}}\n` - **Request** (host → ESP32): `{"id":"1","cmd":"gpio_write","args":{"pin":13,"value":1}}\n`
- **Response** (ESP32 → host): `{"id":"1","ok":true,"result":"done"}\n` - **Response** (ESP32 → host): `{"id":"1","ok":true,"result":"done"}\n`
@ -11,19 +14,44 @@ Commands: `gpio_read`, `gpio_write`.
## Prerequisites ## Prerequisites
1. **ESP toolchain** (espup): 1. **RISC-V ESP-IDF** (ESP32-C2/C3): Uses nightly Rust with `build-std`.
**Python**: ESP-IDF requires Python 3.103.13 (not 3.14). If you have Python 3.14:
```sh
brew install python@3.12
```
**virtualenv** (needed by ESP-IDF tools; PEP 668 workaround on macOS):
```sh
/opt/homebrew/opt/python@3.12/bin/python3.12 -m pip install virtualenv --break-system-packages
```
**Rust tools**:
```sh
cargo install espflash ldproxy
```
The project's `rust-toolchain.toml` pins nightly + rust-src. `esp-idf-sys` downloads ESP-IDF automatically on first build. Use Python 3.12 for the build:
```sh
export PATH="/opt/homebrew/opt/python@3.12/libexec/bin:$PATH"
```
2. **Xtensa targets** (ESP32, ESP32-S2, ESP32-S3): Use espup instead:
```sh ```sh
cargo install espup espflash cargo install espup espflash
espup install espup install
source ~/export-esp.sh # or ~/export-esp.fish for Fish source ~/export-esp.sh
``` ```
Then edit `.cargo/config.toml` to change the target (e.g. `xtensa-esp32-espidf`).
2. **Target**: ESP32-C3 (RISC-V) by default. Edit `.cargo/config.toml` for other targets (e.g. `xtensa-esp32-espidf` for original ESP32).
## Build & Flash ## Build & Flash
```sh ```sh
cd firmware/zeroclaw-esp32 cd firmware/zeroclaw-esp32
# Use Python 3.12 (required if you have 3.14)
export PATH="/opt/homebrew/opt/python@3.12/libexec/bin:$PATH"
# Optional: pin MCU (esp32c3 or esp32c2)
export MCU=esp32c3
cargo build --release cargo build --release
espflash flash target/riscv32imc-esp-espidf/release/zeroclaw-esp32 --monitor espflash flash target/riscv32imc-esp-espidf/release/zeroclaw-esp32 --monitor
``` ```

View file

@ -0,0 +1,156 @@
# ESP32 Firmware Setup Guide
Step-by-step setup for building the ZeroClaw ESP32 firmware. Follow this if you run into issues.
## Quick Start (copy-paste)
```sh
# 1. Install Python 3.12 (ESP-IDF needs 3.103.13, not 3.14)
brew install python@3.12
# 2. Install virtualenv (PEP 668 workaround on macOS)
/opt/homebrew/opt/python@3.12/bin/python3.12 -m pip install virtualenv --break-system-packages
# 3. Install Rust tools
cargo install espflash ldproxy
# 4. Build
cd firmware/zeroclaw-esp32
export PATH="/opt/homebrew/opt/python@3.12/libexec/bin:$PATH"
cargo build --release
# 5. Flash (connect ESP32 via USB)
espflash flash target/riscv32imc-esp-espidf/release/zeroclaw-esp32 --monitor
```
---
## Detailed Steps
### 1. Python
ESP-IDF requires Python 3.103.13. **Python 3.14 is not supported.**
```sh
brew install python@3.12
```
### 2. virtualenv
ESP-IDF tools need `virtualenv`. On macOS with Homebrew Python, PEP 668 blocks `pip install`; use:
```sh
/opt/homebrew/opt/python@3.12/bin/python3.12 -m pip install virtualenv --break-system-packages
```
### 3. Rust Tools
```sh
cargo install espflash ldproxy
```
- **espflash**: flash and monitor
- **ldproxy**: linker for ESP-IDF builds
### 4. Use Python 3.12 for Builds
Before every build (or add to `~/.zshrc`):
```sh
export PATH="/opt/homebrew/opt/python@3.12/libexec/bin:$PATH"
```
### 5. Build
```sh
cd firmware/zeroclaw-esp32
cargo build --release
```
First build downloads and compiles ESP-IDF (~515 min).
### 6. Flash
```sh
espflash flash target/riscv32imc-esp-espidf/release/zeroclaw-esp32 --monitor
```
---
## Troubleshooting
### "No space left on device"
Free disk space. Common targets:
```sh
# Cargo cache (often 520 GB)
rm -rf ~/.cargo/registry/cache ~/.cargo/registry/src
# Unused Rust toolchains
rustup toolchain list
rustup toolchain uninstall <name>
# iOS Simulator runtimes (~35 GB)
xcrun simctl delete unavailable
# Temp files
rm -rf /var/folders/*/T/cargo-install*
```
### "can't find crate for `core`" / "riscv32imc-esp-espidf target may not be installed"
This project uses **nightly Rust with build-std**, not espup. Ensure:
- `rust-toolchain.toml` exists (pins nightly + rust-src)
- You are **not** sourcing `~/export-esp.sh` (that's for Xtensa targets)
- Run `cargo build` from `firmware/zeroclaw-esp32`
### "externally-managed-environment" / "No module named 'virtualenv'"
Install virtualenv with the PEP 668 workaround:
```sh
/opt/homebrew/opt/python@3.12/bin/python3.12 -m pip install virtualenv --break-system-packages
```
### "expected `i64`, found `i32`" (time_t mismatch)
Already fixed in `.cargo/config.toml` with `espidf_time64` for ESP-IDF 5.x. If you use ESP-IDF 4.4, switch to `espidf_time32`.
### "expected `*const u8`, found `*const i8`" (esp-idf-svc)
Already fixed via `[patch.crates-io]` in `Cargo.toml` using esp-rs crates from git. Do not remove the patch.
### 10,000+ files in `git status`
The `.embuild/` directory (ESP-IDF cache) has ~100k+ files. It is in `.gitignore`. If you see them, ensure `.gitignore` contains:
```
.embuild/
```
---
## Optional: Auto-load Python 3.12
Add to `~/.zshrc`:
```sh
# ESP32 firmware build
export PATH="/opt/homebrew/opt/python@3.12/libexec/bin:$PATH"
```
---
## Xtensa Targets (ESP32, ESP32-S2, ESP32-S3)
For nonRISC-V chips, use espup instead:
```sh
cargo install espup espflash
espup install
source ~/export-esp.sh
```
Then edit `.cargo/config.toml` to use `xtensa-esp32-espidf` (or the correct target).

View file

@ -0,0 +1,3 @@
[toolchain]
channel = "nightly"
components = ["rust-src"]

View file

@ -6,8 +6,9 @@
//! Protocol: same as STM32 — see docs/hardware-peripherals-design.md //! Protocol: same as STM32 — see docs/hardware-peripherals-design.md
use esp_idf_svc::hal::gpio::PinDriver; use esp_idf_svc::hal::gpio::PinDriver;
use esp_idf_svc::hal::prelude::*; use esp_idf_svc::hal::peripherals::Peripherals;
use esp_idf_svc::hal::uart::*; use esp_idf_svc::hal::uart::{UartConfig, UartDriver};
use esp_idf_svc::hal::units::Hertz;
use log::info; use log::info;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@ -36,9 +37,13 @@ fn main() -> anyhow::Result<()> {
let peripherals = Peripherals::take()?; let peripherals = Peripherals::take()?;
let pins = peripherals.pins; let pins = peripherals.pins;
// Create GPIO output drivers first (they take ownership of pins)
let mut gpio2 = PinDriver::output(pins.gpio2)?;
let mut gpio13 = PinDriver::output(pins.gpio13)?;
// UART0: TX=21, RX=20 (ESP32) — ESP32-C3 may use different pins; adjust for your board // UART0: TX=21, RX=20 (ESP32) — ESP32-C3 may use different pins; adjust for your board
let config = UartConfig::new().baudrate(Hertz(115_200)); let config = UartConfig::new().baudrate(Hertz(115_200));
let mut uart = UartDriver::new( let uart = UartDriver::new(
peripherals.uart0, peripherals.uart0,
pins.gpio21, pins.gpio21,
pins.gpio20, pins.gpio20,
@ -60,7 +65,8 @@ fn main() -> anyhow::Result<()> {
if b == b'\n' { if b == b'\n' {
if !line.is_empty() { if !line.is_empty() {
if let Ok(line_str) = std::str::from_utf8(&line) { if let Ok(line_str) = std::str::from_utf8(&line) {
if let Ok(resp) = handle_request(line_str, &peripherals) { if let Ok(resp) = handle_request(line_str, &mut gpio2, &mut gpio13)
{
let out = serde_json::to_string(&resp).unwrap_or_default(); let out = serde_json::to_string(&resp).unwrap_or_default();
let _ = uart.write(format!("{}\n", out).as_bytes()); let _ = uart.write(format!("{}\n", out).as_bytes());
} }
@ -80,10 +86,15 @@ fn main() -> anyhow::Result<()> {
} }
} }
fn handle_request( fn handle_request<G2, G13>(
line: &str, line: &str,
peripherals: &esp_idf_svc::hal::peripherals::Peripherals, gpio2: &mut PinDriver<'_, G2>,
) -> anyhow::Result<Response> { gpio13: &mut PinDriver<'_, G13>,
) -> anyhow::Result<Response>
where
G2: esp_idf_svc::hal::gpio::OutputMode,
G13: esp_idf_svc::hal::gpio::OutputMode,
{
let req: Request = serde_json::from_str(line.trim())?; let req: Request = serde_json::from_str(line.trim())?;
let id = req.id.clone(); let id = req.id.clone();
@ -98,13 +109,13 @@ fn handle_request(
} }
"gpio_read" => { "gpio_read" => {
let pin_num = req.args.get("pin").and_then(|v| v.as_u64()).unwrap_or(0) as i32; let pin_num = req.args.get("pin").and_then(|v| v.as_u64()).unwrap_or(0) as i32;
let value = gpio_read(peripherals, pin_num)?; let value = gpio_read(pin_num)?;
Ok(value.to_string()) Ok(value.to_string())
} }
"gpio_write" => { "gpio_write" => {
let pin_num = req.args.get("pin").and_then(|v| v.as_u64()).unwrap_or(0) as i32; let pin_num = req.args.get("pin").and_then(|v| v.as_u64()).unwrap_or(0) as i32;
let value = req.args.get("value").and_then(|v| v.as_u64()).unwrap_or(0); let value = req.args.get("value").and_then(|v| v.as_u64()).unwrap_or(0);
gpio_write(peripherals, pin_num, value)?; gpio_write(gpio2, gpio13, pin_num, value)?;
Ok("done".into()) Ok("done".into())
} }
_ => Err(anyhow::anyhow!("Unknown command: {}", req.cmd)), _ => Err(anyhow::anyhow!("Unknown command: {}", req.cmd)),
@ -126,28 +137,26 @@ fn handle_request(
} }
} }
fn gpio_read(_peripherals: &esp_idf_svc::hal::peripherals::Peripherals, _pin: i32) -> anyhow::Result<u8> { fn gpio_read(_pin: i32) -> anyhow::Result<u8> {
// TODO: implement input pin read — requires storing InputPin drivers per pin // TODO: implement input pin read — requires storing InputPin drivers per pin
Ok(0) Ok(0)
} }
fn gpio_write( fn gpio_write<G2, G13>(
peripherals: &esp_idf_svc::hal::peripherals::Peripherals, gpio2: &mut PinDriver<'_, G2>,
gpio13: &mut PinDriver<'_, G13>,
pin: i32, pin: i32,
value: u64, value: u64,
) -> anyhow::Result<()> { ) -> anyhow::Result<()>
let pins = peripherals.pins; where
let level = value != 0; G2: esp_idf_svc::hal::gpio::OutputMode,
G13: esp_idf_svc::hal::gpio::OutputMode,
{
let level = esp_idf_svc::hal::gpio::Level::from(value != 0);
match pin { match pin {
2 => { 2 => gpio2.set_level(level)?,
let mut out = PinDriver::output(pins.gpio2)?; 13 => gpio13.set_level(level)?,
out.set_level(esp_idf_svc::hal::gpio::Level::from(level))?;
}
13 => {
let mut out = PinDriver::output(pins.gpio13)?;
out.set_level(esp_idf_svc::hal::gpio::Level::from(level))?;
}
_ => anyhow::bail!("Pin {} not configured (add to gpio_write)", pin), _ => anyhow::bail!("Pin {} not configured (add to gpio_write)", pin),
} }
Ok(()) Ok(())

View file

@ -0,0 +1,324 @@
#!/usr/bin/env bash
set -euo pipefail
SCRIPT_NAME="$(basename "$0")"
usage() {
cat <<USAGE
Recompute contributor tier labels for historical PRs/issues.
Usage:
./$SCRIPT_NAME [options]
Options:
--repo <owner/repo> Target repository (default: current gh repo)
--kind <both|prs|issues>
Target objects (default: both)
--state <all|open|closed>
State filter for listing objects (default: all)
--limit <N> Limit processed objects after fetch (default: 0 = no limit)
--apply Apply label updates (default is dry-run)
--dry-run Preview only (default)
-h, --help Show this help
Examples:
./$SCRIPT_NAME --repo zeroclaw-labs/zeroclaw --limit 50
./$SCRIPT_NAME --repo zeroclaw-labs/zeroclaw --kind prs --state open --apply
USAGE
}
die() {
echo "[$SCRIPT_NAME] ERROR: $*" >&2
exit 1
}
require_cmd() {
if ! command -v "$1" >/dev/null 2>&1; then
die "Required command not found: $1"
fi
}
urlencode() {
jq -nr --arg value "$1" '$value|@uri'
}
select_contributor_tier() {
local merged_count="$1"
if (( merged_count >= 50 )); then
echo "distinguished contributor"
elif (( merged_count >= 20 )); then
echo "principal contributor"
elif (( merged_count >= 10 )); then
echo "experienced contributor"
elif (( merged_count >= 5 )); then
echo "trusted contributor"
else
echo ""
fi
}
DRY_RUN=1
KIND="both"
STATE="all"
LIMIT=0
REPO=""
while (($# > 0)); do
case "$1" in
--repo)
[[ $# -ge 2 ]] || die "Missing value for --repo"
REPO="$2"
shift 2
;;
--kind)
[[ $# -ge 2 ]] || die "Missing value for --kind"
KIND="$2"
shift 2
;;
--state)
[[ $# -ge 2 ]] || die "Missing value for --state"
STATE="$2"
shift 2
;;
--limit)
[[ $# -ge 2 ]] || die "Missing value for --limit"
LIMIT="$2"
shift 2
;;
--apply)
DRY_RUN=0
shift
;;
--dry-run)
DRY_RUN=1
shift
;;
-h|--help)
usage
exit 0
;;
*)
die "Unknown option: $1"
;;
esac
done
case "$KIND" in
both|prs|issues) ;;
*) die "--kind must be one of: both, prs, issues" ;;
esac
case "$STATE" in
all|open|closed) ;;
*) die "--state must be one of: all, open, closed" ;;
esac
if ! [[ "$LIMIT" =~ ^[0-9]+$ ]]; then
die "--limit must be a non-negative integer"
fi
require_cmd gh
require_cmd jq
if ! gh auth status >/dev/null 2>&1; then
die "gh CLI is not authenticated. Run: gh auth login"
fi
if [[ -z "$REPO" ]]; then
REPO="$(gh repo view --json nameWithOwner --jq '.nameWithOwner' 2>/dev/null || true)"
[[ -n "$REPO" ]] || die "Unable to infer repo. Pass --repo <owner/repo>."
fi
echo "[$SCRIPT_NAME] Repo: $REPO"
echo "[$SCRIPT_NAME] Mode: $([[ "$DRY_RUN" -eq 1 ]] && echo "dry-run" || echo "apply")"
echo "[$SCRIPT_NAME] Kind: $KIND | State: $STATE | Limit: $LIMIT"
TIERS_JSON='["trusted contributor","experienced contributor","principal contributor","distinguished contributor"]'
TMP_FILES=()
cleanup() {
if ((${#TMP_FILES[@]} > 0)); then
rm -f "${TMP_FILES[@]}"
fi
}
trap cleanup EXIT
new_tmp_file() {
local tmp
tmp="$(mktemp)"
TMP_FILES+=("$tmp")
echo "$tmp"
}
targets_file="$(new_tmp_file)"
if [[ "$KIND" == "both" || "$KIND" == "prs" ]]; then
gh api --paginate "repos/$REPO/pulls?state=$STATE&per_page=100" \
--jq '.[] | {
kind: "pr",
number: .number,
author: (.user.login // ""),
author_type: (.user.type // ""),
labels: [(.labels[]?.name // empty)]
}' >> "$targets_file"
fi
if [[ "$KIND" == "both" || "$KIND" == "issues" ]]; then
gh api --paginate "repos/$REPO/issues?state=$STATE&per_page=100" \
--jq '.[] | select(.pull_request | not) | {
kind: "issue",
number: .number,
author: (.user.login // ""),
author_type: (.user.type // ""),
labels: [(.labels[]?.name // empty)]
}' >> "$targets_file"
fi
if [[ "$LIMIT" -gt 0 ]]; then
limited_file="$(new_tmp_file)"
head -n "$LIMIT" "$targets_file" > "$limited_file"
mv "$limited_file" "$targets_file"
fi
target_count="$(wc -l < "$targets_file" | tr -d ' ')"
if [[ "$target_count" -eq 0 ]]; then
echo "[$SCRIPT_NAME] No targets found."
exit 0
fi
echo "[$SCRIPT_NAME] Targets fetched: $target_count"
# Ensure tier labels exist (trusted contributor might be new).
label_color=""
for probe_label in "experienced contributor" "principal contributor" "distinguished contributor" "trusted contributor"; do
encoded_label="$(urlencode "$probe_label")"
if color_candidate="$(gh api "repos/$REPO/labels/$encoded_label" --jq '.color' 2>/dev/null || true)"; then
if [[ -n "$color_candidate" ]]; then
label_color="$(echo "$color_candidate" | tr '[:lower:]' '[:upper:]')"
break
fi
fi
done
[[ -n "$label_color" ]] || label_color="C5D7A2"
while IFS= read -r tier_label; do
[[ -n "$tier_label" ]] || continue
encoded_label="$(urlencode "$tier_label")"
if gh api "repos/$REPO/labels/$encoded_label" >/dev/null 2>&1; then
continue
fi
if [[ "$DRY_RUN" -eq 1 ]]; then
echo "[dry-run] Would create missing label: $tier_label (color=$label_color)"
else
gh api -X POST "repos/$REPO/labels" \
-f name="$tier_label" \
-f color="$label_color" >/dev/null
echo "[apply] Created missing label: $tier_label"
fi
done < <(jq -r '.[]' <<<"$TIERS_JSON")
# Build merged PR count cache by unique human authors.
authors_file="$(new_tmp_file)"
jq -r 'select(.author != "" and .author_type != "Bot") | .author' "$targets_file" | sort -u > "$authors_file"
author_count="$(wc -l < "$authors_file" | tr -d ' ')"
echo "[$SCRIPT_NAME] Unique human authors: $author_count"
author_counts_file="$(new_tmp_file)"
while IFS= read -r author; do
[[ -n "$author" ]] || continue
query="repo:$REPO is:pr is:merged author:$author"
merged_count="$(gh api search/issues -f q="$query" -F per_page=1 --jq '.total_count' 2>/dev/null || true)"
if ! [[ "$merged_count" =~ ^[0-9]+$ ]]; then
merged_count=0
fi
printf '%s\t%s\n' "$author" "$merged_count" >> "$author_counts_file"
done < "$authors_file"
updated=0
unchanged=0
skipped=0
failed=0
while IFS= read -r target_json; do
[[ -n "$target_json" ]] || continue
number="$(jq -r '.number' <<<"$target_json")"
kind="$(jq -r '.kind' <<<"$target_json")"
author="$(jq -r '.author' <<<"$target_json")"
author_type="$(jq -r '.author_type' <<<"$target_json")"
current_labels_json="$(jq -c '.labels // []' <<<"$target_json")"
if [[ -z "$author" || "$author_type" == "Bot" ]]; then
skipped=$((skipped + 1))
continue
fi
merged_count="$(awk -F '\t' -v key="$author" '$1 == key { print $2; exit }' "$author_counts_file")"
if ! [[ "$merged_count" =~ ^[0-9]+$ ]]; then
merged_count=0
fi
desired_tier="$(select_contributor_tier "$merged_count")"
if ! current_tier="$(jq -r --argjson tiers "$TIERS_JSON" '[.[] | select(. as $label | ($tiers | index($label)) != null)][0] // ""' <<<"$current_labels_json" 2>/dev/null)"; then
echo "[warn] Skipping ${kind} #${number}: cannot parse current labels JSON" >&2
failed=$((failed + 1))
continue
fi
if ! next_labels_json="$(jq -c --arg desired "$desired_tier" --argjson tiers "$TIERS_JSON" '
(. // [])
| map(select(. as $label | ($tiers | index($label)) == null))
| if $desired != "" then . + [$desired] else . end
| unique
' <<<"$current_labels_json" 2>/dev/null)"; then
echo "[warn] Skipping ${kind} #${number}: cannot compute next labels" >&2
failed=$((failed + 1))
continue
fi
if ! normalized_current="$(jq -c 'unique | sort' <<<"$current_labels_json" 2>/dev/null)"; then
echo "[warn] Skipping ${kind} #${number}: cannot normalize current labels" >&2
failed=$((failed + 1))
continue
fi
if ! normalized_next="$(jq -c 'unique | sort' <<<"$next_labels_json" 2>/dev/null)"; then
echo "[warn] Skipping ${kind} #${number}: cannot normalize next labels" >&2
failed=$((failed + 1))
continue
fi
if [[ "$normalized_current" == "$normalized_next" ]]; then
unchanged=$((unchanged + 1))
continue
fi
if [[ "$DRY_RUN" -eq 1 ]]; then
echo "[dry-run] ${kind} #${number} @${author} merged=${merged_count} tier: '${current_tier:-none}' -> '${desired_tier:-none}'"
updated=$((updated + 1))
continue
fi
payload="$(jq -cn --argjson labels "$next_labels_json" '{labels: $labels}')"
if gh api -X PUT "repos/$REPO/issues/$number/labels" --input - <<<"$payload" >/dev/null; then
echo "[apply] Updated ${kind} #${number} @${author} tier: '${current_tier:-none}' -> '${desired_tier:-none}'"
updated=$((updated + 1))
else
echo "[apply] FAILED ${kind} #${number}" >&2
failed=$((failed + 1))
fi
done < "$targets_file"
echo ""
echo "[$SCRIPT_NAME] Summary"
echo " Targets: $target_count"
echo " Updated: $updated"
echo " Unchanged: $unchanged"
echo " Skipped: $skipped"
echo " Failed: $failed"
if [[ "$failed" -gt 0 ]]; then
exit 1
fi

View file

@ -251,6 +251,7 @@ impl Agent {
let provider: Box<dyn Provider> = providers::create_routed_provider( let provider: Box<dyn Provider> = providers::create_routed_provider(
provider_name, provider_name,
config.api_key.as_deref(), config.api_key.as_deref(),
config.api_url.as_deref(),
&config.reliability, &config.reliability,
&config.model_routes, &config.model_routes,
&model_name, &model_name,
@ -388,7 +389,7 @@ impl Agent {
if self.auto_save { if self.auto_save {
let _ = self let _ = self
.memory .memory
.store("user_msg", user_message, MemoryCategory::Conversation) .store("user_msg", user_message, MemoryCategory::Conversation, None)
.await; .await;
} }
@ -447,7 +448,7 @@ impl Agent {
let summary = truncate_with_ellipsis(&final_text, 100); let summary = truncate_with_ellipsis(&final_text, 100);
let _ = self let _ = self
.memory .memory
.store("assistant_resp", &summary, MemoryCategory::Daily) .store("assistant_resp", &summary, MemoryCategory::Daily, None)
.await; .await;
} }
@ -557,6 +558,7 @@ pub async fn run(
agent.observer.record_event(&ObserverEvent::AgentEnd { agent.observer.record_event(&ObserverEvent::AgentEnd {
duration: start.elapsed(), duration: start.elapsed(),
tokens_used: None, tokens_used: None,
cost_usd: None,
}); });
Ok(()) Ok(())

View file

@ -7,14 +7,70 @@ use crate::security::SecurityPolicy;
use crate::tools::{self, Tool}; use crate::tools::{self, Tool};
use crate::util::truncate_with_ellipsis; use crate::util::truncate_with_ellipsis;
use anyhow::Result; use anyhow::Result;
use regex::{Regex, RegexSet};
use std::fmt::Write; use std::fmt::Write;
use std::io::Write as _; use std::io::Write as _;
use std::sync::Arc; use std::sync::{Arc, LazyLock};
use std::time::Instant; use std::time::Instant;
use uuid::Uuid; use uuid::Uuid;
/// Maximum agentic tool-use iterations per user message to prevent runaway loops. /// Maximum agentic tool-use iterations per user message to prevent runaway loops.
const MAX_TOOL_ITERATIONS: usize = 10; const MAX_TOOL_ITERATIONS: usize = 10;
static SENSITIVE_KEY_PATTERNS: LazyLock<RegexSet> = LazyLock::new(|| {
RegexSet::new([
r"(?i)token",
r"(?i)api[_-]?key",
r"(?i)password",
r"(?i)secret",
r"(?i)user[_-]?key",
r"(?i)bearer",
r"(?i)credential",
])
.unwrap()
});
static SENSITIVE_KV_REGEX: LazyLock<Regex> = LazyLock::new(|| {
Regex::new(r#"(?i)(token|api[_-]?key|password|secret|user[_-]?key|bearer|credential)["']?\s*[:=]\s*(?:"([^"]{8,})"|'([^']{8,})'|([a-zA-Z0-9_\-\.]{8,}))"#).unwrap()
});
/// Scrub credentials from tool output to prevent accidental exfiltration.
/// Replaces known credential patterns with a redacted placeholder while preserving
/// a small prefix for context.
fn scrub_credentials(input: &str) -> String {
SENSITIVE_KV_REGEX
.replace_all(input, |caps: &regex::Captures| {
let full_match = &caps[0];
let key = &caps[1];
let val = caps
.get(2)
.or(caps.get(3))
.or(caps.get(4))
.map(|m| m.as_str())
.unwrap_or("");
// Preserve first 4 chars for context, then redact
let prefix = if val.len() > 4 { &val[..4] } else { "" };
if full_match.contains(':') {
if full_match.contains('"') {
format!("\"{}\": \"{}*[REDACTED]\"", key, prefix)
} else {
format!("{}: {}*[REDACTED]", key, prefix)
}
} else if full_match.contains('=') {
if full_match.contains('"') {
format!("{}=\"{}*[REDACTED]\"", key, prefix)
} else {
format!("{}={}*[REDACTED]", key, prefix)
}
} else {
format!("{}: {}*[REDACTED]", key, prefix)
}
})
.to_string()
}
/// Trigger auto-compaction when non-system message count exceeds this threshold. /// Trigger auto-compaction when non-system message count exceeds this threshold.
const MAX_HISTORY_MESSAGES: usize = 50; const MAX_HISTORY_MESSAGES: usize = 50;
@ -145,7 +201,7 @@ async fn build_context(mem: &dyn Memory, user_msg: &str) -> String {
let mut context = String::new(); let mut context = String::new();
// Pull relevant memories for this message // Pull relevant memories for this message
if let Ok(entries) = mem.recall(user_msg, 5).await { if let Ok(entries) = mem.recall(user_msg, 5, None).await {
if !entries.is_empty() { if !entries.is_empty() {
context.push_str("[Memory context]\n"); context.push_str("[Memory context]\n");
for entry in &entries { for entry in &entries {
@ -436,6 +492,7 @@ struct ParsedToolCall {
/// Execute a single turn of the agent loop: send messages, parse tool calls, /// Execute a single turn of the agent loop: send messages, parse tool calls,
/// execute tools, and loop until the LLM produces a final text response. /// execute tools, and loop until the LLM produces a final text response.
/// When `silent` is true, suppresses stdout (for channel use). /// When `silent` is true, suppresses stdout (for channel use).
#[allow(clippy::too_many_arguments)]
pub(crate) async fn agent_turn( pub(crate) async fn agent_turn(
provider: &dyn Provider, provider: &dyn Provider,
history: &mut Vec<ChatMessage>, history: &mut Vec<ChatMessage>,
@ -461,6 +518,7 @@ pub(crate) async fn agent_turn(
/// Execute a single turn of the agent loop: send messages, parse tool calls, /// Execute a single turn of the agent loop: send messages, parse tool calls,
/// execute tools, and loop until the LLM produces a final text response. /// execute tools, and loop until the LLM produces a final text response.
#[allow(clippy::too_many_arguments)]
pub(crate) async fn run_tool_call_loop( pub(crate) async fn run_tool_call_loop(
provider: &dyn Provider, provider: &dyn Provider,
history: &mut Vec<ChatMessage>, history: &mut Vec<ChatMessage>,
@ -606,7 +664,7 @@ pub(crate) async fn run_tool_call_loop(
success: r.success, success: r.success,
}); });
if r.success { if r.success {
r.output scrub_credentials(&r.output)
} else { } else {
format!("Error: {}", r.error.unwrap_or_else(|| r.output)) format!("Error: {}", r.error.unwrap_or_else(|| r.output))
} }
@ -749,6 +807,7 @@ pub async fn run(
let provider: Box<dyn Provider> = providers::create_routed_provider( let provider: Box<dyn Provider> = providers::create_routed_provider(
provider_name, provider_name,
config.api_key.as_deref(), config.api_key.as_deref(),
config.api_url.as_deref(),
&config.reliability, &config.reliability,
&config.model_routes, &config.model_routes,
model_name, model_name,
@ -912,7 +971,7 @@ pub async fn run(
if config.memory.auto_save { if config.memory.auto_save {
let user_key = autosave_memory_key("user_msg"); let user_key = autosave_memory_key("user_msg");
let _ = mem let _ = mem
.store(&user_key, &msg, MemoryCategory::Conversation) .store(&user_key, &msg, MemoryCategory::Conversation, None)
.await; .await;
} }
@ -955,7 +1014,7 @@ pub async fn run(
let summary = truncate_with_ellipsis(&response, 100); let summary = truncate_with_ellipsis(&response, 100);
let response_key = autosave_memory_key("assistant_resp"); let response_key = autosave_memory_key("assistant_resp");
let _ = mem let _ = mem
.store(&response_key, &summary, MemoryCategory::Daily) .store(&response_key, &summary, MemoryCategory::Daily, None)
.await; .await;
} }
} else { } else {
@ -978,7 +1037,7 @@ pub async fn run(
if config.memory.auto_save { if config.memory.auto_save {
let user_key = autosave_memory_key("user_msg"); let user_key = autosave_memory_key("user_msg");
let _ = mem let _ = mem
.store(&user_key, &msg.content, MemoryCategory::Conversation) .store(&user_key, &msg.content, MemoryCategory::Conversation, None)
.await; .await;
} }
@ -1036,7 +1095,7 @@ pub async fn run(
let summary = truncate_with_ellipsis(&response, 100); let summary = truncate_with_ellipsis(&response, 100);
let response_key = autosave_memory_key("assistant_resp"); let response_key = autosave_memory_key("assistant_resp");
let _ = mem let _ = mem
.store(&response_key, &summary, MemoryCategory::Daily) .store(&response_key, &summary, MemoryCategory::Daily, None)
.await; .await;
} }
} }
@ -1048,6 +1107,7 @@ pub async fn run(
observer.record_event(&ObserverEvent::AgentEnd { observer.record_event(&ObserverEvent::AgentEnd {
duration, duration,
tokens_used: None, tokens_used: None,
cost_usd: None,
}); });
Ok(final_output) Ok(final_output)
@ -1104,6 +1164,7 @@ pub async fn process_message(config: Config, message: &str) -> Result<String> {
let provider: Box<dyn Provider> = providers::create_routed_provider( let provider: Box<dyn Provider> = providers::create_routed_provider(
provider_name, provider_name,
config.api_key.as_deref(), config.api_key.as_deref(),
config.api_url.as_deref(),
&config.reliability, &config.reliability,
&config.model_routes, &config.model_routes,
&model_name, &model_name,
@ -1217,6 +1278,25 @@ pub async fn process_message(config: Config, message: &str) -> Result<String> {
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
#[test]
fn test_scrub_credentials() {
let input = "API_KEY=sk-1234567890abcdef; token: 1234567890; password=\"secret123456\"";
let scrubbed = scrub_credentials(input);
assert!(scrubbed.contains("API_KEY=sk-1*[REDACTED]"));
assert!(scrubbed.contains("token: 1234*[REDACTED]"));
assert!(scrubbed.contains("password=\"secr*[REDACTED]\""));
assert!(!scrubbed.contains("abcdef"));
assert!(!scrubbed.contains("secret123456"));
}
#[test]
fn test_scrub_credentials_json() {
let input = r#"{"api_key": "sk-1234567890", "other": "public"}"#;
let scrubbed = scrub_credentials(input);
assert!(scrubbed.contains("\"api_key\": \"sk-1*[REDACTED]\""));
assert!(scrubbed.contains("public"));
}
use crate::memory::{Memory, MemoryCategory, SqliteMemory}; use crate::memory::{Memory, MemoryCategory, SqliteMemory};
use tempfile::TempDir; use tempfile::TempDir;
@ -1496,16 +1576,16 @@ I will now call the tool with this payload:
let key1 = autosave_memory_key("user_msg"); let key1 = autosave_memory_key("user_msg");
let key2 = autosave_memory_key("user_msg"); let key2 = autosave_memory_key("user_msg");
mem.store(&key1, "I'm Paul", MemoryCategory::Conversation) mem.store(&key1, "I'm Paul", MemoryCategory::Conversation, None)
.await .await
.unwrap(); .unwrap();
mem.store(&key2, "I'm 45", MemoryCategory::Conversation) mem.store(&key2, "I'm 45", MemoryCategory::Conversation, None)
.await .await
.unwrap(); .unwrap();
assert_eq!(mem.count().await.unwrap(), 2); assert_eq!(mem.count().await.unwrap(), 2);
let recalled = mem.recall("45", 5).await.unwrap(); let recalled = mem.recall("45", 5, None).await.unwrap();
assert!(recalled.iter().any(|entry| entry.content.contains("45"))); assert!(recalled.iter().any(|entry| entry.content.contains("45")));
} }

View file

@ -33,7 +33,7 @@ impl MemoryLoader for DefaultMemoryLoader {
memory: &dyn Memory, memory: &dyn Memory,
user_message: &str, user_message: &str,
) -> anyhow::Result<String> { ) -> anyhow::Result<String> {
let entries = memory.recall(user_message, self.limit).await?; let entries = memory.recall(user_message, self.limit, None).await?;
if entries.is_empty() { if entries.is_empty() {
return Ok(String::new()); return Ok(String::new());
} }
@ -61,11 +61,17 @@ mod tests {
_key: &str, _key: &str,
_content: &str, _content: &str,
_category: MemoryCategory, _category: MemoryCategory,
_session_id: Option<&str>,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
Ok(()) Ok(())
} }
async fn recall(&self, _query: &str, limit: usize) -> anyhow::Result<Vec<MemoryEntry>> { async fn recall(
&self,
_query: &str,
limit: usize,
_session_id: Option<&str>,
) -> anyhow::Result<Vec<MemoryEntry>> {
if limit == 0 { if limit == 0 {
return Ok(vec![]); return Ok(vec![]);
} }
@ -87,6 +93,7 @@ mod tests {
async fn list( async fn list(
&self, &self,
_category: Option<&MemoryCategory>, _category: Option<&MemoryCategory>,
_session_id: Option<&str>,
) -> anyhow::Result<Vec<MemoryEntry>> { ) -> anyhow::Result<Vec<MemoryEntry>> {
Ok(vec![]) Ok(vec![])
} }

View file

@ -40,6 +40,7 @@ impl Channel for CliChannel {
let msg = ChannelMessage { let msg = ChannelMessage {
id: Uuid::new_v4().to_string(), id: Uuid::new_v4().to_string(),
sender: "user".to_string(), sender: "user".to_string(),
reply_target: "user".to_string(),
content: line, content: line,
channel: "cli".to_string(), channel: "cli".to_string(),
timestamp: std::time::SystemTime::now() timestamp: std::time::SystemTime::now()
@ -90,12 +91,14 @@ mod tests {
let msg = ChannelMessage { let msg = ChannelMessage {
id: "test-id".into(), id: "test-id".into(),
sender: "user".into(), sender: "user".into(),
reply_target: "user".into(),
content: "hello".into(), content: "hello".into(),
channel: "cli".into(), channel: "cli".into(),
timestamp: 1_234_567_890, timestamp: 1_234_567_890,
}; };
assert_eq!(msg.id, "test-id"); assert_eq!(msg.id, "test-id");
assert_eq!(msg.sender, "user"); assert_eq!(msg.sender, "user");
assert_eq!(msg.reply_target, "user");
assert_eq!(msg.content, "hello"); assert_eq!(msg.content, "hello");
assert_eq!(msg.channel, "cli"); assert_eq!(msg.channel, "cli");
assert_eq!(msg.timestamp, 1_234_567_890); assert_eq!(msg.timestamp, 1_234_567_890);
@ -106,6 +109,7 @@ mod tests {
let msg = ChannelMessage { let msg = ChannelMessage {
id: "id".into(), id: "id".into(),
sender: "s".into(), sender: "s".into(),
reply_target: "s".into(),
content: "c".into(), content: "c".into(),
channel: "ch".into(), channel: "ch".into(),
timestamp: 0, timestamp: 0,

View file

@ -7,7 +7,7 @@ use tokio::sync::RwLock;
use tokio_tungstenite::tungstenite::Message; use tokio_tungstenite::tungstenite::Message;
use uuid::Uuid; use uuid::Uuid;
/// DingTalk (钉钉) channel — connects via Stream Mode WebSocket for real-time messages. /// DingTalk channel — connects via Stream Mode WebSocket for real-time messages.
/// Replies are sent through per-message session webhook URLs. /// Replies are sent through per-message session webhook URLs.
pub struct DingTalkChannel { pub struct DingTalkChannel {
client_id: String, client_id: String,
@ -64,6 +64,18 @@ impl DingTalkChannel {
let gw: GatewayResponse = resp.json().await?; let gw: GatewayResponse = resp.json().await?;
Ok(gw) Ok(gw)
} }
fn resolve_reply_target(
sender_id: &str,
conversation_type: &str,
conversation_id: Option<&str>,
) -> String {
if conversation_type == "1" {
sender_id.to_string()
} else {
conversation_id.unwrap_or(sender_id).to_string()
}
}
} }
#[async_trait] #[async_trait]
@ -193,14 +205,11 @@ impl Channel for DingTalkChannel {
.unwrap_or("1"); .unwrap_or("1");
// Private chat uses sender ID, group chat uses conversation ID // Private chat uses sender ID, group chat uses conversation ID
let chat_id = if conversation_type == "1" { let chat_id = Self::resolve_reply_target(
sender_id.to_string() sender_id,
} else { conversation_type,
data.get("conversationId") data.get("conversationId").and_then(|c| c.as_str()),
.and_then(|c| c.as_str()) );
.unwrap_or(sender_id)
.to_string()
};
// Store session webhook for later replies // Store session webhook for later replies
if let Some(webhook) = data.get("sessionWebhook").and_then(|w| w.as_str()) { if let Some(webhook) = data.get("sessionWebhook").and_then(|w| w.as_str()) {
@ -229,6 +238,7 @@ impl Channel for DingTalkChannel {
let channel_msg = ChannelMessage { let channel_msg = ChannelMessage {
id: Uuid::new_v4().to_string(), id: Uuid::new_v4().to_string(),
sender: sender_id.to_string(), sender: sender_id.to_string(),
reply_target: chat_id,
content: content.to_string(), content: content.to_string(),
channel: "dingtalk".to_string(), channel: "dingtalk".to_string(),
timestamp: std::time::SystemTime::now() timestamp: std::time::SystemTime::now()
@ -305,4 +315,22 @@ client_secret = "secret"
let config: crate::config::schema::DingTalkConfig = toml::from_str(toml_str).unwrap(); let config: crate::config::schema::DingTalkConfig = toml::from_str(toml_str).unwrap();
assert!(config.allowed_users.is_empty()); assert!(config.allowed_users.is_empty());
} }
#[test]
fn test_resolve_reply_target_private_chat_uses_sender_id() {
let target = DingTalkChannel::resolve_reply_target("staff_1", "1", Some("conv_1"));
assert_eq!(target, "staff_1");
}
#[test]
fn test_resolve_reply_target_group_chat_uses_conversation_id() {
let target = DingTalkChannel::resolve_reply_target("staff_1", "2", Some("conv_1"));
assert_eq!(target, "conv_1");
}
#[test]
fn test_resolve_reply_target_group_chat_falls_back_to_sender_id() {
let target = DingTalkChannel::resolve_reply_target("staff_1", "2", None);
assert_eq!(target, "staff_1");
}
} }

View file

@ -11,6 +11,7 @@ pub struct DiscordChannel {
guild_id: Option<String>, guild_id: Option<String>,
allowed_users: Vec<String>, allowed_users: Vec<String>,
listen_to_bots: bool, listen_to_bots: bool,
mention_only: bool,
client: reqwest::Client, client: reqwest::Client,
typing_handle: std::sync::Mutex<Option<tokio::task::JoinHandle<()>>>, typing_handle: std::sync::Mutex<Option<tokio::task::JoinHandle<()>>>,
} }
@ -21,12 +22,14 @@ impl DiscordChannel {
guild_id: Option<String>, guild_id: Option<String>,
allowed_users: Vec<String>, allowed_users: Vec<String>,
listen_to_bots: bool, listen_to_bots: bool,
mention_only: bool,
) -> Self { ) -> Self {
Self { Self {
bot_token, bot_token,
guild_id, guild_id,
allowed_users, allowed_users,
listen_to_bots, listen_to_bots,
mention_only,
client: reqwest::Client::new(), client: reqwest::Client::new(),
typing_handle: std::sync::Mutex::new(None), typing_handle: std::sync::Mutex::new(None),
} }
@ -343,6 +346,22 @@ impl Channel for DiscordChannel {
continue; continue;
} }
// Skip messages that don't @-mention the bot (when mention_only is enabled)
if self.mention_only {
let mention_tag = format!("<@{bot_user_id}>");
if !content.contains(&mention_tag) {
continue;
}
}
// Strip the bot mention from content so the agent sees clean text
let clean_content = if self.mention_only {
let mention_tag = format!("<@{bot_user_id}>");
content.replace(&mention_tag, "").trim().to_string()
} else {
content.to_string()
};
let message_id = d.get("id").and_then(|i| i.as_str()).unwrap_or(""); let message_id = d.get("id").and_then(|i| i.as_str()).unwrap_or("");
let channel_id = d.get("channel_id").and_then(|c| c.as_str()).unwrap_or("").to_string(); let channel_id = d.get("channel_id").and_then(|c| c.as_str()).unwrap_or("").to_string();
@ -353,6 +372,11 @@ impl Channel for DiscordChannel {
format!("discord_{message_id}") format!("discord_{message_id}")
}, },
sender: author_id.to_string(), sender: author_id.to_string(),
reply_target: if channel_id.is_empty() {
author_id.to_string()
} else {
channel_id
},
content: content.to_string(), content: content.to_string(),
channel: channel_id, channel: channel_id,
timestamp: std::time::SystemTime::now() timestamp: std::time::SystemTime::now()
@ -423,7 +447,7 @@ mod tests {
#[test] #[test]
fn discord_channel_name() { fn discord_channel_name() {
let ch = DiscordChannel::new("fake".into(), None, vec![], false); let ch = DiscordChannel::new("fake".into(), None, vec![], false, false);
assert_eq!(ch.name(), "discord"); assert_eq!(ch.name(), "discord");
} }
@ -444,21 +468,27 @@ mod tests {
#[test] #[test]
fn empty_allowlist_denies_everyone() { fn empty_allowlist_denies_everyone() {
let ch = DiscordChannel::new("fake".into(), None, vec![], false); let ch = DiscordChannel::new("fake".into(), None, vec![], false, false);
assert!(!ch.is_user_allowed("12345")); assert!(!ch.is_user_allowed("12345"));
assert!(!ch.is_user_allowed("anyone")); assert!(!ch.is_user_allowed("anyone"));
} }
#[test] #[test]
fn wildcard_allows_everyone() { fn wildcard_allows_everyone() {
let ch = DiscordChannel::new("fake".into(), None, vec!["*".into()], false); let ch = DiscordChannel::new("fake".into(), None, vec!["*".into()], false, false);
assert!(ch.is_user_allowed("12345")); assert!(ch.is_user_allowed("12345"));
assert!(ch.is_user_allowed("anyone")); assert!(ch.is_user_allowed("anyone"));
} }
#[test] #[test]
fn specific_allowlist_filters() { fn specific_allowlist_filters() {
let ch = DiscordChannel::new("fake".into(), None, vec!["111".into(), "222".into()], false); let ch = DiscordChannel::new(
"fake".into(),
None,
vec!["111".into(), "222".into()],
false,
false,
);
assert!(ch.is_user_allowed("111")); assert!(ch.is_user_allowed("111"));
assert!(ch.is_user_allowed("222")); assert!(ch.is_user_allowed("222"));
assert!(!ch.is_user_allowed("333")); assert!(!ch.is_user_allowed("333"));
@ -467,7 +497,7 @@ mod tests {
#[test] #[test]
fn allowlist_is_exact_match_not_substring() { fn allowlist_is_exact_match_not_substring() {
let ch = DiscordChannel::new("fake".into(), None, vec!["111".into()], false); let ch = DiscordChannel::new("fake".into(), None, vec!["111".into()], false, false);
assert!(!ch.is_user_allowed("1111")); assert!(!ch.is_user_allowed("1111"));
assert!(!ch.is_user_allowed("11")); assert!(!ch.is_user_allowed("11"));
assert!(!ch.is_user_allowed("0111")); assert!(!ch.is_user_allowed("0111"));
@ -475,20 +505,26 @@ mod tests {
#[test] #[test]
fn allowlist_empty_string_user_id() { fn allowlist_empty_string_user_id() {
let ch = DiscordChannel::new("fake".into(), None, vec!["111".into()], false); let ch = DiscordChannel::new("fake".into(), None, vec!["111".into()], false, false);
assert!(!ch.is_user_allowed("")); assert!(!ch.is_user_allowed(""));
} }
#[test] #[test]
fn allowlist_with_wildcard_and_specific() { fn allowlist_with_wildcard_and_specific() {
let ch = DiscordChannel::new("fake".into(), None, vec!["111".into(), "*".into()], false); let ch = DiscordChannel::new(
"fake".into(),
None,
vec!["111".into(), "*".into()],
false,
false,
);
assert!(ch.is_user_allowed("111")); assert!(ch.is_user_allowed("111"));
assert!(ch.is_user_allowed("anyone_else")); assert!(ch.is_user_allowed("anyone_else"));
} }
#[test] #[test]
fn allowlist_case_sensitive() { fn allowlist_case_sensitive() {
let ch = DiscordChannel::new("fake".into(), None, vec!["ABC".into()], false); let ch = DiscordChannel::new("fake".into(), None, vec!["ABC".into()], false, false);
assert!(ch.is_user_allowed("ABC")); assert!(ch.is_user_allowed("ABC"));
assert!(!ch.is_user_allowed("abc")); assert!(!ch.is_user_allowed("abc"));
assert!(!ch.is_user_allowed("Abc")); assert!(!ch.is_user_allowed("Abc"));
@ -663,14 +699,14 @@ mod tests {
#[test] #[test]
fn typing_handle_starts_as_none() { fn typing_handle_starts_as_none() {
let ch = DiscordChannel::new("fake".into(), None, vec![], false); let ch = DiscordChannel::new("fake".into(), None, vec![], false, false);
let guard = ch.typing_handle.lock().unwrap(); let guard = ch.typing_handle.lock().unwrap();
assert!(guard.is_none()); assert!(guard.is_none());
} }
#[tokio::test] #[tokio::test]
async fn start_typing_sets_handle() { async fn start_typing_sets_handle() {
let ch = DiscordChannel::new("fake".into(), None, vec![], false); let ch = DiscordChannel::new("fake".into(), None, vec![], false, false);
let _ = ch.start_typing("123456").await; let _ = ch.start_typing("123456").await;
let guard = ch.typing_handle.lock().unwrap(); let guard = ch.typing_handle.lock().unwrap();
assert!(guard.is_some()); assert!(guard.is_some());
@ -678,7 +714,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn stop_typing_clears_handle() { async fn stop_typing_clears_handle() {
let ch = DiscordChannel::new("fake".into(), None, vec![], false); let ch = DiscordChannel::new("fake".into(), None, vec![], false, false);
let _ = ch.start_typing("123456").await; let _ = ch.start_typing("123456").await;
let _ = ch.stop_typing("123456").await; let _ = ch.stop_typing("123456").await;
let guard = ch.typing_handle.lock().unwrap(); let guard = ch.typing_handle.lock().unwrap();
@ -687,14 +723,14 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn stop_typing_is_idempotent() { async fn stop_typing_is_idempotent() {
let ch = DiscordChannel::new("fake".into(), None, vec![], false); let ch = DiscordChannel::new("fake".into(), None, vec![], false, false);
assert!(ch.stop_typing("123456").await.is_ok()); assert!(ch.stop_typing("123456").await.is_ok());
assert!(ch.stop_typing("123456").await.is_ok()); assert!(ch.stop_typing("123456").await.is_ok());
} }
#[tokio::test] #[tokio::test]
async fn start_typing_replaces_existing_task() { async fn start_typing_replaces_existing_task() {
let ch = DiscordChannel::new("fake".into(), None, vec![], false); let ch = DiscordChannel::new("fake".into(), None, vec![], false, false);
let _ = ch.start_typing("111").await; let _ = ch.start_typing("111").await;
let _ = ch.start_typing("222").await; let _ = ch.start_typing("222").await;
let guard = ch.typing_handle.lock().unwrap(); let guard = ch.typing_handle.lock().unwrap();

View file

@ -10,6 +10,7 @@
use anyhow::{anyhow, Result}; use anyhow::{anyhow, Result};
use async_trait::async_trait; use async_trait::async_trait;
use lettre::message::SinglePart;
use lettre::transport::smtp::authentication::Credentials; use lettre::transport::smtp::authentication::Credentials;
use lettre::{Message, SmtpTransport, Transport}; use lettre::{Message, SmtpTransport, Transport};
use mail_parser::{MessageParser, MimeHeaders}; use mail_parser::{MessageParser, MimeHeaders};
@ -39,7 +40,7 @@ pub struct EmailConfig {
pub imap_folder: String, pub imap_folder: String,
/// SMTP server hostname /// SMTP server hostname
pub smtp_host: String, pub smtp_host: String,
/// SMTP server port (default: 587 for STARTTLS) /// SMTP server port (default: 465 for TLS)
#[serde(default = "default_smtp_port")] #[serde(default = "default_smtp_port")]
pub smtp_port: u16, pub smtp_port: u16,
/// Use TLS for SMTP (default: true) /// Use TLS for SMTP (default: true)
@ -63,7 +64,7 @@ fn default_imap_port() -> u16 {
993 993
} }
fn default_smtp_port() -> u16 { fn default_smtp_port() -> u16 {
587 465
} }
fn default_imap_folder() -> String { fn default_imap_folder() -> String {
"INBOX".into() "INBOX".into()
@ -389,7 +390,7 @@ impl Channel for EmailChannel {
.from(self.config.from_address.parse()?) .from(self.config.from_address.parse()?)
.to(recipient.parse()?) .to(recipient.parse()?)
.subject(subject) .subject(subject)
.body(body.to_string())?; .singlepart(SinglePart::plain(body.to_string()))?;
let transport = self.create_smtp_transport()?; let transport = self.create_smtp_transport()?;
transport.send(&email)?; transport.send(&email)?;
@ -427,6 +428,7 @@ impl Channel for EmailChannel {
} // MutexGuard dropped before await } // MutexGuard dropped before await
let msg = ChannelMessage { let msg = ChannelMessage {
id, id,
reply_target: sender.clone(),
sender, sender,
content, content,
channel: "email".to_string(), channel: "email".to_string(),
@ -464,6 +466,18 @@ impl Channel for EmailChannel {
mod tests { mod tests {
use super::*; use super::*;
#[test]
fn default_smtp_port_uses_tls_port() {
assert_eq!(default_smtp_port(), 465);
}
#[test]
fn email_config_default_uses_tls_smtp_defaults() {
let config = EmailConfig::default();
assert_eq!(config.smtp_port, 465);
assert!(config.smtp_tls);
}
#[test] #[test]
fn build_imap_tls_config_succeeds() { fn build_imap_tls_config_succeeds() {
let tls_config = let tls_config =
@ -504,7 +518,7 @@ mod tests {
assert_eq!(config.imap_port, 993); assert_eq!(config.imap_port, 993);
assert_eq!(config.imap_folder, "INBOX"); assert_eq!(config.imap_folder, "INBOX");
assert_eq!(config.smtp_host, ""); assert_eq!(config.smtp_host, "");
assert_eq!(config.smtp_port, 587); assert_eq!(config.smtp_port, 465);
assert!(config.smtp_tls); assert!(config.smtp_tls);
assert_eq!(config.username, ""); assert_eq!(config.username, "");
assert_eq!(config.password, ""); assert_eq!(config.password, "");
@ -765,8 +779,8 @@ mod tests {
} }
#[test] #[test]
fn default_smtp_port_returns_587() { fn default_smtp_port_returns_465() {
assert_eq!(default_smtp_port(), 587); assert_eq!(default_smtp_port(), 465);
} }
#[test] #[test]
@ -822,7 +836,7 @@ mod tests {
let config: EmailConfig = serde_json::from_str(json).unwrap(); let config: EmailConfig = serde_json::from_str(json).unwrap();
assert_eq!(config.imap_port, 993); // default assert_eq!(config.imap_port, 993); // default
assert_eq!(config.smtp_port, 587); // default assert_eq!(config.smtp_port, 465); // default
assert!(config.smtp_tls); // default assert!(config.smtp_tls); // default
assert_eq!(config.poll_interval_secs, 60); // default assert_eq!(config.poll_interval_secs, 60); // default
} }

View file

@ -172,6 +172,7 @@ end tell"#
let msg = ChannelMessage { let msg = ChannelMessage {
id: rowid.to_string(), id: rowid.to_string(),
sender: sender.clone(), sender: sender.clone(),
reply_target: sender.clone(),
content: text, content: text,
channel: "imessage".to_string(), channel: "imessage".to_string(),
timestamp: std::time::SystemTime::now() timestamp: std::time::SystemTime::now()

View file

@ -220,32 +220,34 @@ fn split_message(message: &str, max_bytes: usize) -> Vec<String> {
chunks chunks
} }
/// Configuration for constructing an `IrcChannel`.
pub struct IrcChannelConfig {
pub server: String,
pub port: u16,
pub nickname: String,
pub username: Option<String>,
pub channels: Vec<String>,
pub allowed_users: Vec<String>,
pub server_password: Option<String>,
pub nickserv_password: Option<String>,
pub sasl_password: Option<String>,
pub verify_tls: bool,
}
impl IrcChannel { impl IrcChannel {
#[allow(clippy::too_many_arguments)] pub fn new(cfg: IrcChannelConfig) -> Self {
pub fn new( let username = cfg.username.unwrap_or_else(|| cfg.nickname.clone());
server: String,
port: u16,
nickname: String,
username: Option<String>,
channels: Vec<String>,
allowed_users: Vec<String>,
server_password: Option<String>,
nickserv_password: Option<String>,
sasl_password: Option<String>,
verify_tls: bool,
) -> Self {
let username = username.unwrap_or_else(|| nickname.clone());
Self { Self {
server, server: cfg.server,
port, port: cfg.port,
nickname, nickname: cfg.nickname,
username, username,
channels, channels: cfg.channels,
allowed_users, allowed_users: cfg.allowed_users,
server_password, server_password: cfg.server_password,
nickserv_password, nickserv_password: cfg.nickserv_password,
sasl_password, sasl_password: cfg.sasl_password,
verify_tls, verify_tls: cfg.verify_tls,
writer: Arc::new(Mutex::new(None)), writer: Arc::new(Mutex::new(None)),
} }
} }
@ -563,7 +565,8 @@ impl Channel for IrcChannel {
let seq = MSG_SEQ.fetch_add(1, Ordering::Relaxed); let seq = MSG_SEQ.fetch_add(1, Ordering::Relaxed);
let channel_msg = ChannelMessage { let channel_msg = ChannelMessage {
id: format!("irc_{}_{seq}", chrono::Utc::now().timestamp_millis()), id: format!("irc_{}_{seq}", chrono::Utc::now().timestamp_millis()),
sender: reply_to, sender: sender_nick.to_string(),
reply_target: reply_to,
content, content,
channel: "irc".to_string(), channel: "irc".to_string(),
timestamp: std::time::SystemTime::now() timestamp: std::time::SystemTime::now()
@ -807,18 +810,18 @@ mod tests {
#[test] #[test]
fn specific_user_allowed() { fn specific_user_allowed() {
let ch = IrcChannel::new( let ch = IrcChannel::new(IrcChannelConfig {
"irc.test".into(), server: "irc.test".into(),
6697, port: 6697,
"bot".into(), nickname: "bot".into(),
None, username: None,
vec![], channels: vec![],
vec!["alice".into(), "bob".into()], allowed_users: vec!["alice".into(), "bob".into()],
None, server_password: None,
None, nickserv_password: None,
None, sasl_password: None,
true, verify_tls: true,
); });
assert!(ch.is_user_allowed("alice")); assert!(ch.is_user_allowed("alice"));
assert!(ch.is_user_allowed("bob")); assert!(ch.is_user_allowed("bob"));
assert!(!ch.is_user_allowed("eve")); assert!(!ch.is_user_allowed("eve"));
@ -826,18 +829,18 @@ mod tests {
#[test] #[test]
fn allowlist_case_insensitive() { fn allowlist_case_insensitive() {
let ch = IrcChannel::new( let ch = IrcChannel::new(IrcChannelConfig {
"irc.test".into(), server: "irc.test".into(),
6697, port: 6697,
"bot".into(), nickname: "bot".into(),
None, username: None,
vec![], channels: vec![],
vec!["Alice".into()], allowed_users: vec!["Alice".into()],
None, server_password: None,
None, nickserv_password: None,
None, sasl_password: None,
true, verify_tls: true,
); });
assert!(ch.is_user_allowed("alice")); assert!(ch.is_user_allowed("alice"));
assert!(ch.is_user_allowed("ALICE")); assert!(ch.is_user_allowed("ALICE"));
assert!(ch.is_user_allowed("Alice")); assert!(ch.is_user_allowed("Alice"));
@ -845,18 +848,18 @@ mod tests {
#[test] #[test]
fn empty_allowlist_denies_all() { fn empty_allowlist_denies_all() {
let ch = IrcChannel::new( let ch = IrcChannel::new(IrcChannelConfig {
"irc.test".into(), server: "irc.test".into(),
6697, port: 6697,
"bot".into(), nickname: "bot".into(),
None, username: None,
vec![], channels: vec![],
vec![], allowed_users: vec![],
None, server_password: None,
None, nickserv_password: None,
None, sasl_password: None,
true, verify_tls: true,
); });
assert!(!ch.is_user_allowed("anyone")); assert!(!ch.is_user_allowed("anyone"));
} }
@ -864,35 +867,35 @@ mod tests {
#[test] #[test]
fn new_defaults_username_to_nickname() { fn new_defaults_username_to_nickname() {
let ch = IrcChannel::new( let ch = IrcChannel::new(IrcChannelConfig {
"irc.test".into(), server: "irc.test".into(),
6697, port: 6697,
"mybot".into(), nickname: "mybot".into(),
None, username: None,
vec![], channels: vec![],
vec![], allowed_users: vec![],
None, server_password: None,
None, nickserv_password: None,
None, sasl_password: None,
true, verify_tls: true,
); });
assert_eq!(ch.username, "mybot"); assert_eq!(ch.username, "mybot");
} }
#[test] #[test]
fn new_uses_explicit_username() { fn new_uses_explicit_username() {
let ch = IrcChannel::new( let ch = IrcChannel::new(IrcChannelConfig {
"irc.test".into(), server: "irc.test".into(),
6697, port: 6697,
"mybot".into(), nickname: "mybot".into(),
Some("customuser".into()), username: Some("customuser".into()),
vec![], channels: vec![],
vec![], allowed_users: vec![],
None, server_password: None,
None, nickserv_password: None,
None, sasl_password: None,
true, verify_tls: true,
); });
assert_eq!(ch.username, "customuser"); assert_eq!(ch.username, "customuser");
assert_eq!(ch.nickname, "mybot"); assert_eq!(ch.nickname, "mybot");
} }
@ -905,18 +908,18 @@ mod tests {
#[test] #[test]
fn new_stores_all_fields() { fn new_stores_all_fields() {
let ch = IrcChannel::new( let ch = IrcChannel::new(IrcChannelConfig {
"irc.example.com".into(), server: "irc.example.com".into(),
6697, port: 6697,
"zcbot".into(), nickname: "zcbot".into(),
Some("zeroclaw".into()), username: Some("zeroclaw".into()),
vec!["#test".into()], channels: vec!["#test".into()],
vec!["alice".into()], allowed_users: vec!["alice".into()],
Some("serverpass".into()), server_password: Some("serverpass".into()),
Some("nspass".into()), nickserv_password: Some("nspass".into()),
Some("saslpass".into()), sasl_password: Some("saslpass".into()),
false, verify_tls: false,
); });
assert_eq!(ch.server, "irc.example.com"); assert_eq!(ch.server, "irc.example.com");
assert_eq!(ch.port, 6697); assert_eq!(ch.port, 6697);
assert_eq!(ch.nickname, "zcbot"); assert_eq!(ch.nickname, "zcbot");
@ -995,17 +998,17 @@ nickname = "bot"
// ── Helpers ───────────────────────────────────────────── // ── Helpers ─────────────────────────────────────────────
fn make_channel() -> IrcChannel { fn make_channel() -> IrcChannel {
IrcChannel::new( IrcChannel::new(IrcChannelConfig {
"irc.example.com".into(), server: "irc.example.com".into(),
6697, port: 6697,
"zcbot".into(), nickname: "zcbot".into(),
None, username: None,
vec!["#zeroclaw".into()], channels: vec!["#zeroclaw".into()],
vec!["*".into()], allowed_users: vec!["*".into()],
None, server_password: None,
None, nickserv_password: None,
None, sasl_password: None,
true, verify_tls: true,
) })
} }
} }

View file

@ -1,21 +1,152 @@
use super::traits::{Channel, ChannelMessage}; use super::traits::{Channel, ChannelMessage};
use async_trait::async_trait; use async_trait::async_trait;
use futures_util::{SinkExt, StreamExt};
use prost::Message as ProstMessage;
use std::collections::HashMap;
use std::sync::Arc; use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::RwLock; use tokio::sync::RwLock;
use tokio_tungstenite::tungstenite::Message as WsMsg;
use uuid::Uuid; use uuid::Uuid;
const FEISHU_BASE_URL: &str = "https://open.feishu.cn/open-apis"; const FEISHU_BASE_URL: &str = "https://open.feishu.cn/open-apis";
const FEISHU_WS_BASE_URL: &str = "https://open.feishu.cn";
const LARK_BASE_URL: &str = "https://open.larksuite.com/open-apis";
const LARK_WS_BASE_URL: &str = "https://open.larksuite.com";
/// Lark/Feishu channel — receives events via HTTP callback, sends via Open API // ─────────────────────────────────────────────────────────────────────────────
// Feishu WebSocket long-connection: pbbp2.proto frame codec
// ─────────────────────────────────────────────────────────────────────────────
#[derive(Clone, PartialEq, prost::Message)]
struct PbHeader {
#[prost(string, tag = "1")]
pub key: String,
#[prost(string, tag = "2")]
pub value: String,
}
/// Feishu WS frame (pbbp2.proto).
/// method=0 → CONTROL (ping/pong) method=1 → DATA (events)
#[derive(Clone, PartialEq, prost::Message)]
struct PbFrame {
#[prost(uint64, tag = "1")]
pub seq_id: u64,
#[prost(uint64, tag = "2")]
pub log_id: u64,
#[prost(int32, tag = "3")]
pub service: i32,
#[prost(int32, tag = "4")]
pub method: i32,
#[prost(message, repeated, tag = "5")]
pub headers: Vec<PbHeader>,
#[prost(bytes = "vec", optional, tag = "8")]
pub payload: Option<Vec<u8>>,
}
impl PbFrame {
fn header_value<'a>(&'a self, key: &str) -> &'a str {
self.headers
.iter()
.find(|h| h.key == key)
.map(|h| h.value.as_str())
.unwrap_or("")
}
}
/// Server-sent client config (parsed from pong payload)
#[derive(Debug, serde::Deserialize, Default, Clone)]
struct WsClientConfig {
#[serde(rename = "PingInterval")]
ping_interval: Option<u64>,
}
/// POST /callback/ws/endpoint response
#[derive(Debug, serde::Deserialize)]
struct WsEndpointResp {
code: i32,
#[serde(default)]
msg: Option<String>,
#[serde(default)]
data: Option<WsEndpoint>,
}
#[derive(Debug, serde::Deserialize)]
struct WsEndpoint {
#[serde(rename = "URL")]
url: String,
#[serde(rename = "ClientConfig")]
client_config: Option<WsClientConfig>,
}
/// LarkEvent envelope (method=1 / type=event payload)
#[derive(Debug, serde::Deserialize)]
struct LarkEvent {
header: LarkEventHeader,
event: serde_json::Value,
}
#[derive(Debug, serde::Deserialize)]
struct LarkEventHeader {
event_type: String,
#[allow(dead_code)]
event_id: String,
}
#[derive(Debug, serde::Deserialize)]
struct MsgReceivePayload {
sender: LarkSender,
message: LarkMessage,
}
#[derive(Debug, serde::Deserialize)]
struct LarkSender {
sender_id: LarkSenderId,
#[serde(default)]
sender_type: String,
}
#[derive(Debug, serde::Deserialize, Default)]
struct LarkSenderId {
open_id: Option<String>,
}
#[derive(Debug, serde::Deserialize)]
struct LarkMessage {
message_id: String,
chat_id: String,
chat_type: String,
message_type: String,
#[serde(default)]
content: String,
#[serde(default)]
mentions: Vec<serde_json::Value>,
}
/// Heartbeat timeout for WS connection — must be larger than ping_interval (default 120 s).
/// If no binary frame (pong or event) is received within this window, reconnect.
const WS_HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(300);
/// Lark/Feishu channel.
///
/// Supports two receive modes (configured via `receive_mode` in config):
/// - **`websocket`** (default): persistent WSS long-connection; no public URL needed.
/// - **`webhook`**: HTTP callback server; requires a public HTTPS endpoint.
pub struct LarkChannel { pub struct LarkChannel {
app_id: String, app_id: String,
app_secret: String, app_secret: String,
verification_token: String, verification_token: String,
port: u16, port: Option<u16>,
allowed_users: Vec<String>, allowed_users: Vec<String>,
/// When true, use Feishu (CN) endpoints; when false, use Lark (international).
use_feishu: bool,
/// How to receive events: WebSocket long-connection or HTTP webhook.
receive_mode: crate::config::schema::LarkReceiveMode,
client: reqwest::Client, client: reqwest::Client,
/// Cached tenant access token /// Cached tenant access token
tenant_token: Arc<RwLock<Option<String>>>, tenant_token: Arc<RwLock<Option<String>>>,
/// Dedup set: WS message_ids seen in last ~30 min to prevent double-dispatch
ws_seen_ids: Arc<RwLock<HashMap<String, Instant>>>,
} }
impl LarkChannel { impl LarkChannel {
@ -23,7 +154,7 @@ impl LarkChannel {
app_id: String, app_id: String,
app_secret: String, app_secret: String,
verification_token: String, verification_token: String,
port: u16, port: Option<u16>,
allowed_users: Vec<String>, allowed_users: Vec<String>,
) -> Self { ) -> Self {
Self { Self {
@ -32,11 +163,310 @@ impl LarkChannel {
verification_token, verification_token,
port, port,
allowed_users, allowed_users,
use_feishu: true,
receive_mode: crate::config::schema::LarkReceiveMode::default(),
client: reqwest::Client::new(), client: reqwest::Client::new(),
tenant_token: Arc::new(RwLock::new(None)), tenant_token: Arc::new(RwLock::new(None)),
ws_seen_ids: Arc::new(RwLock::new(HashMap::new())),
} }
} }
/// Build from `LarkConfig` (preserves `use_feishu` and `receive_mode`).
pub fn from_config(config: &crate::config::schema::LarkConfig) -> Self {
let mut ch = Self::new(
config.app_id.clone(),
config.app_secret.clone(),
config.verification_token.clone().unwrap_or_default(),
config.port,
config.allowed_users.clone(),
);
ch.use_feishu = config.use_feishu;
ch.receive_mode = config.receive_mode.clone();
ch
}
fn api_base(&self) -> &'static str {
if self.use_feishu {
FEISHU_BASE_URL
} else {
LARK_BASE_URL
}
}
fn ws_base(&self) -> &'static str {
if self.use_feishu {
FEISHU_WS_BASE_URL
} else {
LARK_WS_BASE_URL
}
}
fn tenant_access_token_url(&self) -> String {
format!("{}/auth/v3/tenant_access_token/internal", self.api_base())
}
fn send_message_url(&self) -> String {
format!("{}/im/v1/messages?receive_id_type=chat_id", self.api_base())
}
/// POST /callback/ws/endpoint → (wss_url, client_config)
async fn get_ws_endpoint(&self) -> anyhow::Result<(String, WsClientConfig)> {
let resp = self
.client
.post(format!("{}/callback/ws/endpoint", self.ws_base()))
.header("locale", if self.use_feishu { "zh" } else { "en" })
.json(&serde_json::json!({
"AppID": self.app_id,
"AppSecret": self.app_secret,
}))
.send()
.await?
.json::<WsEndpointResp>()
.await?;
if resp.code != 0 {
anyhow::bail!(
"Lark WS endpoint failed: code={} msg={}",
resp.code,
resp.msg.as_deref().unwrap_or("(none)")
);
}
let ep = resp
.data
.ok_or_else(|| anyhow::anyhow!("Lark WS endpoint: empty data"))?;
Ok((ep.url, ep.client_config.unwrap_or_default()))
}
/// WS long-connection event loop. Returns Ok(()) when the connection closes
/// (the caller reconnects).
#[allow(clippy::too_many_lines)]
async fn listen_ws(&self, tx: tokio::sync::mpsc::Sender<ChannelMessage>) -> anyhow::Result<()> {
let (wss_url, client_config) = self.get_ws_endpoint().await?;
let service_id = wss_url
.split('?')
.nth(1)
.and_then(|qs| {
qs.split('&')
.find(|kv| kv.starts_with("service_id="))
.and_then(|kv| kv.split('=').nth(1))
.and_then(|v| v.parse::<i32>().ok())
})
.unwrap_or(0);
tracing::info!("Lark: connecting to {wss_url}");
let (ws_stream, _) = tokio_tungstenite::connect_async(&wss_url).await?;
let (mut write, mut read) = ws_stream.split();
tracing::info!("Lark: WS connected (service_id={service_id})");
let mut ping_secs = client_config.ping_interval.unwrap_or(120).max(10);
let mut hb_interval = tokio::time::interval(Duration::from_secs(ping_secs));
let mut timeout_check = tokio::time::interval(Duration::from_secs(10));
hb_interval.tick().await; // consume immediate tick
let mut seq: u64 = 0;
let mut last_recv = Instant::now();
// Send initial ping immediately (like the official SDK) so the server
// starts responding with pongs and we can calibrate the ping_interval.
seq = seq.wrapping_add(1);
let initial_ping = PbFrame {
seq_id: seq,
log_id: 0,
service: service_id,
method: 0,
headers: vec![PbHeader {
key: "type".into(),
value: "ping".into(),
}],
payload: None,
};
if write
.send(WsMsg::Binary(initial_ping.encode_to_vec()))
.await
.is_err()
{
anyhow::bail!("Lark: initial ping failed");
}
// message_id → (fragment_slots, created_at) for multi-part reassembly
type FragEntry = (Vec<Option<Vec<u8>>>, Instant);
let mut frag_cache: HashMap<String, FragEntry> = HashMap::new();
loop {
tokio::select! {
biased;
_ = hb_interval.tick() => {
seq = seq.wrapping_add(1);
let ping = PbFrame {
seq_id: seq, log_id: 0, service: service_id, method: 0,
headers: vec![PbHeader { key: "type".into(), value: "ping".into() }],
payload: None,
};
if write.send(WsMsg::Binary(ping.encode_to_vec())).await.is_err() {
tracing::warn!("Lark: ping failed, reconnecting");
break;
}
// GC stale fragments > 5 min
let cutoff = Instant::now().checked_sub(Duration::from_secs(300)).unwrap_or(Instant::now());
frag_cache.retain(|_, (_, ts)| *ts > cutoff);
}
_ = timeout_check.tick() => {
if last_recv.elapsed() > WS_HEARTBEAT_TIMEOUT {
tracing::warn!("Lark: heartbeat timeout, reconnecting");
break;
}
}
msg = read.next() => {
let raw = match msg {
Some(Ok(WsMsg::Binary(b))) => { last_recv = Instant::now(); b }
Some(Ok(WsMsg::Ping(d))) => { let _ = write.send(WsMsg::Pong(d)).await; continue; }
Some(Ok(WsMsg::Close(_))) | None => { tracing::info!("Lark: WS closed — reconnecting"); break; }
Some(Err(e)) => { tracing::error!("Lark: WS read error: {e}"); break; }
_ => continue,
};
let frame = match PbFrame::decode(&raw[..]) {
Ok(f) => f,
Err(e) => { tracing::error!("Lark: proto decode: {e}"); continue; }
};
// CONTROL frame
if frame.method == 0 {
if frame.header_value("type") == "pong" {
if let Some(p) = &frame.payload {
if let Ok(cfg) = serde_json::from_slice::<WsClientConfig>(p) {
if let Some(secs) = cfg.ping_interval {
let secs = secs.max(10);
if secs != ping_secs {
ping_secs = secs;
hb_interval = tokio::time::interval(Duration::from_secs(ping_secs));
tracing::info!("Lark: ping_interval → {ping_secs}s");
}
}
}
}
}
continue;
}
// DATA frame
let msg_type = frame.header_value("type").to_string();
let msg_id = frame.header_value("message_id").to_string();
let sum = frame.header_value("sum").parse::<usize>().unwrap_or(1);
let seq_num = frame.header_value("seq").parse::<usize>().unwrap_or(0);
// ACK immediately (Feishu requires within 3 s)
{
let mut ack = frame.clone();
ack.payload = Some(br#"{"code":200,"headers":{},"data":[]}"#.to_vec());
ack.headers.push(PbHeader { key: "biz_rt".into(), value: "0".into() });
let _ = write.send(WsMsg::Binary(ack.encode_to_vec())).await;
}
// Fragment reassembly
let sum = if sum == 0 { 1 } else { sum };
let payload: Vec<u8> = if sum == 1 || msg_id.is_empty() || seq_num >= sum {
frame.payload.clone().unwrap_or_default()
} else {
let entry = frag_cache.entry(msg_id.clone())
.or_insert_with(|| (vec![None; sum], Instant::now()));
if entry.0.len() != sum { *entry = (vec![None; sum], Instant::now()); }
entry.0[seq_num] = frame.payload.clone();
if entry.0.iter().all(|s| s.is_some()) {
let full: Vec<u8> = entry.0.iter()
.flat_map(|s| s.as_deref().unwrap_or(&[]))
.copied().collect();
frag_cache.remove(&msg_id);
full
} else { continue; }
};
if msg_type != "event" { continue; }
let event: LarkEvent = match serde_json::from_slice(&payload) {
Ok(e) => e,
Err(e) => { tracing::error!("Lark: event JSON: {e}"); continue; }
};
if event.header.event_type != "im.message.receive_v1" { continue; }
let recv: MsgReceivePayload = match serde_json::from_value(event.event) {
Ok(r) => r,
Err(e) => { tracing::error!("Lark: payload parse: {e}"); continue; }
};
if recv.sender.sender_type == "app" || recv.sender.sender_type == "bot" { continue; }
let sender_open_id = recv.sender.sender_id.open_id.as_deref().unwrap_or("");
if !self.is_user_allowed(sender_open_id) {
tracing::warn!("Lark WS: ignoring {sender_open_id} (not in allowed_users)");
continue;
}
let lark_msg = &recv.message;
// Dedup
{
let now = Instant::now();
let mut seen = self.ws_seen_ids.write().await;
// GC
seen.retain(|_, t| now.duration_since(*t) < Duration::from_secs(30 * 60));
if seen.contains_key(&lark_msg.message_id) {
tracing::debug!("Lark WS: dup {}", lark_msg.message_id);
continue;
}
seen.insert(lark_msg.message_id.clone(), now);
}
// Decode content by type (mirrors clawdbot-feishu parsing)
let text = match lark_msg.message_type.as_str() {
"text" => {
let v: serde_json::Value = match serde_json::from_str(&lark_msg.content) {
Ok(v) => v,
Err(_) => continue,
};
match v.get("text").and_then(|t| t.as_str()).filter(|s| !s.is_empty()) {
Some(t) => t.to_string(),
None => continue,
}
}
"post" => match parse_post_content(&lark_msg.content) {
Some(t) => t,
None => continue,
},
_ => { tracing::debug!("Lark WS: skipping unsupported type '{}'", lark_msg.message_type); continue; }
};
// Strip @_user_N placeholders
let text = strip_at_placeholders(&text);
let text = text.trim().to_string();
if text.is_empty() { continue; }
// Group-chat: only respond when explicitly @-mentioned
if lark_msg.chat_type == "group" && !should_respond_in_group(&lark_msg.mentions) {
continue;
}
let channel_msg = ChannelMessage {
id: Uuid::new_v4().to_string(),
sender: lark_msg.chat_id.clone(),
reply_target: lark_msg.chat_id.clone(),
content: text,
channel: "lark".to_string(),
timestamp: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
};
tracing::debug!("Lark WS: message in {}", lark_msg.chat_id);
if tx.send(channel_msg).await.is_err() { break; }
}
}
}
Ok(())
}
/// Check if a user open_id is allowed /// Check if a user open_id is allowed
fn is_user_allowed(&self, open_id: &str) -> bool { fn is_user_allowed(&self, open_id: &str) -> bool {
self.allowed_users.iter().any(|u| u == "*" || u == open_id) self.allowed_users.iter().any(|u| u == "*" || u == open_id)
@ -52,7 +482,7 @@ impl LarkChannel {
} }
} }
let url = format!("{FEISHU_BASE_URL}/auth/v3/tenant_access_token/internal"); let url = self.tenant_access_token_url();
let body = serde_json::json!({ let body = serde_json::json!({
"app_id": self.app_id, "app_id": self.app_id,
"app_secret": self.app_secret, "app_secret": self.app_secret,
@ -127,31 +557,41 @@ impl LarkChannel {
return messages; return messages;
} }
// Extract message content (text only) // Extract message content (text and post supported)
let msg_type = event let msg_type = event
.pointer("/message/message_type") .pointer("/message/message_type")
.and_then(|t| t.as_str()) .and_then(|t| t.as_str())
.unwrap_or(""); .unwrap_or("");
if msg_type != "text" {
tracing::debug!("Lark: skipping non-text message type: {msg_type}");
return messages;
}
let content_str = event let content_str = event
.pointer("/message/content") .pointer("/message/content")
.and_then(|c| c.as_str()) .and_then(|c| c.as_str())
.unwrap_or(""); .unwrap_or("");
// content is a JSON string like "{\"text\":\"hello\"}" let text: String = match msg_type {
let text = serde_json::from_str::<serde_json::Value>(content_str) "text" => {
.ok() let extracted = serde_json::from_str::<serde_json::Value>(content_str)
.and_then(|v| v.get("text").and_then(|t| t.as_str()).map(String::from)) .ok()
.unwrap_or_default(); .and_then(|v| {
v.get("text")
if text.is_empty() { .and_then(|t| t.as_str())
return messages; .filter(|s| !s.is_empty())
} .map(String::from)
});
match extracted {
Some(t) => t,
None => return messages,
}
}
"post" => match parse_post_content(content_str) {
Some(t) => t,
None => return messages,
},
_ => {
tracing::debug!("Lark: skipping unsupported message type: {msg_type}");
return messages;
}
};
let timestamp = event let timestamp = event
.pointer("/message/create_time") .pointer("/message/create_time")
@ -174,6 +614,7 @@ impl LarkChannel {
messages.push(ChannelMessage { messages.push(ChannelMessage {
id: Uuid::new_v4().to_string(), id: Uuid::new_v4().to_string(),
sender: chat_id.to_string(), sender: chat_id.to_string(),
reply_target: chat_id.to_string(),
content: text, content: text,
channel: "lark".to_string(), channel: "lark".to_string(),
timestamp, timestamp,
@ -191,7 +632,7 @@ impl Channel for LarkChannel {
async fn send(&self, message: &str, recipient: &str) -> anyhow::Result<()> { async fn send(&self, message: &str, recipient: &str) -> anyhow::Result<()> {
let token = self.get_tenant_access_token().await?; let token = self.get_tenant_access_token().await?;
let url = format!("{FEISHU_BASE_URL}/im/v1/messages?receive_id_type=chat_id"); let url = self.send_message_url();
let content = serde_json::json!({ "text": message }).to_string(); let content = serde_json::json!({ "text": message }).to_string();
let body = serde_json::json!({ let body = serde_json::json!({
@ -238,6 +679,25 @@ impl Channel for LarkChannel {
} }
async fn listen(&self, tx: tokio::sync::mpsc::Sender<ChannelMessage>) -> anyhow::Result<()> { async fn listen(&self, tx: tokio::sync::mpsc::Sender<ChannelMessage>) -> anyhow::Result<()> {
use crate::config::schema::LarkReceiveMode;
match self.receive_mode {
LarkReceiveMode::Websocket => self.listen_ws(tx).await,
LarkReceiveMode::Webhook => self.listen_http(tx).await,
}
}
async fn health_check(&self) -> bool {
self.get_tenant_access_token().await.is_ok()
}
}
impl LarkChannel {
/// HTTP callback server (legacy — requires a public endpoint).
/// Use `listen()` (WS long-connection) for new deployments.
pub async fn listen_http(
&self,
tx: tokio::sync::mpsc::Sender<ChannelMessage>,
) -> anyhow::Result<()> {
use axum::{extract::State, routing::post, Json, Router}; use axum::{extract::State, routing::post, Json, Router};
#[derive(Clone)] #[derive(Clone)]
@ -282,13 +742,17 @@ impl Channel for LarkChannel {
(StatusCode::OK, "ok").into_response() (StatusCode::OK, "ok").into_response()
} }
let port = self.port.ok_or_else(|| {
anyhow::anyhow!("Lark webhook mode requires `port` to be set in [channels_config.lark]")
})?;
let state = AppState { let state = AppState {
verification_token: self.verification_token.clone(), verification_token: self.verification_token.clone(),
channel: Arc::new(LarkChannel::new( channel: Arc::new(LarkChannel::new(
self.app_id.clone(), self.app_id.clone(),
self.app_secret.clone(), self.app_secret.clone(),
self.verification_token.clone(), self.verification_token.clone(),
self.port, None,
self.allowed_users.clone(), self.allowed_users.clone(),
)), )),
tx, tx,
@ -298,7 +762,7 @@ impl Channel for LarkChannel {
.route("/lark", post(handle_event)) .route("/lark", post(handle_event))
.with_state(state); .with_state(state);
let addr = std::net::SocketAddr::from(([0, 0, 0, 0], self.port)); let addr = std::net::SocketAddr::from(([0, 0, 0, 0], port));
tracing::info!("Lark event callback server listening on {addr}"); tracing::info!("Lark event callback server listening on {addr}");
let listener = tokio::net::TcpListener::bind(addr).await?; let listener = tokio::net::TcpListener::bind(addr).await?;
@ -306,10 +770,110 @@ impl Channel for LarkChannel {
Ok(()) Ok(())
} }
}
async fn health_check(&self) -> bool { // ─────────────────────────────────────────────────────────────────────────────
self.get_tenant_access_token().await.is_ok() // WS helper functions
// ─────────────────────────────────────────────────────────────────────────────
/// Flatten a Feishu `post` rich-text message to plain text.
///
/// Returns `None` when the content cannot be parsed or yields no usable text,
/// so callers can simply `continue` rather than forwarding a meaningless
/// placeholder string to the agent.
fn parse_post_content(content: &str) -> Option<String> {
let parsed = serde_json::from_str::<serde_json::Value>(content).ok()?;
let locale = parsed
.get("zh_cn")
.or_else(|| parsed.get("en_us"))
.or_else(|| {
parsed
.as_object()
.and_then(|m| m.values().find(|v| v.is_object()))
})?;
let mut text = String::new();
if let Some(title) = locale
.get("title")
.and_then(|t| t.as_str())
.filter(|s| !s.is_empty())
{
text.push_str(title);
text.push_str("\n\n");
} }
if let Some(paragraphs) = locale.get("content").and_then(|c| c.as_array()) {
for para in paragraphs {
if let Some(elements) = para.as_array() {
for el in elements {
match el.get("tag").and_then(|t| t.as_str()).unwrap_or("") {
"text" => {
if let Some(t) = el.get("text").and_then(|t| t.as_str()) {
text.push_str(t);
}
}
"a" => {
text.push_str(
el.get("text")
.and_then(|t| t.as_str())
.filter(|s| !s.is_empty())
.or_else(|| el.get("href").and_then(|h| h.as_str()))
.unwrap_or(""),
);
}
"at" => {
let n = el
.get("user_name")
.and_then(|n| n.as_str())
.or_else(|| el.get("user_id").and_then(|i| i.as_str()))
.unwrap_or("user");
text.push('@');
text.push_str(n);
}
_ => {}
}
}
text.push('\n');
}
}
}
let result = text.trim().to_string();
if result.is_empty() {
None
} else {
Some(result)
}
}
/// Remove `@_user_N` placeholder tokens injected by Feishu in group chats.
fn strip_at_placeholders(text: &str) -> String {
let mut result = String::with_capacity(text.len());
let mut chars = text.char_indices().peekable();
while let Some((_, ch)) = chars.next() {
if ch == '@' {
let rest: String = chars.clone().map(|(_, c)| c).collect();
if let Some(after) = rest.strip_prefix("_user_") {
let skip =
"_user_".len() + after.chars().take_while(|c| c.is_ascii_digit()).count();
for _ in 0..=skip {
chars.next();
}
if chars.peek().map(|(_, c)| *c == ' ').unwrap_or(false) {
chars.next();
}
continue;
}
}
result.push(ch);
}
result
}
/// In group chats, only respond when the bot is explicitly @-mentioned.
fn should_respond_in_group(mentions: &[serde_json::Value]) -> bool {
!mentions.is_empty()
} }
#[cfg(test)] #[cfg(test)]
@ -321,7 +885,7 @@ mod tests {
"cli_test_app_id".into(), "cli_test_app_id".into(),
"test_app_secret".into(), "test_app_secret".into(),
"test_verification_token".into(), "test_verification_token".into(),
9898, None,
vec!["ou_testuser123".into()], vec!["ou_testuser123".into()],
) )
} }
@ -345,7 +909,7 @@ mod tests {
"id".into(), "id".into(),
"secret".into(), "secret".into(),
"token".into(), "token".into(),
9898, None,
vec!["*".into()], vec!["*".into()],
); );
assert!(ch.is_user_allowed("ou_anyone")); assert!(ch.is_user_allowed("ou_anyone"));
@ -353,7 +917,7 @@ mod tests {
#[test] #[test]
fn lark_user_denied_empty() { fn lark_user_denied_empty() {
let ch = LarkChannel::new("id".into(), "secret".into(), "token".into(), 9898, vec![]); let ch = LarkChannel::new("id".into(), "secret".into(), "token".into(), None, vec![]);
assert!(!ch.is_user_allowed("ou_anyone")); assert!(!ch.is_user_allowed("ou_anyone"));
} }
@ -426,7 +990,7 @@ mod tests {
"id".into(), "id".into(),
"secret".into(), "secret".into(),
"token".into(), "token".into(),
9898, None,
vec!["*".into()], vec!["*".into()],
); );
let payload = serde_json::json!({ let payload = serde_json::json!({
@ -451,7 +1015,7 @@ mod tests {
"id".into(), "id".into(),
"secret".into(), "secret".into(),
"token".into(), "token".into(),
9898, None,
vec!["*".into()], vec!["*".into()],
); );
let payload = serde_json::json!({ let payload = serde_json::json!({
@ -488,7 +1052,7 @@ mod tests {
"id".into(), "id".into(),
"secret".into(), "secret".into(),
"token".into(), "token".into(),
9898, None,
vec!["*".into()], vec!["*".into()],
); );
let payload = serde_json::json!({ let payload = serde_json::json!({
@ -512,7 +1076,7 @@ mod tests {
"id".into(), "id".into(),
"secret".into(), "secret".into(),
"token".into(), "token".into(),
9898, None,
vec!["*".into()], vec!["*".into()],
); );
let payload = serde_json::json!({ let payload = serde_json::json!({
@ -550,7 +1114,7 @@ mod tests {
"id".into(), "id".into(),
"secret".into(), "secret".into(),
"token".into(), "token".into(),
9898, None,
vec!["*".into()], vec!["*".into()],
); );
let payload = serde_json::json!({ let payload = serde_json::json!({
@ -571,7 +1135,7 @@ mod tests {
#[test] #[test]
fn lark_config_serde() { fn lark_config_serde() {
use crate::config::schema::LarkConfig; use crate::config::schema::{LarkConfig, LarkReceiveMode};
let lc = LarkConfig { let lc = LarkConfig {
app_id: "cli_app123".into(), app_id: "cli_app123".into(),
app_secret: "secret456".into(), app_secret: "secret456".into(),
@ -579,6 +1143,8 @@ mod tests {
verification_token: Some("vtoken789".into()), verification_token: Some("vtoken789".into()),
allowed_users: vec!["ou_user1".into(), "ou_user2".into()], allowed_users: vec!["ou_user1".into(), "ou_user2".into()],
use_feishu: false, use_feishu: false,
receive_mode: LarkReceiveMode::default(),
port: None,
}; };
let json = serde_json::to_string(&lc).unwrap(); let json = serde_json::to_string(&lc).unwrap();
let parsed: LarkConfig = serde_json::from_str(&json).unwrap(); let parsed: LarkConfig = serde_json::from_str(&json).unwrap();
@ -590,7 +1156,7 @@ mod tests {
#[test] #[test]
fn lark_config_toml_roundtrip() { fn lark_config_toml_roundtrip() {
use crate::config::schema::LarkConfig; use crate::config::schema::{LarkConfig, LarkReceiveMode};
let lc = LarkConfig { let lc = LarkConfig {
app_id: "app".into(), app_id: "app".into(),
app_secret: "secret".into(), app_secret: "secret".into(),
@ -598,6 +1164,8 @@ mod tests {
verification_token: Some("tok".into()), verification_token: Some("tok".into()),
allowed_users: vec!["*".into()], allowed_users: vec!["*".into()],
use_feishu: false, use_feishu: false,
receive_mode: LarkReceiveMode::Webhook,
port: Some(9898),
}; };
let toml_str = toml::to_string(&lc).unwrap(); let toml_str = toml::to_string(&lc).unwrap();
let parsed: LarkConfig = toml::from_str(&toml_str).unwrap(); let parsed: LarkConfig = toml::from_str(&toml_str).unwrap();
@ -608,11 +1176,36 @@ mod tests {
#[test] #[test]
fn lark_config_defaults_optional_fields() { fn lark_config_defaults_optional_fields() {
use crate::config::schema::LarkConfig; use crate::config::schema::{LarkConfig, LarkReceiveMode};
let json = r#"{"app_id":"a","app_secret":"s"}"#; let json = r#"{"app_id":"a","app_secret":"s"}"#;
let parsed: LarkConfig = serde_json::from_str(json).unwrap(); let parsed: LarkConfig = serde_json::from_str(json).unwrap();
assert!(parsed.verification_token.is_none()); assert!(parsed.verification_token.is_none());
assert!(parsed.allowed_users.is_empty()); assert!(parsed.allowed_users.is_empty());
assert_eq!(parsed.receive_mode, LarkReceiveMode::Websocket);
assert!(parsed.port.is_none());
}
#[test]
fn lark_from_config_preserves_mode_and_region() {
use crate::config::schema::{LarkConfig, LarkReceiveMode};
let cfg = LarkConfig {
app_id: "cli_app123".into(),
app_secret: "secret456".into(),
encrypt_key: None,
verification_token: Some("vtoken789".into()),
allowed_users: vec!["*".into()],
use_feishu: false,
receive_mode: LarkReceiveMode::Webhook,
port: Some(9898),
};
let ch = LarkChannel::from_config(&cfg);
assert_eq!(ch.api_base(), LARK_BASE_URL);
assert_eq!(ch.ws_base(), LARK_WS_BASE_URL);
assert_eq!(ch.receive_mode, LarkReceiveMode::Webhook);
assert_eq!(ch.port, Some(9898));
} }
#[test] #[test]
@ -622,7 +1215,7 @@ mod tests {
"id".into(), "id".into(),
"secret".into(), "secret".into(),
"token".into(), "token".into(),
9898, None,
vec!["*".into()], vec!["*".into()],
); );
let payload = serde_json::json!({ let payload = serde_json::json!({

View file

@ -230,6 +230,7 @@ impl Channel for MatrixChannel {
let msg = ChannelMessage { let msg = ChannelMessage {
id: format!("mx_{}", chrono::Utc::now().timestamp_millis()), id: format!("mx_{}", chrono::Utc::now().timestamp_millis()),
sender: event.sender.clone(), sender: event.sender.clone(),
reply_target: event.sender.clone(),
content: body.clone(), content: body.clone(),
channel: "matrix".to_string(), channel: "matrix".to_string(),
timestamp: std::time::SystemTime::now() timestamp: std::time::SystemTime::now()

View file

@ -69,10 +69,19 @@ fn conversation_memory_key(msg: &traits::ChannelMessage) -> String {
format!("{}_{}_{}", msg.channel, msg.sender, msg.id) format!("{}_{}_{}", msg.channel, msg.sender, msg.id)
} }
fn channel_delivery_instructions(channel_name: &str) -> Option<&'static str> {
match channel_name {
"telegram" => Some(
"When responding on Telegram, include media markers for files or URLs that should be sent as attachments. Use one marker per attachment with this exact syntax: [IMAGE:<path-or-url>], [DOCUMENT:<path-or-url>], [VIDEO:<path-or-url>], [AUDIO:<path-or-url>], or [VOICE:<path-or-url>]. Keep normal user-facing text outside markers and never wrap markers in code fences.",
),
_ => None,
}
}
async fn build_memory_context(mem: &dyn Memory, user_msg: &str) -> String { async fn build_memory_context(mem: &dyn Memory, user_msg: &str) -> String {
let mut context = String::new(); let mut context = String::new();
if let Ok(entries) = mem.recall(user_msg, 5).await { if let Ok(entries) = mem.recall(user_msg, 5, None).await {
if !entries.is_empty() { if !entries.is_empty() {
context.push_str("[Memory context]\n"); context.push_str("[Memory context]\n");
for entry in &entries { for entry in &entries {
@ -158,6 +167,7 @@ async fn process_channel_message(ctx: Arc<ChannelRuntimeContext>, msg: traits::C
&autosave_key, &autosave_key,
&msg.content, &msg.content,
crate::memory::MemoryCategory::Conversation, crate::memory::MemoryCategory::Conversation,
None,
) )
.await; .await;
} }
@ -171,7 +181,7 @@ async fn process_channel_message(ctx: Arc<ChannelRuntimeContext>, msg: traits::C
let target_channel = ctx.channels_by_name.get(&msg.channel).cloned(); let target_channel = ctx.channels_by_name.get(&msg.channel).cloned();
if let Some(channel) = target_channel.as_ref() { if let Some(channel) = target_channel.as_ref() {
if let Err(e) = channel.start_typing(&msg.sender).await { if let Err(e) = channel.start_typing(&msg.reply_target).await {
tracing::debug!("Failed to start typing on {}: {e}", channel.name()); tracing::debug!("Failed to start typing on {}: {e}", channel.name());
} }
} }
@ -184,6 +194,10 @@ async fn process_channel_message(ctx: Arc<ChannelRuntimeContext>, msg: traits::C
ChatMessage::user(&enriched_message), ChatMessage::user(&enriched_message),
]; ];
if let Some(instructions) = channel_delivery_instructions(&msg.channel) {
history.push(ChatMessage::system(instructions));
}
let llm_result = tokio::time::timeout( let llm_result = tokio::time::timeout(
Duration::from_secs(CHANNEL_MESSAGE_TIMEOUT_SECS), Duration::from_secs(CHANNEL_MESSAGE_TIMEOUT_SECS),
run_tool_call_loop( run_tool_call_loop(
@ -200,7 +214,7 @@ async fn process_channel_message(ctx: Arc<ChannelRuntimeContext>, msg: traits::C
.await; .await;
if let Some(channel) = target_channel.as_ref() { if let Some(channel) = target_channel.as_ref() {
if let Err(e) = channel.stop_typing(&msg.sender).await { if let Err(e) = channel.stop_typing(&msg.reply_target).await {
tracing::debug!("Failed to stop typing on {}: {e}", channel.name()); tracing::debug!("Failed to stop typing on {}: {e}", channel.name());
} }
} }
@ -224,7 +238,9 @@ async fn process_channel_message(ctx: Arc<ChannelRuntimeContext>, msg: traits::C
started_at.elapsed().as_millis() started_at.elapsed().as_millis()
); );
if let Some(channel) = target_channel.as_ref() { if let Some(channel) = target_channel.as_ref() {
let _ = channel.send(&format!("⚠️ Error: {e}"), &msg.sender).await; let _ = channel
.send(&format!("⚠️ Error: {e}"), &msg.reply_target)
.await;
} }
} }
Err(_) => { Err(_) => {
@ -241,7 +257,7 @@ async fn process_channel_message(ctx: Arc<ChannelRuntimeContext>, msg: traits::C
let _ = channel let _ = channel
.send( .send(
"⚠️ Request timed out while waiting for the model. Please try again.", "⚠️ Request timed out while waiting for the model. Please try again.",
&msg.sender, &msg.reply_target,
) )
.await; .await;
} }
@ -483,6 +499,16 @@ pub fn build_system_prompt(
std::env::consts::OS, std::env::consts::OS,
); );
// ── 8. Channel Capabilities ─────────────────────────────────────
prompt.push_str("## Channel Capabilities\n\n");
prompt.push_str(
"- You are running as a Discord bot. You CAN and do send messages to Discord channels.\n",
);
prompt.push_str("- When someone messages you on Discord, your response is automatically sent back to Discord.\n");
prompt.push_str("- You do NOT need to ask permission to respond — just respond directly.\n");
prompt.push_str("- NEVER repeat, describe, or echo credentials, tokens, API keys, or secrets in your responses.\n");
prompt.push_str("- If a tool output contains credentials, they have already been redacted — do not mention them.\n\n");
if prompt.is_empty() { if prompt.is_empty() {
"You are ZeroClaw, a fast and efficient AI assistant built in Rust. Be helpful, concise, and direct.".to_string() "You are ZeroClaw, a fast and efficient AI assistant built in Rust. Be helpful, concise, and direct.".to_string()
} else { } else {
@ -619,6 +645,7 @@ pub async fn doctor_channels(config: Config) -> Result<()> {
dc.guild_id.clone(), dc.guild_id.clone(),
dc.allowed_users.clone(), dc.allowed_users.clone(),
dc.listen_to_bots, dc.listen_to_bots,
dc.mention_only,
)), )),
)); ));
} }
@ -672,32 +699,23 @@ pub async fn doctor_channels(config: Config) -> Result<()> {
if let Some(ref irc) = config.channels_config.irc { if let Some(ref irc) = config.channels_config.irc {
channels.push(( channels.push((
"IRC", "IRC",
Arc::new(IrcChannel::new( Arc::new(IrcChannel::new(irc::IrcChannelConfig {
irc.server.clone(), server: irc.server.clone(),
irc.port, port: irc.port,
irc.nickname.clone(), nickname: irc.nickname.clone(),
irc.username.clone(), username: irc.username.clone(),
irc.channels.clone(), channels: irc.channels.clone(),
irc.allowed_users.clone(), allowed_users: irc.allowed_users.clone(),
irc.server_password.clone(), server_password: irc.server_password.clone(),
irc.nickserv_password.clone(), nickserv_password: irc.nickserv_password.clone(),
irc.sasl_password.clone(), sasl_password: irc.sasl_password.clone(),
irc.verify_tls.unwrap_or(true), verify_tls: irc.verify_tls.unwrap_or(true),
)), })),
)); ));
} }
if let Some(ref lk) = config.channels_config.lark { if let Some(ref lk) = config.channels_config.lark {
channels.push(( channels.push(("Lark", Arc::new(LarkChannel::from_config(lk))));
"Lark",
Arc::new(LarkChannel::new(
lk.app_id.clone(),
lk.app_secret.clone(),
lk.verification_token.clone().unwrap_or_default(),
9898,
lk.allowed_users.clone(),
)),
));
} }
if let Some(ref dt) = config.channels_config.dingtalk { if let Some(ref dt) = config.channels_config.dingtalk {
@ -762,6 +780,7 @@ pub async fn start_channels(config: Config) -> Result<()> {
let provider: Arc<dyn Provider> = Arc::from(providers::create_resilient_provider( let provider: Arc<dyn Provider> = Arc::from(providers::create_resilient_provider(
&provider_name, &provider_name,
config.api_key.as_deref(), config.api_key.as_deref(),
config.api_url.as_deref(),
&config.reliability, &config.reliability,
)?); )?);
@ -860,6 +879,10 @@ pub async fn start_channels(config: Config) -> Result<()> {
"schedule", "schedule",
"Manage scheduled tasks (create/list/get/cancel/pause/resume). Supports recurring cron and one-shot delays.", "Manage scheduled tasks (create/list/get/cancel/pause/resume). Supports recurring cron and one-shot delays.",
)); ));
tool_descs.push((
"pushover",
"Send a Pushover notification to your device. Requires PUSHOVER_TOKEN and PUSHOVER_USER_KEY in .env file.",
));
if !config.agents.is_empty() { if !config.agents.is_empty() {
tool_descs.push(( tool_descs.push((
"delegate", "delegate",
@ -909,6 +932,7 @@ pub async fn start_channels(config: Config) -> Result<()> {
dc.guild_id.clone(), dc.guild_id.clone(),
dc.allowed_users.clone(), dc.allowed_users.clone(),
dc.listen_to_bots, dc.listen_to_bots,
dc.mention_only,
))); )));
} }
@ -947,28 +971,22 @@ pub async fn start_channels(config: Config) -> Result<()> {
} }
if let Some(ref irc) = config.channels_config.irc { if let Some(ref irc) = config.channels_config.irc {
channels.push(Arc::new(IrcChannel::new( channels.push(Arc::new(IrcChannel::new(irc::IrcChannelConfig {
irc.server.clone(), server: irc.server.clone(),
irc.port, port: irc.port,
irc.nickname.clone(), nickname: irc.nickname.clone(),
irc.username.clone(), username: irc.username.clone(),
irc.channels.clone(), channels: irc.channels.clone(),
irc.allowed_users.clone(), allowed_users: irc.allowed_users.clone(),
irc.server_password.clone(), server_password: irc.server_password.clone(),
irc.nickserv_password.clone(), nickserv_password: irc.nickserv_password.clone(),
irc.sasl_password.clone(), sasl_password: irc.sasl_password.clone(),
irc.verify_tls.unwrap_or(true), verify_tls: irc.verify_tls.unwrap_or(true),
))); })));
} }
if let Some(ref lk) = config.channels_config.lark { if let Some(ref lk) = config.channels_config.lark {
channels.push(Arc::new(LarkChannel::new( channels.push(Arc::new(LarkChannel::from_config(lk)));
lk.app_id.clone(),
lk.app_secret.clone(),
lk.verification_token.clone().unwrap_or_default(),
9898,
lk.allowed_users.clone(),
)));
} }
if let Some(ref dt) = config.channels_config.dingtalk { if let Some(ref dt) = config.channels_config.dingtalk {
@ -1242,6 +1260,7 @@ mod tests {
traits::ChannelMessage { traits::ChannelMessage {
id: "msg-1".to_string(), id: "msg-1".to_string(),
sender: "alice".to_string(), sender: "alice".to_string(),
reply_target: "chat-42".to_string(),
content: "What is the BTC price now?".to_string(), content: "What is the BTC price now?".to_string(),
channel: "test-channel".to_string(), channel: "test-channel".to_string(),
timestamp: 1, timestamp: 1,
@ -1251,6 +1270,7 @@ mod tests {
let sent_messages = channel_impl.sent_messages.lock().await; let sent_messages = channel_impl.sent_messages.lock().await;
assert_eq!(sent_messages.len(), 1); assert_eq!(sent_messages.len(), 1);
assert!(sent_messages[0].starts_with("chat-42:"));
assert!(sent_messages[0].contains("BTC is currently around")); assert!(sent_messages[0].contains("BTC is currently around"));
assert!(!sent_messages[0].contains("\"tool_calls\"")); assert!(!sent_messages[0].contains("\"tool_calls\""));
assert!(!sent_messages[0].contains("mock_price")); assert!(!sent_messages[0].contains("mock_price"));
@ -1269,6 +1289,7 @@ mod tests {
_key: &str, _key: &str,
_content: &str, _content: &str,
_category: crate::memory::MemoryCategory, _category: crate::memory::MemoryCategory,
_session_id: Option<&str>,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
Ok(()) Ok(())
} }
@ -1277,6 +1298,7 @@ mod tests {
&self, &self,
_query: &str, _query: &str,
_limit: usize, _limit: usize,
_session_id: Option<&str>,
) -> anyhow::Result<Vec<crate::memory::MemoryEntry>> { ) -> anyhow::Result<Vec<crate::memory::MemoryEntry>> {
Ok(Vec::new()) Ok(Vec::new())
} }
@ -1288,6 +1310,7 @@ mod tests {
async fn list( async fn list(
&self, &self,
_category: Option<&crate::memory::MemoryCategory>, _category: Option<&crate::memory::MemoryCategory>,
_session_id: Option<&str>,
) -> anyhow::Result<Vec<crate::memory::MemoryEntry>> { ) -> anyhow::Result<Vec<crate::memory::MemoryEntry>> {
Ok(Vec::new()) Ok(Vec::new())
} }
@ -1331,6 +1354,7 @@ mod tests {
tx.send(traits::ChannelMessage { tx.send(traits::ChannelMessage {
id: "1".to_string(), id: "1".to_string(),
sender: "alice".to_string(), sender: "alice".to_string(),
reply_target: "alice".to_string(),
content: "hello".to_string(), content: "hello".to_string(),
channel: "test-channel".to_string(), channel: "test-channel".to_string(),
timestamp: 1, timestamp: 1,
@ -1340,6 +1364,7 @@ mod tests {
tx.send(traits::ChannelMessage { tx.send(traits::ChannelMessage {
id: "2".to_string(), id: "2".to_string(),
sender: "bob".to_string(), sender: "bob".to_string(),
reply_target: "bob".to_string(),
content: "world".to_string(), content: "world".to_string(),
channel: "test-channel".to_string(), channel: "test-channel".to_string(),
timestamp: 2, timestamp: 2,
@ -1570,6 +1595,25 @@ mod tests {
assert!(truncated.is_char_boundary(truncated.len())); assert!(truncated.is_char_boundary(truncated.len()));
} }
#[test]
fn prompt_contains_channel_capabilities() {
let ws = make_workspace();
let prompt = build_system_prompt(ws.path(), "model", &[], &[], None, None);
assert!(
prompt.contains("## Channel Capabilities"),
"missing Channel Capabilities section"
);
assert!(
prompt.contains("running as a Discord bot"),
"missing Discord context"
);
assert!(
prompt.contains("NEVER repeat, describe, or echo credentials"),
"missing security instruction"
);
}
#[test] #[test]
fn prompt_workspace_path() { fn prompt_workspace_path() {
let ws = make_workspace(); let ws = make_workspace();
@ -1583,6 +1627,7 @@ mod tests {
let msg = traits::ChannelMessage { let msg = traits::ChannelMessage {
id: "msg_abc123".into(), id: "msg_abc123".into(),
sender: "U123".into(), sender: "U123".into(),
reply_target: "C456".into(),
content: "hello".into(), content: "hello".into(),
channel: "slack".into(), channel: "slack".into(),
timestamp: 1, timestamp: 1,
@ -1596,6 +1641,7 @@ mod tests {
let msg1 = traits::ChannelMessage { let msg1 = traits::ChannelMessage {
id: "msg_1".into(), id: "msg_1".into(),
sender: "U123".into(), sender: "U123".into(),
reply_target: "C456".into(),
content: "first".into(), content: "first".into(),
channel: "slack".into(), channel: "slack".into(),
timestamp: 1, timestamp: 1,
@ -1603,6 +1649,7 @@ mod tests {
let msg2 = traits::ChannelMessage { let msg2 = traits::ChannelMessage {
id: "msg_2".into(), id: "msg_2".into(),
sender: "U123".into(), sender: "U123".into(),
reply_target: "C456".into(),
content: "second".into(), content: "second".into(),
channel: "slack".into(), channel: "slack".into(),
timestamp: 2, timestamp: 2,
@ -1622,6 +1669,7 @@ mod tests {
let msg1 = traits::ChannelMessage { let msg1 = traits::ChannelMessage {
id: "msg_1".into(), id: "msg_1".into(),
sender: "U123".into(), sender: "U123".into(),
reply_target: "C456".into(),
content: "I'm Paul".into(), content: "I'm Paul".into(),
channel: "slack".into(), channel: "slack".into(),
timestamp: 1, timestamp: 1,
@ -1629,6 +1677,7 @@ mod tests {
let msg2 = traits::ChannelMessage { let msg2 = traits::ChannelMessage {
id: "msg_2".into(), id: "msg_2".into(),
sender: "U123".into(), sender: "U123".into(),
reply_target: "C456".into(),
content: "I'm 45".into(), content: "I'm 45".into(),
channel: "slack".into(), channel: "slack".into(),
timestamp: 2, timestamp: 2,
@ -1638,6 +1687,7 @@ mod tests {
&conversation_memory_key(&msg1), &conversation_memory_key(&msg1),
&msg1.content, &msg1.content,
MemoryCategory::Conversation, MemoryCategory::Conversation,
None,
) )
.await .await
.unwrap(); .unwrap();
@ -1645,13 +1695,14 @@ mod tests {
&conversation_memory_key(&msg2), &conversation_memory_key(&msg2),
&msg2.content, &msg2.content,
MemoryCategory::Conversation, MemoryCategory::Conversation,
None,
) )
.await .await
.unwrap(); .unwrap();
assert_eq!(mem.count().await.unwrap(), 2); assert_eq!(mem.count().await.unwrap(), 2);
let recalled = mem.recall("45", 5).await.unwrap(); let recalled = mem.recall("45", 5, None).await.unwrap();
assert!(recalled.iter().any(|entry| entry.content.contains("45"))); assert!(recalled.iter().any(|entry| entry.content.contains("45")));
} }
@ -1659,7 +1710,7 @@ mod tests {
async fn build_memory_context_includes_recalled_entries() { async fn build_memory_context_includes_recalled_entries() {
let tmp = TempDir::new().unwrap(); let tmp = TempDir::new().unwrap();
let mem = SqliteMemory::new(tmp.path()).unwrap(); let mem = SqliteMemory::new(tmp.path()).unwrap();
mem.store("age_fact", "Age is 45", MemoryCategory::Conversation) mem.store("age_fact", "Age is 45", MemoryCategory::Conversation, None)
.await .await
.unwrap(); .unwrap();

View file

@ -161,6 +161,7 @@ impl Channel for SlackChannel {
let channel_msg = ChannelMessage { let channel_msg = ChannelMessage {
id: format!("slack_{channel_id}_{ts}"), id: format!("slack_{channel_id}_{ts}"),
sender: user.to_string(), sender: user.to_string(),
reply_target: channel_id.clone(),
content: text.to_string(), content: text.to_string(),
channel: "slack".to_string(), channel: "slack".to_string(),
timestamp: std::time::SystemTime::now() timestamp: std::time::SystemTime::now()

View file

@ -51,6 +51,133 @@ fn split_message_for_telegram(message: &str) -> Vec<String> {
chunks chunks
} }
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum TelegramAttachmentKind {
Image,
Document,
Video,
Audio,
Voice,
}
#[derive(Debug, Clone, PartialEq, Eq)]
struct TelegramAttachment {
kind: TelegramAttachmentKind,
target: String,
}
impl TelegramAttachmentKind {
fn from_marker(marker: &str) -> Option<Self> {
match marker.trim().to_ascii_uppercase().as_str() {
"IMAGE" | "PHOTO" => Some(Self::Image),
"DOCUMENT" | "FILE" => Some(Self::Document),
"VIDEO" => Some(Self::Video),
"AUDIO" => Some(Self::Audio),
"VOICE" => Some(Self::Voice),
_ => None,
}
}
}
fn is_http_url(target: &str) -> bool {
target.starts_with("http://") || target.starts_with("https://")
}
fn infer_attachment_kind_from_target(target: &str) -> Option<TelegramAttachmentKind> {
let normalized = target
.split('?')
.next()
.unwrap_or(target)
.split('#')
.next()
.unwrap_or(target);
let extension = Path::new(normalized)
.extension()
.and_then(|ext| ext.to_str())?
.to_ascii_lowercase();
match extension.as_str() {
"png" | "jpg" | "jpeg" | "gif" | "webp" | "bmp" => Some(TelegramAttachmentKind::Image),
"mp4" | "mov" | "mkv" | "avi" | "webm" => Some(TelegramAttachmentKind::Video),
"mp3" | "m4a" | "wav" | "flac" => Some(TelegramAttachmentKind::Audio),
"ogg" | "oga" | "opus" => Some(TelegramAttachmentKind::Voice),
"pdf" | "txt" | "md" | "csv" | "json" | "zip" | "tar" | "gz" | "doc" | "docx" | "xls"
| "xlsx" | "ppt" | "pptx" => Some(TelegramAttachmentKind::Document),
_ => None,
}
}
fn parse_path_only_attachment(message: &str) -> Option<TelegramAttachment> {
let trimmed = message.trim();
if trimmed.is_empty() || trimmed.contains('\n') {
return None;
}
let candidate = trimmed.trim_matches(|c| matches!(c, '`' | '"' | '\''));
if candidate.chars().any(char::is_whitespace) {
return None;
}
let candidate = candidate.strip_prefix("file://").unwrap_or(candidate);
let kind = infer_attachment_kind_from_target(candidate)?;
if !is_http_url(candidate) && !Path::new(candidate).exists() {
return None;
}
Some(TelegramAttachment {
kind,
target: candidate.to_string(),
})
}
fn parse_attachment_markers(message: &str) -> (String, Vec<TelegramAttachment>) {
let mut cleaned = String::with_capacity(message.len());
let mut attachments = Vec::new();
let mut cursor = 0;
while cursor < message.len() {
let Some(open_rel) = message[cursor..].find('[') else {
cleaned.push_str(&message[cursor..]);
break;
};
let open = cursor + open_rel;
cleaned.push_str(&message[cursor..open]);
let Some(close_rel) = message[open..].find(']') else {
cleaned.push_str(&message[open..]);
break;
};
let close = open + close_rel;
let marker = &message[open + 1..close];
let parsed = marker.split_once(':').and_then(|(kind, target)| {
let kind = TelegramAttachmentKind::from_marker(kind)?;
let target = target.trim();
if target.is_empty() {
return None;
}
Some(TelegramAttachment {
kind,
target: target.to_string(),
})
});
if let Some(attachment) = parsed {
attachments.push(attachment);
} else {
cleaned.push_str(&message[open..=close]);
}
cursor = close + 1;
}
(cleaned.trim().to_string(), attachments)
}
/// Telegram channel — long-polls the Bot API for updates /// Telegram channel — long-polls the Bot API for updates
pub struct TelegramChannel { pub struct TelegramChannel {
bot_token: String, bot_token: String,
@ -82,6 +209,216 @@ impl TelegramChannel {
identities.into_iter().any(|id| self.is_user_allowed(id)) identities.into_iter().any(|id| self.is_user_allowed(id))
} }
fn parse_update_message(&self, update: &serde_json::Value) -> Option<ChannelMessage> {
let message = update.get("message")?;
let text = message.get("text").and_then(serde_json::Value::as_str)?;
let username = message
.get("from")
.and_then(|from| from.get("username"))
.and_then(serde_json::Value::as_str)
.unwrap_or("unknown")
.to_string();
let user_id = message
.get("from")
.and_then(|from| from.get("id"))
.and_then(serde_json::Value::as_i64)
.map(|id| id.to_string());
let sender_identity = if username == "unknown" {
user_id.clone().unwrap_or_else(|| "unknown".to_string())
} else {
username.clone()
};
let mut identities = vec![username.as_str()];
if let Some(id) = user_id.as_deref() {
identities.push(id);
}
if !self.is_any_user_allowed(identities.iter().copied()) {
tracing::warn!(
"Telegram: ignoring message from unauthorized user: username={username}, user_id={}. \
Allowlist Telegram @username or numeric user ID, then run `zeroclaw onboard --channels-only`.",
user_id.as_deref().unwrap_or("unknown")
);
return None;
}
let chat_id = message
.get("chat")
.and_then(|chat| chat.get("id"))
.and_then(serde_json::Value::as_i64)
.map(|id| id.to_string())?;
let message_id = message
.get("message_id")
.and_then(serde_json::Value::as_i64)
.unwrap_or(0);
Some(ChannelMessage {
id: format!("telegram_{chat_id}_{message_id}"),
sender: sender_identity,
reply_target: chat_id,
content: text.to_string(),
channel: "telegram".to_string(),
timestamp: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
})
}
async fn send_text_chunks(&self, message: &str, chat_id: &str) -> anyhow::Result<()> {
let chunks = split_message_for_telegram(message);
for (index, chunk) in chunks.iter().enumerate() {
let text = if chunks.len() > 1 {
if index == 0 {
format!("{chunk}\n\n(continues...)")
} else if index == chunks.len() - 1 {
format!("(continued)\n\n{chunk}")
} else {
format!("(continued)\n\n{chunk}\n\n(continues...)")
}
} else {
chunk.to_string()
};
let markdown_body = serde_json::json!({
"chat_id": chat_id,
"text": text,
"parse_mode": "Markdown"
});
let markdown_resp = self
.client
.post(self.api_url("sendMessage"))
.json(&markdown_body)
.send()
.await?;
if markdown_resp.status().is_success() {
if index < chunks.len() - 1 {
tokio::time::sleep(Duration::from_millis(100)).await;
}
continue;
}
let markdown_status = markdown_resp.status();
let markdown_err = markdown_resp.text().await.unwrap_or_default();
tracing::warn!(
status = ?markdown_status,
"Telegram sendMessage with Markdown failed; retrying without parse_mode"
);
let plain_body = serde_json::json!({
"chat_id": chat_id,
"text": text,
});
let plain_resp = self
.client
.post(self.api_url("sendMessage"))
.json(&plain_body)
.send()
.await?;
if !plain_resp.status().is_success() {
let plain_status = plain_resp.status();
let plain_err = plain_resp.text().await.unwrap_or_default();
anyhow::bail!(
"Telegram sendMessage failed (markdown {}: {}; plain {}: {})",
markdown_status,
markdown_err,
plain_status,
plain_err
);
}
if index < chunks.len() - 1 {
tokio::time::sleep(Duration::from_millis(100)).await;
}
}
Ok(())
}
async fn send_media_by_url(
&self,
method: &str,
media_field: &str,
chat_id: &str,
url: &str,
caption: Option<&str>,
) -> anyhow::Result<()> {
let mut body = serde_json::json!({
"chat_id": chat_id,
});
body[media_field] = serde_json::Value::String(url.to_string());
if let Some(cap) = caption {
body["caption"] = serde_json::Value::String(cap.to_string());
}
let resp = self
.client
.post(self.api_url(method))
.json(&body)
.send()
.await?;
if !resp.status().is_success() {
let err = resp.text().await?;
anyhow::bail!("Telegram {method} by URL failed: {err}");
}
tracing::info!("Telegram {method} sent to {chat_id}: {url}");
Ok(())
}
async fn send_attachment(
&self,
chat_id: &str,
attachment: &TelegramAttachment,
) -> anyhow::Result<()> {
let target = attachment.target.trim();
if is_http_url(target) {
return match attachment.kind {
TelegramAttachmentKind::Image => {
self.send_photo_by_url(chat_id, target, None).await
}
TelegramAttachmentKind::Document => {
self.send_document_by_url(chat_id, target, None).await
}
TelegramAttachmentKind::Video => {
self.send_video_by_url(chat_id, target, None).await
}
TelegramAttachmentKind::Audio => {
self.send_audio_by_url(chat_id, target, None).await
}
TelegramAttachmentKind::Voice => {
self.send_voice_by_url(chat_id, target, None).await
}
};
}
let path = Path::new(target);
if !path.exists() {
anyhow::bail!("Telegram attachment path not found: {target}");
}
match attachment.kind {
TelegramAttachmentKind::Image => self.send_photo(chat_id, path, None).await,
TelegramAttachmentKind::Document => self.send_document(chat_id, path, None).await,
TelegramAttachmentKind::Video => self.send_video(chat_id, path, None).await,
TelegramAttachmentKind::Audio => self.send_audio(chat_id, path, None).await,
TelegramAttachmentKind::Voice => self.send_voice(chat_id, path, None).await,
}
}
/// Send a document/file to a Telegram chat /// Send a document/file to a Telegram chat
pub async fn send_document( pub async fn send_document(
&self, &self,
@ -408,6 +745,39 @@ impl TelegramChannel {
tracing::info!("Telegram photo (URL) sent to {chat_id}: {url}"); tracing::info!("Telegram photo (URL) sent to {chat_id}: {url}");
Ok(()) Ok(())
} }
/// Send a video by URL (Telegram will download it)
pub async fn send_video_by_url(
&self,
chat_id: &str,
url: &str,
caption: Option<&str>,
) -> anyhow::Result<()> {
self.send_media_by_url("sendVideo", "video", chat_id, url, caption)
.await
}
/// Send an audio file by URL (Telegram will download it)
pub async fn send_audio_by_url(
&self,
chat_id: &str,
url: &str,
caption: Option<&str>,
) -> anyhow::Result<()> {
self.send_media_by_url("sendAudio", "audio", chat_id, url, caption)
.await
}
/// Send a voice message by URL (Telegram will download it)
pub async fn send_voice_by_url(
&self,
chat_id: &str,
url: &str,
caption: Option<&str>,
) -> anyhow::Result<()> {
self.send_media_by_url("sendVoice", "voice", chat_id, url, caption)
.await
}
} }
#[async_trait] #[async_trait]
@ -417,82 +787,27 @@ impl Channel for TelegramChannel {
} }
async fn send(&self, message: &str, chat_id: &str) -> anyhow::Result<()> { async fn send(&self, message: &str, chat_id: &str) -> anyhow::Result<()> {
// Split message if it exceeds Telegram's 4096 character limit let (text_without_markers, attachments) = parse_attachment_markers(message);
let chunks = split_message_for_telegram(message);
for (i, chunk) in chunks.iter().enumerate() { if !attachments.is_empty() {
// Add continuation marker for multi-part messages if !text_without_markers.is_empty() {
let text = if chunks.len() > 1 { self.send_text_chunks(&text_without_markers, chat_id)
if i == 0 { .await?;
format!("{chunk}\n\n(continues...)")
} else if i == chunks.len() - 1 {
format!("(continued)\n\n{chunk}")
} else {
format!("(continued)\n\n{chunk}\n\n(continues...)")
}
} else {
chunk.to_string()
};
let markdown_body = serde_json::json!({
"chat_id": chat_id,
"text": text,
"parse_mode": "Markdown"
});
let markdown_resp = self
.client
.post(self.api_url("sendMessage"))
.json(&markdown_body)
.send()
.await?;
if markdown_resp.status().is_success() {
// Small delay between chunks to avoid rate limiting
if i < chunks.len() - 1 {
tokio::time::sleep(Duration::from_millis(100)).await;
}
continue;
} }
let markdown_status = markdown_resp.status(); for attachment in &attachments {
let markdown_err = markdown_resp.text().await.unwrap_or_default(); self.send_attachment(chat_id, attachment).await?;
tracing::warn!(
status = ?markdown_status,
"Telegram sendMessage with Markdown failed; retrying without parse_mode"
);
// Retry without parse_mode as a compatibility fallback.
let plain_body = serde_json::json!({
"chat_id": chat_id,
"text": text,
});
let plain_resp = self
.client
.post(self.api_url("sendMessage"))
.json(&plain_body)
.send()
.await?;
if !plain_resp.status().is_success() {
let plain_status = plain_resp.status();
let plain_err = plain_resp.text().await.unwrap_or_default();
anyhow::bail!(
"Telegram sendMessage failed (markdown {}: {}; plain {}: {})",
markdown_status,
markdown_err,
plain_status,
plain_err
);
} }
// Small delay between chunks to avoid rate limiting return Ok(());
if i < chunks.len() - 1 {
tokio::time::sleep(Duration::from_millis(100)).await;
}
} }
Ok(()) if let Some(attachment) = parse_path_only_attachment(message) {
self.send_attachment(chat_id, &attachment).await?;
return Ok(());
}
self.send_text_chunks(message, chat_id).await
} }
async fn listen(&self, tx: tokio::sync::mpsc::Sender<ChannelMessage>) -> anyhow::Result<()> { async fn listen(&self, tx: tokio::sync::mpsc::Sender<ChannelMessage>) -> anyhow::Result<()> {
@ -533,59 +848,13 @@ impl Channel for TelegramChannel {
offset = uid + 1; offset = uid + 1;
} }
let Some(message) = update.get("message") else { let Some(msg) = self.parse_update_message(update) else {
continue; continue;
}; };
let Some(text) = message.get("text").and_then(serde_json::Value::as_str) else {
continue;
};
let username_opt = message
.get("from")
.and_then(|f| f.get("username"))
.and_then(|u| u.as_str());
let username = username_opt.unwrap_or("unknown");
let user_id = message
.get("from")
.and_then(|f| f.get("id"))
.and_then(serde_json::Value::as_i64);
let user_id_str = user_id.map(|id| id.to_string());
let mut identities = vec![username];
if let Some(ref id) = user_id_str {
identities.push(id.as_str());
}
if !self.is_any_user_allowed(identities.iter().copied()) {
tracing::warn!(
"Telegram: ignoring message from unauthorized user: username={username}, user_id={}. \
Allowlist Telegram @username or numeric user ID, then run `zeroclaw onboard --channels-only`.",
user_id_str.as_deref().unwrap_or("unknown")
);
continue;
}
let chat_id = message
.get("chat")
.and_then(|c| c.get("id"))
.and_then(serde_json::Value::as_i64)
.map(|id| id.to_string());
let Some(chat_id) = chat_id else {
tracing::warn!("Telegram: missing chat_id in message, skipping");
continue;
};
let message_id = message
.get("message_id")
.and_then(|v| v.as_i64())
.unwrap_or(0);
// Send "typing" indicator immediately when we receive a message // Send "typing" indicator immediately when we receive a message
let typing_body = serde_json::json!({ let typing_body = serde_json::json!({
"chat_id": &chat_id, "chat_id": &msg.reply_target,
"action": "typing" "action": "typing"
}); });
let _ = self let _ = self
@ -595,17 +864,6 @@ Allowlist Telegram @username or numeric user ID, then run `zeroclaw onboard --ch
.send() .send()
.await; // Ignore errors for typing indicator .await; // Ignore errors for typing indicator
let msg = ChannelMessage {
id: format!("telegram_{chat_id}_{message_id}"),
sender: username.to_string(),
content: text.to_string(),
channel: "telegram".to_string(),
timestamp: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
};
if tx.send(msg).await.is_err() { if tx.send(msg).await.is_err() {
return Ok(()); return Ok(());
} }
@ -716,6 +974,107 @@ mod tests {
assert!(!ch.is_any_user_allowed(["unknown", "123456789"])); assert!(!ch.is_any_user_allowed(["unknown", "123456789"]));
} }
#[test]
fn parse_attachment_markers_extracts_multiple_types() {
let message = "Here are files [IMAGE:/tmp/a.png] and [DOCUMENT:https://example.com/a.pdf]";
let (cleaned, attachments) = parse_attachment_markers(message);
assert_eq!(cleaned, "Here are files and");
assert_eq!(attachments.len(), 2);
assert_eq!(attachments[0].kind, TelegramAttachmentKind::Image);
assert_eq!(attachments[0].target, "/tmp/a.png");
assert_eq!(attachments[1].kind, TelegramAttachmentKind::Document);
assert_eq!(attachments[1].target, "https://example.com/a.pdf");
}
#[test]
fn parse_attachment_markers_keeps_invalid_markers_in_text() {
let message = "Report [UNKNOWN:/tmp/a.bin]";
let (cleaned, attachments) = parse_attachment_markers(message);
assert_eq!(cleaned, "Report [UNKNOWN:/tmp/a.bin]");
assert!(attachments.is_empty());
}
#[test]
fn parse_path_only_attachment_detects_existing_file() {
let dir = tempfile::tempdir().unwrap();
let image_path = dir.path().join("snap.png");
std::fs::write(&image_path, b"fake-png").unwrap();
let parsed = parse_path_only_attachment(image_path.to_string_lossy().as_ref())
.expect("expected attachment");
assert_eq!(parsed.kind, TelegramAttachmentKind::Image);
assert_eq!(parsed.target, image_path.to_string_lossy());
}
#[test]
fn parse_path_only_attachment_rejects_sentence_text() {
assert!(parse_path_only_attachment("Screenshot saved to /tmp/snap.png").is_none());
}
#[test]
fn infer_attachment_kind_from_target_detects_document_extension() {
assert_eq!(
infer_attachment_kind_from_target("https://example.com/files/specs.pdf?download=1"),
Some(TelegramAttachmentKind::Document)
);
}
#[test]
fn parse_update_message_uses_chat_id_as_reply_target() {
let ch = TelegramChannel::new("token".into(), vec!["*".into()]);
let update = serde_json::json!({
"update_id": 1,
"message": {
"message_id": 33,
"text": "hello",
"from": {
"id": 555,
"username": "alice"
},
"chat": {
"id": -100200300
}
}
});
let msg = ch
.parse_update_message(&update)
.expect("message should parse");
assert_eq!(msg.sender, "alice");
assert_eq!(msg.reply_target, "-100200300");
assert_eq!(msg.content, "hello");
assert_eq!(msg.id, "telegram_-100200300_33");
}
#[test]
fn parse_update_message_allows_numeric_id_without_username() {
let ch = TelegramChannel::new("token".into(), vec!["555".into()]);
let update = serde_json::json!({
"update_id": 2,
"message": {
"message_id": 9,
"text": "ping",
"from": {
"id": 555
},
"chat": {
"id": 12345
}
}
});
let msg = ch
.parse_update_message(&update)
.expect("numeric allowlist should pass");
assert_eq!(msg.sender, "555");
assert_eq!(msg.reply_target, "12345");
}
// ── File sending API URL tests ────────────────────────────────── // ── File sending API URL tests ──────────────────────────────────
#[test] #[test]

View file

@ -5,6 +5,7 @@ use async_trait::async_trait;
pub struct ChannelMessage { pub struct ChannelMessage {
pub id: String, pub id: String,
pub sender: String, pub sender: String,
pub reply_target: String,
pub content: String, pub content: String,
pub channel: String, pub channel: String,
pub timestamp: u64, pub timestamp: u64,
@ -62,6 +63,7 @@ mod tests {
tx.send(ChannelMessage { tx.send(ChannelMessage {
id: "1".into(), id: "1".into(),
sender: "tester".into(), sender: "tester".into(),
reply_target: "tester".into(),
content: "hello".into(), content: "hello".into(),
channel: "dummy".into(), channel: "dummy".into(),
timestamp: 123, timestamp: 123,
@ -76,6 +78,7 @@ mod tests {
let message = ChannelMessage { let message = ChannelMessage {
id: "42".into(), id: "42".into(),
sender: "alice".into(), sender: "alice".into(),
reply_target: "alice".into(),
content: "ping".into(), content: "ping".into(),
channel: "dummy".into(), channel: "dummy".into(),
timestamp: 999, timestamp: 999,
@ -84,6 +87,7 @@ mod tests {
let cloned = message.clone(); let cloned = message.clone();
assert_eq!(cloned.id, "42"); assert_eq!(cloned.id, "42");
assert_eq!(cloned.sender, "alice"); assert_eq!(cloned.sender, "alice");
assert_eq!(cloned.reply_target, "alice");
assert_eq!(cloned.content, "ping"); assert_eq!(cloned.content, "ping");
assert_eq!(cloned.channel, "dummy"); assert_eq!(cloned.channel, "dummy");
assert_eq!(cloned.timestamp, 999); assert_eq!(cloned.timestamp, 999);

View file

@ -10,7 +10,7 @@ use uuid::Uuid;
/// happens in the gateway when Meta sends webhook events. /// happens in the gateway when Meta sends webhook events.
pub struct WhatsAppChannel { pub struct WhatsAppChannel {
access_token: String, access_token: String,
phone_number_id: String, endpoint_id: String,
verify_token: String, verify_token: String,
allowed_numbers: Vec<String>, allowed_numbers: Vec<String>,
client: reqwest::Client, client: reqwest::Client,
@ -19,13 +19,13 @@ pub struct WhatsAppChannel {
impl WhatsAppChannel { impl WhatsAppChannel {
pub fn new( pub fn new(
access_token: String, access_token: String,
phone_number_id: String, endpoint_id: String,
verify_token: String, verify_token: String,
allowed_numbers: Vec<String>, allowed_numbers: Vec<String>,
) -> Self { ) -> Self {
Self { Self {
access_token, access_token,
phone_number_id, endpoint_id,
verify_token, verify_token,
allowed_numbers, allowed_numbers,
client: reqwest::Client::new(), client: reqwest::Client::new(),
@ -119,6 +119,7 @@ impl WhatsAppChannel {
messages.push(ChannelMessage { messages.push(ChannelMessage {
id: Uuid::new_v4().to_string(), id: Uuid::new_v4().to_string(),
reply_target: normalized_from.clone(),
sender: normalized_from, sender: normalized_from,
content, content,
channel: "whatsapp".to_string(), channel: "whatsapp".to_string(),
@ -142,7 +143,7 @@ impl Channel for WhatsAppChannel {
// WhatsApp Cloud API: POST to /v18.0/{phone_number_id}/messages // WhatsApp Cloud API: POST to /v18.0/{phone_number_id}/messages
let url = format!( let url = format!(
"https://graph.facebook.com/v18.0/{}/messages", "https://graph.facebook.com/v18.0/{}/messages",
self.phone_number_id self.endpoint_id
); );
// Normalize recipient (remove leading + if present for API) // Normalize recipient (remove leading + if present for API)
@ -162,7 +163,7 @@ impl Channel for WhatsAppChannel {
let resp = self let resp = self
.client .client
.post(&url) .post(&url)
.header("Authorization", format!("Bearer {}", self.access_token)) .bearer_auth(&self.access_token)
.header("Content-Type", "application/json") .header("Content-Type", "application/json")
.json(&body) .json(&body)
.send() .send()
@ -195,11 +196,11 @@ impl Channel for WhatsAppChannel {
async fn health_check(&self) -> bool { async fn health_check(&self) -> bool {
// Check if we can reach the WhatsApp API // Check if we can reach the WhatsApp API
let url = format!("https://graph.facebook.com/v18.0/{}", self.phone_number_id); let url = format!("https://graph.facebook.com/v18.0/{}", self.endpoint_id);
self.client self.client
.get(&url) .get(&url)
.header("Authorization", format!("Bearer {}", self.access_token)) .bearer_auth(&self.access_token)
.send() .send()
.await .await
.map(|r| r.status().is_success()) .map(|r| r.status().is_success())

View file

@ -37,9 +37,22 @@ mod tests {
guild_id: Some("123".into()), guild_id: Some("123".into()),
allowed_users: vec![], allowed_users: vec![],
listen_to_bots: false, listen_to_bots: false,
mention_only: false,
};
let lark = LarkConfig {
app_id: "app-id".into(),
app_secret: "app-secret".into(),
encrypt_key: None,
verification_token: None,
allowed_users: vec![],
use_feishu: false,
receive_mode: crate::config::schema::LarkReceiveMode::Websocket,
port: None,
}; };
assert_eq!(telegram.allowed_users.len(), 1); assert_eq!(telegram.allowed_users.len(), 1);
assert_eq!(discord.guild_id.as_deref(), Some("123")); assert_eq!(discord.guild_id.as_deref(), Some("123"));
assert_eq!(lark.app_id, "app-id");
} }
} }

View file

@ -18,6 +18,8 @@ pub struct Config {
#[serde(skip)] #[serde(skip)]
pub config_path: PathBuf, pub config_path: PathBuf,
pub api_key: Option<String>, pub api_key: Option<String>,
/// Base URL override for provider API (e.g. "http://10.0.0.1:11434" for remote Ollama)
pub api_url: Option<String>,
pub default_provider: Option<String>, pub default_provider: Option<String>,
pub default_model: Option<String>, pub default_model: Option<String>,
pub default_temperature: f64, pub default_temperature: f64,
@ -1317,6 +1319,10 @@ pub struct DiscordConfig {
/// The bot still ignores its own messages to prevent feedback loops. /// The bot still ignores its own messages to prevent feedback loops.
#[serde(default)] #[serde(default)]
pub listen_to_bots: bool, pub listen_to_bots: bool,
/// When true, only respond to messages that @-mention the bot.
/// Other messages in the guild are silently ignored.
#[serde(default)]
pub mention_only: bool,
} }
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
@ -1395,8 +1401,20 @@ fn default_irc_port() -> u16 {
6697 6697
} }
/// Lark/Feishu configuration for messaging integration /// How ZeroClaw receives events from Feishu / Lark.
/// Lark is the international version, Feishu is the Chinese version ///
/// - `websocket` (default) — persistent WSS long-connection; no public URL required.
/// - `webhook` — HTTP callback server; requires a public HTTPS endpoint.
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
pub enum LarkReceiveMode {
#[default]
Websocket,
Webhook,
}
/// Lark/Feishu configuration for messaging integration.
/// Lark is the international version; Feishu is the Chinese version.
#[derive(Debug, Clone, Serialize, Deserialize)] #[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LarkConfig { pub struct LarkConfig {
/// App ID from Lark/Feishu developer console /// App ID from Lark/Feishu developer console
@ -1415,6 +1433,13 @@ pub struct LarkConfig {
/// Whether to use the Feishu (Chinese) endpoint instead of Lark (International) /// Whether to use the Feishu (Chinese) endpoint instead of Lark (International)
#[serde(default)] #[serde(default)]
pub use_feishu: bool, pub use_feishu: bool,
/// Event receive mode: "websocket" (default) or "webhook"
#[serde(default)]
pub receive_mode: LarkReceiveMode,
/// HTTP port for webhook mode only. Must be set when receive_mode = "webhook".
/// Not required (and ignored) for websocket mode.
#[serde(default)]
pub port: Option<u16>,
} }
// ── Security Config ───────────────────────────────────────────────── // ── Security Config ─────────────────────────────────────────────────
@ -1594,6 +1619,7 @@ impl Default for Config {
workspace_dir: zeroclaw_dir.join("workspace"), workspace_dir: zeroclaw_dir.join("workspace"),
config_path: zeroclaw_dir.join("config.toml"), config_path: zeroclaw_dir.join("config.toml"),
api_key: None, api_key: None,
api_url: None,
default_provider: Some("openrouter".to_string()), default_provider: Some("openrouter".to_string()),
default_model: Some("anthropic/claude-sonnet-4".to_string()), default_model: Some("anthropic/claude-sonnet-4".to_string()),
default_temperature: 0.7, default_temperature: 0.7,
@ -1623,35 +1649,146 @@ impl Default for Config {
} }
} }
impl Config { fn default_config_and_workspace_dirs() -> Result<(PathBuf, PathBuf)> {
pub fn load_or_init() -> Result<Self> { let home = UserDirs::new()
let home = UserDirs::new() .map(|u| u.home_dir().to_path_buf())
.map(|u| u.home_dir().to_path_buf()) .context("Could not find home directory")?;
.context("Could not find home directory")?; let config_dir = home.join(".zeroclaw");
let zeroclaw_dir = home.join(".zeroclaw"); Ok((config_dir.clone(), config_dir.join("workspace")))
let config_path = zeroclaw_dir.join("config.toml"); }
if !zeroclaw_dir.exists() { fn resolve_config_dir_for_workspace(workspace_dir: &Path) -> PathBuf {
fs::create_dir_all(&zeroclaw_dir).context("Failed to create .zeroclaw directory")?; let workspace_config_dir = workspace_dir.to_path_buf();
fs::create_dir_all(zeroclaw_dir.join("workspace")) if workspace_config_dir.join("config.toml").exists() {
.context("Failed to create workspace directory")?; return workspace_config_dir;
}
let legacy_config_dir = workspace_dir
.parent()
.map(|parent| parent.join(".zeroclaw"));
if let Some(legacy_dir) = legacy_config_dir {
if legacy_dir.join("config.toml").exists() {
return legacy_dir;
} }
if workspace_dir
.file_name()
.is_some_and(|name| name == std::ffi::OsStr::new("workspace"))
{
return legacy_dir;
}
}
workspace_config_dir
}
fn decrypt_optional_secret(
store: &crate::security::SecretStore,
value: &mut Option<String>,
field_name: &str,
) -> Result<()> {
if let Some(raw) = value.clone() {
if crate::security::SecretStore::is_encrypted(&raw) {
*value = Some(
store
.decrypt(&raw)
.with_context(|| format!("Failed to decrypt {field_name}"))?,
);
}
}
Ok(())
}
fn encrypt_optional_secret(
store: &crate::security::SecretStore,
value: &mut Option<String>,
field_name: &str,
) -> Result<()> {
if let Some(raw) = value.clone() {
if !crate::security::SecretStore::is_encrypted(&raw) {
*value = Some(
store
.encrypt(&raw)
.with_context(|| format!("Failed to encrypt {field_name}"))?,
);
}
}
Ok(())
}
impl Config {
pub fn load_or_init() -> Result<Self> {
// Resolve workspace first so config loading can follow ZEROCLAW_WORKSPACE.
let (zeroclaw_dir, workspace_dir) = match std::env::var("ZEROCLAW_WORKSPACE") {
Ok(custom_workspace) if !custom_workspace.is_empty() => {
let workspace = PathBuf::from(custom_workspace);
(resolve_config_dir_for_workspace(&workspace), workspace)
}
_ => default_config_and_workspace_dirs()?,
};
let config_path = zeroclaw_dir.join("config.toml");
fs::create_dir_all(&zeroclaw_dir).context("Failed to create config directory")?;
fs::create_dir_all(&workspace_dir).context("Failed to create workspace directory")?;
if config_path.exists() { if config_path.exists() {
// Warn if config file is world-readable (may contain API keys)
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
if let Ok(meta) = fs::metadata(&config_path) {
if meta.permissions().mode() & 0o004 != 0 {
tracing::warn!(
"Config file {:?} is world-readable (mode {:o}). \
Consider restricting with: chmod 600 {:?}",
config_path,
meta.permissions().mode() & 0o777,
config_path,
);
}
}
}
let contents = let contents =
fs::read_to_string(&config_path).context("Failed to read config file")?; fs::read_to_string(&config_path).context("Failed to read config file")?;
let mut config: Config = let mut config: Config =
toml::from_str(&contents).context("Failed to parse config file")?; toml::from_str(&contents).context("Failed to parse config file")?;
// Set computed paths that are skipped during serialization // Set computed paths that are skipped during serialization
config.config_path = config_path.clone(); config.config_path = config_path.clone();
config.workspace_dir = zeroclaw_dir.join("workspace"); config.workspace_dir = workspace_dir;
let store = crate::security::SecretStore::new(&zeroclaw_dir, config.secrets.encrypt);
decrypt_optional_secret(&store, &mut config.api_key, "config.api_key")?;
decrypt_optional_secret(
&store,
&mut config.composio.api_key,
"config.composio.api_key",
)?;
decrypt_optional_secret(
&store,
&mut config.browser.computer_use.api_key,
"config.browser.computer_use.api_key",
)?;
for agent in config.agents.values_mut() {
decrypt_optional_secret(&store, &mut agent.api_key, "config.agents.*.api_key")?;
}
config.apply_env_overrides(); config.apply_env_overrides();
Ok(config) Ok(config)
} else { } else {
let mut config = Config::default(); let mut config = Config::default();
config.config_path = config_path.clone(); config.config_path = config_path.clone();
config.workspace_dir = zeroclaw_dir.join("workspace"); config.workspace_dir = workspace_dir;
config.save()?; config.save()?;
// Restrict permissions on newly created config file (may contain API keys)
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
let _ = fs::set_permissions(&config_path, fs::Permissions::from_mode(0o600));
}
config.apply_env_overrides(); config.apply_env_overrides();
Ok(config) Ok(config)
} }
@ -1732,23 +1869,29 @@ impl Config {
} }
pub fn save(&self) -> Result<()> { pub fn save(&self) -> Result<()> {
// Encrypt agent API keys before serialization // Encrypt secrets before serialization
let mut config_to_save = self.clone(); let mut config_to_save = self.clone();
let zeroclaw_dir = self let zeroclaw_dir = self
.config_path .config_path
.parent() .parent()
.context("Config path must have a parent directory")?; .context("Config path must have a parent directory")?;
let store = crate::security::SecretStore::new(zeroclaw_dir, self.secrets.encrypt); let store = crate::security::SecretStore::new(zeroclaw_dir, self.secrets.encrypt);
encrypt_optional_secret(&store, &mut config_to_save.api_key, "config.api_key")?;
encrypt_optional_secret(
&store,
&mut config_to_save.composio.api_key,
"config.composio.api_key",
)?;
encrypt_optional_secret(
&store,
&mut config_to_save.browser.computer_use.api_key,
"config.browser.computer_use.api_key",
)?;
for agent in config_to_save.agents.values_mut() { for agent in config_to_save.agents.values_mut() {
if let Some(ref plaintext_key) = agent.api_key { encrypt_optional_secret(&store, &mut agent.api_key, "config.agents.*.api_key")?;
if !crate::security::SecretStore::is_encrypted(plaintext_key) {
agent.api_key = Some(
store
.encrypt(plaintext_key)
.context("Failed to encrypt agent API key")?,
);
}
}
} }
let toml_str = let toml_str =
@ -1949,6 +2092,7 @@ default_temperature = 0.7
workspace_dir: PathBuf::from("/tmp/test/workspace"), workspace_dir: PathBuf::from("/tmp/test/workspace"),
config_path: PathBuf::from("/tmp/test/config.toml"), config_path: PathBuf::from("/tmp/test/config.toml"),
api_key: Some("sk-test-key".into()), api_key: Some("sk-test-key".into()),
api_url: None,
default_provider: Some("openrouter".into()), default_provider: Some("openrouter".into()),
default_model: Some("gpt-4o".into()), default_model: Some("gpt-4o".into()),
default_temperature: 0.5, default_temperature: 0.5,
@ -2091,6 +2235,7 @@ tool_dispatcher = "xml"
workspace_dir: dir.join("workspace"), workspace_dir: dir.join("workspace"),
config_path: config_path.clone(), config_path: config_path.clone(),
api_key: Some("sk-roundtrip".into()), api_key: Some("sk-roundtrip".into()),
api_url: None,
default_provider: Some("openrouter".into()), default_provider: Some("openrouter".into()),
default_model: Some("test-model".into()), default_model: Some("test-model".into()),
default_temperature: 0.9, default_temperature: 0.9,
@ -2123,13 +2268,82 @@ tool_dispatcher = "xml"
let contents = fs::read_to_string(&config_path).unwrap(); let contents = fs::read_to_string(&config_path).unwrap();
let loaded: Config = toml::from_str(&contents).unwrap(); let loaded: Config = toml::from_str(&contents).unwrap();
assert_eq!(loaded.api_key.as_deref(), Some("sk-roundtrip")); assert!(loaded
.api_key
.as_deref()
.is_some_and(crate::security::SecretStore::is_encrypted));
let store = crate::security::SecretStore::new(&dir, true);
let decrypted = store.decrypt(loaded.api_key.as_deref().unwrap()).unwrap();
assert_eq!(decrypted, "sk-roundtrip");
assert_eq!(loaded.default_model.as_deref(), Some("test-model")); assert_eq!(loaded.default_model.as_deref(), Some("test-model"));
assert!((loaded.default_temperature - 0.9).abs() < f64::EPSILON); assert!((loaded.default_temperature - 0.9).abs() < f64::EPSILON);
let _ = fs::remove_dir_all(&dir); let _ = fs::remove_dir_all(&dir);
} }
#[test]
fn config_save_encrypts_nested_credentials() {
let dir = std::env::temp_dir().join(format!(
"zeroclaw_test_nested_credentials_{}",
uuid::Uuid::new_v4()
));
fs::create_dir_all(&dir).unwrap();
let mut config = Config::default();
config.workspace_dir = dir.join("workspace");
config.config_path = dir.join("config.toml");
config.api_key = Some("root-credential".into());
config.composio.api_key = Some("composio-credential".into());
config.browser.computer_use.api_key = Some("browser-credential".into());
config.agents.insert(
"worker".into(),
DelegateAgentConfig {
provider: "openrouter".into(),
model: "model-test".into(),
system_prompt: None,
api_key: Some("agent-credential".into()),
temperature: None,
max_depth: 3,
},
);
config.save().unwrap();
let contents = fs::read_to_string(config.config_path.clone()).unwrap();
let stored: Config = toml::from_str(&contents).unwrap();
let store = crate::security::SecretStore::new(&dir, true);
let root_encrypted = stored.api_key.as_deref().unwrap();
assert!(crate::security::SecretStore::is_encrypted(root_encrypted));
assert_eq!(store.decrypt(root_encrypted).unwrap(), "root-credential");
let composio_encrypted = stored.composio.api_key.as_deref().unwrap();
assert!(crate::security::SecretStore::is_encrypted(
composio_encrypted
));
assert_eq!(
store.decrypt(composio_encrypted).unwrap(),
"composio-credential"
);
let browser_encrypted = stored.browser.computer_use.api_key.as_deref().unwrap();
assert!(crate::security::SecretStore::is_encrypted(
browser_encrypted
));
assert_eq!(
store.decrypt(browser_encrypted).unwrap(),
"browser-credential"
);
let worker = stored.agents.get("worker").unwrap();
let worker_encrypted = worker.api_key.as_deref().unwrap();
assert!(crate::security::SecretStore::is_encrypted(worker_encrypted));
assert_eq!(store.decrypt(worker_encrypted).unwrap(), "agent-credential");
let _ = fs::remove_dir_all(&dir);
}
#[test] #[test]
fn config_save_atomic_cleanup() { fn config_save_atomic_cleanup() {
let dir = let dir =
@ -2182,6 +2396,7 @@ tool_dispatcher = "xml"
guild_id: Some("12345".into()), guild_id: Some("12345".into()),
allowed_users: vec![], allowed_users: vec![],
listen_to_bots: false, listen_to_bots: false,
mention_only: false,
}; };
let json = serde_json::to_string(&dc).unwrap(); let json = serde_json::to_string(&dc).unwrap();
let parsed: DiscordConfig = serde_json::from_str(&json).unwrap(); let parsed: DiscordConfig = serde_json::from_str(&json).unwrap();
@ -2196,6 +2411,7 @@ tool_dispatcher = "xml"
guild_id: None, guild_id: None,
allowed_users: vec![], allowed_users: vec![],
listen_to_bots: false, listen_to_bots: false,
mention_only: false,
}; };
let json = serde_json::to_string(&dc).unwrap(); let json = serde_json::to_string(&dc).unwrap();
let parsed: DiscordConfig = serde_json::from_str(&json).unwrap(); let parsed: DiscordConfig = serde_json::from_str(&json).unwrap();
@ -2818,6 +3034,96 @@ default_temperature = 0.7
std::env::remove_var("ZEROCLAW_WORKSPACE"); std::env::remove_var("ZEROCLAW_WORKSPACE");
} }
#[test]
fn load_or_init_workspace_override_uses_workspace_root_for_config() {
let _env_guard = env_override_test_guard();
let temp_home =
std::env::temp_dir().join(format!("zeroclaw_test_home_{}", uuid::Uuid::new_v4()));
let workspace_dir = temp_home.join("profile-a");
let original_home = std::env::var("HOME").ok();
std::env::set_var("HOME", &temp_home);
std::env::set_var("ZEROCLAW_WORKSPACE", &workspace_dir);
let config = Config::load_or_init().unwrap();
assert_eq!(config.workspace_dir, workspace_dir);
assert_eq!(config.config_path, workspace_dir.join("config.toml"));
assert!(workspace_dir.join("config.toml").exists());
std::env::remove_var("ZEROCLAW_WORKSPACE");
if let Some(home) = original_home {
std::env::set_var("HOME", home);
} else {
std::env::remove_var("HOME");
}
let _ = fs::remove_dir_all(temp_home);
}
#[test]
fn load_or_init_workspace_suffix_uses_legacy_config_layout() {
let _env_guard = env_override_test_guard();
let temp_home =
std::env::temp_dir().join(format!("zeroclaw_test_home_{}", uuid::Uuid::new_v4()));
let workspace_dir = temp_home.join("workspace");
let legacy_config_path = temp_home.join(".zeroclaw").join("config.toml");
let original_home = std::env::var("HOME").ok();
std::env::set_var("HOME", &temp_home);
std::env::set_var("ZEROCLAW_WORKSPACE", &workspace_dir);
let config = Config::load_or_init().unwrap();
assert_eq!(config.workspace_dir, workspace_dir);
assert_eq!(config.config_path, legacy_config_path);
assert!(config.config_path.exists());
std::env::remove_var("ZEROCLAW_WORKSPACE");
if let Some(home) = original_home {
std::env::set_var("HOME", home);
} else {
std::env::remove_var("HOME");
}
let _ = fs::remove_dir_all(temp_home);
}
#[test]
fn load_or_init_workspace_override_keeps_existing_legacy_config() {
let _env_guard = env_override_test_guard();
let temp_home =
std::env::temp_dir().join(format!("zeroclaw_test_home_{}", uuid::Uuid::new_v4()));
let workspace_dir = temp_home.join("custom-workspace");
let legacy_config_dir = temp_home.join(".zeroclaw");
let legacy_config_path = legacy_config_dir.join("config.toml");
fs::create_dir_all(&legacy_config_dir).unwrap();
fs::write(
&legacy_config_path,
r#"default_temperature = 0.7
default_model = "legacy-model"
"#,
)
.unwrap();
let original_home = std::env::var("HOME").ok();
std::env::set_var("HOME", &temp_home);
std::env::set_var("ZEROCLAW_WORKSPACE", &workspace_dir);
let config = Config::load_or_init().unwrap();
assert_eq!(config.workspace_dir, workspace_dir);
assert_eq!(config.config_path, legacy_config_path);
assert_eq!(config.default_model.as_deref(), Some("legacy-model"));
std::env::remove_var("ZEROCLAW_WORKSPACE");
if let Some(home) = original_home {
std::env::set_var("HOME", home);
} else {
std::env::remove_var("HOME");
}
let _ = fs::remove_dir_all(temp_home);
}
#[test] #[test]
fn env_override_empty_values_ignored() { fn env_override_empty_values_ignored() {
let _env_guard = env_override_test_guard(); let _env_guard = env_override_test_guard();
@ -2975,4 +3281,118 @@ default_temperature = 0.7
assert_eq!(parsed.boards[0].board, "nucleo-f401re"); assert_eq!(parsed.boards[0].board, "nucleo-f401re");
assert_eq!(parsed.boards[0].path.as_deref(), Some("/dev/ttyACM0")); assert_eq!(parsed.boards[0].path.as_deref(), Some("/dev/ttyACM0"));
} }
#[test]
fn lark_config_serde() {
let lc = LarkConfig {
app_id: "cli_123456".into(),
app_secret: "secret_abc".into(),
encrypt_key: Some("encrypt_key".into()),
verification_token: Some("verify_token".into()),
allowed_users: vec!["user_123".into(), "user_456".into()],
use_feishu: true,
receive_mode: LarkReceiveMode::Websocket,
port: None,
};
let json = serde_json::to_string(&lc).unwrap();
let parsed: LarkConfig = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.app_id, "cli_123456");
assert_eq!(parsed.app_secret, "secret_abc");
assert_eq!(parsed.encrypt_key.as_deref(), Some("encrypt_key"));
assert_eq!(parsed.verification_token.as_deref(), Some("verify_token"));
assert_eq!(parsed.allowed_users.len(), 2);
assert!(parsed.use_feishu);
}
#[test]
fn lark_config_toml_roundtrip() {
let lc = LarkConfig {
app_id: "cli_123456".into(),
app_secret: "secret_abc".into(),
encrypt_key: Some("encrypt_key".into()),
verification_token: Some("verify_token".into()),
allowed_users: vec!["*".into()],
use_feishu: false,
receive_mode: LarkReceiveMode::Webhook,
port: Some(9898),
};
let toml_str = toml::to_string(&lc).unwrap();
let parsed: LarkConfig = toml::from_str(&toml_str).unwrap();
assert_eq!(parsed.app_id, "cli_123456");
assert_eq!(parsed.app_secret, "secret_abc");
assert!(!parsed.use_feishu);
}
#[test]
fn lark_config_deserializes_without_optional_fields() {
let json = r#"{"app_id":"cli_123","app_secret":"secret"}"#;
let parsed: LarkConfig = serde_json::from_str(json).unwrap();
assert!(parsed.encrypt_key.is_none());
assert!(parsed.verification_token.is_none());
assert!(parsed.allowed_users.is_empty());
assert!(!parsed.use_feishu);
}
#[test]
fn lark_config_defaults_to_lark_endpoint() {
let json = r#"{"app_id":"cli_123","app_secret":"secret"}"#;
let parsed: LarkConfig = serde_json::from_str(json).unwrap();
assert!(
!parsed.use_feishu,
"use_feishu should default to false (Lark)"
);
}
#[test]
fn lark_config_with_wildcard_allowed_users() {
let json = r#"{"app_id":"cli_123","app_secret":"secret","allowed_users":["*"]}"#;
let parsed: LarkConfig = serde_json::from_str(json).unwrap();
assert_eq!(parsed.allowed_users, vec!["*"]);
}
// ── Config file permission hardening (Unix only) ───────────────
#[cfg(unix)]
#[test]
fn new_config_file_has_restricted_permissions() {
use std::os::unix::fs::PermissionsExt;
let tmp = tempfile::TempDir::new().unwrap();
let config_path = tmp.path().join("config.toml");
// Create a config and save it
let mut config = Config::default();
config.config_path = config_path.clone();
config.save().unwrap();
// Apply the same permission logic as load_or_init
let _ = std::fs::set_permissions(&config_path, std::fs::Permissions::from_mode(0o600));
let meta = std::fs::metadata(&config_path).unwrap();
let mode = meta.permissions().mode() & 0o777;
assert_eq!(
mode, 0o600,
"New config file should be owner-only (0600), got {mode:o}"
);
}
#[cfg(unix)]
#[test]
fn world_readable_config_is_detectable() {
use std::os::unix::fs::PermissionsExt;
let tmp = tempfile::TempDir::new().unwrap();
let config_path = tmp.path().join("config.toml");
// Create a config file with intentionally loose permissions
std::fs::write(&config_path, "# test config").unwrap();
std::fs::set_permissions(&config_path, std::fs::Permissions::from_mode(0o644)).unwrap();
let meta = std::fs::metadata(&config_path).unwrap();
let mode = meta.permissions().mode();
assert!(
mode & 0o004 != 0,
"Test setup: file should be world-readable (mode {mode:o})"
);
}
} }

View file

@ -245,6 +245,7 @@ async fn deliver_if_configured(config: &Config, job: &CronJob, output: &str) ->
dc.guild_id.clone(), dc.guild_id.clone(),
dc.allowed_users.clone(), dc.allowed_users.clone(),
dc.listen_to_bots, dc.listen_to_bots,
dc.mention_only,
); );
channel.send(output, target).await?; channel.send(output, target).await?;
} }

View file

@ -216,6 +216,7 @@ fn has_supervised_channels(config: &Config) -> bool {
|| config.channels_config.matrix.is_some() || config.channels_config.matrix.is_some()
|| config.channels_config.whatsapp.is_some() || config.channels_config.whatsapp.is_some()
|| config.channels_config.email.is_some() || config.channels_config.email.is_some()
|| config.channels_config.lark.is_some()
} }
#[cfg(test)] #[cfg(test)]

View file

@ -49,6 +49,13 @@ fn whatsapp_memory_key(msg: &crate::channels::traits::ChannelMessage) -> String
format!("whatsapp_{}_{}", msg.sender, msg.id) format!("whatsapp_{}_{}", msg.sender, msg.id)
} }
fn hash_webhook_secret(value: &str) -> String {
use sha2::{Digest, Sha256};
let digest = Sha256::digest(value.as_bytes());
hex::encode(digest)
}
/// How often the rate limiter sweeps stale IP entries from its map. /// How often the rate limiter sweeps stale IP entries from its map.
const RATE_LIMITER_SWEEP_INTERVAL_SECS: u64 = 300; // 5 minutes const RATE_LIMITER_SWEEP_INTERVAL_SECS: u64 = 300; // 5 minutes
@ -178,7 +185,8 @@ pub struct AppState {
pub temperature: f64, pub temperature: f64,
pub mem: Arc<dyn Memory>, pub mem: Arc<dyn Memory>,
pub auto_save: bool, pub auto_save: bool,
pub webhook_secret: Option<Arc<str>>, /// SHA-256 hash of `X-Webhook-Secret` (hex-encoded), never plaintext.
pub webhook_secret_hash: Option<Arc<str>>,
pub pairing: Arc<PairingGuard>, pub pairing: Arc<PairingGuard>,
pub rate_limiter: Arc<GatewayRateLimiter>, pub rate_limiter: Arc<GatewayRateLimiter>,
pub idempotency_store: Arc<IdempotencyStore>, pub idempotency_store: Arc<IdempotencyStore>,
@ -208,6 +216,7 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
let provider: Arc<dyn Provider> = Arc::from(providers::create_resilient_provider( let provider: Arc<dyn Provider> = Arc::from(providers::create_resilient_provider(
config.default_provider.as_deref().unwrap_or("openrouter"), config.default_provider.as_deref().unwrap_or("openrouter"),
config.api_key.as_deref(), config.api_key.as_deref(),
config.api_url.as_deref(),
&config.reliability, &config.reliability,
)?); )?);
let model = config let model = config
@ -251,12 +260,14 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
&config, &config,
)); ));
// Extract webhook secret for authentication // Extract webhook secret for authentication
let webhook_secret: Option<Arc<str>> = config let webhook_secret_hash: Option<Arc<str>> =
.channels_config config.channels_config.webhook.as_ref().and_then(|webhook| {
.webhook webhook.secret.as_ref().and_then(|raw_secret| {
.as_ref() let trimmed_secret = raw_secret.trim();
.and_then(|w| w.secret.as_deref()) (!trimmed_secret.is_empty())
.map(Arc::from); .then(|| Arc::<str>::from(hash_webhook_secret(trimmed_secret)))
})
});
// WhatsApp channel (if configured) // WhatsApp channel (if configured)
let whatsapp_channel: Option<Arc<WhatsAppChannel>> = let whatsapp_channel: Option<Arc<WhatsAppChannel>> =
@ -342,9 +353,6 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
} else { } else {
println!(" ⚠️ Pairing: DISABLED (all requests accepted)"); println!(" ⚠️ Pairing: DISABLED (all requests accepted)");
} }
if webhook_secret.is_some() {
println!(" 🔒 Webhook secret: ENABLED");
}
println!(" Press Ctrl+C to stop.\n"); println!(" Press Ctrl+C to stop.\n");
crate::health::mark_component_ok("gateway"); crate::health::mark_component_ok("gateway");
@ -356,7 +364,7 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
temperature, temperature,
mem, mem,
auto_save: config.memory.auto_save, auto_save: config.memory.auto_save,
webhook_secret, webhook_secret_hash,
pairing, pairing,
rate_limiter, rate_limiter,
idempotency_store, idempotency_store,
@ -482,12 +490,15 @@ async fn handle_webhook(
} }
// ── Webhook secret auth (optional, additional layer) ── // ── Webhook secret auth (optional, additional layer) ──
if let Some(ref secret) = state.webhook_secret { if let Some(ref secret_hash) = state.webhook_secret_hash {
let header_val = headers let header_hash = headers
.get("X-Webhook-Secret") .get("X-Webhook-Secret")
.and_then(|v| v.to_str().ok()); .and_then(|v| v.to_str().ok())
match header_val { .map(str::trim)
Some(val) if constant_time_eq(val, secret.as_ref()) => {} .filter(|value| !value.is_empty())
.map(hash_webhook_secret);
match header_hash {
Some(val) if constant_time_eq(&val, secret_hash.as_ref()) => {}
_ => { _ => {
tracing::warn!("Webhook: rejected request — invalid or missing X-Webhook-Secret"); tracing::warn!("Webhook: rejected request — invalid or missing X-Webhook-Secret");
let err = serde_json::json!({"error": "Unauthorized — invalid or missing X-Webhook-Secret header"}); let err = serde_json::json!({"error": "Unauthorized — invalid or missing X-Webhook-Secret header"});
@ -532,7 +543,7 @@ async fn handle_webhook(
let key = webhook_memory_key(); let key = webhook_memory_key();
let _ = state let _ = state
.mem .mem
.store(&key, message, MemoryCategory::Conversation) .store(&key, message, MemoryCategory::Conversation, None)
.await; .await;
} }
@ -685,7 +696,7 @@ async fn handle_whatsapp_message(
let key = whatsapp_memory_key(msg); let key = whatsapp_memory_key(msg);
let _ = state let _ = state
.mem .mem
.store(&key, &msg.content, MemoryCategory::Conversation) .store(&key, &msg.content, MemoryCategory::Conversation, None)
.await; .await;
} }
@ -697,7 +708,7 @@ async fn handle_whatsapp_message(
{ {
Ok(response) => { Ok(response) => {
// Send reply via WhatsApp // Send reply via WhatsApp
if let Err(e) = wa.send(&response, &msg.sender).await { if let Err(e) = wa.send(&response, &msg.reply_target).await {
tracing::error!("Failed to send WhatsApp reply: {e}"); tracing::error!("Failed to send WhatsApp reply: {e}");
} }
} }
@ -706,7 +717,7 @@ async fn handle_whatsapp_message(
let _ = wa let _ = wa
.send( .send(
"Sorry, I couldn't process your message right now.", "Sorry, I couldn't process your message right now.",
&msg.sender, &msg.reply_target,
) )
.await; .await;
} }
@ -798,7 +809,9 @@ mod tests {
.requests .requests
.lock() .lock()
.unwrap_or_else(std::sync::PoisonError::into_inner); .unwrap_or_else(std::sync::PoisonError::into_inner);
guard.1 = Instant::now() - Duration::from_secs(RATE_LIMITER_SWEEP_INTERVAL_SECS + 1); guard.1 = Instant::now()
.checked_sub(Duration::from_secs(RATE_LIMITER_SWEEP_INTERVAL_SECS + 1))
.unwrap();
// Clear timestamps for ip-2 and ip-3 to simulate stale entries // Clear timestamps for ip-2 and ip-3 to simulate stale entries
guard.0.get_mut("ip-2").unwrap().clear(); guard.0.get_mut("ip-2").unwrap().clear();
guard.0.get_mut("ip-3").unwrap().clear(); guard.0.get_mut("ip-3").unwrap().clear();
@ -848,6 +861,7 @@ mod tests {
let msg = ChannelMessage { let msg = ChannelMessage {
id: "wamid-123".into(), id: "wamid-123".into(),
sender: "+1234567890".into(), sender: "+1234567890".into(),
reply_target: "+1234567890".into(),
content: "hello".into(), content: "hello".into(),
channel: "whatsapp".into(), channel: "whatsapp".into(),
timestamp: 1, timestamp: 1,
@ -871,11 +885,17 @@ mod tests {
_key: &str, _key: &str,
_content: &str, _content: &str,
_category: MemoryCategory, _category: MemoryCategory,
_session_id: Option<&str>,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
Ok(()) Ok(())
} }
async fn recall(&self, _query: &str, _limit: usize) -> anyhow::Result<Vec<MemoryEntry>> { async fn recall(
&self,
_query: &str,
_limit: usize,
_session_id: Option<&str>,
) -> anyhow::Result<Vec<MemoryEntry>> {
Ok(Vec::new()) Ok(Vec::new())
} }
@ -886,6 +906,7 @@ mod tests {
async fn list( async fn list(
&self, &self,
_category: Option<&MemoryCategory>, _category: Option<&MemoryCategory>,
_session_id: Option<&str>,
) -> anyhow::Result<Vec<MemoryEntry>> { ) -> anyhow::Result<Vec<MemoryEntry>> {
Ok(Vec::new()) Ok(Vec::new())
} }
@ -938,6 +959,7 @@ mod tests {
key: &str, key: &str,
_content: &str, _content: &str,
_category: MemoryCategory, _category: MemoryCategory,
_session_id: Option<&str>,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
self.keys self.keys
.lock() .lock()
@ -946,7 +968,12 @@ mod tests {
Ok(()) Ok(())
} }
async fn recall(&self, _query: &str, _limit: usize) -> anyhow::Result<Vec<MemoryEntry>> { async fn recall(
&self,
_query: &str,
_limit: usize,
_session_id: Option<&str>,
) -> anyhow::Result<Vec<MemoryEntry>> {
Ok(Vec::new()) Ok(Vec::new())
} }
@ -957,6 +984,7 @@ mod tests {
async fn list( async fn list(
&self, &self,
_category: Option<&MemoryCategory>, _category: Option<&MemoryCategory>,
_session_id: Option<&str>,
) -> anyhow::Result<Vec<MemoryEntry>> { ) -> anyhow::Result<Vec<MemoryEntry>> {
Ok(Vec::new()) Ok(Vec::new())
} }
@ -991,7 +1019,7 @@ mod tests {
temperature: 0.0, temperature: 0.0,
mem: memory, mem: memory,
auto_save: false, auto_save: false,
webhook_secret: None, webhook_secret_hash: None,
pairing: Arc::new(PairingGuard::new(false, &[])), pairing: Arc::new(PairingGuard::new(false, &[])),
rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100)), rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100)),
idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300))), idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300))),
@ -1039,7 +1067,7 @@ mod tests {
temperature: 0.0, temperature: 0.0,
mem: memory, mem: memory,
auto_save: true, auto_save: true,
webhook_secret: None, webhook_secret_hash: None,
pairing: Arc::new(PairingGuard::new(false, &[])), pairing: Arc::new(PairingGuard::new(false, &[])),
rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100)), rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100)),
idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300))), idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300))),
@ -1077,6 +1105,125 @@ mod tests {
assert_eq!(provider_impl.calls.load(Ordering::SeqCst), 2); assert_eq!(provider_impl.calls.load(Ordering::SeqCst), 2);
} }
#[test]
fn webhook_secret_hash_is_deterministic_and_nonempty() {
let one = hash_webhook_secret("secret-value");
let two = hash_webhook_secret("secret-value");
let other = hash_webhook_secret("other-value");
assert_eq!(one, two);
assert_ne!(one, other);
assert_eq!(one.len(), 64);
}
#[tokio::test]
async fn webhook_secret_hash_rejects_missing_header() {
let provider_impl = Arc::new(MockProvider::default());
let provider: Arc<dyn Provider> = provider_impl.clone();
let memory: Arc<dyn Memory> = Arc::new(MockMemory);
let state = AppState {
provider,
model: "test-model".into(),
temperature: 0.0,
mem: memory,
auto_save: false,
webhook_secret_hash: Some(Arc::from(hash_webhook_secret("super-secret"))),
pairing: Arc::new(PairingGuard::new(false, &[])),
rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100)),
idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300))),
whatsapp: None,
whatsapp_app_secret: None,
};
let response = handle_webhook(
State(state),
HeaderMap::new(),
Ok(Json(WebhookBody {
message: "hello".into(),
})),
)
.await
.into_response();
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
assert_eq!(provider_impl.calls.load(Ordering::SeqCst), 0);
}
#[tokio::test]
async fn webhook_secret_hash_rejects_invalid_header() {
let provider_impl = Arc::new(MockProvider::default());
let provider: Arc<dyn Provider> = provider_impl.clone();
let memory: Arc<dyn Memory> = Arc::new(MockMemory);
let state = AppState {
provider,
model: "test-model".into(),
temperature: 0.0,
mem: memory,
auto_save: false,
webhook_secret_hash: Some(Arc::from(hash_webhook_secret("super-secret"))),
pairing: Arc::new(PairingGuard::new(false, &[])),
rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100)),
idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300))),
whatsapp: None,
whatsapp_app_secret: None,
};
let mut headers = HeaderMap::new();
headers.insert("X-Webhook-Secret", HeaderValue::from_static("wrong-secret"));
let response = handle_webhook(
State(state),
headers,
Ok(Json(WebhookBody {
message: "hello".into(),
})),
)
.await
.into_response();
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
assert_eq!(provider_impl.calls.load(Ordering::SeqCst), 0);
}
#[tokio::test]
async fn webhook_secret_hash_accepts_valid_header() {
let provider_impl = Arc::new(MockProvider::default());
let provider: Arc<dyn Provider> = provider_impl.clone();
let memory: Arc<dyn Memory> = Arc::new(MockMemory);
let state = AppState {
provider,
model: "test-model".into(),
temperature: 0.0,
mem: memory,
auto_save: false,
webhook_secret_hash: Some(Arc::from(hash_webhook_secret("super-secret"))),
pairing: Arc::new(PairingGuard::new(false, &[])),
rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100)),
idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300))),
whatsapp: None,
whatsapp_app_secret: None,
};
let mut headers = HeaderMap::new();
headers.insert("X-Webhook-Secret", HeaderValue::from_static("super-secret"));
let response = handle_webhook(
State(state),
headers,
Ok(Json(WebhookBody {
message: "hello".into(),
})),
)
.await
.into_response();
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(provider_impl.calls.load(Ordering::SeqCst), 1);
}
// ══════════════════════════════════════════════════════════ // ══════════════════════════════════════════════════════════
// WhatsApp Signature Verification Tests (CWE-345 Prevention) // WhatsApp Signature Verification Tests (CWE-345 Prevention)
// ══════════════════════════════════════════════════════════ // ══════════════════════════════════════════════════════════

View file

@ -34,8 +34,8 @@
use anyhow::{bail, Result}; use anyhow::{bail, Result};
use clap::{Parser, Subcommand}; use clap::{Parser, Subcommand};
use tracing::{info, Level}; use tracing::info;
use tracing_subscriber::FmtSubscriber; use tracing_subscriber::{fmt, EnvFilter};
mod agent; mod agent;
mod channels; mod channels;
@ -147,24 +147,24 @@ enum Commands {
/// Start the gateway server (webhooks, websockets) /// Start the gateway server (webhooks, websockets)
Gateway { Gateway {
/// Port to listen on (use 0 for random available port) /// Port to listen on (use 0 for random available port); defaults to config gateway.port
#[arg(short, long, default_value = "8080")] #[arg(short, long)]
port: u16, port: Option<u16>,
/// Host to bind to /// Host to bind to; defaults to config gateway.host
#[arg(long, default_value = "127.0.0.1")] #[arg(long)]
host: String, host: Option<String>,
}, },
/// Start long-running autonomous runtime (gateway + channels + heartbeat + scheduler) /// Start long-running autonomous runtime (gateway + channels + heartbeat + scheduler)
Daemon { Daemon {
/// Port to listen on (use 0 for random available port) /// Port to listen on (use 0 for random available port); defaults to config gateway.port
#[arg(short, long, default_value = "8080")] #[arg(short, long)]
port: u16, port: Option<u16>,
/// Host to bind to /// Host to bind to; defaults to config gateway.host
#[arg(long, default_value = "127.0.0.1")] #[arg(long)]
host: String, host: Option<String>,
}, },
/// Manage OS service lifecycle (launchd/systemd user service) /// Manage OS service lifecycle (launchd/systemd user service)
@ -367,9 +367,11 @@ async fn main() -> Result<()> {
let cli = Cli::parse(); let cli = Cli::parse();
// Initialize logging // Initialize logging - respects RUST_LOG env var, defaults to INFO
let subscriber = FmtSubscriber::builder() let subscriber = fmt::Subscriber::builder()
.with_max_level(Level::INFO) .with_env_filter(
EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")),
)
.finish(); .finish();
tracing::subscriber::set_global_default(subscriber).expect("setting default subscriber failed"); tracing::subscriber::set_global_default(subscriber).expect("setting default subscriber failed");
@ -434,6 +436,8 @@ async fn main() -> Result<()> {
.map(|_| ()), .map(|_| ()),
Commands::Gateway { port, host } => { Commands::Gateway { port, host } => {
let port = port.unwrap_or(config.gateway.port);
let host = host.unwrap_or_else(|| config.gateway.host.clone());
if port == 0 { if port == 0 {
info!("🚀 Starting ZeroClaw Gateway on {host} (random port)"); info!("🚀 Starting ZeroClaw Gateway on {host} (random port)");
} else { } else {
@ -443,6 +447,8 @@ async fn main() -> Result<()> {
} }
Commands::Daemon { port, host } => { Commands::Daemon { port, host } => {
let port = port.unwrap_or(config.gateway.port);
let host = host.unwrap_or_else(|| config.gateway.host.clone());
if port == 0 { if port == 0 {
info!("🧠 Starting ZeroClaw Daemon on {host} (random port)"); info!("🧠 Starting ZeroClaw Daemon on {host} (random port)");
} else { } else {

View file

@ -7,6 +7,7 @@ pub enum MemoryBackendKind {
Unknown, Unknown,
} }
#[allow(clippy::struct_excessive_bools)]
#[derive(Debug, Clone, Copy, Eq, PartialEq)] #[derive(Debug, Clone, Copy, Eq, PartialEq)]
pub struct MemoryBackendProfile { pub struct MemoryBackendProfile {
pub key: &'static str, pub key: &'static str,

View file

@ -502,10 +502,10 @@ mod tests {
let workspace = tmp.path(); let workspace = tmp.path();
let mem = SqliteMemory::new(workspace).unwrap(); let mem = SqliteMemory::new(workspace).unwrap();
mem.store("conv_old", "outdated", MemoryCategory::Conversation) mem.store("conv_old", "outdated", MemoryCategory::Conversation, None)
.await .await
.unwrap(); .unwrap();
mem.store("core_keep", "durable", MemoryCategory::Core) mem.store("core_keep", "durable", MemoryCategory::Core, None)
.await .await
.unwrap(); .unwrap();
drop(mem); drop(mem);

View file

@ -24,7 +24,9 @@ pub struct LucidMemory {
impl LucidMemory { impl LucidMemory {
const DEFAULT_LUCID_CMD: &'static str = "lucid"; const DEFAULT_LUCID_CMD: &'static str = "lucid";
const DEFAULT_TOKEN_BUDGET: usize = 200; const DEFAULT_TOKEN_BUDGET: usize = 200;
const DEFAULT_RECALL_TIMEOUT_MS: u64 = 120; // Lucid CLI cold start can exceed 120ms on slower machines, which causes
// avoidable fallback to local-only memory and premature cooldown.
const DEFAULT_RECALL_TIMEOUT_MS: u64 = 500;
const DEFAULT_STORE_TIMEOUT_MS: u64 = 800; const DEFAULT_STORE_TIMEOUT_MS: u64 = 800;
const DEFAULT_LOCAL_HIT_THRESHOLD: usize = 3; const DEFAULT_LOCAL_HIT_THRESHOLD: usize = 3;
const DEFAULT_FAILURE_COOLDOWN_MS: u64 = 15_000; const DEFAULT_FAILURE_COOLDOWN_MS: u64 = 15_000;
@ -74,6 +76,7 @@ impl LucidMemory {
} }
#[cfg(test)] #[cfg(test)]
#[allow(clippy::too_many_arguments)]
fn with_options( fn with_options(
workspace_dir: &Path, workspace_dir: &Path,
local: SqliteMemory, local: SqliteMemory,
@ -307,14 +310,22 @@ impl Memory for LucidMemory {
key: &str, key: &str,
content: &str, content: &str,
category: MemoryCategory, category: MemoryCategory,
session_id: Option<&str>,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
self.local.store(key, content, category.clone()).await?; self.local
.store(key, content, category.clone(), session_id)
.await?;
self.sync_to_lucid_async(key, content, &category).await; self.sync_to_lucid_async(key, content, &category).await;
Ok(()) Ok(())
} }
async fn recall(&self, query: &str, limit: usize) -> anyhow::Result<Vec<MemoryEntry>> { async fn recall(
let local_results = self.local.recall(query, limit).await?; &self,
query: &str,
limit: usize,
session_id: Option<&str>,
) -> anyhow::Result<Vec<MemoryEntry>> {
let local_results = self.local.recall(query, limit, session_id).await?;
if limit == 0 if limit == 0
|| local_results.len() >= limit || local_results.len() >= limit
|| local_results.len() >= self.local_hit_threshold || local_results.len() >= self.local_hit_threshold
@ -351,8 +362,12 @@ impl Memory for LucidMemory {
self.local.get(key).await self.local.get(key).await
} }
async fn list(&self, category: Option<&MemoryCategory>) -> anyhow::Result<Vec<MemoryEntry>> { async fn list(
self.local.list(category).await &self,
category: Option<&MemoryCategory>,
session_id: Option<&str>,
) -> anyhow::Result<Vec<MemoryEntry>> {
self.local.list(category, session_id).await
} }
async fn forget(&self, key: &str) -> anyhow::Result<bool> { async fn forget(&self, key: &str) -> anyhow::Result<bool> {
@ -396,6 +411,38 @@ EOF
exit 0 exit 0
fi fi
echo "unsupported command" >&2
exit 1
"#;
fs::write(&script_path, script).unwrap();
let mut perms = fs::metadata(&script_path).unwrap().permissions();
perms.set_mode(0o755);
fs::set_permissions(&script_path, perms).unwrap();
script_path.display().to_string()
}
fn write_delayed_lucid_script(dir: &Path) -> String {
let script_path = dir.join("delayed-lucid.sh");
let script = r#"#!/usr/bin/env bash
set -euo pipefail
if [[ "${1:-}" == "store" ]]; then
echo '{"success":true,"id":"mem_1"}'
exit 0
fi
if [[ "${1:-}" == "context" ]]; then
# Simulate a cold start that is slower than 120ms but below the 500ms timeout.
sleep 0.2
cat <<'EOF'
<lucid-context>
- [decision] Delayed token refresh guidance
</lucid-context>
EOF
exit 0
fi
echo "unsupported command" >&2 echo "unsupported command" >&2
exit 1 exit 1
"#; "#;
@ -449,7 +496,7 @@ exit 1
cmd, cmd,
200, 200,
3, 3,
Duration::from_millis(120), Duration::from_millis(500),
Duration::from_millis(400), Duration::from_millis(400),
Duration::from_secs(2), Duration::from_secs(2),
) )
@ -468,7 +515,7 @@ exit 1
let memory = test_memory(tmp.path(), "nonexistent-lucid-binary".to_string()); let memory = test_memory(tmp.path(), "nonexistent-lucid-binary".to_string());
memory memory
.store("lang", "User prefers Rust", MemoryCategory::Core) .store("lang", "User prefers Rust", MemoryCategory::Core, None)
.await .await
.unwrap(); .unwrap();
@ -483,6 +530,30 @@ exit 1
let fake_cmd = write_fake_lucid_script(tmp.path()); let fake_cmd = write_fake_lucid_script(tmp.path());
let memory = test_memory(tmp.path(), fake_cmd); let memory = test_memory(tmp.path(), fake_cmd);
memory
.store(
"local_note",
"Local sqlite auth fallback note",
MemoryCategory::Core,
None,
)
.await
.unwrap();
let entries = memory.recall("auth", 5, None).await.unwrap();
assert!(entries
.iter()
.any(|e| e.content.contains("Local sqlite auth fallback note")));
assert!(entries.iter().any(|e| e.content.contains("token refresh")));
}
#[tokio::test]
async fn recall_handles_lucid_cold_start_delay_within_timeout() {
let tmp = TempDir::new().unwrap();
let delayed_cmd = write_delayed_lucid_script(tmp.path());
let memory = test_memory(tmp.path(), delayed_cmd);
memory memory
.store( .store(
"local_note", "local_note",
@ -497,7 +568,9 @@ exit 1
assert!(entries assert!(entries
.iter() .iter()
.any(|e| e.content.contains("Local sqlite auth fallback note"))); .any(|e| e.content.contains("Local sqlite auth fallback note")));
assert!(entries.iter().any(|e| e.content.contains("token refresh"))); assert!(entries
.iter()
.any(|e| e.content.contains("Delayed token refresh guidance")));
} }
#[tokio::test] #[tokio::test]
@ -513,17 +586,22 @@ exit 1
probe_cmd, probe_cmd,
200, 200,
1, 1,
Duration::from_millis(120), Duration::from_millis(500),
Duration::from_millis(400), Duration::from_millis(400),
Duration::from_secs(2), Duration::from_secs(2),
); );
memory memory
.store("pref", "Rust should stay local-first", MemoryCategory::Core) .store(
"pref",
"Rust should stay local-first",
MemoryCategory::Core,
None,
)
.await .await
.unwrap(); .unwrap();
let entries = memory.recall("rust", 5).await.unwrap(); let entries = memory.recall("rust", 5, None).await.unwrap();
assert!(entries assert!(entries
.iter() .iter()
.any(|e| e.content.contains("Rust should stay local-first"))); .any(|e| e.content.contains("Rust should stay local-first")));
@ -578,13 +656,13 @@ exit 1
failing_cmd, failing_cmd,
200, 200,
99, 99,
Duration::from_millis(120), Duration::from_millis(500),
Duration::from_millis(400), Duration::from_millis(400),
Duration::from_secs(5), Duration::from_secs(5),
); );
let first = memory.recall("auth", 5).await.unwrap(); let first = memory.recall("auth", 5, None).await.unwrap();
let second = memory.recall("auth", 5).await.unwrap(); let second = memory.recall("auth", 5, None).await.unwrap();
assert!(first.is_empty()); assert!(first.is_empty());
assert!(second.is_empty()); assert!(second.is_empty());

View file

@ -143,6 +143,7 @@ impl Memory for MarkdownMemory {
key: &str, key: &str,
content: &str, content: &str,
category: MemoryCategory, category: MemoryCategory,
_session_id: Option<&str>,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
let entry = format!("- **{key}**: {content}"); let entry = format!("- **{key}**: {content}");
let path = match category { let path = match category {
@ -152,7 +153,12 @@ impl Memory for MarkdownMemory {
self.append_to_file(&path, &entry).await self.append_to_file(&path, &entry).await
} }
async fn recall(&self, query: &str, limit: usize) -> anyhow::Result<Vec<MemoryEntry>> { async fn recall(
&self,
query: &str,
limit: usize,
_session_id: Option<&str>,
) -> anyhow::Result<Vec<MemoryEntry>> {
let all = self.read_all_entries().await?; let all = self.read_all_entries().await?;
let query_lower = query.to_lowercase(); let query_lower = query.to_lowercase();
let keywords: Vec<&str> = query_lower.split_whitespace().collect(); let keywords: Vec<&str> = query_lower.split_whitespace().collect();
@ -192,7 +198,11 @@ impl Memory for MarkdownMemory {
.find(|e| e.key == key || e.content.contains(key))) .find(|e| e.key == key || e.content.contains(key)))
} }
async fn list(&self, category: Option<&MemoryCategory>) -> anyhow::Result<Vec<MemoryEntry>> { async fn list(
&self,
category: Option<&MemoryCategory>,
_session_id: Option<&str>,
) -> anyhow::Result<Vec<MemoryEntry>> {
let all = self.read_all_entries().await?; let all = self.read_all_entries().await?;
match category { match category {
Some(cat) => Ok(all.into_iter().filter(|e| &e.category == cat).collect()), Some(cat) => Ok(all.into_iter().filter(|e| &e.category == cat).collect()),
@ -243,7 +253,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn markdown_store_core() { async fn markdown_store_core() {
let (_tmp, mem) = temp_workspace(); let (_tmp, mem) = temp_workspace();
mem.store("pref", "User likes Rust", MemoryCategory::Core) mem.store("pref", "User likes Rust", MemoryCategory::Core, None)
.await .await
.unwrap(); .unwrap();
let content = sync_fs::read_to_string(mem.core_path()).unwrap(); let content = sync_fs::read_to_string(mem.core_path()).unwrap();
@ -253,7 +263,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn markdown_store_daily() { async fn markdown_store_daily() {
let (_tmp, mem) = temp_workspace(); let (_tmp, mem) = temp_workspace();
mem.store("note", "Finished tests", MemoryCategory::Daily) mem.store("note", "Finished tests", MemoryCategory::Daily, None)
.await .await
.unwrap(); .unwrap();
let path = mem.daily_path(); let path = mem.daily_path();
@ -264,17 +274,17 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn markdown_recall_keyword() { async fn markdown_recall_keyword() {
let (_tmp, mem) = temp_workspace(); let (_tmp, mem) = temp_workspace();
mem.store("a", "Rust is fast", MemoryCategory::Core) mem.store("a", "Rust is fast", MemoryCategory::Core, None)
.await .await
.unwrap(); .unwrap();
mem.store("b", "Python is slow", MemoryCategory::Core) mem.store("b", "Python is slow", MemoryCategory::Core, None)
.await .await
.unwrap(); .unwrap();
mem.store("c", "Rust and safety", MemoryCategory::Core) mem.store("c", "Rust and safety", MemoryCategory::Core, None)
.await .await
.unwrap(); .unwrap();
let results = mem.recall("Rust", 10).await.unwrap(); let results = mem.recall("Rust", 10, None).await.unwrap();
assert!(results.len() >= 2); assert!(results.len() >= 2);
assert!(results assert!(results
.iter() .iter()
@ -284,18 +294,20 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn markdown_recall_no_match() { async fn markdown_recall_no_match() {
let (_tmp, mem) = temp_workspace(); let (_tmp, mem) = temp_workspace();
mem.store("a", "Rust is great", MemoryCategory::Core) mem.store("a", "Rust is great", MemoryCategory::Core, None)
.await .await
.unwrap(); .unwrap();
let results = mem.recall("javascript", 10).await.unwrap(); let results = mem.recall("javascript", 10, None).await.unwrap();
assert!(results.is_empty()); assert!(results.is_empty());
} }
#[tokio::test] #[tokio::test]
async fn markdown_count() { async fn markdown_count() {
let (_tmp, mem) = temp_workspace(); let (_tmp, mem) = temp_workspace();
mem.store("a", "first", MemoryCategory::Core).await.unwrap(); mem.store("a", "first", MemoryCategory::Core, None)
mem.store("b", "second", MemoryCategory::Core) .await
.unwrap();
mem.store("b", "second", MemoryCategory::Core, None)
.await .await
.unwrap(); .unwrap();
let count = mem.count().await.unwrap(); let count = mem.count().await.unwrap();
@ -305,24 +317,24 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn markdown_list_by_category() { async fn markdown_list_by_category() {
let (_tmp, mem) = temp_workspace(); let (_tmp, mem) = temp_workspace();
mem.store("a", "core fact", MemoryCategory::Core) mem.store("a", "core fact", MemoryCategory::Core, None)
.await .await
.unwrap(); .unwrap();
mem.store("b", "daily note", MemoryCategory::Daily) mem.store("b", "daily note", MemoryCategory::Daily, None)
.await .await
.unwrap(); .unwrap();
let core = mem.list(Some(&MemoryCategory::Core)).await.unwrap(); let core = mem.list(Some(&MemoryCategory::Core), None).await.unwrap();
assert!(core.iter().all(|e| e.category == MemoryCategory::Core)); assert!(core.iter().all(|e| e.category == MemoryCategory::Core));
let daily = mem.list(Some(&MemoryCategory::Daily)).await.unwrap(); let daily = mem.list(Some(&MemoryCategory::Daily), None).await.unwrap();
assert!(daily.iter().all(|e| e.category == MemoryCategory::Daily)); assert!(daily.iter().all(|e| e.category == MemoryCategory::Daily));
} }
#[tokio::test] #[tokio::test]
async fn markdown_forget_is_noop() { async fn markdown_forget_is_noop() {
let (_tmp, mem) = temp_workspace(); let (_tmp, mem) = temp_workspace();
mem.store("a", "permanent", MemoryCategory::Core) mem.store("a", "permanent", MemoryCategory::Core, None)
.await .await
.unwrap(); .unwrap();
let removed = mem.forget("a").await.unwrap(); let removed = mem.forget("a").await.unwrap();
@ -332,7 +344,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn markdown_empty_recall() { async fn markdown_empty_recall() {
let (_tmp, mem) = temp_workspace(); let (_tmp, mem) = temp_workspace();
let results = mem.recall("anything", 10).await.unwrap(); let results = mem.recall("anything", 10, None).await.unwrap();
assert!(results.is_empty()); assert!(results.is_empty());
} }

View file

@ -25,11 +25,17 @@ impl Memory for NoneMemory {
_key: &str, _key: &str,
_content: &str, _content: &str,
_category: MemoryCategory, _category: MemoryCategory,
_session_id: Option<&str>,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
Ok(()) Ok(())
} }
async fn recall(&self, _query: &str, _limit: usize) -> anyhow::Result<Vec<MemoryEntry>> { async fn recall(
&self,
_query: &str,
_limit: usize,
_session_id: Option<&str>,
) -> anyhow::Result<Vec<MemoryEntry>> {
Ok(Vec::new()) Ok(Vec::new())
} }
@ -37,7 +43,11 @@ impl Memory for NoneMemory {
Ok(None) Ok(None)
} }
async fn list(&self, _category: Option<&MemoryCategory>) -> anyhow::Result<Vec<MemoryEntry>> { async fn list(
&self,
_category: Option<&MemoryCategory>,
_session_id: Option<&str>,
) -> anyhow::Result<Vec<MemoryEntry>> {
Ok(Vec::new()) Ok(Vec::new())
} }
@ -62,11 +72,14 @@ mod tests {
async fn none_memory_is_noop() { async fn none_memory_is_noop() {
let memory = NoneMemory::new(); let memory = NoneMemory::new();
memory.store("k", "v", MemoryCategory::Core).await.unwrap(); memory
.store("k", "v", MemoryCategory::Core, None)
.await
.unwrap();
assert!(memory.get("k").await.unwrap().is_none()); assert!(memory.get("k").await.unwrap().is_none());
assert!(memory.recall("k", 10).await.unwrap().is_empty()); assert!(memory.recall("k", 10, None).await.unwrap().is_empty());
assert!(memory.list(None).await.unwrap().is_empty()); assert!(memory.list(None, None).await.unwrap().is_empty());
assert!(!memory.forget("k").await.unwrap()); assert!(!memory.forget("k").await.unwrap());
assert_eq!(memory.count().await.unwrap(), 0); assert_eq!(memory.count().await.unwrap(), 0);
assert!(memory.health_check().await); assert!(memory.health_check().await);

View file

@ -157,7 +157,7 @@ impl ResponseCache {
|row| row.get(0), |row| row.get(0),
)?; )?;
#[allow(clippy::cast_sign_loss)] #[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)]
Ok((count as usize, hits as u64, tokens_saved as u64)) Ok((count as usize, hits as u64, tokens_saved as u64))
} }

View file

@ -124,6 +124,19 @@ impl SqliteMemory {
); );
CREATE INDEX IF NOT EXISTS idx_cache_accessed ON embedding_cache(accessed_at);", CREATE INDEX IF NOT EXISTS idx_cache_accessed ON embedding_cache(accessed_at);",
)?; )?;
// Migration: add session_id column if not present (safe to run repeatedly)
let has_session_id: bool = conn
.prepare("SELECT sql FROM sqlite_master WHERE type='table' AND name='memories'")?
.query_row([], |row| row.get::<_, String>(0))?
.contains("session_id");
if !has_session_id {
conn.execute_batch(
"ALTER TABLE memories ADD COLUMN session_id TEXT;
CREATE INDEX IF NOT EXISTS idx_memories_session ON memories(session_id);",
)?;
}
Ok(()) Ok(())
} }
@ -361,6 +374,7 @@ impl Memory for SqliteMemory {
key: &str, key: &str,
content: &str, content: &str,
category: MemoryCategory, category: MemoryCategory,
session_id: Option<&str>,
) -> anyhow::Result<()> { ) -> anyhow::Result<()> {
// Compute embedding (async, before lock) // Compute embedding (async, before lock)
let embedding_bytes = self let embedding_bytes = self
@ -377,20 +391,26 @@ impl Memory for SqliteMemory {
let id = Uuid::new_v4().to_string(); let id = Uuid::new_v4().to_string();
conn.execute( conn.execute(
"INSERT INTO memories (id, key, content, category, embedding, created_at, updated_at) "INSERT INTO memories (id, key, content, category, embedding, created_at, updated_at, session_id)
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)
ON CONFLICT(key) DO UPDATE SET ON CONFLICT(key) DO UPDATE SET
content = excluded.content, content = excluded.content,
category = excluded.category, category = excluded.category,
embedding = excluded.embedding, embedding = excluded.embedding,
updated_at = excluded.updated_at", updated_at = excluded.updated_at,
params![id, key, content, cat, embedding_bytes, now, now], session_id = excluded.session_id",
params![id, key, content, cat, embedding_bytes, now, now, session_id],
)?; )?;
Ok(()) Ok(())
} }
async fn recall(&self, query: &str, limit: usize) -> anyhow::Result<Vec<MemoryEntry>> { async fn recall(
&self,
query: &str,
limit: usize,
session_id: Option<&str>,
) -> anyhow::Result<Vec<MemoryEntry>> {
if query.trim().is_empty() { if query.trim().is_empty() {
return Ok(Vec::new()); return Ok(Vec::new());
} }
@ -439,7 +459,7 @@ impl Memory for SqliteMemory {
let mut results = Vec::new(); let mut results = Vec::new();
for scored in &merged { for scored in &merged {
let mut stmt = conn.prepare( let mut stmt = conn.prepare(
"SELECT id, key, content, category, created_at FROM memories WHERE id = ?1", "SELECT id, key, content, category, created_at, session_id FROM memories WHERE id = ?1",
)?; )?;
if let Ok(entry) = stmt.query_row(params![scored.id], |row| { if let Ok(entry) = stmt.query_row(params![scored.id], |row| {
Ok(MemoryEntry { Ok(MemoryEntry {
@ -448,10 +468,16 @@ impl Memory for SqliteMemory {
content: row.get(2)?, content: row.get(2)?,
category: Self::str_to_category(&row.get::<_, String>(3)?), category: Self::str_to_category(&row.get::<_, String>(3)?),
timestamp: row.get(4)?, timestamp: row.get(4)?,
session_id: None, session_id: row.get(5)?,
score: Some(f64::from(scored.final_score)), score: Some(f64::from(scored.final_score)),
}) })
}) { }) {
// Filter by session_id if requested
if let Some(sid) = session_id {
if entry.session_id.as_deref() != Some(sid) {
continue;
}
}
results.push(entry); results.push(entry);
} }
} }
@ -470,7 +496,7 @@ impl Memory for SqliteMemory {
.collect(); .collect();
let where_clause = conditions.join(" OR "); let where_clause = conditions.join(" OR ");
let sql = format!( let sql = format!(
"SELECT id, key, content, category, created_at FROM memories "SELECT id, key, content, category, created_at, session_id FROM memories
WHERE {where_clause} WHERE {where_clause}
ORDER BY updated_at DESC ORDER BY updated_at DESC
LIMIT ?{}", LIMIT ?{}",
@ -493,12 +519,18 @@ impl Memory for SqliteMemory {
content: row.get(2)?, content: row.get(2)?,
category: Self::str_to_category(&row.get::<_, String>(3)?), category: Self::str_to_category(&row.get::<_, String>(3)?),
timestamp: row.get(4)?, timestamp: row.get(4)?,
session_id: None, session_id: row.get(5)?,
score: Some(1.0), score: Some(1.0),
}) })
})?; })?;
for row in rows { for row in rows {
results.push(row?); let entry = row?;
if let Some(sid) = session_id {
if entry.session_id.as_deref() != Some(sid) {
continue;
}
}
results.push(entry);
} }
} }
} }
@ -514,7 +546,7 @@ impl Memory for SqliteMemory {
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?; .map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
let mut stmt = conn.prepare( let mut stmt = conn.prepare(
"SELECT id, key, content, category, created_at FROM memories WHERE key = ?1", "SELECT id, key, content, category, created_at, session_id FROM memories WHERE key = ?1",
)?; )?;
let mut rows = stmt.query_map(params![key], |row| { let mut rows = stmt.query_map(params![key], |row| {
@ -524,7 +556,7 @@ impl Memory for SqliteMemory {
content: row.get(2)?, content: row.get(2)?,
category: Self::str_to_category(&row.get::<_, String>(3)?), category: Self::str_to_category(&row.get::<_, String>(3)?),
timestamp: row.get(4)?, timestamp: row.get(4)?,
session_id: None, session_id: row.get(5)?,
score: None, score: None,
}) })
})?; })?;
@ -535,7 +567,11 @@ impl Memory for SqliteMemory {
} }
} }
async fn list(&self, category: Option<&MemoryCategory>) -> anyhow::Result<Vec<MemoryEntry>> { async fn list(
&self,
category: Option<&MemoryCategory>,
session_id: Option<&str>,
) -> anyhow::Result<Vec<MemoryEntry>> {
let conn = self let conn = self
.conn .conn
.lock() .lock()
@ -550,7 +586,7 @@ impl Memory for SqliteMemory {
content: row.get(2)?, content: row.get(2)?,
category: Self::str_to_category(&row.get::<_, String>(3)?), category: Self::str_to_category(&row.get::<_, String>(3)?),
timestamp: row.get(4)?, timestamp: row.get(4)?,
session_id: None, session_id: row.get(5)?,
score: None, score: None,
}) })
}; };
@ -558,21 +594,33 @@ impl Memory for SqliteMemory {
if let Some(cat) = category { if let Some(cat) = category {
let cat_str = Self::category_to_str(cat); let cat_str = Self::category_to_str(cat);
let mut stmt = conn.prepare( let mut stmt = conn.prepare(
"SELECT id, key, content, category, created_at FROM memories "SELECT id, key, content, category, created_at, session_id FROM memories
WHERE category = ?1 ORDER BY updated_at DESC", WHERE category = ?1 ORDER BY updated_at DESC",
)?; )?;
let rows = stmt.query_map(params![cat_str], row_mapper)?; let rows = stmt.query_map(params![cat_str], row_mapper)?;
for row in rows { for row in rows {
results.push(row?); let entry = row?;
if let Some(sid) = session_id {
if entry.session_id.as_deref() != Some(sid) {
continue;
}
}
results.push(entry);
} }
} else { } else {
let mut stmt = conn.prepare( let mut stmt = conn.prepare(
"SELECT id, key, content, category, created_at FROM memories "SELECT id, key, content, category, created_at, session_id FROM memories
ORDER BY updated_at DESC", ORDER BY updated_at DESC",
)?; )?;
let rows = stmt.query_map([], row_mapper)?; let rows = stmt.query_map([], row_mapper)?;
for row in rows { for row in rows {
results.push(row?); let entry = row?;
if let Some(sid) = session_id {
if entry.session_id.as_deref() != Some(sid) {
continue;
}
}
results.push(entry);
} }
} }
@ -632,7 +680,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn sqlite_store_and_get() { async fn sqlite_store_and_get() {
let (_tmp, mem) = temp_sqlite(); let (_tmp, mem) = temp_sqlite();
mem.store("user_lang", "Prefers Rust", MemoryCategory::Core) mem.store("user_lang", "Prefers Rust", MemoryCategory::Core, None)
.await .await
.unwrap(); .unwrap();
@ -647,10 +695,10 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn sqlite_store_upsert() { async fn sqlite_store_upsert() {
let (_tmp, mem) = temp_sqlite(); let (_tmp, mem) = temp_sqlite();
mem.store("pref", "likes Rust", MemoryCategory::Core) mem.store("pref", "likes Rust", MemoryCategory::Core, None)
.await .await
.unwrap(); .unwrap();
mem.store("pref", "loves Rust", MemoryCategory::Core) mem.store("pref", "loves Rust", MemoryCategory::Core, None)
.await .await
.unwrap(); .unwrap();
@ -662,17 +710,22 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn sqlite_recall_keyword() { async fn sqlite_recall_keyword() {
let (_tmp, mem) = temp_sqlite(); let (_tmp, mem) = temp_sqlite();
mem.store("a", "Rust is fast and safe", MemoryCategory::Core) mem.store("a", "Rust is fast and safe", MemoryCategory::Core, None)
.await .await
.unwrap(); .unwrap();
mem.store("b", "Python is interpreted", MemoryCategory::Core) mem.store("b", "Python is interpreted", MemoryCategory::Core, None)
.await
.unwrap();
mem.store("c", "Rust has zero-cost abstractions", MemoryCategory::Core)
.await .await
.unwrap(); .unwrap();
mem.store(
"c",
"Rust has zero-cost abstractions",
MemoryCategory::Core,
None,
)
.await
.unwrap();
let results = mem.recall("Rust", 10).await.unwrap(); let results = mem.recall("Rust", 10, None).await.unwrap();
assert_eq!(results.len(), 2); assert_eq!(results.len(), 2);
assert!(results assert!(results
.iter() .iter()
@ -682,14 +735,14 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn sqlite_recall_multi_keyword() { async fn sqlite_recall_multi_keyword() {
let (_tmp, mem) = temp_sqlite(); let (_tmp, mem) = temp_sqlite();
mem.store("a", "Rust is fast", MemoryCategory::Core) mem.store("a", "Rust is fast", MemoryCategory::Core, None)
.await .await
.unwrap(); .unwrap();
mem.store("b", "Rust is safe and fast", MemoryCategory::Core) mem.store("b", "Rust is safe and fast", MemoryCategory::Core, None)
.await .await
.unwrap(); .unwrap();
let results = mem.recall("fast safe", 10).await.unwrap(); let results = mem.recall("fast safe", 10, None).await.unwrap();
assert!(!results.is_empty()); assert!(!results.is_empty());
// Entry with both keywords should score higher // Entry with both keywords should score higher
assert!(results[0].content.contains("safe") && results[0].content.contains("fast")); assert!(results[0].content.contains("safe") && results[0].content.contains("fast"));
@ -698,17 +751,17 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn sqlite_recall_no_match() { async fn sqlite_recall_no_match() {
let (_tmp, mem) = temp_sqlite(); let (_tmp, mem) = temp_sqlite();
mem.store("a", "Rust rocks", MemoryCategory::Core) mem.store("a", "Rust rocks", MemoryCategory::Core, None)
.await .await
.unwrap(); .unwrap();
let results = mem.recall("javascript", 10).await.unwrap(); let results = mem.recall("javascript", 10, None).await.unwrap();
assert!(results.is_empty()); assert!(results.is_empty());
} }
#[tokio::test] #[tokio::test]
async fn sqlite_forget() { async fn sqlite_forget() {
let (_tmp, mem) = temp_sqlite(); let (_tmp, mem) = temp_sqlite();
mem.store("temp", "temporary data", MemoryCategory::Conversation) mem.store("temp", "temporary data", MemoryCategory::Conversation, None)
.await .await
.unwrap(); .unwrap();
assert_eq!(mem.count().await.unwrap(), 1); assert_eq!(mem.count().await.unwrap(), 1);
@ -728,29 +781,37 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn sqlite_list_all() { async fn sqlite_list_all() {
let (_tmp, mem) = temp_sqlite(); let (_tmp, mem) = temp_sqlite();
mem.store("a", "one", MemoryCategory::Core).await.unwrap(); mem.store("a", "one", MemoryCategory::Core, None)
mem.store("b", "two", MemoryCategory::Daily).await.unwrap(); .await
mem.store("c", "three", MemoryCategory::Conversation) .unwrap();
mem.store("b", "two", MemoryCategory::Daily, None)
.await
.unwrap();
mem.store("c", "three", MemoryCategory::Conversation, None)
.await .await
.unwrap(); .unwrap();
let all = mem.list(None).await.unwrap(); let all = mem.list(None, None).await.unwrap();
assert_eq!(all.len(), 3); assert_eq!(all.len(), 3);
} }
#[tokio::test] #[tokio::test]
async fn sqlite_list_by_category() { async fn sqlite_list_by_category() {
let (_tmp, mem) = temp_sqlite(); let (_tmp, mem) = temp_sqlite();
mem.store("a", "core1", MemoryCategory::Core).await.unwrap(); mem.store("a", "core1", MemoryCategory::Core, None)
mem.store("b", "core2", MemoryCategory::Core).await.unwrap(); .await
mem.store("c", "daily1", MemoryCategory::Daily) .unwrap();
mem.store("b", "core2", MemoryCategory::Core, None)
.await
.unwrap();
mem.store("c", "daily1", MemoryCategory::Daily, None)
.await .await
.unwrap(); .unwrap();
let core = mem.list(Some(&MemoryCategory::Core)).await.unwrap(); let core = mem.list(Some(&MemoryCategory::Core), None).await.unwrap();
assert_eq!(core.len(), 2); assert_eq!(core.len(), 2);
let daily = mem.list(Some(&MemoryCategory::Daily)).await.unwrap(); let daily = mem.list(Some(&MemoryCategory::Daily), None).await.unwrap();
assert_eq!(daily.len(), 1); assert_eq!(daily.len(), 1);
} }
@ -772,7 +833,7 @@ mod tests {
{ {
let mem = SqliteMemory::new(tmp.path()).unwrap(); let mem = SqliteMemory::new(tmp.path()).unwrap();
mem.store("persist", "I survive restarts", MemoryCategory::Core) mem.store("persist", "I survive restarts", MemoryCategory::Core, None)
.await .await
.unwrap(); .unwrap();
} }
@ -795,7 +856,7 @@ mod tests {
]; ];
for (i, cat) in categories.iter().enumerate() { for (i, cat) in categories.iter().enumerate() {
mem.store(&format!("k{i}"), &format!("v{i}"), cat.clone()) mem.store(&format!("k{i}"), &format!("v{i}"), cat.clone(), None)
.await .await
.unwrap(); .unwrap();
} }
@ -815,21 +876,28 @@ mod tests {
"a", "a",
"Rust is a systems programming language", "Rust is a systems programming language",
MemoryCategory::Core, MemoryCategory::Core,
None,
)
.await
.unwrap();
mem.store(
"b",
"Python is great for scripting",
MemoryCategory::Core,
None,
) )
.await .await
.unwrap(); .unwrap();
mem.store("b", "Python is great for scripting", MemoryCategory::Core)
.await
.unwrap();
mem.store( mem.store(
"c", "c",
"Rust and Rust and Rust everywhere", "Rust and Rust and Rust everywhere",
MemoryCategory::Core, MemoryCategory::Core,
None,
) )
.await .await
.unwrap(); .unwrap();
let results = mem.recall("Rust", 10).await.unwrap(); let results = mem.recall("Rust", 10, None).await.unwrap();
assert!(results.len() >= 2); assert!(results.len() >= 2);
// All results should contain "Rust" // All results should contain "Rust"
for r in &results { for r in &results {
@ -844,17 +912,17 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn fts5_multi_word_query() { async fn fts5_multi_word_query() {
let (_tmp, mem) = temp_sqlite(); let (_tmp, mem) = temp_sqlite();
mem.store("a", "The quick brown fox jumps", MemoryCategory::Core) mem.store("a", "The quick brown fox jumps", MemoryCategory::Core, None)
.await .await
.unwrap(); .unwrap();
mem.store("b", "A lazy dog sleeps", MemoryCategory::Core) mem.store("b", "A lazy dog sleeps", MemoryCategory::Core, None)
.await .await
.unwrap(); .unwrap();
mem.store("c", "The quick dog runs fast", MemoryCategory::Core) mem.store("c", "The quick dog runs fast", MemoryCategory::Core, None)
.await .await
.unwrap(); .unwrap();
let results = mem.recall("quick dog", 10).await.unwrap(); let results = mem.recall("quick dog", 10, None).await.unwrap();
assert!(!results.is_empty()); assert!(!results.is_empty());
// "The quick dog runs fast" matches both terms // "The quick dog runs fast" matches both terms
assert!(results[0].content.contains("quick")); assert!(results[0].content.contains("quick"));
@ -863,16 +931,20 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn recall_empty_query_returns_empty() { async fn recall_empty_query_returns_empty() {
let (_tmp, mem) = temp_sqlite(); let (_tmp, mem) = temp_sqlite();
mem.store("a", "data", MemoryCategory::Core).await.unwrap(); mem.store("a", "data", MemoryCategory::Core, None)
let results = mem.recall("", 10).await.unwrap(); .await
.unwrap();
let results = mem.recall("", 10, None).await.unwrap();
assert!(results.is_empty()); assert!(results.is_empty());
} }
#[tokio::test] #[tokio::test]
async fn recall_whitespace_query_returns_empty() { async fn recall_whitespace_query_returns_empty() {
let (_tmp, mem) = temp_sqlite(); let (_tmp, mem) = temp_sqlite();
mem.store("a", "data", MemoryCategory::Core).await.unwrap(); mem.store("a", "data", MemoryCategory::Core, None)
let results = mem.recall(" ", 10).await.unwrap(); .await
.unwrap();
let results = mem.recall(" ", 10, None).await.unwrap();
assert!(results.is_empty()); assert!(results.is_empty());
} }
@ -937,9 +1009,14 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn fts5_syncs_on_insert() { async fn fts5_syncs_on_insert() {
let (_tmp, mem) = temp_sqlite(); let (_tmp, mem) = temp_sqlite();
mem.store("test_key", "unique_searchterm_xyz", MemoryCategory::Core) mem.store(
.await "test_key",
.unwrap(); "unique_searchterm_xyz",
MemoryCategory::Core,
None,
)
.await
.unwrap();
let conn = mem.conn.lock(); let conn = mem.conn.lock();
let count: i64 = conn let count: i64 = conn
@ -955,9 +1032,14 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn fts5_syncs_on_delete() { async fn fts5_syncs_on_delete() {
let (_tmp, mem) = temp_sqlite(); let (_tmp, mem) = temp_sqlite();
mem.store("del_key", "deletable_content_abc", MemoryCategory::Core) mem.store(
.await "del_key",
.unwrap(); "deletable_content_abc",
MemoryCategory::Core,
None,
)
.await
.unwrap();
mem.forget("del_key").await.unwrap(); mem.forget("del_key").await.unwrap();
let conn = mem.conn.lock(); let conn = mem.conn.lock();
@ -974,10 +1056,15 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn fts5_syncs_on_update() { async fn fts5_syncs_on_update() {
let (_tmp, mem) = temp_sqlite(); let (_tmp, mem) = temp_sqlite();
mem.store("upd_key", "original_content_111", MemoryCategory::Core) mem.store(
.await "upd_key",
.unwrap(); "original_content_111",
mem.store("upd_key", "updated_content_222", MemoryCategory::Core) MemoryCategory::Core,
None,
)
.await
.unwrap();
mem.store("upd_key", "updated_content_222", MemoryCategory::Core, None)
.await .await
.unwrap(); .unwrap();
@ -1019,10 +1106,10 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn reindex_rebuilds_fts() { async fn reindex_rebuilds_fts() {
let (_tmp, mem) = temp_sqlite(); let (_tmp, mem) = temp_sqlite();
mem.store("r1", "reindex test alpha", MemoryCategory::Core) mem.store("r1", "reindex test alpha", MemoryCategory::Core, None)
.await .await
.unwrap(); .unwrap();
mem.store("r2", "reindex test beta", MemoryCategory::Core) mem.store("r2", "reindex test beta", MemoryCategory::Core, None)
.await .await
.unwrap(); .unwrap();
@ -1031,7 +1118,7 @@ mod tests {
assert_eq!(count, 0); assert_eq!(count, 0);
// FTS should still work after rebuild // FTS should still work after rebuild
let results = mem.recall("reindex", 10).await.unwrap(); let results = mem.recall("reindex", 10, None).await.unwrap();
assert_eq!(results.len(), 2); assert_eq!(results.len(), 2);
} }
@ -1045,12 +1132,13 @@ mod tests {
&format!("k{i}"), &format!("k{i}"),
&format!("common keyword item {i}"), &format!("common keyword item {i}"),
MemoryCategory::Core, MemoryCategory::Core,
None,
) )
.await .await
.unwrap(); .unwrap();
} }
let results = mem.recall("common keyword", 5).await.unwrap(); let results = mem.recall("common keyword", 5, None).await.unwrap();
assert!(results.len() <= 5); assert!(results.len() <= 5);
} }
@ -1059,11 +1147,11 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn recall_results_have_scores() { async fn recall_results_have_scores() {
let (_tmp, mem) = temp_sqlite(); let (_tmp, mem) = temp_sqlite();
mem.store("s1", "scored result test", MemoryCategory::Core) mem.store("s1", "scored result test", MemoryCategory::Core, None)
.await .await
.unwrap(); .unwrap();
let results = mem.recall("scored", 10).await.unwrap(); let results = mem.recall("scored", 10, None).await.unwrap();
assert!(!results.is_empty()); assert!(!results.is_empty());
for r in &results { for r in &results {
assert!(r.score.is_some(), "Expected score on result: {:?}", r.key); assert!(r.score.is_some(), "Expected score on result: {:?}", r.key);
@ -1075,11 +1163,11 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn recall_with_quotes_in_query() { async fn recall_with_quotes_in_query() {
let (_tmp, mem) = temp_sqlite(); let (_tmp, mem) = temp_sqlite();
mem.store("q1", "He said hello world", MemoryCategory::Core) mem.store("q1", "He said hello world", MemoryCategory::Core, None)
.await .await
.unwrap(); .unwrap();
// Quotes in query should not crash FTS5 // Quotes in query should not crash FTS5
let results = mem.recall("\"hello\"", 10).await.unwrap(); let results = mem.recall("\"hello\"", 10, None).await.unwrap();
// May or may not match depending on FTS5 escaping, but must not error // May or may not match depending on FTS5 escaping, but must not error
assert!(results.len() <= 10); assert!(results.len() <= 10);
} }
@ -1087,31 +1175,34 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn recall_with_asterisk_in_query() { async fn recall_with_asterisk_in_query() {
let (_tmp, mem) = temp_sqlite(); let (_tmp, mem) = temp_sqlite();
mem.store("a1", "wildcard test content", MemoryCategory::Core) mem.store("a1", "wildcard test content", MemoryCategory::Core, None)
.await .await
.unwrap(); .unwrap();
let results = mem.recall("wild*", 10).await.unwrap(); let results = mem.recall("wild*", 10, None).await.unwrap();
assert!(results.len() <= 10); assert!(results.len() <= 10);
} }
#[tokio::test] #[tokio::test]
async fn recall_with_parentheses_in_query() { async fn recall_with_parentheses_in_query() {
let (_tmp, mem) = temp_sqlite(); let (_tmp, mem) = temp_sqlite();
mem.store("p1", "function call test", MemoryCategory::Core) mem.store("p1", "function call test", MemoryCategory::Core, None)
.await .await
.unwrap(); .unwrap();
let results = mem.recall("function()", 10).await.unwrap(); let results = mem.recall("function()", 10, None).await.unwrap();
assert!(results.len() <= 10); assert!(results.len() <= 10);
} }
#[tokio::test] #[tokio::test]
async fn recall_with_sql_injection_attempt() { async fn recall_with_sql_injection_attempt() {
let (_tmp, mem) = temp_sqlite(); let (_tmp, mem) = temp_sqlite();
mem.store("safe", "normal content", MemoryCategory::Core) mem.store("safe", "normal content", MemoryCategory::Core, None)
.await .await
.unwrap(); .unwrap();
// Should not crash or leak data // Should not crash or leak data
let results = mem.recall("'; DROP TABLE memories; --", 10).await.unwrap(); let results = mem
.recall("'; DROP TABLE memories; --", 10, None)
.await
.unwrap();
assert!(results.len() <= 10); assert!(results.len() <= 10);
// Table should still exist // Table should still exist
assert_eq!(mem.count().await.unwrap(), 1); assert_eq!(mem.count().await.unwrap(), 1);
@ -1122,7 +1213,9 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn store_empty_content() { async fn store_empty_content() {
let (_tmp, mem) = temp_sqlite(); let (_tmp, mem) = temp_sqlite();
mem.store("empty", "", MemoryCategory::Core).await.unwrap(); mem.store("empty", "", MemoryCategory::Core, None)
.await
.unwrap();
let entry = mem.get("empty").await.unwrap().unwrap(); let entry = mem.get("empty").await.unwrap().unwrap();
assert_eq!(entry.content, ""); assert_eq!(entry.content, "");
} }
@ -1130,7 +1223,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn store_empty_key() { async fn store_empty_key() {
let (_tmp, mem) = temp_sqlite(); let (_tmp, mem) = temp_sqlite();
mem.store("", "content for empty key", MemoryCategory::Core) mem.store("", "content for empty key", MemoryCategory::Core, None)
.await .await
.unwrap(); .unwrap();
let entry = mem.get("").await.unwrap().unwrap(); let entry = mem.get("").await.unwrap().unwrap();
@ -1141,7 +1234,7 @@ mod tests {
async fn store_very_long_content() { async fn store_very_long_content() {
let (_tmp, mem) = temp_sqlite(); let (_tmp, mem) = temp_sqlite();
let long_content = "x".repeat(100_000); let long_content = "x".repeat(100_000);
mem.store("long", &long_content, MemoryCategory::Core) mem.store("long", &long_content, MemoryCategory::Core, None)
.await .await
.unwrap(); .unwrap();
let entry = mem.get("long").await.unwrap().unwrap(); let entry = mem.get("long").await.unwrap().unwrap();
@ -1151,9 +1244,14 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn store_unicode_and_emoji() { async fn store_unicode_and_emoji() {
let (_tmp, mem) = temp_sqlite(); let (_tmp, mem) = temp_sqlite();
mem.store("emoji_key_🦀", "こんにちは 🚀 Ñoño", MemoryCategory::Core) mem.store(
.await "emoji_key_🦀",
.unwrap(); "こんにちは 🚀 Ñoño",
MemoryCategory::Core,
None,
)
.await
.unwrap();
let entry = mem.get("emoji_key_🦀").await.unwrap().unwrap(); let entry = mem.get("emoji_key_🦀").await.unwrap().unwrap();
assert_eq!(entry.content, "こんにちは 🚀 Ñoño"); assert_eq!(entry.content, "こんにちは 🚀 Ñoño");
} }
@ -1162,7 +1260,7 @@ mod tests {
async fn store_content_with_newlines_and_tabs() { async fn store_content_with_newlines_and_tabs() {
let (_tmp, mem) = temp_sqlite(); let (_tmp, mem) = temp_sqlite();
let content = "line1\nline2\ttab\rcarriage\n\nnewparagraph"; let content = "line1\nline2\ttab\rcarriage\n\nnewparagraph";
mem.store("whitespace", content, MemoryCategory::Core) mem.store("whitespace", content, MemoryCategory::Core, None)
.await .await
.unwrap(); .unwrap();
let entry = mem.get("whitespace").await.unwrap().unwrap(); let entry = mem.get("whitespace").await.unwrap().unwrap();
@ -1174,11 +1272,11 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn recall_single_character_query() { async fn recall_single_character_query() {
let (_tmp, mem) = temp_sqlite(); let (_tmp, mem) = temp_sqlite();
mem.store("a", "x marks the spot", MemoryCategory::Core) mem.store("a", "x marks the spot", MemoryCategory::Core, None)
.await .await
.unwrap(); .unwrap();
// Single char may not match FTS5 but LIKE fallback should work // Single char may not match FTS5 but LIKE fallback should work
let results = mem.recall("x", 10).await.unwrap(); let results = mem.recall("x", 10, None).await.unwrap();
// Should not crash; may or may not find results // Should not crash; may or may not find results
assert!(results.len() <= 10); assert!(results.len() <= 10);
} }
@ -1186,23 +1284,23 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn recall_limit_zero() { async fn recall_limit_zero() {
let (_tmp, mem) = temp_sqlite(); let (_tmp, mem) = temp_sqlite();
mem.store("a", "some content", MemoryCategory::Core) mem.store("a", "some content", MemoryCategory::Core, None)
.await .await
.unwrap(); .unwrap();
let results = mem.recall("some", 0).await.unwrap(); let results = mem.recall("some", 0, None).await.unwrap();
assert!(results.is_empty()); assert!(results.is_empty());
} }
#[tokio::test] #[tokio::test]
async fn recall_limit_one() { async fn recall_limit_one() {
let (_tmp, mem) = temp_sqlite(); let (_tmp, mem) = temp_sqlite();
mem.store("a", "matching content alpha", MemoryCategory::Core) mem.store("a", "matching content alpha", MemoryCategory::Core, None)
.await .await
.unwrap(); .unwrap();
mem.store("b", "matching content beta", MemoryCategory::Core) mem.store("b", "matching content beta", MemoryCategory::Core, None)
.await .await
.unwrap(); .unwrap();
let results = mem.recall("matching content", 1).await.unwrap(); let results = mem.recall("matching content", 1, None).await.unwrap();
assert_eq!(results.len(), 1); assert_eq!(results.len(), 1);
} }
@ -1213,21 +1311,22 @@ mod tests {
"rust_preferences", "rust_preferences",
"User likes systems programming", "User likes systems programming",
MemoryCategory::Core, MemoryCategory::Core,
None,
) )
.await .await
.unwrap(); .unwrap();
// "rust" appears in key but not content — LIKE fallback checks key too // "rust" appears in key but not content — LIKE fallback checks key too
let results = mem.recall("rust", 10).await.unwrap(); let results = mem.recall("rust", 10, None).await.unwrap();
assert!(!results.is_empty(), "Should match by key"); assert!(!results.is_empty(), "Should match by key");
} }
#[tokio::test] #[tokio::test]
async fn recall_unicode_query() { async fn recall_unicode_query() {
let (_tmp, mem) = temp_sqlite(); let (_tmp, mem) = temp_sqlite();
mem.store("jp", "日本語のテスト", MemoryCategory::Core) mem.store("jp", "日本語のテスト", MemoryCategory::Core, None)
.await .await
.unwrap(); .unwrap();
let results = mem.recall("日本語", 10).await.unwrap(); let results = mem.recall("日本語", 10, None).await.unwrap();
assert!(!results.is_empty()); assert!(!results.is_empty());
} }
@ -1238,7 +1337,9 @@ mod tests {
let tmp = TempDir::new().unwrap(); let tmp = TempDir::new().unwrap();
{ {
let mem = SqliteMemory::new(tmp.path()).unwrap(); let mem = SqliteMemory::new(tmp.path()).unwrap();
mem.store("k1", "v1", MemoryCategory::Core).await.unwrap(); mem.store("k1", "v1", MemoryCategory::Core, None)
.await
.unwrap();
} }
// Open again — init_schema runs again on existing DB // Open again — init_schema runs again on existing DB
let mem2 = SqliteMemory::new(tmp.path()).unwrap(); let mem2 = SqliteMemory::new(tmp.path()).unwrap();
@ -1246,7 +1347,9 @@ mod tests {
assert!(entry.is_some()); assert!(entry.is_some());
assert_eq!(entry.unwrap().content, "v1"); assert_eq!(entry.unwrap().content, "v1");
// Store more data — should work fine // Store more data — should work fine
mem2.store("k2", "v2", MemoryCategory::Daily).await.unwrap(); mem2.store("k2", "v2", MemoryCategory::Daily, None)
.await
.unwrap();
assert_eq!(mem2.count().await.unwrap(), 2); assert_eq!(mem2.count().await.unwrap(), 2);
} }
@ -1264,11 +1367,16 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn forget_then_recall_no_ghost_results() { async fn forget_then_recall_no_ghost_results() {
let (_tmp, mem) = temp_sqlite(); let (_tmp, mem) = temp_sqlite();
mem.store("ghost", "phantom memory content", MemoryCategory::Core) mem.store(
.await "ghost",
.unwrap(); "phantom memory content",
MemoryCategory::Core,
None,
)
.await
.unwrap();
mem.forget("ghost").await.unwrap(); mem.forget("ghost").await.unwrap();
let results = mem.recall("phantom memory", 10).await.unwrap(); let results = mem.recall("phantom memory", 10, None).await.unwrap();
assert!( assert!(
results.is_empty(), results.is_empty(),
"Deleted memory should not appear in recall" "Deleted memory should not appear in recall"
@ -1278,11 +1386,11 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn forget_and_re_store_same_key() { async fn forget_and_re_store_same_key() {
let (_tmp, mem) = temp_sqlite(); let (_tmp, mem) = temp_sqlite();
mem.store("cycle", "version 1", MemoryCategory::Core) mem.store("cycle", "version 1", MemoryCategory::Core, None)
.await .await
.unwrap(); .unwrap();
mem.forget("cycle").await.unwrap(); mem.forget("cycle").await.unwrap();
mem.store("cycle", "version 2", MemoryCategory::Core) mem.store("cycle", "version 2", MemoryCategory::Core, None)
.await .await
.unwrap(); .unwrap();
let entry = mem.get("cycle").await.unwrap().unwrap(); let entry = mem.get("cycle").await.unwrap().unwrap();
@ -1302,14 +1410,14 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn reindex_twice_is_safe() { async fn reindex_twice_is_safe() {
let (_tmp, mem) = temp_sqlite(); let (_tmp, mem) = temp_sqlite();
mem.store("r1", "reindex data", MemoryCategory::Core) mem.store("r1", "reindex data", MemoryCategory::Core, None)
.await .await
.unwrap(); .unwrap();
mem.reindex().await.unwrap(); mem.reindex().await.unwrap();
let count = mem.reindex().await.unwrap(); let count = mem.reindex().await.unwrap();
assert_eq!(count, 0); // Noop embedder → nothing to re-embed assert_eq!(count, 0); // Noop embedder → nothing to re-embed
// Data should still be intact // Data should still be intact
let results = mem.recall("reindex", 10).await.unwrap(); let results = mem.recall("reindex", 10, None).await.unwrap();
assert_eq!(results.len(), 1); assert_eq!(results.len(), 1);
} }
@ -1363,18 +1471,28 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn list_custom_category() { async fn list_custom_category() {
let (_tmp, mem) = temp_sqlite(); let (_tmp, mem) = temp_sqlite();
mem.store("c1", "custom1", MemoryCategory::Custom("project".into())) mem.store(
.await "c1",
.unwrap(); "custom1",
mem.store("c2", "custom2", MemoryCategory::Custom("project".into())) MemoryCategory::Custom("project".into()),
.await None,
.unwrap(); )
mem.store("c3", "other", MemoryCategory::Core) .await
.unwrap();
mem.store(
"c2",
"custom2",
MemoryCategory::Custom("project".into()),
None,
)
.await
.unwrap();
mem.store("c3", "other", MemoryCategory::Core, None)
.await .await
.unwrap(); .unwrap();
let project = mem let project = mem
.list(Some(&MemoryCategory::Custom("project".into()))) .list(Some(&MemoryCategory::Custom("project".into())), None)
.await .await
.unwrap(); .unwrap();
assert_eq!(project.len(), 2); assert_eq!(project.len(), 2);
@ -1383,7 +1501,122 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn list_empty_db() { async fn list_empty_db() {
let (_tmp, mem) = temp_sqlite(); let (_tmp, mem) = temp_sqlite();
let all = mem.list(None).await.unwrap(); let all = mem.list(None, None).await.unwrap();
assert!(all.is_empty()); assert!(all.is_empty());
} }
// ── Session isolation ─────────────────────────────────────────
#[tokio::test]
async fn store_and_recall_with_session_id() {
let (_tmp, mem) = temp_sqlite();
mem.store("k1", "session A fact", MemoryCategory::Core, Some("sess-a"))
.await
.unwrap();
mem.store("k2", "session B fact", MemoryCategory::Core, Some("sess-b"))
.await
.unwrap();
mem.store("k3", "no session fact", MemoryCategory::Core, None)
.await
.unwrap();
// Recall with session-a filter returns only session-a entry
let results = mem.recall("fact", 10, Some("sess-a")).await.unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].key, "k1");
assert_eq!(results[0].session_id.as_deref(), Some("sess-a"));
}
#[tokio::test]
async fn recall_no_session_filter_returns_all() {
let (_tmp, mem) = temp_sqlite();
mem.store("k1", "alpha fact", MemoryCategory::Core, Some("sess-a"))
.await
.unwrap();
mem.store("k2", "beta fact", MemoryCategory::Core, Some("sess-b"))
.await
.unwrap();
mem.store("k3", "gamma fact", MemoryCategory::Core, None)
.await
.unwrap();
// Recall without session filter returns all matching entries
let results = mem.recall("fact", 10, None).await.unwrap();
assert_eq!(results.len(), 3);
}
#[tokio::test]
async fn cross_session_recall_isolation() {
let (_tmp, mem) = temp_sqlite();
mem.store(
"secret",
"session A secret data",
MemoryCategory::Core,
Some("sess-a"),
)
.await
.unwrap();
// Session B cannot see session A data
let results = mem.recall("secret", 10, Some("sess-b")).await.unwrap();
assert!(results.is_empty());
// Session A can see its own data
let results = mem.recall("secret", 10, Some("sess-a")).await.unwrap();
assert_eq!(results.len(), 1);
}
#[tokio::test]
async fn list_with_session_filter() {
let (_tmp, mem) = temp_sqlite();
mem.store("k1", "a1", MemoryCategory::Core, Some("sess-a"))
.await
.unwrap();
mem.store("k2", "a2", MemoryCategory::Conversation, Some("sess-a"))
.await
.unwrap();
mem.store("k3", "b1", MemoryCategory::Core, Some("sess-b"))
.await
.unwrap();
mem.store("k4", "none1", MemoryCategory::Core, None)
.await
.unwrap();
// List with session-a filter
let results = mem.list(None, Some("sess-a")).await.unwrap();
assert_eq!(results.len(), 2);
assert!(results
.iter()
.all(|e| e.session_id.as_deref() == Some("sess-a")));
// List with session-a + category filter
let results = mem
.list(Some(&MemoryCategory::Core), Some("sess-a"))
.await
.unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].key, "k1");
}
#[tokio::test]
async fn schema_migration_idempotent_on_reopen() {
let tmp = TempDir::new().unwrap();
// First open: creates schema + migration
{
let mem = SqliteMemory::new(tmp.path()).unwrap();
mem.store("k1", "before reopen", MemoryCategory::Core, Some("sess-x"))
.await
.unwrap();
}
// Second open: migration runs again but is idempotent
{
let mem = SqliteMemory::new(tmp.path()).unwrap();
let results = mem.recall("reopen", 10, Some("sess-x")).await.unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].key, "k1");
assert_eq!(results[0].session_id.as_deref(), Some("sess-x"));
}
}
} }

View file

@ -44,18 +44,32 @@ pub trait Memory: Send + Sync {
/// Backend name /// Backend name
fn name(&self) -> &str; fn name(&self) -> &str;
/// Store a memory entry /// Store a memory entry, optionally scoped to a session
async fn store(&self, key: &str, content: &str, category: MemoryCategory) async fn store(
-> anyhow::Result<()>; &self,
key: &str,
content: &str,
category: MemoryCategory,
session_id: Option<&str>,
) -> anyhow::Result<()>;
/// Recall memories matching a query (keyword search) /// Recall memories matching a query (keyword search), optionally scoped to a session
async fn recall(&self, query: &str, limit: usize) -> anyhow::Result<Vec<MemoryEntry>>; async fn recall(
&self,
query: &str,
limit: usize,
session_id: Option<&str>,
) -> anyhow::Result<Vec<MemoryEntry>>;
/// Get a specific memory by key /// Get a specific memory by key
async fn get(&self, key: &str) -> anyhow::Result<Option<MemoryEntry>>; async fn get(&self, key: &str) -> anyhow::Result<Option<MemoryEntry>>;
/// List all memory keys, optionally filtered by category /// List all memory keys, optionally filtered by category and/or session
async fn list(&self, category: Option<&MemoryCategory>) -> anyhow::Result<Vec<MemoryEntry>>; async fn list(
&self,
category: Option<&MemoryCategory>,
session_id: Option<&str>,
) -> anyhow::Result<Vec<MemoryEntry>>;
/// Remove a memory by key /// Remove a memory by key
async fn forget(&self, key: &str) -> anyhow::Result<bool>; async fn forget(&self, key: &str) -> anyhow::Result<bool>;

View file

@ -95,7 +95,9 @@ async fn migrate_openclaw_memory(
stats.renamed_conflicts += 1; stats.renamed_conflicts += 1;
} }
memory.store(&key, &entry.content, entry.category).await?; memory
.store(&key, &entry.content, entry.category, None)
.await?;
stats.imported += 1; stats.imported += 1;
} }
@ -488,7 +490,7 @@ mod tests {
// Existing target memory // Existing target memory
let target_mem = SqliteMemory::new(target.path()).unwrap(); let target_mem = SqliteMemory::new(target.path()).unwrap();
target_mem target_mem
.store("k", "new value", MemoryCategory::Core) .store("k", "new value", MemoryCategory::Core, None)
.await .await
.unwrap(); .unwrap();
@ -510,7 +512,7 @@ mod tests {
.await .await
.unwrap(); .unwrap();
let all = target_mem.list(None).await.unwrap(); let all = target_mem.list(None, None).await.unwrap();
assert!(all.iter().any(|e| e.key == "k" && e.content == "new value")); assert!(all.iter().any(|e| e.key == "k" && e.content == "new value"));
assert!(all assert!(all
.iter() .iter()

View file

@ -48,9 +48,10 @@ impl Observer for LogObserver {
ObserverEvent::AgentEnd { ObserverEvent::AgentEnd {
duration, duration,
tokens_used, tokens_used,
cost_usd,
} => { } => {
let ms = u64::try_from(duration.as_millis()).unwrap_or(u64::MAX); let ms = u64::try_from(duration.as_millis()).unwrap_or(u64::MAX);
info!(duration_ms = ms, tokens = ?tokens_used, "agent.end"); info!(duration_ms = ms, tokens = ?tokens_used, cost_usd = ?cost_usd, "agent.end");
} }
ObserverEvent::ToolCallStart { tool } => { ObserverEvent::ToolCallStart { tool } => {
info!(tool = %tool, "tool.start"); info!(tool = %tool, "tool.start");
@ -133,10 +134,12 @@ mod tests {
obs.record_event(&ObserverEvent::AgentEnd { obs.record_event(&ObserverEvent::AgentEnd {
duration: Duration::from_millis(500), duration: Duration::from_millis(500),
tokens_used: Some(100), tokens_used: Some(100),
cost_usd: Some(0.0015),
}); });
obs.record_event(&ObserverEvent::AgentEnd { obs.record_event(&ObserverEvent::AgentEnd {
duration: Duration::ZERO, duration: Duration::ZERO,
tokens_used: None, tokens_used: None,
cost_usd: None,
}); });
obs.record_event(&ObserverEvent::ToolCallStart { obs.record_event(&ObserverEvent::ToolCallStart {
tool: "shell".into(), tool: "shell".into(),

View file

@ -48,10 +48,12 @@ mod tests {
obs.record_event(&ObserverEvent::AgentEnd { obs.record_event(&ObserverEvent::AgentEnd {
duration: Duration::from_millis(100), duration: Duration::from_millis(100),
tokens_used: Some(42), tokens_used: Some(42),
cost_usd: Some(0.001),
}); });
obs.record_event(&ObserverEvent::AgentEnd { obs.record_event(&ObserverEvent::AgentEnd {
duration: Duration::ZERO, duration: Duration::ZERO,
tokens_used: None, tokens_used: None,
cost_usd: None,
}); });
obs.record_event(&ObserverEvent::ToolCallStart { obs.record_event(&ObserverEvent::ToolCallStart {
tool: "shell".into(), tool: "shell".into(),

View file

@ -227,6 +227,7 @@ impl Observer for OtelObserver {
ObserverEvent::AgentEnd { ObserverEvent::AgentEnd {
duration, duration,
tokens_used, tokens_used,
cost_usd,
} => { } => {
let secs = duration.as_secs_f64(); let secs = duration.as_secs_f64();
let start_time = SystemTime::now() let start_time = SystemTime::now()
@ -243,6 +244,9 @@ impl Observer for OtelObserver {
if let Some(t) = tokens_used { if let Some(t) = tokens_used {
span.set_attribute(KeyValue::new("tokens_used", *t as i64)); span.set_attribute(KeyValue::new("tokens_used", *t as i64));
} }
if let Some(c) = cost_usd {
span.set_attribute(KeyValue::new("cost_usd", *c));
}
span.end(); span.end();
self.agent_duration.record(secs, &[]); self.agent_duration.record(secs, &[]);
@ -394,10 +398,12 @@ mod tests {
obs.record_event(&ObserverEvent::AgentEnd { obs.record_event(&ObserverEvent::AgentEnd {
duration: Duration::from_millis(500), duration: Duration::from_millis(500),
tokens_used: Some(100), tokens_used: Some(100),
cost_usd: Some(0.0015),
}); });
obs.record_event(&ObserverEvent::AgentEnd { obs.record_event(&ObserverEvent::AgentEnd {
duration: Duration::ZERO, duration: Duration::ZERO,
tokens_used: None, tokens_used: None,
cost_usd: None,
}); });
obs.record_event(&ObserverEvent::ToolCallStart { obs.record_event(&ObserverEvent::ToolCallStart {
tool: "shell".into(), tool: "shell".into(),

View file

@ -27,6 +27,7 @@ pub enum ObserverEvent {
AgentEnd { AgentEnd {
duration: Duration, duration: Duration,
tokens_used: Option<u64>, tokens_used: Option<u64>,
cost_usd: Option<f64>,
}, },
/// A tool call is about to be executed. /// A tool call is about to be executed.
ToolCallStart { ToolCallStart {

View file

@ -106,6 +106,7 @@ pub fn run_wizard() -> Result<Config> {
} else { } else {
Some(api_key) Some(api_key)
}, },
api_url: None,
default_provider: Some(provider), default_provider: Some(provider),
default_model: Some(model), default_model: Some(model),
default_temperature: 0.7, default_temperature: 0.7,
@ -284,7 +285,7 @@ fn memory_config_defaults_for_backend(backend: &str) -> MemoryConfig {
#[allow(clippy::too_many_lines)] #[allow(clippy::too_many_lines)]
pub fn run_quick_setup( pub fn run_quick_setup(
api_key: Option<&str>, credential_override: Option<&str>,
provider: Option<&str>, provider: Option<&str>,
memory_backend: Option<&str>, memory_backend: Option<&str>,
) -> Result<Config> { ) -> Result<Config> {
@ -318,7 +319,8 @@ pub fn run_quick_setup(
let config = Config { let config = Config {
workspace_dir: workspace_dir.clone(), workspace_dir: workspace_dir.clone(),
config_path: config_path.clone(), config_path: config_path.clone(),
api_key: api_key.map(String::from), api_key: credential_override.map(String::from),
api_url: None,
default_provider: Some(provider_name.clone()), default_provider: Some(provider_name.clone()),
default_model: Some(model.clone()), default_model: Some(model.clone()),
default_temperature: 0.7, default_temperature: 0.7,
@ -377,7 +379,7 @@ pub fn run_quick_setup(
println!( println!(
" {} API Key: {}", " {} API Key: {}",
style("").green().bold(), style("").green().bold(),
if api_key.is_some() { if credential_override.is_some() {
style("set").green() style("set").green()
} else { } else {
style("not set (use --api-key or edit config.toml)").yellow() style("not set (use --api-key or edit config.toml)").yellow()
@ -426,7 +428,7 @@ pub fn run_quick_setup(
); );
println!(); println!();
println!(" {}", style("Next steps:").white().bold()); println!(" {}", style("Next steps:").white().bold());
if api_key.is_none() { if credential_override.is_none() {
println!(" 1. Set your API key: export OPENROUTER_API_KEY=\"sk-...\""); println!(" 1. Set your API key: export OPENROUTER_API_KEY=\"sk-...\"");
println!(" 2. Or edit: ~/.zeroclaw/config.toml"); println!(" 2. Or edit: ~/.zeroclaw/config.toml");
println!(" 3. Chat: zeroclaw agent -m \"Hello!\""); println!(" 3. Chat: zeroclaw agent -m \"Hello!\"");
@ -2269,14 +2271,11 @@ fn setup_memory() -> Result<MemoryConfig> {
let backend = backend_key_from_choice(choice); let backend = backend_key_from_choice(choice);
let profile = memory_backend_profile(backend); let profile = memory_backend_profile(backend);
let auto_save = if !profile.auto_save_default { let auto_save = profile.auto_save_default
false && Confirm::new()
} else {
Confirm::new()
.with_prompt(" Auto-save conversations to memory?") .with_prompt(" Auto-save conversations to memory?")
.default(true) .default(true)
.interact()? .interact()?;
};
println!( println!(
" {} Memory: {} (auto-save: {})", " {} Memory: {} (auto-save: {})",
@ -2587,6 +2586,7 @@ fn setup_channels() -> Result<ChannelsConfig> {
guild_id: if guild.is_empty() { None } else { Some(guild) }, guild_id: if guild.is_empty() { None } else { Some(guild) },
allowed_users, allowed_users,
listen_to_bots: false, listen_to_bots: false,
mention_only: false,
}); });
} }
2 => { 2 => {
@ -2799,22 +2799,14 @@ fn setup_channels() -> Result<ChannelsConfig> {
.header("Authorization", format!("Bearer {access_token_clone}")) .header("Authorization", format!("Bearer {access_token_clone}"))
.send()?; .send()?;
let ok = resp.status().is_success(); let ok = resp.status().is_success();
let data: serde_json::Value = resp.json().unwrap_or_default(); Ok::<_, reqwest::Error>(ok)
let user_id = data
.get("user_id")
.and_then(serde_json::Value::as_str)
.unwrap_or("unknown")
.to_string();
Ok::<_, reqwest::Error>((ok, user_id))
}) })
.join(); .join();
match thread_result { match thread_result {
Ok(Ok((true, user_id))) => { Ok(Ok(true)) => println!(
println!( "\r {} Connection verified ",
"\r {} Connected as {user_id} ", style("").green().bold()
style("").green().bold() ),
);
}
_ => { _ => {
println!( println!(
"\r {} Connection failed — check homeserver URL and token", "\r {} Connection failed — check homeserver URL and token",
@ -3779,15 +3771,7 @@ fn print_summary(config: &Config) {
); );
// Secrets // Secrets
println!( println!(" {} Secrets: configured", style("🔒").cyan());
" {} Secrets: {}",
style("🔒").cyan(),
if config.secrets.encrypt {
style("encrypted").green().to_string()
} else {
style("plaintext").yellow().to_string()
}
);
// Gateway // Gateway
println!( println!(

View file

@ -38,6 +38,10 @@ pub fn ensure_arduino_cli() -> Result<()> {
anyhow::bail!("brew install arduino-cli failed. Install manually: https://arduino.github.io/arduino-cli/"); anyhow::bail!("brew install arduino-cli failed. Install manually: https://arduino.github.io/arduino-cli/");
} }
println!("arduino-cli installed."); println!("arduino-cli installed.");
if !arduino_cli_available() {
anyhow::bail!("arduino-cli still not found after install. Ensure it's in PATH.");
}
return Ok(());
} }
#[cfg(target_os = "linux")] #[cfg(target_os = "linux")]
@ -54,11 +58,6 @@ pub fn ensure_arduino_cli() -> Result<()> {
println!("arduino-cli not found. Install it: https://arduino.github.io/arduino-cli/"); println!("arduino-cli not found. Install it: https://arduino.github.io/arduino-cli/");
anyhow::bail!("arduino-cli not installed."); anyhow::bail!("arduino-cli not installed.");
} }
if !arduino_cli_available() {
anyhow::bail!("arduino-cli still not found after install. Ensure it's in PATH.");
}
Ok(())
} }
/// Ensure arduino:avr core is installed. /// Ensure arduino:avr core is installed.

View file

@ -112,6 +112,7 @@ pub struct SerialPeripheral {
impl SerialPeripheral { impl SerialPeripheral {
/// Create and connect to a serial peripheral. /// Create and connect to a serial peripheral.
#[allow(clippy::unused_async)]
pub async fn connect(config: &PeripheralBoardConfig) -> anyhow::Result<Self> { pub async fn connect(config: &PeripheralBoardConfig) -> anyhow::Result<Self> {
let path = config let path = config
.path .path

View file

@ -106,17 +106,17 @@ struct NativeContentIn {
} }
impl AnthropicProvider { impl AnthropicProvider {
pub fn new(api_key: Option<&str>) -> Self { pub fn new(credential: Option<&str>) -> Self {
Self::with_base_url(api_key, None) Self::with_base_url(credential, None)
} }
pub fn with_base_url(api_key: Option<&str>, base_url: Option<&str>) -> Self { pub fn with_base_url(credential: Option<&str>, base_url: Option<&str>) -> Self {
let base_url = base_url let base_url = base_url
.map(|u| u.trim_end_matches('/')) .map(|u| u.trim_end_matches('/'))
.unwrap_or("https://api.anthropic.com") .unwrap_or("https://api.anthropic.com")
.to_string(); .to_string();
Self { Self {
credential: api_key credential: credential
.map(str::trim) .map(str::trim)
.filter(|k| !k.is_empty()) .filter(|k| !k.is_empty())
.map(ToString::to_string), .map(ToString::to_string),
@ -410,9 +410,9 @@ mod tests {
#[test] #[test]
fn creates_with_key() { fn creates_with_key() {
let p = AnthropicProvider::new(Some("sk-ant-test123")); let p = AnthropicProvider::new(Some("anthropic-test-credential"));
assert!(p.credential.is_some()); assert!(p.credential.is_some());
assert_eq!(p.credential.as_deref(), Some("sk-ant-test123")); assert_eq!(p.credential.as_deref(), Some("anthropic-test-credential"));
assert_eq!(p.base_url, "https://api.anthropic.com"); assert_eq!(p.base_url, "https://api.anthropic.com");
} }
@ -431,17 +431,19 @@ mod tests {
#[test] #[test]
fn creates_with_whitespace_key() { fn creates_with_whitespace_key() {
let p = AnthropicProvider::new(Some(" sk-ant-test123 ")); let p = AnthropicProvider::new(Some(" anthropic-test-credential "));
assert!(p.credential.is_some()); assert!(p.credential.is_some());
assert_eq!(p.credential.as_deref(), Some("sk-ant-test123")); assert_eq!(p.credential.as_deref(), Some("anthropic-test-credential"));
} }
#[test] #[test]
fn creates_with_custom_base_url() { fn creates_with_custom_base_url() {
let p = let p = AnthropicProvider::with_base_url(
AnthropicProvider::with_base_url(Some("sk-ant-test"), Some("https://api.example.com")); Some("anthropic-credential"),
Some("https://api.example.com"),
);
assert_eq!(p.base_url, "https://api.example.com"); assert_eq!(p.base_url, "https://api.example.com");
assert_eq!(p.credential.as_deref(), Some("sk-ant-test")); assert_eq!(p.credential.as_deref(), Some("anthropic-credential"));
} }
#[test] #[test]

View file

@ -17,7 +17,7 @@ use serde::{Deserialize, Serialize};
pub struct OpenAiCompatibleProvider { pub struct OpenAiCompatibleProvider {
pub(crate) name: String, pub(crate) name: String,
pub(crate) base_url: String, pub(crate) base_url: String,
pub(crate) api_key: Option<String>, pub(crate) credential: Option<String>,
pub(crate) auth_header: AuthStyle, pub(crate) auth_header: AuthStyle,
/// When false, do not fall back to /v1/responses on chat completions 404. /// When false, do not fall back to /v1/responses on chat completions 404.
/// GLM/Zhipu does not support the responses API. /// GLM/Zhipu does not support the responses API.
@ -37,11 +37,16 @@ pub enum AuthStyle {
} }
impl OpenAiCompatibleProvider { impl OpenAiCompatibleProvider {
pub fn new(name: &str, base_url: &str, api_key: Option<&str>, auth_style: AuthStyle) -> Self { pub fn new(
name: &str,
base_url: &str,
credential: Option<&str>,
auth_style: AuthStyle,
) -> Self {
Self { Self {
name: name.to_string(), name: name.to_string(),
base_url: base_url.trim_end_matches('/').to_string(), base_url: base_url.trim_end_matches('/').to_string(),
api_key: api_key.map(ToString::to_string), credential: credential.map(ToString::to_string),
auth_header: auth_style, auth_header: auth_style,
supports_responses_fallback: true, supports_responses_fallback: true,
client: Client::builder() client: Client::builder()
@ -57,13 +62,13 @@ impl OpenAiCompatibleProvider {
pub fn new_no_responses_fallback( pub fn new_no_responses_fallback(
name: &str, name: &str,
base_url: &str, base_url: &str,
api_key: Option<&str>, credential: Option<&str>,
auth_style: AuthStyle, auth_style: AuthStyle,
) -> Self { ) -> Self {
Self { Self {
name: name.to_string(), name: name.to_string(),
base_url: base_url.trim_end_matches('/').to_string(), base_url: base_url.trim_end_matches('/').to_string(),
api_key: api_key.map(ToString::to_string), credential: credential.map(ToString::to_string),
auth_header: auth_style, auth_header: auth_style,
supports_responses_fallback: false, supports_responses_fallback: false,
client: Client::builder() client: Client::builder()
@ -405,18 +410,18 @@ impl OpenAiCompatibleProvider {
fn apply_auth_header( fn apply_auth_header(
&self, &self,
req: reqwest::RequestBuilder, req: reqwest::RequestBuilder,
api_key: &str, credential: &str,
) -> reqwest::RequestBuilder { ) -> reqwest::RequestBuilder {
match &self.auth_header { match &self.auth_header {
AuthStyle::Bearer => req.header("Authorization", format!("Bearer {api_key}")), AuthStyle::Bearer => req.header("Authorization", format!("Bearer {credential}")),
AuthStyle::XApiKey => req.header("x-api-key", api_key), AuthStyle::XApiKey => req.header("x-api-key", credential),
AuthStyle::Custom(header) => req.header(header, api_key), AuthStyle::Custom(header) => req.header(header, credential),
} }
} }
async fn chat_via_responses( async fn chat_via_responses(
&self, &self,
api_key: &str, credential: &str,
system_prompt: Option<&str>, system_prompt: Option<&str>,
message: &str, message: &str,
model: &str, model: &str,
@ -434,7 +439,7 @@ impl OpenAiCompatibleProvider {
let url = self.responses_url(); let url = self.responses_url();
let response = self let response = self
.apply_auth_header(self.client.post(&url).json(&request), api_key) .apply_auth_header(self.client.post(&url).json(&request), credential)
.send() .send()
.await?; .await?;
@ -459,7 +464,7 @@ impl Provider for OpenAiCompatibleProvider {
model: &str, model: &str,
temperature: f64, temperature: f64,
) -> anyhow::Result<String> { ) -> anyhow::Result<String> {
let api_key = self.api_key.as_ref().ok_or_else(|| { let credential = self.credential.as_ref().ok_or_else(|| {
anyhow::anyhow!( anyhow::anyhow!(
"{} API key not set. Run `zeroclaw onboard` or set the appropriate env var.", "{} API key not set. Run `zeroclaw onboard` or set the appropriate env var.",
self.name self.name
@ -490,7 +495,7 @@ impl Provider for OpenAiCompatibleProvider {
let url = self.chat_completions_url(); let url = self.chat_completions_url();
let response = self let response = self
.apply_auth_header(self.client.post(&url).json(&request), api_key) .apply_auth_header(self.client.post(&url).json(&request), credential)
.send() .send()
.await?; .await?;
@ -501,7 +506,7 @@ impl Provider for OpenAiCompatibleProvider {
if status == reqwest::StatusCode::NOT_FOUND && self.supports_responses_fallback { if status == reqwest::StatusCode::NOT_FOUND && self.supports_responses_fallback {
return self return self
.chat_via_responses(api_key, system_prompt, message, model) .chat_via_responses(credential, system_prompt, message, model)
.await .await
.map_err(|responses_err| { .map_err(|responses_err| {
anyhow::anyhow!( anyhow::anyhow!(
@ -545,7 +550,7 @@ impl Provider for OpenAiCompatibleProvider {
model: &str, model: &str,
temperature: f64, temperature: f64,
) -> anyhow::Result<String> { ) -> anyhow::Result<String> {
let api_key = self.api_key.as_ref().ok_or_else(|| { let credential = self.credential.as_ref().ok_or_else(|| {
anyhow::anyhow!( anyhow::anyhow!(
"{} API key not set. Run `zeroclaw onboard` or set the appropriate env var.", "{} API key not set. Run `zeroclaw onboard` or set the appropriate env var.",
self.name self.name
@ -569,7 +574,7 @@ impl Provider for OpenAiCompatibleProvider {
let url = self.chat_completions_url(); let url = self.chat_completions_url();
let response = self let response = self
.apply_auth_header(self.client.post(&url).json(&request), api_key) .apply_auth_header(self.client.post(&url).json(&request), credential)
.send() .send()
.await?; .await?;
@ -584,7 +589,7 @@ impl Provider for OpenAiCompatibleProvider {
if let Some(user_msg) = last_user { if let Some(user_msg) = last_user {
return self return self
.chat_via_responses( .chat_via_responses(
api_key, credential,
system.map(|m| m.content.as_str()), system.map(|m| m.content.as_str()),
&user_msg.content, &user_msg.content,
model, model,
@ -791,16 +796,20 @@ mod tests {
#[test] #[test]
fn creates_with_key() { fn creates_with_key() {
let p = make_provider("venice", "https://api.venice.ai", Some("vn-key")); let p = make_provider(
"venice",
"https://api.venice.ai",
Some("venice-test-credential"),
);
assert_eq!(p.name, "venice"); assert_eq!(p.name, "venice");
assert_eq!(p.base_url, "https://api.venice.ai"); assert_eq!(p.base_url, "https://api.venice.ai");
assert_eq!(p.api_key.as_deref(), Some("vn-key")); assert_eq!(p.credential.as_deref(), Some("venice-test-credential"));
} }
#[test] #[test]
fn creates_without_key() { fn creates_without_key() {
let p = make_provider("test", "https://example.com", None); let p = make_provider("test", "https://example.com", None);
assert!(p.api_key.is_none()); assert!(p.credential.is_none());
} }
#[test] #[test]
@ -894,6 +903,7 @@ mod tests {
make_provider("Groq", "https://api.groq.com/openai", None), make_provider("Groq", "https://api.groq.com/openai", None),
make_provider("Mistral", "https://api.mistral.ai", None), make_provider("Mistral", "https://api.mistral.ai", None),
make_provider("xAI", "https://api.x.ai", None), make_provider("xAI", "https://api.x.ai", None),
make_provider("Astrai", "https://as-trai.com/v1", None),
]; ];
for p in providers { for p in providers {

705
src/providers/copilot.rs Normal file
View file

@ -0,0 +1,705 @@
//! GitHub Copilot provider with OAuth device-flow authentication.
//!
//! Authenticates via GitHub's device code flow (same as VS Code Copilot),
//! then exchanges the OAuth token for short-lived Copilot API keys.
//! Tokens are cached to disk and auto-refreshed.
//!
//! **Note:** This uses VS Code's OAuth client ID (`Iv1.b507a08c87ecfe98`) and
//! editor headers. This is the same approach used by LiteLLM, Codex CLI,
//! and other third-party Copilot integrations. The Copilot token endpoint is
//! private; there is no public OAuth scope or app registration for it.
//! GitHub could change or revoke this at any time, which would break all
//! third-party integrations simultaneously.
use crate::providers::traits::{
ChatMessage, ChatRequest as ProviderChatRequest, ChatResponse as ProviderChatResponse,
Provider, ToolCall as ProviderToolCall,
};
use crate::tools::ToolSpec;
use async_trait::async_trait;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::Mutex;
use tracing::warn;
/// GitHub OAuth client ID for Copilot (VS Code extension).
const GITHUB_CLIENT_ID: &str = "Iv1.b507a08c87ecfe98";
const GITHUB_DEVICE_CODE_URL: &str = "https://github.com/login/device/code";
const GITHUB_ACCESS_TOKEN_URL: &str = "https://github.com/login/oauth/access_token";
const GITHUB_API_KEY_URL: &str = "https://api.github.com/copilot_internal/v2/token";
const DEFAULT_API: &str = "https://api.githubcopilot.com";
// ── Token types ──────────────────────────────────────────────────
#[derive(Debug, Deserialize)]
struct DeviceCodeResponse {
device_code: String,
user_code: String,
verification_uri: String,
#[serde(default = "default_interval")]
interval: u64,
#[serde(default = "default_expires_in")]
expires_in: u64,
}
fn default_interval() -> u64 {
5
}
fn default_expires_in() -> u64 {
900
}
#[derive(Debug, Deserialize)]
struct AccessTokenResponse {
access_token: Option<String>,
error: Option<String>,
}
#[derive(Debug, Serialize, Deserialize)]
struct ApiKeyInfo {
token: String,
expires_at: i64,
#[serde(default)]
endpoints: Option<ApiEndpoints>,
}
#[derive(Debug, Serialize, Deserialize)]
struct ApiEndpoints {
api: Option<String>,
}
struct CachedApiKey {
token: String,
api_endpoint: String,
expires_at: i64,
}
// ── Chat completions types ───────────────────────────────────────
#[derive(Debug, Serialize)]
struct ApiChatRequest {
model: String,
messages: Vec<ApiMessage>,
temperature: f64,
#[serde(skip_serializing_if = "Option::is_none")]
tools: Option<Vec<NativeToolSpec>>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_choice: Option<String>,
}
#[derive(Debug, Serialize)]
struct ApiMessage {
role: String,
#[serde(skip_serializing_if = "Option::is_none")]
content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_call_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_calls: Option<Vec<NativeToolCall>>,
}
#[derive(Debug, Serialize)]
struct NativeToolSpec {
#[serde(rename = "type")]
kind: String,
function: NativeToolFunctionSpec,
}
#[derive(Debug, Serialize)]
struct NativeToolFunctionSpec {
name: String,
description: String,
parameters: serde_json::Value,
}
#[derive(Debug, Serialize, Deserialize)]
struct NativeToolCall {
#[serde(skip_serializing_if = "Option::is_none")]
id: Option<String>,
#[serde(rename = "type", skip_serializing_if = "Option::is_none")]
kind: Option<String>,
function: NativeFunctionCall,
}
#[derive(Debug, Serialize, Deserialize)]
struct NativeFunctionCall {
name: String,
arguments: String,
}
#[derive(Debug, Deserialize)]
struct ApiChatResponse {
choices: Vec<Choice>,
}
#[derive(Debug, Deserialize)]
struct Choice {
message: ResponseMessage,
}
#[derive(Debug, Deserialize)]
struct ResponseMessage {
#[serde(default)]
content: Option<String>,
#[serde(default)]
tool_calls: Option<Vec<NativeToolCall>>,
}
// ── Provider ─────────────────────────────────────────────────────
/// GitHub Copilot provider with automatic OAuth and token refresh.
///
/// On first use, prompts the user to visit github.com/login/device.
/// Tokens are cached to `~/.config/zeroclaw/copilot/` and refreshed
/// automatically.
pub struct CopilotProvider {
github_token: Option<String>,
/// Mutex ensures only one caller refreshes tokens at a time,
/// preventing duplicate device flow prompts or redundant API calls.
refresh_lock: Arc<Mutex<Option<CachedApiKey>>>,
http: Client,
token_dir: PathBuf,
}
impl CopilotProvider {
pub fn new(github_token: Option<&str>) -> Self {
let token_dir = directories::ProjectDirs::from("", "", "zeroclaw")
.map(|dir| dir.config_dir().join("copilot"))
.unwrap_or_else(|| {
// Fall back to a user-specific temp directory to avoid
// shared-directory symlink attacks.
let user = std::env::var("USER")
.or_else(|_| std::env::var("USERNAME"))
.unwrap_or_else(|_| "unknown".to_string());
std::env::temp_dir().join(format!("zeroclaw-copilot-{user}"))
});
if let Err(err) = std::fs::create_dir_all(&token_dir) {
warn!(
"Failed to create Copilot token directory {:?}: {err}. Token caching is disabled.",
token_dir
);
} else {
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
if let Err(err) =
std::fs::set_permissions(&token_dir, std::fs::Permissions::from_mode(0o700))
{
warn!(
"Failed to set Copilot token directory permissions on {:?}: {err}",
token_dir
);
}
}
}
Self {
github_token: github_token
.filter(|token| !token.is_empty())
.map(String::from),
refresh_lock: Arc::new(Mutex::new(None)),
http: Client::builder()
.timeout(Duration::from_secs(120))
.connect_timeout(Duration::from_secs(10))
.build()
.unwrap_or_else(|_| Client::new()),
token_dir,
}
}
/// Required headers for Copilot API requests (editor identification).
const COPILOT_HEADERS: [(&str, &str); 4] = [
("Editor-Version", "vscode/1.85.1"),
("Editor-Plugin-Version", "copilot/1.155.0"),
("User-Agent", "GithubCopilot/1.155.0"),
("Accept", "application/json"),
];
fn convert_tools(tools: Option<&[ToolSpec]>) -> Option<Vec<NativeToolSpec>> {
tools.map(|items| {
items
.iter()
.map(|tool| NativeToolSpec {
kind: "function".to_string(),
function: NativeToolFunctionSpec {
name: tool.name.clone(),
description: tool.description.clone(),
parameters: tool.parameters.clone(),
},
})
.collect()
})
}
fn convert_messages(messages: &[ChatMessage]) -> Vec<ApiMessage> {
messages
.iter()
.map(|message| {
if message.role == "assistant" {
if let Ok(value) = serde_json::from_str::<serde_json::Value>(&message.content) {
if let Some(tool_calls_value) = value.get("tool_calls") {
if let Ok(parsed_calls) =
serde_json::from_value::<Vec<ProviderToolCall>>(tool_calls_value.clone())
{
let tool_calls = parsed_calls
.into_iter()
.map(|tool_call| NativeToolCall {
id: Some(tool_call.id),
kind: Some("function".to_string()),
function: NativeFunctionCall {
name: tool_call.name,
arguments: tool_call.arguments,
},
})
.collect::<Vec<_>>();
let content = value
.get("content")
.and_then(serde_json::Value::as_str)
.map(ToString::to_string);
return ApiMessage {
role: "assistant".to_string(),
content,
tool_call_id: None,
tool_calls: Some(tool_calls),
};
}
}
}
}
if message.role == "tool" {
if let Ok(value) = serde_json::from_str::<serde_json::Value>(&message.content) {
let tool_call_id = value
.get("tool_call_id")
.and_then(serde_json::Value::as_str)
.map(ToString::to_string);
let content = value
.get("content")
.and_then(serde_json::Value::as_str)
.map(ToString::to_string);
return ApiMessage {
role: "tool".to_string(),
content,
tool_call_id,
tool_calls: None,
};
}
}
ApiMessage {
role: message.role.clone(),
content: Some(message.content.clone()),
tool_call_id: None,
tool_calls: None,
}
})
.collect()
}
/// Send a chat completions request with required Copilot headers.
async fn send_chat_request(
&self,
messages: Vec<ApiMessage>,
tools: Option<&[ToolSpec]>,
model: &str,
temperature: f64,
) -> anyhow::Result<ProviderChatResponse> {
let (token, endpoint) = self.get_api_key().await?;
let url = format!("{}/chat/completions", endpoint.trim_end_matches('/'));
let native_tools = Self::convert_tools(tools);
let request = ApiChatRequest {
model: model.to_string(),
messages,
temperature,
tool_choice: native_tools.as_ref().map(|_| "auto".to_string()),
tools: native_tools,
};
let mut req = self
.http
.post(&url)
.header("Authorization", format!("Bearer {token}"))
.json(&request);
for (header, value) in &Self::COPILOT_HEADERS {
req = req.header(*header, *value);
}
let response = req.send().await?;
if !response.status().is_success() {
return Err(super::api_error("GitHub Copilot", response).await);
}
let api_response: ApiChatResponse = response.json().await?;
let choice = api_response
.choices
.into_iter()
.next()
.ok_or_else(|| anyhow::anyhow!("No response from GitHub Copilot"))?;
let tool_calls = choice
.message
.tool_calls
.unwrap_or_default()
.into_iter()
.map(|tool_call| ProviderToolCall {
id: tool_call
.id
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string()),
name: tool_call.function.name,
arguments: tool_call.function.arguments,
})
.collect();
Ok(ProviderChatResponse {
text: choice.message.content,
tool_calls,
})
}
/// Get a valid Copilot API key, refreshing or re-authenticating as needed.
/// Uses a Mutex to ensure only one caller refreshes at a time.
async fn get_api_key(&self) -> anyhow::Result<(String, String)> {
let mut cached = self.refresh_lock.lock().await;
if let Some(cached_key) = cached.as_ref() {
if chrono::Utc::now().timestamp() + 120 < cached_key.expires_at {
return Ok((cached_key.token.clone(), cached_key.api_endpoint.clone()));
}
}
if let Some(info) = self.load_api_key_from_disk().await {
if chrono::Utc::now().timestamp() + 120 < info.expires_at {
let endpoint = info
.endpoints
.as_ref()
.and_then(|e| e.api.clone())
.unwrap_or_else(|| DEFAULT_API.to_string());
let token = info.token;
*cached = Some(CachedApiKey {
token: token.clone(),
api_endpoint: endpoint.clone(),
expires_at: info.expires_at,
});
return Ok((token, endpoint));
}
}
let access_token = self.get_github_access_token().await?;
let api_key_info = self.exchange_for_api_key(&access_token).await?;
self.save_api_key_to_disk(&api_key_info).await;
let endpoint = api_key_info
.endpoints
.as_ref()
.and_then(|e| e.api.clone())
.unwrap_or_else(|| DEFAULT_API.to_string());
*cached = Some(CachedApiKey {
token: api_key_info.token.clone(),
api_endpoint: endpoint.clone(),
expires_at: api_key_info.expires_at,
});
Ok((api_key_info.token, endpoint))
}
/// Get a GitHub access token from config, cache, or device flow.
async fn get_github_access_token(&self) -> anyhow::Result<String> {
if let Some(token) = &self.github_token {
return Ok(token.clone());
}
let access_token_path = self.token_dir.join("access-token");
if let Ok(cached) = tokio::fs::read_to_string(&access_token_path).await {
let token = cached.trim();
if !token.is_empty() {
return Ok(token.to_string());
}
}
let token = self.device_code_login().await?;
write_file_secure(&access_token_path, &token).await;
Ok(token)
}
/// Run GitHub OAuth device code flow.
async fn device_code_login(&self) -> anyhow::Result<String> {
let response: DeviceCodeResponse = self
.http
.post(GITHUB_DEVICE_CODE_URL)
.header("Accept", "application/json")
.json(&serde_json::json!({
"client_id": GITHUB_CLIENT_ID,
"scope": "read:user"
}))
.send()
.await?
.error_for_status()?
.json()
.await?;
let mut poll_interval = Duration::from_secs(response.interval.max(5));
let expires_in = response.expires_in.max(1);
let expires_at = tokio::time::Instant::now() + Duration::from_secs(expires_in);
eprintln!(
"\nGitHub Copilot authentication is required.\n\
Visit: {}\n\
Code: {}\n\
Waiting for authorization...\n",
response.verification_uri, response.user_code
);
while tokio::time::Instant::now() < expires_at {
tokio::time::sleep(poll_interval).await;
let token_response: AccessTokenResponse = self
.http
.post(GITHUB_ACCESS_TOKEN_URL)
.header("Accept", "application/json")
.json(&serde_json::json!({
"client_id": GITHUB_CLIENT_ID,
"device_code": response.device_code,
"grant_type": "urn:ietf:params:oauth:grant-type:device_code"
}))
.send()
.await?
.json()
.await?;
if let Some(token) = token_response.access_token {
eprintln!("Authentication succeeded.\n");
return Ok(token);
}
match token_response.error.as_deref() {
Some("slow_down") => {
poll_interval += Duration::from_secs(5);
}
Some("authorization_pending") | None => {}
Some("expired_token") => {
anyhow::bail!("GitHub device authorization expired")
}
Some(error) => anyhow::bail!("GitHub auth failed: {error}"),
}
}
anyhow::bail!("Timed out waiting for GitHub authorization")
}
/// Exchange a GitHub access token for a Copilot API key.
async fn exchange_for_api_key(&self, access_token: &str) -> anyhow::Result<ApiKeyInfo> {
let mut request = self.http.get(GITHUB_API_KEY_URL);
for (header, value) in &Self::COPILOT_HEADERS {
request = request.header(*header, *value);
}
request = request.header("Authorization", format!("token {access_token}"));
let response = request.send().await?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
let sanitized = super::sanitize_api_error(&body);
if status.as_u16() == 401 || status.as_u16() == 403 {
let access_token_path = self.token_dir.join("access-token");
tokio::fs::remove_file(&access_token_path).await.ok();
}
anyhow::bail!(
"Failed to get Copilot API key ({status}): {sanitized}. \
Ensure your GitHub account has an active Copilot subscription."
);
}
let info: ApiKeyInfo = response.json().await?;
Ok(info)
}
async fn load_api_key_from_disk(&self) -> Option<ApiKeyInfo> {
let path = self.token_dir.join("api-key.json");
let data = tokio::fs::read_to_string(&path).await.ok()?;
serde_json::from_str(&data).ok()
}
async fn save_api_key_to_disk(&self, info: &ApiKeyInfo) {
let path = self.token_dir.join("api-key.json");
if let Ok(json) = serde_json::to_string_pretty(info) {
write_file_secure(&path, &json).await;
}
}
}
/// Write a file with 0600 permissions (owner read/write only).
/// Uses `spawn_blocking` to avoid blocking the async runtime.
async fn write_file_secure(path: &Path, content: &str) {
let path = path.to_path_buf();
let content = content.to_string();
let result = tokio::task::spawn_blocking(move || {
#[cfg(unix)]
{
use std::io::Write;
use std::os::unix::fs::{OpenOptionsExt, PermissionsExt};
let mut file = std::fs::OpenOptions::new()
.write(true)
.create(true)
.truncate(true)
.mode(0o600)
.open(&path)?;
file.write_all(content.as_bytes())?;
std::fs::set_permissions(&path, std::fs::Permissions::from_mode(0o600))?;
Ok::<(), std::io::Error>(())
}
#[cfg(not(unix))]
{
std::fs::write(&path, &content)?;
Ok::<(), std::io::Error>(())
}
})
.await;
match result {
Ok(Ok(())) => {}
Ok(Err(err)) => warn!("Failed to write secure file: {err}"),
Err(err) => warn!("Failed to spawn blocking write: {err}"),
}
}
#[async_trait]
impl Provider for CopilotProvider {
async fn chat_with_system(
&self,
system_prompt: Option<&str>,
message: &str,
model: &str,
temperature: f64,
) -> anyhow::Result<String> {
let mut messages = Vec::new();
if let Some(system) = system_prompt {
messages.push(ApiMessage {
role: "system".to_string(),
content: Some(system.to_string()),
tool_call_id: None,
tool_calls: None,
});
}
messages.push(ApiMessage {
role: "user".to_string(),
content: Some(message.to_string()),
tool_call_id: None,
tool_calls: None,
});
let response = self
.send_chat_request(messages, None, model, temperature)
.await?;
Ok(response.text.unwrap_or_default())
}
async fn chat_with_history(
&self,
messages: &[ChatMessage],
model: &str,
temperature: f64,
) -> anyhow::Result<String> {
let response = self
.send_chat_request(Self::convert_messages(messages), None, model, temperature)
.await?;
Ok(response.text.unwrap_or_default())
}
async fn chat(
&self,
request: ProviderChatRequest<'_>,
model: &str,
temperature: f64,
) -> anyhow::Result<ProviderChatResponse> {
self.send_chat_request(
Self::convert_messages(request.messages),
request.tools,
model,
temperature,
)
.await
}
fn supports_native_tools(&self) -> bool {
true
}
async fn warmup(&self) -> anyhow::Result<()> {
let _ = self.get_api_key().await?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn new_without_token() {
let provider = CopilotProvider::new(None);
assert!(provider.github_token.is_none());
}
#[test]
fn new_with_token() {
let provider = CopilotProvider::new(Some("ghp_test"));
assert_eq!(provider.github_token.as_deref(), Some("ghp_test"));
}
#[test]
fn empty_token_treated_as_none() {
let provider = CopilotProvider::new(Some(""));
assert!(provider.github_token.is_none());
}
#[tokio::test]
async fn cache_starts_empty() {
let provider = CopilotProvider::new(None);
let cached = provider.refresh_lock.lock().await;
assert!(cached.is_none());
}
#[test]
fn copilot_headers_include_required_fields() {
let headers = CopilotProvider::COPILOT_HEADERS;
assert!(headers
.iter()
.any(|(header, _)| *header == "Editor-Version"));
assert!(headers
.iter()
.any(|(header, _)| *header == "Editor-Plugin-Version"));
assert!(headers.iter().any(|(header, _)| *header == "User-Agent"));
}
#[test]
fn default_interval_and_expiry() {
assert_eq!(default_interval(), 5);
assert_eq!(default_expires_in(), 900);
}
#[test]
fn supports_native_tools() {
let provider = CopilotProvider::new(None);
assert!(provider.supports_native_tools());
}
}

View file

@ -1,5 +1,6 @@
pub mod anthropic; pub mod anthropic;
pub mod compatible; pub mod compatible;
pub mod copilot;
pub mod gemini; pub mod gemini;
pub mod ollama; pub mod ollama;
pub mod openai; pub mod openai;
@ -37,9 +38,18 @@ fn token_end(input: &str, from: usize) -> usize {
/// Scrub known secret-like token prefixes from provider error strings. /// Scrub known secret-like token prefixes from provider error strings.
/// ///
/// Redacts tokens with prefixes like `sk-`, `xoxb-`, and `xoxp-`. /// Redacts tokens with prefixes like `sk-`, `xoxb-`, `xoxp-`, `ghp_`, `gho_`,
/// `ghu_`, and `github_pat_`.
pub fn scrub_secret_patterns(input: &str) -> String { pub fn scrub_secret_patterns(input: &str) -> String {
const PREFIXES: [&str; 3] = ["sk-", "xoxb-", "xoxp-"]; const PREFIXES: [&str; 7] = [
"sk-",
"xoxb-",
"xoxp-",
"ghp_",
"gho_",
"ghu_",
"github_pat_",
];
let mut scrubbed = input.to_string(); let mut scrubbed = input.to_string();
@ -104,9 +114,12 @@ pub async fn api_error(provider: &str, response: reqwest::Response) -> anyhow::E
/// ///
/// For Anthropic, the provider-specific env var is `ANTHROPIC_OAUTH_TOKEN` (for setup-tokens) /// For Anthropic, the provider-specific env var is `ANTHROPIC_OAUTH_TOKEN` (for setup-tokens)
/// followed by `ANTHROPIC_API_KEY` (for regular API keys). /// followed by `ANTHROPIC_API_KEY` (for regular API keys).
fn resolve_api_key(name: &str, api_key: Option<&str>) -> Option<String> { fn resolve_provider_credential(name: &str, credential_override: Option<&str>) -> Option<String> {
if let Some(key) = api_key.map(str::trim).filter(|k| !k.is_empty()) { if let Some(raw_override) = credential_override {
return Some(key.to_string()); let trimmed_override = raw_override.trim();
if !trimmed_override.is_empty() {
return Some(trimmed_override.to_owned());
}
} }
let provider_env_candidates: Vec<&str> = match name { let provider_env_candidates: Vec<&str> = match name {
@ -135,6 +148,7 @@ fn resolve_api_key(name: &str, api_key: Option<&str>) -> Option<String> {
"opencode" | "opencode-zen" => vec!["OPENCODE_API_KEY"], "opencode" | "opencode-zen" => vec!["OPENCODE_API_KEY"],
"vercel" | "vercel-ai" => vec!["VERCEL_API_KEY"], "vercel" | "vercel-ai" => vec!["VERCEL_API_KEY"],
"cloudflare" | "cloudflare-ai" => vec!["CLOUDFLARE_API_KEY"], "cloudflare" | "cloudflare-ai" => vec!["CLOUDFLARE_API_KEY"],
"astrai" => vec!["ASTRAI_API_KEY"],
_ => vec![], _ => vec![],
}; };
@ -182,19 +196,28 @@ fn parse_custom_provider_url(
} }
} }
/// Factory: create the right provider from config /// Factory: create the right provider from config (without custom URL)
#[allow(clippy::too_many_lines)]
pub fn create_provider(name: &str, api_key: Option<&str>) -> anyhow::Result<Box<dyn Provider>> { pub fn create_provider(name: &str, api_key: Option<&str>) -> anyhow::Result<Box<dyn Provider>> {
let resolved_key = resolve_api_key(name, api_key); create_provider_with_url(name, api_key, None)
let key = resolved_key.as_deref(); }
/// Factory: create the right provider from config with optional custom base URL
#[allow(clippy::too_many_lines)]
pub fn create_provider_with_url(
name: &str,
api_key: Option<&str>,
api_url: Option<&str>,
) -> anyhow::Result<Box<dyn Provider>> {
let resolved_credential = resolve_provider_credential(name, api_key);
#[allow(clippy::option_as_ref_deref)]
let key = resolved_credential.as_ref().map(String::as_str);
match name { match name {
// ── Primary providers (custom implementations) ─────── // ── Primary providers (custom implementations) ───────
"openrouter" => Ok(Box::new(openrouter::OpenRouterProvider::new(key))), "openrouter" => Ok(Box::new(openrouter::OpenRouterProvider::new(key))),
"anthropic" => Ok(Box::new(anthropic::AnthropicProvider::new(key))), "anthropic" => Ok(Box::new(anthropic::AnthropicProvider::new(key))),
"openai" => Ok(Box::new(openai::OpenAiProvider::new(key))), "openai" => Ok(Box::new(openai::OpenAiProvider::new(key))),
// Ollama is a local service that doesn't use API keys. // Ollama uses api_url for custom base URL (e.g. remote Ollama instance)
// The api_key parameter is ignored to avoid it being misinterpreted as a base_url. "ollama" => Ok(Box::new(ollama::OllamaProvider::new(api_url))),
"ollama" => Ok(Box::new(ollama::OllamaProvider::new(None))),
"gemini" | "google" | "google-gemini" => { "gemini" | "google" | "google-gemini" => {
Ok(Box::new(gemini::GeminiProvider::new(key))) Ok(Box::new(gemini::GeminiProvider::new(key)))
} }
@ -257,7 +280,7 @@ pub fn create_provider(name: &str, api_key: Option<&str>) -> anyhow::Result<Box<
"Groq", "https://api.groq.com/openai", key, AuthStyle::Bearer, "Groq", "https://api.groq.com/openai", key, AuthStyle::Bearer,
))), ))),
"mistral" => Ok(Box::new(OpenAiCompatibleProvider::new( "mistral" => Ok(Box::new(OpenAiCompatibleProvider::new(
"Mistral", "https://api.mistral.ai", key, AuthStyle::Bearer, "Mistral", "https://api.mistral.ai/v1", key, AuthStyle::Bearer,
))), ))),
"xai" | "grok" => Ok(Box::new(OpenAiCompatibleProvider::new( "xai" | "grok" => Ok(Box::new(OpenAiCompatibleProvider::new(
"xAI", "https://api.x.ai", key, AuthStyle::Bearer, "xAI", "https://api.x.ai", key, AuthStyle::Bearer,
@ -277,11 +300,33 @@ pub fn create_provider(name: &str, api_key: Option<&str>) -> anyhow::Result<Box<
"cohere" => Ok(Box::new(OpenAiCompatibleProvider::new( "cohere" => Ok(Box::new(OpenAiCompatibleProvider::new(
"Cohere", "https://api.cohere.com/compatibility", key, AuthStyle::Bearer, "Cohere", "https://api.cohere.com/compatibility", key, AuthStyle::Bearer,
))), ))),
"copilot" | "github-copilot" => Ok(Box::new(OpenAiCompatibleProvider::new( "copilot" | "github-copilot" => {
"GitHub Copilot", "https://api.githubcopilot.com", key, AuthStyle::Bearer, Ok(Box::new(copilot::CopilotProvider::new(api_key)))
))), },
"nvidia" | "nvidia-nim" | "build.nvidia.com" => Ok(Box::new(OpenAiCompatibleProvider::new( "lmstudio" | "lm-studio" => {
"NVIDIA NIM", "https://integrate.api.nvidia.com/v1", key, AuthStyle::Bearer, let lm_studio_key = api_key
.map(str::trim)
.filter(|value| !value.is_empty())
.unwrap_or("lm-studio");
Ok(Box::new(OpenAiCompatibleProvider::new(
"LM Studio",
"http://localhost:1234/v1",
Some(lm_studio_key),
AuthStyle::Bearer,
)))
}
"nvidia" | "nvidia-nim" | "build.nvidia.com" => Ok(Box::new(
OpenAiCompatibleProvider::new(
"NVIDIA NIM",
"https://integrate.api.nvidia.com/v1",
key,
AuthStyle::Bearer,
),
)),
// ── AI inference routers ─────────────────────────────
"astrai" => Ok(Box::new(OpenAiCompatibleProvider::new(
"Astrai", "https://as-trai.com/v1", key, AuthStyle::Bearer,
))), ))),
// ── Bring Your Own Provider (custom URL) ─────────── // ── Bring Your Own Provider (custom URL) ───────────
@ -326,13 +371,14 @@ pub fn create_provider(name: &str, api_key: Option<&str>) -> anyhow::Result<Box<
pub fn create_resilient_provider( pub fn create_resilient_provider(
primary_name: &str, primary_name: &str,
api_key: Option<&str>, api_key: Option<&str>,
api_url: Option<&str>,
reliability: &crate::config::ReliabilityConfig, reliability: &crate::config::ReliabilityConfig,
) -> anyhow::Result<Box<dyn Provider>> { ) -> anyhow::Result<Box<dyn Provider>> {
let mut providers: Vec<(String, Box<dyn Provider>)> = Vec::new(); let mut providers: Vec<(String, Box<dyn Provider>)> = Vec::new();
providers.push(( providers.push((
primary_name.to_string(), primary_name.to_string(),
create_provider(primary_name, api_key)?, create_provider_with_url(primary_name, api_key, api_url)?,
)); ));
for fallback in &reliability.fallback_providers { for fallback in &reliability.fallback_providers {
@ -340,21 +386,13 @@ pub fn create_resilient_provider(
continue; continue;
} }
if api_key.is_some() && fallback != "ollama" { // Fallback providers don't use the custom api_url (it's specific to primary)
tracing::warn!(
fallback_provider = fallback,
primary_provider = primary_name,
"Fallback provider will use the primary provider's API key — \
this will fail if the providers require different keys"
);
}
match create_provider(fallback, api_key) { match create_provider(fallback, api_key) {
Ok(provider) => providers.push((fallback.clone(), provider)), Ok(provider) => providers.push((fallback.clone(), provider)),
Err(e) => { Err(_error) => {
tracing::warn!( tracing::warn!(
fallback_provider = fallback, fallback_provider = fallback,
"Ignoring invalid fallback provider: {e}" "Ignoring invalid fallback provider during initialization"
); );
} }
} }
@ -377,12 +415,13 @@ pub fn create_resilient_provider(
pub fn create_routed_provider( pub fn create_routed_provider(
primary_name: &str, primary_name: &str,
api_key: Option<&str>, api_key: Option<&str>,
api_url: Option<&str>,
reliability: &crate::config::ReliabilityConfig, reliability: &crate::config::ReliabilityConfig,
model_routes: &[crate::config::ModelRouteConfig], model_routes: &[crate::config::ModelRouteConfig],
default_model: &str, default_model: &str,
) -> anyhow::Result<Box<dyn Provider>> { ) -> anyhow::Result<Box<dyn Provider>> {
if model_routes.is_empty() { if model_routes.is_empty() {
return create_resilient_provider(primary_name, api_key, reliability); return create_resilient_provider(primary_name, api_key, api_url, reliability);
} }
// Collect unique provider names needed // Collect unique provider names needed
@ -396,12 +435,19 @@ pub fn create_routed_provider(
// Create each provider (with its own resilience wrapper) // Create each provider (with its own resilience wrapper)
let mut providers: Vec<(String, Box<dyn Provider>)> = Vec::new(); let mut providers: Vec<(String, Box<dyn Provider>)> = Vec::new();
for name in &needed { for name in &needed {
let key = model_routes let routed_credential = model_routes
.iter() .iter()
.find(|r| &r.provider == name) .find(|r| &r.provider == name)
.and_then(|r| r.api_key.as_deref()) .and_then(|r| {
.or(api_key); r.api_key.as_ref().and_then(|raw_key| {
match create_resilient_provider(name, key, reliability) { let trimmed_key = raw_key.trim();
(!trimmed_key.is_empty()).then_some(trimmed_key)
})
});
let key = routed_credential.or(api_key);
// Only use api_url for the primary provider
let url = if name == primary_name { api_url } else { None };
match create_resilient_provider(name, key, url, reliability) {
Ok(provider) => providers.push((name.clone(), provider)), Ok(provider) => providers.push((name.clone(), provider)),
Err(e) => { Err(e) => {
if name == primary_name { if name == primary_name {
@ -409,7 +455,7 @@ pub fn create_routed_provider(
} }
tracing::warn!( tracing::warn!(
provider = name.as_str(), provider = name.as_str(),
"Ignoring routed provider that failed to create: {e}" "Ignoring routed provider that failed to initialize"
); );
} }
} }
@ -441,27 +487,27 @@ mod tests {
use super::*; use super::*;
#[test] #[test]
fn resolve_api_key_prefers_explicit_argument() { fn resolve_provider_credential_prefers_explicit_argument() {
let resolved = resolve_api_key("openrouter", Some(" explicit-key ")); let resolved = resolve_provider_credential("openrouter", Some(" explicit-key "));
assert_eq!(resolved.as_deref(), Some("explicit-key")); assert_eq!(resolved, Some("explicit-key".to_string()));
} }
// ── Primary providers ──────────────────────────────────── // ── Primary providers ────────────────────────────────────
#[test] #[test]
fn factory_openrouter() { fn factory_openrouter() {
assert!(create_provider("openrouter", Some("sk-test")).is_ok()); assert!(create_provider("openrouter", Some("provider-test-credential")).is_ok());
assert!(create_provider("openrouter", None).is_ok()); assert!(create_provider("openrouter", None).is_ok());
} }
#[test] #[test]
fn factory_anthropic() { fn factory_anthropic() {
assert!(create_provider("anthropic", Some("sk-test")).is_ok()); assert!(create_provider("anthropic", Some("provider-test-credential")).is_ok());
} }
#[test] #[test]
fn factory_openai() { fn factory_openai() {
assert!(create_provider("openai", Some("sk-test")).is_ok()); assert!(create_provider("openai", Some("provider-test-credential")).is_ok());
} }
#[test] #[test]
@ -556,6 +602,13 @@ mod tests {
assert!(create_provider("dashscope-us", Some("key")).is_ok()); assert!(create_provider("dashscope-us", Some("key")).is_ok());
} }
#[test]
fn factory_lmstudio() {
assert!(create_provider("lmstudio", Some("key")).is_ok());
assert!(create_provider("lm-studio", Some("key")).is_ok());
assert!(create_provider("lmstudio", None).is_ok());
}
// ── Extended ecosystem ─────────────────────────────────── // ── Extended ecosystem ───────────────────────────────────
#[test] #[test]
@ -614,6 +667,13 @@ mod tests {
assert!(create_provider("build.nvidia.com", Some("nvapi-test")).is_ok()); assert!(create_provider("build.nvidia.com", Some("nvapi-test")).is_ok());
} }
// ── AI inference routers ─────────────────────────────────
#[test]
fn factory_astrai() {
assert!(create_provider("astrai", Some("sk-astrai-test")).is_ok());
}
// ── Custom / BYOP provider ───────────────────────────── // ── Custom / BYOP provider ─────────────────────────────
#[test] #[test]
@ -761,17 +821,33 @@ mod tests {
scheduler_retries: 2, scheduler_retries: 2,
}; };
let provider = create_resilient_provider("openrouter", Some("sk-test"), &reliability); let provider = create_resilient_provider(
"openrouter",
Some("provider-test-credential"),
None,
&reliability,
);
assert!(provider.is_ok()); assert!(provider.is_ok());
} }
#[test] #[test]
fn resilient_provider_errors_for_invalid_primary() { fn resilient_provider_errors_for_invalid_primary() {
let reliability = crate::config::ReliabilityConfig::default(); let reliability = crate::config::ReliabilityConfig::default();
let provider = create_resilient_provider("totally-invalid", Some("sk-test"), &reliability); let provider = create_resilient_provider(
"totally-invalid",
Some("provider-test-credential"),
None,
&reliability,
);
assert!(provider.is_err()); assert!(provider.is_err());
} }
#[test]
fn ollama_with_custom_url() {
let provider = create_provider_with_url("ollama", None, Some("http://10.100.2.32:11434"));
assert!(provider.is_ok());
}
#[test] #[test]
fn factory_all_providers_create_successfully() { fn factory_all_providers_create_successfully() {
let providers = [ let providers = [
@ -794,6 +870,7 @@ mod tests {
"qwen", "qwen",
"qwen-intl", "qwen-intl",
"qwen-us", "qwen-us",
"lmstudio",
"groq", "groq",
"mistral", "mistral",
"xai", "xai",
@ -888,7 +965,7 @@ mod tests {
#[test] #[test]
fn sanitize_preserves_unicode_boundaries() { fn sanitize_preserves_unicode_boundaries() {
let input = format!("{} sk-abcdef123", "こんにちは".repeat(80)); let input = format!("{} sk-abcdef123", "hello🙂".repeat(80));
let result = sanitize_api_error(&input); let result = sanitize_api_error(&input);
assert!(std::str::from_utf8(result.as_bytes()).is_ok()); assert!(std::str::from_utf8(result.as_bytes()).is_ok());
assert!(!result.contains("sk-abcdef123")); assert!(!result.contains("sk-abcdef123"));
@ -900,4 +977,32 @@ mod tests {
let result = sanitize_api_error(input); let result = sanitize_api_error(input);
assert_eq!(result, input); assert_eq!(result, input);
} }
#[test]
fn scrub_github_personal_access_token() {
let input = "auth failed with token ghp_abc123def456";
let result = scrub_secret_patterns(input);
assert_eq!(result, "auth failed with token [REDACTED]");
}
#[test]
fn scrub_github_oauth_token() {
let input = "Bearer gho_1234567890abcdef";
let result = scrub_secret_patterns(input);
assert_eq!(result, "Bearer [REDACTED]");
}
#[test]
fn scrub_github_user_token() {
let input = "token ghu_sessiontoken123";
let result = scrub_secret_patterns(input);
assert_eq!(result, "token [REDACTED]");
}
#[test]
fn scrub_github_fine_grained_pat() {
let input = "failed: github_pat_11AABBC_xyzzy789";
let result = scrub_secret_patterns(input);
assert_eq!(result, "failed: [REDACTED]");
}
} }

View file

@ -8,6 +8,8 @@ pub struct OllamaProvider {
client: Client, client: Client,
} }
// ─── Request Structures ───────────────────────────────────────────────────────
#[derive(Debug, Serialize)] #[derive(Debug, Serialize)]
struct ChatRequest { struct ChatRequest {
model: String, model: String,
@ -27,6 +29,8 @@ struct Options {
temperature: f64, temperature: f64,
} }
// ─── Response Structures ──────────────────────────────────────────────────────
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
struct ApiChatResponse { struct ApiChatResponse {
message: ResponseMessage, message: ResponseMessage,
@ -34,9 +38,30 @@ struct ApiChatResponse {
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
struct ResponseMessage { struct ResponseMessage {
#[serde(default)]
content: String, content: String,
#[serde(default)]
tool_calls: Vec<OllamaToolCall>,
/// Some models return a "thinking" field with internal reasoning
#[serde(default)]
thinking: Option<String>,
} }
#[derive(Debug, Deserialize)]
struct OllamaToolCall {
id: Option<String>,
function: OllamaFunction,
}
#[derive(Debug, Deserialize)]
struct OllamaFunction {
name: String,
#[serde(default)]
arguments: serde_json::Value,
}
// ─── Implementation ───────────────────────────────────────────────────────────
impl OllamaProvider { impl OllamaProvider {
pub fn new(base_url: Option<&str>) -> Self { pub fn new(base_url: Option<&str>) -> Self {
Self { Self {
@ -45,12 +70,145 @@ impl OllamaProvider {
.trim_end_matches('/') .trim_end_matches('/')
.to_string(), .to_string(),
client: Client::builder() client: Client::builder()
.timeout(std::time::Duration::from_secs(300)) // Ollama runs locally, may be slow .timeout(std::time::Duration::from_secs(300))
.connect_timeout(std::time::Duration::from_secs(10)) .connect_timeout(std::time::Duration::from_secs(10))
.build() .build()
.unwrap_or_else(|_| Client::new()), .unwrap_or_else(|_| Client::new()),
} }
} }
/// Send a request to Ollama and get the parsed response
async fn send_request(
&self,
messages: Vec<Message>,
model: &str,
temperature: f64,
) -> anyhow::Result<ApiChatResponse> {
let request = ChatRequest {
model: model.to_string(),
messages,
stream: false,
options: Options { temperature },
};
let url = format!("{}/api/chat", self.base_url);
tracing::debug!(
"Ollama request: url={} model={} message_count={} temperature={}",
url,
model,
request.messages.len(),
temperature
);
let response = self.client.post(&url).json(&request).send().await?;
let status = response.status();
tracing::debug!("Ollama response status: {}", status);
let body = response.bytes().await?;
tracing::debug!("Ollama response body length: {} bytes", body.len());
if !status.is_success() {
let raw = String::from_utf8_lossy(&body);
let sanitized = super::sanitize_api_error(&raw);
tracing::error!(
"Ollama error response: status={} body_excerpt={}",
status,
sanitized
);
anyhow::bail!(
"Ollama API error ({}): {}. Is Ollama running? (brew install ollama && ollama serve)",
status,
sanitized
);
}
let chat_response: ApiChatResponse = match serde_json::from_slice(&body) {
Ok(r) => r,
Err(e) => {
let raw = String::from_utf8_lossy(&body);
let sanitized = super::sanitize_api_error(&raw);
tracing::error!(
"Ollama response deserialization failed: {e}. body_excerpt={}",
sanitized
);
anyhow::bail!("Failed to parse Ollama response: {e}");
}
};
Ok(chat_response)
}
/// Convert Ollama tool calls to the JSON format expected by parse_tool_calls in loop_.rs
///
/// Handles quirky model behavior where tool calls are wrapped:
/// - `{"name": "tool_call", "arguments": {"name": "shell", "arguments": {...}}}`
/// - `{"name": "tool.shell", "arguments": {...}}`
fn format_tool_calls_for_loop(&self, tool_calls: &[OllamaToolCall]) -> String {
let formatted_calls: Vec<serde_json::Value> = tool_calls
.iter()
.map(|tc| {
let (tool_name, tool_args) = self.extract_tool_name_and_args(tc);
// Arguments must be a JSON string for parse_tool_calls compatibility
let args_str =
serde_json::to_string(&tool_args).unwrap_or_else(|_| "{}".to_string());
serde_json::json!({
"id": tc.id,
"type": "function",
"function": {
"name": tool_name,
"arguments": args_str
}
})
})
.collect();
serde_json::json!({
"content": "",
"tool_calls": formatted_calls
})
.to_string()
}
/// Extract the actual tool name and arguments from potentially nested structures
fn extract_tool_name_and_args(&self, tc: &OllamaToolCall) -> (String, serde_json::Value) {
let name = &tc.function.name;
let args = &tc.function.arguments;
// Pattern 1: Nested tool_call wrapper (various malformed versions)
// {"name": "tool_call", "arguments": {"name": "shell", "arguments": {"command": "date"}}}
// {"name": "tool_call><json", "arguments": {"name": "shell", ...}}
// {"name": "tool.call", "arguments": {"name": "shell", ...}}
if name == "tool_call"
|| name == "tool.call"
|| name.starts_with("tool_call>")
|| name.starts_with("tool_call<")
{
if let Some(nested_name) = args.get("name").and_then(|v| v.as_str()) {
let nested_args = args
.get("arguments")
.cloned()
.unwrap_or(serde_json::json!({}));
tracing::debug!(
"Unwrapped nested tool call: {} -> {} with args {:?}",
name,
nested_name,
nested_args
);
return (nested_name.to_string(), nested_args);
}
}
// Pattern 2: Prefixed tool name (tool.shell, tool.file_read, etc.)
if let Some(stripped) = name.strip_prefix("tool.") {
return (stripped.to_string(), args.clone());
}
// Pattern 3: Normal tool call
(name.clone(), args.clone())
}
} }
#[async_trait] #[async_trait]
@ -76,27 +234,96 @@ impl Provider for OllamaProvider {
content: message.to_string(), content: message.to_string(),
}); });
let request = ChatRequest { let response = self.send_request(messages, model, temperature).await?;
model: model.to_string(),
messages,
stream: false,
options: Options { temperature },
};
let url = format!("{}/api/chat", self.base_url); // If model returned tool calls, format them for loop_.rs's parse_tool_calls
if !response.message.tool_calls.is_empty() {
let response = self.client.post(&url).json(&request).send().await?; tracing::debug!(
"Ollama returned {} tool call(s), formatting for loop parser",
if !response.status().is_success() { response.message.tool_calls.len()
let err = super::api_error("Ollama", response).await; );
anyhow::bail!("{err}. Is Ollama running? (brew install ollama && ollama serve)"); return Ok(self.format_tool_calls_for_loop(&response.message.tool_calls));
} }
let chat_response: ApiChatResponse = response.json().await?; // Plain text response
Ok(chat_response.message.content) let content = response.message.content;
// Handle edge case: model returned only "thinking" with no content or tool calls
if content.is_empty() {
if let Some(thinking) = &response.message.thinking {
tracing::warn!(
"Ollama returned empty content with only thinking: '{}'. Model may have stopped prematurely.",
if thinking.len() > 100 { &thinking[..100] } else { thinking }
);
return Ok(format!(
"I was thinking about this: {}... but I didn't complete my response. Could you try asking again?",
if thinking.len() > 200 { &thinking[..200] } else { thinking }
));
}
tracing::warn!("Ollama returned empty content with no tool calls");
}
Ok(content)
}
async fn chat_with_history(
&self,
messages: &[crate::providers::ChatMessage],
model: &str,
temperature: f64,
) -> anyhow::Result<String> {
let api_messages: Vec<Message> = messages
.iter()
.map(|m| Message {
role: m.role.clone(),
content: m.content.clone(),
})
.collect();
let response = self.send_request(api_messages, model, temperature).await?;
// If model returned tool calls, format them for loop_.rs's parse_tool_calls
if !response.message.tool_calls.is_empty() {
tracing::debug!(
"Ollama returned {} tool call(s), formatting for loop parser",
response.message.tool_calls.len()
);
return Ok(self.format_tool_calls_for_loop(&response.message.tool_calls));
}
// Plain text response
let content = response.message.content;
// Handle edge case: model returned only "thinking" with no content or tool calls
// This is a model quirk - it stopped after reasoning without producing output
if content.is_empty() {
if let Some(thinking) = &response.message.thinking {
tracing::warn!(
"Ollama returned empty content with only thinking: '{}'. Model may have stopped prematurely.",
if thinking.len() > 100 { &thinking[..100] } else { thinking }
);
// Return a message indicating the model's thought process but no action
return Ok(format!(
"I was thinking about this: {}... but I didn't complete my response. Could you try asking again?",
if thinking.len() > 200 { &thinking[..200] } else { thinking }
));
}
tracing::warn!("Ollama returned empty content with no tool calls");
}
Ok(content)
}
fn supports_native_tools(&self) -> bool {
// Return false since loop_.rs uses XML-style tool parsing via system prompt
// The model may return native tool_calls but we convert them to JSON format
// that parse_tool_calls() understands
false
} }
} }
// ─── Tests ────────────────────────────────────────────────────────────────────
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use super::*; use super::*;
@ -125,46 +352,6 @@ mod tests {
assert_eq!(p.base_url, ""); assert_eq!(p.base_url, "");
} }
#[test]
fn request_serializes_with_system() {
let req = ChatRequest {
model: "llama3".to_string(),
messages: vec![
Message {
role: "system".to_string(),
content: "You are ZeroClaw".to_string(),
},
Message {
role: "user".to_string(),
content: "hello".to_string(),
},
],
stream: false,
options: Options { temperature: 0.7 },
};
let json = serde_json::to_string(&req).unwrap();
assert!(json.contains("\"stream\":false"));
assert!(json.contains("llama3"));
assert!(json.contains("system"));
assert!(json.contains("\"temperature\":0.7"));
}
#[test]
fn request_serializes_without_system() {
let req = ChatRequest {
model: "mistral".to_string(),
messages: vec![Message {
role: "user".to_string(),
content: "test".to_string(),
}],
stream: false,
options: Options { temperature: 0.0 },
};
let json = serde_json::to_string(&req).unwrap();
assert!(!json.contains("\"role\":\"system\""));
assert!(json.contains("mistral"));
}
#[test] #[test]
fn response_deserializes() { fn response_deserializes() {
let json = r#"{"message":{"role":"assistant","content":"Hello from Ollama!"}}"#; let json = r#"{"message":{"role":"assistant","content":"Hello from Ollama!"}}"#;
@ -180,9 +367,98 @@ mod tests {
} }
#[test] #[test]
fn response_with_multiline() { fn response_with_missing_content_defaults_to_empty() {
let json = r#"{"message":{"role":"assistant","content":"line1\nline2\nline3"}}"#; let json = r#"{"message":{"role":"assistant"}}"#;
let resp: ApiChatResponse = serde_json::from_str(json).unwrap(); let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
assert!(resp.message.content.contains("line1")); assert!(resp.message.content.is_empty());
}
#[test]
fn response_with_thinking_field_extracts_content() {
let json =
r#"{"message":{"role":"assistant","content":"hello","thinking":"internal reasoning"}}"#;
let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
assert_eq!(resp.message.content, "hello");
}
#[test]
fn response_with_tool_calls_parses_correctly() {
let json = r#"{"message":{"role":"assistant","content":"","tool_calls":[{"id":"call_123","function":{"name":"shell","arguments":{"command":"date"}}}]}}"#;
let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
assert!(resp.message.content.is_empty());
assert_eq!(resp.message.tool_calls.len(), 1);
assert_eq!(resp.message.tool_calls[0].function.name, "shell");
}
#[test]
fn extract_tool_name_handles_nested_tool_call() {
let provider = OllamaProvider::new(None);
let tc = OllamaToolCall {
id: Some("call_123".into()),
function: OllamaFunction {
name: "tool_call".into(),
arguments: serde_json::json!({
"name": "shell",
"arguments": {"command": "date"}
}),
},
};
let (name, args) = provider.extract_tool_name_and_args(&tc);
assert_eq!(name, "shell");
assert_eq!(args.get("command").unwrap(), "date");
}
#[test]
fn extract_tool_name_handles_prefixed_name() {
let provider = OllamaProvider::new(None);
let tc = OllamaToolCall {
id: Some("call_123".into()),
function: OllamaFunction {
name: "tool.shell".into(),
arguments: serde_json::json!({"command": "ls"}),
},
};
let (name, args) = provider.extract_tool_name_and_args(&tc);
assert_eq!(name, "shell");
assert_eq!(args.get("command").unwrap(), "ls");
}
#[test]
fn extract_tool_name_handles_normal_call() {
let provider = OllamaProvider::new(None);
let tc = OllamaToolCall {
id: Some("call_123".into()),
function: OllamaFunction {
name: "file_read".into(),
arguments: serde_json::json!({"path": "/tmp/test"}),
},
};
let (name, args) = provider.extract_tool_name_and_args(&tc);
assert_eq!(name, "file_read");
assert_eq!(args.get("path").unwrap(), "/tmp/test");
}
#[test]
fn format_tool_calls_produces_valid_json() {
let provider = OllamaProvider::new(None);
let tool_calls = vec![OllamaToolCall {
id: Some("call_abc".into()),
function: OllamaFunction {
name: "shell".into(),
arguments: serde_json::json!({"command": "date"}),
},
}];
let formatted = provider.format_tool_calls_for_loop(&tool_calls);
let parsed: serde_json::Value = serde_json::from_str(&formatted).unwrap();
assert!(parsed.get("tool_calls").is_some());
let calls = parsed.get("tool_calls").unwrap().as_array().unwrap();
assert_eq!(calls.len(), 1);
let func = calls[0].get("function").unwrap();
assert_eq!(func.get("name").unwrap(), "shell");
// arguments should be a string (JSON-encoded)
assert!(func.get("arguments").unwrap().is_string());
} }
} }

View file

@ -8,7 +8,7 @@ use reqwest::Client;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
pub struct OpenAiProvider { pub struct OpenAiProvider {
api_key: Option<String>, credential: Option<String>,
client: Client, client: Client,
} }
@ -110,9 +110,9 @@ struct NativeResponseMessage {
} }
impl OpenAiProvider { impl OpenAiProvider {
pub fn new(api_key: Option<&str>) -> Self { pub fn new(credential: Option<&str>) -> Self {
Self { Self {
api_key: api_key.map(ToString::to_string), credential: credential.map(ToString::to_string),
client: Client::builder() client: Client::builder()
.timeout(std::time::Duration::from_secs(120)) .timeout(std::time::Duration::from_secs(120))
.connect_timeout(std::time::Duration::from_secs(10)) .connect_timeout(std::time::Duration::from_secs(10))
@ -232,7 +232,7 @@ impl Provider for OpenAiProvider {
model: &str, model: &str,
temperature: f64, temperature: f64,
) -> anyhow::Result<String> { ) -> anyhow::Result<String> {
let api_key = self.api_key.as_ref().ok_or_else(|| { let credential = self.credential.as_ref().ok_or_else(|| {
anyhow::anyhow!("OpenAI API key not set. Set OPENAI_API_KEY or edit config.toml.") anyhow::anyhow!("OpenAI API key not set. Set OPENAI_API_KEY or edit config.toml.")
})?; })?;
@ -259,7 +259,7 @@ impl Provider for OpenAiProvider {
let response = self let response = self
.client .client
.post("https://api.openai.com/v1/chat/completions") .post("https://api.openai.com/v1/chat/completions")
.header("Authorization", format!("Bearer {api_key}")) .header("Authorization", format!("Bearer {credential}"))
.json(&request) .json(&request)
.send() .send()
.await?; .await?;
@ -284,7 +284,7 @@ impl Provider for OpenAiProvider {
model: &str, model: &str,
temperature: f64, temperature: f64,
) -> anyhow::Result<ProviderChatResponse> { ) -> anyhow::Result<ProviderChatResponse> {
let api_key = self.api_key.as_ref().ok_or_else(|| { let credential = self.credential.as_ref().ok_or_else(|| {
anyhow::anyhow!("OpenAI API key not set. Set OPENAI_API_KEY or edit config.toml.") anyhow::anyhow!("OpenAI API key not set. Set OPENAI_API_KEY or edit config.toml.")
})?; })?;
@ -300,7 +300,7 @@ impl Provider for OpenAiProvider {
let response = self let response = self
.client .client
.post("https://api.openai.com/v1/chat/completions") .post("https://api.openai.com/v1/chat/completions")
.header("Authorization", format!("Bearer {api_key}")) .header("Authorization", format!("Bearer {credential}"))
.json(&native_request) .json(&native_request)
.send() .send()
.await?; .await?;
@ -330,20 +330,20 @@ mod tests {
#[test] #[test]
fn creates_with_key() { fn creates_with_key() {
let p = OpenAiProvider::new(Some("sk-proj-abc123")); let p = OpenAiProvider::new(Some("openai-test-credential"));
assert_eq!(p.api_key.as_deref(), Some("sk-proj-abc123")); assert_eq!(p.credential.as_deref(), Some("openai-test-credential"));
} }
#[test] #[test]
fn creates_without_key() { fn creates_without_key() {
let p = OpenAiProvider::new(None); let p = OpenAiProvider::new(None);
assert!(p.api_key.is_none()); assert!(p.credential.is_none());
} }
#[test] #[test]
fn creates_with_empty_key() { fn creates_with_empty_key() {
let p = OpenAiProvider::new(Some("")); let p = OpenAiProvider::new(Some(""));
assert_eq!(p.api_key.as_deref(), Some("")); assert_eq!(p.credential.as_deref(), Some(""));
} }
#[tokio::test] #[tokio::test]

View file

@ -8,7 +8,7 @@ use reqwest::Client;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
pub struct OpenRouterProvider { pub struct OpenRouterProvider {
api_key: Option<String>, credential: Option<String>,
client: Client, client: Client,
} }
@ -110,9 +110,9 @@ struct NativeResponseMessage {
} }
impl OpenRouterProvider { impl OpenRouterProvider {
pub fn new(api_key: Option<&str>) -> Self { pub fn new(credential: Option<&str>) -> Self {
Self { Self {
api_key: api_key.map(ToString::to_string), credential: credential.map(ToString::to_string),
client: Client::builder() client: Client::builder()
.timeout(std::time::Duration::from_secs(120)) .timeout(std::time::Duration::from_secs(120))
.connect_timeout(std::time::Duration::from_secs(10)) .connect_timeout(std::time::Duration::from_secs(10))
@ -232,10 +232,10 @@ impl Provider for OpenRouterProvider {
async fn warmup(&self) -> anyhow::Result<()> { async fn warmup(&self) -> anyhow::Result<()> {
// Hit a lightweight endpoint to establish TLS + HTTP/2 connection pool. // Hit a lightweight endpoint to establish TLS + HTTP/2 connection pool.
// This prevents the first real chat request from timing out on cold start. // This prevents the first real chat request from timing out on cold start.
if let Some(api_key) = self.api_key.as_ref() { if let Some(credential) = self.credential.as_ref() {
self.client self.client
.get("https://openrouter.ai/api/v1/auth/key") .get("https://openrouter.ai/api/v1/auth/key")
.header("Authorization", format!("Bearer {api_key}")) .header("Authorization", format!("Bearer {credential}"))
.send() .send()
.await? .await?
.error_for_status()?; .error_for_status()?;
@ -250,7 +250,7 @@ impl Provider for OpenRouterProvider {
model: &str, model: &str,
temperature: f64, temperature: f64,
) -> anyhow::Result<String> { ) -> anyhow::Result<String> {
let api_key = self.api_key.as_ref() let credential = self.credential.as_ref()
.ok_or_else(|| anyhow::anyhow!("OpenRouter API key not set. Run `zeroclaw onboard` or set OPENROUTER_API_KEY env var."))?; .ok_or_else(|| anyhow::anyhow!("OpenRouter API key not set. Run `zeroclaw onboard` or set OPENROUTER_API_KEY env var."))?;
let mut messages = Vec::new(); let mut messages = Vec::new();
@ -276,7 +276,7 @@ impl Provider for OpenRouterProvider {
let response = self let response = self
.client .client
.post("https://openrouter.ai/api/v1/chat/completions") .post("https://openrouter.ai/api/v1/chat/completions")
.header("Authorization", format!("Bearer {api_key}")) .header("Authorization", format!("Bearer {credential}"))
.header( .header(
"HTTP-Referer", "HTTP-Referer",
"https://github.com/theonlyhennygod/zeroclaw", "https://github.com/theonlyhennygod/zeroclaw",
@ -306,7 +306,7 @@ impl Provider for OpenRouterProvider {
model: &str, model: &str,
temperature: f64, temperature: f64,
) -> anyhow::Result<String> { ) -> anyhow::Result<String> {
let api_key = self.api_key.as_ref() let credential = self.credential.as_ref()
.ok_or_else(|| anyhow::anyhow!("OpenRouter API key not set. Run `zeroclaw onboard` or set OPENROUTER_API_KEY env var."))?; .ok_or_else(|| anyhow::anyhow!("OpenRouter API key not set. Run `zeroclaw onboard` or set OPENROUTER_API_KEY env var."))?;
let api_messages: Vec<Message> = messages let api_messages: Vec<Message> = messages
@ -326,7 +326,7 @@ impl Provider for OpenRouterProvider {
let response = self let response = self
.client .client
.post("https://openrouter.ai/api/v1/chat/completions") .post("https://openrouter.ai/api/v1/chat/completions")
.header("Authorization", format!("Bearer {api_key}")) .header("Authorization", format!("Bearer {credential}"))
.header( .header(
"HTTP-Referer", "HTTP-Referer",
"https://github.com/theonlyhennygod/zeroclaw", "https://github.com/theonlyhennygod/zeroclaw",
@ -356,7 +356,7 @@ impl Provider for OpenRouterProvider {
model: &str, model: &str,
temperature: f64, temperature: f64,
) -> anyhow::Result<ProviderChatResponse> { ) -> anyhow::Result<ProviderChatResponse> {
let api_key = self.api_key.as_ref().ok_or_else(|| { let credential = self.credential.as_ref().ok_or_else(|| {
anyhow::anyhow!( anyhow::anyhow!(
"OpenRouter API key not set. Run `zeroclaw onboard` or set OPENROUTER_API_KEY env var." "OpenRouter API key not set. Run `zeroclaw onboard` or set OPENROUTER_API_KEY env var."
) )
@ -374,7 +374,7 @@ impl Provider for OpenRouterProvider {
let response = self let response = self
.client .client
.post("https://openrouter.ai/api/v1/chat/completions") .post("https://openrouter.ai/api/v1/chat/completions")
.header("Authorization", format!("Bearer {api_key}")) .header("Authorization", format!("Bearer {credential}"))
.header( .header(
"HTTP-Referer", "HTTP-Referer",
"https://github.com/theonlyhennygod/zeroclaw", "https://github.com/theonlyhennygod/zeroclaw",
@ -409,7 +409,7 @@ impl Provider for OpenRouterProvider {
model: &str, model: &str,
temperature: f64, temperature: f64,
) -> anyhow::Result<ProviderChatResponse> { ) -> anyhow::Result<ProviderChatResponse> {
let api_key = self.api_key.as_ref().ok_or_else(|| { let credential = self.credential.as_ref().ok_or_else(|| {
anyhow::anyhow!( anyhow::anyhow!(
"OpenRouter API key not set. Run `zeroclaw onboard` or set OPENROUTER_API_KEY env var." "OpenRouter API key not set. Run `zeroclaw onboard` or set OPENROUTER_API_KEY env var."
) )
@ -462,7 +462,7 @@ impl Provider for OpenRouterProvider {
let response = self let response = self
.client .client
.post("https://openrouter.ai/api/v1/chat/completions") .post("https://openrouter.ai/api/v1/chat/completions")
.header("Authorization", format!("Bearer {api_key}")) .header("Authorization", format!("Bearer {credential}"))
.header( .header(
"HTTP-Referer", "HTTP-Referer",
"https://github.com/theonlyhennygod/zeroclaw", "https://github.com/theonlyhennygod/zeroclaw",
@ -494,14 +494,17 @@ mod tests {
#[test] #[test]
fn creates_with_key() { fn creates_with_key() {
let provider = OpenRouterProvider::new(Some("sk-or-123")); let provider = OpenRouterProvider::new(Some("openrouter-test-credential"));
assert_eq!(provider.api_key.as_deref(), Some("sk-or-123")); assert_eq!(
provider.credential.as_deref(),
Some("openrouter-test-credential")
);
} }
#[test] #[test]
fn creates_without_key() { fn creates_without_key() {
let provider = OpenRouterProvider::new(None); let provider = OpenRouterProvider::new(None);
assert!(provider.api_key.is_none()); assert!(provider.credential.is_none());
} }
#[tokio::test] #[tokio::test]

View file

@ -144,8 +144,8 @@ impl Provider for ReliableProvider {
async fn warmup(&self) -> anyhow::Result<()> { async fn warmup(&self) -> anyhow::Result<()> {
for (name, provider) in &self.providers { for (name, provider) in &self.providers {
tracing::info!(provider = name, "Warming up provider connection pool"); tracing::info!(provider = name, "Warming up provider connection pool");
if let Err(e) = provider.warmup().await { if provider.warmup().await.is_err() {
tracing::warn!(provider = name, "Warmup failed (non-fatal): {e}"); tracing::warn!(provider = name, "Warmup failed (non-fatal)");
} }
} }
Ok(()) Ok(())
@ -186,8 +186,15 @@ impl Provider for ReliableProvider {
let non_retryable = is_non_retryable(&e); let non_retryable = is_non_retryable(&e);
let rate_limited = is_rate_limited(&e); let rate_limited = is_rate_limited(&e);
let failure_reason = if rate_limited {
"rate_limited"
} else if non_retryable {
"non_retryable"
} else {
"retryable"
};
failures.push(format!( failures.push(format!(
"{provider_name}/{current_model} attempt {}/{}: {e}", "{provider_name}/{current_model} attempt {}/{}: {failure_reason}",
attempt + 1, attempt + 1,
self.max_retries + 1 self.max_retries + 1
)); ));
@ -284,8 +291,15 @@ impl Provider for ReliableProvider {
let non_retryable = is_non_retryable(&e); let non_retryable = is_non_retryable(&e);
let rate_limited = is_rate_limited(&e); let rate_limited = is_rate_limited(&e);
let failure_reason = if rate_limited {
"rate_limited"
} else if non_retryable {
"non_retryable"
} else {
"retryable"
};
failures.push(format!( failures.push(format!(
"{provider_name}/{current_model} attempt {}/{}: {e}", "{provider_name}/{current_model} attempt {}/{}: {failure_reason}",
attempt + 1, attempt + 1,
self.max_retries + 1 self.max_retries + 1
)); ));

View file

@ -193,6 +193,13 @@ pub enum StreamError {
#[async_trait] #[async_trait]
pub trait Provider: Send + Sync { pub trait Provider: Send + Sync {
/// Query provider capabilities.
///
/// Default implementation returns minimal capabilities (no native tool calling).
/// Providers should override this to declare their actual capabilities.
fn capabilities(&self) -> ProviderCapabilities {
ProviderCapabilities::default()
}
/// Simple one-shot chat (single user message, no explicit system prompt). /// Simple one-shot chat (single user message, no explicit system prompt).
/// ///
/// This is the preferred API for non-agentic direct interactions. /// This is the preferred API for non-agentic direct interactions.
@ -256,7 +263,7 @@ pub trait Provider: Send + Sync {
/// Whether provider supports native tool calls over API. /// Whether provider supports native tool calls over API.
fn supports_native_tools(&self) -> bool { fn supports_native_tools(&self) -> bool {
false self.capabilities().native_tool_calling
} }
/// Warm up the HTTP connection pool (TLS handshake, DNS, HTTP/2 setup). /// Warm up the HTTP connection pool (TLS handshake, DNS, HTTP/2 setup).
@ -336,6 +343,27 @@ pub trait Provider: Send + Sync {
mod tests { mod tests {
use super::*; use super::*;
struct CapabilityMockProvider;
#[async_trait]
impl Provider for CapabilityMockProvider {
fn capabilities(&self) -> ProviderCapabilities {
ProviderCapabilities {
native_tool_calling: true,
}
}
async fn chat_with_system(
&self,
_system_prompt: Option<&str>,
_message: &str,
_model: &str,
_temperature: f64,
) -> anyhow::Result<String> {
Ok("ok".into())
}
}
#[test] #[test]
fn chat_message_constructors() { fn chat_message_constructors() {
let sys = ChatMessage::system("Be helpful"); let sys = ChatMessage::system("Be helpful");
@ -398,4 +426,32 @@ mod tests {
let json = serde_json::to_string(&tool_result).unwrap(); let json = serde_json::to_string(&tool_result).unwrap();
assert!(json.contains("\"type\":\"ToolResults\"")); assert!(json.contains("\"type\":\"ToolResults\""));
} }
#[test]
fn provider_capabilities_default() {
let caps = ProviderCapabilities::default();
assert!(!caps.native_tool_calling);
}
#[test]
fn provider_capabilities_equality() {
let caps1 = ProviderCapabilities {
native_tool_calling: true,
};
let caps2 = ProviderCapabilities {
native_tool_calling: true,
};
let caps3 = ProviderCapabilities {
native_tool_calling: false,
};
assert_eq!(caps1, caps2);
assert_ne!(caps1, caps3);
}
#[test]
fn supports_native_tools_reflects_capabilities_default_mapping() {
let provider = CapabilityMockProvider;
assert!(provider.supports_native_tools());
}
} }

View file

@ -81,14 +81,17 @@ mod tests {
#[test] #[test]
fn bubblewrap_sandbox_name() { fn bubblewrap_sandbox_name() {
assert_eq!(BubblewrapSandbox.name(), "bubblewrap"); let sandbox = BubblewrapSandbox;
assert_eq!(sandbox.name(), "bubblewrap");
} }
#[test] #[test]
fn bubblewrap_is_available_only_if_installed() { fn bubblewrap_is_available_only_if_installed() {
// Result depends on whether bwrap is installed // Result depends on whether bwrap is installed
let available = BubblewrapSandbox::is_available(); let sandbox = BubblewrapSandbox;
let _available = sandbox.is_available();
// Either way, the name should still work // Either way, the name should still work
assert_eq!(BubblewrapSandbox.name(), "bubblewrap"); assert_eq!(sandbox.name(), "bubblewrap");
} }
} }

View file

@ -184,7 +184,7 @@ fn generate_token() -> String {
use rand::RngCore; use rand::RngCore;
let mut bytes = [0u8; 32]; let mut bytes = [0u8; 32];
rand::thread_rng().fill_bytes(&mut bytes); rand::thread_rng().fill_bytes(&mut bytes);
format!("zc_{}", hex::encode(&bytes)) format!("zc_{}", hex::encode(bytes))
} }
/// SHA-256 hash a bearer token for storage. Returns lowercase hex. /// SHA-256 hash a bearer token for storage. Returns lowercase hex.

View file

@ -343,6 +343,7 @@ impl SecurityPolicy {
/// validates each sub-command against the allowlist /// validates each sub-command against the allowlist
/// - Blocks single `&` background chaining (`&&` remains supported) /// - Blocks single `&` background chaining (`&&` remains supported)
/// - Blocks output redirections (`>`, `>>`) that could write outside workspace /// - Blocks output redirections (`>`, `>>`) that could write outside workspace
/// - Blocks dangerous arguments (e.g. `find -exec`, `git config`)
pub fn is_command_allowed(&self, command: &str) -> bool { pub fn is_command_allowed(&self, command: &str) -> bool {
if self.autonomy == AutonomyLevel::ReadOnly { if self.autonomy == AutonomyLevel::ReadOnly {
return false; return false;
@ -350,7 +351,12 @@ impl SecurityPolicy {
// Block subshell/expansion operators — these allow hiding arbitrary // Block subshell/expansion operators — these allow hiding arbitrary
// commands inside an allowed command (e.g. `echo $(rm -rf /)`) // commands inside an allowed command (e.g. `echo $(rm -rf /)`)
if command.contains('`') || command.contains("$(") || command.contains("${") { if command.contains('`')
|| command.contains("$(")
|| command.contains("${")
|| command.contains("<(")
|| command.contains(">(")
{
return false; return false;
} }
@ -359,6 +365,15 @@ impl SecurityPolicy {
return false; return false;
} }
// Block `tee` — it can write to arbitrary files, bypassing the
// redirect check above (e.g. `echo secret | tee /etc/crontab`)
if command
.split_whitespace()
.any(|w| w == "tee" || w.ends_with("/tee"))
{
return false;
}
// Block background command chaining (`&`), which can hide extra // Block background command chaining (`&`), which can hide extra
// sub-commands and outlive timeout expectations. Keep `&&` allowed. // sub-commands and outlive timeout expectations. Keep `&&` allowed.
if contains_single_ampersand(command) { if contains_single_ampersand(command) {
@ -384,13 +399,9 @@ impl SecurityPolicy {
// Strip leading env var assignments (e.g. FOO=bar cmd) // Strip leading env var assignments (e.g. FOO=bar cmd)
let cmd_part = skip_env_assignments(segment); let cmd_part = skip_env_assignments(segment);
let base_cmd = cmd_part let mut words = cmd_part.split_whitespace();
.split_whitespace() let base_raw = words.next().unwrap_or("");
.next() let base_cmd = base_raw.rsplit('/').next().unwrap_or("");
.unwrap_or("")
.rsplit('/')
.next()
.unwrap_or("");
if base_cmd.is_empty() { if base_cmd.is_empty() {
continue; continue;
@ -403,6 +414,12 @@ impl SecurityPolicy {
{ {
return false; return false;
} }
// Validate arguments for the command
let args: Vec<String> = words.map(|w| w.to_ascii_lowercase()).collect();
if !self.is_args_safe(base_cmd, &args) {
return false;
}
} }
// At least one command must be present // At least one command must be present
@ -414,6 +431,29 @@ impl SecurityPolicy {
has_cmd has_cmd
} }
/// Check for dangerous arguments that allow sub-command execution.
fn is_args_safe(&self, base: &str, args: &[String]) -> bool {
let base = base.to_ascii_lowercase();
match base.as_str() {
"find" => {
// find -exec and find -ok allow arbitrary command execution
!args.iter().any(|arg| arg == "-exec" || arg == "-ok")
}
"git" => {
// git config, alias, and -c can be used to set dangerous options
// (e.g. git config core.editor "rm -rf /")
!args.iter().any(|arg| {
arg == "config"
|| arg.starts_with("config.")
|| arg == "alias"
|| arg.starts_with("alias.")
|| arg == "-c"
})
}
_ => true,
}
}
/// Check if a file path is allowed (no path traversal, within workspace) /// Check if a file path is allowed (no path traversal, within workspace)
pub fn is_path_allowed(&self, path: &str) -> bool { pub fn is_path_allowed(&self, path: &str) -> bool {
// Block null bytes (can truncate paths in C-backed syscalls) // Block null bytes (can truncate paths in C-backed syscalls)
@ -982,12 +1022,43 @@ mod tests {
assert!(!p.is_command_allowed("ls >> /tmp/exfil.txt")); assert!(!p.is_command_allowed("ls >> /tmp/exfil.txt"));
} }
#[test]
fn command_argument_injection_blocked() {
let p = default_policy();
// find -exec is a common bypass
assert!(!p.is_command_allowed("find . -exec rm -rf {} +"));
assert!(!p.is_command_allowed("find / -ok cat {} \\;"));
// git config/alias can execute commands
assert!(!p.is_command_allowed("git config core.editor \"rm -rf /\""));
assert!(!p.is_command_allowed("git alias.st status"));
assert!(!p.is_command_allowed("git -c core.editor=calc.exe commit"));
// Legitimate commands should still work
assert!(p.is_command_allowed("find . -name '*.txt'"));
assert!(p.is_command_allowed("git status"));
assert!(p.is_command_allowed("git add ."));
}
#[test] #[test]
fn command_injection_dollar_brace_blocked() { fn command_injection_dollar_brace_blocked() {
let p = default_policy(); let p = default_policy();
assert!(!p.is_command_allowed("echo ${IFS}cat${IFS}/etc/passwd")); assert!(!p.is_command_allowed("echo ${IFS}cat${IFS}/etc/passwd"));
} }
#[test]
fn command_injection_tee_blocked() {
let p = default_policy();
assert!(!p.is_command_allowed("echo secret | tee /etc/crontab"));
assert!(!p.is_command_allowed("ls | /usr/bin/tee outfile"));
assert!(!p.is_command_allowed("tee file.txt"));
}
#[test]
fn command_injection_process_substitution_blocked() {
let p = default_policy();
assert!(!p.is_command_allowed("cat <(echo pwned)"));
assert!(!p.is_command_allowed("ls >(cat /etc/passwd)"));
}
#[test] #[test]
fn command_env_var_prefix_with_allowed_cmd() { fn command_env_var_prefix_with_allowed_cmd() {
let p = default_policy(); let p = default_policy();

View file

@ -854,7 +854,6 @@ impl BrowserTool {
} }
} }
#[allow(clippy::too_many_lines)]
#[async_trait] #[async_trait]
impl Tool for BrowserTool { impl Tool for BrowserTool {
fn name(&self) -> &str { fn name(&self) -> &str {
@ -1031,165 +1030,21 @@ impl Tool for BrowserTool {
return self.execute_computer_use_action(action_str, &args).await; return self.execute_computer_use_action(action_str, &args).await;
} }
let action = match action_str { if is_computer_use_only_action(action_str) {
"open" => { return Ok(ToolResult {
let url = args success: false,
.get("url") output: String::new(),
.and_then(|v| v.as_str()) error: Some(unavailable_action_for_backend_error(action_str, backend)),
.ok_or_else(|| anyhow::anyhow!("Missing 'url' for open action"))?; });
BrowserAction::Open { url: url.into() } }
}
"snapshot" => BrowserAction::Snapshot { let action = match parse_browser_action(action_str, &args) {
interactive_only: args Ok(a) => a,
.get("interactive_only") Err(e) => {
.and_then(serde_json::Value::as_bool)
.unwrap_or(true), // Default to interactive for AI
compact: args
.get("compact")
.and_then(serde_json::Value::as_bool)
.unwrap_or(true),
depth: args
.get("depth")
.and_then(serde_json::Value::as_u64)
.map(|d| u32::try_from(d).unwrap_or(u32::MAX)),
},
"click" => {
let selector = args
.get("selector")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing 'selector' for click"))?;
BrowserAction::Click {
selector: selector.into(),
}
}
"fill" => {
let selector = args
.get("selector")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing 'selector' for fill"))?;
let value = args
.get("value")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing 'value' for fill"))?;
BrowserAction::Fill {
selector: selector.into(),
value: value.into(),
}
}
"type" => {
let selector = args
.get("selector")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing 'selector' for type"))?;
let text = args
.get("text")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing 'text' for type"))?;
BrowserAction::Type {
selector: selector.into(),
text: text.into(),
}
}
"get_text" => {
let selector = args
.get("selector")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing 'selector' for get_text"))?;
BrowserAction::GetText {
selector: selector.into(),
}
}
"get_title" => BrowserAction::GetTitle,
"get_url" => BrowserAction::GetUrl,
"screenshot" => BrowserAction::Screenshot {
path: args.get("path").and_then(|v| v.as_str()).map(String::from),
full_page: args
.get("full_page")
.and_then(serde_json::Value::as_bool)
.unwrap_or(false),
},
"wait" => BrowserAction::Wait {
selector: args
.get("selector")
.and_then(|v| v.as_str())
.map(String::from),
ms: args.get("ms").and_then(serde_json::Value::as_u64),
text: args.get("text").and_then(|v| v.as_str()).map(String::from),
},
"press" => {
let key = args
.get("key")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing 'key' for press"))?;
BrowserAction::Press { key: key.into() }
}
"hover" => {
let selector = args
.get("selector")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing 'selector' for hover"))?;
BrowserAction::Hover {
selector: selector.into(),
}
}
"scroll" => {
let direction = args
.get("direction")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing 'direction' for scroll"))?;
BrowserAction::Scroll {
direction: direction.into(),
pixels: args
.get("pixels")
.and_then(serde_json::Value::as_u64)
.map(|p| u32::try_from(p).unwrap_or(u32::MAX)),
}
}
"is_visible" => {
let selector = args
.get("selector")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing 'selector' for is_visible"))?;
BrowserAction::IsVisible {
selector: selector.into(),
}
}
"close" => BrowserAction::Close,
"find" => {
let by = args
.get("by")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing 'by' for find"))?;
let value = args
.get("value")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing 'value' for find"))?;
let action = args
.get("find_action")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing 'find_action' for find"))?;
BrowserAction::Find {
by: by.into(),
value: value.into(),
action: action.into(),
fill_value: args
.get("fill_value")
.and_then(|v| v.as_str())
.map(String::from),
}
}
_ => {
return Ok(ToolResult { return Ok(ToolResult {
success: false, success: false,
output: String::new(), output: String::new(),
error: Some(format!( error: Some(e.to_string()),
"Action '{action_str}' is unavailable for backend '{}'",
match backend {
ResolvedBackend::AgentBrowser => "agent_browser",
ResolvedBackend::RustNative => "rust_native",
ResolvedBackend::ComputerUse => "computer_use",
}
)),
}); });
} }
}; };
@ -1871,6 +1726,161 @@ mod native_backend {
} }
} }
// ── Action parsing ──────────────────────────────────────────────
/// Parse a JSON `args` object into a typed `BrowserAction`.
fn parse_browser_action(action_str: &str, args: &Value) -> anyhow::Result<BrowserAction> {
match action_str {
"open" => {
let url = args
.get("url")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing 'url' for open action"))?;
Ok(BrowserAction::Open { url: url.into() })
}
"snapshot" => Ok(BrowserAction::Snapshot {
interactive_only: args
.get("interactive_only")
.and_then(serde_json::Value::as_bool)
.unwrap_or(true),
compact: args
.get("compact")
.and_then(serde_json::Value::as_bool)
.unwrap_or(true),
depth: args
.get("depth")
.and_then(serde_json::Value::as_u64)
.map(|d| u32::try_from(d).unwrap_or(u32::MAX)),
}),
"click" => {
let selector = args
.get("selector")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing 'selector' for click"))?;
Ok(BrowserAction::Click {
selector: selector.into(),
})
}
"fill" => {
let selector = args
.get("selector")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing 'selector' for fill"))?;
let value = args
.get("value")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing 'value' for fill"))?;
Ok(BrowserAction::Fill {
selector: selector.into(),
value: value.into(),
})
}
"type" => {
let selector = args
.get("selector")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing 'selector' for type"))?;
let text = args
.get("text")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing 'text' for type"))?;
Ok(BrowserAction::Type {
selector: selector.into(),
text: text.into(),
})
}
"get_text" => {
let selector = args
.get("selector")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing 'selector' for get_text"))?;
Ok(BrowserAction::GetText {
selector: selector.into(),
})
}
"get_title" => Ok(BrowserAction::GetTitle),
"get_url" => Ok(BrowserAction::GetUrl),
"screenshot" => Ok(BrowserAction::Screenshot {
path: args.get("path").and_then(|v| v.as_str()).map(String::from),
full_page: args
.get("full_page")
.and_then(serde_json::Value::as_bool)
.unwrap_or(false),
}),
"wait" => Ok(BrowserAction::Wait {
selector: args
.get("selector")
.and_then(|v| v.as_str())
.map(String::from),
ms: args.get("ms").and_then(serde_json::Value::as_u64),
text: args.get("text").and_then(|v| v.as_str()).map(String::from),
}),
"press" => {
let key = args
.get("key")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing 'key' for press"))?;
Ok(BrowserAction::Press { key: key.into() })
}
"hover" => {
let selector = args
.get("selector")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing 'selector' for hover"))?;
Ok(BrowserAction::Hover {
selector: selector.into(),
})
}
"scroll" => {
let direction = args
.get("direction")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing 'direction' for scroll"))?;
Ok(BrowserAction::Scroll {
direction: direction.into(),
pixels: args
.get("pixels")
.and_then(serde_json::Value::as_u64)
.map(|p| u32::try_from(p).unwrap_or(u32::MAX)),
})
}
"is_visible" => {
let selector = args
.get("selector")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing 'selector' for is_visible"))?;
Ok(BrowserAction::IsVisible {
selector: selector.into(),
})
}
"close" => Ok(BrowserAction::Close),
"find" => {
let by = args
.get("by")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing 'by' for find"))?;
let value = args
.get("value")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing 'value' for find"))?;
let action = args
.get("find_action")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing 'find_action' for find"))?;
Ok(BrowserAction::Find {
by: by.into(),
value: value.into(),
action: action.into(),
fill_value: args
.get("fill_value")
.and_then(|v| v.as_str())
.map(String::from),
})
}
other => anyhow::bail!("Unsupported browser action: {other}"),
}
}
// ── Helper functions ───────────────────────────────────────────── // ── Helper functions ─────────────────────────────────────────────
fn is_supported_browser_action(action: &str) -> bool { fn is_supported_browser_action(action: &str) -> bool {
@ -1901,6 +1911,28 @@ fn is_supported_browser_action(action: &str) -> bool {
) )
} }
fn is_computer_use_only_action(action: &str) -> bool {
matches!(
action,
"mouse_move" | "mouse_click" | "mouse_drag" | "key_type" | "key_press" | "screen_capture"
)
}
fn backend_name(backend: ResolvedBackend) -> &'static str {
match backend {
ResolvedBackend::AgentBrowser => "agent_browser",
ResolvedBackend::RustNative => "rust_native",
ResolvedBackend::ComputerUse => "computer_use",
}
}
fn unavailable_action_for_backend_error(action: &str, backend: ResolvedBackend) -> String {
format!(
"Action '{action}' is unavailable for backend '{}'",
backend_name(backend)
)
}
fn normalize_domains(domains: Vec<String>) -> Vec<String> { fn normalize_domains(domains: Vec<String>) -> Vec<String> {
domains domains
.into_iter() .into_iter()
@ -2342,4 +2374,28 @@ mod tests {
let tool = BrowserTool::new(security, vec![], None); let tool = BrowserTool::new(security, vec![], None);
assert!(tool.validate_url("https://example.com").is_err()); assert!(tool.validate_url("https://example.com").is_err());
} }
#[test]
fn computer_use_only_action_detection_is_correct() {
assert!(is_computer_use_only_action("mouse_move"));
assert!(is_computer_use_only_action("mouse_click"));
assert!(is_computer_use_only_action("mouse_drag"));
assert!(is_computer_use_only_action("key_type"));
assert!(is_computer_use_only_action("key_press"));
assert!(is_computer_use_only_action("screen_capture"));
assert!(!is_computer_use_only_action("open"));
assert!(!is_computer_use_only_action("snapshot"));
}
#[test]
fn unavailable_action_error_preserves_backend_context() {
assert_eq!(
unavailable_action_for_backend_error("mouse_move", ResolvedBackend::AgentBrowser),
"Action 'mouse_move' is unavailable for backend 'agent_browser'"
);
assert_eq!(
unavailable_action_for_backend_error("mouse_move", ResolvedBackend::RustNative),
"Action 'mouse_move' is unavailable for backend 'rust_native'"
);
}
} }

View file

@ -112,12 +112,12 @@ impl ComposioTool {
action_name: &str, action_name: &str,
params: serde_json::Value, params: serde_json::Value,
entity_id: Option<&str>, entity_id: Option<&str>,
connected_account_id: Option<&str>, connected_account_ref: Option<&str>,
) -> anyhow::Result<serde_json::Value> { ) -> anyhow::Result<serde_json::Value> {
let tool_slug = normalize_tool_slug(action_name); let tool_slug = normalize_tool_slug(action_name);
match self match self
.execute_action_v3(&tool_slug, params.clone(), entity_id, connected_account_id) .execute_action_v3(&tool_slug, params.clone(), entity_id, connected_account_ref)
.await .await
{ {
Ok(result) => Ok(result), Ok(result) => Ok(result),
@ -130,21 +130,17 @@ impl ComposioTool {
} }
} }
async fn execute_action_v3( fn build_execute_action_v3_request(
&self,
tool_slug: &str, tool_slug: &str,
params: serde_json::Value, params: serde_json::Value,
entity_id: Option<&str>, entity_id: Option<&str>,
connected_account_id: Option<&str>, connected_account_ref: Option<&str>,
) -> anyhow::Result<serde_json::Value> { ) -> (String, serde_json::Value) {
let url = if let Some(connected_account_id) = connected_account_id let url = format!("{COMPOSIO_API_BASE_V3}/tools/{tool_slug}/execute");
.map(str::trim) let account_ref = connected_account_ref.and_then(|candidate| {
.filter(|id| !id.is_empty()) let trimmed_candidate = candidate.trim();
{ (!trimmed_candidate.is_empty()).then_some(trimmed_candidate)
format!("{COMPOSIO_API_BASE_V3}/tools/{tool_slug}/execute/{connected_account_id}") });
} else {
format!("{COMPOSIO_API_BASE_V3}/tools/{tool_slug}/execute")
};
let mut body = json!({ let mut body = json!({
"arguments": params, "arguments": params,
@ -153,6 +149,26 @@ impl ComposioTool {
if let Some(entity) = entity_id { if let Some(entity) = entity_id {
body["user_id"] = json!(entity); body["user_id"] = json!(entity);
} }
if let Some(account_ref) = account_ref {
body["connected_account_id"] = json!(account_ref);
}
(url, body)
}
async fn execute_action_v3(
&self,
tool_slug: &str,
params: serde_json::Value,
entity_id: Option<&str>,
connected_account_ref: Option<&str>,
) -> anyhow::Result<serde_json::Value> {
let (url, body) = Self::build_execute_action_v3_request(
tool_slug,
params,
entity_id,
connected_account_ref,
);
let resp = self let resp = self
.client .client
@ -474,11 +490,11 @@ impl Tool for ComposioTool {
})?; })?;
let params = args.get("params").cloned().unwrap_or(json!({})); let params = args.get("params").cloned().unwrap_or(json!({}));
let connected_account_id = let connected_account_ref =
args.get("connected_account_id").and_then(|v| v.as_str()); args.get("connected_account_id").and_then(|v| v.as_str());
match self match self
.execute_action(action_name, params, Some(entity_id), connected_account_id) .execute_action(action_name, params, Some(entity_id), connected_account_ref)
.await .await
{ {
Ok(result) => { Ok(result) => {
@ -594,9 +610,38 @@ async fn response_error(resp: reqwest::Response) -> String {
} }
if let Some(api_error) = extract_api_error_message(&body) { if let Some(api_error) = extract_api_error_message(&body) {
format!("HTTP {}: {api_error}", status.as_u16()) return format!(
"HTTP {}: {}",
status.as_u16(),
sanitize_error_message(&api_error)
);
}
format!("HTTP {}", status.as_u16())
}
fn sanitize_error_message(message: &str) -> String {
let mut sanitized = message.replace('\n', " ");
for marker in [
"connected_account_id",
"connectedAccountId",
"entity_id",
"entityId",
"user_id",
"userId",
] {
sanitized = sanitized.replace(marker, "[redacted]");
}
let max_chars = 240;
if sanitized.chars().count() <= max_chars {
sanitized
} else { } else {
format!("HTTP {}: {body}", status.as_u16()) let mut end = max_chars;
while end > 0 && !sanitized.is_char_boundary(end) {
end -= 1;
}
format!("{}...", &sanitized[..end])
} }
} }
@ -948,4 +993,40 @@ mod tests {
fn composio_api_base_url_is_v3() { fn composio_api_base_url_is_v3() {
assert_eq!(COMPOSIO_API_BASE_V3, "https://backend.composio.dev/api/v3"); assert_eq!(COMPOSIO_API_BASE_V3, "https://backend.composio.dev/api/v3");
} }
#[test]
fn build_execute_action_v3_request_uses_fixed_endpoint_and_body_account_id() {
let (url, body) = ComposioTool::build_execute_action_v3_request(
"gmail-send-email",
json!({"to": "test@example.com"}),
Some("workspace-user"),
Some("account-42"),
);
assert_eq!(
url,
"https://backend.composio.dev/api/v3/tools/gmail-send-email/execute"
);
assert_eq!(body["arguments"]["to"], json!("test@example.com"));
assert_eq!(body["user_id"], json!("workspace-user"));
assert_eq!(body["connected_account_id"], json!("account-42"));
}
#[test]
fn build_execute_action_v3_request_drops_blank_optional_fields() {
let (url, body) = ComposioTool::build_execute_action_v3_request(
"github-list-repos",
json!({}),
None,
Some(" "),
);
assert_eq!(
url,
"https://backend.composio.dev/api/v3/tools/github-list-repos/execute"
);
assert_eq!(body["arguments"], json!({}));
assert!(body.get("connected_account_id").is_none());
assert!(body.get("user_id").is_none());
}
} }

View file

@ -16,8 +16,8 @@ const DELEGATE_TIMEOUT_SECS: u64 = 120;
/// summarization) to purpose-built sub-agents. /// summarization) to purpose-built sub-agents.
pub struct DelegateTool { pub struct DelegateTool {
agents: Arc<HashMap<String, DelegateAgentConfig>>, agents: Arc<HashMap<String, DelegateAgentConfig>>,
/// Global API key fallback (from config.api_key) /// Global credential fallback (from config.api_key)
fallback_api_key: Option<String>, fallback_credential: Option<String>,
/// Depth at which this tool instance lives in the delegation chain. /// Depth at which this tool instance lives in the delegation chain.
depth: u32, depth: u32,
} }
@ -25,11 +25,11 @@ pub struct DelegateTool {
impl DelegateTool { impl DelegateTool {
pub fn new( pub fn new(
agents: HashMap<String, DelegateAgentConfig>, agents: HashMap<String, DelegateAgentConfig>,
fallback_api_key: Option<String>, fallback_credential: Option<String>,
) -> Self { ) -> Self {
Self { Self {
agents: Arc::new(agents), agents: Arc::new(agents),
fallback_api_key, fallback_credential,
depth: 0, depth: 0,
} }
} }
@ -39,12 +39,12 @@ impl DelegateTool {
/// their DelegateTool via this method with `depth: parent.depth + 1`. /// their DelegateTool via this method with `depth: parent.depth + 1`.
pub fn with_depth( pub fn with_depth(
agents: HashMap<String, DelegateAgentConfig>, agents: HashMap<String, DelegateAgentConfig>,
fallback_api_key: Option<String>, fallback_credential: Option<String>,
depth: u32, depth: u32,
) -> Self { ) -> Self {
Self { Self {
agents: Arc::new(agents), agents: Arc::new(agents),
fallback_api_key, fallback_credential,
depth, depth,
} }
} }
@ -165,13 +165,15 @@ impl Tool for DelegateTool {
} }
// Create provider for this agent // Create provider for this agent
let api_key = agent_config let provider_credential_owned = agent_config
.api_key .api_key
.as_deref() .clone()
.or(self.fallback_api_key.as_deref()); .or_else(|| self.fallback_credential.clone());
#[allow(clippy::option_as_ref_deref)]
let provider_credential = provider_credential_owned.as_ref().map(String::as_str);
let provider: Box<dyn Provider> = let provider: Box<dyn Provider> =
match providers::create_provider(&agent_config.provider, api_key) { match providers::create_provider(&agent_config.provider, provider_credential) {
Ok(p) => p, Ok(p) => p,
Err(e) => { Err(e) => {
return Ok(ToolResult { return Ok(ToolResult {
@ -268,7 +270,7 @@ mod tests {
provider: "openrouter".to_string(), provider: "openrouter".to_string(),
model: "anthropic/claude-sonnet-4-20250514".to_string(), model: "anthropic/claude-sonnet-4-20250514".to_string(),
system_prompt: None, system_prompt: None,
api_key: Some("sk-test".to_string()), api_key: Some("delegate-test-credential".to_string()),
temperature: None, temperature: None,
max_depth: 2, max_depth: 2,
}, },

View file

@ -28,13 +28,22 @@ impl GitOperationsTool {
if arg_lower.starts_with("--exec=") if arg_lower.starts_with("--exec=")
|| arg_lower.starts_with("--upload-pack=") || arg_lower.starts_with("--upload-pack=")
|| arg_lower.starts_with("--receive-pack=") || arg_lower.starts_with("--receive-pack=")
|| arg_lower.starts_with("--pager=")
|| arg_lower.starts_with("--editor=")
|| arg_lower == "--no-verify"
|| arg_lower.contains("$(") || arg_lower.contains("$(")
|| arg_lower.contains('`') || arg_lower.contains('`')
|| arg.contains('|') || arg.contains('|')
|| arg.contains(';') || arg.contains(';')
|| arg.contains('>')
{ {
anyhow::bail!("Blocked potentially dangerous git argument: {arg}"); anyhow::bail!("Blocked potentially dangerous git argument: {arg}");
} }
// Block `-c` config injection (exact match or `-c=...` prefix).
// This must not false-positive on `--cached` or `-cached`.
if arg_lower == "-c" || arg_lower.starts_with("-c=") {
anyhow::bail!("Blocked potentially dangerous git argument: {arg}");
}
result.push(arg.to_string()); result.push(arg.to_string());
} }
Ok(result) Ok(result)
@ -129,6 +138,9 @@ impl GitOperationsTool {
.and_then(|v| v.as_bool()) .and_then(|v| v.as_bool())
.unwrap_or(false); .unwrap_or(false);
// Validate files argument against injection patterns
self.sanitize_git_args(files)?;
let mut git_args = vec!["diff", "--unified=3"]; let mut git_args = vec!["diff", "--unified=3"];
if cached { if cached {
git_args.push("--cached"); git_args.push("--cached");
@ -267,6 +279,14 @@ impl GitOperationsTool {
}) })
} }
fn truncate_commit_message(message: &str) -> String {
if message.chars().count() > 2000 {
format!("{}...", message.chars().take(1997).collect::<String>())
} else {
message.to_string()
}
}
async fn git_commit(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> { async fn git_commit(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
let message = args let message = args
.get("message") .get("message")
@ -286,11 +306,7 @@ impl GitOperationsTool {
} }
// Limit message length // Limit message length
let message = if sanitized.len() > 2000 { let message = Self::truncate_commit_message(&sanitized);
format!("{}...", &sanitized[..1997])
} else {
sanitized
};
let output = self.run_git_command(&["commit", "-m", &message]).await; let output = self.run_git_command(&["commit", "-m", &message]).await;
@ -314,6 +330,9 @@ impl GitOperationsTool {
.and_then(|v| v.as_str()) .and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing 'paths' parameter"))?; .ok_or_else(|| anyhow::anyhow!("Missing 'paths' parameter"))?;
// Validate paths against injection patterns
self.sanitize_git_args(paths)?;
let output = self.run_git_command(&["add", "--", paths]).await; let output = self.run_git_command(&["add", "--", paths]).await;
match output { match output {
@ -574,6 +593,52 @@ mod tests {
assert!(tool.sanitize_git_args("arg; rm file").is_err()); assert!(tool.sanitize_git_args("arg; rm file").is_err());
} }
#[test]
fn sanitize_git_blocks_pager_editor_injection() {
let tmp = TempDir::new().unwrap();
let tool = test_tool(tmp.path());
assert!(tool.sanitize_git_args("--pager=less").is_err());
assert!(tool.sanitize_git_args("--editor=vim").is_err());
}
#[test]
fn sanitize_git_blocks_config_injection() {
let tmp = TempDir::new().unwrap();
let tool = test_tool(tmp.path());
// Exact `-c` flag (config injection)
assert!(tool.sanitize_git_args("-c core.sshCommand=evil").is_err());
assert!(tool.sanitize_git_args("-c=core.pager=less").is_err());
}
#[test]
fn sanitize_git_blocks_no_verify() {
let tmp = TempDir::new().unwrap();
let tool = test_tool(tmp.path());
assert!(tool.sanitize_git_args("--no-verify").is_err());
}
#[test]
fn sanitize_git_blocks_redirect_in_args() {
let tmp = TempDir::new().unwrap();
let tool = test_tool(tmp.path());
assert!(tool.sanitize_git_args("file.txt > /tmp/out").is_err());
}
#[test]
fn sanitize_git_cached_not_blocked() {
let tmp = TempDir::new().unwrap();
let tool = test_tool(tmp.path());
// --cached must NOT be blocked by the `-c` check
assert!(tool.sanitize_git_args("--cached").is_ok());
// Other safe flags starting with -c prefix
assert!(tool.sanitize_git_args("-cached").is_ok());
}
#[test] #[test]
fn sanitize_git_allows_safe() { fn sanitize_git_allows_safe() {
let tmp = TempDir::new().unwrap(); let tmp = TempDir::new().unwrap();
@ -583,6 +648,8 @@ mod tests {
assert!(tool.sanitize_git_args("main").is_ok()); assert!(tool.sanitize_git_args("main").is_ok());
assert!(tool.sanitize_git_args("feature/test-branch").is_ok()); assert!(tool.sanitize_git_args("feature/test-branch").is_ok());
assert!(tool.sanitize_git_args("--cached").is_ok()); assert!(tool.sanitize_git_args("--cached").is_ok());
assert!(tool.sanitize_git_args("src/main.rs").is_ok());
assert!(tool.sanitize_git_args(".").is_ok());
} }
#[test] #[test]
@ -691,4 +758,12 @@ mod tests {
.unwrap_or("") .unwrap_or("")
.contains("Unknown operation")); .contains("Unknown operation"));
} }
#[test]
fn truncates_multibyte_commit_message_without_panicking() {
let long = "🦀".repeat(2500);
let truncated = GitOperationsTool::truncate_commit_message(&long);
assert_eq!(truncated.chars().count(), 2000);
}
} }

View file

@ -124,10 +124,11 @@ impl Tool for HardwareBoardInfoTool {
}); });
} }
Err(e) => { Err(e) => {
output.push_str(&format!( use std::fmt::Write;
"probe-rs attach failed: {}. Using static info.\n\n", let _ = write!(
e output,
)); "probe-rs attach failed: {e}. Using static info.\n\n"
);
} }
} }
} }
@ -135,13 +136,15 @@ impl Tool for HardwareBoardInfoTool {
if let Some(info) = self.static_info_for_board(board) { if let Some(info) = self.static_info_for_board(board) {
output.push_str(&info); output.push_str(&info);
if let Some(mem) = memory_map_static(board) { if let Some(mem) = memory_map_static(board) {
output.push_str(&format!("\n\n**Memory map:**\n{}", mem)); use std::fmt::Write;
let _ = write!(output, "\n\n**Memory map:**\n{mem}");
} }
} else { } else {
output.push_str(&format!( use std::fmt::Write;
"Board '{}' configured. No static info available.", let _ = write!(
board output,
)); "Board '{board}' configured. No static info available."
);
} }
Ok(ToolResult { Ok(ToolResult {

View file

@ -122,14 +122,16 @@ impl Tool for HardwareMemoryMapTool {
if !probe_ok { if !probe_ok {
if let Some(map) = self.static_map_for_board(board) { if let Some(map) = self.static_map_for_board(board) {
output.push_str(&format!("**{}** (from datasheet):\n{}", board, map)); use std::fmt::Write;
let _ = write!(output, "**{board}** (from datasheet):\n{map}");
} else { } else {
use std::fmt::Write;
let known: Vec<&str> = MEMORY_MAPS.iter().map(|(b, _)| *b).collect(); let known: Vec<&str> = MEMORY_MAPS.iter().map(|(b, _)| *b).collect();
output.push_str(&format!( let _ = write!(
"No memory map for board '{}'. Known boards: {}", output,
board, "No memory map for board '{board}'. Known boards: {}",
known.join(", ") known.join(", ")
)); );
} }
} }

View file

@ -94,14 +94,16 @@ impl Tool for HardwareMemoryReadTool {
.get("address") .get("address")
.and_then(|v| v.as_str()) .and_then(|v| v.as_str())
.unwrap_or("0x20000000"); .unwrap_or("0x20000000");
let address = parse_hex_address(address_str).unwrap_or(NUCLEO_RAM_BASE); let _address = parse_hex_address(address_str).unwrap_or(NUCLEO_RAM_BASE);
let length = args.get("length").and_then(|v| v.as_u64()).unwrap_or(128) as usize; let requested_length = args.get("length").and_then(|v| v.as_u64()).unwrap_or(128);
let length = length.min(256).max(1); let _length = usize::try_from(requested_length)
.unwrap_or(256)
.clamp(1, 256);
#[cfg(feature = "probe")] #[cfg(feature = "probe")]
{ {
match probe_read_memory(chip.unwrap(), address, length) { match probe_read_memory(chip.unwrap(), _address, _length) {
Ok(output) => { Ok(output) => {
return Ok(ToolResult { return Ok(ToolResult {
success: true, success: true,

View file

@ -749,4 +749,54 @@ mod tests {
let _ = HttpRequestTool::redact_headers_for_display(&headers); let _ = HttpRequestTool::redact_headers_for_display(&headers);
assert_eq!(headers[0].1, "Bearer real-token"); assert_eq!(headers[0].1, "Bearer real-token");
} }
// ── SSRF: alternate IP notation bypass defense-in-depth ─────────
//
// Rust's IpAddr::parse() rejects non-standard notations (octal, hex,
// decimal integer, zero-padded). These tests document that property
// so regressions are caught if the parsing strategy ever changes.
#[test]
fn ssrf_octal_loopback_not_parsed_as_ip() {
// 0177.0.0.1 is octal for 127.0.0.1 in some languages, but
// Rust's IpAddr rejects it — it falls through as a hostname.
assert!(!is_private_or_local_host("0177.0.0.1"));
}
#[test]
fn ssrf_hex_loopback_not_parsed_as_ip() {
// 0x7f000001 is hex for 127.0.0.1 in some languages.
assert!(!is_private_or_local_host("0x7f000001"));
}
#[test]
fn ssrf_decimal_loopback_not_parsed_as_ip() {
// 2130706433 is decimal for 127.0.0.1 in some languages.
assert!(!is_private_or_local_host("2130706433"));
}
#[test]
fn ssrf_zero_padded_loopback_not_parsed_as_ip() {
// 127.000.000.001 uses zero-padded octets.
assert!(!is_private_or_local_host("127.000.000.001"));
}
#[test]
fn ssrf_alternate_notations_rejected_by_validate_url() {
// Even if is_private_or_local_host doesn't flag these, they
// fail the allowlist because they're treated as hostnames.
let tool = test_tool(vec!["example.com"]);
for notation in [
"http://0177.0.0.1",
"http://0x7f000001",
"http://2130706433",
"http://127.000.000.001",
] {
let err = tool.validate_url(notation).unwrap_err().to_string();
assert!(
err.contains("allowed_domains"),
"Expected allowlist rejection for {notation}, got: {err}"
);
}
}
} }

View file

@ -87,7 +87,7 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn forget_existing() { async fn forget_existing() {
let (_tmp, mem) = test_mem(); let (_tmp, mem) = test_mem();
mem.store("temp", "temporary", MemoryCategory::Conversation) mem.store("temp", "temporary", MemoryCategory::Conversation, None)
.await .await
.unwrap(); .unwrap();

View file

@ -55,7 +55,7 @@ impl Tool for MemoryRecallTool {
.and_then(serde_json::Value::as_u64) .and_then(serde_json::Value::as_u64)
.map_or(5, |v| v as usize); .map_or(5, |v| v as usize);
match self.memory.recall(query, limit).await { match self.memory.recall(query, limit, None).await {
Ok(entries) if entries.is_empty() => Ok(ToolResult { Ok(entries) if entries.is_empty() => Ok(ToolResult {
success: true, success: true,
output: "No memories found matching that query.".into(), output: "No memories found matching that query.".into(),
@ -112,10 +112,10 @@ mod tests {
#[tokio::test] #[tokio::test]
async fn recall_finds_match() { async fn recall_finds_match() {
let (_tmp, mem) = seeded_mem(); let (_tmp, mem) = seeded_mem();
mem.store("lang", "User prefers Rust", MemoryCategory::Core) mem.store("lang", "User prefers Rust", MemoryCategory::Core, None)
.await .await
.unwrap(); .unwrap();
mem.store("tz", "Timezone is EST", MemoryCategory::Core) mem.store("tz", "Timezone is EST", MemoryCategory::Core, None)
.await .await
.unwrap(); .unwrap();
@ -134,6 +134,7 @@ mod tests {
&format!("k{i}"), &format!("k{i}"),
&format!("Rust fact {i}"), &format!("Rust fact {i}"),
MemoryCategory::Core, MemoryCategory::Core,
None,
) )
.await .await
.unwrap(); .unwrap();

View file

@ -64,7 +64,7 @@ impl Tool for MemoryStoreTool {
_ => MemoryCategory::Core, _ => MemoryCategory::Core,
}; };
match self.memory.store(key, content, category).await { match self.memory.store(key, content, category, None).await {
Ok(()) => Ok(ToolResult { Ok(()) => Ok(ToolResult {
success: true, success: true,
output: format!("Stored memory: {key}"), output: format!("Stored memory: {key}"),

View file

@ -19,7 +19,9 @@ pub mod image_info;
pub mod memory_forget; pub mod memory_forget;
pub mod memory_recall; pub mod memory_recall;
pub mod memory_store; pub mod memory_store;
pub mod pushover;
pub mod schedule; pub mod schedule;
pub mod schema;
pub mod screenshot; pub mod screenshot;
pub mod shell; pub mod shell;
pub mod traits; pub mod traits;
@ -45,7 +47,9 @@ pub use image_info::ImageInfoTool;
pub use memory_forget::MemoryForgetTool; pub use memory_forget::MemoryForgetTool;
pub use memory_recall::MemoryRecallTool; pub use memory_recall::MemoryRecallTool;
pub use memory_store::MemoryStoreTool; pub use memory_store::MemoryStoreTool;
pub use pushover::PushoverTool;
pub use schedule::ScheduleTool; pub use schedule::ScheduleTool;
pub use schema::{CleaningStrategy, SchemaCleanr};
pub use screenshot::ScreenshotTool; pub use screenshot::ScreenshotTool;
pub use shell::ShellTool; pub use shell::ShellTool;
pub use traits::Tool; pub use traits::Tool;
@ -141,6 +145,10 @@ pub fn all_tools_with_runtime(
security.clone(), security.clone(),
workspace_dir.to_path_buf(), workspace_dir.to_path_buf(),
)), )),
Box::new(PushoverTool::new(
security.clone(),
workspace_dir.to_path_buf(),
)),
]; ];
if browser_config.enabled { if browser_config.enabled {
@ -195,9 +203,13 @@ pub fn all_tools_with_runtime(
.iter() .iter()
.map(|(name, cfg)| (name.clone(), cfg.clone())) .map(|(name, cfg)| (name.clone(), cfg.clone()))
.collect(); .collect();
let delegate_fallback_credential = fallback_api_key.and_then(|value| {
let trimmed_value = value.trim();
(!trimmed_value.is_empty()).then(|| trimmed_value.to_owned())
});
tools.push(Box::new(DelegateTool::new( tools.push(Box::new(DelegateTool::new(
delegate_agents, delegate_agents,
fallback_api_key.map(String::from), delegate_fallback_credential,
))); )));
} }
@ -261,6 +273,7 @@ mod tests {
let names: Vec<&str> = tools.iter().map(|t| t.name()).collect(); let names: Vec<&str> = tools.iter().map(|t| t.name()).collect();
assert!(!names.contains(&"browser_open")); assert!(!names.contains(&"browser_open"));
assert!(names.contains(&"schedule")); assert!(names.contains(&"schedule"));
assert!(names.contains(&"pushover"));
} }
#[test] #[test]
@ -298,6 +311,7 @@ mod tests {
); );
let names: Vec<&str> = tools.iter().map(|t| t.name()).collect(); let names: Vec<&str> = tools.iter().map(|t| t.name()).collect();
assert!(names.contains(&"browser_open")); assert!(names.contains(&"browser_open"));
assert!(names.contains(&"pushover"));
} }
#[test] #[test]
@ -432,7 +446,7 @@ mod tests {
&http, &http,
tmp.path(), tmp.path(),
&agents, &agents,
Some("sk-test"), Some("delegate-test-credential"),
&cfg, &cfg,
); );
let names: Vec<&str> = tools.iter().map(|t| t.name()).collect(); let names: Vec<&str> = tools.iter().map(|t| t.name()).collect();

442
src/tools/pushover.rs Normal file
View file

@ -0,0 +1,442 @@
use super::traits::{Tool, ToolResult};
use crate::security::SecurityPolicy;
use async_trait::async_trait;
use reqwest::Client;
use serde_json::json;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration;
const PUSHOVER_API_URL: &str = "https://api.pushover.net/1/messages.json";
const PUSHOVER_REQUEST_TIMEOUT_SECS: u64 = 15;
pub struct PushoverTool {
client: Client,
security: Arc<SecurityPolicy>,
workspace_dir: PathBuf,
}
impl PushoverTool {
pub fn new(security: Arc<SecurityPolicy>, workspace_dir: PathBuf) -> Self {
let client = Client::builder()
.timeout(Duration::from_secs(PUSHOVER_REQUEST_TIMEOUT_SECS))
.build()
.unwrap_or_else(|_| Client::new());
Self {
client,
security,
workspace_dir,
}
}
fn parse_env_value(raw: &str) -> String {
let raw = raw.trim();
let unquoted = if raw.len() >= 2
&& ((raw.starts_with('"') && raw.ends_with('"'))
|| (raw.starts_with('\'') && raw.ends_with('\'')))
{
&raw[1..raw.len() - 1]
} else {
raw
};
// Keep support for inline comments in unquoted values:
// KEY=value # comment
unquoted.split_once(" #").map_or_else(
|| unquoted.trim().to_string(),
|(value, _)| value.trim().to_string(),
)
}
fn get_credentials(&self) -> anyhow::Result<(String, String)> {
let env_path = self.workspace_dir.join(".env");
let content = std::fs::read_to_string(&env_path)
.map_err(|e| anyhow::anyhow!("Failed to read {}: {}", env_path.display(), e))?;
let mut token = None;
let mut user_key = None;
for line in content.lines() {
let line = line.trim();
if line.starts_with('#') || line.is_empty() {
continue;
}
let line = line.strip_prefix("export ").map(str::trim).unwrap_or(line);
if let Some((key, value)) = line.split_once('=') {
let key = key.trim();
let value = Self::parse_env_value(value);
if key.eq_ignore_ascii_case("PUSHOVER_TOKEN") {
token = Some(value);
} else if key.eq_ignore_ascii_case("PUSHOVER_USER_KEY") {
user_key = Some(value);
}
}
}
let token = token.ok_or_else(|| anyhow::anyhow!("PUSHOVER_TOKEN not found in .env"))?;
let user_key =
user_key.ok_or_else(|| anyhow::anyhow!("PUSHOVER_USER_KEY not found in .env"))?;
Ok((token, user_key))
}
}
#[async_trait]
impl Tool for PushoverTool {
fn name(&self) -> &str {
"pushover"
}
fn description(&self) -> &str {
"Send a Pushover notification to your device. Requires PUSHOVER_TOKEN and PUSHOVER_USER_KEY in .env file."
}
fn parameters_schema(&self) -> serde_json::Value {
json!({
"type": "object",
"properties": {
"message": {
"type": "string",
"description": "The notification message to send"
},
"title": {
"type": "string",
"description": "Optional notification title"
},
"priority": {
"type": "integer",
"enum": [-2, -1, 0, 1, 2],
"description": "Message priority: -2 (lowest/silent), -1 (low/no sound), 0 (normal), 1 (high), 2 (emergency/repeating)"
},
"sound": {
"type": "string",
"description": "Notification sound override (e.g., 'pushover', 'bike', 'bugle', 'cashregister', etc.)"
}
},
"required": ["message"]
})
}
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
if !self.security.can_act() {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some("Action blocked: autonomy is read-only".into()),
});
}
if !self.security.record_action() {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some("Action blocked: rate limit exceeded".into()),
});
}
let message = args
.get("message")
.and_then(|v| v.as_str())
.map(str::trim)
.filter(|v| !v.is_empty())
.ok_or_else(|| anyhow::anyhow!("Missing 'message' parameter"))?
.to_string();
let title = args.get("title").and_then(|v| v.as_str()).map(String::from);
let priority = match args.get("priority").and_then(|v| v.as_i64()) {
Some(value) if (-2..=2).contains(&value) => Some(value),
Some(value) => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!(
"Invalid 'priority': {value}. Expected integer in range -2..=2"
)),
})
}
None => None,
};
let sound = args.get("sound").and_then(|v| v.as_str()).map(String::from);
let (token, user_key) = self.get_credentials()?;
let mut form = reqwest::multipart::Form::new()
.text("token", token)
.text("user", user_key)
.text("message", message);
if let Some(title) = title {
form = form.text("title", title);
}
if let Some(priority) = priority {
form = form.text("priority", priority.to_string());
}
if let Some(sound) = sound {
form = form.text("sound", sound);
}
let response = self
.client
.post(PUSHOVER_API_URL)
.multipart(form)
.send()
.await?;
let status = response.status();
let body = response.text().await.unwrap_or_default();
if !status.is_success() {
return Ok(ToolResult {
success: false,
output: body,
error: Some(format!("Pushover API returned status {}", status)),
});
}
let api_status = serde_json::from_str::<serde_json::Value>(&body)
.ok()
.and_then(|json| json.get("status").and_then(|value| value.as_i64()));
if api_status == Some(1) {
Ok(ToolResult {
success: true,
output: format!(
"Pushover notification sent successfully. Response: {}",
body
),
error: None,
})
} else {
Ok(ToolResult {
success: false,
output: body,
error: Some("Pushover API returned an application-level error".into()),
})
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::security::AutonomyLevel;
use std::fs;
use tempfile::TempDir;
fn test_security(level: AutonomyLevel, max_actions_per_hour: u32) -> Arc<SecurityPolicy> {
Arc::new(SecurityPolicy {
autonomy: level,
max_actions_per_hour,
workspace_dir: std::env::temp_dir(),
..SecurityPolicy::default()
})
}
#[test]
fn pushover_tool_name() {
let tool = PushoverTool::new(
test_security(AutonomyLevel::Full, 100),
PathBuf::from("/tmp"),
);
assert_eq!(tool.name(), "pushover");
}
#[test]
fn pushover_tool_description() {
let tool = PushoverTool::new(
test_security(AutonomyLevel::Full, 100),
PathBuf::from("/tmp"),
);
assert!(!tool.description().is_empty());
}
#[test]
fn pushover_tool_has_parameters_schema() {
let tool = PushoverTool::new(
test_security(AutonomyLevel::Full, 100),
PathBuf::from("/tmp"),
);
let schema = tool.parameters_schema();
assert_eq!(schema["type"], "object");
assert!(schema["properties"].get("message").is_some());
}
#[test]
fn pushover_tool_requires_message() {
let tool = PushoverTool::new(
test_security(AutonomyLevel::Full, 100),
PathBuf::from("/tmp"),
);
let schema = tool.parameters_schema();
let required = schema["required"].as_array().unwrap();
assert!(required.contains(&serde_json::Value::String("message".to_string())));
}
#[test]
fn credentials_parsed_from_env_file() {
let tmp = TempDir::new().unwrap();
let env_path = tmp.path().join(".env");
fs::write(
&env_path,
"PUSHOVER_TOKEN=testtoken123\nPUSHOVER_USER_KEY=userkey456\n",
)
.unwrap();
let tool = PushoverTool::new(
test_security(AutonomyLevel::Full, 100),
tmp.path().to_path_buf(),
);
let result = tool.get_credentials();
assert!(result.is_ok());
let (token, user_key) = result.unwrap();
assert_eq!(token, "testtoken123");
assert_eq!(user_key, "userkey456");
}
#[test]
fn credentials_fail_without_env_file() {
let tmp = TempDir::new().unwrap();
let tool = PushoverTool::new(
test_security(AutonomyLevel::Full, 100),
tmp.path().to_path_buf(),
);
let result = tool.get_credentials();
assert!(result.is_err());
}
#[test]
fn credentials_fail_without_token() {
let tmp = TempDir::new().unwrap();
let env_path = tmp.path().join(".env");
fs::write(&env_path, "PUSHOVER_USER_KEY=userkey456\n").unwrap();
let tool = PushoverTool::new(
test_security(AutonomyLevel::Full, 100),
tmp.path().to_path_buf(),
);
let result = tool.get_credentials();
assert!(result.is_err());
}
#[test]
fn credentials_fail_without_user_key() {
let tmp = TempDir::new().unwrap();
let env_path = tmp.path().join(".env");
fs::write(&env_path, "PUSHOVER_TOKEN=testtoken123\n").unwrap();
let tool = PushoverTool::new(
test_security(AutonomyLevel::Full, 100),
tmp.path().to_path_buf(),
);
let result = tool.get_credentials();
assert!(result.is_err());
}
#[test]
fn credentials_ignore_comments() {
let tmp = TempDir::new().unwrap();
let env_path = tmp.path().join(".env");
fs::write(&env_path, "# This is a comment\nPUSHOVER_TOKEN=realtoken\n# Another comment\nPUSHOVER_USER_KEY=realuser\n").unwrap();
let tool = PushoverTool::new(
test_security(AutonomyLevel::Full, 100),
tmp.path().to_path_buf(),
);
let result = tool.get_credentials();
assert!(result.is_ok());
let (token, user_key) = result.unwrap();
assert_eq!(token, "realtoken");
assert_eq!(user_key, "realuser");
}
#[test]
fn pushover_tool_supports_priority() {
let tool = PushoverTool::new(
test_security(AutonomyLevel::Full, 100),
PathBuf::from("/tmp"),
);
let schema = tool.parameters_schema();
assert!(schema["properties"].get("priority").is_some());
}
#[test]
fn pushover_tool_supports_sound() {
let tool = PushoverTool::new(
test_security(AutonomyLevel::Full, 100),
PathBuf::from("/tmp"),
);
let schema = tool.parameters_schema();
assert!(schema["properties"].get("sound").is_some());
}
#[test]
fn credentials_support_export_and_quoted_values() {
let tmp = TempDir::new().unwrap();
let env_path = tmp.path().join(".env");
fs::write(
&env_path,
"export PUSHOVER_TOKEN=\"quotedtoken\"\nPUSHOVER_USER_KEY='quoteduser'\n",
)
.unwrap();
let tool = PushoverTool::new(
test_security(AutonomyLevel::Full, 100),
tmp.path().to_path_buf(),
);
let result = tool.get_credentials();
assert!(result.is_ok());
let (token, user_key) = result.unwrap();
assert_eq!(token, "quotedtoken");
assert_eq!(user_key, "quoteduser");
}
#[tokio::test]
async fn execute_blocks_readonly_mode() {
let tool = PushoverTool::new(
test_security(AutonomyLevel::ReadOnly, 100),
PathBuf::from("/tmp"),
);
let result = tool.execute(json!({"message": "hello"})).await.unwrap();
assert!(!result.success);
assert!(result.error.unwrap().contains("read-only"));
}
#[tokio::test]
async fn execute_blocks_rate_limit() {
let tool = PushoverTool::new(test_security(AutonomyLevel::Full, 0), PathBuf::from("/tmp"));
let result = tool.execute(json!({"message": "hello"})).await.unwrap();
assert!(!result.success);
assert!(result.error.unwrap().contains("rate limit"));
}
#[tokio::test]
async fn execute_rejects_priority_out_of_range() {
let tool = PushoverTool::new(
test_security(AutonomyLevel::Full, 100),
PathBuf::from("/tmp"),
);
let result = tool
.execute(json!({"message": "hello", "priority": 5}))
.await
.unwrap();
assert!(!result.success);
assert!(result.error.unwrap().contains("-2..=2"));
}
}

838
src/tools/schema.rs Normal file
View file

@ -0,0 +1,838 @@
//! JSON Schema cleaning and validation for LLM tool-calling compatibility.
//!
//! Different providers support different subsets of JSON Schema. This module
//! normalizes tool schemas to improve cross-provider compatibility while
//! preserving semantic intent.
//!
//! ## What this module does
//!
//! 1. Removes unsupported keywords per provider strategy
//! 2. Resolves local `$ref` entries from `$defs` and `definitions`
//! 3. Flattens literal `anyOf` / `oneOf` unions into `enum`
//! 4. Strips nullable variants from unions and `type` arrays
//! 5. Converts `const` to single-value `enum`
//! 6. Detects circular references and stops recursion safely
//!
//! # Example
//!
//! ```rust
//! use serde_json::json;
//! use zeroclaw::tools::schema::SchemaCleanr;
//!
//! let dirty_schema = json!({
//! "type": "object",
//! "properties": {
//! "name": {
//! "type": "string",
//! "minLength": 1, // Gemini rejects this
//! "pattern": "^[a-z]+$" // Gemini rejects this
//! },
//! "age": {
//! "$ref": "#/$defs/Age" // Needs resolution
//! }
//! },
//! "$defs": {
//! "Age": {
//! "type": "integer",
//! "minimum": 0 // Gemini rejects this
//! }
//! }
//! });
//!
//! let cleaned = SchemaCleanr::clean_for_gemini(dirty_schema);
//!
//! // Result:
//! // {
//! // "type": "object",
//! // "properties": {
//! // "name": { "type": "string" },
//! // "age": { "type": "integer" }
//! // }
//! // }
//! ```
//!
use serde_json::{json, Map, Value};
use std::collections::{HashMap, HashSet};
/// Keywords that Gemini rejects for tool schemas.
pub const GEMINI_UNSUPPORTED_KEYWORDS: &[&str] = &[
// Schema composition
"$ref",
"$schema",
"$id",
"$defs",
"definitions",
// Property constraints
"additionalProperties",
"patternProperties",
// String constraints
"minLength",
"maxLength",
"pattern",
"format",
// Number constraints
"minimum",
"maximum",
"multipleOf",
// Array constraints
"minItems",
"maxItems",
"uniqueItems",
// Object constraints
"minProperties",
"maxProperties",
// Non-standard
"examples", // OpenAPI keyword, not JSON Schema
];
/// Keywords that should be preserved during cleaning (metadata).
const SCHEMA_META_KEYS: &[&str] = &["description", "title", "default"];
/// Schema cleaning strategies for different LLM providers.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CleaningStrategy {
/// Gemini (Google AI / Vertex AI) - Most restrictive
Gemini,
/// Anthropic Claude - Moderately permissive
Anthropic,
/// OpenAI GPT - Most permissive
OpenAI,
/// Conservative: Remove only universally unsupported keywords
Conservative,
}
impl CleaningStrategy {
/// Get the list of unsupported keywords for this strategy.
pub fn unsupported_keywords(&self) -> &'static [&'static str] {
match self {
Self::Gemini => GEMINI_UNSUPPORTED_KEYWORDS,
Self::Anthropic => &["$ref", "$defs", "definitions"], // Anthropic doesn't resolve refs
Self::OpenAI => &[], // OpenAI is most permissive
Self::Conservative => &["$ref", "$defs", "definitions", "additionalProperties"],
}
}
}
/// JSON Schema cleaner optimized for LLM tool calling.
pub struct SchemaCleanr;
impl SchemaCleanr {
/// Clean schema for Gemini compatibility (strictest).
///
/// This is the most aggressive cleaning strategy, removing all keywords
/// that Gemini's API rejects.
pub fn clean_for_gemini(schema: Value) -> Value {
Self::clean(schema, CleaningStrategy::Gemini)
}
/// Clean schema for Anthropic compatibility.
pub fn clean_for_anthropic(schema: Value) -> Value {
Self::clean(schema, CleaningStrategy::Anthropic)
}
/// Clean schema for OpenAI compatibility (most permissive).
pub fn clean_for_openai(schema: Value) -> Value {
Self::clean(schema, CleaningStrategy::OpenAI)
}
/// Clean schema with specified strategy.
pub fn clean(schema: Value, strategy: CleaningStrategy) -> Value {
// Extract $defs for reference resolution
let defs = if let Some(obj) = schema.as_object() {
Self::extract_defs(obj)
} else {
HashMap::new()
};
Self::clean_with_defs(schema, &defs, strategy, &mut HashSet::new())
}
/// Validate that a schema is suitable for LLM tool calling.
///
/// Returns an error if the schema is invalid or missing required fields.
pub fn validate(schema: &Value) -> anyhow::Result<()> {
let obj = schema
.as_object()
.ok_or_else(|| anyhow::anyhow!("Schema must be an object"))?;
// Must have 'type' field
if !obj.contains_key("type") {
anyhow::bail!("Schema missing required 'type' field");
}
// If type is 'object', should have 'properties'
if let Some(Value::String(t)) = obj.get("type") {
if t == "object" && !obj.contains_key("properties") {
tracing::warn!("Object schema without 'properties' field may cause issues");
}
}
Ok(())
}
// --------------------------------------------------------------------
// Internal implementation
// --------------------------------------------------------------------
/// Extract $defs and definitions into a flat map for reference resolution.
fn extract_defs(obj: &Map<String, Value>) -> HashMap<String, Value> {
let mut defs = HashMap::new();
// Extract from $defs (JSON Schema 2019-09+)
if let Some(Value::Object(defs_obj)) = obj.get("$defs") {
for (key, value) in defs_obj {
defs.insert(key.clone(), value.clone());
}
}
// Extract from definitions (JSON Schema draft-07)
if let Some(Value::Object(defs_obj)) = obj.get("definitions") {
for (key, value) in defs_obj {
defs.insert(key.clone(), value.clone());
}
}
defs
}
/// Recursively clean a schema value.
fn clean_with_defs(
schema: Value,
defs: &HashMap<String, Value>,
strategy: CleaningStrategy,
ref_stack: &mut HashSet<String>,
) -> Value {
match schema {
Value::Object(obj) => Self::clean_object(obj, defs, strategy, ref_stack),
Value::Array(arr) => Value::Array(
arr.into_iter()
.map(|v| Self::clean_with_defs(v, defs, strategy, ref_stack))
.collect(),
),
other => other,
}
}
/// Clean an object schema.
fn clean_object(
obj: Map<String, Value>,
defs: &HashMap<String, Value>,
strategy: CleaningStrategy,
ref_stack: &mut HashSet<String>,
) -> Value {
// Handle $ref resolution
if let Some(Value::String(ref_value)) = obj.get("$ref") {
return Self::resolve_ref(ref_value, &obj, defs, strategy, ref_stack);
}
// Handle anyOf/oneOf simplification
if obj.contains_key("anyOf") || obj.contains_key("oneOf") {
if let Some(simplified) = Self::try_simplify_union(&obj, defs, strategy, ref_stack) {
return simplified;
}
}
// Build cleaned object
let mut cleaned = Map::new();
let unsupported: HashSet<&str> = strategy.unsupported_keywords().iter().copied().collect();
let has_union = obj.contains_key("anyOf") || obj.contains_key("oneOf");
for (key, value) in obj {
// Skip unsupported keywords
if unsupported.contains(key.as_str()) {
continue;
}
// Special handling for specific keys
match key.as_str() {
// Convert const to enum
"const" => {
cleaned.insert("enum".to_string(), json!([value]));
}
// Skip type if we have anyOf/oneOf (they define the type)
"type" if has_union => {
// Skip
}
// Handle type arrays (remove null)
"type" if matches!(value, Value::Array(_)) => {
let cleaned_value = Self::clean_type_array(value);
cleaned.insert(key, cleaned_value);
}
// Recursively clean nested schemas
"properties" => {
let cleaned_value = Self::clean_properties(value, defs, strategy, ref_stack);
cleaned.insert(key, cleaned_value);
}
"items" => {
let cleaned_value = Self::clean_with_defs(value, defs, strategy, ref_stack);
cleaned.insert(key, cleaned_value);
}
"anyOf" | "oneOf" | "allOf" => {
let cleaned_value = Self::clean_union(value, defs, strategy, ref_stack);
cleaned.insert(key, cleaned_value);
}
// Keep all other keys, cleaning nested objects/arrays recursively.
_ => {
let cleaned_value = match value {
Value::Object(_) | Value::Array(_) => {
Self::clean_with_defs(value, defs, strategy, ref_stack)
}
other => other,
};
cleaned.insert(key, cleaned_value);
}
}
}
Value::Object(cleaned)
}
/// Resolve a $ref to its definition.
fn resolve_ref(
ref_value: &str,
obj: &Map<String, Value>,
defs: &HashMap<String, Value>,
strategy: CleaningStrategy,
ref_stack: &mut HashSet<String>,
) -> Value {
// Prevent circular references
if ref_stack.contains(ref_value) {
tracing::warn!("Circular $ref detected: {}", ref_value);
return Self::preserve_meta(obj, Value::Object(Map::new()));
}
// Try to resolve local ref (#/$defs/Name or #/definitions/Name)
if let Some(def_name) = Self::parse_local_ref(ref_value) {
if let Some(definition) = defs.get(def_name.as_str()) {
ref_stack.insert(ref_value.to_string());
let cleaned = Self::clean_with_defs(definition.clone(), defs, strategy, ref_stack);
ref_stack.remove(ref_value);
return Self::preserve_meta(obj, cleaned);
}
}
// Can't resolve: return empty object with metadata
tracing::warn!("Cannot resolve $ref: {}", ref_value);
Self::preserve_meta(obj, Value::Object(Map::new()))
}
/// Parse a local JSON Pointer ref (#/$defs/Name).
fn parse_local_ref(ref_value: &str) -> Option<String> {
ref_value
.strip_prefix("#/$defs/")
.or_else(|| ref_value.strip_prefix("#/definitions/"))
.map(Self::decode_json_pointer)
}
/// Decode JSON Pointer escaping (`~0` = `~`, `~1` = `/`).
fn decode_json_pointer(segment: &str) -> String {
if !segment.contains('~') {
return segment.to_string();
}
let mut decoded = String::with_capacity(segment.len());
let mut chars = segment.chars().peekable();
while let Some(ch) = chars.next() {
if ch == '~' {
match chars.peek().copied() {
Some('0') => {
chars.next();
decoded.push('~');
}
Some('1') => {
chars.next();
decoded.push('/');
}
_ => decoded.push('~'),
}
} else {
decoded.push(ch);
}
}
decoded
}
/// Try to simplify anyOf/oneOf to a simpler form.
fn try_simplify_union(
obj: &Map<String, Value>,
defs: &HashMap<String, Value>,
strategy: CleaningStrategy,
ref_stack: &mut HashSet<String>,
) -> Option<Value> {
let union_key = if obj.contains_key("anyOf") {
"anyOf"
} else if obj.contains_key("oneOf") {
"oneOf"
} else {
return None;
};
let variants = obj.get(union_key)?.as_array()?;
// Clean all variants first
let cleaned_variants: Vec<Value> = variants
.iter()
.map(|v| Self::clean_with_defs(v.clone(), defs, strategy, ref_stack))
.collect();
// Strip null variants
let non_null: Vec<Value> = cleaned_variants
.into_iter()
.filter(|v| !Self::is_null_schema(v))
.collect();
// If only one variant remains after stripping nulls, return it
if non_null.len() == 1 {
return Some(Self::preserve_meta(obj, non_null[0].clone()));
}
// Try to flatten to enum if all variants are literals
if let Some(enum_value) = Self::try_flatten_literal_union(&non_null) {
return Some(Self::preserve_meta(obj, enum_value));
}
None
}
/// Check if a schema represents null type.
fn is_null_schema(value: &Value) -> bool {
if let Some(obj) = value.as_object() {
// { const: null }
if let Some(Value::Null) = obj.get("const") {
return true;
}
// { enum: [null] }
if let Some(Value::Array(arr)) = obj.get("enum") {
if arr.len() == 1 && matches!(arr[0], Value::Null) {
return true;
}
}
// { type: "null" }
if let Some(Value::String(t)) = obj.get("type") {
if t == "null" {
return true;
}
}
}
false
}
/// Try to flatten anyOf/oneOf with only literal values to enum.
///
/// Example: `anyOf: [{const: "a"}, {const: "b"}]` -> `{type: "string", enum: ["a", "b"]}`
fn try_flatten_literal_union(variants: &[Value]) -> Option<Value> {
if variants.is_empty() {
return None;
}
let mut all_values = Vec::new();
let mut common_type: Option<String> = None;
for variant in variants {
let obj = variant.as_object()?;
// Extract literal value from const or single-item enum
let literal_value = if let Some(const_val) = obj.get("const") {
const_val.clone()
} else if let Some(Value::Array(arr)) = obj.get("enum") {
if arr.len() == 1 {
arr[0].clone()
} else {
return None;
}
} else {
return None;
};
// Check type consistency
let variant_type = obj.get("type")?.as_str()?;
match &common_type {
None => common_type = Some(variant_type.to_string()),
Some(t) if t != variant_type => return None,
_ => {}
}
all_values.push(literal_value);
}
common_type.map(|t| {
json!({
"type": t,
"enum": all_values
})
})
}
/// Clean type array, removing null.
fn clean_type_array(value: Value) -> Value {
if let Value::Array(types) = value {
let non_null: Vec<Value> = types
.into_iter()
.filter(|v| v.as_str() != Some("null"))
.collect();
match non_null.len() {
0 => Value::String("null".to_string()),
1 => non_null
.into_iter()
.next()
.unwrap_or(Value::String("null".to_string())),
_ => Value::Array(non_null),
}
} else {
value
}
}
/// Clean properties object.
fn clean_properties(
value: Value,
defs: &HashMap<String, Value>,
strategy: CleaningStrategy,
ref_stack: &mut HashSet<String>,
) -> Value {
if let Value::Object(props) = value {
let cleaned: Map<String, Value> = props
.into_iter()
.map(|(k, v)| (k, Self::clean_with_defs(v, defs, strategy, ref_stack)))
.collect();
Value::Object(cleaned)
} else {
value
}
}
/// Clean union (anyOf/oneOf/allOf).
fn clean_union(
value: Value,
defs: &HashMap<String, Value>,
strategy: CleaningStrategy,
ref_stack: &mut HashSet<String>,
) -> Value {
if let Value::Array(variants) = value {
let cleaned: Vec<Value> = variants
.into_iter()
.map(|v| Self::clean_with_defs(v, defs, strategy, ref_stack))
.collect();
Value::Array(cleaned)
} else {
value
}
}
/// Preserve metadata (description, title, default) from source to target.
fn preserve_meta(source: &Map<String, Value>, mut target: Value) -> Value {
if let Value::Object(target_obj) = &mut target {
for &key in SCHEMA_META_KEYS {
if let Some(value) = source.get(key) {
target_obj.insert(key.to_string(), value.clone());
}
}
}
target
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_remove_unsupported_keywords() {
let schema = json!({
"type": "string",
"minLength": 1,
"maxLength": 100,
"pattern": "^[a-z]+$",
"description": "A lowercase string"
});
let cleaned = SchemaCleanr::clean_for_gemini(schema);
assert_eq!(cleaned["type"], "string");
assert_eq!(cleaned["description"], "A lowercase string");
assert!(cleaned.get("minLength").is_none());
assert!(cleaned.get("maxLength").is_none());
assert!(cleaned.get("pattern").is_none());
}
#[test]
fn test_resolve_ref() {
let schema = json!({
"type": "object",
"properties": {
"age": {
"$ref": "#/$defs/Age"
}
},
"$defs": {
"Age": {
"type": "integer",
"minimum": 0
}
}
});
let cleaned = SchemaCleanr::clean_for_gemini(schema);
assert_eq!(cleaned["properties"]["age"]["type"], "integer");
assert!(cleaned["properties"]["age"].get("minimum").is_none()); // Stripped by Gemini strategy
assert!(cleaned.get("$defs").is_none());
}
#[test]
fn test_flatten_literal_union() {
let schema = json!({
"anyOf": [
{ "const": "admin", "type": "string" },
{ "const": "user", "type": "string" },
{ "const": "guest", "type": "string" }
]
});
let cleaned = SchemaCleanr::clean_for_gemini(schema);
assert_eq!(cleaned["type"], "string");
assert!(cleaned["enum"].is_array());
let enum_values = cleaned["enum"].as_array().unwrap();
assert_eq!(enum_values.len(), 3);
assert!(enum_values.contains(&json!("admin")));
assert!(enum_values.contains(&json!("user")));
assert!(enum_values.contains(&json!("guest")));
}
#[test]
fn test_strip_null_from_union() {
let schema = json!({
"oneOf": [
{ "type": "string" },
{ "type": "null" }
]
});
let cleaned = SchemaCleanr::clean_for_gemini(schema);
// Should simplify to just { type: "string" }
assert_eq!(cleaned["type"], "string");
assert!(cleaned.get("oneOf").is_none());
}
#[test]
fn test_const_to_enum() {
let schema = json!({
"const": "fixed_value",
"description": "A constant"
});
let cleaned = SchemaCleanr::clean_for_gemini(schema);
assert_eq!(cleaned["enum"], json!(["fixed_value"]));
assert_eq!(cleaned["description"], "A constant");
assert!(cleaned.get("const").is_none());
}
#[test]
fn test_preserve_metadata() {
let schema = json!({
"$ref": "#/$defs/Name",
"description": "User's name",
"title": "Name Field",
"default": "Anonymous",
"$defs": {
"Name": {
"type": "string"
}
}
});
let cleaned = SchemaCleanr::clean_for_gemini(schema);
assert_eq!(cleaned["type"], "string");
assert_eq!(cleaned["description"], "User's name");
assert_eq!(cleaned["title"], "Name Field");
assert_eq!(cleaned["default"], "Anonymous");
}
#[test]
fn test_circular_ref_prevention() {
let schema = json!({
"type": "object",
"properties": {
"parent": {
"$ref": "#/$defs/Node"
}
},
"$defs": {
"Node": {
"type": "object",
"properties": {
"child": {
"$ref": "#/$defs/Node"
}
}
}
}
});
// Should not panic on circular reference
let cleaned = SchemaCleanr::clean_for_gemini(schema);
assert_eq!(cleaned["properties"]["parent"]["type"], "object");
// Circular reference should be broken
}
#[test]
fn test_validate_schema() {
let valid = json!({
"type": "object",
"properties": {
"name": { "type": "string" }
}
});
assert!(SchemaCleanr::validate(&valid).is_ok());
let invalid = json!({
"properties": {
"name": { "type": "string" }
}
});
assert!(SchemaCleanr::validate(&invalid).is_err());
}
#[test]
fn test_strategy_differences() {
let schema = json!({
"type": "string",
"minLength": 1,
"description": "A string field"
});
// Gemini: Most restrictive (removes minLength)
let gemini = SchemaCleanr::clean_for_gemini(schema.clone());
assert!(gemini.get("minLength").is_none());
assert_eq!(gemini["type"], "string");
assert_eq!(gemini["description"], "A string field");
// OpenAI: Most permissive (keeps minLength)
let openai = SchemaCleanr::clean_for_openai(schema.clone());
assert_eq!(openai["minLength"], 1); // OpenAI allows validation keywords
assert_eq!(openai["type"], "string");
}
#[test]
fn test_nested_properties() {
let schema = json!({
"type": "object",
"properties": {
"user": {
"type": "object",
"properties": {
"name": {
"type": "string",
"minLength": 1
}
},
"additionalProperties": false
}
}
});
let cleaned = SchemaCleanr::clean_for_gemini(schema);
assert!(cleaned["properties"]["user"]["properties"]["name"]
.get("minLength")
.is_none());
assert!(cleaned["properties"]["user"]
.get("additionalProperties")
.is_none());
}
#[test]
fn test_type_array_null_removal() {
let schema = json!({
"type": ["string", "null"]
});
let cleaned = SchemaCleanr::clean_for_gemini(schema);
// Should simplify to just "string"
assert_eq!(cleaned["type"], "string");
}
#[test]
fn test_type_array_only_null_preserved() {
let schema = json!({
"type": ["null"]
});
let cleaned = SchemaCleanr::clean_for_gemini(schema);
assert_eq!(cleaned["type"], "null");
}
#[test]
fn test_ref_with_json_pointer_escape() {
let schema = json!({
"$ref": "#/$defs/Foo~1Bar",
"$defs": {
"Foo/Bar": {
"type": "string"
}
}
});
let cleaned = SchemaCleanr::clean_for_gemini(schema);
assert_eq!(cleaned["type"], "string");
}
#[test]
fn test_skip_type_when_non_simplifiable_union_exists() {
let schema = json!({
"type": "object",
"oneOf": [
{
"type": "object",
"properties": {
"a": { "type": "string" }
}
},
{
"type": "object",
"properties": {
"b": { "type": "number" }
}
}
]
});
let cleaned = SchemaCleanr::clean_for_gemini(schema);
assert!(cleaned.get("type").is_none());
assert!(cleaned.get("oneOf").is_some());
}
#[test]
fn test_clean_nested_unknown_schema_keyword() {
let schema = json!({
"not": {
"$ref": "#/$defs/Age"
},
"$defs": {
"Age": {
"type": "integer",
"minimum": 0
}
}
});
let cleaned = SchemaCleanr::clean_for_gemini(schema);
assert_eq!(cleaned["not"]["type"], "integer");
assert!(cleaned["not"].get("minimum").is_none());
}
}

View file

@ -36,6 +36,7 @@ async fn compare_store_speed() {
&format!("key_{i}"), &format!("key_{i}"),
&format!("Memory entry number {i} about Rust programming"), &format!("Memory entry number {i} about Rust programming"),
MemoryCategory::Core, MemoryCategory::Core,
None,
) )
.await .await
.unwrap(); .unwrap();
@ -49,6 +50,7 @@ async fn compare_store_speed() {
&format!("key_{i}"), &format!("key_{i}"),
&format!("Memory entry number {i} about Rust programming"), &format!("Memory entry number {i} about Rust programming"),
MemoryCategory::Core, MemoryCategory::Core,
None,
) )
.await .await
.unwrap(); .unwrap();
@ -127,8 +129,8 @@ async fn compare_recall_quality() {
]; ];
for (key, content, cat) in &entries { for (key, content, cat) in &entries {
sq.store(key, content, cat.clone()).await.unwrap(); sq.store(key, content, cat.clone(), None).await.unwrap();
md.store(key, content, cat.clone()).await.unwrap(); md.store(key, content, cat.clone(), None).await.unwrap();
} }
// Test queries and compare results // Test queries and compare results
@ -145,8 +147,8 @@ async fn compare_recall_quality() {
println!("RECALL QUALITY (10 entries seeded):\n"); println!("RECALL QUALITY (10 entries seeded):\n");
for (query, desc) in &queries { for (query, desc) in &queries {
let sq_results = sq.recall(query, 10).await.unwrap(); let sq_results = sq.recall(query, 10, None).await.unwrap();
let md_results = md.recall(query, 10).await.unwrap(); let md_results = md.recall(query, 10, None).await.unwrap();
println!(" Query: \"{query}\"{desc}"); println!(" Query: \"{query}\"{desc}");
println!(" SQLite: {} results", sq_results.len()); println!(" SQLite: {} results", sq_results.len());
@ -190,21 +192,21 @@ async fn compare_recall_speed() {
} else { } else {
format!("TypeScript powers modern web apps, entry {i}") format!("TypeScript powers modern web apps, entry {i}")
}; };
sq.store(&format!("e{i}"), &content, MemoryCategory::Core) sq.store(&format!("e{i}"), &content, MemoryCategory::Core, None)
.await .await
.unwrap(); .unwrap();
md.store(&format!("e{i}"), &content, MemoryCategory::Daily) md.store(&format!("e{i}"), &content, MemoryCategory::Daily, None)
.await .await
.unwrap(); .unwrap();
} }
// Benchmark recall // Benchmark recall
let start = Instant::now(); let start = Instant::now();
let sq_results = sq.recall("Rust systems", 10).await.unwrap(); let sq_results = sq.recall("Rust systems", 10, None).await.unwrap();
let sq_dur = start.elapsed(); let sq_dur = start.elapsed();
let start = Instant::now(); let start = Instant::now();
let md_results = md.recall("Rust systems", 10).await.unwrap(); let md_results = md.recall("Rust systems", 10, None).await.unwrap();
let md_dur = start.elapsed(); let md_dur = start.elapsed();
println!("\n============================================================"); println!("\n============================================================");
@ -227,15 +229,25 @@ async fn compare_persistence() {
// Store in both, then drop and re-open // Store in both, then drop and re-open
{ {
let sq = sqlite_backend(tmp_sq.path()); let sq = sqlite_backend(tmp_sq.path());
sq.store("persist_test", "I should survive", MemoryCategory::Core) sq.store(
.await "persist_test",
.unwrap(); "I should survive",
MemoryCategory::Core,
None,
)
.await
.unwrap();
} }
{ {
let md = markdown_backend(tmp_md.path()); let md = markdown_backend(tmp_md.path());
md.store("persist_test", "I should survive", MemoryCategory::Core) md.store(
.await "persist_test",
.unwrap(); "I should survive",
MemoryCategory::Core,
None,
)
.await
.unwrap();
} }
// Re-open // Re-open
@ -282,17 +294,17 @@ async fn compare_upsert() {
let md = markdown_backend(tmp_md.path()); let md = markdown_backend(tmp_md.path());
// Store twice with same key, different content // Store twice with same key, different content
sq.store("pref", "likes Rust", MemoryCategory::Core) sq.store("pref", "likes Rust", MemoryCategory::Core, None)
.await .await
.unwrap(); .unwrap();
sq.store("pref", "loves Rust", MemoryCategory::Core) sq.store("pref", "loves Rust", MemoryCategory::Core, None)
.await .await
.unwrap(); .unwrap();
md.store("pref", "likes Rust", MemoryCategory::Core) md.store("pref", "likes Rust", MemoryCategory::Core, None)
.await .await
.unwrap(); .unwrap();
md.store("pref", "loves Rust", MemoryCategory::Core) md.store("pref", "loves Rust", MemoryCategory::Core, None)
.await .await
.unwrap(); .unwrap();
@ -300,7 +312,7 @@ async fn compare_upsert() {
let md_count = md.count().await.unwrap(); let md_count = md.count().await.unwrap();
let sq_entry = sq.get("pref").await.unwrap(); let sq_entry = sq.get("pref").await.unwrap();
let md_results = md.recall("loves Rust", 5).await.unwrap(); let md_results = md.recall("loves Rust", 5, None).await.unwrap();
println!("\n============================================================"); println!("\n============================================================");
println!("UPSERT (store same key twice):"); println!("UPSERT (store same key twice):");
@ -328,10 +340,10 @@ async fn compare_forget() {
let sq = sqlite_backend(tmp_sq.path()); let sq = sqlite_backend(tmp_sq.path());
let md = markdown_backend(tmp_md.path()); let md = markdown_backend(tmp_md.path());
sq.store("secret", "API key: sk-1234", MemoryCategory::Core) sq.store("secret", "API key: sk-1234", MemoryCategory::Core, None)
.await .await
.unwrap(); .unwrap();
md.store("secret", "API key: sk-1234", MemoryCategory::Core) md.store("secret", "API key: sk-1234", MemoryCategory::Core, None)
.await .await
.unwrap(); .unwrap();
@ -372,37 +384,40 @@ async fn compare_category_filter() {
let md = markdown_backend(tmp_md.path()); let md = markdown_backend(tmp_md.path());
// Mix of categories // Mix of categories
sq.store("a", "core fact 1", MemoryCategory::Core) sq.store("a", "core fact 1", MemoryCategory::Core, None)
.await .await
.unwrap(); .unwrap();
sq.store("b", "core fact 2", MemoryCategory::Core) sq.store("b", "core fact 2", MemoryCategory::Core, None)
.await .await
.unwrap(); .unwrap();
sq.store("c", "daily note", MemoryCategory::Daily) sq.store("c", "daily note", MemoryCategory::Daily, None)
.await .await
.unwrap(); .unwrap();
sq.store("d", "convo msg", MemoryCategory::Conversation) sq.store("d", "convo msg", MemoryCategory::Conversation, None)
.await .await
.unwrap(); .unwrap();
md.store("a", "core fact 1", MemoryCategory::Core) md.store("a", "core fact 1", MemoryCategory::Core, None)
.await .await
.unwrap(); .unwrap();
md.store("b", "core fact 2", MemoryCategory::Core) md.store("b", "core fact 2", MemoryCategory::Core, None)
.await .await
.unwrap(); .unwrap();
md.store("c", "daily note", MemoryCategory::Daily) md.store("c", "daily note", MemoryCategory::Daily, None)
.await .await
.unwrap(); .unwrap();
let sq_core = sq.list(Some(&MemoryCategory::Core)).await.unwrap(); let sq_core = sq.list(Some(&MemoryCategory::Core), None).await.unwrap();
let sq_daily = sq.list(Some(&MemoryCategory::Daily)).await.unwrap(); let sq_daily = sq.list(Some(&MemoryCategory::Daily), None).await.unwrap();
let sq_conv = sq.list(Some(&MemoryCategory::Conversation)).await.unwrap(); let sq_conv = sq
let sq_all = sq.list(None).await.unwrap(); .list(Some(&MemoryCategory::Conversation), None)
.await
.unwrap();
let sq_all = sq.list(None, None).await.unwrap();
let md_core = md.list(Some(&MemoryCategory::Core)).await.unwrap(); let md_core = md.list(Some(&MemoryCategory::Core), None).await.unwrap();
let md_daily = md.list(Some(&MemoryCategory::Daily)).await.unwrap(); let md_daily = md.list(Some(&MemoryCategory::Daily), None).await.unwrap();
let md_all = md.list(None).await.unwrap(); let md_all = md.list(None, None).await.unwrap();
println!("\n============================================================"); println!("\n============================================================");
println!("CATEGORY FILTERING:"); println!("CATEGORY FILTERING:");