Merge branch 'main' into pr-484-clean
This commit is contained in:
commit
ee05d62ce4
90 changed files with 6937 additions and 1403 deletions
62
.env.example
62
.env.example
|
|
@ -1,25 +1,69 @@
|
|||
# ZeroClaw Environment Variables
|
||||
# Copy this file to .env and fill in your values.
|
||||
# NEVER commit .env — it is listed in .gitignore.
|
||||
# Copy this file to `.env` and fill in your local values.
|
||||
# Never commit `.env` or any real secrets.
|
||||
|
||||
# ── Required ──────────────────────────────────────────────────
|
||||
# Your LLM provider API key
|
||||
# ZEROCLAW_API_KEY=sk-your-key-here
|
||||
# ── Core Runtime ──────────────────────────────────────────────
|
||||
# Provider key resolution at runtime:
|
||||
# 1) explicit key passed from config/CLI
|
||||
# 2) provider-specific env var (OPENROUTER_API_KEY, OPENAI_API_KEY, ...)
|
||||
# 3) generic fallback env vars below
|
||||
|
||||
# Generic fallback API key (used when provider-specific key is absent)
|
||||
API_KEY=your-api-key-here
|
||||
# ZEROCLAW_API_KEY=your-api-key-here
|
||||
|
||||
# ── Provider & Model ─────────────────────────────────────────
|
||||
# LLM provider: openrouter, openai, anthropic, ollama, glm
|
||||
# Default provider/model (can be overridden by CLI flags)
|
||||
PROVIDER=openrouter
|
||||
# ZEROCLAW_PROVIDER=openrouter
|
||||
# ZEROCLAW_MODEL=anthropic/claude-sonnet-4-20250514
|
||||
# ZEROCLAW_TEMPERATURE=0.7
|
||||
|
||||
# Workspace directory override
|
||||
# ZEROCLAW_WORKSPACE=/path/to/workspace
|
||||
|
||||
# ── Provider-Specific API Keys ────────────────────────────────
|
||||
# OpenRouter
|
||||
# OPENROUTER_API_KEY=sk-or-v1-...
|
||||
|
||||
# Anthropic
|
||||
# ANTHROPIC_OAUTH_TOKEN=...
|
||||
# ANTHROPIC_API_KEY=sk-ant-...
|
||||
|
||||
# OpenAI / Gemini
|
||||
# OPENAI_API_KEY=sk-...
|
||||
# GEMINI_API_KEY=...
|
||||
# GOOGLE_API_KEY=...
|
||||
|
||||
# Other supported providers
|
||||
# VENICE_API_KEY=...
|
||||
# GROQ_API_KEY=...
|
||||
# MISTRAL_API_KEY=...
|
||||
# DEEPSEEK_API_KEY=...
|
||||
# XAI_API_KEY=...
|
||||
# TOGETHER_API_KEY=...
|
||||
# FIREWORKS_API_KEY=...
|
||||
# PERPLEXITY_API_KEY=...
|
||||
# COHERE_API_KEY=...
|
||||
# MOONSHOT_API_KEY=...
|
||||
# GLM_API_KEY=...
|
||||
# MINIMAX_API_KEY=...
|
||||
# QIANFAN_API_KEY=...
|
||||
# DASHSCOPE_API_KEY=...
|
||||
# ZAI_API_KEY=...
|
||||
# SYNTHETIC_API_KEY=...
|
||||
# OPENCODE_API_KEY=...
|
||||
# VERCEL_API_KEY=...
|
||||
# CLOUDFLARE_API_KEY=...
|
||||
|
||||
# ── Gateway ──────────────────────────────────────────────────
|
||||
# ZEROCLAW_GATEWAY_PORT=3000
|
||||
# ZEROCLAW_GATEWAY_HOST=127.0.0.1
|
||||
# ZEROCLAW_ALLOW_PUBLIC_BIND=false
|
||||
|
||||
# ── Workspace ────────────────────────────────────────────────
|
||||
# ZEROCLAW_WORKSPACE=/path/to/workspace
|
||||
# ── Optional Integrations ────────────────────────────────────
|
||||
# Pushover notifications (`pushover` tool)
|
||||
# PUSHOVER_TOKEN=your-pushover-app-token
|
||||
# PUSHOVER_USER_KEY=your-pushover-user-key
|
||||
|
||||
# ── Docker Compose ───────────────────────────────────────────
|
||||
# Host port mapping (used by docker-compose.yml)
|
||||
|
|
|
|||
8
.githooks/pre-commit
Executable file
8
.githooks/pre-commit
Executable file
|
|
@ -0,0 +1,8 @@
|
|||
#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
if command -v gitleaks >/dev/null 2>&1; then
|
||||
gitleaks protect --staged --redact
|
||||
else
|
||||
echo "warning: gitleaks not found; skipping staged secret scan" >&2
|
||||
fi
|
||||
4
.github/pull_request_template.md
vendored
4
.github/pull_request_template.md
vendored
|
|
@ -12,7 +12,11 @@ Describe this PR in 2-5 bullets:
|
|||
- Risk label (`risk: low|medium|high`):
|
||||
- Size label (`size: XS|S|M|L|XL`, auto-managed/read-only):
|
||||
- Scope labels (`core|agent|channel|config|cron|daemon|doctor|gateway|health|heartbeat|integration|memory|observability|onboard|provider|runtime|security|service|skillforge|skills|tool|tunnel|docs|dependencies|ci|tests|scripts|dev`, comma-separated):
|
||||
<<<<<<< chore/labeler-spacing-trusted-tier
|
||||
- Module labels (`<module>: <component>`, for example `channel: telegram`, `provider: kimi`, `tool: shell`):
|
||||
=======
|
||||
- Module labels (`<module>:<component>`, for example `channel:telegram`, `provider:kimi`, `tool:shell`):
|
||||
>>>>>>> main
|
||||
- Contributor tier label (`trusted contributor|experienced contributor|principal contributor|distinguished contributor`, auto-managed/read-only; author merged PRs >=5/10/20/50):
|
||||
- If any auto-label is incorrect, note requested correction:
|
||||
|
||||
|
|
|
|||
1
.github/workflows/auto-response.yml
vendored
1
.github/workflows/auto-response.yml
vendored
|
|
@ -18,6 +18,7 @@ jobs:
|
|||
runs-on: blacksmith-2vcpu-ubuntu-2404
|
||||
permissions:
|
||||
issues: write
|
||||
pull-requests: write
|
||||
steps:
|
||||
- name: Apply contributor tier label for issue author
|
||||
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8
|
||||
|
|
|
|||
8
.github/workflows/docker.yml
vendored
8
.github/workflows/docker.yml
vendored
|
|
@ -35,7 +35,7 @@ jobs:
|
|||
uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4
|
||||
|
||||
- name: Setup Blacksmith Builder
|
||||
uses: useblacksmith/setup-docker-builder@v1
|
||||
uses: useblacksmith/setup-docker-builder@ef12d5b165b596e3aa44ea8198d8fde563eab402 # v1
|
||||
|
||||
- name: Extract metadata (tags, labels)
|
||||
id: meta
|
||||
|
|
@ -46,7 +46,7 @@ jobs:
|
|||
type=ref,event=pr
|
||||
|
||||
- name: Build smoke image
|
||||
uses: useblacksmith/build-push-action@v2
|
||||
uses: useblacksmith/build-push-action@30c71162f16ea2c27c3e21523255d209b8b538c1 # v2
|
||||
with:
|
||||
context: .
|
||||
push: false
|
||||
|
|
@ -71,7 +71,7 @@ jobs:
|
|||
uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4
|
||||
|
||||
- name: Setup Blacksmith Builder
|
||||
uses: useblacksmith/setup-docker-builder@v1
|
||||
uses: useblacksmith/setup-docker-builder@ef12d5b165b596e3aa44ea8198d8fde563eab402 # v1
|
||||
|
||||
- name: Log in to Container Registry
|
||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # v3
|
||||
|
|
@ -102,7 +102,7 @@ jobs:
|
|||
echo "tags=${TAGS}" >> "$GITHUB_OUTPUT"
|
||||
|
||||
- name: Build and push Docker image
|
||||
uses: useblacksmith/build-push-action@v2
|
||||
uses: useblacksmith/build-push-action@30c71162f16ea2c27c3e21523255d209b8b538c1 # v2
|
||||
with:
|
||||
context: .
|
||||
push: true
|
||||
|
|
|
|||
27
.github/workflows/labeler.yml
vendored
27
.github/workflows/labeler.yml
vendored
|
|
@ -325,13 +325,18 @@ jobs:
|
|||
return pattern.test(text);
|
||||
}
|
||||
|
||||
function formatModuleLabel(prefix, segment) {
|
||||
return `${prefix}: ${segment}`;
|
||||
}
|
||||
|
||||
function parseModuleLabel(label) {
|
||||
const separatorIndex = label.indexOf(":");
|
||||
if (separatorIndex <= 0 || separatorIndex >= label.length - 1) return null;
|
||||
return {
|
||||
prefix: label.slice(0, separatorIndex),
|
||||
segment: label.slice(separatorIndex + 1),
|
||||
};
|
||||
if (typeof label !== "string") return null;
|
||||
const match = label.match(/^([^:]+):\s*(.+)$/);
|
||||
if (!match) return null;
|
||||
const prefix = match[1].trim().toLowerCase();
|
||||
const segment = (match[2] || "").trim().toLowerCase();
|
||||
if (!prefix || !segment) return null;
|
||||
return { prefix, segment };
|
||||
}
|
||||
|
||||
function sortByPriority(labels, priorityIndex) {
|
||||
|
|
@ -389,7 +394,7 @@ jobs:
|
|||
for (const [prefix, segments] of segmentsByPrefix) {
|
||||
const hasSpecificSegment = [...segments].some((segment) => segment !== "core");
|
||||
if (hasSpecificSegment) {
|
||||
refined.delete(`${prefix}:core`);
|
||||
refined.delete(formatModuleLabel(prefix, "core"));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -418,7 +423,7 @@ jobs:
|
|||
if (uniqueSegments.length === 0) continue;
|
||||
|
||||
if (uniqueSegments.length === 1) {
|
||||
compactedModuleLabels.add(`${prefix}:${uniqueSegments[0]}`);
|
||||
compactedModuleLabels.add(formatModuleLabel(prefix, uniqueSegments[0]));
|
||||
} else {
|
||||
forcePathPrefixes.add(prefix);
|
||||
}
|
||||
|
|
@ -609,7 +614,7 @@ jobs:
|
|||
segment = normalizeLabelSegment(segment);
|
||||
if (!segment) continue;
|
||||
|
||||
detectedModuleLabels.add(`${rule.prefix}:${segment}`);
|
||||
detectedModuleLabels.add(formatModuleLabel(rule.prefix, segment));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -635,7 +640,7 @@ jobs:
|
|||
|
||||
for (const keyword of providerKeywordHints) {
|
||||
if (containsKeyword(searchableText, keyword)) {
|
||||
detectedModuleLabels.add(`provider:${keyword}`);
|
||||
detectedModuleLabels.add(formatModuleLabel("provider", keyword));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -661,7 +666,7 @@ jobs:
|
|||
|
||||
for (const keyword of channelKeywordHints) {
|
||||
if (containsKeyword(searchableText, keyword)) {
|
||||
detectedModuleLabels.add(`channel:${keyword}`);
|
||||
detectedModuleLabels.add(formatModuleLabel("channel", keyword));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
22
.gitignore
vendored
22
.gitignore
vendored
|
|
@ -4,6 +4,26 @@ firmware/*/target
|
|||
*.db-journal
|
||||
.DS_Store
|
||||
.wt-pr37/
|
||||
.env
|
||||
__pycache__/
|
||||
*.pyc
|
||||
docker-compose.override.yml
|
||||
|
||||
# Environment files (may contain secrets)
|
||||
.env
|
||||
|
||||
# Python virtual environments
|
||||
|
||||
.venv/
|
||||
venv/
|
||||
|
||||
# ESP32 build cache (esp-idf-sys managed)
|
||||
|
||||
.embuild/
|
||||
.env.local
|
||||
.env.*.local
|
||||
|
||||
# Secret keys and credentials
|
||||
.secret_key
|
||||
*.key
|
||||
*.pem
|
||||
credentials.json
|
||||
|
|
|
|||
|
|
@ -79,6 +79,94 @@ git push --no-verify
|
|||
|
||||
> **Note:** CI runs the same checks, so skipped hooks will be caught on the PR.
|
||||
|
||||
## Local Secret Management (Required)
|
||||
|
||||
ZeroClaw supports layered secret management for local development and CI hygiene.
|
||||
|
||||
### Secret Storage Options
|
||||
|
||||
1. **Environment variables** (recommended for local development)
|
||||
- Copy `.env.example` to `.env` and fill in values
|
||||
- `.env` files are Git-ignored and should stay local
|
||||
- Best for temporary/local API keys
|
||||
|
||||
2. **Config file** (`~/.zeroclaw/config.toml`)
|
||||
- Persistent setup for long-term use
|
||||
- When `secrets.encrypt = true` (default), secret values are encrypted before save
|
||||
- Secret key is stored at `~/.zeroclaw/.secret_key` with restricted permissions
|
||||
- Use `zeroclaw onboard` for guided setup
|
||||
|
||||
### Runtime Resolution Rules
|
||||
|
||||
API key resolution follows this order:
|
||||
|
||||
1. Explicit key passed from config/CLI
|
||||
2. Provider-specific env vars (`OPENROUTER_API_KEY`, `OPENAI_API_KEY`, `ANTHROPIC_API_KEY`, ...)
|
||||
3. Generic env vars (`ZEROCLAW_API_KEY`, `API_KEY`)
|
||||
|
||||
Provider/model config overrides:
|
||||
|
||||
- `ZEROCLAW_PROVIDER` / `PROVIDER`
|
||||
- `ZEROCLAW_MODEL`
|
||||
|
||||
See `.env.example` for practical examples and currently supported provider key env vars.
|
||||
|
||||
### Pre-Commit Secret Hygiene (Mandatory)
|
||||
|
||||
Before every commit, verify:
|
||||
|
||||
- [ ] No `.env` files are staged (`.env.example` only)
|
||||
- [ ] No raw API keys/tokens in code, tests, fixtures, examples, logs, or commit messages
|
||||
- [ ] No credentials in debug output or error payloads
|
||||
- [ ] `git diff --cached` has no accidental secret-like strings
|
||||
|
||||
Quick local audit:
|
||||
|
||||
```bash
|
||||
# Search staged diff for common secret markers
|
||||
git diff --cached | grep -iE '(api[_-]?key|secret|token|password|bearer|sk-)'
|
||||
|
||||
# Confirm no .env file is staged
|
||||
git status --short | grep -E '\.env$'
|
||||
```
|
||||
|
||||
### Optional Local Secret Scanning
|
||||
|
||||
For extra guardrails, install one of:
|
||||
|
||||
- **gitleaks**: [GitHub - gitleaks/gitleaks](https://github.com/gitleaks/gitleaks)
|
||||
- **truffleHog**: [GitHub - trufflesecurity/trufflehog](https://github.com/trufflesecurity/trufflehog)
|
||||
- **git-secrets**: [GitHub - awslabs/git-secrets](https://github.com/awslabs/git-secrets)
|
||||
|
||||
This repo includes `.githooks/pre-commit` to run `gitleaks protect --staged --redact` when gitleaks is installed.
|
||||
|
||||
Enable hooks with:
|
||||
|
||||
```bash
|
||||
git config core.hooksPath .githooks
|
||||
```
|
||||
|
||||
If gitleaks is not installed, the pre-commit hook prints a warning and continues.
|
||||
|
||||
### What Must Never Be Committed
|
||||
|
||||
- `.env` files (use `.env.example` only)
|
||||
- API keys, tokens, passwords, or credentials (plain or encrypted)
|
||||
- OAuth tokens or session identifiers
|
||||
- Webhook signing secrets
|
||||
- `~/.zeroclaw/.secret_key` or similar key files
|
||||
- Personal identifiers or real user data in tests/fixtures
|
||||
|
||||
### If a Secret Is Committed Accidentally
|
||||
|
||||
1. Revoke/rotate the credential immediately
|
||||
2. Do not rely only on `git revert` (history still contains the secret)
|
||||
3. Purge history with `git filter-repo` or BFG
|
||||
4. Force-push cleaned history (coordinate with maintainers)
|
||||
5. Ensure the leaked value is removed from PR/issue/discussion/comment history
|
||||
|
||||
Reference: [GitHub guide: removing sensitive data from a repository](https://docs.github.com/en/authentication/keeping-your-account-and-data-secure/removing-sensitive-data-from-a-repository)
|
||||
|
||||
## Collaboration Tracks (Risk-Based)
|
||||
|
||||
To keep review throughput high without lowering quality, every PR should map to one track:
|
||||
|
|
|
|||
51
Cargo.lock
generated
51
Cargo.lock
generated
|
|
@ -209,6 +209,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||
checksum = "8b52af3cb4058c895d37317bb27508dccc8e5f2d39454016b297bf4a400597b8"
|
||||
dependencies = [
|
||||
"axum-core",
|
||||
"base64",
|
||||
"bytes",
|
||||
"form_urlencoded",
|
||||
"futures-util",
|
||||
|
|
@ -227,8 +228,10 @@ dependencies = [
|
|||
"serde_json",
|
||||
"serde_path_to_error",
|
||||
"serde_urlencoded",
|
||||
"sha1",
|
||||
"sync_wrapper",
|
||||
"tokio",
|
||||
"tokio-tungstenite 0.28.0",
|
||||
"tower",
|
||||
"tower-layer",
|
||||
"tower-service",
|
||||
|
|
@ -2057,6 +2060,15 @@ dependencies = [
|
|||
"hashify",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "matchers"
|
||||
version = "0.2.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d1525a2a28c7f4fa0fc98bb91ae755d1e2d1505079e05539e35bc876b5d65ae9"
|
||||
dependencies = [
|
||||
"regex-automata",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "matchit"
|
||||
version = "0.8.4"
|
||||
|
|
@ -3747,10 +3759,22 @@ dependencies = [
|
|||
"rustls-pki-types",
|
||||
"tokio",
|
||||
"tokio-rustls",
|
||||
"tungstenite",
|
||||
"tungstenite 0.24.0",
|
||||
"webpki-roots 0.26.11",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-tungstenite"
|
||||
version = "0.28.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "d25a406cddcc431a75d3d9afc6a7c0f7428d4891dd973e4d54c56b46127bf857"
|
||||
dependencies = [
|
||||
"futures-util",
|
||||
"log",
|
||||
"tokio",
|
||||
"tungstenite 0.28.0",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-util"
|
||||
version = "0.7.18"
|
||||
|
|
@ -3940,9 +3964,13 @@ version = "0.3.22"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2f30143827ddab0d256fd843b7a66d164e9f271cfa0dde49142c5ca0ca291f1e"
|
||||
dependencies = [
|
||||
"matchers",
|
||||
"nu-ansi-term",
|
||||
"once_cell",
|
||||
"regex-automata",
|
||||
"sharded-slab",
|
||||
"thread_local",
|
||||
"tracing",
|
||||
"tracing-core",
|
||||
]
|
||||
|
||||
|
|
@ -3978,6 +4006,23 @@ dependencies = [
|
|||
"utf-8",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tungstenite"
|
||||
version = "0.28.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "8628dcc84e5a09eb3d8423d6cb682965dea9133204e8fb3efee74c2a0c259442"
|
||||
dependencies = [
|
||||
"bytes",
|
||||
"data-encoding",
|
||||
"http 1.4.0",
|
||||
"httparse",
|
||||
"log",
|
||||
"rand 0.9.2",
|
||||
"sha1",
|
||||
"thiserror 2.0.18",
|
||||
"utf-8",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "twox-hash"
|
||||
version = "2.1.2"
|
||||
|
|
@ -4880,7 +4925,9 @@ dependencies = [
|
|||
"pdf-extract",
|
||||
"probe-rs",
|
||||
"prometheus",
|
||||
"prost",
|
||||
"rand 0.8.5",
|
||||
"regex",
|
||||
"reqwest",
|
||||
"rppal",
|
||||
"rusqlite",
|
||||
|
|
@ -4896,7 +4943,7 @@ dependencies = [
|
|||
"tokio-rustls",
|
||||
"tokio-serial",
|
||||
"tokio-test",
|
||||
"tokio-tungstenite",
|
||||
"tokio-tungstenite 0.24.0",
|
||||
"toml",
|
||||
"tower",
|
||||
"tower-http",
|
||||
|
|
|
|||
28
Cargo.toml
28
Cargo.toml
|
|
@ -1,3 +1,7 @@
|
|||
[workspace]
|
||||
members = ["."]
|
||||
resolver = "2"
|
||||
|
||||
[package]
|
||||
name = "zeroclaw"
|
||||
version = "0.1.0"
|
||||
|
|
@ -31,7 +35,7 @@ shellexpand = "3.1"
|
|||
|
||||
# Logging - minimal
|
||||
tracing = { version = "0.1", default-features = false }
|
||||
tracing-subscriber = { version = "0.3", default-features = false, features = ["fmt", "ansi"] }
|
||||
tracing-subscriber = { version = "0.3", default-features = false, features = ["fmt", "ansi", "env-filter"] }
|
||||
|
||||
# Observability - Prometheus metrics
|
||||
prometheus = { version = "0.14", default-features = false }
|
||||
|
|
@ -63,12 +67,12 @@ rand = "0.8"
|
|||
# Fast mutexes that don't poison on panic
|
||||
parking_lot = "0.12"
|
||||
|
||||
# Landlock (Linux sandbox) - optional dependency
|
||||
landlock = { version = "0.4", optional = true }
|
||||
|
||||
# Async traits
|
||||
async-trait = "0.1"
|
||||
|
||||
# Protobuf encode/decode (Feishu WS long-connection frame codec)
|
||||
prost = { version = "0.14", default-features = false }
|
||||
|
||||
# Memory / persistence
|
||||
rusqlite = { version = "0.38", features = ["bundled"] }
|
||||
chrono = { version = "0.4", default-features = false, features = ["clock", "std", "serde"] }
|
||||
|
|
@ -86,6 +90,7 @@ glob = "0.3"
|
|||
tokio-tungstenite = { version = "0.24", features = ["rustls-tls-webpki-roots"] }
|
||||
futures-util = { version = "0.3", default-features = false, features = ["sink"] }
|
||||
futures = "0.3"
|
||||
regex = "1.10"
|
||||
hostname = "0.4.2"
|
||||
lettre = { version = "0.11.19", default-features = false, features = ["builder", "smtp-transport", "rustls-tls"] }
|
||||
mail-parser = "0.11.2"
|
||||
|
|
@ -95,7 +100,7 @@ tokio-rustls = "0.26.4"
|
|||
webpki-roots = "1.0.6"
|
||||
|
||||
# HTTP server (gateway) — replaces raw TCP for proper HTTP/1.1 compliance
|
||||
axum = { version = "0.8", default-features = false, features = ["http1", "json", "tokio", "query"] }
|
||||
axum = { version = "0.8", default-features = false, features = ["http1", "json", "tokio", "query", "ws"] }
|
||||
tower = { version = "0.5", default-features = false }
|
||||
tower-http = { version = "0.6", default-features = false, features = ["limit", "timeout"] }
|
||||
http-body-util = "0.1"
|
||||
|
|
@ -117,19 +122,28 @@ probe-rs = { version = "0.30", optional = true }
|
|||
# PDF extraction for datasheet RAG (optional, enable with --features rag-pdf)
|
||||
pdf-extract = { version = "0.10", optional = true }
|
||||
|
||||
# Raspberry Pi GPIO (Linux/RPi only) — target-specific to avoid compile failure on macOS
|
||||
# Raspberry Pi GPIO / Landlock (Linux only) — target-specific to avoid compile failure on macOS
|
||||
[target.'cfg(target_os = "linux")'.dependencies]
|
||||
rppal = { version = "0.14", optional = true }
|
||||
landlock = { version = "0.4", optional = true }
|
||||
|
||||
[features]
|
||||
default = ["hardware"]
|
||||
hardware = ["nusb", "tokio-serial"]
|
||||
peripheral-rpi = ["rppal"]
|
||||
# Browser backend feature alias used by cfg(feature = "browser-native")
|
||||
browser-native = ["dep:fantoccini"]
|
||||
# Backward-compatible alias for older invocations
|
||||
fantoccini = ["browser-native"]
|
||||
# Sandbox feature aliases used by cfg(feature = "sandbox-*")
|
||||
sandbox-landlock = ["dep:landlock"]
|
||||
sandbox-bubblewrap = []
|
||||
# Backward-compatible alias for older invocations
|
||||
landlock = ["sandbox-landlock"]
|
||||
# probe = probe-rs for Nucleo memory read (adds ~50 deps; optional)
|
||||
probe = ["dep:probe-rs"]
|
||||
# rag-pdf = PDF ingestion for datasheet RAG
|
||||
rag-pdf = ["dep:pdf-extract"]
|
||||
|
||||
[profile.release]
|
||||
opt-level = "z" # Optimize for size
|
||||
lto = "thin" # Lower memory use during release builds
|
||||
|
|
|
|||
211
LICENSE
211
LICENSE
|
|
@ -1,197 +1,28 @@
|
|||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
MIT License
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
Copyright (c) 2025 ZeroClaw Labs
|
||||
|
||||
1. Definitions.
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
================================================================================
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
This product includes software developed by ZeroClaw Labs and contributors:
|
||||
https://github.com/zeroclaw-labs/zeroclaw/graphs/contributors
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to the Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
Copyright 2025-2026 Argenis Delarosa
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
|
||||
===============================================================================
|
||||
|
||||
This product includes software developed by ZeroClaw Labs and contributors:
|
||||
https://github.com/zeroclaw-labs/zeroclaw/graphs/contributors
|
||||
|
||||
See NOTICE file for full contributor attribution.
|
||||
See NOTICE file for full contributor attribution.
|
||||
|
|
|
|||
27
README.md
27
README.md
|
|
@ -10,14 +10,14 @@
|
|||
</p>
|
||||
|
||||
<p align="center">
|
||||
<a href="LICENSE"><img src="https://img.shields.io/badge/license-Apache%202.0-blue.svg" alt="License: Apache 2.0" /></a>
|
||||
<a href="LICENSE"><img src="https://img.shields.io/badge/license-MIT-blue.svg" alt="License: MIT" /></a>
|
||||
<a href="NOTICE"><img src="https://img.shields.io/badge/contributors-27+-green.svg" alt="Contributors" /></a>
|
||||
</p>
|
||||
|
||||
Fast, small, and fully autonomous AI assistant infrastructure — deploy anywhere, swap anything.
|
||||
|
||||
```
|
||||
~3.4MB binary · <10ms startup · 1,017 tests · 22+ providers · 8 traits · Pluggable everything
|
||||
~3.4MB binary · <10ms startup · 1,017 tests · 23+ providers · 8 traits · Pluggable everything
|
||||
```
|
||||
|
||||
### ✨ Features
|
||||
|
|
@ -132,6 +132,9 @@ cd zeroclaw
|
|||
cargo build --release --locked
|
||||
cargo install --path . --force --locked
|
||||
|
||||
# Ensure ~/.cargo/bin is in your PATH
|
||||
export PATH="$HOME/.cargo/bin:$PATH"
|
||||
|
||||
# Quick setup (no prompts)
|
||||
zeroclaw onboard --api-key sk-... --provider openrouter
|
||||
|
||||
|
|
@ -187,7 +190,7 @@ Every subsystem is a **trait** — swap implementations with a config change, ze
|
|||
|
||||
| Subsystem | Trait | Ships with | Extend |
|
||||
|-----------|-------|------------|--------|
|
||||
| **AI Models** | `Provider` | 22+ providers (OpenRouter, Anthropic, OpenAI, Ollama, Venice, Groq, Mistral, xAI, DeepSeek, Together, Fireworks, Perplexity, Cohere, Bedrock, etc.) | `custom:https://your-api.com` — any OpenAI-compatible API |
|
||||
| **AI Models** | `Provider` | 23+ providers (OpenRouter, Anthropic, OpenAI, Ollama, Venice, Groq, Mistral, xAI, DeepSeek, Together, Fireworks, Perplexity, Cohere, Bedrock, Astrai, etc.) | `custom:https://your-api.com` — any OpenAI-compatible API |
|
||||
| **Channels** | `Channel` | CLI, Telegram, Discord, Slack, iMessage, Matrix, WhatsApp, Webhook | Any messaging API |
|
||||
| **Memory** | `Memory` | SQLite with hybrid search (FTS5 + vector cosine similarity), Lucid bridge (CLI sync + SQLite fallback), Markdown | Any persistence backend |
|
||||
| **Tools** | `Tool` | shell, file_read, file_write, memory_store, memory_recall, memory_forget, browser_open (Brave + allowlist), browser (agent-browser / rust-native), composio (optional) | Any capability |
|
||||
|
|
@ -287,6 +290,21 @@ rerun channel setup only:
|
|||
zeroclaw onboard --channels-only
|
||||
```
|
||||
|
||||
### Telegram media replies
|
||||
|
||||
Telegram routing now replies to the source **chat ID** from incoming updates (instead of usernames),
|
||||
which avoids `Bad Request: chat not found` failures.
|
||||
|
||||
For non-text replies, ZeroClaw can send Telegram attachments when the assistant includes markers:
|
||||
|
||||
- `[IMAGE:<path-or-url>]`
|
||||
- `[DOCUMENT:<path-or-url>]`
|
||||
- `[VIDEO:<path-or-url>]`
|
||||
- `[AUDIO:<path-or-url>]`
|
||||
- `[VOICE:<path-or-url>]`
|
||||
|
||||
Paths can be local files (for example `/tmp/screenshot.png`) or HTTPS URLs.
|
||||
|
||||
### WhatsApp Business Cloud API Setup
|
||||
|
||||
WhatsApp uses Meta's Cloud API with webhooks (push-based, not polling):
|
||||
|
|
@ -610,7 +628,7 @@ We're building in the open because the best ideas come from everywhere. If you'r
|
|||
|
||||
## License
|
||||
|
||||
Apache 2.0 — see [LICENSE](LICENSE) and [NOTICE](NOTICE) for contributor attribution
|
||||
MIT — see [LICENSE](LICENSE) and [NOTICE](NOTICE) for contributor attribution
|
||||
|
||||
## Contributing
|
||||
|
||||
|
|
@ -624,7 +642,6 @@ See [CONTRIBUTING.md](CONTRIBUTING.md). Implement a trait, submit a PR:
|
|||
- New `Tunnel` → `src/tunnel/`
|
||||
- New `Skill` → `~/.zeroclaw/workspace/skills/<name>/`
|
||||
|
||||
|
||||
---
|
||||
|
||||
**ZeroClaw** — Zero overhead. Zero compromise. Deploy anywhere. Swap anything. 🦀
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
FROM ubuntu:22.04
|
||||
FROM ubuntu:22.04@sha256:c7eb020043d8fc2ae0793fb35a37bff1cf33f156d4d4b12ccc7f3ef8706c38b1
|
||||
|
||||
# Prevent interactive prompts during package installation
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
|
|
|
|||
|
|
@ -27,7 +27,7 @@ Merge-blocking checks should stay small and deterministic. Optional checks are u
|
|||
### Optional Repository Automation
|
||||
|
||||
- `.github/workflows/labeler.yml` (`PR Labeler`)
|
||||
- Purpose: scope/path labels + size/risk labels + fine-grained module labels (`<module>:<component>`)
|
||||
- Purpose: scope/path labels + size/risk labels + fine-grained module labels (`<module>: <component>`)
|
||||
- Additional behavior: label descriptions are auto-managed as hover tooltips to explain each auto-judgment rule
|
||||
- Additional behavior: provider-related keywords in provider/config/onboard/integration changes are promoted to `provider:*` labels (for example `provider:kimi`, `provider:deepseek`)
|
||||
- Additional behavior: hierarchical de-duplication keeps only the most specific scope labels (for example `tool:composio` suppresses `tool:core` and `tool`)
|
||||
|
|
|
|||
|
|
@ -244,7 +244,7 @@ Label discipline:
|
|||
- Path labels identify subsystem ownership quickly.
|
||||
- Size labels drive batching strategy.
|
||||
- Risk labels drive review depth (`risk: low/medium/high`).
|
||||
- Module labels (`<module>:<component>`) improve reviewer routing for integration-specific changes and future newly-added modules.
|
||||
- Module labels (`<module>: <component>`) improve reviewer routing for integration-specific changes and future newly-added modules.
|
||||
- `risk: manual` allows maintainers to preserve a human risk judgment when automation lacks context.
|
||||
- `no-stale` is reserved for accepted-but-blocked work.
|
||||
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ Use it to reduce review latency without reducing quality.
|
|||
For every new PR, do a fast intake pass:
|
||||
|
||||
1. Confirm template completeness (`summary`, `validation`, `security`, `rollback`).
|
||||
2. Confirm labels (`size:*`, `risk:*`, scope labels such as `provider`/`channel`/`security`, module-scoped labels such as `channel:*`/`provider:*`/`tool:*`, and contributor tier labels when applicable) are present and plausible.
|
||||
2. Confirm labels (`size:*`, `risk:*`, scope labels such as `provider`/`channel`/`security`, module-scoped labels such as `channel: *`/`provider: *`/`tool: *`, and contributor tier labels when applicable) are present and plausible.
|
||||
3. Confirm CI signal status (`CI Required Gate`).
|
||||
4. Confirm scope is one concern (reject mixed mega-PRs unless justified).
|
||||
5. Confirm privacy/data-hygiene and neutral test wording requirements are satisfied.
|
||||
|
|
|
|||
|
|
@ -12,6 +12,8 @@ use tokio::sync::mpsc;
|
|||
pub struct ChannelMessage {
|
||||
pub id: String,
|
||||
pub sender: String,
|
||||
/// Channel-specific reply address (e.g. Telegram chat_id, Discord channel_id).
|
||||
pub reply_to: String,
|
||||
pub content: String,
|
||||
pub channel: String,
|
||||
pub timestamp: u64,
|
||||
|
|
@ -90,9 +92,12 @@ impl Channel for TelegramChannel {
|
|||
continue;
|
||||
}
|
||||
|
||||
let chat_id = msg["chat"]["id"].to_string();
|
||||
|
||||
let channel_msg = ChannelMessage {
|
||||
id: msg["message_id"].to_string(),
|
||||
sender,
|
||||
reply_to: chat_id,
|
||||
content: msg["text"].as_str().unwrap_or("").to_string(),
|
||||
channel: "telegram".into(),
|
||||
timestamp: msg["date"].as_u64().unwrap_or(0),
|
||||
|
|
|
|||
|
|
@ -2,4 +2,10 @@
|
|||
target = "riscv32imc-esp-espidf"
|
||||
|
||||
[target.riscv32imc-esp-espidf]
|
||||
linker = "ldproxy"
|
||||
runner = "espflash flash --monitor"
|
||||
# ESP-IDF 5.x uses 64-bit time_t
|
||||
rustflags = ["-C", "default-linker-libraries", "--cfg", "espidf_time64"]
|
||||
|
||||
[unstable]
|
||||
build-std = ["std", "panic_abort"]
|
||||
|
|
|
|||
106
firmware/zeroclaw-esp32/Cargo.lock
generated
106
firmware/zeroclaw-esp32/Cargo.lock
generated
|
|
@ -58,24 +58,22 @@ checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8"
|
|||
|
||||
[[package]]
|
||||
name = "bindgen"
|
||||
version = "0.63.0"
|
||||
version = "0.71.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "36d860121800b2a9a94f9b5604b332d5cffb234ce17609ea479d723dbc9d3885"
|
||||
checksum = "5f58bf3d7db68cfbac37cfc485a8d711e87e064c3d0fe0435b92f7a407f9d6b3"
|
||||
dependencies = [
|
||||
"bitflags 1.3.2",
|
||||
"bitflags 2.11.0",
|
||||
"cexpr",
|
||||
"clang-sys",
|
||||
"lazy_static",
|
||||
"lazycell",
|
||||
"itertools",
|
||||
"log",
|
||||
"peeking_take_while",
|
||||
"prettyplease",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"regex",
|
||||
"rustc-hash",
|
||||
"shlex",
|
||||
"syn 1.0.109",
|
||||
"which",
|
||||
"syn 2.0.116",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
|
@ -374,14 +372,15 @@ checksum = "dc2d050bdc5c21e0862a89256ed8029ae6c290a93aecefc73084b3002cdebb01"
|
|||
|
||||
[[package]]
|
||||
name = "embassy-sync"
|
||||
version = "0.5.0"
|
||||
version = "0.7.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "dd938f25c0798db4280fcd8026bf4c2f48789aebf8f77b6e5cf8a7693ba114ec"
|
||||
checksum = "73974a3edbd0bd286759b3d483540f0ebef705919a5f56f4fc7709066f71689b"
|
||||
dependencies = [
|
||||
"cfg-if",
|
||||
"critical-section",
|
||||
"embedded-io-async",
|
||||
"futures-util",
|
||||
"futures-core",
|
||||
"futures-sink",
|
||||
"heapless",
|
||||
]
|
||||
|
||||
|
|
@ -446,16 +445,15 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "embedded-svc"
|
||||
version = "0.27.1"
|
||||
version = "0.28.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "ac6f87e7654f28018340aa55f933803017aefabaa5417820a3b2f808033c7bbc"
|
||||
checksum = "a7770e30ab55cfbf954c00019522490d6ce26a3334bede05a732ba61010e98e0"
|
||||
dependencies = [
|
||||
"defmt 0.3.100",
|
||||
"embedded-io",
|
||||
"embedded-io-async",
|
||||
"enumset",
|
||||
"heapless",
|
||||
"no-std-net",
|
||||
"num_enum",
|
||||
"serde",
|
||||
"strum 0.25.0",
|
||||
|
|
@ -463,9 +461,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "embuild"
|
||||
version = "0.31.4"
|
||||
version = "0.33.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "4caa4f198bb9152a55c0103efb83fa4edfcbb8625f4c9e94ae8ec8e23827c563"
|
||||
checksum = "e188ad2bbe82afa841ea4a29880651e53ab86815db036b2cb9f8de3ac32dad75"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"bindgen",
|
||||
|
|
@ -475,6 +473,7 @@ dependencies = [
|
|||
"globwalk",
|
||||
"home",
|
||||
"log",
|
||||
"regex",
|
||||
"remove_dir_all",
|
||||
"serde",
|
||||
"serde_json",
|
||||
|
|
@ -533,9 +532,8 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "esp-idf-hal"
|
||||
version = "0.43.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "f7adf3fb19a9ca016cbea1ab8a7b852ac69df8fcde4923c23d3b155efbc42a74"
|
||||
version = "0.45.2"
|
||||
source = "git+https://github.com/esp-rs/esp-idf-hal#bc48639bd626c72afc1e25e5d497b5c639161d30"
|
||||
dependencies = [
|
||||
"atomic-waker",
|
||||
"embassy-sync",
|
||||
|
|
@ -552,14 +550,12 @@ dependencies = [
|
|||
"heapless",
|
||||
"log",
|
||||
"nb 1.1.0",
|
||||
"num_enum",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "esp-idf-svc"
|
||||
version = "0.48.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2180642ca122a7fec1ec417a9b1a77aa66aaa067fdf1daae683dd8caba84f26b"
|
||||
version = "0.51.0"
|
||||
source = "git+https://github.com/esp-rs/esp-idf-svc#dee202f146c7681e54eabbf118a216fc0195d203"
|
||||
dependencies = [
|
||||
"embassy-futures",
|
||||
"embedded-hal-async",
|
||||
|
|
@ -567,6 +563,7 @@ dependencies = [
|
|||
"embuild",
|
||||
"enumset",
|
||||
"esp-idf-hal",
|
||||
"futures-io",
|
||||
"heapless",
|
||||
"log",
|
||||
"num_enum",
|
||||
|
|
@ -575,14 +572,13 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "esp-idf-sys"
|
||||
version = "0.34.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "2e148f97c04ed3e9181a08bcdc9560a515aad939b0ba7f50a0022e294665e0af"
|
||||
version = "0.36.1"
|
||||
source = "git+https://github.com/esp-rs/esp-idf-sys#64667a38fb8004e1fc3b032488af6857ca3cd849"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"bindgen",
|
||||
"build-time",
|
||||
"cargo_metadata",
|
||||
"cmake",
|
||||
"const_format",
|
||||
"embuild",
|
||||
"envy",
|
||||
|
|
@ -649,21 +645,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
|||
checksum = "7e3450815272ef58cec6d564423f6e755e25379b217b0bc688e295ba24df6b1d"
|
||||
|
||||
[[package]]
|
||||
name = "futures-task"
|
||||
name = "futures-io"
|
||||
version = "0.3.32"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "037711b3d59c33004d3856fbdc83b99d4ff37a24768fa1be9ce3538a1cde4393"
|
||||
checksum = "cecba35d7ad927e23624b22ad55235f2239cfa44fd10428eecbeba6d6a717718"
|
||||
|
||||
[[package]]
|
||||
name = "futures-util"
|
||||
name = "futures-sink"
|
||||
version = "0.3.32"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "389ca41296e6190b48053de0321d02a77f32f8a5d2461dd38762c0593805c6d6"
|
||||
dependencies = [
|
||||
"futures-core",
|
||||
"futures-task",
|
||||
"pin-project-lite",
|
||||
]
|
||||
checksum = "c39754e157331b013978ec91992bde1ac089843443c49cbc7f46150b0fad0893"
|
||||
|
||||
[[package]]
|
||||
name = "getrandom"
|
||||
|
|
@ -827,6 +818,15 @@ dependencies = [
|
|||
"serde_core",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "itertools"
|
||||
version = "0.13.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186"
|
||||
dependencies = [
|
||||
"either",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "itoa"
|
||||
version = "1.0.17"
|
||||
|
|
@ -843,18 +843,6 @@ dependencies = [
|
|||
"wasm-bindgen",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "lazy_static"
|
||||
version = "1.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe"
|
||||
|
||||
[[package]]
|
||||
name = "lazycell"
|
||||
version = "1.3.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55"
|
||||
|
||||
[[package]]
|
||||
name = "leb128fmt"
|
||||
version = "0.1.0"
|
||||
|
|
@ -945,12 +933,6 @@ dependencies = [
|
|||
"libc",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "no-std-net"
|
||||
version = "0.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1bcece43b12349917e096cddfa66107277f123e6c96a5aea78711dc601a47152"
|
||||
|
||||
[[package]]
|
||||
name = "nom"
|
||||
version = "7.1.3"
|
||||
|
|
@ -1007,18 +989,6 @@ version = "1.21.3"
|
|||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d"
|
||||
|
||||
[[package]]
|
||||
name = "peeking_take_while"
|
||||
version = "0.1.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "19b17cddbe7ec3f8bc800887bab5e717348c95ea2ca0b1bf0837fb964dc67099"
|
||||
|
||||
[[package]]
|
||||
name = "pin-project-lite"
|
||||
version = "0.2.16"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3b3cff922bd51709b605d9ead9aa71031d81447142d828eb4a6eba76fe619f9b"
|
||||
|
||||
[[package]]
|
||||
name = "prettyplease"
|
||||
version = "0.2.37"
|
||||
|
|
@ -1138,9 +1108,9 @@ dependencies = [
|
|||
|
||||
[[package]]
|
||||
name = "rustc-hash"
|
||||
version = "1.1.0"
|
||||
version = "2.1.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2"
|
||||
checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d"
|
||||
|
||||
[[package]]
|
||||
name = "rustix"
|
||||
|
|
|
|||
|
|
@ -14,15 +14,21 @@ edition = "2021"
|
|||
license = "MIT"
|
||||
description = "ZeroClaw ESP32 peripheral firmware — GPIO over JSON serial"
|
||||
|
||||
[patch.crates-io]
|
||||
# Use latest esp-rs crates to fix u8/i8 char pointer compatibility with ESP-IDF 5.x
|
||||
esp-idf-sys = { git = "https://github.com/esp-rs/esp-idf-sys" }
|
||||
esp-idf-hal = { git = "https://github.com/esp-rs/esp-idf-hal" }
|
||||
esp-idf-svc = { git = "https://github.com/esp-rs/esp-idf-svc" }
|
||||
|
||||
[dependencies]
|
||||
esp-idf-svc = "0.48"
|
||||
esp-idf-svc = { git = "https://github.com/esp-rs/esp-idf-svc" }
|
||||
log = "0.4"
|
||||
anyhow = "1.0"
|
||||
serde = { version = "1.0", features = ["derive"] }
|
||||
serde_json = "1.0"
|
||||
|
||||
[build-dependencies]
|
||||
embuild = "0.31"
|
||||
embuild = { version = "0.33", features = ["espidf"] }
|
||||
|
||||
[profile.release]
|
||||
opt-level = "s"
|
||||
|
|
|
|||
|
|
@ -2,8 +2,11 @@
|
|||
|
||||
Peripheral firmware for ESP32 — speaks the same JSON-over-serial protocol as the STM32 firmware. Flash this to your ESP32, then configure ZeroClaw on the host to connect via serial.
|
||||
|
||||
**New to this?** See [SETUP.md](SETUP.md) for step-by-step commands and troubleshooting.
|
||||
|
||||
## Protocol
|
||||
|
||||
|
||||
- **Request** (host → ESP32): `{"id":"1","cmd":"gpio_write","args":{"pin":13,"value":1}}\n`
|
||||
- **Response** (ESP32 → host): `{"id":"1","ok":true,"result":"done"}\n`
|
||||
|
||||
|
|
@ -11,19 +14,44 @@ Commands: `gpio_read`, `gpio_write`.
|
|||
|
||||
## Prerequisites
|
||||
|
||||
1. **ESP toolchain** (espup):
|
||||
1. **RISC-V ESP-IDF** (ESP32-C2/C3): Uses nightly Rust with `build-std`.
|
||||
|
||||
**Python**: ESP-IDF requires Python 3.10–3.13 (not 3.14). If you have Python 3.14:
|
||||
```sh
|
||||
brew install python@3.12
|
||||
```
|
||||
|
||||
**virtualenv** (needed by ESP-IDF tools; PEP 668 workaround on macOS):
|
||||
```sh
|
||||
/opt/homebrew/opt/python@3.12/bin/python3.12 -m pip install virtualenv --break-system-packages
|
||||
```
|
||||
|
||||
**Rust tools**:
|
||||
```sh
|
||||
cargo install espflash ldproxy
|
||||
```
|
||||
|
||||
The project's `rust-toolchain.toml` pins nightly + rust-src. `esp-idf-sys` downloads ESP-IDF automatically on first build. Use Python 3.12 for the build:
|
||||
```sh
|
||||
export PATH="/opt/homebrew/opt/python@3.12/libexec/bin:$PATH"
|
||||
```
|
||||
|
||||
2. **Xtensa targets** (ESP32, ESP32-S2, ESP32-S3): Use espup instead:
|
||||
```sh
|
||||
cargo install espup espflash
|
||||
espup install
|
||||
source ~/export-esp.sh # or ~/export-esp.fish for Fish
|
||||
source ~/export-esp.sh
|
||||
```
|
||||
|
||||
2. **Target**: ESP32-C3 (RISC-V) by default. Edit `.cargo/config.toml` for other targets (e.g. `xtensa-esp32-espidf` for original ESP32).
|
||||
Then edit `.cargo/config.toml` to change the target (e.g. `xtensa-esp32-espidf`).
|
||||
|
||||
## Build & Flash
|
||||
|
||||
```sh
|
||||
cd firmware/zeroclaw-esp32
|
||||
# Use Python 3.12 (required if you have 3.14)
|
||||
export PATH="/opt/homebrew/opt/python@3.12/libexec/bin:$PATH"
|
||||
# Optional: pin MCU (esp32c3 or esp32c2)
|
||||
export MCU=esp32c3
|
||||
cargo build --release
|
||||
espflash flash target/riscv32imc-esp-espidf/release/zeroclaw-esp32 --monitor
|
||||
```
|
||||
|
|
|
|||
156
firmware/zeroclaw-esp32/SETUP.md
Normal file
156
firmware/zeroclaw-esp32/SETUP.md
Normal file
|
|
@ -0,0 +1,156 @@
|
|||
# ESP32 Firmware Setup Guide
|
||||
|
||||
Step-by-step setup for building the ZeroClaw ESP32 firmware. Follow this if you run into issues.
|
||||
|
||||
## Quick Start (copy-paste)
|
||||
|
||||
```sh
|
||||
# 1. Install Python 3.12 (ESP-IDF needs 3.10–3.13, not 3.14)
|
||||
brew install python@3.12
|
||||
|
||||
# 2. Install virtualenv (PEP 668 workaround on macOS)
|
||||
/opt/homebrew/opt/python@3.12/bin/python3.12 -m pip install virtualenv --break-system-packages
|
||||
|
||||
# 3. Install Rust tools
|
||||
cargo install espflash ldproxy
|
||||
|
||||
# 4. Build
|
||||
cd firmware/zeroclaw-esp32
|
||||
export PATH="/opt/homebrew/opt/python@3.12/libexec/bin:$PATH"
|
||||
cargo build --release
|
||||
|
||||
# 5. Flash (connect ESP32 via USB)
|
||||
espflash flash target/riscv32imc-esp-espidf/release/zeroclaw-esp32 --monitor
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Detailed Steps
|
||||
|
||||
### 1. Python
|
||||
|
||||
ESP-IDF requires Python 3.10–3.13. **Python 3.14 is not supported.**
|
||||
|
||||
```sh
|
||||
brew install python@3.12
|
||||
```
|
||||
|
||||
### 2. virtualenv
|
||||
|
||||
ESP-IDF tools need `virtualenv`. On macOS with Homebrew Python, PEP 668 blocks `pip install`; use:
|
||||
|
||||
```sh
|
||||
/opt/homebrew/opt/python@3.12/bin/python3.12 -m pip install virtualenv --break-system-packages
|
||||
```
|
||||
|
||||
### 3. Rust Tools
|
||||
|
||||
```sh
|
||||
cargo install espflash ldproxy
|
||||
```
|
||||
|
||||
- **espflash**: flash and monitor
|
||||
- **ldproxy**: linker for ESP-IDF builds
|
||||
|
||||
### 4. Use Python 3.12 for Builds
|
||||
|
||||
Before every build (or add to `~/.zshrc`):
|
||||
|
||||
```sh
|
||||
export PATH="/opt/homebrew/opt/python@3.12/libexec/bin:$PATH"
|
||||
```
|
||||
|
||||
### 5. Build
|
||||
|
||||
```sh
|
||||
cd firmware/zeroclaw-esp32
|
||||
cargo build --release
|
||||
```
|
||||
|
||||
First build downloads and compiles ESP-IDF (~5–15 min).
|
||||
|
||||
### 6. Flash
|
||||
|
||||
```sh
|
||||
espflash flash target/riscv32imc-esp-espidf/release/zeroclaw-esp32 --monitor
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### "No space left on device"
|
||||
|
||||
Free disk space. Common targets:
|
||||
|
||||
```sh
|
||||
# Cargo cache (often 5–20 GB)
|
||||
rm -rf ~/.cargo/registry/cache ~/.cargo/registry/src
|
||||
|
||||
# Unused Rust toolchains
|
||||
rustup toolchain list
|
||||
rustup toolchain uninstall <name>
|
||||
|
||||
# iOS Simulator runtimes (~35 GB)
|
||||
xcrun simctl delete unavailable
|
||||
|
||||
# Temp files
|
||||
rm -rf /var/folders/*/T/cargo-install*
|
||||
```
|
||||
|
||||
### "can't find crate for `core`" / "riscv32imc-esp-espidf target may not be installed"
|
||||
|
||||
This project uses **nightly Rust with build-std**, not espup. Ensure:
|
||||
|
||||
- `rust-toolchain.toml` exists (pins nightly + rust-src)
|
||||
- You are **not** sourcing `~/export-esp.sh` (that's for Xtensa targets)
|
||||
- Run `cargo build` from `firmware/zeroclaw-esp32`
|
||||
|
||||
### "externally-managed-environment" / "No module named 'virtualenv'"
|
||||
|
||||
Install virtualenv with the PEP 668 workaround:
|
||||
|
||||
```sh
|
||||
/opt/homebrew/opt/python@3.12/bin/python3.12 -m pip install virtualenv --break-system-packages
|
||||
```
|
||||
|
||||
### "expected `i64`, found `i32`" (time_t mismatch)
|
||||
|
||||
Already fixed in `.cargo/config.toml` with `espidf_time64` for ESP-IDF 5.x. If you use ESP-IDF 4.4, switch to `espidf_time32`.
|
||||
|
||||
### "expected `*const u8`, found `*const i8`" (esp-idf-svc)
|
||||
|
||||
Already fixed via `[patch.crates-io]` in `Cargo.toml` using esp-rs crates from git. Do not remove the patch.
|
||||
|
||||
### 10,000+ files in `git status`
|
||||
|
||||
The `.embuild/` directory (ESP-IDF cache) has ~100k+ files. It is in `.gitignore`. If you see them, ensure `.gitignore` contains:
|
||||
|
||||
```
|
||||
.embuild/
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Optional: Auto-load Python 3.12
|
||||
|
||||
Add to `~/.zshrc`:
|
||||
|
||||
```sh
|
||||
# ESP32 firmware build
|
||||
export PATH="/opt/homebrew/opt/python@3.12/libexec/bin:$PATH"
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Xtensa Targets (ESP32, ESP32-S2, ESP32-S3)
|
||||
|
||||
For non–RISC-V chips, use espup instead:
|
||||
|
||||
```sh
|
||||
cargo install espup espflash
|
||||
espup install
|
||||
source ~/export-esp.sh
|
||||
```
|
||||
|
||||
Then edit `.cargo/config.toml` to use `xtensa-esp32-espidf` (or the correct target).
|
||||
3
firmware/zeroclaw-esp32/rust-toolchain.toml
Normal file
3
firmware/zeroclaw-esp32/rust-toolchain.toml
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
[toolchain]
|
||||
channel = "nightly"
|
||||
components = ["rust-src"]
|
||||
|
|
@ -6,8 +6,9 @@
|
|||
//! Protocol: same as STM32 — see docs/hardware-peripherals-design.md
|
||||
|
||||
use esp_idf_svc::hal::gpio::PinDriver;
|
||||
use esp_idf_svc::hal::prelude::*;
|
||||
use esp_idf_svc::hal::uart::*;
|
||||
use esp_idf_svc::hal::peripherals::Peripherals;
|
||||
use esp_idf_svc::hal::uart::{UartConfig, UartDriver};
|
||||
use esp_idf_svc::hal::units::Hertz;
|
||||
use log::info;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
|
|
@ -36,9 +37,13 @@ fn main() -> anyhow::Result<()> {
|
|||
let peripherals = Peripherals::take()?;
|
||||
let pins = peripherals.pins;
|
||||
|
||||
// Create GPIO output drivers first (they take ownership of pins)
|
||||
let mut gpio2 = PinDriver::output(pins.gpio2)?;
|
||||
let mut gpio13 = PinDriver::output(pins.gpio13)?;
|
||||
|
||||
// UART0: TX=21, RX=20 (ESP32) — ESP32-C3 may use different pins; adjust for your board
|
||||
let config = UartConfig::new().baudrate(Hertz(115_200));
|
||||
let mut uart = UartDriver::new(
|
||||
let uart = UartDriver::new(
|
||||
peripherals.uart0,
|
||||
pins.gpio21,
|
||||
pins.gpio20,
|
||||
|
|
@ -60,7 +65,8 @@ fn main() -> anyhow::Result<()> {
|
|||
if b == b'\n' {
|
||||
if !line.is_empty() {
|
||||
if let Ok(line_str) = std::str::from_utf8(&line) {
|
||||
if let Ok(resp) = handle_request(line_str, &peripherals) {
|
||||
if let Ok(resp) = handle_request(line_str, &mut gpio2, &mut gpio13)
|
||||
{
|
||||
let out = serde_json::to_string(&resp).unwrap_or_default();
|
||||
let _ = uart.write(format!("{}\n", out).as_bytes());
|
||||
}
|
||||
|
|
@ -80,10 +86,15 @@ fn main() -> anyhow::Result<()> {
|
|||
}
|
||||
}
|
||||
|
||||
fn handle_request(
|
||||
fn handle_request<G2, G13>(
|
||||
line: &str,
|
||||
peripherals: &esp_idf_svc::hal::peripherals::Peripherals,
|
||||
) -> anyhow::Result<Response> {
|
||||
gpio2: &mut PinDriver<'_, G2>,
|
||||
gpio13: &mut PinDriver<'_, G13>,
|
||||
) -> anyhow::Result<Response>
|
||||
where
|
||||
G2: esp_idf_svc::hal::gpio::OutputMode,
|
||||
G13: esp_idf_svc::hal::gpio::OutputMode,
|
||||
{
|
||||
let req: Request = serde_json::from_str(line.trim())?;
|
||||
let id = req.id.clone();
|
||||
|
||||
|
|
@ -98,13 +109,13 @@ fn handle_request(
|
|||
}
|
||||
"gpio_read" => {
|
||||
let pin_num = req.args.get("pin").and_then(|v| v.as_u64()).unwrap_or(0) as i32;
|
||||
let value = gpio_read(peripherals, pin_num)?;
|
||||
let value = gpio_read(pin_num)?;
|
||||
Ok(value.to_string())
|
||||
}
|
||||
"gpio_write" => {
|
||||
let pin_num = req.args.get("pin").and_then(|v| v.as_u64()).unwrap_or(0) as i32;
|
||||
let value = req.args.get("value").and_then(|v| v.as_u64()).unwrap_or(0);
|
||||
gpio_write(peripherals, pin_num, value)?;
|
||||
gpio_write(gpio2, gpio13, pin_num, value)?;
|
||||
Ok("done".into())
|
||||
}
|
||||
_ => Err(anyhow::anyhow!("Unknown command: {}", req.cmd)),
|
||||
|
|
@ -126,28 +137,26 @@ fn handle_request(
|
|||
}
|
||||
}
|
||||
|
||||
fn gpio_read(_peripherals: &esp_idf_svc::hal::peripherals::Peripherals, _pin: i32) -> anyhow::Result<u8> {
|
||||
fn gpio_read(_pin: i32) -> anyhow::Result<u8> {
|
||||
// TODO: implement input pin read — requires storing InputPin drivers per pin
|
||||
Ok(0)
|
||||
}
|
||||
|
||||
fn gpio_write(
|
||||
peripherals: &esp_idf_svc::hal::peripherals::Peripherals,
|
||||
fn gpio_write<G2, G13>(
|
||||
gpio2: &mut PinDriver<'_, G2>,
|
||||
gpio13: &mut PinDriver<'_, G13>,
|
||||
pin: i32,
|
||||
value: u64,
|
||||
) -> anyhow::Result<()> {
|
||||
let pins = peripherals.pins;
|
||||
let level = value != 0;
|
||||
) -> anyhow::Result<()>
|
||||
where
|
||||
G2: esp_idf_svc::hal::gpio::OutputMode,
|
||||
G13: esp_idf_svc::hal::gpio::OutputMode,
|
||||
{
|
||||
let level = esp_idf_svc::hal::gpio::Level::from(value != 0);
|
||||
|
||||
match pin {
|
||||
2 => {
|
||||
let mut out = PinDriver::output(pins.gpio2)?;
|
||||
out.set_level(esp_idf_svc::hal::gpio::Level::from(level))?;
|
||||
}
|
||||
13 => {
|
||||
let mut out = PinDriver::output(pins.gpio13)?;
|
||||
out.set_level(esp_idf_svc::hal::gpio::Level::from(level))?;
|
||||
}
|
||||
2 => gpio2.set_level(level)?,
|
||||
13 => gpio13.set_level(level)?,
|
||||
_ => anyhow::bail!("Pin {} not configured (add to gpio_write)", pin),
|
||||
}
|
||||
Ok(())
|
||||
|
|
|
|||
324
scripts/recompute_contributor_tiers.sh
Executable file
324
scripts/recompute_contributor_tiers.sh
Executable file
|
|
@ -0,0 +1,324 @@
|
|||
#!/usr/bin/env bash
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
SCRIPT_NAME="$(basename "$0")"
|
||||
|
||||
usage() {
|
||||
cat <<USAGE
|
||||
Recompute contributor tier labels for historical PRs/issues.
|
||||
|
||||
Usage:
|
||||
./$SCRIPT_NAME [options]
|
||||
|
||||
Options:
|
||||
--repo <owner/repo> Target repository (default: current gh repo)
|
||||
--kind <both|prs|issues>
|
||||
Target objects (default: both)
|
||||
--state <all|open|closed>
|
||||
State filter for listing objects (default: all)
|
||||
--limit <N> Limit processed objects after fetch (default: 0 = no limit)
|
||||
--apply Apply label updates (default is dry-run)
|
||||
--dry-run Preview only (default)
|
||||
-h, --help Show this help
|
||||
|
||||
Examples:
|
||||
./$SCRIPT_NAME --repo zeroclaw-labs/zeroclaw --limit 50
|
||||
./$SCRIPT_NAME --repo zeroclaw-labs/zeroclaw --kind prs --state open --apply
|
||||
USAGE
|
||||
}
|
||||
|
||||
die() {
|
||||
echo "[$SCRIPT_NAME] ERROR: $*" >&2
|
||||
exit 1
|
||||
}
|
||||
|
||||
require_cmd() {
|
||||
if ! command -v "$1" >/dev/null 2>&1; then
|
||||
die "Required command not found: $1"
|
||||
fi
|
||||
}
|
||||
|
||||
urlencode() {
|
||||
jq -nr --arg value "$1" '$value|@uri'
|
||||
}
|
||||
|
||||
select_contributor_tier() {
|
||||
local merged_count="$1"
|
||||
if (( merged_count >= 50 )); then
|
||||
echo "distinguished contributor"
|
||||
elif (( merged_count >= 20 )); then
|
||||
echo "principal contributor"
|
||||
elif (( merged_count >= 10 )); then
|
||||
echo "experienced contributor"
|
||||
elif (( merged_count >= 5 )); then
|
||||
echo "trusted contributor"
|
||||
else
|
||||
echo ""
|
||||
fi
|
||||
}
|
||||
|
||||
DRY_RUN=1
|
||||
KIND="both"
|
||||
STATE="all"
|
||||
LIMIT=0
|
||||
REPO=""
|
||||
|
||||
while (($# > 0)); do
|
||||
case "$1" in
|
||||
--repo)
|
||||
[[ $# -ge 2 ]] || die "Missing value for --repo"
|
||||
REPO="$2"
|
||||
shift 2
|
||||
;;
|
||||
--kind)
|
||||
[[ $# -ge 2 ]] || die "Missing value for --kind"
|
||||
KIND="$2"
|
||||
shift 2
|
||||
;;
|
||||
--state)
|
||||
[[ $# -ge 2 ]] || die "Missing value for --state"
|
||||
STATE="$2"
|
||||
shift 2
|
||||
;;
|
||||
--limit)
|
||||
[[ $# -ge 2 ]] || die "Missing value for --limit"
|
||||
LIMIT="$2"
|
||||
shift 2
|
||||
;;
|
||||
--apply)
|
||||
DRY_RUN=0
|
||||
shift
|
||||
;;
|
||||
--dry-run)
|
||||
DRY_RUN=1
|
||||
shift
|
||||
;;
|
||||
-h|--help)
|
||||
usage
|
||||
exit 0
|
||||
;;
|
||||
*)
|
||||
die "Unknown option: $1"
|
||||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
case "$KIND" in
|
||||
both|prs|issues) ;;
|
||||
*) die "--kind must be one of: both, prs, issues" ;;
|
||||
esac
|
||||
|
||||
case "$STATE" in
|
||||
all|open|closed) ;;
|
||||
*) die "--state must be one of: all, open, closed" ;;
|
||||
esac
|
||||
|
||||
if ! [[ "$LIMIT" =~ ^[0-9]+$ ]]; then
|
||||
die "--limit must be a non-negative integer"
|
||||
fi
|
||||
|
||||
require_cmd gh
|
||||
require_cmd jq
|
||||
|
||||
if ! gh auth status >/dev/null 2>&1; then
|
||||
die "gh CLI is not authenticated. Run: gh auth login"
|
||||
fi
|
||||
|
||||
if [[ -z "$REPO" ]]; then
|
||||
REPO="$(gh repo view --json nameWithOwner --jq '.nameWithOwner' 2>/dev/null || true)"
|
||||
[[ -n "$REPO" ]] || die "Unable to infer repo. Pass --repo <owner/repo>."
|
||||
fi
|
||||
|
||||
echo "[$SCRIPT_NAME] Repo: $REPO"
|
||||
echo "[$SCRIPT_NAME] Mode: $([[ "$DRY_RUN" -eq 1 ]] && echo "dry-run" || echo "apply")"
|
||||
echo "[$SCRIPT_NAME] Kind: $KIND | State: $STATE | Limit: $LIMIT"
|
||||
|
||||
TIERS_JSON='["trusted contributor","experienced contributor","principal contributor","distinguished contributor"]'
|
||||
|
||||
TMP_FILES=()
|
||||
cleanup() {
|
||||
if ((${#TMP_FILES[@]} > 0)); then
|
||||
rm -f "${TMP_FILES[@]}"
|
||||
fi
|
||||
}
|
||||
trap cleanup EXIT
|
||||
|
||||
new_tmp_file() {
|
||||
local tmp
|
||||
tmp="$(mktemp)"
|
||||
TMP_FILES+=("$tmp")
|
||||
echo "$tmp"
|
||||
}
|
||||
|
||||
targets_file="$(new_tmp_file)"
|
||||
|
||||
if [[ "$KIND" == "both" || "$KIND" == "prs" ]]; then
|
||||
gh api --paginate "repos/$REPO/pulls?state=$STATE&per_page=100" \
|
||||
--jq '.[] | {
|
||||
kind: "pr",
|
||||
number: .number,
|
||||
author: (.user.login // ""),
|
||||
author_type: (.user.type // ""),
|
||||
labels: [(.labels[]?.name // empty)]
|
||||
}' >> "$targets_file"
|
||||
fi
|
||||
|
||||
if [[ "$KIND" == "both" || "$KIND" == "issues" ]]; then
|
||||
gh api --paginate "repos/$REPO/issues?state=$STATE&per_page=100" \
|
||||
--jq '.[] | select(.pull_request | not) | {
|
||||
kind: "issue",
|
||||
number: .number,
|
||||
author: (.user.login // ""),
|
||||
author_type: (.user.type // ""),
|
||||
labels: [(.labels[]?.name // empty)]
|
||||
}' >> "$targets_file"
|
||||
fi
|
||||
|
||||
if [[ "$LIMIT" -gt 0 ]]; then
|
||||
limited_file="$(new_tmp_file)"
|
||||
head -n "$LIMIT" "$targets_file" > "$limited_file"
|
||||
mv "$limited_file" "$targets_file"
|
||||
fi
|
||||
|
||||
target_count="$(wc -l < "$targets_file" | tr -d ' ')"
|
||||
if [[ "$target_count" -eq 0 ]]; then
|
||||
echo "[$SCRIPT_NAME] No targets found."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "[$SCRIPT_NAME] Targets fetched: $target_count"
|
||||
|
||||
# Ensure tier labels exist (trusted contributor might be new).
|
||||
label_color=""
|
||||
for probe_label in "experienced contributor" "principal contributor" "distinguished contributor" "trusted contributor"; do
|
||||
encoded_label="$(urlencode "$probe_label")"
|
||||
if color_candidate="$(gh api "repos/$REPO/labels/$encoded_label" --jq '.color' 2>/dev/null || true)"; then
|
||||
if [[ -n "$color_candidate" ]]; then
|
||||
label_color="$(echo "$color_candidate" | tr '[:lower:]' '[:upper:]')"
|
||||
break
|
||||
fi
|
||||
fi
|
||||
done
|
||||
[[ -n "$label_color" ]] || label_color="C5D7A2"
|
||||
|
||||
while IFS= read -r tier_label; do
|
||||
[[ -n "$tier_label" ]] || continue
|
||||
encoded_label="$(urlencode "$tier_label")"
|
||||
if gh api "repos/$REPO/labels/$encoded_label" >/dev/null 2>&1; then
|
||||
continue
|
||||
fi
|
||||
|
||||
if [[ "$DRY_RUN" -eq 1 ]]; then
|
||||
echo "[dry-run] Would create missing label: $tier_label (color=$label_color)"
|
||||
else
|
||||
gh api -X POST "repos/$REPO/labels" \
|
||||
-f name="$tier_label" \
|
||||
-f color="$label_color" >/dev/null
|
||||
echo "[apply] Created missing label: $tier_label"
|
||||
fi
|
||||
done < <(jq -r '.[]' <<<"$TIERS_JSON")
|
||||
|
||||
# Build merged PR count cache by unique human authors.
|
||||
authors_file="$(new_tmp_file)"
|
||||
jq -r 'select(.author != "" and .author_type != "Bot") | .author' "$targets_file" | sort -u > "$authors_file"
|
||||
author_count="$(wc -l < "$authors_file" | tr -d ' ')"
|
||||
echo "[$SCRIPT_NAME] Unique human authors: $author_count"
|
||||
|
||||
author_counts_file="$(new_tmp_file)"
|
||||
while IFS= read -r author; do
|
||||
[[ -n "$author" ]] || continue
|
||||
query="repo:$REPO is:pr is:merged author:$author"
|
||||
merged_count="$(gh api search/issues -f q="$query" -F per_page=1 --jq '.total_count' 2>/dev/null || true)"
|
||||
if ! [[ "$merged_count" =~ ^[0-9]+$ ]]; then
|
||||
merged_count=0
|
||||
fi
|
||||
printf '%s\t%s\n' "$author" "$merged_count" >> "$author_counts_file"
|
||||
done < "$authors_file"
|
||||
|
||||
updated=0
|
||||
unchanged=0
|
||||
skipped=0
|
||||
failed=0
|
||||
|
||||
while IFS= read -r target_json; do
|
||||
[[ -n "$target_json" ]] || continue
|
||||
|
||||
number="$(jq -r '.number' <<<"$target_json")"
|
||||
kind="$(jq -r '.kind' <<<"$target_json")"
|
||||
author="$(jq -r '.author' <<<"$target_json")"
|
||||
author_type="$(jq -r '.author_type' <<<"$target_json")"
|
||||
current_labels_json="$(jq -c '.labels // []' <<<"$target_json")"
|
||||
|
||||
if [[ -z "$author" || "$author_type" == "Bot" ]]; then
|
||||
skipped=$((skipped + 1))
|
||||
continue
|
||||
fi
|
||||
|
||||
merged_count="$(awk -F '\t' -v key="$author" '$1 == key { print $2; exit }' "$author_counts_file")"
|
||||
if ! [[ "$merged_count" =~ ^[0-9]+$ ]]; then
|
||||
merged_count=0
|
||||
fi
|
||||
desired_tier="$(select_contributor_tier "$merged_count")"
|
||||
|
||||
if ! current_tier="$(jq -r --argjson tiers "$TIERS_JSON" '[.[] | select(. as $label | ($tiers | index($label)) != null)][0] // ""' <<<"$current_labels_json" 2>/dev/null)"; then
|
||||
echo "[warn] Skipping ${kind} #${number}: cannot parse current labels JSON" >&2
|
||||
failed=$((failed + 1))
|
||||
continue
|
||||
fi
|
||||
|
||||
if ! next_labels_json="$(jq -c --arg desired "$desired_tier" --argjson tiers "$TIERS_JSON" '
|
||||
(. // [])
|
||||
| map(select(. as $label | ($tiers | index($label)) == null))
|
||||
| if $desired != "" then . + [$desired] else . end
|
||||
| unique
|
||||
' <<<"$current_labels_json" 2>/dev/null)"; then
|
||||
echo "[warn] Skipping ${kind} #${number}: cannot compute next labels" >&2
|
||||
failed=$((failed + 1))
|
||||
continue
|
||||
fi
|
||||
|
||||
if ! normalized_current="$(jq -c 'unique | sort' <<<"$current_labels_json" 2>/dev/null)"; then
|
||||
echo "[warn] Skipping ${kind} #${number}: cannot normalize current labels" >&2
|
||||
failed=$((failed + 1))
|
||||
continue
|
||||
fi
|
||||
|
||||
if ! normalized_next="$(jq -c 'unique | sort' <<<"$next_labels_json" 2>/dev/null)"; then
|
||||
echo "[warn] Skipping ${kind} #${number}: cannot normalize next labels" >&2
|
||||
failed=$((failed + 1))
|
||||
continue
|
||||
fi
|
||||
|
||||
if [[ "$normalized_current" == "$normalized_next" ]]; then
|
||||
unchanged=$((unchanged + 1))
|
||||
continue
|
||||
fi
|
||||
|
||||
if [[ "$DRY_RUN" -eq 1 ]]; then
|
||||
echo "[dry-run] ${kind} #${number} @${author} merged=${merged_count} tier: '${current_tier:-none}' -> '${desired_tier:-none}'"
|
||||
updated=$((updated + 1))
|
||||
continue
|
||||
fi
|
||||
|
||||
payload="$(jq -cn --argjson labels "$next_labels_json" '{labels: $labels}')"
|
||||
if gh api -X PUT "repos/$REPO/issues/$number/labels" --input - <<<"$payload" >/dev/null; then
|
||||
echo "[apply] Updated ${kind} #${number} @${author} tier: '${current_tier:-none}' -> '${desired_tier:-none}'"
|
||||
updated=$((updated + 1))
|
||||
else
|
||||
echo "[apply] FAILED ${kind} #${number}" >&2
|
||||
failed=$((failed + 1))
|
||||
fi
|
||||
done < "$targets_file"
|
||||
|
||||
echo ""
|
||||
echo "[$SCRIPT_NAME] Summary"
|
||||
echo " Targets: $target_count"
|
||||
echo " Updated: $updated"
|
||||
echo " Unchanged: $unchanged"
|
||||
echo " Skipped: $skipped"
|
||||
echo " Failed: $failed"
|
||||
|
||||
if [[ "$failed" -gt 0 ]]; then
|
||||
exit 1
|
||||
fi
|
||||
|
|
@ -251,6 +251,7 @@ impl Agent {
|
|||
let provider: Box<dyn Provider> = providers::create_routed_provider(
|
||||
provider_name,
|
||||
config.api_key.as_deref(),
|
||||
config.api_url.as_deref(),
|
||||
&config.reliability,
|
||||
&config.model_routes,
|
||||
&model_name,
|
||||
|
|
@ -388,7 +389,7 @@ impl Agent {
|
|||
if self.auto_save {
|
||||
let _ = self
|
||||
.memory
|
||||
.store("user_msg", user_message, MemoryCategory::Conversation)
|
||||
.store("user_msg", user_message, MemoryCategory::Conversation, None)
|
||||
.await;
|
||||
}
|
||||
|
||||
|
|
@ -447,7 +448,7 @@ impl Agent {
|
|||
let summary = truncate_with_ellipsis(&final_text, 100);
|
||||
let _ = self
|
||||
.memory
|
||||
.store("assistant_resp", &summary, MemoryCategory::Daily)
|
||||
.store("assistant_resp", &summary, MemoryCategory::Daily, None)
|
||||
.await;
|
||||
}
|
||||
|
||||
|
|
@ -557,6 +558,7 @@ pub async fn run(
|
|||
agent.observer.record_event(&ObserverEvent::AgentEnd {
|
||||
duration: start.elapsed(),
|
||||
tokens_used: None,
|
||||
cost_usd: None,
|
||||
});
|
||||
|
||||
Ok(())
|
||||
|
|
|
|||
|
|
@ -7,14 +7,70 @@ use crate::security::SecurityPolicy;
|
|||
use crate::tools::{self, Tool};
|
||||
use crate::util::truncate_with_ellipsis;
|
||||
use anyhow::Result;
|
||||
use regex::{Regex, RegexSet};
|
||||
use std::fmt::Write;
|
||||
use std::io::Write as _;
|
||||
use std::sync::Arc;
|
||||
use std::sync::{Arc, LazyLock};
|
||||
use std::time::Instant;
|
||||
use uuid::Uuid;
|
||||
|
||||
/// Maximum agentic tool-use iterations per user message to prevent runaway loops.
|
||||
const MAX_TOOL_ITERATIONS: usize = 10;
|
||||
|
||||
static SENSITIVE_KEY_PATTERNS: LazyLock<RegexSet> = LazyLock::new(|| {
|
||||
RegexSet::new([
|
||||
r"(?i)token",
|
||||
r"(?i)api[_-]?key",
|
||||
r"(?i)password",
|
||||
r"(?i)secret",
|
||||
r"(?i)user[_-]?key",
|
||||
r"(?i)bearer",
|
||||
r"(?i)credential",
|
||||
])
|
||||
.unwrap()
|
||||
});
|
||||
|
||||
static SENSITIVE_KV_REGEX: LazyLock<Regex> = LazyLock::new(|| {
|
||||
Regex::new(r#"(?i)(token|api[_-]?key|password|secret|user[_-]?key|bearer|credential)["']?\s*[:=]\s*(?:"([^"]{8,})"|'([^']{8,})'|([a-zA-Z0-9_\-\.]{8,}))"#).unwrap()
|
||||
});
|
||||
|
||||
/// Scrub credentials from tool output to prevent accidental exfiltration.
|
||||
/// Replaces known credential patterns with a redacted placeholder while preserving
|
||||
/// a small prefix for context.
|
||||
fn scrub_credentials(input: &str) -> String {
|
||||
SENSITIVE_KV_REGEX
|
||||
.replace_all(input, |caps: ®ex::Captures| {
|
||||
let full_match = &caps[0];
|
||||
let key = &caps[1];
|
||||
let val = caps
|
||||
.get(2)
|
||||
.or(caps.get(3))
|
||||
.or(caps.get(4))
|
||||
.map(|m| m.as_str())
|
||||
.unwrap_or("");
|
||||
|
||||
// Preserve first 4 chars for context, then redact
|
||||
let prefix = if val.len() > 4 { &val[..4] } else { "" };
|
||||
|
||||
if full_match.contains(':') {
|
||||
if full_match.contains('"') {
|
||||
format!("\"{}\": \"{}*[REDACTED]\"", key, prefix)
|
||||
} else {
|
||||
format!("{}: {}*[REDACTED]", key, prefix)
|
||||
}
|
||||
} else if full_match.contains('=') {
|
||||
if full_match.contains('"') {
|
||||
format!("{}=\"{}*[REDACTED]\"", key, prefix)
|
||||
} else {
|
||||
format!("{}={}*[REDACTED]", key, prefix)
|
||||
}
|
||||
} else {
|
||||
format!("{}: {}*[REDACTED]", key, prefix)
|
||||
}
|
||||
})
|
||||
.to_string()
|
||||
}
|
||||
|
||||
/// Trigger auto-compaction when non-system message count exceeds this threshold.
|
||||
const MAX_HISTORY_MESSAGES: usize = 50;
|
||||
|
||||
|
|
@ -145,7 +201,7 @@ async fn build_context(mem: &dyn Memory, user_msg: &str) -> String {
|
|||
let mut context = String::new();
|
||||
|
||||
// Pull relevant memories for this message
|
||||
if let Ok(entries) = mem.recall(user_msg, 5).await {
|
||||
if let Ok(entries) = mem.recall(user_msg, 5, None).await {
|
||||
if !entries.is_empty() {
|
||||
context.push_str("[Memory context]\n");
|
||||
for entry in &entries {
|
||||
|
|
@ -436,6 +492,7 @@ struct ParsedToolCall {
|
|||
/// Execute a single turn of the agent loop: send messages, parse tool calls,
|
||||
/// execute tools, and loop until the LLM produces a final text response.
|
||||
/// When `silent` is true, suppresses stdout (for channel use).
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) async fn agent_turn(
|
||||
provider: &dyn Provider,
|
||||
history: &mut Vec<ChatMessage>,
|
||||
|
|
@ -461,6 +518,7 @@ pub(crate) async fn agent_turn(
|
|||
|
||||
/// Execute a single turn of the agent loop: send messages, parse tool calls,
|
||||
/// execute tools, and loop until the LLM produces a final text response.
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub(crate) async fn run_tool_call_loop(
|
||||
provider: &dyn Provider,
|
||||
history: &mut Vec<ChatMessage>,
|
||||
|
|
@ -606,7 +664,7 @@ pub(crate) async fn run_tool_call_loop(
|
|||
success: r.success,
|
||||
});
|
||||
if r.success {
|
||||
r.output
|
||||
scrub_credentials(&r.output)
|
||||
} else {
|
||||
format!("Error: {}", r.error.unwrap_or_else(|| r.output))
|
||||
}
|
||||
|
|
@ -749,6 +807,7 @@ pub async fn run(
|
|||
let provider: Box<dyn Provider> = providers::create_routed_provider(
|
||||
provider_name,
|
||||
config.api_key.as_deref(),
|
||||
config.api_url.as_deref(),
|
||||
&config.reliability,
|
||||
&config.model_routes,
|
||||
model_name,
|
||||
|
|
@ -912,7 +971,7 @@ pub async fn run(
|
|||
if config.memory.auto_save {
|
||||
let user_key = autosave_memory_key("user_msg");
|
||||
let _ = mem
|
||||
.store(&user_key, &msg, MemoryCategory::Conversation)
|
||||
.store(&user_key, &msg, MemoryCategory::Conversation, None)
|
||||
.await;
|
||||
}
|
||||
|
||||
|
|
@ -955,7 +1014,7 @@ pub async fn run(
|
|||
let summary = truncate_with_ellipsis(&response, 100);
|
||||
let response_key = autosave_memory_key("assistant_resp");
|
||||
let _ = mem
|
||||
.store(&response_key, &summary, MemoryCategory::Daily)
|
||||
.store(&response_key, &summary, MemoryCategory::Daily, None)
|
||||
.await;
|
||||
}
|
||||
} else {
|
||||
|
|
@ -978,7 +1037,7 @@ pub async fn run(
|
|||
if config.memory.auto_save {
|
||||
let user_key = autosave_memory_key("user_msg");
|
||||
let _ = mem
|
||||
.store(&user_key, &msg.content, MemoryCategory::Conversation)
|
||||
.store(&user_key, &msg.content, MemoryCategory::Conversation, None)
|
||||
.await;
|
||||
}
|
||||
|
||||
|
|
@ -1036,7 +1095,7 @@ pub async fn run(
|
|||
let summary = truncate_with_ellipsis(&response, 100);
|
||||
let response_key = autosave_memory_key("assistant_resp");
|
||||
let _ = mem
|
||||
.store(&response_key, &summary, MemoryCategory::Daily)
|
||||
.store(&response_key, &summary, MemoryCategory::Daily, None)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
|
@ -1048,6 +1107,7 @@ pub async fn run(
|
|||
observer.record_event(&ObserverEvent::AgentEnd {
|
||||
duration,
|
||||
tokens_used: None,
|
||||
cost_usd: None,
|
||||
});
|
||||
|
||||
Ok(final_output)
|
||||
|
|
@ -1104,6 +1164,7 @@ pub async fn process_message(config: Config, message: &str) -> Result<String> {
|
|||
let provider: Box<dyn Provider> = providers::create_routed_provider(
|
||||
provider_name,
|
||||
config.api_key.as_deref(),
|
||||
config.api_url.as_deref(),
|
||||
&config.reliability,
|
||||
&config.model_routes,
|
||||
&model_name,
|
||||
|
|
@ -1217,6 +1278,25 @@ pub async fn process_message(config: Config, message: &str) -> Result<String> {
|
|||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_scrub_credentials() {
|
||||
let input = "API_KEY=sk-1234567890abcdef; token: 1234567890; password=\"secret123456\"";
|
||||
let scrubbed = scrub_credentials(input);
|
||||
assert!(scrubbed.contains("API_KEY=sk-1*[REDACTED]"));
|
||||
assert!(scrubbed.contains("token: 1234*[REDACTED]"));
|
||||
assert!(scrubbed.contains("password=\"secr*[REDACTED]\""));
|
||||
assert!(!scrubbed.contains("abcdef"));
|
||||
assert!(!scrubbed.contains("secret123456"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_scrub_credentials_json() {
|
||||
let input = r#"{"api_key": "sk-1234567890", "other": "public"}"#;
|
||||
let scrubbed = scrub_credentials(input);
|
||||
assert!(scrubbed.contains("\"api_key\": \"sk-1*[REDACTED]\""));
|
||||
assert!(scrubbed.contains("public"));
|
||||
}
|
||||
use crate::memory::{Memory, MemoryCategory, SqliteMemory};
|
||||
use tempfile::TempDir;
|
||||
|
||||
|
|
@ -1496,16 +1576,16 @@ I will now call the tool with this payload:
|
|||
let key1 = autosave_memory_key("user_msg");
|
||||
let key2 = autosave_memory_key("user_msg");
|
||||
|
||||
mem.store(&key1, "I'm Paul", MemoryCategory::Conversation)
|
||||
mem.store(&key1, "I'm Paul", MemoryCategory::Conversation, None)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store(&key2, "I'm 45", MemoryCategory::Conversation)
|
||||
mem.store(&key2, "I'm 45", MemoryCategory::Conversation, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(mem.count().await.unwrap(), 2);
|
||||
|
||||
let recalled = mem.recall("45", 5).await.unwrap();
|
||||
let recalled = mem.recall("45", 5, None).await.unwrap();
|
||||
assert!(recalled.iter().any(|entry| entry.content.contains("45")));
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -33,7 +33,7 @@ impl MemoryLoader for DefaultMemoryLoader {
|
|||
memory: &dyn Memory,
|
||||
user_message: &str,
|
||||
) -> anyhow::Result<String> {
|
||||
let entries = memory.recall(user_message, self.limit).await?;
|
||||
let entries = memory.recall(user_message, self.limit, None).await?;
|
||||
if entries.is_empty() {
|
||||
return Ok(String::new());
|
||||
}
|
||||
|
|
@ -61,11 +61,17 @@ mod tests {
|
|||
_key: &str,
|
||||
_content: &str,
|
||||
_category: MemoryCategory,
|
||||
_session_id: Option<&str>,
|
||||
) -> anyhow::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn recall(&self, _query: &str, limit: usize) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
async fn recall(
|
||||
&self,
|
||||
_query: &str,
|
||||
limit: usize,
|
||||
_session_id: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
if limit == 0 {
|
||||
return Ok(vec![]);
|
||||
}
|
||||
|
|
@ -87,6 +93,7 @@ mod tests {
|
|||
async fn list(
|
||||
&self,
|
||||
_category: Option<&MemoryCategory>,
|
||||
_session_id: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
Ok(vec![])
|
||||
}
|
||||
|
|
|
|||
|
|
@ -40,6 +40,7 @@ impl Channel for CliChannel {
|
|||
let msg = ChannelMessage {
|
||||
id: Uuid::new_v4().to_string(),
|
||||
sender: "user".to_string(),
|
||||
reply_target: "user".to_string(),
|
||||
content: line,
|
||||
channel: "cli".to_string(),
|
||||
timestamp: std::time::SystemTime::now()
|
||||
|
|
@ -90,12 +91,14 @@ mod tests {
|
|||
let msg = ChannelMessage {
|
||||
id: "test-id".into(),
|
||||
sender: "user".into(),
|
||||
reply_target: "user".into(),
|
||||
content: "hello".into(),
|
||||
channel: "cli".into(),
|
||||
timestamp: 1_234_567_890,
|
||||
};
|
||||
assert_eq!(msg.id, "test-id");
|
||||
assert_eq!(msg.sender, "user");
|
||||
assert_eq!(msg.reply_target, "user");
|
||||
assert_eq!(msg.content, "hello");
|
||||
assert_eq!(msg.channel, "cli");
|
||||
assert_eq!(msg.timestamp, 1_234_567_890);
|
||||
|
|
@ -106,6 +109,7 @@ mod tests {
|
|||
let msg = ChannelMessage {
|
||||
id: "id".into(),
|
||||
sender: "s".into(),
|
||||
reply_target: "s".into(),
|
||||
content: "c".into(),
|
||||
channel: "ch".into(),
|
||||
timestamp: 0,
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ use tokio::sync::RwLock;
|
|||
use tokio_tungstenite::tungstenite::Message;
|
||||
use uuid::Uuid;
|
||||
|
||||
/// DingTalk (钉钉) channel — connects via Stream Mode WebSocket for real-time messages.
|
||||
/// DingTalk channel — connects via Stream Mode WebSocket for real-time messages.
|
||||
/// Replies are sent through per-message session webhook URLs.
|
||||
pub struct DingTalkChannel {
|
||||
client_id: String,
|
||||
|
|
@ -64,6 +64,18 @@ impl DingTalkChannel {
|
|||
let gw: GatewayResponse = resp.json().await?;
|
||||
Ok(gw)
|
||||
}
|
||||
|
||||
fn resolve_reply_target(
|
||||
sender_id: &str,
|
||||
conversation_type: &str,
|
||||
conversation_id: Option<&str>,
|
||||
) -> String {
|
||||
if conversation_type == "1" {
|
||||
sender_id.to_string()
|
||||
} else {
|
||||
conversation_id.unwrap_or(sender_id).to_string()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
|
|
@ -193,14 +205,11 @@ impl Channel for DingTalkChannel {
|
|||
.unwrap_or("1");
|
||||
|
||||
// Private chat uses sender ID, group chat uses conversation ID
|
||||
let chat_id = if conversation_type == "1" {
|
||||
sender_id.to_string()
|
||||
} else {
|
||||
data.get("conversationId")
|
||||
.and_then(|c| c.as_str())
|
||||
.unwrap_or(sender_id)
|
||||
.to_string()
|
||||
};
|
||||
let chat_id = Self::resolve_reply_target(
|
||||
sender_id,
|
||||
conversation_type,
|
||||
data.get("conversationId").and_then(|c| c.as_str()),
|
||||
);
|
||||
|
||||
// Store session webhook for later replies
|
||||
if let Some(webhook) = data.get("sessionWebhook").and_then(|w| w.as_str()) {
|
||||
|
|
@ -229,6 +238,7 @@ impl Channel for DingTalkChannel {
|
|||
let channel_msg = ChannelMessage {
|
||||
id: Uuid::new_v4().to_string(),
|
||||
sender: sender_id.to_string(),
|
||||
reply_target: chat_id,
|
||||
content: content.to_string(),
|
||||
channel: "dingtalk".to_string(),
|
||||
timestamp: std::time::SystemTime::now()
|
||||
|
|
@ -305,4 +315,22 @@ client_secret = "secret"
|
|||
let config: crate::config::schema::DingTalkConfig = toml::from_str(toml_str).unwrap();
|
||||
assert!(config.allowed_users.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolve_reply_target_private_chat_uses_sender_id() {
|
||||
let target = DingTalkChannel::resolve_reply_target("staff_1", "1", Some("conv_1"));
|
||||
assert_eq!(target, "staff_1");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolve_reply_target_group_chat_uses_conversation_id() {
|
||||
let target = DingTalkChannel::resolve_reply_target("staff_1", "2", Some("conv_1"));
|
||||
assert_eq!(target, "conv_1");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolve_reply_target_group_chat_falls_back_to_sender_id() {
|
||||
let target = DingTalkChannel::resolve_reply_target("staff_1", "2", None);
|
||||
assert_eq!(target, "staff_1");
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ pub struct DiscordChannel {
|
|||
guild_id: Option<String>,
|
||||
allowed_users: Vec<String>,
|
||||
listen_to_bots: bool,
|
||||
mention_only: bool,
|
||||
client: reqwest::Client,
|
||||
typing_handle: std::sync::Mutex<Option<tokio::task::JoinHandle<()>>>,
|
||||
}
|
||||
|
|
@ -21,12 +22,14 @@ impl DiscordChannel {
|
|||
guild_id: Option<String>,
|
||||
allowed_users: Vec<String>,
|
||||
listen_to_bots: bool,
|
||||
mention_only: bool,
|
||||
) -> Self {
|
||||
Self {
|
||||
bot_token,
|
||||
guild_id,
|
||||
allowed_users,
|
||||
listen_to_bots,
|
||||
mention_only,
|
||||
client: reqwest::Client::new(),
|
||||
typing_handle: std::sync::Mutex::new(None),
|
||||
}
|
||||
|
|
@ -343,6 +346,22 @@ impl Channel for DiscordChannel {
|
|||
continue;
|
||||
}
|
||||
|
||||
// Skip messages that don't @-mention the bot (when mention_only is enabled)
|
||||
if self.mention_only {
|
||||
let mention_tag = format!("<@{bot_user_id}>");
|
||||
if !content.contains(&mention_tag) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
// Strip the bot mention from content so the agent sees clean text
|
||||
let clean_content = if self.mention_only {
|
||||
let mention_tag = format!("<@{bot_user_id}>");
|
||||
content.replace(&mention_tag, "").trim().to_string()
|
||||
} else {
|
||||
content.to_string()
|
||||
};
|
||||
|
||||
let message_id = d.get("id").and_then(|i| i.as_str()).unwrap_or("");
|
||||
let channel_id = d.get("channel_id").and_then(|c| c.as_str()).unwrap_or("").to_string();
|
||||
|
||||
|
|
@ -353,6 +372,11 @@ impl Channel for DiscordChannel {
|
|||
format!("discord_{message_id}")
|
||||
},
|
||||
sender: author_id.to_string(),
|
||||
reply_target: if channel_id.is_empty() {
|
||||
author_id.to_string()
|
||||
} else {
|
||||
channel_id
|
||||
},
|
||||
content: content.to_string(),
|
||||
channel: channel_id,
|
||||
timestamp: std::time::SystemTime::now()
|
||||
|
|
@ -423,7 +447,7 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn discord_channel_name() {
|
||||
let ch = DiscordChannel::new("fake".into(), None, vec![], false);
|
||||
let ch = DiscordChannel::new("fake".into(), None, vec![], false, false);
|
||||
assert_eq!(ch.name(), "discord");
|
||||
}
|
||||
|
||||
|
|
@ -444,21 +468,27 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn empty_allowlist_denies_everyone() {
|
||||
let ch = DiscordChannel::new("fake".into(), None, vec![], false);
|
||||
let ch = DiscordChannel::new("fake".into(), None, vec![], false, false);
|
||||
assert!(!ch.is_user_allowed("12345"));
|
||||
assert!(!ch.is_user_allowed("anyone"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn wildcard_allows_everyone() {
|
||||
let ch = DiscordChannel::new("fake".into(), None, vec!["*".into()], false);
|
||||
let ch = DiscordChannel::new("fake".into(), None, vec!["*".into()], false, false);
|
||||
assert!(ch.is_user_allowed("12345"));
|
||||
assert!(ch.is_user_allowed("anyone"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn specific_allowlist_filters() {
|
||||
let ch = DiscordChannel::new("fake".into(), None, vec!["111".into(), "222".into()], false);
|
||||
let ch = DiscordChannel::new(
|
||||
"fake".into(),
|
||||
None,
|
||||
vec!["111".into(), "222".into()],
|
||||
false,
|
||||
false,
|
||||
);
|
||||
assert!(ch.is_user_allowed("111"));
|
||||
assert!(ch.is_user_allowed("222"));
|
||||
assert!(!ch.is_user_allowed("333"));
|
||||
|
|
@ -467,7 +497,7 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn allowlist_is_exact_match_not_substring() {
|
||||
let ch = DiscordChannel::new("fake".into(), None, vec!["111".into()], false);
|
||||
let ch = DiscordChannel::new("fake".into(), None, vec!["111".into()], false, false);
|
||||
assert!(!ch.is_user_allowed("1111"));
|
||||
assert!(!ch.is_user_allowed("11"));
|
||||
assert!(!ch.is_user_allowed("0111"));
|
||||
|
|
@ -475,20 +505,26 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn allowlist_empty_string_user_id() {
|
||||
let ch = DiscordChannel::new("fake".into(), None, vec!["111".into()], false);
|
||||
let ch = DiscordChannel::new("fake".into(), None, vec!["111".into()], false, false);
|
||||
assert!(!ch.is_user_allowed(""));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn allowlist_with_wildcard_and_specific() {
|
||||
let ch = DiscordChannel::new("fake".into(), None, vec!["111".into(), "*".into()], false);
|
||||
let ch = DiscordChannel::new(
|
||||
"fake".into(),
|
||||
None,
|
||||
vec!["111".into(), "*".into()],
|
||||
false,
|
||||
false,
|
||||
);
|
||||
assert!(ch.is_user_allowed("111"));
|
||||
assert!(ch.is_user_allowed("anyone_else"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn allowlist_case_sensitive() {
|
||||
let ch = DiscordChannel::new("fake".into(), None, vec!["ABC".into()], false);
|
||||
let ch = DiscordChannel::new("fake".into(), None, vec!["ABC".into()], false, false);
|
||||
assert!(ch.is_user_allowed("ABC"));
|
||||
assert!(!ch.is_user_allowed("abc"));
|
||||
assert!(!ch.is_user_allowed("Abc"));
|
||||
|
|
@ -663,14 +699,14 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn typing_handle_starts_as_none() {
|
||||
let ch = DiscordChannel::new("fake".into(), None, vec![], false);
|
||||
let ch = DiscordChannel::new("fake".into(), None, vec![], false, false);
|
||||
let guard = ch.typing_handle.lock().unwrap();
|
||||
assert!(guard.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn start_typing_sets_handle() {
|
||||
let ch = DiscordChannel::new("fake".into(), None, vec![], false);
|
||||
let ch = DiscordChannel::new("fake".into(), None, vec![], false, false);
|
||||
let _ = ch.start_typing("123456").await;
|
||||
let guard = ch.typing_handle.lock().unwrap();
|
||||
assert!(guard.is_some());
|
||||
|
|
@ -678,7 +714,7 @@ mod tests {
|
|||
|
||||
#[tokio::test]
|
||||
async fn stop_typing_clears_handle() {
|
||||
let ch = DiscordChannel::new("fake".into(), None, vec![], false);
|
||||
let ch = DiscordChannel::new("fake".into(), None, vec![], false, false);
|
||||
let _ = ch.start_typing("123456").await;
|
||||
let _ = ch.stop_typing("123456").await;
|
||||
let guard = ch.typing_handle.lock().unwrap();
|
||||
|
|
@ -687,14 +723,14 @@ mod tests {
|
|||
|
||||
#[tokio::test]
|
||||
async fn stop_typing_is_idempotent() {
|
||||
let ch = DiscordChannel::new("fake".into(), None, vec![], false);
|
||||
let ch = DiscordChannel::new("fake".into(), None, vec![], false, false);
|
||||
assert!(ch.stop_typing("123456").await.is_ok());
|
||||
assert!(ch.stop_typing("123456").await.is_ok());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn start_typing_replaces_existing_task() {
|
||||
let ch = DiscordChannel::new("fake".into(), None, vec![], false);
|
||||
let ch = DiscordChannel::new("fake".into(), None, vec![], false, false);
|
||||
let _ = ch.start_typing("111").await;
|
||||
let _ = ch.start_typing("222").await;
|
||||
let guard = ch.typing_handle.lock().unwrap();
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@
|
|||
|
||||
use anyhow::{anyhow, Result};
|
||||
use async_trait::async_trait;
|
||||
use lettre::message::SinglePart;
|
||||
use lettre::transport::smtp::authentication::Credentials;
|
||||
use lettre::{Message, SmtpTransport, Transport};
|
||||
use mail_parser::{MessageParser, MimeHeaders};
|
||||
|
|
@ -39,7 +40,7 @@ pub struct EmailConfig {
|
|||
pub imap_folder: String,
|
||||
/// SMTP server hostname
|
||||
pub smtp_host: String,
|
||||
/// SMTP server port (default: 587 for STARTTLS)
|
||||
/// SMTP server port (default: 465 for TLS)
|
||||
#[serde(default = "default_smtp_port")]
|
||||
pub smtp_port: u16,
|
||||
/// Use TLS for SMTP (default: true)
|
||||
|
|
@ -63,7 +64,7 @@ fn default_imap_port() -> u16 {
|
|||
993
|
||||
}
|
||||
fn default_smtp_port() -> u16 {
|
||||
587
|
||||
465
|
||||
}
|
||||
fn default_imap_folder() -> String {
|
||||
"INBOX".into()
|
||||
|
|
@ -389,7 +390,7 @@ impl Channel for EmailChannel {
|
|||
.from(self.config.from_address.parse()?)
|
||||
.to(recipient.parse()?)
|
||||
.subject(subject)
|
||||
.body(body.to_string())?;
|
||||
.singlepart(SinglePart::plain(body.to_string()))?;
|
||||
|
||||
let transport = self.create_smtp_transport()?;
|
||||
transport.send(&email)?;
|
||||
|
|
@ -427,6 +428,7 @@ impl Channel for EmailChannel {
|
|||
} // MutexGuard dropped before await
|
||||
let msg = ChannelMessage {
|
||||
id,
|
||||
reply_target: sender.clone(),
|
||||
sender,
|
||||
content,
|
||||
channel: "email".to_string(),
|
||||
|
|
@ -464,6 +466,18 @@ impl Channel for EmailChannel {
|
|||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn default_smtp_port_uses_tls_port() {
|
||||
assert_eq!(default_smtp_port(), 465);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn email_config_default_uses_tls_smtp_defaults() {
|
||||
let config = EmailConfig::default();
|
||||
assert_eq!(config.smtp_port, 465);
|
||||
assert!(config.smtp_tls);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_imap_tls_config_succeeds() {
|
||||
let tls_config =
|
||||
|
|
@ -504,7 +518,7 @@ mod tests {
|
|||
assert_eq!(config.imap_port, 993);
|
||||
assert_eq!(config.imap_folder, "INBOX");
|
||||
assert_eq!(config.smtp_host, "");
|
||||
assert_eq!(config.smtp_port, 587);
|
||||
assert_eq!(config.smtp_port, 465);
|
||||
assert!(config.smtp_tls);
|
||||
assert_eq!(config.username, "");
|
||||
assert_eq!(config.password, "");
|
||||
|
|
@ -765,8 +779,8 @@ mod tests {
|
|||
}
|
||||
|
||||
#[test]
|
||||
fn default_smtp_port_returns_587() {
|
||||
assert_eq!(default_smtp_port(), 587);
|
||||
fn default_smtp_port_returns_465() {
|
||||
assert_eq!(default_smtp_port(), 465);
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -822,7 +836,7 @@ mod tests {
|
|||
|
||||
let config: EmailConfig = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(config.imap_port, 993); // default
|
||||
assert_eq!(config.smtp_port, 587); // default
|
||||
assert_eq!(config.smtp_port, 465); // default
|
||||
assert!(config.smtp_tls); // default
|
||||
assert_eq!(config.poll_interval_secs, 60); // default
|
||||
}
|
||||
|
|
|
|||
|
|
@ -172,6 +172,7 @@ end tell"#
|
|||
let msg = ChannelMessage {
|
||||
id: rowid.to_string(),
|
||||
sender: sender.clone(),
|
||||
reply_target: sender.clone(),
|
||||
content: text,
|
||||
channel: "imessage".to_string(),
|
||||
timestamp: std::time::SystemTime::now()
|
||||
|
|
|
|||
|
|
@ -220,32 +220,34 @@ fn split_message(message: &str, max_bytes: usize) -> Vec<String> {
|
|||
chunks
|
||||
}
|
||||
|
||||
/// Configuration for constructing an `IrcChannel`.
|
||||
pub struct IrcChannelConfig {
|
||||
pub server: String,
|
||||
pub port: u16,
|
||||
pub nickname: String,
|
||||
pub username: Option<String>,
|
||||
pub channels: Vec<String>,
|
||||
pub allowed_users: Vec<String>,
|
||||
pub server_password: Option<String>,
|
||||
pub nickserv_password: Option<String>,
|
||||
pub sasl_password: Option<String>,
|
||||
pub verify_tls: bool,
|
||||
}
|
||||
|
||||
impl IrcChannel {
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
pub fn new(
|
||||
server: String,
|
||||
port: u16,
|
||||
nickname: String,
|
||||
username: Option<String>,
|
||||
channels: Vec<String>,
|
||||
allowed_users: Vec<String>,
|
||||
server_password: Option<String>,
|
||||
nickserv_password: Option<String>,
|
||||
sasl_password: Option<String>,
|
||||
verify_tls: bool,
|
||||
) -> Self {
|
||||
let username = username.unwrap_or_else(|| nickname.clone());
|
||||
pub fn new(cfg: IrcChannelConfig) -> Self {
|
||||
let username = cfg.username.unwrap_or_else(|| cfg.nickname.clone());
|
||||
Self {
|
||||
server,
|
||||
port,
|
||||
nickname,
|
||||
server: cfg.server,
|
||||
port: cfg.port,
|
||||
nickname: cfg.nickname,
|
||||
username,
|
||||
channels,
|
||||
allowed_users,
|
||||
server_password,
|
||||
nickserv_password,
|
||||
sasl_password,
|
||||
verify_tls,
|
||||
channels: cfg.channels,
|
||||
allowed_users: cfg.allowed_users,
|
||||
server_password: cfg.server_password,
|
||||
nickserv_password: cfg.nickserv_password,
|
||||
sasl_password: cfg.sasl_password,
|
||||
verify_tls: cfg.verify_tls,
|
||||
writer: Arc::new(Mutex::new(None)),
|
||||
}
|
||||
}
|
||||
|
|
@ -563,7 +565,8 @@ impl Channel for IrcChannel {
|
|||
let seq = MSG_SEQ.fetch_add(1, Ordering::Relaxed);
|
||||
let channel_msg = ChannelMessage {
|
||||
id: format!("irc_{}_{seq}", chrono::Utc::now().timestamp_millis()),
|
||||
sender: reply_to,
|
||||
sender: sender_nick.to_string(),
|
||||
reply_target: reply_to,
|
||||
content,
|
||||
channel: "irc".to_string(),
|
||||
timestamp: std::time::SystemTime::now()
|
||||
|
|
@ -807,18 +810,18 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn specific_user_allowed() {
|
||||
let ch = IrcChannel::new(
|
||||
"irc.test".into(),
|
||||
6697,
|
||||
"bot".into(),
|
||||
None,
|
||||
vec![],
|
||||
vec!["alice".into(), "bob".into()],
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
true,
|
||||
);
|
||||
let ch = IrcChannel::new(IrcChannelConfig {
|
||||
server: "irc.test".into(),
|
||||
port: 6697,
|
||||
nickname: "bot".into(),
|
||||
username: None,
|
||||
channels: vec![],
|
||||
allowed_users: vec!["alice".into(), "bob".into()],
|
||||
server_password: None,
|
||||
nickserv_password: None,
|
||||
sasl_password: None,
|
||||
verify_tls: true,
|
||||
});
|
||||
assert!(ch.is_user_allowed("alice"));
|
||||
assert!(ch.is_user_allowed("bob"));
|
||||
assert!(!ch.is_user_allowed("eve"));
|
||||
|
|
@ -826,18 +829,18 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn allowlist_case_insensitive() {
|
||||
let ch = IrcChannel::new(
|
||||
"irc.test".into(),
|
||||
6697,
|
||||
"bot".into(),
|
||||
None,
|
||||
vec![],
|
||||
vec!["Alice".into()],
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
true,
|
||||
);
|
||||
let ch = IrcChannel::new(IrcChannelConfig {
|
||||
server: "irc.test".into(),
|
||||
port: 6697,
|
||||
nickname: "bot".into(),
|
||||
username: None,
|
||||
channels: vec![],
|
||||
allowed_users: vec!["Alice".into()],
|
||||
server_password: None,
|
||||
nickserv_password: None,
|
||||
sasl_password: None,
|
||||
verify_tls: true,
|
||||
});
|
||||
assert!(ch.is_user_allowed("alice"));
|
||||
assert!(ch.is_user_allowed("ALICE"));
|
||||
assert!(ch.is_user_allowed("Alice"));
|
||||
|
|
@ -845,18 +848,18 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn empty_allowlist_denies_all() {
|
||||
let ch = IrcChannel::new(
|
||||
"irc.test".into(),
|
||||
6697,
|
||||
"bot".into(),
|
||||
None,
|
||||
vec![],
|
||||
vec![],
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
true,
|
||||
);
|
||||
let ch = IrcChannel::new(IrcChannelConfig {
|
||||
server: "irc.test".into(),
|
||||
port: 6697,
|
||||
nickname: "bot".into(),
|
||||
username: None,
|
||||
channels: vec![],
|
||||
allowed_users: vec![],
|
||||
server_password: None,
|
||||
nickserv_password: None,
|
||||
sasl_password: None,
|
||||
verify_tls: true,
|
||||
});
|
||||
assert!(!ch.is_user_allowed("anyone"));
|
||||
}
|
||||
|
||||
|
|
@ -864,35 +867,35 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn new_defaults_username_to_nickname() {
|
||||
let ch = IrcChannel::new(
|
||||
"irc.test".into(),
|
||||
6697,
|
||||
"mybot".into(),
|
||||
None,
|
||||
vec![],
|
||||
vec![],
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
true,
|
||||
);
|
||||
let ch = IrcChannel::new(IrcChannelConfig {
|
||||
server: "irc.test".into(),
|
||||
port: 6697,
|
||||
nickname: "mybot".into(),
|
||||
username: None,
|
||||
channels: vec![],
|
||||
allowed_users: vec![],
|
||||
server_password: None,
|
||||
nickserv_password: None,
|
||||
sasl_password: None,
|
||||
verify_tls: true,
|
||||
});
|
||||
assert_eq!(ch.username, "mybot");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn new_uses_explicit_username() {
|
||||
let ch = IrcChannel::new(
|
||||
"irc.test".into(),
|
||||
6697,
|
||||
"mybot".into(),
|
||||
Some("customuser".into()),
|
||||
vec![],
|
||||
vec![],
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
true,
|
||||
);
|
||||
let ch = IrcChannel::new(IrcChannelConfig {
|
||||
server: "irc.test".into(),
|
||||
port: 6697,
|
||||
nickname: "mybot".into(),
|
||||
username: Some("customuser".into()),
|
||||
channels: vec![],
|
||||
allowed_users: vec![],
|
||||
server_password: None,
|
||||
nickserv_password: None,
|
||||
sasl_password: None,
|
||||
verify_tls: true,
|
||||
});
|
||||
assert_eq!(ch.username, "customuser");
|
||||
assert_eq!(ch.nickname, "mybot");
|
||||
}
|
||||
|
|
@ -905,18 +908,18 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn new_stores_all_fields() {
|
||||
let ch = IrcChannel::new(
|
||||
"irc.example.com".into(),
|
||||
6697,
|
||||
"zcbot".into(),
|
||||
Some("zeroclaw".into()),
|
||||
vec!["#test".into()],
|
||||
vec!["alice".into()],
|
||||
Some("serverpass".into()),
|
||||
Some("nspass".into()),
|
||||
Some("saslpass".into()),
|
||||
false,
|
||||
);
|
||||
let ch = IrcChannel::new(IrcChannelConfig {
|
||||
server: "irc.example.com".into(),
|
||||
port: 6697,
|
||||
nickname: "zcbot".into(),
|
||||
username: Some("zeroclaw".into()),
|
||||
channels: vec!["#test".into()],
|
||||
allowed_users: vec!["alice".into()],
|
||||
server_password: Some("serverpass".into()),
|
||||
nickserv_password: Some("nspass".into()),
|
||||
sasl_password: Some("saslpass".into()),
|
||||
verify_tls: false,
|
||||
});
|
||||
assert_eq!(ch.server, "irc.example.com");
|
||||
assert_eq!(ch.port, 6697);
|
||||
assert_eq!(ch.nickname, "zcbot");
|
||||
|
|
@ -995,17 +998,17 @@ nickname = "bot"
|
|||
// ── Helpers ─────────────────────────────────────────────
|
||||
|
||||
fn make_channel() -> IrcChannel {
|
||||
IrcChannel::new(
|
||||
"irc.example.com".into(),
|
||||
6697,
|
||||
"zcbot".into(),
|
||||
None,
|
||||
vec!["#zeroclaw".into()],
|
||||
vec!["*".into()],
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
true,
|
||||
)
|
||||
IrcChannel::new(IrcChannelConfig {
|
||||
server: "irc.example.com".into(),
|
||||
port: 6697,
|
||||
nickname: "zcbot".into(),
|
||||
username: None,
|
||||
channels: vec!["#zeroclaw".into()],
|
||||
allowed_users: vec!["*".into()],
|
||||
server_password: None,
|
||||
nickserv_password: None,
|
||||
sasl_password: None,
|
||||
verify_tls: true,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,21 +1,152 @@
|
|||
use super::traits::{Channel, ChannelMessage};
|
||||
use async_trait::async_trait;
|
||||
use futures_util::{SinkExt, StreamExt};
|
||||
use prost::Message as ProstMessage;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
use tokio::sync::RwLock;
|
||||
use tokio_tungstenite::tungstenite::Message as WsMsg;
|
||||
use uuid::Uuid;
|
||||
|
||||
const FEISHU_BASE_URL: &str = "https://open.feishu.cn/open-apis";
|
||||
const FEISHU_WS_BASE_URL: &str = "https://open.feishu.cn";
|
||||
const LARK_BASE_URL: &str = "https://open.larksuite.com/open-apis";
|
||||
const LARK_WS_BASE_URL: &str = "https://open.larksuite.com";
|
||||
|
||||
/// Lark/Feishu channel — receives events via HTTP callback, sends via Open API
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// Feishu WebSocket long-connection: pbbp2.proto frame codec
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
#[derive(Clone, PartialEq, prost::Message)]
|
||||
struct PbHeader {
|
||||
#[prost(string, tag = "1")]
|
||||
pub key: String,
|
||||
#[prost(string, tag = "2")]
|
||||
pub value: String,
|
||||
}
|
||||
|
||||
/// Feishu WS frame (pbbp2.proto).
|
||||
/// method=0 → CONTROL (ping/pong) method=1 → DATA (events)
|
||||
#[derive(Clone, PartialEq, prost::Message)]
|
||||
struct PbFrame {
|
||||
#[prost(uint64, tag = "1")]
|
||||
pub seq_id: u64,
|
||||
#[prost(uint64, tag = "2")]
|
||||
pub log_id: u64,
|
||||
#[prost(int32, tag = "3")]
|
||||
pub service: i32,
|
||||
#[prost(int32, tag = "4")]
|
||||
pub method: i32,
|
||||
#[prost(message, repeated, tag = "5")]
|
||||
pub headers: Vec<PbHeader>,
|
||||
#[prost(bytes = "vec", optional, tag = "8")]
|
||||
pub payload: Option<Vec<u8>>,
|
||||
}
|
||||
|
||||
impl PbFrame {
|
||||
fn header_value<'a>(&'a self, key: &str) -> &'a str {
|
||||
self.headers
|
||||
.iter()
|
||||
.find(|h| h.key == key)
|
||||
.map(|h| h.value.as_str())
|
||||
.unwrap_or("")
|
||||
}
|
||||
}
|
||||
|
||||
/// Server-sent client config (parsed from pong payload)
|
||||
#[derive(Debug, serde::Deserialize, Default, Clone)]
|
||||
struct WsClientConfig {
|
||||
#[serde(rename = "PingInterval")]
|
||||
ping_interval: Option<u64>,
|
||||
}
|
||||
|
||||
/// POST /callback/ws/endpoint response
|
||||
#[derive(Debug, serde::Deserialize)]
|
||||
struct WsEndpointResp {
|
||||
code: i32,
|
||||
#[serde(default)]
|
||||
msg: Option<String>,
|
||||
#[serde(default)]
|
||||
data: Option<WsEndpoint>,
|
||||
}
|
||||
|
||||
#[derive(Debug, serde::Deserialize)]
|
||||
struct WsEndpoint {
|
||||
#[serde(rename = "URL")]
|
||||
url: String,
|
||||
#[serde(rename = "ClientConfig")]
|
||||
client_config: Option<WsClientConfig>,
|
||||
}
|
||||
|
||||
/// LarkEvent envelope (method=1 / type=event payload)
|
||||
#[derive(Debug, serde::Deserialize)]
|
||||
struct LarkEvent {
|
||||
header: LarkEventHeader,
|
||||
event: serde_json::Value,
|
||||
}
|
||||
|
||||
#[derive(Debug, serde::Deserialize)]
|
||||
struct LarkEventHeader {
|
||||
event_type: String,
|
||||
#[allow(dead_code)]
|
||||
event_id: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, serde::Deserialize)]
|
||||
struct MsgReceivePayload {
|
||||
sender: LarkSender,
|
||||
message: LarkMessage,
|
||||
}
|
||||
|
||||
#[derive(Debug, serde::Deserialize)]
|
||||
struct LarkSender {
|
||||
sender_id: LarkSenderId,
|
||||
#[serde(default)]
|
||||
sender_type: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, serde::Deserialize, Default)]
|
||||
struct LarkSenderId {
|
||||
open_id: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, serde::Deserialize)]
|
||||
struct LarkMessage {
|
||||
message_id: String,
|
||||
chat_id: String,
|
||||
chat_type: String,
|
||||
message_type: String,
|
||||
#[serde(default)]
|
||||
content: String,
|
||||
#[serde(default)]
|
||||
mentions: Vec<serde_json::Value>,
|
||||
}
|
||||
|
||||
/// Heartbeat timeout for WS connection — must be larger than ping_interval (default 120 s).
|
||||
/// If no binary frame (pong or event) is received within this window, reconnect.
|
||||
const WS_HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(300);
|
||||
|
||||
/// Lark/Feishu channel.
|
||||
///
|
||||
/// Supports two receive modes (configured via `receive_mode` in config):
|
||||
/// - **`websocket`** (default): persistent WSS long-connection; no public URL needed.
|
||||
/// - **`webhook`**: HTTP callback server; requires a public HTTPS endpoint.
|
||||
pub struct LarkChannel {
|
||||
app_id: String,
|
||||
app_secret: String,
|
||||
verification_token: String,
|
||||
port: u16,
|
||||
port: Option<u16>,
|
||||
allowed_users: Vec<String>,
|
||||
/// When true, use Feishu (CN) endpoints; when false, use Lark (international).
|
||||
use_feishu: bool,
|
||||
/// How to receive events: WebSocket long-connection or HTTP webhook.
|
||||
receive_mode: crate::config::schema::LarkReceiveMode,
|
||||
client: reqwest::Client,
|
||||
/// Cached tenant access token
|
||||
tenant_token: Arc<RwLock<Option<String>>>,
|
||||
/// Dedup set: WS message_ids seen in last ~30 min to prevent double-dispatch
|
||||
ws_seen_ids: Arc<RwLock<HashMap<String, Instant>>>,
|
||||
}
|
||||
|
||||
impl LarkChannel {
|
||||
|
|
@ -23,7 +154,7 @@ impl LarkChannel {
|
|||
app_id: String,
|
||||
app_secret: String,
|
||||
verification_token: String,
|
||||
port: u16,
|
||||
port: Option<u16>,
|
||||
allowed_users: Vec<String>,
|
||||
) -> Self {
|
||||
Self {
|
||||
|
|
@ -32,11 +163,310 @@ impl LarkChannel {
|
|||
verification_token,
|
||||
port,
|
||||
allowed_users,
|
||||
use_feishu: true,
|
||||
receive_mode: crate::config::schema::LarkReceiveMode::default(),
|
||||
client: reqwest::Client::new(),
|
||||
tenant_token: Arc::new(RwLock::new(None)),
|
||||
ws_seen_ids: Arc::new(RwLock::new(HashMap::new())),
|
||||
}
|
||||
}
|
||||
|
||||
/// Build from `LarkConfig` (preserves `use_feishu` and `receive_mode`).
|
||||
pub fn from_config(config: &crate::config::schema::LarkConfig) -> Self {
|
||||
let mut ch = Self::new(
|
||||
config.app_id.clone(),
|
||||
config.app_secret.clone(),
|
||||
config.verification_token.clone().unwrap_or_default(),
|
||||
config.port,
|
||||
config.allowed_users.clone(),
|
||||
);
|
||||
ch.use_feishu = config.use_feishu;
|
||||
ch.receive_mode = config.receive_mode.clone();
|
||||
ch
|
||||
}
|
||||
|
||||
fn api_base(&self) -> &'static str {
|
||||
if self.use_feishu {
|
||||
FEISHU_BASE_URL
|
||||
} else {
|
||||
LARK_BASE_URL
|
||||
}
|
||||
}
|
||||
|
||||
fn ws_base(&self) -> &'static str {
|
||||
if self.use_feishu {
|
||||
FEISHU_WS_BASE_URL
|
||||
} else {
|
||||
LARK_WS_BASE_URL
|
||||
}
|
||||
}
|
||||
|
||||
fn tenant_access_token_url(&self) -> String {
|
||||
format!("{}/auth/v3/tenant_access_token/internal", self.api_base())
|
||||
}
|
||||
|
||||
fn send_message_url(&self) -> String {
|
||||
format!("{}/im/v1/messages?receive_id_type=chat_id", self.api_base())
|
||||
}
|
||||
|
||||
/// POST /callback/ws/endpoint → (wss_url, client_config)
|
||||
async fn get_ws_endpoint(&self) -> anyhow::Result<(String, WsClientConfig)> {
|
||||
let resp = self
|
||||
.client
|
||||
.post(format!("{}/callback/ws/endpoint", self.ws_base()))
|
||||
.header("locale", if self.use_feishu { "zh" } else { "en" })
|
||||
.json(&serde_json::json!({
|
||||
"AppID": self.app_id,
|
||||
"AppSecret": self.app_secret,
|
||||
}))
|
||||
.send()
|
||||
.await?
|
||||
.json::<WsEndpointResp>()
|
||||
.await?;
|
||||
if resp.code != 0 {
|
||||
anyhow::bail!(
|
||||
"Lark WS endpoint failed: code={} msg={}",
|
||||
resp.code,
|
||||
resp.msg.as_deref().unwrap_or("(none)")
|
||||
);
|
||||
}
|
||||
let ep = resp
|
||||
.data
|
||||
.ok_or_else(|| anyhow::anyhow!("Lark WS endpoint: empty data"))?;
|
||||
Ok((ep.url, ep.client_config.unwrap_or_default()))
|
||||
}
|
||||
|
||||
/// WS long-connection event loop. Returns Ok(()) when the connection closes
|
||||
/// (the caller reconnects).
|
||||
#[allow(clippy::too_many_lines)]
|
||||
async fn listen_ws(&self, tx: tokio::sync::mpsc::Sender<ChannelMessage>) -> anyhow::Result<()> {
|
||||
let (wss_url, client_config) = self.get_ws_endpoint().await?;
|
||||
let service_id = wss_url
|
||||
.split('?')
|
||||
.nth(1)
|
||||
.and_then(|qs| {
|
||||
qs.split('&')
|
||||
.find(|kv| kv.starts_with("service_id="))
|
||||
.and_then(|kv| kv.split('=').nth(1))
|
||||
.and_then(|v| v.parse::<i32>().ok())
|
||||
})
|
||||
.unwrap_or(0);
|
||||
tracing::info!("Lark: connecting to {wss_url}");
|
||||
|
||||
let (ws_stream, _) = tokio_tungstenite::connect_async(&wss_url).await?;
|
||||
let (mut write, mut read) = ws_stream.split();
|
||||
tracing::info!("Lark: WS connected (service_id={service_id})");
|
||||
|
||||
let mut ping_secs = client_config.ping_interval.unwrap_or(120).max(10);
|
||||
let mut hb_interval = tokio::time::interval(Duration::from_secs(ping_secs));
|
||||
let mut timeout_check = tokio::time::interval(Duration::from_secs(10));
|
||||
hb_interval.tick().await; // consume immediate tick
|
||||
|
||||
let mut seq: u64 = 0;
|
||||
let mut last_recv = Instant::now();
|
||||
|
||||
// Send initial ping immediately (like the official SDK) so the server
|
||||
// starts responding with pongs and we can calibrate the ping_interval.
|
||||
seq = seq.wrapping_add(1);
|
||||
let initial_ping = PbFrame {
|
||||
seq_id: seq,
|
||||
log_id: 0,
|
||||
service: service_id,
|
||||
method: 0,
|
||||
headers: vec![PbHeader {
|
||||
key: "type".into(),
|
||||
value: "ping".into(),
|
||||
}],
|
||||
payload: None,
|
||||
};
|
||||
if write
|
||||
.send(WsMsg::Binary(initial_ping.encode_to_vec()))
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
anyhow::bail!("Lark: initial ping failed");
|
||||
}
|
||||
// message_id → (fragment_slots, created_at) for multi-part reassembly
|
||||
type FragEntry = (Vec<Option<Vec<u8>>>, Instant);
|
||||
let mut frag_cache: HashMap<String, FragEntry> = HashMap::new();
|
||||
|
||||
loop {
|
||||
tokio::select! {
|
||||
biased;
|
||||
|
||||
_ = hb_interval.tick() => {
|
||||
seq = seq.wrapping_add(1);
|
||||
let ping = PbFrame {
|
||||
seq_id: seq, log_id: 0, service: service_id, method: 0,
|
||||
headers: vec![PbHeader { key: "type".into(), value: "ping".into() }],
|
||||
payload: None,
|
||||
};
|
||||
if write.send(WsMsg::Binary(ping.encode_to_vec())).await.is_err() {
|
||||
tracing::warn!("Lark: ping failed, reconnecting");
|
||||
break;
|
||||
}
|
||||
// GC stale fragments > 5 min
|
||||
let cutoff = Instant::now().checked_sub(Duration::from_secs(300)).unwrap_or(Instant::now());
|
||||
frag_cache.retain(|_, (_, ts)| *ts > cutoff);
|
||||
}
|
||||
|
||||
_ = timeout_check.tick() => {
|
||||
if last_recv.elapsed() > WS_HEARTBEAT_TIMEOUT {
|
||||
tracing::warn!("Lark: heartbeat timeout, reconnecting");
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
msg = read.next() => {
|
||||
let raw = match msg {
|
||||
Some(Ok(WsMsg::Binary(b))) => { last_recv = Instant::now(); b }
|
||||
Some(Ok(WsMsg::Ping(d))) => { let _ = write.send(WsMsg::Pong(d)).await; continue; }
|
||||
Some(Ok(WsMsg::Close(_))) | None => { tracing::info!("Lark: WS closed — reconnecting"); break; }
|
||||
Some(Err(e)) => { tracing::error!("Lark: WS read error: {e}"); break; }
|
||||
_ => continue,
|
||||
};
|
||||
|
||||
let frame = match PbFrame::decode(&raw[..]) {
|
||||
Ok(f) => f,
|
||||
Err(e) => { tracing::error!("Lark: proto decode: {e}"); continue; }
|
||||
};
|
||||
|
||||
// CONTROL frame
|
||||
if frame.method == 0 {
|
||||
if frame.header_value("type") == "pong" {
|
||||
if let Some(p) = &frame.payload {
|
||||
if let Ok(cfg) = serde_json::from_slice::<WsClientConfig>(p) {
|
||||
if let Some(secs) = cfg.ping_interval {
|
||||
let secs = secs.max(10);
|
||||
if secs != ping_secs {
|
||||
ping_secs = secs;
|
||||
hb_interval = tokio::time::interval(Duration::from_secs(ping_secs));
|
||||
tracing::info!("Lark: ping_interval → {ping_secs}s");
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
// DATA frame
|
||||
let msg_type = frame.header_value("type").to_string();
|
||||
let msg_id = frame.header_value("message_id").to_string();
|
||||
let sum = frame.header_value("sum").parse::<usize>().unwrap_or(1);
|
||||
let seq_num = frame.header_value("seq").parse::<usize>().unwrap_or(0);
|
||||
|
||||
// ACK immediately (Feishu requires within 3 s)
|
||||
{
|
||||
let mut ack = frame.clone();
|
||||
ack.payload = Some(br#"{"code":200,"headers":{},"data":[]}"#.to_vec());
|
||||
ack.headers.push(PbHeader { key: "biz_rt".into(), value: "0".into() });
|
||||
let _ = write.send(WsMsg::Binary(ack.encode_to_vec())).await;
|
||||
}
|
||||
|
||||
// Fragment reassembly
|
||||
let sum = if sum == 0 { 1 } else { sum };
|
||||
let payload: Vec<u8> = if sum == 1 || msg_id.is_empty() || seq_num >= sum {
|
||||
frame.payload.clone().unwrap_or_default()
|
||||
} else {
|
||||
let entry = frag_cache.entry(msg_id.clone())
|
||||
.or_insert_with(|| (vec![None; sum], Instant::now()));
|
||||
if entry.0.len() != sum { *entry = (vec![None; sum], Instant::now()); }
|
||||
entry.0[seq_num] = frame.payload.clone();
|
||||
if entry.0.iter().all(|s| s.is_some()) {
|
||||
let full: Vec<u8> = entry.0.iter()
|
||||
.flat_map(|s| s.as_deref().unwrap_or(&[]))
|
||||
.copied().collect();
|
||||
frag_cache.remove(&msg_id);
|
||||
full
|
||||
} else { continue; }
|
||||
};
|
||||
|
||||
if msg_type != "event" { continue; }
|
||||
|
||||
let event: LarkEvent = match serde_json::from_slice(&payload) {
|
||||
Ok(e) => e,
|
||||
Err(e) => { tracing::error!("Lark: event JSON: {e}"); continue; }
|
||||
};
|
||||
if event.header.event_type != "im.message.receive_v1" { continue; }
|
||||
|
||||
let recv: MsgReceivePayload = match serde_json::from_value(event.event) {
|
||||
Ok(r) => r,
|
||||
Err(e) => { tracing::error!("Lark: payload parse: {e}"); continue; }
|
||||
};
|
||||
|
||||
if recv.sender.sender_type == "app" || recv.sender.sender_type == "bot" { continue; }
|
||||
|
||||
let sender_open_id = recv.sender.sender_id.open_id.as_deref().unwrap_or("");
|
||||
if !self.is_user_allowed(sender_open_id) {
|
||||
tracing::warn!("Lark WS: ignoring {sender_open_id} (not in allowed_users)");
|
||||
continue;
|
||||
}
|
||||
|
||||
let lark_msg = &recv.message;
|
||||
|
||||
// Dedup
|
||||
{
|
||||
let now = Instant::now();
|
||||
let mut seen = self.ws_seen_ids.write().await;
|
||||
// GC
|
||||
seen.retain(|_, t| now.duration_since(*t) < Duration::from_secs(30 * 60));
|
||||
if seen.contains_key(&lark_msg.message_id) {
|
||||
tracing::debug!("Lark WS: dup {}", lark_msg.message_id);
|
||||
continue;
|
||||
}
|
||||
seen.insert(lark_msg.message_id.clone(), now);
|
||||
}
|
||||
|
||||
// Decode content by type (mirrors clawdbot-feishu parsing)
|
||||
let text = match lark_msg.message_type.as_str() {
|
||||
"text" => {
|
||||
let v: serde_json::Value = match serde_json::from_str(&lark_msg.content) {
|
||||
Ok(v) => v,
|
||||
Err(_) => continue,
|
||||
};
|
||||
match v.get("text").and_then(|t| t.as_str()).filter(|s| !s.is_empty()) {
|
||||
Some(t) => t.to_string(),
|
||||
None => continue,
|
||||
}
|
||||
}
|
||||
"post" => match parse_post_content(&lark_msg.content) {
|
||||
Some(t) => t,
|
||||
None => continue,
|
||||
},
|
||||
_ => { tracing::debug!("Lark WS: skipping unsupported type '{}'", lark_msg.message_type); continue; }
|
||||
};
|
||||
|
||||
// Strip @_user_N placeholders
|
||||
let text = strip_at_placeholders(&text);
|
||||
let text = text.trim().to_string();
|
||||
if text.is_empty() { continue; }
|
||||
|
||||
// Group-chat: only respond when explicitly @-mentioned
|
||||
if lark_msg.chat_type == "group" && !should_respond_in_group(&lark_msg.mentions) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let channel_msg = ChannelMessage {
|
||||
id: Uuid::new_v4().to_string(),
|
||||
sender: lark_msg.chat_id.clone(),
|
||||
reply_target: lark_msg.chat_id.clone(),
|
||||
content: text,
|
||||
channel: "lark".to_string(),
|
||||
timestamp: std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs(),
|
||||
};
|
||||
|
||||
tracing::debug!("Lark WS: message in {}", lark_msg.chat_id);
|
||||
if tx.send(channel_msg).await.is_err() { break; }
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Check if a user open_id is allowed
|
||||
fn is_user_allowed(&self, open_id: &str) -> bool {
|
||||
self.allowed_users.iter().any(|u| u == "*" || u == open_id)
|
||||
|
|
@ -52,7 +482,7 @@ impl LarkChannel {
|
|||
}
|
||||
}
|
||||
|
||||
let url = format!("{FEISHU_BASE_URL}/auth/v3/tenant_access_token/internal");
|
||||
let url = self.tenant_access_token_url();
|
||||
let body = serde_json::json!({
|
||||
"app_id": self.app_id,
|
||||
"app_secret": self.app_secret,
|
||||
|
|
@ -127,31 +557,41 @@ impl LarkChannel {
|
|||
return messages;
|
||||
}
|
||||
|
||||
// Extract message content (text only)
|
||||
// Extract message content (text and post supported)
|
||||
let msg_type = event
|
||||
.pointer("/message/message_type")
|
||||
.and_then(|t| t.as_str())
|
||||
.unwrap_or("");
|
||||
|
||||
if msg_type != "text" {
|
||||
tracing::debug!("Lark: skipping non-text message type: {msg_type}");
|
||||
return messages;
|
||||
}
|
||||
|
||||
let content_str = event
|
||||
.pointer("/message/content")
|
||||
.and_then(|c| c.as_str())
|
||||
.unwrap_or("");
|
||||
|
||||
// content is a JSON string like "{\"text\":\"hello\"}"
|
||||
let text = serde_json::from_str::<serde_json::Value>(content_str)
|
||||
.ok()
|
||||
.and_then(|v| v.get("text").and_then(|t| t.as_str()).map(String::from))
|
||||
.unwrap_or_default();
|
||||
|
||||
if text.is_empty() {
|
||||
return messages;
|
||||
}
|
||||
let text: String = match msg_type {
|
||||
"text" => {
|
||||
let extracted = serde_json::from_str::<serde_json::Value>(content_str)
|
||||
.ok()
|
||||
.and_then(|v| {
|
||||
v.get("text")
|
||||
.and_then(|t| t.as_str())
|
||||
.filter(|s| !s.is_empty())
|
||||
.map(String::from)
|
||||
});
|
||||
match extracted {
|
||||
Some(t) => t,
|
||||
None => return messages,
|
||||
}
|
||||
}
|
||||
"post" => match parse_post_content(content_str) {
|
||||
Some(t) => t,
|
||||
None => return messages,
|
||||
},
|
||||
_ => {
|
||||
tracing::debug!("Lark: skipping unsupported message type: {msg_type}");
|
||||
return messages;
|
||||
}
|
||||
};
|
||||
|
||||
let timestamp = event
|
||||
.pointer("/message/create_time")
|
||||
|
|
@ -174,6 +614,7 @@ impl LarkChannel {
|
|||
messages.push(ChannelMessage {
|
||||
id: Uuid::new_v4().to_string(),
|
||||
sender: chat_id.to_string(),
|
||||
reply_target: chat_id.to_string(),
|
||||
content: text,
|
||||
channel: "lark".to_string(),
|
||||
timestamp,
|
||||
|
|
@ -191,7 +632,7 @@ impl Channel for LarkChannel {
|
|||
|
||||
async fn send(&self, message: &str, recipient: &str) -> anyhow::Result<()> {
|
||||
let token = self.get_tenant_access_token().await?;
|
||||
let url = format!("{FEISHU_BASE_URL}/im/v1/messages?receive_id_type=chat_id");
|
||||
let url = self.send_message_url();
|
||||
|
||||
let content = serde_json::json!({ "text": message }).to_string();
|
||||
let body = serde_json::json!({
|
||||
|
|
@ -238,6 +679,25 @@ impl Channel for LarkChannel {
|
|||
}
|
||||
|
||||
async fn listen(&self, tx: tokio::sync::mpsc::Sender<ChannelMessage>) -> anyhow::Result<()> {
|
||||
use crate::config::schema::LarkReceiveMode;
|
||||
match self.receive_mode {
|
||||
LarkReceiveMode::Websocket => self.listen_ws(tx).await,
|
||||
LarkReceiveMode::Webhook => self.listen_http(tx).await,
|
||||
}
|
||||
}
|
||||
|
||||
async fn health_check(&self) -> bool {
|
||||
self.get_tenant_access_token().await.is_ok()
|
||||
}
|
||||
}
|
||||
|
||||
impl LarkChannel {
|
||||
/// HTTP callback server (legacy — requires a public endpoint).
|
||||
/// Use `listen()` (WS long-connection) for new deployments.
|
||||
pub async fn listen_http(
|
||||
&self,
|
||||
tx: tokio::sync::mpsc::Sender<ChannelMessage>,
|
||||
) -> anyhow::Result<()> {
|
||||
use axum::{extract::State, routing::post, Json, Router};
|
||||
|
||||
#[derive(Clone)]
|
||||
|
|
@ -282,13 +742,17 @@ impl Channel for LarkChannel {
|
|||
(StatusCode::OK, "ok").into_response()
|
||||
}
|
||||
|
||||
let port = self.port.ok_or_else(|| {
|
||||
anyhow::anyhow!("Lark webhook mode requires `port` to be set in [channels_config.lark]")
|
||||
})?;
|
||||
|
||||
let state = AppState {
|
||||
verification_token: self.verification_token.clone(),
|
||||
channel: Arc::new(LarkChannel::new(
|
||||
self.app_id.clone(),
|
||||
self.app_secret.clone(),
|
||||
self.verification_token.clone(),
|
||||
self.port,
|
||||
None,
|
||||
self.allowed_users.clone(),
|
||||
)),
|
||||
tx,
|
||||
|
|
@ -298,7 +762,7 @@ impl Channel for LarkChannel {
|
|||
.route("/lark", post(handle_event))
|
||||
.with_state(state);
|
||||
|
||||
let addr = std::net::SocketAddr::from(([0, 0, 0, 0], self.port));
|
||||
let addr = std::net::SocketAddr::from(([0, 0, 0, 0], port));
|
||||
tracing::info!("Lark event callback server listening on {addr}");
|
||||
|
||||
let listener = tokio::net::TcpListener::bind(addr).await?;
|
||||
|
|
@ -306,10 +770,110 @@ impl Channel for LarkChannel {
|
|||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
async fn health_check(&self) -> bool {
|
||||
self.get_tenant_access_token().await.is_ok()
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
// WS helper functions
|
||||
// ─────────────────────────────────────────────────────────────────────────────
|
||||
|
||||
/// Flatten a Feishu `post` rich-text message to plain text.
|
||||
///
|
||||
/// Returns `None` when the content cannot be parsed or yields no usable text,
|
||||
/// so callers can simply `continue` rather than forwarding a meaningless
|
||||
/// placeholder string to the agent.
|
||||
fn parse_post_content(content: &str) -> Option<String> {
|
||||
let parsed = serde_json::from_str::<serde_json::Value>(content).ok()?;
|
||||
let locale = parsed
|
||||
.get("zh_cn")
|
||||
.or_else(|| parsed.get("en_us"))
|
||||
.or_else(|| {
|
||||
parsed
|
||||
.as_object()
|
||||
.and_then(|m| m.values().find(|v| v.is_object()))
|
||||
})?;
|
||||
|
||||
let mut text = String::new();
|
||||
|
||||
if let Some(title) = locale
|
||||
.get("title")
|
||||
.and_then(|t| t.as_str())
|
||||
.filter(|s| !s.is_empty())
|
||||
{
|
||||
text.push_str(title);
|
||||
text.push_str("\n\n");
|
||||
}
|
||||
|
||||
if let Some(paragraphs) = locale.get("content").and_then(|c| c.as_array()) {
|
||||
for para in paragraphs {
|
||||
if let Some(elements) = para.as_array() {
|
||||
for el in elements {
|
||||
match el.get("tag").and_then(|t| t.as_str()).unwrap_or("") {
|
||||
"text" => {
|
||||
if let Some(t) = el.get("text").and_then(|t| t.as_str()) {
|
||||
text.push_str(t);
|
||||
}
|
||||
}
|
||||
"a" => {
|
||||
text.push_str(
|
||||
el.get("text")
|
||||
.and_then(|t| t.as_str())
|
||||
.filter(|s| !s.is_empty())
|
||||
.or_else(|| el.get("href").and_then(|h| h.as_str()))
|
||||
.unwrap_or(""),
|
||||
);
|
||||
}
|
||||
"at" => {
|
||||
let n = el
|
||||
.get("user_name")
|
||||
.and_then(|n| n.as_str())
|
||||
.or_else(|| el.get("user_id").and_then(|i| i.as_str()))
|
||||
.unwrap_or("user");
|
||||
text.push('@');
|
||||
text.push_str(n);
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
text.push('\n');
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let result = text.trim().to_string();
|
||||
if result.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(result)
|
||||
}
|
||||
}
|
||||
|
||||
/// Remove `@_user_N` placeholder tokens injected by Feishu in group chats.
|
||||
fn strip_at_placeholders(text: &str) -> String {
|
||||
let mut result = String::with_capacity(text.len());
|
||||
let mut chars = text.char_indices().peekable();
|
||||
while let Some((_, ch)) = chars.next() {
|
||||
if ch == '@' {
|
||||
let rest: String = chars.clone().map(|(_, c)| c).collect();
|
||||
if let Some(after) = rest.strip_prefix("_user_") {
|
||||
let skip =
|
||||
"_user_".len() + after.chars().take_while(|c| c.is_ascii_digit()).count();
|
||||
for _ in 0..=skip {
|
||||
chars.next();
|
||||
}
|
||||
if chars.peek().map(|(_, c)| *c == ' ').unwrap_or(false) {
|
||||
chars.next();
|
||||
}
|
||||
continue;
|
||||
}
|
||||
}
|
||||
result.push(ch);
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
/// In group chats, only respond when the bot is explicitly @-mentioned.
|
||||
fn should_respond_in_group(mentions: &[serde_json::Value]) -> bool {
|
||||
!mentions.is_empty()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
|
@ -321,7 +885,7 @@ mod tests {
|
|||
"cli_test_app_id".into(),
|
||||
"test_app_secret".into(),
|
||||
"test_verification_token".into(),
|
||||
9898,
|
||||
None,
|
||||
vec!["ou_testuser123".into()],
|
||||
)
|
||||
}
|
||||
|
|
@ -345,7 +909,7 @@ mod tests {
|
|||
"id".into(),
|
||||
"secret".into(),
|
||||
"token".into(),
|
||||
9898,
|
||||
None,
|
||||
vec!["*".into()],
|
||||
);
|
||||
assert!(ch.is_user_allowed("ou_anyone"));
|
||||
|
|
@ -353,7 +917,7 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn lark_user_denied_empty() {
|
||||
let ch = LarkChannel::new("id".into(), "secret".into(), "token".into(), 9898, vec![]);
|
||||
let ch = LarkChannel::new("id".into(), "secret".into(), "token".into(), None, vec![]);
|
||||
assert!(!ch.is_user_allowed("ou_anyone"));
|
||||
}
|
||||
|
||||
|
|
@ -426,7 +990,7 @@ mod tests {
|
|||
"id".into(),
|
||||
"secret".into(),
|
||||
"token".into(),
|
||||
9898,
|
||||
None,
|
||||
vec!["*".into()],
|
||||
);
|
||||
let payload = serde_json::json!({
|
||||
|
|
@ -451,7 +1015,7 @@ mod tests {
|
|||
"id".into(),
|
||||
"secret".into(),
|
||||
"token".into(),
|
||||
9898,
|
||||
None,
|
||||
vec!["*".into()],
|
||||
);
|
||||
let payload = serde_json::json!({
|
||||
|
|
@ -488,7 +1052,7 @@ mod tests {
|
|||
"id".into(),
|
||||
"secret".into(),
|
||||
"token".into(),
|
||||
9898,
|
||||
None,
|
||||
vec!["*".into()],
|
||||
);
|
||||
let payload = serde_json::json!({
|
||||
|
|
@ -512,7 +1076,7 @@ mod tests {
|
|||
"id".into(),
|
||||
"secret".into(),
|
||||
"token".into(),
|
||||
9898,
|
||||
None,
|
||||
vec!["*".into()],
|
||||
);
|
||||
let payload = serde_json::json!({
|
||||
|
|
@ -550,7 +1114,7 @@ mod tests {
|
|||
"id".into(),
|
||||
"secret".into(),
|
||||
"token".into(),
|
||||
9898,
|
||||
None,
|
||||
vec!["*".into()],
|
||||
);
|
||||
let payload = serde_json::json!({
|
||||
|
|
@ -571,7 +1135,7 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn lark_config_serde() {
|
||||
use crate::config::schema::LarkConfig;
|
||||
use crate::config::schema::{LarkConfig, LarkReceiveMode};
|
||||
let lc = LarkConfig {
|
||||
app_id: "cli_app123".into(),
|
||||
app_secret: "secret456".into(),
|
||||
|
|
@ -579,6 +1143,8 @@ mod tests {
|
|||
verification_token: Some("vtoken789".into()),
|
||||
allowed_users: vec!["ou_user1".into(), "ou_user2".into()],
|
||||
use_feishu: false,
|
||||
receive_mode: LarkReceiveMode::default(),
|
||||
port: None,
|
||||
};
|
||||
let json = serde_json::to_string(&lc).unwrap();
|
||||
let parsed: LarkConfig = serde_json::from_str(&json).unwrap();
|
||||
|
|
@ -590,7 +1156,7 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn lark_config_toml_roundtrip() {
|
||||
use crate::config::schema::LarkConfig;
|
||||
use crate::config::schema::{LarkConfig, LarkReceiveMode};
|
||||
let lc = LarkConfig {
|
||||
app_id: "app".into(),
|
||||
app_secret: "secret".into(),
|
||||
|
|
@ -598,6 +1164,8 @@ mod tests {
|
|||
verification_token: Some("tok".into()),
|
||||
allowed_users: vec!["*".into()],
|
||||
use_feishu: false,
|
||||
receive_mode: LarkReceiveMode::Webhook,
|
||||
port: Some(9898),
|
||||
};
|
||||
let toml_str = toml::to_string(&lc).unwrap();
|
||||
let parsed: LarkConfig = toml::from_str(&toml_str).unwrap();
|
||||
|
|
@ -608,11 +1176,36 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn lark_config_defaults_optional_fields() {
|
||||
use crate::config::schema::LarkConfig;
|
||||
use crate::config::schema::{LarkConfig, LarkReceiveMode};
|
||||
let json = r#"{"app_id":"a","app_secret":"s"}"#;
|
||||
let parsed: LarkConfig = serde_json::from_str(json).unwrap();
|
||||
assert!(parsed.verification_token.is_none());
|
||||
assert!(parsed.allowed_users.is_empty());
|
||||
assert_eq!(parsed.receive_mode, LarkReceiveMode::Websocket);
|
||||
assert!(parsed.port.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lark_from_config_preserves_mode_and_region() {
|
||||
use crate::config::schema::{LarkConfig, LarkReceiveMode};
|
||||
|
||||
let cfg = LarkConfig {
|
||||
app_id: "cli_app123".into(),
|
||||
app_secret: "secret456".into(),
|
||||
encrypt_key: None,
|
||||
verification_token: Some("vtoken789".into()),
|
||||
allowed_users: vec!["*".into()],
|
||||
use_feishu: false,
|
||||
receive_mode: LarkReceiveMode::Webhook,
|
||||
port: Some(9898),
|
||||
};
|
||||
|
||||
let ch = LarkChannel::from_config(&cfg);
|
||||
|
||||
assert_eq!(ch.api_base(), LARK_BASE_URL);
|
||||
assert_eq!(ch.ws_base(), LARK_WS_BASE_URL);
|
||||
assert_eq!(ch.receive_mode, LarkReceiveMode::Webhook);
|
||||
assert_eq!(ch.port, Some(9898));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -622,7 +1215,7 @@ mod tests {
|
|||
"id".into(),
|
||||
"secret".into(),
|
||||
"token".into(),
|
||||
9898,
|
||||
None,
|
||||
vec!["*".into()],
|
||||
);
|
||||
let payload = serde_json::json!({
|
||||
|
|
|
|||
|
|
@ -230,6 +230,7 @@ impl Channel for MatrixChannel {
|
|||
let msg = ChannelMessage {
|
||||
id: format!("mx_{}", chrono::Utc::now().timestamp_millis()),
|
||||
sender: event.sender.clone(),
|
||||
reply_target: event.sender.clone(),
|
||||
content: body.clone(),
|
||||
channel: "matrix".to_string(),
|
||||
timestamp: std::time::SystemTime::now()
|
||||
|
|
|
|||
|
|
@ -69,10 +69,19 @@ fn conversation_memory_key(msg: &traits::ChannelMessage) -> String {
|
|||
format!("{}_{}_{}", msg.channel, msg.sender, msg.id)
|
||||
}
|
||||
|
||||
fn channel_delivery_instructions(channel_name: &str) -> Option<&'static str> {
|
||||
match channel_name {
|
||||
"telegram" => Some(
|
||||
"When responding on Telegram, include media markers for files or URLs that should be sent as attachments. Use one marker per attachment with this exact syntax: [IMAGE:<path-or-url>], [DOCUMENT:<path-or-url>], [VIDEO:<path-or-url>], [AUDIO:<path-or-url>], or [VOICE:<path-or-url>]. Keep normal user-facing text outside markers and never wrap markers in code fences.",
|
||||
),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
async fn build_memory_context(mem: &dyn Memory, user_msg: &str) -> String {
|
||||
let mut context = String::new();
|
||||
|
||||
if let Ok(entries) = mem.recall(user_msg, 5).await {
|
||||
if let Ok(entries) = mem.recall(user_msg, 5, None).await {
|
||||
if !entries.is_empty() {
|
||||
context.push_str("[Memory context]\n");
|
||||
for entry in &entries {
|
||||
|
|
@ -158,6 +167,7 @@ async fn process_channel_message(ctx: Arc<ChannelRuntimeContext>, msg: traits::C
|
|||
&autosave_key,
|
||||
&msg.content,
|
||||
crate::memory::MemoryCategory::Conversation,
|
||||
None,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
|
@ -171,7 +181,7 @@ async fn process_channel_message(ctx: Arc<ChannelRuntimeContext>, msg: traits::C
|
|||
let target_channel = ctx.channels_by_name.get(&msg.channel).cloned();
|
||||
|
||||
if let Some(channel) = target_channel.as_ref() {
|
||||
if let Err(e) = channel.start_typing(&msg.sender).await {
|
||||
if let Err(e) = channel.start_typing(&msg.reply_target).await {
|
||||
tracing::debug!("Failed to start typing on {}: {e}", channel.name());
|
||||
}
|
||||
}
|
||||
|
|
@ -184,6 +194,10 @@ async fn process_channel_message(ctx: Arc<ChannelRuntimeContext>, msg: traits::C
|
|||
ChatMessage::user(&enriched_message),
|
||||
];
|
||||
|
||||
if let Some(instructions) = channel_delivery_instructions(&msg.channel) {
|
||||
history.push(ChatMessage::system(instructions));
|
||||
}
|
||||
|
||||
let llm_result = tokio::time::timeout(
|
||||
Duration::from_secs(CHANNEL_MESSAGE_TIMEOUT_SECS),
|
||||
run_tool_call_loop(
|
||||
|
|
@ -200,7 +214,7 @@ async fn process_channel_message(ctx: Arc<ChannelRuntimeContext>, msg: traits::C
|
|||
.await;
|
||||
|
||||
if let Some(channel) = target_channel.as_ref() {
|
||||
if let Err(e) = channel.stop_typing(&msg.sender).await {
|
||||
if let Err(e) = channel.stop_typing(&msg.reply_target).await {
|
||||
tracing::debug!("Failed to stop typing on {}: {e}", channel.name());
|
||||
}
|
||||
}
|
||||
|
|
@ -224,7 +238,9 @@ async fn process_channel_message(ctx: Arc<ChannelRuntimeContext>, msg: traits::C
|
|||
started_at.elapsed().as_millis()
|
||||
);
|
||||
if let Some(channel) = target_channel.as_ref() {
|
||||
let _ = channel.send(&format!("⚠️ Error: {e}"), &msg.sender).await;
|
||||
let _ = channel
|
||||
.send(&format!("⚠️ Error: {e}"), &msg.reply_target)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
|
|
@ -241,7 +257,7 @@ async fn process_channel_message(ctx: Arc<ChannelRuntimeContext>, msg: traits::C
|
|||
let _ = channel
|
||||
.send(
|
||||
"⚠️ Request timed out while waiting for the model. Please try again.",
|
||||
&msg.sender,
|
||||
&msg.reply_target,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
|
@ -483,6 +499,16 @@ pub fn build_system_prompt(
|
|||
std::env::consts::OS,
|
||||
);
|
||||
|
||||
// ── 8. Channel Capabilities ─────────────────────────────────────
|
||||
prompt.push_str("## Channel Capabilities\n\n");
|
||||
prompt.push_str(
|
||||
"- You are running as a Discord bot. You CAN and do send messages to Discord channels.\n",
|
||||
);
|
||||
prompt.push_str("- When someone messages you on Discord, your response is automatically sent back to Discord.\n");
|
||||
prompt.push_str("- You do NOT need to ask permission to respond — just respond directly.\n");
|
||||
prompt.push_str("- NEVER repeat, describe, or echo credentials, tokens, API keys, or secrets in your responses.\n");
|
||||
prompt.push_str("- If a tool output contains credentials, they have already been redacted — do not mention them.\n\n");
|
||||
|
||||
if prompt.is_empty() {
|
||||
"You are ZeroClaw, a fast and efficient AI assistant built in Rust. Be helpful, concise, and direct.".to_string()
|
||||
} else {
|
||||
|
|
@ -619,6 +645,7 @@ pub async fn doctor_channels(config: Config) -> Result<()> {
|
|||
dc.guild_id.clone(),
|
||||
dc.allowed_users.clone(),
|
||||
dc.listen_to_bots,
|
||||
dc.mention_only,
|
||||
)),
|
||||
));
|
||||
}
|
||||
|
|
@ -672,32 +699,23 @@ pub async fn doctor_channels(config: Config) -> Result<()> {
|
|||
if let Some(ref irc) = config.channels_config.irc {
|
||||
channels.push((
|
||||
"IRC",
|
||||
Arc::new(IrcChannel::new(
|
||||
irc.server.clone(),
|
||||
irc.port,
|
||||
irc.nickname.clone(),
|
||||
irc.username.clone(),
|
||||
irc.channels.clone(),
|
||||
irc.allowed_users.clone(),
|
||||
irc.server_password.clone(),
|
||||
irc.nickserv_password.clone(),
|
||||
irc.sasl_password.clone(),
|
||||
irc.verify_tls.unwrap_or(true),
|
||||
)),
|
||||
Arc::new(IrcChannel::new(irc::IrcChannelConfig {
|
||||
server: irc.server.clone(),
|
||||
port: irc.port,
|
||||
nickname: irc.nickname.clone(),
|
||||
username: irc.username.clone(),
|
||||
channels: irc.channels.clone(),
|
||||
allowed_users: irc.allowed_users.clone(),
|
||||
server_password: irc.server_password.clone(),
|
||||
nickserv_password: irc.nickserv_password.clone(),
|
||||
sasl_password: irc.sasl_password.clone(),
|
||||
verify_tls: irc.verify_tls.unwrap_or(true),
|
||||
})),
|
||||
));
|
||||
}
|
||||
|
||||
if let Some(ref lk) = config.channels_config.lark {
|
||||
channels.push((
|
||||
"Lark",
|
||||
Arc::new(LarkChannel::new(
|
||||
lk.app_id.clone(),
|
||||
lk.app_secret.clone(),
|
||||
lk.verification_token.clone().unwrap_or_default(),
|
||||
9898,
|
||||
lk.allowed_users.clone(),
|
||||
)),
|
||||
));
|
||||
channels.push(("Lark", Arc::new(LarkChannel::from_config(lk))));
|
||||
}
|
||||
|
||||
if let Some(ref dt) = config.channels_config.dingtalk {
|
||||
|
|
@ -762,6 +780,7 @@ pub async fn start_channels(config: Config) -> Result<()> {
|
|||
let provider: Arc<dyn Provider> = Arc::from(providers::create_resilient_provider(
|
||||
&provider_name,
|
||||
config.api_key.as_deref(),
|
||||
config.api_url.as_deref(),
|
||||
&config.reliability,
|
||||
)?);
|
||||
|
||||
|
|
@ -860,6 +879,10 @@ pub async fn start_channels(config: Config) -> Result<()> {
|
|||
"schedule",
|
||||
"Manage scheduled tasks (create/list/get/cancel/pause/resume). Supports recurring cron and one-shot delays.",
|
||||
));
|
||||
tool_descs.push((
|
||||
"pushover",
|
||||
"Send a Pushover notification to your device. Requires PUSHOVER_TOKEN and PUSHOVER_USER_KEY in .env file.",
|
||||
));
|
||||
if !config.agents.is_empty() {
|
||||
tool_descs.push((
|
||||
"delegate",
|
||||
|
|
@ -909,6 +932,7 @@ pub async fn start_channels(config: Config) -> Result<()> {
|
|||
dc.guild_id.clone(),
|
||||
dc.allowed_users.clone(),
|
||||
dc.listen_to_bots,
|
||||
dc.mention_only,
|
||||
)));
|
||||
}
|
||||
|
||||
|
|
@ -947,28 +971,22 @@ pub async fn start_channels(config: Config) -> Result<()> {
|
|||
}
|
||||
|
||||
if let Some(ref irc) = config.channels_config.irc {
|
||||
channels.push(Arc::new(IrcChannel::new(
|
||||
irc.server.clone(),
|
||||
irc.port,
|
||||
irc.nickname.clone(),
|
||||
irc.username.clone(),
|
||||
irc.channels.clone(),
|
||||
irc.allowed_users.clone(),
|
||||
irc.server_password.clone(),
|
||||
irc.nickserv_password.clone(),
|
||||
irc.sasl_password.clone(),
|
||||
irc.verify_tls.unwrap_or(true),
|
||||
)));
|
||||
channels.push(Arc::new(IrcChannel::new(irc::IrcChannelConfig {
|
||||
server: irc.server.clone(),
|
||||
port: irc.port,
|
||||
nickname: irc.nickname.clone(),
|
||||
username: irc.username.clone(),
|
||||
channels: irc.channels.clone(),
|
||||
allowed_users: irc.allowed_users.clone(),
|
||||
server_password: irc.server_password.clone(),
|
||||
nickserv_password: irc.nickserv_password.clone(),
|
||||
sasl_password: irc.sasl_password.clone(),
|
||||
verify_tls: irc.verify_tls.unwrap_or(true),
|
||||
})));
|
||||
}
|
||||
|
||||
if let Some(ref lk) = config.channels_config.lark {
|
||||
channels.push(Arc::new(LarkChannel::new(
|
||||
lk.app_id.clone(),
|
||||
lk.app_secret.clone(),
|
||||
lk.verification_token.clone().unwrap_or_default(),
|
||||
9898,
|
||||
lk.allowed_users.clone(),
|
||||
)));
|
||||
channels.push(Arc::new(LarkChannel::from_config(lk)));
|
||||
}
|
||||
|
||||
if let Some(ref dt) = config.channels_config.dingtalk {
|
||||
|
|
@ -1242,6 +1260,7 @@ mod tests {
|
|||
traits::ChannelMessage {
|
||||
id: "msg-1".to_string(),
|
||||
sender: "alice".to_string(),
|
||||
reply_target: "chat-42".to_string(),
|
||||
content: "What is the BTC price now?".to_string(),
|
||||
channel: "test-channel".to_string(),
|
||||
timestamp: 1,
|
||||
|
|
@ -1251,6 +1270,7 @@ mod tests {
|
|||
|
||||
let sent_messages = channel_impl.sent_messages.lock().await;
|
||||
assert_eq!(sent_messages.len(), 1);
|
||||
assert!(sent_messages[0].starts_with("chat-42:"));
|
||||
assert!(sent_messages[0].contains("BTC is currently around"));
|
||||
assert!(!sent_messages[0].contains("\"tool_calls\""));
|
||||
assert!(!sent_messages[0].contains("mock_price"));
|
||||
|
|
@ -1269,6 +1289,7 @@ mod tests {
|
|||
_key: &str,
|
||||
_content: &str,
|
||||
_category: crate::memory::MemoryCategory,
|
||||
_session_id: Option<&str>,
|
||||
) -> anyhow::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
|
@ -1277,6 +1298,7 @@ mod tests {
|
|||
&self,
|
||||
_query: &str,
|
||||
_limit: usize,
|
||||
_session_id: Option<&str>,
|
||||
) -> anyhow::Result<Vec<crate::memory::MemoryEntry>> {
|
||||
Ok(Vec::new())
|
||||
}
|
||||
|
|
@ -1288,6 +1310,7 @@ mod tests {
|
|||
async fn list(
|
||||
&self,
|
||||
_category: Option<&crate::memory::MemoryCategory>,
|
||||
_session_id: Option<&str>,
|
||||
) -> anyhow::Result<Vec<crate::memory::MemoryEntry>> {
|
||||
Ok(Vec::new())
|
||||
}
|
||||
|
|
@ -1331,6 +1354,7 @@ mod tests {
|
|||
tx.send(traits::ChannelMessage {
|
||||
id: "1".to_string(),
|
||||
sender: "alice".to_string(),
|
||||
reply_target: "alice".to_string(),
|
||||
content: "hello".to_string(),
|
||||
channel: "test-channel".to_string(),
|
||||
timestamp: 1,
|
||||
|
|
@ -1340,6 +1364,7 @@ mod tests {
|
|||
tx.send(traits::ChannelMessage {
|
||||
id: "2".to_string(),
|
||||
sender: "bob".to_string(),
|
||||
reply_target: "bob".to_string(),
|
||||
content: "world".to_string(),
|
||||
channel: "test-channel".to_string(),
|
||||
timestamp: 2,
|
||||
|
|
@ -1570,6 +1595,25 @@ mod tests {
|
|||
assert!(truncated.is_char_boundary(truncated.len()));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn prompt_contains_channel_capabilities() {
|
||||
let ws = make_workspace();
|
||||
let prompt = build_system_prompt(ws.path(), "model", &[], &[], None, None);
|
||||
|
||||
assert!(
|
||||
prompt.contains("## Channel Capabilities"),
|
||||
"missing Channel Capabilities section"
|
||||
);
|
||||
assert!(
|
||||
prompt.contains("running as a Discord bot"),
|
||||
"missing Discord context"
|
||||
);
|
||||
assert!(
|
||||
prompt.contains("NEVER repeat, describe, or echo credentials"),
|
||||
"missing security instruction"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn prompt_workspace_path() {
|
||||
let ws = make_workspace();
|
||||
|
|
@ -1583,6 +1627,7 @@ mod tests {
|
|||
let msg = traits::ChannelMessage {
|
||||
id: "msg_abc123".into(),
|
||||
sender: "U123".into(),
|
||||
reply_target: "C456".into(),
|
||||
content: "hello".into(),
|
||||
channel: "slack".into(),
|
||||
timestamp: 1,
|
||||
|
|
@ -1596,6 +1641,7 @@ mod tests {
|
|||
let msg1 = traits::ChannelMessage {
|
||||
id: "msg_1".into(),
|
||||
sender: "U123".into(),
|
||||
reply_target: "C456".into(),
|
||||
content: "first".into(),
|
||||
channel: "slack".into(),
|
||||
timestamp: 1,
|
||||
|
|
@ -1603,6 +1649,7 @@ mod tests {
|
|||
let msg2 = traits::ChannelMessage {
|
||||
id: "msg_2".into(),
|
||||
sender: "U123".into(),
|
||||
reply_target: "C456".into(),
|
||||
content: "second".into(),
|
||||
channel: "slack".into(),
|
||||
timestamp: 2,
|
||||
|
|
@ -1622,6 +1669,7 @@ mod tests {
|
|||
let msg1 = traits::ChannelMessage {
|
||||
id: "msg_1".into(),
|
||||
sender: "U123".into(),
|
||||
reply_target: "C456".into(),
|
||||
content: "I'm Paul".into(),
|
||||
channel: "slack".into(),
|
||||
timestamp: 1,
|
||||
|
|
@ -1629,6 +1677,7 @@ mod tests {
|
|||
let msg2 = traits::ChannelMessage {
|
||||
id: "msg_2".into(),
|
||||
sender: "U123".into(),
|
||||
reply_target: "C456".into(),
|
||||
content: "I'm 45".into(),
|
||||
channel: "slack".into(),
|
||||
timestamp: 2,
|
||||
|
|
@ -1638,6 +1687,7 @@ mod tests {
|
|||
&conversation_memory_key(&msg1),
|
||||
&msg1.content,
|
||||
MemoryCategory::Conversation,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
|
@ -1645,13 +1695,14 @@ mod tests {
|
|||
&conversation_memory_key(&msg2),
|
||||
&msg2.content,
|
||||
MemoryCategory::Conversation,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(mem.count().await.unwrap(), 2);
|
||||
|
||||
let recalled = mem.recall("45", 5).await.unwrap();
|
||||
let recalled = mem.recall("45", 5, None).await.unwrap();
|
||||
assert!(recalled.iter().any(|entry| entry.content.contains("45")));
|
||||
}
|
||||
|
||||
|
|
@ -1659,7 +1710,7 @@ mod tests {
|
|||
async fn build_memory_context_includes_recalled_entries() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let mem = SqliteMemory::new(tmp.path()).unwrap();
|
||||
mem.store("age_fact", "Age is 45", MemoryCategory::Conversation)
|
||||
mem.store("age_fact", "Age is 45", MemoryCategory::Conversation, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
|
|
|
|||
|
|
@ -161,6 +161,7 @@ impl Channel for SlackChannel {
|
|||
let channel_msg = ChannelMessage {
|
||||
id: format!("slack_{channel_id}_{ts}"),
|
||||
sender: user.to_string(),
|
||||
reply_target: channel_id.clone(),
|
||||
content: text.to_string(),
|
||||
channel: "slack".to_string(),
|
||||
timestamp: std::time::SystemTime::now()
|
||||
|
|
|
|||
|
|
@ -51,6 +51,133 @@ fn split_message_for_telegram(message: &str) -> Vec<String> {
|
|||
chunks
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
enum TelegramAttachmentKind {
|
||||
Image,
|
||||
Document,
|
||||
Video,
|
||||
Audio,
|
||||
Voice,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
struct TelegramAttachment {
|
||||
kind: TelegramAttachmentKind,
|
||||
target: String,
|
||||
}
|
||||
|
||||
impl TelegramAttachmentKind {
|
||||
fn from_marker(marker: &str) -> Option<Self> {
|
||||
match marker.trim().to_ascii_uppercase().as_str() {
|
||||
"IMAGE" | "PHOTO" => Some(Self::Image),
|
||||
"DOCUMENT" | "FILE" => Some(Self::Document),
|
||||
"VIDEO" => Some(Self::Video),
|
||||
"AUDIO" => Some(Self::Audio),
|
||||
"VOICE" => Some(Self::Voice),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn is_http_url(target: &str) -> bool {
|
||||
target.starts_with("http://") || target.starts_with("https://")
|
||||
}
|
||||
|
||||
fn infer_attachment_kind_from_target(target: &str) -> Option<TelegramAttachmentKind> {
|
||||
let normalized = target
|
||||
.split('?')
|
||||
.next()
|
||||
.unwrap_or(target)
|
||||
.split('#')
|
||||
.next()
|
||||
.unwrap_or(target);
|
||||
|
||||
let extension = Path::new(normalized)
|
||||
.extension()
|
||||
.and_then(|ext| ext.to_str())?
|
||||
.to_ascii_lowercase();
|
||||
|
||||
match extension.as_str() {
|
||||
"png" | "jpg" | "jpeg" | "gif" | "webp" | "bmp" => Some(TelegramAttachmentKind::Image),
|
||||
"mp4" | "mov" | "mkv" | "avi" | "webm" => Some(TelegramAttachmentKind::Video),
|
||||
"mp3" | "m4a" | "wav" | "flac" => Some(TelegramAttachmentKind::Audio),
|
||||
"ogg" | "oga" | "opus" => Some(TelegramAttachmentKind::Voice),
|
||||
"pdf" | "txt" | "md" | "csv" | "json" | "zip" | "tar" | "gz" | "doc" | "docx" | "xls"
|
||||
| "xlsx" | "ppt" | "pptx" => Some(TelegramAttachmentKind::Document),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_path_only_attachment(message: &str) -> Option<TelegramAttachment> {
|
||||
let trimmed = message.trim();
|
||||
if trimmed.is_empty() || trimmed.contains('\n') {
|
||||
return None;
|
||||
}
|
||||
|
||||
let candidate = trimmed.trim_matches(|c| matches!(c, '`' | '"' | '\''));
|
||||
if candidate.chars().any(char::is_whitespace) {
|
||||
return None;
|
||||
}
|
||||
|
||||
let candidate = candidate.strip_prefix("file://").unwrap_or(candidate);
|
||||
let kind = infer_attachment_kind_from_target(candidate)?;
|
||||
|
||||
if !is_http_url(candidate) && !Path::new(candidate).exists() {
|
||||
return None;
|
||||
}
|
||||
|
||||
Some(TelegramAttachment {
|
||||
kind,
|
||||
target: candidate.to_string(),
|
||||
})
|
||||
}
|
||||
|
||||
fn parse_attachment_markers(message: &str) -> (String, Vec<TelegramAttachment>) {
|
||||
let mut cleaned = String::with_capacity(message.len());
|
||||
let mut attachments = Vec::new();
|
||||
let mut cursor = 0;
|
||||
|
||||
while cursor < message.len() {
|
||||
let Some(open_rel) = message[cursor..].find('[') else {
|
||||
cleaned.push_str(&message[cursor..]);
|
||||
break;
|
||||
};
|
||||
|
||||
let open = cursor + open_rel;
|
||||
cleaned.push_str(&message[cursor..open]);
|
||||
|
||||
let Some(close_rel) = message[open..].find(']') else {
|
||||
cleaned.push_str(&message[open..]);
|
||||
break;
|
||||
};
|
||||
|
||||
let close = open + close_rel;
|
||||
let marker = &message[open + 1..close];
|
||||
|
||||
let parsed = marker.split_once(':').and_then(|(kind, target)| {
|
||||
let kind = TelegramAttachmentKind::from_marker(kind)?;
|
||||
let target = target.trim();
|
||||
if target.is_empty() {
|
||||
return None;
|
||||
}
|
||||
Some(TelegramAttachment {
|
||||
kind,
|
||||
target: target.to_string(),
|
||||
})
|
||||
});
|
||||
|
||||
if let Some(attachment) = parsed {
|
||||
attachments.push(attachment);
|
||||
} else {
|
||||
cleaned.push_str(&message[open..=close]);
|
||||
}
|
||||
|
||||
cursor = close + 1;
|
||||
}
|
||||
|
||||
(cleaned.trim().to_string(), attachments)
|
||||
}
|
||||
|
||||
/// Telegram channel — long-polls the Bot API for updates
|
||||
pub struct TelegramChannel {
|
||||
bot_token: String,
|
||||
|
|
@ -82,6 +209,216 @@ impl TelegramChannel {
|
|||
identities.into_iter().any(|id| self.is_user_allowed(id))
|
||||
}
|
||||
|
||||
fn parse_update_message(&self, update: &serde_json::Value) -> Option<ChannelMessage> {
|
||||
let message = update.get("message")?;
|
||||
|
||||
let text = message.get("text").and_then(serde_json::Value::as_str)?;
|
||||
|
||||
let username = message
|
||||
.get("from")
|
||||
.and_then(|from| from.get("username"))
|
||||
.and_then(serde_json::Value::as_str)
|
||||
.unwrap_or("unknown")
|
||||
.to_string();
|
||||
|
||||
let user_id = message
|
||||
.get("from")
|
||||
.and_then(|from| from.get("id"))
|
||||
.and_then(serde_json::Value::as_i64)
|
||||
.map(|id| id.to_string());
|
||||
|
||||
let sender_identity = if username == "unknown" {
|
||||
user_id.clone().unwrap_or_else(|| "unknown".to_string())
|
||||
} else {
|
||||
username.clone()
|
||||
};
|
||||
|
||||
let mut identities = vec![username.as_str()];
|
||||
if let Some(id) = user_id.as_deref() {
|
||||
identities.push(id);
|
||||
}
|
||||
|
||||
if !self.is_any_user_allowed(identities.iter().copied()) {
|
||||
tracing::warn!(
|
||||
"Telegram: ignoring message from unauthorized user: username={username}, user_id={}. \
|
||||
Allowlist Telegram @username or numeric user ID, then run `zeroclaw onboard --channels-only`.",
|
||||
user_id.as_deref().unwrap_or("unknown")
|
||||
);
|
||||
return None;
|
||||
}
|
||||
|
||||
let chat_id = message
|
||||
.get("chat")
|
||||
.and_then(|chat| chat.get("id"))
|
||||
.and_then(serde_json::Value::as_i64)
|
||||
.map(|id| id.to_string())?;
|
||||
|
||||
let message_id = message
|
||||
.get("message_id")
|
||||
.and_then(serde_json::Value::as_i64)
|
||||
.unwrap_or(0);
|
||||
|
||||
Some(ChannelMessage {
|
||||
id: format!("telegram_{chat_id}_{message_id}"),
|
||||
sender: sender_identity,
|
||||
reply_target: chat_id,
|
||||
content: text.to_string(),
|
||||
channel: "telegram".to_string(),
|
||||
timestamp: std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs(),
|
||||
})
|
||||
}
|
||||
|
||||
async fn send_text_chunks(&self, message: &str, chat_id: &str) -> anyhow::Result<()> {
|
||||
let chunks = split_message_for_telegram(message);
|
||||
|
||||
for (index, chunk) in chunks.iter().enumerate() {
|
||||
let text = if chunks.len() > 1 {
|
||||
if index == 0 {
|
||||
format!("{chunk}\n\n(continues...)")
|
||||
} else if index == chunks.len() - 1 {
|
||||
format!("(continued)\n\n{chunk}")
|
||||
} else {
|
||||
format!("(continued)\n\n{chunk}\n\n(continues...)")
|
||||
}
|
||||
} else {
|
||||
chunk.to_string()
|
||||
};
|
||||
|
||||
let markdown_body = serde_json::json!({
|
||||
"chat_id": chat_id,
|
||||
"text": text,
|
||||
"parse_mode": "Markdown"
|
||||
});
|
||||
|
||||
let markdown_resp = self
|
||||
.client
|
||||
.post(self.api_url("sendMessage"))
|
||||
.json(&markdown_body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if markdown_resp.status().is_success() {
|
||||
if index < chunks.len() - 1 {
|
||||
tokio::time::sleep(Duration::from_millis(100)).await;
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
let markdown_status = markdown_resp.status();
|
||||
let markdown_err = markdown_resp.text().await.unwrap_or_default();
|
||||
tracing::warn!(
|
||||
status = ?markdown_status,
|
||||
"Telegram sendMessage with Markdown failed; retrying without parse_mode"
|
||||
);
|
||||
|
||||
let plain_body = serde_json::json!({
|
||||
"chat_id": chat_id,
|
||||
"text": text,
|
||||
});
|
||||
let plain_resp = self
|
||||
.client
|
||||
.post(self.api_url("sendMessage"))
|
||||
.json(&plain_body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !plain_resp.status().is_success() {
|
||||
let plain_status = plain_resp.status();
|
||||
let plain_err = plain_resp.text().await.unwrap_or_default();
|
||||
anyhow::bail!(
|
||||
"Telegram sendMessage failed (markdown {}: {}; plain {}: {})",
|
||||
markdown_status,
|
||||
markdown_err,
|
||||
plain_status,
|
||||
plain_err
|
||||
);
|
||||
}
|
||||
|
||||
if index < chunks.len() - 1 {
|
||||
tokio::time::sleep(Duration::from_millis(100)).await;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn send_media_by_url(
|
||||
&self,
|
||||
method: &str,
|
||||
media_field: &str,
|
||||
chat_id: &str,
|
||||
url: &str,
|
||||
caption: Option<&str>,
|
||||
) -> anyhow::Result<()> {
|
||||
let mut body = serde_json::json!({
|
||||
"chat_id": chat_id,
|
||||
});
|
||||
body[media_field] = serde_json::Value::String(url.to_string());
|
||||
|
||||
if let Some(cap) = caption {
|
||||
body["caption"] = serde_json::Value::String(cap.to_string());
|
||||
}
|
||||
|
||||
let resp = self
|
||||
.client
|
||||
.post(self.api_url(method))
|
||||
.json(&body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !resp.status().is_success() {
|
||||
let err = resp.text().await?;
|
||||
anyhow::bail!("Telegram {method} by URL failed: {err}");
|
||||
}
|
||||
|
||||
tracing::info!("Telegram {method} sent to {chat_id}: {url}");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn send_attachment(
|
||||
&self,
|
||||
chat_id: &str,
|
||||
attachment: &TelegramAttachment,
|
||||
) -> anyhow::Result<()> {
|
||||
let target = attachment.target.trim();
|
||||
|
||||
if is_http_url(target) {
|
||||
return match attachment.kind {
|
||||
TelegramAttachmentKind::Image => {
|
||||
self.send_photo_by_url(chat_id, target, None).await
|
||||
}
|
||||
TelegramAttachmentKind::Document => {
|
||||
self.send_document_by_url(chat_id, target, None).await
|
||||
}
|
||||
TelegramAttachmentKind::Video => {
|
||||
self.send_video_by_url(chat_id, target, None).await
|
||||
}
|
||||
TelegramAttachmentKind::Audio => {
|
||||
self.send_audio_by_url(chat_id, target, None).await
|
||||
}
|
||||
TelegramAttachmentKind::Voice => {
|
||||
self.send_voice_by_url(chat_id, target, None).await
|
||||
}
|
||||
};
|
||||
}
|
||||
|
||||
let path = Path::new(target);
|
||||
if !path.exists() {
|
||||
anyhow::bail!("Telegram attachment path not found: {target}");
|
||||
}
|
||||
|
||||
match attachment.kind {
|
||||
TelegramAttachmentKind::Image => self.send_photo(chat_id, path, None).await,
|
||||
TelegramAttachmentKind::Document => self.send_document(chat_id, path, None).await,
|
||||
TelegramAttachmentKind::Video => self.send_video(chat_id, path, None).await,
|
||||
TelegramAttachmentKind::Audio => self.send_audio(chat_id, path, None).await,
|
||||
TelegramAttachmentKind::Voice => self.send_voice(chat_id, path, None).await,
|
||||
}
|
||||
}
|
||||
|
||||
/// Send a document/file to a Telegram chat
|
||||
pub async fn send_document(
|
||||
&self,
|
||||
|
|
@ -408,6 +745,39 @@ impl TelegramChannel {
|
|||
tracing::info!("Telegram photo (URL) sent to {chat_id}: {url}");
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Send a video by URL (Telegram will download it)
|
||||
pub async fn send_video_by_url(
|
||||
&self,
|
||||
chat_id: &str,
|
||||
url: &str,
|
||||
caption: Option<&str>,
|
||||
) -> anyhow::Result<()> {
|
||||
self.send_media_by_url("sendVideo", "video", chat_id, url, caption)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Send an audio file by URL (Telegram will download it)
|
||||
pub async fn send_audio_by_url(
|
||||
&self,
|
||||
chat_id: &str,
|
||||
url: &str,
|
||||
caption: Option<&str>,
|
||||
) -> anyhow::Result<()> {
|
||||
self.send_media_by_url("sendAudio", "audio", chat_id, url, caption)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Send a voice message by URL (Telegram will download it)
|
||||
pub async fn send_voice_by_url(
|
||||
&self,
|
||||
chat_id: &str,
|
||||
url: &str,
|
||||
caption: Option<&str>,
|
||||
) -> anyhow::Result<()> {
|
||||
self.send_media_by_url("sendVoice", "voice", chat_id, url, caption)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
|
|
@ -417,82 +787,27 @@ impl Channel for TelegramChannel {
|
|||
}
|
||||
|
||||
async fn send(&self, message: &str, chat_id: &str) -> anyhow::Result<()> {
|
||||
// Split message if it exceeds Telegram's 4096 character limit
|
||||
let chunks = split_message_for_telegram(message);
|
||||
let (text_without_markers, attachments) = parse_attachment_markers(message);
|
||||
|
||||
for (i, chunk) in chunks.iter().enumerate() {
|
||||
// Add continuation marker for multi-part messages
|
||||
let text = if chunks.len() > 1 {
|
||||
if i == 0 {
|
||||
format!("{chunk}\n\n(continues...)")
|
||||
} else if i == chunks.len() - 1 {
|
||||
format!("(continued)\n\n{chunk}")
|
||||
} else {
|
||||
format!("(continued)\n\n{chunk}\n\n(continues...)")
|
||||
}
|
||||
} else {
|
||||
chunk.to_string()
|
||||
};
|
||||
|
||||
let markdown_body = serde_json::json!({
|
||||
"chat_id": chat_id,
|
||||
"text": text,
|
||||
"parse_mode": "Markdown"
|
||||
});
|
||||
|
||||
let markdown_resp = self
|
||||
.client
|
||||
.post(self.api_url("sendMessage"))
|
||||
.json(&markdown_body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if markdown_resp.status().is_success() {
|
||||
// Small delay between chunks to avoid rate limiting
|
||||
if i < chunks.len() - 1 {
|
||||
tokio::time::sleep(Duration::from_millis(100)).await;
|
||||
}
|
||||
continue;
|
||||
if !attachments.is_empty() {
|
||||
if !text_without_markers.is_empty() {
|
||||
self.send_text_chunks(&text_without_markers, chat_id)
|
||||
.await?;
|
||||
}
|
||||
|
||||
let markdown_status = markdown_resp.status();
|
||||
let markdown_err = markdown_resp.text().await.unwrap_or_default();
|
||||
tracing::warn!(
|
||||
status = ?markdown_status,
|
||||
"Telegram sendMessage with Markdown failed; retrying without parse_mode"
|
||||
);
|
||||
|
||||
// Retry without parse_mode as a compatibility fallback.
|
||||
let plain_body = serde_json::json!({
|
||||
"chat_id": chat_id,
|
||||
"text": text,
|
||||
});
|
||||
let plain_resp = self
|
||||
.client
|
||||
.post(self.api_url("sendMessage"))
|
||||
.json(&plain_body)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
if !plain_resp.status().is_success() {
|
||||
let plain_status = plain_resp.status();
|
||||
let plain_err = plain_resp.text().await.unwrap_or_default();
|
||||
anyhow::bail!(
|
||||
"Telegram sendMessage failed (markdown {}: {}; plain {}: {})",
|
||||
markdown_status,
|
||||
markdown_err,
|
||||
plain_status,
|
||||
plain_err
|
||||
);
|
||||
for attachment in &attachments {
|
||||
self.send_attachment(chat_id, attachment).await?;
|
||||
}
|
||||
|
||||
// Small delay between chunks to avoid rate limiting
|
||||
if i < chunks.len() - 1 {
|
||||
tokio::time::sleep(Duration::from_millis(100)).await;
|
||||
}
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
Ok(())
|
||||
if let Some(attachment) = parse_path_only_attachment(message) {
|
||||
self.send_attachment(chat_id, &attachment).await?;
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
self.send_text_chunks(message, chat_id).await
|
||||
}
|
||||
|
||||
async fn listen(&self, tx: tokio::sync::mpsc::Sender<ChannelMessage>) -> anyhow::Result<()> {
|
||||
|
|
@ -533,59 +848,13 @@ impl Channel for TelegramChannel {
|
|||
offset = uid + 1;
|
||||
}
|
||||
|
||||
let Some(message) = update.get("message") else {
|
||||
let Some(msg) = self.parse_update_message(update) else {
|
||||
continue;
|
||||
};
|
||||
|
||||
let Some(text) = message.get("text").and_then(serde_json::Value::as_str) else {
|
||||
continue;
|
||||
};
|
||||
|
||||
let username_opt = message
|
||||
.get("from")
|
||||
.and_then(|f| f.get("username"))
|
||||
.and_then(|u| u.as_str());
|
||||
let username = username_opt.unwrap_or("unknown");
|
||||
|
||||
let user_id = message
|
||||
.get("from")
|
||||
.and_then(|f| f.get("id"))
|
||||
.and_then(serde_json::Value::as_i64);
|
||||
let user_id_str = user_id.map(|id| id.to_string());
|
||||
|
||||
let mut identities = vec![username];
|
||||
if let Some(ref id) = user_id_str {
|
||||
identities.push(id.as_str());
|
||||
}
|
||||
|
||||
if !self.is_any_user_allowed(identities.iter().copied()) {
|
||||
tracing::warn!(
|
||||
"Telegram: ignoring message from unauthorized user: username={username}, user_id={}. \
|
||||
Allowlist Telegram @username or numeric user ID, then run `zeroclaw onboard --channels-only`.",
|
||||
user_id_str.as_deref().unwrap_or("unknown")
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
let chat_id = message
|
||||
.get("chat")
|
||||
.and_then(|c| c.get("id"))
|
||||
.and_then(serde_json::Value::as_i64)
|
||||
.map(|id| id.to_string());
|
||||
|
||||
let Some(chat_id) = chat_id else {
|
||||
tracing::warn!("Telegram: missing chat_id in message, skipping");
|
||||
continue;
|
||||
};
|
||||
|
||||
let message_id = message
|
||||
.get("message_id")
|
||||
.and_then(|v| v.as_i64())
|
||||
.unwrap_or(0);
|
||||
|
||||
// Send "typing" indicator immediately when we receive a message
|
||||
let typing_body = serde_json::json!({
|
||||
"chat_id": &chat_id,
|
||||
"chat_id": &msg.reply_target,
|
||||
"action": "typing"
|
||||
});
|
||||
let _ = self
|
||||
|
|
@ -595,17 +864,6 @@ Allowlist Telegram @username or numeric user ID, then run `zeroclaw onboard --ch
|
|||
.send()
|
||||
.await; // Ignore errors for typing indicator
|
||||
|
||||
let msg = ChannelMessage {
|
||||
id: format!("telegram_{chat_id}_{message_id}"),
|
||||
sender: username.to_string(),
|
||||
content: text.to_string(),
|
||||
channel: "telegram".to_string(),
|
||||
timestamp: std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.unwrap_or_default()
|
||||
.as_secs(),
|
||||
};
|
||||
|
||||
if tx.send(msg).await.is_err() {
|
||||
return Ok(());
|
||||
}
|
||||
|
|
@ -716,6 +974,107 @@ mod tests {
|
|||
assert!(!ch.is_any_user_allowed(["unknown", "123456789"]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_attachment_markers_extracts_multiple_types() {
|
||||
let message = "Here are files [IMAGE:/tmp/a.png] and [DOCUMENT:https://example.com/a.pdf]";
|
||||
let (cleaned, attachments) = parse_attachment_markers(message);
|
||||
|
||||
assert_eq!(cleaned, "Here are files and");
|
||||
assert_eq!(attachments.len(), 2);
|
||||
assert_eq!(attachments[0].kind, TelegramAttachmentKind::Image);
|
||||
assert_eq!(attachments[0].target, "/tmp/a.png");
|
||||
assert_eq!(attachments[1].kind, TelegramAttachmentKind::Document);
|
||||
assert_eq!(attachments[1].target, "https://example.com/a.pdf");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_attachment_markers_keeps_invalid_markers_in_text() {
|
||||
let message = "Report [UNKNOWN:/tmp/a.bin]";
|
||||
let (cleaned, attachments) = parse_attachment_markers(message);
|
||||
|
||||
assert_eq!(cleaned, "Report [UNKNOWN:/tmp/a.bin]");
|
||||
assert!(attachments.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_path_only_attachment_detects_existing_file() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let image_path = dir.path().join("snap.png");
|
||||
std::fs::write(&image_path, b"fake-png").unwrap();
|
||||
|
||||
let parsed = parse_path_only_attachment(image_path.to_string_lossy().as_ref())
|
||||
.expect("expected attachment");
|
||||
|
||||
assert_eq!(parsed.kind, TelegramAttachmentKind::Image);
|
||||
assert_eq!(parsed.target, image_path.to_string_lossy());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_path_only_attachment_rejects_sentence_text() {
|
||||
assert!(parse_path_only_attachment("Screenshot saved to /tmp/snap.png").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn infer_attachment_kind_from_target_detects_document_extension() {
|
||||
assert_eq!(
|
||||
infer_attachment_kind_from_target("https://example.com/files/specs.pdf?download=1"),
|
||||
Some(TelegramAttachmentKind::Document)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_update_message_uses_chat_id_as_reply_target() {
|
||||
let ch = TelegramChannel::new("token".into(), vec!["*".into()]);
|
||||
let update = serde_json::json!({
|
||||
"update_id": 1,
|
||||
"message": {
|
||||
"message_id": 33,
|
||||
"text": "hello",
|
||||
"from": {
|
||||
"id": 555,
|
||||
"username": "alice"
|
||||
},
|
||||
"chat": {
|
||||
"id": -100200300
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let msg = ch
|
||||
.parse_update_message(&update)
|
||||
.expect("message should parse");
|
||||
|
||||
assert_eq!(msg.sender, "alice");
|
||||
assert_eq!(msg.reply_target, "-100200300");
|
||||
assert_eq!(msg.content, "hello");
|
||||
assert_eq!(msg.id, "telegram_-100200300_33");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_update_message_allows_numeric_id_without_username() {
|
||||
let ch = TelegramChannel::new("token".into(), vec!["555".into()]);
|
||||
let update = serde_json::json!({
|
||||
"update_id": 2,
|
||||
"message": {
|
||||
"message_id": 9,
|
||||
"text": "ping",
|
||||
"from": {
|
||||
"id": 555
|
||||
},
|
||||
"chat": {
|
||||
"id": 12345
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let msg = ch
|
||||
.parse_update_message(&update)
|
||||
.expect("numeric allowlist should pass");
|
||||
|
||||
assert_eq!(msg.sender, "555");
|
||||
assert_eq!(msg.reply_target, "12345");
|
||||
}
|
||||
|
||||
// ── File sending API URL tests ──────────────────────────────────
|
||||
|
||||
#[test]
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ use async_trait::async_trait;
|
|||
pub struct ChannelMessage {
|
||||
pub id: String,
|
||||
pub sender: String,
|
||||
pub reply_target: String,
|
||||
pub content: String,
|
||||
pub channel: String,
|
||||
pub timestamp: u64,
|
||||
|
|
@ -62,6 +63,7 @@ mod tests {
|
|||
tx.send(ChannelMessage {
|
||||
id: "1".into(),
|
||||
sender: "tester".into(),
|
||||
reply_target: "tester".into(),
|
||||
content: "hello".into(),
|
||||
channel: "dummy".into(),
|
||||
timestamp: 123,
|
||||
|
|
@ -76,6 +78,7 @@ mod tests {
|
|||
let message = ChannelMessage {
|
||||
id: "42".into(),
|
||||
sender: "alice".into(),
|
||||
reply_target: "alice".into(),
|
||||
content: "ping".into(),
|
||||
channel: "dummy".into(),
|
||||
timestamp: 999,
|
||||
|
|
@ -84,6 +87,7 @@ mod tests {
|
|||
let cloned = message.clone();
|
||||
assert_eq!(cloned.id, "42");
|
||||
assert_eq!(cloned.sender, "alice");
|
||||
assert_eq!(cloned.reply_target, "alice");
|
||||
assert_eq!(cloned.content, "ping");
|
||||
assert_eq!(cloned.channel, "dummy");
|
||||
assert_eq!(cloned.timestamp, 999);
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ use uuid::Uuid;
|
|||
/// happens in the gateway when Meta sends webhook events.
|
||||
pub struct WhatsAppChannel {
|
||||
access_token: String,
|
||||
phone_number_id: String,
|
||||
endpoint_id: String,
|
||||
verify_token: String,
|
||||
allowed_numbers: Vec<String>,
|
||||
client: reqwest::Client,
|
||||
|
|
@ -19,13 +19,13 @@ pub struct WhatsAppChannel {
|
|||
impl WhatsAppChannel {
|
||||
pub fn new(
|
||||
access_token: String,
|
||||
phone_number_id: String,
|
||||
endpoint_id: String,
|
||||
verify_token: String,
|
||||
allowed_numbers: Vec<String>,
|
||||
) -> Self {
|
||||
Self {
|
||||
access_token,
|
||||
phone_number_id,
|
||||
endpoint_id,
|
||||
verify_token,
|
||||
allowed_numbers,
|
||||
client: reqwest::Client::new(),
|
||||
|
|
@ -119,6 +119,7 @@ impl WhatsAppChannel {
|
|||
|
||||
messages.push(ChannelMessage {
|
||||
id: Uuid::new_v4().to_string(),
|
||||
reply_target: normalized_from.clone(),
|
||||
sender: normalized_from,
|
||||
content,
|
||||
channel: "whatsapp".to_string(),
|
||||
|
|
@ -142,7 +143,7 @@ impl Channel for WhatsAppChannel {
|
|||
// WhatsApp Cloud API: POST to /v18.0/{phone_number_id}/messages
|
||||
let url = format!(
|
||||
"https://graph.facebook.com/v18.0/{}/messages",
|
||||
self.phone_number_id
|
||||
self.endpoint_id
|
||||
);
|
||||
|
||||
// Normalize recipient (remove leading + if present for API)
|
||||
|
|
@ -162,7 +163,7 @@ impl Channel for WhatsAppChannel {
|
|||
let resp = self
|
||||
.client
|
||||
.post(&url)
|
||||
.header("Authorization", format!("Bearer {}", self.access_token))
|
||||
.bearer_auth(&self.access_token)
|
||||
.header("Content-Type", "application/json")
|
||||
.json(&body)
|
||||
.send()
|
||||
|
|
@ -195,11 +196,11 @@ impl Channel for WhatsAppChannel {
|
|||
|
||||
async fn health_check(&self) -> bool {
|
||||
// Check if we can reach the WhatsApp API
|
||||
let url = format!("https://graph.facebook.com/v18.0/{}", self.phone_number_id);
|
||||
let url = format!("https://graph.facebook.com/v18.0/{}", self.endpoint_id);
|
||||
|
||||
self.client
|
||||
.get(&url)
|
||||
.header("Authorization", format!("Bearer {}", self.access_token))
|
||||
.bearer_auth(&self.access_token)
|
||||
.send()
|
||||
.await
|
||||
.map(|r| r.status().is_success())
|
||||
|
|
|
|||
|
|
@ -37,9 +37,22 @@ mod tests {
|
|||
guild_id: Some("123".into()),
|
||||
allowed_users: vec![],
|
||||
listen_to_bots: false,
|
||||
mention_only: false,
|
||||
};
|
||||
|
||||
let lark = LarkConfig {
|
||||
app_id: "app-id".into(),
|
||||
app_secret: "app-secret".into(),
|
||||
encrypt_key: None,
|
||||
verification_token: None,
|
||||
allowed_users: vec![],
|
||||
use_feishu: false,
|
||||
receive_mode: crate::config::schema::LarkReceiveMode::Websocket,
|
||||
port: None,
|
||||
};
|
||||
|
||||
assert_eq!(telegram.allowed_users.len(), 1);
|
||||
assert_eq!(discord.guild_id.as_deref(), Some("123"));
|
||||
assert_eq!(lark.app_id, "app-id");
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -18,6 +18,8 @@ pub struct Config {
|
|||
#[serde(skip)]
|
||||
pub config_path: PathBuf,
|
||||
pub api_key: Option<String>,
|
||||
/// Base URL override for provider API (e.g. "http://10.0.0.1:11434" for remote Ollama)
|
||||
pub api_url: Option<String>,
|
||||
pub default_provider: Option<String>,
|
||||
pub default_model: Option<String>,
|
||||
pub default_temperature: f64,
|
||||
|
|
@ -1317,6 +1319,10 @@ pub struct DiscordConfig {
|
|||
/// The bot still ignores its own messages to prevent feedback loops.
|
||||
#[serde(default)]
|
||||
pub listen_to_bots: bool,
|
||||
/// When true, only respond to messages that @-mention the bot.
|
||||
/// Other messages in the guild are silently ignored.
|
||||
#[serde(default)]
|
||||
pub mention_only: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
|
|
@ -1395,8 +1401,20 @@ fn default_irc_port() -> u16 {
|
|||
6697
|
||||
}
|
||||
|
||||
/// Lark/Feishu configuration for messaging integration
|
||||
/// Lark is the international version, Feishu is the Chinese version
|
||||
/// How ZeroClaw receives events from Feishu / Lark.
|
||||
///
|
||||
/// - `websocket` (default) — persistent WSS long-connection; no public URL required.
|
||||
/// - `webhook` — HTTP callback server; requires a public HTTPS endpoint.
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum LarkReceiveMode {
|
||||
#[default]
|
||||
Websocket,
|
||||
Webhook,
|
||||
}
|
||||
|
||||
/// Lark/Feishu configuration for messaging integration.
|
||||
/// Lark is the international version; Feishu is the Chinese version.
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct LarkConfig {
|
||||
/// App ID from Lark/Feishu developer console
|
||||
|
|
@ -1415,6 +1433,13 @@ pub struct LarkConfig {
|
|||
/// Whether to use the Feishu (Chinese) endpoint instead of Lark (International)
|
||||
#[serde(default)]
|
||||
pub use_feishu: bool,
|
||||
/// Event receive mode: "websocket" (default) or "webhook"
|
||||
#[serde(default)]
|
||||
pub receive_mode: LarkReceiveMode,
|
||||
/// HTTP port for webhook mode only. Must be set when receive_mode = "webhook".
|
||||
/// Not required (and ignored) for websocket mode.
|
||||
#[serde(default)]
|
||||
pub port: Option<u16>,
|
||||
}
|
||||
|
||||
// ── Security Config ─────────────────────────────────────────────────
|
||||
|
|
@ -1594,6 +1619,7 @@ impl Default for Config {
|
|||
workspace_dir: zeroclaw_dir.join("workspace"),
|
||||
config_path: zeroclaw_dir.join("config.toml"),
|
||||
api_key: None,
|
||||
api_url: None,
|
||||
default_provider: Some("openrouter".to_string()),
|
||||
default_model: Some("anthropic/claude-sonnet-4".to_string()),
|
||||
default_temperature: 0.7,
|
||||
|
|
@ -1623,35 +1649,146 @@ impl Default for Config {
|
|||
}
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub fn load_or_init() -> Result<Self> {
|
||||
let home = UserDirs::new()
|
||||
.map(|u| u.home_dir().to_path_buf())
|
||||
.context("Could not find home directory")?;
|
||||
let zeroclaw_dir = home.join(".zeroclaw");
|
||||
let config_path = zeroclaw_dir.join("config.toml");
|
||||
fn default_config_and_workspace_dirs() -> Result<(PathBuf, PathBuf)> {
|
||||
let home = UserDirs::new()
|
||||
.map(|u| u.home_dir().to_path_buf())
|
||||
.context("Could not find home directory")?;
|
||||
let config_dir = home.join(".zeroclaw");
|
||||
Ok((config_dir.clone(), config_dir.join("workspace")))
|
||||
}
|
||||
|
||||
if !zeroclaw_dir.exists() {
|
||||
fs::create_dir_all(&zeroclaw_dir).context("Failed to create .zeroclaw directory")?;
|
||||
fs::create_dir_all(zeroclaw_dir.join("workspace"))
|
||||
.context("Failed to create workspace directory")?;
|
||||
fn resolve_config_dir_for_workspace(workspace_dir: &Path) -> PathBuf {
|
||||
let workspace_config_dir = workspace_dir.to_path_buf();
|
||||
if workspace_config_dir.join("config.toml").exists() {
|
||||
return workspace_config_dir;
|
||||
}
|
||||
|
||||
let legacy_config_dir = workspace_dir
|
||||
.parent()
|
||||
.map(|parent| parent.join(".zeroclaw"));
|
||||
if let Some(legacy_dir) = legacy_config_dir {
|
||||
if legacy_dir.join("config.toml").exists() {
|
||||
return legacy_dir;
|
||||
}
|
||||
|
||||
if workspace_dir
|
||||
.file_name()
|
||||
.is_some_and(|name| name == std::ffi::OsStr::new("workspace"))
|
||||
{
|
||||
return legacy_dir;
|
||||
}
|
||||
}
|
||||
|
||||
workspace_config_dir
|
||||
}
|
||||
|
||||
fn decrypt_optional_secret(
|
||||
store: &crate::security::SecretStore,
|
||||
value: &mut Option<String>,
|
||||
field_name: &str,
|
||||
) -> Result<()> {
|
||||
if let Some(raw) = value.clone() {
|
||||
if crate::security::SecretStore::is_encrypted(&raw) {
|
||||
*value = Some(
|
||||
store
|
||||
.decrypt(&raw)
|
||||
.with_context(|| format!("Failed to decrypt {field_name}"))?,
|
||||
);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn encrypt_optional_secret(
|
||||
store: &crate::security::SecretStore,
|
||||
value: &mut Option<String>,
|
||||
field_name: &str,
|
||||
) -> Result<()> {
|
||||
if let Some(raw) = value.clone() {
|
||||
if !crate::security::SecretStore::is_encrypted(&raw) {
|
||||
*value = Some(
|
||||
store
|
||||
.encrypt(&raw)
|
||||
.with_context(|| format!("Failed to encrypt {field_name}"))?,
|
||||
);
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
impl Config {
|
||||
pub fn load_or_init() -> Result<Self> {
|
||||
// Resolve workspace first so config loading can follow ZEROCLAW_WORKSPACE.
|
||||
let (zeroclaw_dir, workspace_dir) = match std::env::var("ZEROCLAW_WORKSPACE") {
|
||||
Ok(custom_workspace) if !custom_workspace.is_empty() => {
|
||||
let workspace = PathBuf::from(custom_workspace);
|
||||
(resolve_config_dir_for_workspace(&workspace), workspace)
|
||||
}
|
||||
_ => default_config_and_workspace_dirs()?,
|
||||
};
|
||||
|
||||
let config_path = zeroclaw_dir.join("config.toml");
|
||||
|
||||
fs::create_dir_all(&zeroclaw_dir).context("Failed to create config directory")?;
|
||||
fs::create_dir_all(&workspace_dir).context("Failed to create workspace directory")?;
|
||||
|
||||
if config_path.exists() {
|
||||
// Warn if config file is world-readable (may contain API keys)
|
||||
#[cfg(unix)]
|
||||
{
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
if let Ok(meta) = fs::metadata(&config_path) {
|
||||
if meta.permissions().mode() & 0o004 != 0 {
|
||||
tracing::warn!(
|
||||
"Config file {:?} is world-readable (mode {:o}). \
|
||||
Consider restricting with: chmod 600 {:?}",
|
||||
config_path,
|
||||
meta.permissions().mode() & 0o777,
|
||||
config_path,
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let contents =
|
||||
fs::read_to_string(&config_path).context("Failed to read config file")?;
|
||||
let mut config: Config =
|
||||
toml::from_str(&contents).context("Failed to parse config file")?;
|
||||
// Set computed paths that are skipped during serialization
|
||||
config.config_path = config_path.clone();
|
||||
config.workspace_dir = zeroclaw_dir.join("workspace");
|
||||
config.workspace_dir = workspace_dir;
|
||||
let store = crate::security::SecretStore::new(&zeroclaw_dir, config.secrets.encrypt);
|
||||
decrypt_optional_secret(&store, &mut config.api_key, "config.api_key")?;
|
||||
decrypt_optional_secret(
|
||||
&store,
|
||||
&mut config.composio.api_key,
|
||||
"config.composio.api_key",
|
||||
)?;
|
||||
|
||||
decrypt_optional_secret(
|
||||
&store,
|
||||
&mut config.browser.computer_use.api_key,
|
||||
"config.browser.computer_use.api_key",
|
||||
)?;
|
||||
|
||||
for agent in config.agents.values_mut() {
|
||||
decrypt_optional_secret(&store, &mut agent.api_key, "config.agents.*.api_key")?;
|
||||
}
|
||||
config.apply_env_overrides();
|
||||
Ok(config)
|
||||
} else {
|
||||
let mut config = Config::default();
|
||||
config.config_path = config_path.clone();
|
||||
config.workspace_dir = zeroclaw_dir.join("workspace");
|
||||
config.workspace_dir = workspace_dir;
|
||||
config.save()?;
|
||||
|
||||
// Restrict permissions on newly created config file (may contain API keys)
|
||||
#[cfg(unix)]
|
||||
{
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
let _ = fs::set_permissions(&config_path, fs::Permissions::from_mode(0o600));
|
||||
}
|
||||
|
||||
config.apply_env_overrides();
|
||||
Ok(config)
|
||||
}
|
||||
|
|
@ -1732,23 +1869,29 @@ impl Config {
|
|||
}
|
||||
|
||||
pub fn save(&self) -> Result<()> {
|
||||
// Encrypt agent API keys before serialization
|
||||
// Encrypt secrets before serialization
|
||||
let mut config_to_save = self.clone();
|
||||
let zeroclaw_dir = self
|
||||
.config_path
|
||||
.parent()
|
||||
.context("Config path must have a parent directory")?;
|
||||
let store = crate::security::SecretStore::new(zeroclaw_dir, self.secrets.encrypt);
|
||||
|
||||
encrypt_optional_secret(&store, &mut config_to_save.api_key, "config.api_key")?;
|
||||
encrypt_optional_secret(
|
||||
&store,
|
||||
&mut config_to_save.composio.api_key,
|
||||
"config.composio.api_key",
|
||||
)?;
|
||||
|
||||
encrypt_optional_secret(
|
||||
&store,
|
||||
&mut config_to_save.browser.computer_use.api_key,
|
||||
"config.browser.computer_use.api_key",
|
||||
)?;
|
||||
|
||||
for agent in config_to_save.agents.values_mut() {
|
||||
if let Some(ref plaintext_key) = agent.api_key {
|
||||
if !crate::security::SecretStore::is_encrypted(plaintext_key) {
|
||||
agent.api_key = Some(
|
||||
store
|
||||
.encrypt(plaintext_key)
|
||||
.context("Failed to encrypt agent API key")?,
|
||||
);
|
||||
}
|
||||
}
|
||||
encrypt_optional_secret(&store, &mut agent.api_key, "config.agents.*.api_key")?;
|
||||
}
|
||||
|
||||
let toml_str =
|
||||
|
|
@ -1949,6 +2092,7 @@ default_temperature = 0.7
|
|||
workspace_dir: PathBuf::from("/tmp/test/workspace"),
|
||||
config_path: PathBuf::from("/tmp/test/config.toml"),
|
||||
api_key: Some("sk-test-key".into()),
|
||||
api_url: None,
|
||||
default_provider: Some("openrouter".into()),
|
||||
default_model: Some("gpt-4o".into()),
|
||||
default_temperature: 0.5,
|
||||
|
|
@ -2091,6 +2235,7 @@ tool_dispatcher = "xml"
|
|||
workspace_dir: dir.join("workspace"),
|
||||
config_path: config_path.clone(),
|
||||
api_key: Some("sk-roundtrip".into()),
|
||||
api_url: None,
|
||||
default_provider: Some("openrouter".into()),
|
||||
default_model: Some("test-model".into()),
|
||||
default_temperature: 0.9,
|
||||
|
|
@ -2123,13 +2268,82 @@ tool_dispatcher = "xml"
|
|||
|
||||
let contents = fs::read_to_string(&config_path).unwrap();
|
||||
let loaded: Config = toml::from_str(&contents).unwrap();
|
||||
assert_eq!(loaded.api_key.as_deref(), Some("sk-roundtrip"));
|
||||
assert!(loaded
|
||||
.api_key
|
||||
.as_deref()
|
||||
.is_some_and(crate::security::SecretStore::is_encrypted));
|
||||
let store = crate::security::SecretStore::new(&dir, true);
|
||||
let decrypted = store.decrypt(loaded.api_key.as_deref().unwrap()).unwrap();
|
||||
assert_eq!(decrypted, "sk-roundtrip");
|
||||
assert_eq!(loaded.default_model.as_deref(), Some("test-model"));
|
||||
assert!((loaded.default_temperature - 0.9).abs() < f64::EPSILON);
|
||||
|
||||
let _ = fs::remove_dir_all(&dir);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn config_save_encrypts_nested_credentials() {
|
||||
let dir = std::env::temp_dir().join(format!(
|
||||
"zeroclaw_test_nested_credentials_{}",
|
||||
uuid::Uuid::new_v4()
|
||||
));
|
||||
fs::create_dir_all(&dir).unwrap();
|
||||
|
||||
let mut config = Config::default();
|
||||
config.workspace_dir = dir.join("workspace");
|
||||
config.config_path = dir.join("config.toml");
|
||||
config.api_key = Some("root-credential".into());
|
||||
config.composio.api_key = Some("composio-credential".into());
|
||||
config.browser.computer_use.api_key = Some("browser-credential".into());
|
||||
|
||||
config.agents.insert(
|
||||
"worker".into(),
|
||||
DelegateAgentConfig {
|
||||
provider: "openrouter".into(),
|
||||
model: "model-test".into(),
|
||||
system_prompt: None,
|
||||
api_key: Some("agent-credential".into()),
|
||||
temperature: None,
|
||||
max_depth: 3,
|
||||
},
|
||||
);
|
||||
|
||||
config.save().unwrap();
|
||||
|
||||
let contents = fs::read_to_string(config.config_path.clone()).unwrap();
|
||||
let stored: Config = toml::from_str(&contents).unwrap();
|
||||
let store = crate::security::SecretStore::new(&dir, true);
|
||||
|
||||
let root_encrypted = stored.api_key.as_deref().unwrap();
|
||||
assert!(crate::security::SecretStore::is_encrypted(root_encrypted));
|
||||
assert_eq!(store.decrypt(root_encrypted).unwrap(), "root-credential");
|
||||
|
||||
let composio_encrypted = stored.composio.api_key.as_deref().unwrap();
|
||||
assert!(crate::security::SecretStore::is_encrypted(
|
||||
composio_encrypted
|
||||
));
|
||||
assert_eq!(
|
||||
store.decrypt(composio_encrypted).unwrap(),
|
||||
"composio-credential"
|
||||
);
|
||||
|
||||
let browser_encrypted = stored.browser.computer_use.api_key.as_deref().unwrap();
|
||||
assert!(crate::security::SecretStore::is_encrypted(
|
||||
browser_encrypted
|
||||
));
|
||||
assert_eq!(
|
||||
store.decrypt(browser_encrypted).unwrap(),
|
||||
"browser-credential"
|
||||
);
|
||||
|
||||
let worker = stored.agents.get("worker").unwrap();
|
||||
let worker_encrypted = worker.api_key.as_deref().unwrap();
|
||||
assert!(crate::security::SecretStore::is_encrypted(worker_encrypted));
|
||||
assert_eq!(store.decrypt(worker_encrypted).unwrap(), "agent-credential");
|
||||
|
||||
let _ = fs::remove_dir_all(&dir);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn config_save_atomic_cleanup() {
|
||||
let dir =
|
||||
|
|
@ -2182,6 +2396,7 @@ tool_dispatcher = "xml"
|
|||
guild_id: Some("12345".into()),
|
||||
allowed_users: vec![],
|
||||
listen_to_bots: false,
|
||||
mention_only: false,
|
||||
};
|
||||
let json = serde_json::to_string(&dc).unwrap();
|
||||
let parsed: DiscordConfig = serde_json::from_str(&json).unwrap();
|
||||
|
|
@ -2196,6 +2411,7 @@ tool_dispatcher = "xml"
|
|||
guild_id: None,
|
||||
allowed_users: vec![],
|
||||
listen_to_bots: false,
|
||||
mention_only: false,
|
||||
};
|
||||
let json = serde_json::to_string(&dc).unwrap();
|
||||
let parsed: DiscordConfig = serde_json::from_str(&json).unwrap();
|
||||
|
|
@ -2818,6 +3034,96 @@ default_temperature = 0.7
|
|||
std::env::remove_var("ZEROCLAW_WORKSPACE");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_or_init_workspace_override_uses_workspace_root_for_config() {
|
||||
let _env_guard = env_override_test_guard();
|
||||
let temp_home =
|
||||
std::env::temp_dir().join(format!("zeroclaw_test_home_{}", uuid::Uuid::new_v4()));
|
||||
let workspace_dir = temp_home.join("profile-a");
|
||||
|
||||
let original_home = std::env::var("HOME").ok();
|
||||
std::env::set_var("HOME", &temp_home);
|
||||
std::env::set_var("ZEROCLAW_WORKSPACE", &workspace_dir);
|
||||
|
||||
let config = Config::load_or_init().unwrap();
|
||||
|
||||
assert_eq!(config.workspace_dir, workspace_dir);
|
||||
assert_eq!(config.config_path, workspace_dir.join("config.toml"));
|
||||
assert!(workspace_dir.join("config.toml").exists());
|
||||
|
||||
std::env::remove_var("ZEROCLAW_WORKSPACE");
|
||||
if let Some(home) = original_home {
|
||||
std::env::set_var("HOME", home);
|
||||
} else {
|
||||
std::env::remove_var("HOME");
|
||||
}
|
||||
let _ = fs::remove_dir_all(temp_home);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_or_init_workspace_suffix_uses_legacy_config_layout() {
|
||||
let _env_guard = env_override_test_guard();
|
||||
let temp_home =
|
||||
std::env::temp_dir().join(format!("zeroclaw_test_home_{}", uuid::Uuid::new_v4()));
|
||||
let workspace_dir = temp_home.join("workspace");
|
||||
let legacy_config_path = temp_home.join(".zeroclaw").join("config.toml");
|
||||
|
||||
let original_home = std::env::var("HOME").ok();
|
||||
std::env::set_var("HOME", &temp_home);
|
||||
std::env::set_var("ZEROCLAW_WORKSPACE", &workspace_dir);
|
||||
|
||||
let config = Config::load_or_init().unwrap();
|
||||
|
||||
assert_eq!(config.workspace_dir, workspace_dir);
|
||||
assert_eq!(config.config_path, legacy_config_path);
|
||||
assert!(config.config_path.exists());
|
||||
|
||||
std::env::remove_var("ZEROCLAW_WORKSPACE");
|
||||
if let Some(home) = original_home {
|
||||
std::env::set_var("HOME", home);
|
||||
} else {
|
||||
std::env::remove_var("HOME");
|
||||
}
|
||||
let _ = fs::remove_dir_all(temp_home);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_or_init_workspace_override_keeps_existing_legacy_config() {
|
||||
let _env_guard = env_override_test_guard();
|
||||
let temp_home =
|
||||
std::env::temp_dir().join(format!("zeroclaw_test_home_{}", uuid::Uuid::new_v4()));
|
||||
let workspace_dir = temp_home.join("custom-workspace");
|
||||
let legacy_config_dir = temp_home.join(".zeroclaw");
|
||||
let legacy_config_path = legacy_config_dir.join("config.toml");
|
||||
|
||||
fs::create_dir_all(&legacy_config_dir).unwrap();
|
||||
fs::write(
|
||||
&legacy_config_path,
|
||||
r#"default_temperature = 0.7
|
||||
default_model = "legacy-model"
|
||||
"#,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let original_home = std::env::var("HOME").ok();
|
||||
std::env::set_var("HOME", &temp_home);
|
||||
std::env::set_var("ZEROCLAW_WORKSPACE", &workspace_dir);
|
||||
|
||||
let config = Config::load_or_init().unwrap();
|
||||
|
||||
assert_eq!(config.workspace_dir, workspace_dir);
|
||||
assert_eq!(config.config_path, legacy_config_path);
|
||||
assert_eq!(config.default_model.as_deref(), Some("legacy-model"));
|
||||
|
||||
std::env::remove_var("ZEROCLAW_WORKSPACE");
|
||||
if let Some(home) = original_home {
|
||||
std::env::set_var("HOME", home);
|
||||
} else {
|
||||
std::env::remove_var("HOME");
|
||||
}
|
||||
let _ = fs::remove_dir_all(temp_home);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn env_override_empty_values_ignored() {
|
||||
let _env_guard = env_override_test_guard();
|
||||
|
|
@ -2975,4 +3281,118 @@ default_temperature = 0.7
|
|||
assert_eq!(parsed.boards[0].board, "nucleo-f401re");
|
||||
assert_eq!(parsed.boards[0].path.as_deref(), Some("/dev/ttyACM0"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lark_config_serde() {
|
||||
let lc = LarkConfig {
|
||||
app_id: "cli_123456".into(),
|
||||
app_secret: "secret_abc".into(),
|
||||
encrypt_key: Some("encrypt_key".into()),
|
||||
verification_token: Some("verify_token".into()),
|
||||
allowed_users: vec!["user_123".into(), "user_456".into()],
|
||||
use_feishu: true,
|
||||
receive_mode: LarkReceiveMode::Websocket,
|
||||
port: None,
|
||||
};
|
||||
let json = serde_json::to_string(&lc).unwrap();
|
||||
let parsed: LarkConfig = serde_json::from_str(&json).unwrap();
|
||||
assert_eq!(parsed.app_id, "cli_123456");
|
||||
assert_eq!(parsed.app_secret, "secret_abc");
|
||||
assert_eq!(parsed.encrypt_key.as_deref(), Some("encrypt_key"));
|
||||
assert_eq!(parsed.verification_token.as_deref(), Some("verify_token"));
|
||||
assert_eq!(parsed.allowed_users.len(), 2);
|
||||
assert!(parsed.use_feishu);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lark_config_toml_roundtrip() {
|
||||
let lc = LarkConfig {
|
||||
app_id: "cli_123456".into(),
|
||||
app_secret: "secret_abc".into(),
|
||||
encrypt_key: Some("encrypt_key".into()),
|
||||
verification_token: Some("verify_token".into()),
|
||||
allowed_users: vec!["*".into()],
|
||||
use_feishu: false,
|
||||
receive_mode: LarkReceiveMode::Webhook,
|
||||
port: Some(9898),
|
||||
};
|
||||
let toml_str = toml::to_string(&lc).unwrap();
|
||||
let parsed: LarkConfig = toml::from_str(&toml_str).unwrap();
|
||||
assert_eq!(parsed.app_id, "cli_123456");
|
||||
assert_eq!(parsed.app_secret, "secret_abc");
|
||||
assert!(!parsed.use_feishu);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lark_config_deserializes_without_optional_fields() {
|
||||
let json = r#"{"app_id":"cli_123","app_secret":"secret"}"#;
|
||||
let parsed: LarkConfig = serde_json::from_str(json).unwrap();
|
||||
assert!(parsed.encrypt_key.is_none());
|
||||
assert!(parsed.verification_token.is_none());
|
||||
assert!(parsed.allowed_users.is_empty());
|
||||
assert!(!parsed.use_feishu);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lark_config_defaults_to_lark_endpoint() {
|
||||
let json = r#"{"app_id":"cli_123","app_secret":"secret"}"#;
|
||||
let parsed: LarkConfig = serde_json::from_str(json).unwrap();
|
||||
assert!(
|
||||
!parsed.use_feishu,
|
||||
"use_feishu should default to false (Lark)"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn lark_config_with_wildcard_allowed_users() {
|
||||
let json = r#"{"app_id":"cli_123","app_secret":"secret","allowed_users":["*"]}"#;
|
||||
let parsed: LarkConfig = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(parsed.allowed_users, vec!["*"]);
|
||||
}
|
||||
|
||||
// ── Config file permission hardening (Unix only) ───────────────
|
||||
|
||||
#[cfg(unix)]
|
||||
#[test]
|
||||
fn new_config_file_has_restricted_permissions() {
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
|
||||
let tmp = tempfile::TempDir::new().unwrap();
|
||||
let config_path = tmp.path().join("config.toml");
|
||||
|
||||
// Create a config and save it
|
||||
let mut config = Config::default();
|
||||
config.config_path = config_path.clone();
|
||||
config.save().unwrap();
|
||||
|
||||
// Apply the same permission logic as load_or_init
|
||||
let _ = std::fs::set_permissions(&config_path, std::fs::Permissions::from_mode(0o600));
|
||||
|
||||
let meta = std::fs::metadata(&config_path).unwrap();
|
||||
let mode = meta.permissions().mode() & 0o777;
|
||||
assert_eq!(
|
||||
mode, 0o600,
|
||||
"New config file should be owner-only (0600), got {mode:o}"
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
#[test]
|
||||
fn world_readable_config_is_detectable() {
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
|
||||
let tmp = tempfile::TempDir::new().unwrap();
|
||||
let config_path = tmp.path().join("config.toml");
|
||||
|
||||
// Create a config file with intentionally loose permissions
|
||||
std::fs::write(&config_path, "# test config").unwrap();
|
||||
std::fs::set_permissions(&config_path, std::fs::Permissions::from_mode(0o644)).unwrap();
|
||||
|
||||
let meta = std::fs::metadata(&config_path).unwrap();
|
||||
let mode = meta.permissions().mode();
|
||||
assert!(
|
||||
mode & 0o004 != 0,
|
||||
"Test setup: file should be world-readable (mode {mode:o})"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -245,6 +245,7 @@ async fn deliver_if_configured(config: &Config, job: &CronJob, output: &str) ->
|
|||
dc.guild_id.clone(),
|
||||
dc.allowed_users.clone(),
|
||||
dc.listen_to_bots,
|
||||
dc.mention_only,
|
||||
);
|
||||
channel.send(output, target).await?;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -216,6 +216,7 @@ fn has_supervised_channels(config: &Config) -> bool {
|
|||
|| config.channels_config.matrix.is_some()
|
||||
|| config.channels_config.whatsapp.is_some()
|
||||
|| config.channels_config.email.is_some()
|
||||
|| config.channels_config.lark.is_some()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
|
|
|
|||
|
|
@ -49,6 +49,13 @@ fn whatsapp_memory_key(msg: &crate::channels::traits::ChannelMessage) -> String
|
|||
format!("whatsapp_{}_{}", msg.sender, msg.id)
|
||||
}
|
||||
|
||||
fn hash_webhook_secret(value: &str) -> String {
|
||||
use sha2::{Digest, Sha256};
|
||||
|
||||
let digest = Sha256::digest(value.as_bytes());
|
||||
hex::encode(digest)
|
||||
}
|
||||
|
||||
/// How often the rate limiter sweeps stale IP entries from its map.
|
||||
const RATE_LIMITER_SWEEP_INTERVAL_SECS: u64 = 300; // 5 minutes
|
||||
|
||||
|
|
@ -178,7 +185,8 @@ pub struct AppState {
|
|||
pub temperature: f64,
|
||||
pub mem: Arc<dyn Memory>,
|
||||
pub auto_save: bool,
|
||||
pub webhook_secret: Option<Arc<str>>,
|
||||
/// SHA-256 hash of `X-Webhook-Secret` (hex-encoded), never plaintext.
|
||||
pub webhook_secret_hash: Option<Arc<str>>,
|
||||
pub pairing: Arc<PairingGuard>,
|
||||
pub rate_limiter: Arc<GatewayRateLimiter>,
|
||||
pub idempotency_store: Arc<IdempotencyStore>,
|
||||
|
|
@ -208,6 +216,7 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
|
|||
let provider: Arc<dyn Provider> = Arc::from(providers::create_resilient_provider(
|
||||
config.default_provider.as_deref().unwrap_or("openrouter"),
|
||||
config.api_key.as_deref(),
|
||||
config.api_url.as_deref(),
|
||||
&config.reliability,
|
||||
)?);
|
||||
let model = config
|
||||
|
|
@ -251,12 +260,14 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
|
|||
&config,
|
||||
));
|
||||
// Extract webhook secret for authentication
|
||||
let webhook_secret: Option<Arc<str>> = config
|
||||
.channels_config
|
||||
.webhook
|
||||
.as_ref()
|
||||
.and_then(|w| w.secret.as_deref())
|
||||
.map(Arc::from);
|
||||
let webhook_secret_hash: Option<Arc<str>> =
|
||||
config.channels_config.webhook.as_ref().and_then(|webhook| {
|
||||
webhook.secret.as_ref().and_then(|raw_secret| {
|
||||
let trimmed_secret = raw_secret.trim();
|
||||
(!trimmed_secret.is_empty())
|
||||
.then(|| Arc::<str>::from(hash_webhook_secret(trimmed_secret)))
|
||||
})
|
||||
});
|
||||
|
||||
// WhatsApp channel (if configured)
|
||||
let whatsapp_channel: Option<Arc<WhatsAppChannel>> =
|
||||
|
|
@ -342,9 +353,6 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
|
|||
} else {
|
||||
println!(" ⚠️ Pairing: DISABLED (all requests accepted)");
|
||||
}
|
||||
if webhook_secret.is_some() {
|
||||
println!(" 🔒 Webhook secret: ENABLED");
|
||||
}
|
||||
println!(" Press Ctrl+C to stop.\n");
|
||||
|
||||
crate::health::mark_component_ok("gateway");
|
||||
|
|
@ -356,7 +364,7 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
|
|||
temperature,
|
||||
mem,
|
||||
auto_save: config.memory.auto_save,
|
||||
webhook_secret,
|
||||
webhook_secret_hash,
|
||||
pairing,
|
||||
rate_limiter,
|
||||
idempotency_store,
|
||||
|
|
@ -482,12 +490,15 @@ async fn handle_webhook(
|
|||
}
|
||||
|
||||
// ── Webhook secret auth (optional, additional layer) ──
|
||||
if let Some(ref secret) = state.webhook_secret {
|
||||
let header_val = headers
|
||||
if let Some(ref secret_hash) = state.webhook_secret_hash {
|
||||
let header_hash = headers
|
||||
.get("X-Webhook-Secret")
|
||||
.and_then(|v| v.to_str().ok());
|
||||
match header_val {
|
||||
Some(val) if constant_time_eq(val, secret.as_ref()) => {}
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.map(str::trim)
|
||||
.filter(|value| !value.is_empty())
|
||||
.map(hash_webhook_secret);
|
||||
match header_hash {
|
||||
Some(val) if constant_time_eq(&val, secret_hash.as_ref()) => {}
|
||||
_ => {
|
||||
tracing::warn!("Webhook: rejected request — invalid or missing X-Webhook-Secret");
|
||||
let err = serde_json::json!({"error": "Unauthorized — invalid or missing X-Webhook-Secret header"});
|
||||
|
|
@ -532,7 +543,7 @@ async fn handle_webhook(
|
|||
let key = webhook_memory_key();
|
||||
let _ = state
|
||||
.mem
|
||||
.store(&key, message, MemoryCategory::Conversation)
|
||||
.store(&key, message, MemoryCategory::Conversation, None)
|
||||
.await;
|
||||
}
|
||||
|
||||
|
|
@ -685,7 +696,7 @@ async fn handle_whatsapp_message(
|
|||
let key = whatsapp_memory_key(msg);
|
||||
let _ = state
|
||||
.mem
|
||||
.store(&key, &msg.content, MemoryCategory::Conversation)
|
||||
.store(&key, &msg.content, MemoryCategory::Conversation, None)
|
||||
.await;
|
||||
}
|
||||
|
||||
|
|
@ -697,7 +708,7 @@ async fn handle_whatsapp_message(
|
|||
{
|
||||
Ok(response) => {
|
||||
// Send reply via WhatsApp
|
||||
if let Err(e) = wa.send(&response, &msg.sender).await {
|
||||
if let Err(e) = wa.send(&response, &msg.reply_target).await {
|
||||
tracing::error!("Failed to send WhatsApp reply: {e}");
|
||||
}
|
||||
}
|
||||
|
|
@ -706,7 +717,7 @@ async fn handle_whatsapp_message(
|
|||
let _ = wa
|
||||
.send(
|
||||
"Sorry, I couldn't process your message right now.",
|
||||
&msg.sender,
|
||||
&msg.reply_target,
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
|
@ -798,7 +809,9 @@ mod tests {
|
|||
.requests
|
||||
.lock()
|
||||
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||
guard.1 = Instant::now() - Duration::from_secs(RATE_LIMITER_SWEEP_INTERVAL_SECS + 1);
|
||||
guard.1 = Instant::now()
|
||||
.checked_sub(Duration::from_secs(RATE_LIMITER_SWEEP_INTERVAL_SECS + 1))
|
||||
.unwrap();
|
||||
// Clear timestamps for ip-2 and ip-3 to simulate stale entries
|
||||
guard.0.get_mut("ip-2").unwrap().clear();
|
||||
guard.0.get_mut("ip-3").unwrap().clear();
|
||||
|
|
@ -848,6 +861,7 @@ mod tests {
|
|||
let msg = ChannelMessage {
|
||||
id: "wamid-123".into(),
|
||||
sender: "+1234567890".into(),
|
||||
reply_target: "+1234567890".into(),
|
||||
content: "hello".into(),
|
||||
channel: "whatsapp".into(),
|
||||
timestamp: 1,
|
||||
|
|
@ -871,11 +885,17 @@ mod tests {
|
|||
_key: &str,
|
||||
_content: &str,
|
||||
_category: MemoryCategory,
|
||||
_session_id: Option<&str>,
|
||||
) -> anyhow::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn recall(&self, _query: &str, _limit: usize) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
async fn recall(
|
||||
&self,
|
||||
_query: &str,
|
||||
_limit: usize,
|
||||
_session_id: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
Ok(Vec::new())
|
||||
}
|
||||
|
||||
|
|
@ -886,6 +906,7 @@ mod tests {
|
|||
async fn list(
|
||||
&self,
|
||||
_category: Option<&MemoryCategory>,
|
||||
_session_id: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
Ok(Vec::new())
|
||||
}
|
||||
|
|
@ -938,6 +959,7 @@ mod tests {
|
|||
key: &str,
|
||||
_content: &str,
|
||||
_category: MemoryCategory,
|
||||
_session_id: Option<&str>,
|
||||
) -> anyhow::Result<()> {
|
||||
self.keys
|
||||
.lock()
|
||||
|
|
@ -946,7 +968,12 @@ mod tests {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
async fn recall(&self, _query: &str, _limit: usize) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
async fn recall(
|
||||
&self,
|
||||
_query: &str,
|
||||
_limit: usize,
|
||||
_session_id: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
Ok(Vec::new())
|
||||
}
|
||||
|
||||
|
|
@ -957,6 +984,7 @@ mod tests {
|
|||
async fn list(
|
||||
&self,
|
||||
_category: Option<&MemoryCategory>,
|
||||
_session_id: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
Ok(Vec::new())
|
||||
}
|
||||
|
|
@ -991,7 +1019,7 @@ mod tests {
|
|||
temperature: 0.0,
|
||||
mem: memory,
|
||||
auto_save: false,
|
||||
webhook_secret: None,
|
||||
webhook_secret_hash: None,
|
||||
pairing: Arc::new(PairingGuard::new(false, &[])),
|
||||
rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100)),
|
||||
idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300))),
|
||||
|
|
@ -1039,7 +1067,7 @@ mod tests {
|
|||
temperature: 0.0,
|
||||
mem: memory,
|
||||
auto_save: true,
|
||||
webhook_secret: None,
|
||||
webhook_secret_hash: None,
|
||||
pairing: Arc::new(PairingGuard::new(false, &[])),
|
||||
rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100)),
|
||||
idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300))),
|
||||
|
|
@ -1077,6 +1105,125 @@ mod tests {
|
|||
assert_eq!(provider_impl.calls.load(Ordering::SeqCst), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn webhook_secret_hash_is_deterministic_and_nonempty() {
|
||||
let one = hash_webhook_secret("secret-value");
|
||||
let two = hash_webhook_secret("secret-value");
|
||||
let other = hash_webhook_secret("other-value");
|
||||
|
||||
assert_eq!(one, two);
|
||||
assert_ne!(one, other);
|
||||
assert_eq!(one.len(), 64);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn webhook_secret_hash_rejects_missing_header() {
|
||||
let provider_impl = Arc::new(MockProvider::default());
|
||||
let provider: Arc<dyn Provider> = provider_impl.clone();
|
||||
let memory: Arc<dyn Memory> = Arc::new(MockMemory);
|
||||
|
||||
let state = AppState {
|
||||
provider,
|
||||
model: "test-model".into(),
|
||||
temperature: 0.0,
|
||||
mem: memory,
|
||||
auto_save: false,
|
||||
webhook_secret_hash: Some(Arc::from(hash_webhook_secret("super-secret"))),
|
||||
pairing: Arc::new(PairingGuard::new(false, &[])),
|
||||
rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100)),
|
||||
idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300))),
|
||||
whatsapp: None,
|
||||
whatsapp_app_secret: None,
|
||||
};
|
||||
|
||||
let response = handle_webhook(
|
||||
State(state),
|
||||
HeaderMap::new(),
|
||||
Ok(Json(WebhookBody {
|
||||
message: "hello".into(),
|
||||
})),
|
||||
)
|
||||
.await
|
||||
.into_response();
|
||||
|
||||
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
|
||||
assert_eq!(provider_impl.calls.load(Ordering::SeqCst), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn webhook_secret_hash_rejects_invalid_header() {
|
||||
let provider_impl = Arc::new(MockProvider::default());
|
||||
let provider: Arc<dyn Provider> = provider_impl.clone();
|
||||
let memory: Arc<dyn Memory> = Arc::new(MockMemory);
|
||||
|
||||
let state = AppState {
|
||||
provider,
|
||||
model: "test-model".into(),
|
||||
temperature: 0.0,
|
||||
mem: memory,
|
||||
auto_save: false,
|
||||
webhook_secret_hash: Some(Arc::from(hash_webhook_secret("super-secret"))),
|
||||
pairing: Arc::new(PairingGuard::new(false, &[])),
|
||||
rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100)),
|
||||
idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300))),
|
||||
whatsapp: None,
|
||||
whatsapp_app_secret: None,
|
||||
};
|
||||
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert("X-Webhook-Secret", HeaderValue::from_static("wrong-secret"));
|
||||
|
||||
let response = handle_webhook(
|
||||
State(state),
|
||||
headers,
|
||||
Ok(Json(WebhookBody {
|
||||
message: "hello".into(),
|
||||
})),
|
||||
)
|
||||
.await
|
||||
.into_response();
|
||||
|
||||
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
|
||||
assert_eq!(provider_impl.calls.load(Ordering::SeqCst), 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn webhook_secret_hash_accepts_valid_header() {
|
||||
let provider_impl = Arc::new(MockProvider::default());
|
||||
let provider: Arc<dyn Provider> = provider_impl.clone();
|
||||
let memory: Arc<dyn Memory> = Arc::new(MockMemory);
|
||||
|
||||
let state = AppState {
|
||||
provider,
|
||||
model: "test-model".into(),
|
||||
temperature: 0.0,
|
||||
mem: memory,
|
||||
auto_save: false,
|
||||
webhook_secret_hash: Some(Arc::from(hash_webhook_secret("super-secret"))),
|
||||
pairing: Arc::new(PairingGuard::new(false, &[])),
|
||||
rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100)),
|
||||
idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300))),
|
||||
whatsapp: None,
|
||||
whatsapp_app_secret: None,
|
||||
};
|
||||
|
||||
let mut headers = HeaderMap::new();
|
||||
headers.insert("X-Webhook-Secret", HeaderValue::from_static("super-secret"));
|
||||
|
||||
let response = handle_webhook(
|
||||
State(state),
|
||||
headers,
|
||||
Ok(Json(WebhookBody {
|
||||
message: "hello".into(),
|
||||
})),
|
||||
)
|
||||
.await
|
||||
.into_response();
|
||||
|
||||
assert_eq!(response.status(), StatusCode::OK);
|
||||
assert_eq!(provider_impl.calls.load(Ordering::SeqCst), 1);
|
||||
}
|
||||
|
||||
// ══════════════════════════════════════════════════════════
|
||||
// WhatsApp Signature Verification Tests (CWE-345 Prevention)
|
||||
// ══════════════════════════════════════════════════════════
|
||||
|
|
|
|||
40
src/main.rs
40
src/main.rs
|
|
@ -34,8 +34,8 @@
|
|||
|
||||
use anyhow::{bail, Result};
|
||||
use clap::{Parser, Subcommand};
|
||||
use tracing::{info, Level};
|
||||
use tracing_subscriber::FmtSubscriber;
|
||||
use tracing::info;
|
||||
use tracing_subscriber::{fmt, EnvFilter};
|
||||
|
||||
mod agent;
|
||||
mod channels;
|
||||
|
|
@ -147,24 +147,24 @@ enum Commands {
|
|||
|
||||
/// Start the gateway server (webhooks, websockets)
|
||||
Gateway {
|
||||
/// Port to listen on (use 0 for random available port)
|
||||
#[arg(short, long, default_value = "8080")]
|
||||
port: u16,
|
||||
/// Port to listen on (use 0 for random available port); defaults to config gateway.port
|
||||
#[arg(short, long)]
|
||||
port: Option<u16>,
|
||||
|
||||
/// Host to bind to
|
||||
#[arg(long, default_value = "127.0.0.1")]
|
||||
host: String,
|
||||
/// Host to bind to; defaults to config gateway.host
|
||||
#[arg(long)]
|
||||
host: Option<String>,
|
||||
},
|
||||
|
||||
/// Start long-running autonomous runtime (gateway + channels + heartbeat + scheduler)
|
||||
Daemon {
|
||||
/// Port to listen on (use 0 for random available port)
|
||||
#[arg(short, long, default_value = "8080")]
|
||||
port: u16,
|
||||
/// Port to listen on (use 0 for random available port); defaults to config gateway.port
|
||||
#[arg(short, long)]
|
||||
port: Option<u16>,
|
||||
|
||||
/// Host to bind to
|
||||
#[arg(long, default_value = "127.0.0.1")]
|
||||
host: String,
|
||||
/// Host to bind to; defaults to config gateway.host
|
||||
#[arg(long)]
|
||||
host: Option<String>,
|
||||
},
|
||||
|
||||
/// Manage OS service lifecycle (launchd/systemd user service)
|
||||
|
|
@ -367,9 +367,11 @@ async fn main() -> Result<()> {
|
|||
|
||||
let cli = Cli::parse();
|
||||
|
||||
// Initialize logging
|
||||
let subscriber = FmtSubscriber::builder()
|
||||
.with_max_level(Level::INFO)
|
||||
// Initialize logging - respects RUST_LOG env var, defaults to INFO
|
||||
let subscriber = fmt::Subscriber::builder()
|
||||
.with_env_filter(
|
||||
EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")),
|
||||
)
|
||||
.finish();
|
||||
|
||||
tracing::subscriber::set_global_default(subscriber).expect("setting default subscriber failed");
|
||||
|
|
@ -434,6 +436,8 @@ async fn main() -> Result<()> {
|
|||
.map(|_| ()),
|
||||
|
||||
Commands::Gateway { port, host } => {
|
||||
let port = port.unwrap_or(config.gateway.port);
|
||||
let host = host.unwrap_or_else(|| config.gateway.host.clone());
|
||||
if port == 0 {
|
||||
info!("🚀 Starting ZeroClaw Gateway on {host} (random port)");
|
||||
} else {
|
||||
|
|
@ -443,6 +447,8 @@ async fn main() -> Result<()> {
|
|||
}
|
||||
|
||||
Commands::Daemon { port, host } => {
|
||||
let port = port.unwrap_or(config.gateway.port);
|
||||
let host = host.unwrap_or_else(|| config.gateway.host.clone());
|
||||
if port == 0 {
|
||||
info!("🧠 Starting ZeroClaw Daemon on {host} (random port)");
|
||||
} else {
|
||||
|
|
|
|||
|
|
@ -7,6 +7,7 @@ pub enum MemoryBackendKind {
|
|||
Unknown,
|
||||
}
|
||||
|
||||
#[allow(clippy::struct_excessive_bools)]
|
||||
#[derive(Debug, Clone, Copy, Eq, PartialEq)]
|
||||
pub struct MemoryBackendProfile {
|
||||
pub key: &'static str,
|
||||
|
|
|
|||
|
|
@ -502,10 +502,10 @@ mod tests {
|
|||
let workspace = tmp.path();
|
||||
|
||||
let mem = SqliteMemory::new(workspace).unwrap();
|
||||
mem.store("conv_old", "outdated", MemoryCategory::Conversation)
|
||||
mem.store("conv_old", "outdated", MemoryCategory::Conversation, None)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("core_keep", "durable", MemoryCategory::Core)
|
||||
mem.store("core_keep", "durable", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
drop(mem);
|
||||
|
|
|
|||
|
|
@ -24,7 +24,9 @@ pub struct LucidMemory {
|
|||
impl LucidMemory {
|
||||
const DEFAULT_LUCID_CMD: &'static str = "lucid";
|
||||
const DEFAULT_TOKEN_BUDGET: usize = 200;
|
||||
const DEFAULT_RECALL_TIMEOUT_MS: u64 = 120;
|
||||
// Lucid CLI cold start can exceed 120ms on slower machines, which causes
|
||||
// avoidable fallback to local-only memory and premature cooldown.
|
||||
const DEFAULT_RECALL_TIMEOUT_MS: u64 = 500;
|
||||
const DEFAULT_STORE_TIMEOUT_MS: u64 = 800;
|
||||
const DEFAULT_LOCAL_HIT_THRESHOLD: usize = 3;
|
||||
const DEFAULT_FAILURE_COOLDOWN_MS: u64 = 15_000;
|
||||
|
|
@ -74,6 +76,7 @@ impl LucidMemory {
|
|||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
fn with_options(
|
||||
workspace_dir: &Path,
|
||||
local: SqliteMemory,
|
||||
|
|
@ -307,14 +310,22 @@ impl Memory for LucidMemory {
|
|||
key: &str,
|
||||
content: &str,
|
||||
category: MemoryCategory,
|
||||
session_id: Option<&str>,
|
||||
) -> anyhow::Result<()> {
|
||||
self.local.store(key, content, category.clone()).await?;
|
||||
self.local
|
||||
.store(key, content, category.clone(), session_id)
|
||||
.await?;
|
||||
self.sync_to_lucid_async(key, content, &category).await;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn recall(&self, query: &str, limit: usize) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
let local_results = self.local.recall(query, limit).await?;
|
||||
async fn recall(
|
||||
&self,
|
||||
query: &str,
|
||||
limit: usize,
|
||||
session_id: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
let local_results = self.local.recall(query, limit, session_id).await?;
|
||||
if limit == 0
|
||||
|| local_results.len() >= limit
|
||||
|| local_results.len() >= self.local_hit_threshold
|
||||
|
|
@ -351,8 +362,12 @@ impl Memory for LucidMemory {
|
|||
self.local.get(key).await
|
||||
}
|
||||
|
||||
async fn list(&self, category: Option<&MemoryCategory>) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
self.local.list(category).await
|
||||
async fn list(
|
||||
&self,
|
||||
category: Option<&MemoryCategory>,
|
||||
session_id: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
self.local.list(category, session_id).await
|
||||
}
|
||||
|
||||
async fn forget(&self, key: &str) -> anyhow::Result<bool> {
|
||||
|
|
@ -396,6 +411,38 @@ EOF
|
|||
exit 0
|
||||
fi
|
||||
|
||||
echo "unsupported command" >&2
|
||||
exit 1
|
||||
"#;
|
||||
|
||||
fs::write(&script_path, script).unwrap();
|
||||
let mut perms = fs::metadata(&script_path).unwrap().permissions();
|
||||
perms.set_mode(0o755);
|
||||
fs::set_permissions(&script_path, perms).unwrap();
|
||||
script_path.display().to_string()
|
||||
}
|
||||
|
||||
fn write_delayed_lucid_script(dir: &Path) -> String {
|
||||
let script_path = dir.join("delayed-lucid.sh");
|
||||
let script = r#"#!/usr/bin/env bash
|
||||
set -euo pipefail
|
||||
|
||||
if [[ "${1:-}" == "store" ]]; then
|
||||
echo '{"success":true,"id":"mem_1"}'
|
||||
exit 0
|
||||
fi
|
||||
|
||||
if [[ "${1:-}" == "context" ]]; then
|
||||
# Simulate a cold start that is slower than 120ms but below the 500ms timeout.
|
||||
sleep 0.2
|
||||
cat <<'EOF'
|
||||
<lucid-context>
|
||||
- [decision] Delayed token refresh guidance
|
||||
</lucid-context>
|
||||
EOF
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "unsupported command" >&2
|
||||
exit 1
|
||||
"#;
|
||||
|
|
@ -449,7 +496,7 @@ exit 1
|
|||
cmd,
|
||||
200,
|
||||
3,
|
||||
Duration::from_millis(120),
|
||||
Duration::from_millis(500),
|
||||
Duration::from_millis(400),
|
||||
Duration::from_secs(2),
|
||||
)
|
||||
|
|
@ -468,7 +515,7 @@ exit 1
|
|||
let memory = test_memory(tmp.path(), "nonexistent-lucid-binary".to_string());
|
||||
|
||||
memory
|
||||
.store("lang", "User prefers Rust", MemoryCategory::Core)
|
||||
.store("lang", "User prefers Rust", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
|
|
@ -483,6 +530,30 @@ exit 1
|
|||
let fake_cmd = write_fake_lucid_script(tmp.path());
|
||||
let memory = test_memory(tmp.path(), fake_cmd);
|
||||
|
||||
memory
|
||||
.store(
|
||||
"local_note",
|
||||
"Local sqlite auth fallback note",
|
||||
MemoryCategory::Core,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let entries = memory.recall("auth", 5, None).await.unwrap();
|
||||
|
||||
assert!(entries
|
||||
.iter()
|
||||
.any(|e| e.content.contains("Local sqlite auth fallback note")));
|
||||
assert!(entries.iter().any(|e| e.content.contains("token refresh")));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn recall_handles_lucid_cold_start_delay_within_timeout() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let delayed_cmd = write_delayed_lucid_script(tmp.path());
|
||||
let memory = test_memory(tmp.path(), delayed_cmd);
|
||||
|
||||
memory
|
||||
.store(
|
||||
"local_note",
|
||||
|
|
@ -497,7 +568,9 @@ exit 1
|
|||
assert!(entries
|
||||
.iter()
|
||||
.any(|e| e.content.contains("Local sqlite auth fallback note")));
|
||||
assert!(entries.iter().any(|e| e.content.contains("token refresh")));
|
||||
assert!(entries
|
||||
.iter()
|
||||
.any(|e| e.content.contains("Delayed token refresh guidance")));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
|
@ -513,17 +586,22 @@ exit 1
|
|||
probe_cmd,
|
||||
200,
|
||||
1,
|
||||
Duration::from_millis(120),
|
||||
Duration::from_millis(500),
|
||||
Duration::from_millis(400),
|
||||
Duration::from_secs(2),
|
||||
);
|
||||
|
||||
memory
|
||||
.store("pref", "Rust should stay local-first", MemoryCategory::Core)
|
||||
.store(
|
||||
"pref",
|
||||
"Rust should stay local-first",
|
||||
MemoryCategory::Core,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let entries = memory.recall("rust", 5).await.unwrap();
|
||||
let entries = memory.recall("rust", 5, None).await.unwrap();
|
||||
assert!(entries
|
||||
.iter()
|
||||
.any(|e| e.content.contains("Rust should stay local-first")));
|
||||
|
|
@ -578,13 +656,13 @@ exit 1
|
|||
failing_cmd,
|
||||
200,
|
||||
99,
|
||||
Duration::from_millis(120),
|
||||
Duration::from_millis(500),
|
||||
Duration::from_millis(400),
|
||||
Duration::from_secs(5),
|
||||
);
|
||||
|
||||
let first = memory.recall("auth", 5).await.unwrap();
|
||||
let second = memory.recall("auth", 5).await.unwrap();
|
||||
let first = memory.recall("auth", 5, None).await.unwrap();
|
||||
let second = memory.recall("auth", 5, None).await.unwrap();
|
||||
|
||||
assert!(first.is_empty());
|
||||
assert!(second.is_empty());
|
||||
|
|
|
|||
|
|
@ -143,6 +143,7 @@ impl Memory for MarkdownMemory {
|
|||
key: &str,
|
||||
content: &str,
|
||||
category: MemoryCategory,
|
||||
_session_id: Option<&str>,
|
||||
) -> anyhow::Result<()> {
|
||||
let entry = format!("- **{key}**: {content}");
|
||||
let path = match category {
|
||||
|
|
@ -152,7 +153,12 @@ impl Memory for MarkdownMemory {
|
|||
self.append_to_file(&path, &entry).await
|
||||
}
|
||||
|
||||
async fn recall(&self, query: &str, limit: usize) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
async fn recall(
|
||||
&self,
|
||||
query: &str,
|
||||
limit: usize,
|
||||
_session_id: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
let all = self.read_all_entries().await?;
|
||||
let query_lower = query.to_lowercase();
|
||||
let keywords: Vec<&str> = query_lower.split_whitespace().collect();
|
||||
|
|
@ -192,7 +198,11 @@ impl Memory for MarkdownMemory {
|
|||
.find(|e| e.key == key || e.content.contains(key)))
|
||||
}
|
||||
|
||||
async fn list(&self, category: Option<&MemoryCategory>) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
async fn list(
|
||||
&self,
|
||||
category: Option<&MemoryCategory>,
|
||||
_session_id: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
let all = self.read_all_entries().await?;
|
||||
match category {
|
||||
Some(cat) => Ok(all.into_iter().filter(|e| &e.category == cat).collect()),
|
||||
|
|
@ -243,7 +253,7 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn markdown_store_core() {
|
||||
let (_tmp, mem) = temp_workspace();
|
||||
mem.store("pref", "User likes Rust", MemoryCategory::Core)
|
||||
mem.store("pref", "User likes Rust", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
let content = sync_fs::read_to_string(mem.core_path()).unwrap();
|
||||
|
|
@ -253,7 +263,7 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn markdown_store_daily() {
|
||||
let (_tmp, mem) = temp_workspace();
|
||||
mem.store("note", "Finished tests", MemoryCategory::Daily)
|
||||
mem.store("note", "Finished tests", MemoryCategory::Daily, None)
|
||||
.await
|
||||
.unwrap();
|
||||
let path = mem.daily_path();
|
||||
|
|
@ -264,17 +274,17 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn markdown_recall_keyword() {
|
||||
let (_tmp, mem) = temp_workspace();
|
||||
mem.store("a", "Rust is fast", MemoryCategory::Core)
|
||||
mem.store("a", "Rust is fast", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("b", "Python is slow", MemoryCategory::Core)
|
||||
mem.store("b", "Python is slow", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("c", "Rust and safety", MemoryCategory::Core)
|
||||
mem.store("c", "Rust and safety", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let results = mem.recall("Rust", 10).await.unwrap();
|
||||
let results = mem.recall("Rust", 10, None).await.unwrap();
|
||||
assert!(results.len() >= 2);
|
||||
assert!(results
|
||||
.iter()
|
||||
|
|
@ -284,18 +294,20 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn markdown_recall_no_match() {
|
||||
let (_tmp, mem) = temp_workspace();
|
||||
mem.store("a", "Rust is great", MemoryCategory::Core)
|
||||
mem.store("a", "Rust is great", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
let results = mem.recall("javascript", 10).await.unwrap();
|
||||
let results = mem.recall("javascript", 10, None).await.unwrap();
|
||||
assert!(results.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn markdown_count() {
|
||||
let (_tmp, mem) = temp_workspace();
|
||||
mem.store("a", "first", MemoryCategory::Core).await.unwrap();
|
||||
mem.store("b", "second", MemoryCategory::Core)
|
||||
mem.store("a", "first", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("b", "second", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
let count = mem.count().await.unwrap();
|
||||
|
|
@ -305,24 +317,24 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn markdown_list_by_category() {
|
||||
let (_tmp, mem) = temp_workspace();
|
||||
mem.store("a", "core fact", MemoryCategory::Core)
|
||||
mem.store("a", "core fact", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("b", "daily note", MemoryCategory::Daily)
|
||||
mem.store("b", "daily note", MemoryCategory::Daily, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let core = mem.list(Some(&MemoryCategory::Core)).await.unwrap();
|
||||
let core = mem.list(Some(&MemoryCategory::Core), None).await.unwrap();
|
||||
assert!(core.iter().all(|e| e.category == MemoryCategory::Core));
|
||||
|
||||
let daily = mem.list(Some(&MemoryCategory::Daily)).await.unwrap();
|
||||
let daily = mem.list(Some(&MemoryCategory::Daily), None).await.unwrap();
|
||||
assert!(daily.iter().all(|e| e.category == MemoryCategory::Daily));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn markdown_forget_is_noop() {
|
||||
let (_tmp, mem) = temp_workspace();
|
||||
mem.store("a", "permanent", MemoryCategory::Core)
|
||||
mem.store("a", "permanent", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
let removed = mem.forget("a").await.unwrap();
|
||||
|
|
@ -332,7 +344,7 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn markdown_empty_recall() {
|
||||
let (_tmp, mem) = temp_workspace();
|
||||
let results = mem.recall("anything", 10).await.unwrap();
|
||||
let results = mem.recall("anything", 10, None).await.unwrap();
|
||||
assert!(results.is_empty());
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -25,11 +25,17 @@ impl Memory for NoneMemory {
|
|||
_key: &str,
|
||||
_content: &str,
|
||||
_category: MemoryCategory,
|
||||
_session_id: Option<&str>,
|
||||
) -> anyhow::Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn recall(&self, _query: &str, _limit: usize) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
async fn recall(
|
||||
&self,
|
||||
_query: &str,
|
||||
_limit: usize,
|
||||
_session_id: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
Ok(Vec::new())
|
||||
}
|
||||
|
||||
|
|
@ -37,7 +43,11 @@ impl Memory for NoneMemory {
|
|||
Ok(None)
|
||||
}
|
||||
|
||||
async fn list(&self, _category: Option<&MemoryCategory>) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
async fn list(
|
||||
&self,
|
||||
_category: Option<&MemoryCategory>,
|
||||
_session_id: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
Ok(Vec::new())
|
||||
}
|
||||
|
||||
|
|
@ -62,11 +72,14 @@ mod tests {
|
|||
async fn none_memory_is_noop() {
|
||||
let memory = NoneMemory::new();
|
||||
|
||||
memory.store("k", "v", MemoryCategory::Core).await.unwrap();
|
||||
memory
|
||||
.store("k", "v", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(memory.get("k").await.unwrap().is_none());
|
||||
assert!(memory.recall("k", 10).await.unwrap().is_empty());
|
||||
assert!(memory.list(None).await.unwrap().is_empty());
|
||||
assert!(memory.recall("k", 10, None).await.unwrap().is_empty());
|
||||
assert!(memory.list(None, None).await.unwrap().is_empty());
|
||||
assert!(!memory.forget("k").await.unwrap());
|
||||
assert_eq!(memory.count().await.unwrap(), 0);
|
||||
assert!(memory.health_check().await);
|
||||
|
|
|
|||
|
|
@ -157,7 +157,7 @@ impl ResponseCache {
|
|||
|row| row.get(0),
|
||||
)?;
|
||||
|
||||
#[allow(clippy::cast_sign_loss)]
|
||||
#[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)]
|
||||
Ok((count as usize, hits as u64, tokens_saved as u64))
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -124,6 +124,19 @@ impl SqliteMemory {
|
|||
);
|
||||
CREATE INDEX IF NOT EXISTS idx_cache_accessed ON embedding_cache(accessed_at);",
|
||||
)?;
|
||||
|
||||
// Migration: add session_id column if not present (safe to run repeatedly)
|
||||
let has_session_id: bool = conn
|
||||
.prepare("SELECT sql FROM sqlite_master WHERE type='table' AND name='memories'")?
|
||||
.query_row([], |row| row.get::<_, String>(0))?
|
||||
.contains("session_id");
|
||||
if !has_session_id {
|
||||
conn.execute_batch(
|
||||
"ALTER TABLE memories ADD COLUMN session_id TEXT;
|
||||
CREATE INDEX IF NOT EXISTS idx_memories_session ON memories(session_id);",
|
||||
)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
|
|
@ -361,6 +374,7 @@ impl Memory for SqliteMemory {
|
|||
key: &str,
|
||||
content: &str,
|
||||
category: MemoryCategory,
|
||||
session_id: Option<&str>,
|
||||
) -> anyhow::Result<()> {
|
||||
// Compute embedding (async, before lock)
|
||||
let embedding_bytes = self
|
||||
|
|
@ -377,20 +391,26 @@ impl Memory for SqliteMemory {
|
|||
let id = Uuid::new_v4().to_string();
|
||||
|
||||
conn.execute(
|
||||
"INSERT INTO memories (id, key, content, category, embedding, created_at, updated_at)
|
||||
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)
|
||||
"INSERT INTO memories (id, key, content, category, embedding, created_at, updated_at, session_id)
|
||||
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)
|
||||
ON CONFLICT(key) DO UPDATE SET
|
||||
content = excluded.content,
|
||||
category = excluded.category,
|
||||
embedding = excluded.embedding,
|
||||
updated_at = excluded.updated_at",
|
||||
params![id, key, content, cat, embedding_bytes, now, now],
|
||||
updated_at = excluded.updated_at,
|
||||
session_id = excluded.session_id",
|
||||
params![id, key, content, cat, embedding_bytes, now, now, session_id],
|
||||
)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn recall(&self, query: &str, limit: usize) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
async fn recall(
|
||||
&self,
|
||||
query: &str,
|
||||
limit: usize,
|
||||
session_id: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
if query.trim().is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
|
@ -439,7 +459,7 @@ impl Memory for SqliteMemory {
|
|||
let mut results = Vec::new();
|
||||
for scored in &merged {
|
||||
let mut stmt = conn.prepare(
|
||||
"SELECT id, key, content, category, created_at FROM memories WHERE id = ?1",
|
||||
"SELECT id, key, content, category, created_at, session_id FROM memories WHERE id = ?1",
|
||||
)?;
|
||||
if let Ok(entry) = stmt.query_row(params![scored.id], |row| {
|
||||
Ok(MemoryEntry {
|
||||
|
|
@ -448,10 +468,16 @@ impl Memory for SqliteMemory {
|
|||
content: row.get(2)?,
|
||||
category: Self::str_to_category(&row.get::<_, String>(3)?),
|
||||
timestamp: row.get(4)?,
|
||||
session_id: None,
|
||||
session_id: row.get(5)?,
|
||||
score: Some(f64::from(scored.final_score)),
|
||||
})
|
||||
}) {
|
||||
// Filter by session_id if requested
|
||||
if let Some(sid) = session_id {
|
||||
if entry.session_id.as_deref() != Some(sid) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
results.push(entry);
|
||||
}
|
||||
}
|
||||
|
|
@ -470,7 +496,7 @@ impl Memory for SqliteMemory {
|
|||
.collect();
|
||||
let where_clause = conditions.join(" OR ");
|
||||
let sql = format!(
|
||||
"SELECT id, key, content, category, created_at FROM memories
|
||||
"SELECT id, key, content, category, created_at, session_id FROM memories
|
||||
WHERE {where_clause}
|
||||
ORDER BY updated_at DESC
|
||||
LIMIT ?{}",
|
||||
|
|
@ -493,12 +519,18 @@ impl Memory for SqliteMemory {
|
|||
content: row.get(2)?,
|
||||
category: Self::str_to_category(&row.get::<_, String>(3)?),
|
||||
timestamp: row.get(4)?,
|
||||
session_id: None,
|
||||
session_id: row.get(5)?,
|
||||
score: Some(1.0),
|
||||
})
|
||||
})?;
|
||||
for row in rows {
|
||||
results.push(row?);
|
||||
let entry = row?;
|
||||
if let Some(sid) = session_id {
|
||||
if entry.session_id.as_deref() != Some(sid) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
results.push(entry);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -514,7 +546,7 @@ impl Memory for SqliteMemory {
|
|||
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
|
||||
|
||||
let mut stmt = conn.prepare(
|
||||
"SELECT id, key, content, category, created_at FROM memories WHERE key = ?1",
|
||||
"SELECT id, key, content, category, created_at, session_id FROM memories WHERE key = ?1",
|
||||
)?;
|
||||
|
||||
let mut rows = stmt.query_map(params![key], |row| {
|
||||
|
|
@ -524,7 +556,7 @@ impl Memory for SqliteMemory {
|
|||
content: row.get(2)?,
|
||||
category: Self::str_to_category(&row.get::<_, String>(3)?),
|
||||
timestamp: row.get(4)?,
|
||||
session_id: None,
|
||||
session_id: row.get(5)?,
|
||||
score: None,
|
||||
})
|
||||
})?;
|
||||
|
|
@ -535,7 +567,11 @@ impl Memory for SqliteMemory {
|
|||
}
|
||||
}
|
||||
|
||||
async fn list(&self, category: Option<&MemoryCategory>) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
async fn list(
|
||||
&self,
|
||||
category: Option<&MemoryCategory>,
|
||||
session_id: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||
let conn = self
|
||||
.conn
|
||||
.lock()
|
||||
|
|
@ -550,7 +586,7 @@ impl Memory for SqliteMemory {
|
|||
content: row.get(2)?,
|
||||
category: Self::str_to_category(&row.get::<_, String>(3)?),
|
||||
timestamp: row.get(4)?,
|
||||
session_id: None,
|
||||
session_id: row.get(5)?,
|
||||
score: None,
|
||||
})
|
||||
};
|
||||
|
|
@ -558,21 +594,33 @@ impl Memory for SqliteMemory {
|
|||
if let Some(cat) = category {
|
||||
let cat_str = Self::category_to_str(cat);
|
||||
let mut stmt = conn.prepare(
|
||||
"SELECT id, key, content, category, created_at FROM memories
|
||||
"SELECT id, key, content, category, created_at, session_id FROM memories
|
||||
WHERE category = ?1 ORDER BY updated_at DESC",
|
||||
)?;
|
||||
let rows = stmt.query_map(params![cat_str], row_mapper)?;
|
||||
for row in rows {
|
||||
results.push(row?);
|
||||
let entry = row?;
|
||||
if let Some(sid) = session_id {
|
||||
if entry.session_id.as_deref() != Some(sid) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
results.push(entry);
|
||||
}
|
||||
} else {
|
||||
let mut stmt = conn.prepare(
|
||||
"SELECT id, key, content, category, created_at FROM memories
|
||||
"SELECT id, key, content, category, created_at, session_id FROM memories
|
||||
ORDER BY updated_at DESC",
|
||||
)?;
|
||||
let rows = stmt.query_map([], row_mapper)?;
|
||||
for row in rows {
|
||||
results.push(row?);
|
||||
let entry = row?;
|
||||
if let Some(sid) = session_id {
|
||||
if entry.session_id.as_deref() != Some(sid) {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
results.push(entry);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -632,7 +680,7 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn sqlite_store_and_get() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("user_lang", "Prefers Rust", MemoryCategory::Core)
|
||||
mem.store("user_lang", "Prefers Rust", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
|
|
@ -647,10 +695,10 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn sqlite_store_upsert() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("pref", "likes Rust", MemoryCategory::Core)
|
||||
mem.store("pref", "likes Rust", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("pref", "loves Rust", MemoryCategory::Core)
|
||||
mem.store("pref", "loves Rust", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
|
|
@ -662,17 +710,22 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn sqlite_recall_keyword() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("a", "Rust is fast and safe", MemoryCategory::Core)
|
||||
mem.store("a", "Rust is fast and safe", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("b", "Python is interpreted", MemoryCategory::Core)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("c", "Rust has zero-cost abstractions", MemoryCategory::Core)
|
||||
mem.store("b", "Python is interpreted", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store(
|
||||
"c",
|
||||
"Rust has zero-cost abstractions",
|
||||
MemoryCategory::Core,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let results = mem.recall("Rust", 10).await.unwrap();
|
||||
let results = mem.recall("Rust", 10, None).await.unwrap();
|
||||
assert_eq!(results.len(), 2);
|
||||
assert!(results
|
||||
.iter()
|
||||
|
|
@ -682,14 +735,14 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn sqlite_recall_multi_keyword() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("a", "Rust is fast", MemoryCategory::Core)
|
||||
mem.store("a", "Rust is fast", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("b", "Rust is safe and fast", MemoryCategory::Core)
|
||||
mem.store("b", "Rust is safe and fast", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let results = mem.recall("fast safe", 10).await.unwrap();
|
||||
let results = mem.recall("fast safe", 10, None).await.unwrap();
|
||||
assert!(!results.is_empty());
|
||||
// Entry with both keywords should score higher
|
||||
assert!(results[0].content.contains("safe") && results[0].content.contains("fast"));
|
||||
|
|
@ -698,17 +751,17 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn sqlite_recall_no_match() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("a", "Rust rocks", MemoryCategory::Core)
|
||||
mem.store("a", "Rust rocks", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
let results = mem.recall("javascript", 10).await.unwrap();
|
||||
let results = mem.recall("javascript", 10, None).await.unwrap();
|
||||
assert!(results.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn sqlite_forget() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("temp", "temporary data", MemoryCategory::Conversation)
|
||||
mem.store("temp", "temporary data", MemoryCategory::Conversation, None)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(mem.count().await.unwrap(), 1);
|
||||
|
|
@ -728,29 +781,37 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn sqlite_list_all() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("a", "one", MemoryCategory::Core).await.unwrap();
|
||||
mem.store("b", "two", MemoryCategory::Daily).await.unwrap();
|
||||
mem.store("c", "three", MemoryCategory::Conversation)
|
||||
mem.store("a", "one", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("b", "two", MemoryCategory::Daily, None)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("c", "three", MemoryCategory::Conversation, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let all = mem.list(None).await.unwrap();
|
||||
let all = mem.list(None, None).await.unwrap();
|
||||
assert_eq!(all.len(), 3);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn sqlite_list_by_category() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("a", "core1", MemoryCategory::Core).await.unwrap();
|
||||
mem.store("b", "core2", MemoryCategory::Core).await.unwrap();
|
||||
mem.store("c", "daily1", MemoryCategory::Daily)
|
||||
mem.store("a", "core1", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("b", "core2", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("c", "daily1", MemoryCategory::Daily, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let core = mem.list(Some(&MemoryCategory::Core)).await.unwrap();
|
||||
let core = mem.list(Some(&MemoryCategory::Core), None).await.unwrap();
|
||||
assert_eq!(core.len(), 2);
|
||||
|
||||
let daily = mem.list(Some(&MemoryCategory::Daily)).await.unwrap();
|
||||
let daily = mem.list(Some(&MemoryCategory::Daily), None).await.unwrap();
|
||||
assert_eq!(daily.len(), 1);
|
||||
}
|
||||
|
||||
|
|
@ -772,7 +833,7 @@ mod tests {
|
|||
|
||||
{
|
||||
let mem = SqliteMemory::new(tmp.path()).unwrap();
|
||||
mem.store("persist", "I survive restarts", MemoryCategory::Core)
|
||||
mem.store("persist", "I survive restarts", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
|
@ -795,7 +856,7 @@ mod tests {
|
|||
];
|
||||
|
||||
for (i, cat) in categories.iter().enumerate() {
|
||||
mem.store(&format!("k{i}"), &format!("v{i}"), cat.clone())
|
||||
mem.store(&format!("k{i}"), &format!("v{i}"), cat.clone(), None)
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
|
@ -815,21 +876,28 @@ mod tests {
|
|||
"a",
|
||||
"Rust is a systems programming language",
|
||||
MemoryCategory::Core,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store(
|
||||
"b",
|
||||
"Python is great for scripting",
|
||||
MemoryCategory::Core,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("b", "Python is great for scripting", MemoryCategory::Core)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store(
|
||||
"c",
|
||||
"Rust and Rust and Rust everywhere",
|
||||
MemoryCategory::Core,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let results = mem.recall("Rust", 10).await.unwrap();
|
||||
let results = mem.recall("Rust", 10, None).await.unwrap();
|
||||
assert!(results.len() >= 2);
|
||||
// All results should contain "Rust"
|
||||
for r in &results {
|
||||
|
|
@ -844,17 +912,17 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn fts5_multi_word_query() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("a", "The quick brown fox jumps", MemoryCategory::Core)
|
||||
mem.store("a", "The quick brown fox jumps", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("b", "A lazy dog sleeps", MemoryCategory::Core)
|
||||
mem.store("b", "A lazy dog sleeps", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("c", "The quick dog runs fast", MemoryCategory::Core)
|
||||
mem.store("c", "The quick dog runs fast", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let results = mem.recall("quick dog", 10).await.unwrap();
|
||||
let results = mem.recall("quick dog", 10, None).await.unwrap();
|
||||
assert!(!results.is_empty());
|
||||
// "The quick dog runs fast" matches both terms
|
||||
assert!(results[0].content.contains("quick"));
|
||||
|
|
@ -863,16 +931,20 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn recall_empty_query_returns_empty() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("a", "data", MemoryCategory::Core).await.unwrap();
|
||||
let results = mem.recall("", 10).await.unwrap();
|
||||
mem.store("a", "data", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
let results = mem.recall("", 10, None).await.unwrap();
|
||||
assert!(results.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn recall_whitespace_query_returns_empty() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("a", "data", MemoryCategory::Core).await.unwrap();
|
||||
let results = mem.recall(" ", 10).await.unwrap();
|
||||
mem.store("a", "data", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
let results = mem.recall(" ", 10, None).await.unwrap();
|
||||
assert!(results.is_empty());
|
||||
}
|
||||
|
||||
|
|
@ -937,9 +1009,14 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn fts5_syncs_on_insert() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("test_key", "unique_searchterm_xyz", MemoryCategory::Core)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store(
|
||||
"test_key",
|
||||
"unique_searchterm_xyz",
|
||||
MemoryCategory::Core,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let conn = mem.conn.lock();
|
||||
let count: i64 = conn
|
||||
|
|
@ -955,9 +1032,14 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn fts5_syncs_on_delete() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("del_key", "deletable_content_abc", MemoryCategory::Core)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store(
|
||||
"del_key",
|
||||
"deletable_content_abc",
|
||||
MemoryCategory::Core,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.forget("del_key").await.unwrap();
|
||||
|
||||
let conn = mem.conn.lock();
|
||||
|
|
@ -974,10 +1056,15 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn fts5_syncs_on_update() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("upd_key", "original_content_111", MemoryCategory::Core)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("upd_key", "updated_content_222", MemoryCategory::Core)
|
||||
mem.store(
|
||||
"upd_key",
|
||||
"original_content_111",
|
||||
MemoryCategory::Core,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("upd_key", "updated_content_222", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
|
|
@ -1019,10 +1106,10 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn reindex_rebuilds_fts() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("r1", "reindex test alpha", MemoryCategory::Core)
|
||||
mem.store("r1", "reindex test alpha", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("r2", "reindex test beta", MemoryCategory::Core)
|
||||
mem.store("r2", "reindex test beta", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
|
|
@ -1031,7 +1118,7 @@ mod tests {
|
|||
assert_eq!(count, 0);
|
||||
|
||||
// FTS should still work after rebuild
|
||||
let results = mem.recall("reindex", 10).await.unwrap();
|
||||
let results = mem.recall("reindex", 10, None).await.unwrap();
|
||||
assert_eq!(results.len(), 2);
|
||||
}
|
||||
|
||||
|
|
@ -1045,12 +1132,13 @@ mod tests {
|
|||
&format!("k{i}"),
|
||||
&format!("common keyword item {i}"),
|
||||
MemoryCategory::Core,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
let results = mem.recall("common keyword", 5).await.unwrap();
|
||||
let results = mem.recall("common keyword", 5, None).await.unwrap();
|
||||
assert!(results.len() <= 5);
|
||||
}
|
||||
|
||||
|
|
@ -1059,11 +1147,11 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn recall_results_have_scores() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("s1", "scored result test", MemoryCategory::Core)
|
||||
mem.store("s1", "scored result test", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let results = mem.recall("scored", 10).await.unwrap();
|
||||
let results = mem.recall("scored", 10, None).await.unwrap();
|
||||
assert!(!results.is_empty());
|
||||
for r in &results {
|
||||
assert!(r.score.is_some(), "Expected score on result: {:?}", r.key);
|
||||
|
|
@ -1075,11 +1163,11 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn recall_with_quotes_in_query() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("q1", "He said hello world", MemoryCategory::Core)
|
||||
mem.store("q1", "He said hello world", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
// Quotes in query should not crash FTS5
|
||||
let results = mem.recall("\"hello\"", 10).await.unwrap();
|
||||
let results = mem.recall("\"hello\"", 10, None).await.unwrap();
|
||||
// May or may not match depending on FTS5 escaping, but must not error
|
||||
assert!(results.len() <= 10);
|
||||
}
|
||||
|
|
@ -1087,31 +1175,34 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn recall_with_asterisk_in_query() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("a1", "wildcard test content", MemoryCategory::Core)
|
||||
mem.store("a1", "wildcard test content", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
let results = mem.recall("wild*", 10).await.unwrap();
|
||||
let results = mem.recall("wild*", 10, None).await.unwrap();
|
||||
assert!(results.len() <= 10);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn recall_with_parentheses_in_query() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("p1", "function call test", MemoryCategory::Core)
|
||||
mem.store("p1", "function call test", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
let results = mem.recall("function()", 10).await.unwrap();
|
||||
let results = mem.recall("function()", 10, None).await.unwrap();
|
||||
assert!(results.len() <= 10);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn recall_with_sql_injection_attempt() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("safe", "normal content", MemoryCategory::Core)
|
||||
mem.store("safe", "normal content", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
// Should not crash or leak data
|
||||
let results = mem.recall("'; DROP TABLE memories; --", 10).await.unwrap();
|
||||
let results = mem
|
||||
.recall("'; DROP TABLE memories; --", 10, None)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(results.len() <= 10);
|
||||
// Table should still exist
|
||||
assert_eq!(mem.count().await.unwrap(), 1);
|
||||
|
|
@ -1122,7 +1213,9 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn store_empty_content() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("empty", "", MemoryCategory::Core).await.unwrap();
|
||||
mem.store("empty", "", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
let entry = mem.get("empty").await.unwrap().unwrap();
|
||||
assert_eq!(entry.content, "");
|
||||
}
|
||||
|
|
@ -1130,7 +1223,7 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn store_empty_key() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("", "content for empty key", MemoryCategory::Core)
|
||||
mem.store("", "content for empty key", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
let entry = mem.get("").await.unwrap().unwrap();
|
||||
|
|
@ -1141,7 +1234,7 @@ mod tests {
|
|||
async fn store_very_long_content() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
let long_content = "x".repeat(100_000);
|
||||
mem.store("long", &long_content, MemoryCategory::Core)
|
||||
mem.store("long", &long_content, MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
let entry = mem.get("long").await.unwrap().unwrap();
|
||||
|
|
@ -1151,9 +1244,14 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn store_unicode_and_emoji() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("emoji_key_🦀", "こんにちは 🚀 Ñoño", MemoryCategory::Core)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store(
|
||||
"emoji_key_🦀",
|
||||
"こんにちは 🚀 Ñoño",
|
||||
MemoryCategory::Core,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
let entry = mem.get("emoji_key_🦀").await.unwrap().unwrap();
|
||||
assert_eq!(entry.content, "こんにちは 🚀 Ñoño");
|
||||
}
|
||||
|
|
@ -1162,7 +1260,7 @@ mod tests {
|
|||
async fn store_content_with_newlines_and_tabs() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
let content = "line1\nline2\ttab\rcarriage\n\nnewparagraph";
|
||||
mem.store("whitespace", content, MemoryCategory::Core)
|
||||
mem.store("whitespace", content, MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
let entry = mem.get("whitespace").await.unwrap().unwrap();
|
||||
|
|
@ -1174,11 +1272,11 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn recall_single_character_query() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("a", "x marks the spot", MemoryCategory::Core)
|
||||
mem.store("a", "x marks the spot", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
// Single char may not match FTS5 but LIKE fallback should work
|
||||
let results = mem.recall("x", 10).await.unwrap();
|
||||
let results = mem.recall("x", 10, None).await.unwrap();
|
||||
// Should not crash; may or may not find results
|
||||
assert!(results.len() <= 10);
|
||||
}
|
||||
|
|
@ -1186,23 +1284,23 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn recall_limit_zero() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("a", "some content", MemoryCategory::Core)
|
||||
mem.store("a", "some content", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
let results = mem.recall("some", 0).await.unwrap();
|
||||
let results = mem.recall("some", 0, None).await.unwrap();
|
||||
assert!(results.is_empty());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn recall_limit_one() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("a", "matching content alpha", MemoryCategory::Core)
|
||||
mem.store("a", "matching content alpha", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("b", "matching content beta", MemoryCategory::Core)
|
||||
mem.store("b", "matching content beta", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
let results = mem.recall("matching content", 1).await.unwrap();
|
||||
let results = mem.recall("matching content", 1, None).await.unwrap();
|
||||
assert_eq!(results.len(), 1);
|
||||
}
|
||||
|
||||
|
|
@ -1213,21 +1311,22 @@ mod tests {
|
|||
"rust_preferences",
|
||||
"User likes systems programming",
|
||||
MemoryCategory::Core,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
// "rust" appears in key but not content — LIKE fallback checks key too
|
||||
let results = mem.recall("rust", 10).await.unwrap();
|
||||
let results = mem.recall("rust", 10, None).await.unwrap();
|
||||
assert!(!results.is_empty(), "Should match by key");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn recall_unicode_query() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("jp", "日本語のテスト", MemoryCategory::Core)
|
||||
mem.store("jp", "日本語のテスト", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
let results = mem.recall("日本語", 10).await.unwrap();
|
||||
let results = mem.recall("日本語", 10, None).await.unwrap();
|
||||
assert!(!results.is_empty());
|
||||
}
|
||||
|
||||
|
|
@ -1238,7 +1337,9 @@ mod tests {
|
|||
let tmp = TempDir::new().unwrap();
|
||||
{
|
||||
let mem = SqliteMemory::new(tmp.path()).unwrap();
|
||||
mem.store("k1", "v1", MemoryCategory::Core).await.unwrap();
|
||||
mem.store("k1", "v1", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
// Open again — init_schema runs again on existing DB
|
||||
let mem2 = SqliteMemory::new(tmp.path()).unwrap();
|
||||
|
|
@ -1246,7 +1347,9 @@ mod tests {
|
|||
assert!(entry.is_some());
|
||||
assert_eq!(entry.unwrap().content, "v1");
|
||||
// Store more data — should work fine
|
||||
mem2.store("k2", "v2", MemoryCategory::Daily).await.unwrap();
|
||||
mem2.store("k2", "v2", MemoryCategory::Daily, None)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(mem2.count().await.unwrap(), 2);
|
||||
}
|
||||
|
||||
|
|
@ -1264,11 +1367,16 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn forget_then_recall_no_ghost_results() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("ghost", "phantom memory content", MemoryCategory::Core)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store(
|
||||
"ghost",
|
||||
"phantom memory content",
|
||||
MemoryCategory::Core,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.forget("ghost").await.unwrap();
|
||||
let results = mem.recall("phantom memory", 10).await.unwrap();
|
||||
let results = mem.recall("phantom memory", 10, None).await.unwrap();
|
||||
assert!(
|
||||
results.is_empty(),
|
||||
"Deleted memory should not appear in recall"
|
||||
|
|
@ -1278,11 +1386,11 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn forget_and_re_store_same_key() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("cycle", "version 1", MemoryCategory::Core)
|
||||
mem.store("cycle", "version 1", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.forget("cycle").await.unwrap();
|
||||
mem.store("cycle", "version 2", MemoryCategory::Core)
|
||||
mem.store("cycle", "version 2", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
let entry = mem.get("cycle").await.unwrap().unwrap();
|
||||
|
|
@ -1302,14 +1410,14 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn reindex_twice_is_safe() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("r1", "reindex data", MemoryCategory::Core)
|
||||
mem.store("r1", "reindex data", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.reindex().await.unwrap();
|
||||
let count = mem.reindex().await.unwrap();
|
||||
assert_eq!(count, 0); // Noop embedder → nothing to re-embed
|
||||
// Data should still be intact
|
||||
let results = mem.recall("reindex", 10).await.unwrap();
|
||||
let results = mem.recall("reindex", 10, None).await.unwrap();
|
||||
assert_eq!(results.len(), 1);
|
||||
}
|
||||
|
||||
|
|
@ -1363,18 +1471,28 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn list_custom_category() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("c1", "custom1", MemoryCategory::Custom("project".into()))
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("c2", "custom2", MemoryCategory::Custom("project".into()))
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("c3", "other", MemoryCategory::Core)
|
||||
mem.store(
|
||||
"c1",
|
||||
"custom1",
|
||||
MemoryCategory::Custom("project".into()),
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store(
|
||||
"c2",
|
||||
"custom2",
|
||||
MemoryCategory::Custom("project".into()),
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("c3", "other", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let project = mem
|
||||
.list(Some(&MemoryCategory::Custom("project".into())))
|
||||
.list(Some(&MemoryCategory::Custom("project".into())), None)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(project.len(), 2);
|
||||
|
|
@ -1383,7 +1501,122 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn list_empty_db() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
let all = mem.list(None).await.unwrap();
|
||||
let all = mem.list(None, None).await.unwrap();
|
||||
assert!(all.is_empty());
|
||||
}
|
||||
|
||||
// ── Session isolation ─────────────────────────────────────────
|
||||
|
||||
#[tokio::test]
|
||||
async fn store_and_recall_with_session_id() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("k1", "session A fact", MemoryCategory::Core, Some("sess-a"))
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("k2", "session B fact", MemoryCategory::Core, Some("sess-b"))
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("k3", "no session fact", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Recall with session-a filter returns only session-a entry
|
||||
let results = mem.recall("fact", 10, Some("sess-a")).await.unwrap();
|
||||
assert_eq!(results.len(), 1);
|
||||
assert_eq!(results[0].key, "k1");
|
||||
assert_eq!(results[0].session_id.as_deref(), Some("sess-a"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn recall_no_session_filter_returns_all() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("k1", "alpha fact", MemoryCategory::Core, Some("sess-a"))
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("k2", "beta fact", MemoryCategory::Core, Some("sess-b"))
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("k3", "gamma fact", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Recall without session filter returns all matching entries
|
||||
let results = mem.recall("fact", 10, None).await.unwrap();
|
||||
assert_eq!(results.len(), 3);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn cross_session_recall_isolation() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store(
|
||||
"secret",
|
||||
"session A secret data",
|
||||
MemoryCategory::Core,
|
||||
Some("sess-a"),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Session B cannot see session A data
|
||||
let results = mem.recall("secret", 10, Some("sess-b")).await.unwrap();
|
||||
assert!(results.is_empty());
|
||||
|
||||
// Session A can see its own data
|
||||
let results = mem.recall("secret", 10, Some("sess-a")).await.unwrap();
|
||||
assert_eq!(results.len(), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn list_with_session_filter() {
|
||||
let (_tmp, mem) = temp_sqlite();
|
||||
mem.store("k1", "a1", MemoryCategory::Core, Some("sess-a"))
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("k2", "a2", MemoryCategory::Conversation, Some("sess-a"))
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("k3", "b1", MemoryCategory::Core, Some("sess-b"))
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("k4", "none1", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// List with session-a filter
|
||||
let results = mem.list(None, Some("sess-a")).await.unwrap();
|
||||
assert_eq!(results.len(), 2);
|
||||
assert!(results
|
||||
.iter()
|
||||
.all(|e| e.session_id.as_deref() == Some("sess-a")));
|
||||
|
||||
// List with session-a + category filter
|
||||
let results = mem
|
||||
.list(Some(&MemoryCategory::Core), Some("sess-a"))
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(results.len(), 1);
|
||||
assert_eq!(results[0].key, "k1");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn schema_migration_idempotent_on_reopen() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
|
||||
// First open: creates schema + migration
|
||||
{
|
||||
let mem = SqliteMemory::new(tmp.path()).unwrap();
|
||||
mem.store("k1", "before reopen", MemoryCategory::Core, Some("sess-x"))
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
// Second open: migration runs again but is idempotent
|
||||
{
|
||||
let mem = SqliteMemory::new(tmp.path()).unwrap();
|
||||
let results = mem.recall("reopen", 10, Some("sess-x")).await.unwrap();
|
||||
assert_eq!(results.len(), 1);
|
||||
assert_eq!(results[0].key, "k1");
|
||||
assert_eq!(results[0].session_id.as_deref(), Some("sess-x"));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -44,18 +44,32 @@ pub trait Memory: Send + Sync {
|
|||
/// Backend name
|
||||
fn name(&self) -> &str;
|
||||
|
||||
/// Store a memory entry
|
||||
async fn store(&self, key: &str, content: &str, category: MemoryCategory)
|
||||
-> anyhow::Result<()>;
|
||||
/// Store a memory entry, optionally scoped to a session
|
||||
async fn store(
|
||||
&self,
|
||||
key: &str,
|
||||
content: &str,
|
||||
category: MemoryCategory,
|
||||
session_id: Option<&str>,
|
||||
) -> anyhow::Result<()>;
|
||||
|
||||
/// Recall memories matching a query (keyword search)
|
||||
async fn recall(&self, query: &str, limit: usize) -> anyhow::Result<Vec<MemoryEntry>>;
|
||||
/// Recall memories matching a query (keyword search), optionally scoped to a session
|
||||
async fn recall(
|
||||
&self,
|
||||
query: &str,
|
||||
limit: usize,
|
||||
session_id: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>>;
|
||||
|
||||
/// Get a specific memory by key
|
||||
async fn get(&self, key: &str) -> anyhow::Result<Option<MemoryEntry>>;
|
||||
|
||||
/// List all memory keys, optionally filtered by category
|
||||
async fn list(&self, category: Option<&MemoryCategory>) -> anyhow::Result<Vec<MemoryEntry>>;
|
||||
/// List all memory keys, optionally filtered by category and/or session
|
||||
async fn list(
|
||||
&self,
|
||||
category: Option<&MemoryCategory>,
|
||||
session_id: Option<&str>,
|
||||
) -> anyhow::Result<Vec<MemoryEntry>>;
|
||||
|
||||
/// Remove a memory by key
|
||||
async fn forget(&self, key: &str) -> anyhow::Result<bool>;
|
||||
|
|
|
|||
|
|
@ -95,7 +95,9 @@ async fn migrate_openclaw_memory(
|
|||
stats.renamed_conflicts += 1;
|
||||
}
|
||||
|
||||
memory.store(&key, &entry.content, entry.category).await?;
|
||||
memory
|
||||
.store(&key, &entry.content, entry.category, None)
|
||||
.await?;
|
||||
stats.imported += 1;
|
||||
}
|
||||
|
||||
|
|
@ -488,7 +490,7 @@ mod tests {
|
|||
// Existing target memory
|
||||
let target_mem = SqliteMemory::new(target.path()).unwrap();
|
||||
target_mem
|
||||
.store("k", "new value", MemoryCategory::Core)
|
||||
.store("k", "new value", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
|
|
@ -510,7 +512,7 @@ mod tests {
|
|||
.await
|
||||
.unwrap();
|
||||
|
||||
let all = target_mem.list(None).await.unwrap();
|
||||
let all = target_mem.list(None, None).await.unwrap();
|
||||
assert!(all.iter().any(|e| e.key == "k" && e.content == "new value"));
|
||||
assert!(all
|
||||
.iter()
|
||||
|
|
|
|||
|
|
@ -48,9 +48,10 @@ impl Observer for LogObserver {
|
|||
ObserverEvent::AgentEnd {
|
||||
duration,
|
||||
tokens_used,
|
||||
cost_usd,
|
||||
} => {
|
||||
let ms = u64::try_from(duration.as_millis()).unwrap_or(u64::MAX);
|
||||
info!(duration_ms = ms, tokens = ?tokens_used, "agent.end");
|
||||
info!(duration_ms = ms, tokens = ?tokens_used, cost_usd = ?cost_usd, "agent.end");
|
||||
}
|
||||
ObserverEvent::ToolCallStart { tool } => {
|
||||
info!(tool = %tool, "tool.start");
|
||||
|
|
@ -133,10 +134,12 @@ mod tests {
|
|||
obs.record_event(&ObserverEvent::AgentEnd {
|
||||
duration: Duration::from_millis(500),
|
||||
tokens_used: Some(100),
|
||||
cost_usd: Some(0.0015),
|
||||
});
|
||||
obs.record_event(&ObserverEvent::AgentEnd {
|
||||
duration: Duration::ZERO,
|
||||
tokens_used: None,
|
||||
cost_usd: None,
|
||||
});
|
||||
obs.record_event(&ObserverEvent::ToolCallStart {
|
||||
tool: "shell".into(),
|
||||
|
|
|
|||
|
|
@ -48,10 +48,12 @@ mod tests {
|
|||
obs.record_event(&ObserverEvent::AgentEnd {
|
||||
duration: Duration::from_millis(100),
|
||||
tokens_used: Some(42),
|
||||
cost_usd: Some(0.001),
|
||||
});
|
||||
obs.record_event(&ObserverEvent::AgentEnd {
|
||||
duration: Duration::ZERO,
|
||||
tokens_used: None,
|
||||
cost_usd: None,
|
||||
});
|
||||
obs.record_event(&ObserverEvent::ToolCallStart {
|
||||
tool: "shell".into(),
|
||||
|
|
|
|||
|
|
@ -227,6 +227,7 @@ impl Observer for OtelObserver {
|
|||
ObserverEvent::AgentEnd {
|
||||
duration,
|
||||
tokens_used,
|
||||
cost_usd,
|
||||
} => {
|
||||
let secs = duration.as_secs_f64();
|
||||
let start_time = SystemTime::now()
|
||||
|
|
@ -243,6 +244,9 @@ impl Observer for OtelObserver {
|
|||
if let Some(t) = tokens_used {
|
||||
span.set_attribute(KeyValue::new("tokens_used", *t as i64));
|
||||
}
|
||||
if let Some(c) = cost_usd {
|
||||
span.set_attribute(KeyValue::new("cost_usd", *c));
|
||||
}
|
||||
span.end();
|
||||
|
||||
self.agent_duration.record(secs, &[]);
|
||||
|
|
@ -394,10 +398,12 @@ mod tests {
|
|||
obs.record_event(&ObserverEvent::AgentEnd {
|
||||
duration: Duration::from_millis(500),
|
||||
tokens_used: Some(100),
|
||||
cost_usd: Some(0.0015),
|
||||
});
|
||||
obs.record_event(&ObserverEvent::AgentEnd {
|
||||
duration: Duration::ZERO,
|
||||
tokens_used: None,
|
||||
cost_usd: None,
|
||||
});
|
||||
obs.record_event(&ObserverEvent::ToolCallStart {
|
||||
tool: "shell".into(),
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ pub enum ObserverEvent {
|
|||
AgentEnd {
|
||||
duration: Duration,
|
||||
tokens_used: Option<u64>,
|
||||
cost_usd: Option<f64>,
|
||||
},
|
||||
/// A tool call is about to be executed.
|
||||
ToolCallStart {
|
||||
|
|
|
|||
|
|
@ -106,6 +106,7 @@ pub fn run_wizard() -> Result<Config> {
|
|||
} else {
|
||||
Some(api_key)
|
||||
},
|
||||
api_url: None,
|
||||
default_provider: Some(provider),
|
||||
default_model: Some(model),
|
||||
default_temperature: 0.7,
|
||||
|
|
@ -284,7 +285,7 @@ fn memory_config_defaults_for_backend(backend: &str) -> MemoryConfig {
|
|||
|
||||
#[allow(clippy::too_many_lines)]
|
||||
pub fn run_quick_setup(
|
||||
api_key: Option<&str>,
|
||||
credential_override: Option<&str>,
|
||||
provider: Option<&str>,
|
||||
memory_backend: Option<&str>,
|
||||
) -> Result<Config> {
|
||||
|
|
@ -318,7 +319,8 @@ pub fn run_quick_setup(
|
|||
let config = Config {
|
||||
workspace_dir: workspace_dir.clone(),
|
||||
config_path: config_path.clone(),
|
||||
api_key: api_key.map(String::from),
|
||||
api_key: credential_override.map(String::from),
|
||||
api_url: None,
|
||||
default_provider: Some(provider_name.clone()),
|
||||
default_model: Some(model.clone()),
|
||||
default_temperature: 0.7,
|
||||
|
|
@ -377,7 +379,7 @@ pub fn run_quick_setup(
|
|||
println!(
|
||||
" {} API Key: {}",
|
||||
style("✓").green().bold(),
|
||||
if api_key.is_some() {
|
||||
if credential_override.is_some() {
|
||||
style("set").green()
|
||||
} else {
|
||||
style("not set (use --api-key or edit config.toml)").yellow()
|
||||
|
|
@ -426,7 +428,7 @@ pub fn run_quick_setup(
|
|||
);
|
||||
println!();
|
||||
println!(" {}", style("Next steps:").white().bold());
|
||||
if api_key.is_none() {
|
||||
if credential_override.is_none() {
|
||||
println!(" 1. Set your API key: export OPENROUTER_API_KEY=\"sk-...\"");
|
||||
println!(" 2. Or edit: ~/.zeroclaw/config.toml");
|
||||
println!(" 3. Chat: zeroclaw agent -m \"Hello!\"");
|
||||
|
|
@ -2269,14 +2271,11 @@ fn setup_memory() -> Result<MemoryConfig> {
|
|||
let backend = backend_key_from_choice(choice);
|
||||
let profile = memory_backend_profile(backend);
|
||||
|
||||
let auto_save = if !profile.auto_save_default {
|
||||
false
|
||||
} else {
|
||||
Confirm::new()
|
||||
let auto_save = profile.auto_save_default
|
||||
&& Confirm::new()
|
||||
.with_prompt(" Auto-save conversations to memory?")
|
||||
.default(true)
|
||||
.interact()?
|
||||
};
|
||||
.interact()?;
|
||||
|
||||
println!(
|
||||
" {} Memory: {} (auto-save: {})",
|
||||
|
|
@ -2587,6 +2586,7 @@ fn setup_channels() -> Result<ChannelsConfig> {
|
|||
guild_id: if guild.is_empty() { None } else { Some(guild) },
|
||||
allowed_users,
|
||||
listen_to_bots: false,
|
||||
mention_only: false,
|
||||
});
|
||||
}
|
||||
2 => {
|
||||
|
|
@ -2799,22 +2799,14 @@ fn setup_channels() -> Result<ChannelsConfig> {
|
|||
.header("Authorization", format!("Bearer {access_token_clone}"))
|
||||
.send()?;
|
||||
let ok = resp.status().is_success();
|
||||
let data: serde_json::Value = resp.json().unwrap_or_default();
|
||||
let user_id = data
|
||||
.get("user_id")
|
||||
.and_then(serde_json::Value::as_str)
|
||||
.unwrap_or("unknown")
|
||||
.to_string();
|
||||
Ok::<_, reqwest::Error>((ok, user_id))
|
||||
Ok::<_, reqwest::Error>(ok)
|
||||
})
|
||||
.join();
|
||||
match thread_result {
|
||||
Ok(Ok((true, user_id))) => {
|
||||
println!(
|
||||
"\r {} Connected as {user_id} ",
|
||||
style("✅").green().bold()
|
||||
);
|
||||
}
|
||||
Ok(Ok(true)) => println!(
|
||||
"\r {} Connection verified ",
|
||||
style("✅").green().bold()
|
||||
),
|
||||
_ => {
|
||||
println!(
|
||||
"\r {} Connection failed — check homeserver URL and token",
|
||||
|
|
@ -3779,15 +3771,7 @@ fn print_summary(config: &Config) {
|
|||
);
|
||||
|
||||
// Secrets
|
||||
println!(
|
||||
" {} Secrets: {}",
|
||||
style("🔒").cyan(),
|
||||
if config.secrets.encrypt {
|
||||
style("encrypted").green().to_string()
|
||||
} else {
|
||||
style("plaintext").yellow().to_string()
|
||||
}
|
||||
);
|
||||
println!(" {} Secrets: configured", style("🔒").cyan());
|
||||
|
||||
// Gateway
|
||||
println!(
|
||||
|
|
|
|||
|
|
@ -38,6 +38,10 @@ pub fn ensure_arduino_cli() -> Result<()> {
|
|||
anyhow::bail!("brew install arduino-cli failed. Install manually: https://arduino.github.io/arduino-cli/");
|
||||
}
|
||||
println!("arduino-cli installed.");
|
||||
if !arduino_cli_available() {
|
||||
anyhow::bail!("arduino-cli still not found after install. Ensure it's in PATH.");
|
||||
}
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
#[cfg(target_os = "linux")]
|
||||
|
|
@ -54,11 +58,6 @@ pub fn ensure_arduino_cli() -> Result<()> {
|
|||
println!("arduino-cli not found. Install it: https://arduino.github.io/arduino-cli/");
|
||||
anyhow::bail!("arduino-cli not installed.");
|
||||
}
|
||||
|
||||
if !arduino_cli_available() {
|
||||
anyhow::bail!("arduino-cli still not found after install. Ensure it's in PATH.");
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Ensure arduino:avr core is installed.
|
||||
|
|
|
|||
|
|
@ -112,6 +112,7 @@ pub struct SerialPeripheral {
|
|||
|
||||
impl SerialPeripheral {
|
||||
/// Create and connect to a serial peripheral.
|
||||
#[allow(clippy::unused_async)]
|
||||
pub async fn connect(config: &PeripheralBoardConfig) -> anyhow::Result<Self> {
|
||||
let path = config
|
||||
.path
|
||||
|
|
|
|||
|
|
@ -106,17 +106,17 @@ struct NativeContentIn {
|
|||
}
|
||||
|
||||
impl AnthropicProvider {
|
||||
pub fn new(api_key: Option<&str>) -> Self {
|
||||
Self::with_base_url(api_key, None)
|
||||
pub fn new(credential: Option<&str>) -> Self {
|
||||
Self::with_base_url(credential, None)
|
||||
}
|
||||
|
||||
pub fn with_base_url(api_key: Option<&str>, base_url: Option<&str>) -> Self {
|
||||
pub fn with_base_url(credential: Option<&str>, base_url: Option<&str>) -> Self {
|
||||
let base_url = base_url
|
||||
.map(|u| u.trim_end_matches('/'))
|
||||
.unwrap_or("https://api.anthropic.com")
|
||||
.to_string();
|
||||
Self {
|
||||
credential: api_key
|
||||
credential: credential
|
||||
.map(str::trim)
|
||||
.filter(|k| !k.is_empty())
|
||||
.map(ToString::to_string),
|
||||
|
|
@ -410,9 +410,9 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn creates_with_key() {
|
||||
let p = AnthropicProvider::new(Some("sk-ant-test123"));
|
||||
let p = AnthropicProvider::new(Some("anthropic-test-credential"));
|
||||
assert!(p.credential.is_some());
|
||||
assert_eq!(p.credential.as_deref(), Some("sk-ant-test123"));
|
||||
assert_eq!(p.credential.as_deref(), Some("anthropic-test-credential"));
|
||||
assert_eq!(p.base_url, "https://api.anthropic.com");
|
||||
}
|
||||
|
||||
|
|
@ -431,17 +431,19 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn creates_with_whitespace_key() {
|
||||
let p = AnthropicProvider::new(Some(" sk-ant-test123 "));
|
||||
let p = AnthropicProvider::new(Some(" anthropic-test-credential "));
|
||||
assert!(p.credential.is_some());
|
||||
assert_eq!(p.credential.as_deref(), Some("sk-ant-test123"));
|
||||
assert_eq!(p.credential.as_deref(), Some("anthropic-test-credential"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn creates_with_custom_base_url() {
|
||||
let p =
|
||||
AnthropicProvider::with_base_url(Some("sk-ant-test"), Some("https://api.example.com"));
|
||||
let p = AnthropicProvider::with_base_url(
|
||||
Some("anthropic-credential"),
|
||||
Some("https://api.example.com"),
|
||||
);
|
||||
assert_eq!(p.base_url, "https://api.example.com");
|
||||
assert_eq!(p.credential.as_deref(), Some("sk-ant-test"));
|
||||
assert_eq!(p.credential.as_deref(), Some("anthropic-credential"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
|
|||
|
|
@ -17,7 +17,7 @@ use serde::{Deserialize, Serialize};
|
|||
pub struct OpenAiCompatibleProvider {
|
||||
pub(crate) name: String,
|
||||
pub(crate) base_url: String,
|
||||
pub(crate) api_key: Option<String>,
|
||||
pub(crate) credential: Option<String>,
|
||||
pub(crate) auth_header: AuthStyle,
|
||||
/// When false, do not fall back to /v1/responses on chat completions 404.
|
||||
/// GLM/Zhipu does not support the responses API.
|
||||
|
|
@ -37,11 +37,16 @@ pub enum AuthStyle {
|
|||
}
|
||||
|
||||
impl OpenAiCompatibleProvider {
|
||||
pub fn new(name: &str, base_url: &str, api_key: Option<&str>, auth_style: AuthStyle) -> Self {
|
||||
pub fn new(
|
||||
name: &str,
|
||||
base_url: &str,
|
||||
credential: Option<&str>,
|
||||
auth_style: AuthStyle,
|
||||
) -> Self {
|
||||
Self {
|
||||
name: name.to_string(),
|
||||
base_url: base_url.trim_end_matches('/').to_string(),
|
||||
api_key: api_key.map(ToString::to_string),
|
||||
credential: credential.map(ToString::to_string),
|
||||
auth_header: auth_style,
|
||||
supports_responses_fallback: true,
|
||||
client: Client::builder()
|
||||
|
|
@ -57,13 +62,13 @@ impl OpenAiCompatibleProvider {
|
|||
pub fn new_no_responses_fallback(
|
||||
name: &str,
|
||||
base_url: &str,
|
||||
api_key: Option<&str>,
|
||||
credential: Option<&str>,
|
||||
auth_style: AuthStyle,
|
||||
) -> Self {
|
||||
Self {
|
||||
name: name.to_string(),
|
||||
base_url: base_url.trim_end_matches('/').to_string(),
|
||||
api_key: api_key.map(ToString::to_string),
|
||||
credential: credential.map(ToString::to_string),
|
||||
auth_header: auth_style,
|
||||
supports_responses_fallback: false,
|
||||
client: Client::builder()
|
||||
|
|
@ -405,18 +410,18 @@ impl OpenAiCompatibleProvider {
|
|||
fn apply_auth_header(
|
||||
&self,
|
||||
req: reqwest::RequestBuilder,
|
||||
api_key: &str,
|
||||
credential: &str,
|
||||
) -> reqwest::RequestBuilder {
|
||||
match &self.auth_header {
|
||||
AuthStyle::Bearer => req.header("Authorization", format!("Bearer {api_key}")),
|
||||
AuthStyle::XApiKey => req.header("x-api-key", api_key),
|
||||
AuthStyle::Custom(header) => req.header(header, api_key),
|
||||
AuthStyle::Bearer => req.header("Authorization", format!("Bearer {credential}")),
|
||||
AuthStyle::XApiKey => req.header("x-api-key", credential),
|
||||
AuthStyle::Custom(header) => req.header(header, credential),
|
||||
}
|
||||
}
|
||||
|
||||
async fn chat_via_responses(
|
||||
&self,
|
||||
api_key: &str,
|
||||
credential: &str,
|
||||
system_prompt: Option<&str>,
|
||||
message: &str,
|
||||
model: &str,
|
||||
|
|
@ -434,7 +439,7 @@ impl OpenAiCompatibleProvider {
|
|||
let url = self.responses_url();
|
||||
|
||||
let response = self
|
||||
.apply_auth_header(self.client.post(&url).json(&request), api_key)
|
||||
.apply_auth_header(self.client.post(&url).json(&request), credential)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
|
|
@ -459,7 +464,7 @@ impl Provider for OpenAiCompatibleProvider {
|
|||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
let api_key = self.api_key.as_ref().ok_or_else(|| {
|
||||
let credential = self.credential.as_ref().ok_or_else(|| {
|
||||
anyhow::anyhow!(
|
||||
"{} API key not set. Run `zeroclaw onboard` or set the appropriate env var.",
|
||||
self.name
|
||||
|
|
@ -490,7 +495,7 @@ impl Provider for OpenAiCompatibleProvider {
|
|||
let url = self.chat_completions_url();
|
||||
|
||||
let response = self
|
||||
.apply_auth_header(self.client.post(&url).json(&request), api_key)
|
||||
.apply_auth_header(self.client.post(&url).json(&request), credential)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
|
|
@ -501,7 +506,7 @@ impl Provider for OpenAiCompatibleProvider {
|
|||
|
||||
if status == reqwest::StatusCode::NOT_FOUND && self.supports_responses_fallback {
|
||||
return self
|
||||
.chat_via_responses(api_key, system_prompt, message, model)
|
||||
.chat_via_responses(credential, system_prompt, message, model)
|
||||
.await
|
||||
.map_err(|responses_err| {
|
||||
anyhow::anyhow!(
|
||||
|
|
@ -545,7 +550,7 @@ impl Provider for OpenAiCompatibleProvider {
|
|||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
let api_key = self.api_key.as_ref().ok_or_else(|| {
|
||||
let credential = self.credential.as_ref().ok_or_else(|| {
|
||||
anyhow::anyhow!(
|
||||
"{} API key not set. Run `zeroclaw onboard` or set the appropriate env var.",
|
||||
self.name
|
||||
|
|
@ -569,7 +574,7 @@ impl Provider for OpenAiCompatibleProvider {
|
|||
|
||||
let url = self.chat_completions_url();
|
||||
let response = self
|
||||
.apply_auth_header(self.client.post(&url).json(&request), api_key)
|
||||
.apply_auth_header(self.client.post(&url).json(&request), credential)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
|
|
@ -584,7 +589,7 @@ impl Provider for OpenAiCompatibleProvider {
|
|||
if let Some(user_msg) = last_user {
|
||||
return self
|
||||
.chat_via_responses(
|
||||
api_key,
|
||||
credential,
|
||||
system.map(|m| m.content.as_str()),
|
||||
&user_msg.content,
|
||||
model,
|
||||
|
|
@ -791,16 +796,20 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn creates_with_key() {
|
||||
let p = make_provider("venice", "https://api.venice.ai", Some("vn-key"));
|
||||
let p = make_provider(
|
||||
"venice",
|
||||
"https://api.venice.ai",
|
||||
Some("venice-test-credential"),
|
||||
);
|
||||
assert_eq!(p.name, "venice");
|
||||
assert_eq!(p.base_url, "https://api.venice.ai");
|
||||
assert_eq!(p.api_key.as_deref(), Some("vn-key"));
|
||||
assert_eq!(p.credential.as_deref(), Some("venice-test-credential"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn creates_without_key() {
|
||||
let p = make_provider("test", "https://example.com", None);
|
||||
assert!(p.api_key.is_none());
|
||||
assert!(p.credential.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -894,6 +903,7 @@ mod tests {
|
|||
make_provider("Groq", "https://api.groq.com/openai", None),
|
||||
make_provider("Mistral", "https://api.mistral.ai", None),
|
||||
make_provider("xAI", "https://api.x.ai", None),
|
||||
make_provider("Astrai", "https://as-trai.com/v1", None),
|
||||
];
|
||||
|
||||
for p in providers {
|
||||
|
|
|
|||
705
src/providers/copilot.rs
Normal file
705
src/providers/copilot.rs
Normal file
|
|
@ -0,0 +1,705 @@
|
|||
//! GitHub Copilot provider with OAuth device-flow authentication.
|
||||
//!
|
||||
//! Authenticates via GitHub's device code flow (same as VS Code Copilot),
|
||||
//! then exchanges the OAuth token for short-lived Copilot API keys.
|
||||
//! Tokens are cached to disk and auto-refreshed.
|
||||
//!
|
||||
//! **Note:** This uses VS Code's OAuth client ID (`Iv1.b507a08c87ecfe98`) and
|
||||
//! editor headers. This is the same approach used by LiteLLM, Codex CLI,
|
||||
//! and other third-party Copilot integrations. The Copilot token endpoint is
|
||||
//! private; there is no public OAuth scope or app registration for it.
|
||||
//! GitHub could change or revoke this at any time, which would break all
|
||||
//! third-party integrations simultaneously.
|
||||
|
||||
use crate::providers::traits::{
|
||||
ChatMessage, ChatRequest as ProviderChatRequest, ChatResponse as ProviderChatResponse,
|
||||
Provider, ToolCall as ProviderToolCall,
|
||||
};
|
||||
use crate::tools::ToolSpec;
|
||||
use async_trait::async_trait;
|
||||
use reqwest::Client;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use tokio::sync::Mutex;
|
||||
use tracing::warn;
|
||||
|
||||
/// GitHub OAuth client ID for Copilot (VS Code extension).
|
||||
const GITHUB_CLIENT_ID: &str = "Iv1.b507a08c87ecfe98";
|
||||
const GITHUB_DEVICE_CODE_URL: &str = "https://github.com/login/device/code";
|
||||
const GITHUB_ACCESS_TOKEN_URL: &str = "https://github.com/login/oauth/access_token";
|
||||
const GITHUB_API_KEY_URL: &str = "https://api.github.com/copilot_internal/v2/token";
|
||||
const DEFAULT_API: &str = "https://api.githubcopilot.com";
|
||||
|
||||
// ── Token types ──────────────────────────────────────────────────
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct DeviceCodeResponse {
|
||||
device_code: String,
|
||||
user_code: String,
|
||||
verification_uri: String,
|
||||
#[serde(default = "default_interval")]
|
||||
interval: u64,
|
||||
#[serde(default = "default_expires_in")]
|
||||
expires_in: u64,
|
||||
}
|
||||
|
||||
fn default_interval() -> u64 {
|
||||
5
|
||||
}
|
||||
|
||||
fn default_expires_in() -> u64 {
|
||||
900
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct AccessTokenResponse {
|
||||
access_token: Option<String>,
|
||||
error: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct ApiKeyInfo {
|
||||
token: String,
|
||||
expires_at: i64,
|
||||
#[serde(default)]
|
||||
endpoints: Option<ApiEndpoints>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct ApiEndpoints {
|
||||
api: Option<String>,
|
||||
}
|
||||
|
||||
struct CachedApiKey {
|
||||
token: String,
|
||||
api_endpoint: String,
|
||||
expires_at: i64,
|
||||
}
|
||||
|
||||
// ── Chat completions types ───────────────────────────────────────
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct ApiChatRequest {
|
||||
model: String,
|
||||
messages: Vec<ApiMessage>,
|
||||
temperature: f64,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
tools: Option<Vec<NativeToolSpec>>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
tool_choice: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct ApiMessage {
|
||||
role: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
content: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
tool_call_id: Option<String>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
tool_calls: Option<Vec<NativeToolCall>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct NativeToolSpec {
|
||||
#[serde(rename = "type")]
|
||||
kind: String,
|
||||
function: NativeToolFunctionSpec,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct NativeToolFunctionSpec {
|
||||
name: String,
|
||||
description: String,
|
||||
parameters: serde_json::Value,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct NativeToolCall {
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
id: Option<String>,
|
||||
#[serde(rename = "type", skip_serializing_if = "Option::is_none")]
|
||||
kind: Option<String>,
|
||||
function: NativeFunctionCall,
|
||||
}
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize)]
|
||||
struct NativeFunctionCall {
|
||||
name: String,
|
||||
arguments: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ApiChatResponse {
|
||||
choices: Vec<Choice>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct Choice {
|
||||
message: ResponseMessage,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ResponseMessage {
|
||||
#[serde(default)]
|
||||
content: Option<String>,
|
||||
#[serde(default)]
|
||||
tool_calls: Option<Vec<NativeToolCall>>,
|
||||
}
|
||||
|
||||
// ── Provider ─────────────────────────────────────────────────────
|
||||
|
||||
/// GitHub Copilot provider with automatic OAuth and token refresh.
|
||||
///
|
||||
/// On first use, prompts the user to visit github.com/login/device.
|
||||
/// Tokens are cached to `~/.config/zeroclaw/copilot/` and refreshed
|
||||
/// automatically.
|
||||
pub struct CopilotProvider {
|
||||
github_token: Option<String>,
|
||||
/// Mutex ensures only one caller refreshes tokens at a time,
|
||||
/// preventing duplicate device flow prompts or redundant API calls.
|
||||
refresh_lock: Arc<Mutex<Option<CachedApiKey>>>,
|
||||
http: Client,
|
||||
token_dir: PathBuf,
|
||||
}
|
||||
|
||||
impl CopilotProvider {
|
||||
pub fn new(github_token: Option<&str>) -> Self {
|
||||
let token_dir = directories::ProjectDirs::from("", "", "zeroclaw")
|
||||
.map(|dir| dir.config_dir().join("copilot"))
|
||||
.unwrap_or_else(|| {
|
||||
// Fall back to a user-specific temp directory to avoid
|
||||
// shared-directory symlink attacks.
|
||||
let user = std::env::var("USER")
|
||||
.or_else(|_| std::env::var("USERNAME"))
|
||||
.unwrap_or_else(|_| "unknown".to_string());
|
||||
std::env::temp_dir().join(format!("zeroclaw-copilot-{user}"))
|
||||
});
|
||||
|
||||
if let Err(err) = std::fs::create_dir_all(&token_dir) {
|
||||
warn!(
|
||||
"Failed to create Copilot token directory {:?}: {err}. Token caching is disabled.",
|
||||
token_dir
|
||||
);
|
||||
} else {
|
||||
#[cfg(unix)]
|
||||
{
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
|
||||
if let Err(err) =
|
||||
std::fs::set_permissions(&token_dir, std::fs::Permissions::from_mode(0o700))
|
||||
{
|
||||
warn!(
|
||||
"Failed to set Copilot token directory permissions on {:?}: {err}",
|
||||
token_dir
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Self {
|
||||
github_token: github_token
|
||||
.filter(|token| !token.is_empty())
|
||||
.map(String::from),
|
||||
refresh_lock: Arc::new(Mutex::new(None)),
|
||||
http: Client::builder()
|
||||
.timeout(Duration::from_secs(120))
|
||||
.connect_timeout(Duration::from_secs(10))
|
||||
.build()
|
||||
.unwrap_or_else(|_| Client::new()),
|
||||
token_dir,
|
||||
}
|
||||
}
|
||||
|
||||
/// Required headers for Copilot API requests (editor identification).
|
||||
const COPILOT_HEADERS: [(&str, &str); 4] = [
|
||||
("Editor-Version", "vscode/1.85.1"),
|
||||
("Editor-Plugin-Version", "copilot/1.155.0"),
|
||||
("User-Agent", "GithubCopilot/1.155.0"),
|
||||
("Accept", "application/json"),
|
||||
];
|
||||
|
||||
fn convert_tools(tools: Option<&[ToolSpec]>) -> Option<Vec<NativeToolSpec>> {
|
||||
tools.map(|items| {
|
||||
items
|
||||
.iter()
|
||||
.map(|tool| NativeToolSpec {
|
||||
kind: "function".to_string(),
|
||||
function: NativeToolFunctionSpec {
|
||||
name: tool.name.clone(),
|
||||
description: tool.description.clone(),
|
||||
parameters: tool.parameters.clone(),
|
||||
},
|
||||
})
|
||||
.collect()
|
||||
})
|
||||
}
|
||||
|
||||
fn convert_messages(messages: &[ChatMessage]) -> Vec<ApiMessage> {
|
||||
messages
|
||||
.iter()
|
||||
.map(|message| {
|
||||
if message.role == "assistant" {
|
||||
if let Ok(value) = serde_json::from_str::<serde_json::Value>(&message.content) {
|
||||
if let Some(tool_calls_value) = value.get("tool_calls") {
|
||||
if let Ok(parsed_calls) =
|
||||
serde_json::from_value::<Vec<ProviderToolCall>>(tool_calls_value.clone())
|
||||
{
|
||||
let tool_calls = parsed_calls
|
||||
.into_iter()
|
||||
.map(|tool_call| NativeToolCall {
|
||||
id: Some(tool_call.id),
|
||||
kind: Some("function".to_string()),
|
||||
function: NativeFunctionCall {
|
||||
name: tool_call.name,
|
||||
arguments: tool_call.arguments,
|
||||
},
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let content = value
|
||||
.get("content")
|
||||
.and_then(serde_json::Value::as_str)
|
||||
.map(ToString::to_string);
|
||||
|
||||
return ApiMessage {
|
||||
role: "assistant".to_string(),
|
||||
content,
|
||||
tool_call_id: None,
|
||||
tool_calls: Some(tool_calls),
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if message.role == "tool" {
|
||||
if let Ok(value) = serde_json::from_str::<serde_json::Value>(&message.content) {
|
||||
let tool_call_id = value
|
||||
.get("tool_call_id")
|
||||
.and_then(serde_json::Value::as_str)
|
||||
.map(ToString::to_string);
|
||||
let content = value
|
||||
.get("content")
|
||||
.and_then(serde_json::Value::as_str)
|
||||
.map(ToString::to_string);
|
||||
|
||||
return ApiMessage {
|
||||
role: "tool".to_string(),
|
||||
content,
|
||||
tool_call_id,
|
||||
tool_calls: None,
|
||||
};
|
||||
}
|
||||
}
|
||||
|
||||
ApiMessage {
|
||||
role: message.role.clone(),
|
||||
content: Some(message.content.clone()),
|
||||
tool_call_id: None,
|
||||
tool_calls: None,
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Send a chat completions request with required Copilot headers.
|
||||
async fn send_chat_request(
|
||||
&self,
|
||||
messages: Vec<ApiMessage>,
|
||||
tools: Option<&[ToolSpec]>,
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<ProviderChatResponse> {
|
||||
let (token, endpoint) = self.get_api_key().await?;
|
||||
let url = format!("{}/chat/completions", endpoint.trim_end_matches('/'));
|
||||
|
||||
let native_tools = Self::convert_tools(tools);
|
||||
let request = ApiChatRequest {
|
||||
model: model.to_string(),
|
||||
messages,
|
||||
temperature,
|
||||
tool_choice: native_tools.as_ref().map(|_| "auto".to_string()),
|
||||
tools: native_tools,
|
||||
};
|
||||
|
||||
let mut req = self
|
||||
.http
|
||||
.post(&url)
|
||||
.header("Authorization", format!("Bearer {token}"))
|
||||
.json(&request);
|
||||
|
||||
for (header, value) in &Self::COPILOT_HEADERS {
|
||||
req = req.header(*header, *value);
|
||||
}
|
||||
|
||||
let response = req.send().await?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
return Err(super::api_error("GitHub Copilot", response).await);
|
||||
}
|
||||
|
||||
let api_response: ApiChatResponse = response.json().await?;
|
||||
let choice = api_response
|
||||
.choices
|
||||
.into_iter()
|
||||
.next()
|
||||
.ok_or_else(|| anyhow::anyhow!("No response from GitHub Copilot"))?;
|
||||
|
||||
let tool_calls = choice
|
||||
.message
|
||||
.tool_calls
|
||||
.unwrap_or_default()
|
||||
.into_iter()
|
||||
.map(|tool_call| ProviderToolCall {
|
||||
id: tool_call
|
||||
.id
|
||||
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string()),
|
||||
name: tool_call.function.name,
|
||||
arguments: tool_call.function.arguments,
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(ProviderChatResponse {
|
||||
text: choice.message.content,
|
||||
tool_calls,
|
||||
})
|
||||
}
|
||||
|
||||
/// Get a valid Copilot API key, refreshing or re-authenticating as needed.
|
||||
/// Uses a Mutex to ensure only one caller refreshes at a time.
|
||||
async fn get_api_key(&self) -> anyhow::Result<(String, String)> {
|
||||
let mut cached = self.refresh_lock.lock().await;
|
||||
|
||||
if let Some(cached_key) = cached.as_ref() {
|
||||
if chrono::Utc::now().timestamp() + 120 < cached_key.expires_at {
|
||||
return Ok((cached_key.token.clone(), cached_key.api_endpoint.clone()));
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(info) = self.load_api_key_from_disk().await {
|
||||
if chrono::Utc::now().timestamp() + 120 < info.expires_at {
|
||||
let endpoint = info
|
||||
.endpoints
|
||||
.as_ref()
|
||||
.and_then(|e| e.api.clone())
|
||||
.unwrap_or_else(|| DEFAULT_API.to_string());
|
||||
let token = info.token;
|
||||
|
||||
*cached = Some(CachedApiKey {
|
||||
token: token.clone(),
|
||||
api_endpoint: endpoint.clone(),
|
||||
expires_at: info.expires_at,
|
||||
});
|
||||
return Ok((token, endpoint));
|
||||
}
|
||||
}
|
||||
|
||||
let access_token = self.get_github_access_token().await?;
|
||||
let api_key_info = self.exchange_for_api_key(&access_token).await?;
|
||||
self.save_api_key_to_disk(&api_key_info).await;
|
||||
|
||||
let endpoint = api_key_info
|
||||
.endpoints
|
||||
.as_ref()
|
||||
.and_then(|e| e.api.clone())
|
||||
.unwrap_or_else(|| DEFAULT_API.to_string());
|
||||
|
||||
*cached = Some(CachedApiKey {
|
||||
token: api_key_info.token.clone(),
|
||||
api_endpoint: endpoint.clone(),
|
||||
expires_at: api_key_info.expires_at,
|
||||
});
|
||||
|
||||
Ok((api_key_info.token, endpoint))
|
||||
}
|
||||
|
||||
/// Get a GitHub access token from config, cache, or device flow.
|
||||
async fn get_github_access_token(&self) -> anyhow::Result<String> {
|
||||
if let Some(token) = &self.github_token {
|
||||
return Ok(token.clone());
|
||||
}
|
||||
|
||||
let access_token_path = self.token_dir.join("access-token");
|
||||
if let Ok(cached) = tokio::fs::read_to_string(&access_token_path).await {
|
||||
let token = cached.trim();
|
||||
if !token.is_empty() {
|
||||
return Ok(token.to_string());
|
||||
}
|
||||
}
|
||||
|
||||
let token = self.device_code_login().await?;
|
||||
write_file_secure(&access_token_path, &token).await;
|
||||
Ok(token)
|
||||
}
|
||||
|
||||
/// Run GitHub OAuth device code flow.
|
||||
async fn device_code_login(&self) -> anyhow::Result<String> {
|
||||
let response: DeviceCodeResponse = self
|
||||
.http
|
||||
.post(GITHUB_DEVICE_CODE_URL)
|
||||
.header("Accept", "application/json")
|
||||
.json(&serde_json::json!({
|
||||
"client_id": GITHUB_CLIENT_ID,
|
||||
"scope": "read:user"
|
||||
}))
|
||||
.send()
|
||||
.await?
|
||||
.error_for_status()?
|
||||
.json()
|
||||
.await?;
|
||||
|
||||
let mut poll_interval = Duration::from_secs(response.interval.max(5));
|
||||
let expires_in = response.expires_in.max(1);
|
||||
let expires_at = tokio::time::Instant::now() + Duration::from_secs(expires_in);
|
||||
|
||||
eprintln!(
|
||||
"\nGitHub Copilot authentication is required.\n\
|
||||
Visit: {}\n\
|
||||
Code: {}\n\
|
||||
Waiting for authorization...\n",
|
||||
response.verification_uri, response.user_code
|
||||
);
|
||||
|
||||
while tokio::time::Instant::now() < expires_at {
|
||||
tokio::time::sleep(poll_interval).await;
|
||||
|
||||
let token_response: AccessTokenResponse = self
|
||||
.http
|
||||
.post(GITHUB_ACCESS_TOKEN_URL)
|
||||
.header("Accept", "application/json")
|
||||
.json(&serde_json::json!({
|
||||
"client_id": GITHUB_CLIENT_ID,
|
||||
"device_code": response.device_code,
|
||||
"grant_type": "urn:ietf:params:oauth:grant-type:device_code"
|
||||
}))
|
||||
.send()
|
||||
.await?
|
||||
.json()
|
||||
.await?;
|
||||
|
||||
if let Some(token) = token_response.access_token {
|
||||
eprintln!("Authentication succeeded.\n");
|
||||
return Ok(token);
|
||||
}
|
||||
|
||||
match token_response.error.as_deref() {
|
||||
Some("slow_down") => {
|
||||
poll_interval += Duration::from_secs(5);
|
||||
}
|
||||
Some("authorization_pending") | None => {}
|
||||
Some("expired_token") => {
|
||||
anyhow::bail!("GitHub device authorization expired")
|
||||
}
|
||||
Some(error) => anyhow::bail!("GitHub auth failed: {error}"),
|
||||
}
|
||||
}
|
||||
|
||||
anyhow::bail!("Timed out waiting for GitHub authorization")
|
||||
}
|
||||
|
||||
/// Exchange a GitHub access token for a Copilot API key.
|
||||
async fn exchange_for_api_key(&self, access_token: &str) -> anyhow::Result<ApiKeyInfo> {
|
||||
let mut request = self.http.get(GITHUB_API_KEY_URL);
|
||||
for (header, value) in &Self::COPILOT_HEADERS {
|
||||
request = request.header(*header, *value);
|
||||
}
|
||||
request = request.header("Authorization", format!("token {access_token}"));
|
||||
|
||||
let response = request.send().await?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let status = response.status();
|
||||
let body = response.text().await.unwrap_or_default();
|
||||
let sanitized = super::sanitize_api_error(&body);
|
||||
|
||||
if status.as_u16() == 401 || status.as_u16() == 403 {
|
||||
let access_token_path = self.token_dir.join("access-token");
|
||||
tokio::fs::remove_file(&access_token_path).await.ok();
|
||||
}
|
||||
|
||||
anyhow::bail!(
|
||||
"Failed to get Copilot API key ({status}): {sanitized}. \
|
||||
Ensure your GitHub account has an active Copilot subscription."
|
||||
);
|
||||
}
|
||||
|
||||
let info: ApiKeyInfo = response.json().await?;
|
||||
Ok(info)
|
||||
}
|
||||
|
||||
async fn load_api_key_from_disk(&self) -> Option<ApiKeyInfo> {
|
||||
let path = self.token_dir.join("api-key.json");
|
||||
let data = tokio::fs::read_to_string(&path).await.ok()?;
|
||||
serde_json::from_str(&data).ok()
|
||||
}
|
||||
|
||||
async fn save_api_key_to_disk(&self, info: &ApiKeyInfo) {
|
||||
let path = self.token_dir.join("api-key.json");
|
||||
if let Ok(json) = serde_json::to_string_pretty(info) {
|
||||
write_file_secure(&path, &json).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Write a file with 0600 permissions (owner read/write only).
|
||||
/// Uses `spawn_blocking` to avoid blocking the async runtime.
|
||||
async fn write_file_secure(path: &Path, content: &str) {
|
||||
let path = path.to_path_buf();
|
||||
let content = content.to_string();
|
||||
|
||||
let result = tokio::task::spawn_blocking(move || {
|
||||
#[cfg(unix)]
|
||||
{
|
||||
use std::io::Write;
|
||||
use std::os::unix::fs::{OpenOptionsExt, PermissionsExt};
|
||||
|
||||
let mut file = std::fs::OpenOptions::new()
|
||||
.write(true)
|
||||
.create(true)
|
||||
.truncate(true)
|
||||
.mode(0o600)
|
||||
.open(&path)?;
|
||||
file.write_all(content.as_bytes())?;
|
||||
|
||||
std::fs::set_permissions(&path, std::fs::Permissions::from_mode(0o600))?;
|
||||
Ok::<(), std::io::Error>(())
|
||||
}
|
||||
#[cfg(not(unix))]
|
||||
{
|
||||
std::fs::write(&path, &content)?;
|
||||
Ok::<(), std::io::Error>(())
|
||||
}
|
||||
})
|
||||
.await;
|
||||
|
||||
match result {
|
||||
Ok(Ok(())) => {}
|
||||
Ok(Err(err)) => warn!("Failed to write secure file: {err}"),
|
||||
Err(err) => warn!("Failed to spawn blocking write: {err}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Provider for CopilotProvider {
|
||||
async fn chat_with_system(
|
||||
&self,
|
||||
system_prompt: Option<&str>,
|
||||
message: &str,
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
let mut messages = Vec::new();
|
||||
if let Some(system) = system_prompt {
|
||||
messages.push(ApiMessage {
|
||||
role: "system".to_string(),
|
||||
content: Some(system.to_string()),
|
||||
tool_call_id: None,
|
||||
tool_calls: None,
|
||||
});
|
||||
}
|
||||
messages.push(ApiMessage {
|
||||
role: "user".to_string(),
|
||||
content: Some(message.to_string()),
|
||||
tool_call_id: None,
|
||||
tool_calls: None,
|
||||
});
|
||||
|
||||
let response = self
|
||||
.send_chat_request(messages, None, model, temperature)
|
||||
.await?;
|
||||
Ok(response.text.unwrap_or_default())
|
||||
}
|
||||
|
||||
async fn chat_with_history(
|
||||
&self,
|
||||
messages: &[ChatMessage],
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
let response = self
|
||||
.send_chat_request(Self::convert_messages(messages), None, model, temperature)
|
||||
.await?;
|
||||
Ok(response.text.unwrap_or_default())
|
||||
}
|
||||
|
||||
async fn chat(
|
||||
&self,
|
||||
request: ProviderChatRequest<'_>,
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<ProviderChatResponse> {
|
||||
self.send_chat_request(
|
||||
Self::convert_messages(request.messages),
|
||||
request.tools,
|
||||
model,
|
||||
temperature,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
fn supports_native_tools(&self) -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
async fn warmup(&self) -> anyhow::Result<()> {
|
||||
let _ = self.get_api_key().await?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn new_without_token() {
|
||||
let provider = CopilotProvider::new(None);
|
||||
assert!(provider.github_token.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn new_with_token() {
|
||||
let provider = CopilotProvider::new(Some("ghp_test"));
|
||||
assert_eq!(provider.github_token.as_deref(), Some("ghp_test"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn empty_token_treated_as_none() {
|
||||
let provider = CopilotProvider::new(Some(""));
|
||||
assert!(provider.github_token.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn cache_starts_empty() {
|
||||
let provider = CopilotProvider::new(None);
|
||||
let cached = provider.refresh_lock.lock().await;
|
||||
assert!(cached.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn copilot_headers_include_required_fields() {
|
||||
let headers = CopilotProvider::COPILOT_HEADERS;
|
||||
assert!(headers
|
||||
.iter()
|
||||
.any(|(header, _)| *header == "Editor-Version"));
|
||||
assert!(headers
|
||||
.iter()
|
||||
.any(|(header, _)| *header == "Editor-Plugin-Version"));
|
||||
assert!(headers.iter().any(|(header, _)| *header == "User-Agent"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn default_interval_and_expiry() {
|
||||
assert_eq!(default_interval(), 5);
|
||||
assert_eq!(default_expires_in(), 900);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn supports_native_tools() {
|
||||
let provider = CopilotProvider::new(None);
|
||||
assert!(provider.supports_native_tools());
|
||||
}
|
||||
}
|
||||
|
|
@ -1,5 +1,6 @@
|
|||
pub mod anthropic;
|
||||
pub mod compatible;
|
||||
pub mod copilot;
|
||||
pub mod gemini;
|
||||
pub mod ollama;
|
||||
pub mod openai;
|
||||
|
|
@ -37,9 +38,18 @@ fn token_end(input: &str, from: usize) -> usize {
|
|||
|
||||
/// Scrub known secret-like token prefixes from provider error strings.
|
||||
///
|
||||
/// Redacts tokens with prefixes like `sk-`, `xoxb-`, and `xoxp-`.
|
||||
/// Redacts tokens with prefixes like `sk-`, `xoxb-`, `xoxp-`, `ghp_`, `gho_`,
|
||||
/// `ghu_`, and `github_pat_`.
|
||||
pub fn scrub_secret_patterns(input: &str) -> String {
|
||||
const PREFIXES: [&str; 3] = ["sk-", "xoxb-", "xoxp-"];
|
||||
const PREFIXES: [&str; 7] = [
|
||||
"sk-",
|
||||
"xoxb-",
|
||||
"xoxp-",
|
||||
"ghp_",
|
||||
"gho_",
|
||||
"ghu_",
|
||||
"github_pat_",
|
||||
];
|
||||
|
||||
let mut scrubbed = input.to_string();
|
||||
|
||||
|
|
@ -104,9 +114,12 @@ pub async fn api_error(provider: &str, response: reqwest::Response) -> anyhow::E
|
|||
///
|
||||
/// For Anthropic, the provider-specific env var is `ANTHROPIC_OAUTH_TOKEN` (for setup-tokens)
|
||||
/// followed by `ANTHROPIC_API_KEY` (for regular API keys).
|
||||
fn resolve_api_key(name: &str, api_key: Option<&str>) -> Option<String> {
|
||||
if let Some(key) = api_key.map(str::trim).filter(|k| !k.is_empty()) {
|
||||
return Some(key.to_string());
|
||||
fn resolve_provider_credential(name: &str, credential_override: Option<&str>) -> Option<String> {
|
||||
if let Some(raw_override) = credential_override {
|
||||
let trimmed_override = raw_override.trim();
|
||||
if !trimmed_override.is_empty() {
|
||||
return Some(trimmed_override.to_owned());
|
||||
}
|
||||
}
|
||||
|
||||
let provider_env_candidates: Vec<&str> = match name {
|
||||
|
|
@ -135,6 +148,7 @@ fn resolve_api_key(name: &str, api_key: Option<&str>) -> Option<String> {
|
|||
"opencode" | "opencode-zen" => vec!["OPENCODE_API_KEY"],
|
||||
"vercel" | "vercel-ai" => vec!["VERCEL_API_KEY"],
|
||||
"cloudflare" | "cloudflare-ai" => vec!["CLOUDFLARE_API_KEY"],
|
||||
"astrai" => vec!["ASTRAI_API_KEY"],
|
||||
_ => vec![],
|
||||
};
|
||||
|
||||
|
|
@ -182,19 +196,28 @@ fn parse_custom_provider_url(
|
|||
}
|
||||
}
|
||||
|
||||
/// Factory: create the right provider from config
|
||||
#[allow(clippy::too_many_lines)]
|
||||
/// Factory: create the right provider from config (without custom URL)
|
||||
pub fn create_provider(name: &str, api_key: Option<&str>) -> anyhow::Result<Box<dyn Provider>> {
|
||||
let resolved_key = resolve_api_key(name, api_key);
|
||||
let key = resolved_key.as_deref();
|
||||
create_provider_with_url(name, api_key, None)
|
||||
}
|
||||
|
||||
/// Factory: create the right provider from config with optional custom base URL
|
||||
#[allow(clippy::too_many_lines)]
|
||||
pub fn create_provider_with_url(
|
||||
name: &str,
|
||||
api_key: Option<&str>,
|
||||
api_url: Option<&str>,
|
||||
) -> anyhow::Result<Box<dyn Provider>> {
|
||||
let resolved_credential = resolve_provider_credential(name, api_key);
|
||||
#[allow(clippy::option_as_ref_deref)]
|
||||
let key = resolved_credential.as_ref().map(String::as_str);
|
||||
match name {
|
||||
// ── Primary providers (custom implementations) ───────
|
||||
"openrouter" => Ok(Box::new(openrouter::OpenRouterProvider::new(key))),
|
||||
"anthropic" => Ok(Box::new(anthropic::AnthropicProvider::new(key))),
|
||||
"openai" => Ok(Box::new(openai::OpenAiProvider::new(key))),
|
||||
// Ollama is a local service that doesn't use API keys.
|
||||
// The api_key parameter is ignored to avoid it being misinterpreted as a base_url.
|
||||
"ollama" => Ok(Box::new(ollama::OllamaProvider::new(None))),
|
||||
// Ollama uses api_url for custom base URL (e.g. remote Ollama instance)
|
||||
"ollama" => Ok(Box::new(ollama::OllamaProvider::new(api_url))),
|
||||
"gemini" | "google" | "google-gemini" => {
|
||||
Ok(Box::new(gemini::GeminiProvider::new(key)))
|
||||
}
|
||||
|
|
@ -257,7 +280,7 @@ pub fn create_provider(name: &str, api_key: Option<&str>) -> anyhow::Result<Box<
|
|||
"Groq", "https://api.groq.com/openai", key, AuthStyle::Bearer,
|
||||
))),
|
||||
"mistral" => Ok(Box::new(OpenAiCompatibleProvider::new(
|
||||
"Mistral", "https://api.mistral.ai", key, AuthStyle::Bearer,
|
||||
"Mistral", "https://api.mistral.ai/v1", key, AuthStyle::Bearer,
|
||||
))),
|
||||
"xai" | "grok" => Ok(Box::new(OpenAiCompatibleProvider::new(
|
||||
"xAI", "https://api.x.ai", key, AuthStyle::Bearer,
|
||||
|
|
@ -277,11 +300,33 @@ pub fn create_provider(name: &str, api_key: Option<&str>) -> anyhow::Result<Box<
|
|||
"cohere" => Ok(Box::new(OpenAiCompatibleProvider::new(
|
||||
"Cohere", "https://api.cohere.com/compatibility", key, AuthStyle::Bearer,
|
||||
))),
|
||||
"copilot" | "github-copilot" => Ok(Box::new(OpenAiCompatibleProvider::new(
|
||||
"GitHub Copilot", "https://api.githubcopilot.com", key, AuthStyle::Bearer,
|
||||
))),
|
||||
"nvidia" | "nvidia-nim" | "build.nvidia.com" => Ok(Box::new(OpenAiCompatibleProvider::new(
|
||||
"NVIDIA NIM", "https://integrate.api.nvidia.com/v1", key, AuthStyle::Bearer,
|
||||
"copilot" | "github-copilot" => {
|
||||
Ok(Box::new(copilot::CopilotProvider::new(api_key)))
|
||||
},
|
||||
"lmstudio" | "lm-studio" => {
|
||||
let lm_studio_key = api_key
|
||||
.map(str::trim)
|
||||
.filter(|value| !value.is_empty())
|
||||
.unwrap_or("lm-studio");
|
||||
Ok(Box::new(OpenAiCompatibleProvider::new(
|
||||
"LM Studio",
|
||||
"http://localhost:1234/v1",
|
||||
Some(lm_studio_key),
|
||||
AuthStyle::Bearer,
|
||||
)))
|
||||
}
|
||||
"nvidia" | "nvidia-nim" | "build.nvidia.com" => Ok(Box::new(
|
||||
OpenAiCompatibleProvider::new(
|
||||
"NVIDIA NIM",
|
||||
"https://integrate.api.nvidia.com/v1",
|
||||
key,
|
||||
AuthStyle::Bearer,
|
||||
),
|
||||
)),
|
||||
|
||||
// ── AI inference routers ─────────────────────────────
|
||||
"astrai" => Ok(Box::new(OpenAiCompatibleProvider::new(
|
||||
"Astrai", "https://as-trai.com/v1", key, AuthStyle::Bearer,
|
||||
))),
|
||||
|
||||
// ── Bring Your Own Provider (custom URL) ───────────
|
||||
|
|
@ -326,13 +371,14 @@ pub fn create_provider(name: &str, api_key: Option<&str>) -> anyhow::Result<Box<
|
|||
pub fn create_resilient_provider(
|
||||
primary_name: &str,
|
||||
api_key: Option<&str>,
|
||||
api_url: Option<&str>,
|
||||
reliability: &crate::config::ReliabilityConfig,
|
||||
) -> anyhow::Result<Box<dyn Provider>> {
|
||||
let mut providers: Vec<(String, Box<dyn Provider>)> = Vec::new();
|
||||
|
||||
providers.push((
|
||||
primary_name.to_string(),
|
||||
create_provider(primary_name, api_key)?,
|
||||
create_provider_with_url(primary_name, api_key, api_url)?,
|
||||
));
|
||||
|
||||
for fallback in &reliability.fallback_providers {
|
||||
|
|
@ -340,21 +386,13 @@ pub fn create_resilient_provider(
|
|||
continue;
|
||||
}
|
||||
|
||||
if api_key.is_some() && fallback != "ollama" {
|
||||
tracing::warn!(
|
||||
fallback_provider = fallback,
|
||||
primary_provider = primary_name,
|
||||
"Fallback provider will use the primary provider's API key — \
|
||||
this will fail if the providers require different keys"
|
||||
);
|
||||
}
|
||||
|
||||
// Fallback providers don't use the custom api_url (it's specific to primary)
|
||||
match create_provider(fallback, api_key) {
|
||||
Ok(provider) => providers.push((fallback.clone(), provider)),
|
||||
Err(e) => {
|
||||
Err(_error) => {
|
||||
tracing::warn!(
|
||||
fallback_provider = fallback,
|
||||
"Ignoring invalid fallback provider: {e}"
|
||||
"Ignoring invalid fallback provider during initialization"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
@ -377,12 +415,13 @@ pub fn create_resilient_provider(
|
|||
pub fn create_routed_provider(
|
||||
primary_name: &str,
|
||||
api_key: Option<&str>,
|
||||
api_url: Option<&str>,
|
||||
reliability: &crate::config::ReliabilityConfig,
|
||||
model_routes: &[crate::config::ModelRouteConfig],
|
||||
default_model: &str,
|
||||
) -> anyhow::Result<Box<dyn Provider>> {
|
||||
if model_routes.is_empty() {
|
||||
return create_resilient_provider(primary_name, api_key, reliability);
|
||||
return create_resilient_provider(primary_name, api_key, api_url, reliability);
|
||||
}
|
||||
|
||||
// Collect unique provider names needed
|
||||
|
|
@ -396,12 +435,19 @@ pub fn create_routed_provider(
|
|||
// Create each provider (with its own resilience wrapper)
|
||||
let mut providers: Vec<(String, Box<dyn Provider>)> = Vec::new();
|
||||
for name in &needed {
|
||||
let key = model_routes
|
||||
let routed_credential = model_routes
|
||||
.iter()
|
||||
.find(|r| &r.provider == name)
|
||||
.and_then(|r| r.api_key.as_deref())
|
||||
.or(api_key);
|
||||
match create_resilient_provider(name, key, reliability) {
|
||||
.and_then(|r| {
|
||||
r.api_key.as_ref().and_then(|raw_key| {
|
||||
let trimmed_key = raw_key.trim();
|
||||
(!trimmed_key.is_empty()).then_some(trimmed_key)
|
||||
})
|
||||
});
|
||||
let key = routed_credential.or(api_key);
|
||||
// Only use api_url for the primary provider
|
||||
let url = if name == primary_name { api_url } else { None };
|
||||
match create_resilient_provider(name, key, url, reliability) {
|
||||
Ok(provider) => providers.push((name.clone(), provider)),
|
||||
Err(e) => {
|
||||
if name == primary_name {
|
||||
|
|
@ -409,7 +455,7 @@ pub fn create_routed_provider(
|
|||
}
|
||||
tracing::warn!(
|
||||
provider = name.as_str(),
|
||||
"Ignoring routed provider that failed to create: {e}"
|
||||
"Ignoring routed provider that failed to initialize"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
@ -441,27 +487,27 @@ mod tests {
|
|||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn resolve_api_key_prefers_explicit_argument() {
|
||||
let resolved = resolve_api_key("openrouter", Some(" explicit-key "));
|
||||
assert_eq!(resolved.as_deref(), Some("explicit-key"));
|
||||
fn resolve_provider_credential_prefers_explicit_argument() {
|
||||
let resolved = resolve_provider_credential("openrouter", Some(" explicit-key "));
|
||||
assert_eq!(resolved, Some("explicit-key".to_string()));
|
||||
}
|
||||
|
||||
// ── Primary providers ────────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn factory_openrouter() {
|
||||
assert!(create_provider("openrouter", Some("sk-test")).is_ok());
|
||||
assert!(create_provider("openrouter", Some("provider-test-credential")).is_ok());
|
||||
assert!(create_provider("openrouter", None).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn factory_anthropic() {
|
||||
assert!(create_provider("anthropic", Some("sk-test")).is_ok());
|
||||
assert!(create_provider("anthropic", Some("provider-test-credential")).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn factory_openai() {
|
||||
assert!(create_provider("openai", Some("sk-test")).is_ok());
|
||||
assert!(create_provider("openai", Some("provider-test-credential")).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -556,6 +602,13 @@ mod tests {
|
|||
assert!(create_provider("dashscope-us", Some("key")).is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn factory_lmstudio() {
|
||||
assert!(create_provider("lmstudio", Some("key")).is_ok());
|
||||
assert!(create_provider("lm-studio", Some("key")).is_ok());
|
||||
assert!(create_provider("lmstudio", None).is_ok());
|
||||
}
|
||||
|
||||
// ── Extended ecosystem ───────────────────────────────────
|
||||
|
||||
#[test]
|
||||
|
|
@ -614,6 +667,13 @@ mod tests {
|
|||
assert!(create_provider("build.nvidia.com", Some("nvapi-test")).is_ok());
|
||||
}
|
||||
|
||||
// ── AI inference routers ─────────────────────────────────
|
||||
|
||||
#[test]
|
||||
fn factory_astrai() {
|
||||
assert!(create_provider("astrai", Some("sk-astrai-test")).is_ok());
|
||||
}
|
||||
|
||||
// ── Custom / BYOP provider ─────────────────────────────
|
||||
|
||||
#[test]
|
||||
|
|
@ -761,17 +821,33 @@ mod tests {
|
|||
scheduler_retries: 2,
|
||||
};
|
||||
|
||||
let provider = create_resilient_provider("openrouter", Some("sk-test"), &reliability);
|
||||
let provider = create_resilient_provider(
|
||||
"openrouter",
|
||||
Some("provider-test-credential"),
|
||||
None,
|
||||
&reliability,
|
||||
);
|
||||
assert!(provider.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resilient_provider_errors_for_invalid_primary() {
|
||||
let reliability = crate::config::ReliabilityConfig::default();
|
||||
let provider = create_resilient_provider("totally-invalid", Some("sk-test"), &reliability);
|
||||
let provider = create_resilient_provider(
|
||||
"totally-invalid",
|
||||
Some("provider-test-credential"),
|
||||
None,
|
||||
&reliability,
|
||||
);
|
||||
assert!(provider.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ollama_with_custom_url() {
|
||||
let provider = create_provider_with_url("ollama", None, Some("http://10.100.2.32:11434"));
|
||||
assert!(provider.is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn factory_all_providers_create_successfully() {
|
||||
let providers = [
|
||||
|
|
@ -794,6 +870,7 @@ mod tests {
|
|||
"qwen",
|
||||
"qwen-intl",
|
||||
"qwen-us",
|
||||
"lmstudio",
|
||||
"groq",
|
||||
"mistral",
|
||||
"xai",
|
||||
|
|
@ -888,7 +965,7 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn sanitize_preserves_unicode_boundaries() {
|
||||
let input = format!("{} sk-abcdef123", "こんにちは".repeat(80));
|
||||
let input = format!("{} sk-abcdef123", "hello🙂".repeat(80));
|
||||
let result = sanitize_api_error(&input);
|
||||
assert!(std::str::from_utf8(result.as_bytes()).is_ok());
|
||||
assert!(!result.contains("sk-abcdef123"));
|
||||
|
|
@ -900,4 +977,32 @@ mod tests {
|
|||
let result = sanitize_api_error(input);
|
||||
assert_eq!(result, input);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn scrub_github_personal_access_token() {
|
||||
let input = "auth failed with token ghp_abc123def456";
|
||||
let result = scrub_secret_patterns(input);
|
||||
assert_eq!(result, "auth failed with token [REDACTED]");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn scrub_github_oauth_token() {
|
||||
let input = "Bearer gho_1234567890abcdef";
|
||||
let result = scrub_secret_patterns(input);
|
||||
assert_eq!(result, "Bearer [REDACTED]");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn scrub_github_user_token() {
|
||||
let input = "token ghu_sessiontoken123";
|
||||
let result = scrub_secret_patterns(input);
|
||||
assert_eq!(result, "token [REDACTED]");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn scrub_github_fine_grained_pat() {
|
||||
let input = "failed: github_pat_11AABBC_xyzzy789";
|
||||
let result = scrub_secret_patterns(input);
|
||||
assert_eq!(result, "failed: [REDACTED]");
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8,6 +8,8 @@ pub struct OllamaProvider {
|
|||
client: Client,
|
||||
}
|
||||
|
||||
// ─── Request Structures ───────────────────────────────────────────────────────
|
||||
|
||||
#[derive(Debug, Serialize)]
|
||||
struct ChatRequest {
|
||||
model: String,
|
||||
|
|
@ -27,6 +29,8 @@ struct Options {
|
|||
temperature: f64,
|
||||
}
|
||||
|
||||
// ─── Response Structures ──────────────────────────────────────────────────────
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ApiChatResponse {
|
||||
message: ResponseMessage,
|
||||
|
|
@ -34,9 +38,30 @@ struct ApiChatResponse {
|
|||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct ResponseMessage {
|
||||
#[serde(default)]
|
||||
content: String,
|
||||
#[serde(default)]
|
||||
tool_calls: Vec<OllamaToolCall>,
|
||||
/// Some models return a "thinking" field with internal reasoning
|
||||
#[serde(default)]
|
||||
thinking: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct OllamaToolCall {
|
||||
id: Option<String>,
|
||||
function: OllamaFunction,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct OllamaFunction {
|
||||
name: String,
|
||||
#[serde(default)]
|
||||
arguments: serde_json::Value,
|
||||
}
|
||||
|
||||
// ─── Implementation ───────────────────────────────────────────────────────────
|
||||
|
||||
impl OllamaProvider {
|
||||
pub fn new(base_url: Option<&str>) -> Self {
|
||||
Self {
|
||||
|
|
@ -45,12 +70,145 @@ impl OllamaProvider {
|
|||
.trim_end_matches('/')
|
||||
.to_string(),
|
||||
client: Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(300)) // Ollama runs locally, may be slow
|
||||
.timeout(std::time::Duration::from_secs(300))
|
||||
.connect_timeout(std::time::Duration::from_secs(10))
|
||||
.build()
|
||||
.unwrap_or_else(|_| Client::new()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Send a request to Ollama and get the parsed response
|
||||
async fn send_request(
|
||||
&self,
|
||||
messages: Vec<Message>,
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<ApiChatResponse> {
|
||||
let request = ChatRequest {
|
||||
model: model.to_string(),
|
||||
messages,
|
||||
stream: false,
|
||||
options: Options { temperature },
|
||||
};
|
||||
|
||||
let url = format!("{}/api/chat", self.base_url);
|
||||
|
||||
tracing::debug!(
|
||||
"Ollama request: url={} model={} message_count={} temperature={}",
|
||||
url,
|
||||
model,
|
||||
request.messages.len(),
|
||||
temperature
|
||||
);
|
||||
|
||||
let response = self.client.post(&url).json(&request).send().await?;
|
||||
let status = response.status();
|
||||
tracing::debug!("Ollama response status: {}", status);
|
||||
|
||||
let body = response.bytes().await?;
|
||||
tracing::debug!("Ollama response body length: {} bytes", body.len());
|
||||
|
||||
if !status.is_success() {
|
||||
let raw = String::from_utf8_lossy(&body);
|
||||
let sanitized = super::sanitize_api_error(&raw);
|
||||
tracing::error!(
|
||||
"Ollama error response: status={} body_excerpt={}",
|
||||
status,
|
||||
sanitized
|
||||
);
|
||||
anyhow::bail!(
|
||||
"Ollama API error ({}): {}. Is Ollama running? (brew install ollama && ollama serve)",
|
||||
status,
|
||||
sanitized
|
||||
);
|
||||
}
|
||||
|
||||
let chat_response: ApiChatResponse = match serde_json::from_slice(&body) {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
let raw = String::from_utf8_lossy(&body);
|
||||
let sanitized = super::sanitize_api_error(&raw);
|
||||
tracing::error!(
|
||||
"Ollama response deserialization failed: {e}. body_excerpt={}",
|
||||
sanitized
|
||||
);
|
||||
anyhow::bail!("Failed to parse Ollama response: {e}");
|
||||
}
|
||||
};
|
||||
|
||||
Ok(chat_response)
|
||||
}
|
||||
|
||||
/// Convert Ollama tool calls to the JSON format expected by parse_tool_calls in loop_.rs
|
||||
///
|
||||
/// Handles quirky model behavior where tool calls are wrapped:
|
||||
/// - `{"name": "tool_call", "arguments": {"name": "shell", "arguments": {...}}}`
|
||||
/// - `{"name": "tool.shell", "arguments": {...}}`
|
||||
fn format_tool_calls_for_loop(&self, tool_calls: &[OllamaToolCall]) -> String {
|
||||
let formatted_calls: Vec<serde_json::Value> = tool_calls
|
||||
.iter()
|
||||
.map(|tc| {
|
||||
let (tool_name, tool_args) = self.extract_tool_name_and_args(tc);
|
||||
|
||||
// Arguments must be a JSON string for parse_tool_calls compatibility
|
||||
let args_str =
|
||||
serde_json::to_string(&tool_args).unwrap_or_else(|_| "{}".to_string());
|
||||
|
||||
serde_json::json!({
|
||||
"id": tc.id,
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool_name,
|
||||
"arguments": args_str
|
||||
}
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
|
||||
serde_json::json!({
|
||||
"content": "",
|
||||
"tool_calls": formatted_calls
|
||||
})
|
||||
.to_string()
|
||||
}
|
||||
|
||||
/// Extract the actual tool name and arguments from potentially nested structures
|
||||
fn extract_tool_name_and_args(&self, tc: &OllamaToolCall) -> (String, serde_json::Value) {
|
||||
let name = &tc.function.name;
|
||||
let args = &tc.function.arguments;
|
||||
|
||||
// Pattern 1: Nested tool_call wrapper (various malformed versions)
|
||||
// {"name": "tool_call", "arguments": {"name": "shell", "arguments": {"command": "date"}}}
|
||||
// {"name": "tool_call><json", "arguments": {"name": "shell", ...}}
|
||||
// {"name": "tool.call", "arguments": {"name": "shell", ...}}
|
||||
if name == "tool_call"
|
||||
|| name == "tool.call"
|
||||
|| name.starts_with("tool_call>")
|
||||
|| name.starts_with("tool_call<")
|
||||
{
|
||||
if let Some(nested_name) = args.get("name").and_then(|v| v.as_str()) {
|
||||
let nested_args = args
|
||||
.get("arguments")
|
||||
.cloned()
|
||||
.unwrap_or(serde_json::json!({}));
|
||||
tracing::debug!(
|
||||
"Unwrapped nested tool call: {} -> {} with args {:?}",
|
||||
name,
|
||||
nested_name,
|
||||
nested_args
|
||||
);
|
||||
return (nested_name.to_string(), nested_args);
|
||||
}
|
||||
}
|
||||
|
||||
// Pattern 2: Prefixed tool name (tool.shell, tool.file_read, etc.)
|
||||
if let Some(stripped) = name.strip_prefix("tool.") {
|
||||
return (stripped.to_string(), args.clone());
|
||||
}
|
||||
|
||||
// Pattern 3: Normal tool call
|
||||
(name.clone(), args.clone())
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
|
|
@ -76,27 +234,96 @@ impl Provider for OllamaProvider {
|
|||
content: message.to_string(),
|
||||
});
|
||||
|
||||
let request = ChatRequest {
|
||||
model: model.to_string(),
|
||||
messages,
|
||||
stream: false,
|
||||
options: Options { temperature },
|
||||
};
|
||||
let response = self.send_request(messages, model, temperature).await?;
|
||||
|
||||
let url = format!("{}/api/chat", self.base_url);
|
||||
|
||||
let response = self.client.post(&url).json(&request).send().await?;
|
||||
|
||||
if !response.status().is_success() {
|
||||
let err = super::api_error("Ollama", response).await;
|
||||
anyhow::bail!("{err}. Is Ollama running? (brew install ollama && ollama serve)");
|
||||
// If model returned tool calls, format them for loop_.rs's parse_tool_calls
|
||||
if !response.message.tool_calls.is_empty() {
|
||||
tracing::debug!(
|
||||
"Ollama returned {} tool call(s), formatting for loop parser",
|
||||
response.message.tool_calls.len()
|
||||
);
|
||||
return Ok(self.format_tool_calls_for_loop(&response.message.tool_calls));
|
||||
}
|
||||
|
||||
let chat_response: ApiChatResponse = response.json().await?;
|
||||
Ok(chat_response.message.content)
|
||||
// Plain text response
|
||||
let content = response.message.content;
|
||||
|
||||
// Handle edge case: model returned only "thinking" with no content or tool calls
|
||||
if content.is_empty() {
|
||||
if let Some(thinking) = &response.message.thinking {
|
||||
tracing::warn!(
|
||||
"Ollama returned empty content with only thinking: '{}'. Model may have stopped prematurely.",
|
||||
if thinking.len() > 100 { &thinking[..100] } else { thinking }
|
||||
);
|
||||
return Ok(format!(
|
||||
"I was thinking about this: {}... but I didn't complete my response. Could you try asking again?",
|
||||
if thinking.len() > 200 { &thinking[..200] } else { thinking }
|
||||
));
|
||||
}
|
||||
tracing::warn!("Ollama returned empty content with no tool calls");
|
||||
}
|
||||
|
||||
Ok(content)
|
||||
}
|
||||
|
||||
async fn chat_with_history(
|
||||
&self,
|
||||
messages: &[crate::providers::ChatMessage],
|
||||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
let api_messages: Vec<Message> = messages
|
||||
.iter()
|
||||
.map(|m| Message {
|
||||
role: m.role.clone(),
|
||||
content: m.content.clone(),
|
||||
})
|
||||
.collect();
|
||||
|
||||
let response = self.send_request(api_messages, model, temperature).await?;
|
||||
|
||||
// If model returned tool calls, format them for loop_.rs's parse_tool_calls
|
||||
if !response.message.tool_calls.is_empty() {
|
||||
tracing::debug!(
|
||||
"Ollama returned {} tool call(s), formatting for loop parser",
|
||||
response.message.tool_calls.len()
|
||||
);
|
||||
return Ok(self.format_tool_calls_for_loop(&response.message.tool_calls));
|
||||
}
|
||||
|
||||
// Plain text response
|
||||
let content = response.message.content;
|
||||
|
||||
// Handle edge case: model returned only "thinking" with no content or tool calls
|
||||
// This is a model quirk - it stopped after reasoning without producing output
|
||||
if content.is_empty() {
|
||||
if let Some(thinking) = &response.message.thinking {
|
||||
tracing::warn!(
|
||||
"Ollama returned empty content with only thinking: '{}'. Model may have stopped prematurely.",
|
||||
if thinking.len() > 100 { &thinking[..100] } else { thinking }
|
||||
);
|
||||
// Return a message indicating the model's thought process but no action
|
||||
return Ok(format!(
|
||||
"I was thinking about this: {}... but I didn't complete my response. Could you try asking again?",
|
||||
if thinking.len() > 200 { &thinking[..200] } else { thinking }
|
||||
));
|
||||
}
|
||||
tracing::warn!("Ollama returned empty content with no tool calls");
|
||||
}
|
||||
|
||||
Ok(content)
|
||||
}
|
||||
|
||||
fn supports_native_tools(&self) -> bool {
|
||||
// Return false since loop_.rs uses XML-style tool parsing via system prompt
|
||||
// The model may return native tool_calls but we convert them to JSON format
|
||||
// that parse_tool_calls() understands
|
||||
false
|
||||
}
|
||||
}
|
||||
|
||||
// ─── Tests ────────────────────────────────────────────────────────────────────
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
|
@ -125,46 +352,6 @@ mod tests {
|
|||
assert_eq!(p.base_url, "");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn request_serializes_with_system() {
|
||||
let req = ChatRequest {
|
||||
model: "llama3".to_string(),
|
||||
messages: vec![
|
||||
Message {
|
||||
role: "system".to_string(),
|
||||
content: "You are ZeroClaw".to_string(),
|
||||
},
|
||||
Message {
|
||||
role: "user".to_string(),
|
||||
content: "hello".to_string(),
|
||||
},
|
||||
],
|
||||
stream: false,
|
||||
options: Options { temperature: 0.7 },
|
||||
};
|
||||
let json = serde_json::to_string(&req).unwrap();
|
||||
assert!(json.contains("\"stream\":false"));
|
||||
assert!(json.contains("llama3"));
|
||||
assert!(json.contains("system"));
|
||||
assert!(json.contains("\"temperature\":0.7"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn request_serializes_without_system() {
|
||||
let req = ChatRequest {
|
||||
model: "mistral".to_string(),
|
||||
messages: vec![Message {
|
||||
role: "user".to_string(),
|
||||
content: "test".to_string(),
|
||||
}],
|
||||
stream: false,
|
||||
options: Options { temperature: 0.0 },
|
||||
};
|
||||
let json = serde_json::to_string(&req).unwrap();
|
||||
assert!(!json.contains("\"role\":\"system\""));
|
||||
assert!(json.contains("mistral"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_deserializes() {
|
||||
let json = r#"{"message":{"role":"assistant","content":"Hello from Ollama!"}}"#;
|
||||
|
|
@ -180,9 +367,98 @@ mod tests {
|
|||
}
|
||||
|
||||
#[test]
|
||||
fn response_with_multiline() {
|
||||
let json = r#"{"message":{"role":"assistant","content":"line1\nline2\nline3"}}"#;
|
||||
fn response_with_missing_content_defaults_to_empty() {
|
||||
let json = r#"{"message":{"role":"assistant"}}"#;
|
||||
let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
|
||||
assert!(resp.message.content.contains("line1"));
|
||||
assert!(resp.message.content.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_with_thinking_field_extracts_content() {
|
||||
let json =
|
||||
r#"{"message":{"role":"assistant","content":"hello","thinking":"internal reasoning"}}"#;
|
||||
let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
|
||||
assert_eq!(resp.message.content, "hello");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn response_with_tool_calls_parses_correctly() {
|
||||
let json = r#"{"message":{"role":"assistant","content":"","tool_calls":[{"id":"call_123","function":{"name":"shell","arguments":{"command":"date"}}}]}}"#;
|
||||
let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
|
||||
assert!(resp.message.content.is_empty());
|
||||
assert_eq!(resp.message.tool_calls.len(), 1);
|
||||
assert_eq!(resp.message.tool_calls[0].function.name, "shell");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_tool_name_handles_nested_tool_call() {
|
||||
let provider = OllamaProvider::new(None);
|
||||
let tc = OllamaToolCall {
|
||||
id: Some("call_123".into()),
|
||||
function: OllamaFunction {
|
||||
name: "tool_call".into(),
|
||||
arguments: serde_json::json!({
|
||||
"name": "shell",
|
||||
"arguments": {"command": "date"}
|
||||
}),
|
||||
},
|
||||
};
|
||||
let (name, args) = provider.extract_tool_name_and_args(&tc);
|
||||
assert_eq!(name, "shell");
|
||||
assert_eq!(args.get("command").unwrap(), "date");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_tool_name_handles_prefixed_name() {
|
||||
let provider = OllamaProvider::new(None);
|
||||
let tc = OllamaToolCall {
|
||||
id: Some("call_123".into()),
|
||||
function: OllamaFunction {
|
||||
name: "tool.shell".into(),
|
||||
arguments: serde_json::json!({"command": "ls"}),
|
||||
},
|
||||
};
|
||||
let (name, args) = provider.extract_tool_name_and_args(&tc);
|
||||
assert_eq!(name, "shell");
|
||||
assert_eq!(args.get("command").unwrap(), "ls");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn extract_tool_name_handles_normal_call() {
|
||||
let provider = OllamaProvider::new(None);
|
||||
let tc = OllamaToolCall {
|
||||
id: Some("call_123".into()),
|
||||
function: OllamaFunction {
|
||||
name: "file_read".into(),
|
||||
arguments: serde_json::json!({"path": "/tmp/test"}),
|
||||
},
|
||||
};
|
||||
let (name, args) = provider.extract_tool_name_and_args(&tc);
|
||||
assert_eq!(name, "file_read");
|
||||
assert_eq!(args.get("path").unwrap(), "/tmp/test");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn format_tool_calls_produces_valid_json() {
|
||||
let provider = OllamaProvider::new(None);
|
||||
let tool_calls = vec![OllamaToolCall {
|
||||
id: Some("call_abc".into()),
|
||||
function: OllamaFunction {
|
||||
name: "shell".into(),
|
||||
arguments: serde_json::json!({"command": "date"}),
|
||||
},
|
||||
}];
|
||||
|
||||
let formatted = provider.format_tool_calls_for_loop(&tool_calls);
|
||||
let parsed: serde_json::Value = serde_json::from_str(&formatted).unwrap();
|
||||
|
||||
assert!(parsed.get("tool_calls").is_some());
|
||||
let calls = parsed.get("tool_calls").unwrap().as_array().unwrap();
|
||||
assert_eq!(calls.len(), 1);
|
||||
|
||||
let func = calls[0].get("function").unwrap();
|
||||
assert_eq!(func.get("name").unwrap(), "shell");
|
||||
// arguments should be a string (JSON-encoded)
|
||||
assert!(func.get("arguments").unwrap().is_string());
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ use reqwest::Client;
|
|||
use serde::{Deserialize, Serialize};
|
||||
|
||||
pub struct OpenAiProvider {
|
||||
api_key: Option<String>,
|
||||
credential: Option<String>,
|
||||
client: Client,
|
||||
}
|
||||
|
||||
|
|
@ -110,9 +110,9 @@ struct NativeResponseMessage {
|
|||
}
|
||||
|
||||
impl OpenAiProvider {
|
||||
pub fn new(api_key: Option<&str>) -> Self {
|
||||
pub fn new(credential: Option<&str>) -> Self {
|
||||
Self {
|
||||
api_key: api_key.map(ToString::to_string),
|
||||
credential: credential.map(ToString::to_string),
|
||||
client: Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(120))
|
||||
.connect_timeout(std::time::Duration::from_secs(10))
|
||||
|
|
@ -232,7 +232,7 @@ impl Provider for OpenAiProvider {
|
|||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
let api_key = self.api_key.as_ref().ok_or_else(|| {
|
||||
let credential = self.credential.as_ref().ok_or_else(|| {
|
||||
anyhow::anyhow!("OpenAI API key not set. Set OPENAI_API_KEY or edit config.toml.")
|
||||
})?;
|
||||
|
||||
|
|
@ -259,7 +259,7 @@ impl Provider for OpenAiProvider {
|
|||
let response = self
|
||||
.client
|
||||
.post("https://api.openai.com/v1/chat/completions")
|
||||
.header("Authorization", format!("Bearer {api_key}"))
|
||||
.header("Authorization", format!("Bearer {credential}"))
|
||||
.json(&request)
|
||||
.send()
|
||||
.await?;
|
||||
|
|
@ -284,7 +284,7 @@ impl Provider for OpenAiProvider {
|
|||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<ProviderChatResponse> {
|
||||
let api_key = self.api_key.as_ref().ok_or_else(|| {
|
||||
let credential = self.credential.as_ref().ok_or_else(|| {
|
||||
anyhow::anyhow!("OpenAI API key not set. Set OPENAI_API_KEY or edit config.toml.")
|
||||
})?;
|
||||
|
||||
|
|
@ -300,7 +300,7 @@ impl Provider for OpenAiProvider {
|
|||
let response = self
|
||||
.client
|
||||
.post("https://api.openai.com/v1/chat/completions")
|
||||
.header("Authorization", format!("Bearer {api_key}"))
|
||||
.header("Authorization", format!("Bearer {credential}"))
|
||||
.json(&native_request)
|
||||
.send()
|
||||
.await?;
|
||||
|
|
@ -330,20 +330,20 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn creates_with_key() {
|
||||
let p = OpenAiProvider::new(Some("sk-proj-abc123"));
|
||||
assert_eq!(p.api_key.as_deref(), Some("sk-proj-abc123"));
|
||||
let p = OpenAiProvider::new(Some("openai-test-credential"));
|
||||
assert_eq!(p.credential.as_deref(), Some("openai-test-credential"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn creates_without_key() {
|
||||
let p = OpenAiProvider::new(None);
|
||||
assert!(p.api_key.is_none());
|
||||
assert!(p.credential.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn creates_with_empty_key() {
|
||||
let p = OpenAiProvider::new(Some(""));
|
||||
assert_eq!(p.api_key.as_deref(), Some(""));
|
||||
assert_eq!(p.credential.as_deref(), Some(""));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ use reqwest::Client;
|
|||
use serde::{Deserialize, Serialize};
|
||||
|
||||
pub struct OpenRouterProvider {
|
||||
api_key: Option<String>,
|
||||
credential: Option<String>,
|
||||
client: Client,
|
||||
}
|
||||
|
||||
|
|
@ -110,9 +110,9 @@ struct NativeResponseMessage {
|
|||
}
|
||||
|
||||
impl OpenRouterProvider {
|
||||
pub fn new(api_key: Option<&str>) -> Self {
|
||||
pub fn new(credential: Option<&str>) -> Self {
|
||||
Self {
|
||||
api_key: api_key.map(ToString::to_string),
|
||||
credential: credential.map(ToString::to_string),
|
||||
client: Client::builder()
|
||||
.timeout(std::time::Duration::from_secs(120))
|
||||
.connect_timeout(std::time::Duration::from_secs(10))
|
||||
|
|
@ -232,10 +232,10 @@ impl Provider for OpenRouterProvider {
|
|||
async fn warmup(&self) -> anyhow::Result<()> {
|
||||
// Hit a lightweight endpoint to establish TLS + HTTP/2 connection pool.
|
||||
// This prevents the first real chat request from timing out on cold start.
|
||||
if let Some(api_key) = self.api_key.as_ref() {
|
||||
if let Some(credential) = self.credential.as_ref() {
|
||||
self.client
|
||||
.get("https://openrouter.ai/api/v1/auth/key")
|
||||
.header("Authorization", format!("Bearer {api_key}"))
|
||||
.header("Authorization", format!("Bearer {credential}"))
|
||||
.send()
|
||||
.await?
|
||||
.error_for_status()?;
|
||||
|
|
@ -250,7 +250,7 @@ impl Provider for OpenRouterProvider {
|
|||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
let api_key = self.api_key.as_ref()
|
||||
let credential = self.credential.as_ref()
|
||||
.ok_or_else(|| anyhow::anyhow!("OpenRouter API key not set. Run `zeroclaw onboard` or set OPENROUTER_API_KEY env var."))?;
|
||||
|
||||
let mut messages = Vec::new();
|
||||
|
|
@ -276,7 +276,7 @@ impl Provider for OpenRouterProvider {
|
|||
let response = self
|
||||
.client
|
||||
.post("https://openrouter.ai/api/v1/chat/completions")
|
||||
.header("Authorization", format!("Bearer {api_key}"))
|
||||
.header("Authorization", format!("Bearer {credential}"))
|
||||
.header(
|
||||
"HTTP-Referer",
|
||||
"https://github.com/theonlyhennygod/zeroclaw",
|
||||
|
|
@ -306,7 +306,7 @@ impl Provider for OpenRouterProvider {
|
|||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
let api_key = self.api_key.as_ref()
|
||||
let credential = self.credential.as_ref()
|
||||
.ok_or_else(|| anyhow::anyhow!("OpenRouter API key not set. Run `zeroclaw onboard` or set OPENROUTER_API_KEY env var."))?;
|
||||
|
||||
let api_messages: Vec<Message> = messages
|
||||
|
|
@ -326,7 +326,7 @@ impl Provider for OpenRouterProvider {
|
|||
let response = self
|
||||
.client
|
||||
.post("https://openrouter.ai/api/v1/chat/completions")
|
||||
.header("Authorization", format!("Bearer {api_key}"))
|
||||
.header("Authorization", format!("Bearer {credential}"))
|
||||
.header(
|
||||
"HTTP-Referer",
|
||||
"https://github.com/theonlyhennygod/zeroclaw",
|
||||
|
|
@ -356,7 +356,7 @@ impl Provider for OpenRouterProvider {
|
|||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<ProviderChatResponse> {
|
||||
let api_key = self.api_key.as_ref().ok_or_else(|| {
|
||||
let credential = self.credential.as_ref().ok_or_else(|| {
|
||||
anyhow::anyhow!(
|
||||
"OpenRouter API key not set. Run `zeroclaw onboard` or set OPENROUTER_API_KEY env var."
|
||||
)
|
||||
|
|
@ -374,7 +374,7 @@ impl Provider for OpenRouterProvider {
|
|||
let response = self
|
||||
.client
|
||||
.post("https://openrouter.ai/api/v1/chat/completions")
|
||||
.header("Authorization", format!("Bearer {api_key}"))
|
||||
.header("Authorization", format!("Bearer {credential}"))
|
||||
.header(
|
||||
"HTTP-Referer",
|
||||
"https://github.com/theonlyhennygod/zeroclaw",
|
||||
|
|
@ -409,7 +409,7 @@ impl Provider for OpenRouterProvider {
|
|||
model: &str,
|
||||
temperature: f64,
|
||||
) -> anyhow::Result<ProviderChatResponse> {
|
||||
let api_key = self.api_key.as_ref().ok_or_else(|| {
|
||||
let credential = self.credential.as_ref().ok_or_else(|| {
|
||||
anyhow::anyhow!(
|
||||
"OpenRouter API key not set. Run `zeroclaw onboard` or set OPENROUTER_API_KEY env var."
|
||||
)
|
||||
|
|
@ -462,7 +462,7 @@ impl Provider for OpenRouterProvider {
|
|||
let response = self
|
||||
.client
|
||||
.post("https://openrouter.ai/api/v1/chat/completions")
|
||||
.header("Authorization", format!("Bearer {api_key}"))
|
||||
.header("Authorization", format!("Bearer {credential}"))
|
||||
.header(
|
||||
"HTTP-Referer",
|
||||
"https://github.com/theonlyhennygod/zeroclaw",
|
||||
|
|
@ -494,14 +494,17 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn creates_with_key() {
|
||||
let provider = OpenRouterProvider::new(Some("sk-or-123"));
|
||||
assert_eq!(provider.api_key.as_deref(), Some("sk-or-123"));
|
||||
let provider = OpenRouterProvider::new(Some("openrouter-test-credential"));
|
||||
assert_eq!(
|
||||
provider.credential.as_deref(),
|
||||
Some("openrouter-test-credential")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn creates_without_key() {
|
||||
let provider = OpenRouterProvider::new(None);
|
||||
assert!(provider.api_key.is_none());
|
||||
assert!(provider.credential.is_none());
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
|
|
|||
|
|
@ -144,8 +144,8 @@ impl Provider for ReliableProvider {
|
|||
async fn warmup(&self) -> anyhow::Result<()> {
|
||||
for (name, provider) in &self.providers {
|
||||
tracing::info!(provider = name, "Warming up provider connection pool");
|
||||
if let Err(e) = provider.warmup().await {
|
||||
tracing::warn!(provider = name, "Warmup failed (non-fatal): {e}");
|
||||
if provider.warmup().await.is_err() {
|
||||
tracing::warn!(provider = name, "Warmup failed (non-fatal)");
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
|
|
@ -186,8 +186,15 @@ impl Provider for ReliableProvider {
|
|||
let non_retryable = is_non_retryable(&e);
|
||||
let rate_limited = is_rate_limited(&e);
|
||||
|
||||
let failure_reason = if rate_limited {
|
||||
"rate_limited"
|
||||
} else if non_retryable {
|
||||
"non_retryable"
|
||||
} else {
|
||||
"retryable"
|
||||
};
|
||||
failures.push(format!(
|
||||
"{provider_name}/{current_model} attempt {}/{}: {e}",
|
||||
"{provider_name}/{current_model} attempt {}/{}: {failure_reason}",
|
||||
attempt + 1,
|
||||
self.max_retries + 1
|
||||
));
|
||||
|
|
@ -284,8 +291,15 @@ impl Provider for ReliableProvider {
|
|||
let non_retryable = is_non_retryable(&e);
|
||||
let rate_limited = is_rate_limited(&e);
|
||||
|
||||
let failure_reason = if rate_limited {
|
||||
"rate_limited"
|
||||
} else if non_retryable {
|
||||
"non_retryable"
|
||||
} else {
|
||||
"retryable"
|
||||
};
|
||||
failures.push(format!(
|
||||
"{provider_name}/{current_model} attempt {}/{}: {e}",
|
||||
"{provider_name}/{current_model} attempt {}/{}: {failure_reason}",
|
||||
attempt + 1,
|
||||
self.max_retries + 1
|
||||
));
|
||||
|
|
|
|||
|
|
@ -193,6 +193,13 @@ pub enum StreamError {
|
|||
|
||||
#[async_trait]
|
||||
pub trait Provider: Send + Sync {
|
||||
/// Query provider capabilities.
|
||||
///
|
||||
/// Default implementation returns minimal capabilities (no native tool calling).
|
||||
/// Providers should override this to declare their actual capabilities.
|
||||
fn capabilities(&self) -> ProviderCapabilities {
|
||||
ProviderCapabilities::default()
|
||||
}
|
||||
/// Simple one-shot chat (single user message, no explicit system prompt).
|
||||
///
|
||||
/// This is the preferred API for non-agentic direct interactions.
|
||||
|
|
@ -256,7 +263,7 @@ pub trait Provider: Send + Sync {
|
|||
|
||||
/// Whether provider supports native tool calls over API.
|
||||
fn supports_native_tools(&self) -> bool {
|
||||
false
|
||||
self.capabilities().native_tool_calling
|
||||
}
|
||||
|
||||
/// Warm up the HTTP connection pool (TLS handshake, DNS, HTTP/2 setup).
|
||||
|
|
@ -336,6 +343,27 @@ pub trait Provider: Send + Sync {
|
|||
mod tests {
|
||||
use super::*;
|
||||
|
||||
struct CapabilityMockProvider;
|
||||
|
||||
#[async_trait]
|
||||
impl Provider for CapabilityMockProvider {
|
||||
fn capabilities(&self) -> ProviderCapabilities {
|
||||
ProviderCapabilities {
|
||||
native_tool_calling: true,
|
||||
}
|
||||
}
|
||||
|
||||
async fn chat_with_system(
|
||||
&self,
|
||||
_system_prompt: Option<&str>,
|
||||
_message: &str,
|
||||
_model: &str,
|
||||
_temperature: f64,
|
||||
) -> anyhow::Result<String> {
|
||||
Ok("ok".into())
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn chat_message_constructors() {
|
||||
let sys = ChatMessage::system("Be helpful");
|
||||
|
|
@ -398,4 +426,32 @@ mod tests {
|
|||
let json = serde_json::to_string(&tool_result).unwrap();
|
||||
assert!(json.contains("\"type\":\"ToolResults\""));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn provider_capabilities_default() {
|
||||
let caps = ProviderCapabilities::default();
|
||||
assert!(!caps.native_tool_calling);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn provider_capabilities_equality() {
|
||||
let caps1 = ProviderCapabilities {
|
||||
native_tool_calling: true,
|
||||
};
|
||||
let caps2 = ProviderCapabilities {
|
||||
native_tool_calling: true,
|
||||
};
|
||||
let caps3 = ProviderCapabilities {
|
||||
native_tool_calling: false,
|
||||
};
|
||||
|
||||
assert_eq!(caps1, caps2);
|
||||
assert_ne!(caps1, caps3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn supports_native_tools_reflects_capabilities_default_mapping() {
|
||||
let provider = CapabilityMockProvider;
|
||||
assert!(provider.supports_native_tools());
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -81,14 +81,17 @@ mod tests {
|
|||
|
||||
#[test]
|
||||
fn bubblewrap_sandbox_name() {
|
||||
assert_eq!(BubblewrapSandbox.name(), "bubblewrap");
|
||||
let sandbox = BubblewrapSandbox;
|
||||
assert_eq!(sandbox.name(), "bubblewrap");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn bubblewrap_is_available_only_if_installed() {
|
||||
// Result depends on whether bwrap is installed
|
||||
let available = BubblewrapSandbox::is_available();
|
||||
let sandbox = BubblewrapSandbox;
|
||||
let _available = sandbox.is_available();
|
||||
|
||||
// Either way, the name should still work
|
||||
assert_eq!(BubblewrapSandbox.name(), "bubblewrap");
|
||||
assert_eq!(sandbox.name(), "bubblewrap");
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -184,7 +184,7 @@ fn generate_token() -> String {
|
|||
use rand::RngCore;
|
||||
let mut bytes = [0u8; 32];
|
||||
rand::thread_rng().fill_bytes(&mut bytes);
|
||||
format!("zc_{}", hex::encode(&bytes))
|
||||
format!("zc_{}", hex::encode(bytes))
|
||||
}
|
||||
|
||||
/// SHA-256 hash a bearer token for storage. Returns lowercase hex.
|
||||
|
|
|
|||
|
|
@ -343,6 +343,7 @@ impl SecurityPolicy {
|
|||
/// validates each sub-command against the allowlist
|
||||
/// - Blocks single `&` background chaining (`&&` remains supported)
|
||||
/// - Blocks output redirections (`>`, `>>`) that could write outside workspace
|
||||
/// - Blocks dangerous arguments (e.g. `find -exec`, `git config`)
|
||||
pub fn is_command_allowed(&self, command: &str) -> bool {
|
||||
if self.autonomy == AutonomyLevel::ReadOnly {
|
||||
return false;
|
||||
|
|
@ -350,7 +351,12 @@ impl SecurityPolicy {
|
|||
|
||||
// Block subshell/expansion operators — these allow hiding arbitrary
|
||||
// commands inside an allowed command (e.g. `echo $(rm -rf /)`)
|
||||
if command.contains('`') || command.contains("$(") || command.contains("${") {
|
||||
if command.contains('`')
|
||||
|| command.contains("$(")
|
||||
|| command.contains("${")
|
||||
|| command.contains("<(")
|
||||
|| command.contains(">(")
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
|
|
@ -359,6 +365,15 @@ impl SecurityPolicy {
|
|||
return false;
|
||||
}
|
||||
|
||||
// Block `tee` — it can write to arbitrary files, bypassing the
|
||||
// redirect check above (e.g. `echo secret | tee /etc/crontab`)
|
||||
if command
|
||||
.split_whitespace()
|
||||
.any(|w| w == "tee" || w.ends_with("/tee"))
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// Block background command chaining (`&`), which can hide extra
|
||||
// sub-commands and outlive timeout expectations. Keep `&&` allowed.
|
||||
if contains_single_ampersand(command) {
|
||||
|
|
@ -384,13 +399,9 @@ impl SecurityPolicy {
|
|||
// Strip leading env var assignments (e.g. FOO=bar cmd)
|
||||
let cmd_part = skip_env_assignments(segment);
|
||||
|
||||
let base_cmd = cmd_part
|
||||
.split_whitespace()
|
||||
.next()
|
||||
.unwrap_or("")
|
||||
.rsplit('/')
|
||||
.next()
|
||||
.unwrap_or("");
|
||||
let mut words = cmd_part.split_whitespace();
|
||||
let base_raw = words.next().unwrap_or("");
|
||||
let base_cmd = base_raw.rsplit('/').next().unwrap_or("");
|
||||
|
||||
if base_cmd.is_empty() {
|
||||
continue;
|
||||
|
|
@ -403,6 +414,12 @@ impl SecurityPolicy {
|
|||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
// Validate arguments for the command
|
||||
let args: Vec<String> = words.map(|w| w.to_ascii_lowercase()).collect();
|
||||
if !self.is_args_safe(base_cmd, &args) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
// At least one command must be present
|
||||
|
|
@ -414,6 +431,29 @@ impl SecurityPolicy {
|
|||
has_cmd
|
||||
}
|
||||
|
||||
/// Check for dangerous arguments that allow sub-command execution.
|
||||
fn is_args_safe(&self, base: &str, args: &[String]) -> bool {
|
||||
let base = base.to_ascii_lowercase();
|
||||
match base.as_str() {
|
||||
"find" => {
|
||||
// find -exec and find -ok allow arbitrary command execution
|
||||
!args.iter().any(|arg| arg == "-exec" || arg == "-ok")
|
||||
}
|
||||
"git" => {
|
||||
// git config, alias, and -c can be used to set dangerous options
|
||||
// (e.g. git config core.editor "rm -rf /")
|
||||
!args.iter().any(|arg| {
|
||||
arg == "config"
|
||||
|| arg.starts_with("config.")
|
||||
|| arg == "alias"
|
||||
|| arg.starts_with("alias.")
|
||||
|| arg == "-c"
|
||||
})
|
||||
}
|
||||
_ => true,
|
||||
}
|
||||
}
|
||||
|
||||
/// Check if a file path is allowed (no path traversal, within workspace)
|
||||
pub fn is_path_allowed(&self, path: &str) -> bool {
|
||||
// Block null bytes (can truncate paths in C-backed syscalls)
|
||||
|
|
@ -982,12 +1022,43 @@ mod tests {
|
|||
assert!(!p.is_command_allowed("ls >> /tmp/exfil.txt"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn command_argument_injection_blocked() {
|
||||
let p = default_policy();
|
||||
// find -exec is a common bypass
|
||||
assert!(!p.is_command_allowed("find . -exec rm -rf {} +"));
|
||||
assert!(!p.is_command_allowed("find / -ok cat {} \\;"));
|
||||
// git config/alias can execute commands
|
||||
assert!(!p.is_command_allowed("git config core.editor \"rm -rf /\""));
|
||||
assert!(!p.is_command_allowed("git alias.st status"));
|
||||
assert!(!p.is_command_allowed("git -c core.editor=calc.exe commit"));
|
||||
// Legitimate commands should still work
|
||||
assert!(p.is_command_allowed("find . -name '*.txt'"));
|
||||
assert!(p.is_command_allowed("git status"));
|
||||
assert!(p.is_command_allowed("git add ."));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn command_injection_dollar_brace_blocked() {
|
||||
let p = default_policy();
|
||||
assert!(!p.is_command_allowed("echo ${IFS}cat${IFS}/etc/passwd"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn command_injection_tee_blocked() {
|
||||
let p = default_policy();
|
||||
assert!(!p.is_command_allowed("echo secret | tee /etc/crontab"));
|
||||
assert!(!p.is_command_allowed("ls | /usr/bin/tee outfile"));
|
||||
assert!(!p.is_command_allowed("tee file.txt"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn command_injection_process_substitution_blocked() {
|
||||
let p = default_policy();
|
||||
assert!(!p.is_command_allowed("cat <(echo pwned)"));
|
||||
assert!(!p.is_command_allowed("ls >(cat /etc/passwd)"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn command_env_var_prefix_with_allowed_cmd() {
|
||||
let p = default_policy();
|
||||
|
|
|
|||
|
|
@ -854,7 +854,6 @@ impl BrowserTool {
|
|||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_lines)]
|
||||
#[async_trait]
|
||||
impl Tool for BrowserTool {
|
||||
fn name(&self) -> &str {
|
||||
|
|
@ -1031,165 +1030,21 @@ impl Tool for BrowserTool {
|
|||
return self.execute_computer_use_action(action_str, &args).await;
|
||||
}
|
||||
|
||||
let action = match action_str {
|
||||
"open" => {
|
||||
let url = args
|
||||
.get("url")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'url' for open action"))?;
|
||||
BrowserAction::Open { url: url.into() }
|
||||
}
|
||||
"snapshot" => BrowserAction::Snapshot {
|
||||
interactive_only: args
|
||||
.get("interactive_only")
|
||||
.and_then(serde_json::Value::as_bool)
|
||||
.unwrap_or(true), // Default to interactive for AI
|
||||
compact: args
|
||||
.get("compact")
|
||||
.and_then(serde_json::Value::as_bool)
|
||||
.unwrap_or(true),
|
||||
depth: args
|
||||
.get("depth")
|
||||
.and_then(serde_json::Value::as_u64)
|
||||
.map(|d| u32::try_from(d).unwrap_or(u32::MAX)),
|
||||
},
|
||||
"click" => {
|
||||
let selector = args
|
||||
.get("selector")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'selector' for click"))?;
|
||||
BrowserAction::Click {
|
||||
selector: selector.into(),
|
||||
}
|
||||
}
|
||||
"fill" => {
|
||||
let selector = args
|
||||
.get("selector")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'selector' for fill"))?;
|
||||
let value = args
|
||||
.get("value")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'value' for fill"))?;
|
||||
BrowserAction::Fill {
|
||||
selector: selector.into(),
|
||||
value: value.into(),
|
||||
}
|
||||
}
|
||||
"type" => {
|
||||
let selector = args
|
||||
.get("selector")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'selector' for type"))?;
|
||||
let text = args
|
||||
.get("text")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'text' for type"))?;
|
||||
BrowserAction::Type {
|
||||
selector: selector.into(),
|
||||
text: text.into(),
|
||||
}
|
||||
}
|
||||
"get_text" => {
|
||||
let selector = args
|
||||
.get("selector")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'selector' for get_text"))?;
|
||||
BrowserAction::GetText {
|
||||
selector: selector.into(),
|
||||
}
|
||||
}
|
||||
"get_title" => BrowserAction::GetTitle,
|
||||
"get_url" => BrowserAction::GetUrl,
|
||||
"screenshot" => BrowserAction::Screenshot {
|
||||
path: args.get("path").and_then(|v| v.as_str()).map(String::from),
|
||||
full_page: args
|
||||
.get("full_page")
|
||||
.and_then(serde_json::Value::as_bool)
|
||||
.unwrap_or(false),
|
||||
},
|
||||
"wait" => BrowserAction::Wait {
|
||||
selector: args
|
||||
.get("selector")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(String::from),
|
||||
ms: args.get("ms").and_then(serde_json::Value::as_u64),
|
||||
text: args.get("text").and_then(|v| v.as_str()).map(String::from),
|
||||
},
|
||||
"press" => {
|
||||
let key = args
|
||||
.get("key")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'key' for press"))?;
|
||||
BrowserAction::Press { key: key.into() }
|
||||
}
|
||||
"hover" => {
|
||||
let selector = args
|
||||
.get("selector")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'selector' for hover"))?;
|
||||
BrowserAction::Hover {
|
||||
selector: selector.into(),
|
||||
}
|
||||
}
|
||||
"scroll" => {
|
||||
let direction = args
|
||||
.get("direction")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'direction' for scroll"))?;
|
||||
BrowserAction::Scroll {
|
||||
direction: direction.into(),
|
||||
pixels: args
|
||||
.get("pixels")
|
||||
.and_then(serde_json::Value::as_u64)
|
||||
.map(|p| u32::try_from(p).unwrap_or(u32::MAX)),
|
||||
}
|
||||
}
|
||||
"is_visible" => {
|
||||
let selector = args
|
||||
.get("selector")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'selector' for is_visible"))?;
|
||||
BrowserAction::IsVisible {
|
||||
selector: selector.into(),
|
||||
}
|
||||
}
|
||||
"close" => BrowserAction::Close,
|
||||
"find" => {
|
||||
let by = args
|
||||
.get("by")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'by' for find"))?;
|
||||
let value = args
|
||||
.get("value")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'value' for find"))?;
|
||||
let action = args
|
||||
.get("find_action")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'find_action' for find"))?;
|
||||
BrowserAction::Find {
|
||||
by: by.into(),
|
||||
value: value.into(),
|
||||
action: action.into(),
|
||||
fill_value: args
|
||||
.get("fill_value")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(String::from),
|
||||
}
|
||||
}
|
||||
_ => {
|
||||
if is_computer_use_only_action(action_str) {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(unavailable_action_for_backend_error(action_str, backend)),
|
||||
});
|
||||
}
|
||||
|
||||
let action = match parse_browser_action(action_str, &args) {
|
||||
Ok(a) => a,
|
||||
Err(e) => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"Action '{action_str}' is unavailable for backend '{}'",
|
||||
match backend {
|
||||
ResolvedBackend::AgentBrowser => "agent_browser",
|
||||
ResolvedBackend::RustNative => "rust_native",
|
||||
ResolvedBackend::ComputerUse => "computer_use",
|
||||
}
|
||||
)),
|
||||
error: Some(e.to_string()),
|
||||
});
|
||||
}
|
||||
};
|
||||
|
|
@ -1871,6 +1726,161 @@ mod native_backend {
|
|||
}
|
||||
}
|
||||
|
||||
// ── Action parsing ──────────────────────────────────────────────
|
||||
|
||||
/// Parse a JSON `args` object into a typed `BrowserAction`.
|
||||
fn parse_browser_action(action_str: &str, args: &Value) -> anyhow::Result<BrowserAction> {
|
||||
match action_str {
|
||||
"open" => {
|
||||
let url = args
|
||||
.get("url")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'url' for open action"))?;
|
||||
Ok(BrowserAction::Open { url: url.into() })
|
||||
}
|
||||
"snapshot" => Ok(BrowserAction::Snapshot {
|
||||
interactive_only: args
|
||||
.get("interactive_only")
|
||||
.and_then(serde_json::Value::as_bool)
|
||||
.unwrap_or(true),
|
||||
compact: args
|
||||
.get("compact")
|
||||
.and_then(serde_json::Value::as_bool)
|
||||
.unwrap_or(true),
|
||||
depth: args
|
||||
.get("depth")
|
||||
.and_then(serde_json::Value::as_u64)
|
||||
.map(|d| u32::try_from(d).unwrap_or(u32::MAX)),
|
||||
}),
|
||||
"click" => {
|
||||
let selector = args
|
||||
.get("selector")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'selector' for click"))?;
|
||||
Ok(BrowserAction::Click {
|
||||
selector: selector.into(),
|
||||
})
|
||||
}
|
||||
"fill" => {
|
||||
let selector = args
|
||||
.get("selector")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'selector' for fill"))?;
|
||||
let value = args
|
||||
.get("value")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'value' for fill"))?;
|
||||
Ok(BrowserAction::Fill {
|
||||
selector: selector.into(),
|
||||
value: value.into(),
|
||||
})
|
||||
}
|
||||
"type" => {
|
||||
let selector = args
|
||||
.get("selector")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'selector' for type"))?;
|
||||
let text = args
|
||||
.get("text")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'text' for type"))?;
|
||||
Ok(BrowserAction::Type {
|
||||
selector: selector.into(),
|
||||
text: text.into(),
|
||||
})
|
||||
}
|
||||
"get_text" => {
|
||||
let selector = args
|
||||
.get("selector")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'selector' for get_text"))?;
|
||||
Ok(BrowserAction::GetText {
|
||||
selector: selector.into(),
|
||||
})
|
||||
}
|
||||
"get_title" => Ok(BrowserAction::GetTitle),
|
||||
"get_url" => Ok(BrowserAction::GetUrl),
|
||||
"screenshot" => Ok(BrowserAction::Screenshot {
|
||||
path: args.get("path").and_then(|v| v.as_str()).map(String::from),
|
||||
full_page: args
|
||||
.get("full_page")
|
||||
.and_then(serde_json::Value::as_bool)
|
||||
.unwrap_or(false),
|
||||
}),
|
||||
"wait" => Ok(BrowserAction::Wait {
|
||||
selector: args
|
||||
.get("selector")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(String::from),
|
||||
ms: args.get("ms").and_then(serde_json::Value::as_u64),
|
||||
text: args.get("text").and_then(|v| v.as_str()).map(String::from),
|
||||
}),
|
||||
"press" => {
|
||||
let key = args
|
||||
.get("key")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'key' for press"))?;
|
||||
Ok(BrowserAction::Press { key: key.into() })
|
||||
}
|
||||
"hover" => {
|
||||
let selector = args
|
||||
.get("selector")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'selector' for hover"))?;
|
||||
Ok(BrowserAction::Hover {
|
||||
selector: selector.into(),
|
||||
})
|
||||
}
|
||||
"scroll" => {
|
||||
let direction = args
|
||||
.get("direction")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'direction' for scroll"))?;
|
||||
Ok(BrowserAction::Scroll {
|
||||
direction: direction.into(),
|
||||
pixels: args
|
||||
.get("pixels")
|
||||
.and_then(serde_json::Value::as_u64)
|
||||
.map(|p| u32::try_from(p).unwrap_or(u32::MAX)),
|
||||
})
|
||||
}
|
||||
"is_visible" => {
|
||||
let selector = args
|
||||
.get("selector")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'selector' for is_visible"))?;
|
||||
Ok(BrowserAction::IsVisible {
|
||||
selector: selector.into(),
|
||||
})
|
||||
}
|
||||
"close" => Ok(BrowserAction::Close),
|
||||
"find" => {
|
||||
let by = args
|
||||
.get("by")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'by' for find"))?;
|
||||
let value = args
|
||||
.get("value")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'value' for find"))?;
|
||||
let action = args
|
||||
.get("find_action")
|
||||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'find_action' for find"))?;
|
||||
Ok(BrowserAction::Find {
|
||||
by: by.into(),
|
||||
value: value.into(),
|
||||
action: action.into(),
|
||||
fill_value: args
|
||||
.get("fill_value")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(String::from),
|
||||
})
|
||||
}
|
||||
other => anyhow::bail!("Unsupported browser action: {other}"),
|
||||
}
|
||||
}
|
||||
|
||||
// ── Helper functions ─────────────────────────────────────────────
|
||||
|
||||
fn is_supported_browser_action(action: &str) -> bool {
|
||||
|
|
@ -1901,6 +1911,28 @@ fn is_supported_browser_action(action: &str) -> bool {
|
|||
)
|
||||
}
|
||||
|
||||
fn is_computer_use_only_action(action: &str) -> bool {
|
||||
matches!(
|
||||
action,
|
||||
"mouse_move" | "mouse_click" | "mouse_drag" | "key_type" | "key_press" | "screen_capture"
|
||||
)
|
||||
}
|
||||
|
||||
fn backend_name(backend: ResolvedBackend) -> &'static str {
|
||||
match backend {
|
||||
ResolvedBackend::AgentBrowser => "agent_browser",
|
||||
ResolvedBackend::RustNative => "rust_native",
|
||||
ResolvedBackend::ComputerUse => "computer_use",
|
||||
}
|
||||
}
|
||||
|
||||
fn unavailable_action_for_backend_error(action: &str, backend: ResolvedBackend) -> String {
|
||||
format!(
|
||||
"Action '{action}' is unavailable for backend '{}'",
|
||||
backend_name(backend)
|
||||
)
|
||||
}
|
||||
|
||||
fn normalize_domains(domains: Vec<String>) -> Vec<String> {
|
||||
domains
|
||||
.into_iter()
|
||||
|
|
@ -2342,4 +2374,28 @@ mod tests {
|
|||
let tool = BrowserTool::new(security, vec![], None);
|
||||
assert!(tool.validate_url("https://example.com").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn computer_use_only_action_detection_is_correct() {
|
||||
assert!(is_computer_use_only_action("mouse_move"));
|
||||
assert!(is_computer_use_only_action("mouse_click"));
|
||||
assert!(is_computer_use_only_action("mouse_drag"));
|
||||
assert!(is_computer_use_only_action("key_type"));
|
||||
assert!(is_computer_use_only_action("key_press"));
|
||||
assert!(is_computer_use_only_action("screen_capture"));
|
||||
assert!(!is_computer_use_only_action("open"));
|
||||
assert!(!is_computer_use_only_action("snapshot"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn unavailable_action_error_preserves_backend_context() {
|
||||
assert_eq!(
|
||||
unavailable_action_for_backend_error("mouse_move", ResolvedBackend::AgentBrowser),
|
||||
"Action 'mouse_move' is unavailable for backend 'agent_browser'"
|
||||
);
|
||||
assert_eq!(
|
||||
unavailable_action_for_backend_error("mouse_move", ResolvedBackend::RustNative),
|
||||
"Action 'mouse_move' is unavailable for backend 'rust_native'"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -112,12 +112,12 @@ impl ComposioTool {
|
|||
action_name: &str,
|
||||
params: serde_json::Value,
|
||||
entity_id: Option<&str>,
|
||||
connected_account_id: Option<&str>,
|
||||
connected_account_ref: Option<&str>,
|
||||
) -> anyhow::Result<serde_json::Value> {
|
||||
let tool_slug = normalize_tool_slug(action_name);
|
||||
|
||||
match self
|
||||
.execute_action_v3(&tool_slug, params.clone(), entity_id, connected_account_id)
|
||||
.execute_action_v3(&tool_slug, params.clone(), entity_id, connected_account_ref)
|
||||
.await
|
||||
{
|
||||
Ok(result) => Ok(result),
|
||||
|
|
@ -130,21 +130,17 @@ impl ComposioTool {
|
|||
}
|
||||
}
|
||||
|
||||
async fn execute_action_v3(
|
||||
&self,
|
||||
fn build_execute_action_v3_request(
|
||||
tool_slug: &str,
|
||||
params: serde_json::Value,
|
||||
entity_id: Option<&str>,
|
||||
connected_account_id: Option<&str>,
|
||||
) -> anyhow::Result<serde_json::Value> {
|
||||
let url = if let Some(connected_account_id) = connected_account_id
|
||||
.map(str::trim)
|
||||
.filter(|id| !id.is_empty())
|
||||
{
|
||||
format!("{COMPOSIO_API_BASE_V3}/tools/{tool_slug}/execute/{connected_account_id}")
|
||||
} else {
|
||||
format!("{COMPOSIO_API_BASE_V3}/tools/{tool_slug}/execute")
|
||||
};
|
||||
connected_account_ref: Option<&str>,
|
||||
) -> (String, serde_json::Value) {
|
||||
let url = format!("{COMPOSIO_API_BASE_V3}/tools/{tool_slug}/execute");
|
||||
let account_ref = connected_account_ref.and_then(|candidate| {
|
||||
let trimmed_candidate = candidate.trim();
|
||||
(!trimmed_candidate.is_empty()).then_some(trimmed_candidate)
|
||||
});
|
||||
|
||||
let mut body = json!({
|
||||
"arguments": params,
|
||||
|
|
@ -153,6 +149,26 @@ impl ComposioTool {
|
|||
if let Some(entity) = entity_id {
|
||||
body["user_id"] = json!(entity);
|
||||
}
|
||||
if let Some(account_ref) = account_ref {
|
||||
body["connected_account_id"] = json!(account_ref);
|
||||
}
|
||||
|
||||
(url, body)
|
||||
}
|
||||
|
||||
async fn execute_action_v3(
|
||||
&self,
|
||||
tool_slug: &str,
|
||||
params: serde_json::Value,
|
||||
entity_id: Option<&str>,
|
||||
connected_account_ref: Option<&str>,
|
||||
) -> anyhow::Result<serde_json::Value> {
|
||||
let (url, body) = Self::build_execute_action_v3_request(
|
||||
tool_slug,
|
||||
params,
|
||||
entity_id,
|
||||
connected_account_ref,
|
||||
);
|
||||
|
||||
let resp = self
|
||||
.client
|
||||
|
|
@ -474,11 +490,11 @@ impl Tool for ComposioTool {
|
|||
})?;
|
||||
|
||||
let params = args.get("params").cloned().unwrap_or(json!({}));
|
||||
let connected_account_id =
|
||||
let connected_account_ref =
|
||||
args.get("connected_account_id").and_then(|v| v.as_str());
|
||||
|
||||
match self
|
||||
.execute_action(action_name, params, Some(entity_id), connected_account_id)
|
||||
.execute_action(action_name, params, Some(entity_id), connected_account_ref)
|
||||
.await
|
||||
{
|
||||
Ok(result) => {
|
||||
|
|
@ -594,9 +610,38 @@ async fn response_error(resp: reqwest::Response) -> String {
|
|||
}
|
||||
|
||||
if let Some(api_error) = extract_api_error_message(&body) {
|
||||
format!("HTTP {}: {api_error}", status.as_u16())
|
||||
return format!(
|
||||
"HTTP {}: {}",
|
||||
status.as_u16(),
|
||||
sanitize_error_message(&api_error)
|
||||
);
|
||||
}
|
||||
|
||||
format!("HTTP {}", status.as_u16())
|
||||
}
|
||||
|
||||
fn sanitize_error_message(message: &str) -> String {
|
||||
let mut sanitized = message.replace('\n', " ");
|
||||
for marker in [
|
||||
"connected_account_id",
|
||||
"connectedAccountId",
|
||||
"entity_id",
|
||||
"entityId",
|
||||
"user_id",
|
||||
"userId",
|
||||
] {
|
||||
sanitized = sanitized.replace(marker, "[redacted]");
|
||||
}
|
||||
|
||||
let max_chars = 240;
|
||||
if sanitized.chars().count() <= max_chars {
|
||||
sanitized
|
||||
} else {
|
||||
format!("HTTP {}: {body}", status.as_u16())
|
||||
let mut end = max_chars;
|
||||
while end > 0 && !sanitized.is_char_boundary(end) {
|
||||
end -= 1;
|
||||
}
|
||||
format!("{}...", &sanitized[..end])
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -948,4 +993,40 @@ mod tests {
|
|||
fn composio_api_base_url_is_v3() {
|
||||
assert_eq!(COMPOSIO_API_BASE_V3, "https://backend.composio.dev/api/v3");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_execute_action_v3_request_uses_fixed_endpoint_and_body_account_id() {
|
||||
let (url, body) = ComposioTool::build_execute_action_v3_request(
|
||||
"gmail-send-email",
|
||||
json!({"to": "test@example.com"}),
|
||||
Some("workspace-user"),
|
||||
Some("account-42"),
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
url,
|
||||
"https://backend.composio.dev/api/v3/tools/gmail-send-email/execute"
|
||||
);
|
||||
assert_eq!(body["arguments"]["to"], json!("test@example.com"));
|
||||
assert_eq!(body["user_id"], json!("workspace-user"));
|
||||
assert_eq!(body["connected_account_id"], json!("account-42"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_execute_action_v3_request_drops_blank_optional_fields() {
|
||||
let (url, body) = ComposioTool::build_execute_action_v3_request(
|
||||
"github-list-repos",
|
||||
json!({}),
|
||||
None,
|
||||
Some(" "),
|
||||
);
|
||||
|
||||
assert_eq!(
|
||||
url,
|
||||
"https://backend.composio.dev/api/v3/tools/github-list-repos/execute"
|
||||
);
|
||||
assert_eq!(body["arguments"], json!({}));
|
||||
assert!(body.get("connected_account_id").is_none());
|
||||
assert!(body.get("user_id").is_none());
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -16,8 +16,8 @@ const DELEGATE_TIMEOUT_SECS: u64 = 120;
|
|||
/// summarization) to purpose-built sub-agents.
|
||||
pub struct DelegateTool {
|
||||
agents: Arc<HashMap<String, DelegateAgentConfig>>,
|
||||
/// Global API key fallback (from config.api_key)
|
||||
fallback_api_key: Option<String>,
|
||||
/// Global credential fallback (from config.api_key)
|
||||
fallback_credential: Option<String>,
|
||||
/// Depth at which this tool instance lives in the delegation chain.
|
||||
depth: u32,
|
||||
}
|
||||
|
|
@ -25,11 +25,11 @@ pub struct DelegateTool {
|
|||
impl DelegateTool {
|
||||
pub fn new(
|
||||
agents: HashMap<String, DelegateAgentConfig>,
|
||||
fallback_api_key: Option<String>,
|
||||
fallback_credential: Option<String>,
|
||||
) -> Self {
|
||||
Self {
|
||||
agents: Arc::new(agents),
|
||||
fallback_api_key,
|
||||
fallback_credential,
|
||||
depth: 0,
|
||||
}
|
||||
}
|
||||
|
|
@ -39,12 +39,12 @@ impl DelegateTool {
|
|||
/// their DelegateTool via this method with `depth: parent.depth + 1`.
|
||||
pub fn with_depth(
|
||||
agents: HashMap<String, DelegateAgentConfig>,
|
||||
fallback_api_key: Option<String>,
|
||||
fallback_credential: Option<String>,
|
||||
depth: u32,
|
||||
) -> Self {
|
||||
Self {
|
||||
agents: Arc::new(agents),
|
||||
fallback_api_key,
|
||||
fallback_credential,
|
||||
depth,
|
||||
}
|
||||
}
|
||||
|
|
@ -165,13 +165,15 @@ impl Tool for DelegateTool {
|
|||
}
|
||||
|
||||
// Create provider for this agent
|
||||
let api_key = agent_config
|
||||
let provider_credential_owned = agent_config
|
||||
.api_key
|
||||
.as_deref()
|
||||
.or(self.fallback_api_key.as_deref());
|
||||
.clone()
|
||||
.or_else(|| self.fallback_credential.clone());
|
||||
#[allow(clippy::option_as_ref_deref)]
|
||||
let provider_credential = provider_credential_owned.as_ref().map(String::as_str);
|
||||
|
||||
let provider: Box<dyn Provider> =
|
||||
match providers::create_provider(&agent_config.provider, api_key) {
|
||||
match providers::create_provider(&agent_config.provider, provider_credential) {
|
||||
Ok(p) => p,
|
||||
Err(e) => {
|
||||
return Ok(ToolResult {
|
||||
|
|
@ -268,7 +270,7 @@ mod tests {
|
|||
provider: "openrouter".to_string(),
|
||||
model: "anthropic/claude-sonnet-4-20250514".to_string(),
|
||||
system_prompt: None,
|
||||
api_key: Some("sk-test".to_string()),
|
||||
api_key: Some("delegate-test-credential".to_string()),
|
||||
temperature: None,
|
||||
max_depth: 2,
|
||||
},
|
||||
|
|
|
|||
|
|
@ -28,13 +28,22 @@ impl GitOperationsTool {
|
|||
if arg_lower.starts_with("--exec=")
|
||||
|| arg_lower.starts_with("--upload-pack=")
|
||||
|| arg_lower.starts_with("--receive-pack=")
|
||||
|| arg_lower.starts_with("--pager=")
|
||||
|| arg_lower.starts_with("--editor=")
|
||||
|| arg_lower == "--no-verify"
|
||||
|| arg_lower.contains("$(")
|
||||
|| arg_lower.contains('`')
|
||||
|| arg.contains('|')
|
||||
|| arg.contains(';')
|
||||
|| arg.contains('>')
|
||||
{
|
||||
anyhow::bail!("Blocked potentially dangerous git argument: {arg}");
|
||||
}
|
||||
// Block `-c` config injection (exact match or `-c=...` prefix).
|
||||
// This must not false-positive on `--cached` or `-cached`.
|
||||
if arg_lower == "-c" || arg_lower.starts_with("-c=") {
|
||||
anyhow::bail!("Blocked potentially dangerous git argument: {arg}");
|
||||
}
|
||||
result.push(arg.to_string());
|
||||
}
|
||||
Ok(result)
|
||||
|
|
@ -129,6 +138,9 @@ impl GitOperationsTool {
|
|||
.and_then(|v| v.as_bool())
|
||||
.unwrap_or(false);
|
||||
|
||||
// Validate files argument against injection patterns
|
||||
self.sanitize_git_args(files)?;
|
||||
|
||||
let mut git_args = vec!["diff", "--unified=3"];
|
||||
if cached {
|
||||
git_args.push("--cached");
|
||||
|
|
@ -267,6 +279,14 @@ impl GitOperationsTool {
|
|||
})
|
||||
}
|
||||
|
||||
fn truncate_commit_message(message: &str) -> String {
|
||||
if message.chars().count() > 2000 {
|
||||
format!("{}...", message.chars().take(1997).collect::<String>())
|
||||
} else {
|
||||
message.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
async fn git_commit(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
let message = args
|
||||
.get("message")
|
||||
|
|
@ -286,11 +306,7 @@ impl GitOperationsTool {
|
|||
}
|
||||
|
||||
// Limit message length
|
||||
let message = if sanitized.len() > 2000 {
|
||||
format!("{}...", &sanitized[..1997])
|
||||
} else {
|
||||
sanitized
|
||||
};
|
||||
let message = Self::truncate_commit_message(&sanitized);
|
||||
|
||||
let output = self.run_git_command(&["commit", "-m", &message]).await;
|
||||
|
||||
|
|
@ -314,6 +330,9 @@ impl GitOperationsTool {
|
|||
.and_then(|v| v.as_str())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'paths' parameter"))?;
|
||||
|
||||
// Validate paths against injection patterns
|
||||
self.sanitize_git_args(paths)?;
|
||||
|
||||
let output = self.run_git_command(&["add", "--", paths]).await;
|
||||
|
||||
match output {
|
||||
|
|
@ -574,6 +593,52 @@ mod tests {
|
|||
assert!(tool.sanitize_git_args("arg; rm file").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sanitize_git_blocks_pager_editor_injection() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let tool = test_tool(tmp.path());
|
||||
|
||||
assert!(tool.sanitize_git_args("--pager=less").is_err());
|
||||
assert!(tool.sanitize_git_args("--editor=vim").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sanitize_git_blocks_config_injection() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let tool = test_tool(tmp.path());
|
||||
|
||||
// Exact `-c` flag (config injection)
|
||||
assert!(tool.sanitize_git_args("-c core.sshCommand=evil").is_err());
|
||||
assert!(tool.sanitize_git_args("-c=core.pager=less").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sanitize_git_blocks_no_verify() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let tool = test_tool(tmp.path());
|
||||
|
||||
assert!(tool.sanitize_git_args("--no-verify").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sanitize_git_blocks_redirect_in_args() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let tool = test_tool(tmp.path());
|
||||
|
||||
assert!(tool.sanitize_git_args("file.txt > /tmp/out").is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sanitize_git_cached_not_blocked() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let tool = test_tool(tmp.path());
|
||||
|
||||
// --cached must NOT be blocked by the `-c` check
|
||||
assert!(tool.sanitize_git_args("--cached").is_ok());
|
||||
// Other safe flags starting with -c prefix
|
||||
assert!(tool.sanitize_git_args("-cached").is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn sanitize_git_allows_safe() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
|
|
@ -583,6 +648,8 @@ mod tests {
|
|||
assert!(tool.sanitize_git_args("main").is_ok());
|
||||
assert!(tool.sanitize_git_args("feature/test-branch").is_ok());
|
||||
assert!(tool.sanitize_git_args("--cached").is_ok());
|
||||
assert!(tool.sanitize_git_args("src/main.rs").is_ok());
|
||||
assert!(tool.sanitize_git_args(".").is_ok());
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -691,4 +758,12 @@ mod tests {
|
|||
.unwrap_or("")
|
||||
.contains("Unknown operation"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn truncates_multibyte_commit_message_without_panicking() {
|
||||
let long = "🦀".repeat(2500);
|
||||
let truncated = GitOperationsTool::truncate_commit_message(&long);
|
||||
|
||||
assert_eq!(truncated.chars().count(), 2000);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -124,10 +124,11 @@ impl Tool for HardwareBoardInfoTool {
|
|||
});
|
||||
}
|
||||
Err(e) => {
|
||||
output.push_str(&format!(
|
||||
"probe-rs attach failed: {}. Using static info.\n\n",
|
||||
e
|
||||
));
|
||||
use std::fmt::Write;
|
||||
let _ = write!(
|
||||
output,
|
||||
"probe-rs attach failed: {e}. Using static info.\n\n"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -135,13 +136,15 @@ impl Tool for HardwareBoardInfoTool {
|
|||
if let Some(info) = self.static_info_for_board(board) {
|
||||
output.push_str(&info);
|
||||
if let Some(mem) = memory_map_static(board) {
|
||||
output.push_str(&format!("\n\n**Memory map:**\n{}", mem));
|
||||
use std::fmt::Write;
|
||||
let _ = write!(output, "\n\n**Memory map:**\n{mem}");
|
||||
}
|
||||
} else {
|
||||
output.push_str(&format!(
|
||||
"Board '{}' configured. No static info available.",
|
||||
board
|
||||
));
|
||||
use std::fmt::Write;
|
||||
let _ = write!(
|
||||
output,
|
||||
"Board '{board}' configured. No static info available."
|
||||
);
|
||||
}
|
||||
|
||||
Ok(ToolResult {
|
||||
|
|
|
|||
|
|
@ -122,14 +122,16 @@ impl Tool for HardwareMemoryMapTool {
|
|||
|
||||
if !probe_ok {
|
||||
if let Some(map) = self.static_map_for_board(board) {
|
||||
output.push_str(&format!("**{}** (from datasheet):\n{}", board, map));
|
||||
use std::fmt::Write;
|
||||
let _ = write!(output, "**{board}** (from datasheet):\n{map}");
|
||||
} else {
|
||||
use std::fmt::Write;
|
||||
let known: Vec<&str> = MEMORY_MAPS.iter().map(|(b, _)| *b).collect();
|
||||
output.push_str(&format!(
|
||||
"No memory map for board '{}'. Known boards: {}",
|
||||
board,
|
||||
let _ = write!(
|
||||
output,
|
||||
"No memory map for board '{board}'. Known boards: {}",
|
||||
known.join(", ")
|
||||
));
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -94,14 +94,16 @@ impl Tool for HardwareMemoryReadTool {
|
|||
.get("address")
|
||||
.and_then(|v| v.as_str())
|
||||
.unwrap_or("0x20000000");
|
||||
let address = parse_hex_address(address_str).unwrap_or(NUCLEO_RAM_BASE);
|
||||
let _address = parse_hex_address(address_str).unwrap_or(NUCLEO_RAM_BASE);
|
||||
|
||||
let length = args.get("length").and_then(|v| v.as_u64()).unwrap_or(128) as usize;
|
||||
let length = length.min(256).max(1);
|
||||
let requested_length = args.get("length").and_then(|v| v.as_u64()).unwrap_or(128);
|
||||
let _length = usize::try_from(requested_length)
|
||||
.unwrap_or(256)
|
||||
.clamp(1, 256);
|
||||
|
||||
#[cfg(feature = "probe")]
|
||||
{
|
||||
match probe_read_memory(chip.unwrap(), address, length) {
|
||||
match probe_read_memory(chip.unwrap(), _address, _length) {
|
||||
Ok(output) => {
|
||||
return Ok(ToolResult {
|
||||
success: true,
|
||||
|
|
|
|||
|
|
@ -749,4 +749,54 @@ mod tests {
|
|||
let _ = HttpRequestTool::redact_headers_for_display(&headers);
|
||||
assert_eq!(headers[0].1, "Bearer real-token");
|
||||
}
|
||||
|
||||
// ── SSRF: alternate IP notation bypass defense-in-depth ─────────
|
||||
//
|
||||
// Rust's IpAddr::parse() rejects non-standard notations (octal, hex,
|
||||
// decimal integer, zero-padded). These tests document that property
|
||||
// so regressions are caught if the parsing strategy ever changes.
|
||||
|
||||
#[test]
|
||||
fn ssrf_octal_loopback_not_parsed_as_ip() {
|
||||
// 0177.0.0.1 is octal for 127.0.0.1 in some languages, but
|
||||
// Rust's IpAddr rejects it — it falls through as a hostname.
|
||||
assert!(!is_private_or_local_host("0177.0.0.1"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ssrf_hex_loopback_not_parsed_as_ip() {
|
||||
// 0x7f000001 is hex for 127.0.0.1 in some languages.
|
||||
assert!(!is_private_or_local_host("0x7f000001"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ssrf_decimal_loopback_not_parsed_as_ip() {
|
||||
// 2130706433 is decimal for 127.0.0.1 in some languages.
|
||||
assert!(!is_private_or_local_host("2130706433"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ssrf_zero_padded_loopback_not_parsed_as_ip() {
|
||||
// 127.000.000.001 uses zero-padded octets.
|
||||
assert!(!is_private_or_local_host("127.000.000.001"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn ssrf_alternate_notations_rejected_by_validate_url() {
|
||||
// Even if is_private_or_local_host doesn't flag these, they
|
||||
// fail the allowlist because they're treated as hostnames.
|
||||
let tool = test_tool(vec!["example.com"]);
|
||||
for notation in [
|
||||
"http://0177.0.0.1",
|
||||
"http://0x7f000001",
|
||||
"http://2130706433",
|
||||
"http://127.000.000.001",
|
||||
] {
|
||||
let err = tool.validate_url(notation).unwrap_err().to_string();
|
||||
assert!(
|
||||
err.contains("allowed_domains"),
|
||||
"Expected allowlist rejection for {notation}, got: {err}"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -87,7 +87,7 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn forget_existing() {
|
||||
let (_tmp, mem) = test_mem();
|
||||
mem.store("temp", "temporary", MemoryCategory::Conversation)
|
||||
mem.store("temp", "temporary", MemoryCategory::Conversation, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
|
|
|
|||
|
|
@ -55,7 +55,7 @@ impl Tool for MemoryRecallTool {
|
|||
.and_then(serde_json::Value::as_u64)
|
||||
.map_or(5, |v| v as usize);
|
||||
|
||||
match self.memory.recall(query, limit).await {
|
||||
match self.memory.recall(query, limit, None).await {
|
||||
Ok(entries) if entries.is_empty() => Ok(ToolResult {
|
||||
success: true,
|
||||
output: "No memories found matching that query.".into(),
|
||||
|
|
@ -112,10 +112,10 @@ mod tests {
|
|||
#[tokio::test]
|
||||
async fn recall_finds_match() {
|
||||
let (_tmp, mem) = seeded_mem();
|
||||
mem.store("lang", "User prefers Rust", MemoryCategory::Core)
|
||||
mem.store("lang", "User prefers Rust", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
mem.store("tz", "Timezone is EST", MemoryCategory::Core)
|
||||
mem.store("tz", "Timezone is EST", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
|
|
@ -134,6 +134,7 @@ mod tests {
|
|||
&format!("k{i}"),
|
||||
&format!("Rust fact {i}"),
|
||||
MemoryCategory::Core,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
|
|
|||
|
|
@ -64,7 +64,7 @@ impl Tool for MemoryStoreTool {
|
|||
_ => MemoryCategory::Core,
|
||||
};
|
||||
|
||||
match self.memory.store(key, content, category).await {
|
||||
match self.memory.store(key, content, category, None).await {
|
||||
Ok(()) => Ok(ToolResult {
|
||||
success: true,
|
||||
output: format!("Stored memory: {key}"),
|
||||
|
|
|
|||
|
|
@ -19,7 +19,9 @@ pub mod image_info;
|
|||
pub mod memory_forget;
|
||||
pub mod memory_recall;
|
||||
pub mod memory_store;
|
||||
pub mod pushover;
|
||||
pub mod schedule;
|
||||
pub mod schema;
|
||||
pub mod screenshot;
|
||||
pub mod shell;
|
||||
pub mod traits;
|
||||
|
|
@ -45,7 +47,9 @@ pub use image_info::ImageInfoTool;
|
|||
pub use memory_forget::MemoryForgetTool;
|
||||
pub use memory_recall::MemoryRecallTool;
|
||||
pub use memory_store::MemoryStoreTool;
|
||||
pub use pushover::PushoverTool;
|
||||
pub use schedule::ScheduleTool;
|
||||
pub use schema::{CleaningStrategy, SchemaCleanr};
|
||||
pub use screenshot::ScreenshotTool;
|
||||
pub use shell::ShellTool;
|
||||
pub use traits::Tool;
|
||||
|
|
@ -141,6 +145,10 @@ pub fn all_tools_with_runtime(
|
|||
security.clone(),
|
||||
workspace_dir.to_path_buf(),
|
||||
)),
|
||||
Box::new(PushoverTool::new(
|
||||
security.clone(),
|
||||
workspace_dir.to_path_buf(),
|
||||
)),
|
||||
];
|
||||
|
||||
if browser_config.enabled {
|
||||
|
|
@ -195,9 +203,13 @@ pub fn all_tools_with_runtime(
|
|||
.iter()
|
||||
.map(|(name, cfg)| (name.clone(), cfg.clone()))
|
||||
.collect();
|
||||
let delegate_fallback_credential = fallback_api_key.and_then(|value| {
|
||||
let trimmed_value = value.trim();
|
||||
(!trimmed_value.is_empty()).then(|| trimmed_value.to_owned())
|
||||
});
|
||||
tools.push(Box::new(DelegateTool::new(
|
||||
delegate_agents,
|
||||
fallback_api_key.map(String::from),
|
||||
delegate_fallback_credential,
|
||||
)));
|
||||
}
|
||||
|
||||
|
|
@ -261,6 +273,7 @@ mod tests {
|
|||
let names: Vec<&str> = tools.iter().map(|t| t.name()).collect();
|
||||
assert!(!names.contains(&"browser_open"));
|
||||
assert!(names.contains(&"schedule"));
|
||||
assert!(names.contains(&"pushover"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -298,6 +311,7 @@ mod tests {
|
|||
);
|
||||
let names: Vec<&str> = tools.iter().map(|t| t.name()).collect();
|
||||
assert!(names.contains(&"browser_open"));
|
||||
assert!(names.contains(&"pushover"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
|
|
@ -432,7 +446,7 @@ mod tests {
|
|||
&http,
|
||||
tmp.path(),
|
||||
&agents,
|
||||
Some("sk-test"),
|
||||
Some("delegate-test-credential"),
|
||||
&cfg,
|
||||
);
|
||||
let names: Vec<&str> = tools.iter().map(|t| t.name()).collect();
|
||||
|
|
|
|||
442
src/tools/pushover.rs
Normal file
442
src/tools/pushover.rs
Normal file
|
|
@ -0,0 +1,442 @@
|
|||
use super::traits::{Tool, ToolResult};
|
||||
use crate::security::SecurityPolicy;
|
||||
use async_trait::async_trait;
|
||||
use reqwest::Client;
|
||||
use serde_json::json;
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
const PUSHOVER_API_URL: &str = "https://api.pushover.net/1/messages.json";
|
||||
const PUSHOVER_REQUEST_TIMEOUT_SECS: u64 = 15;
|
||||
|
||||
pub struct PushoverTool {
|
||||
client: Client,
|
||||
security: Arc<SecurityPolicy>,
|
||||
workspace_dir: PathBuf,
|
||||
}
|
||||
|
||||
impl PushoverTool {
|
||||
pub fn new(security: Arc<SecurityPolicy>, workspace_dir: PathBuf) -> Self {
|
||||
let client = Client::builder()
|
||||
.timeout(Duration::from_secs(PUSHOVER_REQUEST_TIMEOUT_SECS))
|
||||
.build()
|
||||
.unwrap_or_else(|_| Client::new());
|
||||
|
||||
Self {
|
||||
client,
|
||||
security,
|
||||
workspace_dir,
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_env_value(raw: &str) -> String {
|
||||
let raw = raw.trim();
|
||||
|
||||
let unquoted = if raw.len() >= 2
|
||||
&& ((raw.starts_with('"') && raw.ends_with('"'))
|
||||
|| (raw.starts_with('\'') && raw.ends_with('\'')))
|
||||
{
|
||||
&raw[1..raw.len() - 1]
|
||||
} else {
|
||||
raw
|
||||
};
|
||||
|
||||
// Keep support for inline comments in unquoted values:
|
||||
// KEY=value # comment
|
||||
unquoted.split_once(" #").map_or_else(
|
||||
|| unquoted.trim().to_string(),
|
||||
|(value, _)| value.trim().to_string(),
|
||||
)
|
||||
}
|
||||
|
||||
fn get_credentials(&self) -> anyhow::Result<(String, String)> {
|
||||
let env_path = self.workspace_dir.join(".env");
|
||||
let content = std::fs::read_to_string(&env_path)
|
||||
.map_err(|e| anyhow::anyhow!("Failed to read {}: {}", env_path.display(), e))?;
|
||||
|
||||
let mut token = None;
|
||||
let mut user_key = None;
|
||||
|
||||
for line in content.lines() {
|
||||
let line = line.trim();
|
||||
if line.starts_with('#') || line.is_empty() {
|
||||
continue;
|
||||
}
|
||||
let line = line.strip_prefix("export ").map(str::trim).unwrap_or(line);
|
||||
if let Some((key, value)) = line.split_once('=') {
|
||||
let key = key.trim();
|
||||
let value = Self::parse_env_value(value);
|
||||
|
||||
if key.eq_ignore_ascii_case("PUSHOVER_TOKEN") {
|
||||
token = Some(value);
|
||||
} else if key.eq_ignore_ascii_case("PUSHOVER_USER_KEY") {
|
||||
user_key = Some(value);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let token = token.ok_or_else(|| anyhow::anyhow!("PUSHOVER_TOKEN not found in .env"))?;
|
||||
let user_key =
|
||||
user_key.ok_or_else(|| anyhow::anyhow!("PUSHOVER_USER_KEY not found in .env"))?;
|
||||
|
||||
Ok((token, user_key))
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl Tool for PushoverTool {
|
||||
fn name(&self) -> &str {
|
||||
"pushover"
|
||||
}
|
||||
|
||||
fn description(&self) -> &str {
|
||||
"Send a Pushover notification to your device. Requires PUSHOVER_TOKEN and PUSHOVER_USER_KEY in .env file."
|
||||
}
|
||||
|
||||
fn parameters_schema(&self) -> serde_json::Value {
|
||||
json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"message": {
|
||||
"type": "string",
|
||||
"description": "The notification message to send"
|
||||
},
|
||||
"title": {
|
||||
"type": "string",
|
||||
"description": "Optional notification title"
|
||||
},
|
||||
"priority": {
|
||||
"type": "integer",
|
||||
"enum": [-2, -1, 0, 1, 2],
|
||||
"description": "Message priority: -2 (lowest/silent), -1 (low/no sound), 0 (normal), 1 (high), 2 (emergency/repeating)"
|
||||
},
|
||||
"sound": {
|
||||
"type": "string",
|
||||
"description": "Notification sound override (e.g., 'pushover', 'bike', 'bugle', 'cashregister', etc.)"
|
||||
}
|
||||
},
|
||||
"required": ["message"]
|
||||
})
|
||||
}
|
||||
|
||||
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||
if !self.security.can_act() {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("Action blocked: autonomy is read-only".into()),
|
||||
});
|
||||
}
|
||||
|
||||
if !self.security.record_action() {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some("Action blocked: rate limit exceeded".into()),
|
||||
});
|
||||
}
|
||||
|
||||
let message = args
|
||||
.get("message")
|
||||
.and_then(|v| v.as_str())
|
||||
.map(str::trim)
|
||||
.filter(|v| !v.is_empty())
|
||||
.ok_or_else(|| anyhow::anyhow!("Missing 'message' parameter"))?
|
||||
.to_string();
|
||||
|
||||
let title = args.get("title").and_then(|v| v.as_str()).map(String::from);
|
||||
|
||||
let priority = match args.get("priority").and_then(|v| v.as_i64()) {
|
||||
Some(value) if (-2..=2).contains(&value) => Some(value),
|
||||
Some(value) => {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: String::new(),
|
||||
error: Some(format!(
|
||||
"Invalid 'priority': {value}. Expected integer in range -2..=2"
|
||||
)),
|
||||
})
|
||||
}
|
||||
None => None,
|
||||
};
|
||||
|
||||
let sound = args.get("sound").and_then(|v| v.as_str()).map(String::from);
|
||||
|
||||
let (token, user_key) = self.get_credentials()?;
|
||||
|
||||
let mut form = reqwest::multipart::Form::new()
|
||||
.text("token", token)
|
||||
.text("user", user_key)
|
||||
.text("message", message);
|
||||
|
||||
if let Some(title) = title {
|
||||
form = form.text("title", title);
|
||||
}
|
||||
|
||||
if let Some(priority) = priority {
|
||||
form = form.text("priority", priority.to_string());
|
||||
}
|
||||
|
||||
if let Some(sound) = sound {
|
||||
form = form.text("sound", sound);
|
||||
}
|
||||
|
||||
let response = self
|
||||
.client
|
||||
.post(PUSHOVER_API_URL)
|
||||
.multipart(form)
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
let status = response.status();
|
||||
let body = response.text().await.unwrap_or_default();
|
||||
|
||||
if !status.is_success() {
|
||||
return Ok(ToolResult {
|
||||
success: false,
|
||||
output: body,
|
||||
error: Some(format!("Pushover API returned status {}", status)),
|
||||
});
|
||||
}
|
||||
|
||||
let api_status = serde_json::from_str::<serde_json::Value>(&body)
|
||||
.ok()
|
||||
.and_then(|json| json.get("status").and_then(|value| value.as_i64()));
|
||||
|
||||
if api_status == Some(1) {
|
||||
Ok(ToolResult {
|
||||
success: true,
|
||||
output: format!(
|
||||
"Pushover notification sent successfully. Response: {}",
|
||||
body
|
||||
),
|
||||
error: None,
|
||||
})
|
||||
} else {
|
||||
Ok(ToolResult {
|
||||
success: false,
|
||||
output: body,
|
||||
error: Some("Pushover API returned an application-level error".into()),
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::security::AutonomyLevel;
|
||||
use std::fs;
|
||||
use tempfile::TempDir;
|
||||
|
||||
fn test_security(level: AutonomyLevel, max_actions_per_hour: u32) -> Arc<SecurityPolicy> {
|
||||
Arc::new(SecurityPolicy {
|
||||
autonomy: level,
|
||||
max_actions_per_hour,
|
||||
workspace_dir: std::env::temp_dir(),
|
||||
..SecurityPolicy::default()
|
||||
})
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pushover_tool_name() {
|
||||
let tool = PushoverTool::new(
|
||||
test_security(AutonomyLevel::Full, 100),
|
||||
PathBuf::from("/tmp"),
|
||||
);
|
||||
assert_eq!(tool.name(), "pushover");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pushover_tool_description() {
|
||||
let tool = PushoverTool::new(
|
||||
test_security(AutonomyLevel::Full, 100),
|
||||
PathBuf::from("/tmp"),
|
||||
);
|
||||
assert!(!tool.description().is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pushover_tool_has_parameters_schema() {
|
||||
let tool = PushoverTool::new(
|
||||
test_security(AutonomyLevel::Full, 100),
|
||||
PathBuf::from("/tmp"),
|
||||
);
|
||||
let schema = tool.parameters_schema();
|
||||
assert_eq!(schema["type"], "object");
|
||||
assert!(schema["properties"].get("message").is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pushover_tool_requires_message() {
|
||||
let tool = PushoverTool::new(
|
||||
test_security(AutonomyLevel::Full, 100),
|
||||
PathBuf::from("/tmp"),
|
||||
);
|
||||
let schema = tool.parameters_schema();
|
||||
let required = schema["required"].as_array().unwrap();
|
||||
assert!(required.contains(&serde_json::Value::String("message".to_string())));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn credentials_parsed_from_env_file() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let env_path = tmp.path().join(".env");
|
||||
fs::write(
|
||||
&env_path,
|
||||
"PUSHOVER_TOKEN=testtoken123\nPUSHOVER_USER_KEY=userkey456\n",
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let tool = PushoverTool::new(
|
||||
test_security(AutonomyLevel::Full, 100),
|
||||
tmp.path().to_path_buf(),
|
||||
);
|
||||
let result = tool.get_credentials();
|
||||
|
||||
assert!(result.is_ok());
|
||||
let (token, user_key) = result.unwrap();
|
||||
assert_eq!(token, "testtoken123");
|
||||
assert_eq!(user_key, "userkey456");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn credentials_fail_without_env_file() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let tool = PushoverTool::new(
|
||||
test_security(AutonomyLevel::Full, 100),
|
||||
tmp.path().to_path_buf(),
|
||||
);
|
||||
let result = tool.get_credentials();
|
||||
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn credentials_fail_without_token() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let env_path = tmp.path().join(".env");
|
||||
fs::write(&env_path, "PUSHOVER_USER_KEY=userkey456\n").unwrap();
|
||||
|
||||
let tool = PushoverTool::new(
|
||||
test_security(AutonomyLevel::Full, 100),
|
||||
tmp.path().to_path_buf(),
|
||||
);
|
||||
let result = tool.get_credentials();
|
||||
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn credentials_fail_without_user_key() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let env_path = tmp.path().join(".env");
|
||||
fs::write(&env_path, "PUSHOVER_TOKEN=testtoken123\n").unwrap();
|
||||
|
||||
let tool = PushoverTool::new(
|
||||
test_security(AutonomyLevel::Full, 100),
|
||||
tmp.path().to_path_buf(),
|
||||
);
|
||||
let result = tool.get_credentials();
|
||||
|
||||
assert!(result.is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn credentials_ignore_comments() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let env_path = tmp.path().join(".env");
|
||||
fs::write(&env_path, "# This is a comment\nPUSHOVER_TOKEN=realtoken\n# Another comment\nPUSHOVER_USER_KEY=realuser\n").unwrap();
|
||||
|
||||
let tool = PushoverTool::new(
|
||||
test_security(AutonomyLevel::Full, 100),
|
||||
tmp.path().to_path_buf(),
|
||||
);
|
||||
let result = tool.get_credentials();
|
||||
|
||||
assert!(result.is_ok());
|
||||
let (token, user_key) = result.unwrap();
|
||||
assert_eq!(token, "realtoken");
|
||||
assert_eq!(user_key, "realuser");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pushover_tool_supports_priority() {
|
||||
let tool = PushoverTool::new(
|
||||
test_security(AutonomyLevel::Full, 100),
|
||||
PathBuf::from("/tmp"),
|
||||
);
|
||||
let schema = tool.parameters_schema();
|
||||
assert!(schema["properties"].get("priority").is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn pushover_tool_supports_sound() {
|
||||
let tool = PushoverTool::new(
|
||||
test_security(AutonomyLevel::Full, 100),
|
||||
PathBuf::from("/tmp"),
|
||||
);
|
||||
let schema = tool.parameters_schema();
|
||||
assert!(schema["properties"].get("sound").is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn credentials_support_export_and_quoted_values() {
|
||||
let tmp = TempDir::new().unwrap();
|
||||
let env_path = tmp.path().join(".env");
|
||||
fs::write(
|
||||
&env_path,
|
||||
"export PUSHOVER_TOKEN=\"quotedtoken\"\nPUSHOVER_USER_KEY='quoteduser'\n",
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let tool = PushoverTool::new(
|
||||
test_security(AutonomyLevel::Full, 100),
|
||||
tmp.path().to_path_buf(),
|
||||
);
|
||||
let result = tool.get_credentials();
|
||||
|
||||
assert!(result.is_ok());
|
||||
let (token, user_key) = result.unwrap();
|
||||
assert_eq!(token, "quotedtoken");
|
||||
assert_eq!(user_key, "quoteduser");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn execute_blocks_readonly_mode() {
|
||||
let tool = PushoverTool::new(
|
||||
test_security(AutonomyLevel::ReadOnly, 100),
|
||||
PathBuf::from("/tmp"),
|
||||
);
|
||||
|
||||
let result = tool.execute(json!({"message": "hello"})).await.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.unwrap().contains("read-only"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn execute_blocks_rate_limit() {
|
||||
let tool = PushoverTool::new(test_security(AutonomyLevel::Full, 0), PathBuf::from("/tmp"));
|
||||
|
||||
let result = tool.execute(json!({"message": "hello"})).await.unwrap();
|
||||
assert!(!result.success);
|
||||
assert!(result.error.unwrap().contains("rate limit"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn execute_rejects_priority_out_of_range() {
|
||||
let tool = PushoverTool::new(
|
||||
test_security(AutonomyLevel::Full, 100),
|
||||
PathBuf::from("/tmp"),
|
||||
);
|
||||
|
||||
let result = tool
|
||||
.execute(json!({"message": "hello", "priority": 5}))
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(!result.success);
|
||||
assert!(result.error.unwrap().contains("-2..=2"));
|
||||
}
|
||||
}
|
||||
838
src/tools/schema.rs
Normal file
838
src/tools/schema.rs
Normal file
|
|
@ -0,0 +1,838 @@
|
|||
//! JSON Schema cleaning and validation for LLM tool-calling compatibility.
|
||||
//!
|
||||
//! Different providers support different subsets of JSON Schema. This module
|
||||
//! normalizes tool schemas to improve cross-provider compatibility while
|
||||
//! preserving semantic intent.
|
||||
//!
|
||||
//! ## What this module does
|
||||
//!
|
||||
//! 1. Removes unsupported keywords per provider strategy
|
||||
//! 2. Resolves local `$ref` entries from `$defs` and `definitions`
|
||||
//! 3. Flattens literal `anyOf` / `oneOf` unions into `enum`
|
||||
//! 4. Strips nullable variants from unions and `type` arrays
|
||||
//! 5. Converts `const` to single-value `enum`
|
||||
//! 6. Detects circular references and stops recursion safely
|
||||
//!
|
||||
//! # Example
|
||||
//!
|
||||
//! ```rust
|
||||
//! use serde_json::json;
|
||||
//! use zeroclaw::tools::schema::SchemaCleanr;
|
||||
//!
|
||||
//! let dirty_schema = json!({
|
||||
//! "type": "object",
|
||||
//! "properties": {
|
||||
//! "name": {
|
||||
//! "type": "string",
|
||||
//! "minLength": 1, // Gemini rejects this
|
||||
//! "pattern": "^[a-z]+$" // Gemini rejects this
|
||||
//! },
|
||||
//! "age": {
|
||||
//! "$ref": "#/$defs/Age" // Needs resolution
|
||||
//! }
|
||||
//! },
|
||||
//! "$defs": {
|
||||
//! "Age": {
|
||||
//! "type": "integer",
|
||||
//! "minimum": 0 // Gemini rejects this
|
||||
//! }
|
||||
//! }
|
||||
//! });
|
||||
//!
|
||||
//! let cleaned = SchemaCleanr::clean_for_gemini(dirty_schema);
|
||||
//!
|
||||
//! // Result:
|
||||
//! // {
|
||||
//! // "type": "object",
|
||||
//! // "properties": {
|
||||
//! // "name": { "type": "string" },
|
||||
//! // "age": { "type": "integer" }
|
||||
//! // }
|
||||
//! // }
|
||||
//! ```
|
||||
//!
|
||||
use serde_json::{json, Map, Value};
|
||||
use std::collections::{HashMap, HashSet};
|
||||
|
||||
/// Keywords that Gemini rejects for tool schemas.
|
||||
pub const GEMINI_UNSUPPORTED_KEYWORDS: &[&str] = &[
|
||||
// Schema composition
|
||||
"$ref",
|
||||
"$schema",
|
||||
"$id",
|
||||
"$defs",
|
||||
"definitions",
|
||||
// Property constraints
|
||||
"additionalProperties",
|
||||
"patternProperties",
|
||||
// String constraints
|
||||
"minLength",
|
||||
"maxLength",
|
||||
"pattern",
|
||||
"format",
|
||||
// Number constraints
|
||||
"minimum",
|
||||
"maximum",
|
||||
"multipleOf",
|
||||
// Array constraints
|
||||
"minItems",
|
||||
"maxItems",
|
||||
"uniqueItems",
|
||||
// Object constraints
|
||||
"minProperties",
|
||||
"maxProperties",
|
||||
// Non-standard
|
||||
"examples", // OpenAPI keyword, not JSON Schema
|
||||
];
|
||||
|
||||
/// Keywords that should be preserved during cleaning (metadata).
|
||||
const SCHEMA_META_KEYS: &[&str] = &["description", "title", "default"];
|
||||
|
||||
/// Schema cleaning strategies for different LLM providers.
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum CleaningStrategy {
|
||||
/// Gemini (Google AI / Vertex AI) - Most restrictive
|
||||
Gemini,
|
||||
/// Anthropic Claude - Moderately permissive
|
||||
Anthropic,
|
||||
/// OpenAI GPT - Most permissive
|
||||
OpenAI,
|
||||
/// Conservative: Remove only universally unsupported keywords
|
||||
Conservative,
|
||||
}
|
||||
|
||||
impl CleaningStrategy {
|
||||
/// Get the list of unsupported keywords for this strategy.
|
||||
pub fn unsupported_keywords(&self) -> &'static [&'static str] {
|
||||
match self {
|
||||
Self::Gemini => GEMINI_UNSUPPORTED_KEYWORDS,
|
||||
Self::Anthropic => &["$ref", "$defs", "definitions"], // Anthropic doesn't resolve refs
|
||||
Self::OpenAI => &[], // OpenAI is most permissive
|
||||
Self::Conservative => &["$ref", "$defs", "definitions", "additionalProperties"],
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// JSON Schema cleaner optimized for LLM tool calling.
|
||||
pub struct SchemaCleanr;
|
||||
|
||||
impl SchemaCleanr {
|
||||
/// Clean schema for Gemini compatibility (strictest).
|
||||
///
|
||||
/// This is the most aggressive cleaning strategy, removing all keywords
|
||||
/// that Gemini's API rejects.
|
||||
pub fn clean_for_gemini(schema: Value) -> Value {
|
||||
Self::clean(schema, CleaningStrategy::Gemini)
|
||||
}
|
||||
|
||||
/// Clean schema for Anthropic compatibility.
|
||||
pub fn clean_for_anthropic(schema: Value) -> Value {
|
||||
Self::clean(schema, CleaningStrategy::Anthropic)
|
||||
}
|
||||
|
||||
/// Clean schema for OpenAI compatibility (most permissive).
|
||||
pub fn clean_for_openai(schema: Value) -> Value {
|
||||
Self::clean(schema, CleaningStrategy::OpenAI)
|
||||
}
|
||||
|
||||
/// Clean schema with specified strategy.
|
||||
pub fn clean(schema: Value, strategy: CleaningStrategy) -> Value {
|
||||
// Extract $defs for reference resolution
|
||||
let defs = if let Some(obj) = schema.as_object() {
|
||||
Self::extract_defs(obj)
|
||||
} else {
|
||||
HashMap::new()
|
||||
};
|
||||
|
||||
Self::clean_with_defs(schema, &defs, strategy, &mut HashSet::new())
|
||||
}
|
||||
|
||||
/// Validate that a schema is suitable for LLM tool calling.
|
||||
///
|
||||
/// Returns an error if the schema is invalid or missing required fields.
|
||||
pub fn validate(schema: &Value) -> anyhow::Result<()> {
|
||||
let obj = schema
|
||||
.as_object()
|
||||
.ok_or_else(|| anyhow::anyhow!("Schema must be an object"))?;
|
||||
|
||||
// Must have 'type' field
|
||||
if !obj.contains_key("type") {
|
||||
anyhow::bail!("Schema missing required 'type' field");
|
||||
}
|
||||
|
||||
// If type is 'object', should have 'properties'
|
||||
if let Some(Value::String(t)) = obj.get("type") {
|
||||
if t == "object" && !obj.contains_key("properties") {
|
||||
tracing::warn!("Object schema without 'properties' field may cause issues");
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// --------------------------------------------------------------------
|
||||
// Internal implementation
|
||||
// --------------------------------------------------------------------
|
||||
|
||||
/// Extract $defs and definitions into a flat map for reference resolution.
|
||||
fn extract_defs(obj: &Map<String, Value>) -> HashMap<String, Value> {
|
||||
let mut defs = HashMap::new();
|
||||
|
||||
// Extract from $defs (JSON Schema 2019-09+)
|
||||
if let Some(Value::Object(defs_obj)) = obj.get("$defs") {
|
||||
for (key, value) in defs_obj {
|
||||
defs.insert(key.clone(), value.clone());
|
||||
}
|
||||
}
|
||||
|
||||
// Extract from definitions (JSON Schema draft-07)
|
||||
if let Some(Value::Object(defs_obj)) = obj.get("definitions") {
|
||||
for (key, value) in defs_obj {
|
||||
defs.insert(key.clone(), value.clone());
|
||||
}
|
||||
}
|
||||
|
||||
defs
|
||||
}
|
||||
|
||||
/// Recursively clean a schema value.
|
||||
fn clean_with_defs(
|
||||
schema: Value,
|
||||
defs: &HashMap<String, Value>,
|
||||
strategy: CleaningStrategy,
|
||||
ref_stack: &mut HashSet<String>,
|
||||
) -> Value {
|
||||
match schema {
|
||||
Value::Object(obj) => Self::clean_object(obj, defs, strategy, ref_stack),
|
||||
Value::Array(arr) => Value::Array(
|
||||
arr.into_iter()
|
||||
.map(|v| Self::clean_with_defs(v, defs, strategy, ref_stack))
|
||||
.collect(),
|
||||
),
|
||||
other => other,
|
||||
}
|
||||
}
|
||||
|
||||
/// Clean an object schema.
|
||||
fn clean_object(
|
||||
obj: Map<String, Value>,
|
||||
defs: &HashMap<String, Value>,
|
||||
strategy: CleaningStrategy,
|
||||
ref_stack: &mut HashSet<String>,
|
||||
) -> Value {
|
||||
// Handle $ref resolution
|
||||
if let Some(Value::String(ref_value)) = obj.get("$ref") {
|
||||
return Self::resolve_ref(ref_value, &obj, defs, strategy, ref_stack);
|
||||
}
|
||||
|
||||
// Handle anyOf/oneOf simplification
|
||||
if obj.contains_key("anyOf") || obj.contains_key("oneOf") {
|
||||
if let Some(simplified) = Self::try_simplify_union(&obj, defs, strategy, ref_stack) {
|
||||
return simplified;
|
||||
}
|
||||
}
|
||||
|
||||
// Build cleaned object
|
||||
let mut cleaned = Map::new();
|
||||
let unsupported: HashSet<&str> = strategy.unsupported_keywords().iter().copied().collect();
|
||||
let has_union = obj.contains_key("anyOf") || obj.contains_key("oneOf");
|
||||
|
||||
for (key, value) in obj {
|
||||
// Skip unsupported keywords
|
||||
if unsupported.contains(key.as_str()) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Special handling for specific keys
|
||||
match key.as_str() {
|
||||
// Convert const to enum
|
||||
"const" => {
|
||||
cleaned.insert("enum".to_string(), json!([value]));
|
||||
}
|
||||
// Skip type if we have anyOf/oneOf (they define the type)
|
||||
"type" if has_union => {
|
||||
// Skip
|
||||
}
|
||||
// Handle type arrays (remove null)
|
||||
"type" if matches!(value, Value::Array(_)) => {
|
||||
let cleaned_value = Self::clean_type_array(value);
|
||||
cleaned.insert(key, cleaned_value);
|
||||
}
|
||||
// Recursively clean nested schemas
|
||||
"properties" => {
|
||||
let cleaned_value = Self::clean_properties(value, defs, strategy, ref_stack);
|
||||
cleaned.insert(key, cleaned_value);
|
||||
}
|
||||
"items" => {
|
||||
let cleaned_value = Self::clean_with_defs(value, defs, strategy, ref_stack);
|
||||
cleaned.insert(key, cleaned_value);
|
||||
}
|
||||
"anyOf" | "oneOf" | "allOf" => {
|
||||
let cleaned_value = Self::clean_union(value, defs, strategy, ref_stack);
|
||||
cleaned.insert(key, cleaned_value);
|
||||
}
|
||||
// Keep all other keys, cleaning nested objects/arrays recursively.
|
||||
_ => {
|
||||
let cleaned_value = match value {
|
||||
Value::Object(_) | Value::Array(_) => {
|
||||
Self::clean_with_defs(value, defs, strategy, ref_stack)
|
||||
}
|
||||
other => other,
|
||||
};
|
||||
cleaned.insert(key, cleaned_value);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Value::Object(cleaned)
|
||||
}
|
||||
|
||||
/// Resolve a $ref to its definition.
|
||||
fn resolve_ref(
|
||||
ref_value: &str,
|
||||
obj: &Map<String, Value>,
|
||||
defs: &HashMap<String, Value>,
|
||||
strategy: CleaningStrategy,
|
||||
ref_stack: &mut HashSet<String>,
|
||||
) -> Value {
|
||||
// Prevent circular references
|
||||
if ref_stack.contains(ref_value) {
|
||||
tracing::warn!("Circular $ref detected: {}", ref_value);
|
||||
return Self::preserve_meta(obj, Value::Object(Map::new()));
|
||||
}
|
||||
|
||||
// Try to resolve local ref (#/$defs/Name or #/definitions/Name)
|
||||
if let Some(def_name) = Self::parse_local_ref(ref_value) {
|
||||
if let Some(definition) = defs.get(def_name.as_str()) {
|
||||
ref_stack.insert(ref_value.to_string());
|
||||
let cleaned = Self::clean_with_defs(definition.clone(), defs, strategy, ref_stack);
|
||||
ref_stack.remove(ref_value);
|
||||
return Self::preserve_meta(obj, cleaned);
|
||||
}
|
||||
}
|
||||
|
||||
// Can't resolve: return empty object with metadata
|
||||
tracing::warn!("Cannot resolve $ref: {}", ref_value);
|
||||
Self::preserve_meta(obj, Value::Object(Map::new()))
|
||||
}
|
||||
|
||||
/// Parse a local JSON Pointer ref (#/$defs/Name).
|
||||
fn parse_local_ref(ref_value: &str) -> Option<String> {
|
||||
ref_value
|
||||
.strip_prefix("#/$defs/")
|
||||
.or_else(|| ref_value.strip_prefix("#/definitions/"))
|
||||
.map(Self::decode_json_pointer)
|
||||
}
|
||||
|
||||
/// Decode JSON Pointer escaping (`~0` = `~`, `~1` = `/`).
|
||||
fn decode_json_pointer(segment: &str) -> String {
|
||||
if !segment.contains('~') {
|
||||
return segment.to_string();
|
||||
}
|
||||
|
||||
let mut decoded = String::with_capacity(segment.len());
|
||||
let mut chars = segment.chars().peekable();
|
||||
|
||||
while let Some(ch) = chars.next() {
|
||||
if ch == '~' {
|
||||
match chars.peek().copied() {
|
||||
Some('0') => {
|
||||
chars.next();
|
||||
decoded.push('~');
|
||||
}
|
||||
Some('1') => {
|
||||
chars.next();
|
||||
decoded.push('/');
|
||||
}
|
||||
_ => decoded.push('~'),
|
||||
}
|
||||
} else {
|
||||
decoded.push(ch);
|
||||
}
|
||||
}
|
||||
|
||||
decoded
|
||||
}
|
||||
|
||||
/// Try to simplify anyOf/oneOf to a simpler form.
|
||||
fn try_simplify_union(
|
||||
obj: &Map<String, Value>,
|
||||
defs: &HashMap<String, Value>,
|
||||
strategy: CleaningStrategy,
|
||||
ref_stack: &mut HashSet<String>,
|
||||
) -> Option<Value> {
|
||||
let union_key = if obj.contains_key("anyOf") {
|
||||
"anyOf"
|
||||
} else if obj.contains_key("oneOf") {
|
||||
"oneOf"
|
||||
} else {
|
||||
return None;
|
||||
};
|
||||
|
||||
let variants = obj.get(union_key)?.as_array()?;
|
||||
|
||||
// Clean all variants first
|
||||
let cleaned_variants: Vec<Value> = variants
|
||||
.iter()
|
||||
.map(|v| Self::clean_with_defs(v.clone(), defs, strategy, ref_stack))
|
||||
.collect();
|
||||
|
||||
// Strip null variants
|
||||
let non_null: Vec<Value> = cleaned_variants
|
||||
.into_iter()
|
||||
.filter(|v| !Self::is_null_schema(v))
|
||||
.collect();
|
||||
|
||||
// If only one variant remains after stripping nulls, return it
|
||||
if non_null.len() == 1 {
|
||||
return Some(Self::preserve_meta(obj, non_null[0].clone()));
|
||||
}
|
||||
|
||||
// Try to flatten to enum if all variants are literals
|
||||
if let Some(enum_value) = Self::try_flatten_literal_union(&non_null) {
|
||||
return Some(Self::preserve_meta(obj, enum_value));
|
||||
}
|
||||
|
||||
None
|
||||
}
|
||||
|
||||
/// Check if a schema represents null type.
|
||||
fn is_null_schema(value: &Value) -> bool {
|
||||
if let Some(obj) = value.as_object() {
|
||||
// { const: null }
|
||||
if let Some(Value::Null) = obj.get("const") {
|
||||
return true;
|
||||
}
|
||||
// { enum: [null] }
|
||||
if let Some(Value::Array(arr)) = obj.get("enum") {
|
||||
if arr.len() == 1 && matches!(arr[0], Value::Null) {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
// { type: "null" }
|
||||
if let Some(Value::String(t)) = obj.get("type") {
|
||||
if t == "null" {
|
||||
return true;
|
||||
}
|
||||
}
|
||||
}
|
||||
false
|
||||
}
|
||||
|
||||
/// Try to flatten anyOf/oneOf with only literal values to enum.
|
||||
///
|
||||
/// Example: `anyOf: [{const: "a"}, {const: "b"}]` -> `{type: "string", enum: ["a", "b"]}`
|
||||
fn try_flatten_literal_union(variants: &[Value]) -> Option<Value> {
|
||||
if variants.is_empty() {
|
||||
return None;
|
||||
}
|
||||
|
||||
let mut all_values = Vec::new();
|
||||
let mut common_type: Option<String> = None;
|
||||
|
||||
for variant in variants {
|
||||
let obj = variant.as_object()?;
|
||||
|
||||
// Extract literal value from const or single-item enum
|
||||
let literal_value = if let Some(const_val) = obj.get("const") {
|
||||
const_val.clone()
|
||||
} else if let Some(Value::Array(arr)) = obj.get("enum") {
|
||||
if arr.len() == 1 {
|
||||
arr[0].clone()
|
||||
} else {
|
||||
return None;
|
||||
}
|
||||
} else {
|
||||
return None;
|
||||
};
|
||||
|
||||
// Check type consistency
|
||||
let variant_type = obj.get("type")?.as_str()?;
|
||||
match &common_type {
|
||||
None => common_type = Some(variant_type.to_string()),
|
||||
Some(t) if t != variant_type => return None,
|
||||
_ => {}
|
||||
}
|
||||
|
||||
all_values.push(literal_value);
|
||||
}
|
||||
|
||||
common_type.map(|t| {
|
||||
json!({
|
||||
"type": t,
|
||||
"enum": all_values
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
/// Clean type array, removing null.
|
||||
fn clean_type_array(value: Value) -> Value {
|
||||
if let Value::Array(types) = value {
|
||||
let non_null: Vec<Value> = types
|
||||
.into_iter()
|
||||
.filter(|v| v.as_str() != Some("null"))
|
||||
.collect();
|
||||
|
||||
match non_null.len() {
|
||||
0 => Value::String("null".to_string()),
|
||||
1 => non_null
|
||||
.into_iter()
|
||||
.next()
|
||||
.unwrap_or(Value::String("null".to_string())),
|
||||
_ => Value::Array(non_null),
|
||||
}
|
||||
} else {
|
||||
value
|
||||
}
|
||||
}
|
||||
|
||||
/// Clean properties object.
|
||||
fn clean_properties(
|
||||
value: Value,
|
||||
defs: &HashMap<String, Value>,
|
||||
strategy: CleaningStrategy,
|
||||
ref_stack: &mut HashSet<String>,
|
||||
) -> Value {
|
||||
if let Value::Object(props) = value {
|
||||
let cleaned: Map<String, Value> = props
|
||||
.into_iter()
|
||||
.map(|(k, v)| (k, Self::clean_with_defs(v, defs, strategy, ref_stack)))
|
||||
.collect();
|
||||
Value::Object(cleaned)
|
||||
} else {
|
||||
value
|
||||
}
|
||||
}
|
||||
|
||||
/// Clean union (anyOf/oneOf/allOf).
|
||||
fn clean_union(
|
||||
value: Value,
|
||||
defs: &HashMap<String, Value>,
|
||||
strategy: CleaningStrategy,
|
||||
ref_stack: &mut HashSet<String>,
|
||||
) -> Value {
|
||||
if let Value::Array(variants) = value {
|
||||
let cleaned: Vec<Value> = variants
|
||||
.into_iter()
|
||||
.map(|v| Self::clean_with_defs(v, defs, strategy, ref_stack))
|
||||
.collect();
|
||||
Value::Array(cleaned)
|
||||
} else {
|
||||
value
|
||||
}
|
||||
}
|
||||
|
||||
/// Preserve metadata (description, title, default) from source to target.
|
||||
fn preserve_meta(source: &Map<String, Value>, mut target: Value) -> Value {
|
||||
if let Value::Object(target_obj) = &mut target {
|
||||
for &key in SCHEMA_META_KEYS {
|
||||
if let Some(value) = source.get(key) {
|
||||
target_obj.insert(key.to_string(), value.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
target
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_remove_unsupported_keywords() {
|
||||
let schema = json!({
|
||||
"type": "string",
|
||||
"minLength": 1,
|
||||
"maxLength": 100,
|
||||
"pattern": "^[a-z]+$",
|
||||
"description": "A lowercase string"
|
||||
});
|
||||
|
||||
let cleaned = SchemaCleanr::clean_for_gemini(schema);
|
||||
|
||||
assert_eq!(cleaned["type"], "string");
|
||||
assert_eq!(cleaned["description"], "A lowercase string");
|
||||
assert!(cleaned.get("minLength").is_none());
|
||||
assert!(cleaned.get("maxLength").is_none());
|
||||
assert!(cleaned.get("pattern").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_resolve_ref() {
|
||||
let schema = json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"age": {
|
||||
"$ref": "#/$defs/Age"
|
||||
}
|
||||
},
|
||||
"$defs": {
|
||||
"Age": {
|
||||
"type": "integer",
|
||||
"minimum": 0
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let cleaned = SchemaCleanr::clean_for_gemini(schema);
|
||||
|
||||
assert_eq!(cleaned["properties"]["age"]["type"], "integer");
|
||||
assert!(cleaned["properties"]["age"].get("minimum").is_none()); // Stripped by Gemini strategy
|
||||
assert!(cleaned.get("$defs").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_flatten_literal_union() {
|
||||
let schema = json!({
|
||||
"anyOf": [
|
||||
{ "const": "admin", "type": "string" },
|
||||
{ "const": "user", "type": "string" },
|
||||
{ "const": "guest", "type": "string" }
|
||||
]
|
||||
});
|
||||
|
||||
let cleaned = SchemaCleanr::clean_for_gemini(schema);
|
||||
|
||||
assert_eq!(cleaned["type"], "string");
|
||||
assert!(cleaned["enum"].is_array());
|
||||
let enum_values = cleaned["enum"].as_array().unwrap();
|
||||
assert_eq!(enum_values.len(), 3);
|
||||
assert!(enum_values.contains(&json!("admin")));
|
||||
assert!(enum_values.contains(&json!("user")));
|
||||
assert!(enum_values.contains(&json!("guest")));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_strip_null_from_union() {
|
||||
let schema = json!({
|
||||
"oneOf": [
|
||||
{ "type": "string" },
|
||||
{ "type": "null" }
|
||||
]
|
||||
});
|
||||
|
||||
let cleaned = SchemaCleanr::clean_for_gemini(schema);
|
||||
|
||||
// Should simplify to just { type: "string" }
|
||||
assert_eq!(cleaned["type"], "string");
|
||||
assert!(cleaned.get("oneOf").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_const_to_enum() {
|
||||
let schema = json!({
|
||||
"const": "fixed_value",
|
||||
"description": "A constant"
|
||||
});
|
||||
|
||||
let cleaned = SchemaCleanr::clean_for_gemini(schema);
|
||||
|
||||
assert_eq!(cleaned["enum"], json!(["fixed_value"]));
|
||||
assert_eq!(cleaned["description"], "A constant");
|
||||
assert!(cleaned.get("const").is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_preserve_metadata() {
|
||||
let schema = json!({
|
||||
"$ref": "#/$defs/Name",
|
||||
"description": "User's name",
|
||||
"title": "Name Field",
|
||||
"default": "Anonymous",
|
||||
"$defs": {
|
||||
"Name": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let cleaned = SchemaCleanr::clean_for_gemini(schema);
|
||||
|
||||
assert_eq!(cleaned["type"], "string");
|
||||
assert_eq!(cleaned["description"], "User's name");
|
||||
assert_eq!(cleaned["title"], "Name Field");
|
||||
assert_eq!(cleaned["default"], "Anonymous");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_circular_ref_prevention() {
|
||||
let schema = json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"parent": {
|
||||
"$ref": "#/$defs/Node"
|
||||
}
|
||||
},
|
||||
"$defs": {
|
||||
"Node": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"child": {
|
||||
"$ref": "#/$defs/Node"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
// Should not panic on circular reference
|
||||
let cleaned = SchemaCleanr::clean_for_gemini(schema);
|
||||
|
||||
assert_eq!(cleaned["properties"]["parent"]["type"], "object");
|
||||
// Circular reference should be broken
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_validate_schema() {
|
||||
let valid = json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": { "type": "string" }
|
||||
}
|
||||
});
|
||||
|
||||
assert!(SchemaCleanr::validate(&valid).is_ok());
|
||||
|
||||
let invalid = json!({
|
||||
"properties": {
|
||||
"name": { "type": "string" }
|
||||
}
|
||||
});
|
||||
|
||||
assert!(SchemaCleanr::validate(&invalid).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_strategy_differences() {
|
||||
let schema = json!({
|
||||
"type": "string",
|
||||
"minLength": 1,
|
||||
"description": "A string field"
|
||||
});
|
||||
|
||||
// Gemini: Most restrictive (removes minLength)
|
||||
let gemini = SchemaCleanr::clean_for_gemini(schema.clone());
|
||||
assert!(gemini.get("minLength").is_none());
|
||||
assert_eq!(gemini["type"], "string");
|
||||
assert_eq!(gemini["description"], "A string field");
|
||||
|
||||
// OpenAI: Most permissive (keeps minLength)
|
||||
let openai = SchemaCleanr::clean_for_openai(schema.clone());
|
||||
assert_eq!(openai["minLength"], 1); // OpenAI allows validation keywords
|
||||
assert_eq!(openai["type"], "string");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_nested_properties() {
|
||||
let schema = json!({
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"user": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": "string",
|
||||
"minLength": 1
|
||||
}
|
||||
},
|
||||
"additionalProperties": false
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let cleaned = SchemaCleanr::clean_for_gemini(schema);
|
||||
|
||||
assert!(cleaned["properties"]["user"]["properties"]["name"]
|
||||
.get("minLength")
|
||||
.is_none());
|
||||
assert!(cleaned["properties"]["user"]
|
||||
.get("additionalProperties")
|
||||
.is_none());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_type_array_null_removal() {
|
||||
let schema = json!({
|
||||
"type": ["string", "null"]
|
||||
});
|
||||
|
||||
let cleaned = SchemaCleanr::clean_for_gemini(schema);
|
||||
|
||||
// Should simplify to just "string"
|
||||
assert_eq!(cleaned["type"], "string");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_type_array_only_null_preserved() {
|
||||
let schema = json!({
|
||||
"type": ["null"]
|
||||
});
|
||||
|
||||
let cleaned = SchemaCleanr::clean_for_gemini(schema);
|
||||
|
||||
assert_eq!(cleaned["type"], "null");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_ref_with_json_pointer_escape() {
|
||||
let schema = json!({
|
||||
"$ref": "#/$defs/Foo~1Bar",
|
||||
"$defs": {
|
||||
"Foo/Bar": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let cleaned = SchemaCleanr::clean_for_gemini(schema);
|
||||
|
||||
assert_eq!(cleaned["type"], "string");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_skip_type_when_non_simplifiable_union_exists() {
|
||||
let schema = json!({
|
||||
"type": "object",
|
||||
"oneOf": [
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"a": { "type": "string" }
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"b": { "type": "number" }
|
||||
}
|
||||
}
|
||||
]
|
||||
});
|
||||
|
||||
let cleaned = SchemaCleanr::clean_for_gemini(schema);
|
||||
|
||||
assert!(cleaned.get("type").is_none());
|
||||
assert!(cleaned.get("oneOf").is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_clean_nested_unknown_schema_keyword() {
|
||||
let schema = json!({
|
||||
"not": {
|
||||
"$ref": "#/$defs/Age"
|
||||
},
|
||||
"$defs": {
|
||||
"Age": {
|
||||
"type": "integer",
|
||||
"minimum": 0
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let cleaned = SchemaCleanr::clean_for_gemini(schema);
|
||||
|
||||
assert_eq!(cleaned["not"]["type"], "integer");
|
||||
assert!(cleaned["not"].get("minimum").is_none());
|
||||
}
|
||||
}
|
||||
|
|
@ -36,6 +36,7 @@ async fn compare_store_speed() {
|
|||
&format!("key_{i}"),
|
||||
&format!("Memory entry number {i} about Rust programming"),
|
||||
MemoryCategory::Core,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
|
@ -49,6 +50,7 @@ async fn compare_store_speed() {
|
|||
&format!("key_{i}"),
|
||||
&format!("Memory entry number {i} about Rust programming"),
|
||||
MemoryCategory::Core,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
|
@ -127,8 +129,8 @@ async fn compare_recall_quality() {
|
|||
];
|
||||
|
||||
for (key, content, cat) in &entries {
|
||||
sq.store(key, content, cat.clone()).await.unwrap();
|
||||
md.store(key, content, cat.clone()).await.unwrap();
|
||||
sq.store(key, content, cat.clone(), None).await.unwrap();
|
||||
md.store(key, content, cat.clone(), None).await.unwrap();
|
||||
}
|
||||
|
||||
// Test queries and compare results
|
||||
|
|
@ -145,8 +147,8 @@ async fn compare_recall_quality() {
|
|||
println!("RECALL QUALITY (10 entries seeded):\n");
|
||||
|
||||
for (query, desc) in &queries {
|
||||
let sq_results = sq.recall(query, 10).await.unwrap();
|
||||
let md_results = md.recall(query, 10).await.unwrap();
|
||||
let sq_results = sq.recall(query, 10, None).await.unwrap();
|
||||
let md_results = md.recall(query, 10, None).await.unwrap();
|
||||
|
||||
println!(" Query: \"{query}\" — {desc}");
|
||||
println!(" SQLite: {} results", sq_results.len());
|
||||
|
|
@ -190,21 +192,21 @@ async fn compare_recall_speed() {
|
|||
} else {
|
||||
format!("TypeScript powers modern web apps, entry {i}")
|
||||
};
|
||||
sq.store(&format!("e{i}"), &content, MemoryCategory::Core)
|
||||
sq.store(&format!("e{i}"), &content, MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
md.store(&format!("e{i}"), &content, MemoryCategory::Daily)
|
||||
md.store(&format!("e{i}"), &content, MemoryCategory::Daily, None)
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
// Benchmark recall
|
||||
let start = Instant::now();
|
||||
let sq_results = sq.recall("Rust systems", 10).await.unwrap();
|
||||
let sq_results = sq.recall("Rust systems", 10, None).await.unwrap();
|
||||
let sq_dur = start.elapsed();
|
||||
|
||||
let start = Instant::now();
|
||||
let md_results = md.recall("Rust systems", 10).await.unwrap();
|
||||
let md_results = md.recall("Rust systems", 10, None).await.unwrap();
|
||||
let md_dur = start.elapsed();
|
||||
|
||||
println!("\n============================================================");
|
||||
|
|
@ -227,15 +229,25 @@ async fn compare_persistence() {
|
|||
// Store in both, then drop and re-open
|
||||
{
|
||||
let sq = sqlite_backend(tmp_sq.path());
|
||||
sq.store("persist_test", "I should survive", MemoryCategory::Core)
|
||||
.await
|
||||
.unwrap();
|
||||
sq.store(
|
||||
"persist_test",
|
||||
"I should survive",
|
||||
MemoryCategory::Core,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
{
|
||||
let md = markdown_backend(tmp_md.path());
|
||||
md.store("persist_test", "I should survive", MemoryCategory::Core)
|
||||
.await
|
||||
.unwrap();
|
||||
md.store(
|
||||
"persist_test",
|
||||
"I should survive",
|
||||
MemoryCategory::Core,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
// Re-open
|
||||
|
|
@ -282,17 +294,17 @@ async fn compare_upsert() {
|
|||
let md = markdown_backend(tmp_md.path());
|
||||
|
||||
// Store twice with same key, different content
|
||||
sq.store("pref", "likes Rust", MemoryCategory::Core)
|
||||
sq.store("pref", "likes Rust", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
sq.store("pref", "loves Rust", MemoryCategory::Core)
|
||||
sq.store("pref", "loves Rust", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
md.store("pref", "likes Rust", MemoryCategory::Core)
|
||||
md.store("pref", "likes Rust", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
md.store("pref", "loves Rust", MemoryCategory::Core)
|
||||
md.store("pref", "loves Rust", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
|
|
@ -300,7 +312,7 @@ async fn compare_upsert() {
|
|||
let md_count = md.count().await.unwrap();
|
||||
|
||||
let sq_entry = sq.get("pref").await.unwrap();
|
||||
let md_results = md.recall("loves Rust", 5).await.unwrap();
|
||||
let md_results = md.recall("loves Rust", 5, None).await.unwrap();
|
||||
|
||||
println!("\n============================================================");
|
||||
println!("UPSERT (store same key twice):");
|
||||
|
|
@ -328,10 +340,10 @@ async fn compare_forget() {
|
|||
let sq = sqlite_backend(tmp_sq.path());
|
||||
let md = markdown_backend(tmp_md.path());
|
||||
|
||||
sq.store("secret", "API key: sk-1234", MemoryCategory::Core)
|
||||
sq.store("secret", "API key: sk-1234", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
md.store("secret", "API key: sk-1234", MemoryCategory::Core)
|
||||
md.store("secret", "API key: sk-1234", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
|
|
@ -372,37 +384,40 @@ async fn compare_category_filter() {
|
|||
let md = markdown_backend(tmp_md.path());
|
||||
|
||||
// Mix of categories
|
||||
sq.store("a", "core fact 1", MemoryCategory::Core)
|
||||
sq.store("a", "core fact 1", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
sq.store("b", "core fact 2", MemoryCategory::Core)
|
||||
sq.store("b", "core fact 2", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
sq.store("c", "daily note", MemoryCategory::Daily)
|
||||
sq.store("c", "daily note", MemoryCategory::Daily, None)
|
||||
.await
|
||||
.unwrap();
|
||||
sq.store("d", "convo msg", MemoryCategory::Conversation)
|
||||
sq.store("d", "convo msg", MemoryCategory::Conversation, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
md.store("a", "core fact 1", MemoryCategory::Core)
|
||||
md.store("a", "core fact 1", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
md.store("b", "core fact 2", MemoryCategory::Core)
|
||||
md.store("b", "core fact 2", MemoryCategory::Core, None)
|
||||
.await
|
||||
.unwrap();
|
||||
md.store("c", "daily note", MemoryCategory::Daily)
|
||||
md.store("c", "daily note", MemoryCategory::Daily, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let sq_core = sq.list(Some(&MemoryCategory::Core)).await.unwrap();
|
||||
let sq_daily = sq.list(Some(&MemoryCategory::Daily)).await.unwrap();
|
||||
let sq_conv = sq.list(Some(&MemoryCategory::Conversation)).await.unwrap();
|
||||
let sq_all = sq.list(None).await.unwrap();
|
||||
let sq_core = sq.list(Some(&MemoryCategory::Core), None).await.unwrap();
|
||||
let sq_daily = sq.list(Some(&MemoryCategory::Daily), None).await.unwrap();
|
||||
let sq_conv = sq
|
||||
.list(Some(&MemoryCategory::Conversation), None)
|
||||
.await
|
||||
.unwrap();
|
||||
let sq_all = sq.list(None, None).await.unwrap();
|
||||
|
||||
let md_core = md.list(Some(&MemoryCategory::Core)).await.unwrap();
|
||||
let md_daily = md.list(Some(&MemoryCategory::Daily)).await.unwrap();
|
||||
let md_all = md.list(None).await.unwrap();
|
||||
let md_core = md.list(Some(&MemoryCategory::Core), None).await.unwrap();
|
||||
let md_daily = md.list(Some(&MemoryCategory::Daily), None).await.unwrap();
|
||||
let md_all = md.list(None, None).await.unwrap();
|
||||
|
||||
println!("\n============================================================");
|
||||
println!("CATEGORY FILTERING:");
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue