Merge branch 'main' into pr-484-clean

This commit is contained in:
Will Sarg 2026-02-17 08:54:24 -05:00 committed by GitHub
commit ee05d62ce4
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
90 changed files with 6937 additions and 1403 deletions

View file

@ -1,25 +1,69 @@
# ZeroClaw Environment Variables
# Copy this file to .env and fill in your values.
# NEVER commit .env — it is listed in .gitignore.
# Copy this file to `.env` and fill in your local values.
# Never commit `.env` or any real secrets.
# ── Required ──────────────────────────────────────────────────
# Your LLM provider API key
# ZEROCLAW_API_KEY=sk-your-key-here
# ── Core Runtime ──────────────────────────────────────────────
# Provider key resolution at runtime:
# 1) explicit key passed from config/CLI
# 2) provider-specific env var (OPENROUTER_API_KEY, OPENAI_API_KEY, ...)
# 3) generic fallback env vars below
# Generic fallback API key (used when provider-specific key is absent)
API_KEY=your-api-key-here
# ZEROCLAW_API_KEY=your-api-key-here
# ── Provider & Model ─────────────────────────────────────────
# LLM provider: openrouter, openai, anthropic, ollama, glm
# Default provider/model (can be overridden by CLI flags)
PROVIDER=openrouter
# ZEROCLAW_PROVIDER=openrouter
# ZEROCLAW_MODEL=anthropic/claude-sonnet-4-20250514
# ZEROCLAW_TEMPERATURE=0.7
# Workspace directory override
# ZEROCLAW_WORKSPACE=/path/to/workspace
# ── Provider-Specific API Keys ────────────────────────────────
# OpenRouter
# OPENROUTER_API_KEY=sk-or-v1-...
# Anthropic
# ANTHROPIC_OAUTH_TOKEN=...
# ANTHROPIC_API_KEY=sk-ant-...
# OpenAI / Gemini
# OPENAI_API_KEY=sk-...
# GEMINI_API_KEY=...
# GOOGLE_API_KEY=...
# Other supported providers
# VENICE_API_KEY=...
# GROQ_API_KEY=...
# MISTRAL_API_KEY=...
# DEEPSEEK_API_KEY=...
# XAI_API_KEY=...
# TOGETHER_API_KEY=...
# FIREWORKS_API_KEY=...
# PERPLEXITY_API_KEY=...
# COHERE_API_KEY=...
# MOONSHOT_API_KEY=...
# GLM_API_KEY=...
# MINIMAX_API_KEY=...
# QIANFAN_API_KEY=...
# DASHSCOPE_API_KEY=...
# ZAI_API_KEY=...
# SYNTHETIC_API_KEY=...
# OPENCODE_API_KEY=...
# VERCEL_API_KEY=...
# CLOUDFLARE_API_KEY=...
# ── Gateway ──────────────────────────────────────────────────
# ZEROCLAW_GATEWAY_PORT=3000
# ZEROCLAW_GATEWAY_HOST=127.0.0.1
# ZEROCLAW_ALLOW_PUBLIC_BIND=false
# ── Workspace ────────────────────────────────────────────────
# ZEROCLAW_WORKSPACE=/path/to/workspace
# ── Optional Integrations ────────────────────────────────────
# Pushover notifications (`pushover` tool)
# PUSHOVER_TOKEN=your-pushover-app-token
# PUSHOVER_USER_KEY=your-pushover-user-key
# ── Docker Compose ───────────────────────────────────────────
# Host port mapping (used by docker-compose.yml)

8
.githooks/pre-commit Executable file
View file

@ -0,0 +1,8 @@
#!/usr/bin/env bash
set -euo pipefail
if command -v gitleaks >/dev/null 2>&1; then
gitleaks protect --staged --redact
else
echo "warning: gitleaks not found; skipping staged secret scan" >&2
fi

View file

@ -12,7 +12,11 @@ Describe this PR in 2-5 bullets:
- Risk label (`risk: low|medium|high`):
- Size label (`size: XS|S|M|L|XL`, auto-managed/read-only):
- Scope labels (`core|agent|channel|config|cron|daemon|doctor|gateway|health|heartbeat|integration|memory|observability|onboard|provider|runtime|security|service|skillforge|skills|tool|tunnel|docs|dependencies|ci|tests|scripts|dev`, comma-separated):
<<<<<<< chore/labeler-spacing-trusted-tier
- Module labels (`<module>: <component>`, for example `channel: telegram`, `provider: kimi`, `tool: shell`):
=======
- Module labels (`<module>:<component>`, for example `channel:telegram`, `provider:kimi`, `tool:shell`):
>>>>>>> main
- Contributor tier label (`trusted contributor|experienced contributor|principal contributor|distinguished contributor`, auto-managed/read-only; author merged PRs >=5/10/20/50):
- If any auto-label is incorrect, note requested correction:

View file

@ -18,6 +18,7 @@ jobs:
runs-on: blacksmith-2vcpu-ubuntu-2404
permissions:
issues: write
pull-requests: write
steps:
- name: Apply contributor tier label for issue author
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8

View file

@ -35,7 +35,7 @@ jobs:
uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4
- name: Setup Blacksmith Builder
uses: useblacksmith/setup-docker-builder@v1
uses: useblacksmith/setup-docker-builder@ef12d5b165b596e3aa44ea8198d8fde563eab402 # v1
- name: Extract metadata (tags, labels)
id: meta
@ -46,7 +46,7 @@ jobs:
type=ref,event=pr
- name: Build smoke image
uses: useblacksmith/build-push-action@v2
uses: useblacksmith/build-push-action@30c71162f16ea2c27c3e21523255d209b8b538c1 # v2
with:
context: .
push: false
@ -71,7 +71,7 @@ jobs:
uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4
- name: Setup Blacksmith Builder
uses: useblacksmith/setup-docker-builder@v1
uses: useblacksmith/setup-docker-builder@ef12d5b165b596e3aa44ea8198d8fde563eab402 # v1
- name: Log in to Container Registry
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # v3
@ -102,7 +102,7 @@ jobs:
echo "tags=${TAGS}" >> "$GITHUB_OUTPUT"
- name: Build and push Docker image
uses: useblacksmith/build-push-action@v2
uses: useblacksmith/build-push-action@30c71162f16ea2c27c3e21523255d209b8b538c1 # v2
with:
context: .
push: true

View file

@ -325,13 +325,18 @@ jobs:
return pattern.test(text);
}
function formatModuleLabel(prefix, segment) {
return `${prefix}: ${segment}`;
}
function parseModuleLabel(label) {
const separatorIndex = label.indexOf(":");
if (separatorIndex <= 0 || separatorIndex >= label.length - 1) return null;
return {
prefix: label.slice(0, separatorIndex),
segment: label.slice(separatorIndex + 1),
};
if (typeof label !== "string") return null;
const match = label.match(/^([^:]+):\s*(.+)$/);
if (!match) return null;
const prefix = match[1].trim().toLowerCase();
const segment = (match[2] || "").trim().toLowerCase();
if (!prefix || !segment) return null;
return { prefix, segment };
}
function sortByPriority(labels, priorityIndex) {
@ -389,7 +394,7 @@ jobs:
for (const [prefix, segments] of segmentsByPrefix) {
const hasSpecificSegment = [...segments].some((segment) => segment !== "core");
if (hasSpecificSegment) {
refined.delete(`${prefix}:core`);
refined.delete(formatModuleLabel(prefix, "core"));
}
}
@ -418,7 +423,7 @@ jobs:
if (uniqueSegments.length === 0) continue;
if (uniqueSegments.length === 1) {
compactedModuleLabels.add(`${prefix}:${uniqueSegments[0]}`);
compactedModuleLabels.add(formatModuleLabel(prefix, uniqueSegments[0]));
} else {
forcePathPrefixes.add(prefix);
}
@ -609,7 +614,7 @@ jobs:
segment = normalizeLabelSegment(segment);
if (!segment) continue;
detectedModuleLabels.add(`${rule.prefix}:${segment}`);
detectedModuleLabels.add(formatModuleLabel(rule.prefix, segment));
}
}
@ -635,7 +640,7 @@ jobs:
for (const keyword of providerKeywordHints) {
if (containsKeyword(searchableText, keyword)) {
detectedModuleLabels.add(`provider:${keyword}`);
detectedModuleLabels.add(formatModuleLabel("provider", keyword));
}
}
}
@ -661,7 +666,7 @@ jobs:
for (const keyword of channelKeywordHints) {
if (containsKeyword(searchableText, keyword)) {
detectedModuleLabels.add(`channel:${keyword}`);
detectedModuleLabels.add(formatModuleLabel("channel", keyword));
}
}
}

22
.gitignore vendored
View file

@ -4,6 +4,26 @@ firmware/*/target
*.db-journal
.DS_Store
.wt-pr37/
.env
__pycache__/
*.pyc
docker-compose.override.yml
# Environment files (may contain secrets)
.env
# Python virtual environments
.venv/
venv/
# ESP32 build cache (esp-idf-sys managed)
.embuild/
.env.local
.env.*.local
# Secret keys and credentials
.secret_key
*.key
*.pem
credentials.json

View file

@ -79,6 +79,94 @@ git push --no-verify
> **Note:** CI runs the same checks, so skipped hooks will be caught on the PR.
## Local Secret Management (Required)
ZeroClaw supports layered secret management for local development and CI hygiene.
### Secret Storage Options
1. **Environment variables** (recommended for local development)
- Copy `.env.example` to `.env` and fill in values
- `.env` files are Git-ignored and should stay local
- Best for temporary/local API keys
2. **Config file** (`~/.zeroclaw/config.toml`)
- Persistent setup for long-term use
- When `secrets.encrypt = true` (default), secret values are encrypted before save
- Secret key is stored at `~/.zeroclaw/.secret_key` with restricted permissions
- Use `zeroclaw onboard` for guided setup
### Runtime Resolution Rules
API key resolution follows this order:
1. Explicit key passed from config/CLI
2. Provider-specific env vars (`OPENROUTER_API_KEY`, `OPENAI_API_KEY`, `ANTHROPIC_API_KEY`, ...)
3. Generic env vars (`ZEROCLAW_API_KEY`, `API_KEY`)
Provider/model config overrides:
- `ZEROCLAW_PROVIDER` / `PROVIDER`
- `ZEROCLAW_MODEL`
See `.env.example` for practical examples and currently supported provider key env vars.
### Pre-Commit Secret Hygiene (Mandatory)
Before every commit, verify:
- [ ] No `.env` files are staged (`.env.example` only)
- [ ] No raw API keys/tokens in code, tests, fixtures, examples, logs, or commit messages
- [ ] No credentials in debug output or error payloads
- [ ] `git diff --cached` has no accidental secret-like strings
Quick local audit:
```bash
# Search staged diff for common secret markers
git diff --cached | grep -iE '(api[_-]?key|secret|token|password|bearer|sk-)'
# Confirm no .env file is staged
git status --short | grep -E '\.env$'
```
### Optional Local Secret Scanning
For extra guardrails, install one of:
- **gitleaks**: [GitHub - gitleaks/gitleaks](https://github.com/gitleaks/gitleaks)
- **truffleHog**: [GitHub - trufflesecurity/trufflehog](https://github.com/trufflesecurity/trufflehog)
- **git-secrets**: [GitHub - awslabs/git-secrets](https://github.com/awslabs/git-secrets)
This repo includes `.githooks/pre-commit` to run `gitleaks protect --staged --redact` when gitleaks is installed.
Enable hooks with:
```bash
git config core.hooksPath .githooks
```
If gitleaks is not installed, the pre-commit hook prints a warning and continues.
### What Must Never Be Committed
- `.env` files (use `.env.example` only)
- API keys, tokens, passwords, or credentials (plain or encrypted)
- OAuth tokens or session identifiers
- Webhook signing secrets
- `~/.zeroclaw/.secret_key` or similar key files
- Personal identifiers or real user data in tests/fixtures
### If a Secret Is Committed Accidentally
1. Revoke/rotate the credential immediately
2. Do not rely only on `git revert` (history still contains the secret)
3. Purge history with `git filter-repo` or BFG
4. Force-push cleaned history (coordinate with maintainers)
5. Ensure the leaked value is removed from PR/issue/discussion/comment history
Reference: [GitHub guide: removing sensitive data from a repository](https://docs.github.com/en/authentication/keeping-your-account-and-data-secure/removing-sensitive-data-from-a-repository)
## Collaboration Tracks (Risk-Based)
To keep review throughput high without lowering quality, every PR should map to one track:

51
Cargo.lock generated
View file

@ -209,6 +209,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8b52af3cb4058c895d37317bb27508dccc8e5f2d39454016b297bf4a400597b8"
dependencies = [
"axum-core",
"base64",
"bytes",
"form_urlencoded",
"futures-util",
@ -227,8 +228,10 @@ dependencies = [
"serde_json",
"serde_path_to_error",
"serde_urlencoded",
"sha1",
"sync_wrapper",
"tokio",
"tokio-tungstenite 0.28.0",
"tower",
"tower-layer",
"tower-service",
@ -2057,6 +2060,15 @@ dependencies = [
"hashify",
]
[[package]]
name = "matchers"
version = "0.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d1525a2a28c7f4fa0fc98bb91ae755d1e2d1505079e05539e35bc876b5d65ae9"
dependencies = [
"regex-automata",
]
[[package]]
name = "matchit"
version = "0.8.4"
@ -3747,10 +3759,22 @@ dependencies = [
"rustls-pki-types",
"tokio",
"tokio-rustls",
"tungstenite",
"tungstenite 0.24.0",
"webpki-roots 0.26.11",
]
[[package]]
name = "tokio-tungstenite"
version = "0.28.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d25a406cddcc431a75d3d9afc6a7c0f7428d4891dd973e4d54c56b46127bf857"
dependencies = [
"futures-util",
"log",
"tokio",
"tungstenite 0.28.0",
]
[[package]]
name = "tokio-util"
version = "0.7.18"
@ -3940,9 +3964,13 @@ version = "0.3.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2f30143827ddab0d256fd843b7a66d164e9f271cfa0dde49142c5ca0ca291f1e"
dependencies = [
"matchers",
"nu-ansi-term",
"once_cell",
"regex-automata",
"sharded-slab",
"thread_local",
"tracing",
"tracing-core",
]
@ -3978,6 +4006,23 @@ dependencies = [
"utf-8",
]
[[package]]
name = "tungstenite"
version = "0.28.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8628dcc84e5a09eb3d8423d6cb682965dea9133204e8fb3efee74c2a0c259442"
dependencies = [
"bytes",
"data-encoding",
"http 1.4.0",
"httparse",
"log",
"rand 0.9.2",
"sha1",
"thiserror 2.0.18",
"utf-8",
]
[[package]]
name = "twox-hash"
version = "2.1.2"
@ -4880,7 +4925,9 @@ dependencies = [
"pdf-extract",
"probe-rs",
"prometheus",
"prost",
"rand 0.8.5",
"regex",
"reqwest",
"rppal",
"rusqlite",
@ -4896,7 +4943,7 @@ dependencies = [
"tokio-rustls",
"tokio-serial",
"tokio-test",
"tokio-tungstenite",
"tokio-tungstenite 0.24.0",
"toml",
"tower",
"tower-http",

View file

@ -1,3 +1,7 @@
[workspace]
members = ["."]
resolver = "2"
[package]
name = "zeroclaw"
version = "0.1.0"
@ -31,7 +35,7 @@ shellexpand = "3.1"
# Logging - minimal
tracing = { version = "0.1", default-features = false }
tracing-subscriber = { version = "0.3", default-features = false, features = ["fmt", "ansi"] }
tracing-subscriber = { version = "0.3", default-features = false, features = ["fmt", "ansi", "env-filter"] }
# Observability - Prometheus metrics
prometheus = { version = "0.14", default-features = false }
@ -63,12 +67,12 @@ rand = "0.8"
# Fast mutexes that don't poison on panic
parking_lot = "0.12"
# Landlock (Linux sandbox) - optional dependency
landlock = { version = "0.4", optional = true }
# Async traits
async-trait = "0.1"
# Protobuf encode/decode (Feishu WS long-connection frame codec)
prost = { version = "0.14", default-features = false }
# Memory / persistence
rusqlite = { version = "0.38", features = ["bundled"] }
chrono = { version = "0.4", default-features = false, features = ["clock", "std", "serde"] }
@ -86,6 +90,7 @@ glob = "0.3"
tokio-tungstenite = { version = "0.24", features = ["rustls-tls-webpki-roots"] }
futures-util = { version = "0.3", default-features = false, features = ["sink"] }
futures = "0.3"
regex = "1.10"
hostname = "0.4.2"
lettre = { version = "0.11.19", default-features = false, features = ["builder", "smtp-transport", "rustls-tls"] }
mail-parser = "0.11.2"
@ -95,7 +100,7 @@ tokio-rustls = "0.26.4"
webpki-roots = "1.0.6"
# HTTP server (gateway) — replaces raw TCP for proper HTTP/1.1 compliance
axum = { version = "0.8", default-features = false, features = ["http1", "json", "tokio", "query"] }
axum = { version = "0.8", default-features = false, features = ["http1", "json", "tokio", "query", "ws"] }
tower = { version = "0.5", default-features = false }
tower-http = { version = "0.6", default-features = false, features = ["limit", "timeout"] }
http-body-util = "0.1"
@ -117,19 +122,28 @@ probe-rs = { version = "0.30", optional = true }
# PDF extraction for datasheet RAG (optional, enable with --features rag-pdf)
pdf-extract = { version = "0.10", optional = true }
# Raspberry Pi GPIO (Linux/RPi only) — target-specific to avoid compile failure on macOS
# Raspberry Pi GPIO / Landlock (Linux only) — target-specific to avoid compile failure on macOS
[target.'cfg(target_os = "linux")'.dependencies]
rppal = { version = "0.14", optional = true }
landlock = { version = "0.4", optional = true }
[features]
default = ["hardware"]
hardware = ["nusb", "tokio-serial"]
peripheral-rpi = ["rppal"]
# Browser backend feature alias used by cfg(feature = "browser-native")
browser-native = ["dep:fantoccini"]
# Backward-compatible alias for older invocations
fantoccini = ["browser-native"]
# Sandbox feature aliases used by cfg(feature = "sandbox-*")
sandbox-landlock = ["dep:landlock"]
sandbox-bubblewrap = []
# Backward-compatible alias for older invocations
landlock = ["sandbox-landlock"]
# probe = probe-rs for Nucleo memory read (adds ~50 deps; optional)
probe = ["dep:probe-rs"]
# rag-pdf = PDF ingestion for datasheet RAG
rag-pdf = ["dep:pdf-extract"]
[profile.release]
opt-level = "z" # Optimize for size
lto = "thin" # Lower memory use during release builds

211
LICENSE
View file

@ -1,197 +1,28 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
MIT License
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
Copyright (c) 2025 ZeroClaw Labs
1. Definitions.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
================================================================================
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
This product includes software developed by ZeroClaw Labs and contributors:
https://github.com/zeroclaw-labs/zeroclaw/graphs/contributors
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to the Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
Copyright 2025-2026 Argenis Delarosa
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
===============================================================================
This product includes software developed by ZeroClaw Labs and contributors:
https://github.com/zeroclaw-labs/zeroclaw/graphs/contributors
See NOTICE file for full contributor attribution.
See NOTICE file for full contributor attribution.

View file

@ -10,14 +10,14 @@
</p>
<p align="center">
<a href="LICENSE"><img src="https://img.shields.io/badge/license-Apache%202.0-blue.svg" alt="License: Apache 2.0" /></a>
<a href="LICENSE"><img src="https://img.shields.io/badge/license-MIT-blue.svg" alt="License: MIT" /></a>
<a href="NOTICE"><img src="https://img.shields.io/badge/contributors-27+-green.svg" alt="Contributors" /></a>
</p>
Fast, small, and fully autonomous AI assistant infrastructure — deploy anywhere, swap anything.
```
~3.4MB binary · <10ms startup · 1,017 tests · 22+ providers · 8 traits · Pluggable everything
~3.4MB binary · <10ms startup · 1,017 tests · 23+ providers · 8 traits · Pluggable everything
```
### ✨ Features
@ -132,6 +132,9 @@ cd zeroclaw
cargo build --release --locked
cargo install --path . --force --locked
# Ensure ~/.cargo/bin is in your PATH
export PATH="$HOME/.cargo/bin:$PATH"
# Quick setup (no prompts)
zeroclaw onboard --api-key sk-... --provider openrouter
@ -187,7 +190,7 @@ Every subsystem is a **trait** — swap implementations with a config change, ze
| Subsystem | Trait | Ships with | Extend |
|-----------|-------|------------|--------|
| **AI Models** | `Provider` | 22+ providers (OpenRouter, Anthropic, OpenAI, Ollama, Venice, Groq, Mistral, xAI, DeepSeek, Together, Fireworks, Perplexity, Cohere, Bedrock, etc.) | `custom:https://your-api.com` — any OpenAI-compatible API |
| **AI Models** | `Provider` | 23+ providers (OpenRouter, Anthropic, OpenAI, Ollama, Venice, Groq, Mistral, xAI, DeepSeek, Together, Fireworks, Perplexity, Cohere, Bedrock, Astrai, etc.) | `custom:https://your-api.com` — any OpenAI-compatible API |
| **Channels** | `Channel` | CLI, Telegram, Discord, Slack, iMessage, Matrix, WhatsApp, Webhook | Any messaging API |
| **Memory** | `Memory` | SQLite with hybrid search (FTS5 + vector cosine similarity), Lucid bridge (CLI sync + SQLite fallback), Markdown | Any persistence backend |
| **Tools** | `Tool` | shell, file_read, file_write, memory_store, memory_recall, memory_forget, browser_open (Brave + allowlist), browser (agent-browser / rust-native), composio (optional) | Any capability |
@ -287,6 +290,21 @@ rerun channel setup only:
zeroclaw onboard --channels-only
```
### Telegram media replies
Telegram routing now replies to the source **chat ID** from incoming updates (instead of usernames),
which avoids `Bad Request: chat not found` failures.
For non-text replies, ZeroClaw can send Telegram attachments when the assistant includes markers:
- `[IMAGE:<path-or-url>]`
- `[DOCUMENT:<path-or-url>]`
- `[VIDEO:<path-or-url>]`
- `[AUDIO:<path-or-url>]`
- `[VOICE:<path-or-url>]`
Paths can be local files (for example `/tmp/screenshot.png`) or HTTPS URLs.
### WhatsApp Business Cloud API Setup
WhatsApp uses Meta's Cloud API with webhooks (push-based, not polling):
@ -610,7 +628,7 @@ We're building in the open because the best ideas come from everywhere. If you'r
## License
Apache 2.0 — see [LICENSE](LICENSE) and [NOTICE](NOTICE) for contributor attribution
MIT — see [LICENSE](LICENSE) and [NOTICE](NOTICE) for contributor attribution
## Contributing
@ -624,7 +642,6 @@ See [CONTRIBUTING.md](CONTRIBUTING.md). Implement a trait, submit a PR:
- New `Tunnel``src/tunnel/`
- New `Skill``~/.zeroclaw/workspace/skills/<name>/`
---
**ZeroClaw** — Zero overhead. Zero compromise. Deploy anywhere. Swap anything. 🦀

View file

@ -1,4 +1,4 @@
FROM ubuntu:22.04
FROM ubuntu:22.04@sha256:c7eb020043d8fc2ae0793fb35a37bff1cf33f156d4d4b12ccc7f3ef8706c38b1
# Prevent interactive prompts during package installation
ENV DEBIAN_FRONTEND=noninteractive

View file

@ -27,7 +27,7 @@ Merge-blocking checks should stay small and deterministic. Optional checks are u
### Optional Repository Automation
- `.github/workflows/labeler.yml` (`PR Labeler`)
- Purpose: scope/path labels + size/risk labels + fine-grained module labels (`<module>:<component>`)
- Purpose: scope/path labels + size/risk labels + fine-grained module labels (`<module>: <component>`)
- Additional behavior: label descriptions are auto-managed as hover tooltips to explain each auto-judgment rule
- Additional behavior: provider-related keywords in provider/config/onboard/integration changes are promoted to `provider:*` labels (for example `provider:kimi`, `provider:deepseek`)
- Additional behavior: hierarchical de-duplication keeps only the most specific scope labels (for example `tool:composio` suppresses `tool:core` and `tool`)

View file

@ -244,7 +244,7 @@ Label discipline:
- Path labels identify subsystem ownership quickly.
- Size labels drive batching strategy.
- Risk labels drive review depth (`risk: low/medium/high`).
- Module labels (`<module>:<component>`) improve reviewer routing for integration-specific changes and future newly-added modules.
- Module labels (`<module>: <component>`) improve reviewer routing for integration-specific changes and future newly-added modules.
- `risk: manual` allows maintainers to preserve a human risk judgment when automation lacks context.
- `no-stale` is reserved for accepted-but-blocked work.

View file

@ -14,7 +14,7 @@ Use it to reduce review latency without reducing quality.
For every new PR, do a fast intake pass:
1. Confirm template completeness (`summary`, `validation`, `security`, `rollback`).
2. Confirm labels (`size:*`, `risk:*`, scope labels such as `provider`/`channel`/`security`, module-scoped labels such as `channel:*`/`provider:*`/`tool:*`, and contributor tier labels when applicable) are present and plausible.
2. Confirm labels (`size:*`, `risk:*`, scope labels such as `provider`/`channel`/`security`, module-scoped labels such as `channel: *`/`provider: *`/`tool: *`, and contributor tier labels when applicable) are present and plausible.
3. Confirm CI signal status (`CI Required Gate`).
4. Confirm scope is one concern (reject mixed mega-PRs unless justified).
5. Confirm privacy/data-hygiene and neutral test wording requirements are satisfied.

View file

@ -12,6 +12,8 @@ use tokio::sync::mpsc;
pub struct ChannelMessage {
pub id: String,
pub sender: String,
/// Channel-specific reply address (e.g. Telegram chat_id, Discord channel_id).
pub reply_to: String,
pub content: String,
pub channel: String,
pub timestamp: u64,
@ -90,9 +92,12 @@ impl Channel for TelegramChannel {
continue;
}
let chat_id = msg["chat"]["id"].to_string();
let channel_msg = ChannelMessage {
id: msg["message_id"].to_string(),
sender,
reply_to: chat_id,
content: msg["text"].as_str().unwrap_or("").to_string(),
channel: "telegram".into(),
timestamp: msg["date"].as_u64().unwrap_or(0),

View file

@ -2,4 +2,10 @@
target = "riscv32imc-esp-espidf"
[target.riscv32imc-esp-espidf]
linker = "ldproxy"
runner = "espflash flash --monitor"
# ESP-IDF 5.x uses 64-bit time_t
rustflags = ["-C", "default-linker-libraries", "--cfg", "espidf_time64"]
[unstable]
build-std = ["std", "panic_abort"]

View file

@ -58,24 +58,22 @@ checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8"
[[package]]
name = "bindgen"
version = "0.63.0"
version = "0.71.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "36d860121800b2a9a94f9b5604b332d5cffb234ce17609ea479d723dbc9d3885"
checksum = "5f58bf3d7db68cfbac37cfc485a8d711e87e064c3d0fe0435b92f7a407f9d6b3"
dependencies = [
"bitflags 1.3.2",
"bitflags 2.11.0",
"cexpr",
"clang-sys",
"lazy_static",
"lazycell",
"itertools",
"log",
"peeking_take_while",
"prettyplease",
"proc-macro2",
"quote",
"regex",
"rustc-hash",
"shlex",
"syn 1.0.109",
"which",
"syn 2.0.116",
]
[[package]]
@ -374,14 +372,15 @@ checksum = "dc2d050bdc5c21e0862a89256ed8029ae6c290a93aecefc73084b3002cdebb01"
[[package]]
name = "embassy-sync"
version = "0.5.0"
version = "0.7.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dd938f25c0798db4280fcd8026bf4c2f48789aebf8f77b6e5cf8a7693ba114ec"
checksum = "73974a3edbd0bd286759b3d483540f0ebef705919a5f56f4fc7709066f71689b"
dependencies = [
"cfg-if",
"critical-section",
"embedded-io-async",
"futures-util",
"futures-core",
"futures-sink",
"heapless",
]
@ -446,16 +445,15 @@ dependencies = [
[[package]]
name = "embedded-svc"
version = "0.27.1"
version = "0.28.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ac6f87e7654f28018340aa55f933803017aefabaa5417820a3b2f808033c7bbc"
checksum = "a7770e30ab55cfbf954c00019522490d6ce26a3334bede05a732ba61010e98e0"
dependencies = [
"defmt 0.3.100",
"embedded-io",
"embedded-io-async",
"enumset",
"heapless",
"no-std-net",
"num_enum",
"serde",
"strum 0.25.0",
@ -463,9 +461,9 @@ dependencies = [
[[package]]
name = "embuild"
version = "0.31.4"
version = "0.33.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4caa4f198bb9152a55c0103efb83fa4edfcbb8625f4c9e94ae8ec8e23827c563"
checksum = "e188ad2bbe82afa841ea4a29880651e53ab86815db036b2cb9f8de3ac32dad75"
dependencies = [
"anyhow",
"bindgen",
@ -475,6 +473,7 @@ dependencies = [
"globwalk",
"home",
"log",
"regex",
"remove_dir_all",
"serde",
"serde_json",
@ -533,9 +532,8 @@ dependencies = [
[[package]]
name = "esp-idf-hal"
version = "0.43.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f7adf3fb19a9ca016cbea1ab8a7b852ac69df8fcde4923c23d3b155efbc42a74"
version = "0.45.2"
source = "git+https://github.com/esp-rs/esp-idf-hal#bc48639bd626c72afc1e25e5d497b5c639161d30"
dependencies = [
"atomic-waker",
"embassy-sync",
@ -552,14 +550,12 @@ dependencies = [
"heapless",
"log",
"nb 1.1.0",
"num_enum",
]
[[package]]
name = "esp-idf-svc"
version = "0.48.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2180642ca122a7fec1ec417a9b1a77aa66aaa067fdf1daae683dd8caba84f26b"
version = "0.51.0"
source = "git+https://github.com/esp-rs/esp-idf-svc#dee202f146c7681e54eabbf118a216fc0195d203"
dependencies = [
"embassy-futures",
"embedded-hal-async",
@ -567,6 +563,7 @@ dependencies = [
"embuild",
"enumset",
"esp-idf-hal",
"futures-io",
"heapless",
"log",
"num_enum",
@ -575,14 +572,13 @@ dependencies = [
[[package]]
name = "esp-idf-sys"
version = "0.34.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2e148f97c04ed3e9181a08bcdc9560a515aad939b0ba7f50a0022e294665e0af"
version = "0.36.1"
source = "git+https://github.com/esp-rs/esp-idf-sys#64667a38fb8004e1fc3b032488af6857ca3cd849"
dependencies = [
"anyhow",
"bindgen",
"build-time",
"cargo_metadata",
"cmake",
"const_format",
"embuild",
"envy",
@ -649,21 +645,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7e3450815272ef58cec6d564423f6e755e25379b217b0bc688e295ba24df6b1d"
[[package]]
name = "futures-task"
name = "futures-io"
version = "0.3.32"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "037711b3d59c33004d3856fbdc83b99d4ff37a24768fa1be9ce3538a1cde4393"
checksum = "cecba35d7ad927e23624b22ad55235f2239cfa44fd10428eecbeba6d6a717718"
[[package]]
name = "futures-util"
name = "futures-sink"
version = "0.3.32"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "389ca41296e6190b48053de0321d02a77f32f8a5d2461dd38762c0593805c6d6"
dependencies = [
"futures-core",
"futures-task",
"pin-project-lite",
]
checksum = "c39754e157331b013978ec91992bde1ac089843443c49cbc7f46150b0fad0893"
[[package]]
name = "getrandom"
@ -827,6 +818,15 @@ dependencies = [
"serde_core",
]
[[package]]
name = "itertools"
version = "0.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186"
dependencies = [
"either",
]
[[package]]
name = "itoa"
version = "1.0.17"
@ -843,18 +843,6 @@ dependencies = [
"wasm-bindgen",
]
[[package]]
name = "lazy_static"
version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe"
[[package]]
name = "lazycell"
version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55"
[[package]]
name = "leb128fmt"
version = "0.1.0"
@ -945,12 +933,6 @@ dependencies = [
"libc",
]
[[package]]
name = "no-std-net"
version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1bcece43b12349917e096cddfa66107277f123e6c96a5aea78711dc601a47152"
[[package]]
name = "nom"
version = "7.1.3"
@ -1007,18 +989,6 @@ version = "1.21.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d"
[[package]]
name = "peeking_take_while"
version = "0.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "19b17cddbe7ec3f8bc800887bab5e717348c95ea2ca0b1bf0837fb964dc67099"
[[package]]
name = "pin-project-lite"
version = "0.2.16"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3b3cff922bd51709b605d9ead9aa71031d81447142d828eb4a6eba76fe619f9b"
[[package]]
name = "prettyplease"
version = "0.2.37"
@ -1138,9 +1108,9 @@ dependencies = [
[[package]]
name = "rustc-hash"
version = "1.1.0"
version = "2.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2"
checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d"
[[package]]
name = "rustix"

View file

@ -14,15 +14,21 @@ edition = "2021"
license = "MIT"
description = "ZeroClaw ESP32 peripheral firmware — GPIO over JSON serial"
[patch.crates-io]
# Use latest esp-rs crates to fix u8/i8 char pointer compatibility with ESP-IDF 5.x
esp-idf-sys = { git = "https://github.com/esp-rs/esp-idf-sys" }
esp-idf-hal = { git = "https://github.com/esp-rs/esp-idf-hal" }
esp-idf-svc = { git = "https://github.com/esp-rs/esp-idf-svc" }
[dependencies]
esp-idf-svc = "0.48"
esp-idf-svc = { git = "https://github.com/esp-rs/esp-idf-svc" }
log = "0.4"
anyhow = "1.0"
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
[build-dependencies]
embuild = "0.31"
embuild = { version = "0.33", features = ["espidf"] }
[profile.release]
opt-level = "s"

View file

@ -2,8 +2,11 @@
Peripheral firmware for ESP32 — speaks the same JSON-over-serial protocol as the STM32 firmware. Flash this to your ESP32, then configure ZeroClaw on the host to connect via serial.
**New to this?** See [SETUP.md](SETUP.md) for step-by-step commands and troubleshooting.
## Protocol
- **Request** (host → ESP32): `{"id":"1","cmd":"gpio_write","args":{"pin":13,"value":1}}\n`
- **Response** (ESP32 → host): `{"id":"1","ok":true,"result":"done"}\n`
@ -11,19 +14,44 @@ Commands: `gpio_read`, `gpio_write`.
## Prerequisites
1. **ESP toolchain** (espup):
1. **RISC-V ESP-IDF** (ESP32-C2/C3): Uses nightly Rust with `build-std`.
**Python**: ESP-IDF requires Python 3.103.13 (not 3.14). If you have Python 3.14:
```sh
brew install python@3.12
```
**virtualenv** (needed by ESP-IDF tools; PEP 668 workaround on macOS):
```sh
/opt/homebrew/opt/python@3.12/bin/python3.12 -m pip install virtualenv --break-system-packages
```
**Rust tools**:
```sh
cargo install espflash ldproxy
```
The project's `rust-toolchain.toml` pins nightly + rust-src. `esp-idf-sys` downloads ESP-IDF automatically on first build. Use Python 3.12 for the build:
```sh
export PATH="/opt/homebrew/opt/python@3.12/libexec/bin:$PATH"
```
2. **Xtensa targets** (ESP32, ESP32-S2, ESP32-S3): Use espup instead:
```sh
cargo install espup espflash
espup install
source ~/export-esp.sh # or ~/export-esp.fish for Fish
source ~/export-esp.sh
```
2. **Target**: ESP32-C3 (RISC-V) by default. Edit `.cargo/config.toml` for other targets (e.g. `xtensa-esp32-espidf` for original ESP32).
Then edit `.cargo/config.toml` to change the target (e.g. `xtensa-esp32-espidf`).
## Build & Flash
```sh
cd firmware/zeroclaw-esp32
# Use Python 3.12 (required if you have 3.14)
export PATH="/opt/homebrew/opt/python@3.12/libexec/bin:$PATH"
# Optional: pin MCU (esp32c3 or esp32c2)
export MCU=esp32c3
cargo build --release
espflash flash target/riscv32imc-esp-espidf/release/zeroclaw-esp32 --monitor
```

View file

@ -0,0 +1,156 @@
# ESP32 Firmware Setup Guide
Step-by-step setup for building the ZeroClaw ESP32 firmware. Follow this if you run into issues.
## Quick Start (copy-paste)
```sh
# 1. Install Python 3.12 (ESP-IDF needs 3.103.13, not 3.14)
brew install python@3.12
# 2. Install virtualenv (PEP 668 workaround on macOS)
/opt/homebrew/opt/python@3.12/bin/python3.12 -m pip install virtualenv --break-system-packages
# 3. Install Rust tools
cargo install espflash ldproxy
# 4. Build
cd firmware/zeroclaw-esp32
export PATH="/opt/homebrew/opt/python@3.12/libexec/bin:$PATH"
cargo build --release
# 5. Flash (connect ESP32 via USB)
espflash flash target/riscv32imc-esp-espidf/release/zeroclaw-esp32 --monitor
```
---
## Detailed Steps
### 1. Python
ESP-IDF requires Python 3.103.13. **Python 3.14 is not supported.**
```sh
brew install python@3.12
```
### 2. virtualenv
ESP-IDF tools need `virtualenv`. On macOS with Homebrew Python, PEP 668 blocks `pip install`; use:
```sh
/opt/homebrew/opt/python@3.12/bin/python3.12 -m pip install virtualenv --break-system-packages
```
### 3. Rust Tools
```sh
cargo install espflash ldproxy
```
- **espflash**: flash and monitor
- **ldproxy**: linker for ESP-IDF builds
### 4. Use Python 3.12 for Builds
Before every build (or add to `~/.zshrc`):
```sh
export PATH="/opt/homebrew/opt/python@3.12/libexec/bin:$PATH"
```
### 5. Build
```sh
cd firmware/zeroclaw-esp32
cargo build --release
```
First build downloads and compiles ESP-IDF (~515 min).
### 6. Flash
```sh
espflash flash target/riscv32imc-esp-espidf/release/zeroclaw-esp32 --monitor
```
---
## Troubleshooting
### "No space left on device"
Free disk space. Common targets:
```sh
# Cargo cache (often 520 GB)
rm -rf ~/.cargo/registry/cache ~/.cargo/registry/src
# Unused Rust toolchains
rustup toolchain list
rustup toolchain uninstall <name>
# iOS Simulator runtimes (~35 GB)
xcrun simctl delete unavailable
# Temp files
rm -rf /var/folders/*/T/cargo-install*
```
### "can't find crate for `core`" / "riscv32imc-esp-espidf target may not be installed"
This project uses **nightly Rust with build-std**, not espup. Ensure:
- `rust-toolchain.toml` exists (pins nightly + rust-src)
- You are **not** sourcing `~/export-esp.sh` (that's for Xtensa targets)
- Run `cargo build` from `firmware/zeroclaw-esp32`
### "externally-managed-environment" / "No module named 'virtualenv'"
Install virtualenv with the PEP 668 workaround:
```sh
/opt/homebrew/opt/python@3.12/bin/python3.12 -m pip install virtualenv --break-system-packages
```
### "expected `i64`, found `i32`" (time_t mismatch)
Already fixed in `.cargo/config.toml` with `espidf_time64` for ESP-IDF 5.x. If you use ESP-IDF 4.4, switch to `espidf_time32`.
### "expected `*const u8`, found `*const i8`" (esp-idf-svc)
Already fixed via `[patch.crates-io]` in `Cargo.toml` using esp-rs crates from git. Do not remove the patch.
### 10,000+ files in `git status`
The `.embuild/` directory (ESP-IDF cache) has ~100k+ files. It is in `.gitignore`. If you see them, ensure `.gitignore` contains:
```
.embuild/
```
---
## Optional: Auto-load Python 3.12
Add to `~/.zshrc`:
```sh
# ESP32 firmware build
export PATH="/opt/homebrew/opt/python@3.12/libexec/bin:$PATH"
```
---
## Xtensa Targets (ESP32, ESP32-S2, ESP32-S3)
For nonRISC-V chips, use espup instead:
```sh
cargo install espup espflash
espup install
source ~/export-esp.sh
```
Then edit `.cargo/config.toml` to use `xtensa-esp32-espidf` (or the correct target).

View file

@ -0,0 +1,3 @@
[toolchain]
channel = "nightly"
components = ["rust-src"]

View file

@ -6,8 +6,9 @@
//! Protocol: same as STM32 — see docs/hardware-peripherals-design.md
use esp_idf_svc::hal::gpio::PinDriver;
use esp_idf_svc::hal::prelude::*;
use esp_idf_svc::hal::uart::*;
use esp_idf_svc::hal::peripherals::Peripherals;
use esp_idf_svc::hal::uart::{UartConfig, UartDriver};
use esp_idf_svc::hal::units::Hertz;
use log::info;
use serde::{Deserialize, Serialize};
@ -36,9 +37,13 @@ fn main() -> anyhow::Result<()> {
let peripherals = Peripherals::take()?;
let pins = peripherals.pins;
// Create GPIO output drivers first (they take ownership of pins)
let mut gpio2 = PinDriver::output(pins.gpio2)?;
let mut gpio13 = PinDriver::output(pins.gpio13)?;
// UART0: TX=21, RX=20 (ESP32) — ESP32-C3 may use different pins; adjust for your board
let config = UartConfig::new().baudrate(Hertz(115_200));
let mut uart = UartDriver::new(
let uart = UartDriver::new(
peripherals.uart0,
pins.gpio21,
pins.gpio20,
@ -60,7 +65,8 @@ fn main() -> anyhow::Result<()> {
if b == b'\n' {
if !line.is_empty() {
if let Ok(line_str) = std::str::from_utf8(&line) {
if let Ok(resp) = handle_request(line_str, &peripherals) {
if let Ok(resp) = handle_request(line_str, &mut gpio2, &mut gpio13)
{
let out = serde_json::to_string(&resp).unwrap_or_default();
let _ = uart.write(format!("{}\n", out).as_bytes());
}
@ -80,10 +86,15 @@ fn main() -> anyhow::Result<()> {
}
}
fn handle_request(
fn handle_request<G2, G13>(
line: &str,
peripherals: &esp_idf_svc::hal::peripherals::Peripherals,
) -> anyhow::Result<Response> {
gpio2: &mut PinDriver<'_, G2>,
gpio13: &mut PinDriver<'_, G13>,
) -> anyhow::Result<Response>
where
G2: esp_idf_svc::hal::gpio::OutputMode,
G13: esp_idf_svc::hal::gpio::OutputMode,
{
let req: Request = serde_json::from_str(line.trim())?;
let id = req.id.clone();
@ -98,13 +109,13 @@ fn handle_request(
}
"gpio_read" => {
let pin_num = req.args.get("pin").and_then(|v| v.as_u64()).unwrap_or(0) as i32;
let value = gpio_read(peripherals, pin_num)?;
let value = gpio_read(pin_num)?;
Ok(value.to_string())
}
"gpio_write" => {
let pin_num = req.args.get("pin").and_then(|v| v.as_u64()).unwrap_or(0) as i32;
let value = req.args.get("value").and_then(|v| v.as_u64()).unwrap_or(0);
gpio_write(peripherals, pin_num, value)?;
gpio_write(gpio2, gpio13, pin_num, value)?;
Ok("done".into())
}
_ => Err(anyhow::anyhow!("Unknown command: {}", req.cmd)),
@ -126,28 +137,26 @@ fn handle_request(
}
}
fn gpio_read(_peripherals: &esp_idf_svc::hal::peripherals::Peripherals, _pin: i32) -> anyhow::Result<u8> {
fn gpio_read(_pin: i32) -> anyhow::Result<u8> {
// TODO: implement input pin read — requires storing InputPin drivers per pin
Ok(0)
}
fn gpio_write(
peripherals: &esp_idf_svc::hal::peripherals::Peripherals,
fn gpio_write<G2, G13>(
gpio2: &mut PinDriver<'_, G2>,
gpio13: &mut PinDriver<'_, G13>,
pin: i32,
value: u64,
) -> anyhow::Result<()> {
let pins = peripherals.pins;
let level = value != 0;
) -> anyhow::Result<()>
where
G2: esp_idf_svc::hal::gpio::OutputMode,
G13: esp_idf_svc::hal::gpio::OutputMode,
{
let level = esp_idf_svc::hal::gpio::Level::from(value != 0);
match pin {
2 => {
let mut out = PinDriver::output(pins.gpio2)?;
out.set_level(esp_idf_svc::hal::gpio::Level::from(level))?;
}
13 => {
let mut out = PinDriver::output(pins.gpio13)?;
out.set_level(esp_idf_svc::hal::gpio::Level::from(level))?;
}
2 => gpio2.set_level(level)?,
13 => gpio13.set_level(level)?,
_ => anyhow::bail!("Pin {} not configured (add to gpio_write)", pin),
}
Ok(())

View file

@ -0,0 +1,324 @@
#!/usr/bin/env bash
set -euo pipefail
SCRIPT_NAME="$(basename "$0")"
usage() {
cat <<USAGE
Recompute contributor tier labels for historical PRs/issues.
Usage:
./$SCRIPT_NAME [options]
Options:
--repo <owner/repo> Target repository (default: current gh repo)
--kind <both|prs|issues>
Target objects (default: both)
--state <all|open|closed>
State filter for listing objects (default: all)
--limit <N> Limit processed objects after fetch (default: 0 = no limit)
--apply Apply label updates (default is dry-run)
--dry-run Preview only (default)
-h, --help Show this help
Examples:
./$SCRIPT_NAME --repo zeroclaw-labs/zeroclaw --limit 50
./$SCRIPT_NAME --repo zeroclaw-labs/zeroclaw --kind prs --state open --apply
USAGE
}
die() {
echo "[$SCRIPT_NAME] ERROR: $*" >&2
exit 1
}
require_cmd() {
if ! command -v "$1" >/dev/null 2>&1; then
die "Required command not found: $1"
fi
}
urlencode() {
jq -nr --arg value "$1" '$value|@uri'
}
select_contributor_tier() {
local merged_count="$1"
if (( merged_count >= 50 )); then
echo "distinguished contributor"
elif (( merged_count >= 20 )); then
echo "principal contributor"
elif (( merged_count >= 10 )); then
echo "experienced contributor"
elif (( merged_count >= 5 )); then
echo "trusted contributor"
else
echo ""
fi
}
DRY_RUN=1
KIND="both"
STATE="all"
LIMIT=0
REPO=""
while (($# > 0)); do
case "$1" in
--repo)
[[ $# -ge 2 ]] || die "Missing value for --repo"
REPO="$2"
shift 2
;;
--kind)
[[ $# -ge 2 ]] || die "Missing value for --kind"
KIND="$2"
shift 2
;;
--state)
[[ $# -ge 2 ]] || die "Missing value for --state"
STATE="$2"
shift 2
;;
--limit)
[[ $# -ge 2 ]] || die "Missing value for --limit"
LIMIT="$2"
shift 2
;;
--apply)
DRY_RUN=0
shift
;;
--dry-run)
DRY_RUN=1
shift
;;
-h|--help)
usage
exit 0
;;
*)
die "Unknown option: $1"
;;
esac
done
case "$KIND" in
both|prs|issues) ;;
*) die "--kind must be one of: both, prs, issues" ;;
esac
case "$STATE" in
all|open|closed) ;;
*) die "--state must be one of: all, open, closed" ;;
esac
if ! [[ "$LIMIT" =~ ^[0-9]+$ ]]; then
die "--limit must be a non-negative integer"
fi
require_cmd gh
require_cmd jq
if ! gh auth status >/dev/null 2>&1; then
die "gh CLI is not authenticated. Run: gh auth login"
fi
if [[ -z "$REPO" ]]; then
REPO="$(gh repo view --json nameWithOwner --jq '.nameWithOwner' 2>/dev/null || true)"
[[ -n "$REPO" ]] || die "Unable to infer repo. Pass --repo <owner/repo>."
fi
echo "[$SCRIPT_NAME] Repo: $REPO"
echo "[$SCRIPT_NAME] Mode: $([[ "$DRY_RUN" -eq 1 ]] && echo "dry-run" || echo "apply")"
echo "[$SCRIPT_NAME] Kind: $KIND | State: $STATE | Limit: $LIMIT"
TIERS_JSON='["trusted contributor","experienced contributor","principal contributor","distinguished contributor"]'
TMP_FILES=()
cleanup() {
if ((${#TMP_FILES[@]} > 0)); then
rm -f "${TMP_FILES[@]}"
fi
}
trap cleanup EXIT
new_tmp_file() {
local tmp
tmp="$(mktemp)"
TMP_FILES+=("$tmp")
echo "$tmp"
}
targets_file="$(new_tmp_file)"
if [[ "$KIND" == "both" || "$KIND" == "prs" ]]; then
gh api --paginate "repos/$REPO/pulls?state=$STATE&per_page=100" \
--jq '.[] | {
kind: "pr",
number: .number,
author: (.user.login // ""),
author_type: (.user.type // ""),
labels: [(.labels[]?.name // empty)]
}' >> "$targets_file"
fi
if [[ "$KIND" == "both" || "$KIND" == "issues" ]]; then
gh api --paginate "repos/$REPO/issues?state=$STATE&per_page=100" \
--jq '.[] | select(.pull_request | not) | {
kind: "issue",
number: .number,
author: (.user.login // ""),
author_type: (.user.type // ""),
labels: [(.labels[]?.name // empty)]
}' >> "$targets_file"
fi
if [[ "$LIMIT" -gt 0 ]]; then
limited_file="$(new_tmp_file)"
head -n "$LIMIT" "$targets_file" > "$limited_file"
mv "$limited_file" "$targets_file"
fi
target_count="$(wc -l < "$targets_file" | tr -d ' ')"
if [[ "$target_count" -eq 0 ]]; then
echo "[$SCRIPT_NAME] No targets found."
exit 0
fi
echo "[$SCRIPT_NAME] Targets fetched: $target_count"
# Ensure tier labels exist (trusted contributor might be new).
label_color=""
for probe_label in "experienced contributor" "principal contributor" "distinguished contributor" "trusted contributor"; do
encoded_label="$(urlencode "$probe_label")"
if color_candidate="$(gh api "repos/$REPO/labels/$encoded_label" --jq '.color' 2>/dev/null || true)"; then
if [[ -n "$color_candidate" ]]; then
label_color="$(echo "$color_candidate" | tr '[:lower:]' '[:upper:]')"
break
fi
fi
done
[[ -n "$label_color" ]] || label_color="C5D7A2"
while IFS= read -r tier_label; do
[[ -n "$tier_label" ]] || continue
encoded_label="$(urlencode "$tier_label")"
if gh api "repos/$REPO/labels/$encoded_label" >/dev/null 2>&1; then
continue
fi
if [[ "$DRY_RUN" -eq 1 ]]; then
echo "[dry-run] Would create missing label: $tier_label (color=$label_color)"
else
gh api -X POST "repos/$REPO/labels" \
-f name="$tier_label" \
-f color="$label_color" >/dev/null
echo "[apply] Created missing label: $tier_label"
fi
done < <(jq -r '.[]' <<<"$TIERS_JSON")
# Build merged PR count cache by unique human authors.
authors_file="$(new_tmp_file)"
jq -r 'select(.author != "" and .author_type != "Bot") | .author' "$targets_file" | sort -u > "$authors_file"
author_count="$(wc -l < "$authors_file" | tr -d ' ')"
echo "[$SCRIPT_NAME] Unique human authors: $author_count"
author_counts_file="$(new_tmp_file)"
while IFS= read -r author; do
[[ -n "$author" ]] || continue
query="repo:$REPO is:pr is:merged author:$author"
merged_count="$(gh api search/issues -f q="$query" -F per_page=1 --jq '.total_count' 2>/dev/null || true)"
if ! [[ "$merged_count" =~ ^[0-9]+$ ]]; then
merged_count=0
fi
printf '%s\t%s\n' "$author" "$merged_count" >> "$author_counts_file"
done < "$authors_file"
updated=0
unchanged=0
skipped=0
failed=0
while IFS= read -r target_json; do
[[ -n "$target_json" ]] || continue
number="$(jq -r '.number' <<<"$target_json")"
kind="$(jq -r '.kind' <<<"$target_json")"
author="$(jq -r '.author' <<<"$target_json")"
author_type="$(jq -r '.author_type' <<<"$target_json")"
current_labels_json="$(jq -c '.labels // []' <<<"$target_json")"
if [[ -z "$author" || "$author_type" == "Bot" ]]; then
skipped=$((skipped + 1))
continue
fi
merged_count="$(awk -F '\t' -v key="$author" '$1 == key { print $2; exit }' "$author_counts_file")"
if ! [[ "$merged_count" =~ ^[0-9]+$ ]]; then
merged_count=0
fi
desired_tier="$(select_contributor_tier "$merged_count")"
if ! current_tier="$(jq -r --argjson tiers "$TIERS_JSON" '[.[] | select(. as $label | ($tiers | index($label)) != null)][0] // ""' <<<"$current_labels_json" 2>/dev/null)"; then
echo "[warn] Skipping ${kind} #${number}: cannot parse current labels JSON" >&2
failed=$((failed + 1))
continue
fi
if ! next_labels_json="$(jq -c --arg desired "$desired_tier" --argjson tiers "$TIERS_JSON" '
(. // [])
| map(select(. as $label | ($tiers | index($label)) == null))
| if $desired != "" then . + [$desired] else . end
| unique
' <<<"$current_labels_json" 2>/dev/null)"; then
echo "[warn] Skipping ${kind} #${number}: cannot compute next labels" >&2
failed=$((failed + 1))
continue
fi
if ! normalized_current="$(jq -c 'unique | sort' <<<"$current_labels_json" 2>/dev/null)"; then
echo "[warn] Skipping ${kind} #${number}: cannot normalize current labels" >&2
failed=$((failed + 1))
continue
fi
if ! normalized_next="$(jq -c 'unique | sort' <<<"$next_labels_json" 2>/dev/null)"; then
echo "[warn] Skipping ${kind} #${number}: cannot normalize next labels" >&2
failed=$((failed + 1))
continue
fi
if [[ "$normalized_current" == "$normalized_next" ]]; then
unchanged=$((unchanged + 1))
continue
fi
if [[ "$DRY_RUN" -eq 1 ]]; then
echo "[dry-run] ${kind} #${number} @${author} merged=${merged_count} tier: '${current_tier:-none}' -> '${desired_tier:-none}'"
updated=$((updated + 1))
continue
fi
payload="$(jq -cn --argjson labels "$next_labels_json" '{labels: $labels}')"
if gh api -X PUT "repos/$REPO/issues/$number/labels" --input - <<<"$payload" >/dev/null; then
echo "[apply] Updated ${kind} #${number} @${author} tier: '${current_tier:-none}' -> '${desired_tier:-none}'"
updated=$((updated + 1))
else
echo "[apply] FAILED ${kind} #${number}" >&2
failed=$((failed + 1))
fi
done < "$targets_file"
echo ""
echo "[$SCRIPT_NAME] Summary"
echo " Targets: $target_count"
echo " Updated: $updated"
echo " Unchanged: $unchanged"
echo " Skipped: $skipped"
echo " Failed: $failed"
if [[ "$failed" -gt 0 ]]; then
exit 1
fi

View file

@ -251,6 +251,7 @@ impl Agent {
let provider: Box<dyn Provider> = providers::create_routed_provider(
provider_name,
config.api_key.as_deref(),
config.api_url.as_deref(),
&config.reliability,
&config.model_routes,
&model_name,
@ -388,7 +389,7 @@ impl Agent {
if self.auto_save {
let _ = self
.memory
.store("user_msg", user_message, MemoryCategory::Conversation)
.store("user_msg", user_message, MemoryCategory::Conversation, None)
.await;
}
@ -447,7 +448,7 @@ impl Agent {
let summary = truncate_with_ellipsis(&final_text, 100);
let _ = self
.memory
.store("assistant_resp", &summary, MemoryCategory::Daily)
.store("assistant_resp", &summary, MemoryCategory::Daily, None)
.await;
}
@ -557,6 +558,7 @@ pub async fn run(
agent.observer.record_event(&ObserverEvent::AgentEnd {
duration: start.elapsed(),
tokens_used: None,
cost_usd: None,
});
Ok(())

View file

@ -7,14 +7,70 @@ use crate::security::SecurityPolicy;
use crate::tools::{self, Tool};
use crate::util::truncate_with_ellipsis;
use anyhow::Result;
use regex::{Regex, RegexSet};
use std::fmt::Write;
use std::io::Write as _;
use std::sync::Arc;
use std::sync::{Arc, LazyLock};
use std::time::Instant;
use uuid::Uuid;
/// Maximum agentic tool-use iterations per user message to prevent runaway loops.
const MAX_TOOL_ITERATIONS: usize = 10;
static SENSITIVE_KEY_PATTERNS: LazyLock<RegexSet> = LazyLock::new(|| {
RegexSet::new([
r"(?i)token",
r"(?i)api[_-]?key",
r"(?i)password",
r"(?i)secret",
r"(?i)user[_-]?key",
r"(?i)bearer",
r"(?i)credential",
])
.unwrap()
});
static SENSITIVE_KV_REGEX: LazyLock<Regex> = LazyLock::new(|| {
Regex::new(r#"(?i)(token|api[_-]?key|password|secret|user[_-]?key|bearer|credential)["']?\s*[:=]\s*(?:"([^"]{8,})"|'([^']{8,})'|([a-zA-Z0-9_\-\.]{8,}))"#).unwrap()
});
/// Scrub credentials from tool output to prevent accidental exfiltration.
/// Replaces known credential patterns with a redacted placeholder while preserving
/// a small prefix for context.
fn scrub_credentials(input: &str) -> String {
SENSITIVE_KV_REGEX
.replace_all(input, |caps: &regex::Captures| {
let full_match = &caps[0];
let key = &caps[1];
let val = caps
.get(2)
.or(caps.get(3))
.or(caps.get(4))
.map(|m| m.as_str())
.unwrap_or("");
// Preserve first 4 chars for context, then redact
let prefix = if val.len() > 4 { &val[..4] } else { "" };
if full_match.contains(':') {
if full_match.contains('"') {
format!("\"{}\": \"{}*[REDACTED]\"", key, prefix)
} else {
format!("{}: {}*[REDACTED]", key, prefix)
}
} else if full_match.contains('=') {
if full_match.contains('"') {
format!("{}=\"{}*[REDACTED]\"", key, prefix)
} else {
format!("{}={}*[REDACTED]", key, prefix)
}
} else {
format!("{}: {}*[REDACTED]", key, prefix)
}
})
.to_string()
}
/// Trigger auto-compaction when non-system message count exceeds this threshold.
const MAX_HISTORY_MESSAGES: usize = 50;
@ -145,7 +201,7 @@ async fn build_context(mem: &dyn Memory, user_msg: &str) -> String {
let mut context = String::new();
// Pull relevant memories for this message
if let Ok(entries) = mem.recall(user_msg, 5).await {
if let Ok(entries) = mem.recall(user_msg, 5, None).await {
if !entries.is_empty() {
context.push_str("[Memory context]\n");
for entry in &entries {
@ -436,6 +492,7 @@ struct ParsedToolCall {
/// Execute a single turn of the agent loop: send messages, parse tool calls,
/// execute tools, and loop until the LLM produces a final text response.
/// When `silent` is true, suppresses stdout (for channel use).
#[allow(clippy::too_many_arguments)]
pub(crate) async fn agent_turn(
provider: &dyn Provider,
history: &mut Vec<ChatMessage>,
@ -461,6 +518,7 @@ pub(crate) async fn agent_turn(
/// Execute a single turn of the agent loop: send messages, parse tool calls,
/// execute tools, and loop until the LLM produces a final text response.
#[allow(clippy::too_many_arguments)]
pub(crate) async fn run_tool_call_loop(
provider: &dyn Provider,
history: &mut Vec<ChatMessage>,
@ -606,7 +664,7 @@ pub(crate) async fn run_tool_call_loop(
success: r.success,
});
if r.success {
r.output
scrub_credentials(&r.output)
} else {
format!("Error: {}", r.error.unwrap_or_else(|| r.output))
}
@ -749,6 +807,7 @@ pub async fn run(
let provider: Box<dyn Provider> = providers::create_routed_provider(
provider_name,
config.api_key.as_deref(),
config.api_url.as_deref(),
&config.reliability,
&config.model_routes,
model_name,
@ -912,7 +971,7 @@ pub async fn run(
if config.memory.auto_save {
let user_key = autosave_memory_key("user_msg");
let _ = mem
.store(&user_key, &msg, MemoryCategory::Conversation)
.store(&user_key, &msg, MemoryCategory::Conversation, None)
.await;
}
@ -955,7 +1014,7 @@ pub async fn run(
let summary = truncate_with_ellipsis(&response, 100);
let response_key = autosave_memory_key("assistant_resp");
let _ = mem
.store(&response_key, &summary, MemoryCategory::Daily)
.store(&response_key, &summary, MemoryCategory::Daily, None)
.await;
}
} else {
@ -978,7 +1037,7 @@ pub async fn run(
if config.memory.auto_save {
let user_key = autosave_memory_key("user_msg");
let _ = mem
.store(&user_key, &msg.content, MemoryCategory::Conversation)
.store(&user_key, &msg.content, MemoryCategory::Conversation, None)
.await;
}
@ -1036,7 +1095,7 @@ pub async fn run(
let summary = truncate_with_ellipsis(&response, 100);
let response_key = autosave_memory_key("assistant_resp");
let _ = mem
.store(&response_key, &summary, MemoryCategory::Daily)
.store(&response_key, &summary, MemoryCategory::Daily, None)
.await;
}
}
@ -1048,6 +1107,7 @@ pub async fn run(
observer.record_event(&ObserverEvent::AgentEnd {
duration,
tokens_used: None,
cost_usd: None,
});
Ok(final_output)
@ -1104,6 +1164,7 @@ pub async fn process_message(config: Config, message: &str) -> Result<String> {
let provider: Box<dyn Provider> = providers::create_routed_provider(
provider_name,
config.api_key.as_deref(),
config.api_url.as_deref(),
&config.reliability,
&config.model_routes,
&model_name,
@ -1217,6 +1278,25 @@ pub async fn process_message(config: Config, message: &str) -> Result<String> {
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_scrub_credentials() {
let input = "API_KEY=sk-1234567890abcdef; token: 1234567890; password=\"secret123456\"";
let scrubbed = scrub_credentials(input);
assert!(scrubbed.contains("API_KEY=sk-1*[REDACTED]"));
assert!(scrubbed.contains("token: 1234*[REDACTED]"));
assert!(scrubbed.contains("password=\"secr*[REDACTED]\""));
assert!(!scrubbed.contains("abcdef"));
assert!(!scrubbed.contains("secret123456"));
}
#[test]
fn test_scrub_credentials_json() {
let input = r#"{"api_key": "sk-1234567890", "other": "public"}"#;
let scrubbed = scrub_credentials(input);
assert!(scrubbed.contains("\"api_key\": \"sk-1*[REDACTED]\""));
assert!(scrubbed.contains("public"));
}
use crate::memory::{Memory, MemoryCategory, SqliteMemory};
use tempfile::TempDir;
@ -1496,16 +1576,16 @@ I will now call the tool with this payload:
let key1 = autosave_memory_key("user_msg");
let key2 = autosave_memory_key("user_msg");
mem.store(&key1, "I'm Paul", MemoryCategory::Conversation)
mem.store(&key1, "I'm Paul", MemoryCategory::Conversation, None)
.await
.unwrap();
mem.store(&key2, "I'm 45", MemoryCategory::Conversation)
mem.store(&key2, "I'm 45", MemoryCategory::Conversation, None)
.await
.unwrap();
assert_eq!(mem.count().await.unwrap(), 2);
let recalled = mem.recall("45", 5).await.unwrap();
let recalled = mem.recall("45", 5, None).await.unwrap();
assert!(recalled.iter().any(|entry| entry.content.contains("45")));
}

View file

@ -33,7 +33,7 @@ impl MemoryLoader for DefaultMemoryLoader {
memory: &dyn Memory,
user_message: &str,
) -> anyhow::Result<String> {
let entries = memory.recall(user_message, self.limit).await?;
let entries = memory.recall(user_message, self.limit, None).await?;
if entries.is_empty() {
return Ok(String::new());
}
@ -61,11 +61,17 @@ mod tests {
_key: &str,
_content: &str,
_category: MemoryCategory,
_session_id: Option<&str>,
) -> anyhow::Result<()> {
Ok(())
}
async fn recall(&self, _query: &str, limit: usize) -> anyhow::Result<Vec<MemoryEntry>> {
async fn recall(
&self,
_query: &str,
limit: usize,
_session_id: Option<&str>,
) -> anyhow::Result<Vec<MemoryEntry>> {
if limit == 0 {
return Ok(vec![]);
}
@ -87,6 +93,7 @@ mod tests {
async fn list(
&self,
_category: Option<&MemoryCategory>,
_session_id: Option<&str>,
) -> anyhow::Result<Vec<MemoryEntry>> {
Ok(vec![])
}

View file

@ -40,6 +40,7 @@ impl Channel for CliChannel {
let msg = ChannelMessage {
id: Uuid::new_v4().to_string(),
sender: "user".to_string(),
reply_target: "user".to_string(),
content: line,
channel: "cli".to_string(),
timestamp: std::time::SystemTime::now()
@ -90,12 +91,14 @@ mod tests {
let msg = ChannelMessage {
id: "test-id".into(),
sender: "user".into(),
reply_target: "user".into(),
content: "hello".into(),
channel: "cli".into(),
timestamp: 1_234_567_890,
};
assert_eq!(msg.id, "test-id");
assert_eq!(msg.sender, "user");
assert_eq!(msg.reply_target, "user");
assert_eq!(msg.content, "hello");
assert_eq!(msg.channel, "cli");
assert_eq!(msg.timestamp, 1_234_567_890);
@ -106,6 +109,7 @@ mod tests {
let msg = ChannelMessage {
id: "id".into(),
sender: "s".into(),
reply_target: "s".into(),
content: "c".into(),
channel: "ch".into(),
timestamp: 0,

View file

@ -7,7 +7,7 @@ use tokio::sync::RwLock;
use tokio_tungstenite::tungstenite::Message;
use uuid::Uuid;
/// DingTalk (钉钉) channel — connects via Stream Mode WebSocket for real-time messages.
/// DingTalk channel — connects via Stream Mode WebSocket for real-time messages.
/// Replies are sent through per-message session webhook URLs.
pub struct DingTalkChannel {
client_id: String,
@ -64,6 +64,18 @@ impl DingTalkChannel {
let gw: GatewayResponse = resp.json().await?;
Ok(gw)
}
fn resolve_reply_target(
sender_id: &str,
conversation_type: &str,
conversation_id: Option<&str>,
) -> String {
if conversation_type == "1" {
sender_id.to_string()
} else {
conversation_id.unwrap_or(sender_id).to_string()
}
}
}
#[async_trait]
@ -193,14 +205,11 @@ impl Channel for DingTalkChannel {
.unwrap_or("1");
// Private chat uses sender ID, group chat uses conversation ID
let chat_id = if conversation_type == "1" {
sender_id.to_string()
} else {
data.get("conversationId")
.and_then(|c| c.as_str())
.unwrap_or(sender_id)
.to_string()
};
let chat_id = Self::resolve_reply_target(
sender_id,
conversation_type,
data.get("conversationId").and_then(|c| c.as_str()),
);
// Store session webhook for later replies
if let Some(webhook) = data.get("sessionWebhook").and_then(|w| w.as_str()) {
@ -229,6 +238,7 @@ impl Channel for DingTalkChannel {
let channel_msg = ChannelMessage {
id: Uuid::new_v4().to_string(),
sender: sender_id.to_string(),
reply_target: chat_id,
content: content.to_string(),
channel: "dingtalk".to_string(),
timestamp: std::time::SystemTime::now()
@ -305,4 +315,22 @@ client_secret = "secret"
let config: crate::config::schema::DingTalkConfig = toml::from_str(toml_str).unwrap();
assert!(config.allowed_users.is_empty());
}
#[test]
fn test_resolve_reply_target_private_chat_uses_sender_id() {
let target = DingTalkChannel::resolve_reply_target("staff_1", "1", Some("conv_1"));
assert_eq!(target, "staff_1");
}
#[test]
fn test_resolve_reply_target_group_chat_uses_conversation_id() {
let target = DingTalkChannel::resolve_reply_target("staff_1", "2", Some("conv_1"));
assert_eq!(target, "conv_1");
}
#[test]
fn test_resolve_reply_target_group_chat_falls_back_to_sender_id() {
let target = DingTalkChannel::resolve_reply_target("staff_1", "2", None);
assert_eq!(target, "staff_1");
}
}

View file

@ -11,6 +11,7 @@ pub struct DiscordChannel {
guild_id: Option<String>,
allowed_users: Vec<String>,
listen_to_bots: bool,
mention_only: bool,
client: reqwest::Client,
typing_handle: std::sync::Mutex<Option<tokio::task::JoinHandle<()>>>,
}
@ -21,12 +22,14 @@ impl DiscordChannel {
guild_id: Option<String>,
allowed_users: Vec<String>,
listen_to_bots: bool,
mention_only: bool,
) -> Self {
Self {
bot_token,
guild_id,
allowed_users,
listen_to_bots,
mention_only,
client: reqwest::Client::new(),
typing_handle: std::sync::Mutex::new(None),
}
@ -343,6 +346,22 @@ impl Channel for DiscordChannel {
continue;
}
// Skip messages that don't @-mention the bot (when mention_only is enabled)
if self.mention_only {
let mention_tag = format!("<@{bot_user_id}>");
if !content.contains(&mention_tag) {
continue;
}
}
// Strip the bot mention from content so the agent sees clean text
let clean_content = if self.mention_only {
let mention_tag = format!("<@{bot_user_id}>");
content.replace(&mention_tag, "").trim().to_string()
} else {
content.to_string()
};
let message_id = d.get("id").and_then(|i| i.as_str()).unwrap_or("");
let channel_id = d.get("channel_id").and_then(|c| c.as_str()).unwrap_or("").to_string();
@ -353,6 +372,11 @@ impl Channel for DiscordChannel {
format!("discord_{message_id}")
},
sender: author_id.to_string(),
reply_target: if channel_id.is_empty() {
author_id.to_string()
} else {
channel_id
},
content: content.to_string(),
channel: channel_id,
timestamp: std::time::SystemTime::now()
@ -423,7 +447,7 @@ mod tests {
#[test]
fn discord_channel_name() {
let ch = DiscordChannel::new("fake".into(), None, vec![], false);
let ch = DiscordChannel::new("fake".into(), None, vec![], false, false);
assert_eq!(ch.name(), "discord");
}
@ -444,21 +468,27 @@ mod tests {
#[test]
fn empty_allowlist_denies_everyone() {
let ch = DiscordChannel::new("fake".into(), None, vec![], false);
let ch = DiscordChannel::new("fake".into(), None, vec![], false, false);
assert!(!ch.is_user_allowed("12345"));
assert!(!ch.is_user_allowed("anyone"));
}
#[test]
fn wildcard_allows_everyone() {
let ch = DiscordChannel::new("fake".into(), None, vec!["*".into()], false);
let ch = DiscordChannel::new("fake".into(), None, vec!["*".into()], false, false);
assert!(ch.is_user_allowed("12345"));
assert!(ch.is_user_allowed("anyone"));
}
#[test]
fn specific_allowlist_filters() {
let ch = DiscordChannel::new("fake".into(), None, vec!["111".into(), "222".into()], false);
let ch = DiscordChannel::new(
"fake".into(),
None,
vec!["111".into(), "222".into()],
false,
false,
);
assert!(ch.is_user_allowed("111"));
assert!(ch.is_user_allowed("222"));
assert!(!ch.is_user_allowed("333"));
@ -467,7 +497,7 @@ mod tests {
#[test]
fn allowlist_is_exact_match_not_substring() {
let ch = DiscordChannel::new("fake".into(), None, vec!["111".into()], false);
let ch = DiscordChannel::new("fake".into(), None, vec!["111".into()], false, false);
assert!(!ch.is_user_allowed("1111"));
assert!(!ch.is_user_allowed("11"));
assert!(!ch.is_user_allowed("0111"));
@ -475,20 +505,26 @@ mod tests {
#[test]
fn allowlist_empty_string_user_id() {
let ch = DiscordChannel::new("fake".into(), None, vec!["111".into()], false);
let ch = DiscordChannel::new("fake".into(), None, vec!["111".into()], false, false);
assert!(!ch.is_user_allowed(""));
}
#[test]
fn allowlist_with_wildcard_and_specific() {
let ch = DiscordChannel::new("fake".into(), None, vec!["111".into(), "*".into()], false);
let ch = DiscordChannel::new(
"fake".into(),
None,
vec!["111".into(), "*".into()],
false,
false,
);
assert!(ch.is_user_allowed("111"));
assert!(ch.is_user_allowed("anyone_else"));
}
#[test]
fn allowlist_case_sensitive() {
let ch = DiscordChannel::new("fake".into(), None, vec!["ABC".into()], false);
let ch = DiscordChannel::new("fake".into(), None, vec!["ABC".into()], false, false);
assert!(ch.is_user_allowed("ABC"));
assert!(!ch.is_user_allowed("abc"));
assert!(!ch.is_user_allowed("Abc"));
@ -663,14 +699,14 @@ mod tests {
#[test]
fn typing_handle_starts_as_none() {
let ch = DiscordChannel::new("fake".into(), None, vec![], false);
let ch = DiscordChannel::new("fake".into(), None, vec![], false, false);
let guard = ch.typing_handle.lock().unwrap();
assert!(guard.is_none());
}
#[tokio::test]
async fn start_typing_sets_handle() {
let ch = DiscordChannel::new("fake".into(), None, vec![], false);
let ch = DiscordChannel::new("fake".into(), None, vec![], false, false);
let _ = ch.start_typing("123456").await;
let guard = ch.typing_handle.lock().unwrap();
assert!(guard.is_some());
@ -678,7 +714,7 @@ mod tests {
#[tokio::test]
async fn stop_typing_clears_handle() {
let ch = DiscordChannel::new("fake".into(), None, vec![], false);
let ch = DiscordChannel::new("fake".into(), None, vec![], false, false);
let _ = ch.start_typing("123456").await;
let _ = ch.stop_typing("123456").await;
let guard = ch.typing_handle.lock().unwrap();
@ -687,14 +723,14 @@ mod tests {
#[tokio::test]
async fn stop_typing_is_idempotent() {
let ch = DiscordChannel::new("fake".into(), None, vec![], false);
let ch = DiscordChannel::new("fake".into(), None, vec![], false, false);
assert!(ch.stop_typing("123456").await.is_ok());
assert!(ch.stop_typing("123456").await.is_ok());
}
#[tokio::test]
async fn start_typing_replaces_existing_task() {
let ch = DiscordChannel::new("fake".into(), None, vec![], false);
let ch = DiscordChannel::new("fake".into(), None, vec![], false, false);
let _ = ch.start_typing("111").await;
let _ = ch.start_typing("222").await;
let guard = ch.typing_handle.lock().unwrap();

View file

@ -10,6 +10,7 @@
use anyhow::{anyhow, Result};
use async_trait::async_trait;
use lettre::message::SinglePart;
use lettre::transport::smtp::authentication::Credentials;
use lettre::{Message, SmtpTransport, Transport};
use mail_parser::{MessageParser, MimeHeaders};
@ -39,7 +40,7 @@ pub struct EmailConfig {
pub imap_folder: String,
/// SMTP server hostname
pub smtp_host: String,
/// SMTP server port (default: 587 for STARTTLS)
/// SMTP server port (default: 465 for TLS)
#[serde(default = "default_smtp_port")]
pub smtp_port: u16,
/// Use TLS for SMTP (default: true)
@ -63,7 +64,7 @@ fn default_imap_port() -> u16 {
993
}
fn default_smtp_port() -> u16 {
587
465
}
fn default_imap_folder() -> String {
"INBOX".into()
@ -389,7 +390,7 @@ impl Channel for EmailChannel {
.from(self.config.from_address.parse()?)
.to(recipient.parse()?)
.subject(subject)
.body(body.to_string())?;
.singlepart(SinglePart::plain(body.to_string()))?;
let transport = self.create_smtp_transport()?;
transport.send(&email)?;
@ -427,6 +428,7 @@ impl Channel for EmailChannel {
} // MutexGuard dropped before await
let msg = ChannelMessage {
id,
reply_target: sender.clone(),
sender,
content,
channel: "email".to_string(),
@ -464,6 +466,18 @@ impl Channel for EmailChannel {
mod tests {
use super::*;
#[test]
fn default_smtp_port_uses_tls_port() {
assert_eq!(default_smtp_port(), 465);
}
#[test]
fn email_config_default_uses_tls_smtp_defaults() {
let config = EmailConfig::default();
assert_eq!(config.smtp_port, 465);
assert!(config.smtp_tls);
}
#[test]
fn build_imap_tls_config_succeeds() {
let tls_config =
@ -504,7 +518,7 @@ mod tests {
assert_eq!(config.imap_port, 993);
assert_eq!(config.imap_folder, "INBOX");
assert_eq!(config.smtp_host, "");
assert_eq!(config.smtp_port, 587);
assert_eq!(config.smtp_port, 465);
assert!(config.smtp_tls);
assert_eq!(config.username, "");
assert_eq!(config.password, "");
@ -765,8 +779,8 @@ mod tests {
}
#[test]
fn default_smtp_port_returns_587() {
assert_eq!(default_smtp_port(), 587);
fn default_smtp_port_returns_465() {
assert_eq!(default_smtp_port(), 465);
}
#[test]
@ -822,7 +836,7 @@ mod tests {
let config: EmailConfig = serde_json::from_str(json).unwrap();
assert_eq!(config.imap_port, 993); // default
assert_eq!(config.smtp_port, 587); // default
assert_eq!(config.smtp_port, 465); // default
assert!(config.smtp_tls); // default
assert_eq!(config.poll_interval_secs, 60); // default
}

View file

@ -172,6 +172,7 @@ end tell"#
let msg = ChannelMessage {
id: rowid.to_string(),
sender: sender.clone(),
reply_target: sender.clone(),
content: text,
channel: "imessage".to_string(),
timestamp: std::time::SystemTime::now()

View file

@ -220,32 +220,34 @@ fn split_message(message: &str, max_bytes: usize) -> Vec<String> {
chunks
}
/// Configuration for constructing an `IrcChannel`.
pub struct IrcChannelConfig {
pub server: String,
pub port: u16,
pub nickname: String,
pub username: Option<String>,
pub channels: Vec<String>,
pub allowed_users: Vec<String>,
pub server_password: Option<String>,
pub nickserv_password: Option<String>,
pub sasl_password: Option<String>,
pub verify_tls: bool,
}
impl IrcChannel {
#[allow(clippy::too_many_arguments)]
pub fn new(
server: String,
port: u16,
nickname: String,
username: Option<String>,
channels: Vec<String>,
allowed_users: Vec<String>,
server_password: Option<String>,
nickserv_password: Option<String>,
sasl_password: Option<String>,
verify_tls: bool,
) -> Self {
let username = username.unwrap_or_else(|| nickname.clone());
pub fn new(cfg: IrcChannelConfig) -> Self {
let username = cfg.username.unwrap_or_else(|| cfg.nickname.clone());
Self {
server,
port,
nickname,
server: cfg.server,
port: cfg.port,
nickname: cfg.nickname,
username,
channels,
allowed_users,
server_password,
nickserv_password,
sasl_password,
verify_tls,
channels: cfg.channels,
allowed_users: cfg.allowed_users,
server_password: cfg.server_password,
nickserv_password: cfg.nickserv_password,
sasl_password: cfg.sasl_password,
verify_tls: cfg.verify_tls,
writer: Arc::new(Mutex::new(None)),
}
}
@ -563,7 +565,8 @@ impl Channel for IrcChannel {
let seq = MSG_SEQ.fetch_add(1, Ordering::Relaxed);
let channel_msg = ChannelMessage {
id: format!("irc_{}_{seq}", chrono::Utc::now().timestamp_millis()),
sender: reply_to,
sender: sender_nick.to_string(),
reply_target: reply_to,
content,
channel: "irc".to_string(),
timestamp: std::time::SystemTime::now()
@ -807,18 +810,18 @@ mod tests {
#[test]
fn specific_user_allowed() {
let ch = IrcChannel::new(
"irc.test".into(),
6697,
"bot".into(),
None,
vec![],
vec!["alice".into(), "bob".into()],
None,
None,
None,
true,
);
let ch = IrcChannel::new(IrcChannelConfig {
server: "irc.test".into(),
port: 6697,
nickname: "bot".into(),
username: None,
channels: vec![],
allowed_users: vec!["alice".into(), "bob".into()],
server_password: None,
nickserv_password: None,
sasl_password: None,
verify_tls: true,
});
assert!(ch.is_user_allowed("alice"));
assert!(ch.is_user_allowed("bob"));
assert!(!ch.is_user_allowed("eve"));
@ -826,18 +829,18 @@ mod tests {
#[test]
fn allowlist_case_insensitive() {
let ch = IrcChannel::new(
"irc.test".into(),
6697,
"bot".into(),
None,
vec![],
vec!["Alice".into()],
None,
None,
None,
true,
);
let ch = IrcChannel::new(IrcChannelConfig {
server: "irc.test".into(),
port: 6697,
nickname: "bot".into(),
username: None,
channels: vec![],
allowed_users: vec!["Alice".into()],
server_password: None,
nickserv_password: None,
sasl_password: None,
verify_tls: true,
});
assert!(ch.is_user_allowed("alice"));
assert!(ch.is_user_allowed("ALICE"));
assert!(ch.is_user_allowed("Alice"));
@ -845,18 +848,18 @@ mod tests {
#[test]
fn empty_allowlist_denies_all() {
let ch = IrcChannel::new(
"irc.test".into(),
6697,
"bot".into(),
None,
vec![],
vec![],
None,
None,
None,
true,
);
let ch = IrcChannel::new(IrcChannelConfig {
server: "irc.test".into(),
port: 6697,
nickname: "bot".into(),
username: None,
channels: vec![],
allowed_users: vec![],
server_password: None,
nickserv_password: None,
sasl_password: None,
verify_tls: true,
});
assert!(!ch.is_user_allowed("anyone"));
}
@ -864,35 +867,35 @@ mod tests {
#[test]
fn new_defaults_username_to_nickname() {
let ch = IrcChannel::new(
"irc.test".into(),
6697,
"mybot".into(),
None,
vec![],
vec![],
None,
None,
None,
true,
);
let ch = IrcChannel::new(IrcChannelConfig {
server: "irc.test".into(),
port: 6697,
nickname: "mybot".into(),
username: None,
channels: vec![],
allowed_users: vec![],
server_password: None,
nickserv_password: None,
sasl_password: None,
verify_tls: true,
});
assert_eq!(ch.username, "mybot");
}
#[test]
fn new_uses_explicit_username() {
let ch = IrcChannel::new(
"irc.test".into(),
6697,
"mybot".into(),
Some("customuser".into()),
vec![],
vec![],
None,
None,
None,
true,
);
let ch = IrcChannel::new(IrcChannelConfig {
server: "irc.test".into(),
port: 6697,
nickname: "mybot".into(),
username: Some("customuser".into()),
channels: vec![],
allowed_users: vec![],
server_password: None,
nickserv_password: None,
sasl_password: None,
verify_tls: true,
});
assert_eq!(ch.username, "customuser");
assert_eq!(ch.nickname, "mybot");
}
@ -905,18 +908,18 @@ mod tests {
#[test]
fn new_stores_all_fields() {
let ch = IrcChannel::new(
"irc.example.com".into(),
6697,
"zcbot".into(),
Some("zeroclaw".into()),
vec!["#test".into()],
vec!["alice".into()],
Some("serverpass".into()),
Some("nspass".into()),
Some("saslpass".into()),
false,
);
let ch = IrcChannel::new(IrcChannelConfig {
server: "irc.example.com".into(),
port: 6697,
nickname: "zcbot".into(),
username: Some("zeroclaw".into()),
channels: vec!["#test".into()],
allowed_users: vec!["alice".into()],
server_password: Some("serverpass".into()),
nickserv_password: Some("nspass".into()),
sasl_password: Some("saslpass".into()),
verify_tls: false,
});
assert_eq!(ch.server, "irc.example.com");
assert_eq!(ch.port, 6697);
assert_eq!(ch.nickname, "zcbot");
@ -995,17 +998,17 @@ nickname = "bot"
// ── Helpers ─────────────────────────────────────────────
fn make_channel() -> IrcChannel {
IrcChannel::new(
"irc.example.com".into(),
6697,
"zcbot".into(),
None,
vec!["#zeroclaw".into()],
vec!["*".into()],
None,
None,
None,
true,
)
IrcChannel::new(IrcChannelConfig {
server: "irc.example.com".into(),
port: 6697,
nickname: "zcbot".into(),
username: None,
channels: vec!["#zeroclaw".into()],
allowed_users: vec!["*".into()],
server_password: None,
nickserv_password: None,
sasl_password: None,
verify_tls: true,
})
}
}

View file

@ -1,21 +1,152 @@
use super::traits::{Channel, ChannelMessage};
use async_trait::async_trait;
use futures_util::{SinkExt, StreamExt};
use prost::Message as ProstMessage;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::RwLock;
use tokio_tungstenite::tungstenite::Message as WsMsg;
use uuid::Uuid;
const FEISHU_BASE_URL: &str = "https://open.feishu.cn/open-apis";
const FEISHU_WS_BASE_URL: &str = "https://open.feishu.cn";
const LARK_BASE_URL: &str = "https://open.larksuite.com/open-apis";
const LARK_WS_BASE_URL: &str = "https://open.larksuite.com";
/// Lark/Feishu channel — receives events via HTTP callback, sends via Open API
// ─────────────────────────────────────────────────────────────────────────────
// Feishu WebSocket long-connection: pbbp2.proto frame codec
// ─────────────────────────────────────────────────────────────────────────────
#[derive(Clone, PartialEq, prost::Message)]
struct PbHeader {
#[prost(string, tag = "1")]
pub key: String,
#[prost(string, tag = "2")]
pub value: String,
}
/// Feishu WS frame (pbbp2.proto).
/// method=0 → CONTROL (ping/pong) method=1 → DATA (events)
#[derive(Clone, PartialEq, prost::Message)]
struct PbFrame {
#[prost(uint64, tag = "1")]
pub seq_id: u64,
#[prost(uint64, tag = "2")]
pub log_id: u64,
#[prost(int32, tag = "3")]
pub service: i32,
#[prost(int32, tag = "4")]
pub method: i32,
#[prost(message, repeated, tag = "5")]
pub headers: Vec<PbHeader>,
#[prost(bytes = "vec", optional, tag = "8")]
pub payload: Option<Vec<u8>>,
}
impl PbFrame {
fn header_value<'a>(&'a self, key: &str) -> &'a str {
self.headers
.iter()
.find(|h| h.key == key)
.map(|h| h.value.as_str())
.unwrap_or("")
}
}
/// Server-sent client config (parsed from pong payload)
#[derive(Debug, serde::Deserialize, Default, Clone)]
struct WsClientConfig {
#[serde(rename = "PingInterval")]
ping_interval: Option<u64>,
}
/// POST /callback/ws/endpoint response
#[derive(Debug, serde::Deserialize)]
struct WsEndpointResp {
code: i32,
#[serde(default)]
msg: Option<String>,
#[serde(default)]
data: Option<WsEndpoint>,
}
#[derive(Debug, serde::Deserialize)]
struct WsEndpoint {
#[serde(rename = "URL")]
url: String,
#[serde(rename = "ClientConfig")]
client_config: Option<WsClientConfig>,
}
/// LarkEvent envelope (method=1 / type=event payload)
#[derive(Debug, serde::Deserialize)]
struct LarkEvent {
header: LarkEventHeader,
event: serde_json::Value,
}
#[derive(Debug, serde::Deserialize)]
struct LarkEventHeader {
event_type: String,
#[allow(dead_code)]
event_id: String,
}
#[derive(Debug, serde::Deserialize)]
struct MsgReceivePayload {
sender: LarkSender,
message: LarkMessage,
}
#[derive(Debug, serde::Deserialize)]
struct LarkSender {
sender_id: LarkSenderId,
#[serde(default)]
sender_type: String,
}
#[derive(Debug, serde::Deserialize, Default)]
struct LarkSenderId {
open_id: Option<String>,
}
#[derive(Debug, serde::Deserialize)]
struct LarkMessage {
message_id: String,
chat_id: String,
chat_type: String,
message_type: String,
#[serde(default)]
content: String,
#[serde(default)]
mentions: Vec<serde_json::Value>,
}
/// Heartbeat timeout for WS connection — must be larger than ping_interval (default 120 s).
/// If no binary frame (pong or event) is received within this window, reconnect.
const WS_HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(300);
/// Lark/Feishu channel.
///
/// Supports two receive modes (configured via `receive_mode` in config):
/// - **`websocket`** (default): persistent WSS long-connection; no public URL needed.
/// - **`webhook`**: HTTP callback server; requires a public HTTPS endpoint.
pub struct LarkChannel {
app_id: String,
app_secret: String,
verification_token: String,
port: u16,
port: Option<u16>,
allowed_users: Vec<String>,
/// When true, use Feishu (CN) endpoints; when false, use Lark (international).
use_feishu: bool,
/// How to receive events: WebSocket long-connection or HTTP webhook.
receive_mode: crate::config::schema::LarkReceiveMode,
client: reqwest::Client,
/// Cached tenant access token
tenant_token: Arc<RwLock<Option<String>>>,
/// Dedup set: WS message_ids seen in last ~30 min to prevent double-dispatch
ws_seen_ids: Arc<RwLock<HashMap<String, Instant>>>,
}
impl LarkChannel {
@ -23,7 +154,7 @@ impl LarkChannel {
app_id: String,
app_secret: String,
verification_token: String,
port: u16,
port: Option<u16>,
allowed_users: Vec<String>,
) -> Self {
Self {
@ -32,11 +163,310 @@ impl LarkChannel {
verification_token,
port,
allowed_users,
use_feishu: true,
receive_mode: crate::config::schema::LarkReceiveMode::default(),
client: reqwest::Client::new(),
tenant_token: Arc::new(RwLock::new(None)),
ws_seen_ids: Arc::new(RwLock::new(HashMap::new())),
}
}
/// Build from `LarkConfig` (preserves `use_feishu` and `receive_mode`).
pub fn from_config(config: &crate::config::schema::LarkConfig) -> Self {
let mut ch = Self::new(
config.app_id.clone(),
config.app_secret.clone(),
config.verification_token.clone().unwrap_or_default(),
config.port,
config.allowed_users.clone(),
);
ch.use_feishu = config.use_feishu;
ch.receive_mode = config.receive_mode.clone();
ch
}
fn api_base(&self) -> &'static str {
if self.use_feishu {
FEISHU_BASE_URL
} else {
LARK_BASE_URL
}
}
fn ws_base(&self) -> &'static str {
if self.use_feishu {
FEISHU_WS_BASE_URL
} else {
LARK_WS_BASE_URL
}
}
fn tenant_access_token_url(&self) -> String {
format!("{}/auth/v3/tenant_access_token/internal", self.api_base())
}
fn send_message_url(&self) -> String {
format!("{}/im/v1/messages?receive_id_type=chat_id", self.api_base())
}
/// POST /callback/ws/endpoint → (wss_url, client_config)
async fn get_ws_endpoint(&self) -> anyhow::Result<(String, WsClientConfig)> {
let resp = self
.client
.post(format!("{}/callback/ws/endpoint", self.ws_base()))
.header("locale", if self.use_feishu { "zh" } else { "en" })
.json(&serde_json::json!({
"AppID": self.app_id,
"AppSecret": self.app_secret,
}))
.send()
.await?
.json::<WsEndpointResp>()
.await?;
if resp.code != 0 {
anyhow::bail!(
"Lark WS endpoint failed: code={} msg={}",
resp.code,
resp.msg.as_deref().unwrap_or("(none)")
);
}
let ep = resp
.data
.ok_or_else(|| anyhow::anyhow!("Lark WS endpoint: empty data"))?;
Ok((ep.url, ep.client_config.unwrap_or_default()))
}
/// WS long-connection event loop. Returns Ok(()) when the connection closes
/// (the caller reconnects).
#[allow(clippy::too_many_lines)]
async fn listen_ws(&self, tx: tokio::sync::mpsc::Sender<ChannelMessage>) -> anyhow::Result<()> {
let (wss_url, client_config) = self.get_ws_endpoint().await?;
let service_id = wss_url
.split('?')
.nth(1)
.and_then(|qs| {
qs.split('&')
.find(|kv| kv.starts_with("service_id="))
.and_then(|kv| kv.split('=').nth(1))
.and_then(|v| v.parse::<i32>().ok())
})
.unwrap_or(0);
tracing::info!("Lark: connecting to {wss_url}");
let (ws_stream, _) = tokio_tungstenite::connect_async(&wss_url).await?;
let (mut write, mut read) = ws_stream.split();
tracing::info!("Lark: WS connected (service_id={service_id})");
let mut ping_secs = client_config.ping_interval.unwrap_or(120).max(10);
let mut hb_interval = tokio::time::interval(Duration::from_secs(ping_secs));
let mut timeout_check = tokio::time::interval(Duration::from_secs(10));
hb_interval.tick().await; // consume immediate tick
let mut seq: u64 = 0;
let mut last_recv = Instant::now();
// Send initial ping immediately (like the official SDK) so the server
// starts responding with pongs and we can calibrate the ping_interval.
seq = seq.wrapping_add(1);
let initial_ping = PbFrame {
seq_id: seq,
log_id: 0,
service: service_id,
method: 0,
headers: vec![PbHeader {
key: "type".into(),
value: "ping".into(),
}],
payload: None,
};
if write
.send(WsMsg::Binary(initial_ping.encode_to_vec()))
.await
.is_err()
{
anyhow::bail!("Lark: initial ping failed");
}
// message_id → (fragment_slots, created_at) for multi-part reassembly
type FragEntry = (Vec<Option<Vec<u8>>>, Instant);
let mut frag_cache: HashMap<String, FragEntry> = HashMap::new();
loop {
tokio::select! {
biased;
_ = hb_interval.tick() => {
seq = seq.wrapping_add(1);
let ping = PbFrame {
seq_id: seq, log_id: 0, service: service_id, method: 0,
headers: vec![PbHeader { key: "type".into(), value: "ping".into() }],
payload: None,
};
if write.send(WsMsg::Binary(ping.encode_to_vec())).await.is_err() {
tracing::warn!("Lark: ping failed, reconnecting");
break;
}
// GC stale fragments > 5 min
let cutoff = Instant::now().checked_sub(Duration::from_secs(300)).unwrap_or(Instant::now());
frag_cache.retain(|_, (_, ts)| *ts > cutoff);
}
_ = timeout_check.tick() => {
if last_recv.elapsed() > WS_HEARTBEAT_TIMEOUT {
tracing::warn!("Lark: heartbeat timeout, reconnecting");
break;
}
}
msg = read.next() => {
let raw = match msg {
Some(Ok(WsMsg::Binary(b))) => { last_recv = Instant::now(); b }
Some(Ok(WsMsg::Ping(d))) => { let _ = write.send(WsMsg::Pong(d)).await; continue; }
Some(Ok(WsMsg::Close(_))) | None => { tracing::info!("Lark: WS closed — reconnecting"); break; }
Some(Err(e)) => { tracing::error!("Lark: WS read error: {e}"); break; }
_ => continue,
};
let frame = match PbFrame::decode(&raw[..]) {
Ok(f) => f,
Err(e) => { tracing::error!("Lark: proto decode: {e}"); continue; }
};
// CONTROL frame
if frame.method == 0 {
if frame.header_value("type") == "pong" {
if let Some(p) = &frame.payload {
if let Ok(cfg) = serde_json::from_slice::<WsClientConfig>(p) {
if let Some(secs) = cfg.ping_interval {
let secs = secs.max(10);
if secs != ping_secs {
ping_secs = secs;
hb_interval = tokio::time::interval(Duration::from_secs(ping_secs));
tracing::info!("Lark: ping_interval → {ping_secs}s");
}
}
}
}
}
continue;
}
// DATA frame
let msg_type = frame.header_value("type").to_string();
let msg_id = frame.header_value("message_id").to_string();
let sum = frame.header_value("sum").parse::<usize>().unwrap_or(1);
let seq_num = frame.header_value("seq").parse::<usize>().unwrap_or(0);
// ACK immediately (Feishu requires within 3 s)
{
let mut ack = frame.clone();
ack.payload = Some(br#"{"code":200,"headers":{},"data":[]}"#.to_vec());
ack.headers.push(PbHeader { key: "biz_rt".into(), value: "0".into() });
let _ = write.send(WsMsg::Binary(ack.encode_to_vec())).await;
}
// Fragment reassembly
let sum = if sum == 0 { 1 } else { sum };
let payload: Vec<u8> = if sum == 1 || msg_id.is_empty() || seq_num >= sum {
frame.payload.clone().unwrap_or_default()
} else {
let entry = frag_cache.entry(msg_id.clone())
.or_insert_with(|| (vec![None; sum], Instant::now()));
if entry.0.len() != sum { *entry = (vec![None; sum], Instant::now()); }
entry.0[seq_num] = frame.payload.clone();
if entry.0.iter().all(|s| s.is_some()) {
let full: Vec<u8> = entry.0.iter()
.flat_map(|s| s.as_deref().unwrap_or(&[]))
.copied().collect();
frag_cache.remove(&msg_id);
full
} else { continue; }
};
if msg_type != "event" { continue; }
let event: LarkEvent = match serde_json::from_slice(&payload) {
Ok(e) => e,
Err(e) => { tracing::error!("Lark: event JSON: {e}"); continue; }
};
if event.header.event_type != "im.message.receive_v1" { continue; }
let recv: MsgReceivePayload = match serde_json::from_value(event.event) {
Ok(r) => r,
Err(e) => { tracing::error!("Lark: payload parse: {e}"); continue; }
};
if recv.sender.sender_type == "app" || recv.sender.sender_type == "bot" { continue; }
let sender_open_id = recv.sender.sender_id.open_id.as_deref().unwrap_or("");
if !self.is_user_allowed(sender_open_id) {
tracing::warn!("Lark WS: ignoring {sender_open_id} (not in allowed_users)");
continue;
}
let lark_msg = &recv.message;
// Dedup
{
let now = Instant::now();
let mut seen = self.ws_seen_ids.write().await;
// GC
seen.retain(|_, t| now.duration_since(*t) < Duration::from_secs(30 * 60));
if seen.contains_key(&lark_msg.message_id) {
tracing::debug!("Lark WS: dup {}", lark_msg.message_id);
continue;
}
seen.insert(lark_msg.message_id.clone(), now);
}
// Decode content by type (mirrors clawdbot-feishu parsing)
let text = match lark_msg.message_type.as_str() {
"text" => {
let v: serde_json::Value = match serde_json::from_str(&lark_msg.content) {
Ok(v) => v,
Err(_) => continue,
};
match v.get("text").and_then(|t| t.as_str()).filter(|s| !s.is_empty()) {
Some(t) => t.to_string(),
None => continue,
}
}
"post" => match parse_post_content(&lark_msg.content) {
Some(t) => t,
None => continue,
},
_ => { tracing::debug!("Lark WS: skipping unsupported type '{}'", lark_msg.message_type); continue; }
};
// Strip @_user_N placeholders
let text = strip_at_placeholders(&text);
let text = text.trim().to_string();
if text.is_empty() { continue; }
// Group-chat: only respond when explicitly @-mentioned
if lark_msg.chat_type == "group" && !should_respond_in_group(&lark_msg.mentions) {
continue;
}
let channel_msg = ChannelMessage {
id: Uuid::new_v4().to_string(),
sender: lark_msg.chat_id.clone(),
reply_target: lark_msg.chat_id.clone(),
content: text,
channel: "lark".to_string(),
timestamp: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
};
tracing::debug!("Lark WS: message in {}", lark_msg.chat_id);
if tx.send(channel_msg).await.is_err() { break; }
}
}
}
Ok(())
}
/// Check if a user open_id is allowed
fn is_user_allowed(&self, open_id: &str) -> bool {
self.allowed_users.iter().any(|u| u == "*" || u == open_id)
@ -52,7 +482,7 @@ impl LarkChannel {
}
}
let url = format!("{FEISHU_BASE_URL}/auth/v3/tenant_access_token/internal");
let url = self.tenant_access_token_url();
let body = serde_json::json!({
"app_id": self.app_id,
"app_secret": self.app_secret,
@ -127,31 +557,41 @@ impl LarkChannel {
return messages;
}
// Extract message content (text only)
// Extract message content (text and post supported)
let msg_type = event
.pointer("/message/message_type")
.and_then(|t| t.as_str())
.unwrap_or("");
if msg_type != "text" {
tracing::debug!("Lark: skipping non-text message type: {msg_type}");
return messages;
}
let content_str = event
.pointer("/message/content")
.and_then(|c| c.as_str())
.unwrap_or("");
// content is a JSON string like "{\"text\":\"hello\"}"
let text = serde_json::from_str::<serde_json::Value>(content_str)
let text: String = match msg_type {
"text" => {
let extracted = serde_json::from_str::<serde_json::Value>(content_str)
.ok()
.and_then(|v| v.get("text").and_then(|t| t.as_str()).map(String::from))
.unwrap_or_default();
if text.is_empty() {
.and_then(|v| {
v.get("text")
.and_then(|t| t.as_str())
.filter(|s| !s.is_empty())
.map(String::from)
});
match extracted {
Some(t) => t,
None => return messages,
}
}
"post" => match parse_post_content(content_str) {
Some(t) => t,
None => return messages,
},
_ => {
tracing::debug!("Lark: skipping unsupported message type: {msg_type}");
return messages;
}
};
let timestamp = event
.pointer("/message/create_time")
@ -174,6 +614,7 @@ impl LarkChannel {
messages.push(ChannelMessage {
id: Uuid::new_v4().to_string(),
sender: chat_id.to_string(),
reply_target: chat_id.to_string(),
content: text,
channel: "lark".to_string(),
timestamp,
@ -191,7 +632,7 @@ impl Channel for LarkChannel {
async fn send(&self, message: &str, recipient: &str) -> anyhow::Result<()> {
let token = self.get_tenant_access_token().await?;
let url = format!("{FEISHU_BASE_URL}/im/v1/messages?receive_id_type=chat_id");
let url = self.send_message_url();
let content = serde_json::json!({ "text": message }).to_string();
let body = serde_json::json!({
@ -238,6 +679,25 @@ impl Channel for LarkChannel {
}
async fn listen(&self, tx: tokio::sync::mpsc::Sender<ChannelMessage>) -> anyhow::Result<()> {
use crate::config::schema::LarkReceiveMode;
match self.receive_mode {
LarkReceiveMode::Websocket => self.listen_ws(tx).await,
LarkReceiveMode::Webhook => self.listen_http(tx).await,
}
}
async fn health_check(&self) -> bool {
self.get_tenant_access_token().await.is_ok()
}
}
impl LarkChannel {
/// HTTP callback server (legacy — requires a public endpoint).
/// Use `listen()` (WS long-connection) for new deployments.
pub async fn listen_http(
&self,
tx: tokio::sync::mpsc::Sender<ChannelMessage>,
) -> anyhow::Result<()> {
use axum::{extract::State, routing::post, Json, Router};
#[derive(Clone)]
@ -282,13 +742,17 @@ impl Channel for LarkChannel {
(StatusCode::OK, "ok").into_response()
}
let port = self.port.ok_or_else(|| {
anyhow::anyhow!("Lark webhook mode requires `port` to be set in [channels_config.lark]")
})?;
let state = AppState {
verification_token: self.verification_token.clone(),
channel: Arc::new(LarkChannel::new(
self.app_id.clone(),
self.app_secret.clone(),
self.verification_token.clone(),
self.port,
None,
self.allowed_users.clone(),
)),
tx,
@ -298,7 +762,7 @@ impl Channel for LarkChannel {
.route("/lark", post(handle_event))
.with_state(state);
let addr = std::net::SocketAddr::from(([0, 0, 0, 0], self.port));
let addr = std::net::SocketAddr::from(([0, 0, 0, 0], port));
tracing::info!("Lark event callback server listening on {addr}");
let listener = tokio::net::TcpListener::bind(addr).await?;
@ -306,10 +770,110 @@ impl Channel for LarkChannel {
Ok(())
}
}
async fn health_check(&self) -> bool {
self.get_tenant_access_token().await.is_ok()
// ─────────────────────────────────────────────────────────────────────────────
// WS helper functions
// ─────────────────────────────────────────────────────────────────────────────
/// Flatten a Feishu `post` rich-text message to plain text.
///
/// Returns `None` when the content cannot be parsed or yields no usable text,
/// so callers can simply `continue` rather than forwarding a meaningless
/// placeholder string to the agent.
fn parse_post_content(content: &str) -> Option<String> {
let parsed = serde_json::from_str::<serde_json::Value>(content).ok()?;
let locale = parsed
.get("zh_cn")
.or_else(|| parsed.get("en_us"))
.or_else(|| {
parsed
.as_object()
.and_then(|m| m.values().find(|v| v.is_object()))
})?;
let mut text = String::new();
if let Some(title) = locale
.get("title")
.and_then(|t| t.as_str())
.filter(|s| !s.is_empty())
{
text.push_str(title);
text.push_str("\n\n");
}
if let Some(paragraphs) = locale.get("content").and_then(|c| c.as_array()) {
for para in paragraphs {
if let Some(elements) = para.as_array() {
for el in elements {
match el.get("tag").and_then(|t| t.as_str()).unwrap_or("") {
"text" => {
if let Some(t) = el.get("text").and_then(|t| t.as_str()) {
text.push_str(t);
}
}
"a" => {
text.push_str(
el.get("text")
.and_then(|t| t.as_str())
.filter(|s| !s.is_empty())
.or_else(|| el.get("href").and_then(|h| h.as_str()))
.unwrap_or(""),
);
}
"at" => {
let n = el
.get("user_name")
.and_then(|n| n.as_str())
.or_else(|| el.get("user_id").and_then(|i| i.as_str()))
.unwrap_or("user");
text.push('@');
text.push_str(n);
}
_ => {}
}
}
text.push('\n');
}
}
}
let result = text.trim().to_string();
if result.is_empty() {
None
} else {
Some(result)
}
}
/// Remove `@_user_N` placeholder tokens injected by Feishu in group chats.
fn strip_at_placeholders(text: &str) -> String {
let mut result = String::with_capacity(text.len());
let mut chars = text.char_indices().peekable();
while let Some((_, ch)) = chars.next() {
if ch == '@' {
let rest: String = chars.clone().map(|(_, c)| c).collect();
if let Some(after) = rest.strip_prefix("_user_") {
let skip =
"_user_".len() + after.chars().take_while(|c| c.is_ascii_digit()).count();
for _ in 0..=skip {
chars.next();
}
if chars.peek().map(|(_, c)| *c == ' ').unwrap_or(false) {
chars.next();
}
continue;
}
}
result.push(ch);
}
result
}
/// In group chats, only respond when the bot is explicitly @-mentioned.
fn should_respond_in_group(mentions: &[serde_json::Value]) -> bool {
!mentions.is_empty()
}
#[cfg(test)]
@ -321,7 +885,7 @@ mod tests {
"cli_test_app_id".into(),
"test_app_secret".into(),
"test_verification_token".into(),
9898,
None,
vec!["ou_testuser123".into()],
)
}
@ -345,7 +909,7 @@ mod tests {
"id".into(),
"secret".into(),
"token".into(),
9898,
None,
vec!["*".into()],
);
assert!(ch.is_user_allowed("ou_anyone"));
@ -353,7 +917,7 @@ mod tests {
#[test]
fn lark_user_denied_empty() {
let ch = LarkChannel::new("id".into(), "secret".into(), "token".into(), 9898, vec![]);
let ch = LarkChannel::new("id".into(), "secret".into(), "token".into(), None, vec![]);
assert!(!ch.is_user_allowed("ou_anyone"));
}
@ -426,7 +990,7 @@ mod tests {
"id".into(),
"secret".into(),
"token".into(),
9898,
None,
vec!["*".into()],
);
let payload = serde_json::json!({
@ -451,7 +1015,7 @@ mod tests {
"id".into(),
"secret".into(),
"token".into(),
9898,
None,
vec!["*".into()],
);
let payload = serde_json::json!({
@ -488,7 +1052,7 @@ mod tests {
"id".into(),
"secret".into(),
"token".into(),
9898,
None,
vec!["*".into()],
);
let payload = serde_json::json!({
@ -512,7 +1076,7 @@ mod tests {
"id".into(),
"secret".into(),
"token".into(),
9898,
None,
vec!["*".into()],
);
let payload = serde_json::json!({
@ -550,7 +1114,7 @@ mod tests {
"id".into(),
"secret".into(),
"token".into(),
9898,
None,
vec!["*".into()],
);
let payload = serde_json::json!({
@ -571,7 +1135,7 @@ mod tests {
#[test]
fn lark_config_serde() {
use crate::config::schema::LarkConfig;
use crate::config::schema::{LarkConfig, LarkReceiveMode};
let lc = LarkConfig {
app_id: "cli_app123".into(),
app_secret: "secret456".into(),
@ -579,6 +1143,8 @@ mod tests {
verification_token: Some("vtoken789".into()),
allowed_users: vec!["ou_user1".into(), "ou_user2".into()],
use_feishu: false,
receive_mode: LarkReceiveMode::default(),
port: None,
};
let json = serde_json::to_string(&lc).unwrap();
let parsed: LarkConfig = serde_json::from_str(&json).unwrap();
@ -590,7 +1156,7 @@ mod tests {
#[test]
fn lark_config_toml_roundtrip() {
use crate::config::schema::LarkConfig;
use crate::config::schema::{LarkConfig, LarkReceiveMode};
let lc = LarkConfig {
app_id: "app".into(),
app_secret: "secret".into(),
@ -598,6 +1164,8 @@ mod tests {
verification_token: Some("tok".into()),
allowed_users: vec!["*".into()],
use_feishu: false,
receive_mode: LarkReceiveMode::Webhook,
port: Some(9898),
};
let toml_str = toml::to_string(&lc).unwrap();
let parsed: LarkConfig = toml::from_str(&toml_str).unwrap();
@ -608,11 +1176,36 @@ mod tests {
#[test]
fn lark_config_defaults_optional_fields() {
use crate::config::schema::LarkConfig;
use crate::config::schema::{LarkConfig, LarkReceiveMode};
let json = r#"{"app_id":"a","app_secret":"s"}"#;
let parsed: LarkConfig = serde_json::from_str(json).unwrap();
assert!(parsed.verification_token.is_none());
assert!(parsed.allowed_users.is_empty());
assert_eq!(parsed.receive_mode, LarkReceiveMode::Websocket);
assert!(parsed.port.is_none());
}
#[test]
fn lark_from_config_preserves_mode_and_region() {
use crate::config::schema::{LarkConfig, LarkReceiveMode};
let cfg = LarkConfig {
app_id: "cli_app123".into(),
app_secret: "secret456".into(),
encrypt_key: None,
verification_token: Some("vtoken789".into()),
allowed_users: vec!["*".into()],
use_feishu: false,
receive_mode: LarkReceiveMode::Webhook,
port: Some(9898),
};
let ch = LarkChannel::from_config(&cfg);
assert_eq!(ch.api_base(), LARK_BASE_URL);
assert_eq!(ch.ws_base(), LARK_WS_BASE_URL);
assert_eq!(ch.receive_mode, LarkReceiveMode::Webhook);
assert_eq!(ch.port, Some(9898));
}
#[test]
@ -622,7 +1215,7 @@ mod tests {
"id".into(),
"secret".into(),
"token".into(),
9898,
None,
vec!["*".into()],
);
let payload = serde_json::json!({

View file

@ -230,6 +230,7 @@ impl Channel for MatrixChannel {
let msg = ChannelMessage {
id: format!("mx_{}", chrono::Utc::now().timestamp_millis()),
sender: event.sender.clone(),
reply_target: event.sender.clone(),
content: body.clone(),
channel: "matrix".to_string(),
timestamp: std::time::SystemTime::now()

View file

@ -69,10 +69,19 @@ fn conversation_memory_key(msg: &traits::ChannelMessage) -> String {
format!("{}_{}_{}", msg.channel, msg.sender, msg.id)
}
fn channel_delivery_instructions(channel_name: &str) -> Option<&'static str> {
match channel_name {
"telegram" => Some(
"When responding on Telegram, include media markers for files or URLs that should be sent as attachments. Use one marker per attachment with this exact syntax: [IMAGE:<path-or-url>], [DOCUMENT:<path-or-url>], [VIDEO:<path-or-url>], [AUDIO:<path-or-url>], or [VOICE:<path-or-url>]. Keep normal user-facing text outside markers and never wrap markers in code fences.",
),
_ => None,
}
}
async fn build_memory_context(mem: &dyn Memory, user_msg: &str) -> String {
let mut context = String::new();
if let Ok(entries) = mem.recall(user_msg, 5).await {
if let Ok(entries) = mem.recall(user_msg, 5, None).await {
if !entries.is_empty() {
context.push_str("[Memory context]\n");
for entry in &entries {
@ -158,6 +167,7 @@ async fn process_channel_message(ctx: Arc<ChannelRuntimeContext>, msg: traits::C
&autosave_key,
&msg.content,
crate::memory::MemoryCategory::Conversation,
None,
)
.await;
}
@ -171,7 +181,7 @@ async fn process_channel_message(ctx: Arc<ChannelRuntimeContext>, msg: traits::C
let target_channel = ctx.channels_by_name.get(&msg.channel).cloned();
if let Some(channel) = target_channel.as_ref() {
if let Err(e) = channel.start_typing(&msg.sender).await {
if let Err(e) = channel.start_typing(&msg.reply_target).await {
tracing::debug!("Failed to start typing on {}: {e}", channel.name());
}
}
@ -184,6 +194,10 @@ async fn process_channel_message(ctx: Arc<ChannelRuntimeContext>, msg: traits::C
ChatMessage::user(&enriched_message),
];
if let Some(instructions) = channel_delivery_instructions(&msg.channel) {
history.push(ChatMessage::system(instructions));
}
let llm_result = tokio::time::timeout(
Duration::from_secs(CHANNEL_MESSAGE_TIMEOUT_SECS),
run_tool_call_loop(
@ -200,7 +214,7 @@ async fn process_channel_message(ctx: Arc<ChannelRuntimeContext>, msg: traits::C
.await;
if let Some(channel) = target_channel.as_ref() {
if let Err(e) = channel.stop_typing(&msg.sender).await {
if let Err(e) = channel.stop_typing(&msg.reply_target).await {
tracing::debug!("Failed to stop typing on {}: {e}", channel.name());
}
}
@ -224,7 +238,9 @@ async fn process_channel_message(ctx: Arc<ChannelRuntimeContext>, msg: traits::C
started_at.elapsed().as_millis()
);
if let Some(channel) = target_channel.as_ref() {
let _ = channel.send(&format!("⚠️ Error: {e}"), &msg.sender).await;
let _ = channel
.send(&format!("⚠️ Error: {e}"), &msg.reply_target)
.await;
}
}
Err(_) => {
@ -241,7 +257,7 @@ async fn process_channel_message(ctx: Arc<ChannelRuntimeContext>, msg: traits::C
let _ = channel
.send(
"⚠️ Request timed out while waiting for the model. Please try again.",
&msg.sender,
&msg.reply_target,
)
.await;
}
@ -483,6 +499,16 @@ pub fn build_system_prompt(
std::env::consts::OS,
);
// ── 8. Channel Capabilities ─────────────────────────────────────
prompt.push_str("## Channel Capabilities\n\n");
prompt.push_str(
"- You are running as a Discord bot. You CAN and do send messages to Discord channels.\n",
);
prompt.push_str("- When someone messages you on Discord, your response is automatically sent back to Discord.\n");
prompt.push_str("- You do NOT need to ask permission to respond — just respond directly.\n");
prompt.push_str("- NEVER repeat, describe, or echo credentials, tokens, API keys, or secrets in your responses.\n");
prompt.push_str("- If a tool output contains credentials, they have already been redacted — do not mention them.\n\n");
if prompt.is_empty() {
"You are ZeroClaw, a fast and efficient AI assistant built in Rust. Be helpful, concise, and direct.".to_string()
} else {
@ -619,6 +645,7 @@ pub async fn doctor_channels(config: Config) -> Result<()> {
dc.guild_id.clone(),
dc.allowed_users.clone(),
dc.listen_to_bots,
dc.mention_only,
)),
));
}
@ -672,32 +699,23 @@ pub async fn doctor_channels(config: Config) -> Result<()> {
if let Some(ref irc) = config.channels_config.irc {
channels.push((
"IRC",
Arc::new(IrcChannel::new(
irc.server.clone(),
irc.port,
irc.nickname.clone(),
irc.username.clone(),
irc.channels.clone(),
irc.allowed_users.clone(),
irc.server_password.clone(),
irc.nickserv_password.clone(),
irc.sasl_password.clone(),
irc.verify_tls.unwrap_or(true),
)),
Arc::new(IrcChannel::new(irc::IrcChannelConfig {
server: irc.server.clone(),
port: irc.port,
nickname: irc.nickname.clone(),
username: irc.username.clone(),
channels: irc.channels.clone(),
allowed_users: irc.allowed_users.clone(),
server_password: irc.server_password.clone(),
nickserv_password: irc.nickserv_password.clone(),
sasl_password: irc.sasl_password.clone(),
verify_tls: irc.verify_tls.unwrap_or(true),
})),
));
}
if let Some(ref lk) = config.channels_config.lark {
channels.push((
"Lark",
Arc::new(LarkChannel::new(
lk.app_id.clone(),
lk.app_secret.clone(),
lk.verification_token.clone().unwrap_or_default(),
9898,
lk.allowed_users.clone(),
)),
));
channels.push(("Lark", Arc::new(LarkChannel::from_config(lk))));
}
if let Some(ref dt) = config.channels_config.dingtalk {
@ -762,6 +780,7 @@ pub async fn start_channels(config: Config) -> Result<()> {
let provider: Arc<dyn Provider> = Arc::from(providers::create_resilient_provider(
&provider_name,
config.api_key.as_deref(),
config.api_url.as_deref(),
&config.reliability,
)?);
@ -860,6 +879,10 @@ pub async fn start_channels(config: Config) -> Result<()> {
"schedule",
"Manage scheduled tasks (create/list/get/cancel/pause/resume). Supports recurring cron and one-shot delays.",
));
tool_descs.push((
"pushover",
"Send a Pushover notification to your device. Requires PUSHOVER_TOKEN and PUSHOVER_USER_KEY in .env file.",
));
if !config.agents.is_empty() {
tool_descs.push((
"delegate",
@ -909,6 +932,7 @@ pub async fn start_channels(config: Config) -> Result<()> {
dc.guild_id.clone(),
dc.allowed_users.clone(),
dc.listen_to_bots,
dc.mention_only,
)));
}
@ -947,28 +971,22 @@ pub async fn start_channels(config: Config) -> Result<()> {
}
if let Some(ref irc) = config.channels_config.irc {
channels.push(Arc::new(IrcChannel::new(
irc.server.clone(),
irc.port,
irc.nickname.clone(),
irc.username.clone(),
irc.channels.clone(),
irc.allowed_users.clone(),
irc.server_password.clone(),
irc.nickserv_password.clone(),
irc.sasl_password.clone(),
irc.verify_tls.unwrap_or(true),
)));
channels.push(Arc::new(IrcChannel::new(irc::IrcChannelConfig {
server: irc.server.clone(),
port: irc.port,
nickname: irc.nickname.clone(),
username: irc.username.clone(),
channels: irc.channels.clone(),
allowed_users: irc.allowed_users.clone(),
server_password: irc.server_password.clone(),
nickserv_password: irc.nickserv_password.clone(),
sasl_password: irc.sasl_password.clone(),
verify_tls: irc.verify_tls.unwrap_or(true),
})));
}
if let Some(ref lk) = config.channels_config.lark {
channels.push(Arc::new(LarkChannel::new(
lk.app_id.clone(),
lk.app_secret.clone(),
lk.verification_token.clone().unwrap_or_default(),
9898,
lk.allowed_users.clone(),
)));
channels.push(Arc::new(LarkChannel::from_config(lk)));
}
if let Some(ref dt) = config.channels_config.dingtalk {
@ -1242,6 +1260,7 @@ mod tests {
traits::ChannelMessage {
id: "msg-1".to_string(),
sender: "alice".to_string(),
reply_target: "chat-42".to_string(),
content: "What is the BTC price now?".to_string(),
channel: "test-channel".to_string(),
timestamp: 1,
@ -1251,6 +1270,7 @@ mod tests {
let sent_messages = channel_impl.sent_messages.lock().await;
assert_eq!(sent_messages.len(), 1);
assert!(sent_messages[0].starts_with("chat-42:"));
assert!(sent_messages[0].contains("BTC is currently around"));
assert!(!sent_messages[0].contains("\"tool_calls\""));
assert!(!sent_messages[0].contains("mock_price"));
@ -1269,6 +1289,7 @@ mod tests {
_key: &str,
_content: &str,
_category: crate::memory::MemoryCategory,
_session_id: Option<&str>,
) -> anyhow::Result<()> {
Ok(())
}
@ -1277,6 +1298,7 @@ mod tests {
&self,
_query: &str,
_limit: usize,
_session_id: Option<&str>,
) -> anyhow::Result<Vec<crate::memory::MemoryEntry>> {
Ok(Vec::new())
}
@ -1288,6 +1310,7 @@ mod tests {
async fn list(
&self,
_category: Option<&crate::memory::MemoryCategory>,
_session_id: Option<&str>,
) -> anyhow::Result<Vec<crate::memory::MemoryEntry>> {
Ok(Vec::new())
}
@ -1331,6 +1354,7 @@ mod tests {
tx.send(traits::ChannelMessage {
id: "1".to_string(),
sender: "alice".to_string(),
reply_target: "alice".to_string(),
content: "hello".to_string(),
channel: "test-channel".to_string(),
timestamp: 1,
@ -1340,6 +1364,7 @@ mod tests {
tx.send(traits::ChannelMessage {
id: "2".to_string(),
sender: "bob".to_string(),
reply_target: "bob".to_string(),
content: "world".to_string(),
channel: "test-channel".to_string(),
timestamp: 2,
@ -1570,6 +1595,25 @@ mod tests {
assert!(truncated.is_char_boundary(truncated.len()));
}
#[test]
fn prompt_contains_channel_capabilities() {
let ws = make_workspace();
let prompt = build_system_prompt(ws.path(), "model", &[], &[], None, None);
assert!(
prompt.contains("## Channel Capabilities"),
"missing Channel Capabilities section"
);
assert!(
prompt.contains("running as a Discord bot"),
"missing Discord context"
);
assert!(
prompt.contains("NEVER repeat, describe, or echo credentials"),
"missing security instruction"
);
}
#[test]
fn prompt_workspace_path() {
let ws = make_workspace();
@ -1583,6 +1627,7 @@ mod tests {
let msg = traits::ChannelMessage {
id: "msg_abc123".into(),
sender: "U123".into(),
reply_target: "C456".into(),
content: "hello".into(),
channel: "slack".into(),
timestamp: 1,
@ -1596,6 +1641,7 @@ mod tests {
let msg1 = traits::ChannelMessage {
id: "msg_1".into(),
sender: "U123".into(),
reply_target: "C456".into(),
content: "first".into(),
channel: "slack".into(),
timestamp: 1,
@ -1603,6 +1649,7 @@ mod tests {
let msg2 = traits::ChannelMessage {
id: "msg_2".into(),
sender: "U123".into(),
reply_target: "C456".into(),
content: "second".into(),
channel: "slack".into(),
timestamp: 2,
@ -1622,6 +1669,7 @@ mod tests {
let msg1 = traits::ChannelMessage {
id: "msg_1".into(),
sender: "U123".into(),
reply_target: "C456".into(),
content: "I'm Paul".into(),
channel: "slack".into(),
timestamp: 1,
@ -1629,6 +1677,7 @@ mod tests {
let msg2 = traits::ChannelMessage {
id: "msg_2".into(),
sender: "U123".into(),
reply_target: "C456".into(),
content: "I'm 45".into(),
channel: "slack".into(),
timestamp: 2,
@ -1638,6 +1687,7 @@ mod tests {
&conversation_memory_key(&msg1),
&msg1.content,
MemoryCategory::Conversation,
None,
)
.await
.unwrap();
@ -1645,13 +1695,14 @@ mod tests {
&conversation_memory_key(&msg2),
&msg2.content,
MemoryCategory::Conversation,
None,
)
.await
.unwrap();
assert_eq!(mem.count().await.unwrap(), 2);
let recalled = mem.recall("45", 5).await.unwrap();
let recalled = mem.recall("45", 5, None).await.unwrap();
assert!(recalled.iter().any(|entry| entry.content.contains("45")));
}
@ -1659,7 +1710,7 @@ mod tests {
async fn build_memory_context_includes_recalled_entries() {
let tmp = TempDir::new().unwrap();
let mem = SqliteMemory::new(tmp.path()).unwrap();
mem.store("age_fact", "Age is 45", MemoryCategory::Conversation)
mem.store("age_fact", "Age is 45", MemoryCategory::Conversation, None)
.await
.unwrap();

View file

@ -161,6 +161,7 @@ impl Channel for SlackChannel {
let channel_msg = ChannelMessage {
id: format!("slack_{channel_id}_{ts}"),
sender: user.to_string(),
reply_target: channel_id.clone(),
content: text.to_string(),
channel: "slack".to_string(),
timestamp: std::time::SystemTime::now()

View file

@ -51,6 +51,133 @@ fn split_message_for_telegram(message: &str) -> Vec<String> {
chunks
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum TelegramAttachmentKind {
Image,
Document,
Video,
Audio,
Voice,
}
#[derive(Debug, Clone, PartialEq, Eq)]
struct TelegramAttachment {
kind: TelegramAttachmentKind,
target: String,
}
impl TelegramAttachmentKind {
fn from_marker(marker: &str) -> Option<Self> {
match marker.trim().to_ascii_uppercase().as_str() {
"IMAGE" | "PHOTO" => Some(Self::Image),
"DOCUMENT" | "FILE" => Some(Self::Document),
"VIDEO" => Some(Self::Video),
"AUDIO" => Some(Self::Audio),
"VOICE" => Some(Self::Voice),
_ => None,
}
}
}
fn is_http_url(target: &str) -> bool {
target.starts_with("http://") || target.starts_with("https://")
}
fn infer_attachment_kind_from_target(target: &str) -> Option<TelegramAttachmentKind> {
let normalized = target
.split('?')
.next()
.unwrap_or(target)
.split('#')
.next()
.unwrap_or(target);
let extension = Path::new(normalized)
.extension()
.and_then(|ext| ext.to_str())?
.to_ascii_lowercase();
match extension.as_str() {
"png" | "jpg" | "jpeg" | "gif" | "webp" | "bmp" => Some(TelegramAttachmentKind::Image),
"mp4" | "mov" | "mkv" | "avi" | "webm" => Some(TelegramAttachmentKind::Video),
"mp3" | "m4a" | "wav" | "flac" => Some(TelegramAttachmentKind::Audio),
"ogg" | "oga" | "opus" => Some(TelegramAttachmentKind::Voice),
"pdf" | "txt" | "md" | "csv" | "json" | "zip" | "tar" | "gz" | "doc" | "docx" | "xls"
| "xlsx" | "ppt" | "pptx" => Some(TelegramAttachmentKind::Document),
_ => None,
}
}
fn parse_path_only_attachment(message: &str) -> Option<TelegramAttachment> {
let trimmed = message.trim();
if trimmed.is_empty() || trimmed.contains('\n') {
return None;
}
let candidate = trimmed.trim_matches(|c| matches!(c, '`' | '"' | '\''));
if candidate.chars().any(char::is_whitespace) {
return None;
}
let candidate = candidate.strip_prefix("file://").unwrap_or(candidate);
let kind = infer_attachment_kind_from_target(candidate)?;
if !is_http_url(candidate) && !Path::new(candidate).exists() {
return None;
}
Some(TelegramAttachment {
kind,
target: candidate.to_string(),
})
}
fn parse_attachment_markers(message: &str) -> (String, Vec<TelegramAttachment>) {
let mut cleaned = String::with_capacity(message.len());
let mut attachments = Vec::new();
let mut cursor = 0;
while cursor < message.len() {
let Some(open_rel) = message[cursor..].find('[') else {
cleaned.push_str(&message[cursor..]);
break;
};
let open = cursor + open_rel;
cleaned.push_str(&message[cursor..open]);
let Some(close_rel) = message[open..].find(']') else {
cleaned.push_str(&message[open..]);
break;
};
let close = open + close_rel;
let marker = &message[open + 1..close];
let parsed = marker.split_once(':').and_then(|(kind, target)| {
let kind = TelegramAttachmentKind::from_marker(kind)?;
let target = target.trim();
if target.is_empty() {
return None;
}
Some(TelegramAttachment {
kind,
target: target.to_string(),
})
});
if let Some(attachment) = parsed {
attachments.push(attachment);
} else {
cleaned.push_str(&message[open..=close]);
}
cursor = close + 1;
}
(cleaned.trim().to_string(), attachments)
}
/// Telegram channel — long-polls the Bot API for updates
pub struct TelegramChannel {
bot_token: String,
@ -82,6 +209,216 @@ impl TelegramChannel {
identities.into_iter().any(|id| self.is_user_allowed(id))
}
fn parse_update_message(&self, update: &serde_json::Value) -> Option<ChannelMessage> {
let message = update.get("message")?;
let text = message.get("text").and_then(serde_json::Value::as_str)?;
let username = message
.get("from")
.and_then(|from| from.get("username"))
.and_then(serde_json::Value::as_str)
.unwrap_or("unknown")
.to_string();
let user_id = message
.get("from")
.and_then(|from| from.get("id"))
.and_then(serde_json::Value::as_i64)
.map(|id| id.to_string());
let sender_identity = if username == "unknown" {
user_id.clone().unwrap_or_else(|| "unknown".to_string())
} else {
username.clone()
};
let mut identities = vec![username.as_str()];
if let Some(id) = user_id.as_deref() {
identities.push(id);
}
if !self.is_any_user_allowed(identities.iter().copied()) {
tracing::warn!(
"Telegram: ignoring message from unauthorized user: username={username}, user_id={}. \
Allowlist Telegram @username or numeric user ID, then run `zeroclaw onboard --channels-only`.",
user_id.as_deref().unwrap_or("unknown")
);
return None;
}
let chat_id = message
.get("chat")
.and_then(|chat| chat.get("id"))
.and_then(serde_json::Value::as_i64)
.map(|id| id.to_string())?;
let message_id = message
.get("message_id")
.and_then(serde_json::Value::as_i64)
.unwrap_or(0);
Some(ChannelMessage {
id: format!("telegram_{chat_id}_{message_id}"),
sender: sender_identity,
reply_target: chat_id,
content: text.to_string(),
channel: "telegram".to_string(),
timestamp: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
})
}
async fn send_text_chunks(&self, message: &str, chat_id: &str) -> anyhow::Result<()> {
let chunks = split_message_for_telegram(message);
for (index, chunk) in chunks.iter().enumerate() {
let text = if chunks.len() > 1 {
if index == 0 {
format!("{chunk}\n\n(continues...)")
} else if index == chunks.len() - 1 {
format!("(continued)\n\n{chunk}")
} else {
format!("(continued)\n\n{chunk}\n\n(continues...)")
}
} else {
chunk.to_string()
};
let markdown_body = serde_json::json!({
"chat_id": chat_id,
"text": text,
"parse_mode": "Markdown"
});
let markdown_resp = self
.client
.post(self.api_url("sendMessage"))
.json(&markdown_body)
.send()
.await?;
if markdown_resp.status().is_success() {
if index < chunks.len() - 1 {
tokio::time::sleep(Duration::from_millis(100)).await;
}
continue;
}
let markdown_status = markdown_resp.status();
let markdown_err = markdown_resp.text().await.unwrap_or_default();
tracing::warn!(
status = ?markdown_status,
"Telegram sendMessage with Markdown failed; retrying without parse_mode"
);
let plain_body = serde_json::json!({
"chat_id": chat_id,
"text": text,
});
let plain_resp = self
.client
.post(self.api_url("sendMessage"))
.json(&plain_body)
.send()
.await?;
if !plain_resp.status().is_success() {
let plain_status = plain_resp.status();
let plain_err = plain_resp.text().await.unwrap_or_default();
anyhow::bail!(
"Telegram sendMessage failed (markdown {}: {}; plain {}: {})",
markdown_status,
markdown_err,
plain_status,
plain_err
);
}
if index < chunks.len() - 1 {
tokio::time::sleep(Duration::from_millis(100)).await;
}
}
Ok(())
}
async fn send_media_by_url(
&self,
method: &str,
media_field: &str,
chat_id: &str,
url: &str,
caption: Option<&str>,
) -> anyhow::Result<()> {
let mut body = serde_json::json!({
"chat_id": chat_id,
});
body[media_field] = serde_json::Value::String(url.to_string());
if let Some(cap) = caption {
body["caption"] = serde_json::Value::String(cap.to_string());
}
let resp = self
.client
.post(self.api_url(method))
.json(&body)
.send()
.await?;
if !resp.status().is_success() {
let err = resp.text().await?;
anyhow::bail!("Telegram {method} by URL failed: {err}");
}
tracing::info!("Telegram {method} sent to {chat_id}: {url}");
Ok(())
}
async fn send_attachment(
&self,
chat_id: &str,
attachment: &TelegramAttachment,
) -> anyhow::Result<()> {
let target = attachment.target.trim();
if is_http_url(target) {
return match attachment.kind {
TelegramAttachmentKind::Image => {
self.send_photo_by_url(chat_id, target, None).await
}
TelegramAttachmentKind::Document => {
self.send_document_by_url(chat_id, target, None).await
}
TelegramAttachmentKind::Video => {
self.send_video_by_url(chat_id, target, None).await
}
TelegramAttachmentKind::Audio => {
self.send_audio_by_url(chat_id, target, None).await
}
TelegramAttachmentKind::Voice => {
self.send_voice_by_url(chat_id, target, None).await
}
};
}
let path = Path::new(target);
if !path.exists() {
anyhow::bail!("Telegram attachment path not found: {target}");
}
match attachment.kind {
TelegramAttachmentKind::Image => self.send_photo(chat_id, path, None).await,
TelegramAttachmentKind::Document => self.send_document(chat_id, path, None).await,
TelegramAttachmentKind::Video => self.send_video(chat_id, path, None).await,
TelegramAttachmentKind::Audio => self.send_audio(chat_id, path, None).await,
TelegramAttachmentKind::Voice => self.send_voice(chat_id, path, None).await,
}
}
/// Send a document/file to a Telegram chat
pub async fn send_document(
&self,
@ -408,6 +745,39 @@ impl TelegramChannel {
tracing::info!("Telegram photo (URL) sent to {chat_id}: {url}");
Ok(())
}
/// Send a video by URL (Telegram will download it)
pub async fn send_video_by_url(
&self,
chat_id: &str,
url: &str,
caption: Option<&str>,
) -> anyhow::Result<()> {
self.send_media_by_url("sendVideo", "video", chat_id, url, caption)
.await
}
/// Send an audio file by URL (Telegram will download it)
pub async fn send_audio_by_url(
&self,
chat_id: &str,
url: &str,
caption: Option<&str>,
) -> anyhow::Result<()> {
self.send_media_by_url("sendAudio", "audio", chat_id, url, caption)
.await
}
/// Send a voice message by URL (Telegram will download it)
pub async fn send_voice_by_url(
&self,
chat_id: &str,
url: &str,
caption: Option<&str>,
) -> anyhow::Result<()> {
self.send_media_by_url("sendVoice", "voice", chat_id, url, caption)
.await
}
}
#[async_trait]
@ -417,82 +787,27 @@ impl Channel for TelegramChannel {
}
async fn send(&self, message: &str, chat_id: &str) -> anyhow::Result<()> {
// Split message if it exceeds Telegram's 4096 character limit
let chunks = split_message_for_telegram(message);
let (text_without_markers, attachments) = parse_attachment_markers(message);
for (i, chunk) in chunks.iter().enumerate() {
// Add continuation marker for multi-part messages
let text = if chunks.len() > 1 {
if i == 0 {
format!("{chunk}\n\n(continues...)")
} else if i == chunks.len() - 1 {
format!("(continued)\n\n{chunk}")
} else {
format!("(continued)\n\n{chunk}\n\n(continues...)")
}
} else {
chunk.to_string()
};
let markdown_body = serde_json::json!({
"chat_id": chat_id,
"text": text,
"parse_mode": "Markdown"
});
let markdown_resp = self
.client
.post(self.api_url("sendMessage"))
.json(&markdown_body)
.send()
if !attachments.is_empty() {
if !text_without_markers.is_empty() {
self.send_text_chunks(&text_without_markers, chat_id)
.await?;
if markdown_resp.status().is_success() {
// Small delay between chunks to avoid rate limiting
if i < chunks.len() - 1 {
tokio::time::sleep(Duration::from_millis(100)).await;
}
continue;
}
let markdown_status = markdown_resp.status();
let markdown_err = markdown_resp.text().await.unwrap_or_default();
tracing::warn!(
status = ?markdown_status,
"Telegram sendMessage with Markdown failed; retrying without parse_mode"
);
// Retry without parse_mode as a compatibility fallback.
let plain_body = serde_json::json!({
"chat_id": chat_id,
"text": text,
});
let plain_resp = self
.client
.post(self.api_url("sendMessage"))
.json(&plain_body)
.send()
.await?;
if !plain_resp.status().is_success() {
let plain_status = plain_resp.status();
let plain_err = plain_resp.text().await.unwrap_or_default();
anyhow::bail!(
"Telegram sendMessage failed (markdown {}: {}; plain {}: {})",
markdown_status,
markdown_err,
plain_status,
plain_err
);
for attachment in &attachments {
self.send_attachment(chat_id, attachment).await?;
}
// Small delay between chunks to avoid rate limiting
if i < chunks.len() - 1 {
tokio::time::sleep(Duration::from_millis(100)).await;
}
return Ok(());
}
Ok(())
if let Some(attachment) = parse_path_only_attachment(message) {
self.send_attachment(chat_id, &attachment).await?;
return Ok(());
}
self.send_text_chunks(message, chat_id).await
}
async fn listen(&self, tx: tokio::sync::mpsc::Sender<ChannelMessage>) -> anyhow::Result<()> {
@ -533,59 +848,13 @@ impl Channel for TelegramChannel {
offset = uid + 1;
}
let Some(message) = update.get("message") else {
let Some(msg) = self.parse_update_message(update) else {
continue;
};
let Some(text) = message.get("text").and_then(serde_json::Value::as_str) else {
continue;
};
let username_opt = message
.get("from")
.and_then(|f| f.get("username"))
.and_then(|u| u.as_str());
let username = username_opt.unwrap_or("unknown");
let user_id = message
.get("from")
.and_then(|f| f.get("id"))
.and_then(serde_json::Value::as_i64);
let user_id_str = user_id.map(|id| id.to_string());
let mut identities = vec![username];
if let Some(ref id) = user_id_str {
identities.push(id.as_str());
}
if !self.is_any_user_allowed(identities.iter().copied()) {
tracing::warn!(
"Telegram: ignoring message from unauthorized user: username={username}, user_id={}. \
Allowlist Telegram @username or numeric user ID, then run `zeroclaw onboard --channels-only`.",
user_id_str.as_deref().unwrap_or("unknown")
);
continue;
}
let chat_id = message
.get("chat")
.and_then(|c| c.get("id"))
.and_then(serde_json::Value::as_i64)
.map(|id| id.to_string());
let Some(chat_id) = chat_id else {
tracing::warn!("Telegram: missing chat_id in message, skipping");
continue;
};
let message_id = message
.get("message_id")
.and_then(|v| v.as_i64())
.unwrap_or(0);
// Send "typing" indicator immediately when we receive a message
let typing_body = serde_json::json!({
"chat_id": &chat_id,
"chat_id": &msg.reply_target,
"action": "typing"
});
let _ = self
@ -595,17 +864,6 @@ Allowlist Telegram @username or numeric user ID, then run `zeroclaw onboard --ch
.send()
.await; // Ignore errors for typing indicator
let msg = ChannelMessage {
id: format!("telegram_{chat_id}_{message_id}"),
sender: username.to_string(),
content: text.to_string(),
channel: "telegram".to_string(),
timestamp: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs(),
};
if tx.send(msg).await.is_err() {
return Ok(());
}
@ -716,6 +974,107 @@ mod tests {
assert!(!ch.is_any_user_allowed(["unknown", "123456789"]));
}
#[test]
fn parse_attachment_markers_extracts_multiple_types() {
let message = "Here are files [IMAGE:/tmp/a.png] and [DOCUMENT:https://example.com/a.pdf]";
let (cleaned, attachments) = parse_attachment_markers(message);
assert_eq!(cleaned, "Here are files and");
assert_eq!(attachments.len(), 2);
assert_eq!(attachments[0].kind, TelegramAttachmentKind::Image);
assert_eq!(attachments[0].target, "/tmp/a.png");
assert_eq!(attachments[1].kind, TelegramAttachmentKind::Document);
assert_eq!(attachments[1].target, "https://example.com/a.pdf");
}
#[test]
fn parse_attachment_markers_keeps_invalid_markers_in_text() {
let message = "Report [UNKNOWN:/tmp/a.bin]";
let (cleaned, attachments) = parse_attachment_markers(message);
assert_eq!(cleaned, "Report [UNKNOWN:/tmp/a.bin]");
assert!(attachments.is_empty());
}
#[test]
fn parse_path_only_attachment_detects_existing_file() {
let dir = tempfile::tempdir().unwrap();
let image_path = dir.path().join("snap.png");
std::fs::write(&image_path, b"fake-png").unwrap();
let parsed = parse_path_only_attachment(image_path.to_string_lossy().as_ref())
.expect("expected attachment");
assert_eq!(parsed.kind, TelegramAttachmentKind::Image);
assert_eq!(parsed.target, image_path.to_string_lossy());
}
#[test]
fn parse_path_only_attachment_rejects_sentence_text() {
assert!(parse_path_only_attachment("Screenshot saved to /tmp/snap.png").is_none());
}
#[test]
fn infer_attachment_kind_from_target_detects_document_extension() {
assert_eq!(
infer_attachment_kind_from_target("https://example.com/files/specs.pdf?download=1"),
Some(TelegramAttachmentKind::Document)
);
}
#[test]
fn parse_update_message_uses_chat_id_as_reply_target() {
let ch = TelegramChannel::new("token".into(), vec!["*".into()]);
let update = serde_json::json!({
"update_id": 1,
"message": {
"message_id": 33,
"text": "hello",
"from": {
"id": 555,
"username": "alice"
},
"chat": {
"id": -100200300
}
}
});
let msg = ch
.parse_update_message(&update)
.expect("message should parse");
assert_eq!(msg.sender, "alice");
assert_eq!(msg.reply_target, "-100200300");
assert_eq!(msg.content, "hello");
assert_eq!(msg.id, "telegram_-100200300_33");
}
#[test]
fn parse_update_message_allows_numeric_id_without_username() {
let ch = TelegramChannel::new("token".into(), vec!["555".into()]);
let update = serde_json::json!({
"update_id": 2,
"message": {
"message_id": 9,
"text": "ping",
"from": {
"id": 555
},
"chat": {
"id": 12345
}
}
});
let msg = ch
.parse_update_message(&update)
.expect("numeric allowlist should pass");
assert_eq!(msg.sender, "555");
assert_eq!(msg.reply_target, "12345");
}
// ── File sending API URL tests ──────────────────────────────────
#[test]

View file

@ -5,6 +5,7 @@ use async_trait::async_trait;
pub struct ChannelMessage {
pub id: String,
pub sender: String,
pub reply_target: String,
pub content: String,
pub channel: String,
pub timestamp: u64,
@ -62,6 +63,7 @@ mod tests {
tx.send(ChannelMessage {
id: "1".into(),
sender: "tester".into(),
reply_target: "tester".into(),
content: "hello".into(),
channel: "dummy".into(),
timestamp: 123,
@ -76,6 +78,7 @@ mod tests {
let message = ChannelMessage {
id: "42".into(),
sender: "alice".into(),
reply_target: "alice".into(),
content: "ping".into(),
channel: "dummy".into(),
timestamp: 999,
@ -84,6 +87,7 @@ mod tests {
let cloned = message.clone();
assert_eq!(cloned.id, "42");
assert_eq!(cloned.sender, "alice");
assert_eq!(cloned.reply_target, "alice");
assert_eq!(cloned.content, "ping");
assert_eq!(cloned.channel, "dummy");
assert_eq!(cloned.timestamp, 999);

View file

@ -10,7 +10,7 @@ use uuid::Uuid;
/// happens in the gateway when Meta sends webhook events.
pub struct WhatsAppChannel {
access_token: String,
phone_number_id: String,
endpoint_id: String,
verify_token: String,
allowed_numbers: Vec<String>,
client: reqwest::Client,
@ -19,13 +19,13 @@ pub struct WhatsAppChannel {
impl WhatsAppChannel {
pub fn new(
access_token: String,
phone_number_id: String,
endpoint_id: String,
verify_token: String,
allowed_numbers: Vec<String>,
) -> Self {
Self {
access_token,
phone_number_id,
endpoint_id,
verify_token,
allowed_numbers,
client: reqwest::Client::new(),
@ -119,6 +119,7 @@ impl WhatsAppChannel {
messages.push(ChannelMessage {
id: Uuid::new_v4().to_string(),
reply_target: normalized_from.clone(),
sender: normalized_from,
content,
channel: "whatsapp".to_string(),
@ -142,7 +143,7 @@ impl Channel for WhatsAppChannel {
// WhatsApp Cloud API: POST to /v18.0/{phone_number_id}/messages
let url = format!(
"https://graph.facebook.com/v18.0/{}/messages",
self.phone_number_id
self.endpoint_id
);
// Normalize recipient (remove leading + if present for API)
@ -162,7 +163,7 @@ impl Channel for WhatsAppChannel {
let resp = self
.client
.post(&url)
.header("Authorization", format!("Bearer {}", self.access_token))
.bearer_auth(&self.access_token)
.header("Content-Type", "application/json")
.json(&body)
.send()
@ -195,11 +196,11 @@ impl Channel for WhatsAppChannel {
async fn health_check(&self) -> bool {
// Check if we can reach the WhatsApp API
let url = format!("https://graph.facebook.com/v18.0/{}", self.phone_number_id);
let url = format!("https://graph.facebook.com/v18.0/{}", self.endpoint_id);
self.client
.get(&url)
.header("Authorization", format!("Bearer {}", self.access_token))
.bearer_auth(&self.access_token)
.send()
.await
.map(|r| r.status().is_success())

View file

@ -37,9 +37,22 @@ mod tests {
guild_id: Some("123".into()),
allowed_users: vec![],
listen_to_bots: false,
mention_only: false,
};
let lark = LarkConfig {
app_id: "app-id".into(),
app_secret: "app-secret".into(),
encrypt_key: None,
verification_token: None,
allowed_users: vec![],
use_feishu: false,
receive_mode: crate::config::schema::LarkReceiveMode::Websocket,
port: None,
};
assert_eq!(telegram.allowed_users.len(), 1);
assert_eq!(discord.guild_id.as_deref(), Some("123"));
assert_eq!(lark.app_id, "app-id");
}
}

View file

@ -18,6 +18,8 @@ pub struct Config {
#[serde(skip)]
pub config_path: PathBuf,
pub api_key: Option<String>,
/// Base URL override for provider API (e.g. "http://10.0.0.1:11434" for remote Ollama)
pub api_url: Option<String>,
pub default_provider: Option<String>,
pub default_model: Option<String>,
pub default_temperature: f64,
@ -1317,6 +1319,10 @@ pub struct DiscordConfig {
/// The bot still ignores its own messages to prevent feedback loops.
#[serde(default)]
pub listen_to_bots: bool,
/// When true, only respond to messages that @-mention the bot.
/// Other messages in the guild are silently ignored.
#[serde(default)]
pub mention_only: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
@ -1395,8 +1401,20 @@ fn default_irc_port() -> u16 {
6697
}
/// Lark/Feishu configuration for messaging integration
/// Lark is the international version, Feishu is the Chinese version
/// How ZeroClaw receives events from Feishu / Lark.
///
/// - `websocket` (default) — persistent WSS long-connection; no public URL required.
/// - `webhook` — HTTP callback server; requires a public HTTPS endpoint.
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
#[serde(rename_all = "lowercase")]
pub enum LarkReceiveMode {
#[default]
Websocket,
Webhook,
}
/// Lark/Feishu configuration for messaging integration.
/// Lark is the international version; Feishu is the Chinese version.
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LarkConfig {
/// App ID from Lark/Feishu developer console
@ -1415,6 +1433,13 @@ pub struct LarkConfig {
/// Whether to use the Feishu (Chinese) endpoint instead of Lark (International)
#[serde(default)]
pub use_feishu: bool,
/// Event receive mode: "websocket" (default) or "webhook"
#[serde(default)]
pub receive_mode: LarkReceiveMode,
/// HTTP port for webhook mode only. Must be set when receive_mode = "webhook".
/// Not required (and ignored) for websocket mode.
#[serde(default)]
pub port: Option<u16>,
}
// ── Security Config ─────────────────────────────────────────────────
@ -1594,6 +1619,7 @@ impl Default for Config {
workspace_dir: zeroclaw_dir.join("workspace"),
config_path: zeroclaw_dir.join("config.toml"),
api_key: None,
api_url: None,
default_provider: Some("openrouter".to_string()),
default_model: Some("anthropic/claude-sonnet-4".to_string()),
default_temperature: 0.7,
@ -1623,35 +1649,146 @@ impl Default for Config {
}
}
impl Config {
pub fn load_or_init() -> Result<Self> {
fn default_config_and_workspace_dirs() -> Result<(PathBuf, PathBuf)> {
let home = UserDirs::new()
.map(|u| u.home_dir().to_path_buf())
.context("Could not find home directory")?;
let zeroclaw_dir = home.join(".zeroclaw");
let config_path = zeroclaw_dir.join("config.toml");
let config_dir = home.join(".zeroclaw");
Ok((config_dir.clone(), config_dir.join("workspace")))
}
if !zeroclaw_dir.exists() {
fs::create_dir_all(&zeroclaw_dir).context("Failed to create .zeroclaw directory")?;
fs::create_dir_all(zeroclaw_dir.join("workspace"))
.context("Failed to create workspace directory")?;
fn resolve_config_dir_for_workspace(workspace_dir: &Path) -> PathBuf {
let workspace_config_dir = workspace_dir.to_path_buf();
if workspace_config_dir.join("config.toml").exists() {
return workspace_config_dir;
}
let legacy_config_dir = workspace_dir
.parent()
.map(|parent| parent.join(".zeroclaw"));
if let Some(legacy_dir) = legacy_config_dir {
if legacy_dir.join("config.toml").exists() {
return legacy_dir;
}
if workspace_dir
.file_name()
.is_some_and(|name| name == std::ffi::OsStr::new("workspace"))
{
return legacy_dir;
}
}
workspace_config_dir
}
fn decrypt_optional_secret(
store: &crate::security::SecretStore,
value: &mut Option<String>,
field_name: &str,
) -> Result<()> {
if let Some(raw) = value.clone() {
if crate::security::SecretStore::is_encrypted(&raw) {
*value = Some(
store
.decrypt(&raw)
.with_context(|| format!("Failed to decrypt {field_name}"))?,
);
}
}
Ok(())
}
fn encrypt_optional_secret(
store: &crate::security::SecretStore,
value: &mut Option<String>,
field_name: &str,
) -> Result<()> {
if let Some(raw) = value.clone() {
if !crate::security::SecretStore::is_encrypted(&raw) {
*value = Some(
store
.encrypt(&raw)
.with_context(|| format!("Failed to encrypt {field_name}"))?,
);
}
}
Ok(())
}
impl Config {
pub fn load_or_init() -> Result<Self> {
// Resolve workspace first so config loading can follow ZEROCLAW_WORKSPACE.
let (zeroclaw_dir, workspace_dir) = match std::env::var("ZEROCLAW_WORKSPACE") {
Ok(custom_workspace) if !custom_workspace.is_empty() => {
let workspace = PathBuf::from(custom_workspace);
(resolve_config_dir_for_workspace(&workspace), workspace)
}
_ => default_config_and_workspace_dirs()?,
};
let config_path = zeroclaw_dir.join("config.toml");
fs::create_dir_all(&zeroclaw_dir).context("Failed to create config directory")?;
fs::create_dir_all(&workspace_dir).context("Failed to create workspace directory")?;
if config_path.exists() {
// Warn if config file is world-readable (may contain API keys)
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
if let Ok(meta) = fs::metadata(&config_path) {
if meta.permissions().mode() & 0o004 != 0 {
tracing::warn!(
"Config file {:?} is world-readable (mode {:o}). \
Consider restricting with: chmod 600 {:?}",
config_path,
meta.permissions().mode() & 0o777,
config_path,
);
}
}
}
let contents =
fs::read_to_string(&config_path).context("Failed to read config file")?;
let mut config: Config =
toml::from_str(&contents).context("Failed to parse config file")?;
// Set computed paths that are skipped during serialization
config.config_path = config_path.clone();
config.workspace_dir = zeroclaw_dir.join("workspace");
config.workspace_dir = workspace_dir;
let store = crate::security::SecretStore::new(&zeroclaw_dir, config.secrets.encrypt);
decrypt_optional_secret(&store, &mut config.api_key, "config.api_key")?;
decrypt_optional_secret(
&store,
&mut config.composio.api_key,
"config.composio.api_key",
)?;
decrypt_optional_secret(
&store,
&mut config.browser.computer_use.api_key,
"config.browser.computer_use.api_key",
)?;
for agent in config.agents.values_mut() {
decrypt_optional_secret(&store, &mut agent.api_key, "config.agents.*.api_key")?;
}
config.apply_env_overrides();
Ok(config)
} else {
let mut config = Config::default();
config.config_path = config_path.clone();
config.workspace_dir = zeroclaw_dir.join("workspace");
config.workspace_dir = workspace_dir;
config.save()?;
// Restrict permissions on newly created config file (may contain API keys)
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
let _ = fs::set_permissions(&config_path, fs::Permissions::from_mode(0o600));
}
config.apply_env_overrides();
Ok(config)
}
@ -1732,23 +1869,29 @@ impl Config {
}
pub fn save(&self) -> Result<()> {
// Encrypt agent API keys before serialization
// Encrypt secrets before serialization
let mut config_to_save = self.clone();
let zeroclaw_dir = self
.config_path
.parent()
.context("Config path must have a parent directory")?;
let store = crate::security::SecretStore::new(zeroclaw_dir, self.secrets.encrypt);
encrypt_optional_secret(&store, &mut config_to_save.api_key, "config.api_key")?;
encrypt_optional_secret(
&store,
&mut config_to_save.composio.api_key,
"config.composio.api_key",
)?;
encrypt_optional_secret(
&store,
&mut config_to_save.browser.computer_use.api_key,
"config.browser.computer_use.api_key",
)?;
for agent in config_to_save.agents.values_mut() {
if let Some(ref plaintext_key) = agent.api_key {
if !crate::security::SecretStore::is_encrypted(plaintext_key) {
agent.api_key = Some(
store
.encrypt(plaintext_key)
.context("Failed to encrypt agent API key")?,
);
}
}
encrypt_optional_secret(&store, &mut agent.api_key, "config.agents.*.api_key")?;
}
let toml_str =
@ -1949,6 +2092,7 @@ default_temperature = 0.7
workspace_dir: PathBuf::from("/tmp/test/workspace"),
config_path: PathBuf::from("/tmp/test/config.toml"),
api_key: Some("sk-test-key".into()),
api_url: None,
default_provider: Some("openrouter".into()),
default_model: Some("gpt-4o".into()),
default_temperature: 0.5,
@ -2091,6 +2235,7 @@ tool_dispatcher = "xml"
workspace_dir: dir.join("workspace"),
config_path: config_path.clone(),
api_key: Some("sk-roundtrip".into()),
api_url: None,
default_provider: Some("openrouter".into()),
default_model: Some("test-model".into()),
default_temperature: 0.9,
@ -2123,13 +2268,82 @@ tool_dispatcher = "xml"
let contents = fs::read_to_string(&config_path).unwrap();
let loaded: Config = toml::from_str(&contents).unwrap();
assert_eq!(loaded.api_key.as_deref(), Some("sk-roundtrip"));
assert!(loaded
.api_key
.as_deref()
.is_some_and(crate::security::SecretStore::is_encrypted));
let store = crate::security::SecretStore::new(&dir, true);
let decrypted = store.decrypt(loaded.api_key.as_deref().unwrap()).unwrap();
assert_eq!(decrypted, "sk-roundtrip");
assert_eq!(loaded.default_model.as_deref(), Some("test-model"));
assert!((loaded.default_temperature - 0.9).abs() < f64::EPSILON);
let _ = fs::remove_dir_all(&dir);
}
#[test]
fn config_save_encrypts_nested_credentials() {
let dir = std::env::temp_dir().join(format!(
"zeroclaw_test_nested_credentials_{}",
uuid::Uuid::new_v4()
));
fs::create_dir_all(&dir).unwrap();
let mut config = Config::default();
config.workspace_dir = dir.join("workspace");
config.config_path = dir.join("config.toml");
config.api_key = Some("root-credential".into());
config.composio.api_key = Some("composio-credential".into());
config.browser.computer_use.api_key = Some("browser-credential".into());
config.agents.insert(
"worker".into(),
DelegateAgentConfig {
provider: "openrouter".into(),
model: "model-test".into(),
system_prompt: None,
api_key: Some("agent-credential".into()),
temperature: None,
max_depth: 3,
},
);
config.save().unwrap();
let contents = fs::read_to_string(config.config_path.clone()).unwrap();
let stored: Config = toml::from_str(&contents).unwrap();
let store = crate::security::SecretStore::new(&dir, true);
let root_encrypted = stored.api_key.as_deref().unwrap();
assert!(crate::security::SecretStore::is_encrypted(root_encrypted));
assert_eq!(store.decrypt(root_encrypted).unwrap(), "root-credential");
let composio_encrypted = stored.composio.api_key.as_deref().unwrap();
assert!(crate::security::SecretStore::is_encrypted(
composio_encrypted
));
assert_eq!(
store.decrypt(composio_encrypted).unwrap(),
"composio-credential"
);
let browser_encrypted = stored.browser.computer_use.api_key.as_deref().unwrap();
assert!(crate::security::SecretStore::is_encrypted(
browser_encrypted
));
assert_eq!(
store.decrypt(browser_encrypted).unwrap(),
"browser-credential"
);
let worker = stored.agents.get("worker").unwrap();
let worker_encrypted = worker.api_key.as_deref().unwrap();
assert!(crate::security::SecretStore::is_encrypted(worker_encrypted));
assert_eq!(store.decrypt(worker_encrypted).unwrap(), "agent-credential");
let _ = fs::remove_dir_all(&dir);
}
#[test]
fn config_save_atomic_cleanup() {
let dir =
@ -2182,6 +2396,7 @@ tool_dispatcher = "xml"
guild_id: Some("12345".into()),
allowed_users: vec![],
listen_to_bots: false,
mention_only: false,
};
let json = serde_json::to_string(&dc).unwrap();
let parsed: DiscordConfig = serde_json::from_str(&json).unwrap();
@ -2196,6 +2411,7 @@ tool_dispatcher = "xml"
guild_id: None,
allowed_users: vec![],
listen_to_bots: false,
mention_only: false,
};
let json = serde_json::to_string(&dc).unwrap();
let parsed: DiscordConfig = serde_json::from_str(&json).unwrap();
@ -2818,6 +3034,96 @@ default_temperature = 0.7
std::env::remove_var("ZEROCLAW_WORKSPACE");
}
#[test]
fn load_or_init_workspace_override_uses_workspace_root_for_config() {
let _env_guard = env_override_test_guard();
let temp_home =
std::env::temp_dir().join(format!("zeroclaw_test_home_{}", uuid::Uuid::new_v4()));
let workspace_dir = temp_home.join("profile-a");
let original_home = std::env::var("HOME").ok();
std::env::set_var("HOME", &temp_home);
std::env::set_var("ZEROCLAW_WORKSPACE", &workspace_dir);
let config = Config::load_or_init().unwrap();
assert_eq!(config.workspace_dir, workspace_dir);
assert_eq!(config.config_path, workspace_dir.join("config.toml"));
assert!(workspace_dir.join("config.toml").exists());
std::env::remove_var("ZEROCLAW_WORKSPACE");
if let Some(home) = original_home {
std::env::set_var("HOME", home);
} else {
std::env::remove_var("HOME");
}
let _ = fs::remove_dir_all(temp_home);
}
#[test]
fn load_or_init_workspace_suffix_uses_legacy_config_layout() {
let _env_guard = env_override_test_guard();
let temp_home =
std::env::temp_dir().join(format!("zeroclaw_test_home_{}", uuid::Uuid::new_v4()));
let workspace_dir = temp_home.join("workspace");
let legacy_config_path = temp_home.join(".zeroclaw").join("config.toml");
let original_home = std::env::var("HOME").ok();
std::env::set_var("HOME", &temp_home);
std::env::set_var("ZEROCLAW_WORKSPACE", &workspace_dir);
let config = Config::load_or_init().unwrap();
assert_eq!(config.workspace_dir, workspace_dir);
assert_eq!(config.config_path, legacy_config_path);
assert!(config.config_path.exists());
std::env::remove_var("ZEROCLAW_WORKSPACE");
if let Some(home) = original_home {
std::env::set_var("HOME", home);
} else {
std::env::remove_var("HOME");
}
let _ = fs::remove_dir_all(temp_home);
}
#[test]
fn load_or_init_workspace_override_keeps_existing_legacy_config() {
let _env_guard = env_override_test_guard();
let temp_home =
std::env::temp_dir().join(format!("zeroclaw_test_home_{}", uuid::Uuid::new_v4()));
let workspace_dir = temp_home.join("custom-workspace");
let legacy_config_dir = temp_home.join(".zeroclaw");
let legacy_config_path = legacy_config_dir.join("config.toml");
fs::create_dir_all(&legacy_config_dir).unwrap();
fs::write(
&legacy_config_path,
r#"default_temperature = 0.7
default_model = "legacy-model"
"#,
)
.unwrap();
let original_home = std::env::var("HOME").ok();
std::env::set_var("HOME", &temp_home);
std::env::set_var("ZEROCLAW_WORKSPACE", &workspace_dir);
let config = Config::load_or_init().unwrap();
assert_eq!(config.workspace_dir, workspace_dir);
assert_eq!(config.config_path, legacy_config_path);
assert_eq!(config.default_model.as_deref(), Some("legacy-model"));
std::env::remove_var("ZEROCLAW_WORKSPACE");
if let Some(home) = original_home {
std::env::set_var("HOME", home);
} else {
std::env::remove_var("HOME");
}
let _ = fs::remove_dir_all(temp_home);
}
#[test]
fn env_override_empty_values_ignored() {
let _env_guard = env_override_test_guard();
@ -2975,4 +3281,118 @@ default_temperature = 0.7
assert_eq!(parsed.boards[0].board, "nucleo-f401re");
assert_eq!(parsed.boards[0].path.as_deref(), Some("/dev/ttyACM0"));
}
#[test]
fn lark_config_serde() {
let lc = LarkConfig {
app_id: "cli_123456".into(),
app_secret: "secret_abc".into(),
encrypt_key: Some("encrypt_key".into()),
verification_token: Some("verify_token".into()),
allowed_users: vec!["user_123".into(), "user_456".into()],
use_feishu: true,
receive_mode: LarkReceiveMode::Websocket,
port: None,
};
let json = serde_json::to_string(&lc).unwrap();
let parsed: LarkConfig = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.app_id, "cli_123456");
assert_eq!(parsed.app_secret, "secret_abc");
assert_eq!(parsed.encrypt_key.as_deref(), Some("encrypt_key"));
assert_eq!(parsed.verification_token.as_deref(), Some("verify_token"));
assert_eq!(parsed.allowed_users.len(), 2);
assert!(parsed.use_feishu);
}
#[test]
fn lark_config_toml_roundtrip() {
let lc = LarkConfig {
app_id: "cli_123456".into(),
app_secret: "secret_abc".into(),
encrypt_key: Some("encrypt_key".into()),
verification_token: Some("verify_token".into()),
allowed_users: vec!["*".into()],
use_feishu: false,
receive_mode: LarkReceiveMode::Webhook,
port: Some(9898),
};
let toml_str = toml::to_string(&lc).unwrap();
let parsed: LarkConfig = toml::from_str(&toml_str).unwrap();
assert_eq!(parsed.app_id, "cli_123456");
assert_eq!(parsed.app_secret, "secret_abc");
assert!(!parsed.use_feishu);
}
#[test]
fn lark_config_deserializes_without_optional_fields() {
let json = r#"{"app_id":"cli_123","app_secret":"secret"}"#;
let parsed: LarkConfig = serde_json::from_str(json).unwrap();
assert!(parsed.encrypt_key.is_none());
assert!(parsed.verification_token.is_none());
assert!(parsed.allowed_users.is_empty());
assert!(!parsed.use_feishu);
}
#[test]
fn lark_config_defaults_to_lark_endpoint() {
let json = r#"{"app_id":"cli_123","app_secret":"secret"}"#;
let parsed: LarkConfig = serde_json::from_str(json).unwrap();
assert!(
!parsed.use_feishu,
"use_feishu should default to false (Lark)"
);
}
#[test]
fn lark_config_with_wildcard_allowed_users() {
let json = r#"{"app_id":"cli_123","app_secret":"secret","allowed_users":["*"]}"#;
let parsed: LarkConfig = serde_json::from_str(json).unwrap();
assert_eq!(parsed.allowed_users, vec!["*"]);
}
// ── Config file permission hardening (Unix only) ───────────────
#[cfg(unix)]
#[test]
fn new_config_file_has_restricted_permissions() {
use std::os::unix::fs::PermissionsExt;
let tmp = tempfile::TempDir::new().unwrap();
let config_path = tmp.path().join("config.toml");
// Create a config and save it
let mut config = Config::default();
config.config_path = config_path.clone();
config.save().unwrap();
// Apply the same permission logic as load_or_init
let _ = std::fs::set_permissions(&config_path, std::fs::Permissions::from_mode(0o600));
let meta = std::fs::metadata(&config_path).unwrap();
let mode = meta.permissions().mode() & 0o777;
assert_eq!(
mode, 0o600,
"New config file should be owner-only (0600), got {mode:o}"
);
}
#[cfg(unix)]
#[test]
fn world_readable_config_is_detectable() {
use std::os::unix::fs::PermissionsExt;
let tmp = tempfile::TempDir::new().unwrap();
let config_path = tmp.path().join("config.toml");
// Create a config file with intentionally loose permissions
std::fs::write(&config_path, "# test config").unwrap();
std::fs::set_permissions(&config_path, std::fs::Permissions::from_mode(0o644)).unwrap();
let meta = std::fs::metadata(&config_path).unwrap();
let mode = meta.permissions().mode();
assert!(
mode & 0o004 != 0,
"Test setup: file should be world-readable (mode {mode:o})"
);
}
}

View file

@ -245,6 +245,7 @@ async fn deliver_if_configured(config: &Config, job: &CronJob, output: &str) ->
dc.guild_id.clone(),
dc.allowed_users.clone(),
dc.listen_to_bots,
dc.mention_only,
);
channel.send(output, target).await?;
}

View file

@ -216,6 +216,7 @@ fn has_supervised_channels(config: &Config) -> bool {
|| config.channels_config.matrix.is_some()
|| config.channels_config.whatsapp.is_some()
|| config.channels_config.email.is_some()
|| config.channels_config.lark.is_some()
}
#[cfg(test)]

View file

@ -49,6 +49,13 @@ fn whatsapp_memory_key(msg: &crate::channels::traits::ChannelMessage) -> String
format!("whatsapp_{}_{}", msg.sender, msg.id)
}
fn hash_webhook_secret(value: &str) -> String {
use sha2::{Digest, Sha256};
let digest = Sha256::digest(value.as_bytes());
hex::encode(digest)
}
/// How often the rate limiter sweeps stale IP entries from its map.
const RATE_LIMITER_SWEEP_INTERVAL_SECS: u64 = 300; // 5 minutes
@ -178,7 +185,8 @@ pub struct AppState {
pub temperature: f64,
pub mem: Arc<dyn Memory>,
pub auto_save: bool,
pub webhook_secret: Option<Arc<str>>,
/// SHA-256 hash of `X-Webhook-Secret` (hex-encoded), never plaintext.
pub webhook_secret_hash: Option<Arc<str>>,
pub pairing: Arc<PairingGuard>,
pub rate_limiter: Arc<GatewayRateLimiter>,
pub idempotency_store: Arc<IdempotencyStore>,
@ -208,6 +216,7 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
let provider: Arc<dyn Provider> = Arc::from(providers::create_resilient_provider(
config.default_provider.as_deref().unwrap_or("openrouter"),
config.api_key.as_deref(),
config.api_url.as_deref(),
&config.reliability,
)?);
let model = config
@ -251,12 +260,14 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
&config,
));
// Extract webhook secret for authentication
let webhook_secret: Option<Arc<str>> = config
.channels_config
.webhook
.as_ref()
.and_then(|w| w.secret.as_deref())
.map(Arc::from);
let webhook_secret_hash: Option<Arc<str>> =
config.channels_config.webhook.as_ref().and_then(|webhook| {
webhook.secret.as_ref().and_then(|raw_secret| {
let trimmed_secret = raw_secret.trim();
(!trimmed_secret.is_empty())
.then(|| Arc::<str>::from(hash_webhook_secret(trimmed_secret)))
})
});
// WhatsApp channel (if configured)
let whatsapp_channel: Option<Arc<WhatsAppChannel>> =
@ -342,9 +353,6 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
} else {
println!(" ⚠️ Pairing: DISABLED (all requests accepted)");
}
if webhook_secret.is_some() {
println!(" 🔒 Webhook secret: ENABLED");
}
println!(" Press Ctrl+C to stop.\n");
crate::health::mark_component_ok("gateway");
@ -356,7 +364,7 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
temperature,
mem,
auto_save: config.memory.auto_save,
webhook_secret,
webhook_secret_hash,
pairing,
rate_limiter,
idempotency_store,
@ -482,12 +490,15 @@ async fn handle_webhook(
}
// ── Webhook secret auth (optional, additional layer) ──
if let Some(ref secret) = state.webhook_secret {
let header_val = headers
if let Some(ref secret_hash) = state.webhook_secret_hash {
let header_hash = headers
.get("X-Webhook-Secret")
.and_then(|v| v.to_str().ok());
match header_val {
Some(val) if constant_time_eq(val, secret.as_ref()) => {}
.and_then(|v| v.to_str().ok())
.map(str::trim)
.filter(|value| !value.is_empty())
.map(hash_webhook_secret);
match header_hash {
Some(val) if constant_time_eq(&val, secret_hash.as_ref()) => {}
_ => {
tracing::warn!("Webhook: rejected request — invalid or missing X-Webhook-Secret");
let err = serde_json::json!({"error": "Unauthorized — invalid or missing X-Webhook-Secret header"});
@ -532,7 +543,7 @@ async fn handle_webhook(
let key = webhook_memory_key();
let _ = state
.mem
.store(&key, message, MemoryCategory::Conversation)
.store(&key, message, MemoryCategory::Conversation, None)
.await;
}
@ -685,7 +696,7 @@ async fn handle_whatsapp_message(
let key = whatsapp_memory_key(msg);
let _ = state
.mem
.store(&key, &msg.content, MemoryCategory::Conversation)
.store(&key, &msg.content, MemoryCategory::Conversation, None)
.await;
}
@ -697,7 +708,7 @@ async fn handle_whatsapp_message(
{
Ok(response) => {
// Send reply via WhatsApp
if let Err(e) = wa.send(&response, &msg.sender).await {
if let Err(e) = wa.send(&response, &msg.reply_target).await {
tracing::error!("Failed to send WhatsApp reply: {e}");
}
}
@ -706,7 +717,7 @@ async fn handle_whatsapp_message(
let _ = wa
.send(
"Sorry, I couldn't process your message right now.",
&msg.sender,
&msg.reply_target,
)
.await;
}
@ -798,7 +809,9 @@ mod tests {
.requests
.lock()
.unwrap_or_else(std::sync::PoisonError::into_inner);
guard.1 = Instant::now() - Duration::from_secs(RATE_LIMITER_SWEEP_INTERVAL_SECS + 1);
guard.1 = Instant::now()
.checked_sub(Duration::from_secs(RATE_LIMITER_SWEEP_INTERVAL_SECS + 1))
.unwrap();
// Clear timestamps for ip-2 and ip-3 to simulate stale entries
guard.0.get_mut("ip-2").unwrap().clear();
guard.0.get_mut("ip-3").unwrap().clear();
@ -848,6 +861,7 @@ mod tests {
let msg = ChannelMessage {
id: "wamid-123".into(),
sender: "+1234567890".into(),
reply_target: "+1234567890".into(),
content: "hello".into(),
channel: "whatsapp".into(),
timestamp: 1,
@ -871,11 +885,17 @@ mod tests {
_key: &str,
_content: &str,
_category: MemoryCategory,
_session_id: Option<&str>,
) -> anyhow::Result<()> {
Ok(())
}
async fn recall(&self, _query: &str, _limit: usize) -> anyhow::Result<Vec<MemoryEntry>> {
async fn recall(
&self,
_query: &str,
_limit: usize,
_session_id: Option<&str>,
) -> anyhow::Result<Vec<MemoryEntry>> {
Ok(Vec::new())
}
@ -886,6 +906,7 @@ mod tests {
async fn list(
&self,
_category: Option<&MemoryCategory>,
_session_id: Option<&str>,
) -> anyhow::Result<Vec<MemoryEntry>> {
Ok(Vec::new())
}
@ -938,6 +959,7 @@ mod tests {
key: &str,
_content: &str,
_category: MemoryCategory,
_session_id: Option<&str>,
) -> anyhow::Result<()> {
self.keys
.lock()
@ -946,7 +968,12 @@ mod tests {
Ok(())
}
async fn recall(&self, _query: &str, _limit: usize) -> anyhow::Result<Vec<MemoryEntry>> {
async fn recall(
&self,
_query: &str,
_limit: usize,
_session_id: Option<&str>,
) -> anyhow::Result<Vec<MemoryEntry>> {
Ok(Vec::new())
}
@ -957,6 +984,7 @@ mod tests {
async fn list(
&self,
_category: Option<&MemoryCategory>,
_session_id: Option<&str>,
) -> anyhow::Result<Vec<MemoryEntry>> {
Ok(Vec::new())
}
@ -991,7 +1019,7 @@ mod tests {
temperature: 0.0,
mem: memory,
auto_save: false,
webhook_secret: None,
webhook_secret_hash: None,
pairing: Arc::new(PairingGuard::new(false, &[])),
rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100)),
idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300))),
@ -1039,7 +1067,7 @@ mod tests {
temperature: 0.0,
mem: memory,
auto_save: true,
webhook_secret: None,
webhook_secret_hash: None,
pairing: Arc::new(PairingGuard::new(false, &[])),
rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100)),
idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300))),
@ -1077,6 +1105,125 @@ mod tests {
assert_eq!(provider_impl.calls.load(Ordering::SeqCst), 2);
}
#[test]
fn webhook_secret_hash_is_deterministic_and_nonempty() {
let one = hash_webhook_secret("secret-value");
let two = hash_webhook_secret("secret-value");
let other = hash_webhook_secret("other-value");
assert_eq!(one, two);
assert_ne!(one, other);
assert_eq!(one.len(), 64);
}
#[tokio::test]
async fn webhook_secret_hash_rejects_missing_header() {
let provider_impl = Arc::new(MockProvider::default());
let provider: Arc<dyn Provider> = provider_impl.clone();
let memory: Arc<dyn Memory> = Arc::new(MockMemory);
let state = AppState {
provider,
model: "test-model".into(),
temperature: 0.0,
mem: memory,
auto_save: false,
webhook_secret_hash: Some(Arc::from(hash_webhook_secret("super-secret"))),
pairing: Arc::new(PairingGuard::new(false, &[])),
rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100)),
idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300))),
whatsapp: None,
whatsapp_app_secret: None,
};
let response = handle_webhook(
State(state),
HeaderMap::new(),
Ok(Json(WebhookBody {
message: "hello".into(),
})),
)
.await
.into_response();
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
assert_eq!(provider_impl.calls.load(Ordering::SeqCst), 0);
}
#[tokio::test]
async fn webhook_secret_hash_rejects_invalid_header() {
let provider_impl = Arc::new(MockProvider::default());
let provider: Arc<dyn Provider> = provider_impl.clone();
let memory: Arc<dyn Memory> = Arc::new(MockMemory);
let state = AppState {
provider,
model: "test-model".into(),
temperature: 0.0,
mem: memory,
auto_save: false,
webhook_secret_hash: Some(Arc::from(hash_webhook_secret("super-secret"))),
pairing: Arc::new(PairingGuard::new(false, &[])),
rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100)),
idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300))),
whatsapp: None,
whatsapp_app_secret: None,
};
let mut headers = HeaderMap::new();
headers.insert("X-Webhook-Secret", HeaderValue::from_static("wrong-secret"));
let response = handle_webhook(
State(state),
headers,
Ok(Json(WebhookBody {
message: "hello".into(),
})),
)
.await
.into_response();
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
assert_eq!(provider_impl.calls.load(Ordering::SeqCst), 0);
}
#[tokio::test]
async fn webhook_secret_hash_accepts_valid_header() {
let provider_impl = Arc::new(MockProvider::default());
let provider: Arc<dyn Provider> = provider_impl.clone();
let memory: Arc<dyn Memory> = Arc::new(MockMemory);
let state = AppState {
provider,
model: "test-model".into(),
temperature: 0.0,
mem: memory,
auto_save: false,
webhook_secret_hash: Some(Arc::from(hash_webhook_secret("super-secret"))),
pairing: Arc::new(PairingGuard::new(false, &[])),
rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100)),
idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300))),
whatsapp: None,
whatsapp_app_secret: None,
};
let mut headers = HeaderMap::new();
headers.insert("X-Webhook-Secret", HeaderValue::from_static("super-secret"));
let response = handle_webhook(
State(state),
headers,
Ok(Json(WebhookBody {
message: "hello".into(),
})),
)
.await
.into_response();
assert_eq!(response.status(), StatusCode::OK);
assert_eq!(provider_impl.calls.load(Ordering::SeqCst), 1);
}
// ══════════════════════════════════════════════════════════
// WhatsApp Signature Verification Tests (CWE-345 Prevention)
// ══════════════════════════════════════════════════════════

View file

@ -34,8 +34,8 @@
use anyhow::{bail, Result};
use clap::{Parser, Subcommand};
use tracing::{info, Level};
use tracing_subscriber::FmtSubscriber;
use tracing::info;
use tracing_subscriber::{fmt, EnvFilter};
mod agent;
mod channels;
@ -147,24 +147,24 @@ enum Commands {
/// Start the gateway server (webhooks, websockets)
Gateway {
/// Port to listen on (use 0 for random available port)
#[arg(short, long, default_value = "8080")]
port: u16,
/// Port to listen on (use 0 for random available port); defaults to config gateway.port
#[arg(short, long)]
port: Option<u16>,
/// Host to bind to
#[arg(long, default_value = "127.0.0.1")]
host: String,
/// Host to bind to; defaults to config gateway.host
#[arg(long)]
host: Option<String>,
},
/// Start long-running autonomous runtime (gateway + channels + heartbeat + scheduler)
Daemon {
/// Port to listen on (use 0 for random available port)
#[arg(short, long, default_value = "8080")]
port: u16,
/// Port to listen on (use 0 for random available port); defaults to config gateway.port
#[arg(short, long)]
port: Option<u16>,
/// Host to bind to
#[arg(long, default_value = "127.0.0.1")]
host: String,
/// Host to bind to; defaults to config gateway.host
#[arg(long)]
host: Option<String>,
},
/// Manage OS service lifecycle (launchd/systemd user service)
@ -367,9 +367,11 @@ async fn main() -> Result<()> {
let cli = Cli::parse();
// Initialize logging
let subscriber = FmtSubscriber::builder()
.with_max_level(Level::INFO)
// Initialize logging - respects RUST_LOG env var, defaults to INFO
let subscriber = fmt::Subscriber::builder()
.with_env_filter(
EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")),
)
.finish();
tracing::subscriber::set_global_default(subscriber).expect("setting default subscriber failed");
@ -434,6 +436,8 @@ async fn main() -> Result<()> {
.map(|_| ()),
Commands::Gateway { port, host } => {
let port = port.unwrap_or(config.gateway.port);
let host = host.unwrap_or_else(|| config.gateway.host.clone());
if port == 0 {
info!("🚀 Starting ZeroClaw Gateway on {host} (random port)");
} else {
@ -443,6 +447,8 @@ async fn main() -> Result<()> {
}
Commands::Daemon { port, host } => {
let port = port.unwrap_or(config.gateway.port);
let host = host.unwrap_or_else(|| config.gateway.host.clone());
if port == 0 {
info!("🧠 Starting ZeroClaw Daemon on {host} (random port)");
} else {

View file

@ -7,6 +7,7 @@ pub enum MemoryBackendKind {
Unknown,
}
#[allow(clippy::struct_excessive_bools)]
#[derive(Debug, Clone, Copy, Eq, PartialEq)]
pub struct MemoryBackendProfile {
pub key: &'static str,

View file

@ -502,10 +502,10 @@ mod tests {
let workspace = tmp.path();
let mem = SqliteMemory::new(workspace).unwrap();
mem.store("conv_old", "outdated", MemoryCategory::Conversation)
mem.store("conv_old", "outdated", MemoryCategory::Conversation, None)
.await
.unwrap();
mem.store("core_keep", "durable", MemoryCategory::Core)
mem.store("core_keep", "durable", MemoryCategory::Core, None)
.await
.unwrap();
drop(mem);

View file

@ -24,7 +24,9 @@ pub struct LucidMemory {
impl LucidMemory {
const DEFAULT_LUCID_CMD: &'static str = "lucid";
const DEFAULT_TOKEN_BUDGET: usize = 200;
const DEFAULT_RECALL_TIMEOUT_MS: u64 = 120;
// Lucid CLI cold start can exceed 120ms on slower machines, which causes
// avoidable fallback to local-only memory and premature cooldown.
const DEFAULT_RECALL_TIMEOUT_MS: u64 = 500;
const DEFAULT_STORE_TIMEOUT_MS: u64 = 800;
const DEFAULT_LOCAL_HIT_THRESHOLD: usize = 3;
const DEFAULT_FAILURE_COOLDOWN_MS: u64 = 15_000;
@ -74,6 +76,7 @@ impl LucidMemory {
}
#[cfg(test)]
#[allow(clippy::too_many_arguments)]
fn with_options(
workspace_dir: &Path,
local: SqliteMemory,
@ -307,14 +310,22 @@ impl Memory for LucidMemory {
key: &str,
content: &str,
category: MemoryCategory,
session_id: Option<&str>,
) -> anyhow::Result<()> {
self.local.store(key, content, category.clone()).await?;
self.local
.store(key, content, category.clone(), session_id)
.await?;
self.sync_to_lucid_async(key, content, &category).await;
Ok(())
}
async fn recall(&self, query: &str, limit: usize) -> anyhow::Result<Vec<MemoryEntry>> {
let local_results = self.local.recall(query, limit).await?;
async fn recall(
&self,
query: &str,
limit: usize,
session_id: Option<&str>,
) -> anyhow::Result<Vec<MemoryEntry>> {
let local_results = self.local.recall(query, limit, session_id).await?;
if limit == 0
|| local_results.len() >= limit
|| local_results.len() >= self.local_hit_threshold
@ -351,8 +362,12 @@ impl Memory for LucidMemory {
self.local.get(key).await
}
async fn list(&self, category: Option<&MemoryCategory>) -> anyhow::Result<Vec<MemoryEntry>> {
self.local.list(category).await
async fn list(
&self,
category: Option<&MemoryCategory>,
session_id: Option<&str>,
) -> anyhow::Result<Vec<MemoryEntry>> {
self.local.list(category, session_id).await
}
async fn forget(&self, key: &str) -> anyhow::Result<bool> {
@ -396,6 +411,38 @@ EOF
exit 0
fi
echo "unsupported command" >&2
exit 1
"#;
fs::write(&script_path, script).unwrap();
let mut perms = fs::metadata(&script_path).unwrap().permissions();
perms.set_mode(0o755);
fs::set_permissions(&script_path, perms).unwrap();
script_path.display().to_string()
}
fn write_delayed_lucid_script(dir: &Path) -> String {
let script_path = dir.join("delayed-lucid.sh");
let script = r#"#!/usr/bin/env bash
set -euo pipefail
if [[ "${1:-}" == "store" ]]; then
echo '{"success":true,"id":"mem_1"}'
exit 0
fi
if [[ "${1:-}" == "context" ]]; then
# Simulate a cold start that is slower than 120ms but below the 500ms timeout.
sleep 0.2
cat <<'EOF'
<lucid-context>
- [decision] Delayed token refresh guidance
</lucid-context>
EOF
exit 0
fi
echo "unsupported command" >&2
exit 1
"#;
@ -449,7 +496,7 @@ exit 1
cmd,
200,
3,
Duration::from_millis(120),
Duration::from_millis(500),
Duration::from_millis(400),
Duration::from_secs(2),
)
@ -468,7 +515,7 @@ exit 1
let memory = test_memory(tmp.path(), "nonexistent-lucid-binary".to_string());
memory
.store("lang", "User prefers Rust", MemoryCategory::Core)
.store("lang", "User prefers Rust", MemoryCategory::Core, None)
.await
.unwrap();
@ -483,6 +530,30 @@ exit 1
let fake_cmd = write_fake_lucid_script(tmp.path());
let memory = test_memory(tmp.path(), fake_cmd);
memory
.store(
"local_note",
"Local sqlite auth fallback note",
MemoryCategory::Core,
None,
)
.await
.unwrap();
let entries = memory.recall("auth", 5, None).await.unwrap();
assert!(entries
.iter()
.any(|e| e.content.contains("Local sqlite auth fallback note")));
assert!(entries.iter().any(|e| e.content.contains("token refresh")));
}
#[tokio::test]
async fn recall_handles_lucid_cold_start_delay_within_timeout() {
let tmp = TempDir::new().unwrap();
let delayed_cmd = write_delayed_lucid_script(tmp.path());
let memory = test_memory(tmp.path(), delayed_cmd);
memory
.store(
"local_note",
@ -497,7 +568,9 @@ exit 1
assert!(entries
.iter()
.any(|e| e.content.contains("Local sqlite auth fallback note")));
assert!(entries.iter().any(|e| e.content.contains("token refresh")));
assert!(entries
.iter()
.any(|e| e.content.contains("Delayed token refresh guidance")));
}
#[tokio::test]
@ -513,17 +586,22 @@ exit 1
probe_cmd,
200,
1,
Duration::from_millis(120),
Duration::from_millis(500),
Duration::from_millis(400),
Duration::from_secs(2),
);
memory
.store("pref", "Rust should stay local-first", MemoryCategory::Core)
.store(
"pref",
"Rust should stay local-first",
MemoryCategory::Core,
None,
)
.await
.unwrap();
let entries = memory.recall("rust", 5).await.unwrap();
let entries = memory.recall("rust", 5, None).await.unwrap();
assert!(entries
.iter()
.any(|e| e.content.contains("Rust should stay local-first")));
@ -578,13 +656,13 @@ exit 1
failing_cmd,
200,
99,
Duration::from_millis(120),
Duration::from_millis(500),
Duration::from_millis(400),
Duration::from_secs(5),
);
let first = memory.recall("auth", 5).await.unwrap();
let second = memory.recall("auth", 5).await.unwrap();
let first = memory.recall("auth", 5, None).await.unwrap();
let second = memory.recall("auth", 5, None).await.unwrap();
assert!(first.is_empty());
assert!(second.is_empty());

View file

@ -143,6 +143,7 @@ impl Memory for MarkdownMemory {
key: &str,
content: &str,
category: MemoryCategory,
_session_id: Option<&str>,
) -> anyhow::Result<()> {
let entry = format!("- **{key}**: {content}");
let path = match category {
@ -152,7 +153,12 @@ impl Memory for MarkdownMemory {
self.append_to_file(&path, &entry).await
}
async fn recall(&self, query: &str, limit: usize) -> anyhow::Result<Vec<MemoryEntry>> {
async fn recall(
&self,
query: &str,
limit: usize,
_session_id: Option<&str>,
) -> anyhow::Result<Vec<MemoryEntry>> {
let all = self.read_all_entries().await?;
let query_lower = query.to_lowercase();
let keywords: Vec<&str> = query_lower.split_whitespace().collect();
@ -192,7 +198,11 @@ impl Memory for MarkdownMemory {
.find(|e| e.key == key || e.content.contains(key)))
}
async fn list(&self, category: Option<&MemoryCategory>) -> anyhow::Result<Vec<MemoryEntry>> {
async fn list(
&self,
category: Option<&MemoryCategory>,
_session_id: Option<&str>,
) -> anyhow::Result<Vec<MemoryEntry>> {
let all = self.read_all_entries().await?;
match category {
Some(cat) => Ok(all.into_iter().filter(|e| &e.category == cat).collect()),
@ -243,7 +253,7 @@ mod tests {
#[tokio::test]
async fn markdown_store_core() {
let (_tmp, mem) = temp_workspace();
mem.store("pref", "User likes Rust", MemoryCategory::Core)
mem.store("pref", "User likes Rust", MemoryCategory::Core, None)
.await
.unwrap();
let content = sync_fs::read_to_string(mem.core_path()).unwrap();
@ -253,7 +263,7 @@ mod tests {
#[tokio::test]
async fn markdown_store_daily() {
let (_tmp, mem) = temp_workspace();
mem.store("note", "Finished tests", MemoryCategory::Daily)
mem.store("note", "Finished tests", MemoryCategory::Daily, None)
.await
.unwrap();
let path = mem.daily_path();
@ -264,17 +274,17 @@ mod tests {
#[tokio::test]
async fn markdown_recall_keyword() {
let (_tmp, mem) = temp_workspace();
mem.store("a", "Rust is fast", MemoryCategory::Core)
mem.store("a", "Rust is fast", MemoryCategory::Core, None)
.await
.unwrap();
mem.store("b", "Python is slow", MemoryCategory::Core)
mem.store("b", "Python is slow", MemoryCategory::Core, None)
.await
.unwrap();
mem.store("c", "Rust and safety", MemoryCategory::Core)
mem.store("c", "Rust and safety", MemoryCategory::Core, None)
.await
.unwrap();
let results = mem.recall("Rust", 10).await.unwrap();
let results = mem.recall("Rust", 10, None).await.unwrap();
assert!(results.len() >= 2);
assert!(results
.iter()
@ -284,18 +294,20 @@ mod tests {
#[tokio::test]
async fn markdown_recall_no_match() {
let (_tmp, mem) = temp_workspace();
mem.store("a", "Rust is great", MemoryCategory::Core)
mem.store("a", "Rust is great", MemoryCategory::Core, None)
.await
.unwrap();
let results = mem.recall("javascript", 10).await.unwrap();
let results = mem.recall("javascript", 10, None).await.unwrap();
assert!(results.is_empty());
}
#[tokio::test]
async fn markdown_count() {
let (_tmp, mem) = temp_workspace();
mem.store("a", "first", MemoryCategory::Core).await.unwrap();
mem.store("b", "second", MemoryCategory::Core)
mem.store("a", "first", MemoryCategory::Core, None)
.await
.unwrap();
mem.store("b", "second", MemoryCategory::Core, None)
.await
.unwrap();
let count = mem.count().await.unwrap();
@ -305,24 +317,24 @@ mod tests {
#[tokio::test]
async fn markdown_list_by_category() {
let (_tmp, mem) = temp_workspace();
mem.store("a", "core fact", MemoryCategory::Core)
mem.store("a", "core fact", MemoryCategory::Core, None)
.await
.unwrap();
mem.store("b", "daily note", MemoryCategory::Daily)
mem.store("b", "daily note", MemoryCategory::Daily, None)
.await
.unwrap();
let core = mem.list(Some(&MemoryCategory::Core)).await.unwrap();
let core = mem.list(Some(&MemoryCategory::Core), None).await.unwrap();
assert!(core.iter().all(|e| e.category == MemoryCategory::Core));
let daily = mem.list(Some(&MemoryCategory::Daily)).await.unwrap();
let daily = mem.list(Some(&MemoryCategory::Daily), None).await.unwrap();
assert!(daily.iter().all(|e| e.category == MemoryCategory::Daily));
}
#[tokio::test]
async fn markdown_forget_is_noop() {
let (_tmp, mem) = temp_workspace();
mem.store("a", "permanent", MemoryCategory::Core)
mem.store("a", "permanent", MemoryCategory::Core, None)
.await
.unwrap();
let removed = mem.forget("a").await.unwrap();
@ -332,7 +344,7 @@ mod tests {
#[tokio::test]
async fn markdown_empty_recall() {
let (_tmp, mem) = temp_workspace();
let results = mem.recall("anything", 10).await.unwrap();
let results = mem.recall("anything", 10, None).await.unwrap();
assert!(results.is_empty());
}

View file

@ -25,11 +25,17 @@ impl Memory for NoneMemory {
_key: &str,
_content: &str,
_category: MemoryCategory,
_session_id: Option<&str>,
) -> anyhow::Result<()> {
Ok(())
}
async fn recall(&self, _query: &str, _limit: usize) -> anyhow::Result<Vec<MemoryEntry>> {
async fn recall(
&self,
_query: &str,
_limit: usize,
_session_id: Option<&str>,
) -> anyhow::Result<Vec<MemoryEntry>> {
Ok(Vec::new())
}
@ -37,7 +43,11 @@ impl Memory for NoneMemory {
Ok(None)
}
async fn list(&self, _category: Option<&MemoryCategory>) -> anyhow::Result<Vec<MemoryEntry>> {
async fn list(
&self,
_category: Option<&MemoryCategory>,
_session_id: Option<&str>,
) -> anyhow::Result<Vec<MemoryEntry>> {
Ok(Vec::new())
}
@ -62,11 +72,14 @@ mod tests {
async fn none_memory_is_noop() {
let memory = NoneMemory::new();
memory.store("k", "v", MemoryCategory::Core).await.unwrap();
memory
.store("k", "v", MemoryCategory::Core, None)
.await
.unwrap();
assert!(memory.get("k").await.unwrap().is_none());
assert!(memory.recall("k", 10).await.unwrap().is_empty());
assert!(memory.list(None).await.unwrap().is_empty());
assert!(memory.recall("k", 10, None).await.unwrap().is_empty());
assert!(memory.list(None, None).await.unwrap().is_empty());
assert!(!memory.forget("k").await.unwrap());
assert_eq!(memory.count().await.unwrap(), 0);
assert!(memory.health_check().await);

View file

@ -157,7 +157,7 @@ impl ResponseCache {
|row| row.get(0),
)?;
#[allow(clippy::cast_sign_loss)]
#[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)]
Ok((count as usize, hits as u64, tokens_saved as u64))
}

View file

@ -124,6 +124,19 @@ impl SqliteMemory {
);
CREATE INDEX IF NOT EXISTS idx_cache_accessed ON embedding_cache(accessed_at);",
)?;
// Migration: add session_id column if not present (safe to run repeatedly)
let has_session_id: bool = conn
.prepare("SELECT sql FROM sqlite_master WHERE type='table' AND name='memories'")?
.query_row([], |row| row.get::<_, String>(0))?
.contains("session_id");
if !has_session_id {
conn.execute_batch(
"ALTER TABLE memories ADD COLUMN session_id TEXT;
CREATE INDEX IF NOT EXISTS idx_memories_session ON memories(session_id);",
)?;
}
Ok(())
}
@ -361,6 +374,7 @@ impl Memory for SqliteMemory {
key: &str,
content: &str,
category: MemoryCategory,
session_id: Option<&str>,
) -> anyhow::Result<()> {
// Compute embedding (async, before lock)
let embedding_bytes = self
@ -377,20 +391,26 @@ impl Memory for SqliteMemory {
let id = Uuid::new_v4().to_string();
conn.execute(
"INSERT INTO memories (id, key, content, category, embedding, created_at, updated_at)
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)
"INSERT INTO memories (id, key, content, category, embedding, created_at, updated_at, session_id)
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)
ON CONFLICT(key) DO UPDATE SET
content = excluded.content,
category = excluded.category,
embedding = excluded.embedding,
updated_at = excluded.updated_at",
params![id, key, content, cat, embedding_bytes, now, now],
updated_at = excluded.updated_at,
session_id = excluded.session_id",
params![id, key, content, cat, embedding_bytes, now, now, session_id],
)?;
Ok(())
}
async fn recall(&self, query: &str, limit: usize) -> anyhow::Result<Vec<MemoryEntry>> {
async fn recall(
&self,
query: &str,
limit: usize,
session_id: Option<&str>,
) -> anyhow::Result<Vec<MemoryEntry>> {
if query.trim().is_empty() {
return Ok(Vec::new());
}
@ -439,7 +459,7 @@ impl Memory for SqliteMemory {
let mut results = Vec::new();
for scored in &merged {
let mut stmt = conn.prepare(
"SELECT id, key, content, category, created_at FROM memories WHERE id = ?1",
"SELECT id, key, content, category, created_at, session_id FROM memories WHERE id = ?1",
)?;
if let Ok(entry) = stmt.query_row(params![scored.id], |row| {
Ok(MemoryEntry {
@ -448,10 +468,16 @@ impl Memory for SqliteMemory {
content: row.get(2)?,
category: Self::str_to_category(&row.get::<_, String>(3)?),
timestamp: row.get(4)?,
session_id: None,
session_id: row.get(5)?,
score: Some(f64::from(scored.final_score)),
})
}) {
// Filter by session_id if requested
if let Some(sid) = session_id {
if entry.session_id.as_deref() != Some(sid) {
continue;
}
}
results.push(entry);
}
}
@ -470,7 +496,7 @@ impl Memory for SqliteMemory {
.collect();
let where_clause = conditions.join(" OR ");
let sql = format!(
"SELECT id, key, content, category, created_at FROM memories
"SELECT id, key, content, category, created_at, session_id FROM memories
WHERE {where_clause}
ORDER BY updated_at DESC
LIMIT ?{}",
@ -493,12 +519,18 @@ impl Memory for SqliteMemory {
content: row.get(2)?,
category: Self::str_to_category(&row.get::<_, String>(3)?),
timestamp: row.get(4)?,
session_id: None,
session_id: row.get(5)?,
score: Some(1.0),
})
})?;
for row in rows {
results.push(row?);
let entry = row?;
if let Some(sid) = session_id {
if entry.session_id.as_deref() != Some(sid) {
continue;
}
}
results.push(entry);
}
}
}
@ -514,7 +546,7 @@ impl Memory for SqliteMemory {
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
let mut stmt = conn.prepare(
"SELECT id, key, content, category, created_at FROM memories WHERE key = ?1",
"SELECT id, key, content, category, created_at, session_id FROM memories WHERE key = ?1",
)?;
let mut rows = stmt.query_map(params![key], |row| {
@ -524,7 +556,7 @@ impl Memory for SqliteMemory {
content: row.get(2)?,
category: Self::str_to_category(&row.get::<_, String>(3)?),
timestamp: row.get(4)?,
session_id: None,
session_id: row.get(5)?,
score: None,
})
})?;
@ -535,7 +567,11 @@ impl Memory for SqliteMemory {
}
}
async fn list(&self, category: Option<&MemoryCategory>) -> anyhow::Result<Vec<MemoryEntry>> {
async fn list(
&self,
category: Option<&MemoryCategory>,
session_id: Option<&str>,
) -> anyhow::Result<Vec<MemoryEntry>> {
let conn = self
.conn
.lock()
@ -550,7 +586,7 @@ impl Memory for SqliteMemory {
content: row.get(2)?,
category: Self::str_to_category(&row.get::<_, String>(3)?),
timestamp: row.get(4)?,
session_id: None,
session_id: row.get(5)?,
score: None,
})
};
@ -558,21 +594,33 @@ impl Memory for SqliteMemory {
if let Some(cat) = category {
let cat_str = Self::category_to_str(cat);
let mut stmt = conn.prepare(
"SELECT id, key, content, category, created_at FROM memories
"SELECT id, key, content, category, created_at, session_id FROM memories
WHERE category = ?1 ORDER BY updated_at DESC",
)?;
let rows = stmt.query_map(params![cat_str], row_mapper)?;
for row in rows {
results.push(row?);
let entry = row?;
if let Some(sid) = session_id {
if entry.session_id.as_deref() != Some(sid) {
continue;
}
}
results.push(entry);
}
} else {
let mut stmt = conn.prepare(
"SELECT id, key, content, category, created_at FROM memories
"SELECT id, key, content, category, created_at, session_id FROM memories
ORDER BY updated_at DESC",
)?;
let rows = stmt.query_map([], row_mapper)?;
for row in rows {
results.push(row?);
let entry = row?;
if let Some(sid) = session_id {
if entry.session_id.as_deref() != Some(sid) {
continue;
}
}
results.push(entry);
}
}
@ -632,7 +680,7 @@ mod tests {
#[tokio::test]
async fn sqlite_store_and_get() {
let (_tmp, mem) = temp_sqlite();
mem.store("user_lang", "Prefers Rust", MemoryCategory::Core)
mem.store("user_lang", "Prefers Rust", MemoryCategory::Core, None)
.await
.unwrap();
@ -647,10 +695,10 @@ mod tests {
#[tokio::test]
async fn sqlite_store_upsert() {
let (_tmp, mem) = temp_sqlite();
mem.store("pref", "likes Rust", MemoryCategory::Core)
mem.store("pref", "likes Rust", MemoryCategory::Core, None)
.await
.unwrap();
mem.store("pref", "loves Rust", MemoryCategory::Core)
mem.store("pref", "loves Rust", MemoryCategory::Core, None)
.await
.unwrap();
@ -662,17 +710,22 @@ mod tests {
#[tokio::test]
async fn sqlite_recall_keyword() {
let (_tmp, mem) = temp_sqlite();
mem.store("a", "Rust is fast and safe", MemoryCategory::Core)
mem.store("a", "Rust is fast and safe", MemoryCategory::Core, None)
.await
.unwrap();
mem.store("b", "Python is interpreted", MemoryCategory::Core)
mem.store("b", "Python is interpreted", MemoryCategory::Core, None)
.await
.unwrap();
mem.store("c", "Rust has zero-cost abstractions", MemoryCategory::Core)
mem.store(
"c",
"Rust has zero-cost abstractions",
MemoryCategory::Core,
None,
)
.await
.unwrap();
let results = mem.recall("Rust", 10).await.unwrap();
let results = mem.recall("Rust", 10, None).await.unwrap();
assert_eq!(results.len(), 2);
assert!(results
.iter()
@ -682,14 +735,14 @@ mod tests {
#[tokio::test]
async fn sqlite_recall_multi_keyword() {
let (_tmp, mem) = temp_sqlite();
mem.store("a", "Rust is fast", MemoryCategory::Core)
mem.store("a", "Rust is fast", MemoryCategory::Core, None)
.await
.unwrap();
mem.store("b", "Rust is safe and fast", MemoryCategory::Core)
mem.store("b", "Rust is safe and fast", MemoryCategory::Core, None)
.await
.unwrap();
let results = mem.recall("fast safe", 10).await.unwrap();
let results = mem.recall("fast safe", 10, None).await.unwrap();
assert!(!results.is_empty());
// Entry with both keywords should score higher
assert!(results[0].content.contains("safe") && results[0].content.contains("fast"));
@ -698,17 +751,17 @@ mod tests {
#[tokio::test]
async fn sqlite_recall_no_match() {
let (_tmp, mem) = temp_sqlite();
mem.store("a", "Rust rocks", MemoryCategory::Core)
mem.store("a", "Rust rocks", MemoryCategory::Core, None)
.await
.unwrap();
let results = mem.recall("javascript", 10).await.unwrap();
let results = mem.recall("javascript", 10, None).await.unwrap();
assert!(results.is_empty());
}
#[tokio::test]
async fn sqlite_forget() {
let (_tmp, mem) = temp_sqlite();
mem.store("temp", "temporary data", MemoryCategory::Conversation)
mem.store("temp", "temporary data", MemoryCategory::Conversation, None)
.await
.unwrap();
assert_eq!(mem.count().await.unwrap(), 1);
@ -728,29 +781,37 @@ mod tests {
#[tokio::test]
async fn sqlite_list_all() {
let (_tmp, mem) = temp_sqlite();
mem.store("a", "one", MemoryCategory::Core).await.unwrap();
mem.store("b", "two", MemoryCategory::Daily).await.unwrap();
mem.store("c", "three", MemoryCategory::Conversation)
mem.store("a", "one", MemoryCategory::Core, None)
.await
.unwrap();
mem.store("b", "two", MemoryCategory::Daily, None)
.await
.unwrap();
mem.store("c", "three", MemoryCategory::Conversation, None)
.await
.unwrap();
let all = mem.list(None).await.unwrap();
let all = mem.list(None, None).await.unwrap();
assert_eq!(all.len(), 3);
}
#[tokio::test]
async fn sqlite_list_by_category() {
let (_tmp, mem) = temp_sqlite();
mem.store("a", "core1", MemoryCategory::Core).await.unwrap();
mem.store("b", "core2", MemoryCategory::Core).await.unwrap();
mem.store("c", "daily1", MemoryCategory::Daily)
mem.store("a", "core1", MemoryCategory::Core, None)
.await
.unwrap();
mem.store("b", "core2", MemoryCategory::Core, None)
.await
.unwrap();
mem.store("c", "daily1", MemoryCategory::Daily, None)
.await
.unwrap();
let core = mem.list(Some(&MemoryCategory::Core)).await.unwrap();
let core = mem.list(Some(&MemoryCategory::Core), None).await.unwrap();
assert_eq!(core.len(), 2);
let daily = mem.list(Some(&MemoryCategory::Daily)).await.unwrap();
let daily = mem.list(Some(&MemoryCategory::Daily), None).await.unwrap();
assert_eq!(daily.len(), 1);
}
@ -772,7 +833,7 @@ mod tests {
{
let mem = SqliteMemory::new(tmp.path()).unwrap();
mem.store("persist", "I survive restarts", MemoryCategory::Core)
mem.store("persist", "I survive restarts", MemoryCategory::Core, None)
.await
.unwrap();
}
@ -795,7 +856,7 @@ mod tests {
];
for (i, cat) in categories.iter().enumerate() {
mem.store(&format!("k{i}"), &format!("v{i}"), cat.clone())
mem.store(&format!("k{i}"), &format!("v{i}"), cat.clone(), None)
.await
.unwrap();
}
@ -815,21 +876,28 @@ mod tests {
"a",
"Rust is a systems programming language",
MemoryCategory::Core,
None,
)
.await
.unwrap();
mem.store("b", "Python is great for scripting", MemoryCategory::Core)
mem.store(
"b",
"Python is great for scripting",
MemoryCategory::Core,
None,
)
.await
.unwrap();
mem.store(
"c",
"Rust and Rust and Rust everywhere",
MemoryCategory::Core,
None,
)
.await
.unwrap();
let results = mem.recall("Rust", 10).await.unwrap();
let results = mem.recall("Rust", 10, None).await.unwrap();
assert!(results.len() >= 2);
// All results should contain "Rust"
for r in &results {
@ -844,17 +912,17 @@ mod tests {
#[tokio::test]
async fn fts5_multi_word_query() {
let (_tmp, mem) = temp_sqlite();
mem.store("a", "The quick brown fox jumps", MemoryCategory::Core)
mem.store("a", "The quick brown fox jumps", MemoryCategory::Core, None)
.await
.unwrap();
mem.store("b", "A lazy dog sleeps", MemoryCategory::Core)
mem.store("b", "A lazy dog sleeps", MemoryCategory::Core, None)
.await
.unwrap();
mem.store("c", "The quick dog runs fast", MemoryCategory::Core)
mem.store("c", "The quick dog runs fast", MemoryCategory::Core, None)
.await
.unwrap();
let results = mem.recall("quick dog", 10).await.unwrap();
let results = mem.recall("quick dog", 10, None).await.unwrap();
assert!(!results.is_empty());
// "The quick dog runs fast" matches both terms
assert!(results[0].content.contains("quick"));
@ -863,16 +931,20 @@ mod tests {
#[tokio::test]
async fn recall_empty_query_returns_empty() {
let (_tmp, mem) = temp_sqlite();
mem.store("a", "data", MemoryCategory::Core).await.unwrap();
let results = mem.recall("", 10).await.unwrap();
mem.store("a", "data", MemoryCategory::Core, None)
.await
.unwrap();
let results = mem.recall("", 10, None).await.unwrap();
assert!(results.is_empty());
}
#[tokio::test]
async fn recall_whitespace_query_returns_empty() {
let (_tmp, mem) = temp_sqlite();
mem.store("a", "data", MemoryCategory::Core).await.unwrap();
let results = mem.recall(" ", 10).await.unwrap();
mem.store("a", "data", MemoryCategory::Core, None)
.await
.unwrap();
let results = mem.recall(" ", 10, None).await.unwrap();
assert!(results.is_empty());
}
@ -937,7 +1009,12 @@ mod tests {
#[tokio::test]
async fn fts5_syncs_on_insert() {
let (_tmp, mem) = temp_sqlite();
mem.store("test_key", "unique_searchterm_xyz", MemoryCategory::Core)
mem.store(
"test_key",
"unique_searchterm_xyz",
MemoryCategory::Core,
None,
)
.await
.unwrap();
@ -955,7 +1032,12 @@ mod tests {
#[tokio::test]
async fn fts5_syncs_on_delete() {
let (_tmp, mem) = temp_sqlite();
mem.store("del_key", "deletable_content_abc", MemoryCategory::Core)
mem.store(
"del_key",
"deletable_content_abc",
MemoryCategory::Core,
None,
)
.await
.unwrap();
mem.forget("del_key").await.unwrap();
@ -974,10 +1056,15 @@ mod tests {
#[tokio::test]
async fn fts5_syncs_on_update() {
let (_tmp, mem) = temp_sqlite();
mem.store("upd_key", "original_content_111", MemoryCategory::Core)
mem.store(
"upd_key",
"original_content_111",
MemoryCategory::Core,
None,
)
.await
.unwrap();
mem.store("upd_key", "updated_content_222", MemoryCategory::Core)
mem.store("upd_key", "updated_content_222", MemoryCategory::Core, None)
.await
.unwrap();
@ -1019,10 +1106,10 @@ mod tests {
#[tokio::test]
async fn reindex_rebuilds_fts() {
let (_tmp, mem) = temp_sqlite();
mem.store("r1", "reindex test alpha", MemoryCategory::Core)
mem.store("r1", "reindex test alpha", MemoryCategory::Core, None)
.await
.unwrap();
mem.store("r2", "reindex test beta", MemoryCategory::Core)
mem.store("r2", "reindex test beta", MemoryCategory::Core, None)
.await
.unwrap();
@ -1031,7 +1118,7 @@ mod tests {
assert_eq!(count, 0);
// FTS should still work after rebuild
let results = mem.recall("reindex", 10).await.unwrap();
let results = mem.recall("reindex", 10, None).await.unwrap();
assert_eq!(results.len(), 2);
}
@ -1045,12 +1132,13 @@ mod tests {
&format!("k{i}"),
&format!("common keyword item {i}"),
MemoryCategory::Core,
None,
)
.await
.unwrap();
}
let results = mem.recall("common keyword", 5).await.unwrap();
let results = mem.recall("common keyword", 5, None).await.unwrap();
assert!(results.len() <= 5);
}
@ -1059,11 +1147,11 @@ mod tests {
#[tokio::test]
async fn recall_results_have_scores() {
let (_tmp, mem) = temp_sqlite();
mem.store("s1", "scored result test", MemoryCategory::Core)
mem.store("s1", "scored result test", MemoryCategory::Core, None)
.await
.unwrap();
let results = mem.recall("scored", 10).await.unwrap();
let results = mem.recall("scored", 10, None).await.unwrap();
assert!(!results.is_empty());
for r in &results {
assert!(r.score.is_some(), "Expected score on result: {:?}", r.key);
@ -1075,11 +1163,11 @@ mod tests {
#[tokio::test]
async fn recall_with_quotes_in_query() {
let (_tmp, mem) = temp_sqlite();
mem.store("q1", "He said hello world", MemoryCategory::Core)
mem.store("q1", "He said hello world", MemoryCategory::Core, None)
.await
.unwrap();
// Quotes in query should not crash FTS5
let results = mem.recall("\"hello\"", 10).await.unwrap();
let results = mem.recall("\"hello\"", 10, None).await.unwrap();
// May or may not match depending on FTS5 escaping, but must not error
assert!(results.len() <= 10);
}
@ -1087,31 +1175,34 @@ mod tests {
#[tokio::test]
async fn recall_with_asterisk_in_query() {
let (_tmp, mem) = temp_sqlite();
mem.store("a1", "wildcard test content", MemoryCategory::Core)
mem.store("a1", "wildcard test content", MemoryCategory::Core, None)
.await
.unwrap();
let results = mem.recall("wild*", 10).await.unwrap();
let results = mem.recall("wild*", 10, None).await.unwrap();
assert!(results.len() <= 10);
}
#[tokio::test]
async fn recall_with_parentheses_in_query() {
let (_tmp, mem) = temp_sqlite();
mem.store("p1", "function call test", MemoryCategory::Core)
mem.store("p1", "function call test", MemoryCategory::Core, None)
.await
.unwrap();
let results = mem.recall("function()", 10).await.unwrap();
let results = mem.recall("function()", 10, None).await.unwrap();
assert!(results.len() <= 10);
}
#[tokio::test]
async fn recall_with_sql_injection_attempt() {
let (_tmp, mem) = temp_sqlite();
mem.store("safe", "normal content", MemoryCategory::Core)
mem.store("safe", "normal content", MemoryCategory::Core, None)
.await
.unwrap();
// Should not crash or leak data
let results = mem.recall("'; DROP TABLE memories; --", 10).await.unwrap();
let results = mem
.recall("'; DROP TABLE memories; --", 10, None)
.await
.unwrap();
assert!(results.len() <= 10);
// Table should still exist
assert_eq!(mem.count().await.unwrap(), 1);
@ -1122,7 +1213,9 @@ mod tests {
#[tokio::test]
async fn store_empty_content() {
let (_tmp, mem) = temp_sqlite();
mem.store("empty", "", MemoryCategory::Core).await.unwrap();
mem.store("empty", "", MemoryCategory::Core, None)
.await
.unwrap();
let entry = mem.get("empty").await.unwrap().unwrap();
assert_eq!(entry.content, "");
}
@ -1130,7 +1223,7 @@ mod tests {
#[tokio::test]
async fn store_empty_key() {
let (_tmp, mem) = temp_sqlite();
mem.store("", "content for empty key", MemoryCategory::Core)
mem.store("", "content for empty key", MemoryCategory::Core, None)
.await
.unwrap();
let entry = mem.get("").await.unwrap().unwrap();
@ -1141,7 +1234,7 @@ mod tests {
async fn store_very_long_content() {
let (_tmp, mem) = temp_sqlite();
let long_content = "x".repeat(100_000);
mem.store("long", &long_content, MemoryCategory::Core)
mem.store("long", &long_content, MemoryCategory::Core, None)
.await
.unwrap();
let entry = mem.get("long").await.unwrap().unwrap();
@ -1151,7 +1244,12 @@ mod tests {
#[tokio::test]
async fn store_unicode_and_emoji() {
let (_tmp, mem) = temp_sqlite();
mem.store("emoji_key_🦀", "こんにちは 🚀 Ñoño", MemoryCategory::Core)
mem.store(
"emoji_key_🦀",
"こんにちは 🚀 Ñoño",
MemoryCategory::Core,
None,
)
.await
.unwrap();
let entry = mem.get("emoji_key_🦀").await.unwrap().unwrap();
@ -1162,7 +1260,7 @@ mod tests {
async fn store_content_with_newlines_and_tabs() {
let (_tmp, mem) = temp_sqlite();
let content = "line1\nline2\ttab\rcarriage\n\nnewparagraph";
mem.store("whitespace", content, MemoryCategory::Core)
mem.store("whitespace", content, MemoryCategory::Core, None)
.await
.unwrap();
let entry = mem.get("whitespace").await.unwrap().unwrap();
@ -1174,11 +1272,11 @@ mod tests {
#[tokio::test]
async fn recall_single_character_query() {
let (_tmp, mem) = temp_sqlite();
mem.store("a", "x marks the spot", MemoryCategory::Core)
mem.store("a", "x marks the spot", MemoryCategory::Core, None)
.await
.unwrap();
// Single char may not match FTS5 but LIKE fallback should work
let results = mem.recall("x", 10).await.unwrap();
let results = mem.recall("x", 10, None).await.unwrap();
// Should not crash; may or may not find results
assert!(results.len() <= 10);
}
@ -1186,23 +1284,23 @@ mod tests {
#[tokio::test]
async fn recall_limit_zero() {
let (_tmp, mem) = temp_sqlite();
mem.store("a", "some content", MemoryCategory::Core)
mem.store("a", "some content", MemoryCategory::Core, None)
.await
.unwrap();
let results = mem.recall("some", 0).await.unwrap();
let results = mem.recall("some", 0, None).await.unwrap();
assert!(results.is_empty());
}
#[tokio::test]
async fn recall_limit_one() {
let (_tmp, mem) = temp_sqlite();
mem.store("a", "matching content alpha", MemoryCategory::Core)
mem.store("a", "matching content alpha", MemoryCategory::Core, None)
.await
.unwrap();
mem.store("b", "matching content beta", MemoryCategory::Core)
mem.store("b", "matching content beta", MemoryCategory::Core, None)
.await
.unwrap();
let results = mem.recall("matching content", 1).await.unwrap();
let results = mem.recall("matching content", 1, None).await.unwrap();
assert_eq!(results.len(), 1);
}
@ -1213,21 +1311,22 @@ mod tests {
"rust_preferences",
"User likes systems programming",
MemoryCategory::Core,
None,
)
.await
.unwrap();
// "rust" appears in key but not content — LIKE fallback checks key too
let results = mem.recall("rust", 10).await.unwrap();
let results = mem.recall("rust", 10, None).await.unwrap();
assert!(!results.is_empty(), "Should match by key");
}
#[tokio::test]
async fn recall_unicode_query() {
let (_tmp, mem) = temp_sqlite();
mem.store("jp", "日本語のテスト", MemoryCategory::Core)
mem.store("jp", "日本語のテスト", MemoryCategory::Core, None)
.await
.unwrap();
let results = mem.recall("日本語", 10).await.unwrap();
let results = mem.recall("日本語", 10, None).await.unwrap();
assert!(!results.is_empty());
}
@ -1238,7 +1337,9 @@ mod tests {
let tmp = TempDir::new().unwrap();
{
let mem = SqliteMemory::new(tmp.path()).unwrap();
mem.store("k1", "v1", MemoryCategory::Core).await.unwrap();
mem.store("k1", "v1", MemoryCategory::Core, None)
.await
.unwrap();
}
// Open again — init_schema runs again on existing DB
let mem2 = SqliteMemory::new(tmp.path()).unwrap();
@ -1246,7 +1347,9 @@ mod tests {
assert!(entry.is_some());
assert_eq!(entry.unwrap().content, "v1");
// Store more data — should work fine
mem2.store("k2", "v2", MemoryCategory::Daily).await.unwrap();
mem2.store("k2", "v2", MemoryCategory::Daily, None)
.await
.unwrap();
assert_eq!(mem2.count().await.unwrap(), 2);
}
@ -1264,11 +1367,16 @@ mod tests {
#[tokio::test]
async fn forget_then_recall_no_ghost_results() {
let (_tmp, mem) = temp_sqlite();
mem.store("ghost", "phantom memory content", MemoryCategory::Core)
mem.store(
"ghost",
"phantom memory content",
MemoryCategory::Core,
None,
)
.await
.unwrap();
mem.forget("ghost").await.unwrap();
let results = mem.recall("phantom memory", 10).await.unwrap();
let results = mem.recall("phantom memory", 10, None).await.unwrap();
assert!(
results.is_empty(),
"Deleted memory should not appear in recall"
@ -1278,11 +1386,11 @@ mod tests {
#[tokio::test]
async fn forget_and_re_store_same_key() {
let (_tmp, mem) = temp_sqlite();
mem.store("cycle", "version 1", MemoryCategory::Core)
mem.store("cycle", "version 1", MemoryCategory::Core, None)
.await
.unwrap();
mem.forget("cycle").await.unwrap();
mem.store("cycle", "version 2", MemoryCategory::Core)
mem.store("cycle", "version 2", MemoryCategory::Core, None)
.await
.unwrap();
let entry = mem.get("cycle").await.unwrap().unwrap();
@ -1302,14 +1410,14 @@ mod tests {
#[tokio::test]
async fn reindex_twice_is_safe() {
let (_tmp, mem) = temp_sqlite();
mem.store("r1", "reindex data", MemoryCategory::Core)
mem.store("r1", "reindex data", MemoryCategory::Core, None)
.await
.unwrap();
mem.reindex().await.unwrap();
let count = mem.reindex().await.unwrap();
assert_eq!(count, 0); // Noop embedder → nothing to re-embed
// Data should still be intact
let results = mem.recall("reindex", 10).await.unwrap();
let results = mem.recall("reindex", 10, None).await.unwrap();
assert_eq!(results.len(), 1);
}
@ -1363,18 +1471,28 @@ mod tests {
#[tokio::test]
async fn list_custom_category() {
let (_tmp, mem) = temp_sqlite();
mem.store("c1", "custom1", MemoryCategory::Custom("project".into()))
mem.store(
"c1",
"custom1",
MemoryCategory::Custom("project".into()),
None,
)
.await
.unwrap();
mem.store("c2", "custom2", MemoryCategory::Custom("project".into()))
mem.store(
"c2",
"custom2",
MemoryCategory::Custom("project".into()),
None,
)
.await
.unwrap();
mem.store("c3", "other", MemoryCategory::Core)
mem.store("c3", "other", MemoryCategory::Core, None)
.await
.unwrap();
let project = mem
.list(Some(&MemoryCategory::Custom("project".into())))
.list(Some(&MemoryCategory::Custom("project".into())), None)
.await
.unwrap();
assert_eq!(project.len(), 2);
@ -1383,7 +1501,122 @@ mod tests {
#[tokio::test]
async fn list_empty_db() {
let (_tmp, mem) = temp_sqlite();
let all = mem.list(None).await.unwrap();
let all = mem.list(None, None).await.unwrap();
assert!(all.is_empty());
}
// ── Session isolation ─────────────────────────────────────────
#[tokio::test]
async fn store_and_recall_with_session_id() {
let (_tmp, mem) = temp_sqlite();
mem.store("k1", "session A fact", MemoryCategory::Core, Some("sess-a"))
.await
.unwrap();
mem.store("k2", "session B fact", MemoryCategory::Core, Some("sess-b"))
.await
.unwrap();
mem.store("k3", "no session fact", MemoryCategory::Core, None)
.await
.unwrap();
// Recall with session-a filter returns only session-a entry
let results = mem.recall("fact", 10, Some("sess-a")).await.unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].key, "k1");
assert_eq!(results[0].session_id.as_deref(), Some("sess-a"));
}
#[tokio::test]
async fn recall_no_session_filter_returns_all() {
let (_tmp, mem) = temp_sqlite();
mem.store("k1", "alpha fact", MemoryCategory::Core, Some("sess-a"))
.await
.unwrap();
mem.store("k2", "beta fact", MemoryCategory::Core, Some("sess-b"))
.await
.unwrap();
mem.store("k3", "gamma fact", MemoryCategory::Core, None)
.await
.unwrap();
// Recall without session filter returns all matching entries
let results = mem.recall("fact", 10, None).await.unwrap();
assert_eq!(results.len(), 3);
}
#[tokio::test]
async fn cross_session_recall_isolation() {
let (_tmp, mem) = temp_sqlite();
mem.store(
"secret",
"session A secret data",
MemoryCategory::Core,
Some("sess-a"),
)
.await
.unwrap();
// Session B cannot see session A data
let results = mem.recall("secret", 10, Some("sess-b")).await.unwrap();
assert!(results.is_empty());
// Session A can see its own data
let results = mem.recall("secret", 10, Some("sess-a")).await.unwrap();
assert_eq!(results.len(), 1);
}
#[tokio::test]
async fn list_with_session_filter() {
let (_tmp, mem) = temp_sqlite();
mem.store("k1", "a1", MemoryCategory::Core, Some("sess-a"))
.await
.unwrap();
mem.store("k2", "a2", MemoryCategory::Conversation, Some("sess-a"))
.await
.unwrap();
mem.store("k3", "b1", MemoryCategory::Core, Some("sess-b"))
.await
.unwrap();
mem.store("k4", "none1", MemoryCategory::Core, None)
.await
.unwrap();
// List with session-a filter
let results = mem.list(None, Some("sess-a")).await.unwrap();
assert_eq!(results.len(), 2);
assert!(results
.iter()
.all(|e| e.session_id.as_deref() == Some("sess-a")));
// List with session-a + category filter
let results = mem
.list(Some(&MemoryCategory::Core), Some("sess-a"))
.await
.unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].key, "k1");
}
#[tokio::test]
async fn schema_migration_idempotent_on_reopen() {
let tmp = TempDir::new().unwrap();
// First open: creates schema + migration
{
let mem = SqliteMemory::new(tmp.path()).unwrap();
mem.store("k1", "before reopen", MemoryCategory::Core, Some("sess-x"))
.await
.unwrap();
}
// Second open: migration runs again but is idempotent
{
let mem = SqliteMemory::new(tmp.path()).unwrap();
let results = mem.recall("reopen", 10, Some("sess-x")).await.unwrap();
assert_eq!(results.len(), 1);
assert_eq!(results[0].key, "k1");
assert_eq!(results[0].session_id.as_deref(), Some("sess-x"));
}
}
}

View file

@ -44,18 +44,32 @@ pub trait Memory: Send + Sync {
/// Backend name
fn name(&self) -> &str;
/// Store a memory entry
async fn store(&self, key: &str, content: &str, category: MemoryCategory)
-> anyhow::Result<()>;
/// Store a memory entry, optionally scoped to a session
async fn store(
&self,
key: &str,
content: &str,
category: MemoryCategory,
session_id: Option<&str>,
) -> anyhow::Result<()>;
/// Recall memories matching a query (keyword search)
async fn recall(&self, query: &str, limit: usize) -> anyhow::Result<Vec<MemoryEntry>>;
/// Recall memories matching a query (keyword search), optionally scoped to a session
async fn recall(
&self,
query: &str,
limit: usize,
session_id: Option<&str>,
) -> anyhow::Result<Vec<MemoryEntry>>;
/// Get a specific memory by key
async fn get(&self, key: &str) -> anyhow::Result<Option<MemoryEntry>>;
/// List all memory keys, optionally filtered by category
async fn list(&self, category: Option<&MemoryCategory>) -> anyhow::Result<Vec<MemoryEntry>>;
/// List all memory keys, optionally filtered by category and/or session
async fn list(
&self,
category: Option<&MemoryCategory>,
session_id: Option<&str>,
) -> anyhow::Result<Vec<MemoryEntry>>;
/// Remove a memory by key
async fn forget(&self, key: &str) -> anyhow::Result<bool>;

View file

@ -95,7 +95,9 @@ async fn migrate_openclaw_memory(
stats.renamed_conflicts += 1;
}
memory.store(&key, &entry.content, entry.category).await?;
memory
.store(&key, &entry.content, entry.category, None)
.await?;
stats.imported += 1;
}
@ -488,7 +490,7 @@ mod tests {
// Existing target memory
let target_mem = SqliteMemory::new(target.path()).unwrap();
target_mem
.store("k", "new value", MemoryCategory::Core)
.store("k", "new value", MemoryCategory::Core, None)
.await
.unwrap();
@ -510,7 +512,7 @@ mod tests {
.await
.unwrap();
let all = target_mem.list(None).await.unwrap();
let all = target_mem.list(None, None).await.unwrap();
assert!(all.iter().any(|e| e.key == "k" && e.content == "new value"));
assert!(all
.iter()

View file

@ -48,9 +48,10 @@ impl Observer for LogObserver {
ObserverEvent::AgentEnd {
duration,
tokens_used,
cost_usd,
} => {
let ms = u64::try_from(duration.as_millis()).unwrap_or(u64::MAX);
info!(duration_ms = ms, tokens = ?tokens_used, "agent.end");
info!(duration_ms = ms, tokens = ?tokens_used, cost_usd = ?cost_usd, "agent.end");
}
ObserverEvent::ToolCallStart { tool } => {
info!(tool = %tool, "tool.start");
@ -133,10 +134,12 @@ mod tests {
obs.record_event(&ObserverEvent::AgentEnd {
duration: Duration::from_millis(500),
tokens_used: Some(100),
cost_usd: Some(0.0015),
});
obs.record_event(&ObserverEvent::AgentEnd {
duration: Duration::ZERO,
tokens_used: None,
cost_usd: None,
});
obs.record_event(&ObserverEvent::ToolCallStart {
tool: "shell".into(),

View file

@ -48,10 +48,12 @@ mod tests {
obs.record_event(&ObserverEvent::AgentEnd {
duration: Duration::from_millis(100),
tokens_used: Some(42),
cost_usd: Some(0.001),
});
obs.record_event(&ObserverEvent::AgentEnd {
duration: Duration::ZERO,
tokens_used: None,
cost_usd: None,
});
obs.record_event(&ObserverEvent::ToolCallStart {
tool: "shell".into(),

View file

@ -227,6 +227,7 @@ impl Observer for OtelObserver {
ObserverEvent::AgentEnd {
duration,
tokens_used,
cost_usd,
} => {
let secs = duration.as_secs_f64();
let start_time = SystemTime::now()
@ -243,6 +244,9 @@ impl Observer for OtelObserver {
if let Some(t) = tokens_used {
span.set_attribute(KeyValue::new("tokens_used", *t as i64));
}
if let Some(c) = cost_usd {
span.set_attribute(KeyValue::new("cost_usd", *c));
}
span.end();
self.agent_duration.record(secs, &[]);
@ -394,10 +398,12 @@ mod tests {
obs.record_event(&ObserverEvent::AgentEnd {
duration: Duration::from_millis(500),
tokens_used: Some(100),
cost_usd: Some(0.0015),
});
obs.record_event(&ObserverEvent::AgentEnd {
duration: Duration::ZERO,
tokens_used: None,
cost_usd: None,
});
obs.record_event(&ObserverEvent::ToolCallStart {
tool: "shell".into(),

View file

@ -27,6 +27,7 @@ pub enum ObserverEvent {
AgentEnd {
duration: Duration,
tokens_used: Option<u64>,
cost_usd: Option<f64>,
},
/// A tool call is about to be executed.
ToolCallStart {

View file

@ -106,6 +106,7 @@ pub fn run_wizard() -> Result<Config> {
} else {
Some(api_key)
},
api_url: None,
default_provider: Some(provider),
default_model: Some(model),
default_temperature: 0.7,
@ -284,7 +285,7 @@ fn memory_config_defaults_for_backend(backend: &str) -> MemoryConfig {
#[allow(clippy::too_many_lines)]
pub fn run_quick_setup(
api_key: Option<&str>,
credential_override: Option<&str>,
provider: Option<&str>,
memory_backend: Option<&str>,
) -> Result<Config> {
@ -318,7 +319,8 @@ pub fn run_quick_setup(
let config = Config {
workspace_dir: workspace_dir.clone(),
config_path: config_path.clone(),
api_key: api_key.map(String::from),
api_key: credential_override.map(String::from),
api_url: None,
default_provider: Some(provider_name.clone()),
default_model: Some(model.clone()),
default_temperature: 0.7,
@ -377,7 +379,7 @@ pub fn run_quick_setup(
println!(
" {} API Key: {}",
style("").green().bold(),
if api_key.is_some() {
if credential_override.is_some() {
style("set").green()
} else {
style("not set (use --api-key or edit config.toml)").yellow()
@ -426,7 +428,7 @@ pub fn run_quick_setup(
);
println!();
println!(" {}", style("Next steps:").white().bold());
if api_key.is_none() {
if credential_override.is_none() {
println!(" 1. Set your API key: export OPENROUTER_API_KEY=\"sk-...\"");
println!(" 2. Or edit: ~/.zeroclaw/config.toml");
println!(" 3. Chat: zeroclaw agent -m \"Hello!\"");
@ -2269,14 +2271,11 @@ fn setup_memory() -> Result<MemoryConfig> {
let backend = backend_key_from_choice(choice);
let profile = memory_backend_profile(backend);
let auto_save = if !profile.auto_save_default {
false
} else {
Confirm::new()
let auto_save = profile.auto_save_default
&& Confirm::new()
.with_prompt(" Auto-save conversations to memory?")
.default(true)
.interact()?
};
.interact()?;
println!(
" {} Memory: {} (auto-save: {})",
@ -2587,6 +2586,7 @@ fn setup_channels() -> Result<ChannelsConfig> {
guild_id: if guild.is_empty() { None } else { Some(guild) },
allowed_users,
listen_to_bots: false,
mention_only: false,
});
}
2 => {
@ -2799,22 +2799,14 @@ fn setup_channels() -> Result<ChannelsConfig> {
.header("Authorization", format!("Bearer {access_token_clone}"))
.send()?;
let ok = resp.status().is_success();
let data: serde_json::Value = resp.json().unwrap_or_default();
let user_id = data
.get("user_id")
.and_then(serde_json::Value::as_str)
.unwrap_or("unknown")
.to_string();
Ok::<_, reqwest::Error>((ok, user_id))
Ok::<_, reqwest::Error>(ok)
})
.join();
match thread_result {
Ok(Ok((true, user_id))) => {
println!(
"\r {} Connected as {user_id} ",
Ok(Ok(true)) => println!(
"\r {} Connection verified ",
style("").green().bold()
);
}
),
_ => {
println!(
"\r {} Connection failed — check homeserver URL and token",
@ -3779,15 +3771,7 @@ fn print_summary(config: &Config) {
);
// Secrets
println!(
" {} Secrets: {}",
style("🔒").cyan(),
if config.secrets.encrypt {
style("encrypted").green().to_string()
} else {
style("plaintext").yellow().to_string()
}
);
println!(" {} Secrets: configured", style("🔒").cyan());
// Gateway
println!(

View file

@ -38,6 +38,10 @@ pub fn ensure_arduino_cli() -> Result<()> {
anyhow::bail!("brew install arduino-cli failed. Install manually: https://arduino.github.io/arduino-cli/");
}
println!("arduino-cli installed.");
if !arduino_cli_available() {
anyhow::bail!("arduino-cli still not found after install. Ensure it's in PATH.");
}
return Ok(());
}
#[cfg(target_os = "linux")]
@ -54,11 +58,6 @@ pub fn ensure_arduino_cli() -> Result<()> {
println!("arduino-cli not found. Install it: https://arduino.github.io/arduino-cli/");
anyhow::bail!("arduino-cli not installed.");
}
if !arduino_cli_available() {
anyhow::bail!("arduino-cli still not found after install. Ensure it's in PATH.");
}
Ok(())
}
/// Ensure arduino:avr core is installed.

View file

@ -112,6 +112,7 @@ pub struct SerialPeripheral {
impl SerialPeripheral {
/// Create and connect to a serial peripheral.
#[allow(clippy::unused_async)]
pub async fn connect(config: &PeripheralBoardConfig) -> anyhow::Result<Self> {
let path = config
.path

View file

@ -106,17 +106,17 @@ struct NativeContentIn {
}
impl AnthropicProvider {
pub fn new(api_key: Option<&str>) -> Self {
Self::with_base_url(api_key, None)
pub fn new(credential: Option<&str>) -> Self {
Self::with_base_url(credential, None)
}
pub fn with_base_url(api_key: Option<&str>, base_url: Option<&str>) -> Self {
pub fn with_base_url(credential: Option<&str>, base_url: Option<&str>) -> Self {
let base_url = base_url
.map(|u| u.trim_end_matches('/'))
.unwrap_or("https://api.anthropic.com")
.to_string();
Self {
credential: api_key
credential: credential
.map(str::trim)
.filter(|k| !k.is_empty())
.map(ToString::to_string),
@ -410,9 +410,9 @@ mod tests {
#[test]
fn creates_with_key() {
let p = AnthropicProvider::new(Some("sk-ant-test123"));
let p = AnthropicProvider::new(Some("anthropic-test-credential"));
assert!(p.credential.is_some());
assert_eq!(p.credential.as_deref(), Some("sk-ant-test123"));
assert_eq!(p.credential.as_deref(), Some("anthropic-test-credential"));
assert_eq!(p.base_url, "https://api.anthropic.com");
}
@ -431,17 +431,19 @@ mod tests {
#[test]
fn creates_with_whitespace_key() {
let p = AnthropicProvider::new(Some(" sk-ant-test123 "));
let p = AnthropicProvider::new(Some(" anthropic-test-credential "));
assert!(p.credential.is_some());
assert_eq!(p.credential.as_deref(), Some("sk-ant-test123"));
assert_eq!(p.credential.as_deref(), Some("anthropic-test-credential"));
}
#[test]
fn creates_with_custom_base_url() {
let p =
AnthropicProvider::with_base_url(Some("sk-ant-test"), Some("https://api.example.com"));
let p = AnthropicProvider::with_base_url(
Some("anthropic-credential"),
Some("https://api.example.com"),
);
assert_eq!(p.base_url, "https://api.example.com");
assert_eq!(p.credential.as_deref(), Some("sk-ant-test"));
assert_eq!(p.credential.as_deref(), Some("anthropic-credential"));
}
#[test]

View file

@ -17,7 +17,7 @@ use serde::{Deserialize, Serialize};
pub struct OpenAiCompatibleProvider {
pub(crate) name: String,
pub(crate) base_url: String,
pub(crate) api_key: Option<String>,
pub(crate) credential: Option<String>,
pub(crate) auth_header: AuthStyle,
/// When false, do not fall back to /v1/responses on chat completions 404.
/// GLM/Zhipu does not support the responses API.
@ -37,11 +37,16 @@ pub enum AuthStyle {
}
impl OpenAiCompatibleProvider {
pub fn new(name: &str, base_url: &str, api_key: Option<&str>, auth_style: AuthStyle) -> Self {
pub fn new(
name: &str,
base_url: &str,
credential: Option<&str>,
auth_style: AuthStyle,
) -> Self {
Self {
name: name.to_string(),
base_url: base_url.trim_end_matches('/').to_string(),
api_key: api_key.map(ToString::to_string),
credential: credential.map(ToString::to_string),
auth_header: auth_style,
supports_responses_fallback: true,
client: Client::builder()
@ -57,13 +62,13 @@ impl OpenAiCompatibleProvider {
pub fn new_no_responses_fallback(
name: &str,
base_url: &str,
api_key: Option<&str>,
credential: Option<&str>,
auth_style: AuthStyle,
) -> Self {
Self {
name: name.to_string(),
base_url: base_url.trim_end_matches('/').to_string(),
api_key: api_key.map(ToString::to_string),
credential: credential.map(ToString::to_string),
auth_header: auth_style,
supports_responses_fallback: false,
client: Client::builder()
@ -405,18 +410,18 @@ impl OpenAiCompatibleProvider {
fn apply_auth_header(
&self,
req: reqwest::RequestBuilder,
api_key: &str,
credential: &str,
) -> reqwest::RequestBuilder {
match &self.auth_header {
AuthStyle::Bearer => req.header("Authorization", format!("Bearer {api_key}")),
AuthStyle::XApiKey => req.header("x-api-key", api_key),
AuthStyle::Custom(header) => req.header(header, api_key),
AuthStyle::Bearer => req.header("Authorization", format!("Bearer {credential}")),
AuthStyle::XApiKey => req.header("x-api-key", credential),
AuthStyle::Custom(header) => req.header(header, credential),
}
}
async fn chat_via_responses(
&self,
api_key: &str,
credential: &str,
system_prompt: Option<&str>,
message: &str,
model: &str,
@ -434,7 +439,7 @@ impl OpenAiCompatibleProvider {
let url = self.responses_url();
let response = self
.apply_auth_header(self.client.post(&url).json(&request), api_key)
.apply_auth_header(self.client.post(&url).json(&request), credential)
.send()
.await?;
@ -459,7 +464,7 @@ impl Provider for OpenAiCompatibleProvider {
model: &str,
temperature: f64,
) -> anyhow::Result<String> {
let api_key = self.api_key.as_ref().ok_or_else(|| {
let credential = self.credential.as_ref().ok_or_else(|| {
anyhow::anyhow!(
"{} API key not set. Run `zeroclaw onboard` or set the appropriate env var.",
self.name
@ -490,7 +495,7 @@ impl Provider for OpenAiCompatibleProvider {
let url = self.chat_completions_url();
let response = self
.apply_auth_header(self.client.post(&url).json(&request), api_key)
.apply_auth_header(self.client.post(&url).json(&request), credential)
.send()
.await?;
@ -501,7 +506,7 @@ impl Provider for OpenAiCompatibleProvider {
if status == reqwest::StatusCode::NOT_FOUND && self.supports_responses_fallback {
return self
.chat_via_responses(api_key, system_prompt, message, model)
.chat_via_responses(credential, system_prompt, message, model)
.await
.map_err(|responses_err| {
anyhow::anyhow!(
@ -545,7 +550,7 @@ impl Provider for OpenAiCompatibleProvider {
model: &str,
temperature: f64,
) -> anyhow::Result<String> {
let api_key = self.api_key.as_ref().ok_or_else(|| {
let credential = self.credential.as_ref().ok_or_else(|| {
anyhow::anyhow!(
"{} API key not set. Run `zeroclaw onboard` or set the appropriate env var.",
self.name
@ -569,7 +574,7 @@ impl Provider for OpenAiCompatibleProvider {
let url = self.chat_completions_url();
let response = self
.apply_auth_header(self.client.post(&url).json(&request), api_key)
.apply_auth_header(self.client.post(&url).json(&request), credential)
.send()
.await?;
@ -584,7 +589,7 @@ impl Provider for OpenAiCompatibleProvider {
if let Some(user_msg) = last_user {
return self
.chat_via_responses(
api_key,
credential,
system.map(|m| m.content.as_str()),
&user_msg.content,
model,
@ -791,16 +796,20 @@ mod tests {
#[test]
fn creates_with_key() {
let p = make_provider("venice", "https://api.venice.ai", Some("vn-key"));
let p = make_provider(
"venice",
"https://api.venice.ai",
Some("venice-test-credential"),
);
assert_eq!(p.name, "venice");
assert_eq!(p.base_url, "https://api.venice.ai");
assert_eq!(p.api_key.as_deref(), Some("vn-key"));
assert_eq!(p.credential.as_deref(), Some("venice-test-credential"));
}
#[test]
fn creates_without_key() {
let p = make_provider("test", "https://example.com", None);
assert!(p.api_key.is_none());
assert!(p.credential.is_none());
}
#[test]
@ -894,6 +903,7 @@ mod tests {
make_provider("Groq", "https://api.groq.com/openai", None),
make_provider("Mistral", "https://api.mistral.ai", None),
make_provider("xAI", "https://api.x.ai", None),
make_provider("Astrai", "https://as-trai.com/v1", None),
];
for p in providers {

705
src/providers/copilot.rs Normal file
View file

@ -0,0 +1,705 @@
//! GitHub Copilot provider with OAuth device-flow authentication.
//!
//! Authenticates via GitHub's device code flow (same as VS Code Copilot),
//! then exchanges the OAuth token for short-lived Copilot API keys.
//! Tokens are cached to disk and auto-refreshed.
//!
//! **Note:** This uses VS Code's OAuth client ID (`Iv1.b507a08c87ecfe98`) and
//! editor headers. This is the same approach used by LiteLLM, Codex CLI,
//! and other third-party Copilot integrations. The Copilot token endpoint is
//! private; there is no public OAuth scope or app registration for it.
//! GitHub could change or revoke this at any time, which would break all
//! third-party integrations simultaneously.
use crate::providers::traits::{
ChatMessage, ChatRequest as ProviderChatRequest, ChatResponse as ProviderChatResponse,
Provider, ToolCall as ProviderToolCall,
};
use crate::tools::ToolSpec;
use async_trait::async_trait;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::path::{Path, PathBuf};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::Mutex;
use tracing::warn;
/// GitHub OAuth client ID for Copilot (VS Code extension).
const GITHUB_CLIENT_ID: &str = "Iv1.b507a08c87ecfe98";
const GITHUB_DEVICE_CODE_URL: &str = "https://github.com/login/device/code";
const GITHUB_ACCESS_TOKEN_URL: &str = "https://github.com/login/oauth/access_token";
const GITHUB_API_KEY_URL: &str = "https://api.github.com/copilot_internal/v2/token";
const DEFAULT_API: &str = "https://api.githubcopilot.com";
// ── Token types ──────────────────────────────────────────────────
#[derive(Debug, Deserialize)]
struct DeviceCodeResponse {
device_code: String,
user_code: String,
verification_uri: String,
#[serde(default = "default_interval")]
interval: u64,
#[serde(default = "default_expires_in")]
expires_in: u64,
}
fn default_interval() -> u64 {
5
}
fn default_expires_in() -> u64 {
900
}
#[derive(Debug, Deserialize)]
struct AccessTokenResponse {
access_token: Option<String>,
error: Option<String>,
}
#[derive(Debug, Serialize, Deserialize)]
struct ApiKeyInfo {
token: String,
expires_at: i64,
#[serde(default)]
endpoints: Option<ApiEndpoints>,
}
#[derive(Debug, Serialize, Deserialize)]
struct ApiEndpoints {
api: Option<String>,
}
struct CachedApiKey {
token: String,
api_endpoint: String,
expires_at: i64,
}
// ── Chat completions types ───────────────────────────────────────
#[derive(Debug, Serialize)]
struct ApiChatRequest {
model: String,
messages: Vec<ApiMessage>,
temperature: f64,
#[serde(skip_serializing_if = "Option::is_none")]
tools: Option<Vec<NativeToolSpec>>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_choice: Option<String>,
}
#[derive(Debug, Serialize)]
struct ApiMessage {
role: String,
#[serde(skip_serializing_if = "Option::is_none")]
content: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_call_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
tool_calls: Option<Vec<NativeToolCall>>,
}
#[derive(Debug, Serialize)]
struct NativeToolSpec {
#[serde(rename = "type")]
kind: String,
function: NativeToolFunctionSpec,
}
#[derive(Debug, Serialize)]
struct NativeToolFunctionSpec {
name: String,
description: String,
parameters: serde_json::Value,
}
#[derive(Debug, Serialize, Deserialize)]
struct NativeToolCall {
#[serde(skip_serializing_if = "Option::is_none")]
id: Option<String>,
#[serde(rename = "type", skip_serializing_if = "Option::is_none")]
kind: Option<String>,
function: NativeFunctionCall,
}
#[derive(Debug, Serialize, Deserialize)]
struct NativeFunctionCall {
name: String,
arguments: String,
}
#[derive(Debug, Deserialize)]
struct ApiChatResponse {
choices: Vec<Choice>,
}
#[derive(Debug, Deserialize)]
struct Choice {
message: ResponseMessage,
}
#[derive(Debug, Deserialize)]
struct ResponseMessage {
#[serde(default)]
content: Option<String>,
#[serde(default)]
tool_calls: Option<Vec<NativeToolCall>>,
}
// ── Provider ─────────────────────────────────────────────────────
/// GitHub Copilot provider with automatic OAuth and token refresh.
///
/// On first use, prompts the user to visit github.com/login/device.
/// Tokens are cached to `~/.config/zeroclaw/copilot/` and refreshed
/// automatically.
pub struct CopilotProvider {
github_token: Option<String>,
/// Mutex ensures only one caller refreshes tokens at a time,
/// preventing duplicate device flow prompts or redundant API calls.
refresh_lock: Arc<Mutex<Option<CachedApiKey>>>,
http: Client,
token_dir: PathBuf,
}
impl CopilotProvider {
pub fn new(github_token: Option<&str>) -> Self {
let token_dir = directories::ProjectDirs::from("", "", "zeroclaw")
.map(|dir| dir.config_dir().join("copilot"))
.unwrap_or_else(|| {
// Fall back to a user-specific temp directory to avoid
// shared-directory symlink attacks.
let user = std::env::var("USER")
.or_else(|_| std::env::var("USERNAME"))
.unwrap_or_else(|_| "unknown".to_string());
std::env::temp_dir().join(format!("zeroclaw-copilot-{user}"))
});
if let Err(err) = std::fs::create_dir_all(&token_dir) {
warn!(
"Failed to create Copilot token directory {:?}: {err}. Token caching is disabled.",
token_dir
);
} else {
#[cfg(unix)]
{
use std::os::unix::fs::PermissionsExt;
if let Err(err) =
std::fs::set_permissions(&token_dir, std::fs::Permissions::from_mode(0o700))
{
warn!(
"Failed to set Copilot token directory permissions on {:?}: {err}",
token_dir
);
}
}
}
Self {
github_token: github_token
.filter(|token| !token.is_empty())
.map(String::from),
refresh_lock: Arc::new(Mutex::new(None)),
http: Client::builder()
.timeout(Duration::from_secs(120))
.connect_timeout(Duration::from_secs(10))
.build()
.unwrap_or_else(|_| Client::new()),
token_dir,
}
}
/// Required headers for Copilot API requests (editor identification).
const COPILOT_HEADERS: [(&str, &str); 4] = [
("Editor-Version", "vscode/1.85.1"),
("Editor-Plugin-Version", "copilot/1.155.0"),
("User-Agent", "GithubCopilot/1.155.0"),
("Accept", "application/json"),
];
fn convert_tools(tools: Option<&[ToolSpec]>) -> Option<Vec<NativeToolSpec>> {
tools.map(|items| {
items
.iter()
.map(|tool| NativeToolSpec {
kind: "function".to_string(),
function: NativeToolFunctionSpec {
name: tool.name.clone(),
description: tool.description.clone(),
parameters: tool.parameters.clone(),
},
})
.collect()
})
}
fn convert_messages(messages: &[ChatMessage]) -> Vec<ApiMessage> {
messages
.iter()
.map(|message| {
if message.role == "assistant" {
if let Ok(value) = serde_json::from_str::<serde_json::Value>(&message.content) {
if let Some(tool_calls_value) = value.get("tool_calls") {
if let Ok(parsed_calls) =
serde_json::from_value::<Vec<ProviderToolCall>>(tool_calls_value.clone())
{
let tool_calls = parsed_calls
.into_iter()
.map(|tool_call| NativeToolCall {
id: Some(tool_call.id),
kind: Some("function".to_string()),
function: NativeFunctionCall {
name: tool_call.name,
arguments: tool_call.arguments,
},
})
.collect::<Vec<_>>();
let content = value
.get("content")
.and_then(serde_json::Value::as_str)
.map(ToString::to_string);
return ApiMessage {
role: "assistant".to_string(),
content,
tool_call_id: None,
tool_calls: Some(tool_calls),
};
}
}
}
}
if message.role == "tool" {
if let Ok(value) = serde_json::from_str::<serde_json::Value>(&message.content) {
let tool_call_id = value
.get("tool_call_id")
.and_then(serde_json::Value::as_str)
.map(ToString::to_string);
let content = value
.get("content")
.and_then(serde_json::Value::as_str)
.map(ToString::to_string);
return ApiMessage {
role: "tool".to_string(),
content,
tool_call_id,
tool_calls: None,
};
}
}
ApiMessage {
role: message.role.clone(),
content: Some(message.content.clone()),
tool_call_id: None,
tool_calls: None,
}
})
.collect()
}
/// Send a chat completions request with required Copilot headers.
async fn send_chat_request(
&self,
messages: Vec<ApiMessage>,
tools: Option<&[ToolSpec]>,
model: &str,
temperature: f64,
) -> anyhow::Result<ProviderChatResponse> {
let (token, endpoint) = self.get_api_key().await?;
let url = format!("{}/chat/completions", endpoint.trim_end_matches('/'));
let native_tools = Self::convert_tools(tools);
let request = ApiChatRequest {
model: model.to_string(),
messages,
temperature,
tool_choice: native_tools.as_ref().map(|_| "auto".to_string()),
tools: native_tools,
};
let mut req = self
.http
.post(&url)
.header("Authorization", format!("Bearer {token}"))
.json(&request);
for (header, value) in &Self::COPILOT_HEADERS {
req = req.header(*header, *value);
}
let response = req.send().await?;
if !response.status().is_success() {
return Err(super::api_error("GitHub Copilot", response).await);
}
let api_response: ApiChatResponse = response.json().await?;
let choice = api_response
.choices
.into_iter()
.next()
.ok_or_else(|| anyhow::anyhow!("No response from GitHub Copilot"))?;
let tool_calls = choice
.message
.tool_calls
.unwrap_or_default()
.into_iter()
.map(|tool_call| ProviderToolCall {
id: tool_call
.id
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string()),
name: tool_call.function.name,
arguments: tool_call.function.arguments,
})
.collect();
Ok(ProviderChatResponse {
text: choice.message.content,
tool_calls,
})
}
/// Get a valid Copilot API key, refreshing or re-authenticating as needed.
/// Uses a Mutex to ensure only one caller refreshes at a time.
async fn get_api_key(&self) -> anyhow::Result<(String, String)> {
let mut cached = self.refresh_lock.lock().await;
if let Some(cached_key) = cached.as_ref() {
if chrono::Utc::now().timestamp() + 120 < cached_key.expires_at {
return Ok((cached_key.token.clone(), cached_key.api_endpoint.clone()));
}
}
if let Some(info) = self.load_api_key_from_disk().await {
if chrono::Utc::now().timestamp() + 120 < info.expires_at {
let endpoint = info
.endpoints
.as_ref()
.and_then(|e| e.api.clone())
.unwrap_or_else(|| DEFAULT_API.to_string());
let token = info.token;
*cached = Some(CachedApiKey {
token: token.clone(),
api_endpoint: endpoint.clone(),
expires_at: info.expires_at,
});
return Ok((token, endpoint));
}
}
let access_token = self.get_github_access_token().await?;
let api_key_info = self.exchange_for_api_key(&access_token).await?;
self.save_api_key_to_disk(&api_key_info).await;
let endpoint = api_key_info
.endpoints
.as_ref()
.and_then(|e| e.api.clone())
.unwrap_or_else(|| DEFAULT_API.to_string());
*cached = Some(CachedApiKey {
token: api_key_info.token.clone(),
api_endpoint: endpoint.clone(),
expires_at: api_key_info.expires_at,
});
Ok((api_key_info.token, endpoint))
}
/// Get a GitHub access token from config, cache, or device flow.
async fn get_github_access_token(&self) -> anyhow::Result<String> {
if let Some(token) = &self.github_token {
return Ok(token.clone());
}
let access_token_path = self.token_dir.join("access-token");
if let Ok(cached) = tokio::fs::read_to_string(&access_token_path).await {
let token = cached.trim();
if !token.is_empty() {
return Ok(token.to_string());
}
}
let token = self.device_code_login().await?;
write_file_secure(&access_token_path, &token).await;
Ok(token)
}
/// Run GitHub OAuth device code flow.
async fn device_code_login(&self) -> anyhow::Result<String> {
let response: DeviceCodeResponse = self
.http
.post(GITHUB_DEVICE_CODE_URL)
.header("Accept", "application/json")
.json(&serde_json::json!({
"client_id": GITHUB_CLIENT_ID,
"scope": "read:user"
}))
.send()
.await?
.error_for_status()?
.json()
.await?;
let mut poll_interval = Duration::from_secs(response.interval.max(5));
let expires_in = response.expires_in.max(1);
let expires_at = tokio::time::Instant::now() + Duration::from_secs(expires_in);
eprintln!(
"\nGitHub Copilot authentication is required.\n\
Visit: {}\n\
Code: {}\n\
Waiting for authorization...\n",
response.verification_uri, response.user_code
);
while tokio::time::Instant::now() < expires_at {
tokio::time::sleep(poll_interval).await;
let token_response: AccessTokenResponse = self
.http
.post(GITHUB_ACCESS_TOKEN_URL)
.header("Accept", "application/json")
.json(&serde_json::json!({
"client_id": GITHUB_CLIENT_ID,
"device_code": response.device_code,
"grant_type": "urn:ietf:params:oauth:grant-type:device_code"
}))
.send()
.await?
.json()
.await?;
if let Some(token) = token_response.access_token {
eprintln!("Authentication succeeded.\n");
return Ok(token);
}
match token_response.error.as_deref() {
Some("slow_down") => {
poll_interval += Duration::from_secs(5);
}
Some("authorization_pending") | None => {}
Some("expired_token") => {
anyhow::bail!("GitHub device authorization expired")
}
Some(error) => anyhow::bail!("GitHub auth failed: {error}"),
}
}
anyhow::bail!("Timed out waiting for GitHub authorization")
}
/// Exchange a GitHub access token for a Copilot API key.
async fn exchange_for_api_key(&self, access_token: &str) -> anyhow::Result<ApiKeyInfo> {
let mut request = self.http.get(GITHUB_API_KEY_URL);
for (header, value) in &Self::COPILOT_HEADERS {
request = request.header(*header, *value);
}
request = request.header("Authorization", format!("token {access_token}"));
let response = request.send().await?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
let sanitized = super::sanitize_api_error(&body);
if status.as_u16() == 401 || status.as_u16() == 403 {
let access_token_path = self.token_dir.join("access-token");
tokio::fs::remove_file(&access_token_path).await.ok();
}
anyhow::bail!(
"Failed to get Copilot API key ({status}): {sanitized}. \
Ensure your GitHub account has an active Copilot subscription."
);
}
let info: ApiKeyInfo = response.json().await?;
Ok(info)
}
async fn load_api_key_from_disk(&self) -> Option<ApiKeyInfo> {
let path = self.token_dir.join("api-key.json");
let data = tokio::fs::read_to_string(&path).await.ok()?;
serde_json::from_str(&data).ok()
}
async fn save_api_key_to_disk(&self, info: &ApiKeyInfo) {
let path = self.token_dir.join("api-key.json");
if let Ok(json) = serde_json::to_string_pretty(info) {
write_file_secure(&path, &json).await;
}
}
}
/// Write a file with 0600 permissions (owner read/write only).
/// Uses `spawn_blocking` to avoid blocking the async runtime.
async fn write_file_secure(path: &Path, content: &str) {
let path = path.to_path_buf();
let content = content.to_string();
let result = tokio::task::spawn_blocking(move || {
#[cfg(unix)]
{
use std::io::Write;
use std::os::unix::fs::{OpenOptionsExt, PermissionsExt};
let mut file = std::fs::OpenOptions::new()
.write(true)
.create(true)
.truncate(true)
.mode(0o600)
.open(&path)?;
file.write_all(content.as_bytes())?;
std::fs::set_permissions(&path, std::fs::Permissions::from_mode(0o600))?;
Ok::<(), std::io::Error>(())
}
#[cfg(not(unix))]
{
std::fs::write(&path, &content)?;
Ok::<(), std::io::Error>(())
}
})
.await;
match result {
Ok(Ok(())) => {}
Ok(Err(err)) => warn!("Failed to write secure file: {err}"),
Err(err) => warn!("Failed to spawn blocking write: {err}"),
}
}
#[async_trait]
impl Provider for CopilotProvider {
async fn chat_with_system(
&self,
system_prompt: Option<&str>,
message: &str,
model: &str,
temperature: f64,
) -> anyhow::Result<String> {
let mut messages = Vec::new();
if let Some(system) = system_prompt {
messages.push(ApiMessage {
role: "system".to_string(),
content: Some(system.to_string()),
tool_call_id: None,
tool_calls: None,
});
}
messages.push(ApiMessage {
role: "user".to_string(),
content: Some(message.to_string()),
tool_call_id: None,
tool_calls: None,
});
let response = self
.send_chat_request(messages, None, model, temperature)
.await?;
Ok(response.text.unwrap_or_default())
}
async fn chat_with_history(
&self,
messages: &[ChatMessage],
model: &str,
temperature: f64,
) -> anyhow::Result<String> {
let response = self
.send_chat_request(Self::convert_messages(messages), None, model, temperature)
.await?;
Ok(response.text.unwrap_or_default())
}
async fn chat(
&self,
request: ProviderChatRequest<'_>,
model: &str,
temperature: f64,
) -> anyhow::Result<ProviderChatResponse> {
self.send_chat_request(
Self::convert_messages(request.messages),
request.tools,
model,
temperature,
)
.await
}
fn supports_native_tools(&self) -> bool {
true
}
async fn warmup(&self) -> anyhow::Result<()> {
let _ = self.get_api_key().await?;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn new_without_token() {
let provider = CopilotProvider::new(None);
assert!(provider.github_token.is_none());
}
#[test]
fn new_with_token() {
let provider = CopilotProvider::new(Some("ghp_test"));
assert_eq!(provider.github_token.as_deref(), Some("ghp_test"));
}
#[test]
fn empty_token_treated_as_none() {
let provider = CopilotProvider::new(Some(""));
assert!(provider.github_token.is_none());
}
#[tokio::test]
async fn cache_starts_empty() {
let provider = CopilotProvider::new(None);
let cached = provider.refresh_lock.lock().await;
assert!(cached.is_none());
}
#[test]
fn copilot_headers_include_required_fields() {
let headers = CopilotProvider::COPILOT_HEADERS;
assert!(headers
.iter()
.any(|(header, _)| *header == "Editor-Version"));
assert!(headers
.iter()
.any(|(header, _)| *header == "Editor-Plugin-Version"));
assert!(headers.iter().any(|(header, _)| *header == "User-Agent"));
}
#[test]
fn default_interval_and_expiry() {
assert_eq!(default_interval(), 5);
assert_eq!(default_expires_in(), 900);
}
#[test]
fn supports_native_tools() {
let provider = CopilotProvider::new(None);
assert!(provider.supports_native_tools());
}
}

View file

@ -1,5 +1,6 @@
pub mod anthropic;
pub mod compatible;
pub mod copilot;
pub mod gemini;
pub mod ollama;
pub mod openai;
@ -37,9 +38,18 @@ fn token_end(input: &str, from: usize) -> usize {
/// Scrub known secret-like token prefixes from provider error strings.
///
/// Redacts tokens with prefixes like `sk-`, `xoxb-`, and `xoxp-`.
/// Redacts tokens with prefixes like `sk-`, `xoxb-`, `xoxp-`, `ghp_`, `gho_`,
/// `ghu_`, and `github_pat_`.
pub fn scrub_secret_patterns(input: &str) -> String {
const PREFIXES: [&str; 3] = ["sk-", "xoxb-", "xoxp-"];
const PREFIXES: [&str; 7] = [
"sk-",
"xoxb-",
"xoxp-",
"ghp_",
"gho_",
"ghu_",
"github_pat_",
];
let mut scrubbed = input.to_string();
@ -104,9 +114,12 @@ pub async fn api_error(provider: &str, response: reqwest::Response) -> anyhow::E
///
/// For Anthropic, the provider-specific env var is `ANTHROPIC_OAUTH_TOKEN` (for setup-tokens)
/// followed by `ANTHROPIC_API_KEY` (for regular API keys).
fn resolve_api_key(name: &str, api_key: Option<&str>) -> Option<String> {
if let Some(key) = api_key.map(str::trim).filter(|k| !k.is_empty()) {
return Some(key.to_string());
fn resolve_provider_credential(name: &str, credential_override: Option<&str>) -> Option<String> {
if let Some(raw_override) = credential_override {
let trimmed_override = raw_override.trim();
if !trimmed_override.is_empty() {
return Some(trimmed_override.to_owned());
}
}
let provider_env_candidates: Vec<&str> = match name {
@ -135,6 +148,7 @@ fn resolve_api_key(name: &str, api_key: Option<&str>) -> Option<String> {
"opencode" | "opencode-zen" => vec!["OPENCODE_API_KEY"],
"vercel" | "vercel-ai" => vec!["VERCEL_API_KEY"],
"cloudflare" | "cloudflare-ai" => vec!["CLOUDFLARE_API_KEY"],
"astrai" => vec!["ASTRAI_API_KEY"],
_ => vec![],
};
@ -182,19 +196,28 @@ fn parse_custom_provider_url(
}
}
/// Factory: create the right provider from config
#[allow(clippy::too_many_lines)]
/// Factory: create the right provider from config (without custom URL)
pub fn create_provider(name: &str, api_key: Option<&str>) -> anyhow::Result<Box<dyn Provider>> {
let resolved_key = resolve_api_key(name, api_key);
let key = resolved_key.as_deref();
create_provider_with_url(name, api_key, None)
}
/// Factory: create the right provider from config with optional custom base URL
#[allow(clippy::too_many_lines)]
pub fn create_provider_with_url(
name: &str,
api_key: Option<&str>,
api_url: Option<&str>,
) -> anyhow::Result<Box<dyn Provider>> {
let resolved_credential = resolve_provider_credential(name, api_key);
#[allow(clippy::option_as_ref_deref)]
let key = resolved_credential.as_ref().map(String::as_str);
match name {
// ── Primary providers (custom implementations) ───────
"openrouter" => Ok(Box::new(openrouter::OpenRouterProvider::new(key))),
"anthropic" => Ok(Box::new(anthropic::AnthropicProvider::new(key))),
"openai" => Ok(Box::new(openai::OpenAiProvider::new(key))),
// Ollama is a local service that doesn't use API keys.
// The api_key parameter is ignored to avoid it being misinterpreted as a base_url.
"ollama" => Ok(Box::new(ollama::OllamaProvider::new(None))),
// Ollama uses api_url for custom base URL (e.g. remote Ollama instance)
"ollama" => Ok(Box::new(ollama::OllamaProvider::new(api_url))),
"gemini" | "google" | "google-gemini" => {
Ok(Box::new(gemini::GeminiProvider::new(key)))
}
@ -257,7 +280,7 @@ pub fn create_provider(name: &str, api_key: Option<&str>) -> anyhow::Result<Box<
"Groq", "https://api.groq.com/openai", key, AuthStyle::Bearer,
))),
"mistral" => Ok(Box::new(OpenAiCompatibleProvider::new(
"Mistral", "https://api.mistral.ai", key, AuthStyle::Bearer,
"Mistral", "https://api.mistral.ai/v1", key, AuthStyle::Bearer,
))),
"xai" | "grok" => Ok(Box::new(OpenAiCompatibleProvider::new(
"xAI", "https://api.x.ai", key, AuthStyle::Bearer,
@ -277,11 +300,33 @@ pub fn create_provider(name: &str, api_key: Option<&str>) -> anyhow::Result<Box<
"cohere" => Ok(Box::new(OpenAiCompatibleProvider::new(
"Cohere", "https://api.cohere.com/compatibility", key, AuthStyle::Bearer,
))),
"copilot" | "github-copilot" => Ok(Box::new(OpenAiCompatibleProvider::new(
"GitHub Copilot", "https://api.githubcopilot.com", key, AuthStyle::Bearer,
))),
"nvidia" | "nvidia-nim" | "build.nvidia.com" => Ok(Box::new(OpenAiCompatibleProvider::new(
"NVIDIA NIM", "https://integrate.api.nvidia.com/v1", key, AuthStyle::Bearer,
"copilot" | "github-copilot" => {
Ok(Box::new(copilot::CopilotProvider::new(api_key)))
},
"lmstudio" | "lm-studio" => {
let lm_studio_key = api_key
.map(str::trim)
.filter(|value| !value.is_empty())
.unwrap_or("lm-studio");
Ok(Box::new(OpenAiCompatibleProvider::new(
"LM Studio",
"http://localhost:1234/v1",
Some(lm_studio_key),
AuthStyle::Bearer,
)))
}
"nvidia" | "nvidia-nim" | "build.nvidia.com" => Ok(Box::new(
OpenAiCompatibleProvider::new(
"NVIDIA NIM",
"https://integrate.api.nvidia.com/v1",
key,
AuthStyle::Bearer,
),
)),
// ── AI inference routers ─────────────────────────────
"astrai" => Ok(Box::new(OpenAiCompatibleProvider::new(
"Astrai", "https://as-trai.com/v1", key, AuthStyle::Bearer,
))),
// ── Bring Your Own Provider (custom URL) ───────────
@ -326,13 +371,14 @@ pub fn create_provider(name: &str, api_key: Option<&str>) -> anyhow::Result<Box<
pub fn create_resilient_provider(
primary_name: &str,
api_key: Option<&str>,
api_url: Option<&str>,
reliability: &crate::config::ReliabilityConfig,
) -> anyhow::Result<Box<dyn Provider>> {
let mut providers: Vec<(String, Box<dyn Provider>)> = Vec::new();
providers.push((
primary_name.to_string(),
create_provider(primary_name, api_key)?,
create_provider_with_url(primary_name, api_key, api_url)?,
));
for fallback in &reliability.fallback_providers {
@ -340,21 +386,13 @@ pub fn create_resilient_provider(
continue;
}
if api_key.is_some() && fallback != "ollama" {
tracing::warn!(
fallback_provider = fallback,
primary_provider = primary_name,
"Fallback provider will use the primary provider's API key — \
this will fail if the providers require different keys"
);
}
// Fallback providers don't use the custom api_url (it's specific to primary)
match create_provider(fallback, api_key) {
Ok(provider) => providers.push((fallback.clone(), provider)),
Err(e) => {
Err(_error) => {
tracing::warn!(
fallback_provider = fallback,
"Ignoring invalid fallback provider: {e}"
"Ignoring invalid fallback provider during initialization"
);
}
}
@ -377,12 +415,13 @@ pub fn create_resilient_provider(
pub fn create_routed_provider(
primary_name: &str,
api_key: Option<&str>,
api_url: Option<&str>,
reliability: &crate::config::ReliabilityConfig,
model_routes: &[crate::config::ModelRouteConfig],
default_model: &str,
) -> anyhow::Result<Box<dyn Provider>> {
if model_routes.is_empty() {
return create_resilient_provider(primary_name, api_key, reliability);
return create_resilient_provider(primary_name, api_key, api_url, reliability);
}
// Collect unique provider names needed
@ -396,12 +435,19 @@ pub fn create_routed_provider(
// Create each provider (with its own resilience wrapper)
let mut providers: Vec<(String, Box<dyn Provider>)> = Vec::new();
for name in &needed {
let key = model_routes
let routed_credential = model_routes
.iter()
.find(|r| &r.provider == name)
.and_then(|r| r.api_key.as_deref())
.or(api_key);
match create_resilient_provider(name, key, reliability) {
.and_then(|r| {
r.api_key.as_ref().and_then(|raw_key| {
let trimmed_key = raw_key.trim();
(!trimmed_key.is_empty()).then_some(trimmed_key)
})
});
let key = routed_credential.or(api_key);
// Only use api_url for the primary provider
let url = if name == primary_name { api_url } else { None };
match create_resilient_provider(name, key, url, reliability) {
Ok(provider) => providers.push((name.clone(), provider)),
Err(e) => {
if name == primary_name {
@ -409,7 +455,7 @@ pub fn create_routed_provider(
}
tracing::warn!(
provider = name.as_str(),
"Ignoring routed provider that failed to create: {e}"
"Ignoring routed provider that failed to initialize"
);
}
}
@ -441,27 +487,27 @@ mod tests {
use super::*;
#[test]
fn resolve_api_key_prefers_explicit_argument() {
let resolved = resolve_api_key("openrouter", Some(" explicit-key "));
assert_eq!(resolved.as_deref(), Some("explicit-key"));
fn resolve_provider_credential_prefers_explicit_argument() {
let resolved = resolve_provider_credential("openrouter", Some(" explicit-key "));
assert_eq!(resolved, Some("explicit-key".to_string()));
}
// ── Primary providers ────────────────────────────────────
#[test]
fn factory_openrouter() {
assert!(create_provider("openrouter", Some("sk-test")).is_ok());
assert!(create_provider("openrouter", Some("provider-test-credential")).is_ok());
assert!(create_provider("openrouter", None).is_ok());
}
#[test]
fn factory_anthropic() {
assert!(create_provider("anthropic", Some("sk-test")).is_ok());
assert!(create_provider("anthropic", Some("provider-test-credential")).is_ok());
}
#[test]
fn factory_openai() {
assert!(create_provider("openai", Some("sk-test")).is_ok());
assert!(create_provider("openai", Some("provider-test-credential")).is_ok());
}
#[test]
@ -556,6 +602,13 @@ mod tests {
assert!(create_provider("dashscope-us", Some("key")).is_ok());
}
#[test]
fn factory_lmstudio() {
assert!(create_provider("lmstudio", Some("key")).is_ok());
assert!(create_provider("lm-studio", Some("key")).is_ok());
assert!(create_provider("lmstudio", None).is_ok());
}
// ── Extended ecosystem ───────────────────────────────────
#[test]
@ -614,6 +667,13 @@ mod tests {
assert!(create_provider("build.nvidia.com", Some("nvapi-test")).is_ok());
}
// ── AI inference routers ─────────────────────────────────
#[test]
fn factory_astrai() {
assert!(create_provider("astrai", Some("sk-astrai-test")).is_ok());
}
// ── Custom / BYOP provider ─────────────────────────────
#[test]
@ -761,17 +821,33 @@ mod tests {
scheduler_retries: 2,
};
let provider = create_resilient_provider("openrouter", Some("sk-test"), &reliability);
let provider = create_resilient_provider(
"openrouter",
Some("provider-test-credential"),
None,
&reliability,
);
assert!(provider.is_ok());
}
#[test]
fn resilient_provider_errors_for_invalid_primary() {
let reliability = crate::config::ReliabilityConfig::default();
let provider = create_resilient_provider("totally-invalid", Some("sk-test"), &reliability);
let provider = create_resilient_provider(
"totally-invalid",
Some("provider-test-credential"),
None,
&reliability,
);
assert!(provider.is_err());
}
#[test]
fn ollama_with_custom_url() {
let provider = create_provider_with_url("ollama", None, Some("http://10.100.2.32:11434"));
assert!(provider.is_ok());
}
#[test]
fn factory_all_providers_create_successfully() {
let providers = [
@ -794,6 +870,7 @@ mod tests {
"qwen",
"qwen-intl",
"qwen-us",
"lmstudio",
"groq",
"mistral",
"xai",
@ -888,7 +965,7 @@ mod tests {
#[test]
fn sanitize_preserves_unicode_boundaries() {
let input = format!("{} sk-abcdef123", "こんにちは".repeat(80));
let input = format!("{} sk-abcdef123", "hello🙂".repeat(80));
let result = sanitize_api_error(&input);
assert!(std::str::from_utf8(result.as_bytes()).is_ok());
assert!(!result.contains("sk-abcdef123"));
@ -900,4 +977,32 @@ mod tests {
let result = sanitize_api_error(input);
assert_eq!(result, input);
}
#[test]
fn scrub_github_personal_access_token() {
let input = "auth failed with token ghp_abc123def456";
let result = scrub_secret_patterns(input);
assert_eq!(result, "auth failed with token [REDACTED]");
}
#[test]
fn scrub_github_oauth_token() {
let input = "Bearer gho_1234567890abcdef";
let result = scrub_secret_patterns(input);
assert_eq!(result, "Bearer [REDACTED]");
}
#[test]
fn scrub_github_user_token() {
let input = "token ghu_sessiontoken123";
let result = scrub_secret_patterns(input);
assert_eq!(result, "token [REDACTED]");
}
#[test]
fn scrub_github_fine_grained_pat() {
let input = "failed: github_pat_11AABBC_xyzzy789";
let result = scrub_secret_patterns(input);
assert_eq!(result, "failed: [REDACTED]");
}
}

View file

@ -8,6 +8,8 @@ pub struct OllamaProvider {
client: Client,
}
// ─── Request Structures ───────────────────────────────────────────────────────
#[derive(Debug, Serialize)]
struct ChatRequest {
model: String,
@ -27,6 +29,8 @@ struct Options {
temperature: f64,
}
// ─── Response Structures ──────────────────────────────────────────────────────
#[derive(Debug, Deserialize)]
struct ApiChatResponse {
message: ResponseMessage,
@ -34,9 +38,30 @@ struct ApiChatResponse {
#[derive(Debug, Deserialize)]
struct ResponseMessage {
#[serde(default)]
content: String,
#[serde(default)]
tool_calls: Vec<OllamaToolCall>,
/// Some models return a "thinking" field with internal reasoning
#[serde(default)]
thinking: Option<String>,
}
#[derive(Debug, Deserialize)]
struct OllamaToolCall {
id: Option<String>,
function: OllamaFunction,
}
#[derive(Debug, Deserialize)]
struct OllamaFunction {
name: String,
#[serde(default)]
arguments: serde_json::Value,
}
// ─── Implementation ───────────────────────────────────────────────────────────
impl OllamaProvider {
pub fn new(base_url: Option<&str>) -> Self {
Self {
@ -45,12 +70,145 @@ impl OllamaProvider {
.trim_end_matches('/')
.to_string(),
client: Client::builder()
.timeout(std::time::Duration::from_secs(300)) // Ollama runs locally, may be slow
.timeout(std::time::Duration::from_secs(300))
.connect_timeout(std::time::Duration::from_secs(10))
.build()
.unwrap_or_else(|_| Client::new()),
}
}
/// Send a request to Ollama and get the parsed response
async fn send_request(
&self,
messages: Vec<Message>,
model: &str,
temperature: f64,
) -> anyhow::Result<ApiChatResponse> {
let request = ChatRequest {
model: model.to_string(),
messages,
stream: false,
options: Options { temperature },
};
let url = format!("{}/api/chat", self.base_url);
tracing::debug!(
"Ollama request: url={} model={} message_count={} temperature={}",
url,
model,
request.messages.len(),
temperature
);
let response = self.client.post(&url).json(&request).send().await?;
let status = response.status();
tracing::debug!("Ollama response status: {}", status);
let body = response.bytes().await?;
tracing::debug!("Ollama response body length: {} bytes", body.len());
if !status.is_success() {
let raw = String::from_utf8_lossy(&body);
let sanitized = super::sanitize_api_error(&raw);
tracing::error!(
"Ollama error response: status={} body_excerpt={}",
status,
sanitized
);
anyhow::bail!(
"Ollama API error ({}): {}. Is Ollama running? (brew install ollama && ollama serve)",
status,
sanitized
);
}
let chat_response: ApiChatResponse = match serde_json::from_slice(&body) {
Ok(r) => r,
Err(e) => {
let raw = String::from_utf8_lossy(&body);
let sanitized = super::sanitize_api_error(&raw);
tracing::error!(
"Ollama response deserialization failed: {e}. body_excerpt={}",
sanitized
);
anyhow::bail!("Failed to parse Ollama response: {e}");
}
};
Ok(chat_response)
}
/// Convert Ollama tool calls to the JSON format expected by parse_tool_calls in loop_.rs
///
/// Handles quirky model behavior where tool calls are wrapped:
/// - `{"name": "tool_call", "arguments": {"name": "shell", "arguments": {...}}}`
/// - `{"name": "tool.shell", "arguments": {...}}`
fn format_tool_calls_for_loop(&self, tool_calls: &[OllamaToolCall]) -> String {
let formatted_calls: Vec<serde_json::Value> = tool_calls
.iter()
.map(|tc| {
let (tool_name, tool_args) = self.extract_tool_name_and_args(tc);
// Arguments must be a JSON string for parse_tool_calls compatibility
let args_str =
serde_json::to_string(&tool_args).unwrap_or_else(|_| "{}".to_string());
serde_json::json!({
"id": tc.id,
"type": "function",
"function": {
"name": tool_name,
"arguments": args_str
}
})
})
.collect();
serde_json::json!({
"content": "",
"tool_calls": formatted_calls
})
.to_string()
}
/// Extract the actual tool name and arguments from potentially nested structures
fn extract_tool_name_and_args(&self, tc: &OllamaToolCall) -> (String, serde_json::Value) {
let name = &tc.function.name;
let args = &tc.function.arguments;
// Pattern 1: Nested tool_call wrapper (various malformed versions)
// {"name": "tool_call", "arguments": {"name": "shell", "arguments": {"command": "date"}}}
// {"name": "tool_call><json", "arguments": {"name": "shell", ...}}
// {"name": "tool.call", "arguments": {"name": "shell", ...}}
if name == "tool_call"
|| name == "tool.call"
|| name.starts_with("tool_call>")
|| name.starts_with("tool_call<")
{
if let Some(nested_name) = args.get("name").and_then(|v| v.as_str()) {
let nested_args = args
.get("arguments")
.cloned()
.unwrap_or(serde_json::json!({}));
tracing::debug!(
"Unwrapped nested tool call: {} -> {} with args {:?}",
name,
nested_name,
nested_args
);
return (nested_name.to_string(), nested_args);
}
}
// Pattern 2: Prefixed tool name (tool.shell, tool.file_read, etc.)
if let Some(stripped) = name.strip_prefix("tool.") {
return (stripped.to_string(), args.clone());
}
// Pattern 3: Normal tool call
(name.clone(), args.clone())
}
}
#[async_trait]
@ -76,27 +234,96 @@ impl Provider for OllamaProvider {
content: message.to_string(),
});
let request = ChatRequest {
model: model.to_string(),
messages,
stream: false,
options: Options { temperature },
};
let response = self.send_request(messages, model, temperature).await?;
let url = format!("{}/api/chat", self.base_url);
let response = self.client.post(&url).json(&request).send().await?;
if !response.status().is_success() {
let err = super::api_error("Ollama", response).await;
anyhow::bail!("{err}. Is Ollama running? (brew install ollama && ollama serve)");
// If model returned tool calls, format them for loop_.rs's parse_tool_calls
if !response.message.tool_calls.is_empty() {
tracing::debug!(
"Ollama returned {} tool call(s), formatting for loop parser",
response.message.tool_calls.len()
);
return Ok(self.format_tool_calls_for_loop(&response.message.tool_calls));
}
let chat_response: ApiChatResponse = response.json().await?;
Ok(chat_response.message.content)
// Plain text response
let content = response.message.content;
// Handle edge case: model returned only "thinking" with no content or tool calls
if content.is_empty() {
if let Some(thinking) = &response.message.thinking {
tracing::warn!(
"Ollama returned empty content with only thinking: '{}'. Model may have stopped prematurely.",
if thinking.len() > 100 { &thinking[..100] } else { thinking }
);
return Ok(format!(
"I was thinking about this: {}... but I didn't complete my response. Could you try asking again?",
if thinking.len() > 200 { &thinking[..200] } else { thinking }
));
}
tracing::warn!("Ollama returned empty content with no tool calls");
}
Ok(content)
}
async fn chat_with_history(
&self,
messages: &[crate::providers::ChatMessage],
model: &str,
temperature: f64,
) -> anyhow::Result<String> {
let api_messages: Vec<Message> = messages
.iter()
.map(|m| Message {
role: m.role.clone(),
content: m.content.clone(),
})
.collect();
let response = self.send_request(api_messages, model, temperature).await?;
// If model returned tool calls, format them for loop_.rs's parse_tool_calls
if !response.message.tool_calls.is_empty() {
tracing::debug!(
"Ollama returned {} tool call(s), formatting for loop parser",
response.message.tool_calls.len()
);
return Ok(self.format_tool_calls_for_loop(&response.message.tool_calls));
}
// Plain text response
let content = response.message.content;
// Handle edge case: model returned only "thinking" with no content or tool calls
// This is a model quirk - it stopped after reasoning without producing output
if content.is_empty() {
if let Some(thinking) = &response.message.thinking {
tracing::warn!(
"Ollama returned empty content with only thinking: '{}'. Model may have stopped prematurely.",
if thinking.len() > 100 { &thinking[..100] } else { thinking }
);
// Return a message indicating the model's thought process but no action
return Ok(format!(
"I was thinking about this: {}... but I didn't complete my response. Could you try asking again?",
if thinking.len() > 200 { &thinking[..200] } else { thinking }
));
}
tracing::warn!("Ollama returned empty content with no tool calls");
}
Ok(content)
}
fn supports_native_tools(&self) -> bool {
// Return false since loop_.rs uses XML-style tool parsing via system prompt
// The model may return native tool_calls but we convert them to JSON format
// that parse_tool_calls() understands
false
}
}
// ─── Tests ────────────────────────────────────────────────────────────────────
#[cfg(test)]
mod tests {
use super::*;
@ -125,46 +352,6 @@ mod tests {
assert_eq!(p.base_url, "");
}
#[test]
fn request_serializes_with_system() {
let req = ChatRequest {
model: "llama3".to_string(),
messages: vec![
Message {
role: "system".to_string(),
content: "You are ZeroClaw".to_string(),
},
Message {
role: "user".to_string(),
content: "hello".to_string(),
},
],
stream: false,
options: Options { temperature: 0.7 },
};
let json = serde_json::to_string(&req).unwrap();
assert!(json.contains("\"stream\":false"));
assert!(json.contains("llama3"));
assert!(json.contains("system"));
assert!(json.contains("\"temperature\":0.7"));
}
#[test]
fn request_serializes_without_system() {
let req = ChatRequest {
model: "mistral".to_string(),
messages: vec![Message {
role: "user".to_string(),
content: "test".to_string(),
}],
stream: false,
options: Options { temperature: 0.0 },
};
let json = serde_json::to_string(&req).unwrap();
assert!(!json.contains("\"role\":\"system\""));
assert!(json.contains("mistral"));
}
#[test]
fn response_deserializes() {
let json = r#"{"message":{"role":"assistant","content":"Hello from Ollama!"}}"#;
@ -180,9 +367,98 @@ mod tests {
}
#[test]
fn response_with_multiline() {
let json = r#"{"message":{"role":"assistant","content":"line1\nline2\nline3"}}"#;
fn response_with_missing_content_defaults_to_empty() {
let json = r#"{"message":{"role":"assistant"}}"#;
let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
assert!(resp.message.content.contains("line1"));
assert!(resp.message.content.is_empty());
}
#[test]
fn response_with_thinking_field_extracts_content() {
let json =
r#"{"message":{"role":"assistant","content":"hello","thinking":"internal reasoning"}}"#;
let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
assert_eq!(resp.message.content, "hello");
}
#[test]
fn response_with_tool_calls_parses_correctly() {
let json = r#"{"message":{"role":"assistant","content":"","tool_calls":[{"id":"call_123","function":{"name":"shell","arguments":{"command":"date"}}}]}}"#;
let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
assert!(resp.message.content.is_empty());
assert_eq!(resp.message.tool_calls.len(), 1);
assert_eq!(resp.message.tool_calls[0].function.name, "shell");
}
#[test]
fn extract_tool_name_handles_nested_tool_call() {
let provider = OllamaProvider::new(None);
let tc = OllamaToolCall {
id: Some("call_123".into()),
function: OllamaFunction {
name: "tool_call".into(),
arguments: serde_json::json!({
"name": "shell",
"arguments": {"command": "date"}
}),
},
};
let (name, args) = provider.extract_tool_name_and_args(&tc);
assert_eq!(name, "shell");
assert_eq!(args.get("command").unwrap(), "date");
}
#[test]
fn extract_tool_name_handles_prefixed_name() {
let provider = OllamaProvider::new(None);
let tc = OllamaToolCall {
id: Some("call_123".into()),
function: OllamaFunction {
name: "tool.shell".into(),
arguments: serde_json::json!({"command": "ls"}),
},
};
let (name, args) = provider.extract_tool_name_and_args(&tc);
assert_eq!(name, "shell");
assert_eq!(args.get("command").unwrap(), "ls");
}
#[test]
fn extract_tool_name_handles_normal_call() {
let provider = OllamaProvider::new(None);
let tc = OllamaToolCall {
id: Some("call_123".into()),
function: OllamaFunction {
name: "file_read".into(),
arguments: serde_json::json!({"path": "/tmp/test"}),
},
};
let (name, args) = provider.extract_tool_name_and_args(&tc);
assert_eq!(name, "file_read");
assert_eq!(args.get("path").unwrap(), "/tmp/test");
}
#[test]
fn format_tool_calls_produces_valid_json() {
let provider = OllamaProvider::new(None);
let tool_calls = vec![OllamaToolCall {
id: Some("call_abc".into()),
function: OllamaFunction {
name: "shell".into(),
arguments: serde_json::json!({"command": "date"}),
},
}];
let formatted = provider.format_tool_calls_for_loop(&tool_calls);
let parsed: serde_json::Value = serde_json::from_str(&formatted).unwrap();
assert!(parsed.get("tool_calls").is_some());
let calls = parsed.get("tool_calls").unwrap().as_array().unwrap();
assert_eq!(calls.len(), 1);
let func = calls[0].get("function").unwrap();
assert_eq!(func.get("name").unwrap(), "shell");
// arguments should be a string (JSON-encoded)
assert!(func.get("arguments").unwrap().is_string());
}
}

View file

@ -8,7 +8,7 @@ use reqwest::Client;
use serde::{Deserialize, Serialize};
pub struct OpenAiProvider {
api_key: Option<String>,
credential: Option<String>,
client: Client,
}
@ -110,9 +110,9 @@ struct NativeResponseMessage {
}
impl OpenAiProvider {
pub fn new(api_key: Option<&str>) -> Self {
pub fn new(credential: Option<&str>) -> Self {
Self {
api_key: api_key.map(ToString::to_string),
credential: credential.map(ToString::to_string),
client: Client::builder()
.timeout(std::time::Duration::from_secs(120))
.connect_timeout(std::time::Duration::from_secs(10))
@ -232,7 +232,7 @@ impl Provider for OpenAiProvider {
model: &str,
temperature: f64,
) -> anyhow::Result<String> {
let api_key = self.api_key.as_ref().ok_or_else(|| {
let credential = self.credential.as_ref().ok_or_else(|| {
anyhow::anyhow!("OpenAI API key not set. Set OPENAI_API_KEY or edit config.toml.")
})?;
@ -259,7 +259,7 @@ impl Provider for OpenAiProvider {
let response = self
.client
.post("https://api.openai.com/v1/chat/completions")
.header("Authorization", format!("Bearer {api_key}"))
.header("Authorization", format!("Bearer {credential}"))
.json(&request)
.send()
.await?;
@ -284,7 +284,7 @@ impl Provider for OpenAiProvider {
model: &str,
temperature: f64,
) -> anyhow::Result<ProviderChatResponse> {
let api_key = self.api_key.as_ref().ok_or_else(|| {
let credential = self.credential.as_ref().ok_or_else(|| {
anyhow::anyhow!("OpenAI API key not set. Set OPENAI_API_KEY or edit config.toml.")
})?;
@ -300,7 +300,7 @@ impl Provider for OpenAiProvider {
let response = self
.client
.post("https://api.openai.com/v1/chat/completions")
.header("Authorization", format!("Bearer {api_key}"))
.header("Authorization", format!("Bearer {credential}"))
.json(&native_request)
.send()
.await?;
@ -330,20 +330,20 @@ mod tests {
#[test]
fn creates_with_key() {
let p = OpenAiProvider::new(Some("sk-proj-abc123"));
assert_eq!(p.api_key.as_deref(), Some("sk-proj-abc123"));
let p = OpenAiProvider::new(Some("openai-test-credential"));
assert_eq!(p.credential.as_deref(), Some("openai-test-credential"));
}
#[test]
fn creates_without_key() {
let p = OpenAiProvider::new(None);
assert!(p.api_key.is_none());
assert!(p.credential.is_none());
}
#[test]
fn creates_with_empty_key() {
let p = OpenAiProvider::new(Some(""));
assert_eq!(p.api_key.as_deref(), Some(""));
assert_eq!(p.credential.as_deref(), Some(""));
}
#[tokio::test]

View file

@ -8,7 +8,7 @@ use reqwest::Client;
use serde::{Deserialize, Serialize};
pub struct OpenRouterProvider {
api_key: Option<String>,
credential: Option<String>,
client: Client,
}
@ -110,9 +110,9 @@ struct NativeResponseMessage {
}
impl OpenRouterProvider {
pub fn new(api_key: Option<&str>) -> Self {
pub fn new(credential: Option<&str>) -> Self {
Self {
api_key: api_key.map(ToString::to_string),
credential: credential.map(ToString::to_string),
client: Client::builder()
.timeout(std::time::Duration::from_secs(120))
.connect_timeout(std::time::Duration::from_secs(10))
@ -232,10 +232,10 @@ impl Provider for OpenRouterProvider {
async fn warmup(&self) -> anyhow::Result<()> {
// Hit a lightweight endpoint to establish TLS + HTTP/2 connection pool.
// This prevents the first real chat request from timing out on cold start.
if let Some(api_key) = self.api_key.as_ref() {
if let Some(credential) = self.credential.as_ref() {
self.client
.get("https://openrouter.ai/api/v1/auth/key")
.header("Authorization", format!("Bearer {api_key}"))
.header("Authorization", format!("Bearer {credential}"))
.send()
.await?
.error_for_status()?;
@ -250,7 +250,7 @@ impl Provider for OpenRouterProvider {
model: &str,
temperature: f64,
) -> anyhow::Result<String> {
let api_key = self.api_key.as_ref()
let credential = self.credential.as_ref()
.ok_or_else(|| anyhow::anyhow!("OpenRouter API key not set. Run `zeroclaw onboard` or set OPENROUTER_API_KEY env var."))?;
let mut messages = Vec::new();
@ -276,7 +276,7 @@ impl Provider for OpenRouterProvider {
let response = self
.client
.post("https://openrouter.ai/api/v1/chat/completions")
.header("Authorization", format!("Bearer {api_key}"))
.header("Authorization", format!("Bearer {credential}"))
.header(
"HTTP-Referer",
"https://github.com/theonlyhennygod/zeroclaw",
@ -306,7 +306,7 @@ impl Provider for OpenRouterProvider {
model: &str,
temperature: f64,
) -> anyhow::Result<String> {
let api_key = self.api_key.as_ref()
let credential = self.credential.as_ref()
.ok_or_else(|| anyhow::anyhow!("OpenRouter API key not set. Run `zeroclaw onboard` or set OPENROUTER_API_KEY env var."))?;
let api_messages: Vec<Message> = messages
@ -326,7 +326,7 @@ impl Provider for OpenRouterProvider {
let response = self
.client
.post("https://openrouter.ai/api/v1/chat/completions")
.header("Authorization", format!("Bearer {api_key}"))
.header("Authorization", format!("Bearer {credential}"))
.header(
"HTTP-Referer",
"https://github.com/theonlyhennygod/zeroclaw",
@ -356,7 +356,7 @@ impl Provider for OpenRouterProvider {
model: &str,
temperature: f64,
) -> anyhow::Result<ProviderChatResponse> {
let api_key = self.api_key.as_ref().ok_or_else(|| {
let credential = self.credential.as_ref().ok_or_else(|| {
anyhow::anyhow!(
"OpenRouter API key not set. Run `zeroclaw onboard` or set OPENROUTER_API_KEY env var."
)
@ -374,7 +374,7 @@ impl Provider for OpenRouterProvider {
let response = self
.client
.post("https://openrouter.ai/api/v1/chat/completions")
.header("Authorization", format!("Bearer {api_key}"))
.header("Authorization", format!("Bearer {credential}"))
.header(
"HTTP-Referer",
"https://github.com/theonlyhennygod/zeroclaw",
@ -409,7 +409,7 @@ impl Provider for OpenRouterProvider {
model: &str,
temperature: f64,
) -> anyhow::Result<ProviderChatResponse> {
let api_key = self.api_key.as_ref().ok_or_else(|| {
let credential = self.credential.as_ref().ok_or_else(|| {
anyhow::anyhow!(
"OpenRouter API key not set. Run `zeroclaw onboard` or set OPENROUTER_API_KEY env var."
)
@ -462,7 +462,7 @@ impl Provider for OpenRouterProvider {
let response = self
.client
.post("https://openrouter.ai/api/v1/chat/completions")
.header("Authorization", format!("Bearer {api_key}"))
.header("Authorization", format!("Bearer {credential}"))
.header(
"HTTP-Referer",
"https://github.com/theonlyhennygod/zeroclaw",
@ -494,14 +494,17 @@ mod tests {
#[test]
fn creates_with_key() {
let provider = OpenRouterProvider::new(Some("sk-or-123"));
assert_eq!(provider.api_key.as_deref(), Some("sk-or-123"));
let provider = OpenRouterProvider::new(Some("openrouter-test-credential"));
assert_eq!(
provider.credential.as_deref(),
Some("openrouter-test-credential")
);
}
#[test]
fn creates_without_key() {
let provider = OpenRouterProvider::new(None);
assert!(provider.api_key.is_none());
assert!(provider.credential.is_none());
}
#[tokio::test]

View file

@ -144,8 +144,8 @@ impl Provider for ReliableProvider {
async fn warmup(&self) -> anyhow::Result<()> {
for (name, provider) in &self.providers {
tracing::info!(provider = name, "Warming up provider connection pool");
if let Err(e) = provider.warmup().await {
tracing::warn!(provider = name, "Warmup failed (non-fatal): {e}");
if provider.warmup().await.is_err() {
tracing::warn!(provider = name, "Warmup failed (non-fatal)");
}
}
Ok(())
@ -186,8 +186,15 @@ impl Provider for ReliableProvider {
let non_retryable = is_non_retryable(&e);
let rate_limited = is_rate_limited(&e);
let failure_reason = if rate_limited {
"rate_limited"
} else if non_retryable {
"non_retryable"
} else {
"retryable"
};
failures.push(format!(
"{provider_name}/{current_model} attempt {}/{}: {e}",
"{provider_name}/{current_model} attempt {}/{}: {failure_reason}",
attempt + 1,
self.max_retries + 1
));
@ -284,8 +291,15 @@ impl Provider for ReliableProvider {
let non_retryable = is_non_retryable(&e);
let rate_limited = is_rate_limited(&e);
let failure_reason = if rate_limited {
"rate_limited"
} else if non_retryable {
"non_retryable"
} else {
"retryable"
};
failures.push(format!(
"{provider_name}/{current_model} attempt {}/{}: {e}",
"{provider_name}/{current_model} attempt {}/{}: {failure_reason}",
attempt + 1,
self.max_retries + 1
));

View file

@ -193,6 +193,13 @@ pub enum StreamError {
#[async_trait]
pub trait Provider: Send + Sync {
/// Query provider capabilities.
///
/// Default implementation returns minimal capabilities (no native tool calling).
/// Providers should override this to declare their actual capabilities.
fn capabilities(&self) -> ProviderCapabilities {
ProviderCapabilities::default()
}
/// Simple one-shot chat (single user message, no explicit system prompt).
///
/// This is the preferred API for non-agentic direct interactions.
@ -256,7 +263,7 @@ pub trait Provider: Send + Sync {
/// Whether provider supports native tool calls over API.
fn supports_native_tools(&self) -> bool {
false
self.capabilities().native_tool_calling
}
/// Warm up the HTTP connection pool (TLS handshake, DNS, HTTP/2 setup).
@ -336,6 +343,27 @@ pub trait Provider: Send + Sync {
mod tests {
use super::*;
struct CapabilityMockProvider;
#[async_trait]
impl Provider for CapabilityMockProvider {
fn capabilities(&self) -> ProviderCapabilities {
ProviderCapabilities {
native_tool_calling: true,
}
}
async fn chat_with_system(
&self,
_system_prompt: Option<&str>,
_message: &str,
_model: &str,
_temperature: f64,
) -> anyhow::Result<String> {
Ok("ok".into())
}
}
#[test]
fn chat_message_constructors() {
let sys = ChatMessage::system("Be helpful");
@ -398,4 +426,32 @@ mod tests {
let json = serde_json::to_string(&tool_result).unwrap();
assert!(json.contains("\"type\":\"ToolResults\""));
}
#[test]
fn provider_capabilities_default() {
let caps = ProviderCapabilities::default();
assert!(!caps.native_tool_calling);
}
#[test]
fn provider_capabilities_equality() {
let caps1 = ProviderCapabilities {
native_tool_calling: true,
};
let caps2 = ProviderCapabilities {
native_tool_calling: true,
};
let caps3 = ProviderCapabilities {
native_tool_calling: false,
};
assert_eq!(caps1, caps2);
assert_ne!(caps1, caps3);
}
#[test]
fn supports_native_tools_reflects_capabilities_default_mapping() {
let provider = CapabilityMockProvider;
assert!(provider.supports_native_tools());
}
}

View file

@ -81,14 +81,17 @@ mod tests {
#[test]
fn bubblewrap_sandbox_name() {
assert_eq!(BubblewrapSandbox.name(), "bubblewrap");
let sandbox = BubblewrapSandbox;
assert_eq!(sandbox.name(), "bubblewrap");
}
#[test]
fn bubblewrap_is_available_only_if_installed() {
// Result depends on whether bwrap is installed
let available = BubblewrapSandbox::is_available();
let sandbox = BubblewrapSandbox;
let _available = sandbox.is_available();
// Either way, the name should still work
assert_eq!(BubblewrapSandbox.name(), "bubblewrap");
assert_eq!(sandbox.name(), "bubblewrap");
}
}

View file

@ -184,7 +184,7 @@ fn generate_token() -> String {
use rand::RngCore;
let mut bytes = [0u8; 32];
rand::thread_rng().fill_bytes(&mut bytes);
format!("zc_{}", hex::encode(&bytes))
format!("zc_{}", hex::encode(bytes))
}
/// SHA-256 hash a bearer token for storage. Returns lowercase hex.

View file

@ -343,6 +343,7 @@ impl SecurityPolicy {
/// validates each sub-command against the allowlist
/// - Blocks single `&` background chaining (`&&` remains supported)
/// - Blocks output redirections (`>`, `>>`) that could write outside workspace
/// - Blocks dangerous arguments (e.g. `find -exec`, `git config`)
pub fn is_command_allowed(&self, command: &str) -> bool {
if self.autonomy == AutonomyLevel::ReadOnly {
return false;
@ -350,7 +351,12 @@ impl SecurityPolicy {
// Block subshell/expansion operators — these allow hiding arbitrary
// commands inside an allowed command (e.g. `echo $(rm -rf /)`)
if command.contains('`') || command.contains("$(") || command.contains("${") {
if command.contains('`')
|| command.contains("$(")
|| command.contains("${")
|| command.contains("<(")
|| command.contains(">(")
{
return false;
}
@ -359,6 +365,15 @@ impl SecurityPolicy {
return false;
}
// Block `tee` — it can write to arbitrary files, bypassing the
// redirect check above (e.g. `echo secret | tee /etc/crontab`)
if command
.split_whitespace()
.any(|w| w == "tee" || w.ends_with("/tee"))
{
return false;
}
// Block background command chaining (`&`), which can hide extra
// sub-commands and outlive timeout expectations. Keep `&&` allowed.
if contains_single_ampersand(command) {
@ -384,13 +399,9 @@ impl SecurityPolicy {
// Strip leading env var assignments (e.g. FOO=bar cmd)
let cmd_part = skip_env_assignments(segment);
let base_cmd = cmd_part
.split_whitespace()
.next()
.unwrap_or("")
.rsplit('/')
.next()
.unwrap_or("");
let mut words = cmd_part.split_whitespace();
let base_raw = words.next().unwrap_or("");
let base_cmd = base_raw.rsplit('/').next().unwrap_or("");
if base_cmd.is_empty() {
continue;
@ -403,6 +414,12 @@ impl SecurityPolicy {
{
return false;
}
// Validate arguments for the command
let args: Vec<String> = words.map(|w| w.to_ascii_lowercase()).collect();
if !self.is_args_safe(base_cmd, &args) {
return false;
}
}
// At least one command must be present
@ -414,6 +431,29 @@ impl SecurityPolicy {
has_cmd
}
/// Check for dangerous arguments that allow sub-command execution.
fn is_args_safe(&self, base: &str, args: &[String]) -> bool {
let base = base.to_ascii_lowercase();
match base.as_str() {
"find" => {
// find -exec and find -ok allow arbitrary command execution
!args.iter().any(|arg| arg == "-exec" || arg == "-ok")
}
"git" => {
// git config, alias, and -c can be used to set dangerous options
// (e.g. git config core.editor "rm -rf /")
!args.iter().any(|arg| {
arg == "config"
|| arg.starts_with("config.")
|| arg == "alias"
|| arg.starts_with("alias.")
|| arg == "-c"
})
}
_ => true,
}
}
/// Check if a file path is allowed (no path traversal, within workspace)
pub fn is_path_allowed(&self, path: &str) -> bool {
// Block null bytes (can truncate paths in C-backed syscalls)
@ -982,12 +1022,43 @@ mod tests {
assert!(!p.is_command_allowed("ls >> /tmp/exfil.txt"));
}
#[test]
fn command_argument_injection_blocked() {
let p = default_policy();
// find -exec is a common bypass
assert!(!p.is_command_allowed("find . -exec rm -rf {} +"));
assert!(!p.is_command_allowed("find / -ok cat {} \\;"));
// git config/alias can execute commands
assert!(!p.is_command_allowed("git config core.editor \"rm -rf /\""));
assert!(!p.is_command_allowed("git alias.st status"));
assert!(!p.is_command_allowed("git -c core.editor=calc.exe commit"));
// Legitimate commands should still work
assert!(p.is_command_allowed("find . -name '*.txt'"));
assert!(p.is_command_allowed("git status"));
assert!(p.is_command_allowed("git add ."));
}
#[test]
fn command_injection_dollar_brace_blocked() {
let p = default_policy();
assert!(!p.is_command_allowed("echo ${IFS}cat${IFS}/etc/passwd"));
}
#[test]
fn command_injection_tee_blocked() {
let p = default_policy();
assert!(!p.is_command_allowed("echo secret | tee /etc/crontab"));
assert!(!p.is_command_allowed("ls | /usr/bin/tee outfile"));
assert!(!p.is_command_allowed("tee file.txt"));
}
#[test]
fn command_injection_process_substitution_blocked() {
let p = default_policy();
assert!(!p.is_command_allowed("cat <(echo pwned)"));
assert!(!p.is_command_allowed("ls >(cat /etc/passwd)"));
}
#[test]
fn command_env_var_prefix_with_allowed_cmd() {
let p = default_policy();

View file

@ -854,7 +854,6 @@ impl BrowserTool {
}
}
#[allow(clippy::too_many_lines)]
#[async_trait]
impl Tool for BrowserTool {
fn name(&self) -> &str {
@ -1031,165 +1030,21 @@ impl Tool for BrowserTool {
return self.execute_computer_use_action(action_str, &args).await;
}
let action = match action_str {
"open" => {
let url = args
.get("url")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing 'url' for open action"))?;
BrowserAction::Open { url: url.into() }
}
"snapshot" => BrowserAction::Snapshot {
interactive_only: args
.get("interactive_only")
.and_then(serde_json::Value::as_bool)
.unwrap_or(true), // Default to interactive for AI
compact: args
.get("compact")
.and_then(serde_json::Value::as_bool)
.unwrap_or(true),
depth: args
.get("depth")
.and_then(serde_json::Value::as_u64)
.map(|d| u32::try_from(d).unwrap_or(u32::MAX)),
},
"click" => {
let selector = args
.get("selector")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing 'selector' for click"))?;
BrowserAction::Click {
selector: selector.into(),
}
}
"fill" => {
let selector = args
.get("selector")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing 'selector' for fill"))?;
let value = args
.get("value")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing 'value' for fill"))?;
BrowserAction::Fill {
selector: selector.into(),
value: value.into(),
}
}
"type" => {
let selector = args
.get("selector")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing 'selector' for type"))?;
let text = args
.get("text")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing 'text' for type"))?;
BrowserAction::Type {
selector: selector.into(),
text: text.into(),
}
}
"get_text" => {
let selector = args
.get("selector")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing 'selector' for get_text"))?;
BrowserAction::GetText {
selector: selector.into(),
}
}
"get_title" => BrowserAction::GetTitle,
"get_url" => BrowserAction::GetUrl,
"screenshot" => BrowserAction::Screenshot {
path: args.get("path").and_then(|v| v.as_str()).map(String::from),
full_page: args
.get("full_page")
.and_then(serde_json::Value::as_bool)
.unwrap_or(false),
},
"wait" => BrowserAction::Wait {
selector: args
.get("selector")
.and_then(|v| v.as_str())
.map(String::from),
ms: args.get("ms").and_then(serde_json::Value::as_u64),
text: args.get("text").and_then(|v| v.as_str()).map(String::from),
},
"press" => {
let key = args
.get("key")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing 'key' for press"))?;
BrowserAction::Press { key: key.into() }
}
"hover" => {
let selector = args
.get("selector")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing 'selector' for hover"))?;
BrowserAction::Hover {
selector: selector.into(),
}
}
"scroll" => {
let direction = args
.get("direction")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing 'direction' for scroll"))?;
BrowserAction::Scroll {
direction: direction.into(),
pixels: args
.get("pixels")
.and_then(serde_json::Value::as_u64)
.map(|p| u32::try_from(p).unwrap_or(u32::MAX)),
}
}
"is_visible" => {
let selector = args
.get("selector")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing 'selector' for is_visible"))?;
BrowserAction::IsVisible {
selector: selector.into(),
}
}
"close" => BrowserAction::Close,
"find" => {
let by = args
.get("by")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing 'by' for find"))?;
let value = args
.get("value")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing 'value' for find"))?;
let action = args
.get("find_action")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing 'find_action' for find"))?;
BrowserAction::Find {
by: by.into(),
value: value.into(),
action: action.into(),
fill_value: args
.get("fill_value")
.and_then(|v| v.as_str())
.map(String::from),
}
}
_ => {
if is_computer_use_only_action(action_str) {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!(
"Action '{action_str}' is unavailable for backend '{}'",
match backend {
ResolvedBackend::AgentBrowser => "agent_browser",
ResolvedBackend::RustNative => "rust_native",
ResolvedBackend::ComputerUse => "computer_use",
error: Some(unavailable_action_for_backend_error(action_str, backend)),
});
}
)),
let action = match parse_browser_action(action_str, &args) {
Ok(a) => a,
Err(e) => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(e.to_string()),
});
}
};
@ -1871,6 +1726,161 @@ mod native_backend {
}
}
// ── Action parsing ──────────────────────────────────────────────
/// Parse a JSON `args` object into a typed `BrowserAction`.
fn parse_browser_action(action_str: &str, args: &Value) -> anyhow::Result<BrowserAction> {
match action_str {
"open" => {
let url = args
.get("url")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing 'url' for open action"))?;
Ok(BrowserAction::Open { url: url.into() })
}
"snapshot" => Ok(BrowserAction::Snapshot {
interactive_only: args
.get("interactive_only")
.and_then(serde_json::Value::as_bool)
.unwrap_or(true),
compact: args
.get("compact")
.and_then(serde_json::Value::as_bool)
.unwrap_or(true),
depth: args
.get("depth")
.and_then(serde_json::Value::as_u64)
.map(|d| u32::try_from(d).unwrap_or(u32::MAX)),
}),
"click" => {
let selector = args
.get("selector")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing 'selector' for click"))?;
Ok(BrowserAction::Click {
selector: selector.into(),
})
}
"fill" => {
let selector = args
.get("selector")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing 'selector' for fill"))?;
let value = args
.get("value")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing 'value' for fill"))?;
Ok(BrowserAction::Fill {
selector: selector.into(),
value: value.into(),
})
}
"type" => {
let selector = args
.get("selector")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing 'selector' for type"))?;
let text = args
.get("text")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing 'text' for type"))?;
Ok(BrowserAction::Type {
selector: selector.into(),
text: text.into(),
})
}
"get_text" => {
let selector = args
.get("selector")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing 'selector' for get_text"))?;
Ok(BrowserAction::GetText {
selector: selector.into(),
})
}
"get_title" => Ok(BrowserAction::GetTitle),
"get_url" => Ok(BrowserAction::GetUrl),
"screenshot" => Ok(BrowserAction::Screenshot {
path: args.get("path").and_then(|v| v.as_str()).map(String::from),
full_page: args
.get("full_page")
.and_then(serde_json::Value::as_bool)
.unwrap_or(false),
}),
"wait" => Ok(BrowserAction::Wait {
selector: args
.get("selector")
.and_then(|v| v.as_str())
.map(String::from),
ms: args.get("ms").and_then(serde_json::Value::as_u64),
text: args.get("text").and_then(|v| v.as_str()).map(String::from),
}),
"press" => {
let key = args
.get("key")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing 'key' for press"))?;
Ok(BrowserAction::Press { key: key.into() })
}
"hover" => {
let selector = args
.get("selector")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing 'selector' for hover"))?;
Ok(BrowserAction::Hover {
selector: selector.into(),
})
}
"scroll" => {
let direction = args
.get("direction")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing 'direction' for scroll"))?;
Ok(BrowserAction::Scroll {
direction: direction.into(),
pixels: args
.get("pixels")
.and_then(serde_json::Value::as_u64)
.map(|p| u32::try_from(p).unwrap_or(u32::MAX)),
})
}
"is_visible" => {
let selector = args
.get("selector")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing 'selector' for is_visible"))?;
Ok(BrowserAction::IsVisible {
selector: selector.into(),
})
}
"close" => Ok(BrowserAction::Close),
"find" => {
let by = args
.get("by")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing 'by' for find"))?;
let value = args
.get("value")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing 'value' for find"))?;
let action = args
.get("find_action")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing 'find_action' for find"))?;
Ok(BrowserAction::Find {
by: by.into(),
value: value.into(),
action: action.into(),
fill_value: args
.get("fill_value")
.and_then(|v| v.as_str())
.map(String::from),
})
}
other => anyhow::bail!("Unsupported browser action: {other}"),
}
}
// ── Helper functions ─────────────────────────────────────────────
fn is_supported_browser_action(action: &str) -> bool {
@ -1901,6 +1911,28 @@ fn is_supported_browser_action(action: &str) -> bool {
)
}
fn is_computer_use_only_action(action: &str) -> bool {
matches!(
action,
"mouse_move" | "mouse_click" | "mouse_drag" | "key_type" | "key_press" | "screen_capture"
)
}
fn backend_name(backend: ResolvedBackend) -> &'static str {
match backend {
ResolvedBackend::AgentBrowser => "agent_browser",
ResolvedBackend::RustNative => "rust_native",
ResolvedBackend::ComputerUse => "computer_use",
}
}
fn unavailable_action_for_backend_error(action: &str, backend: ResolvedBackend) -> String {
format!(
"Action '{action}' is unavailable for backend '{}'",
backend_name(backend)
)
}
fn normalize_domains(domains: Vec<String>) -> Vec<String> {
domains
.into_iter()
@ -2342,4 +2374,28 @@ mod tests {
let tool = BrowserTool::new(security, vec![], None);
assert!(tool.validate_url("https://example.com").is_err());
}
#[test]
fn computer_use_only_action_detection_is_correct() {
assert!(is_computer_use_only_action("mouse_move"));
assert!(is_computer_use_only_action("mouse_click"));
assert!(is_computer_use_only_action("mouse_drag"));
assert!(is_computer_use_only_action("key_type"));
assert!(is_computer_use_only_action("key_press"));
assert!(is_computer_use_only_action("screen_capture"));
assert!(!is_computer_use_only_action("open"));
assert!(!is_computer_use_only_action("snapshot"));
}
#[test]
fn unavailable_action_error_preserves_backend_context() {
assert_eq!(
unavailable_action_for_backend_error("mouse_move", ResolvedBackend::AgentBrowser),
"Action 'mouse_move' is unavailable for backend 'agent_browser'"
);
assert_eq!(
unavailable_action_for_backend_error("mouse_move", ResolvedBackend::RustNative),
"Action 'mouse_move' is unavailable for backend 'rust_native'"
);
}
}

View file

@ -112,12 +112,12 @@ impl ComposioTool {
action_name: &str,
params: serde_json::Value,
entity_id: Option<&str>,
connected_account_id: Option<&str>,
connected_account_ref: Option<&str>,
) -> anyhow::Result<serde_json::Value> {
let tool_slug = normalize_tool_slug(action_name);
match self
.execute_action_v3(&tool_slug, params.clone(), entity_id, connected_account_id)
.execute_action_v3(&tool_slug, params.clone(), entity_id, connected_account_ref)
.await
{
Ok(result) => Ok(result),
@ -130,21 +130,17 @@ impl ComposioTool {
}
}
async fn execute_action_v3(
&self,
fn build_execute_action_v3_request(
tool_slug: &str,
params: serde_json::Value,
entity_id: Option<&str>,
connected_account_id: Option<&str>,
) -> anyhow::Result<serde_json::Value> {
let url = if let Some(connected_account_id) = connected_account_id
.map(str::trim)
.filter(|id| !id.is_empty())
{
format!("{COMPOSIO_API_BASE_V3}/tools/{tool_slug}/execute/{connected_account_id}")
} else {
format!("{COMPOSIO_API_BASE_V3}/tools/{tool_slug}/execute")
};
connected_account_ref: Option<&str>,
) -> (String, serde_json::Value) {
let url = format!("{COMPOSIO_API_BASE_V3}/tools/{tool_slug}/execute");
let account_ref = connected_account_ref.and_then(|candidate| {
let trimmed_candidate = candidate.trim();
(!trimmed_candidate.is_empty()).then_some(trimmed_candidate)
});
let mut body = json!({
"arguments": params,
@ -153,6 +149,26 @@ impl ComposioTool {
if let Some(entity) = entity_id {
body["user_id"] = json!(entity);
}
if let Some(account_ref) = account_ref {
body["connected_account_id"] = json!(account_ref);
}
(url, body)
}
async fn execute_action_v3(
&self,
tool_slug: &str,
params: serde_json::Value,
entity_id: Option<&str>,
connected_account_ref: Option<&str>,
) -> anyhow::Result<serde_json::Value> {
let (url, body) = Self::build_execute_action_v3_request(
tool_slug,
params,
entity_id,
connected_account_ref,
);
let resp = self
.client
@ -474,11 +490,11 @@ impl Tool for ComposioTool {
})?;
let params = args.get("params").cloned().unwrap_or(json!({}));
let connected_account_id =
let connected_account_ref =
args.get("connected_account_id").and_then(|v| v.as_str());
match self
.execute_action(action_name, params, Some(entity_id), connected_account_id)
.execute_action(action_name, params, Some(entity_id), connected_account_ref)
.await
{
Ok(result) => {
@ -594,9 +610,38 @@ async fn response_error(resp: reqwest::Response) -> String {
}
if let Some(api_error) = extract_api_error_message(&body) {
format!("HTTP {}: {api_error}", status.as_u16())
return format!(
"HTTP {}: {}",
status.as_u16(),
sanitize_error_message(&api_error)
);
}
format!("HTTP {}", status.as_u16())
}
fn sanitize_error_message(message: &str) -> String {
let mut sanitized = message.replace('\n', " ");
for marker in [
"connected_account_id",
"connectedAccountId",
"entity_id",
"entityId",
"user_id",
"userId",
] {
sanitized = sanitized.replace(marker, "[redacted]");
}
let max_chars = 240;
if sanitized.chars().count() <= max_chars {
sanitized
} else {
format!("HTTP {}: {body}", status.as_u16())
let mut end = max_chars;
while end > 0 && !sanitized.is_char_boundary(end) {
end -= 1;
}
format!("{}...", &sanitized[..end])
}
}
@ -948,4 +993,40 @@ mod tests {
fn composio_api_base_url_is_v3() {
assert_eq!(COMPOSIO_API_BASE_V3, "https://backend.composio.dev/api/v3");
}
#[test]
fn build_execute_action_v3_request_uses_fixed_endpoint_and_body_account_id() {
let (url, body) = ComposioTool::build_execute_action_v3_request(
"gmail-send-email",
json!({"to": "test@example.com"}),
Some("workspace-user"),
Some("account-42"),
);
assert_eq!(
url,
"https://backend.composio.dev/api/v3/tools/gmail-send-email/execute"
);
assert_eq!(body["arguments"]["to"], json!("test@example.com"));
assert_eq!(body["user_id"], json!("workspace-user"));
assert_eq!(body["connected_account_id"], json!("account-42"));
}
#[test]
fn build_execute_action_v3_request_drops_blank_optional_fields() {
let (url, body) = ComposioTool::build_execute_action_v3_request(
"github-list-repos",
json!({}),
None,
Some(" "),
);
assert_eq!(
url,
"https://backend.composio.dev/api/v3/tools/github-list-repos/execute"
);
assert_eq!(body["arguments"], json!({}));
assert!(body.get("connected_account_id").is_none());
assert!(body.get("user_id").is_none());
}
}

View file

@ -16,8 +16,8 @@ const DELEGATE_TIMEOUT_SECS: u64 = 120;
/// summarization) to purpose-built sub-agents.
pub struct DelegateTool {
agents: Arc<HashMap<String, DelegateAgentConfig>>,
/// Global API key fallback (from config.api_key)
fallback_api_key: Option<String>,
/// Global credential fallback (from config.api_key)
fallback_credential: Option<String>,
/// Depth at which this tool instance lives in the delegation chain.
depth: u32,
}
@ -25,11 +25,11 @@ pub struct DelegateTool {
impl DelegateTool {
pub fn new(
agents: HashMap<String, DelegateAgentConfig>,
fallback_api_key: Option<String>,
fallback_credential: Option<String>,
) -> Self {
Self {
agents: Arc::new(agents),
fallback_api_key,
fallback_credential,
depth: 0,
}
}
@ -39,12 +39,12 @@ impl DelegateTool {
/// their DelegateTool via this method with `depth: parent.depth + 1`.
pub fn with_depth(
agents: HashMap<String, DelegateAgentConfig>,
fallback_api_key: Option<String>,
fallback_credential: Option<String>,
depth: u32,
) -> Self {
Self {
agents: Arc::new(agents),
fallback_api_key,
fallback_credential,
depth,
}
}
@ -165,13 +165,15 @@ impl Tool for DelegateTool {
}
// Create provider for this agent
let api_key = agent_config
let provider_credential_owned = agent_config
.api_key
.as_deref()
.or(self.fallback_api_key.as_deref());
.clone()
.or_else(|| self.fallback_credential.clone());
#[allow(clippy::option_as_ref_deref)]
let provider_credential = provider_credential_owned.as_ref().map(String::as_str);
let provider: Box<dyn Provider> =
match providers::create_provider(&agent_config.provider, api_key) {
match providers::create_provider(&agent_config.provider, provider_credential) {
Ok(p) => p,
Err(e) => {
return Ok(ToolResult {
@ -268,7 +270,7 @@ mod tests {
provider: "openrouter".to_string(),
model: "anthropic/claude-sonnet-4-20250514".to_string(),
system_prompt: None,
api_key: Some("sk-test".to_string()),
api_key: Some("delegate-test-credential".to_string()),
temperature: None,
max_depth: 2,
},

View file

@ -28,13 +28,22 @@ impl GitOperationsTool {
if arg_lower.starts_with("--exec=")
|| arg_lower.starts_with("--upload-pack=")
|| arg_lower.starts_with("--receive-pack=")
|| arg_lower.starts_with("--pager=")
|| arg_lower.starts_with("--editor=")
|| arg_lower == "--no-verify"
|| arg_lower.contains("$(")
|| arg_lower.contains('`')
|| arg.contains('|')
|| arg.contains(';')
|| arg.contains('>')
{
anyhow::bail!("Blocked potentially dangerous git argument: {arg}");
}
// Block `-c` config injection (exact match or `-c=...` prefix).
// This must not false-positive on `--cached` or `-cached`.
if arg_lower == "-c" || arg_lower.starts_with("-c=") {
anyhow::bail!("Blocked potentially dangerous git argument: {arg}");
}
result.push(arg.to_string());
}
Ok(result)
@ -129,6 +138,9 @@ impl GitOperationsTool {
.and_then(|v| v.as_bool())
.unwrap_or(false);
// Validate files argument against injection patterns
self.sanitize_git_args(files)?;
let mut git_args = vec!["diff", "--unified=3"];
if cached {
git_args.push("--cached");
@ -267,6 +279,14 @@ impl GitOperationsTool {
})
}
fn truncate_commit_message(message: &str) -> String {
if message.chars().count() > 2000 {
format!("{}...", message.chars().take(1997).collect::<String>())
} else {
message.to_string()
}
}
async fn git_commit(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
let message = args
.get("message")
@ -286,11 +306,7 @@ impl GitOperationsTool {
}
// Limit message length
let message = if sanitized.len() > 2000 {
format!("{}...", &sanitized[..1997])
} else {
sanitized
};
let message = Self::truncate_commit_message(&sanitized);
let output = self.run_git_command(&["commit", "-m", &message]).await;
@ -314,6 +330,9 @@ impl GitOperationsTool {
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing 'paths' parameter"))?;
// Validate paths against injection patterns
self.sanitize_git_args(paths)?;
let output = self.run_git_command(&["add", "--", paths]).await;
match output {
@ -574,6 +593,52 @@ mod tests {
assert!(tool.sanitize_git_args("arg; rm file").is_err());
}
#[test]
fn sanitize_git_blocks_pager_editor_injection() {
let tmp = TempDir::new().unwrap();
let tool = test_tool(tmp.path());
assert!(tool.sanitize_git_args("--pager=less").is_err());
assert!(tool.sanitize_git_args("--editor=vim").is_err());
}
#[test]
fn sanitize_git_blocks_config_injection() {
let tmp = TempDir::new().unwrap();
let tool = test_tool(tmp.path());
// Exact `-c` flag (config injection)
assert!(tool.sanitize_git_args("-c core.sshCommand=evil").is_err());
assert!(tool.sanitize_git_args("-c=core.pager=less").is_err());
}
#[test]
fn sanitize_git_blocks_no_verify() {
let tmp = TempDir::new().unwrap();
let tool = test_tool(tmp.path());
assert!(tool.sanitize_git_args("--no-verify").is_err());
}
#[test]
fn sanitize_git_blocks_redirect_in_args() {
let tmp = TempDir::new().unwrap();
let tool = test_tool(tmp.path());
assert!(tool.sanitize_git_args("file.txt > /tmp/out").is_err());
}
#[test]
fn sanitize_git_cached_not_blocked() {
let tmp = TempDir::new().unwrap();
let tool = test_tool(tmp.path());
// --cached must NOT be blocked by the `-c` check
assert!(tool.sanitize_git_args("--cached").is_ok());
// Other safe flags starting with -c prefix
assert!(tool.sanitize_git_args("-cached").is_ok());
}
#[test]
fn sanitize_git_allows_safe() {
let tmp = TempDir::new().unwrap();
@ -583,6 +648,8 @@ mod tests {
assert!(tool.sanitize_git_args("main").is_ok());
assert!(tool.sanitize_git_args("feature/test-branch").is_ok());
assert!(tool.sanitize_git_args("--cached").is_ok());
assert!(tool.sanitize_git_args("src/main.rs").is_ok());
assert!(tool.sanitize_git_args(".").is_ok());
}
#[test]
@ -691,4 +758,12 @@ mod tests {
.unwrap_or("")
.contains("Unknown operation"));
}
#[test]
fn truncates_multibyte_commit_message_without_panicking() {
let long = "🦀".repeat(2500);
let truncated = GitOperationsTool::truncate_commit_message(&long);
assert_eq!(truncated.chars().count(), 2000);
}
}

View file

@ -124,10 +124,11 @@ impl Tool for HardwareBoardInfoTool {
});
}
Err(e) => {
output.push_str(&format!(
"probe-rs attach failed: {}. Using static info.\n\n",
e
));
use std::fmt::Write;
let _ = write!(
output,
"probe-rs attach failed: {e}. Using static info.\n\n"
);
}
}
}
@ -135,13 +136,15 @@ impl Tool for HardwareBoardInfoTool {
if let Some(info) = self.static_info_for_board(board) {
output.push_str(&info);
if let Some(mem) = memory_map_static(board) {
output.push_str(&format!("\n\n**Memory map:**\n{}", mem));
use std::fmt::Write;
let _ = write!(output, "\n\n**Memory map:**\n{mem}");
}
} else {
output.push_str(&format!(
"Board '{}' configured. No static info available.",
board
));
use std::fmt::Write;
let _ = write!(
output,
"Board '{board}' configured. No static info available."
);
}
Ok(ToolResult {

View file

@ -122,14 +122,16 @@ impl Tool for HardwareMemoryMapTool {
if !probe_ok {
if let Some(map) = self.static_map_for_board(board) {
output.push_str(&format!("**{}** (from datasheet):\n{}", board, map));
use std::fmt::Write;
let _ = write!(output, "**{board}** (from datasheet):\n{map}");
} else {
use std::fmt::Write;
let known: Vec<&str> = MEMORY_MAPS.iter().map(|(b, _)| *b).collect();
output.push_str(&format!(
"No memory map for board '{}'. Known boards: {}",
board,
let _ = write!(
output,
"No memory map for board '{board}'. Known boards: {}",
known.join(", ")
));
);
}
}

View file

@ -94,14 +94,16 @@ impl Tool for HardwareMemoryReadTool {
.get("address")
.and_then(|v| v.as_str())
.unwrap_or("0x20000000");
let address = parse_hex_address(address_str).unwrap_or(NUCLEO_RAM_BASE);
let _address = parse_hex_address(address_str).unwrap_or(NUCLEO_RAM_BASE);
let length = args.get("length").and_then(|v| v.as_u64()).unwrap_or(128) as usize;
let length = length.min(256).max(1);
let requested_length = args.get("length").and_then(|v| v.as_u64()).unwrap_or(128);
let _length = usize::try_from(requested_length)
.unwrap_or(256)
.clamp(1, 256);
#[cfg(feature = "probe")]
{
match probe_read_memory(chip.unwrap(), address, length) {
match probe_read_memory(chip.unwrap(), _address, _length) {
Ok(output) => {
return Ok(ToolResult {
success: true,

View file

@ -749,4 +749,54 @@ mod tests {
let _ = HttpRequestTool::redact_headers_for_display(&headers);
assert_eq!(headers[0].1, "Bearer real-token");
}
// ── SSRF: alternate IP notation bypass defense-in-depth ─────────
//
// Rust's IpAddr::parse() rejects non-standard notations (octal, hex,
// decimal integer, zero-padded). These tests document that property
// so regressions are caught if the parsing strategy ever changes.
#[test]
fn ssrf_octal_loopback_not_parsed_as_ip() {
// 0177.0.0.1 is octal for 127.0.0.1 in some languages, but
// Rust's IpAddr rejects it — it falls through as a hostname.
assert!(!is_private_or_local_host("0177.0.0.1"));
}
#[test]
fn ssrf_hex_loopback_not_parsed_as_ip() {
// 0x7f000001 is hex for 127.0.0.1 in some languages.
assert!(!is_private_or_local_host("0x7f000001"));
}
#[test]
fn ssrf_decimal_loopback_not_parsed_as_ip() {
// 2130706433 is decimal for 127.0.0.1 in some languages.
assert!(!is_private_or_local_host("2130706433"));
}
#[test]
fn ssrf_zero_padded_loopback_not_parsed_as_ip() {
// 127.000.000.001 uses zero-padded octets.
assert!(!is_private_or_local_host("127.000.000.001"));
}
#[test]
fn ssrf_alternate_notations_rejected_by_validate_url() {
// Even if is_private_or_local_host doesn't flag these, they
// fail the allowlist because they're treated as hostnames.
let tool = test_tool(vec!["example.com"]);
for notation in [
"http://0177.0.0.1",
"http://0x7f000001",
"http://2130706433",
"http://127.000.000.001",
] {
let err = tool.validate_url(notation).unwrap_err().to_string();
assert!(
err.contains("allowed_domains"),
"Expected allowlist rejection for {notation}, got: {err}"
);
}
}
}

View file

@ -87,7 +87,7 @@ mod tests {
#[tokio::test]
async fn forget_existing() {
let (_tmp, mem) = test_mem();
mem.store("temp", "temporary", MemoryCategory::Conversation)
mem.store("temp", "temporary", MemoryCategory::Conversation, None)
.await
.unwrap();

View file

@ -55,7 +55,7 @@ impl Tool for MemoryRecallTool {
.and_then(serde_json::Value::as_u64)
.map_or(5, |v| v as usize);
match self.memory.recall(query, limit).await {
match self.memory.recall(query, limit, None).await {
Ok(entries) if entries.is_empty() => Ok(ToolResult {
success: true,
output: "No memories found matching that query.".into(),
@ -112,10 +112,10 @@ mod tests {
#[tokio::test]
async fn recall_finds_match() {
let (_tmp, mem) = seeded_mem();
mem.store("lang", "User prefers Rust", MemoryCategory::Core)
mem.store("lang", "User prefers Rust", MemoryCategory::Core, None)
.await
.unwrap();
mem.store("tz", "Timezone is EST", MemoryCategory::Core)
mem.store("tz", "Timezone is EST", MemoryCategory::Core, None)
.await
.unwrap();
@ -134,6 +134,7 @@ mod tests {
&format!("k{i}"),
&format!("Rust fact {i}"),
MemoryCategory::Core,
None,
)
.await
.unwrap();

View file

@ -64,7 +64,7 @@ impl Tool for MemoryStoreTool {
_ => MemoryCategory::Core,
};
match self.memory.store(key, content, category).await {
match self.memory.store(key, content, category, None).await {
Ok(()) => Ok(ToolResult {
success: true,
output: format!("Stored memory: {key}"),

View file

@ -19,7 +19,9 @@ pub mod image_info;
pub mod memory_forget;
pub mod memory_recall;
pub mod memory_store;
pub mod pushover;
pub mod schedule;
pub mod schema;
pub mod screenshot;
pub mod shell;
pub mod traits;
@ -45,7 +47,9 @@ pub use image_info::ImageInfoTool;
pub use memory_forget::MemoryForgetTool;
pub use memory_recall::MemoryRecallTool;
pub use memory_store::MemoryStoreTool;
pub use pushover::PushoverTool;
pub use schedule::ScheduleTool;
pub use schema::{CleaningStrategy, SchemaCleanr};
pub use screenshot::ScreenshotTool;
pub use shell::ShellTool;
pub use traits::Tool;
@ -141,6 +145,10 @@ pub fn all_tools_with_runtime(
security.clone(),
workspace_dir.to_path_buf(),
)),
Box::new(PushoverTool::new(
security.clone(),
workspace_dir.to_path_buf(),
)),
];
if browser_config.enabled {
@ -195,9 +203,13 @@ pub fn all_tools_with_runtime(
.iter()
.map(|(name, cfg)| (name.clone(), cfg.clone()))
.collect();
let delegate_fallback_credential = fallback_api_key.and_then(|value| {
let trimmed_value = value.trim();
(!trimmed_value.is_empty()).then(|| trimmed_value.to_owned())
});
tools.push(Box::new(DelegateTool::new(
delegate_agents,
fallback_api_key.map(String::from),
delegate_fallback_credential,
)));
}
@ -261,6 +273,7 @@ mod tests {
let names: Vec<&str> = tools.iter().map(|t| t.name()).collect();
assert!(!names.contains(&"browser_open"));
assert!(names.contains(&"schedule"));
assert!(names.contains(&"pushover"));
}
#[test]
@ -298,6 +311,7 @@ mod tests {
);
let names: Vec<&str> = tools.iter().map(|t| t.name()).collect();
assert!(names.contains(&"browser_open"));
assert!(names.contains(&"pushover"));
}
#[test]
@ -432,7 +446,7 @@ mod tests {
&http,
tmp.path(),
&agents,
Some("sk-test"),
Some("delegate-test-credential"),
&cfg,
);
let names: Vec<&str> = tools.iter().map(|t| t.name()).collect();

442
src/tools/pushover.rs Normal file
View file

@ -0,0 +1,442 @@
use super::traits::{Tool, ToolResult};
use crate::security::SecurityPolicy;
use async_trait::async_trait;
use reqwest::Client;
use serde_json::json;
use std::path::PathBuf;
use std::sync::Arc;
use std::time::Duration;
const PUSHOVER_API_URL: &str = "https://api.pushover.net/1/messages.json";
const PUSHOVER_REQUEST_TIMEOUT_SECS: u64 = 15;
pub struct PushoverTool {
client: Client,
security: Arc<SecurityPolicy>,
workspace_dir: PathBuf,
}
impl PushoverTool {
pub fn new(security: Arc<SecurityPolicy>, workspace_dir: PathBuf) -> Self {
let client = Client::builder()
.timeout(Duration::from_secs(PUSHOVER_REQUEST_TIMEOUT_SECS))
.build()
.unwrap_or_else(|_| Client::new());
Self {
client,
security,
workspace_dir,
}
}
fn parse_env_value(raw: &str) -> String {
let raw = raw.trim();
let unquoted = if raw.len() >= 2
&& ((raw.starts_with('"') && raw.ends_with('"'))
|| (raw.starts_with('\'') && raw.ends_with('\'')))
{
&raw[1..raw.len() - 1]
} else {
raw
};
// Keep support for inline comments in unquoted values:
// KEY=value # comment
unquoted.split_once(" #").map_or_else(
|| unquoted.trim().to_string(),
|(value, _)| value.trim().to_string(),
)
}
fn get_credentials(&self) -> anyhow::Result<(String, String)> {
let env_path = self.workspace_dir.join(".env");
let content = std::fs::read_to_string(&env_path)
.map_err(|e| anyhow::anyhow!("Failed to read {}: {}", env_path.display(), e))?;
let mut token = None;
let mut user_key = None;
for line in content.lines() {
let line = line.trim();
if line.starts_with('#') || line.is_empty() {
continue;
}
let line = line.strip_prefix("export ").map(str::trim).unwrap_or(line);
if let Some((key, value)) = line.split_once('=') {
let key = key.trim();
let value = Self::parse_env_value(value);
if key.eq_ignore_ascii_case("PUSHOVER_TOKEN") {
token = Some(value);
} else if key.eq_ignore_ascii_case("PUSHOVER_USER_KEY") {
user_key = Some(value);
}
}
}
let token = token.ok_or_else(|| anyhow::anyhow!("PUSHOVER_TOKEN not found in .env"))?;
let user_key =
user_key.ok_or_else(|| anyhow::anyhow!("PUSHOVER_USER_KEY not found in .env"))?;
Ok((token, user_key))
}
}
#[async_trait]
impl Tool for PushoverTool {
fn name(&self) -> &str {
"pushover"
}
fn description(&self) -> &str {
"Send a Pushover notification to your device. Requires PUSHOVER_TOKEN and PUSHOVER_USER_KEY in .env file."
}
fn parameters_schema(&self) -> serde_json::Value {
json!({
"type": "object",
"properties": {
"message": {
"type": "string",
"description": "The notification message to send"
},
"title": {
"type": "string",
"description": "Optional notification title"
},
"priority": {
"type": "integer",
"enum": [-2, -1, 0, 1, 2],
"description": "Message priority: -2 (lowest/silent), -1 (low/no sound), 0 (normal), 1 (high), 2 (emergency/repeating)"
},
"sound": {
"type": "string",
"description": "Notification sound override (e.g., 'pushover', 'bike', 'bugle', 'cashregister', etc.)"
}
},
"required": ["message"]
})
}
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
if !self.security.can_act() {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some("Action blocked: autonomy is read-only".into()),
});
}
if !self.security.record_action() {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some("Action blocked: rate limit exceeded".into()),
});
}
let message = args
.get("message")
.and_then(|v| v.as_str())
.map(str::trim)
.filter(|v| !v.is_empty())
.ok_or_else(|| anyhow::anyhow!("Missing 'message' parameter"))?
.to_string();
let title = args.get("title").and_then(|v| v.as_str()).map(String::from);
let priority = match args.get("priority").and_then(|v| v.as_i64()) {
Some(value) if (-2..=2).contains(&value) => Some(value),
Some(value) => {
return Ok(ToolResult {
success: false,
output: String::new(),
error: Some(format!(
"Invalid 'priority': {value}. Expected integer in range -2..=2"
)),
})
}
None => None,
};
let sound = args.get("sound").and_then(|v| v.as_str()).map(String::from);
let (token, user_key) = self.get_credentials()?;
let mut form = reqwest::multipart::Form::new()
.text("token", token)
.text("user", user_key)
.text("message", message);
if let Some(title) = title {
form = form.text("title", title);
}
if let Some(priority) = priority {
form = form.text("priority", priority.to_string());
}
if let Some(sound) = sound {
form = form.text("sound", sound);
}
let response = self
.client
.post(PUSHOVER_API_URL)
.multipart(form)
.send()
.await?;
let status = response.status();
let body = response.text().await.unwrap_or_default();
if !status.is_success() {
return Ok(ToolResult {
success: false,
output: body,
error: Some(format!("Pushover API returned status {}", status)),
});
}
let api_status = serde_json::from_str::<serde_json::Value>(&body)
.ok()
.and_then(|json| json.get("status").and_then(|value| value.as_i64()));
if api_status == Some(1) {
Ok(ToolResult {
success: true,
output: format!(
"Pushover notification sent successfully. Response: {}",
body
),
error: None,
})
} else {
Ok(ToolResult {
success: false,
output: body,
error: Some("Pushover API returned an application-level error".into()),
})
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::security::AutonomyLevel;
use std::fs;
use tempfile::TempDir;
fn test_security(level: AutonomyLevel, max_actions_per_hour: u32) -> Arc<SecurityPolicy> {
Arc::new(SecurityPolicy {
autonomy: level,
max_actions_per_hour,
workspace_dir: std::env::temp_dir(),
..SecurityPolicy::default()
})
}
#[test]
fn pushover_tool_name() {
let tool = PushoverTool::new(
test_security(AutonomyLevel::Full, 100),
PathBuf::from("/tmp"),
);
assert_eq!(tool.name(), "pushover");
}
#[test]
fn pushover_tool_description() {
let tool = PushoverTool::new(
test_security(AutonomyLevel::Full, 100),
PathBuf::from("/tmp"),
);
assert!(!tool.description().is_empty());
}
#[test]
fn pushover_tool_has_parameters_schema() {
let tool = PushoverTool::new(
test_security(AutonomyLevel::Full, 100),
PathBuf::from("/tmp"),
);
let schema = tool.parameters_schema();
assert_eq!(schema["type"], "object");
assert!(schema["properties"].get("message").is_some());
}
#[test]
fn pushover_tool_requires_message() {
let tool = PushoverTool::new(
test_security(AutonomyLevel::Full, 100),
PathBuf::from("/tmp"),
);
let schema = tool.parameters_schema();
let required = schema["required"].as_array().unwrap();
assert!(required.contains(&serde_json::Value::String("message".to_string())));
}
#[test]
fn credentials_parsed_from_env_file() {
let tmp = TempDir::new().unwrap();
let env_path = tmp.path().join(".env");
fs::write(
&env_path,
"PUSHOVER_TOKEN=testtoken123\nPUSHOVER_USER_KEY=userkey456\n",
)
.unwrap();
let tool = PushoverTool::new(
test_security(AutonomyLevel::Full, 100),
tmp.path().to_path_buf(),
);
let result = tool.get_credentials();
assert!(result.is_ok());
let (token, user_key) = result.unwrap();
assert_eq!(token, "testtoken123");
assert_eq!(user_key, "userkey456");
}
#[test]
fn credentials_fail_without_env_file() {
let tmp = TempDir::new().unwrap();
let tool = PushoverTool::new(
test_security(AutonomyLevel::Full, 100),
tmp.path().to_path_buf(),
);
let result = tool.get_credentials();
assert!(result.is_err());
}
#[test]
fn credentials_fail_without_token() {
let tmp = TempDir::new().unwrap();
let env_path = tmp.path().join(".env");
fs::write(&env_path, "PUSHOVER_USER_KEY=userkey456\n").unwrap();
let tool = PushoverTool::new(
test_security(AutonomyLevel::Full, 100),
tmp.path().to_path_buf(),
);
let result = tool.get_credentials();
assert!(result.is_err());
}
#[test]
fn credentials_fail_without_user_key() {
let tmp = TempDir::new().unwrap();
let env_path = tmp.path().join(".env");
fs::write(&env_path, "PUSHOVER_TOKEN=testtoken123\n").unwrap();
let tool = PushoverTool::new(
test_security(AutonomyLevel::Full, 100),
tmp.path().to_path_buf(),
);
let result = tool.get_credentials();
assert!(result.is_err());
}
#[test]
fn credentials_ignore_comments() {
let tmp = TempDir::new().unwrap();
let env_path = tmp.path().join(".env");
fs::write(&env_path, "# This is a comment\nPUSHOVER_TOKEN=realtoken\n# Another comment\nPUSHOVER_USER_KEY=realuser\n").unwrap();
let tool = PushoverTool::new(
test_security(AutonomyLevel::Full, 100),
tmp.path().to_path_buf(),
);
let result = tool.get_credentials();
assert!(result.is_ok());
let (token, user_key) = result.unwrap();
assert_eq!(token, "realtoken");
assert_eq!(user_key, "realuser");
}
#[test]
fn pushover_tool_supports_priority() {
let tool = PushoverTool::new(
test_security(AutonomyLevel::Full, 100),
PathBuf::from("/tmp"),
);
let schema = tool.parameters_schema();
assert!(schema["properties"].get("priority").is_some());
}
#[test]
fn pushover_tool_supports_sound() {
let tool = PushoverTool::new(
test_security(AutonomyLevel::Full, 100),
PathBuf::from("/tmp"),
);
let schema = tool.parameters_schema();
assert!(schema["properties"].get("sound").is_some());
}
#[test]
fn credentials_support_export_and_quoted_values() {
let tmp = TempDir::new().unwrap();
let env_path = tmp.path().join(".env");
fs::write(
&env_path,
"export PUSHOVER_TOKEN=\"quotedtoken\"\nPUSHOVER_USER_KEY='quoteduser'\n",
)
.unwrap();
let tool = PushoverTool::new(
test_security(AutonomyLevel::Full, 100),
tmp.path().to_path_buf(),
);
let result = tool.get_credentials();
assert!(result.is_ok());
let (token, user_key) = result.unwrap();
assert_eq!(token, "quotedtoken");
assert_eq!(user_key, "quoteduser");
}
#[tokio::test]
async fn execute_blocks_readonly_mode() {
let tool = PushoverTool::new(
test_security(AutonomyLevel::ReadOnly, 100),
PathBuf::from("/tmp"),
);
let result = tool.execute(json!({"message": "hello"})).await.unwrap();
assert!(!result.success);
assert!(result.error.unwrap().contains("read-only"));
}
#[tokio::test]
async fn execute_blocks_rate_limit() {
let tool = PushoverTool::new(test_security(AutonomyLevel::Full, 0), PathBuf::from("/tmp"));
let result = tool.execute(json!({"message": "hello"})).await.unwrap();
assert!(!result.success);
assert!(result.error.unwrap().contains("rate limit"));
}
#[tokio::test]
async fn execute_rejects_priority_out_of_range() {
let tool = PushoverTool::new(
test_security(AutonomyLevel::Full, 100),
PathBuf::from("/tmp"),
);
let result = tool
.execute(json!({"message": "hello", "priority": 5}))
.await
.unwrap();
assert!(!result.success);
assert!(result.error.unwrap().contains("-2..=2"));
}
}

838
src/tools/schema.rs Normal file
View file

@ -0,0 +1,838 @@
//! JSON Schema cleaning and validation for LLM tool-calling compatibility.
//!
//! Different providers support different subsets of JSON Schema. This module
//! normalizes tool schemas to improve cross-provider compatibility while
//! preserving semantic intent.
//!
//! ## What this module does
//!
//! 1. Removes unsupported keywords per provider strategy
//! 2. Resolves local `$ref` entries from `$defs` and `definitions`
//! 3. Flattens literal `anyOf` / `oneOf` unions into `enum`
//! 4. Strips nullable variants from unions and `type` arrays
//! 5. Converts `const` to single-value `enum`
//! 6. Detects circular references and stops recursion safely
//!
//! # Example
//!
//! ```rust
//! use serde_json::json;
//! use zeroclaw::tools::schema::SchemaCleanr;
//!
//! let dirty_schema = json!({
//! "type": "object",
//! "properties": {
//! "name": {
//! "type": "string",
//! "minLength": 1, // Gemini rejects this
//! "pattern": "^[a-z]+$" // Gemini rejects this
//! },
//! "age": {
//! "$ref": "#/$defs/Age" // Needs resolution
//! }
//! },
//! "$defs": {
//! "Age": {
//! "type": "integer",
//! "minimum": 0 // Gemini rejects this
//! }
//! }
//! });
//!
//! let cleaned = SchemaCleanr::clean_for_gemini(dirty_schema);
//!
//! // Result:
//! // {
//! // "type": "object",
//! // "properties": {
//! // "name": { "type": "string" },
//! // "age": { "type": "integer" }
//! // }
//! // }
//! ```
//!
use serde_json::{json, Map, Value};
use std::collections::{HashMap, HashSet};
/// Keywords that Gemini rejects for tool schemas.
pub const GEMINI_UNSUPPORTED_KEYWORDS: &[&str] = &[
// Schema composition
"$ref",
"$schema",
"$id",
"$defs",
"definitions",
// Property constraints
"additionalProperties",
"patternProperties",
// String constraints
"minLength",
"maxLength",
"pattern",
"format",
// Number constraints
"minimum",
"maximum",
"multipleOf",
// Array constraints
"minItems",
"maxItems",
"uniqueItems",
// Object constraints
"minProperties",
"maxProperties",
// Non-standard
"examples", // OpenAPI keyword, not JSON Schema
];
/// Keywords that should be preserved during cleaning (metadata).
const SCHEMA_META_KEYS: &[&str] = &["description", "title", "default"];
/// Schema cleaning strategies for different LLM providers.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CleaningStrategy {
/// Gemini (Google AI / Vertex AI) - Most restrictive
Gemini,
/// Anthropic Claude - Moderately permissive
Anthropic,
/// OpenAI GPT - Most permissive
OpenAI,
/// Conservative: Remove only universally unsupported keywords
Conservative,
}
impl CleaningStrategy {
/// Get the list of unsupported keywords for this strategy.
pub fn unsupported_keywords(&self) -> &'static [&'static str] {
match self {
Self::Gemini => GEMINI_UNSUPPORTED_KEYWORDS,
Self::Anthropic => &["$ref", "$defs", "definitions"], // Anthropic doesn't resolve refs
Self::OpenAI => &[], // OpenAI is most permissive
Self::Conservative => &["$ref", "$defs", "definitions", "additionalProperties"],
}
}
}
/// JSON Schema cleaner optimized for LLM tool calling.
pub struct SchemaCleanr;
impl SchemaCleanr {
/// Clean schema for Gemini compatibility (strictest).
///
/// This is the most aggressive cleaning strategy, removing all keywords
/// that Gemini's API rejects.
pub fn clean_for_gemini(schema: Value) -> Value {
Self::clean(schema, CleaningStrategy::Gemini)
}
/// Clean schema for Anthropic compatibility.
pub fn clean_for_anthropic(schema: Value) -> Value {
Self::clean(schema, CleaningStrategy::Anthropic)
}
/// Clean schema for OpenAI compatibility (most permissive).
pub fn clean_for_openai(schema: Value) -> Value {
Self::clean(schema, CleaningStrategy::OpenAI)
}
/// Clean schema with specified strategy.
pub fn clean(schema: Value, strategy: CleaningStrategy) -> Value {
// Extract $defs for reference resolution
let defs = if let Some(obj) = schema.as_object() {
Self::extract_defs(obj)
} else {
HashMap::new()
};
Self::clean_with_defs(schema, &defs, strategy, &mut HashSet::new())
}
/// Validate that a schema is suitable for LLM tool calling.
///
/// Returns an error if the schema is invalid or missing required fields.
pub fn validate(schema: &Value) -> anyhow::Result<()> {
let obj = schema
.as_object()
.ok_or_else(|| anyhow::anyhow!("Schema must be an object"))?;
// Must have 'type' field
if !obj.contains_key("type") {
anyhow::bail!("Schema missing required 'type' field");
}
// If type is 'object', should have 'properties'
if let Some(Value::String(t)) = obj.get("type") {
if t == "object" && !obj.contains_key("properties") {
tracing::warn!("Object schema without 'properties' field may cause issues");
}
}
Ok(())
}
// --------------------------------------------------------------------
// Internal implementation
// --------------------------------------------------------------------
/// Extract $defs and definitions into a flat map for reference resolution.
fn extract_defs(obj: &Map<String, Value>) -> HashMap<String, Value> {
let mut defs = HashMap::new();
// Extract from $defs (JSON Schema 2019-09+)
if let Some(Value::Object(defs_obj)) = obj.get("$defs") {
for (key, value) in defs_obj {
defs.insert(key.clone(), value.clone());
}
}
// Extract from definitions (JSON Schema draft-07)
if let Some(Value::Object(defs_obj)) = obj.get("definitions") {
for (key, value) in defs_obj {
defs.insert(key.clone(), value.clone());
}
}
defs
}
/// Recursively clean a schema value.
fn clean_with_defs(
schema: Value,
defs: &HashMap<String, Value>,
strategy: CleaningStrategy,
ref_stack: &mut HashSet<String>,
) -> Value {
match schema {
Value::Object(obj) => Self::clean_object(obj, defs, strategy, ref_stack),
Value::Array(arr) => Value::Array(
arr.into_iter()
.map(|v| Self::clean_with_defs(v, defs, strategy, ref_stack))
.collect(),
),
other => other,
}
}
/// Clean an object schema.
fn clean_object(
obj: Map<String, Value>,
defs: &HashMap<String, Value>,
strategy: CleaningStrategy,
ref_stack: &mut HashSet<String>,
) -> Value {
// Handle $ref resolution
if let Some(Value::String(ref_value)) = obj.get("$ref") {
return Self::resolve_ref(ref_value, &obj, defs, strategy, ref_stack);
}
// Handle anyOf/oneOf simplification
if obj.contains_key("anyOf") || obj.contains_key("oneOf") {
if let Some(simplified) = Self::try_simplify_union(&obj, defs, strategy, ref_stack) {
return simplified;
}
}
// Build cleaned object
let mut cleaned = Map::new();
let unsupported: HashSet<&str> = strategy.unsupported_keywords().iter().copied().collect();
let has_union = obj.contains_key("anyOf") || obj.contains_key("oneOf");
for (key, value) in obj {
// Skip unsupported keywords
if unsupported.contains(key.as_str()) {
continue;
}
// Special handling for specific keys
match key.as_str() {
// Convert const to enum
"const" => {
cleaned.insert("enum".to_string(), json!([value]));
}
// Skip type if we have anyOf/oneOf (they define the type)
"type" if has_union => {
// Skip
}
// Handle type arrays (remove null)
"type" if matches!(value, Value::Array(_)) => {
let cleaned_value = Self::clean_type_array(value);
cleaned.insert(key, cleaned_value);
}
// Recursively clean nested schemas
"properties" => {
let cleaned_value = Self::clean_properties(value, defs, strategy, ref_stack);
cleaned.insert(key, cleaned_value);
}
"items" => {
let cleaned_value = Self::clean_with_defs(value, defs, strategy, ref_stack);
cleaned.insert(key, cleaned_value);
}
"anyOf" | "oneOf" | "allOf" => {
let cleaned_value = Self::clean_union(value, defs, strategy, ref_stack);
cleaned.insert(key, cleaned_value);
}
// Keep all other keys, cleaning nested objects/arrays recursively.
_ => {
let cleaned_value = match value {
Value::Object(_) | Value::Array(_) => {
Self::clean_with_defs(value, defs, strategy, ref_stack)
}
other => other,
};
cleaned.insert(key, cleaned_value);
}
}
}
Value::Object(cleaned)
}
/// Resolve a $ref to its definition.
fn resolve_ref(
ref_value: &str,
obj: &Map<String, Value>,
defs: &HashMap<String, Value>,
strategy: CleaningStrategy,
ref_stack: &mut HashSet<String>,
) -> Value {
// Prevent circular references
if ref_stack.contains(ref_value) {
tracing::warn!("Circular $ref detected: {}", ref_value);
return Self::preserve_meta(obj, Value::Object(Map::new()));
}
// Try to resolve local ref (#/$defs/Name or #/definitions/Name)
if let Some(def_name) = Self::parse_local_ref(ref_value) {
if let Some(definition) = defs.get(def_name.as_str()) {
ref_stack.insert(ref_value.to_string());
let cleaned = Self::clean_with_defs(definition.clone(), defs, strategy, ref_stack);
ref_stack.remove(ref_value);
return Self::preserve_meta(obj, cleaned);
}
}
// Can't resolve: return empty object with metadata
tracing::warn!("Cannot resolve $ref: {}", ref_value);
Self::preserve_meta(obj, Value::Object(Map::new()))
}
/// Parse a local JSON Pointer ref (#/$defs/Name).
fn parse_local_ref(ref_value: &str) -> Option<String> {
ref_value
.strip_prefix("#/$defs/")
.or_else(|| ref_value.strip_prefix("#/definitions/"))
.map(Self::decode_json_pointer)
}
/// Decode JSON Pointer escaping (`~0` = `~`, `~1` = `/`).
fn decode_json_pointer(segment: &str) -> String {
if !segment.contains('~') {
return segment.to_string();
}
let mut decoded = String::with_capacity(segment.len());
let mut chars = segment.chars().peekable();
while let Some(ch) = chars.next() {
if ch == '~' {
match chars.peek().copied() {
Some('0') => {
chars.next();
decoded.push('~');
}
Some('1') => {
chars.next();
decoded.push('/');
}
_ => decoded.push('~'),
}
} else {
decoded.push(ch);
}
}
decoded
}
/// Try to simplify anyOf/oneOf to a simpler form.
fn try_simplify_union(
obj: &Map<String, Value>,
defs: &HashMap<String, Value>,
strategy: CleaningStrategy,
ref_stack: &mut HashSet<String>,
) -> Option<Value> {
let union_key = if obj.contains_key("anyOf") {
"anyOf"
} else if obj.contains_key("oneOf") {
"oneOf"
} else {
return None;
};
let variants = obj.get(union_key)?.as_array()?;
// Clean all variants first
let cleaned_variants: Vec<Value> = variants
.iter()
.map(|v| Self::clean_with_defs(v.clone(), defs, strategy, ref_stack))
.collect();
// Strip null variants
let non_null: Vec<Value> = cleaned_variants
.into_iter()
.filter(|v| !Self::is_null_schema(v))
.collect();
// If only one variant remains after stripping nulls, return it
if non_null.len() == 1 {
return Some(Self::preserve_meta(obj, non_null[0].clone()));
}
// Try to flatten to enum if all variants are literals
if let Some(enum_value) = Self::try_flatten_literal_union(&non_null) {
return Some(Self::preserve_meta(obj, enum_value));
}
None
}
/// Check if a schema represents null type.
fn is_null_schema(value: &Value) -> bool {
if let Some(obj) = value.as_object() {
// { const: null }
if let Some(Value::Null) = obj.get("const") {
return true;
}
// { enum: [null] }
if let Some(Value::Array(arr)) = obj.get("enum") {
if arr.len() == 1 && matches!(arr[0], Value::Null) {
return true;
}
}
// { type: "null" }
if let Some(Value::String(t)) = obj.get("type") {
if t == "null" {
return true;
}
}
}
false
}
/// Try to flatten anyOf/oneOf with only literal values to enum.
///
/// Example: `anyOf: [{const: "a"}, {const: "b"}]` -> `{type: "string", enum: ["a", "b"]}`
fn try_flatten_literal_union(variants: &[Value]) -> Option<Value> {
if variants.is_empty() {
return None;
}
let mut all_values = Vec::new();
let mut common_type: Option<String> = None;
for variant in variants {
let obj = variant.as_object()?;
// Extract literal value from const or single-item enum
let literal_value = if let Some(const_val) = obj.get("const") {
const_val.clone()
} else if let Some(Value::Array(arr)) = obj.get("enum") {
if arr.len() == 1 {
arr[0].clone()
} else {
return None;
}
} else {
return None;
};
// Check type consistency
let variant_type = obj.get("type")?.as_str()?;
match &common_type {
None => common_type = Some(variant_type.to_string()),
Some(t) if t != variant_type => return None,
_ => {}
}
all_values.push(literal_value);
}
common_type.map(|t| {
json!({
"type": t,
"enum": all_values
})
})
}
/// Clean type array, removing null.
fn clean_type_array(value: Value) -> Value {
if let Value::Array(types) = value {
let non_null: Vec<Value> = types
.into_iter()
.filter(|v| v.as_str() != Some("null"))
.collect();
match non_null.len() {
0 => Value::String("null".to_string()),
1 => non_null
.into_iter()
.next()
.unwrap_or(Value::String("null".to_string())),
_ => Value::Array(non_null),
}
} else {
value
}
}
/// Clean properties object.
fn clean_properties(
value: Value,
defs: &HashMap<String, Value>,
strategy: CleaningStrategy,
ref_stack: &mut HashSet<String>,
) -> Value {
if let Value::Object(props) = value {
let cleaned: Map<String, Value> = props
.into_iter()
.map(|(k, v)| (k, Self::clean_with_defs(v, defs, strategy, ref_stack)))
.collect();
Value::Object(cleaned)
} else {
value
}
}
/// Clean union (anyOf/oneOf/allOf).
fn clean_union(
value: Value,
defs: &HashMap<String, Value>,
strategy: CleaningStrategy,
ref_stack: &mut HashSet<String>,
) -> Value {
if let Value::Array(variants) = value {
let cleaned: Vec<Value> = variants
.into_iter()
.map(|v| Self::clean_with_defs(v, defs, strategy, ref_stack))
.collect();
Value::Array(cleaned)
} else {
value
}
}
/// Preserve metadata (description, title, default) from source to target.
fn preserve_meta(source: &Map<String, Value>, mut target: Value) -> Value {
if let Value::Object(target_obj) = &mut target {
for &key in SCHEMA_META_KEYS {
if let Some(value) = source.get(key) {
target_obj.insert(key.to_string(), value.clone());
}
}
}
target
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_remove_unsupported_keywords() {
let schema = json!({
"type": "string",
"minLength": 1,
"maxLength": 100,
"pattern": "^[a-z]+$",
"description": "A lowercase string"
});
let cleaned = SchemaCleanr::clean_for_gemini(schema);
assert_eq!(cleaned["type"], "string");
assert_eq!(cleaned["description"], "A lowercase string");
assert!(cleaned.get("minLength").is_none());
assert!(cleaned.get("maxLength").is_none());
assert!(cleaned.get("pattern").is_none());
}
#[test]
fn test_resolve_ref() {
let schema = json!({
"type": "object",
"properties": {
"age": {
"$ref": "#/$defs/Age"
}
},
"$defs": {
"Age": {
"type": "integer",
"minimum": 0
}
}
});
let cleaned = SchemaCleanr::clean_for_gemini(schema);
assert_eq!(cleaned["properties"]["age"]["type"], "integer");
assert!(cleaned["properties"]["age"].get("minimum").is_none()); // Stripped by Gemini strategy
assert!(cleaned.get("$defs").is_none());
}
#[test]
fn test_flatten_literal_union() {
let schema = json!({
"anyOf": [
{ "const": "admin", "type": "string" },
{ "const": "user", "type": "string" },
{ "const": "guest", "type": "string" }
]
});
let cleaned = SchemaCleanr::clean_for_gemini(schema);
assert_eq!(cleaned["type"], "string");
assert!(cleaned["enum"].is_array());
let enum_values = cleaned["enum"].as_array().unwrap();
assert_eq!(enum_values.len(), 3);
assert!(enum_values.contains(&json!("admin")));
assert!(enum_values.contains(&json!("user")));
assert!(enum_values.contains(&json!("guest")));
}
#[test]
fn test_strip_null_from_union() {
let schema = json!({
"oneOf": [
{ "type": "string" },
{ "type": "null" }
]
});
let cleaned = SchemaCleanr::clean_for_gemini(schema);
// Should simplify to just { type: "string" }
assert_eq!(cleaned["type"], "string");
assert!(cleaned.get("oneOf").is_none());
}
#[test]
fn test_const_to_enum() {
let schema = json!({
"const": "fixed_value",
"description": "A constant"
});
let cleaned = SchemaCleanr::clean_for_gemini(schema);
assert_eq!(cleaned["enum"], json!(["fixed_value"]));
assert_eq!(cleaned["description"], "A constant");
assert!(cleaned.get("const").is_none());
}
#[test]
fn test_preserve_metadata() {
let schema = json!({
"$ref": "#/$defs/Name",
"description": "User's name",
"title": "Name Field",
"default": "Anonymous",
"$defs": {
"Name": {
"type": "string"
}
}
});
let cleaned = SchemaCleanr::clean_for_gemini(schema);
assert_eq!(cleaned["type"], "string");
assert_eq!(cleaned["description"], "User's name");
assert_eq!(cleaned["title"], "Name Field");
assert_eq!(cleaned["default"], "Anonymous");
}
#[test]
fn test_circular_ref_prevention() {
let schema = json!({
"type": "object",
"properties": {
"parent": {
"$ref": "#/$defs/Node"
}
},
"$defs": {
"Node": {
"type": "object",
"properties": {
"child": {
"$ref": "#/$defs/Node"
}
}
}
}
});
// Should not panic on circular reference
let cleaned = SchemaCleanr::clean_for_gemini(schema);
assert_eq!(cleaned["properties"]["parent"]["type"], "object");
// Circular reference should be broken
}
#[test]
fn test_validate_schema() {
let valid = json!({
"type": "object",
"properties": {
"name": { "type": "string" }
}
});
assert!(SchemaCleanr::validate(&valid).is_ok());
let invalid = json!({
"properties": {
"name": { "type": "string" }
}
});
assert!(SchemaCleanr::validate(&invalid).is_err());
}
#[test]
fn test_strategy_differences() {
let schema = json!({
"type": "string",
"minLength": 1,
"description": "A string field"
});
// Gemini: Most restrictive (removes minLength)
let gemini = SchemaCleanr::clean_for_gemini(schema.clone());
assert!(gemini.get("minLength").is_none());
assert_eq!(gemini["type"], "string");
assert_eq!(gemini["description"], "A string field");
// OpenAI: Most permissive (keeps minLength)
let openai = SchemaCleanr::clean_for_openai(schema.clone());
assert_eq!(openai["minLength"], 1); // OpenAI allows validation keywords
assert_eq!(openai["type"], "string");
}
#[test]
fn test_nested_properties() {
let schema = json!({
"type": "object",
"properties": {
"user": {
"type": "object",
"properties": {
"name": {
"type": "string",
"minLength": 1
}
},
"additionalProperties": false
}
}
});
let cleaned = SchemaCleanr::clean_for_gemini(schema);
assert!(cleaned["properties"]["user"]["properties"]["name"]
.get("minLength")
.is_none());
assert!(cleaned["properties"]["user"]
.get("additionalProperties")
.is_none());
}
#[test]
fn test_type_array_null_removal() {
let schema = json!({
"type": ["string", "null"]
});
let cleaned = SchemaCleanr::clean_for_gemini(schema);
// Should simplify to just "string"
assert_eq!(cleaned["type"], "string");
}
#[test]
fn test_type_array_only_null_preserved() {
let schema = json!({
"type": ["null"]
});
let cleaned = SchemaCleanr::clean_for_gemini(schema);
assert_eq!(cleaned["type"], "null");
}
#[test]
fn test_ref_with_json_pointer_escape() {
let schema = json!({
"$ref": "#/$defs/Foo~1Bar",
"$defs": {
"Foo/Bar": {
"type": "string"
}
}
});
let cleaned = SchemaCleanr::clean_for_gemini(schema);
assert_eq!(cleaned["type"], "string");
}
#[test]
fn test_skip_type_when_non_simplifiable_union_exists() {
let schema = json!({
"type": "object",
"oneOf": [
{
"type": "object",
"properties": {
"a": { "type": "string" }
}
},
{
"type": "object",
"properties": {
"b": { "type": "number" }
}
}
]
});
let cleaned = SchemaCleanr::clean_for_gemini(schema);
assert!(cleaned.get("type").is_none());
assert!(cleaned.get("oneOf").is_some());
}
#[test]
fn test_clean_nested_unknown_schema_keyword() {
let schema = json!({
"not": {
"$ref": "#/$defs/Age"
},
"$defs": {
"Age": {
"type": "integer",
"minimum": 0
}
}
});
let cleaned = SchemaCleanr::clean_for_gemini(schema);
assert_eq!(cleaned["not"]["type"], "integer");
assert!(cleaned["not"].get("minimum").is_none());
}
}

View file

@ -36,6 +36,7 @@ async fn compare_store_speed() {
&format!("key_{i}"),
&format!("Memory entry number {i} about Rust programming"),
MemoryCategory::Core,
None,
)
.await
.unwrap();
@ -49,6 +50,7 @@ async fn compare_store_speed() {
&format!("key_{i}"),
&format!("Memory entry number {i} about Rust programming"),
MemoryCategory::Core,
None,
)
.await
.unwrap();
@ -127,8 +129,8 @@ async fn compare_recall_quality() {
];
for (key, content, cat) in &entries {
sq.store(key, content, cat.clone()).await.unwrap();
md.store(key, content, cat.clone()).await.unwrap();
sq.store(key, content, cat.clone(), None).await.unwrap();
md.store(key, content, cat.clone(), None).await.unwrap();
}
// Test queries and compare results
@ -145,8 +147,8 @@ async fn compare_recall_quality() {
println!("RECALL QUALITY (10 entries seeded):\n");
for (query, desc) in &queries {
let sq_results = sq.recall(query, 10).await.unwrap();
let md_results = md.recall(query, 10).await.unwrap();
let sq_results = sq.recall(query, 10, None).await.unwrap();
let md_results = md.recall(query, 10, None).await.unwrap();
println!(" Query: \"{query}\"{desc}");
println!(" SQLite: {} results", sq_results.len());
@ -190,21 +192,21 @@ async fn compare_recall_speed() {
} else {
format!("TypeScript powers modern web apps, entry {i}")
};
sq.store(&format!("e{i}"), &content, MemoryCategory::Core)
sq.store(&format!("e{i}"), &content, MemoryCategory::Core, None)
.await
.unwrap();
md.store(&format!("e{i}"), &content, MemoryCategory::Daily)
md.store(&format!("e{i}"), &content, MemoryCategory::Daily, None)
.await
.unwrap();
}
// Benchmark recall
let start = Instant::now();
let sq_results = sq.recall("Rust systems", 10).await.unwrap();
let sq_results = sq.recall("Rust systems", 10, None).await.unwrap();
let sq_dur = start.elapsed();
let start = Instant::now();
let md_results = md.recall("Rust systems", 10).await.unwrap();
let md_results = md.recall("Rust systems", 10, None).await.unwrap();
let md_dur = start.elapsed();
println!("\n============================================================");
@ -227,13 +229,23 @@ async fn compare_persistence() {
// Store in both, then drop and re-open
{
let sq = sqlite_backend(tmp_sq.path());
sq.store("persist_test", "I should survive", MemoryCategory::Core)
sq.store(
"persist_test",
"I should survive",
MemoryCategory::Core,
None,
)
.await
.unwrap();
}
{
let md = markdown_backend(tmp_md.path());
md.store("persist_test", "I should survive", MemoryCategory::Core)
md.store(
"persist_test",
"I should survive",
MemoryCategory::Core,
None,
)
.await
.unwrap();
}
@ -282,17 +294,17 @@ async fn compare_upsert() {
let md = markdown_backend(tmp_md.path());
// Store twice with same key, different content
sq.store("pref", "likes Rust", MemoryCategory::Core)
sq.store("pref", "likes Rust", MemoryCategory::Core, None)
.await
.unwrap();
sq.store("pref", "loves Rust", MemoryCategory::Core)
sq.store("pref", "loves Rust", MemoryCategory::Core, None)
.await
.unwrap();
md.store("pref", "likes Rust", MemoryCategory::Core)
md.store("pref", "likes Rust", MemoryCategory::Core, None)
.await
.unwrap();
md.store("pref", "loves Rust", MemoryCategory::Core)
md.store("pref", "loves Rust", MemoryCategory::Core, None)
.await
.unwrap();
@ -300,7 +312,7 @@ async fn compare_upsert() {
let md_count = md.count().await.unwrap();
let sq_entry = sq.get("pref").await.unwrap();
let md_results = md.recall("loves Rust", 5).await.unwrap();
let md_results = md.recall("loves Rust", 5, None).await.unwrap();
println!("\n============================================================");
println!("UPSERT (store same key twice):");
@ -328,10 +340,10 @@ async fn compare_forget() {
let sq = sqlite_backend(tmp_sq.path());
let md = markdown_backend(tmp_md.path());
sq.store("secret", "API key: sk-1234", MemoryCategory::Core)
sq.store("secret", "API key: sk-1234", MemoryCategory::Core, None)
.await
.unwrap();
md.store("secret", "API key: sk-1234", MemoryCategory::Core)
md.store("secret", "API key: sk-1234", MemoryCategory::Core, None)
.await
.unwrap();
@ -372,37 +384,40 @@ async fn compare_category_filter() {
let md = markdown_backend(tmp_md.path());
// Mix of categories
sq.store("a", "core fact 1", MemoryCategory::Core)
sq.store("a", "core fact 1", MemoryCategory::Core, None)
.await
.unwrap();
sq.store("b", "core fact 2", MemoryCategory::Core)
sq.store("b", "core fact 2", MemoryCategory::Core, None)
.await
.unwrap();
sq.store("c", "daily note", MemoryCategory::Daily)
sq.store("c", "daily note", MemoryCategory::Daily, None)
.await
.unwrap();
sq.store("d", "convo msg", MemoryCategory::Conversation)
sq.store("d", "convo msg", MemoryCategory::Conversation, None)
.await
.unwrap();
md.store("a", "core fact 1", MemoryCategory::Core)
md.store("a", "core fact 1", MemoryCategory::Core, None)
.await
.unwrap();
md.store("b", "core fact 2", MemoryCategory::Core)
md.store("b", "core fact 2", MemoryCategory::Core, None)
.await
.unwrap();
md.store("c", "daily note", MemoryCategory::Daily)
md.store("c", "daily note", MemoryCategory::Daily, None)
.await
.unwrap();
let sq_core = sq.list(Some(&MemoryCategory::Core)).await.unwrap();
let sq_daily = sq.list(Some(&MemoryCategory::Daily)).await.unwrap();
let sq_conv = sq.list(Some(&MemoryCategory::Conversation)).await.unwrap();
let sq_all = sq.list(None).await.unwrap();
let sq_core = sq.list(Some(&MemoryCategory::Core), None).await.unwrap();
let sq_daily = sq.list(Some(&MemoryCategory::Daily), None).await.unwrap();
let sq_conv = sq
.list(Some(&MemoryCategory::Conversation), None)
.await
.unwrap();
let sq_all = sq.list(None, None).await.unwrap();
let md_core = md.list(Some(&MemoryCategory::Core)).await.unwrap();
let md_daily = md.list(Some(&MemoryCategory::Daily)).await.unwrap();
let md_all = md.list(None).await.unwrap();
let md_core = md.list(Some(&MemoryCategory::Core), None).await.unwrap();
let md_daily = md.list(Some(&MemoryCategory::Daily), None).await.unwrap();
let md_all = md.list(None, None).await.unwrap();
println!("\n============================================================");
println!("CATEGORY FILTERING:");