Merge branch 'main' into pr-484-clean
This commit is contained in:
commit
ee05d62ce4
90 changed files with 6937 additions and 1403 deletions
62
.env.example
62
.env.example
|
|
@ -1,25 +1,69 @@
|
||||||
# ZeroClaw Environment Variables
|
# ZeroClaw Environment Variables
|
||||||
# Copy this file to .env and fill in your values.
|
# Copy this file to `.env` and fill in your local values.
|
||||||
# NEVER commit .env — it is listed in .gitignore.
|
# Never commit `.env` or any real secrets.
|
||||||
|
|
||||||
# ── Required ──────────────────────────────────────────────────
|
# ── Core Runtime ──────────────────────────────────────────────
|
||||||
# Your LLM provider API key
|
# Provider key resolution at runtime:
|
||||||
# ZEROCLAW_API_KEY=sk-your-key-here
|
# 1) explicit key passed from config/CLI
|
||||||
|
# 2) provider-specific env var (OPENROUTER_API_KEY, OPENAI_API_KEY, ...)
|
||||||
|
# 3) generic fallback env vars below
|
||||||
|
|
||||||
|
# Generic fallback API key (used when provider-specific key is absent)
|
||||||
API_KEY=your-api-key-here
|
API_KEY=your-api-key-here
|
||||||
|
# ZEROCLAW_API_KEY=your-api-key-here
|
||||||
|
|
||||||
# ── Provider & Model ─────────────────────────────────────────
|
# Default provider/model (can be overridden by CLI flags)
|
||||||
# LLM provider: openrouter, openai, anthropic, ollama, glm
|
|
||||||
PROVIDER=openrouter
|
PROVIDER=openrouter
|
||||||
|
# ZEROCLAW_PROVIDER=openrouter
|
||||||
# ZEROCLAW_MODEL=anthropic/claude-sonnet-4-20250514
|
# ZEROCLAW_MODEL=anthropic/claude-sonnet-4-20250514
|
||||||
# ZEROCLAW_TEMPERATURE=0.7
|
# ZEROCLAW_TEMPERATURE=0.7
|
||||||
|
|
||||||
|
# Workspace directory override
|
||||||
|
# ZEROCLAW_WORKSPACE=/path/to/workspace
|
||||||
|
|
||||||
|
# ── Provider-Specific API Keys ────────────────────────────────
|
||||||
|
# OpenRouter
|
||||||
|
# OPENROUTER_API_KEY=sk-or-v1-...
|
||||||
|
|
||||||
|
# Anthropic
|
||||||
|
# ANTHROPIC_OAUTH_TOKEN=...
|
||||||
|
# ANTHROPIC_API_KEY=sk-ant-...
|
||||||
|
|
||||||
|
# OpenAI / Gemini
|
||||||
|
# OPENAI_API_KEY=sk-...
|
||||||
|
# GEMINI_API_KEY=...
|
||||||
|
# GOOGLE_API_KEY=...
|
||||||
|
|
||||||
|
# Other supported providers
|
||||||
|
# VENICE_API_KEY=...
|
||||||
|
# GROQ_API_KEY=...
|
||||||
|
# MISTRAL_API_KEY=...
|
||||||
|
# DEEPSEEK_API_KEY=...
|
||||||
|
# XAI_API_KEY=...
|
||||||
|
# TOGETHER_API_KEY=...
|
||||||
|
# FIREWORKS_API_KEY=...
|
||||||
|
# PERPLEXITY_API_KEY=...
|
||||||
|
# COHERE_API_KEY=...
|
||||||
|
# MOONSHOT_API_KEY=...
|
||||||
|
# GLM_API_KEY=...
|
||||||
|
# MINIMAX_API_KEY=...
|
||||||
|
# QIANFAN_API_KEY=...
|
||||||
|
# DASHSCOPE_API_KEY=...
|
||||||
|
# ZAI_API_KEY=...
|
||||||
|
# SYNTHETIC_API_KEY=...
|
||||||
|
# OPENCODE_API_KEY=...
|
||||||
|
# VERCEL_API_KEY=...
|
||||||
|
# CLOUDFLARE_API_KEY=...
|
||||||
|
|
||||||
# ── Gateway ──────────────────────────────────────────────────
|
# ── Gateway ──────────────────────────────────────────────────
|
||||||
# ZEROCLAW_GATEWAY_PORT=3000
|
# ZEROCLAW_GATEWAY_PORT=3000
|
||||||
# ZEROCLAW_GATEWAY_HOST=127.0.0.1
|
# ZEROCLAW_GATEWAY_HOST=127.0.0.1
|
||||||
# ZEROCLAW_ALLOW_PUBLIC_BIND=false
|
# ZEROCLAW_ALLOW_PUBLIC_BIND=false
|
||||||
|
|
||||||
# ── Workspace ────────────────────────────────────────────────
|
# ── Optional Integrations ────────────────────────────────────
|
||||||
# ZEROCLAW_WORKSPACE=/path/to/workspace
|
# Pushover notifications (`pushover` tool)
|
||||||
|
# PUSHOVER_TOKEN=your-pushover-app-token
|
||||||
|
# PUSHOVER_USER_KEY=your-pushover-user-key
|
||||||
|
|
||||||
# ── Docker Compose ───────────────────────────────────────────
|
# ── Docker Compose ───────────────────────────────────────────
|
||||||
# Host port mapping (used by docker-compose.yml)
|
# Host port mapping (used by docker-compose.yml)
|
||||||
|
|
|
||||||
8
.githooks/pre-commit
Executable file
8
.githooks/pre-commit
Executable file
|
|
@ -0,0 +1,8 @@
|
||||||
|
#!/usr/bin/env bash
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
if command -v gitleaks >/dev/null 2>&1; then
|
||||||
|
gitleaks protect --staged --redact
|
||||||
|
else
|
||||||
|
echo "warning: gitleaks not found; skipping staged secret scan" >&2
|
||||||
|
fi
|
||||||
4
.github/pull_request_template.md
vendored
4
.github/pull_request_template.md
vendored
|
|
@ -12,7 +12,11 @@ Describe this PR in 2-5 bullets:
|
||||||
- Risk label (`risk: low|medium|high`):
|
- Risk label (`risk: low|medium|high`):
|
||||||
- Size label (`size: XS|S|M|L|XL`, auto-managed/read-only):
|
- Size label (`size: XS|S|M|L|XL`, auto-managed/read-only):
|
||||||
- Scope labels (`core|agent|channel|config|cron|daemon|doctor|gateway|health|heartbeat|integration|memory|observability|onboard|provider|runtime|security|service|skillforge|skills|tool|tunnel|docs|dependencies|ci|tests|scripts|dev`, comma-separated):
|
- Scope labels (`core|agent|channel|config|cron|daemon|doctor|gateway|health|heartbeat|integration|memory|observability|onboard|provider|runtime|security|service|skillforge|skills|tool|tunnel|docs|dependencies|ci|tests|scripts|dev`, comma-separated):
|
||||||
|
<<<<<<< chore/labeler-spacing-trusted-tier
|
||||||
|
- Module labels (`<module>: <component>`, for example `channel: telegram`, `provider: kimi`, `tool: shell`):
|
||||||
|
=======
|
||||||
- Module labels (`<module>:<component>`, for example `channel:telegram`, `provider:kimi`, `tool:shell`):
|
- Module labels (`<module>:<component>`, for example `channel:telegram`, `provider:kimi`, `tool:shell`):
|
||||||
|
>>>>>>> main
|
||||||
- Contributor tier label (`trusted contributor|experienced contributor|principal contributor|distinguished contributor`, auto-managed/read-only; author merged PRs >=5/10/20/50):
|
- Contributor tier label (`trusted contributor|experienced contributor|principal contributor|distinguished contributor`, auto-managed/read-only; author merged PRs >=5/10/20/50):
|
||||||
- If any auto-label is incorrect, note requested correction:
|
- If any auto-label is incorrect, note requested correction:
|
||||||
|
|
||||||
|
|
|
||||||
1
.github/workflows/auto-response.yml
vendored
1
.github/workflows/auto-response.yml
vendored
|
|
@ -18,6 +18,7 @@ jobs:
|
||||||
runs-on: blacksmith-2vcpu-ubuntu-2404
|
runs-on: blacksmith-2vcpu-ubuntu-2404
|
||||||
permissions:
|
permissions:
|
||||||
issues: write
|
issues: write
|
||||||
|
pull-requests: write
|
||||||
steps:
|
steps:
|
||||||
- name: Apply contributor tier label for issue author
|
- name: Apply contributor tier label for issue author
|
||||||
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8
|
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8
|
||||||
|
|
|
||||||
8
.github/workflows/docker.yml
vendored
8
.github/workflows/docker.yml
vendored
|
|
@ -35,7 +35,7 @@ jobs:
|
||||||
uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4
|
uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4
|
||||||
|
|
||||||
- name: Setup Blacksmith Builder
|
- name: Setup Blacksmith Builder
|
||||||
uses: useblacksmith/setup-docker-builder@v1
|
uses: useblacksmith/setup-docker-builder@ef12d5b165b596e3aa44ea8198d8fde563eab402 # v1
|
||||||
|
|
||||||
- name: Extract metadata (tags, labels)
|
- name: Extract metadata (tags, labels)
|
||||||
id: meta
|
id: meta
|
||||||
|
|
@ -46,7 +46,7 @@ jobs:
|
||||||
type=ref,event=pr
|
type=ref,event=pr
|
||||||
|
|
||||||
- name: Build smoke image
|
- name: Build smoke image
|
||||||
uses: useblacksmith/build-push-action@v2
|
uses: useblacksmith/build-push-action@30c71162f16ea2c27c3e21523255d209b8b538c1 # v2
|
||||||
with:
|
with:
|
||||||
context: .
|
context: .
|
||||||
push: false
|
push: false
|
||||||
|
|
@ -71,7 +71,7 @@ jobs:
|
||||||
uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4
|
uses: actions/checkout@34e114876b0b11c390a56381ad16ebd13914f8d5 # v4
|
||||||
|
|
||||||
- name: Setup Blacksmith Builder
|
- name: Setup Blacksmith Builder
|
||||||
uses: useblacksmith/setup-docker-builder@v1
|
uses: useblacksmith/setup-docker-builder@ef12d5b165b596e3aa44ea8198d8fde563eab402 # v1
|
||||||
|
|
||||||
- name: Log in to Container Registry
|
- name: Log in to Container Registry
|
||||||
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # v3
|
uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # v3
|
||||||
|
|
@ -102,7 +102,7 @@ jobs:
|
||||||
echo "tags=${TAGS}" >> "$GITHUB_OUTPUT"
|
echo "tags=${TAGS}" >> "$GITHUB_OUTPUT"
|
||||||
|
|
||||||
- name: Build and push Docker image
|
- name: Build and push Docker image
|
||||||
uses: useblacksmith/build-push-action@v2
|
uses: useblacksmith/build-push-action@30c71162f16ea2c27c3e21523255d209b8b538c1 # v2
|
||||||
with:
|
with:
|
||||||
context: .
|
context: .
|
||||||
push: true
|
push: true
|
||||||
|
|
|
||||||
27
.github/workflows/labeler.yml
vendored
27
.github/workflows/labeler.yml
vendored
|
|
@ -325,13 +325,18 @@ jobs:
|
||||||
return pattern.test(text);
|
return pattern.test(text);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
function formatModuleLabel(prefix, segment) {
|
||||||
|
return `${prefix}: ${segment}`;
|
||||||
|
}
|
||||||
|
|
||||||
function parseModuleLabel(label) {
|
function parseModuleLabel(label) {
|
||||||
const separatorIndex = label.indexOf(":");
|
if (typeof label !== "string") return null;
|
||||||
if (separatorIndex <= 0 || separatorIndex >= label.length - 1) return null;
|
const match = label.match(/^([^:]+):\s*(.+)$/);
|
||||||
return {
|
if (!match) return null;
|
||||||
prefix: label.slice(0, separatorIndex),
|
const prefix = match[1].trim().toLowerCase();
|
||||||
segment: label.slice(separatorIndex + 1),
|
const segment = (match[2] || "").trim().toLowerCase();
|
||||||
};
|
if (!prefix || !segment) return null;
|
||||||
|
return { prefix, segment };
|
||||||
}
|
}
|
||||||
|
|
||||||
function sortByPriority(labels, priorityIndex) {
|
function sortByPriority(labels, priorityIndex) {
|
||||||
|
|
@ -389,7 +394,7 @@ jobs:
|
||||||
for (const [prefix, segments] of segmentsByPrefix) {
|
for (const [prefix, segments] of segmentsByPrefix) {
|
||||||
const hasSpecificSegment = [...segments].some((segment) => segment !== "core");
|
const hasSpecificSegment = [...segments].some((segment) => segment !== "core");
|
||||||
if (hasSpecificSegment) {
|
if (hasSpecificSegment) {
|
||||||
refined.delete(`${prefix}:core`);
|
refined.delete(formatModuleLabel(prefix, "core"));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -418,7 +423,7 @@ jobs:
|
||||||
if (uniqueSegments.length === 0) continue;
|
if (uniqueSegments.length === 0) continue;
|
||||||
|
|
||||||
if (uniqueSegments.length === 1) {
|
if (uniqueSegments.length === 1) {
|
||||||
compactedModuleLabels.add(`${prefix}:${uniqueSegments[0]}`);
|
compactedModuleLabels.add(formatModuleLabel(prefix, uniqueSegments[0]));
|
||||||
} else {
|
} else {
|
||||||
forcePathPrefixes.add(prefix);
|
forcePathPrefixes.add(prefix);
|
||||||
}
|
}
|
||||||
|
|
@ -609,7 +614,7 @@ jobs:
|
||||||
segment = normalizeLabelSegment(segment);
|
segment = normalizeLabelSegment(segment);
|
||||||
if (!segment) continue;
|
if (!segment) continue;
|
||||||
|
|
||||||
detectedModuleLabels.add(`${rule.prefix}:${segment}`);
|
detectedModuleLabels.add(formatModuleLabel(rule.prefix, segment));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -635,7 +640,7 @@ jobs:
|
||||||
|
|
||||||
for (const keyword of providerKeywordHints) {
|
for (const keyword of providerKeywordHints) {
|
||||||
if (containsKeyword(searchableText, keyword)) {
|
if (containsKeyword(searchableText, keyword)) {
|
||||||
detectedModuleLabels.add(`provider:${keyword}`);
|
detectedModuleLabels.add(formatModuleLabel("provider", keyword));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -661,7 +666,7 @@ jobs:
|
||||||
|
|
||||||
for (const keyword of channelKeywordHints) {
|
for (const keyword of channelKeywordHints) {
|
||||||
if (containsKeyword(searchableText, keyword)) {
|
if (containsKeyword(searchableText, keyword)) {
|
||||||
detectedModuleLabels.add(`channel:${keyword}`);
|
detectedModuleLabels.add(formatModuleLabel("channel", keyword));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
22
.gitignore
vendored
22
.gitignore
vendored
|
|
@ -4,6 +4,26 @@ firmware/*/target
|
||||||
*.db-journal
|
*.db-journal
|
||||||
.DS_Store
|
.DS_Store
|
||||||
.wt-pr37/
|
.wt-pr37/
|
||||||
.env
|
|
||||||
__pycache__/
|
__pycache__/
|
||||||
*.pyc
|
*.pyc
|
||||||
|
docker-compose.override.yml
|
||||||
|
|
||||||
|
# Environment files (may contain secrets)
|
||||||
|
.env
|
||||||
|
|
||||||
|
# Python virtual environments
|
||||||
|
|
||||||
|
.venv/
|
||||||
|
venv/
|
||||||
|
|
||||||
|
# ESP32 build cache (esp-idf-sys managed)
|
||||||
|
|
||||||
|
.embuild/
|
||||||
|
.env.local
|
||||||
|
.env.*.local
|
||||||
|
|
||||||
|
# Secret keys and credentials
|
||||||
|
.secret_key
|
||||||
|
*.key
|
||||||
|
*.pem
|
||||||
|
credentials.json
|
||||||
|
|
|
||||||
|
|
@ -79,6 +79,94 @@ git push --no-verify
|
||||||
|
|
||||||
> **Note:** CI runs the same checks, so skipped hooks will be caught on the PR.
|
> **Note:** CI runs the same checks, so skipped hooks will be caught on the PR.
|
||||||
|
|
||||||
|
## Local Secret Management (Required)
|
||||||
|
|
||||||
|
ZeroClaw supports layered secret management for local development and CI hygiene.
|
||||||
|
|
||||||
|
### Secret Storage Options
|
||||||
|
|
||||||
|
1. **Environment variables** (recommended for local development)
|
||||||
|
- Copy `.env.example` to `.env` and fill in values
|
||||||
|
- `.env` files are Git-ignored and should stay local
|
||||||
|
- Best for temporary/local API keys
|
||||||
|
|
||||||
|
2. **Config file** (`~/.zeroclaw/config.toml`)
|
||||||
|
- Persistent setup for long-term use
|
||||||
|
- When `secrets.encrypt = true` (default), secret values are encrypted before save
|
||||||
|
- Secret key is stored at `~/.zeroclaw/.secret_key` with restricted permissions
|
||||||
|
- Use `zeroclaw onboard` for guided setup
|
||||||
|
|
||||||
|
### Runtime Resolution Rules
|
||||||
|
|
||||||
|
API key resolution follows this order:
|
||||||
|
|
||||||
|
1. Explicit key passed from config/CLI
|
||||||
|
2. Provider-specific env vars (`OPENROUTER_API_KEY`, `OPENAI_API_KEY`, `ANTHROPIC_API_KEY`, ...)
|
||||||
|
3. Generic env vars (`ZEROCLAW_API_KEY`, `API_KEY`)
|
||||||
|
|
||||||
|
Provider/model config overrides:
|
||||||
|
|
||||||
|
- `ZEROCLAW_PROVIDER` / `PROVIDER`
|
||||||
|
- `ZEROCLAW_MODEL`
|
||||||
|
|
||||||
|
See `.env.example` for practical examples and currently supported provider key env vars.
|
||||||
|
|
||||||
|
### Pre-Commit Secret Hygiene (Mandatory)
|
||||||
|
|
||||||
|
Before every commit, verify:
|
||||||
|
|
||||||
|
- [ ] No `.env` files are staged (`.env.example` only)
|
||||||
|
- [ ] No raw API keys/tokens in code, tests, fixtures, examples, logs, or commit messages
|
||||||
|
- [ ] No credentials in debug output or error payloads
|
||||||
|
- [ ] `git diff --cached` has no accidental secret-like strings
|
||||||
|
|
||||||
|
Quick local audit:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Search staged diff for common secret markers
|
||||||
|
git diff --cached | grep -iE '(api[_-]?key|secret|token|password|bearer|sk-)'
|
||||||
|
|
||||||
|
# Confirm no .env file is staged
|
||||||
|
git status --short | grep -E '\.env$'
|
||||||
|
```
|
||||||
|
|
||||||
|
### Optional Local Secret Scanning
|
||||||
|
|
||||||
|
For extra guardrails, install one of:
|
||||||
|
|
||||||
|
- **gitleaks**: [GitHub - gitleaks/gitleaks](https://github.com/gitleaks/gitleaks)
|
||||||
|
- **truffleHog**: [GitHub - trufflesecurity/trufflehog](https://github.com/trufflesecurity/trufflehog)
|
||||||
|
- **git-secrets**: [GitHub - awslabs/git-secrets](https://github.com/awslabs/git-secrets)
|
||||||
|
|
||||||
|
This repo includes `.githooks/pre-commit` to run `gitleaks protect --staged --redact` when gitleaks is installed.
|
||||||
|
|
||||||
|
Enable hooks with:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git config core.hooksPath .githooks
|
||||||
|
```
|
||||||
|
|
||||||
|
If gitleaks is not installed, the pre-commit hook prints a warning and continues.
|
||||||
|
|
||||||
|
### What Must Never Be Committed
|
||||||
|
|
||||||
|
- `.env` files (use `.env.example` only)
|
||||||
|
- API keys, tokens, passwords, or credentials (plain or encrypted)
|
||||||
|
- OAuth tokens or session identifiers
|
||||||
|
- Webhook signing secrets
|
||||||
|
- `~/.zeroclaw/.secret_key` or similar key files
|
||||||
|
- Personal identifiers or real user data in tests/fixtures
|
||||||
|
|
||||||
|
### If a Secret Is Committed Accidentally
|
||||||
|
|
||||||
|
1. Revoke/rotate the credential immediately
|
||||||
|
2. Do not rely only on `git revert` (history still contains the secret)
|
||||||
|
3. Purge history with `git filter-repo` or BFG
|
||||||
|
4. Force-push cleaned history (coordinate with maintainers)
|
||||||
|
5. Ensure the leaked value is removed from PR/issue/discussion/comment history
|
||||||
|
|
||||||
|
Reference: [GitHub guide: removing sensitive data from a repository](https://docs.github.com/en/authentication/keeping-your-account-and-data-secure/removing-sensitive-data-from-a-repository)
|
||||||
|
|
||||||
## Collaboration Tracks (Risk-Based)
|
## Collaboration Tracks (Risk-Based)
|
||||||
|
|
||||||
To keep review throughput high without lowering quality, every PR should map to one track:
|
To keep review throughput high without lowering quality, every PR should map to one track:
|
||||||
|
|
|
||||||
51
Cargo.lock
generated
51
Cargo.lock
generated
|
|
@ -209,6 +209,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "8b52af3cb4058c895d37317bb27508dccc8e5f2d39454016b297bf4a400597b8"
|
checksum = "8b52af3cb4058c895d37317bb27508dccc8e5f2d39454016b297bf4a400597b8"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"axum-core",
|
"axum-core",
|
||||||
|
"base64",
|
||||||
"bytes",
|
"bytes",
|
||||||
"form_urlencoded",
|
"form_urlencoded",
|
||||||
"futures-util",
|
"futures-util",
|
||||||
|
|
@ -227,8 +228,10 @@ dependencies = [
|
||||||
"serde_json",
|
"serde_json",
|
||||||
"serde_path_to_error",
|
"serde_path_to_error",
|
||||||
"serde_urlencoded",
|
"serde_urlencoded",
|
||||||
|
"sha1",
|
||||||
"sync_wrapper",
|
"sync_wrapper",
|
||||||
"tokio",
|
"tokio",
|
||||||
|
"tokio-tungstenite 0.28.0",
|
||||||
"tower",
|
"tower",
|
||||||
"tower-layer",
|
"tower-layer",
|
||||||
"tower-service",
|
"tower-service",
|
||||||
|
|
@ -2057,6 +2060,15 @@ dependencies = [
|
||||||
"hashify",
|
"hashify",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "matchers"
|
||||||
|
version = "0.2.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "d1525a2a28c7f4fa0fc98bb91ae755d1e2d1505079e05539e35bc876b5d65ae9"
|
||||||
|
dependencies = [
|
||||||
|
"regex-automata",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "matchit"
|
name = "matchit"
|
||||||
version = "0.8.4"
|
version = "0.8.4"
|
||||||
|
|
@ -3747,10 +3759,22 @@ dependencies = [
|
||||||
"rustls-pki-types",
|
"rustls-pki-types",
|
||||||
"tokio",
|
"tokio",
|
||||||
"tokio-rustls",
|
"tokio-rustls",
|
||||||
"tungstenite",
|
"tungstenite 0.24.0",
|
||||||
"webpki-roots 0.26.11",
|
"webpki-roots 0.26.11",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "tokio-tungstenite"
|
||||||
|
version = "0.28.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "d25a406cddcc431a75d3d9afc6a7c0f7428d4891dd973e4d54c56b46127bf857"
|
||||||
|
dependencies = [
|
||||||
|
"futures-util",
|
||||||
|
"log",
|
||||||
|
"tokio",
|
||||||
|
"tungstenite 0.28.0",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tokio-util"
|
name = "tokio-util"
|
||||||
version = "0.7.18"
|
version = "0.7.18"
|
||||||
|
|
@ -3940,9 +3964,13 @@ version = "0.3.22"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "2f30143827ddab0d256fd843b7a66d164e9f271cfa0dde49142c5ca0ca291f1e"
|
checksum = "2f30143827ddab0d256fd843b7a66d164e9f271cfa0dde49142c5ca0ca291f1e"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"matchers",
|
||||||
"nu-ansi-term",
|
"nu-ansi-term",
|
||||||
|
"once_cell",
|
||||||
|
"regex-automata",
|
||||||
"sharded-slab",
|
"sharded-slab",
|
||||||
"thread_local",
|
"thread_local",
|
||||||
|
"tracing",
|
||||||
"tracing-core",
|
"tracing-core",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
@ -3978,6 +4006,23 @@ dependencies = [
|
||||||
"utf-8",
|
"utf-8",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "tungstenite"
|
||||||
|
version = "0.28.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "8628dcc84e5a09eb3d8423d6cb682965dea9133204e8fb3efee74c2a0c259442"
|
||||||
|
dependencies = [
|
||||||
|
"bytes",
|
||||||
|
"data-encoding",
|
||||||
|
"http 1.4.0",
|
||||||
|
"httparse",
|
||||||
|
"log",
|
||||||
|
"rand 0.9.2",
|
||||||
|
"sha1",
|
||||||
|
"thiserror 2.0.18",
|
||||||
|
"utf-8",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "twox-hash"
|
name = "twox-hash"
|
||||||
version = "2.1.2"
|
version = "2.1.2"
|
||||||
|
|
@ -4880,7 +4925,9 @@ dependencies = [
|
||||||
"pdf-extract",
|
"pdf-extract",
|
||||||
"probe-rs",
|
"probe-rs",
|
||||||
"prometheus",
|
"prometheus",
|
||||||
|
"prost",
|
||||||
"rand 0.8.5",
|
"rand 0.8.5",
|
||||||
|
"regex",
|
||||||
"reqwest",
|
"reqwest",
|
||||||
"rppal",
|
"rppal",
|
||||||
"rusqlite",
|
"rusqlite",
|
||||||
|
|
@ -4896,7 +4943,7 @@ dependencies = [
|
||||||
"tokio-rustls",
|
"tokio-rustls",
|
||||||
"tokio-serial",
|
"tokio-serial",
|
||||||
"tokio-test",
|
"tokio-test",
|
||||||
"tokio-tungstenite",
|
"tokio-tungstenite 0.24.0",
|
||||||
"toml",
|
"toml",
|
||||||
"tower",
|
"tower",
|
||||||
"tower-http",
|
"tower-http",
|
||||||
|
|
|
||||||
28
Cargo.toml
28
Cargo.toml
|
|
@ -1,3 +1,7 @@
|
||||||
|
[workspace]
|
||||||
|
members = ["."]
|
||||||
|
resolver = "2"
|
||||||
|
|
||||||
[package]
|
[package]
|
||||||
name = "zeroclaw"
|
name = "zeroclaw"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
|
|
@ -31,7 +35,7 @@ shellexpand = "3.1"
|
||||||
|
|
||||||
# Logging - minimal
|
# Logging - minimal
|
||||||
tracing = { version = "0.1", default-features = false }
|
tracing = { version = "0.1", default-features = false }
|
||||||
tracing-subscriber = { version = "0.3", default-features = false, features = ["fmt", "ansi"] }
|
tracing-subscriber = { version = "0.3", default-features = false, features = ["fmt", "ansi", "env-filter"] }
|
||||||
|
|
||||||
# Observability - Prometheus metrics
|
# Observability - Prometheus metrics
|
||||||
prometheus = { version = "0.14", default-features = false }
|
prometheus = { version = "0.14", default-features = false }
|
||||||
|
|
@ -63,12 +67,12 @@ rand = "0.8"
|
||||||
# Fast mutexes that don't poison on panic
|
# Fast mutexes that don't poison on panic
|
||||||
parking_lot = "0.12"
|
parking_lot = "0.12"
|
||||||
|
|
||||||
# Landlock (Linux sandbox) - optional dependency
|
|
||||||
landlock = { version = "0.4", optional = true }
|
|
||||||
|
|
||||||
# Async traits
|
# Async traits
|
||||||
async-trait = "0.1"
|
async-trait = "0.1"
|
||||||
|
|
||||||
|
# Protobuf encode/decode (Feishu WS long-connection frame codec)
|
||||||
|
prost = { version = "0.14", default-features = false }
|
||||||
|
|
||||||
# Memory / persistence
|
# Memory / persistence
|
||||||
rusqlite = { version = "0.38", features = ["bundled"] }
|
rusqlite = { version = "0.38", features = ["bundled"] }
|
||||||
chrono = { version = "0.4", default-features = false, features = ["clock", "std", "serde"] }
|
chrono = { version = "0.4", default-features = false, features = ["clock", "std", "serde"] }
|
||||||
|
|
@ -86,6 +90,7 @@ glob = "0.3"
|
||||||
tokio-tungstenite = { version = "0.24", features = ["rustls-tls-webpki-roots"] }
|
tokio-tungstenite = { version = "0.24", features = ["rustls-tls-webpki-roots"] }
|
||||||
futures-util = { version = "0.3", default-features = false, features = ["sink"] }
|
futures-util = { version = "0.3", default-features = false, features = ["sink"] }
|
||||||
futures = "0.3"
|
futures = "0.3"
|
||||||
|
regex = "1.10"
|
||||||
hostname = "0.4.2"
|
hostname = "0.4.2"
|
||||||
lettre = { version = "0.11.19", default-features = false, features = ["builder", "smtp-transport", "rustls-tls"] }
|
lettre = { version = "0.11.19", default-features = false, features = ["builder", "smtp-transport", "rustls-tls"] }
|
||||||
mail-parser = "0.11.2"
|
mail-parser = "0.11.2"
|
||||||
|
|
@ -95,7 +100,7 @@ tokio-rustls = "0.26.4"
|
||||||
webpki-roots = "1.0.6"
|
webpki-roots = "1.0.6"
|
||||||
|
|
||||||
# HTTP server (gateway) — replaces raw TCP for proper HTTP/1.1 compliance
|
# HTTP server (gateway) — replaces raw TCP for proper HTTP/1.1 compliance
|
||||||
axum = { version = "0.8", default-features = false, features = ["http1", "json", "tokio", "query"] }
|
axum = { version = "0.8", default-features = false, features = ["http1", "json", "tokio", "query", "ws"] }
|
||||||
tower = { version = "0.5", default-features = false }
|
tower = { version = "0.5", default-features = false }
|
||||||
tower-http = { version = "0.6", default-features = false, features = ["limit", "timeout"] }
|
tower-http = { version = "0.6", default-features = false, features = ["limit", "timeout"] }
|
||||||
http-body-util = "0.1"
|
http-body-util = "0.1"
|
||||||
|
|
@ -117,19 +122,28 @@ probe-rs = { version = "0.30", optional = true }
|
||||||
# PDF extraction for datasheet RAG (optional, enable with --features rag-pdf)
|
# PDF extraction for datasheet RAG (optional, enable with --features rag-pdf)
|
||||||
pdf-extract = { version = "0.10", optional = true }
|
pdf-extract = { version = "0.10", optional = true }
|
||||||
|
|
||||||
# Raspberry Pi GPIO (Linux/RPi only) — target-specific to avoid compile failure on macOS
|
# Raspberry Pi GPIO / Landlock (Linux only) — target-specific to avoid compile failure on macOS
|
||||||
[target.'cfg(target_os = "linux")'.dependencies]
|
[target.'cfg(target_os = "linux")'.dependencies]
|
||||||
rppal = { version = "0.14", optional = true }
|
rppal = { version = "0.14", optional = true }
|
||||||
|
landlock = { version = "0.4", optional = true }
|
||||||
|
|
||||||
[features]
|
[features]
|
||||||
default = ["hardware"]
|
default = ["hardware"]
|
||||||
hardware = ["nusb", "tokio-serial"]
|
hardware = ["nusb", "tokio-serial"]
|
||||||
peripheral-rpi = ["rppal"]
|
peripheral-rpi = ["rppal"]
|
||||||
|
# Browser backend feature alias used by cfg(feature = "browser-native")
|
||||||
|
browser-native = ["dep:fantoccini"]
|
||||||
|
# Backward-compatible alias for older invocations
|
||||||
|
fantoccini = ["browser-native"]
|
||||||
|
# Sandbox feature aliases used by cfg(feature = "sandbox-*")
|
||||||
|
sandbox-landlock = ["dep:landlock"]
|
||||||
|
sandbox-bubblewrap = []
|
||||||
|
# Backward-compatible alias for older invocations
|
||||||
|
landlock = ["sandbox-landlock"]
|
||||||
# probe = probe-rs for Nucleo memory read (adds ~50 deps; optional)
|
# probe = probe-rs for Nucleo memory read (adds ~50 deps; optional)
|
||||||
probe = ["dep:probe-rs"]
|
probe = ["dep:probe-rs"]
|
||||||
# rag-pdf = PDF ingestion for datasheet RAG
|
# rag-pdf = PDF ingestion for datasheet RAG
|
||||||
rag-pdf = ["dep:pdf-extract"]
|
rag-pdf = ["dep:pdf-extract"]
|
||||||
|
|
||||||
[profile.release]
|
[profile.release]
|
||||||
opt-level = "z" # Optimize for size
|
opt-level = "z" # Optimize for size
|
||||||
lto = "thin" # Lower memory use during release builds
|
lto = "thin" # Lower memory use during release builds
|
||||||
|
|
|
||||||
211
LICENSE
211
LICENSE
|
|
@ -1,197 +1,28 @@
|
||||||
Apache License
|
MIT License
|
||||||
Version 2.0, January 2004
|
|
||||||
http://www.apache.org/licenses/
|
|
||||||
|
|
||||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
Copyright (c) 2025 ZeroClaw Labs
|
||||||
|
|
||||||
1. Definitions.
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
"License" shall mean the terms and conditions for use, reproduction,
|
The above copyright notice and this permission notice shall be included in all
|
||||||
and distribution as defined by Sections 1 through 9 of this document.
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
"Licensor" shall mean the copyright owner or entity authorized by
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
the copyright owner that is granting the License.
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
||||||
|
|
||||||
"Legal Entity" shall mean the union of the acting entity and all
|
================================================================================
|
||||||
other entities that control, are controlled by, or are under common
|
|
||||||
control with that entity. For the purposes of this definition,
|
|
||||||
"control" means (i) the power, direct or indirect, to cause the
|
|
||||||
direction or management of such entity, whether by contract or
|
|
||||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
|
||||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
|
||||||
|
|
||||||
"You" (or "Your") shall mean an individual or Legal Entity
|
This product includes software developed by ZeroClaw Labs and contributors:
|
||||||
exercising permissions granted by this License.
|
https://github.com/zeroclaw-labs/zeroclaw/graphs/contributors
|
||||||
|
|
||||||
"Source" form shall mean the preferred form for making modifications,
|
See NOTICE file for full contributor attribution.
|
||||||
including but not limited to software source code, documentation
|
|
||||||
source, and configuration files.
|
|
||||||
|
|
||||||
"Object" form shall mean any form resulting from mechanical
|
|
||||||
transformation or translation of a Source form, including but
|
|
||||||
not limited to compiled object code, generated documentation,
|
|
||||||
and conversions to other media types.
|
|
||||||
|
|
||||||
"Work" shall mean the work of authorship, whether in Source or
|
|
||||||
Object form, made available under the License, as indicated by a
|
|
||||||
copyright notice that is included in or attached to the work
|
|
||||||
(an example is provided in the Appendix below).
|
|
||||||
|
|
||||||
"Derivative Works" shall mean any work, whether in Source or Object
|
|
||||||
form, that is based on (or derived from) the Work and for which the
|
|
||||||
editorial revisions, annotations, elaborations, or other modifications
|
|
||||||
represent, as a whole, an original work of authorship. For the purposes
|
|
||||||
of this License, Derivative Works shall not include works that remain
|
|
||||||
separable from, or merely link (or bind by name) to the interfaces of,
|
|
||||||
the Work and Derivative Works thereof.
|
|
||||||
|
|
||||||
"Contribution" shall mean any work of authorship, including
|
|
||||||
the original version of the Work and any modifications or additions
|
|
||||||
to that Work or Derivative Works thereof, that is intentionally
|
|
||||||
submitted to the Licensor for inclusion in the Work by the copyright owner
|
|
||||||
or by an individual or Legal Entity authorized to submit on behalf of
|
|
||||||
the copyright owner. For the purposes of this definition, "submitted"
|
|
||||||
means any form of electronic, verbal, or written communication sent
|
|
||||||
to the Licensor or its representatives, including but not limited to
|
|
||||||
communication on electronic mailing lists, source code control systems,
|
|
||||||
and issue tracking systems that are managed by, or on behalf of, the
|
|
||||||
Licensor for the purpose of discussing and improving the Work, but
|
|
||||||
excluding communication that is conspicuously marked or otherwise
|
|
||||||
designated in writing by the copyright owner as "Not a Contribution."
|
|
||||||
|
|
||||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
|
||||||
on behalf of whom a Contribution has been received by Licensor and
|
|
||||||
subsequently incorporated within the Work.
|
|
||||||
|
|
||||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
|
||||||
this License, each Contributor hereby grants to You a perpetual,
|
|
||||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
|
||||||
copyright license to reproduce, prepare Derivative Works of,
|
|
||||||
publicly display, publicly perform, sublicense, and distribute the
|
|
||||||
Work and such Derivative Works in Source or Object form.
|
|
||||||
|
|
||||||
3. Grant of Patent License. Subject to the terms and conditions of
|
|
||||||
this License, each Contributor hereby grants to You a perpetual,
|
|
||||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
|
||||||
(except as stated in this section) patent license to make, have made,
|
|
||||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
|
||||||
where such license applies only to those patent claims licensable
|
|
||||||
by such Contributor that are necessarily infringed by their
|
|
||||||
Contribution(s) alone or by combination of their Contribution(s)
|
|
||||||
with the Work to which such Contribution(s) was submitted. If You
|
|
||||||
institute patent litigation against any entity (including a
|
|
||||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
|
||||||
or a Contribution incorporated within the Work constitutes direct
|
|
||||||
or contributory patent infringement, then any patent licenses
|
|
||||||
granted to You under this License for that Work shall terminate
|
|
||||||
as of the date such litigation is filed.
|
|
||||||
|
|
||||||
4. Redistribution. You may reproduce and distribute copies of the
|
|
||||||
Work or Derivative Works thereof in any medium, with or without
|
|
||||||
modifications, and in Source or Object form, provided that You
|
|
||||||
meet the following conditions:
|
|
||||||
|
|
||||||
(a) You must give any other recipients of the Work or
|
|
||||||
Derivative Works a copy of this License; and
|
|
||||||
|
|
||||||
(b) You must cause any modified files to carry prominent notices
|
|
||||||
stating that You changed the files; and
|
|
||||||
|
|
||||||
(c) You must retain, in the Source form of any Derivative Works
|
|
||||||
that You distribute, all copyright, patent, trademark, and
|
|
||||||
attribution notices from the Source form of the Work,
|
|
||||||
excluding those notices that do not pertain to any part of
|
|
||||||
the Derivative Works; and
|
|
||||||
|
|
||||||
(d) If the Work includes a "NOTICE" text file as part of its
|
|
||||||
distribution, then any Derivative Works that You distribute must
|
|
||||||
include a readable copy of the attribution notices contained
|
|
||||||
within such NOTICE file, excluding those notices that do not
|
|
||||||
pertain to any part of the Derivative Works, in at least one
|
|
||||||
of the following places: within a NOTICE text file distributed
|
|
||||||
as part of the Derivative Works; within the Source form or
|
|
||||||
documentation, if provided along with the Derivative Works; or,
|
|
||||||
within a display generated by the Derivative Works, if and
|
|
||||||
wherever such third-party notices normally appear. The contents
|
|
||||||
of the NOTICE file are for informational purposes only and
|
|
||||||
do not modify the License. You may add Your own attribution
|
|
||||||
notices within Derivative Works that You distribute, alongside
|
|
||||||
or as an addendum to the NOTICE text from the Work, provided
|
|
||||||
that such additional attribution notices cannot be construed
|
|
||||||
as modifying the License.
|
|
||||||
|
|
||||||
You may add Your own copyright statement to Your modifications and
|
|
||||||
may provide additional or different license terms and conditions
|
|
||||||
for use, reproduction, or distribution of Your modifications, or
|
|
||||||
for any such Derivative Works as a whole, provided Your use,
|
|
||||||
reproduction, and distribution of the Work otherwise complies with
|
|
||||||
the conditions stated in this License.
|
|
||||||
|
|
||||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
|
||||||
any Contribution intentionally submitted for inclusion in the Work
|
|
||||||
by You to the Licensor shall be under the terms and conditions of
|
|
||||||
this License, without any additional terms or conditions.
|
|
||||||
Notwithstanding the above, nothing herein shall supersede or modify
|
|
||||||
the terms of any separate license agreement you may have executed
|
|
||||||
with Licensor regarding such Contributions.
|
|
||||||
|
|
||||||
6. Trademarks. This License does not grant permission to use the trade
|
|
||||||
names, trademarks, service marks, or product names of the Licensor,
|
|
||||||
except as required for reasonable and customary use in describing the
|
|
||||||
origin of the Work and reproducing the content of the NOTICE file.
|
|
||||||
|
|
||||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
|
||||||
agreed to in writing, Licensor provides the Work (and each
|
|
||||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
|
||||||
implied, including, without limitation, any warranties or conditions
|
|
||||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
|
||||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
|
||||||
appropriateness of using or redistributing the Work and assume any
|
|
||||||
risks associated with Your exercise of permissions under this License.
|
|
||||||
|
|
||||||
8. Limitation of Liability. In no event and under no legal theory,
|
|
||||||
whether in tort (including negligence), contract, or otherwise,
|
|
||||||
unless required by applicable law (such as deliberate and grossly
|
|
||||||
negligent acts) or agreed to in writing, shall any Contributor be
|
|
||||||
liable to You for damages, including any direct, indirect, special,
|
|
||||||
incidental, or consequential damages of any character arising as a
|
|
||||||
result of this License or out of the use or inability to use the
|
|
||||||
Work (including but not limited to damages for loss of goodwill,
|
|
||||||
work stoppage, computer failure or malfunction, or any and all
|
|
||||||
other commercial damages or losses), even if such Contributor
|
|
||||||
has been advised of the possibility of such damages.
|
|
||||||
|
|
||||||
9. Accepting Warranty or Additional Liability. While redistributing
|
|
||||||
the Work or Derivative Works thereof, You may choose to offer,
|
|
||||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
|
||||||
or other liability obligations and/or rights consistent with this
|
|
||||||
License. However, in accepting such obligations, You may act only
|
|
||||||
on Your own behalf and on Your sole responsibility, not on behalf
|
|
||||||
of any other Contributor, and only if You agree to indemnify,
|
|
||||||
defend, and hold each Contributor harmless for any liability
|
|
||||||
incurred by, or claims asserted against, such Contributor by reason
|
|
||||||
of your accepting any such warranty or additional liability.
|
|
||||||
|
|
||||||
END OF TERMS AND CONDITIONS
|
|
||||||
|
|
||||||
Copyright 2025-2026 Argenis Delarosa
|
|
||||||
|
|
||||||
Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
you may not use this file except in compliance with the License.
|
|
||||||
You may obtain a copy of the License at
|
|
||||||
|
|
||||||
http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
|
|
||||||
Unless required by applicable law or agreed to in writing, software
|
|
||||||
distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
See the License for the specific language governing permissions and
|
|
||||||
limitations under the License.
|
|
||||||
|
|
||||||
===============================================================================
|
|
||||||
|
|
||||||
This product includes software developed by ZeroClaw Labs and contributors:
|
|
||||||
https://github.com/zeroclaw-labs/zeroclaw/graphs/contributors
|
|
||||||
|
|
||||||
See NOTICE file for full contributor attribution.
|
|
||||||
|
|
|
||||||
27
README.md
27
README.md
|
|
@ -10,14 +10,14 @@
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
<p align="center">
|
<p align="center">
|
||||||
<a href="LICENSE"><img src="https://img.shields.io/badge/license-Apache%202.0-blue.svg" alt="License: Apache 2.0" /></a>
|
<a href="LICENSE"><img src="https://img.shields.io/badge/license-MIT-blue.svg" alt="License: MIT" /></a>
|
||||||
<a href="NOTICE"><img src="https://img.shields.io/badge/contributors-27+-green.svg" alt="Contributors" /></a>
|
<a href="NOTICE"><img src="https://img.shields.io/badge/contributors-27+-green.svg" alt="Contributors" /></a>
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
Fast, small, and fully autonomous AI assistant infrastructure — deploy anywhere, swap anything.
|
Fast, small, and fully autonomous AI assistant infrastructure — deploy anywhere, swap anything.
|
||||||
|
|
||||||
```
|
```
|
||||||
~3.4MB binary · <10ms startup · 1,017 tests · 22+ providers · 8 traits · Pluggable everything
|
~3.4MB binary · <10ms startup · 1,017 tests · 23+ providers · 8 traits · Pluggable everything
|
||||||
```
|
```
|
||||||
|
|
||||||
### ✨ Features
|
### ✨ Features
|
||||||
|
|
@ -132,6 +132,9 @@ cd zeroclaw
|
||||||
cargo build --release --locked
|
cargo build --release --locked
|
||||||
cargo install --path . --force --locked
|
cargo install --path . --force --locked
|
||||||
|
|
||||||
|
# Ensure ~/.cargo/bin is in your PATH
|
||||||
|
export PATH="$HOME/.cargo/bin:$PATH"
|
||||||
|
|
||||||
# Quick setup (no prompts)
|
# Quick setup (no prompts)
|
||||||
zeroclaw onboard --api-key sk-... --provider openrouter
|
zeroclaw onboard --api-key sk-... --provider openrouter
|
||||||
|
|
||||||
|
|
@ -187,7 +190,7 @@ Every subsystem is a **trait** — swap implementations with a config change, ze
|
||||||
|
|
||||||
| Subsystem | Trait | Ships with | Extend |
|
| Subsystem | Trait | Ships with | Extend |
|
||||||
|-----------|-------|------------|--------|
|
|-----------|-------|------------|--------|
|
||||||
| **AI Models** | `Provider` | 22+ providers (OpenRouter, Anthropic, OpenAI, Ollama, Venice, Groq, Mistral, xAI, DeepSeek, Together, Fireworks, Perplexity, Cohere, Bedrock, etc.) | `custom:https://your-api.com` — any OpenAI-compatible API |
|
| **AI Models** | `Provider` | 23+ providers (OpenRouter, Anthropic, OpenAI, Ollama, Venice, Groq, Mistral, xAI, DeepSeek, Together, Fireworks, Perplexity, Cohere, Bedrock, Astrai, etc.) | `custom:https://your-api.com` — any OpenAI-compatible API |
|
||||||
| **Channels** | `Channel` | CLI, Telegram, Discord, Slack, iMessage, Matrix, WhatsApp, Webhook | Any messaging API |
|
| **Channels** | `Channel` | CLI, Telegram, Discord, Slack, iMessage, Matrix, WhatsApp, Webhook | Any messaging API |
|
||||||
| **Memory** | `Memory` | SQLite with hybrid search (FTS5 + vector cosine similarity), Lucid bridge (CLI sync + SQLite fallback), Markdown | Any persistence backend |
|
| **Memory** | `Memory` | SQLite with hybrid search (FTS5 + vector cosine similarity), Lucid bridge (CLI sync + SQLite fallback), Markdown | Any persistence backend |
|
||||||
| **Tools** | `Tool` | shell, file_read, file_write, memory_store, memory_recall, memory_forget, browser_open (Brave + allowlist), browser (agent-browser / rust-native), composio (optional) | Any capability |
|
| **Tools** | `Tool` | shell, file_read, file_write, memory_store, memory_recall, memory_forget, browser_open (Brave + allowlist), browser (agent-browser / rust-native), composio (optional) | Any capability |
|
||||||
|
|
@ -287,6 +290,21 @@ rerun channel setup only:
|
||||||
zeroclaw onboard --channels-only
|
zeroclaw onboard --channels-only
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Telegram media replies
|
||||||
|
|
||||||
|
Telegram routing now replies to the source **chat ID** from incoming updates (instead of usernames),
|
||||||
|
which avoids `Bad Request: chat not found` failures.
|
||||||
|
|
||||||
|
For non-text replies, ZeroClaw can send Telegram attachments when the assistant includes markers:
|
||||||
|
|
||||||
|
- `[IMAGE:<path-or-url>]`
|
||||||
|
- `[DOCUMENT:<path-or-url>]`
|
||||||
|
- `[VIDEO:<path-or-url>]`
|
||||||
|
- `[AUDIO:<path-or-url>]`
|
||||||
|
- `[VOICE:<path-or-url>]`
|
||||||
|
|
||||||
|
Paths can be local files (for example `/tmp/screenshot.png`) or HTTPS URLs.
|
||||||
|
|
||||||
### WhatsApp Business Cloud API Setup
|
### WhatsApp Business Cloud API Setup
|
||||||
|
|
||||||
WhatsApp uses Meta's Cloud API with webhooks (push-based, not polling):
|
WhatsApp uses Meta's Cloud API with webhooks (push-based, not polling):
|
||||||
|
|
@ -610,7 +628,7 @@ We're building in the open because the best ideas come from everywhere. If you'r
|
||||||
|
|
||||||
## License
|
## License
|
||||||
|
|
||||||
Apache 2.0 — see [LICENSE](LICENSE) and [NOTICE](NOTICE) for contributor attribution
|
MIT — see [LICENSE](LICENSE) and [NOTICE](NOTICE) for contributor attribution
|
||||||
|
|
||||||
## Contributing
|
## Contributing
|
||||||
|
|
||||||
|
|
@ -624,7 +642,6 @@ See [CONTRIBUTING.md](CONTRIBUTING.md). Implement a trait, submit a PR:
|
||||||
- New `Tunnel` → `src/tunnel/`
|
- New `Tunnel` → `src/tunnel/`
|
||||||
- New `Skill` → `~/.zeroclaw/workspace/skills/<name>/`
|
- New `Skill` → `~/.zeroclaw/workspace/skills/<name>/`
|
||||||
|
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
**ZeroClaw** — Zero overhead. Zero compromise. Deploy anywhere. Swap anything. 🦀
|
**ZeroClaw** — Zero overhead. Zero compromise. Deploy anywhere. Swap anything. 🦀
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,4 @@
|
||||||
FROM ubuntu:22.04
|
FROM ubuntu:22.04@sha256:c7eb020043d8fc2ae0793fb35a37bff1cf33f156d4d4b12ccc7f3ef8706c38b1
|
||||||
|
|
||||||
# Prevent interactive prompts during package installation
|
# Prevent interactive prompts during package installation
|
||||||
ENV DEBIAN_FRONTEND=noninteractive
|
ENV DEBIAN_FRONTEND=noninteractive
|
||||||
|
|
|
||||||
|
|
@ -27,7 +27,7 @@ Merge-blocking checks should stay small and deterministic. Optional checks are u
|
||||||
### Optional Repository Automation
|
### Optional Repository Automation
|
||||||
|
|
||||||
- `.github/workflows/labeler.yml` (`PR Labeler`)
|
- `.github/workflows/labeler.yml` (`PR Labeler`)
|
||||||
- Purpose: scope/path labels + size/risk labels + fine-grained module labels (`<module>:<component>`)
|
- Purpose: scope/path labels + size/risk labels + fine-grained module labels (`<module>: <component>`)
|
||||||
- Additional behavior: label descriptions are auto-managed as hover tooltips to explain each auto-judgment rule
|
- Additional behavior: label descriptions are auto-managed as hover tooltips to explain each auto-judgment rule
|
||||||
- Additional behavior: provider-related keywords in provider/config/onboard/integration changes are promoted to `provider:*` labels (for example `provider:kimi`, `provider:deepseek`)
|
- Additional behavior: provider-related keywords in provider/config/onboard/integration changes are promoted to `provider:*` labels (for example `provider:kimi`, `provider:deepseek`)
|
||||||
- Additional behavior: hierarchical de-duplication keeps only the most specific scope labels (for example `tool:composio` suppresses `tool:core` and `tool`)
|
- Additional behavior: hierarchical de-duplication keeps only the most specific scope labels (for example `tool:composio` suppresses `tool:core` and `tool`)
|
||||||
|
|
|
||||||
|
|
@ -244,7 +244,7 @@ Label discipline:
|
||||||
- Path labels identify subsystem ownership quickly.
|
- Path labels identify subsystem ownership quickly.
|
||||||
- Size labels drive batching strategy.
|
- Size labels drive batching strategy.
|
||||||
- Risk labels drive review depth (`risk: low/medium/high`).
|
- Risk labels drive review depth (`risk: low/medium/high`).
|
||||||
- Module labels (`<module>:<component>`) improve reviewer routing for integration-specific changes and future newly-added modules.
|
- Module labels (`<module>: <component>`) improve reviewer routing for integration-specific changes and future newly-added modules.
|
||||||
- `risk: manual` allows maintainers to preserve a human risk judgment when automation lacks context.
|
- `risk: manual` allows maintainers to preserve a human risk judgment when automation lacks context.
|
||||||
- `no-stale` is reserved for accepted-but-blocked work.
|
- `no-stale` is reserved for accepted-but-blocked work.
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -14,7 +14,7 @@ Use it to reduce review latency without reducing quality.
|
||||||
For every new PR, do a fast intake pass:
|
For every new PR, do a fast intake pass:
|
||||||
|
|
||||||
1. Confirm template completeness (`summary`, `validation`, `security`, `rollback`).
|
1. Confirm template completeness (`summary`, `validation`, `security`, `rollback`).
|
||||||
2. Confirm labels (`size:*`, `risk:*`, scope labels such as `provider`/`channel`/`security`, module-scoped labels such as `channel:*`/`provider:*`/`tool:*`, and contributor tier labels when applicable) are present and plausible.
|
2. Confirm labels (`size:*`, `risk:*`, scope labels such as `provider`/`channel`/`security`, module-scoped labels such as `channel: *`/`provider: *`/`tool: *`, and contributor tier labels when applicable) are present and plausible.
|
||||||
3. Confirm CI signal status (`CI Required Gate`).
|
3. Confirm CI signal status (`CI Required Gate`).
|
||||||
4. Confirm scope is one concern (reject mixed mega-PRs unless justified).
|
4. Confirm scope is one concern (reject mixed mega-PRs unless justified).
|
||||||
5. Confirm privacy/data-hygiene and neutral test wording requirements are satisfied.
|
5. Confirm privacy/data-hygiene and neutral test wording requirements are satisfied.
|
||||||
|
|
|
||||||
|
|
@ -12,6 +12,8 @@ use tokio::sync::mpsc;
|
||||||
pub struct ChannelMessage {
|
pub struct ChannelMessage {
|
||||||
pub id: String,
|
pub id: String,
|
||||||
pub sender: String,
|
pub sender: String,
|
||||||
|
/// Channel-specific reply address (e.g. Telegram chat_id, Discord channel_id).
|
||||||
|
pub reply_to: String,
|
||||||
pub content: String,
|
pub content: String,
|
||||||
pub channel: String,
|
pub channel: String,
|
||||||
pub timestamp: u64,
|
pub timestamp: u64,
|
||||||
|
|
@ -90,9 +92,12 @@ impl Channel for TelegramChannel {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let chat_id = msg["chat"]["id"].to_string();
|
||||||
|
|
||||||
let channel_msg = ChannelMessage {
|
let channel_msg = ChannelMessage {
|
||||||
id: msg["message_id"].to_string(),
|
id: msg["message_id"].to_string(),
|
||||||
sender,
|
sender,
|
||||||
|
reply_to: chat_id,
|
||||||
content: msg["text"].as_str().unwrap_or("").to_string(),
|
content: msg["text"].as_str().unwrap_or("").to_string(),
|
||||||
channel: "telegram".into(),
|
channel: "telegram".into(),
|
||||||
timestamp: msg["date"].as_u64().unwrap_or(0),
|
timestamp: msg["date"].as_u64().unwrap_or(0),
|
||||||
|
|
|
||||||
|
|
@ -2,4 +2,10 @@
|
||||||
target = "riscv32imc-esp-espidf"
|
target = "riscv32imc-esp-espidf"
|
||||||
|
|
||||||
[target.riscv32imc-esp-espidf]
|
[target.riscv32imc-esp-espidf]
|
||||||
|
linker = "ldproxy"
|
||||||
runner = "espflash flash --monitor"
|
runner = "espflash flash --monitor"
|
||||||
|
# ESP-IDF 5.x uses 64-bit time_t
|
||||||
|
rustflags = ["-C", "default-linker-libraries", "--cfg", "espidf_time64"]
|
||||||
|
|
||||||
|
[unstable]
|
||||||
|
build-std = ["std", "panic_abort"]
|
||||||
|
|
|
||||||
106
firmware/zeroclaw-esp32/Cargo.lock
generated
106
firmware/zeroclaw-esp32/Cargo.lock
generated
|
|
@ -58,24 +58,22 @@ checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "bindgen"
|
name = "bindgen"
|
||||||
version = "0.63.0"
|
version = "0.71.1"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "36d860121800b2a9a94f9b5604b332d5cffb234ce17609ea479d723dbc9d3885"
|
checksum = "5f58bf3d7db68cfbac37cfc485a8d711e87e064c3d0fe0435b92f7a407f9d6b3"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"bitflags 1.3.2",
|
"bitflags 2.11.0",
|
||||||
"cexpr",
|
"cexpr",
|
||||||
"clang-sys",
|
"clang-sys",
|
||||||
"lazy_static",
|
"itertools",
|
||||||
"lazycell",
|
|
||||||
"log",
|
"log",
|
||||||
"peeking_take_while",
|
"prettyplease",
|
||||||
"proc-macro2",
|
"proc-macro2",
|
||||||
"quote",
|
"quote",
|
||||||
"regex",
|
"regex",
|
||||||
"rustc-hash",
|
"rustc-hash",
|
||||||
"shlex",
|
"shlex",
|
||||||
"syn 1.0.109",
|
"syn 2.0.116",
|
||||||
"which",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
|
@ -374,14 +372,15 @@ checksum = "dc2d050bdc5c21e0862a89256ed8029ae6c290a93aecefc73084b3002cdebb01"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "embassy-sync"
|
name = "embassy-sync"
|
||||||
version = "0.5.0"
|
version = "0.7.2"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "dd938f25c0798db4280fcd8026bf4c2f48789aebf8f77b6e5cf8a7693ba114ec"
|
checksum = "73974a3edbd0bd286759b3d483540f0ebef705919a5f56f4fc7709066f71689b"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"cfg-if",
|
"cfg-if",
|
||||||
"critical-section",
|
"critical-section",
|
||||||
"embedded-io-async",
|
"embedded-io-async",
|
||||||
"futures-util",
|
"futures-core",
|
||||||
|
"futures-sink",
|
||||||
"heapless",
|
"heapless",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
@ -446,16 +445,15 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "embedded-svc"
|
name = "embedded-svc"
|
||||||
version = "0.27.1"
|
version = "0.28.1"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "ac6f87e7654f28018340aa55f933803017aefabaa5417820a3b2f808033c7bbc"
|
checksum = "a7770e30ab55cfbf954c00019522490d6ce26a3334bede05a732ba61010e98e0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"defmt 0.3.100",
|
"defmt 0.3.100",
|
||||||
"embedded-io",
|
"embedded-io",
|
||||||
"embedded-io-async",
|
"embedded-io-async",
|
||||||
"enumset",
|
"enumset",
|
||||||
"heapless",
|
"heapless",
|
||||||
"no-std-net",
|
|
||||||
"num_enum",
|
"num_enum",
|
||||||
"serde",
|
"serde",
|
||||||
"strum 0.25.0",
|
"strum 0.25.0",
|
||||||
|
|
@ -463,9 +461,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "embuild"
|
name = "embuild"
|
||||||
version = "0.31.4"
|
version = "0.33.1"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "4caa4f198bb9152a55c0103efb83fa4edfcbb8625f4c9e94ae8ec8e23827c563"
|
checksum = "e188ad2bbe82afa841ea4a29880651e53ab86815db036b2cb9f8de3ac32dad75"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"anyhow",
|
"anyhow",
|
||||||
"bindgen",
|
"bindgen",
|
||||||
|
|
@ -475,6 +473,7 @@ dependencies = [
|
||||||
"globwalk",
|
"globwalk",
|
||||||
"home",
|
"home",
|
||||||
"log",
|
"log",
|
||||||
|
"regex",
|
||||||
"remove_dir_all",
|
"remove_dir_all",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
|
|
@ -533,9 +532,8 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "esp-idf-hal"
|
name = "esp-idf-hal"
|
||||||
version = "0.43.1"
|
version = "0.45.2"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "git+https://github.com/esp-rs/esp-idf-hal#bc48639bd626c72afc1e25e5d497b5c639161d30"
|
||||||
checksum = "f7adf3fb19a9ca016cbea1ab8a7b852ac69df8fcde4923c23d3b155efbc42a74"
|
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"atomic-waker",
|
"atomic-waker",
|
||||||
"embassy-sync",
|
"embassy-sync",
|
||||||
|
|
@ -552,14 +550,12 @@ dependencies = [
|
||||||
"heapless",
|
"heapless",
|
||||||
"log",
|
"log",
|
||||||
"nb 1.1.0",
|
"nb 1.1.0",
|
||||||
"num_enum",
|
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "esp-idf-svc"
|
name = "esp-idf-svc"
|
||||||
version = "0.48.1"
|
version = "0.51.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "git+https://github.com/esp-rs/esp-idf-svc#dee202f146c7681e54eabbf118a216fc0195d203"
|
||||||
checksum = "2180642ca122a7fec1ec417a9b1a77aa66aaa067fdf1daae683dd8caba84f26b"
|
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"embassy-futures",
|
"embassy-futures",
|
||||||
"embedded-hal-async",
|
"embedded-hal-async",
|
||||||
|
|
@ -567,6 +563,7 @@ dependencies = [
|
||||||
"embuild",
|
"embuild",
|
||||||
"enumset",
|
"enumset",
|
||||||
"esp-idf-hal",
|
"esp-idf-hal",
|
||||||
|
"futures-io",
|
||||||
"heapless",
|
"heapless",
|
||||||
"log",
|
"log",
|
||||||
"num_enum",
|
"num_enum",
|
||||||
|
|
@ -575,14 +572,13 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "esp-idf-sys"
|
name = "esp-idf-sys"
|
||||||
version = "0.34.1"
|
version = "0.36.1"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "git+https://github.com/esp-rs/esp-idf-sys#64667a38fb8004e1fc3b032488af6857ca3cd849"
|
||||||
checksum = "2e148f97c04ed3e9181a08bcdc9560a515aad939b0ba7f50a0022e294665e0af"
|
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"anyhow",
|
"anyhow",
|
||||||
"bindgen",
|
|
||||||
"build-time",
|
"build-time",
|
||||||
"cargo_metadata",
|
"cargo_metadata",
|
||||||
|
"cmake",
|
||||||
"const_format",
|
"const_format",
|
||||||
"embuild",
|
"embuild",
|
||||||
"envy",
|
"envy",
|
||||||
|
|
@ -649,21 +645,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "7e3450815272ef58cec6d564423f6e755e25379b217b0bc688e295ba24df6b1d"
|
checksum = "7e3450815272ef58cec6d564423f6e755e25379b217b0bc688e295ba24df6b1d"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "futures-task"
|
name = "futures-io"
|
||||||
version = "0.3.32"
|
version = "0.3.32"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "037711b3d59c33004d3856fbdc83b99d4ff37a24768fa1be9ce3538a1cde4393"
|
checksum = "cecba35d7ad927e23624b22ad55235f2239cfa44fd10428eecbeba6d6a717718"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "futures-util"
|
name = "futures-sink"
|
||||||
version = "0.3.32"
|
version = "0.3.32"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "389ca41296e6190b48053de0321d02a77f32f8a5d2461dd38762c0593805c6d6"
|
checksum = "c39754e157331b013978ec91992bde1ac089843443c49cbc7f46150b0fad0893"
|
||||||
dependencies = [
|
|
||||||
"futures-core",
|
|
||||||
"futures-task",
|
|
||||||
"pin-project-lite",
|
|
||||||
]
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "getrandom"
|
name = "getrandom"
|
||||||
|
|
@ -827,6 +818,15 @@ dependencies = [
|
||||||
"serde_core",
|
"serde_core",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "itertools"
|
||||||
|
version = "0.13.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186"
|
||||||
|
dependencies = [
|
||||||
|
"either",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "itoa"
|
name = "itoa"
|
||||||
version = "1.0.17"
|
version = "1.0.17"
|
||||||
|
|
@ -843,18 +843,6 @@ dependencies = [
|
||||||
"wasm-bindgen",
|
"wasm-bindgen",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "lazy_static"
|
|
||||||
version = "1.5.0"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe"
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "lazycell"
|
|
||||||
version = "1.3.0"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55"
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "leb128fmt"
|
name = "leb128fmt"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
|
|
@ -945,12 +933,6 @@ dependencies = [
|
||||||
"libc",
|
"libc",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "no-std-net"
|
|
||||||
version = "0.5.0"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "1bcece43b12349917e096cddfa66107277f123e6c96a5aea78711dc601a47152"
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "nom"
|
name = "nom"
|
||||||
version = "7.1.3"
|
version = "7.1.3"
|
||||||
|
|
@ -1007,18 +989,6 @@ version = "1.21.3"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d"
|
checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d"
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "peeking_take_while"
|
|
||||||
version = "0.1.2"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "19b17cddbe7ec3f8bc800887bab5e717348c95ea2ca0b1bf0837fb964dc67099"
|
|
||||||
|
|
||||||
[[package]]
|
|
||||||
name = "pin-project-lite"
|
|
||||||
version = "0.2.16"
|
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
|
||||||
checksum = "3b3cff922bd51709b605d9ead9aa71031d81447142d828eb4a6eba76fe619f9b"
|
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "prettyplease"
|
name = "prettyplease"
|
||||||
version = "0.2.37"
|
version = "0.2.37"
|
||||||
|
|
@ -1138,9 +1108,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "rustc-hash"
|
name = "rustc-hash"
|
||||||
version = "1.1.0"
|
version = "2.1.1"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2"
|
checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "rustix"
|
name = "rustix"
|
||||||
|
|
|
||||||
|
|
@ -14,15 +14,21 @@ edition = "2021"
|
||||||
license = "MIT"
|
license = "MIT"
|
||||||
description = "ZeroClaw ESP32 peripheral firmware — GPIO over JSON serial"
|
description = "ZeroClaw ESP32 peripheral firmware — GPIO over JSON serial"
|
||||||
|
|
||||||
|
[patch.crates-io]
|
||||||
|
# Use latest esp-rs crates to fix u8/i8 char pointer compatibility with ESP-IDF 5.x
|
||||||
|
esp-idf-sys = { git = "https://github.com/esp-rs/esp-idf-sys" }
|
||||||
|
esp-idf-hal = { git = "https://github.com/esp-rs/esp-idf-hal" }
|
||||||
|
esp-idf-svc = { git = "https://github.com/esp-rs/esp-idf-svc" }
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
esp-idf-svc = "0.48"
|
esp-idf-svc = { git = "https://github.com/esp-rs/esp-idf-svc" }
|
||||||
log = "0.4"
|
log = "0.4"
|
||||||
anyhow = "1.0"
|
anyhow = "1.0"
|
||||||
serde = { version = "1.0", features = ["derive"] }
|
serde = { version = "1.0", features = ["derive"] }
|
||||||
serde_json = "1.0"
|
serde_json = "1.0"
|
||||||
|
|
||||||
[build-dependencies]
|
[build-dependencies]
|
||||||
embuild = "0.31"
|
embuild = { version = "0.33", features = ["espidf"] }
|
||||||
|
|
||||||
[profile.release]
|
[profile.release]
|
||||||
opt-level = "s"
|
opt-level = "s"
|
||||||
|
|
|
||||||
|
|
@ -2,8 +2,11 @@
|
||||||
|
|
||||||
Peripheral firmware for ESP32 — speaks the same JSON-over-serial protocol as the STM32 firmware. Flash this to your ESP32, then configure ZeroClaw on the host to connect via serial.
|
Peripheral firmware for ESP32 — speaks the same JSON-over-serial protocol as the STM32 firmware. Flash this to your ESP32, then configure ZeroClaw on the host to connect via serial.
|
||||||
|
|
||||||
|
**New to this?** See [SETUP.md](SETUP.md) for step-by-step commands and troubleshooting.
|
||||||
|
|
||||||
## Protocol
|
## Protocol
|
||||||
|
|
||||||
|
|
||||||
- **Request** (host → ESP32): `{"id":"1","cmd":"gpio_write","args":{"pin":13,"value":1}}\n`
|
- **Request** (host → ESP32): `{"id":"1","cmd":"gpio_write","args":{"pin":13,"value":1}}\n`
|
||||||
- **Response** (ESP32 → host): `{"id":"1","ok":true,"result":"done"}\n`
|
- **Response** (ESP32 → host): `{"id":"1","ok":true,"result":"done"}\n`
|
||||||
|
|
||||||
|
|
@ -11,19 +14,44 @@ Commands: `gpio_read`, `gpio_write`.
|
||||||
|
|
||||||
## Prerequisites
|
## Prerequisites
|
||||||
|
|
||||||
1. **ESP toolchain** (espup):
|
1. **RISC-V ESP-IDF** (ESP32-C2/C3): Uses nightly Rust with `build-std`.
|
||||||
|
|
||||||
|
**Python**: ESP-IDF requires Python 3.10–3.13 (not 3.14). If you have Python 3.14:
|
||||||
|
```sh
|
||||||
|
brew install python@3.12
|
||||||
|
```
|
||||||
|
|
||||||
|
**virtualenv** (needed by ESP-IDF tools; PEP 668 workaround on macOS):
|
||||||
|
```sh
|
||||||
|
/opt/homebrew/opt/python@3.12/bin/python3.12 -m pip install virtualenv --break-system-packages
|
||||||
|
```
|
||||||
|
|
||||||
|
**Rust tools**:
|
||||||
|
```sh
|
||||||
|
cargo install espflash ldproxy
|
||||||
|
```
|
||||||
|
|
||||||
|
The project's `rust-toolchain.toml` pins nightly + rust-src. `esp-idf-sys` downloads ESP-IDF automatically on first build. Use Python 3.12 for the build:
|
||||||
|
```sh
|
||||||
|
export PATH="/opt/homebrew/opt/python@3.12/libexec/bin:$PATH"
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **Xtensa targets** (ESP32, ESP32-S2, ESP32-S3): Use espup instead:
|
||||||
```sh
|
```sh
|
||||||
cargo install espup espflash
|
cargo install espup espflash
|
||||||
espup install
|
espup install
|
||||||
source ~/export-esp.sh # or ~/export-esp.fish for Fish
|
source ~/export-esp.sh
|
||||||
```
|
```
|
||||||
|
Then edit `.cargo/config.toml` to change the target (e.g. `xtensa-esp32-espidf`).
|
||||||
2. **Target**: ESP32-C3 (RISC-V) by default. Edit `.cargo/config.toml` for other targets (e.g. `xtensa-esp32-espidf` for original ESP32).
|
|
||||||
|
|
||||||
## Build & Flash
|
## Build & Flash
|
||||||
|
|
||||||
```sh
|
```sh
|
||||||
cd firmware/zeroclaw-esp32
|
cd firmware/zeroclaw-esp32
|
||||||
|
# Use Python 3.12 (required if you have 3.14)
|
||||||
|
export PATH="/opt/homebrew/opt/python@3.12/libexec/bin:$PATH"
|
||||||
|
# Optional: pin MCU (esp32c3 or esp32c2)
|
||||||
|
export MCU=esp32c3
|
||||||
cargo build --release
|
cargo build --release
|
||||||
espflash flash target/riscv32imc-esp-espidf/release/zeroclaw-esp32 --monitor
|
espflash flash target/riscv32imc-esp-espidf/release/zeroclaw-esp32 --monitor
|
||||||
```
|
```
|
||||||
|
|
|
||||||
156
firmware/zeroclaw-esp32/SETUP.md
Normal file
156
firmware/zeroclaw-esp32/SETUP.md
Normal file
|
|
@ -0,0 +1,156 @@
|
||||||
|
# ESP32 Firmware Setup Guide
|
||||||
|
|
||||||
|
Step-by-step setup for building the ZeroClaw ESP32 firmware. Follow this if you run into issues.
|
||||||
|
|
||||||
|
## Quick Start (copy-paste)
|
||||||
|
|
||||||
|
```sh
|
||||||
|
# 1. Install Python 3.12 (ESP-IDF needs 3.10–3.13, not 3.14)
|
||||||
|
brew install python@3.12
|
||||||
|
|
||||||
|
# 2. Install virtualenv (PEP 668 workaround on macOS)
|
||||||
|
/opt/homebrew/opt/python@3.12/bin/python3.12 -m pip install virtualenv --break-system-packages
|
||||||
|
|
||||||
|
# 3. Install Rust tools
|
||||||
|
cargo install espflash ldproxy
|
||||||
|
|
||||||
|
# 4. Build
|
||||||
|
cd firmware/zeroclaw-esp32
|
||||||
|
export PATH="/opt/homebrew/opt/python@3.12/libexec/bin:$PATH"
|
||||||
|
cargo build --release
|
||||||
|
|
||||||
|
# 5. Flash (connect ESP32 via USB)
|
||||||
|
espflash flash target/riscv32imc-esp-espidf/release/zeroclaw-esp32 --monitor
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Detailed Steps
|
||||||
|
|
||||||
|
### 1. Python
|
||||||
|
|
||||||
|
ESP-IDF requires Python 3.10–3.13. **Python 3.14 is not supported.**
|
||||||
|
|
||||||
|
```sh
|
||||||
|
brew install python@3.12
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. virtualenv
|
||||||
|
|
||||||
|
ESP-IDF tools need `virtualenv`. On macOS with Homebrew Python, PEP 668 blocks `pip install`; use:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
/opt/homebrew/opt/python@3.12/bin/python3.12 -m pip install virtualenv --break-system-packages
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. Rust Tools
|
||||||
|
|
||||||
|
```sh
|
||||||
|
cargo install espflash ldproxy
|
||||||
|
```
|
||||||
|
|
||||||
|
- **espflash**: flash and monitor
|
||||||
|
- **ldproxy**: linker for ESP-IDF builds
|
||||||
|
|
||||||
|
### 4. Use Python 3.12 for Builds
|
||||||
|
|
||||||
|
Before every build (or add to `~/.zshrc`):
|
||||||
|
|
||||||
|
```sh
|
||||||
|
export PATH="/opt/homebrew/opt/python@3.12/libexec/bin:$PATH"
|
||||||
|
```
|
||||||
|
|
||||||
|
### 5. Build
|
||||||
|
|
||||||
|
```sh
|
||||||
|
cd firmware/zeroclaw-esp32
|
||||||
|
cargo build --release
|
||||||
|
```
|
||||||
|
|
||||||
|
First build downloads and compiles ESP-IDF (~5–15 min).
|
||||||
|
|
||||||
|
### 6. Flash
|
||||||
|
|
||||||
|
```sh
|
||||||
|
espflash flash target/riscv32imc-esp-espidf/release/zeroclaw-esp32 --monitor
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Troubleshooting
|
||||||
|
|
||||||
|
### "No space left on device"
|
||||||
|
|
||||||
|
Free disk space. Common targets:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
# Cargo cache (often 5–20 GB)
|
||||||
|
rm -rf ~/.cargo/registry/cache ~/.cargo/registry/src
|
||||||
|
|
||||||
|
# Unused Rust toolchains
|
||||||
|
rustup toolchain list
|
||||||
|
rustup toolchain uninstall <name>
|
||||||
|
|
||||||
|
# iOS Simulator runtimes (~35 GB)
|
||||||
|
xcrun simctl delete unavailable
|
||||||
|
|
||||||
|
# Temp files
|
||||||
|
rm -rf /var/folders/*/T/cargo-install*
|
||||||
|
```
|
||||||
|
|
||||||
|
### "can't find crate for `core`" / "riscv32imc-esp-espidf target may not be installed"
|
||||||
|
|
||||||
|
This project uses **nightly Rust with build-std**, not espup. Ensure:
|
||||||
|
|
||||||
|
- `rust-toolchain.toml` exists (pins nightly + rust-src)
|
||||||
|
- You are **not** sourcing `~/export-esp.sh` (that's for Xtensa targets)
|
||||||
|
- Run `cargo build` from `firmware/zeroclaw-esp32`
|
||||||
|
|
||||||
|
### "externally-managed-environment" / "No module named 'virtualenv'"
|
||||||
|
|
||||||
|
Install virtualenv with the PEP 668 workaround:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
/opt/homebrew/opt/python@3.12/bin/python3.12 -m pip install virtualenv --break-system-packages
|
||||||
|
```
|
||||||
|
|
||||||
|
### "expected `i64`, found `i32`" (time_t mismatch)
|
||||||
|
|
||||||
|
Already fixed in `.cargo/config.toml` with `espidf_time64` for ESP-IDF 5.x. If you use ESP-IDF 4.4, switch to `espidf_time32`.
|
||||||
|
|
||||||
|
### "expected `*const u8`, found `*const i8`" (esp-idf-svc)
|
||||||
|
|
||||||
|
Already fixed via `[patch.crates-io]` in `Cargo.toml` using esp-rs crates from git. Do not remove the patch.
|
||||||
|
|
||||||
|
### 10,000+ files in `git status`
|
||||||
|
|
||||||
|
The `.embuild/` directory (ESP-IDF cache) has ~100k+ files. It is in `.gitignore`. If you see them, ensure `.gitignore` contains:
|
||||||
|
|
||||||
|
```
|
||||||
|
.embuild/
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Optional: Auto-load Python 3.12
|
||||||
|
|
||||||
|
Add to `~/.zshrc`:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
# ESP32 firmware build
|
||||||
|
export PATH="/opt/homebrew/opt/python@3.12/libexec/bin:$PATH"
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Xtensa Targets (ESP32, ESP32-S2, ESP32-S3)
|
||||||
|
|
||||||
|
For non–RISC-V chips, use espup instead:
|
||||||
|
|
||||||
|
```sh
|
||||||
|
cargo install espup espflash
|
||||||
|
espup install
|
||||||
|
source ~/export-esp.sh
|
||||||
|
```
|
||||||
|
|
||||||
|
Then edit `.cargo/config.toml` to use `xtensa-esp32-espidf` (or the correct target).
|
||||||
3
firmware/zeroclaw-esp32/rust-toolchain.toml
Normal file
3
firmware/zeroclaw-esp32/rust-toolchain.toml
Normal file
|
|
@ -0,0 +1,3 @@
|
||||||
|
[toolchain]
|
||||||
|
channel = "nightly"
|
||||||
|
components = ["rust-src"]
|
||||||
|
|
@ -6,8 +6,9 @@
|
||||||
//! Protocol: same as STM32 — see docs/hardware-peripherals-design.md
|
//! Protocol: same as STM32 — see docs/hardware-peripherals-design.md
|
||||||
|
|
||||||
use esp_idf_svc::hal::gpio::PinDriver;
|
use esp_idf_svc::hal::gpio::PinDriver;
|
||||||
use esp_idf_svc::hal::prelude::*;
|
use esp_idf_svc::hal::peripherals::Peripherals;
|
||||||
use esp_idf_svc::hal::uart::*;
|
use esp_idf_svc::hal::uart::{UartConfig, UartDriver};
|
||||||
|
use esp_idf_svc::hal::units::Hertz;
|
||||||
use log::info;
|
use log::info;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
|
@ -36,9 +37,13 @@ fn main() -> anyhow::Result<()> {
|
||||||
let peripherals = Peripherals::take()?;
|
let peripherals = Peripherals::take()?;
|
||||||
let pins = peripherals.pins;
|
let pins = peripherals.pins;
|
||||||
|
|
||||||
|
// Create GPIO output drivers first (they take ownership of pins)
|
||||||
|
let mut gpio2 = PinDriver::output(pins.gpio2)?;
|
||||||
|
let mut gpio13 = PinDriver::output(pins.gpio13)?;
|
||||||
|
|
||||||
// UART0: TX=21, RX=20 (ESP32) — ESP32-C3 may use different pins; adjust for your board
|
// UART0: TX=21, RX=20 (ESP32) — ESP32-C3 may use different pins; adjust for your board
|
||||||
let config = UartConfig::new().baudrate(Hertz(115_200));
|
let config = UartConfig::new().baudrate(Hertz(115_200));
|
||||||
let mut uart = UartDriver::new(
|
let uart = UartDriver::new(
|
||||||
peripherals.uart0,
|
peripherals.uart0,
|
||||||
pins.gpio21,
|
pins.gpio21,
|
||||||
pins.gpio20,
|
pins.gpio20,
|
||||||
|
|
@ -60,7 +65,8 @@ fn main() -> anyhow::Result<()> {
|
||||||
if b == b'\n' {
|
if b == b'\n' {
|
||||||
if !line.is_empty() {
|
if !line.is_empty() {
|
||||||
if let Ok(line_str) = std::str::from_utf8(&line) {
|
if let Ok(line_str) = std::str::from_utf8(&line) {
|
||||||
if let Ok(resp) = handle_request(line_str, &peripherals) {
|
if let Ok(resp) = handle_request(line_str, &mut gpio2, &mut gpio13)
|
||||||
|
{
|
||||||
let out = serde_json::to_string(&resp).unwrap_or_default();
|
let out = serde_json::to_string(&resp).unwrap_or_default();
|
||||||
let _ = uart.write(format!("{}\n", out).as_bytes());
|
let _ = uart.write(format!("{}\n", out).as_bytes());
|
||||||
}
|
}
|
||||||
|
|
@ -80,10 +86,15 @@ fn main() -> anyhow::Result<()> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn handle_request(
|
fn handle_request<G2, G13>(
|
||||||
line: &str,
|
line: &str,
|
||||||
peripherals: &esp_idf_svc::hal::peripherals::Peripherals,
|
gpio2: &mut PinDriver<'_, G2>,
|
||||||
) -> anyhow::Result<Response> {
|
gpio13: &mut PinDriver<'_, G13>,
|
||||||
|
) -> anyhow::Result<Response>
|
||||||
|
where
|
||||||
|
G2: esp_idf_svc::hal::gpio::OutputMode,
|
||||||
|
G13: esp_idf_svc::hal::gpio::OutputMode,
|
||||||
|
{
|
||||||
let req: Request = serde_json::from_str(line.trim())?;
|
let req: Request = serde_json::from_str(line.trim())?;
|
||||||
let id = req.id.clone();
|
let id = req.id.clone();
|
||||||
|
|
||||||
|
|
@ -98,13 +109,13 @@ fn handle_request(
|
||||||
}
|
}
|
||||||
"gpio_read" => {
|
"gpio_read" => {
|
||||||
let pin_num = req.args.get("pin").and_then(|v| v.as_u64()).unwrap_or(0) as i32;
|
let pin_num = req.args.get("pin").and_then(|v| v.as_u64()).unwrap_or(0) as i32;
|
||||||
let value = gpio_read(peripherals, pin_num)?;
|
let value = gpio_read(pin_num)?;
|
||||||
Ok(value.to_string())
|
Ok(value.to_string())
|
||||||
}
|
}
|
||||||
"gpio_write" => {
|
"gpio_write" => {
|
||||||
let pin_num = req.args.get("pin").and_then(|v| v.as_u64()).unwrap_or(0) as i32;
|
let pin_num = req.args.get("pin").and_then(|v| v.as_u64()).unwrap_or(0) as i32;
|
||||||
let value = req.args.get("value").and_then(|v| v.as_u64()).unwrap_or(0);
|
let value = req.args.get("value").and_then(|v| v.as_u64()).unwrap_or(0);
|
||||||
gpio_write(peripherals, pin_num, value)?;
|
gpio_write(gpio2, gpio13, pin_num, value)?;
|
||||||
Ok("done".into())
|
Ok("done".into())
|
||||||
}
|
}
|
||||||
_ => Err(anyhow::anyhow!("Unknown command: {}", req.cmd)),
|
_ => Err(anyhow::anyhow!("Unknown command: {}", req.cmd)),
|
||||||
|
|
@ -126,28 +137,26 @@ fn handle_request(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn gpio_read(_peripherals: &esp_idf_svc::hal::peripherals::Peripherals, _pin: i32) -> anyhow::Result<u8> {
|
fn gpio_read(_pin: i32) -> anyhow::Result<u8> {
|
||||||
// TODO: implement input pin read — requires storing InputPin drivers per pin
|
// TODO: implement input pin read — requires storing InputPin drivers per pin
|
||||||
Ok(0)
|
Ok(0)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn gpio_write(
|
fn gpio_write<G2, G13>(
|
||||||
peripherals: &esp_idf_svc::hal::peripherals::Peripherals,
|
gpio2: &mut PinDriver<'_, G2>,
|
||||||
|
gpio13: &mut PinDriver<'_, G13>,
|
||||||
pin: i32,
|
pin: i32,
|
||||||
value: u64,
|
value: u64,
|
||||||
) -> anyhow::Result<()> {
|
) -> anyhow::Result<()>
|
||||||
let pins = peripherals.pins;
|
where
|
||||||
let level = value != 0;
|
G2: esp_idf_svc::hal::gpio::OutputMode,
|
||||||
|
G13: esp_idf_svc::hal::gpio::OutputMode,
|
||||||
|
{
|
||||||
|
let level = esp_idf_svc::hal::gpio::Level::from(value != 0);
|
||||||
|
|
||||||
match pin {
|
match pin {
|
||||||
2 => {
|
2 => gpio2.set_level(level)?,
|
||||||
let mut out = PinDriver::output(pins.gpio2)?;
|
13 => gpio13.set_level(level)?,
|
||||||
out.set_level(esp_idf_svc::hal::gpio::Level::from(level))?;
|
|
||||||
}
|
|
||||||
13 => {
|
|
||||||
let mut out = PinDriver::output(pins.gpio13)?;
|
|
||||||
out.set_level(esp_idf_svc::hal::gpio::Level::from(level))?;
|
|
||||||
}
|
|
||||||
_ => anyhow::bail!("Pin {} not configured (add to gpio_write)", pin),
|
_ => anyhow::bail!("Pin {} not configured (add to gpio_write)", pin),
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
|
|
|
||||||
324
scripts/recompute_contributor_tiers.sh
Executable file
324
scripts/recompute_contributor_tiers.sh
Executable file
|
|
@ -0,0 +1,324 @@
|
||||||
|
#!/usr/bin/env bash
|
||||||
|
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
SCRIPT_NAME="$(basename "$0")"
|
||||||
|
|
||||||
|
usage() {
|
||||||
|
cat <<USAGE
|
||||||
|
Recompute contributor tier labels for historical PRs/issues.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
./$SCRIPT_NAME [options]
|
||||||
|
|
||||||
|
Options:
|
||||||
|
--repo <owner/repo> Target repository (default: current gh repo)
|
||||||
|
--kind <both|prs|issues>
|
||||||
|
Target objects (default: both)
|
||||||
|
--state <all|open|closed>
|
||||||
|
State filter for listing objects (default: all)
|
||||||
|
--limit <N> Limit processed objects after fetch (default: 0 = no limit)
|
||||||
|
--apply Apply label updates (default is dry-run)
|
||||||
|
--dry-run Preview only (default)
|
||||||
|
-h, --help Show this help
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
./$SCRIPT_NAME --repo zeroclaw-labs/zeroclaw --limit 50
|
||||||
|
./$SCRIPT_NAME --repo zeroclaw-labs/zeroclaw --kind prs --state open --apply
|
||||||
|
USAGE
|
||||||
|
}
|
||||||
|
|
||||||
|
die() {
|
||||||
|
echo "[$SCRIPT_NAME] ERROR: $*" >&2
|
||||||
|
exit 1
|
||||||
|
}
|
||||||
|
|
||||||
|
require_cmd() {
|
||||||
|
if ! command -v "$1" >/dev/null 2>&1; then
|
||||||
|
die "Required command not found: $1"
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
urlencode() {
|
||||||
|
jq -nr --arg value "$1" '$value|@uri'
|
||||||
|
}
|
||||||
|
|
||||||
|
select_contributor_tier() {
|
||||||
|
local merged_count="$1"
|
||||||
|
if (( merged_count >= 50 )); then
|
||||||
|
echo "distinguished contributor"
|
||||||
|
elif (( merged_count >= 20 )); then
|
||||||
|
echo "principal contributor"
|
||||||
|
elif (( merged_count >= 10 )); then
|
||||||
|
echo "experienced contributor"
|
||||||
|
elif (( merged_count >= 5 )); then
|
||||||
|
echo "trusted contributor"
|
||||||
|
else
|
||||||
|
echo ""
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
DRY_RUN=1
|
||||||
|
KIND="both"
|
||||||
|
STATE="all"
|
||||||
|
LIMIT=0
|
||||||
|
REPO=""
|
||||||
|
|
||||||
|
while (($# > 0)); do
|
||||||
|
case "$1" in
|
||||||
|
--repo)
|
||||||
|
[[ $# -ge 2 ]] || die "Missing value for --repo"
|
||||||
|
REPO="$2"
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
--kind)
|
||||||
|
[[ $# -ge 2 ]] || die "Missing value for --kind"
|
||||||
|
KIND="$2"
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
--state)
|
||||||
|
[[ $# -ge 2 ]] || die "Missing value for --state"
|
||||||
|
STATE="$2"
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
--limit)
|
||||||
|
[[ $# -ge 2 ]] || die "Missing value for --limit"
|
||||||
|
LIMIT="$2"
|
||||||
|
shift 2
|
||||||
|
;;
|
||||||
|
--apply)
|
||||||
|
DRY_RUN=0
|
||||||
|
shift
|
||||||
|
;;
|
||||||
|
--dry-run)
|
||||||
|
DRY_RUN=1
|
||||||
|
shift
|
||||||
|
;;
|
||||||
|
-h|--help)
|
||||||
|
usage
|
||||||
|
exit 0
|
||||||
|
;;
|
||||||
|
*)
|
||||||
|
die "Unknown option: $1"
|
||||||
|
;;
|
||||||
|
esac
|
||||||
|
done
|
||||||
|
|
||||||
|
case "$KIND" in
|
||||||
|
both|prs|issues) ;;
|
||||||
|
*) die "--kind must be one of: both, prs, issues" ;;
|
||||||
|
esac
|
||||||
|
|
||||||
|
case "$STATE" in
|
||||||
|
all|open|closed) ;;
|
||||||
|
*) die "--state must be one of: all, open, closed" ;;
|
||||||
|
esac
|
||||||
|
|
||||||
|
if ! [[ "$LIMIT" =~ ^[0-9]+$ ]]; then
|
||||||
|
die "--limit must be a non-negative integer"
|
||||||
|
fi
|
||||||
|
|
||||||
|
require_cmd gh
|
||||||
|
require_cmd jq
|
||||||
|
|
||||||
|
if ! gh auth status >/dev/null 2>&1; then
|
||||||
|
die "gh CLI is not authenticated. Run: gh auth login"
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [[ -z "$REPO" ]]; then
|
||||||
|
REPO="$(gh repo view --json nameWithOwner --jq '.nameWithOwner' 2>/dev/null || true)"
|
||||||
|
[[ -n "$REPO" ]] || die "Unable to infer repo. Pass --repo <owner/repo>."
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "[$SCRIPT_NAME] Repo: $REPO"
|
||||||
|
echo "[$SCRIPT_NAME] Mode: $([[ "$DRY_RUN" -eq 1 ]] && echo "dry-run" || echo "apply")"
|
||||||
|
echo "[$SCRIPT_NAME] Kind: $KIND | State: $STATE | Limit: $LIMIT"
|
||||||
|
|
||||||
|
TIERS_JSON='["trusted contributor","experienced contributor","principal contributor","distinguished contributor"]'
|
||||||
|
|
||||||
|
TMP_FILES=()
|
||||||
|
cleanup() {
|
||||||
|
if ((${#TMP_FILES[@]} > 0)); then
|
||||||
|
rm -f "${TMP_FILES[@]}"
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
trap cleanup EXIT
|
||||||
|
|
||||||
|
new_tmp_file() {
|
||||||
|
local tmp
|
||||||
|
tmp="$(mktemp)"
|
||||||
|
TMP_FILES+=("$tmp")
|
||||||
|
echo "$tmp"
|
||||||
|
}
|
||||||
|
|
||||||
|
targets_file="$(new_tmp_file)"
|
||||||
|
|
||||||
|
if [[ "$KIND" == "both" || "$KIND" == "prs" ]]; then
|
||||||
|
gh api --paginate "repos/$REPO/pulls?state=$STATE&per_page=100" \
|
||||||
|
--jq '.[] | {
|
||||||
|
kind: "pr",
|
||||||
|
number: .number,
|
||||||
|
author: (.user.login // ""),
|
||||||
|
author_type: (.user.type // ""),
|
||||||
|
labels: [(.labels[]?.name // empty)]
|
||||||
|
}' >> "$targets_file"
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [[ "$KIND" == "both" || "$KIND" == "issues" ]]; then
|
||||||
|
gh api --paginate "repos/$REPO/issues?state=$STATE&per_page=100" \
|
||||||
|
--jq '.[] | select(.pull_request | not) | {
|
||||||
|
kind: "issue",
|
||||||
|
number: .number,
|
||||||
|
author: (.user.login // ""),
|
||||||
|
author_type: (.user.type // ""),
|
||||||
|
labels: [(.labels[]?.name // empty)]
|
||||||
|
}' >> "$targets_file"
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [[ "$LIMIT" -gt 0 ]]; then
|
||||||
|
limited_file="$(new_tmp_file)"
|
||||||
|
head -n "$LIMIT" "$targets_file" > "$limited_file"
|
||||||
|
mv "$limited_file" "$targets_file"
|
||||||
|
fi
|
||||||
|
|
||||||
|
target_count="$(wc -l < "$targets_file" | tr -d ' ')"
|
||||||
|
if [[ "$target_count" -eq 0 ]]; then
|
||||||
|
echo "[$SCRIPT_NAME] No targets found."
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "[$SCRIPT_NAME] Targets fetched: $target_count"
|
||||||
|
|
||||||
|
# Ensure tier labels exist (trusted contributor might be new).
|
||||||
|
label_color=""
|
||||||
|
for probe_label in "experienced contributor" "principal contributor" "distinguished contributor" "trusted contributor"; do
|
||||||
|
encoded_label="$(urlencode "$probe_label")"
|
||||||
|
if color_candidate="$(gh api "repos/$REPO/labels/$encoded_label" --jq '.color' 2>/dev/null || true)"; then
|
||||||
|
if [[ -n "$color_candidate" ]]; then
|
||||||
|
label_color="$(echo "$color_candidate" | tr '[:lower:]' '[:upper:]')"
|
||||||
|
break
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
[[ -n "$label_color" ]] || label_color="C5D7A2"
|
||||||
|
|
||||||
|
while IFS= read -r tier_label; do
|
||||||
|
[[ -n "$tier_label" ]] || continue
|
||||||
|
encoded_label="$(urlencode "$tier_label")"
|
||||||
|
if gh api "repos/$REPO/labels/$encoded_label" >/dev/null 2>&1; then
|
||||||
|
continue
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [[ "$DRY_RUN" -eq 1 ]]; then
|
||||||
|
echo "[dry-run] Would create missing label: $tier_label (color=$label_color)"
|
||||||
|
else
|
||||||
|
gh api -X POST "repos/$REPO/labels" \
|
||||||
|
-f name="$tier_label" \
|
||||||
|
-f color="$label_color" >/dev/null
|
||||||
|
echo "[apply] Created missing label: $tier_label"
|
||||||
|
fi
|
||||||
|
done < <(jq -r '.[]' <<<"$TIERS_JSON")
|
||||||
|
|
||||||
|
# Build merged PR count cache by unique human authors.
|
||||||
|
authors_file="$(new_tmp_file)"
|
||||||
|
jq -r 'select(.author != "" and .author_type != "Bot") | .author' "$targets_file" | sort -u > "$authors_file"
|
||||||
|
author_count="$(wc -l < "$authors_file" | tr -d ' ')"
|
||||||
|
echo "[$SCRIPT_NAME] Unique human authors: $author_count"
|
||||||
|
|
||||||
|
author_counts_file="$(new_tmp_file)"
|
||||||
|
while IFS= read -r author; do
|
||||||
|
[[ -n "$author" ]] || continue
|
||||||
|
query="repo:$REPO is:pr is:merged author:$author"
|
||||||
|
merged_count="$(gh api search/issues -f q="$query" -F per_page=1 --jq '.total_count' 2>/dev/null || true)"
|
||||||
|
if ! [[ "$merged_count" =~ ^[0-9]+$ ]]; then
|
||||||
|
merged_count=0
|
||||||
|
fi
|
||||||
|
printf '%s\t%s\n' "$author" "$merged_count" >> "$author_counts_file"
|
||||||
|
done < "$authors_file"
|
||||||
|
|
||||||
|
updated=0
|
||||||
|
unchanged=0
|
||||||
|
skipped=0
|
||||||
|
failed=0
|
||||||
|
|
||||||
|
while IFS= read -r target_json; do
|
||||||
|
[[ -n "$target_json" ]] || continue
|
||||||
|
|
||||||
|
number="$(jq -r '.number' <<<"$target_json")"
|
||||||
|
kind="$(jq -r '.kind' <<<"$target_json")"
|
||||||
|
author="$(jq -r '.author' <<<"$target_json")"
|
||||||
|
author_type="$(jq -r '.author_type' <<<"$target_json")"
|
||||||
|
current_labels_json="$(jq -c '.labels // []' <<<"$target_json")"
|
||||||
|
|
||||||
|
if [[ -z "$author" || "$author_type" == "Bot" ]]; then
|
||||||
|
skipped=$((skipped + 1))
|
||||||
|
continue
|
||||||
|
fi
|
||||||
|
|
||||||
|
merged_count="$(awk -F '\t' -v key="$author" '$1 == key { print $2; exit }' "$author_counts_file")"
|
||||||
|
if ! [[ "$merged_count" =~ ^[0-9]+$ ]]; then
|
||||||
|
merged_count=0
|
||||||
|
fi
|
||||||
|
desired_tier="$(select_contributor_tier "$merged_count")"
|
||||||
|
|
||||||
|
if ! current_tier="$(jq -r --argjson tiers "$TIERS_JSON" '[.[] | select(. as $label | ($tiers | index($label)) != null)][0] // ""' <<<"$current_labels_json" 2>/dev/null)"; then
|
||||||
|
echo "[warn] Skipping ${kind} #${number}: cannot parse current labels JSON" >&2
|
||||||
|
failed=$((failed + 1))
|
||||||
|
continue
|
||||||
|
fi
|
||||||
|
|
||||||
|
if ! next_labels_json="$(jq -c --arg desired "$desired_tier" --argjson tiers "$TIERS_JSON" '
|
||||||
|
(. // [])
|
||||||
|
| map(select(. as $label | ($tiers | index($label)) == null))
|
||||||
|
| if $desired != "" then . + [$desired] else . end
|
||||||
|
| unique
|
||||||
|
' <<<"$current_labels_json" 2>/dev/null)"; then
|
||||||
|
echo "[warn] Skipping ${kind} #${number}: cannot compute next labels" >&2
|
||||||
|
failed=$((failed + 1))
|
||||||
|
continue
|
||||||
|
fi
|
||||||
|
|
||||||
|
if ! normalized_current="$(jq -c 'unique | sort' <<<"$current_labels_json" 2>/dev/null)"; then
|
||||||
|
echo "[warn] Skipping ${kind} #${number}: cannot normalize current labels" >&2
|
||||||
|
failed=$((failed + 1))
|
||||||
|
continue
|
||||||
|
fi
|
||||||
|
|
||||||
|
if ! normalized_next="$(jq -c 'unique | sort' <<<"$next_labels_json" 2>/dev/null)"; then
|
||||||
|
echo "[warn] Skipping ${kind} #${number}: cannot normalize next labels" >&2
|
||||||
|
failed=$((failed + 1))
|
||||||
|
continue
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [[ "$normalized_current" == "$normalized_next" ]]; then
|
||||||
|
unchanged=$((unchanged + 1))
|
||||||
|
continue
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [[ "$DRY_RUN" -eq 1 ]]; then
|
||||||
|
echo "[dry-run] ${kind} #${number} @${author} merged=${merged_count} tier: '${current_tier:-none}' -> '${desired_tier:-none}'"
|
||||||
|
updated=$((updated + 1))
|
||||||
|
continue
|
||||||
|
fi
|
||||||
|
|
||||||
|
payload="$(jq -cn --argjson labels "$next_labels_json" '{labels: $labels}')"
|
||||||
|
if gh api -X PUT "repos/$REPO/issues/$number/labels" --input - <<<"$payload" >/dev/null; then
|
||||||
|
echo "[apply] Updated ${kind} #${number} @${author} tier: '${current_tier:-none}' -> '${desired_tier:-none}'"
|
||||||
|
updated=$((updated + 1))
|
||||||
|
else
|
||||||
|
echo "[apply] FAILED ${kind} #${number}" >&2
|
||||||
|
failed=$((failed + 1))
|
||||||
|
fi
|
||||||
|
done < "$targets_file"
|
||||||
|
|
||||||
|
echo ""
|
||||||
|
echo "[$SCRIPT_NAME] Summary"
|
||||||
|
echo " Targets: $target_count"
|
||||||
|
echo " Updated: $updated"
|
||||||
|
echo " Unchanged: $unchanged"
|
||||||
|
echo " Skipped: $skipped"
|
||||||
|
echo " Failed: $failed"
|
||||||
|
|
||||||
|
if [[ "$failed" -gt 0 ]]; then
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
@ -251,6 +251,7 @@ impl Agent {
|
||||||
let provider: Box<dyn Provider> = providers::create_routed_provider(
|
let provider: Box<dyn Provider> = providers::create_routed_provider(
|
||||||
provider_name,
|
provider_name,
|
||||||
config.api_key.as_deref(),
|
config.api_key.as_deref(),
|
||||||
|
config.api_url.as_deref(),
|
||||||
&config.reliability,
|
&config.reliability,
|
||||||
&config.model_routes,
|
&config.model_routes,
|
||||||
&model_name,
|
&model_name,
|
||||||
|
|
@ -388,7 +389,7 @@ impl Agent {
|
||||||
if self.auto_save {
|
if self.auto_save {
|
||||||
let _ = self
|
let _ = self
|
||||||
.memory
|
.memory
|
||||||
.store("user_msg", user_message, MemoryCategory::Conversation)
|
.store("user_msg", user_message, MemoryCategory::Conversation, None)
|
||||||
.await;
|
.await;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -447,7 +448,7 @@ impl Agent {
|
||||||
let summary = truncate_with_ellipsis(&final_text, 100);
|
let summary = truncate_with_ellipsis(&final_text, 100);
|
||||||
let _ = self
|
let _ = self
|
||||||
.memory
|
.memory
|
||||||
.store("assistant_resp", &summary, MemoryCategory::Daily)
|
.store("assistant_resp", &summary, MemoryCategory::Daily, None)
|
||||||
.await;
|
.await;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -557,6 +558,7 @@ pub async fn run(
|
||||||
agent.observer.record_event(&ObserverEvent::AgentEnd {
|
agent.observer.record_event(&ObserverEvent::AgentEnd {
|
||||||
duration: start.elapsed(),
|
duration: start.elapsed(),
|
||||||
tokens_used: None,
|
tokens_used: None,
|
||||||
|
cost_usd: None,
|
||||||
});
|
});
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
|
|
|
||||||
|
|
@ -7,14 +7,70 @@ use crate::security::SecurityPolicy;
|
||||||
use crate::tools::{self, Tool};
|
use crate::tools::{self, Tool};
|
||||||
use crate::util::truncate_with_ellipsis;
|
use crate::util::truncate_with_ellipsis;
|
||||||
use anyhow::Result;
|
use anyhow::Result;
|
||||||
|
use regex::{Regex, RegexSet};
|
||||||
use std::fmt::Write;
|
use std::fmt::Write;
|
||||||
use std::io::Write as _;
|
use std::io::Write as _;
|
||||||
use std::sync::Arc;
|
use std::sync::{Arc, LazyLock};
|
||||||
use std::time::Instant;
|
use std::time::Instant;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
/// Maximum agentic tool-use iterations per user message to prevent runaway loops.
|
/// Maximum agentic tool-use iterations per user message to prevent runaway loops.
|
||||||
const MAX_TOOL_ITERATIONS: usize = 10;
|
const MAX_TOOL_ITERATIONS: usize = 10;
|
||||||
|
|
||||||
|
static SENSITIVE_KEY_PATTERNS: LazyLock<RegexSet> = LazyLock::new(|| {
|
||||||
|
RegexSet::new([
|
||||||
|
r"(?i)token",
|
||||||
|
r"(?i)api[_-]?key",
|
||||||
|
r"(?i)password",
|
||||||
|
r"(?i)secret",
|
||||||
|
r"(?i)user[_-]?key",
|
||||||
|
r"(?i)bearer",
|
||||||
|
r"(?i)credential",
|
||||||
|
])
|
||||||
|
.unwrap()
|
||||||
|
});
|
||||||
|
|
||||||
|
static SENSITIVE_KV_REGEX: LazyLock<Regex> = LazyLock::new(|| {
|
||||||
|
Regex::new(r#"(?i)(token|api[_-]?key|password|secret|user[_-]?key|bearer|credential)["']?\s*[:=]\s*(?:"([^"]{8,})"|'([^']{8,})'|([a-zA-Z0-9_\-\.]{8,}))"#).unwrap()
|
||||||
|
});
|
||||||
|
|
||||||
|
/// Scrub credentials from tool output to prevent accidental exfiltration.
|
||||||
|
/// Replaces known credential patterns with a redacted placeholder while preserving
|
||||||
|
/// a small prefix for context.
|
||||||
|
fn scrub_credentials(input: &str) -> String {
|
||||||
|
SENSITIVE_KV_REGEX
|
||||||
|
.replace_all(input, |caps: ®ex::Captures| {
|
||||||
|
let full_match = &caps[0];
|
||||||
|
let key = &caps[1];
|
||||||
|
let val = caps
|
||||||
|
.get(2)
|
||||||
|
.or(caps.get(3))
|
||||||
|
.or(caps.get(4))
|
||||||
|
.map(|m| m.as_str())
|
||||||
|
.unwrap_or("");
|
||||||
|
|
||||||
|
// Preserve first 4 chars for context, then redact
|
||||||
|
let prefix = if val.len() > 4 { &val[..4] } else { "" };
|
||||||
|
|
||||||
|
if full_match.contains(':') {
|
||||||
|
if full_match.contains('"') {
|
||||||
|
format!("\"{}\": \"{}*[REDACTED]\"", key, prefix)
|
||||||
|
} else {
|
||||||
|
format!("{}: {}*[REDACTED]", key, prefix)
|
||||||
|
}
|
||||||
|
} else if full_match.contains('=') {
|
||||||
|
if full_match.contains('"') {
|
||||||
|
format!("{}=\"{}*[REDACTED]\"", key, prefix)
|
||||||
|
} else {
|
||||||
|
format!("{}={}*[REDACTED]", key, prefix)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
format!("{}: {}*[REDACTED]", key, prefix)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.to_string()
|
||||||
|
}
|
||||||
|
|
||||||
/// Trigger auto-compaction when non-system message count exceeds this threshold.
|
/// Trigger auto-compaction when non-system message count exceeds this threshold.
|
||||||
const MAX_HISTORY_MESSAGES: usize = 50;
|
const MAX_HISTORY_MESSAGES: usize = 50;
|
||||||
|
|
||||||
|
|
@ -145,7 +201,7 @@ async fn build_context(mem: &dyn Memory, user_msg: &str) -> String {
|
||||||
let mut context = String::new();
|
let mut context = String::new();
|
||||||
|
|
||||||
// Pull relevant memories for this message
|
// Pull relevant memories for this message
|
||||||
if let Ok(entries) = mem.recall(user_msg, 5).await {
|
if let Ok(entries) = mem.recall(user_msg, 5, None).await {
|
||||||
if !entries.is_empty() {
|
if !entries.is_empty() {
|
||||||
context.push_str("[Memory context]\n");
|
context.push_str("[Memory context]\n");
|
||||||
for entry in &entries {
|
for entry in &entries {
|
||||||
|
|
@ -436,6 +492,7 @@ struct ParsedToolCall {
|
||||||
/// Execute a single turn of the agent loop: send messages, parse tool calls,
|
/// Execute a single turn of the agent loop: send messages, parse tool calls,
|
||||||
/// execute tools, and loop until the LLM produces a final text response.
|
/// execute tools, and loop until the LLM produces a final text response.
|
||||||
/// When `silent` is true, suppresses stdout (for channel use).
|
/// When `silent` is true, suppresses stdout (for channel use).
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub(crate) async fn agent_turn(
|
pub(crate) async fn agent_turn(
|
||||||
provider: &dyn Provider,
|
provider: &dyn Provider,
|
||||||
history: &mut Vec<ChatMessage>,
|
history: &mut Vec<ChatMessage>,
|
||||||
|
|
@ -461,6 +518,7 @@ pub(crate) async fn agent_turn(
|
||||||
|
|
||||||
/// Execute a single turn of the agent loop: send messages, parse tool calls,
|
/// Execute a single turn of the agent loop: send messages, parse tool calls,
|
||||||
/// execute tools, and loop until the LLM produces a final text response.
|
/// execute tools, and loop until the LLM produces a final text response.
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
pub(crate) async fn run_tool_call_loop(
|
pub(crate) async fn run_tool_call_loop(
|
||||||
provider: &dyn Provider,
|
provider: &dyn Provider,
|
||||||
history: &mut Vec<ChatMessage>,
|
history: &mut Vec<ChatMessage>,
|
||||||
|
|
@ -606,7 +664,7 @@ pub(crate) async fn run_tool_call_loop(
|
||||||
success: r.success,
|
success: r.success,
|
||||||
});
|
});
|
||||||
if r.success {
|
if r.success {
|
||||||
r.output
|
scrub_credentials(&r.output)
|
||||||
} else {
|
} else {
|
||||||
format!("Error: {}", r.error.unwrap_or_else(|| r.output))
|
format!("Error: {}", r.error.unwrap_or_else(|| r.output))
|
||||||
}
|
}
|
||||||
|
|
@ -749,6 +807,7 @@ pub async fn run(
|
||||||
let provider: Box<dyn Provider> = providers::create_routed_provider(
|
let provider: Box<dyn Provider> = providers::create_routed_provider(
|
||||||
provider_name,
|
provider_name,
|
||||||
config.api_key.as_deref(),
|
config.api_key.as_deref(),
|
||||||
|
config.api_url.as_deref(),
|
||||||
&config.reliability,
|
&config.reliability,
|
||||||
&config.model_routes,
|
&config.model_routes,
|
||||||
model_name,
|
model_name,
|
||||||
|
|
@ -912,7 +971,7 @@ pub async fn run(
|
||||||
if config.memory.auto_save {
|
if config.memory.auto_save {
|
||||||
let user_key = autosave_memory_key("user_msg");
|
let user_key = autosave_memory_key("user_msg");
|
||||||
let _ = mem
|
let _ = mem
|
||||||
.store(&user_key, &msg, MemoryCategory::Conversation)
|
.store(&user_key, &msg, MemoryCategory::Conversation, None)
|
||||||
.await;
|
.await;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -955,7 +1014,7 @@ pub async fn run(
|
||||||
let summary = truncate_with_ellipsis(&response, 100);
|
let summary = truncate_with_ellipsis(&response, 100);
|
||||||
let response_key = autosave_memory_key("assistant_resp");
|
let response_key = autosave_memory_key("assistant_resp");
|
||||||
let _ = mem
|
let _ = mem
|
||||||
.store(&response_key, &summary, MemoryCategory::Daily)
|
.store(&response_key, &summary, MemoryCategory::Daily, None)
|
||||||
.await;
|
.await;
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -978,7 +1037,7 @@ pub async fn run(
|
||||||
if config.memory.auto_save {
|
if config.memory.auto_save {
|
||||||
let user_key = autosave_memory_key("user_msg");
|
let user_key = autosave_memory_key("user_msg");
|
||||||
let _ = mem
|
let _ = mem
|
||||||
.store(&user_key, &msg.content, MemoryCategory::Conversation)
|
.store(&user_key, &msg.content, MemoryCategory::Conversation, None)
|
||||||
.await;
|
.await;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1036,7 +1095,7 @@ pub async fn run(
|
||||||
let summary = truncate_with_ellipsis(&response, 100);
|
let summary = truncate_with_ellipsis(&response, 100);
|
||||||
let response_key = autosave_memory_key("assistant_resp");
|
let response_key = autosave_memory_key("assistant_resp");
|
||||||
let _ = mem
|
let _ = mem
|
||||||
.store(&response_key, &summary, MemoryCategory::Daily)
|
.store(&response_key, &summary, MemoryCategory::Daily, None)
|
||||||
.await;
|
.await;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -1048,6 +1107,7 @@ pub async fn run(
|
||||||
observer.record_event(&ObserverEvent::AgentEnd {
|
observer.record_event(&ObserverEvent::AgentEnd {
|
||||||
duration,
|
duration,
|
||||||
tokens_used: None,
|
tokens_used: None,
|
||||||
|
cost_usd: None,
|
||||||
});
|
});
|
||||||
|
|
||||||
Ok(final_output)
|
Ok(final_output)
|
||||||
|
|
@ -1104,6 +1164,7 @@ pub async fn process_message(config: Config, message: &str) -> Result<String> {
|
||||||
let provider: Box<dyn Provider> = providers::create_routed_provider(
|
let provider: Box<dyn Provider> = providers::create_routed_provider(
|
||||||
provider_name,
|
provider_name,
|
||||||
config.api_key.as_deref(),
|
config.api_key.as_deref(),
|
||||||
|
config.api_url.as_deref(),
|
||||||
&config.reliability,
|
&config.reliability,
|
||||||
&config.model_routes,
|
&config.model_routes,
|
||||||
&model_name,
|
&model_name,
|
||||||
|
|
@ -1217,6 +1278,25 @@ pub async fn process_message(config: Config, message: &str) -> Result<String> {
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_scrub_credentials() {
|
||||||
|
let input = "API_KEY=sk-1234567890abcdef; token: 1234567890; password=\"secret123456\"";
|
||||||
|
let scrubbed = scrub_credentials(input);
|
||||||
|
assert!(scrubbed.contains("API_KEY=sk-1*[REDACTED]"));
|
||||||
|
assert!(scrubbed.contains("token: 1234*[REDACTED]"));
|
||||||
|
assert!(scrubbed.contains("password=\"secr*[REDACTED]\""));
|
||||||
|
assert!(!scrubbed.contains("abcdef"));
|
||||||
|
assert!(!scrubbed.contains("secret123456"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_scrub_credentials_json() {
|
||||||
|
let input = r#"{"api_key": "sk-1234567890", "other": "public"}"#;
|
||||||
|
let scrubbed = scrub_credentials(input);
|
||||||
|
assert!(scrubbed.contains("\"api_key\": \"sk-1*[REDACTED]\""));
|
||||||
|
assert!(scrubbed.contains("public"));
|
||||||
|
}
|
||||||
use crate::memory::{Memory, MemoryCategory, SqliteMemory};
|
use crate::memory::{Memory, MemoryCategory, SqliteMemory};
|
||||||
use tempfile::TempDir;
|
use tempfile::TempDir;
|
||||||
|
|
||||||
|
|
@ -1496,16 +1576,16 @@ I will now call the tool with this payload:
|
||||||
let key1 = autosave_memory_key("user_msg");
|
let key1 = autosave_memory_key("user_msg");
|
||||||
let key2 = autosave_memory_key("user_msg");
|
let key2 = autosave_memory_key("user_msg");
|
||||||
|
|
||||||
mem.store(&key1, "I'm Paul", MemoryCategory::Conversation)
|
mem.store(&key1, "I'm Paul", MemoryCategory::Conversation, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
mem.store(&key2, "I'm 45", MemoryCategory::Conversation)
|
mem.store(&key2, "I'm 45", MemoryCategory::Conversation, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
assert_eq!(mem.count().await.unwrap(), 2);
|
assert_eq!(mem.count().await.unwrap(), 2);
|
||||||
|
|
||||||
let recalled = mem.recall("45", 5).await.unwrap();
|
let recalled = mem.recall("45", 5, None).await.unwrap();
|
||||||
assert!(recalled.iter().any(|entry| entry.content.contains("45")));
|
assert!(recalled.iter().any(|entry| entry.content.contains("45")));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -33,7 +33,7 @@ impl MemoryLoader for DefaultMemoryLoader {
|
||||||
memory: &dyn Memory,
|
memory: &dyn Memory,
|
||||||
user_message: &str,
|
user_message: &str,
|
||||||
) -> anyhow::Result<String> {
|
) -> anyhow::Result<String> {
|
||||||
let entries = memory.recall(user_message, self.limit).await?;
|
let entries = memory.recall(user_message, self.limit, None).await?;
|
||||||
if entries.is_empty() {
|
if entries.is_empty() {
|
||||||
return Ok(String::new());
|
return Ok(String::new());
|
||||||
}
|
}
|
||||||
|
|
@ -61,11 +61,17 @@ mod tests {
|
||||||
_key: &str,
|
_key: &str,
|
||||||
_content: &str,
|
_content: &str,
|
||||||
_category: MemoryCategory,
|
_category: MemoryCategory,
|
||||||
|
_session_id: Option<&str>,
|
||||||
) -> anyhow::Result<()> {
|
) -> anyhow::Result<()> {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn recall(&self, _query: &str, limit: usize) -> anyhow::Result<Vec<MemoryEntry>> {
|
async fn recall(
|
||||||
|
&self,
|
||||||
|
_query: &str,
|
||||||
|
limit: usize,
|
||||||
|
_session_id: Option<&str>,
|
||||||
|
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||||
if limit == 0 {
|
if limit == 0 {
|
||||||
return Ok(vec![]);
|
return Ok(vec![]);
|
||||||
}
|
}
|
||||||
|
|
@ -87,6 +93,7 @@ mod tests {
|
||||||
async fn list(
|
async fn list(
|
||||||
&self,
|
&self,
|
||||||
_category: Option<&MemoryCategory>,
|
_category: Option<&MemoryCategory>,
|
||||||
|
_session_id: Option<&str>,
|
||||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||||
Ok(vec![])
|
Ok(vec![])
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -40,6 +40,7 @@ impl Channel for CliChannel {
|
||||||
let msg = ChannelMessage {
|
let msg = ChannelMessage {
|
||||||
id: Uuid::new_v4().to_string(),
|
id: Uuid::new_v4().to_string(),
|
||||||
sender: "user".to_string(),
|
sender: "user".to_string(),
|
||||||
|
reply_target: "user".to_string(),
|
||||||
content: line,
|
content: line,
|
||||||
channel: "cli".to_string(),
|
channel: "cli".to_string(),
|
||||||
timestamp: std::time::SystemTime::now()
|
timestamp: std::time::SystemTime::now()
|
||||||
|
|
@ -90,12 +91,14 @@ mod tests {
|
||||||
let msg = ChannelMessage {
|
let msg = ChannelMessage {
|
||||||
id: "test-id".into(),
|
id: "test-id".into(),
|
||||||
sender: "user".into(),
|
sender: "user".into(),
|
||||||
|
reply_target: "user".into(),
|
||||||
content: "hello".into(),
|
content: "hello".into(),
|
||||||
channel: "cli".into(),
|
channel: "cli".into(),
|
||||||
timestamp: 1_234_567_890,
|
timestamp: 1_234_567_890,
|
||||||
};
|
};
|
||||||
assert_eq!(msg.id, "test-id");
|
assert_eq!(msg.id, "test-id");
|
||||||
assert_eq!(msg.sender, "user");
|
assert_eq!(msg.sender, "user");
|
||||||
|
assert_eq!(msg.reply_target, "user");
|
||||||
assert_eq!(msg.content, "hello");
|
assert_eq!(msg.content, "hello");
|
||||||
assert_eq!(msg.channel, "cli");
|
assert_eq!(msg.channel, "cli");
|
||||||
assert_eq!(msg.timestamp, 1_234_567_890);
|
assert_eq!(msg.timestamp, 1_234_567_890);
|
||||||
|
|
@ -106,6 +109,7 @@ mod tests {
|
||||||
let msg = ChannelMessage {
|
let msg = ChannelMessage {
|
||||||
id: "id".into(),
|
id: "id".into(),
|
||||||
sender: "s".into(),
|
sender: "s".into(),
|
||||||
|
reply_target: "s".into(),
|
||||||
content: "c".into(),
|
content: "c".into(),
|
||||||
channel: "ch".into(),
|
channel: "ch".into(),
|
||||||
timestamp: 0,
|
timestamp: 0,
|
||||||
|
|
|
||||||
|
|
@ -7,7 +7,7 @@ use tokio::sync::RwLock;
|
||||||
use tokio_tungstenite::tungstenite::Message;
|
use tokio_tungstenite::tungstenite::Message;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
/// DingTalk (钉钉) channel — connects via Stream Mode WebSocket for real-time messages.
|
/// DingTalk channel — connects via Stream Mode WebSocket for real-time messages.
|
||||||
/// Replies are sent through per-message session webhook URLs.
|
/// Replies are sent through per-message session webhook URLs.
|
||||||
pub struct DingTalkChannel {
|
pub struct DingTalkChannel {
|
||||||
client_id: String,
|
client_id: String,
|
||||||
|
|
@ -64,6 +64,18 @@ impl DingTalkChannel {
|
||||||
let gw: GatewayResponse = resp.json().await?;
|
let gw: GatewayResponse = resp.json().await?;
|
||||||
Ok(gw)
|
Ok(gw)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn resolve_reply_target(
|
||||||
|
sender_id: &str,
|
||||||
|
conversation_type: &str,
|
||||||
|
conversation_id: Option<&str>,
|
||||||
|
) -> String {
|
||||||
|
if conversation_type == "1" {
|
||||||
|
sender_id.to_string()
|
||||||
|
} else {
|
||||||
|
conversation_id.unwrap_or(sender_id).to_string()
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
|
|
@ -193,14 +205,11 @@ impl Channel for DingTalkChannel {
|
||||||
.unwrap_or("1");
|
.unwrap_or("1");
|
||||||
|
|
||||||
// Private chat uses sender ID, group chat uses conversation ID
|
// Private chat uses sender ID, group chat uses conversation ID
|
||||||
let chat_id = if conversation_type == "1" {
|
let chat_id = Self::resolve_reply_target(
|
||||||
sender_id.to_string()
|
sender_id,
|
||||||
} else {
|
conversation_type,
|
||||||
data.get("conversationId")
|
data.get("conversationId").and_then(|c| c.as_str()),
|
||||||
.and_then(|c| c.as_str())
|
);
|
||||||
.unwrap_or(sender_id)
|
|
||||||
.to_string()
|
|
||||||
};
|
|
||||||
|
|
||||||
// Store session webhook for later replies
|
// Store session webhook for later replies
|
||||||
if let Some(webhook) = data.get("sessionWebhook").and_then(|w| w.as_str()) {
|
if let Some(webhook) = data.get("sessionWebhook").and_then(|w| w.as_str()) {
|
||||||
|
|
@ -229,6 +238,7 @@ impl Channel for DingTalkChannel {
|
||||||
let channel_msg = ChannelMessage {
|
let channel_msg = ChannelMessage {
|
||||||
id: Uuid::new_v4().to_string(),
|
id: Uuid::new_v4().to_string(),
|
||||||
sender: sender_id.to_string(),
|
sender: sender_id.to_string(),
|
||||||
|
reply_target: chat_id,
|
||||||
content: content.to_string(),
|
content: content.to_string(),
|
||||||
channel: "dingtalk".to_string(),
|
channel: "dingtalk".to_string(),
|
||||||
timestamp: std::time::SystemTime::now()
|
timestamp: std::time::SystemTime::now()
|
||||||
|
|
@ -305,4 +315,22 @@ client_secret = "secret"
|
||||||
let config: crate::config::schema::DingTalkConfig = toml::from_str(toml_str).unwrap();
|
let config: crate::config::schema::DingTalkConfig = toml::from_str(toml_str).unwrap();
|
||||||
assert!(config.allowed_users.is_empty());
|
assert!(config.allowed_users.is_empty());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_resolve_reply_target_private_chat_uses_sender_id() {
|
||||||
|
let target = DingTalkChannel::resolve_reply_target("staff_1", "1", Some("conv_1"));
|
||||||
|
assert_eq!(target, "staff_1");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_resolve_reply_target_group_chat_uses_conversation_id() {
|
||||||
|
let target = DingTalkChannel::resolve_reply_target("staff_1", "2", Some("conv_1"));
|
||||||
|
assert_eq!(target, "conv_1");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_resolve_reply_target_group_chat_falls_back_to_sender_id() {
|
||||||
|
let target = DingTalkChannel::resolve_reply_target("staff_1", "2", None);
|
||||||
|
assert_eq!(target, "staff_1");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -11,6 +11,7 @@ pub struct DiscordChannel {
|
||||||
guild_id: Option<String>,
|
guild_id: Option<String>,
|
||||||
allowed_users: Vec<String>,
|
allowed_users: Vec<String>,
|
||||||
listen_to_bots: bool,
|
listen_to_bots: bool,
|
||||||
|
mention_only: bool,
|
||||||
client: reqwest::Client,
|
client: reqwest::Client,
|
||||||
typing_handle: std::sync::Mutex<Option<tokio::task::JoinHandle<()>>>,
|
typing_handle: std::sync::Mutex<Option<tokio::task::JoinHandle<()>>>,
|
||||||
}
|
}
|
||||||
|
|
@ -21,12 +22,14 @@ impl DiscordChannel {
|
||||||
guild_id: Option<String>,
|
guild_id: Option<String>,
|
||||||
allowed_users: Vec<String>,
|
allowed_users: Vec<String>,
|
||||||
listen_to_bots: bool,
|
listen_to_bots: bool,
|
||||||
|
mention_only: bool,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
bot_token,
|
bot_token,
|
||||||
guild_id,
|
guild_id,
|
||||||
allowed_users,
|
allowed_users,
|
||||||
listen_to_bots,
|
listen_to_bots,
|
||||||
|
mention_only,
|
||||||
client: reqwest::Client::new(),
|
client: reqwest::Client::new(),
|
||||||
typing_handle: std::sync::Mutex::new(None),
|
typing_handle: std::sync::Mutex::new(None),
|
||||||
}
|
}
|
||||||
|
|
@ -343,6 +346,22 @@ impl Channel for DiscordChannel {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Skip messages that don't @-mention the bot (when mention_only is enabled)
|
||||||
|
if self.mention_only {
|
||||||
|
let mention_tag = format!("<@{bot_user_id}>");
|
||||||
|
if !content.contains(&mention_tag) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Strip the bot mention from content so the agent sees clean text
|
||||||
|
let clean_content = if self.mention_only {
|
||||||
|
let mention_tag = format!("<@{bot_user_id}>");
|
||||||
|
content.replace(&mention_tag, "").trim().to_string()
|
||||||
|
} else {
|
||||||
|
content.to_string()
|
||||||
|
};
|
||||||
|
|
||||||
let message_id = d.get("id").and_then(|i| i.as_str()).unwrap_or("");
|
let message_id = d.get("id").and_then(|i| i.as_str()).unwrap_or("");
|
||||||
let channel_id = d.get("channel_id").and_then(|c| c.as_str()).unwrap_or("").to_string();
|
let channel_id = d.get("channel_id").and_then(|c| c.as_str()).unwrap_or("").to_string();
|
||||||
|
|
||||||
|
|
@ -353,6 +372,11 @@ impl Channel for DiscordChannel {
|
||||||
format!("discord_{message_id}")
|
format!("discord_{message_id}")
|
||||||
},
|
},
|
||||||
sender: author_id.to_string(),
|
sender: author_id.to_string(),
|
||||||
|
reply_target: if channel_id.is_empty() {
|
||||||
|
author_id.to_string()
|
||||||
|
} else {
|
||||||
|
channel_id
|
||||||
|
},
|
||||||
content: content.to_string(),
|
content: content.to_string(),
|
||||||
channel: channel_id,
|
channel: channel_id,
|
||||||
timestamp: std::time::SystemTime::now()
|
timestamp: std::time::SystemTime::now()
|
||||||
|
|
@ -423,7 +447,7 @@ mod tests {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn discord_channel_name() {
|
fn discord_channel_name() {
|
||||||
let ch = DiscordChannel::new("fake".into(), None, vec![], false);
|
let ch = DiscordChannel::new("fake".into(), None, vec![], false, false);
|
||||||
assert_eq!(ch.name(), "discord");
|
assert_eq!(ch.name(), "discord");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -444,21 +468,27 @@ mod tests {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn empty_allowlist_denies_everyone() {
|
fn empty_allowlist_denies_everyone() {
|
||||||
let ch = DiscordChannel::new("fake".into(), None, vec![], false);
|
let ch = DiscordChannel::new("fake".into(), None, vec![], false, false);
|
||||||
assert!(!ch.is_user_allowed("12345"));
|
assert!(!ch.is_user_allowed("12345"));
|
||||||
assert!(!ch.is_user_allowed("anyone"));
|
assert!(!ch.is_user_allowed("anyone"));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn wildcard_allows_everyone() {
|
fn wildcard_allows_everyone() {
|
||||||
let ch = DiscordChannel::new("fake".into(), None, vec!["*".into()], false);
|
let ch = DiscordChannel::new("fake".into(), None, vec!["*".into()], false, false);
|
||||||
assert!(ch.is_user_allowed("12345"));
|
assert!(ch.is_user_allowed("12345"));
|
||||||
assert!(ch.is_user_allowed("anyone"));
|
assert!(ch.is_user_allowed("anyone"));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn specific_allowlist_filters() {
|
fn specific_allowlist_filters() {
|
||||||
let ch = DiscordChannel::new("fake".into(), None, vec!["111".into(), "222".into()], false);
|
let ch = DiscordChannel::new(
|
||||||
|
"fake".into(),
|
||||||
|
None,
|
||||||
|
vec!["111".into(), "222".into()],
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
);
|
||||||
assert!(ch.is_user_allowed("111"));
|
assert!(ch.is_user_allowed("111"));
|
||||||
assert!(ch.is_user_allowed("222"));
|
assert!(ch.is_user_allowed("222"));
|
||||||
assert!(!ch.is_user_allowed("333"));
|
assert!(!ch.is_user_allowed("333"));
|
||||||
|
|
@ -467,7 +497,7 @@ mod tests {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn allowlist_is_exact_match_not_substring() {
|
fn allowlist_is_exact_match_not_substring() {
|
||||||
let ch = DiscordChannel::new("fake".into(), None, vec!["111".into()], false);
|
let ch = DiscordChannel::new("fake".into(), None, vec!["111".into()], false, false);
|
||||||
assert!(!ch.is_user_allowed("1111"));
|
assert!(!ch.is_user_allowed("1111"));
|
||||||
assert!(!ch.is_user_allowed("11"));
|
assert!(!ch.is_user_allowed("11"));
|
||||||
assert!(!ch.is_user_allowed("0111"));
|
assert!(!ch.is_user_allowed("0111"));
|
||||||
|
|
@ -475,20 +505,26 @@ mod tests {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn allowlist_empty_string_user_id() {
|
fn allowlist_empty_string_user_id() {
|
||||||
let ch = DiscordChannel::new("fake".into(), None, vec!["111".into()], false);
|
let ch = DiscordChannel::new("fake".into(), None, vec!["111".into()], false, false);
|
||||||
assert!(!ch.is_user_allowed(""));
|
assert!(!ch.is_user_allowed(""));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn allowlist_with_wildcard_and_specific() {
|
fn allowlist_with_wildcard_and_specific() {
|
||||||
let ch = DiscordChannel::new("fake".into(), None, vec!["111".into(), "*".into()], false);
|
let ch = DiscordChannel::new(
|
||||||
|
"fake".into(),
|
||||||
|
None,
|
||||||
|
vec!["111".into(), "*".into()],
|
||||||
|
false,
|
||||||
|
false,
|
||||||
|
);
|
||||||
assert!(ch.is_user_allowed("111"));
|
assert!(ch.is_user_allowed("111"));
|
||||||
assert!(ch.is_user_allowed("anyone_else"));
|
assert!(ch.is_user_allowed("anyone_else"));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn allowlist_case_sensitive() {
|
fn allowlist_case_sensitive() {
|
||||||
let ch = DiscordChannel::new("fake".into(), None, vec!["ABC".into()], false);
|
let ch = DiscordChannel::new("fake".into(), None, vec!["ABC".into()], false, false);
|
||||||
assert!(ch.is_user_allowed("ABC"));
|
assert!(ch.is_user_allowed("ABC"));
|
||||||
assert!(!ch.is_user_allowed("abc"));
|
assert!(!ch.is_user_allowed("abc"));
|
||||||
assert!(!ch.is_user_allowed("Abc"));
|
assert!(!ch.is_user_allowed("Abc"));
|
||||||
|
|
@ -663,14 +699,14 @@ mod tests {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn typing_handle_starts_as_none() {
|
fn typing_handle_starts_as_none() {
|
||||||
let ch = DiscordChannel::new("fake".into(), None, vec![], false);
|
let ch = DiscordChannel::new("fake".into(), None, vec![], false, false);
|
||||||
let guard = ch.typing_handle.lock().unwrap();
|
let guard = ch.typing_handle.lock().unwrap();
|
||||||
assert!(guard.is_none());
|
assert!(guard.is_none());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn start_typing_sets_handle() {
|
async fn start_typing_sets_handle() {
|
||||||
let ch = DiscordChannel::new("fake".into(), None, vec![], false);
|
let ch = DiscordChannel::new("fake".into(), None, vec![], false, false);
|
||||||
let _ = ch.start_typing("123456").await;
|
let _ = ch.start_typing("123456").await;
|
||||||
let guard = ch.typing_handle.lock().unwrap();
|
let guard = ch.typing_handle.lock().unwrap();
|
||||||
assert!(guard.is_some());
|
assert!(guard.is_some());
|
||||||
|
|
@ -678,7 +714,7 @@ mod tests {
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn stop_typing_clears_handle() {
|
async fn stop_typing_clears_handle() {
|
||||||
let ch = DiscordChannel::new("fake".into(), None, vec![], false);
|
let ch = DiscordChannel::new("fake".into(), None, vec![], false, false);
|
||||||
let _ = ch.start_typing("123456").await;
|
let _ = ch.start_typing("123456").await;
|
||||||
let _ = ch.stop_typing("123456").await;
|
let _ = ch.stop_typing("123456").await;
|
||||||
let guard = ch.typing_handle.lock().unwrap();
|
let guard = ch.typing_handle.lock().unwrap();
|
||||||
|
|
@ -687,14 +723,14 @@ mod tests {
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn stop_typing_is_idempotent() {
|
async fn stop_typing_is_idempotent() {
|
||||||
let ch = DiscordChannel::new("fake".into(), None, vec![], false);
|
let ch = DiscordChannel::new("fake".into(), None, vec![], false, false);
|
||||||
assert!(ch.stop_typing("123456").await.is_ok());
|
assert!(ch.stop_typing("123456").await.is_ok());
|
||||||
assert!(ch.stop_typing("123456").await.is_ok());
|
assert!(ch.stop_typing("123456").await.is_ok());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn start_typing_replaces_existing_task() {
|
async fn start_typing_replaces_existing_task() {
|
||||||
let ch = DiscordChannel::new("fake".into(), None, vec![], false);
|
let ch = DiscordChannel::new("fake".into(), None, vec![], false, false);
|
||||||
let _ = ch.start_typing("111").await;
|
let _ = ch.start_typing("111").await;
|
||||||
let _ = ch.start_typing("222").await;
|
let _ = ch.start_typing("222").await;
|
||||||
let guard = ch.typing_handle.lock().unwrap();
|
let guard = ch.typing_handle.lock().unwrap();
|
||||||
|
|
|
||||||
|
|
@ -10,6 +10,7 @@
|
||||||
|
|
||||||
use anyhow::{anyhow, Result};
|
use anyhow::{anyhow, Result};
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
|
use lettre::message::SinglePart;
|
||||||
use lettre::transport::smtp::authentication::Credentials;
|
use lettre::transport::smtp::authentication::Credentials;
|
||||||
use lettre::{Message, SmtpTransport, Transport};
|
use lettre::{Message, SmtpTransport, Transport};
|
||||||
use mail_parser::{MessageParser, MimeHeaders};
|
use mail_parser::{MessageParser, MimeHeaders};
|
||||||
|
|
@ -39,7 +40,7 @@ pub struct EmailConfig {
|
||||||
pub imap_folder: String,
|
pub imap_folder: String,
|
||||||
/// SMTP server hostname
|
/// SMTP server hostname
|
||||||
pub smtp_host: String,
|
pub smtp_host: String,
|
||||||
/// SMTP server port (default: 587 for STARTTLS)
|
/// SMTP server port (default: 465 for TLS)
|
||||||
#[serde(default = "default_smtp_port")]
|
#[serde(default = "default_smtp_port")]
|
||||||
pub smtp_port: u16,
|
pub smtp_port: u16,
|
||||||
/// Use TLS for SMTP (default: true)
|
/// Use TLS for SMTP (default: true)
|
||||||
|
|
@ -63,7 +64,7 @@ fn default_imap_port() -> u16 {
|
||||||
993
|
993
|
||||||
}
|
}
|
||||||
fn default_smtp_port() -> u16 {
|
fn default_smtp_port() -> u16 {
|
||||||
587
|
465
|
||||||
}
|
}
|
||||||
fn default_imap_folder() -> String {
|
fn default_imap_folder() -> String {
|
||||||
"INBOX".into()
|
"INBOX".into()
|
||||||
|
|
@ -389,7 +390,7 @@ impl Channel for EmailChannel {
|
||||||
.from(self.config.from_address.parse()?)
|
.from(self.config.from_address.parse()?)
|
||||||
.to(recipient.parse()?)
|
.to(recipient.parse()?)
|
||||||
.subject(subject)
|
.subject(subject)
|
||||||
.body(body.to_string())?;
|
.singlepart(SinglePart::plain(body.to_string()))?;
|
||||||
|
|
||||||
let transport = self.create_smtp_transport()?;
|
let transport = self.create_smtp_transport()?;
|
||||||
transport.send(&email)?;
|
transport.send(&email)?;
|
||||||
|
|
@ -427,6 +428,7 @@ impl Channel for EmailChannel {
|
||||||
} // MutexGuard dropped before await
|
} // MutexGuard dropped before await
|
||||||
let msg = ChannelMessage {
|
let msg = ChannelMessage {
|
||||||
id,
|
id,
|
||||||
|
reply_target: sender.clone(),
|
||||||
sender,
|
sender,
|
||||||
content,
|
content,
|
||||||
channel: "email".to_string(),
|
channel: "email".to_string(),
|
||||||
|
|
@ -464,6 +466,18 @@ impl Channel for EmailChannel {
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn default_smtp_port_uses_tls_port() {
|
||||||
|
assert_eq!(default_smtp_port(), 465);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn email_config_default_uses_tls_smtp_defaults() {
|
||||||
|
let config = EmailConfig::default();
|
||||||
|
assert_eq!(config.smtp_port, 465);
|
||||||
|
assert!(config.smtp_tls);
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn build_imap_tls_config_succeeds() {
|
fn build_imap_tls_config_succeeds() {
|
||||||
let tls_config =
|
let tls_config =
|
||||||
|
|
@ -504,7 +518,7 @@ mod tests {
|
||||||
assert_eq!(config.imap_port, 993);
|
assert_eq!(config.imap_port, 993);
|
||||||
assert_eq!(config.imap_folder, "INBOX");
|
assert_eq!(config.imap_folder, "INBOX");
|
||||||
assert_eq!(config.smtp_host, "");
|
assert_eq!(config.smtp_host, "");
|
||||||
assert_eq!(config.smtp_port, 587);
|
assert_eq!(config.smtp_port, 465);
|
||||||
assert!(config.smtp_tls);
|
assert!(config.smtp_tls);
|
||||||
assert_eq!(config.username, "");
|
assert_eq!(config.username, "");
|
||||||
assert_eq!(config.password, "");
|
assert_eq!(config.password, "");
|
||||||
|
|
@ -765,8 +779,8 @@ mod tests {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn default_smtp_port_returns_587() {
|
fn default_smtp_port_returns_465() {
|
||||||
assert_eq!(default_smtp_port(), 587);
|
assert_eq!(default_smtp_port(), 465);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
@ -822,7 +836,7 @@ mod tests {
|
||||||
|
|
||||||
let config: EmailConfig = serde_json::from_str(json).unwrap();
|
let config: EmailConfig = serde_json::from_str(json).unwrap();
|
||||||
assert_eq!(config.imap_port, 993); // default
|
assert_eq!(config.imap_port, 993); // default
|
||||||
assert_eq!(config.smtp_port, 587); // default
|
assert_eq!(config.smtp_port, 465); // default
|
||||||
assert!(config.smtp_tls); // default
|
assert!(config.smtp_tls); // default
|
||||||
assert_eq!(config.poll_interval_secs, 60); // default
|
assert_eq!(config.poll_interval_secs, 60); // default
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -172,6 +172,7 @@ end tell"#
|
||||||
let msg = ChannelMessage {
|
let msg = ChannelMessage {
|
||||||
id: rowid.to_string(),
|
id: rowid.to_string(),
|
||||||
sender: sender.clone(),
|
sender: sender.clone(),
|
||||||
|
reply_target: sender.clone(),
|
||||||
content: text,
|
content: text,
|
||||||
channel: "imessage".to_string(),
|
channel: "imessage".to_string(),
|
||||||
timestamp: std::time::SystemTime::now()
|
timestamp: std::time::SystemTime::now()
|
||||||
|
|
|
||||||
|
|
@ -220,32 +220,34 @@ fn split_message(message: &str, max_bytes: usize) -> Vec<String> {
|
||||||
chunks
|
chunks
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Configuration for constructing an `IrcChannel`.
|
||||||
|
pub struct IrcChannelConfig {
|
||||||
|
pub server: String,
|
||||||
|
pub port: u16,
|
||||||
|
pub nickname: String,
|
||||||
|
pub username: Option<String>,
|
||||||
|
pub channels: Vec<String>,
|
||||||
|
pub allowed_users: Vec<String>,
|
||||||
|
pub server_password: Option<String>,
|
||||||
|
pub nickserv_password: Option<String>,
|
||||||
|
pub sasl_password: Option<String>,
|
||||||
|
pub verify_tls: bool,
|
||||||
|
}
|
||||||
|
|
||||||
impl IrcChannel {
|
impl IrcChannel {
|
||||||
#[allow(clippy::too_many_arguments)]
|
pub fn new(cfg: IrcChannelConfig) -> Self {
|
||||||
pub fn new(
|
let username = cfg.username.unwrap_or_else(|| cfg.nickname.clone());
|
||||||
server: String,
|
|
||||||
port: u16,
|
|
||||||
nickname: String,
|
|
||||||
username: Option<String>,
|
|
||||||
channels: Vec<String>,
|
|
||||||
allowed_users: Vec<String>,
|
|
||||||
server_password: Option<String>,
|
|
||||||
nickserv_password: Option<String>,
|
|
||||||
sasl_password: Option<String>,
|
|
||||||
verify_tls: bool,
|
|
||||||
) -> Self {
|
|
||||||
let username = username.unwrap_or_else(|| nickname.clone());
|
|
||||||
Self {
|
Self {
|
||||||
server,
|
server: cfg.server,
|
||||||
port,
|
port: cfg.port,
|
||||||
nickname,
|
nickname: cfg.nickname,
|
||||||
username,
|
username,
|
||||||
channels,
|
channels: cfg.channels,
|
||||||
allowed_users,
|
allowed_users: cfg.allowed_users,
|
||||||
server_password,
|
server_password: cfg.server_password,
|
||||||
nickserv_password,
|
nickserv_password: cfg.nickserv_password,
|
||||||
sasl_password,
|
sasl_password: cfg.sasl_password,
|
||||||
verify_tls,
|
verify_tls: cfg.verify_tls,
|
||||||
writer: Arc::new(Mutex::new(None)),
|
writer: Arc::new(Mutex::new(None)),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -563,7 +565,8 @@ impl Channel for IrcChannel {
|
||||||
let seq = MSG_SEQ.fetch_add(1, Ordering::Relaxed);
|
let seq = MSG_SEQ.fetch_add(1, Ordering::Relaxed);
|
||||||
let channel_msg = ChannelMessage {
|
let channel_msg = ChannelMessage {
|
||||||
id: format!("irc_{}_{seq}", chrono::Utc::now().timestamp_millis()),
|
id: format!("irc_{}_{seq}", chrono::Utc::now().timestamp_millis()),
|
||||||
sender: reply_to,
|
sender: sender_nick.to_string(),
|
||||||
|
reply_target: reply_to,
|
||||||
content,
|
content,
|
||||||
channel: "irc".to_string(),
|
channel: "irc".to_string(),
|
||||||
timestamp: std::time::SystemTime::now()
|
timestamp: std::time::SystemTime::now()
|
||||||
|
|
@ -807,18 +810,18 @@ mod tests {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn specific_user_allowed() {
|
fn specific_user_allowed() {
|
||||||
let ch = IrcChannel::new(
|
let ch = IrcChannel::new(IrcChannelConfig {
|
||||||
"irc.test".into(),
|
server: "irc.test".into(),
|
||||||
6697,
|
port: 6697,
|
||||||
"bot".into(),
|
nickname: "bot".into(),
|
||||||
None,
|
username: None,
|
||||||
vec![],
|
channels: vec![],
|
||||||
vec!["alice".into(), "bob".into()],
|
allowed_users: vec!["alice".into(), "bob".into()],
|
||||||
None,
|
server_password: None,
|
||||||
None,
|
nickserv_password: None,
|
||||||
None,
|
sasl_password: None,
|
||||||
true,
|
verify_tls: true,
|
||||||
);
|
});
|
||||||
assert!(ch.is_user_allowed("alice"));
|
assert!(ch.is_user_allowed("alice"));
|
||||||
assert!(ch.is_user_allowed("bob"));
|
assert!(ch.is_user_allowed("bob"));
|
||||||
assert!(!ch.is_user_allowed("eve"));
|
assert!(!ch.is_user_allowed("eve"));
|
||||||
|
|
@ -826,18 +829,18 @@ mod tests {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn allowlist_case_insensitive() {
|
fn allowlist_case_insensitive() {
|
||||||
let ch = IrcChannel::new(
|
let ch = IrcChannel::new(IrcChannelConfig {
|
||||||
"irc.test".into(),
|
server: "irc.test".into(),
|
||||||
6697,
|
port: 6697,
|
||||||
"bot".into(),
|
nickname: "bot".into(),
|
||||||
None,
|
username: None,
|
||||||
vec![],
|
channels: vec![],
|
||||||
vec!["Alice".into()],
|
allowed_users: vec!["Alice".into()],
|
||||||
None,
|
server_password: None,
|
||||||
None,
|
nickserv_password: None,
|
||||||
None,
|
sasl_password: None,
|
||||||
true,
|
verify_tls: true,
|
||||||
);
|
});
|
||||||
assert!(ch.is_user_allowed("alice"));
|
assert!(ch.is_user_allowed("alice"));
|
||||||
assert!(ch.is_user_allowed("ALICE"));
|
assert!(ch.is_user_allowed("ALICE"));
|
||||||
assert!(ch.is_user_allowed("Alice"));
|
assert!(ch.is_user_allowed("Alice"));
|
||||||
|
|
@ -845,18 +848,18 @@ mod tests {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn empty_allowlist_denies_all() {
|
fn empty_allowlist_denies_all() {
|
||||||
let ch = IrcChannel::new(
|
let ch = IrcChannel::new(IrcChannelConfig {
|
||||||
"irc.test".into(),
|
server: "irc.test".into(),
|
||||||
6697,
|
port: 6697,
|
||||||
"bot".into(),
|
nickname: "bot".into(),
|
||||||
None,
|
username: None,
|
||||||
vec![],
|
channels: vec![],
|
||||||
vec![],
|
allowed_users: vec![],
|
||||||
None,
|
server_password: None,
|
||||||
None,
|
nickserv_password: None,
|
||||||
None,
|
sasl_password: None,
|
||||||
true,
|
verify_tls: true,
|
||||||
);
|
});
|
||||||
assert!(!ch.is_user_allowed("anyone"));
|
assert!(!ch.is_user_allowed("anyone"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -864,35 +867,35 @@ mod tests {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn new_defaults_username_to_nickname() {
|
fn new_defaults_username_to_nickname() {
|
||||||
let ch = IrcChannel::new(
|
let ch = IrcChannel::new(IrcChannelConfig {
|
||||||
"irc.test".into(),
|
server: "irc.test".into(),
|
||||||
6697,
|
port: 6697,
|
||||||
"mybot".into(),
|
nickname: "mybot".into(),
|
||||||
None,
|
username: None,
|
||||||
vec![],
|
channels: vec![],
|
||||||
vec![],
|
allowed_users: vec![],
|
||||||
None,
|
server_password: None,
|
||||||
None,
|
nickserv_password: None,
|
||||||
None,
|
sasl_password: None,
|
||||||
true,
|
verify_tls: true,
|
||||||
);
|
});
|
||||||
assert_eq!(ch.username, "mybot");
|
assert_eq!(ch.username, "mybot");
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn new_uses_explicit_username() {
|
fn new_uses_explicit_username() {
|
||||||
let ch = IrcChannel::new(
|
let ch = IrcChannel::new(IrcChannelConfig {
|
||||||
"irc.test".into(),
|
server: "irc.test".into(),
|
||||||
6697,
|
port: 6697,
|
||||||
"mybot".into(),
|
nickname: "mybot".into(),
|
||||||
Some("customuser".into()),
|
username: Some("customuser".into()),
|
||||||
vec![],
|
channels: vec![],
|
||||||
vec![],
|
allowed_users: vec![],
|
||||||
None,
|
server_password: None,
|
||||||
None,
|
nickserv_password: None,
|
||||||
None,
|
sasl_password: None,
|
||||||
true,
|
verify_tls: true,
|
||||||
);
|
});
|
||||||
assert_eq!(ch.username, "customuser");
|
assert_eq!(ch.username, "customuser");
|
||||||
assert_eq!(ch.nickname, "mybot");
|
assert_eq!(ch.nickname, "mybot");
|
||||||
}
|
}
|
||||||
|
|
@ -905,18 +908,18 @@ mod tests {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn new_stores_all_fields() {
|
fn new_stores_all_fields() {
|
||||||
let ch = IrcChannel::new(
|
let ch = IrcChannel::new(IrcChannelConfig {
|
||||||
"irc.example.com".into(),
|
server: "irc.example.com".into(),
|
||||||
6697,
|
port: 6697,
|
||||||
"zcbot".into(),
|
nickname: "zcbot".into(),
|
||||||
Some("zeroclaw".into()),
|
username: Some("zeroclaw".into()),
|
||||||
vec!["#test".into()],
|
channels: vec!["#test".into()],
|
||||||
vec!["alice".into()],
|
allowed_users: vec!["alice".into()],
|
||||||
Some("serverpass".into()),
|
server_password: Some("serverpass".into()),
|
||||||
Some("nspass".into()),
|
nickserv_password: Some("nspass".into()),
|
||||||
Some("saslpass".into()),
|
sasl_password: Some("saslpass".into()),
|
||||||
false,
|
verify_tls: false,
|
||||||
);
|
});
|
||||||
assert_eq!(ch.server, "irc.example.com");
|
assert_eq!(ch.server, "irc.example.com");
|
||||||
assert_eq!(ch.port, 6697);
|
assert_eq!(ch.port, 6697);
|
||||||
assert_eq!(ch.nickname, "zcbot");
|
assert_eq!(ch.nickname, "zcbot");
|
||||||
|
|
@ -995,17 +998,17 @@ nickname = "bot"
|
||||||
// ── Helpers ─────────────────────────────────────────────
|
// ── Helpers ─────────────────────────────────────────────
|
||||||
|
|
||||||
fn make_channel() -> IrcChannel {
|
fn make_channel() -> IrcChannel {
|
||||||
IrcChannel::new(
|
IrcChannel::new(IrcChannelConfig {
|
||||||
"irc.example.com".into(),
|
server: "irc.example.com".into(),
|
||||||
6697,
|
port: 6697,
|
||||||
"zcbot".into(),
|
nickname: "zcbot".into(),
|
||||||
None,
|
username: None,
|
||||||
vec!["#zeroclaw".into()],
|
channels: vec!["#zeroclaw".into()],
|
||||||
vec!["*".into()],
|
allowed_users: vec!["*".into()],
|
||||||
None,
|
server_password: None,
|
||||||
None,
|
nickserv_password: None,
|
||||||
None,
|
sasl_password: None,
|
||||||
true,
|
verify_tls: true,
|
||||||
)
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,21 +1,152 @@
|
||||||
use super::traits::{Channel, ChannelMessage};
|
use super::traits::{Channel, ChannelMessage};
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
|
use futures_util::{SinkExt, StreamExt};
|
||||||
|
use prost::Message as ProstMessage;
|
||||||
|
use std::collections::HashMap;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
use std::time::{Duration, Instant};
|
||||||
use tokio::sync::RwLock;
|
use tokio::sync::RwLock;
|
||||||
|
use tokio_tungstenite::tungstenite::Message as WsMsg;
|
||||||
use uuid::Uuid;
|
use uuid::Uuid;
|
||||||
|
|
||||||
const FEISHU_BASE_URL: &str = "https://open.feishu.cn/open-apis";
|
const FEISHU_BASE_URL: &str = "https://open.feishu.cn/open-apis";
|
||||||
|
const FEISHU_WS_BASE_URL: &str = "https://open.feishu.cn";
|
||||||
|
const LARK_BASE_URL: &str = "https://open.larksuite.com/open-apis";
|
||||||
|
const LARK_WS_BASE_URL: &str = "https://open.larksuite.com";
|
||||||
|
|
||||||
/// Lark/Feishu channel — receives events via HTTP callback, sends via Open API
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
// Feishu WebSocket long-connection: pbbp2.proto frame codec
|
||||||
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[derive(Clone, PartialEq, prost::Message)]
|
||||||
|
struct PbHeader {
|
||||||
|
#[prost(string, tag = "1")]
|
||||||
|
pub key: String,
|
||||||
|
#[prost(string, tag = "2")]
|
||||||
|
pub value: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Feishu WS frame (pbbp2.proto).
|
||||||
|
/// method=0 → CONTROL (ping/pong) method=1 → DATA (events)
|
||||||
|
#[derive(Clone, PartialEq, prost::Message)]
|
||||||
|
struct PbFrame {
|
||||||
|
#[prost(uint64, tag = "1")]
|
||||||
|
pub seq_id: u64,
|
||||||
|
#[prost(uint64, tag = "2")]
|
||||||
|
pub log_id: u64,
|
||||||
|
#[prost(int32, tag = "3")]
|
||||||
|
pub service: i32,
|
||||||
|
#[prost(int32, tag = "4")]
|
||||||
|
pub method: i32,
|
||||||
|
#[prost(message, repeated, tag = "5")]
|
||||||
|
pub headers: Vec<PbHeader>,
|
||||||
|
#[prost(bytes = "vec", optional, tag = "8")]
|
||||||
|
pub payload: Option<Vec<u8>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PbFrame {
|
||||||
|
fn header_value<'a>(&'a self, key: &str) -> &'a str {
|
||||||
|
self.headers
|
||||||
|
.iter()
|
||||||
|
.find(|h| h.key == key)
|
||||||
|
.map(|h| h.value.as_str())
|
||||||
|
.unwrap_or("")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Server-sent client config (parsed from pong payload)
|
||||||
|
#[derive(Debug, serde::Deserialize, Default, Clone)]
|
||||||
|
struct WsClientConfig {
|
||||||
|
#[serde(rename = "PingInterval")]
|
||||||
|
ping_interval: Option<u64>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// POST /callback/ws/endpoint response
|
||||||
|
#[derive(Debug, serde::Deserialize)]
|
||||||
|
struct WsEndpointResp {
|
||||||
|
code: i32,
|
||||||
|
#[serde(default)]
|
||||||
|
msg: Option<String>,
|
||||||
|
#[serde(default)]
|
||||||
|
data: Option<WsEndpoint>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, serde::Deserialize)]
|
||||||
|
struct WsEndpoint {
|
||||||
|
#[serde(rename = "URL")]
|
||||||
|
url: String,
|
||||||
|
#[serde(rename = "ClientConfig")]
|
||||||
|
client_config: Option<WsClientConfig>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// LarkEvent envelope (method=1 / type=event payload)
|
||||||
|
#[derive(Debug, serde::Deserialize)]
|
||||||
|
struct LarkEvent {
|
||||||
|
header: LarkEventHeader,
|
||||||
|
event: serde_json::Value,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, serde::Deserialize)]
|
||||||
|
struct LarkEventHeader {
|
||||||
|
event_type: String,
|
||||||
|
#[allow(dead_code)]
|
||||||
|
event_id: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, serde::Deserialize)]
|
||||||
|
struct MsgReceivePayload {
|
||||||
|
sender: LarkSender,
|
||||||
|
message: LarkMessage,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, serde::Deserialize)]
|
||||||
|
struct LarkSender {
|
||||||
|
sender_id: LarkSenderId,
|
||||||
|
#[serde(default)]
|
||||||
|
sender_type: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, serde::Deserialize, Default)]
|
||||||
|
struct LarkSenderId {
|
||||||
|
open_id: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, serde::Deserialize)]
|
||||||
|
struct LarkMessage {
|
||||||
|
message_id: String,
|
||||||
|
chat_id: String,
|
||||||
|
chat_type: String,
|
||||||
|
message_type: String,
|
||||||
|
#[serde(default)]
|
||||||
|
content: String,
|
||||||
|
#[serde(default)]
|
||||||
|
mentions: Vec<serde_json::Value>,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Heartbeat timeout for WS connection — must be larger than ping_interval (default 120 s).
|
||||||
|
/// If no binary frame (pong or event) is received within this window, reconnect.
|
||||||
|
const WS_HEARTBEAT_TIMEOUT: Duration = Duration::from_secs(300);
|
||||||
|
|
||||||
|
/// Lark/Feishu channel.
|
||||||
|
///
|
||||||
|
/// Supports two receive modes (configured via `receive_mode` in config):
|
||||||
|
/// - **`websocket`** (default): persistent WSS long-connection; no public URL needed.
|
||||||
|
/// - **`webhook`**: HTTP callback server; requires a public HTTPS endpoint.
|
||||||
pub struct LarkChannel {
|
pub struct LarkChannel {
|
||||||
app_id: String,
|
app_id: String,
|
||||||
app_secret: String,
|
app_secret: String,
|
||||||
verification_token: String,
|
verification_token: String,
|
||||||
port: u16,
|
port: Option<u16>,
|
||||||
allowed_users: Vec<String>,
|
allowed_users: Vec<String>,
|
||||||
|
/// When true, use Feishu (CN) endpoints; when false, use Lark (international).
|
||||||
|
use_feishu: bool,
|
||||||
|
/// How to receive events: WebSocket long-connection or HTTP webhook.
|
||||||
|
receive_mode: crate::config::schema::LarkReceiveMode,
|
||||||
client: reqwest::Client,
|
client: reqwest::Client,
|
||||||
/// Cached tenant access token
|
/// Cached tenant access token
|
||||||
tenant_token: Arc<RwLock<Option<String>>>,
|
tenant_token: Arc<RwLock<Option<String>>>,
|
||||||
|
/// Dedup set: WS message_ids seen in last ~30 min to prevent double-dispatch
|
||||||
|
ws_seen_ids: Arc<RwLock<HashMap<String, Instant>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl LarkChannel {
|
impl LarkChannel {
|
||||||
|
|
@ -23,7 +154,7 @@ impl LarkChannel {
|
||||||
app_id: String,
|
app_id: String,
|
||||||
app_secret: String,
|
app_secret: String,
|
||||||
verification_token: String,
|
verification_token: String,
|
||||||
port: u16,
|
port: Option<u16>,
|
||||||
allowed_users: Vec<String>,
|
allowed_users: Vec<String>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
|
|
@ -32,11 +163,310 @@ impl LarkChannel {
|
||||||
verification_token,
|
verification_token,
|
||||||
port,
|
port,
|
||||||
allowed_users,
|
allowed_users,
|
||||||
|
use_feishu: true,
|
||||||
|
receive_mode: crate::config::schema::LarkReceiveMode::default(),
|
||||||
client: reqwest::Client::new(),
|
client: reqwest::Client::new(),
|
||||||
tenant_token: Arc::new(RwLock::new(None)),
|
tenant_token: Arc::new(RwLock::new(None)),
|
||||||
|
ws_seen_ids: Arc::new(RwLock::new(HashMap::new())),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Build from `LarkConfig` (preserves `use_feishu` and `receive_mode`).
|
||||||
|
pub fn from_config(config: &crate::config::schema::LarkConfig) -> Self {
|
||||||
|
let mut ch = Self::new(
|
||||||
|
config.app_id.clone(),
|
||||||
|
config.app_secret.clone(),
|
||||||
|
config.verification_token.clone().unwrap_or_default(),
|
||||||
|
config.port,
|
||||||
|
config.allowed_users.clone(),
|
||||||
|
);
|
||||||
|
ch.use_feishu = config.use_feishu;
|
||||||
|
ch.receive_mode = config.receive_mode.clone();
|
||||||
|
ch
|
||||||
|
}
|
||||||
|
|
||||||
|
fn api_base(&self) -> &'static str {
|
||||||
|
if self.use_feishu {
|
||||||
|
FEISHU_BASE_URL
|
||||||
|
} else {
|
||||||
|
LARK_BASE_URL
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn ws_base(&self) -> &'static str {
|
||||||
|
if self.use_feishu {
|
||||||
|
FEISHU_WS_BASE_URL
|
||||||
|
} else {
|
||||||
|
LARK_WS_BASE_URL
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn tenant_access_token_url(&self) -> String {
|
||||||
|
format!("{}/auth/v3/tenant_access_token/internal", self.api_base())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn send_message_url(&self) -> String {
|
||||||
|
format!("{}/im/v1/messages?receive_id_type=chat_id", self.api_base())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// POST /callback/ws/endpoint → (wss_url, client_config)
|
||||||
|
async fn get_ws_endpoint(&self) -> anyhow::Result<(String, WsClientConfig)> {
|
||||||
|
let resp = self
|
||||||
|
.client
|
||||||
|
.post(format!("{}/callback/ws/endpoint", self.ws_base()))
|
||||||
|
.header("locale", if self.use_feishu { "zh" } else { "en" })
|
||||||
|
.json(&serde_json::json!({
|
||||||
|
"AppID": self.app_id,
|
||||||
|
"AppSecret": self.app_secret,
|
||||||
|
}))
|
||||||
|
.send()
|
||||||
|
.await?
|
||||||
|
.json::<WsEndpointResp>()
|
||||||
|
.await?;
|
||||||
|
if resp.code != 0 {
|
||||||
|
anyhow::bail!(
|
||||||
|
"Lark WS endpoint failed: code={} msg={}",
|
||||||
|
resp.code,
|
||||||
|
resp.msg.as_deref().unwrap_or("(none)")
|
||||||
|
);
|
||||||
|
}
|
||||||
|
let ep = resp
|
||||||
|
.data
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("Lark WS endpoint: empty data"))?;
|
||||||
|
Ok((ep.url, ep.client_config.unwrap_or_default()))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// WS long-connection event loop. Returns Ok(()) when the connection closes
|
||||||
|
/// (the caller reconnects).
|
||||||
|
#[allow(clippy::too_many_lines)]
|
||||||
|
async fn listen_ws(&self, tx: tokio::sync::mpsc::Sender<ChannelMessage>) -> anyhow::Result<()> {
|
||||||
|
let (wss_url, client_config) = self.get_ws_endpoint().await?;
|
||||||
|
let service_id = wss_url
|
||||||
|
.split('?')
|
||||||
|
.nth(1)
|
||||||
|
.and_then(|qs| {
|
||||||
|
qs.split('&')
|
||||||
|
.find(|kv| kv.starts_with("service_id="))
|
||||||
|
.and_then(|kv| kv.split('=').nth(1))
|
||||||
|
.and_then(|v| v.parse::<i32>().ok())
|
||||||
|
})
|
||||||
|
.unwrap_or(0);
|
||||||
|
tracing::info!("Lark: connecting to {wss_url}");
|
||||||
|
|
||||||
|
let (ws_stream, _) = tokio_tungstenite::connect_async(&wss_url).await?;
|
||||||
|
let (mut write, mut read) = ws_stream.split();
|
||||||
|
tracing::info!("Lark: WS connected (service_id={service_id})");
|
||||||
|
|
||||||
|
let mut ping_secs = client_config.ping_interval.unwrap_or(120).max(10);
|
||||||
|
let mut hb_interval = tokio::time::interval(Duration::from_secs(ping_secs));
|
||||||
|
let mut timeout_check = tokio::time::interval(Duration::from_secs(10));
|
||||||
|
hb_interval.tick().await; // consume immediate tick
|
||||||
|
|
||||||
|
let mut seq: u64 = 0;
|
||||||
|
let mut last_recv = Instant::now();
|
||||||
|
|
||||||
|
// Send initial ping immediately (like the official SDK) so the server
|
||||||
|
// starts responding with pongs and we can calibrate the ping_interval.
|
||||||
|
seq = seq.wrapping_add(1);
|
||||||
|
let initial_ping = PbFrame {
|
||||||
|
seq_id: seq,
|
||||||
|
log_id: 0,
|
||||||
|
service: service_id,
|
||||||
|
method: 0,
|
||||||
|
headers: vec![PbHeader {
|
||||||
|
key: "type".into(),
|
||||||
|
value: "ping".into(),
|
||||||
|
}],
|
||||||
|
payload: None,
|
||||||
|
};
|
||||||
|
if write
|
||||||
|
.send(WsMsg::Binary(initial_ping.encode_to_vec()))
|
||||||
|
.await
|
||||||
|
.is_err()
|
||||||
|
{
|
||||||
|
anyhow::bail!("Lark: initial ping failed");
|
||||||
|
}
|
||||||
|
// message_id → (fragment_slots, created_at) for multi-part reassembly
|
||||||
|
type FragEntry = (Vec<Option<Vec<u8>>>, Instant);
|
||||||
|
let mut frag_cache: HashMap<String, FragEntry> = HashMap::new();
|
||||||
|
|
||||||
|
loop {
|
||||||
|
tokio::select! {
|
||||||
|
biased;
|
||||||
|
|
||||||
|
_ = hb_interval.tick() => {
|
||||||
|
seq = seq.wrapping_add(1);
|
||||||
|
let ping = PbFrame {
|
||||||
|
seq_id: seq, log_id: 0, service: service_id, method: 0,
|
||||||
|
headers: vec![PbHeader { key: "type".into(), value: "ping".into() }],
|
||||||
|
payload: None,
|
||||||
|
};
|
||||||
|
if write.send(WsMsg::Binary(ping.encode_to_vec())).await.is_err() {
|
||||||
|
tracing::warn!("Lark: ping failed, reconnecting");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
// GC stale fragments > 5 min
|
||||||
|
let cutoff = Instant::now().checked_sub(Duration::from_secs(300)).unwrap_or(Instant::now());
|
||||||
|
frag_cache.retain(|_, (_, ts)| *ts > cutoff);
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = timeout_check.tick() => {
|
||||||
|
if last_recv.elapsed() > WS_HEARTBEAT_TIMEOUT {
|
||||||
|
tracing::warn!("Lark: heartbeat timeout, reconnecting");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
msg = read.next() => {
|
||||||
|
let raw = match msg {
|
||||||
|
Some(Ok(WsMsg::Binary(b))) => { last_recv = Instant::now(); b }
|
||||||
|
Some(Ok(WsMsg::Ping(d))) => { let _ = write.send(WsMsg::Pong(d)).await; continue; }
|
||||||
|
Some(Ok(WsMsg::Close(_))) | None => { tracing::info!("Lark: WS closed — reconnecting"); break; }
|
||||||
|
Some(Err(e)) => { tracing::error!("Lark: WS read error: {e}"); break; }
|
||||||
|
_ => continue,
|
||||||
|
};
|
||||||
|
|
||||||
|
let frame = match PbFrame::decode(&raw[..]) {
|
||||||
|
Ok(f) => f,
|
||||||
|
Err(e) => { tracing::error!("Lark: proto decode: {e}"); continue; }
|
||||||
|
};
|
||||||
|
|
||||||
|
// CONTROL frame
|
||||||
|
if frame.method == 0 {
|
||||||
|
if frame.header_value("type") == "pong" {
|
||||||
|
if let Some(p) = &frame.payload {
|
||||||
|
if let Ok(cfg) = serde_json::from_slice::<WsClientConfig>(p) {
|
||||||
|
if let Some(secs) = cfg.ping_interval {
|
||||||
|
let secs = secs.max(10);
|
||||||
|
if secs != ping_secs {
|
||||||
|
ping_secs = secs;
|
||||||
|
hb_interval = tokio::time::interval(Duration::from_secs(ping_secs));
|
||||||
|
tracing::info!("Lark: ping_interval → {ping_secs}s");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// DATA frame
|
||||||
|
let msg_type = frame.header_value("type").to_string();
|
||||||
|
let msg_id = frame.header_value("message_id").to_string();
|
||||||
|
let sum = frame.header_value("sum").parse::<usize>().unwrap_or(1);
|
||||||
|
let seq_num = frame.header_value("seq").parse::<usize>().unwrap_or(0);
|
||||||
|
|
||||||
|
// ACK immediately (Feishu requires within 3 s)
|
||||||
|
{
|
||||||
|
let mut ack = frame.clone();
|
||||||
|
ack.payload = Some(br#"{"code":200,"headers":{},"data":[]}"#.to_vec());
|
||||||
|
ack.headers.push(PbHeader { key: "biz_rt".into(), value: "0".into() });
|
||||||
|
let _ = write.send(WsMsg::Binary(ack.encode_to_vec())).await;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fragment reassembly
|
||||||
|
let sum = if sum == 0 { 1 } else { sum };
|
||||||
|
let payload: Vec<u8> = if sum == 1 || msg_id.is_empty() || seq_num >= sum {
|
||||||
|
frame.payload.clone().unwrap_or_default()
|
||||||
|
} else {
|
||||||
|
let entry = frag_cache.entry(msg_id.clone())
|
||||||
|
.or_insert_with(|| (vec![None; sum], Instant::now()));
|
||||||
|
if entry.0.len() != sum { *entry = (vec![None; sum], Instant::now()); }
|
||||||
|
entry.0[seq_num] = frame.payload.clone();
|
||||||
|
if entry.0.iter().all(|s| s.is_some()) {
|
||||||
|
let full: Vec<u8> = entry.0.iter()
|
||||||
|
.flat_map(|s| s.as_deref().unwrap_or(&[]))
|
||||||
|
.copied().collect();
|
||||||
|
frag_cache.remove(&msg_id);
|
||||||
|
full
|
||||||
|
} else { continue; }
|
||||||
|
};
|
||||||
|
|
||||||
|
if msg_type != "event" { continue; }
|
||||||
|
|
||||||
|
let event: LarkEvent = match serde_json::from_slice(&payload) {
|
||||||
|
Ok(e) => e,
|
||||||
|
Err(e) => { tracing::error!("Lark: event JSON: {e}"); continue; }
|
||||||
|
};
|
||||||
|
if event.header.event_type != "im.message.receive_v1" { continue; }
|
||||||
|
|
||||||
|
let recv: MsgReceivePayload = match serde_json::from_value(event.event) {
|
||||||
|
Ok(r) => r,
|
||||||
|
Err(e) => { tracing::error!("Lark: payload parse: {e}"); continue; }
|
||||||
|
};
|
||||||
|
|
||||||
|
if recv.sender.sender_type == "app" || recv.sender.sender_type == "bot" { continue; }
|
||||||
|
|
||||||
|
let sender_open_id = recv.sender.sender_id.open_id.as_deref().unwrap_or("");
|
||||||
|
if !self.is_user_allowed(sender_open_id) {
|
||||||
|
tracing::warn!("Lark WS: ignoring {sender_open_id} (not in allowed_users)");
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let lark_msg = &recv.message;
|
||||||
|
|
||||||
|
// Dedup
|
||||||
|
{
|
||||||
|
let now = Instant::now();
|
||||||
|
let mut seen = self.ws_seen_ids.write().await;
|
||||||
|
// GC
|
||||||
|
seen.retain(|_, t| now.duration_since(*t) < Duration::from_secs(30 * 60));
|
||||||
|
if seen.contains_key(&lark_msg.message_id) {
|
||||||
|
tracing::debug!("Lark WS: dup {}", lark_msg.message_id);
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
seen.insert(lark_msg.message_id.clone(), now);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decode content by type (mirrors clawdbot-feishu parsing)
|
||||||
|
let text = match lark_msg.message_type.as_str() {
|
||||||
|
"text" => {
|
||||||
|
let v: serde_json::Value = match serde_json::from_str(&lark_msg.content) {
|
||||||
|
Ok(v) => v,
|
||||||
|
Err(_) => continue,
|
||||||
|
};
|
||||||
|
match v.get("text").and_then(|t| t.as_str()).filter(|s| !s.is_empty()) {
|
||||||
|
Some(t) => t.to_string(),
|
||||||
|
None => continue,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"post" => match parse_post_content(&lark_msg.content) {
|
||||||
|
Some(t) => t,
|
||||||
|
None => continue,
|
||||||
|
},
|
||||||
|
_ => { tracing::debug!("Lark WS: skipping unsupported type '{}'", lark_msg.message_type); continue; }
|
||||||
|
};
|
||||||
|
|
||||||
|
// Strip @_user_N placeholders
|
||||||
|
let text = strip_at_placeholders(&text);
|
||||||
|
let text = text.trim().to_string();
|
||||||
|
if text.is_empty() { continue; }
|
||||||
|
|
||||||
|
// Group-chat: only respond when explicitly @-mentioned
|
||||||
|
if lark_msg.chat_type == "group" && !should_respond_in_group(&lark_msg.mentions) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let channel_msg = ChannelMessage {
|
||||||
|
id: Uuid::new_v4().to_string(),
|
||||||
|
sender: lark_msg.chat_id.clone(),
|
||||||
|
reply_target: lark_msg.chat_id.clone(),
|
||||||
|
content: text,
|
||||||
|
channel: "lark".to_string(),
|
||||||
|
timestamp: std::time::SystemTime::now()
|
||||||
|
.duration_since(std::time::UNIX_EPOCH)
|
||||||
|
.unwrap_or_default()
|
||||||
|
.as_secs(),
|
||||||
|
};
|
||||||
|
|
||||||
|
tracing::debug!("Lark WS: message in {}", lark_msg.chat_id);
|
||||||
|
if tx.send(channel_msg).await.is_err() { break; }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
/// Check if a user open_id is allowed
|
/// Check if a user open_id is allowed
|
||||||
fn is_user_allowed(&self, open_id: &str) -> bool {
|
fn is_user_allowed(&self, open_id: &str) -> bool {
|
||||||
self.allowed_users.iter().any(|u| u == "*" || u == open_id)
|
self.allowed_users.iter().any(|u| u == "*" || u == open_id)
|
||||||
|
|
@ -52,7 +482,7 @@ impl LarkChannel {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let url = format!("{FEISHU_BASE_URL}/auth/v3/tenant_access_token/internal");
|
let url = self.tenant_access_token_url();
|
||||||
let body = serde_json::json!({
|
let body = serde_json::json!({
|
||||||
"app_id": self.app_id,
|
"app_id": self.app_id,
|
||||||
"app_secret": self.app_secret,
|
"app_secret": self.app_secret,
|
||||||
|
|
@ -127,31 +557,41 @@ impl LarkChannel {
|
||||||
return messages;
|
return messages;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extract message content (text only)
|
// Extract message content (text and post supported)
|
||||||
let msg_type = event
|
let msg_type = event
|
||||||
.pointer("/message/message_type")
|
.pointer("/message/message_type")
|
||||||
.and_then(|t| t.as_str())
|
.and_then(|t| t.as_str())
|
||||||
.unwrap_or("");
|
.unwrap_or("");
|
||||||
|
|
||||||
if msg_type != "text" {
|
|
||||||
tracing::debug!("Lark: skipping non-text message type: {msg_type}");
|
|
||||||
return messages;
|
|
||||||
}
|
|
||||||
|
|
||||||
let content_str = event
|
let content_str = event
|
||||||
.pointer("/message/content")
|
.pointer("/message/content")
|
||||||
.and_then(|c| c.as_str())
|
.and_then(|c| c.as_str())
|
||||||
.unwrap_or("");
|
.unwrap_or("");
|
||||||
|
|
||||||
// content is a JSON string like "{\"text\":\"hello\"}"
|
let text: String = match msg_type {
|
||||||
let text = serde_json::from_str::<serde_json::Value>(content_str)
|
"text" => {
|
||||||
.ok()
|
let extracted = serde_json::from_str::<serde_json::Value>(content_str)
|
||||||
.and_then(|v| v.get("text").and_then(|t| t.as_str()).map(String::from))
|
.ok()
|
||||||
.unwrap_or_default();
|
.and_then(|v| {
|
||||||
|
v.get("text")
|
||||||
if text.is_empty() {
|
.and_then(|t| t.as_str())
|
||||||
return messages;
|
.filter(|s| !s.is_empty())
|
||||||
}
|
.map(String::from)
|
||||||
|
});
|
||||||
|
match extracted {
|
||||||
|
Some(t) => t,
|
||||||
|
None => return messages,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"post" => match parse_post_content(content_str) {
|
||||||
|
Some(t) => t,
|
||||||
|
None => return messages,
|
||||||
|
},
|
||||||
|
_ => {
|
||||||
|
tracing::debug!("Lark: skipping unsupported message type: {msg_type}");
|
||||||
|
return messages;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
let timestamp = event
|
let timestamp = event
|
||||||
.pointer("/message/create_time")
|
.pointer("/message/create_time")
|
||||||
|
|
@ -174,6 +614,7 @@ impl LarkChannel {
|
||||||
messages.push(ChannelMessage {
|
messages.push(ChannelMessage {
|
||||||
id: Uuid::new_v4().to_string(),
|
id: Uuid::new_v4().to_string(),
|
||||||
sender: chat_id.to_string(),
|
sender: chat_id.to_string(),
|
||||||
|
reply_target: chat_id.to_string(),
|
||||||
content: text,
|
content: text,
|
||||||
channel: "lark".to_string(),
|
channel: "lark".to_string(),
|
||||||
timestamp,
|
timestamp,
|
||||||
|
|
@ -191,7 +632,7 @@ impl Channel for LarkChannel {
|
||||||
|
|
||||||
async fn send(&self, message: &str, recipient: &str) -> anyhow::Result<()> {
|
async fn send(&self, message: &str, recipient: &str) -> anyhow::Result<()> {
|
||||||
let token = self.get_tenant_access_token().await?;
|
let token = self.get_tenant_access_token().await?;
|
||||||
let url = format!("{FEISHU_BASE_URL}/im/v1/messages?receive_id_type=chat_id");
|
let url = self.send_message_url();
|
||||||
|
|
||||||
let content = serde_json::json!({ "text": message }).to_string();
|
let content = serde_json::json!({ "text": message }).to_string();
|
||||||
let body = serde_json::json!({
|
let body = serde_json::json!({
|
||||||
|
|
@ -238,6 +679,25 @@ impl Channel for LarkChannel {
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn listen(&self, tx: tokio::sync::mpsc::Sender<ChannelMessage>) -> anyhow::Result<()> {
|
async fn listen(&self, tx: tokio::sync::mpsc::Sender<ChannelMessage>) -> anyhow::Result<()> {
|
||||||
|
use crate::config::schema::LarkReceiveMode;
|
||||||
|
match self.receive_mode {
|
||||||
|
LarkReceiveMode::Websocket => self.listen_ws(tx).await,
|
||||||
|
LarkReceiveMode::Webhook => self.listen_http(tx).await,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn health_check(&self) -> bool {
|
||||||
|
self.get_tenant_access_token().await.is_ok()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl LarkChannel {
|
||||||
|
/// HTTP callback server (legacy — requires a public endpoint).
|
||||||
|
/// Use `listen()` (WS long-connection) for new deployments.
|
||||||
|
pub async fn listen_http(
|
||||||
|
&self,
|
||||||
|
tx: tokio::sync::mpsc::Sender<ChannelMessage>,
|
||||||
|
) -> anyhow::Result<()> {
|
||||||
use axum::{extract::State, routing::post, Json, Router};
|
use axum::{extract::State, routing::post, Json, Router};
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
|
|
@ -282,13 +742,17 @@ impl Channel for LarkChannel {
|
||||||
(StatusCode::OK, "ok").into_response()
|
(StatusCode::OK, "ok").into_response()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
let port = self.port.ok_or_else(|| {
|
||||||
|
anyhow::anyhow!("Lark webhook mode requires `port` to be set in [channels_config.lark]")
|
||||||
|
})?;
|
||||||
|
|
||||||
let state = AppState {
|
let state = AppState {
|
||||||
verification_token: self.verification_token.clone(),
|
verification_token: self.verification_token.clone(),
|
||||||
channel: Arc::new(LarkChannel::new(
|
channel: Arc::new(LarkChannel::new(
|
||||||
self.app_id.clone(),
|
self.app_id.clone(),
|
||||||
self.app_secret.clone(),
|
self.app_secret.clone(),
|
||||||
self.verification_token.clone(),
|
self.verification_token.clone(),
|
||||||
self.port,
|
None,
|
||||||
self.allowed_users.clone(),
|
self.allowed_users.clone(),
|
||||||
)),
|
)),
|
||||||
tx,
|
tx,
|
||||||
|
|
@ -298,7 +762,7 @@ impl Channel for LarkChannel {
|
||||||
.route("/lark", post(handle_event))
|
.route("/lark", post(handle_event))
|
||||||
.with_state(state);
|
.with_state(state);
|
||||||
|
|
||||||
let addr = std::net::SocketAddr::from(([0, 0, 0, 0], self.port));
|
let addr = std::net::SocketAddr::from(([0, 0, 0, 0], port));
|
||||||
tracing::info!("Lark event callback server listening on {addr}");
|
tracing::info!("Lark event callback server listening on {addr}");
|
||||||
|
|
||||||
let listener = tokio::net::TcpListener::bind(addr).await?;
|
let listener = tokio::net::TcpListener::bind(addr).await?;
|
||||||
|
|
@ -306,10 +770,110 @@ impl Channel for LarkChannel {
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
async fn health_check(&self) -> bool {
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
self.get_tenant_access_token().await.is_ok()
|
// WS helper functions
|
||||||
|
// ─────────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// Flatten a Feishu `post` rich-text message to plain text.
|
||||||
|
///
|
||||||
|
/// Returns `None` when the content cannot be parsed or yields no usable text,
|
||||||
|
/// so callers can simply `continue` rather than forwarding a meaningless
|
||||||
|
/// placeholder string to the agent.
|
||||||
|
fn parse_post_content(content: &str) -> Option<String> {
|
||||||
|
let parsed = serde_json::from_str::<serde_json::Value>(content).ok()?;
|
||||||
|
let locale = parsed
|
||||||
|
.get("zh_cn")
|
||||||
|
.or_else(|| parsed.get("en_us"))
|
||||||
|
.or_else(|| {
|
||||||
|
parsed
|
||||||
|
.as_object()
|
||||||
|
.and_then(|m| m.values().find(|v| v.is_object()))
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let mut text = String::new();
|
||||||
|
|
||||||
|
if let Some(title) = locale
|
||||||
|
.get("title")
|
||||||
|
.and_then(|t| t.as_str())
|
||||||
|
.filter(|s| !s.is_empty())
|
||||||
|
{
|
||||||
|
text.push_str(title);
|
||||||
|
text.push_str("\n\n");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if let Some(paragraphs) = locale.get("content").and_then(|c| c.as_array()) {
|
||||||
|
for para in paragraphs {
|
||||||
|
if let Some(elements) = para.as_array() {
|
||||||
|
for el in elements {
|
||||||
|
match el.get("tag").and_then(|t| t.as_str()).unwrap_or("") {
|
||||||
|
"text" => {
|
||||||
|
if let Some(t) = el.get("text").and_then(|t| t.as_str()) {
|
||||||
|
text.push_str(t);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"a" => {
|
||||||
|
text.push_str(
|
||||||
|
el.get("text")
|
||||||
|
.and_then(|t| t.as_str())
|
||||||
|
.filter(|s| !s.is_empty())
|
||||||
|
.or_else(|| el.get("href").and_then(|h| h.as_str()))
|
||||||
|
.unwrap_or(""),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
"at" => {
|
||||||
|
let n = el
|
||||||
|
.get("user_name")
|
||||||
|
.and_then(|n| n.as_str())
|
||||||
|
.or_else(|| el.get("user_id").and_then(|i| i.as_str()))
|
||||||
|
.unwrap_or("user");
|
||||||
|
text.push('@');
|
||||||
|
text.push_str(n);
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
text.push('\n');
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let result = text.trim().to_string();
|
||||||
|
if result.is_empty() {
|
||||||
|
None
|
||||||
|
} else {
|
||||||
|
Some(result)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Remove `@_user_N` placeholder tokens injected by Feishu in group chats.
|
||||||
|
fn strip_at_placeholders(text: &str) -> String {
|
||||||
|
let mut result = String::with_capacity(text.len());
|
||||||
|
let mut chars = text.char_indices().peekable();
|
||||||
|
while let Some((_, ch)) = chars.next() {
|
||||||
|
if ch == '@' {
|
||||||
|
let rest: String = chars.clone().map(|(_, c)| c).collect();
|
||||||
|
if let Some(after) = rest.strip_prefix("_user_") {
|
||||||
|
let skip =
|
||||||
|
"_user_".len() + after.chars().take_while(|c| c.is_ascii_digit()).count();
|
||||||
|
for _ in 0..=skip {
|
||||||
|
chars.next();
|
||||||
|
}
|
||||||
|
if chars.peek().map(|(_, c)| *c == ' ').unwrap_or(false) {
|
||||||
|
chars.next();
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
result.push(ch);
|
||||||
|
}
|
||||||
|
result
|
||||||
|
}
|
||||||
|
|
||||||
|
/// In group chats, only respond when the bot is explicitly @-mentioned.
|
||||||
|
fn should_respond_in_group(mentions: &[serde_json::Value]) -> bool {
|
||||||
|
!mentions.is_empty()
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
|
@ -321,7 +885,7 @@ mod tests {
|
||||||
"cli_test_app_id".into(),
|
"cli_test_app_id".into(),
|
||||||
"test_app_secret".into(),
|
"test_app_secret".into(),
|
||||||
"test_verification_token".into(),
|
"test_verification_token".into(),
|
||||||
9898,
|
None,
|
||||||
vec!["ou_testuser123".into()],
|
vec!["ou_testuser123".into()],
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
@ -345,7 +909,7 @@ mod tests {
|
||||||
"id".into(),
|
"id".into(),
|
||||||
"secret".into(),
|
"secret".into(),
|
||||||
"token".into(),
|
"token".into(),
|
||||||
9898,
|
None,
|
||||||
vec!["*".into()],
|
vec!["*".into()],
|
||||||
);
|
);
|
||||||
assert!(ch.is_user_allowed("ou_anyone"));
|
assert!(ch.is_user_allowed("ou_anyone"));
|
||||||
|
|
@ -353,7 +917,7 @@ mod tests {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn lark_user_denied_empty() {
|
fn lark_user_denied_empty() {
|
||||||
let ch = LarkChannel::new("id".into(), "secret".into(), "token".into(), 9898, vec![]);
|
let ch = LarkChannel::new("id".into(), "secret".into(), "token".into(), None, vec![]);
|
||||||
assert!(!ch.is_user_allowed("ou_anyone"));
|
assert!(!ch.is_user_allowed("ou_anyone"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -426,7 +990,7 @@ mod tests {
|
||||||
"id".into(),
|
"id".into(),
|
||||||
"secret".into(),
|
"secret".into(),
|
||||||
"token".into(),
|
"token".into(),
|
||||||
9898,
|
None,
|
||||||
vec!["*".into()],
|
vec!["*".into()],
|
||||||
);
|
);
|
||||||
let payload = serde_json::json!({
|
let payload = serde_json::json!({
|
||||||
|
|
@ -451,7 +1015,7 @@ mod tests {
|
||||||
"id".into(),
|
"id".into(),
|
||||||
"secret".into(),
|
"secret".into(),
|
||||||
"token".into(),
|
"token".into(),
|
||||||
9898,
|
None,
|
||||||
vec!["*".into()],
|
vec!["*".into()],
|
||||||
);
|
);
|
||||||
let payload = serde_json::json!({
|
let payload = serde_json::json!({
|
||||||
|
|
@ -488,7 +1052,7 @@ mod tests {
|
||||||
"id".into(),
|
"id".into(),
|
||||||
"secret".into(),
|
"secret".into(),
|
||||||
"token".into(),
|
"token".into(),
|
||||||
9898,
|
None,
|
||||||
vec!["*".into()],
|
vec!["*".into()],
|
||||||
);
|
);
|
||||||
let payload = serde_json::json!({
|
let payload = serde_json::json!({
|
||||||
|
|
@ -512,7 +1076,7 @@ mod tests {
|
||||||
"id".into(),
|
"id".into(),
|
||||||
"secret".into(),
|
"secret".into(),
|
||||||
"token".into(),
|
"token".into(),
|
||||||
9898,
|
None,
|
||||||
vec!["*".into()],
|
vec!["*".into()],
|
||||||
);
|
);
|
||||||
let payload = serde_json::json!({
|
let payload = serde_json::json!({
|
||||||
|
|
@ -550,7 +1114,7 @@ mod tests {
|
||||||
"id".into(),
|
"id".into(),
|
||||||
"secret".into(),
|
"secret".into(),
|
||||||
"token".into(),
|
"token".into(),
|
||||||
9898,
|
None,
|
||||||
vec!["*".into()],
|
vec!["*".into()],
|
||||||
);
|
);
|
||||||
let payload = serde_json::json!({
|
let payload = serde_json::json!({
|
||||||
|
|
@ -571,7 +1135,7 @@ mod tests {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn lark_config_serde() {
|
fn lark_config_serde() {
|
||||||
use crate::config::schema::LarkConfig;
|
use crate::config::schema::{LarkConfig, LarkReceiveMode};
|
||||||
let lc = LarkConfig {
|
let lc = LarkConfig {
|
||||||
app_id: "cli_app123".into(),
|
app_id: "cli_app123".into(),
|
||||||
app_secret: "secret456".into(),
|
app_secret: "secret456".into(),
|
||||||
|
|
@ -579,6 +1143,8 @@ mod tests {
|
||||||
verification_token: Some("vtoken789".into()),
|
verification_token: Some("vtoken789".into()),
|
||||||
allowed_users: vec!["ou_user1".into(), "ou_user2".into()],
|
allowed_users: vec!["ou_user1".into(), "ou_user2".into()],
|
||||||
use_feishu: false,
|
use_feishu: false,
|
||||||
|
receive_mode: LarkReceiveMode::default(),
|
||||||
|
port: None,
|
||||||
};
|
};
|
||||||
let json = serde_json::to_string(&lc).unwrap();
|
let json = serde_json::to_string(&lc).unwrap();
|
||||||
let parsed: LarkConfig = serde_json::from_str(&json).unwrap();
|
let parsed: LarkConfig = serde_json::from_str(&json).unwrap();
|
||||||
|
|
@ -590,7 +1156,7 @@ mod tests {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn lark_config_toml_roundtrip() {
|
fn lark_config_toml_roundtrip() {
|
||||||
use crate::config::schema::LarkConfig;
|
use crate::config::schema::{LarkConfig, LarkReceiveMode};
|
||||||
let lc = LarkConfig {
|
let lc = LarkConfig {
|
||||||
app_id: "app".into(),
|
app_id: "app".into(),
|
||||||
app_secret: "secret".into(),
|
app_secret: "secret".into(),
|
||||||
|
|
@ -598,6 +1164,8 @@ mod tests {
|
||||||
verification_token: Some("tok".into()),
|
verification_token: Some("tok".into()),
|
||||||
allowed_users: vec!["*".into()],
|
allowed_users: vec!["*".into()],
|
||||||
use_feishu: false,
|
use_feishu: false,
|
||||||
|
receive_mode: LarkReceiveMode::Webhook,
|
||||||
|
port: Some(9898),
|
||||||
};
|
};
|
||||||
let toml_str = toml::to_string(&lc).unwrap();
|
let toml_str = toml::to_string(&lc).unwrap();
|
||||||
let parsed: LarkConfig = toml::from_str(&toml_str).unwrap();
|
let parsed: LarkConfig = toml::from_str(&toml_str).unwrap();
|
||||||
|
|
@ -608,11 +1176,36 @@ mod tests {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn lark_config_defaults_optional_fields() {
|
fn lark_config_defaults_optional_fields() {
|
||||||
use crate::config::schema::LarkConfig;
|
use crate::config::schema::{LarkConfig, LarkReceiveMode};
|
||||||
let json = r#"{"app_id":"a","app_secret":"s"}"#;
|
let json = r#"{"app_id":"a","app_secret":"s"}"#;
|
||||||
let parsed: LarkConfig = serde_json::from_str(json).unwrap();
|
let parsed: LarkConfig = serde_json::from_str(json).unwrap();
|
||||||
assert!(parsed.verification_token.is_none());
|
assert!(parsed.verification_token.is_none());
|
||||||
assert!(parsed.allowed_users.is_empty());
|
assert!(parsed.allowed_users.is_empty());
|
||||||
|
assert_eq!(parsed.receive_mode, LarkReceiveMode::Websocket);
|
||||||
|
assert!(parsed.port.is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn lark_from_config_preserves_mode_and_region() {
|
||||||
|
use crate::config::schema::{LarkConfig, LarkReceiveMode};
|
||||||
|
|
||||||
|
let cfg = LarkConfig {
|
||||||
|
app_id: "cli_app123".into(),
|
||||||
|
app_secret: "secret456".into(),
|
||||||
|
encrypt_key: None,
|
||||||
|
verification_token: Some("vtoken789".into()),
|
||||||
|
allowed_users: vec!["*".into()],
|
||||||
|
use_feishu: false,
|
||||||
|
receive_mode: LarkReceiveMode::Webhook,
|
||||||
|
port: Some(9898),
|
||||||
|
};
|
||||||
|
|
||||||
|
let ch = LarkChannel::from_config(&cfg);
|
||||||
|
|
||||||
|
assert_eq!(ch.api_base(), LARK_BASE_URL);
|
||||||
|
assert_eq!(ch.ws_base(), LARK_WS_BASE_URL);
|
||||||
|
assert_eq!(ch.receive_mode, LarkReceiveMode::Webhook);
|
||||||
|
assert_eq!(ch.port, Some(9898));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
@ -622,7 +1215,7 @@ mod tests {
|
||||||
"id".into(),
|
"id".into(),
|
||||||
"secret".into(),
|
"secret".into(),
|
||||||
"token".into(),
|
"token".into(),
|
||||||
9898,
|
None,
|
||||||
vec!["*".into()],
|
vec!["*".into()],
|
||||||
);
|
);
|
||||||
let payload = serde_json::json!({
|
let payload = serde_json::json!({
|
||||||
|
|
|
||||||
|
|
@ -230,6 +230,7 @@ impl Channel for MatrixChannel {
|
||||||
let msg = ChannelMessage {
|
let msg = ChannelMessage {
|
||||||
id: format!("mx_{}", chrono::Utc::now().timestamp_millis()),
|
id: format!("mx_{}", chrono::Utc::now().timestamp_millis()),
|
||||||
sender: event.sender.clone(),
|
sender: event.sender.clone(),
|
||||||
|
reply_target: event.sender.clone(),
|
||||||
content: body.clone(),
|
content: body.clone(),
|
||||||
channel: "matrix".to_string(),
|
channel: "matrix".to_string(),
|
||||||
timestamp: std::time::SystemTime::now()
|
timestamp: std::time::SystemTime::now()
|
||||||
|
|
|
||||||
|
|
@ -69,10 +69,19 @@ fn conversation_memory_key(msg: &traits::ChannelMessage) -> String {
|
||||||
format!("{}_{}_{}", msg.channel, msg.sender, msg.id)
|
format!("{}_{}_{}", msg.channel, msg.sender, msg.id)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn channel_delivery_instructions(channel_name: &str) -> Option<&'static str> {
|
||||||
|
match channel_name {
|
||||||
|
"telegram" => Some(
|
||||||
|
"When responding on Telegram, include media markers for files or URLs that should be sent as attachments. Use one marker per attachment with this exact syntax: [IMAGE:<path-or-url>], [DOCUMENT:<path-or-url>], [VIDEO:<path-or-url>], [AUDIO:<path-or-url>], or [VOICE:<path-or-url>]. Keep normal user-facing text outside markers and never wrap markers in code fences.",
|
||||||
|
),
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
async fn build_memory_context(mem: &dyn Memory, user_msg: &str) -> String {
|
async fn build_memory_context(mem: &dyn Memory, user_msg: &str) -> String {
|
||||||
let mut context = String::new();
|
let mut context = String::new();
|
||||||
|
|
||||||
if let Ok(entries) = mem.recall(user_msg, 5).await {
|
if let Ok(entries) = mem.recall(user_msg, 5, None).await {
|
||||||
if !entries.is_empty() {
|
if !entries.is_empty() {
|
||||||
context.push_str("[Memory context]\n");
|
context.push_str("[Memory context]\n");
|
||||||
for entry in &entries {
|
for entry in &entries {
|
||||||
|
|
@ -158,6 +167,7 @@ async fn process_channel_message(ctx: Arc<ChannelRuntimeContext>, msg: traits::C
|
||||||
&autosave_key,
|
&autosave_key,
|
||||||
&msg.content,
|
&msg.content,
|
||||||
crate::memory::MemoryCategory::Conversation,
|
crate::memory::MemoryCategory::Conversation,
|
||||||
|
None,
|
||||||
)
|
)
|
||||||
.await;
|
.await;
|
||||||
}
|
}
|
||||||
|
|
@ -171,7 +181,7 @@ async fn process_channel_message(ctx: Arc<ChannelRuntimeContext>, msg: traits::C
|
||||||
let target_channel = ctx.channels_by_name.get(&msg.channel).cloned();
|
let target_channel = ctx.channels_by_name.get(&msg.channel).cloned();
|
||||||
|
|
||||||
if let Some(channel) = target_channel.as_ref() {
|
if let Some(channel) = target_channel.as_ref() {
|
||||||
if let Err(e) = channel.start_typing(&msg.sender).await {
|
if let Err(e) = channel.start_typing(&msg.reply_target).await {
|
||||||
tracing::debug!("Failed to start typing on {}: {e}", channel.name());
|
tracing::debug!("Failed to start typing on {}: {e}", channel.name());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -184,6 +194,10 @@ async fn process_channel_message(ctx: Arc<ChannelRuntimeContext>, msg: traits::C
|
||||||
ChatMessage::user(&enriched_message),
|
ChatMessage::user(&enriched_message),
|
||||||
];
|
];
|
||||||
|
|
||||||
|
if let Some(instructions) = channel_delivery_instructions(&msg.channel) {
|
||||||
|
history.push(ChatMessage::system(instructions));
|
||||||
|
}
|
||||||
|
|
||||||
let llm_result = tokio::time::timeout(
|
let llm_result = tokio::time::timeout(
|
||||||
Duration::from_secs(CHANNEL_MESSAGE_TIMEOUT_SECS),
|
Duration::from_secs(CHANNEL_MESSAGE_TIMEOUT_SECS),
|
||||||
run_tool_call_loop(
|
run_tool_call_loop(
|
||||||
|
|
@ -200,7 +214,7 @@ async fn process_channel_message(ctx: Arc<ChannelRuntimeContext>, msg: traits::C
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
if let Some(channel) = target_channel.as_ref() {
|
if let Some(channel) = target_channel.as_ref() {
|
||||||
if let Err(e) = channel.stop_typing(&msg.sender).await {
|
if let Err(e) = channel.stop_typing(&msg.reply_target).await {
|
||||||
tracing::debug!("Failed to stop typing on {}: {e}", channel.name());
|
tracing::debug!("Failed to stop typing on {}: {e}", channel.name());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -224,7 +238,9 @@ async fn process_channel_message(ctx: Arc<ChannelRuntimeContext>, msg: traits::C
|
||||||
started_at.elapsed().as_millis()
|
started_at.elapsed().as_millis()
|
||||||
);
|
);
|
||||||
if let Some(channel) = target_channel.as_ref() {
|
if let Some(channel) = target_channel.as_ref() {
|
||||||
let _ = channel.send(&format!("⚠️ Error: {e}"), &msg.sender).await;
|
let _ = channel
|
||||||
|
.send(&format!("⚠️ Error: {e}"), &msg.reply_target)
|
||||||
|
.await;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Err(_) => {
|
Err(_) => {
|
||||||
|
|
@ -241,7 +257,7 @@ async fn process_channel_message(ctx: Arc<ChannelRuntimeContext>, msg: traits::C
|
||||||
let _ = channel
|
let _ = channel
|
||||||
.send(
|
.send(
|
||||||
"⚠️ Request timed out while waiting for the model. Please try again.",
|
"⚠️ Request timed out while waiting for the model. Please try again.",
|
||||||
&msg.sender,
|
&msg.reply_target,
|
||||||
)
|
)
|
||||||
.await;
|
.await;
|
||||||
}
|
}
|
||||||
|
|
@ -483,6 +499,16 @@ pub fn build_system_prompt(
|
||||||
std::env::consts::OS,
|
std::env::consts::OS,
|
||||||
);
|
);
|
||||||
|
|
||||||
|
// ── 8. Channel Capabilities ─────────────────────────────────────
|
||||||
|
prompt.push_str("## Channel Capabilities\n\n");
|
||||||
|
prompt.push_str(
|
||||||
|
"- You are running as a Discord bot. You CAN and do send messages to Discord channels.\n",
|
||||||
|
);
|
||||||
|
prompt.push_str("- When someone messages you on Discord, your response is automatically sent back to Discord.\n");
|
||||||
|
prompt.push_str("- You do NOT need to ask permission to respond — just respond directly.\n");
|
||||||
|
prompt.push_str("- NEVER repeat, describe, or echo credentials, tokens, API keys, or secrets in your responses.\n");
|
||||||
|
prompt.push_str("- If a tool output contains credentials, they have already been redacted — do not mention them.\n\n");
|
||||||
|
|
||||||
if prompt.is_empty() {
|
if prompt.is_empty() {
|
||||||
"You are ZeroClaw, a fast and efficient AI assistant built in Rust. Be helpful, concise, and direct.".to_string()
|
"You are ZeroClaw, a fast and efficient AI assistant built in Rust. Be helpful, concise, and direct.".to_string()
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -619,6 +645,7 @@ pub async fn doctor_channels(config: Config) -> Result<()> {
|
||||||
dc.guild_id.clone(),
|
dc.guild_id.clone(),
|
||||||
dc.allowed_users.clone(),
|
dc.allowed_users.clone(),
|
||||||
dc.listen_to_bots,
|
dc.listen_to_bots,
|
||||||
|
dc.mention_only,
|
||||||
)),
|
)),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
@ -672,32 +699,23 @@ pub async fn doctor_channels(config: Config) -> Result<()> {
|
||||||
if let Some(ref irc) = config.channels_config.irc {
|
if let Some(ref irc) = config.channels_config.irc {
|
||||||
channels.push((
|
channels.push((
|
||||||
"IRC",
|
"IRC",
|
||||||
Arc::new(IrcChannel::new(
|
Arc::new(IrcChannel::new(irc::IrcChannelConfig {
|
||||||
irc.server.clone(),
|
server: irc.server.clone(),
|
||||||
irc.port,
|
port: irc.port,
|
||||||
irc.nickname.clone(),
|
nickname: irc.nickname.clone(),
|
||||||
irc.username.clone(),
|
username: irc.username.clone(),
|
||||||
irc.channels.clone(),
|
channels: irc.channels.clone(),
|
||||||
irc.allowed_users.clone(),
|
allowed_users: irc.allowed_users.clone(),
|
||||||
irc.server_password.clone(),
|
server_password: irc.server_password.clone(),
|
||||||
irc.nickserv_password.clone(),
|
nickserv_password: irc.nickserv_password.clone(),
|
||||||
irc.sasl_password.clone(),
|
sasl_password: irc.sasl_password.clone(),
|
||||||
irc.verify_tls.unwrap_or(true),
|
verify_tls: irc.verify_tls.unwrap_or(true),
|
||||||
)),
|
})),
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(ref lk) = config.channels_config.lark {
|
if let Some(ref lk) = config.channels_config.lark {
|
||||||
channels.push((
|
channels.push(("Lark", Arc::new(LarkChannel::from_config(lk))));
|
||||||
"Lark",
|
|
||||||
Arc::new(LarkChannel::new(
|
|
||||||
lk.app_id.clone(),
|
|
||||||
lk.app_secret.clone(),
|
|
||||||
lk.verification_token.clone().unwrap_or_default(),
|
|
||||||
9898,
|
|
||||||
lk.allowed_users.clone(),
|
|
||||||
)),
|
|
||||||
));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(ref dt) = config.channels_config.dingtalk {
|
if let Some(ref dt) = config.channels_config.dingtalk {
|
||||||
|
|
@ -762,6 +780,7 @@ pub async fn start_channels(config: Config) -> Result<()> {
|
||||||
let provider: Arc<dyn Provider> = Arc::from(providers::create_resilient_provider(
|
let provider: Arc<dyn Provider> = Arc::from(providers::create_resilient_provider(
|
||||||
&provider_name,
|
&provider_name,
|
||||||
config.api_key.as_deref(),
|
config.api_key.as_deref(),
|
||||||
|
config.api_url.as_deref(),
|
||||||
&config.reliability,
|
&config.reliability,
|
||||||
)?);
|
)?);
|
||||||
|
|
||||||
|
|
@ -860,6 +879,10 @@ pub async fn start_channels(config: Config) -> Result<()> {
|
||||||
"schedule",
|
"schedule",
|
||||||
"Manage scheduled tasks (create/list/get/cancel/pause/resume). Supports recurring cron and one-shot delays.",
|
"Manage scheduled tasks (create/list/get/cancel/pause/resume). Supports recurring cron and one-shot delays.",
|
||||||
));
|
));
|
||||||
|
tool_descs.push((
|
||||||
|
"pushover",
|
||||||
|
"Send a Pushover notification to your device. Requires PUSHOVER_TOKEN and PUSHOVER_USER_KEY in .env file.",
|
||||||
|
));
|
||||||
if !config.agents.is_empty() {
|
if !config.agents.is_empty() {
|
||||||
tool_descs.push((
|
tool_descs.push((
|
||||||
"delegate",
|
"delegate",
|
||||||
|
|
@ -909,6 +932,7 @@ pub async fn start_channels(config: Config) -> Result<()> {
|
||||||
dc.guild_id.clone(),
|
dc.guild_id.clone(),
|
||||||
dc.allowed_users.clone(),
|
dc.allowed_users.clone(),
|
||||||
dc.listen_to_bots,
|
dc.listen_to_bots,
|
||||||
|
dc.mention_only,
|
||||||
)));
|
)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -947,28 +971,22 @@ pub async fn start_channels(config: Config) -> Result<()> {
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(ref irc) = config.channels_config.irc {
|
if let Some(ref irc) = config.channels_config.irc {
|
||||||
channels.push(Arc::new(IrcChannel::new(
|
channels.push(Arc::new(IrcChannel::new(irc::IrcChannelConfig {
|
||||||
irc.server.clone(),
|
server: irc.server.clone(),
|
||||||
irc.port,
|
port: irc.port,
|
||||||
irc.nickname.clone(),
|
nickname: irc.nickname.clone(),
|
||||||
irc.username.clone(),
|
username: irc.username.clone(),
|
||||||
irc.channels.clone(),
|
channels: irc.channels.clone(),
|
||||||
irc.allowed_users.clone(),
|
allowed_users: irc.allowed_users.clone(),
|
||||||
irc.server_password.clone(),
|
server_password: irc.server_password.clone(),
|
||||||
irc.nickserv_password.clone(),
|
nickserv_password: irc.nickserv_password.clone(),
|
||||||
irc.sasl_password.clone(),
|
sasl_password: irc.sasl_password.clone(),
|
||||||
irc.verify_tls.unwrap_or(true),
|
verify_tls: irc.verify_tls.unwrap_or(true),
|
||||||
)));
|
})));
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(ref lk) = config.channels_config.lark {
|
if let Some(ref lk) = config.channels_config.lark {
|
||||||
channels.push(Arc::new(LarkChannel::new(
|
channels.push(Arc::new(LarkChannel::from_config(lk)));
|
||||||
lk.app_id.clone(),
|
|
||||||
lk.app_secret.clone(),
|
|
||||||
lk.verification_token.clone().unwrap_or_default(),
|
|
||||||
9898,
|
|
||||||
lk.allowed_users.clone(),
|
|
||||||
)));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(ref dt) = config.channels_config.dingtalk {
|
if let Some(ref dt) = config.channels_config.dingtalk {
|
||||||
|
|
@ -1242,6 +1260,7 @@ mod tests {
|
||||||
traits::ChannelMessage {
|
traits::ChannelMessage {
|
||||||
id: "msg-1".to_string(),
|
id: "msg-1".to_string(),
|
||||||
sender: "alice".to_string(),
|
sender: "alice".to_string(),
|
||||||
|
reply_target: "chat-42".to_string(),
|
||||||
content: "What is the BTC price now?".to_string(),
|
content: "What is the BTC price now?".to_string(),
|
||||||
channel: "test-channel".to_string(),
|
channel: "test-channel".to_string(),
|
||||||
timestamp: 1,
|
timestamp: 1,
|
||||||
|
|
@ -1251,6 +1270,7 @@ mod tests {
|
||||||
|
|
||||||
let sent_messages = channel_impl.sent_messages.lock().await;
|
let sent_messages = channel_impl.sent_messages.lock().await;
|
||||||
assert_eq!(sent_messages.len(), 1);
|
assert_eq!(sent_messages.len(), 1);
|
||||||
|
assert!(sent_messages[0].starts_with("chat-42:"));
|
||||||
assert!(sent_messages[0].contains("BTC is currently around"));
|
assert!(sent_messages[0].contains("BTC is currently around"));
|
||||||
assert!(!sent_messages[0].contains("\"tool_calls\""));
|
assert!(!sent_messages[0].contains("\"tool_calls\""));
|
||||||
assert!(!sent_messages[0].contains("mock_price"));
|
assert!(!sent_messages[0].contains("mock_price"));
|
||||||
|
|
@ -1269,6 +1289,7 @@ mod tests {
|
||||||
_key: &str,
|
_key: &str,
|
||||||
_content: &str,
|
_content: &str,
|
||||||
_category: crate::memory::MemoryCategory,
|
_category: crate::memory::MemoryCategory,
|
||||||
|
_session_id: Option<&str>,
|
||||||
) -> anyhow::Result<()> {
|
) -> anyhow::Result<()> {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
@ -1277,6 +1298,7 @@ mod tests {
|
||||||
&self,
|
&self,
|
||||||
_query: &str,
|
_query: &str,
|
||||||
_limit: usize,
|
_limit: usize,
|
||||||
|
_session_id: Option<&str>,
|
||||||
) -> anyhow::Result<Vec<crate::memory::MemoryEntry>> {
|
) -> anyhow::Result<Vec<crate::memory::MemoryEntry>> {
|
||||||
Ok(Vec::new())
|
Ok(Vec::new())
|
||||||
}
|
}
|
||||||
|
|
@ -1288,6 +1310,7 @@ mod tests {
|
||||||
async fn list(
|
async fn list(
|
||||||
&self,
|
&self,
|
||||||
_category: Option<&crate::memory::MemoryCategory>,
|
_category: Option<&crate::memory::MemoryCategory>,
|
||||||
|
_session_id: Option<&str>,
|
||||||
) -> anyhow::Result<Vec<crate::memory::MemoryEntry>> {
|
) -> anyhow::Result<Vec<crate::memory::MemoryEntry>> {
|
||||||
Ok(Vec::new())
|
Ok(Vec::new())
|
||||||
}
|
}
|
||||||
|
|
@ -1331,6 +1354,7 @@ mod tests {
|
||||||
tx.send(traits::ChannelMessage {
|
tx.send(traits::ChannelMessage {
|
||||||
id: "1".to_string(),
|
id: "1".to_string(),
|
||||||
sender: "alice".to_string(),
|
sender: "alice".to_string(),
|
||||||
|
reply_target: "alice".to_string(),
|
||||||
content: "hello".to_string(),
|
content: "hello".to_string(),
|
||||||
channel: "test-channel".to_string(),
|
channel: "test-channel".to_string(),
|
||||||
timestamp: 1,
|
timestamp: 1,
|
||||||
|
|
@ -1340,6 +1364,7 @@ mod tests {
|
||||||
tx.send(traits::ChannelMessage {
|
tx.send(traits::ChannelMessage {
|
||||||
id: "2".to_string(),
|
id: "2".to_string(),
|
||||||
sender: "bob".to_string(),
|
sender: "bob".to_string(),
|
||||||
|
reply_target: "bob".to_string(),
|
||||||
content: "world".to_string(),
|
content: "world".to_string(),
|
||||||
channel: "test-channel".to_string(),
|
channel: "test-channel".to_string(),
|
||||||
timestamp: 2,
|
timestamp: 2,
|
||||||
|
|
@ -1570,6 +1595,25 @@ mod tests {
|
||||||
assert!(truncated.is_char_boundary(truncated.len()));
|
assert!(truncated.is_char_boundary(truncated.len()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn prompt_contains_channel_capabilities() {
|
||||||
|
let ws = make_workspace();
|
||||||
|
let prompt = build_system_prompt(ws.path(), "model", &[], &[], None, None);
|
||||||
|
|
||||||
|
assert!(
|
||||||
|
prompt.contains("## Channel Capabilities"),
|
||||||
|
"missing Channel Capabilities section"
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
prompt.contains("running as a Discord bot"),
|
||||||
|
"missing Discord context"
|
||||||
|
);
|
||||||
|
assert!(
|
||||||
|
prompt.contains("NEVER repeat, describe, or echo credentials"),
|
||||||
|
"missing security instruction"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn prompt_workspace_path() {
|
fn prompt_workspace_path() {
|
||||||
let ws = make_workspace();
|
let ws = make_workspace();
|
||||||
|
|
@ -1583,6 +1627,7 @@ mod tests {
|
||||||
let msg = traits::ChannelMessage {
|
let msg = traits::ChannelMessage {
|
||||||
id: "msg_abc123".into(),
|
id: "msg_abc123".into(),
|
||||||
sender: "U123".into(),
|
sender: "U123".into(),
|
||||||
|
reply_target: "C456".into(),
|
||||||
content: "hello".into(),
|
content: "hello".into(),
|
||||||
channel: "slack".into(),
|
channel: "slack".into(),
|
||||||
timestamp: 1,
|
timestamp: 1,
|
||||||
|
|
@ -1596,6 +1641,7 @@ mod tests {
|
||||||
let msg1 = traits::ChannelMessage {
|
let msg1 = traits::ChannelMessage {
|
||||||
id: "msg_1".into(),
|
id: "msg_1".into(),
|
||||||
sender: "U123".into(),
|
sender: "U123".into(),
|
||||||
|
reply_target: "C456".into(),
|
||||||
content: "first".into(),
|
content: "first".into(),
|
||||||
channel: "slack".into(),
|
channel: "slack".into(),
|
||||||
timestamp: 1,
|
timestamp: 1,
|
||||||
|
|
@ -1603,6 +1649,7 @@ mod tests {
|
||||||
let msg2 = traits::ChannelMessage {
|
let msg2 = traits::ChannelMessage {
|
||||||
id: "msg_2".into(),
|
id: "msg_2".into(),
|
||||||
sender: "U123".into(),
|
sender: "U123".into(),
|
||||||
|
reply_target: "C456".into(),
|
||||||
content: "second".into(),
|
content: "second".into(),
|
||||||
channel: "slack".into(),
|
channel: "slack".into(),
|
||||||
timestamp: 2,
|
timestamp: 2,
|
||||||
|
|
@ -1622,6 +1669,7 @@ mod tests {
|
||||||
let msg1 = traits::ChannelMessage {
|
let msg1 = traits::ChannelMessage {
|
||||||
id: "msg_1".into(),
|
id: "msg_1".into(),
|
||||||
sender: "U123".into(),
|
sender: "U123".into(),
|
||||||
|
reply_target: "C456".into(),
|
||||||
content: "I'm Paul".into(),
|
content: "I'm Paul".into(),
|
||||||
channel: "slack".into(),
|
channel: "slack".into(),
|
||||||
timestamp: 1,
|
timestamp: 1,
|
||||||
|
|
@ -1629,6 +1677,7 @@ mod tests {
|
||||||
let msg2 = traits::ChannelMessage {
|
let msg2 = traits::ChannelMessage {
|
||||||
id: "msg_2".into(),
|
id: "msg_2".into(),
|
||||||
sender: "U123".into(),
|
sender: "U123".into(),
|
||||||
|
reply_target: "C456".into(),
|
||||||
content: "I'm 45".into(),
|
content: "I'm 45".into(),
|
||||||
channel: "slack".into(),
|
channel: "slack".into(),
|
||||||
timestamp: 2,
|
timestamp: 2,
|
||||||
|
|
@ -1638,6 +1687,7 @@ mod tests {
|
||||||
&conversation_memory_key(&msg1),
|
&conversation_memory_key(&msg1),
|
||||||
&msg1.content,
|
&msg1.content,
|
||||||
MemoryCategory::Conversation,
|
MemoryCategory::Conversation,
|
||||||
|
None,
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
@ -1645,13 +1695,14 @@ mod tests {
|
||||||
&conversation_memory_key(&msg2),
|
&conversation_memory_key(&msg2),
|
||||||
&msg2.content,
|
&msg2.content,
|
||||||
MemoryCategory::Conversation,
|
MemoryCategory::Conversation,
|
||||||
|
None,
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
assert_eq!(mem.count().await.unwrap(), 2);
|
assert_eq!(mem.count().await.unwrap(), 2);
|
||||||
|
|
||||||
let recalled = mem.recall("45", 5).await.unwrap();
|
let recalled = mem.recall("45", 5, None).await.unwrap();
|
||||||
assert!(recalled.iter().any(|entry| entry.content.contains("45")));
|
assert!(recalled.iter().any(|entry| entry.content.contains("45")));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1659,7 +1710,7 @@ mod tests {
|
||||||
async fn build_memory_context_includes_recalled_entries() {
|
async fn build_memory_context_includes_recalled_entries() {
|
||||||
let tmp = TempDir::new().unwrap();
|
let tmp = TempDir::new().unwrap();
|
||||||
let mem = SqliteMemory::new(tmp.path()).unwrap();
|
let mem = SqliteMemory::new(tmp.path()).unwrap();
|
||||||
mem.store("age_fact", "Age is 45", MemoryCategory::Conversation)
|
mem.store("age_fact", "Age is 45", MemoryCategory::Conversation, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -161,6 +161,7 @@ impl Channel for SlackChannel {
|
||||||
let channel_msg = ChannelMessage {
|
let channel_msg = ChannelMessage {
|
||||||
id: format!("slack_{channel_id}_{ts}"),
|
id: format!("slack_{channel_id}_{ts}"),
|
||||||
sender: user.to_string(),
|
sender: user.to_string(),
|
||||||
|
reply_target: channel_id.clone(),
|
||||||
content: text.to_string(),
|
content: text.to_string(),
|
||||||
channel: "slack".to_string(),
|
channel: "slack".to_string(),
|
||||||
timestamp: std::time::SystemTime::now()
|
timestamp: std::time::SystemTime::now()
|
||||||
|
|
|
||||||
|
|
@ -51,6 +51,133 @@ fn split_message_for_telegram(message: &str) -> Vec<String> {
|
||||||
chunks
|
chunks
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
|
enum TelegramAttachmentKind {
|
||||||
|
Image,
|
||||||
|
Document,
|
||||||
|
Video,
|
||||||
|
Audio,
|
||||||
|
Voice,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||||
|
struct TelegramAttachment {
|
||||||
|
kind: TelegramAttachmentKind,
|
||||||
|
target: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl TelegramAttachmentKind {
|
||||||
|
fn from_marker(marker: &str) -> Option<Self> {
|
||||||
|
match marker.trim().to_ascii_uppercase().as_str() {
|
||||||
|
"IMAGE" | "PHOTO" => Some(Self::Image),
|
||||||
|
"DOCUMENT" | "FILE" => Some(Self::Document),
|
||||||
|
"VIDEO" => Some(Self::Video),
|
||||||
|
"AUDIO" => Some(Self::Audio),
|
||||||
|
"VOICE" => Some(Self::Voice),
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn is_http_url(target: &str) -> bool {
|
||||||
|
target.starts_with("http://") || target.starts_with("https://")
|
||||||
|
}
|
||||||
|
|
||||||
|
fn infer_attachment_kind_from_target(target: &str) -> Option<TelegramAttachmentKind> {
|
||||||
|
let normalized = target
|
||||||
|
.split('?')
|
||||||
|
.next()
|
||||||
|
.unwrap_or(target)
|
||||||
|
.split('#')
|
||||||
|
.next()
|
||||||
|
.unwrap_or(target);
|
||||||
|
|
||||||
|
let extension = Path::new(normalized)
|
||||||
|
.extension()
|
||||||
|
.and_then(|ext| ext.to_str())?
|
||||||
|
.to_ascii_lowercase();
|
||||||
|
|
||||||
|
match extension.as_str() {
|
||||||
|
"png" | "jpg" | "jpeg" | "gif" | "webp" | "bmp" => Some(TelegramAttachmentKind::Image),
|
||||||
|
"mp4" | "mov" | "mkv" | "avi" | "webm" => Some(TelegramAttachmentKind::Video),
|
||||||
|
"mp3" | "m4a" | "wav" | "flac" => Some(TelegramAttachmentKind::Audio),
|
||||||
|
"ogg" | "oga" | "opus" => Some(TelegramAttachmentKind::Voice),
|
||||||
|
"pdf" | "txt" | "md" | "csv" | "json" | "zip" | "tar" | "gz" | "doc" | "docx" | "xls"
|
||||||
|
| "xlsx" | "ppt" | "pptx" => Some(TelegramAttachmentKind::Document),
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_path_only_attachment(message: &str) -> Option<TelegramAttachment> {
|
||||||
|
let trimmed = message.trim();
|
||||||
|
if trimmed.is_empty() || trimmed.contains('\n') {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
let candidate = trimmed.trim_matches(|c| matches!(c, '`' | '"' | '\''));
|
||||||
|
if candidate.chars().any(char::is_whitespace) {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
let candidate = candidate.strip_prefix("file://").unwrap_or(candidate);
|
||||||
|
let kind = infer_attachment_kind_from_target(candidate)?;
|
||||||
|
|
||||||
|
if !is_http_url(candidate) && !Path::new(candidate).exists() {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
Some(TelegramAttachment {
|
||||||
|
kind,
|
||||||
|
target: candidate.to_string(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_attachment_markers(message: &str) -> (String, Vec<TelegramAttachment>) {
|
||||||
|
let mut cleaned = String::with_capacity(message.len());
|
||||||
|
let mut attachments = Vec::new();
|
||||||
|
let mut cursor = 0;
|
||||||
|
|
||||||
|
while cursor < message.len() {
|
||||||
|
let Some(open_rel) = message[cursor..].find('[') else {
|
||||||
|
cleaned.push_str(&message[cursor..]);
|
||||||
|
break;
|
||||||
|
};
|
||||||
|
|
||||||
|
let open = cursor + open_rel;
|
||||||
|
cleaned.push_str(&message[cursor..open]);
|
||||||
|
|
||||||
|
let Some(close_rel) = message[open..].find(']') else {
|
||||||
|
cleaned.push_str(&message[open..]);
|
||||||
|
break;
|
||||||
|
};
|
||||||
|
|
||||||
|
let close = open + close_rel;
|
||||||
|
let marker = &message[open + 1..close];
|
||||||
|
|
||||||
|
let parsed = marker.split_once(':').and_then(|(kind, target)| {
|
||||||
|
let kind = TelegramAttachmentKind::from_marker(kind)?;
|
||||||
|
let target = target.trim();
|
||||||
|
if target.is_empty() {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
Some(TelegramAttachment {
|
||||||
|
kind,
|
||||||
|
target: target.to_string(),
|
||||||
|
})
|
||||||
|
});
|
||||||
|
|
||||||
|
if let Some(attachment) = parsed {
|
||||||
|
attachments.push(attachment);
|
||||||
|
} else {
|
||||||
|
cleaned.push_str(&message[open..=close]);
|
||||||
|
}
|
||||||
|
|
||||||
|
cursor = close + 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
(cleaned.trim().to_string(), attachments)
|
||||||
|
}
|
||||||
|
|
||||||
/// Telegram channel — long-polls the Bot API for updates
|
/// Telegram channel — long-polls the Bot API for updates
|
||||||
pub struct TelegramChannel {
|
pub struct TelegramChannel {
|
||||||
bot_token: String,
|
bot_token: String,
|
||||||
|
|
@ -82,6 +209,216 @@ impl TelegramChannel {
|
||||||
identities.into_iter().any(|id| self.is_user_allowed(id))
|
identities.into_iter().any(|id| self.is_user_allowed(id))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn parse_update_message(&self, update: &serde_json::Value) -> Option<ChannelMessage> {
|
||||||
|
let message = update.get("message")?;
|
||||||
|
|
||||||
|
let text = message.get("text").and_then(serde_json::Value::as_str)?;
|
||||||
|
|
||||||
|
let username = message
|
||||||
|
.get("from")
|
||||||
|
.and_then(|from| from.get("username"))
|
||||||
|
.and_then(serde_json::Value::as_str)
|
||||||
|
.unwrap_or("unknown")
|
||||||
|
.to_string();
|
||||||
|
|
||||||
|
let user_id = message
|
||||||
|
.get("from")
|
||||||
|
.and_then(|from| from.get("id"))
|
||||||
|
.and_then(serde_json::Value::as_i64)
|
||||||
|
.map(|id| id.to_string());
|
||||||
|
|
||||||
|
let sender_identity = if username == "unknown" {
|
||||||
|
user_id.clone().unwrap_or_else(|| "unknown".to_string())
|
||||||
|
} else {
|
||||||
|
username.clone()
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut identities = vec![username.as_str()];
|
||||||
|
if let Some(id) = user_id.as_deref() {
|
||||||
|
identities.push(id);
|
||||||
|
}
|
||||||
|
|
||||||
|
if !self.is_any_user_allowed(identities.iter().copied()) {
|
||||||
|
tracing::warn!(
|
||||||
|
"Telegram: ignoring message from unauthorized user: username={username}, user_id={}. \
|
||||||
|
Allowlist Telegram @username or numeric user ID, then run `zeroclaw onboard --channels-only`.",
|
||||||
|
user_id.as_deref().unwrap_or("unknown")
|
||||||
|
);
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
let chat_id = message
|
||||||
|
.get("chat")
|
||||||
|
.and_then(|chat| chat.get("id"))
|
||||||
|
.and_then(serde_json::Value::as_i64)
|
||||||
|
.map(|id| id.to_string())?;
|
||||||
|
|
||||||
|
let message_id = message
|
||||||
|
.get("message_id")
|
||||||
|
.and_then(serde_json::Value::as_i64)
|
||||||
|
.unwrap_or(0);
|
||||||
|
|
||||||
|
Some(ChannelMessage {
|
||||||
|
id: format!("telegram_{chat_id}_{message_id}"),
|
||||||
|
sender: sender_identity,
|
||||||
|
reply_target: chat_id,
|
||||||
|
content: text.to_string(),
|
||||||
|
channel: "telegram".to_string(),
|
||||||
|
timestamp: std::time::SystemTime::now()
|
||||||
|
.duration_since(std::time::UNIX_EPOCH)
|
||||||
|
.unwrap_or_default()
|
||||||
|
.as_secs(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn send_text_chunks(&self, message: &str, chat_id: &str) -> anyhow::Result<()> {
|
||||||
|
let chunks = split_message_for_telegram(message);
|
||||||
|
|
||||||
|
for (index, chunk) in chunks.iter().enumerate() {
|
||||||
|
let text = if chunks.len() > 1 {
|
||||||
|
if index == 0 {
|
||||||
|
format!("{chunk}\n\n(continues...)")
|
||||||
|
} else if index == chunks.len() - 1 {
|
||||||
|
format!("(continued)\n\n{chunk}")
|
||||||
|
} else {
|
||||||
|
format!("(continued)\n\n{chunk}\n\n(continues...)")
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
chunk.to_string()
|
||||||
|
};
|
||||||
|
|
||||||
|
let markdown_body = serde_json::json!({
|
||||||
|
"chat_id": chat_id,
|
||||||
|
"text": text,
|
||||||
|
"parse_mode": "Markdown"
|
||||||
|
});
|
||||||
|
|
||||||
|
let markdown_resp = self
|
||||||
|
.client
|
||||||
|
.post(self.api_url("sendMessage"))
|
||||||
|
.json(&markdown_body)
|
||||||
|
.send()
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
if markdown_resp.status().is_success() {
|
||||||
|
if index < chunks.len() - 1 {
|
||||||
|
tokio::time::sleep(Duration::from_millis(100)).await;
|
||||||
|
}
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let markdown_status = markdown_resp.status();
|
||||||
|
let markdown_err = markdown_resp.text().await.unwrap_or_default();
|
||||||
|
tracing::warn!(
|
||||||
|
status = ?markdown_status,
|
||||||
|
"Telegram sendMessage with Markdown failed; retrying without parse_mode"
|
||||||
|
);
|
||||||
|
|
||||||
|
let plain_body = serde_json::json!({
|
||||||
|
"chat_id": chat_id,
|
||||||
|
"text": text,
|
||||||
|
});
|
||||||
|
let plain_resp = self
|
||||||
|
.client
|
||||||
|
.post(self.api_url("sendMessage"))
|
||||||
|
.json(&plain_body)
|
||||||
|
.send()
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
if !plain_resp.status().is_success() {
|
||||||
|
let plain_status = plain_resp.status();
|
||||||
|
let plain_err = plain_resp.text().await.unwrap_or_default();
|
||||||
|
anyhow::bail!(
|
||||||
|
"Telegram sendMessage failed (markdown {}: {}; plain {}: {})",
|
||||||
|
markdown_status,
|
||||||
|
markdown_err,
|
||||||
|
plain_status,
|
||||||
|
plain_err
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
if index < chunks.len() - 1 {
|
||||||
|
tokio::time::sleep(Duration::from_millis(100)).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn send_media_by_url(
|
||||||
|
&self,
|
||||||
|
method: &str,
|
||||||
|
media_field: &str,
|
||||||
|
chat_id: &str,
|
||||||
|
url: &str,
|
||||||
|
caption: Option<&str>,
|
||||||
|
) -> anyhow::Result<()> {
|
||||||
|
let mut body = serde_json::json!({
|
||||||
|
"chat_id": chat_id,
|
||||||
|
});
|
||||||
|
body[media_field] = serde_json::Value::String(url.to_string());
|
||||||
|
|
||||||
|
if let Some(cap) = caption {
|
||||||
|
body["caption"] = serde_json::Value::String(cap.to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
let resp = self
|
||||||
|
.client
|
||||||
|
.post(self.api_url(method))
|
||||||
|
.json(&body)
|
||||||
|
.send()
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
if !resp.status().is_success() {
|
||||||
|
let err = resp.text().await?;
|
||||||
|
anyhow::bail!("Telegram {method} by URL failed: {err}");
|
||||||
|
}
|
||||||
|
|
||||||
|
tracing::info!("Telegram {method} sent to {chat_id}: {url}");
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn send_attachment(
|
||||||
|
&self,
|
||||||
|
chat_id: &str,
|
||||||
|
attachment: &TelegramAttachment,
|
||||||
|
) -> anyhow::Result<()> {
|
||||||
|
let target = attachment.target.trim();
|
||||||
|
|
||||||
|
if is_http_url(target) {
|
||||||
|
return match attachment.kind {
|
||||||
|
TelegramAttachmentKind::Image => {
|
||||||
|
self.send_photo_by_url(chat_id, target, None).await
|
||||||
|
}
|
||||||
|
TelegramAttachmentKind::Document => {
|
||||||
|
self.send_document_by_url(chat_id, target, None).await
|
||||||
|
}
|
||||||
|
TelegramAttachmentKind::Video => {
|
||||||
|
self.send_video_by_url(chat_id, target, None).await
|
||||||
|
}
|
||||||
|
TelegramAttachmentKind::Audio => {
|
||||||
|
self.send_audio_by_url(chat_id, target, None).await
|
||||||
|
}
|
||||||
|
TelegramAttachmentKind::Voice => {
|
||||||
|
self.send_voice_by_url(chat_id, target, None).await
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
let path = Path::new(target);
|
||||||
|
if !path.exists() {
|
||||||
|
anyhow::bail!("Telegram attachment path not found: {target}");
|
||||||
|
}
|
||||||
|
|
||||||
|
match attachment.kind {
|
||||||
|
TelegramAttachmentKind::Image => self.send_photo(chat_id, path, None).await,
|
||||||
|
TelegramAttachmentKind::Document => self.send_document(chat_id, path, None).await,
|
||||||
|
TelegramAttachmentKind::Video => self.send_video(chat_id, path, None).await,
|
||||||
|
TelegramAttachmentKind::Audio => self.send_audio(chat_id, path, None).await,
|
||||||
|
TelegramAttachmentKind::Voice => self.send_voice(chat_id, path, None).await,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Send a document/file to a Telegram chat
|
/// Send a document/file to a Telegram chat
|
||||||
pub async fn send_document(
|
pub async fn send_document(
|
||||||
&self,
|
&self,
|
||||||
|
|
@ -408,6 +745,39 @@ impl TelegramChannel {
|
||||||
tracing::info!("Telegram photo (URL) sent to {chat_id}: {url}");
|
tracing::info!("Telegram photo (URL) sent to {chat_id}: {url}");
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Send a video by URL (Telegram will download it)
|
||||||
|
pub async fn send_video_by_url(
|
||||||
|
&self,
|
||||||
|
chat_id: &str,
|
||||||
|
url: &str,
|
||||||
|
caption: Option<&str>,
|
||||||
|
) -> anyhow::Result<()> {
|
||||||
|
self.send_media_by_url("sendVideo", "video", chat_id, url, caption)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Send an audio file by URL (Telegram will download it)
|
||||||
|
pub async fn send_audio_by_url(
|
||||||
|
&self,
|
||||||
|
chat_id: &str,
|
||||||
|
url: &str,
|
||||||
|
caption: Option<&str>,
|
||||||
|
) -> anyhow::Result<()> {
|
||||||
|
self.send_media_by_url("sendAudio", "audio", chat_id, url, caption)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Send a voice message by URL (Telegram will download it)
|
||||||
|
pub async fn send_voice_by_url(
|
||||||
|
&self,
|
||||||
|
chat_id: &str,
|
||||||
|
url: &str,
|
||||||
|
caption: Option<&str>,
|
||||||
|
) -> anyhow::Result<()> {
|
||||||
|
self.send_media_by_url("sendVoice", "voice", chat_id, url, caption)
|
||||||
|
.await
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
|
|
@ -417,82 +787,27 @@ impl Channel for TelegramChannel {
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn send(&self, message: &str, chat_id: &str) -> anyhow::Result<()> {
|
async fn send(&self, message: &str, chat_id: &str) -> anyhow::Result<()> {
|
||||||
// Split message if it exceeds Telegram's 4096 character limit
|
let (text_without_markers, attachments) = parse_attachment_markers(message);
|
||||||
let chunks = split_message_for_telegram(message);
|
|
||||||
|
|
||||||
for (i, chunk) in chunks.iter().enumerate() {
|
if !attachments.is_empty() {
|
||||||
// Add continuation marker for multi-part messages
|
if !text_without_markers.is_empty() {
|
||||||
let text = if chunks.len() > 1 {
|
self.send_text_chunks(&text_without_markers, chat_id)
|
||||||
if i == 0 {
|
.await?;
|
||||||
format!("{chunk}\n\n(continues...)")
|
|
||||||
} else if i == chunks.len() - 1 {
|
|
||||||
format!("(continued)\n\n{chunk}")
|
|
||||||
} else {
|
|
||||||
format!("(continued)\n\n{chunk}\n\n(continues...)")
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
chunk.to_string()
|
|
||||||
};
|
|
||||||
|
|
||||||
let markdown_body = serde_json::json!({
|
|
||||||
"chat_id": chat_id,
|
|
||||||
"text": text,
|
|
||||||
"parse_mode": "Markdown"
|
|
||||||
});
|
|
||||||
|
|
||||||
let markdown_resp = self
|
|
||||||
.client
|
|
||||||
.post(self.api_url("sendMessage"))
|
|
||||||
.json(&markdown_body)
|
|
||||||
.send()
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
if markdown_resp.status().is_success() {
|
|
||||||
// Small delay between chunks to avoid rate limiting
|
|
||||||
if i < chunks.len() - 1 {
|
|
||||||
tokio::time::sleep(Duration::from_millis(100)).await;
|
|
||||||
}
|
|
||||||
continue;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
let markdown_status = markdown_resp.status();
|
for attachment in &attachments {
|
||||||
let markdown_err = markdown_resp.text().await.unwrap_or_default();
|
self.send_attachment(chat_id, attachment).await?;
|
||||||
tracing::warn!(
|
|
||||||
status = ?markdown_status,
|
|
||||||
"Telegram sendMessage with Markdown failed; retrying without parse_mode"
|
|
||||||
);
|
|
||||||
|
|
||||||
// Retry without parse_mode as a compatibility fallback.
|
|
||||||
let plain_body = serde_json::json!({
|
|
||||||
"chat_id": chat_id,
|
|
||||||
"text": text,
|
|
||||||
});
|
|
||||||
let plain_resp = self
|
|
||||||
.client
|
|
||||||
.post(self.api_url("sendMessage"))
|
|
||||||
.json(&plain_body)
|
|
||||||
.send()
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
if !plain_resp.status().is_success() {
|
|
||||||
let plain_status = plain_resp.status();
|
|
||||||
let plain_err = plain_resp.text().await.unwrap_or_default();
|
|
||||||
anyhow::bail!(
|
|
||||||
"Telegram sendMessage failed (markdown {}: {}; plain {}: {})",
|
|
||||||
markdown_status,
|
|
||||||
markdown_err,
|
|
||||||
plain_status,
|
|
||||||
plain_err
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Small delay between chunks to avoid rate limiting
|
return Ok(());
|
||||||
if i < chunks.len() - 1 {
|
|
||||||
tokio::time::sleep(Duration::from_millis(100)).await;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
if let Some(attachment) = parse_path_only_attachment(message) {
|
||||||
|
self.send_attachment(chat_id, &attachment).await?;
|
||||||
|
return Ok(());
|
||||||
|
}
|
||||||
|
|
||||||
|
self.send_text_chunks(message, chat_id).await
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn listen(&self, tx: tokio::sync::mpsc::Sender<ChannelMessage>) -> anyhow::Result<()> {
|
async fn listen(&self, tx: tokio::sync::mpsc::Sender<ChannelMessage>) -> anyhow::Result<()> {
|
||||||
|
|
@ -533,59 +848,13 @@ impl Channel for TelegramChannel {
|
||||||
offset = uid + 1;
|
offset = uid + 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
let Some(message) = update.get("message") else {
|
let Some(msg) = self.parse_update_message(update) else {
|
||||||
continue;
|
continue;
|
||||||
};
|
};
|
||||||
|
|
||||||
let Some(text) = message.get("text").and_then(serde_json::Value::as_str) else {
|
|
||||||
continue;
|
|
||||||
};
|
|
||||||
|
|
||||||
let username_opt = message
|
|
||||||
.get("from")
|
|
||||||
.and_then(|f| f.get("username"))
|
|
||||||
.and_then(|u| u.as_str());
|
|
||||||
let username = username_opt.unwrap_or("unknown");
|
|
||||||
|
|
||||||
let user_id = message
|
|
||||||
.get("from")
|
|
||||||
.and_then(|f| f.get("id"))
|
|
||||||
.and_then(serde_json::Value::as_i64);
|
|
||||||
let user_id_str = user_id.map(|id| id.to_string());
|
|
||||||
|
|
||||||
let mut identities = vec![username];
|
|
||||||
if let Some(ref id) = user_id_str {
|
|
||||||
identities.push(id.as_str());
|
|
||||||
}
|
|
||||||
|
|
||||||
if !self.is_any_user_allowed(identities.iter().copied()) {
|
|
||||||
tracing::warn!(
|
|
||||||
"Telegram: ignoring message from unauthorized user: username={username}, user_id={}. \
|
|
||||||
Allowlist Telegram @username or numeric user ID, then run `zeroclaw onboard --channels-only`.",
|
|
||||||
user_id_str.as_deref().unwrap_or("unknown")
|
|
||||||
);
|
|
||||||
continue;
|
|
||||||
}
|
|
||||||
|
|
||||||
let chat_id = message
|
|
||||||
.get("chat")
|
|
||||||
.and_then(|c| c.get("id"))
|
|
||||||
.and_then(serde_json::Value::as_i64)
|
|
||||||
.map(|id| id.to_string());
|
|
||||||
|
|
||||||
let Some(chat_id) = chat_id else {
|
|
||||||
tracing::warn!("Telegram: missing chat_id in message, skipping");
|
|
||||||
continue;
|
|
||||||
};
|
|
||||||
|
|
||||||
let message_id = message
|
|
||||||
.get("message_id")
|
|
||||||
.and_then(|v| v.as_i64())
|
|
||||||
.unwrap_or(0);
|
|
||||||
|
|
||||||
// Send "typing" indicator immediately when we receive a message
|
// Send "typing" indicator immediately when we receive a message
|
||||||
let typing_body = serde_json::json!({
|
let typing_body = serde_json::json!({
|
||||||
"chat_id": &chat_id,
|
"chat_id": &msg.reply_target,
|
||||||
"action": "typing"
|
"action": "typing"
|
||||||
});
|
});
|
||||||
let _ = self
|
let _ = self
|
||||||
|
|
@ -595,17 +864,6 @@ Allowlist Telegram @username or numeric user ID, then run `zeroclaw onboard --ch
|
||||||
.send()
|
.send()
|
||||||
.await; // Ignore errors for typing indicator
|
.await; // Ignore errors for typing indicator
|
||||||
|
|
||||||
let msg = ChannelMessage {
|
|
||||||
id: format!("telegram_{chat_id}_{message_id}"),
|
|
||||||
sender: username.to_string(),
|
|
||||||
content: text.to_string(),
|
|
||||||
channel: "telegram".to_string(),
|
|
||||||
timestamp: std::time::SystemTime::now()
|
|
||||||
.duration_since(std::time::UNIX_EPOCH)
|
|
||||||
.unwrap_or_default()
|
|
||||||
.as_secs(),
|
|
||||||
};
|
|
||||||
|
|
||||||
if tx.send(msg).await.is_err() {
|
if tx.send(msg).await.is_err() {
|
||||||
return Ok(());
|
return Ok(());
|
||||||
}
|
}
|
||||||
|
|
@ -716,6 +974,107 @@ mod tests {
|
||||||
assert!(!ch.is_any_user_allowed(["unknown", "123456789"]));
|
assert!(!ch.is_any_user_allowed(["unknown", "123456789"]));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_attachment_markers_extracts_multiple_types() {
|
||||||
|
let message = "Here are files [IMAGE:/tmp/a.png] and [DOCUMENT:https://example.com/a.pdf]";
|
||||||
|
let (cleaned, attachments) = parse_attachment_markers(message);
|
||||||
|
|
||||||
|
assert_eq!(cleaned, "Here are files and");
|
||||||
|
assert_eq!(attachments.len(), 2);
|
||||||
|
assert_eq!(attachments[0].kind, TelegramAttachmentKind::Image);
|
||||||
|
assert_eq!(attachments[0].target, "/tmp/a.png");
|
||||||
|
assert_eq!(attachments[1].kind, TelegramAttachmentKind::Document);
|
||||||
|
assert_eq!(attachments[1].target, "https://example.com/a.pdf");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_attachment_markers_keeps_invalid_markers_in_text() {
|
||||||
|
let message = "Report [UNKNOWN:/tmp/a.bin]";
|
||||||
|
let (cleaned, attachments) = parse_attachment_markers(message);
|
||||||
|
|
||||||
|
assert_eq!(cleaned, "Report [UNKNOWN:/tmp/a.bin]");
|
||||||
|
assert!(attachments.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_path_only_attachment_detects_existing_file() {
|
||||||
|
let dir = tempfile::tempdir().unwrap();
|
||||||
|
let image_path = dir.path().join("snap.png");
|
||||||
|
std::fs::write(&image_path, b"fake-png").unwrap();
|
||||||
|
|
||||||
|
let parsed = parse_path_only_attachment(image_path.to_string_lossy().as_ref())
|
||||||
|
.expect("expected attachment");
|
||||||
|
|
||||||
|
assert_eq!(parsed.kind, TelegramAttachmentKind::Image);
|
||||||
|
assert_eq!(parsed.target, image_path.to_string_lossy());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_path_only_attachment_rejects_sentence_text() {
|
||||||
|
assert!(parse_path_only_attachment("Screenshot saved to /tmp/snap.png").is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn infer_attachment_kind_from_target_detects_document_extension() {
|
||||||
|
assert_eq!(
|
||||||
|
infer_attachment_kind_from_target("https://example.com/files/specs.pdf?download=1"),
|
||||||
|
Some(TelegramAttachmentKind::Document)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_update_message_uses_chat_id_as_reply_target() {
|
||||||
|
let ch = TelegramChannel::new("token".into(), vec!["*".into()]);
|
||||||
|
let update = serde_json::json!({
|
||||||
|
"update_id": 1,
|
||||||
|
"message": {
|
||||||
|
"message_id": 33,
|
||||||
|
"text": "hello",
|
||||||
|
"from": {
|
||||||
|
"id": 555,
|
||||||
|
"username": "alice"
|
||||||
|
},
|
||||||
|
"chat": {
|
||||||
|
"id": -100200300
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let msg = ch
|
||||||
|
.parse_update_message(&update)
|
||||||
|
.expect("message should parse");
|
||||||
|
|
||||||
|
assert_eq!(msg.sender, "alice");
|
||||||
|
assert_eq!(msg.reply_target, "-100200300");
|
||||||
|
assert_eq!(msg.content, "hello");
|
||||||
|
assert_eq!(msg.id, "telegram_-100200300_33");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn parse_update_message_allows_numeric_id_without_username() {
|
||||||
|
let ch = TelegramChannel::new("token".into(), vec!["555".into()]);
|
||||||
|
let update = serde_json::json!({
|
||||||
|
"update_id": 2,
|
||||||
|
"message": {
|
||||||
|
"message_id": 9,
|
||||||
|
"text": "ping",
|
||||||
|
"from": {
|
||||||
|
"id": 555
|
||||||
|
},
|
||||||
|
"chat": {
|
||||||
|
"id": 12345
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let msg = ch
|
||||||
|
.parse_update_message(&update)
|
||||||
|
.expect("numeric allowlist should pass");
|
||||||
|
|
||||||
|
assert_eq!(msg.sender, "555");
|
||||||
|
assert_eq!(msg.reply_target, "12345");
|
||||||
|
}
|
||||||
|
|
||||||
// ── File sending API URL tests ──────────────────────────────────
|
// ── File sending API URL tests ──────────────────────────────────
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,7 @@ use async_trait::async_trait;
|
||||||
pub struct ChannelMessage {
|
pub struct ChannelMessage {
|
||||||
pub id: String,
|
pub id: String,
|
||||||
pub sender: String,
|
pub sender: String,
|
||||||
|
pub reply_target: String,
|
||||||
pub content: String,
|
pub content: String,
|
||||||
pub channel: String,
|
pub channel: String,
|
||||||
pub timestamp: u64,
|
pub timestamp: u64,
|
||||||
|
|
@ -62,6 +63,7 @@ mod tests {
|
||||||
tx.send(ChannelMessage {
|
tx.send(ChannelMessage {
|
||||||
id: "1".into(),
|
id: "1".into(),
|
||||||
sender: "tester".into(),
|
sender: "tester".into(),
|
||||||
|
reply_target: "tester".into(),
|
||||||
content: "hello".into(),
|
content: "hello".into(),
|
||||||
channel: "dummy".into(),
|
channel: "dummy".into(),
|
||||||
timestamp: 123,
|
timestamp: 123,
|
||||||
|
|
@ -76,6 +78,7 @@ mod tests {
|
||||||
let message = ChannelMessage {
|
let message = ChannelMessage {
|
||||||
id: "42".into(),
|
id: "42".into(),
|
||||||
sender: "alice".into(),
|
sender: "alice".into(),
|
||||||
|
reply_target: "alice".into(),
|
||||||
content: "ping".into(),
|
content: "ping".into(),
|
||||||
channel: "dummy".into(),
|
channel: "dummy".into(),
|
||||||
timestamp: 999,
|
timestamp: 999,
|
||||||
|
|
@ -84,6 +87,7 @@ mod tests {
|
||||||
let cloned = message.clone();
|
let cloned = message.clone();
|
||||||
assert_eq!(cloned.id, "42");
|
assert_eq!(cloned.id, "42");
|
||||||
assert_eq!(cloned.sender, "alice");
|
assert_eq!(cloned.sender, "alice");
|
||||||
|
assert_eq!(cloned.reply_target, "alice");
|
||||||
assert_eq!(cloned.content, "ping");
|
assert_eq!(cloned.content, "ping");
|
||||||
assert_eq!(cloned.channel, "dummy");
|
assert_eq!(cloned.channel, "dummy");
|
||||||
assert_eq!(cloned.timestamp, 999);
|
assert_eq!(cloned.timestamp, 999);
|
||||||
|
|
|
||||||
|
|
@ -10,7 +10,7 @@ use uuid::Uuid;
|
||||||
/// happens in the gateway when Meta sends webhook events.
|
/// happens in the gateway when Meta sends webhook events.
|
||||||
pub struct WhatsAppChannel {
|
pub struct WhatsAppChannel {
|
||||||
access_token: String,
|
access_token: String,
|
||||||
phone_number_id: String,
|
endpoint_id: String,
|
||||||
verify_token: String,
|
verify_token: String,
|
||||||
allowed_numbers: Vec<String>,
|
allowed_numbers: Vec<String>,
|
||||||
client: reqwest::Client,
|
client: reqwest::Client,
|
||||||
|
|
@ -19,13 +19,13 @@ pub struct WhatsAppChannel {
|
||||||
impl WhatsAppChannel {
|
impl WhatsAppChannel {
|
||||||
pub fn new(
|
pub fn new(
|
||||||
access_token: String,
|
access_token: String,
|
||||||
phone_number_id: String,
|
endpoint_id: String,
|
||||||
verify_token: String,
|
verify_token: String,
|
||||||
allowed_numbers: Vec<String>,
|
allowed_numbers: Vec<String>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
access_token,
|
access_token,
|
||||||
phone_number_id,
|
endpoint_id,
|
||||||
verify_token,
|
verify_token,
|
||||||
allowed_numbers,
|
allowed_numbers,
|
||||||
client: reqwest::Client::new(),
|
client: reqwest::Client::new(),
|
||||||
|
|
@ -119,6 +119,7 @@ impl WhatsAppChannel {
|
||||||
|
|
||||||
messages.push(ChannelMessage {
|
messages.push(ChannelMessage {
|
||||||
id: Uuid::new_v4().to_string(),
|
id: Uuid::new_v4().to_string(),
|
||||||
|
reply_target: normalized_from.clone(),
|
||||||
sender: normalized_from,
|
sender: normalized_from,
|
||||||
content,
|
content,
|
||||||
channel: "whatsapp".to_string(),
|
channel: "whatsapp".to_string(),
|
||||||
|
|
@ -142,7 +143,7 @@ impl Channel for WhatsAppChannel {
|
||||||
// WhatsApp Cloud API: POST to /v18.0/{phone_number_id}/messages
|
// WhatsApp Cloud API: POST to /v18.0/{phone_number_id}/messages
|
||||||
let url = format!(
|
let url = format!(
|
||||||
"https://graph.facebook.com/v18.0/{}/messages",
|
"https://graph.facebook.com/v18.0/{}/messages",
|
||||||
self.phone_number_id
|
self.endpoint_id
|
||||||
);
|
);
|
||||||
|
|
||||||
// Normalize recipient (remove leading + if present for API)
|
// Normalize recipient (remove leading + if present for API)
|
||||||
|
|
@ -162,7 +163,7 @@ impl Channel for WhatsAppChannel {
|
||||||
let resp = self
|
let resp = self
|
||||||
.client
|
.client
|
||||||
.post(&url)
|
.post(&url)
|
||||||
.header("Authorization", format!("Bearer {}", self.access_token))
|
.bearer_auth(&self.access_token)
|
||||||
.header("Content-Type", "application/json")
|
.header("Content-Type", "application/json")
|
||||||
.json(&body)
|
.json(&body)
|
||||||
.send()
|
.send()
|
||||||
|
|
@ -195,11 +196,11 @@ impl Channel for WhatsAppChannel {
|
||||||
|
|
||||||
async fn health_check(&self) -> bool {
|
async fn health_check(&self) -> bool {
|
||||||
// Check if we can reach the WhatsApp API
|
// Check if we can reach the WhatsApp API
|
||||||
let url = format!("https://graph.facebook.com/v18.0/{}", self.phone_number_id);
|
let url = format!("https://graph.facebook.com/v18.0/{}", self.endpoint_id);
|
||||||
|
|
||||||
self.client
|
self.client
|
||||||
.get(&url)
|
.get(&url)
|
||||||
.header("Authorization", format!("Bearer {}", self.access_token))
|
.bearer_auth(&self.access_token)
|
||||||
.send()
|
.send()
|
||||||
.await
|
.await
|
||||||
.map(|r| r.status().is_success())
|
.map(|r| r.status().is_success())
|
||||||
|
|
|
||||||
|
|
@ -37,9 +37,22 @@ mod tests {
|
||||||
guild_id: Some("123".into()),
|
guild_id: Some("123".into()),
|
||||||
allowed_users: vec![],
|
allowed_users: vec![],
|
||||||
listen_to_bots: false,
|
listen_to_bots: false,
|
||||||
|
mention_only: false,
|
||||||
|
};
|
||||||
|
|
||||||
|
let lark = LarkConfig {
|
||||||
|
app_id: "app-id".into(),
|
||||||
|
app_secret: "app-secret".into(),
|
||||||
|
encrypt_key: None,
|
||||||
|
verification_token: None,
|
||||||
|
allowed_users: vec![],
|
||||||
|
use_feishu: false,
|
||||||
|
receive_mode: crate::config::schema::LarkReceiveMode::Websocket,
|
||||||
|
port: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
assert_eq!(telegram.allowed_users.len(), 1);
|
assert_eq!(telegram.allowed_users.len(), 1);
|
||||||
assert_eq!(discord.guild_id.as_deref(), Some("123"));
|
assert_eq!(discord.guild_id.as_deref(), Some("123"));
|
||||||
|
assert_eq!(lark.app_id, "app-id");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -18,6 +18,8 @@ pub struct Config {
|
||||||
#[serde(skip)]
|
#[serde(skip)]
|
||||||
pub config_path: PathBuf,
|
pub config_path: PathBuf,
|
||||||
pub api_key: Option<String>,
|
pub api_key: Option<String>,
|
||||||
|
/// Base URL override for provider API (e.g. "http://10.0.0.1:11434" for remote Ollama)
|
||||||
|
pub api_url: Option<String>,
|
||||||
pub default_provider: Option<String>,
|
pub default_provider: Option<String>,
|
||||||
pub default_model: Option<String>,
|
pub default_model: Option<String>,
|
||||||
pub default_temperature: f64,
|
pub default_temperature: f64,
|
||||||
|
|
@ -1317,6 +1319,10 @@ pub struct DiscordConfig {
|
||||||
/// The bot still ignores its own messages to prevent feedback loops.
|
/// The bot still ignores its own messages to prevent feedback loops.
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub listen_to_bots: bool,
|
pub listen_to_bots: bool,
|
||||||
|
/// When true, only respond to messages that @-mention the bot.
|
||||||
|
/// Other messages in the guild are silently ignored.
|
||||||
|
#[serde(default)]
|
||||||
|
pub mention_only: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
|
@ -1395,8 +1401,20 @@ fn default_irc_port() -> u16 {
|
||||||
6697
|
6697
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Lark/Feishu configuration for messaging integration
|
/// How ZeroClaw receives events from Feishu / Lark.
|
||||||
/// Lark is the international version, Feishu is the Chinese version
|
///
|
||||||
|
/// - `websocket` (default) — persistent WSS long-connection; no public URL required.
|
||||||
|
/// - `webhook` — HTTP callback server; requires a public HTTPS endpoint.
|
||||||
|
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
|
||||||
|
#[serde(rename_all = "lowercase")]
|
||||||
|
pub enum LarkReceiveMode {
|
||||||
|
#[default]
|
||||||
|
Websocket,
|
||||||
|
Webhook,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Lark/Feishu configuration for messaging integration.
|
||||||
|
/// Lark is the international version; Feishu is the Chinese version.
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
pub struct LarkConfig {
|
pub struct LarkConfig {
|
||||||
/// App ID from Lark/Feishu developer console
|
/// App ID from Lark/Feishu developer console
|
||||||
|
|
@ -1415,6 +1433,13 @@ pub struct LarkConfig {
|
||||||
/// Whether to use the Feishu (Chinese) endpoint instead of Lark (International)
|
/// Whether to use the Feishu (Chinese) endpoint instead of Lark (International)
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
pub use_feishu: bool,
|
pub use_feishu: bool,
|
||||||
|
/// Event receive mode: "websocket" (default) or "webhook"
|
||||||
|
#[serde(default)]
|
||||||
|
pub receive_mode: LarkReceiveMode,
|
||||||
|
/// HTTP port for webhook mode only. Must be set when receive_mode = "webhook".
|
||||||
|
/// Not required (and ignored) for websocket mode.
|
||||||
|
#[serde(default)]
|
||||||
|
pub port: Option<u16>,
|
||||||
}
|
}
|
||||||
|
|
||||||
// ── Security Config ─────────────────────────────────────────────────
|
// ── Security Config ─────────────────────────────────────────────────
|
||||||
|
|
@ -1594,6 +1619,7 @@ impl Default for Config {
|
||||||
workspace_dir: zeroclaw_dir.join("workspace"),
|
workspace_dir: zeroclaw_dir.join("workspace"),
|
||||||
config_path: zeroclaw_dir.join("config.toml"),
|
config_path: zeroclaw_dir.join("config.toml"),
|
||||||
api_key: None,
|
api_key: None,
|
||||||
|
api_url: None,
|
||||||
default_provider: Some("openrouter".to_string()),
|
default_provider: Some("openrouter".to_string()),
|
||||||
default_model: Some("anthropic/claude-sonnet-4".to_string()),
|
default_model: Some("anthropic/claude-sonnet-4".to_string()),
|
||||||
default_temperature: 0.7,
|
default_temperature: 0.7,
|
||||||
|
|
@ -1623,35 +1649,146 @@ impl Default for Config {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Config {
|
fn default_config_and_workspace_dirs() -> Result<(PathBuf, PathBuf)> {
|
||||||
pub fn load_or_init() -> Result<Self> {
|
let home = UserDirs::new()
|
||||||
let home = UserDirs::new()
|
.map(|u| u.home_dir().to_path_buf())
|
||||||
.map(|u| u.home_dir().to_path_buf())
|
.context("Could not find home directory")?;
|
||||||
.context("Could not find home directory")?;
|
let config_dir = home.join(".zeroclaw");
|
||||||
let zeroclaw_dir = home.join(".zeroclaw");
|
Ok((config_dir.clone(), config_dir.join("workspace")))
|
||||||
let config_path = zeroclaw_dir.join("config.toml");
|
}
|
||||||
|
|
||||||
if !zeroclaw_dir.exists() {
|
fn resolve_config_dir_for_workspace(workspace_dir: &Path) -> PathBuf {
|
||||||
fs::create_dir_all(&zeroclaw_dir).context("Failed to create .zeroclaw directory")?;
|
let workspace_config_dir = workspace_dir.to_path_buf();
|
||||||
fs::create_dir_all(zeroclaw_dir.join("workspace"))
|
if workspace_config_dir.join("config.toml").exists() {
|
||||||
.context("Failed to create workspace directory")?;
|
return workspace_config_dir;
|
||||||
|
}
|
||||||
|
|
||||||
|
let legacy_config_dir = workspace_dir
|
||||||
|
.parent()
|
||||||
|
.map(|parent| parent.join(".zeroclaw"));
|
||||||
|
if let Some(legacy_dir) = legacy_config_dir {
|
||||||
|
if legacy_dir.join("config.toml").exists() {
|
||||||
|
return legacy_dir;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if workspace_dir
|
||||||
|
.file_name()
|
||||||
|
.is_some_and(|name| name == std::ffi::OsStr::new("workspace"))
|
||||||
|
{
|
||||||
|
return legacy_dir;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
workspace_config_dir
|
||||||
|
}
|
||||||
|
|
||||||
|
fn decrypt_optional_secret(
|
||||||
|
store: &crate::security::SecretStore,
|
||||||
|
value: &mut Option<String>,
|
||||||
|
field_name: &str,
|
||||||
|
) -> Result<()> {
|
||||||
|
if let Some(raw) = value.clone() {
|
||||||
|
if crate::security::SecretStore::is_encrypted(&raw) {
|
||||||
|
*value = Some(
|
||||||
|
store
|
||||||
|
.decrypt(&raw)
|
||||||
|
.with_context(|| format!("Failed to decrypt {field_name}"))?,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn encrypt_optional_secret(
|
||||||
|
store: &crate::security::SecretStore,
|
||||||
|
value: &mut Option<String>,
|
||||||
|
field_name: &str,
|
||||||
|
) -> Result<()> {
|
||||||
|
if let Some(raw) = value.clone() {
|
||||||
|
if !crate::security::SecretStore::is_encrypted(&raw) {
|
||||||
|
*value = Some(
|
||||||
|
store
|
||||||
|
.encrypt(&raw)
|
||||||
|
.with_context(|| format!("Failed to encrypt {field_name}"))?,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Config {
|
||||||
|
pub fn load_or_init() -> Result<Self> {
|
||||||
|
// Resolve workspace first so config loading can follow ZEROCLAW_WORKSPACE.
|
||||||
|
let (zeroclaw_dir, workspace_dir) = match std::env::var("ZEROCLAW_WORKSPACE") {
|
||||||
|
Ok(custom_workspace) if !custom_workspace.is_empty() => {
|
||||||
|
let workspace = PathBuf::from(custom_workspace);
|
||||||
|
(resolve_config_dir_for_workspace(&workspace), workspace)
|
||||||
|
}
|
||||||
|
_ => default_config_and_workspace_dirs()?,
|
||||||
|
};
|
||||||
|
|
||||||
|
let config_path = zeroclaw_dir.join("config.toml");
|
||||||
|
|
||||||
|
fs::create_dir_all(&zeroclaw_dir).context("Failed to create config directory")?;
|
||||||
|
fs::create_dir_all(&workspace_dir).context("Failed to create workspace directory")?;
|
||||||
|
|
||||||
if config_path.exists() {
|
if config_path.exists() {
|
||||||
|
// Warn if config file is world-readable (may contain API keys)
|
||||||
|
#[cfg(unix)]
|
||||||
|
{
|
||||||
|
use std::os::unix::fs::PermissionsExt;
|
||||||
|
if let Ok(meta) = fs::metadata(&config_path) {
|
||||||
|
if meta.permissions().mode() & 0o004 != 0 {
|
||||||
|
tracing::warn!(
|
||||||
|
"Config file {:?} is world-readable (mode {:o}). \
|
||||||
|
Consider restricting with: chmod 600 {:?}",
|
||||||
|
config_path,
|
||||||
|
meta.permissions().mode() & 0o777,
|
||||||
|
config_path,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
let contents =
|
let contents =
|
||||||
fs::read_to_string(&config_path).context("Failed to read config file")?;
|
fs::read_to_string(&config_path).context("Failed to read config file")?;
|
||||||
let mut config: Config =
|
let mut config: Config =
|
||||||
toml::from_str(&contents).context("Failed to parse config file")?;
|
toml::from_str(&contents).context("Failed to parse config file")?;
|
||||||
// Set computed paths that are skipped during serialization
|
// Set computed paths that are skipped during serialization
|
||||||
config.config_path = config_path.clone();
|
config.config_path = config_path.clone();
|
||||||
config.workspace_dir = zeroclaw_dir.join("workspace");
|
config.workspace_dir = workspace_dir;
|
||||||
|
let store = crate::security::SecretStore::new(&zeroclaw_dir, config.secrets.encrypt);
|
||||||
|
decrypt_optional_secret(&store, &mut config.api_key, "config.api_key")?;
|
||||||
|
decrypt_optional_secret(
|
||||||
|
&store,
|
||||||
|
&mut config.composio.api_key,
|
||||||
|
"config.composio.api_key",
|
||||||
|
)?;
|
||||||
|
|
||||||
|
decrypt_optional_secret(
|
||||||
|
&store,
|
||||||
|
&mut config.browser.computer_use.api_key,
|
||||||
|
"config.browser.computer_use.api_key",
|
||||||
|
)?;
|
||||||
|
|
||||||
|
for agent in config.agents.values_mut() {
|
||||||
|
decrypt_optional_secret(&store, &mut agent.api_key, "config.agents.*.api_key")?;
|
||||||
|
}
|
||||||
config.apply_env_overrides();
|
config.apply_env_overrides();
|
||||||
Ok(config)
|
Ok(config)
|
||||||
} else {
|
} else {
|
||||||
let mut config = Config::default();
|
let mut config = Config::default();
|
||||||
config.config_path = config_path.clone();
|
config.config_path = config_path.clone();
|
||||||
config.workspace_dir = zeroclaw_dir.join("workspace");
|
config.workspace_dir = workspace_dir;
|
||||||
config.save()?;
|
config.save()?;
|
||||||
|
|
||||||
|
// Restrict permissions on newly created config file (may contain API keys)
|
||||||
|
#[cfg(unix)]
|
||||||
|
{
|
||||||
|
use std::os::unix::fs::PermissionsExt;
|
||||||
|
let _ = fs::set_permissions(&config_path, fs::Permissions::from_mode(0o600));
|
||||||
|
}
|
||||||
|
|
||||||
config.apply_env_overrides();
|
config.apply_env_overrides();
|
||||||
Ok(config)
|
Ok(config)
|
||||||
}
|
}
|
||||||
|
|
@ -1732,23 +1869,29 @@ impl Config {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn save(&self) -> Result<()> {
|
pub fn save(&self) -> Result<()> {
|
||||||
// Encrypt agent API keys before serialization
|
// Encrypt secrets before serialization
|
||||||
let mut config_to_save = self.clone();
|
let mut config_to_save = self.clone();
|
||||||
let zeroclaw_dir = self
|
let zeroclaw_dir = self
|
||||||
.config_path
|
.config_path
|
||||||
.parent()
|
.parent()
|
||||||
.context("Config path must have a parent directory")?;
|
.context("Config path must have a parent directory")?;
|
||||||
let store = crate::security::SecretStore::new(zeroclaw_dir, self.secrets.encrypt);
|
let store = crate::security::SecretStore::new(zeroclaw_dir, self.secrets.encrypt);
|
||||||
|
|
||||||
|
encrypt_optional_secret(&store, &mut config_to_save.api_key, "config.api_key")?;
|
||||||
|
encrypt_optional_secret(
|
||||||
|
&store,
|
||||||
|
&mut config_to_save.composio.api_key,
|
||||||
|
"config.composio.api_key",
|
||||||
|
)?;
|
||||||
|
|
||||||
|
encrypt_optional_secret(
|
||||||
|
&store,
|
||||||
|
&mut config_to_save.browser.computer_use.api_key,
|
||||||
|
"config.browser.computer_use.api_key",
|
||||||
|
)?;
|
||||||
|
|
||||||
for agent in config_to_save.agents.values_mut() {
|
for agent in config_to_save.agents.values_mut() {
|
||||||
if let Some(ref plaintext_key) = agent.api_key {
|
encrypt_optional_secret(&store, &mut agent.api_key, "config.agents.*.api_key")?;
|
||||||
if !crate::security::SecretStore::is_encrypted(plaintext_key) {
|
|
||||||
agent.api_key = Some(
|
|
||||||
store
|
|
||||||
.encrypt(plaintext_key)
|
|
||||||
.context("Failed to encrypt agent API key")?,
|
|
||||||
);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
let toml_str =
|
let toml_str =
|
||||||
|
|
@ -1949,6 +2092,7 @@ default_temperature = 0.7
|
||||||
workspace_dir: PathBuf::from("/tmp/test/workspace"),
|
workspace_dir: PathBuf::from("/tmp/test/workspace"),
|
||||||
config_path: PathBuf::from("/tmp/test/config.toml"),
|
config_path: PathBuf::from("/tmp/test/config.toml"),
|
||||||
api_key: Some("sk-test-key".into()),
|
api_key: Some("sk-test-key".into()),
|
||||||
|
api_url: None,
|
||||||
default_provider: Some("openrouter".into()),
|
default_provider: Some("openrouter".into()),
|
||||||
default_model: Some("gpt-4o".into()),
|
default_model: Some("gpt-4o".into()),
|
||||||
default_temperature: 0.5,
|
default_temperature: 0.5,
|
||||||
|
|
@ -2091,6 +2235,7 @@ tool_dispatcher = "xml"
|
||||||
workspace_dir: dir.join("workspace"),
|
workspace_dir: dir.join("workspace"),
|
||||||
config_path: config_path.clone(),
|
config_path: config_path.clone(),
|
||||||
api_key: Some("sk-roundtrip".into()),
|
api_key: Some("sk-roundtrip".into()),
|
||||||
|
api_url: None,
|
||||||
default_provider: Some("openrouter".into()),
|
default_provider: Some("openrouter".into()),
|
||||||
default_model: Some("test-model".into()),
|
default_model: Some("test-model".into()),
|
||||||
default_temperature: 0.9,
|
default_temperature: 0.9,
|
||||||
|
|
@ -2123,13 +2268,82 @@ tool_dispatcher = "xml"
|
||||||
|
|
||||||
let contents = fs::read_to_string(&config_path).unwrap();
|
let contents = fs::read_to_string(&config_path).unwrap();
|
||||||
let loaded: Config = toml::from_str(&contents).unwrap();
|
let loaded: Config = toml::from_str(&contents).unwrap();
|
||||||
assert_eq!(loaded.api_key.as_deref(), Some("sk-roundtrip"));
|
assert!(loaded
|
||||||
|
.api_key
|
||||||
|
.as_deref()
|
||||||
|
.is_some_and(crate::security::SecretStore::is_encrypted));
|
||||||
|
let store = crate::security::SecretStore::new(&dir, true);
|
||||||
|
let decrypted = store.decrypt(loaded.api_key.as_deref().unwrap()).unwrap();
|
||||||
|
assert_eq!(decrypted, "sk-roundtrip");
|
||||||
assert_eq!(loaded.default_model.as_deref(), Some("test-model"));
|
assert_eq!(loaded.default_model.as_deref(), Some("test-model"));
|
||||||
assert!((loaded.default_temperature - 0.9).abs() < f64::EPSILON);
|
assert!((loaded.default_temperature - 0.9).abs() < f64::EPSILON);
|
||||||
|
|
||||||
let _ = fs::remove_dir_all(&dir);
|
let _ = fs::remove_dir_all(&dir);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn config_save_encrypts_nested_credentials() {
|
||||||
|
let dir = std::env::temp_dir().join(format!(
|
||||||
|
"zeroclaw_test_nested_credentials_{}",
|
||||||
|
uuid::Uuid::new_v4()
|
||||||
|
));
|
||||||
|
fs::create_dir_all(&dir).unwrap();
|
||||||
|
|
||||||
|
let mut config = Config::default();
|
||||||
|
config.workspace_dir = dir.join("workspace");
|
||||||
|
config.config_path = dir.join("config.toml");
|
||||||
|
config.api_key = Some("root-credential".into());
|
||||||
|
config.composio.api_key = Some("composio-credential".into());
|
||||||
|
config.browser.computer_use.api_key = Some("browser-credential".into());
|
||||||
|
|
||||||
|
config.agents.insert(
|
||||||
|
"worker".into(),
|
||||||
|
DelegateAgentConfig {
|
||||||
|
provider: "openrouter".into(),
|
||||||
|
model: "model-test".into(),
|
||||||
|
system_prompt: None,
|
||||||
|
api_key: Some("agent-credential".into()),
|
||||||
|
temperature: None,
|
||||||
|
max_depth: 3,
|
||||||
|
},
|
||||||
|
);
|
||||||
|
|
||||||
|
config.save().unwrap();
|
||||||
|
|
||||||
|
let contents = fs::read_to_string(config.config_path.clone()).unwrap();
|
||||||
|
let stored: Config = toml::from_str(&contents).unwrap();
|
||||||
|
let store = crate::security::SecretStore::new(&dir, true);
|
||||||
|
|
||||||
|
let root_encrypted = stored.api_key.as_deref().unwrap();
|
||||||
|
assert!(crate::security::SecretStore::is_encrypted(root_encrypted));
|
||||||
|
assert_eq!(store.decrypt(root_encrypted).unwrap(), "root-credential");
|
||||||
|
|
||||||
|
let composio_encrypted = stored.composio.api_key.as_deref().unwrap();
|
||||||
|
assert!(crate::security::SecretStore::is_encrypted(
|
||||||
|
composio_encrypted
|
||||||
|
));
|
||||||
|
assert_eq!(
|
||||||
|
store.decrypt(composio_encrypted).unwrap(),
|
||||||
|
"composio-credential"
|
||||||
|
);
|
||||||
|
|
||||||
|
let browser_encrypted = stored.browser.computer_use.api_key.as_deref().unwrap();
|
||||||
|
assert!(crate::security::SecretStore::is_encrypted(
|
||||||
|
browser_encrypted
|
||||||
|
));
|
||||||
|
assert_eq!(
|
||||||
|
store.decrypt(browser_encrypted).unwrap(),
|
||||||
|
"browser-credential"
|
||||||
|
);
|
||||||
|
|
||||||
|
let worker = stored.agents.get("worker").unwrap();
|
||||||
|
let worker_encrypted = worker.api_key.as_deref().unwrap();
|
||||||
|
assert!(crate::security::SecretStore::is_encrypted(worker_encrypted));
|
||||||
|
assert_eq!(store.decrypt(worker_encrypted).unwrap(), "agent-credential");
|
||||||
|
|
||||||
|
let _ = fs::remove_dir_all(&dir);
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn config_save_atomic_cleanup() {
|
fn config_save_atomic_cleanup() {
|
||||||
let dir =
|
let dir =
|
||||||
|
|
@ -2182,6 +2396,7 @@ tool_dispatcher = "xml"
|
||||||
guild_id: Some("12345".into()),
|
guild_id: Some("12345".into()),
|
||||||
allowed_users: vec![],
|
allowed_users: vec![],
|
||||||
listen_to_bots: false,
|
listen_to_bots: false,
|
||||||
|
mention_only: false,
|
||||||
};
|
};
|
||||||
let json = serde_json::to_string(&dc).unwrap();
|
let json = serde_json::to_string(&dc).unwrap();
|
||||||
let parsed: DiscordConfig = serde_json::from_str(&json).unwrap();
|
let parsed: DiscordConfig = serde_json::from_str(&json).unwrap();
|
||||||
|
|
@ -2196,6 +2411,7 @@ tool_dispatcher = "xml"
|
||||||
guild_id: None,
|
guild_id: None,
|
||||||
allowed_users: vec![],
|
allowed_users: vec![],
|
||||||
listen_to_bots: false,
|
listen_to_bots: false,
|
||||||
|
mention_only: false,
|
||||||
};
|
};
|
||||||
let json = serde_json::to_string(&dc).unwrap();
|
let json = serde_json::to_string(&dc).unwrap();
|
||||||
let parsed: DiscordConfig = serde_json::from_str(&json).unwrap();
|
let parsed: DiscordConfig = serde_json::from_str(&json).unwrap();
|
||||||
|
|
@ -2818,6 +3034,96 @@ default_temperature = 0.7
|
||||||
std::env::remove_var("ZEROCLAW_WORKSPACE");
|
std::env::remove_var("ZEROCLAW_WORKSPACE");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn load_or_init_workspace_override_uses_workspace_root_for_config() {
|
||||||
|
let _env_guard = env_override_test_guard();
|
||||||
|
let temp_home =
|
||||||
|
std::env::temp_dir().join(format!("zeroclaw_test_home_{}", uuid::Uuid::new_v4()));
|
||||||
|
let workspace_dir = temp_home.join("profile-a");
|
||||||
|
|
||||||
|
let original_home = std::env::var("HOME").ok();
|
||||||
|
std::env::set_var("HOME", &temp_home);
|
||||||
|
std::env::set_var("ZEROCLAW_WORKSPACE", &workspace_dir);
|
||||||
|
|
||||||
|
let config = Config::load_or_init().unwrap();
|
||||||
|
|
||||||
|
assert_eq!(config.workspace_dir, workspace_dir);
|
||||||
|
assert_eq!(config.config_path, workspace_dir.join("config.toml"));
|
||||||
|
assert!(workspace_dir.join("config.toml").exists());
|
||||||
|
|
||||||
|
std::env::remove_var("ZEROCLAW_WORKSPACE");
|
||||||
|
if let Some(home) = original_home {
|
||||||
|
std::env::set_var("HOME", home);
|
||||||
|
} else {
|
||||||
|
std::env::remove_var("HOME");
|
||||||
|
}
|
||||||
|
let _ = fs::remove_dir_all(temp_home);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn load_or_init_workspace_suffix_uses_legacy_config_layout() {
|
||||||
|
let _env_guard = env_override_test_guard();
|
||||||
|
let temp_home =
|
||||||
|
std::env::temp_dir().join(format!("zeroclaw_test_home_{}", uuid::Uuid::new_v4()));
|
||||||
|
let workspace_dir = temp_home.join("workspace");
|
||||||
|
let legacy_config_path = temp_home.join(".zeroclaw").join("config.toml");
|
||||||
|
|
||||||
|
let original_home = std::env::var("HOME").ok();
|
||||||
|
std::env::set_var("HOME", &temp_home);
|
||||||
|
std::env::set_var("ZEROCLAW_WORKSPACE", &workspace_dir);
|
||||||
|
|
||||||
|
let config = Config::load_or_init().unwrap();
|
||||||
|
|
||||||
|
assert_eq!(config.workspace_dir, workspace_dir);
|
||||||
|
assert_eq!(config.config_path, legacy_config_path);
|
||||||
|
assert!(config.config_path.exists());
|
||||||
|
|
||||||
|
std::env::remove_var("ZEROCLAW_WORKSPACE");
|
||||||
|
if let Some(home) = original_home {
|
||||||
|
std::env::set_var("HOME", home);
|
||||||
|
} else {
|
||||||
|
std::env::remove_var("HOME");
|
||||||
|
}
|
||||||
|
let _ = fs::remove_dir_all(temp_home);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn load_or_init_workspace_override_keeps_existing_legacy_config() {
|
||||||
|
let _env_guard = env_override_test_guard();
|
||||||
|
let temp_home =
|
||||||
|
std::env::temp_dir().join(format!("zeroclaw_test_home_{}", uuid::Uuid::new_v4()));
|
||||||
|
let workspace_dir = temp_home.join("custom-workspace");
|
||||||
|
let legacy_config_dir = temp_home.join(".zeroclaw");
|
||||||
|
let legacy_config_path = legacy_config_dir.join("config.toml");
|
||||||
|
|
||||||
|
fs::create_dir_all(&legacy_config_dir).unwrap();
|
||||||
|
fs::write(
|
||||||
|
&legacy_config_path,
|
||||||
|
r#"default_temperature = 0.7
|
||||||
|
default_model = "legacy-model"
|
||||||
|
"#,
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let original_home = std::env::var("HOME").ok();
|
||||||
|
std::env::set_var("HOME", &temp_home);
|
||||||
|
std::env::set_var("ZEROCLAW_WORKSPACE", &workspace_dir);
|
||||||
|
|
||||||
|
let config = Config::load_or_init().unwrap();
|
||||||
|
|
||||||
|
assert_eq!(config.workspace_dir, workspace_dir);
|
||||||
|
assert_eq!(config.config_path, legacy_config_path);
|
||||||
|
assert_eq!(config.default_model.as_deref(), Some("legacy-model"));
|
||||||
|
|
||||||
|
std::env::remove_var("ZEROCLAW_WORKSPACE");
|
||||||
|
if let Some(home) = original_home {
|
||||||
|
std::env::set_var("HOME", home);
|
||||||
|
} else {
|
||||||
|
std::env::remove_var("HOME");
|
||||||
|
}
|
||||||
|
let _ = fs::remove_dir_all(temp_home);
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn env_override_empty_values_ignored() {
|
fn env_override_empty_values_ignored() {
|
||||||
let _env_guard = env_override_test_guard();
|
let _env_guard = env_override_test_guard();
|
||||||
|
|
@ -2975,4 +3281,118 @@ default_temperature = 0.7
|
||||||
assert_eq!(parsed.boards[0].board, "nucleo-f401re");
|
assert_eq!(parsed.boards[0].board, "nucleo-f401re");
|
||||||
assert_eq!(parsed.boards[0].path.as_deref(), Some("/dev/ttyACM0"));
|
assert_eq!(parsed.boards[0].path.as_deref(), Some("/dev/ttyACM0"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn lark_config_serde() {
|
||||||
|
let lc = LarkConfig {
|
||||||
|
app_id: "cli_123456".into(),
|
||||||
|
app_secret: "secret_abc".into(),
|
||||||
|
encrypt_key: Some("encrypt_key".into()),
|
||||||
|
verification_token: Some("verify_token".into()),
|
||||||
|
allowed_users: vec!["user_123".into(), "user_456".into()],
|
||||||
|
use_feishu: true,
|
||||||
|
receive_mode: LarkReceiveMode::Websocket,
|
||||||
|
port: None,
|
||||||
|
};
|
||||||
|
let json = serde_json::to_string(&lc).unwrap();
|
||||||
|
let parsed: LarkConfig = serde_json::from_str(&json).unwrap();
|
||||||
|
assert_eq!(parsed.app_id, "cli_123456");
|
||||||
|
assert_eq!(parsed.app_secret, "secret_abc");
|
||||||
|
assert_eq!(parsed.encrypt_key.as_deref(), Some("encrypt_key"));
|
||||||
|
assert_eq!(parsed.verification_token.as_deref(), Some("verify_token"));
|
||||||
|
assert_eq!(parsed.allowed_users.len(), 2);
|
||||||
|
assert!(parsed.use_feishu);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn lark_config_toml_roundtrip() {
|
||||||
|
let lc = LarkConfig {
|
||||||
|
app_id: "cli_123456".into(),
|
||||||
|
app_secret: "secret_abc".into(),
|
||||||
|
encrypt_key: Some("encrypt_key".into()),
|
||||||
|
verification_token: Some("verify_token".into()),
|
||||||
|
allowed_users: vec!["*".into()],
|
||||||
|
use_feishu: false,
|
||||||
|
receive_mode: LarkReceiveMode::Webhook,
|
||||||
|
port: Some(9898),
|
||||||
|
};
|
||||||
|
let toml_str = toml::to_string(&lc).unwrap();
|
||||||
|
let parsed: LarkConfig = toml::from_str(&toml_str).unwrap();
|
||||||
|
assert_eq!(parsed.app_id, "cli_123456");
|
||||||
|
assert_eq!(parsed.app_secret, "secret_abc");
|
||||||
|
assert!(!parsed.use_feishu);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn lark_config_deserializes_without_optional_fields() {
|
||||||
|
let json = r#"{"app_id":"cli_123","app_secret":"secret"}"#;
|
||||||
|
let parsed: LarkConfig = serde_json::from_str(json).unwrap();
|
||||||
|
assert!(parsed.encrypt_key.is_none());
|
||||||
|
assert!(parsed.verification_token.is_none());
|
||||||
|
assert!(parsed.allowed_users.is_empty());
|
||||||
|
assert!(!parsed.use_feishu);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn lark_config_defaults_to_lark_endpoint() {
|
||||||
|
let json = r#"{"app_id":"cli_123","app_secret":"secret"}"#;
|
||||||
|
let parsed: LarkConfig = serde_json::from_str(json).unwrap();
|
||||||
|
assert!(
|
||||||
|
!parsed.use_feishu,
|
||||||
|
"use_feishu should default to false (Lark)"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn lark_config_with_wildcard_allowed_users() {
|
||||||
|
let json = r#"{"app_id":"cli_123","app_secret":"secret","allowed_users":["*"]}"#;
|
||||||
|
let parsed: LarkConfig = serde_json::from_str(json).unwrap();
|
||||||
|
assert_eq!(parsed.allowed_users, vec!["*"]);
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Config file permission hardening (Unix only) ───────────────
|
||||||
|
|
||||||
|
#[cfg(unix)]
|
||||||
|
#[test]
|
||||||
|
fn new_config_file_has_restricted_permissions() {
|
||||||
|
use std::os::unix::fs::PermissionsExt;
|
||||||
|
|
||||||
|
let tmp = tempfile::TempDir::new().unwrap();
|
||||||
|
let config_path = tmp.path().join("config.toml");
|
||||||
|
|
||||||
|
// Create a config and save it
|
||||||
|
let mut config = Config::default();
|
||||||
|
config.config_path = config_path.clone();
|
||||||
|
config.save().unwrap();
|
||||||
|
|
||||||
|
// Apply the same permission logic as load_or_init
|
||||||
|
let _ = std::fs::set_permissions(&config_path, std::fs::Permissions::from_mode(0o600));
|
||||||
|
|
||||||
|
let meta = std::fs::metadata(&config_path).unwrap();
|
||||||
|
let mode = meta.permissions().mode() & 0o777;
|
||||||
|
assert_eq!(
|
||||||
|
mode, 0o600,
|
||||||
|
"New config file should be owner-only (0600), got {mode:o}"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(unix)]
|
||||||
|
#[test]
|
||||||
|
fn world_readable_config_is_detectable() {
|
||||||
|
use std::os::unix::fs::PermissionsExt;
|
||||||
|
|
||||||
|
let tmp = tempfile::TempDir::new().unwrap();
|
||||||
|
let config_path = tmp.path().join("config.toml");
|
||||||
|
|
||||||
|
// Create a config file with intentionally loose permissions
|
||||||
|
std::fs::write(&config_path, "# test config").unwrap();
|
||||||
|
std::fs::set_permissions(&config_path, std::fs::Permissions::from_mode(0o644)).unwrap();
|
||||||
|
|
||||||
|
let meta = std::fs::metadata(&config_path).unwrap();
|
||||||
|
let mode = meta.permissions().mode();
|
||||||
|
assert!(
|
||||||
|
mode & 0o004 != 0,
|
||||||
|
"Test setup: file should be world-readable (mode {mode:o})"
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -245,6 +245,7 @@ async fn deliver_if_configured(config: &Config, job: &CronJob, output: &str) ->
|
||||||
dc.guild_id.clone(),
|
dc.guild_id.clone(),
|
||||||
dc.allowed_users.clone(),
|
dc.allowed_users.clone(),
|
||||||
dc.listen_to_bots,
|
dc.listen_to_bots,
|
||||||
|
dc.mention_only,
|
||||||
);
|
);
|
||||||
channel.send(output, target).await?;
|
channel.send(output, target).await?;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -216,6 +216,7 @@ fn has_supervised_channels(config: &Config) -> bool {
|
||||||
|| config.channels_config.matrix.is_some()
|
|| config.channels_config.matrix.is_some()
|
||||||
|| config.channels_config.whatsapp.is_some()
|
|| config.channels_config.whatsapp.is_some()
|
||||||
|| config.channels_config.email.is_some()
|
|| config.channels_config.email.is_some()
|
||||||
|
|| config.channels_config.lark.is_some()
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
|
|
||||||
|
|
@ -49,6 +49,13 @@ fn whatsapp_memory_key(msg: &crate::channels::traits::ChannelMessage) -> String
|
||||||
format!("whatsapp_{}_{}", msg.sender, msg.id)
|
format!("whatsapp_{}_{}", msg.sender, msg.id)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn hash_webhook_secret(value: &str) -> String {
|
||||||
|
use sha2::{Digest, Sha256};
|
||||||
|
|
||||||
|
let digest = Sha256::digest(value.as_bytes());
|
||||||
|
hex::encode(digest)
|
||||||
|
}
|
||||||
|
|
||||||
/// How often the rate limiter sweeps stale IP entries from its map.
|
/// How often the rate limiter sweeps stale IP entries from its map.
|
||||||
const RATE_LIMITER_SWEEP_INTERVAL_SECS: u64 = 300; // 5 minutes
|
const RATE_LIMITER_SWEEP_INTERVAL_SECS: u64 = 300; // 5 minutes
|
||||||
|
|
||||||
|
|
@ -178,7 +185,8 @@ pub struct AppState {
|
||||||
pub temperature: f64,
|
pub temperature: f64,
|
||||||
pub mem: Arc<dyn Memory>,
|
pub mem: Arc<dyn Memory>,
|
||||||
pub auto_save: bool,
|
pub auto_save: bool,
|
||||||
pub webhook_secret: Option<Arc<str>>,
|
/// SHA-256 hash of `X-Webhook-Secret` (hex-encoded), never plaintext.
|
||||||
|
pub webhook_secret_hash: Option<Arc<str>>,
|
||||||
pub pairing: Arc<PairingGuard>,
|
pub pairing: Arc<PairingGuard>,
|
||||||
pub rate_limiter: Arc<GatewayRateLimiter>,
|
pub rate_limiter: Arc<GatewayRateLimiter>,
|
||||||
pub idempotency_store: Arc<IdempotencyStore>,
|
pub idempotency_store: Arc<IdempotencyStore>,
|
||||||
|
|
@ -208,6 +216,7 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
|
||||||
let provider: Arc<dyn Provider> = Arc::from(providers::create_resilient_provider(
|
let provider: Arc<dyn Provider> = Arc::from(providers::create_resilient_provider(
|
||||||
config.default_provider.as_deref().unwrap_or("openrouter"),
|
config.default_provider.as_deref().unwrap_or("openrouter"),
|
||||||
config.api_key.as_deref(),
|
config.api_key.as_deref(),
|
||||||
|
config.api_url.as_deref(),
|
||||||
&config.reliability,
|
&config.reliability,
|
||||||
)?);
|
)?);
|
||||||
let model = config
|
let model = config
|
||||||
|
|
@ -251,12 +260,14 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
|
||||||
&config,
|
&config,
|
||||||
));
|
));
|
||||||
// Extract webhook secret for authentication
|
// Extract webhook secret for authentication
|
||||||
let webhook_secret: Option<Arc<str>> = config
|
let webhook_secret_hash: Option<Arc<str>> =
|
||||||
.channels_config
|
config.channels_config.webhook.as_ref().and_then(|webhook| {
|
||||||
.webhook
|
webhook.secret.as_ref().and_then(|raw_secret| {
|
||||||
.as_ref()
|
let trimmed_secret = raw_secret.trim();
|
||||||
.and_then(|w| w.secret.as_deref())
|
(!trimmed_secret.is_empty())
|
||||||
.map(Arc::from);
|
.then(|| Arc::<str>::from(hash_webhook_secret(trimmed_secret)))
|
||||||
|
})
|
||||||
|
});
|
||||||
|
|
||||||
// WhatsApp channel (if configured)
|
// WhatsApp channel (if configured)
|
||||||
let whatsapp_channel: Option<Arc<WhatsAppChannel>> =
|
let whatsapp_channel: Option<Arc<WhatsAppChannel>> =
|
||||||
|
|
@ -342,9 +353,6 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
|
||||||
} else {
|
} else {
|
||||||
println!(" ⚠️ Pairing: DISABLED (all requests accepted)");
|
println!(" ⚠️ Pairing: DISABLED (all requests accepted)");
|
||||||
}
|
}
|
||||||
if webhook_secret.is_some() {
|
|
||||||
println!(" 🔒 Webhook secret: ENABLED");
|
|
||||||
}
|
|
||||||
println!(" Press Ctrl+C to stop.\n");
|
println!(" Press Ctrl+C to stop.\n");
|
||||||
|
|
||||||
crate::health::mark_component_ok("gateway");
|
crate::health::mark_component_ok("gateway");
|
||||||
|
|
@ -356,7 +364,7 @@ pub async fn run_gateway(host: &str, port: u16, config: Config) -> Result<()> {
|
||||||
temperature,
|
temperature,
|
||||||
mem,
|
mem,
|
||||||
auto_save: config.memory.auto_save,
|
auto_save: config.memory.auto_save,
|
||||||
webhook_secret,
|
webhook_secret_hash,
|
||||||
pairing,
|
pairing,
|
||||||
rate_limiter,
|
rate_limiter,
|
||||||
idempotency_store,
|
idempotency_store,
|
||||||
|
|
@ -482,12 +490,15 @@ async fn handle_webhook(
|
||||||
}
|
}
|
||||||
|
|
||||||
// ── Webhook secret auth (optional, additional layer) ──
|
// ── Webhook secret auth (optional, additional layer) ──
|
||||||
if let Some(ref secret) = state.webhook_secret {
|
if let Some(ref secret_hash) = state.webhook_secret_hash {
|
||||||
let header_val = headers
|
let header_hash = headers
|
||||||
.get("X-Webhook-Secret")
|
.get("X-Webhook-Secret")
|
||||||
.and_then(|v| v.to_str().ok());
|
.and_then(|v| v.to_str().ok())
|
||||||
match header_val {
|
.map(str::trim)
|
||||||
Some(val) if constant_time_eq(val, secret.as_ref()) => {}
|
.filter(|value| !value.is_empty())
|
||||||
|
.map(hash_webhook_secret);
|
||||||
|
match header_hash {
|
||||||
|
Some(val) if constant_time_eq(&val, secret_hash.as_ref()) => {}
|
||||||
_ => {
|
_ => {
|
||||||
tracing::warn!("Webhook: rejected request — invalid or missing X-Webhook-Secret");
|
tracing::warn!("Webhook: rejected request — invalid or missing X-Webhook-Secret");
|
||||||
let err = serde_json::json!({"error": "Unauthorized — invalid or missing X-Webhook-Secret header"});
|
let err = serde_json::json!({"error": "Unauthorized — invalid or missing X-Webhook-Secret header"});
|
||||||
|
|
@ -532,7 +543,7 @@ async fn handle_webhook(
|
||||||
let key = webhook_memory_key();
|
let key = webhook_memory_key();
|
||||||
let _ = state
|
let _ = state
|
||||||
.mem
|
.mem
|
||||||
.store(&key, message, MemoryCategory::Conversation)
|
.store(&key, message, MemoryCategory::Conversation, None)
|
||||||
.await;
|
.await;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -685,7 +696,7 @@ async fn handle_whatsapp_message(
|
||||||
let key = whatsapp_memory_key(msg);
|
let key = whatsapp_memory_key(msg);
|
||||||
let _ = state
|
let _ = state
|
||||||
.mem
|
.mem
|
||||||
.store(&key, &msg.content, MemoryCategory::Conversation)
|
.store(&key, &msg.content, MemoryCategory::Conversation, None)
|
||||||
.await;
|
.await;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -697,7 +708,7 @@ async fn handle_whatsapp_message(
|
||||||
{
|
{
|
||||||
Ok(response) => {
|
Ok(response) => {
|
||||||
// Send reply via WhatsApp
|
// Send reply via WhatsApp
|
||||||
if let Err(e) = wa.send(&response, &msg.sender).await {
|
if let Err(e) = wa.send(&response, &msg.reply_target).await {
|
||||||
tracing::error!("Failed to send WhatsApp reply: {e}");
|
tracing::error!("Failed to send WhatsApp reply: {e}");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -706,7 +717,7 @@ async fn handle_whatsapp_message(
|
||||||
let _ = wa
|
let _ = wa
|
||||||
.send(
|
.send(
|
||||||
"Sorry, I couldn't process your message right now.",
|
"Sorry, I couldn't process your message right now.",
|
||||||
&msg.sender,
|
&msg.reply_target,
|
||||||
)
|
)
|
||||||
.await;
|
.await;
|
||||||
}
|
}
|
||||||
|
|
@ -798,7 +809,9 @@ mod tests {
|
||||||
.requests
|
.requests
|
||||||
.lock()
|
.lock()
|
||||||
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
.unwrap_or_else(std::sync::PoisonError::into_inner);
|
||||||
guard.1 = Instant::now() - Duration::from_secs(RATE_LIMITER_SWEEP_INTERVAL_SECS + 1);
|
guard.1 = Instant::now()
|
||||||
|
.checked_sub(Duration::from_secs(RATE_LIMITER_SWEEP_INTERVAL_SECS + 1))
|
||||||
|
.unwrap();
|
||||||
// Clear timestamps for ip-2 and ip-3 to simulate stale entries
|
// Clear timestamps for ip-2 and ip-3 to simulate stale entries
|
||||||
guard.0.get_mut("ip-2").unwrap().clear();
|
guard.0.get_mut("ip-2").unwrap().clear();
|
||||||
guard.0.get_mut("ip-3").unwrap().clear();
|
guard.0.get_mut("ip-3").unwrap().clear();
|
||||||
|
|
@ -848,6 +861,7 @@ mod tests {
|
||||||
let msg = ChannelMessage {
|
let msg = ChannelMessage {
|
||||||
id: "wamid-123".into(),
|
id: "wamid-123".into(),
|
||||||
sender: "+1234567890".into(),
|
sender: "+1234567890".into(),
|
||||||
|
reply_target: "+1234567890".into(),
|
||||||
content: "hello".into(),
|
content: "hello".into(),
|
||||||
channel: "whatsapp".into(),
|
channel: "whatsapp".into(),
|
||||||
timestamp: 1,
|
timestamp: 1,
|
||||||
|
|
@ -871,11 +885,17 @@ mod tests {
|
||||||
_key: &str,
|
_key: &str,
|
||||||
_content: &str,
|
_content: &str,
|
||||||
_category: MemoryCategory,
|
_category: MemoryCategory,
|
||||||
|
_session_id: Option<&str>,
|
||||||
) -> anyhow::Result<()> {
|
) -> anyhow::Result<()> {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn recall(&self, _query: &str, _limit: usize) -> anyhow::Result<Vec<MemoryEntry>> {
|
async fn recall(
|
||||||
|
&self,
|
||||||
|
_query: &str,
|
||||||
|
_limit: usize,
|
||||||
|
_session_id: Option<&str>,
|
||||||
|
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||||
Ok(Vec::new())
|
Ok(Vec::new())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -886,6 +906,7 @@ mod tests {
|
||||||
async fn list(
|
async fn list(
|
||||||
&self,
|
&self,
|
||||||
_category: Option<&MemoryCategory>,
|
_category: Option<&MemoryCategory>,
|
||||||
|
_session_id: Option<&str>,
|
||||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||||
Ok(Vec::new())
|
Ok(Vec::new())
|
||||||
}
|
}
|
||||||
|
|
@ -938,6 +959,7 @@ mod tests {
|
||||||
key: &str,
|
key: &str,
|
||||||
_content: &str,
|
_content: &str,
|
||||||
_category: MemoryCategory,
|
_category: MemoryCategory,
|
||||||
|
_session_id: Option<&str>,
|
||||||
) -> anyhow::Result<()> {
|
) -> anyhow::Result<()> {
|
||||||
self.keys
|
self.keys
|
||||||
.lock()
|
.lock()
|
||||||
|
|
@ -946,7 +968,12 @@ mod tests {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn recall(&self, _query: &str, _limit: usize) -> anyhow::Result<Vec<MemoryEntry>> {
|
async fn recall(
|
||||||
|
&self,
|
||||||
|
_query: &str,
|
||||||
|
_limit: usize,
|
||||||
|
_session_id: Option<&str>,
|
||||||
|
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||||
Ok(Vec::new())
|
Ok(Vec::new())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -957,6 +984,7 @@ mod tests {
|
||||||
async fn list(
|
async fn list(
|
||||||
&self,
|
&self,
|
||||||
_category: Option<&MemoryCategory>,
|
_category: Option<&MemoryCategory>,
|
||||||
|
_session_id: Option<&str>,
|
||||||
) -> anyhow::Result<Vec<MemoryEntry>> {
|
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||||
Ok(Vec::new())
|
Ok(Vec::new())
|
||||||
}
|
}
|
||||||
|
|
@ -991,7 +1019,7 @@ mod tests {
|
||||||
temperature: 0.0,
|
temperature: 0.0,
|
||||||
mem: memory,
|
mem: memory,
|
||||||
auto_save: false,
|
auto_save: false,
|
||||||
webhook_secret: None,
|
webhook_secret_hash: None,
|
||||||
pairing: Arc::new(PairingGuard::new(false, &[])),
|
pairing: Arc::new(PairingGuard::new(false, &[])),
|
||||||
rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100)),
|
rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100)),
|
||||||
idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300))),
|
idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300))),
|
||||||
|
|
@ -1039,7 +1067,7 @@ mod tests {
|
||||||
temperature: 0.0,
|
temperature: 0.0,
|
||||||
mem: memory,
|
mem: memory,
|
||||||
auto_save: true,
|
auto_save: true,
|
||||||
webhook_secret: None,
|
webhook_secret_hash: None,
|
||||||
pairing: Arc::new(PairingGuard::new(false, &[])),
|
pairing: Arc::new(PairingGuard::new(false, &[])),
|
||||||
rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100)),
|
rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100)),
|
||||||
idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300))),
|
idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300))),
|
||||||
|
|
@ -1077,6 +1105,125 @@ mod tests {
|
||||||
assert_eq!(provider_impl.calls.load(Ordering::SeqCst), 2);
|
assert_eq!(provider_impl.calls.load(Ordering::SeqCst), 2);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn webhook_secret_hash_is_deterministic_and_nonempty() {
|
||||||
|
let one = hash_webhook_secret("secret-value");
|
||||||
|
let two = hash_webhook_secret("secret-value");
|
||||||
|
let other = hash_webhook_secret("other-value");
|
||||||
|
|
||||||
|
assert_eq!(one, two);
|
||||||
|
assert_ne!(one, other);
|
||||||
|
assert_eq!(one.len(), 64);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn webhook_secret_hash_rejects_missing_header() {
|
||||||
|
let provider_impl = Arc::new(MockProvider::default());
|
||||||
|
let provider: Arc<dyn Provider> = provider_impl.clone();
|
||||||
|
let memory: Arc<dyn Memory> = Arc::new(MockMemory);
|
||||||
|
|
||||||
|
let state = AppState {
|
||||||
|
provider,
|
||||||
|
model: "test-model".into(),
|
||||||
|
temperature: 0.0,
|
||||||
|
mem: memory,
|
||||||
|
auto_save: false,
|
||||||
|
webhook_secret_hash: Some(Arc::from(hash_webhook_secret("super-secret"))),
|
||||||
|
pairing: Arc::new(PairingGuard::new(false, &[])),
|
||||||
|
rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100)),
|
||||||
|
idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300))),
|
||||||
|
whatsapp: None,
|
||||||
|
whatsapp_app_secret: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
let response = handle_webhook(
|
||||||
|
State(state),
|
||||||
|
HeaderMap::new(),
|
||||||
|
Ok(Json(WebhookBody {
|
||||||
|
message: "hello".into(),
|
||||||
|
})),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.into_response();
|
||||||
|
|
||||||
|
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
|
||||||
|
assert_eq!(provider_impl.calls.load(Ordering::SeqCst), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn webhook_secret_hash_rejects_invalid_header() {
|
||||||
|
let provider_impl = Arc::new(MockProvider::default());
|
||||||
|
let provider: Arc<dyn Provider> = provider_impl.clone();
|
||||||
|
let memory: Arc<dyn Memory> = Arc::new(MockMemory);
|
||||||
|
|
||||||
|
let state = AppState {
|
||||||
|
provider,
|
||||||
|
model: "test-model".into(),
|
||||||
|
temperature: 0.0,
|
||||||
|
mem: memory,
|
||||||
|
auto_save: false,
|
||||||
|
webhook_secret_hash: Some(Arc::from(hash_webhook_secret("super-secret"))),
|
||||||
|
pairing: Arc::new(PairingGuard::new(false, &[])),
|
||||||
|
rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100)),
|
||||||
|
idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300))),
|
||||||
|
whatsapp: None,
|
||||||
|
whatsapp_app_secret: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut headers = HeaderMap::new();
|
||||||
|
headers.insert("X-Webhook-Secret", HeaderValue::from_static("wrong-secret"));
|
||||||
|
|
||||||
|
let response = handle_webhook(
|
||||||
|
State(state),
|
||||||
|
headers,
|
||||||
|
Ok(Json(WebhookBody {
|
||||||
|
message: "hello".into(),
|
||||||
|
})),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.into_response();
|
||||||
|
|
||||||
|
assert_eq!(response.status(), StatusCode::UNAUTHORIZED);
|
||||||
|
assert_eq!(provider_impl.calls.load(Ordering::SeqCst), 0);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn webhook_secret_hash_accepts_valid_header() {
|
||||||
|
let provider_impl = Arc::new(MockProvider::default());
|
||||||
|
let provider: Arc<dyn Provider> = provider_impl.clone();
|
||||||
|
let memory: Arc<dyn Memory> = Arc::new(MockMemory);
|
||||||
|
|
||||||
|
let state = AppState {
|
||||||
|
provider,
|
||||||
|
model: "test-model".into(),
|
||||||
|
temperature: 0.0,
|
||||||
|
mem: memory,
|
||||||
|
auto_save: false,
|
||||||
|
webhook_secret_hash: Some(Arc::from(hash_webhook_secret("super-secret"))),
|
||||||
|
pairing: Arc::new(PairingGuard::new(false, &[])),
|
||||||
|
rate_limiter: Arc::new(GatewayRateLimiter::new(100, 100)),
|
||||||
|
idempotency_store: Arc::new(IdempotencyStore::new(Duration::from_secs(300))),
|
||||||
|
whatsapp: None,
|
||||||
|
whatsapp_app_secret: None,
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut headers = HeaderMap::new();
|
||||||
|
headers.insert("X-Webhook-Secret", HeaderValue::from_static("super-secret"));
|
||||||
|
|
||||||
|
let response = handle_webhook(
|
||||||
|
State(state),
|
||||||
|
headers,
|
||||||
|
Ok(Json(WebhookBody {
|
||||||
|
message: "hello".into(),
|
||||||
|
})),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.into_response();
|
||||||
|
|
||||||
|
assert_eq!(response.status(), StatusCode::OK);
|
||||||
|
assert_eq!(provider_impl.calls.load(Ordering::SeqCst), 1);
|
||||||
|
}
|
||||||
|
|
||||||
// ══════════════════════════════════════════════════════════
|
// ══════════════════════════════════════════════════════════
|
||||||
// WhatsApp Signature Verification Tests (CWE-345 Prevention)
|
// WhatsApp Signature Verification Tests (CWE-345 Prevention)
|
||||||
// ══════════════════════════════════════════════════════════
|
// ══════════════════════════════════════════════════════════
|
||||||
|
|
|
||||||
40
src/main.rs
40
src/main.rs
|
|
@ -34,8 +34,8 @@
|
||||||
|
|
||||||
use anyhow::{bail, Result};
|
use anyhow::{bail, Result};
|
||||||
use clap::{Parser, Subcommand};
|
use clap::{Parser, Subcommand};
|
||||||
use tracing::{info, Level};
|
use tracing::info;
|
||||||
use tracing_subscriber::FmtSubscriber;
|
use tracing_subscriber::{fmt, EnvFilter};
|
||||||
|
|
||||||
mod agent;
|
mod agent;
|
||||||
mod channels;
|
mod channels;
|
||||||
|
|
@ -147,24 +147,24 @@ enum Commands {
|
||||||
|
|
||||||
/// Start the gateway server (webhooks, websockets)
|
/// Start the gateway server (webhooks, websockets)
|
||||||
Gateway {
|
Gateway {
|
||||||
/// Port to listen on (use 0 for random available port)
|
/// Port to listen on (use 0 for random available port); defaults to config gateway.port
|
||||||
#[arg(short, long, default_value = "8080")]
|
#[arg(short, long)]
|
||||||
port: u16,
|
port: Option<u16>,
|
||||||
|
|
||||||
/// Host to bind to
|
/// Host to bind to; defaults to config gateway.host
|
||||||
#[arg(long, default_value = "127.0.0.1")]
|
#[arg(long)]
|
||||||
host: String,
|
host: Option<String>,
|
||||||
},
|
},
|
||||||
|
|
||||||
/// Start long-running autonomous runtime (gateway + channels + heartbeat + scheduler)
|
/// Start long-running autonomous runtime (gateway + channels + heartbeat + scheduler)
|
||||||
Daemon {
|
Daemon {
|
||||||
/// Port to listen on (use 0 for random available port)
|
/// Port to listen on (use 0 for random available port); defaults to config gateway.port
|
||||||
#[arg(short, long, default_value = "8080")]
|
#[arg(short, long)]
|
||||||
port: u16,
|
port: Option<u16>,
|
||||||
|
|
||||||
/// Host to bind to
|
/// Host to bind to; defaults to config gateway.host
|
||||||
#[arg(long, default_value = "127.0.0.1")]
|
#[arg(long)]
|
||||||
host: String,
|
host: Option<String>,
|
||||||
},
|
},
|
||||||
|
|
||||||
/// Manage OS service lifecycle (launchd/systemd user service)
|
/// Manage OS service lifecycle (launchd/systemd user service)
|
||||||
|
|
@ -367,9 +367,11 @@ async fn main() -> Result<()> {
|
||||||
|
|
||||||
let cli = Cli::parse();
|
let cli = Cli::parse();
|
||||||
|
|
||||||
// Initialize logging
|
// Initialize logging - respects RUST_LOG env var, defaults to INFO
|
||||||
let subscriber = FmtSubscriber::builder()
|
let subscriber = fmt::Subscriber::builder()
|
||||||
.with_max_level(Level::INFO)
|
.with_env_filter(
|
||||||
|
EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")),
|
||||||
|
)
|
||||||
.finish();
|
.finish();
|
||||||
|
|
||||||
tracing::subscriber::set_global_default(subscriber).expect("setting default subscriber failed");
|
tracing::subscriber::set_global_default(subscriber).expect("setting default subscriber failed");
|
||||||
|
|
@ -434,6 +436,8 @@ async fn main() -> Result<()> {
|
||||||
.map(|_| ()),
|
.map(|_| ()),
|
||||||
|
|
||||||
Commands::Gateway { port, host } => {
|
Commands::Gateway { port, host } => {
|
||||||
|
let port = port.unwrap_or(config.gateway.port);
|
||||||
|
let host = host.unwrap_or_else(|| config.gateway.host.clone());
|
||||||
if port == 0 {
|
if port == 0 {
|
||||||
info!("🚀 Starting ZeroClaw Gateway on {host} (random port)");
|
info!("🚀 Starting ZeroClaw Gateway on {host} (random port)");
|
||||||
} else {
|
} else {
|
||||||
|
|
@ -443,6 +447,8 @@ async fn main() -> Result<()> {
|
||||||
}
|
}
|
||||||
|
|
||||||
Commands::Daemon { port, host } => {
|
Commands::Daemon { port, host } => {
|
||||||
|
let port = port.unwrap_or(config.gateway.port);
|
||||||
|
let host = host.unwrap_or_else(|| config.gateway.host.clone());
|
||||||
if port == 0 {
|
if port == 0 {
|
||||||
info!("🧠 Starting ZeroClaw Daemon on {host} (random port)");
|
info!("🧠 Starting ZeroClaw Daemon on {host} (random port)");
|
||||||
} else {
|
} else {
|
||||||
|
|
|
||||||
|
|
@ -7,6 +7,7 @@ pub enum MemoryBackendKind {
|
||||||
Unknown,
|
Unknown,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[allow(clippy::struct_excessive_bools)]
|
||||||
#[derive(Debug, Clone, Copy, Eq, PartialEq)]
|
#[derive(Debug, Clone, Copy, Eq, PartialEq)]
|
||||||
pub struct MemoryBackendProfile {
|
pub struct MemoryBackendProfile {
|
||||||
pub key: &'static str,
|
pub key: &'static str,
|
||||||
|
|
|
||||||
|
|
@ -502,10 +502,10 @@ mod tests {
|
||||||
let workspace = tmp.path();
|
let workspace = tmp.path();
|
||||||
|
|
||||||
let mem = SqliteMemory::new(workspace).unwrap();
|
let mem = SqliteMemory::new(workspace).unwrap();
|
||||||
mem.store("conv_old", "outdated", MemoryCategory::Conversation)
|
mem.store("conv_old", "outdated", MemoryCategory::Conversation, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
mem.store("core_keep", "durable", MemoryCategory::Core)
|
mem.store("core_keep", "durable", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
drop(mem);
|
drop(mem);
|
||||||
|
|
|
||||||
|
|
@ -24,7 +24,9 @@ pub struct LucidMemory {
|
||||||
impl LucidMemory {
|
impl LucidMemory {
|
||||||
const DEFAULT_LUCID_CMD: &'static str = "lucid";
|
const DEFAULT_LUCID_CMD: &'static str = "lucid";
|
||||||
const DEFAULT_TOKEN_BUDGET: usize = 200;
|
const DEFAULT_TOKEN_BUDGET: usize = 200;
|
||||||
const DEFAULT_RECALL_TIMEOUT_MS: u64 = 120;
|
// Lucid CLI cold start can exceed 120ms on slower machines, which causes
|
||||||
|
// avoidable fallback to local-only memory and premature cooldown.
|
||||||
|
const DEFAULT_RECALL_TIMEOUT_MS: u64 = 500;
|
||||||
const DEFAULT_STORE_TIMEOUT_MS: u64 = 800;
|
const DEFAULT_STORE_TIMEOUT_MS: u64 = 800;
|
||||||
const DEFAULT_LOCAL_HIT_THRESHOLD: usize = 3;
|
const DEFAULT_LOCAL_HIT_THRESHOLD: usize = 3;
|
||||||
const DEFAULT_FAILURE_COOLDOWN_MS: u64 = 15_000;
|
const DEFAULT_FAILURE_COOLDOWN_MS: u64 = 15_000;
|
||||||
|
|
@ -74,6 +76,7 @@ impl LucidMemory {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
|
#[allow(clippy::too_many_arguments)]
|
||||||
fn with_options(
|
fn with_options(
|
||||||
workspace_dir: &Path,
|
workspace_dir: &Path,
|
||||||
local: SqliteMemory,
|
local: SqliteMemory,
|
||||||
|
|
@ -307,14 +310,22 @@ impl Memory for LucidMemory {
|
||||||
key: &str,
|
key: &str,
|
||||||
content: &str,
|
content: &str,
|
||||||
category: MemoryCategory,
|
category: MemoryCategory,
|
||||||
|
session_id: Option<&str>,
|
||||||
) -> anyhow::Result<()> {
|
) -> anyhow::Result<()> {
|
||||||
self.local.store(key, content, category.clone()).await?;
|
self.local
|
||||||
|
.store(key, content, category.clone(), session_id)
|
||||||
|
.await?;
|
||||||
self.sync_to_lucid_async(key, content, &category).await;
|
self.sync_to_lucid_async(key, content, &category).await;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn recall(&self, query: &str, limit: usize) -> anyhow::Result<Vec<MemoryEntry>> {
|
async fn recall(
|
||||||
let local_results = self.local.recall(query, limit).await?;
|
&self,
|
||||||
|
query: &str,
|
||||||
|
limit: usize,
|
||||||
|
session_id: Option<&str>,
|
||||||
|
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||||
|
let local_results = self.local.recall(query, limit, session_id).await?;
|
||||||
if limit == 0
|
if limit == 0
|
||||||
|| local_results.len() >= limit
|
|| local_results.len() >= limit
|
||||||
|| local_results.len() >= self.local_hit_threshold
|
|| local_results.len() >= self.local_hit_threshold
|
||||||
|
|
@ -351,8 +362,12 @@ impl Memory for LucidMemory {
|
||||||
self.local.get(key).await
|
self.local.get(key).await
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn list(&self, category: Option<&MemoryCategory>) -> anyhow::Result<Vec<MemoryEntry>> {
|
async fn list(
|
||||||
self.local.list(category).await
|
&self,
|
||||||
|
category: Option<&MemoryCategory>,
|
||||||
|
session_id: Option<&str>,
|
||||||
|
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||||
|
self.local.list(category, session_id).await
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn forget(&self, key: &str) -> anyhow::Result<bool> {
|
async fn forget(&self, key: &str) -> anyhow::Result<bool> {
|
||||||
|
|
@ -396,6 +411,38 @@ EOF
|
||||||
exit 0
|
exit 0
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
echo "unsupported command" >&2
|
||||||
|
exit 1
|
||||||
|
"#;
|
||||||
|
|
||||||
|
fs::write(&script_path, script).unwrap();
|
||||||
|
let mut perms = fs::metadata(&script_path).unwrap().permissions();
|
||||||
|
perms.set_mode(0o755);
|
||||||
|
fs::set_permissions(&script_path, perms).unwrap();
|
||||||
|
script_path.display().to_string()
|
||||||
|
}
|
||||||
|
|
||||||
|
fn write_delayed_lucid_script(dir: &Path) -> String {
|
||||||
|
let script_path = dir.join("delayed-lucid.sh");
|
||||||
|
let script = r#"#!/usr/bin/env bash
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
if [[ "${1:-}" == "store" ]]; then
|
||||||
|
echo '{"success":true,"id":"mem_1"}'
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
|
||||||
|
if [[ "${1:-}" == "context" ]]; then
|
||||||
|
# Simulate a cold start that is slower than 120ms but below the 500ms timeout.
|
||||||
|
sleep 0.2
|
||||||
|
cat <<'EOF'
|
||||||
|
<lucid-context>
|
||||||
|
- [decision] Delayed token refresh guidance
|
||||||
|
</lucid-context>
|
||||||
|
EOF
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
|
||||||
echo "unsupported command" >&2
|
echo "unsupported command" >&2
|
||||||
exit 1
|
exit 1
|
||||||
"#;
|
"#;
|
||||||
|
|
@ -449,7 +496,7 @@ exit 1
|
||||||
cmd,
|
cmd,
|
||||||
200,
|
200,
|
||||||
3,
|
3,
|
||||||
Duration::from_millis(120),
|
Duration::from_millis(500),
|
||||||
Duration::from_millis(400),
|
Duration::from_millis(400),
|
||||||
Duration::from_secs(2),
|
Duration::from_secs(2),
|
||||||
)
|
)
|
||||||
|
|
@ -468,7 +515,7 @@ exit 1
|
||||||
let memory = test_memory(tmp.path(), "nonexistent-lucid-binary".to_string());
|
let memory = test_memory(tmp.path(), "nonexistent-lucid-binary".to_string());
|
||||||
|
|
||||||
memory
|
memory
|
||||||
.store("lang", "User prefers Rust", MemoryCategory::Core)
|
.store("lang", "User prefers Rust", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
|
|
@ -483,6 +530,30 @@ exit 1
|
||||||
let fake_cmd = write_fake_lucid_script(tmp.path());
|
let fake_cmd = write_fake_lucid_script(tmp.path());
|
||||||
let memory = test_memory(tmp.path(), fake_cmd);
|
let memory = test_memory(tmp.path(), fake_cmd);
|
||||||
|
|
||||||
|
memory
|
||||||
|
.store(
|
||||||
|
"local_note",
|
||||||
|
"Local sqlite auth fallback note",
|
||||||
|
MemoryCategory::Core,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let entries = memory.recall("auth", 5, None).await.unwrap();
|
||||||
|
|
||||||
|
assert!(entries
|
||||||
|
.iter()
|
||||||
|
.any(|e| e.content.contains("Local sqlite auth fallback note")));
|
||||||
|
assert!(entries.iter().any(|e| e.content.contains("token refresh")));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn recall_handles_lucid_cold_start_delay_within_timeout() {
|
||||||
|
let tmp = TempDir::new().unwrap();
|
||||||
|
let delayed_cmd = write_delayed_lucid_script(tmp.path());
|
||||||
|
let memory = test_memory(tmp.path(), delayed_cmd);
|
||||||
|
|
||||||
memory
|
memory
|
||||||
.store(
|
.store(
|
||||||
"local_note",
|
"local_note",
|
||||||
|
|
@ -497,7 +568,9 @@ exit 1
|
||||||
assert!(entries
|
assert!(entries
|
||||||
.iter()
|
.iter()
|
||||||
.any(|e| e.content.contains("Local sqlite auth fallback note")));
|
.any(|e| e.content.contains("Local sqlite auth fallback note")));
|
||||||
assert!(entries.iter().any(|e| e.content.contains("token refresh")));
|
assert!(entries
|
||||||
|
.iter()
|
||||||
|
.any(|e| e.content.contains("Delayed token refresh guidance")));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
|
|
@ -513,17 +586,22 @@ exit 1
|
||||||
probe_cmd,
|
probe_cmd,
|
||||||
200,
|
200,
|
||||||
1,
|
1,
|
||||||
Duration::from_millis(120),
|
Duration::from_millis(500),
|
||||||
Duration::from_millis(400),
|
Duration::from_millis(400),
|
||||||
Duration::from_secs(2),
|
Duration::from_secs(2),
|
||||||
);
|
);
|
||||||
|
|
||||||
memory
|
memory
|
||||||
.store("pref", "Rust should stay local-first", MemoryCategory::Core)
|
.store(
|
||||||
|
"pref",
|
||||||
|
"Rust should stay local-first",
|
||||||
|
MemoryCategory::Core,
|
||||||
|
None,
|
||||||
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let entries = memory.recall("rust", 5).await.unwrap();
|
let entries = memory.recall("rust", 5, None).await.unwrap();
|
||||||
assert!(entries
|
assert!(entries
|
||||||
.iter()
|
.iter()
|
||||||
.any(|e| e.content.contains("Rust should stay local-first")));
|
.any(|e| e.content.contains("Rust should stay local-first")));
|
||||||
|
|
@ -578,13 +656,13 @@ exit 1
|
||||||
failing_cmd,
|
failing_cmd,
|
||||||
200,
|
200,
|
||||||
99,
|
99,
|
||||||
Duration::from_millis(120),
|
Duration::from_millis(500),
|
||||||
Duration::from_millis(400),
|
Duration::from_millis(400),
|
||||||
Duration::from_secs(5),
|
Duration::from_secs(5),
|
||||||
);
|
);
|
||||||
|
|
||||||
let first = memory.recall("auth", 5).await.unwrap();
|
let first = memory.recall("auth", 5, None).await.unwrap();
|
||||||
let second = memory.recall("auth", 5).await.unwrap();
|
let second = memory.recall("auth", 5, None).await.unwrap();
|
||||||
|
|
||||||
assert!(first.is_empty());
|
assert!(first.is_empty());
|
||||||
assert!(second.is_empty());
|
assert!(second.is_empty());
|
||||||
|
|
|
||||||
|
|
@ -143,6 +143,7 @@ impl Memory for MarkdownMemory {
|
||||||
key: &str,
|
key: &str,
|
||||||
content: &str,
|
content: &str,
|
||||||
category: MemoryCategory,
|
category: MemoryCategory,
|
||||||
|
_session_id: Option<&str>,
|
||||||
) -> anyhow::Result<()> {
|
) -> anyhow::Result<()> {
|
||||||
let entry = format!("- **{key}**: {content}");
|
let entry = format!("- **{key}**: {content}");
|
||||||
let path = match category {
|
let path = match category {
|
||||||
|
|
@ -152,7 +153,12 @@ impl Memory for MarkdownMemory {
|
||||||
self.append_to_file(&path, &entry).await
|
self.append_to_file(&path, &entry).await
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn recall(&self, query: &str, limit: usize) -> anyhow::Result<Vec<MemoryEntry>> {
|
async fn recall(
|
||||||
|
&self,
|
||||||
|
query: &str,
|
||||||
|
limit: usize,
|
||||||
|
_session_id: Option<&str>,
|
||||||
|
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||||
let all = self.read_all_entries().await?;
|
let all = self.read_all_entries().await?;
|
||||||
let query_lower = query.to_lowercase();
|
let query_lower = query.to_lowercase();
|
||||||
let keywords: Vec<&str> = query_lower.split_whitespace().collect();
|
let keywords: Vec<&str> = query_lower.split_whitespace().collect();
|
||||||
|
|
@ -192,7 +198,11 @@ impl Memory for MarkdownMemory {
|
||||||
.find(|e| e.key == key || e.content.contains(key)))
|
.find(|e| e.key == key || e.content.contains(key)))
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn list(&self, category: Option<&MemoryCategory>) -> anyhow::Result<Vec<MemoryEntry>> {
|
async fn list(
|
||||||
|
&self,
|
||||||
|
category: Option<&MemoryCategory>,
|
||||||
|
_session_id: Option<&str>,
|
||||||
|
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||||
let all = self.read_all_entries().await?;
|
let all = self.read_all_entries().await?;
|
||||||
match category {
|
match category {
|
||||||
Some(cat) => Ok(all.into_iter().filter(|e| &e.category == cat).collect()),
|
Some(cat) => Ok(all.into_iter().filter(|e| &e.category == cat).collect()),
|
||||||
|
|
@ -243,7 +253,7 @@ mod tests {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn markdown_store_core() {
|
async fn markdown_store_core() {
|
||||||
let (_tmp, mem) = temp_workspace();
|
let (_tmp, mem) = temp_workspace();
|
||||||
mem.store("pref", "User likes Rust", MemoryCategory::Core)
|
mem.store("pref", "User likes Rust", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let content = sync_fs::read_to_string(mem.core_path()).unwrap();
|
let content = sync_fs::read_to_string(mem.core_path()).unwrap();
|
||||||
|
|
@ -253,7 +263,7 @@ mod tests {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn markdown_store_daily() {
|
async fn markdown_store_daily() {
|
||||||
let (_tmp, mem) = temp_workspace();
|
let (_tmp, mem) = temp_workspace();
|
||||||
mem.store("note", "Finished tests", MemoryCategory::Daily)
|
mem.store("note", "Finished tests", MemoryCategory::Daily, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let path = mem.daily_path();
|
let path = mem.daily_path();
|
||||||
|
|
@ -264,17 +274,17 @@ mod tests {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn markdown_recall_keyword() {
|
async fn markdown_recall_keyword() {
|
||||||
let (_tmp, mem) = temp_workspace();
|
let (_tmp, mem) = temp_workspace();
|
||||||
mem.store("a", "Rust is fast", MemoryCategory::Core)
|
mem.store("a", "Rust is fast", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
mem.store("b", "Python is slow", MemoryCategory::Core)
|
mem.store("b", "Python is slow", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
mem.store("c", "Rust and safety", MemoryCategory::Core)
|
mem.store("c", "Rust and safety", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let results = mem.recall("Rust", 10).await.unwrap();
|
let results = mem.recall("Rust", 10, None).await.unwrap();
|
||||||
assert!(results.len() >= 2);
|
assert!(results.len() >= 2);
|
||||||
assert!(results
|
assert!(results
|
||||||
.iter()
|
.iter()
|
||||||
|
|
@ -284,18 +294,20 @@ mod tests {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn markdown_recall_no_match() {
|
async fn markdown_recall_no_match() {
|
||||||
let (_tmp, mem) = temp_workspace();
|
let (_tmp, mem) = temp_workspace();
|
||||||
mem.store("a", "Rust is great", MemoryCategory::Core)
|
mem.store("a", "Rust is great", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let results = mem.recall("javascript", 10).await.unwrap();
|
let results = mem.recall("javascript", 10, None).await.unwrap();
|
||||||
assert!(results.is_empty());
|
assert!(results.is_empty());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn markdown_count() {
|
async fn markdown_count() {
|
||||||
let (_tmp, mem) = temp_workspace();
|
let (_tmp, mem) = temp_workspace();
|
||||||
mem.store("a", "first", MemoryCategory::Core).await.unwrap();
|
mem.store("a", "first", MemoryCategory::Core, None)
|
||||||
mem.store("b", "second", MemoryCategory::Core)
|
.await
|
||||||
|
.unwrap();
|
||||||
|
mem.store("b", "second", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let count = mem.count().await.unwrap();
|
let count = mem.count().await.unwrap();
|
||||||
|
|
@ -305,24 +317,24 @@ mod tests {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn markdown_list_by_category() {
|
async fn markdown_list_by_category() {
|
||||||
let (_tmp, mem) = temp_workspace();
|
let (_tmp, mem) = temp_workspace();
|
||||||
mem.store("a", "core fact", MemoryCategory::Core)
|
mem.store("a", "core fact", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
mem.store("b", "daily note", MemoryCategory::Daily)
|
mem.store("b", "daily note", MemoryCategory::Daily, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let core = mem.list(Some(&MemoryCategory::Core)).await.unwrap();
|
let core = mem.list(Some(&MemoryCategory::Core), None).await.unwrap();
|
||||||
assert!(core.iter().all(|e| e.category == MemoryCategory::Core));
|
assert!(core.iter().all(|e| e.category == MemoryCategory::Core));
|
||||||
|
|
||||||
let daily = mem.list(Some(&MemoryCategory::Daily)).await.unwrap();
|
let daily = mem.list(Some(&MemoryCategory::Daily), None).await.unwrap();
|
||||||
assert!(daily.iter().all(|e| e.category == MemoryCategory::Daily));
|
assert!(daily.iter().all(|e| e.category == MemoryCategory::Daily));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn markdown_forget_is_noop() {
|
async fn markdown_forget_is_noop() {
|
||||||
let (_tmp, mem) = temp_workspace();
|
let (_tmp, mem) = temp_workspace();
|
||||||
mem.store("a", "permanent", MemoryCategory::Core)
|
mem.store("a", "permanent", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let removed = mem.forget("a").await.unwrap();
|
let removed = mem.forget("a").await.unwrap();
|
||||||
|
|
@ -332,7 +344,7 @@ mod tests {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn markdown_empty_recall() {
|
async fn markdown_empty_recall() {
|
||||||
let (_tmp, mem) = temp_workspace();
|
let (_tmp, mem) = temp_workspace();
|
||||||
let results = mem.recall("anything", 10).await.unwrap();
|
let results = mem.recall("anything", 10, None).await.unwrap();
|
||||||
assert!(results.is_empty());
|
assert!(results.is_empty());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -25,11 +25,17 @@ impl Memory for NoneMemory {
|
||||||
_key: &str,
|
_key: &str,
|
||||||
_content: &str,
|
_content: &str,
|
||||||
_category: MemoryCategory,
|
_category: MemoryCategory,
|
||||||
|
_session_id: Option<&str>,
|
||||||
) -> anyhow::Result<()> {
|
) -> anyhow::Result<()> {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn recall(&self, _query: &str, _limit: usize) -> anyhow::Result<Vec<MemoryEntry>> {
|
async fn recall(
|
||||||
|
&self,
|
||||||
|
_query: &str,
|
||||||
|
_limit: usize,
|
||||||
|
_session_id: Option<&str>,
|
||||||
|
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||||
Ok(Vec::new())
|
Ok(Vec::new())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -37,7 +43,11 @@ impl Memory for NoneMemory {
|
||||||
Ok(None)
|
Ok(None)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn list(&self, _category: Option<&MemoryCategory>) -> anyhow::Result<Vec<MemoryEntry>> {
|
async fn list(
|
||||||
|
&self,
|
||||||
|
_category: Option<&MemoryCategory>,
|
||||||
|
_session_id: Option<&str>,
|
||||||
|
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||||
Ok(Vec::new())
|
Ok(Vec::new())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -62,11 +72,14 @@ mod tests {
|
||||||
async fn none_memory_is_noop() {
|
async fn none_memory_is_noop() {
|
||||||
let memory = NoneMemory::new();
|
let memory = NoneMemory::new();
|
||||||
|
|
||||||
memory.store("k", "v", MemoryCategory::Core).await.unwrap();
|
memory
|
||||||
|
.store("k", "v", MemoryCategory::Core, None)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
assert!(memory.get("k").await.unwrap().is_none());
|
assert!(memory.get("k").await.unwrap().is_none());
|
||||||
assert!(memory.recall("k", 10).await.unwrap().is_empty());
|
assert!(memory.recall("k", 10, None).await.unwrap().is_empty());
|
||||||
assert!(memory.list(None).await.unwrap().is_empty());
|
assert!(memory.list(None, None).await.unwrap().is_empty());
|
||||||
assert!(!memory.forget("k").await.unwrap());
|
assert!(!memory.forget("k").await.unwrap());
|
||||||
assert_eq!(memory.count().await.unwrap(), 0);
|
assert_eq!(memory.count().await.unwrap(), 0);
|
||||||
assert!(memory.health_check().await);
|
assert!(memory.health_check().await);
|
||||||
|
|
|
||||||
|
|
@ -157,7 +157,7 @@ impl ResponseCache {
|
||||||
|row| row.get(0),
|
|row| row.get(0),
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
#[allow(clippy::cast_sign_loss)]
|
#[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)]
|
||||||
Ok((count as usize, hits as u64, tokens_saved as u64))
|
Ok((count as usize, hits as u64, tokens_saved as u64))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -124,6 +124,19 @@ impl SqliteMemory {
|
||||||
);
|
);
|
||||||
CREATE INDEX IF NOT EXISTS idx_cache_accessed ON embedding_cache(accessed_at);",
|
CREATE INDEX IF NOT EXISTS idx_cache_accessed ON embedding_cache(accessed_at);",
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
|
// Migration: add session_id column if not present (safe to run repeatedly)
|
||||||
|
let has_session_id: bool = conn
|
||||||
|
.prepare("SELECT sql FROM sqlite_master WHERE type='table' AND name='memories'")?
|
||||||
|
.query_row([], |row| row.get::<_, String>(0))?
|
||||||
|
.contains("session_id");
|
||||||
|
if !has_session_id {
|
||||||
|
conn.execute_batch(
|
||||||
|
"ALTER TABLE memories ADD COLUMN session_id TEXT;
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_memories_session ON memories(session_id);",
|
||||||
|
)?;
|
||||||
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -361,6 +374,7 @@ impl Memory for SqliteMemory {
|
||||||
key: &str,
|
key: &str,
|
||||||
content: &str,
|
content: &str,
|
||||||
category: MemoryCategory,
|
category: MemoryCategory,
|
||||||
|
session_id: Option<&str>,
|
||||||
) -> anyhow::Result<()> {
|
) -> anyhow::Result<()> {
|
||||||
// Compute embedding (async, before lock)
|
// Compute embedding (async, before lock)
|
||||||
let embedding_bytes = self
|
let embedding_bytes = self
|
||||||
|
|
@ -377,20 +391,26 @@ impl Memory for SqliteMemory {
|
||||||
let id = Uuid::new_v4().to_string();
|
let id = Uuid::new_v4().to_string();
|
||||||
|
|
||||||
conn.execute(
|
conn.execute(
|
||||||
"INSERT INTO memories (id, key, content, category, embedding, created_at, updated_at)
|
"INSERT INTO memories (id, key, content, category, embedding, created_at, updated_at, session_id)
|
||||||
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7)
|
VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8)
|
||||||
ON CONFLICT(key) DO UPDATE SET
|
ON CONFLICT(key) DO UPDATE SET
|
||||||
content = excluded.content,
|
content = excluded.content,
|
||||||
category = excluded.category,
|
category = excluded.category,
|
||||||
embedding = excluded.embedding,
|
embedding = excluded.embedding,
|
||||||
updated_at = excluded.updated_at",
|
updated_at = excluded.updated_at,
|
||||||
params![id, key, content, cat, embedding_bytes, now, now],
|
session_id = excluded.session_id",
|
||||||
|
params![id, key, content, cat, embedding_bytes, now, now, session_id],
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn recall(&self, query: &str, limit: usize) -> anyhow::Result<Vec<MemoryEntry>> {
|
async fn recall(
|
||||||
|
&self,
|
||||||
|
query: &str,
|
||||||
|
limit: usize,
|
||||||
|
session_id: Option<&str>,
|
||||||
|
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||||
if query.trim().is_empty() {
|
if query.trim().is_empty() {
|
||||||
return Ok(Vec::new());
|
return Ok(Vec::new());
|
||||||
}
|
}
|
||||||
|
|
@ -439,7 +459,7 @@ impl Memory for SqliteMemory {
|
||||||
let mut results = Vec::new();
|
let mut results = Vec::new();
|
||||||
for scored in &merged {
|
for scored in &merged {
|
||||||
let mut stmt = conn.prepare(
|
let mut stmt = conn.prepare(
|
||||||
"SELECT id, key, content, category, created_at FROM memories WHERE id = ?1",
|
"SELECT id, key, content, category, created_at, session_id FROM memories WHERE id = ?1",
|
||||||
)?;
|
)?;
|
||||||
if let Ok(entry) = stmt.query_row(params![scored.id], |row| {
|
if let Ok(entry) = stmt.query_row(params![scored.id], |row| {
|
||||||
Ok(MemoryEntry {
|
Ok(MemoryEntry {
|
||||||
|
|
@ -448,10 +468,16 @@ impl Memory for SqliteMemory {
|
||||||
content: row.get(2)?,
|
content: row.get(2)?,
|
||||||
category: Self::str_to_category(&row.get::<_, String>(3)?),
|
category: Self::str_to_category(&row.get::<_, String>(3)?),
|
||||||
timestamp: row.get(4)?,
|
timestamp: row.get(4)?,
|
||||||
session_id: None,
|
session_id: row.get(5)?,
|
||||||
score: Some(f64::from(scored.final_score)),
|
score: Some(f64::from(scored.final_score)),
|
||||||
})
|
})
|
||||||
}) {
|
}) {
|
||||||
|
// Filter by session_id if requested
|
||||||
|
if let Some(sid) = session_id {
|
||||||
|
if entry.session_id.as_deref() != Some(sid) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
results.push(entry);
|
results.push(entry);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -470,7 +496,7 @@ impl Memory for SqliteMemory {
|
||||||
.collect();
|
.collect();
|
||||||
let where_clause = conditions.join(" OR ");
|
let where_clause = conditions.join(" OR ");
|
||||||
let sql = format!(
|
let sql = format!(
|
||||||
"SELECT id, key, content, category, created_at FROM memories
|
"SELECT id, key, content, category, created_at, session_id FROM memories
|
||||||
WHERE {where_clause}
|
WHERE {where_clause}
|
||||||
ORDER BY updated_at DESC
|
ORDER BY updated_at DESC
|
||||||
LIMIT ?{}",
|
LIMIT ?{}",
|
||||||
|
|
@ -493,12 +519,18 @@ impl Memory for SqliteMemory {
|
||||||
content: row.get(2)?,
|
content: row.get(2)?,
|
||||||
category: Self::str_to_category(&row.get::<_, String>(3)?),
|
category: Self::str_to_category(&row.get::<_, String>(3)?),
|
||||||
timestamp: row.get(4)?,
|
timestamp: row.get(4)?,
|
||||||
session_id: None,
|
session_id: row.get(5)?,
|
||||||
score: Some(1.0),
|
score: Some(1.0),
|
||||||
})
|
})
|
||||||
})?;
|
})?;
|
||||||
for row in rows {
|
for row in rows {
|
||||||
results.push(row?);
|
let entry = row?;
|
||||||
|
if let Some(sid) = session_id {
|
||||||
|
if entry.session_id.as_deref() != Some(sid) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
results.push(entry);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -514,7 +546,7 @@ impl Memory for SqliteMemory {
|
||||||
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
|
.map_err(|e| anyhow::anyhow!("Lock error: {e}"))?;
|
||||||
|
|
||||||
let mut stmt = conn.prepare(
|
let mut stmt = conn.prepare(
|
||||||
"SELECT id, key, content, category, created_at FROM memories WHERE key = ?1",
|
"SELECT id, key, content, category, created_at, session_id FROM memories WHERE key = ?1",
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
let mut rows = stmt.query_map(params![key], |row| {
|
let mut rows = stmt.query_map(params![key], |row| {
|
||||||
|
|
@ -524,7 +556,7 @@ impl Memory for SqliteMemory {
|
||||||
content: row.get(2)?,
|
content: row.get(2)?,
|
||||||
category: Self::str_to_category(&row.get::<_, String>(3)?),
|
category: Self::str_to_category(&row.get::<_, String>(3)?),
|
||||||
timestamp: row.get(4)?,
|
timestamp: row.get(4)?,
|
||||||
session_id: None,
|
session_id: row.get(5)?,
|
||||||
score: None,
|
score: None,
|
||||||
})
|
})
|
||||||
})?;
|
})?;
|
||||||
|
|
@ -535,7 +567,11 @@ impl Memory for SqliteMemory {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn list(&self, category: Option<&MemoryCategory>) -> anyhow::Result<Vec<MemoryEntry>> {
|
async fn list(
|
||||||
|
&self,
|
||||||
|
category: Option<&MemoryCategory>,
|
||||||
|
session_id: Option<&str>,
|
||||||
|
) -> anyhow::Result<Vec<MemoryEntry>> {
|
||||||
let conn = self
|
let conn = self
|
||||||
.conn
|
.conn
|
||||||
.lock()
|
.lock()
|
||||||
|
|
@ -550,7 +586,7 @@ impl Memory for SqliteMemory {
|
||||||
content: row.get(2)?,
|
content: row.get(2)?,
|
||||||
category: Self::str_to_category(&row.get::<_, String>(3)?),
|
category: Self::str_to_category(&row.get::<_, String>(3)?),
|
||||||
timestamp: row.get(4)?,
|
timestamp: row.get(4)?,
|
||||||
session_id: None,
|
session_id: row.get(5)?,
|
||||||
score: None,
|
score: None,
|
||||||
})
|
})
|
||||||
};
|
};
|
||||||
|
|
@ -558,21 +594,33 @@ impl Memory for SqliteMemory {
|
||||||
if let Some(cat) = category {
|
if let Some(cat) = category {
|
||||||
let cat_str = Self::category_to_str(cat);
|
let cat_str = Self::category_to_str(cat);
|
||||||
let mut stmt = conn.prepare(
|
let mut stmt = conn.prepare(
|
||||||
"SELECT id, key, content, category, created_at FROM memories
|
"SELECT id, key, content, category, created_at, session_id FROM memories
|
||||||
WHERE category = ?1 ORDER BY updated_at DESC",
|
WHERE category = ?1 ORDER BY updated_at DESC",
|
||||||
)?;
|
)?;
|
||||||
let rows = stmt.query_map(params![cat_str], row_mapper)?;
|
let rows = stmt.query_map(params![cat_str], row_mapper)?;
|
||||||
for row in rows {
|
for row in rows {
|
||||||
results.push(row?);
|
let entry = row?;
|
||||||
|
if let Some(sid) = session_id {
|
||||||
|
if entry.session_id.as_deref() != Some(sid) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
results.push(entry);
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
let mut stmt = conn.prepare(
|
let mut stmt = conn.prepare(
|
||||||
"SELECT id, key, content, category, created_at FROM memories
|
"SELECT id, key, content, category, created_at, session_id FROM memories
|
||||||
ORDER BY updated_at DESC",
|
ORDER BY updated_at DESC",
|
||||||
)?;
|
)?;
|
||||||
let rows = stmt.query_map([], row_mapper)?;
|
let rows = stmt.query_map([], row_mapper)?;
|
||||||
for row in rows {
|
for row in rows {
|
||||||
results.push(row?);
|
let entry = row?;
|
||||||
|
if let Some(sid) = session_id {
|
||||||
|
if entry.session_id.as_deref() != Some(sid) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
results.push(entry);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -632,7 +680,7 @@ mod tests {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn sqlite_store_and_get() {
|
async fn sqlite_store_and_get() {
|
||||||
let (_tmp, mem) = temp_sqlite();
|
let (_tmp, mem) = temp_sqlite();
|
||||||
mem.store("user_lang", "Prefers Rust", MemoryCategory::Core)
|
mem.store("user_lang", "Prefers Rust", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
|
|
@ -647,10 +695,10 @@ mod tests {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn sqlite_store_upsert() {
|
async fn sqlite_store_upsert() {
|
||||||
let (_tmp, mem) = temp_sqlite();
|
let (_tmp, mem) = temp_sqlite();
|
||||||
mem.store("pref", "likes Rust", MemoryCategory::Core)
|
mem.store("pref", "likes Rust", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
mem.store("pref", "loves Rust", MemoryCategory::Core)
|
mem.store("pref", "loves Rust", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
|
|
@ -662,17 +710,22 @@ mod tests {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn sqlite_recall_keyword() {
|
async fn sqlite_recall_keyword() {
|
||||||
let (_tmp, mem) = temp_sqlite();
|
let (_tmp, mem) = temp_sqlite();
|
||||||
mem.store("a", "Rust is fast and safe", MemoryCategory::Core)
|
mem.store("a", "Rust is fast and safe", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
mem.store("b", "Python is interpreted", MemoryCategory::Core)
|
mem.store("b", "Python is interpreted", MemoryCategory::Core, None)
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
mem.store("c", "Rust has zero-cost abstractions", MemoryCategory::Core)
|
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
mem.store(
|
||||||
|
"c",
|
||||||
|
"Rust has zero-cost abstractions",
|
||||||
|
MemoryCategory::Core,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let results = mem.recall("Rust", 10).await.unwrap();
|
let results = mem.recall("Rust", 10, None).await.unwrap();
|
||||||
assert_eq!(results.len(), 2);
|
assert_eq!(results.len(), 2);
|
||||||
assert!(results
|
assert!(results
|
||||||
.iter()
|
.iter()
|
||||||
|
|
@ -682,14 +735,14 @@ mod tests {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn sqlite_recall_multi_keyword() {
|
async fn sqlite_recall_multi_keyword() {
|
||||||
let (_tmp, mem) = temp_sqlite();
|
let (_tmp, mem) = temp_sqlite();
|
||||||
mem.store("a", "Rust is fast", MemoryCategory::Core)
|
mem.store("a", "Rust is fast", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
mem.store("b", "Rust is safe and fast", MemoryCategory::Core)
|
mem.store("b", "Rust is safe and fast", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let results = mem.recall("fast safe", 10).await.unwrap();
|
let results = mem.recall("fast safe", 10, None).await.unwrap();
|
||||||
assert!(!results.is_empty());
|
assert!(!results.is_empty());
|
||||||
// Entry with both keywords should score higher
|
// Entry with both keywords should score higher
|
||||||
assert!(results[0].content.contains("safe") && results[0].content.contains("fast"));
|
assert!(results[0].content.contains("safe") && results[0].content.contains("fast"));
|
||||||
|
|
@ -698,17 +751,17 @@ mod tests {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn sqlite_recall_no_match() {
|
async fn sqlite_recall_no_match() {
|
||||||
let (_tmp, mem) = temp_sqlite();
|
let (_tmp, mem) = temp_sqlite();
|
||||||
mem.store("a", "Rust rocks", MemoryCategory::Core)
|
mem.store("a", "Rust rocks", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let results = mem.recall("javascript", 10).await.unwrap();
|
let results = mem.recall("javascript", 10, None).await.unwrap();
|
||||||
assert!(results.is_empty());
|
assert!(results.is_empty());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn sqlite_forget() {
|
async fn sqlite_forget() {
|
||||||
let (_tmp, mem) = temp_sqlite();
|
let (_tmp, mem) = temp_sqlite();
|
||||||
mem.store("temp", "temporary data", MemoryCategory::Conversation)
|
mem.store("temp", "temporary data", MemoryCategory::Conversation, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
assert_eq!(mem.count().await.unwrap(), 1);
|
assert_eq!(mem.count().await.unwrap(), 1);
|
||||||
|
|
@ -728,29 +781,37 @@ mod tests {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn sqlite_list_all() {
|
async fn sqlite_list_all() {
|
||||||
let (_tmp, mem) = temp_sqlite();
|
let (_tmp, mem) = temp_sqlite();
|
||||||
mem.store("a", "one", MemoryCategory::Core).await.unwrap();
|
mem.store("a", "one", MemoryCategory::Core, None)
|
||||||
mem.store("b", "two", MemoryCategory::Daily).await.unwrap();
|
.await
|
||||||
mem.store("c", "three", MemoryCategory::Conversation)
|
.unwrap();
|
||||||
|
mem.store("b", "two", MemoryCategory::Daily, None)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
mem.store("c", "three", MemoryCategory::Conversation, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let all = mem.list(None).await.unwrap();
|
let all = mem.list(None, None).await.unwrap();
|
||||||
assert_eq!(all.len(), 3);
|
assert_eq!(all.len(), 3);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn sqlite_list_by_category() {
|
async fn sqlite_list_by_category() {
|
||||||
let (_tmp, mem) = temp_sqlite();
|
let (_tmp, mem) = temp_sqlite();
|
||||||
mem.store("a", "core1", MemoryCategory::Core).await.unwrap();
|
mem.store("a", "core1", MemoryCategory::Core, None)
|
||||||
mem.store("b", "core2", MemoryCategory::Core).await.unwrap();
|
.await
|
||||||
mem.store("c", "daily1", MemoryCategory::Daily)
|
.unwrap();
|
||||||
|
mem.store("b", "core2", MemoryCategory::Core, None)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
mem.store("c", "daily1", MemoryCategory::Daily, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let core = mem.list(Some(&MemoryCategory::Core)).await.unwrap();
|
let core = mem.list(Some(&MemoryCategory::Core), None).await.unwrap();
|
||||||
assert_eq!(core.len(), 2);
|
assert_eq!(core.len(), 2);
|
||||||
|
|
||||||
let daily = mem.list(Some(&MemoryCategory::Daily)).await.unwrap();
|
let daily = mem.list(Some(&MemoryCategory::Daily), None).await.unwrap();
|
||||||
assert_eq!(daily.len(), 1);
|
assert_eq!(daily.len(), 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -772,7 +833,7 @@ mod tests {
|
||||||
|
|
||||||
{
|
{
|
||||||
let mem = SqliteMemory::new(tmp.path()).unwrap();
|
let mem = SqliteMemory::new(tmp.path()).unwrap();
|
||||||
mem.store("persist", "I survive restarts", MemoryCategory::Core)
|
mem.store("persist", "I survive restarts", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
}
|
}
|
||||||
|
|
@ -795,7 +856,7 @@ mod tests {
|
||||||
];
|
];
|
||||||
|
|
||||||
for (i, cat) in categories.iter().enumerate() {
|
for (i, cat) in categories.iter().enumerate() {
|
||||||
mem.store(&format!("k{i}"), &format!("v{i}"), cat.clone())
|
mem.store(&format!("k{i}"), &format!("v{i}"), cat.clone(), None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
}
|
}
|
||||||
|
|
@ -815,21 +876,28 @@ mod tests {
|
||||||
"a",
|
"a",
|
||||||
"Rust is a systems programming language",
|
"Rust is a systems programming language",
|
||||||
MemoryCategory::Core,
|
MemoryCategory::Core,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
mem.store(
|
||||||
|
"b",
|
||||||
|
"Python is great for scripting",
|
||||||
|
MemoryCategory::Core,
|
||||||
|
None,
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
mem.store("b", "Python is great for scripting", MemoryCategory::Core)
|
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
mem.store(
|
mem.store(
|
||||||
"c",
|
"c",
|
||||||
"Rust and Rust and Rust everywhere",
|
"Rust and Rust and Rust everywhere",
|
||||||
MemoryCategory::Core,
|
MemoryCategory::Core,
|
||||||
|
None,
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let results = mem.recall("Rust", 10).await.unwrap();
|
let results = mem.recall("Rust", 10, None).await.unwrap();
|
||||||
assert!(results.len() >= 2);
|
assert!(results.len() >= 2);
|
||||||
// All results should contain "Rust"
|
// All results should contain "Rust"
|
||||||
for r in &results {
|
for r in &results {
|
||||||
|
|
@ -844,17 +912,17 @@ mod tests {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn fts5_multi_word_query() {
|
async fn fts5_multi_word_query() {
|
||||||
let (_tmp, mem) = temp_sqlite();
|
let (_tmp, mem) = temp_sqlite();
|
||||||
mem.store("a", "The quick brown fox jumps", MemoryCategory::Core)
|
mem.store("a", "The quick brown fox jumps", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
mem.store("b", "A lazy dog sleeps", MemoryCategory::Core)
|
mem.store("b", "A lazy dog sleeps", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
mem.store("c", "The quick dog runs fast", MemoryCategory::Core)
|
mem.store("c", "The quick dog runs fast", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let results = mem.recall("quick dog", 10).await.unwrap();
|
let results = mem.recall("quick dog", 10, None).await.unwrap();
|
||||||
assert!(!results.is_empty());
|
assert!(!results.is_empty());
|
||||||
// "The quick dog runs fast" matches both terms
|
// "The quick dog runs fast" matches both terms
|
||||||
assert!(results[0].content.contains("quick"));
|
assert!(results[0].content.contains("quick"));
|
||||||
|
|
@ -863,16 +931,20 @@ mod tests {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn recall_empty_query_returns_empty() {
|
async fn recall_empty_query_returns_empty() {
|
||||||
let (_tmp, mem) = temp_sqlite();
|
let (_tmp, mem) = temp_sqlite();
|
||||||
mem.store("a", "data", MemoryCategory::Core).await.unwrap();
|
mem.store("a", "data", MemoryCategory::Core, None)
|
||||||
let results = mem.recall("", 10).await.unwrap();
|
.await
|
||||||
|
.unwrap();
|
||||||
|
let results = mem.recall("", 10, None).await.unwrap();
|
||||||
assert!(results.is_empty());
|
assert!(results.is_empty());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn recall_whitespace_query_returns_empty() {
|
async fn recall_whitespace_query_returns_empty() {
|
||||||
let (_tmp, mem) = temp_sqlite();
|
let (_tmp, mem) = temp_sqlite();
|
||||||
mem.store("a", "data", MemoryCategory::Core).await.unwrap();
|
mem.store("a", "data", MemoryCategory::Core, None)
|
||||||
let results = mem.recall(" ", 10).await.unwrap();
|
.await
|
||||||
|
.unwrap();
|
||||||
|
let results = mem.recall(" ", 10, None).await.unwrap();
|
||||||
assert!(results.is_empty());
|
assert!(results.is_empty());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -937,9 +1009,14 @@ mod tests {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn fts5_syncs_on_insert() {
|
async fn fts5_syncs_on_insert() {
|
||||||
let (_tmp, mem) = temp_sqlite();
|
let (_tmp, mem) = temp_sqlite();
|
||||||
mem.store("test_key", "unique_searchterm_xyz", MemoryCategory::Core)
|
mem.store(
|
||||||
.await
|
"test_key",
|
||||||
.unwrap();
|
"unique_searchterm_xyz",
|
||||||
|
MemoryCategory::Core,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
let conn = mem.conn.lock();
|
let conn = mem.conn.lock();
|
||||||
let count: i64 = conn
|
let count: i64 = conn
|
||||||
|
|
@ -955,9 +1032,14 @@ mod tests {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn fts5_syncs_on_delete() {
|
async fn fts5_syncs_on_delete() {
|
||||||
let (_tmp, mem) = temp_sqlite();
|
let (_tmp, mem) = temp_sqlite();
|
||||||
mem.store("del_key", "deletable_content_abc", MemoryCategory::Core)
|
mem.store(
|
||||||
.await
|
"del_key",
|
||||||
.unwrap();
|
"deletable_content_abc",
|
||||||
|
MemoryCategory::Core,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
mem.forget("del_key").await.unwrap();
|
mem.forget("del_key").await.unwrap();
|
||||||
|
|
||||||
let conn = mem.conn.lock();
|
let conn = mem.conn.lock();
|
||||||
|
|
@ -974,10 +1056,15 @@ mod tests {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn fts5_syncs_on_update() {
|
async fn fts5_syncs_on_update() {
|
||||||
let (_tmp, mem) = temp_sqlite();
|
let (_tmp, mem) = temp_sqlite();
|
||||||
mem.store("upd_key", "original_content_111", MemoryCategory::Core)
|
mem.store(
|
||||||
.await
|
"upd_key",
|
||||||
.unwrap();
|
"original_content_111",
|
||||||
mem.store("upd_key", "updated_content_222", MemoryCategory::Core)
|
MemoryCategory::Core,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
mem.store("upd_key", "updated_content_222", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
|
|
@ -1019,10 +1106,10 @@ mod tests {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn reindex_rebuilds_fts() {
|
async fn reindex_rebuilds_fts() {
|
||||||
let (_tmp, mem) = temp_sqlite();
|
let (_tmp, mem) = temp_sqlite();
|
||||||
mem.store("r1", "reindex test alpha", MemoryCategory::Core)
|
mem.store("r1", "reindex test alpha", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
mem.store("r2", "reindex test beta", MemoryCategory::Core)
|
mem.store("r2", "reindex test beta", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
|
|
@ -1031,7 +1118,7 @@ mod tests {
|
||||||
assert_eq!(count, 0);
|
assert_eq!(count, 0);
|
||||||
|
|
||||||
// FTS should still work after rebuild
|
// FTS should still work after rebuild
|
||||||
let results = mem.recall("reindex", 10).await.unwrap();
|
let results = mem.recall("reindex", 10, None).await.unwrap();
|
||||||
assert_eq!(results.len(), 2);
|
assert_eq!(results.len(), 2);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1045,12 +1132,13 @@ mod tests {
|
||||||
&format!("k{i}"),
|
&format!("k{i}"),
|
||||||
&format!("common keyword item {i}"),
|
&format!("common keyword item {i}"),
|
||||||
MemoryCategory::Core,
|
MemoryCategory::Core,
|
||||||
|
None,
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
let results = mem.recall("common keyword", 5).await.unwrap();
|
let results = mem.recall("common keyword", 5, None).await.unwrap();
|
||||||
assert!(results.len() <= 5);
|
assert!(results.len() <= 5);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1059,11 +1147,11 @@ mod tests {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn recall_results_have_scores() {
|
async fn recall_results_have_scores() {
|
||||||
let (_tmp, mem) = temp_sqlite();
|
let (_tmp, mem) = temp_sqlite();
|
||||||
mem.store("s1", "scored result test", MemoryCategory::Core)
|
mem.store("s1", "scored result test", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let results = mem.recall("scored", 10).await.unwrap();
|
let results = mem.recall("scored", 10, None).await.unwrap();
|
||||||
assert!(!results.is_empty());
|
assert!(!results.is_empty());
|
||||||
for r in &results {
|
for r in &results {
|
||||||
assert!(r.score.is_some(), "Expected score on result: {:?}", r.key);
|
assert!(r.score.is_some(), "Expected score on result: {:?}", r.key);
|
||||||
|
|
@ -1075,11 +1163,11 @@ mod tests {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn recall_with_quotes_in_query() {
|
async fn recall_with_quotes_in_query() {
|
||||||
let (_tmp, mem) = temp_sqlite();
|
let (_tmp, mem) = temp_sqlite();
|
||||||
mem.store("q1", "He said hello world", MemoryCategory::Core)
|
mem.store("q1", "He said hello world", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
// Quotes in query should not crash FTS5
|
// Quotes in query should not crash FTS5
|
||||||
let results = mem.recall("\"hello\"", 10).await.unwrap();
|
let results = mem.recall("\"hello\"", 10, None).await.unwrap();
|
||||||
// May or may not match depending on FTS5 escaping, but must not error
|
// May or may not match depending on FTS5 escaping, but must not error
|
||||||
assert!(results.len() <= 10);
|
assert!(results.len() <= 10);
|
||||||
}
|
}
|
||||||
|
|
@ -1087,31 +1175,34 @@ mod tests {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn recall_with_asterisk_in_query() {
|
async fn recall_with_asterisk_in_query() {
|
||||||
let (_tmp, mem) = temp_sqlite();
|
let (_tmp, mem) = temp_sqlite();
|
||||||
mem.store("a1", "wildcard test content", MemoryCategory::Core)
|
mem.store("a1", "wildcard test content", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let results = mem.recall("wild*", 10).await.unwrap();
|
let results = mem.recall("wild*", 10, None).await.unwrap();
|
||||||
assert!(results.len() <= 10);
|
assert!(results.len() <= 10);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn recall_with_parentheses_in_query() {
|
async fn recall_with_parentheses_in_query() {
|
||||||
let (_tmp, mem) = temp_sqlite();
|
let (_tmp, mem) = temp_sqlite();
|
||||||
mem.store("p1", "function call test", MemoryCategory::Core)
|
mem.store("p1", "function call test", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let results = mem.recall("function()", 10).await.unwrap();
|
let results = mem.recall("function()", 10, None).await.unwrap();
|
||||||
assert!(results.len() <= 10);
|
assert!(results.len() <= 10);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn recall_with_sql_injection_attempt() {
|
async fn recall_with_sql_injection_attempt() {
|
||||||
let (_tmp, mem) = temp_sqlite();
|
let (_tmp, mem) = temp_sqlite();
|
||||||
mem.store("safe", "normal content", MemoryCategory::Core)
|
mem.store("safe", "normal content", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
// Should not crash or leak data
|
// Should not crash or leak data
|
||||||
let results = mem.recall("'; DROP TABLE memories; --", 10).await.unwrap();
|
let results = mem
|
||||||
|
.recall("'; DROP TABLE memories; --", 10, None)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
assert!(results.len() <= 10);
|
assert!(results.len() <= 10);
|
||||||
// Table should still exist
|
// Table should still exist
|
||||||
assert_eq!(mem.count().await.unwrap(), 1);
|
assert_eq!(mem.count().await.unwrap(), 1);
|
||||||
|
|
@ -1122,7 +1213,9 @@ mod tests {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn store_empty_content() {
|
async fn store_empty_content() {
|
||||||
let (_tmp, mem) = temp_sqlite();
|
let (_tmp, mem) = temp_sqlite();
|
||||||
mem.store("empty", "", MemoryCategory::Core).await.unwrap();
|
mem.store("empty", "", MemoryCategory::Core, None)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
let entry = mem.get("empty").await.unwrap().unwrap();
|
let entry = mem.get("empty").await.unwrap().unwrap();
|
||||||
assert_eq!(entry.content, "");
|
assert_eq!(entry.content, "");
|
||||||
}
|
}
|
||||||
|
|
@ -1130,7 +1223,7 @@ mod tests {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn store_empty_key() {
|
async fn store_empty_key() {
|
||||||
let (_tmp, mem) = temp_sqlite();
|
let (_tmp, mem) = temp_sqlite();
|
||||||
mem.store("", "content for empty key", MemoryCategory::Core)
|
mem.store("", "content for empty key", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let entry = mem.get("").await.unwrap().unwrap();
|
let entry = mem.get("").await.unwrap().unwrap();
|
||||||
|
|
@ -1141,7 +1234,7 @@ mod tests {
|
||||||
async fn store_very_long_content() {
|
async fn store_very_long_content() {
|
||||||
let (_tmp, mem) = temp_sqlite();
|
let (_tmp, mem) = temp_sqlite();
|
||||||
let long_content = "x".repeat(100_000);
|
let long_content = "x".repeat(100_000);
|
||||||
mem.store("long", &long_content, MemoryCategory::Core)
|
mem.store("long", &long_content, MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let entry = mem.get("long").await.unwrap().unwrap();
|
let entry = mem.get("long").await.unwrap().unwrap();
|
||||||
|
|
@ -1151,9 +1244,14 @@ mod tests {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn store_unicode_and_emoji() {
|
async fn store_unicode_and_emoji() {
|
||||||
let (_tmp, mem) = temp_sqlite();
|
let (_tmp, mem) = temp_sqlite();
|
||||||
mem.store("emoji_key_🦀", "こんにちは 🚀 Ñoño", MemoryCategory::Core)
|
mem.store(
|
||||||
.await
|
"emoji_key_🦀",
|
||||||
.unwrap();
|
"こんにちは 🚀 Ñoño",
|
||||||
|
MemoryCategory::Core,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
let entry = mem.get("emoji_key_🦀").await.unwrap().unwrap();
|
let entry = mem.get("emoji_key_🦀").await.unwrap().unwrap();
|
||||||
assert_eq!(entry.content, "こんにちは 🚀 Ñoño");
|
assert_eq!(entry.content, "こんにちは 🚀 Ñoño");
|
||||||
}
|
}
|
||||||
|
|
@ -1162,7 +1260,7 @@ mod tests {
|
||||||
async fn store_content_with_newlines_and_tabs() {
|
async fn store_content_with_newlines_and_tabs() {
|
||||||
let (_tmp, mem) = temp_sqlite();
|
let (_tmp, mem) = temp_sqlite();
|
||||||
let content = "line1\nline2\ttab\rcarriage\n\nnewparagraph";
|
let content = "line1\nline2\ttab\rcarriage\n\nnewparagraph";
|
||||||
mem.store("whitespace", content, MemoryCategory::Core)
|
mem.store("whitespace", content, MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let entry = mem.get("whitespace").await.unwrap().unwrap();
|
let entry = mem.get("whitespace").await.unwrap().unwrap();
|
||||||
|
|
@ -1174,11 +1272,11 @@ mod tests {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn recall_single_character_query() {
|
async fn recall_single_character_query() {
|
||||||
let (_tmp, mem) = temp_sqlite();
|
let (_tmp, mem) = temp_sqlite();
|
||||||
mem.store("a", "x marks the spot", MemoryCategory::Core)
|
mem.store("a", "x marks the spot", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
// Single char may not match FTS5 but LIKE fallback should work
|
// Single char may not match FTS5 but LIKE fallback should work
|
||||||
let results = mem.recall("x", 10).await.unwrap();
|
let results = mem.recall("x", 10, None).await.unwrap();
|
||||||
// Should not crash; may or may not find results
|
// Should not crash; may or may not find results
|
||||||
assert!(results.len() <= 10);
|
assert!(results.len() <= 10);
|
||||||
}
|
}
|
||||||
|
|
@ -1186,23 +1284,23 @@ mod tests {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn recall_limit_zero() {
|
async fn recall_limit_zero() {
|
||||||
let (_tmp, mem) = temp_sqlite();
|
let (_tmp, mem) = temp_sqlite();
|
||||||
mem.store("a", "some content", MemoryCategory::Core)
|
mem.store("a", "some content", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let results = mem.recall("some", 0).await.unwrap();
|
let results = mem.recall("some", 0, None).await.unwrap();
|
||||||
assert!(results.is_empty());
|
assert!(results.is_empty());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn recall_limit_one() {
|
async fn recall_limit_one() {
|
||||||
let (_tmp, mem) = temp_sqlite();
|
let (_tmp, mem) = temp_sqlite();
|
||||||
mem.store("a", "matching content alpha", MemoryCategory::Core)
|
mem.store("a", "matching content alpha", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
mem.store("b", "matching content beta", MemoryCategory::Core)
|
mem.store("b", "matching content beta", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let results = mem.recall("matching content", 1).await.unwrap();
|
let results = mem.recall("matching content", 1, None).await.unwrap();
|
||||||
assert_eq!(results.len(), 1);
|
assert_eq!(results.len(), 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1213,21 +1311,22 @@ mod tests {
|
||||||
"rust_preferences",
|
"rust_preferences",
|
||||||
"User likes systems programming",
|
"User likes systems programming",
|
||||||
MemoryCategory::Core,
|
MemoryCategory::Core,
|
||||||
|
None,
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
// "rust" appears in key but not content — LIKE fallback checks key too
|
// "rust" appears in key but not content — LIKE fallback checks key too
|
||||||
let results = mem.recall("rust", 10).await.unwrap();
|
let results = mem.recall("rust", 10, None).await.unwrap();
|
||||||
assert!(!results.is_empty(), "Should match by key");
|
assert!(!results.is_empty(), "Should match by key");
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn recall_unicode_query() {
|
async fn recall_unicode_query() {
|
||||||
let (_tmp, mem) = temp_sqlite();
|
let (_tmp, mem) = temp_sqlite();
|
||||||
mem.store("jp", "日本語のテスト", MemoryCategory::Core)
|
mem.store("jp", "日本語のテスト", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let results = mem.recall("日本語", 10).await.unwrap();
|
let results = mem.recall("日本語", 10, None).await.unwrap();
|
||||||
assert!(!results.is_empty());
|
assert!(!results.is_empty());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1238,7 +1337,9 @@ mod tests {
|
||||||
let tmp = TempDir::new().unwrap();
|
let tmp = TempDir::new().unwrap();
|
||||||
{
|
{
|
||||||
let mem = SqliteMemory::new(tmp.path()).unwrap();
|
let mem = SqliteMemory::new(tmp.path()).unwrap();
|
||||||
mem.store("k1", "v1", MemoryCategory::Core).await.unwrap();
|
mem.store("k1", "v1", MemoryCategory::Core, None)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
}
|
}
|
||||||
// Open again — init_schema runs again on existing DB
|
// Open again — init_schema runs again on existing DB
|
||||||
let mem2 = SqliteMemory::new(tmp.path()).unwrap();
|
let mem2 = SqliteMemory::new(tmp.path()).unwrap();
|
||||||
|
|
@ -1246,7 +1347,9 @@ mod tests {
|
||||||
assert!(entry.is_some());
|
assert!(entry.is_some());
|
||||||
assert_eq!(entry.unwrap().content, "v1");
|
assert_eq!(entry.unwrap().content, "v1");
|
||||||
// Store more data — should work fine
|
// Store more data — should work fine
|
||||||
mem2.store("k2", "v2", MemoryCategory::Daily).await.unwrap();
|
mem2.store("k2", "v2", MemoryCategory::Daily, None)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
assert_eq!(mem2.count().await.unwrap(), 2);
|
assert_eq!(mem2.count().await.unwrap(), 2);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1264,11 +1367,16 @@ mod tests {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn forget_then_recall_no_ghost_results() {
|
async fn forget_then_recall_no_ghost_results() {
|
||||||
let (_tmp, mem) = temp_sqlite();
|
let (_tmp, mem) = temp_sqlite();
|
||||||
mem.store("ghost", "phantom memory content", MemoryCategory::Core)
|
mem.store(
|
||||||
.await
|
"ghost",
|
||||||
.unwrap();
|
"phantom memory content",
|
||||||
|
MemoryCategory::Core,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
mem.forget("ghost").await.unwrap();
|
mem.forget("ghost").await.unwrap();
|
||||||
let results = mem.recall("phantom memory", 10).await.unwrap();
|
let results = mem.recall("phantom memory", 10, None).await.unwrap();
|
||||||
assert!(
|
assert!(
|
||||||
results.is_empty(),
|
results.is_empty(),
|
||||||
"Deleted memory should not appear in recall"
|
"Deleted memory should not appear in recall"
|
||||||
|
|
@ -1278,11 +1386,11 @@ mod tests {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn forget_and_re_store_same_key() {
|
async fn forget_and_re_store_same_key() {
|
||||||
let (_tmp, mem) = temp_sqlite();
|
let (_tmp, mem) = temp_sqlite();
|
||||||
mem.store("cycle", "version 1", MemoryCategory::Core)
|
mem.store("cycle", "version 1", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
mem.forget("cycle").await.unwrap();
|
mem.forget("cycle").await.unwrap();
|
||||||
mem.store("cycle", "version 2", MemoryCategory::Core)
|
mem.store("cycle", "version 2", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
let entry = mem.get("cycle").await.unwrap().unwrap();
|
let entry = mem.get("cycle").await.unwrap().unwrap();
|
||||||
|
|
@ -1302,14 +1410,14 @@ mod tests {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn reindex_twice_is_safe() {
|
async fn reindex_twice_is_safe() {
|
||||||
let (_tmp, mem) = temp_sqlite();
|
let (_tmp, mem) = temp_sqlite();
|
||||||
mem.store("r1", "reindex data", MemoryCategory::Core)
|
mem.store("r1", "reindex data", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
mem.reindex().await.unwrap();
|
mem.reindex().await.unwrap();
|
||||||
let count = mem.reindex().await.unwrap();
|
let count = mem.reindex().await.unwrap();
|
||||||
assert_eq!(count, 0); // Noop embedder → nothing to re-embed
|
assert_eq!(count, 0); // Noop embedder → nothing to re-embed
|
||||||
// Data should still be intact
|
// Data should still be intact
|
||||||
let results = mem.recall("reindex", 10).await.unwrap();
|
let results = mem.recall("reindex", 10, None).await.unwrap();
|
||||||
assert_eq!(results.len(), 1);
|
assert_eq!(results.len(), 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -1363,18 +1471,28 @@ mod tests {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn list_custom_category() {
|
async fn list_custom_category() {
|
||||||
let (_tmp, mem) = temp_sqlite();
|
let (_tmp, mem) = temp_sqlite();
|
||||||
mem.store("c1", "custom1", MemoryCategory::Custom("project".into()))
|
mem.store(
|
||||||
.await
|
"c1",
|
||||||
.unwrap();
|
"custom1",
|
||||||
mem.store("c2", "custom2", MemoryCategory::Custom("project".into()))
|
MemoryCategory::Custom("project".into()),
|
||||||
.await
|
None,
|
||||||
.unwrap();
|
)
|
||||||
mem.store("c3", "other", MemoryCategory::Core)
|
.await
|
||||||
|
.unwrap();
|
||||||
|
mem.store(
|
||||||
|
"c2",
|
||||||
|
"custom2",
|
||||||
|
MemoryCategory::Custom("project".into()),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
mem.store("c3", "other", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let project = mem
|
let project = mem
|
||||||
.list(Some(&MemoryCategory::Custom("project".into())))
|
.list(Some(&MemoryCategory::Custom("project".into())), None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
assert_eq!(project.len(), 2);
|
assert_eq!(project.len(), 2);
|
||||||
|
|
@ -1383,7 +1501,122 @@ mod tests {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn list_empty_db() {
|
async fn list_empty_db() {
|
||||||
let (_tmp, mem) = temp_sqlite();
|
let (_tmp, mem) = temp_sqlite();
|
||||||
let all = mem.list(None).await.unwrap();
|
let all = mem.list(None, None).await.unwrap();
|
||||||
assert!(all.is_empty());
|
assert!(all.is_empty());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ── Session isolation ─────────────────────────────────────────
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn store_and_recall_with_session_id() {
|
||||||
|
let (_tmp, mem) = temp_sqlite();
|
||||||
|
mem.store("k1", "session A fact", MemoryCategory::Core, Some("sess-a"))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
mem.store("k2", "session B fact", MemoryCategory::Core, Some("sess-b"))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
mem.store("k3", "no session fact", MemoryCategory::Core, None)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// Recall with session-a filter returns only session-a entry
|
||||||
|
let results = mem.recall("fact", 10, Some("sess-a")).await.unwrap();
|
||||||
|
assert_eq!(results.len(), 1);
|
||||||
|
assert_eq!(results[0].key, "k1");
|
||||||
|
assert_eq!(results[0].session_id.as_deref(), Some("sess-a"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn recall_no_session_filter_returns_all() {
|
||||||
|
let (_tmp, mem) = temp_sqlite();
|
||||||
|
mem.store("k1", "alpha fact", MemoryCategory::Core, Some("sess-a"))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
mem.store("k2", "beta fact", MemoryCategory::Core, Some("sess-b"))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
mem.store("k3", "gamma fact", MemoryCategory::Core, None)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// Recall without session filter returns all matching entries
|
||||||
|
let results = mem.recall("fact", 10, None).await.unwrap();
|
||||||
|
assert_eq!(results.len(), 3);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn cross_session_recall_isolation() {
|
||||||
|
let (_tmp, mem) = temp_sqlite();
|
||||||
|
mem.store(
|
||||||
|
"secret",
|
||||||
|
"session A secret data",
|
||||||
|
MemoryCategory::Core,
|
||||||
|
Some("sess-a"),
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// Session B cannot see session A data
|
||||||
|
let results = mem.recall("secret", 10, Some("sess-b")).await.unwrap();
|
||||||
|
assert!(results.is_empty());
|
||||||
|
|
||||||
|
// Session A can see its own data
|
||||||
|
let results = mem.recall("secret", 10, Some("sess-a")).await.unwrap();
|
||||||
|
assert_eq!(results.len(), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn list_with_session_filter() {
|
||||||
|
let (_tmp, mem) = temp_sqlite();
|
||||||
|
mem.store("k1", "a1", MemoryCategory::Core, Some("sess-a"))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
mem.store("k2", "a2", MemoryCategory::Conversation, Some("sess-a"))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
mem.store("k3", "b1", MemoryCategory::Core, Some("sess-b"))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
mem.store("k4", "none1", MemoryCategory::Core, None)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// List with session-a filter
|
||||||
|
let results = mem.list(None, Some("sess-a")).await.unwrap();
|
||||||
|
assert_eq!(results.len(), 2);
|
||||||
|
assert!(results
|
||||||
|
.iter()
|
||||||
|
.all(|e| e.session_id.as_deref() == Some("sess-a")));
|
||||||
|
|
||||||
|
// List with session-a + category filter
|
||||||
|
let results = mem
|
||||||
|
.list(Some(&MemoryCategory::Core), Some("sess-a"))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
assert_eq!(results.len(), 1);
|
||||||
|
assert_eq!(results[0].key, "k1");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn schema_migration_idempotent_on_reopen() {
|
||||||
|
let tmp = TempDir::new().unwrap();
|
||||||
|
|
||||||
|
// First open: creates schema + migration
|
||||||
|
{
|
||||||
|
let mem = SqliteMemory::new(tmp.path()).unwrap();
|
||||||
|
mem.store("k1", "before reopen", MemoryCategory::Core, Some("sess-x"))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Second open: migration runs again but is idempotent
|
||||||
|
{
|
||||||
|
let mem = SqliteMemory::new(tmp.path()).unwrap();
|
||||||
|
let results = mem.recall("reopen", 10, Some("sess-x")).await.unwrap();
|
||||||
|
assert_eq!(results.len(), 1);
|
||||||
|
assert_eq!(results[0].key, "k1");
|
||||||
|
assert_eq!(results[0].session_id.as_deref(), Some("sess-x"));
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -44,18 +44,32 @@ pub trait Memory: Send + Sync {
|
||||||
/// Backend name
|
/// Backend name
|
||||||
fn name(&self) -> &str;
|
fn name(&self) -> &str;
|
||||||
|
|
||||||
/// Store a memory entry
|
/// Store a memory entry, optionally scoped to a session
|
||||||
async fn store(&self, key: &str, content: &str, category: MemoryCategory)
|
async fn store(
|
||||||
-> anyhow::Result<()>;
|
&self,
|
||||||
|
key: &str,
|
||||||
|
content: &str,
|
||||||
|
category: MemoryCategory,
|
||||||
|
session_id: Option<&str>,
|
||||||
|
) -> anyhow::Result<()>;
|
||||||
|
|
||||||
/// Recall memories matching a query (keyword search)
|
/// Recall memories matching a query (keyword search), optionally scoped to a session
|
||||||
async fn recall(&self, query: &str, limit: usize) -> anyhow::Result<Vec<MemoryEntry>>;
|
async fn recall(
|
||||||
|
&self,
|
||||||
|
query: &str,
|
||||||
|
limit: usize,
|
||||||
|
session_id: Option<&str>,
|
||||||
|
) -> anyhow::Result<Vec<MemoryEntry>>;
|
||||||
|
|
||||||
/// Get a specific memory by key
|
/// Get a specific memory by key
|
||||||
async fn get(&self, key: &str) -> anyhow::Result<Option<MemoryEntry>>;
|
async fn get(&self, key: &str) -> anyhow::Result<Option<MemoryEntry>>;
|
||||||
|
|
||||||
/// List all memory keys, optionally filtered by category
|
/// List all memory keys, optionally filtered by category and/or session
|
||||||
async fn list(&self, category: Option<&MemoryCategory>) -> anyhow::Result<Vec<MemoryEntry>>;
|
async fn list(
|
||||||
|
&self,
|
||||||
|
category: Option<&MemoryCategory>,
|
||||||
|
session_id: Option<&str>,
|
||||||
|
) -> anyhow::Result<Vec<MemoryEntry>>;
|
||||||
|
|
||||||
/// Remove a memory by key
|
/// Remove a memory by key
|
||||||
async fn forget(&self, key: &str) -> anyhow::Result<bool>;
|
async fn forget(&self, key: &str) -> anyhow::Result<bool>;
|
||||||
|
|
|
||||||
|
|
@ -95,7 +95,9 @@ async fn migrate_openclaw_memory(
|
||||||
stats.renamed_conflicts += 1;
|
stats.renamed_conflicts += 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
memory.store(&key, &entry.content, entry.category).await?;
|
memory
|
||||||
|
.store(&key, &entry.content, entry.category, None)
|
||||||
|
.await?;
|
||||||
stats.imported += 1;
|
stats.imported += 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -488,7 +490,7 @@ mod tests {
|
||||||
// Existing target memory
|
// Existing target memory
|
||||||
let target_mem = SqliteMemory::new(target.path()).unwrap();
|
let target_mem = SqliteMemory::new(target.path()).unwrap();
|
||||||
target_mem
|
target_mem
|
||||||
.store("k", "new value", MemoryCategory::Core)
|
.store("k", "new value", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
|
|
@ -510,7 +512,7 @@ mod tests {
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let all = target_mem.list(None).await.unwrap();
|
let all = target_mem.list(None, None).await.unwrap();
|
||||||
assert!(all.iter().any(|e| e.key == "k" && e.content == "new value"));
|
assert!(all.iter().any(|e| e.key == "k" && e.content == "new value"));
|
||||||
assert!(all
|
assert!(all
|
||||||
.iter()
|
.iter()
|
||||||
|
|
|
||||||
|
|
@ -48,9 +48,10 @@ impl Observer for LogObserver {
|
||||||
ObserverEvent::AgentEnd {
|
ObserverEvent::AgentEnd {
|
||||||
duration,
|
duration,
|
||||||
tokens_used,
|
tokens_used,
|
||||||
|
cost_usd,
|
||||||
} => {
|
} => {
|
||||||
let ms = u64::try_from(duration.as_millis()).unwrap_or(u64::MAX);
|
let ms = u64::try_from(duration.as_millis()).unwrap_or(u64::MAX);
|
||||||
info!(duration_ms = ms, tokens = ?tokens_used, "agent.end");
|
info!(duration_ms = ms, tokens = ?tokens_used, cost_usd = ?cost_usd, "agent.end");
|
||||||
}
|
}
|
||||||
ObserverEvent::ToolCallStart { tool } => {
|
ObserverEvent::ToolCallStart { tool } => {
|
||||||
info!(tool = %tool, "tool.start");
|
info!(tool = %tool, "tool.start");
|
||||||
|
|
@ -133,10 +134,12 @@ mod tests {
|
||||||
obs.record_event(&ObserverEvent::AgentEnd {
|
obs.record_event(&ObserverEvent::AgentEnd {
|
||||||
duration: Duration::from_millis(500),
|
duration: Duration::from_millis(500),
|
||||||
tokens_used: Some(100),
|
tokens_used: Some(100),
|
||||||
|
cost_usd: Some(0.0015),
|
||||||
});
|
});
|
||||||
obs.record_event(&ObserverEvent::AgentEnd {
|
obs.record_event(&ObserverEvent::AgentEnd {
|
||||||
duration: Duration::ZERO,
|
duration: Duration::ZERO,
|
||||||
tokens_used: None,
|
tokens_used: None,
|
||||||
|
cost_usd: None,
|
||||||
});
|
});
|
||||||
obs.record_event(&ObserverEvent::ToolCallStart {
|
obs.record_event(&ObserverEvent::ToolCallStart {
|
||||||
tool: "shell".into(),
|
tool: "shell".into(),
|
||||||
|
|
|
||||||
|
|
@ -48,10 +48,12 @@ mod tests {
|
||||||
obs.record_event(&ObserverEvent::AgentEnd {
|
obs.record_event(&ObserverEvent::AgentEnd {
|
||||||
duration: Duration::from_millis(100),
|
duration: Duration::from_millis(100),
|
||||||
tokens_used: Some(42),
|
tokens_used: Some(42),
|
||||||
|
cost_usd: Some(0.001),
|
||||||
});
|
});
|
||||||
obs.record_event(&ObserverEvent::AgentEnd {
|
obs.record_event(&ObserverEvent::AgentEnd {
|
||||||
duration: Duration::ZERO,
|
duration: Duration::ZERO,
|
||||||
tokens_used: None,
|
tokens_used: None,
|
||||||
|
cost_usd: None,
|
||||||
});
|
});
|
||||||
obs.record_event(&ObserverEvent::ToolCallStart {
|
obs.record_event(&ObserverEvent::ToolCallStart {
|
||||||
tool: "shell".into(),
|
tool: "shell".into(),
|
||||||
|
|
|
||||||
|
|
@ -227,6 +227,7 @@ impl Observer for OtelObserver {
|
||||||
ObserverEvent::AgentEnd {
|
ObserverEvent::AgentEnd {
|
||||||
duration,
|
duration,
|
||||||
tokens_used,
|
tokens_used,
|
||||||
|
cost_usd,
|
||||||
} => {
|
} => {
|
||||||
let secs = duration.as_secs_f64();
|
let secs = duration.as_secs_f64();
|
||||||
let start_time = SystemTime::now()
|
let start_time = SystemTime::now()
|
||||||
|
|
@ -243,6 +244,9 @@ impl Observer for OtelObserver {
|
||||||
if let Some(t) = tokens_used {
|
if let Some(t) = tokens_used {
|
||||||
span.set_attribute(KeyValue::new("tokens_used", *t as i64));
|
span.set_attribute(KeyValue::new("tokens_used", *t as i64));
|
||||||
}
|
}
|
||||||
|
if let Some(c) = cost_usd {
|
||||||
|
span.set_attribute(KeyValue::new("cost_usd", *c));
|
||||||
|
}
|
||||||
span.end();
|
span.end();
|
||||||
|
|
||||||
self.agent_duration.record(secs, &[]);
|
self.agent_duration.record(secs, &[]);
|
||||||
|
|
@ -394,10 +398,12 @@ mod tests {
|
||||||
obs.record_event(&ObserverEvent::AgentEnd {
|
obs.record_event(&ObserverEvent::AgentEnd {
|
||||||
duration: Duration::from_millis(500),
|
duration: Duration::from_millis(500),
|
||||||
tokens_used: Some(100),
|
tokens_used: Some(100),
|
||||||
|
cost_usd: Some(0.0015),
|
||||||
});
|
});
|
||||||
obs.record_event(&ObserverEvent::AgentEnd {
|
obs.record_event(&ObserverEvent::AgentEnd {
|
||||||
duration: Duration::ZERO,
|
duration: Duration::ZERO,
|
||||||
tokens_used: None,
|
tokens_used: None,
|
||||||
|
cost_usd: None,
|
||||||
});
|
});
|
||||||
obs.record_event(&ObserverEvent::ToolCallStart {
|
obs.record_event(&ObserverEvent::ToolCallStart {
|
||||||
tool: "shell".into(),
|
tool: "shell".into(),
|
||||||
|
|
|
||||||
|
|
@ -27,6 +27,7 @@ pub enum ObserverEvent {
|
||||||
AgentEnd {
|
AgentEnd {
|
||||||
duration: Duration,
|
duration: Duration,
|
||||||
tokens_used: Option<u64>,
|
tokens_used: Option<u64>,
|
||||||
|
cost_usd: Option<f64>,
|
||||||
},
|
},
|
||||||
/// A tool call is about to be executed.
|
/// A tool call is about to be executed.
|
||||||
ToolCallStart {
|
ToolCallStart {
|
||||||
|
|
|
||||||
|
|
@ -106,6 +106,7 @@ pub fn run_wizard() -> Result<Config> {
|
||||||
} else {
|
} else {
|
||||||
Some(api_key)
|
Some(api_key)
|
||||||
},
|
},
|
||||||
|
api_url: None,
|
||||||
default_provider: Some(provider),
|
default_provider: Some(provider),
|
||||||
default_model: Some(model),
|
default_model: Some(model),
|
||||||
default_temperature: 0.7,
|
default_temperature: 0.7,
|
||||||
|
|
@ -284,7 +285,7 @@ fn memory_config_defaults_for_backend(backend: &str) -> MemoryConfig {
|
||||||
|
|
||||||
#[allow(clippy::too_many_lines)]
|
#[allow(clippy::too_many_lines)]
|
||||||
pub fn run_quick_setup(
|
pub fn run_quick_setup(
|
||||||
api_key: Option<&str>,
|
credential_override: Option<&str>,
|
||||||
provider: Option<&str>,
|
provider: Option<&str>,
|
||||||
memory_backend: Option<&str>,
|
memory_backend: Option<&str>,
|
||||||
) -> Result<Config> {
|
) -> Result<Config> {
|
||||||
|
|
@ -318,7 +319,8 @@ pub fn run_quick_setup(
|
||||||
let config = Config {
|
let config = Config {
|
||||||
workspace_dir: workspace_dir.clone(),
|
workspace_dir: workspace_dir.clone(),
|
||||||
config_path: config_path.clone(),
|
config_path: config_path.clone(),
|
||||||
api_key: api_key.map(String::from),
|
api_key: credential_override.map(String::from),
|
||||||
|
api_url: None,
|
||||||
default_provider: Some(provider_name.clone()),
|
default_provider: Some(provider_name.clone()),
|
||||||
default_model: Some(model.clone()),
|
default_model: Some(model.clone()),
|
||||||
default_temperature: 0.7,
|
default_temperature: 0.7,
|
||||||
|
|
@ -377,7 +379,7 @@ pub fn run_quick_setup(
|
||||||
println!(
|
println!(
|
||||||
" {} API Key: {}",
|
" {} API Key: {}",
|
||||||
style("✓").green().bold(),
|
style("✓").green().bold(),
|
||||||
if api_key.is_some() {
|
if credential_override.is_some() {
|
||||||
style("set").green()
|
style("set").green()
|
||||||
} else {
|
} else {
|
||||||
style("not set (use --api-key or edit config.toml)").yellow()
|
style("not set (use --api-key or edit config.toml)").yellow()
|
||||||
|
|
@ -426,7 +428,7 @@ pub fn run_quick_setup(
|
||||||
);
|
);
|
||||||
println!();
|
println!();
|
||||||
println!(" {}", style("Next steps:").white().bold());
|
println!(" {}", style("Next steps:").white().bold());
|
||||||
if api_key.is_none() {
|
if credential_override.is_none() {
|
||||||
println!(" 1. Set your API key: export OPENROUTER_API_KEY=\"sk-...\"");
|
println!(" 1. Set your API key: export OPENROUTER_API_KEY=\"sk-...\"");
|
||||||
println!(" 2. Or edit: ~/.zeroclaw/config.toml");
|
println!(" 2. Or edit: ~/.zeroclaw/config.toml");
|
||||||
println!(" 3. Chat: zeroclaw agent -m \"Hello!\"");
|
println!(" 3. Chat: zeroclaw agent -m \"Hello!\"");
|
||||||
|
|
@ -2269,14 +2271,11 @@ fn setup_memory() -> Result<MemoryConfig> {
|
||||||
let backend = backend_key_from_choice(choice);
|
let backend = backend_key_from_choice(choice);
|
||||||
let profile = memory_backend_profile(backend);
|
let profile = memory_backend_profile(backend);
|
||||||
|
|
||||||
let auto_save = if !profile.auto_save_default {
|
let auto_save = profile.auto_save_default
|
||||||
false
|
&& Confirm::new()
|
||||||
} else {
|
|
||||||
Confirm::new()
|
|
||||||
.with_prompt(" Auto-save conversations to memory?")
|
.with_prompt(" Auto-save conversations to memory?")
|
||||||
.default(true)
|
.default(true)
|
||||||
.interact()?
|
.interact()?;
|
||||||
};
|
|
||||||
|
|
||||||
println!(
|
println!(
|
||||||
" {} Memory: {} (auto-save: {})",
|
" {} Memory: {} (auto-save: {})",
|
||||||
|
|
@ -2587,6 +2586,7 @@ fn setup_channels() -> Result<ChannelsConfig> {
|
||||||
guild_id: if guild.is_empty() { None } else { Some(guild) },
|
guild_id: if guild.is_empty() { None } else { Some(guild) },
|
||||||
allowed_users,
|
allowed_users,
|
||||||
listen_to_bots: false,
|
listen_to_bots: false,
|
||||||
|
mention_only: false,
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
2 => {
|
2 => {
|
||||||
|
|
@ -2799,22 +2799,14 @@ fn setup_channels() -> Result<ChannelsConfig> {
|
||||||
.header("Authorization", format!("Bearer {access_token_clone}"))
|
.header("Authorization", format!("Bearer {access_token_clone}"))
|
||||||
.send()?;
|
.send()?;
|
||||||
let ok = resp.status().is_success();
|
let ok = resp.status().is_success();
|
||||||
let data: serde_json::Value = resp.json().unwrap_or_default();
|
Ok::<_, reqwest::Error>(ok)
|
||||||
let user_id = data
|
|
||||||
.get("user_id")
|
|
||||||
.and_then(serde_json::Value::as_str)
|
|
||||||
.unwrap_or("unknown")
|
|
||||||
.to_string();
|
|
||||||
Ok::<_, reqwest::Error>((ok, user_id))
|
|
||||||
})
|
})
|
||||||
.join();
|
.join();
|
||||||
match thread_result {
|
match thread_result {
|
||||||
Ok(Ok((true, user_id))) => {
|
Ok(Ok(true)) => println!(
|
||||||
println!(
|
"\r {} Connection verified ",
|
||||||
"\r {} Connected as {user_id} ",
|
style("✅").green().bold()
|
||||||
style("✅").green().bold()
|
),
|
||||||
);
|
|
||||||
}
|
|
||||||
_ => {
|
_ => {
|
||||||
println!(
|
println!(
|
||||||
"\r {} Connection failed — check homeserver URL and token",
|
"\r {} Connection failed — check homeserver URL and token",
|
||||||
|
|
@ -3779,15 +3771,7 @@ fn print_summary(config: &Config) {
|
||||||
);
|
);
|
||||||
|
|
||||||
// Secrets
|
// Secrets
|
||||||
println!(
|
println!(" {} Secrets: configured", style("🔒").cyan());
|
||||||
" {} Secrets: {}",
|
|
||||||
style("🔒").cyan(),
|
|
||||||
if config.secrets.encrypt {
|
|
||||||
style("encrypted").green().to_string()
|
|
||||||
} else {
|
|
||||||
style("plaintext").yellow().to_string()
|
|
||||||
}
|
|
||||||
);
|
|
||||||
|
|
||||||
// Gateway
|
// Gateway
|
||||||
println!(
|
println!(
|
||||||
|
|
|
||||||
|
|
@ -38,6 +38,10 @@ pub fn ensure_arduino_cli() -> Result<()> {
|
||||||
anyhow::bail!("brew install arduino-cli failed. Install manually: https://arduino.github.io/arduino-cli/");
|
anyhow::bail!("brew install arduino-cli failed. Install manually: https://arduino.github.io/arduino-cli/");
|
||||||
}
|
}
|
||||||
println!("arduino-cli installed.");
|
println!("arduino-cli installed.");
|
||||||
|
if !arduino_cli_available() {
|
||||||
|
anyhow::bail!("arduino-cli still not found after install. Ensure it's in PATH.");
|
||||||
|
}
|
||||||
|
return Ok(());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(target_os = "linux")]
|
#[cfg(target_os = "linux")]
|
||||||
|
|
@ -54,11 +58,6 @@ pub fn ensure_arduino_cli() -> Result<()> {
|
||||||
println!("arduino-cli not found. Install it: https://arduino.github.io/arduino-cli/");
|
println!("arduino-cli not found. Install it: https://arduino.github.io/arduino-cli/");
|
||||||
anyhow::bail!("arduino-cli not installed.");
|
anyhow::bail!("arduino-cli not installed.");
|
||||||
}
|
}
|
||||||
|
|
||||||
if !arduino_cli_available() {
|
|
||||||
anyhow::bail!("arduino-cli still not found after install. Ensure it's in PATH.");
|
|
||||||
}
|
|
||||||
Ok(())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Ensure arduino:avr core is installed.
|
/// Ensure arduino:avr core is installed.
|
||||||
|
|
|
||||||
|
|
@ -112,6 +112,7 @@ pub struct SerialPeripheral {
|
||||||
|
|
||||||
impl SerialPeripheral {
|
impl SerialPeripheral {
|
||||||
/// Create and connect to a serial peripheral.
|
/// Create and connect to a serial peripheral.
|
||||||
|
#[allow(clippy::unused_async)]
|
||||||
pub async fn connect(config: &PeripheralBoardConfig) -> anyhow::Result<Self> {
|
pub async fn connect(config: &PeripheralBoardConfig) -> anyhow::Result<Self> {
|
||||||
let path = config
|
let path = config
|
||||||
.path
|
.path
|
||||||
|
|
|
||||||
|
|
@ -106,17 +106,17 @@ struct NativeContentIn {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AnthropicProvider {
|
impl AnthropicProvider {
|
||||||
pub fn new(api_key: Option<&str>) -> Self {
|
pub fn new(credential: Option<&str>) -> Self {
|
||||||
Self::with_base_url(api_key, None)
|
Self::with_base_url(credential, None)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn with_base_url(api_key: Option<&str>, base_url: Option<&str>) -> Self {
|
pub fn with_base_url(credential: Option<&str>, base_url: Option<&str>) -> Self {
|
||||||
let base_url = base_url
|
let base_url = base_url
|
||||||
.map(|u| u.trim_end_matches('/'))
|
.map(|u| u.trim_end_matches('/'))
|
||||||
.unwrap_or("https://api.anthropic.com")
|
.unwrap_or("https://api.anthropic.com")
|
||||||
.to_string();
|
.to_string();
|
||||||
Self {
|
Self {
|
||||||
credential: api_key
|
credential: credential
|
||||||
.map(str::trim)
|
.map(str::trim)
|
||||||
.filter(|k| !k.is_empty())
|
.filter(|k| !k.is_empty())
|
||||||
.map(ToString::to_string),
|
.map(ToString::to_string),
|
||||||
|
|
@ -410,9 +410,9 @@ mod tests {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn creates_with_key() {
|
fn creates_with_key() {
|
||||||
let p = AnthropicProvider::new(Some("sk-ant-test123"));
|
let p = AnthropicProvider::new(Some("anthropic-test-credential"));
|
||||||
assert!(p.credential.is_some());
|
assert!(p.credential.is_some());
|
||||||
assert_eq!(p.credential.as_deref(), Some("sk-ant-test123"));
|
assert_eq!(p.credential.as_deref(), Some("anthropic-test-credential"));
|
||||||
assert_eq!(p.base_url, "https://api.anthropic.com");
|
assert_eq!(p.base_url, "https://api.anthropic.com");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -431,17 +431,19 @@ mod tests {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn creates_with_whitespace_key() {
|
fn creates_with_whitespace_key() {
|
||||||
let p = AnthropicProvider::new(Some(" sk-ant-test123 "));
|
let p = AnthropicProvider::new(Some(" anthropic-test-credential "));
|
||||||
assert!(p.credential.is_some());
|
assert!(p.credential.is_some());
|
||||||
assert_eq!(p.credential.as_deref(), Some("sk-ant-test123"));
|
assert_eq!(p.credential.as_deref(), Some("anthropic-test-credential"));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn creates_with_custom_base_url() {
|
fn creates_with_custom_base_url() {
|
||||||
let p =
|
let p = AnthropicProvider::with_base_url(
|
||||||
AnthropicProvider::with_base_url(Some("sk-ant-test"), Some("https://api.example.com"));
|
Some("anthropic-credential"),
|
||||||
|
Some("https://api.example.com"),
|
||||||
|
);
|
||||||
assert_eq!(p.base_url, "https://api.example.com");
|
assert_eq!(p.base_url, "https://api.example.com");
|
||||||
assert_eq!(p.credential.as_deref(), Some("sk-ant-test"));
|
assert_eq!(p.credential.as_deref(), Some("anthropic-credential"));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
|
||||||
|
|
@ -17,7 +17,7 @@ use serde::{Deserialize, Serialize};
|
||||||
pub struct OpenAiCompatibleProvider {
|
pub struct OpenAiCompatibleProvider {
|
||||||
pub(crate) name: String,
|
pub(crate) name: String,
|
||||||
pub(crate) base_url: String,
|
pub(crate) base_url: String,
|
||||||
pub(crate) api_key: Option<String>,
|
pub(crate) credential: Option<String>,
|
||||||
pub(crate) auth_header: AuthStyle,
|
pub(crate) auth_header: AuthStyle,
|
||||||
/// When false, do not fall back to /v1/responses on chat completions 404.
|
/// When false, do not fall back to /v1/responses on chat completions 404.
|
||||||
/// GLM/Zhipu does not support the responses API.
|
/// GLM/Zhipu does not support the responses API.
|
||||||
|
|
@ -37,11 +37,16 @@ pub enum AuthStyle {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl OpenAiCompatibleProvider {
|
impl OpenAiCompatibleProvider {
|
||||||
pub fn new(name: &str, base_url: &str, api_key: Option<&str>, auth_style: AuthStyle) -> Self {
|
pub fn new(
|
||||||
|
name: &str,
|
||||||
|
base_url: &str,
|
||||||
|
credential: Option<&str>,
|
||||||
|
auth_style: AuthStyle,
|
||||||
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
name: name.to_string(),
|
name: name.to_string(),
|
||||||
base_url: base_url.trim_end_matches('/').to_string(),
|
base_url: base_url.trim_end_matches('/').to_string(),
|
||||||
api_key: api_key.map(ToString::to_string),
|
credential: credential.map(ToString::to_string),
|
||||||
auth_header: auth_style,
|
auth_header: auth_style,
|
||||||
supports_responses_fallback: true,
|
supports_responses_fallback: true,
|
||||||
client: Client::builder()
|
client: Client::builder()
|
||||||
|
|
@ -57,13 +62,13 @@ impl OpenAiCompatibleProvider {
|
||||||
pub fn new_no_responses_fallback(
|
pub fn new_no_responses_fallback(
|
||||||
name: &str,
|
name: &str,
|
||||||
base_url: &str,
|
base_url: &str,
|
||||||
api_key: Option<&str>,
|
credential: Option<&str>,
|
||||||
auth_style: AuthStyle,
|
auth_style: AuthStyle,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
name: name.to_string(),
|
name: name.to_string(),
|
||||||
base_url: base_url.trim_end_matches('/').to_string(),
|
base_url: base_url.trim_end_matches('/').to_string(),
|
||||||
api_key: api_key.map(ToString::to_string),
|
credential: credential.map(ToString::to_string),
|
||||||
auth_header: auth_style,
|
auth_header: auth_style,
|
||||||
supports_responses_fallback: false,
|
supports_responses_fallback: false,
|
||||||
client: Client::builder()
|
client: Client::builder()
|
||||||
|
|
@ -405,18 +410,18 @@ impl OpenAiCompatibleProvider {
|
||||||
fn apply_auth_header(
|
fn apply_auth_header(
|
||||||
&self,
|
&self,
|
||||||
req: reqwest::RequestBuilder,
|
req: reqwest::RequestBuilder,
|
||||||
api_key: &str,
|
credential: &str,
|
||||||
) -> reqwest::RequestBuilder {
|
) -> reqwest::RequestBuilder {
|
||||||
match &self.auth_header {
|
match &self.auth_header {
|
||||||
AuthStyle::Bearer => req.header("Authorization", format!("Bearer {api_key}")),
|
AuthStyle::Bearer => req.header("Authorization", format!("Bearer {credential}")),
|
||||||
AuthStyle::XApiKey => req.header("x-api-key", api_key),
|
AuthStyle::XApiKey => req.header("x-api-key", credential),
|
||||||
AuthStyle::Custom(header) => req.header(header, api_key),
|
AuthStyle::Custom(header) => req.header(header, credential),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn chat_via_responses(
|
async fn chat_via_responses(
|
||||||
&self,
|
&self,
|
||||||
api_key: &str,
|
credential: &str,
|
||||||
system_prompt: Option<&str>,
|
system_prompt: Option<&str>,
|
||||||
message: &str,
|
message: &str,
|
||||||
model: &str,
|
model: &str,
|
||||||
|
|
@ -434,7 +439,7 @@ impl OpenAiCompatibleProvider {
|
||||||
let url = self.responses_url();
|
let url = self.responses_url();
|
||||||
|
|
||||||
let response = self
|
let response = self
|
||||||
.apply_auth_header(self.client.post(&url).json(&request), api_key)
|
.apply_auth_header(self.client.post(&url).json(&request), credential)
|
||||||
.send()
|
.send()
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
|
|
@ -459,7 +464,7 @@ impl Provider for OpenAiCompatibleProvider {
|
||||||
model: &str,
|
model: &str,
|
||||||
temperature: f64,
|
temperature: f64,
|
||||||
) -> anyhow::Result<String> {
|
) -> anyhow::Result<String> {
|
||||||
let api_key = self.api_key.as_ref().ok_or_else(|| {
|
let credential = self.credential.as_ref().ok_or_else(|| {
|
||||||
anyhow::anyhow!(
|
anyhow::anyhow!(
|
||||||
"{} API key not set. Run `zeroclaw onboard` or set the appropriate env var.",
|
"{} API key not set. Run `zeroclaw onboard` or set the appropriate env var.",
|
||||||
self.name
|
self.name
|
||||||
|
|
@ -490,7 +495,7 @@ impl Provider for OpenAiCompatibleProvider {
|
||||||
let url = self.chat_completions_url();
|
let url = self.chat_completions_url();
|
||||||
|
|
||||||
let response = self
|
let response = self
|
||||||
.apply_auth_header(self.client.post(&url).json(&request), api_key)
|
.apply_auth_header(self.client.post(&url).json(&request), credential)
|
||||||
.send()
|
.send()
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
|
|
@ -501,7 +506,7 @@ impl Provider for OpenAiCompatibleProvider {
|
||||||
|
|
||||||
if status == reqwest::StatusCode::NOT_FOUND && self.supports_responses_fallback {
|
if status == reqwest::StatusCode::NOT_FOUND && self.supports_responses_fallback {
|
||||||
return self
|
return self
|
||||||
.chat_via_responses(api_key, system_prompt, message, model)
|
.chat_via_responses(credential, system_prompt, message, model)
|
||||||
.await
|
.await
|
||||||
.map_err(|responses_err| {
|
.map_err(|responses_err| {
|
||||||
anyhow::anyhow!(
|
anyhow::anyhow!(
|
||||||
|
|
@ -545,7 +550,7 @@ impl Provider for OpenAiCompatibleProvider {
|
||||||
model: &str,
|
model: &str,
|
||||||
temperature: f64,
|
temperature: f64,
|
||||||
) -> anyhow::Result<String> {
|
) -> anyhow::Result<String> {
|
||||||
let api_key = self.api_key.as_ref().ok_or_else(|| {
|
let credential = self.credential.as_ref().ok_or_else(|| {
|
||||||
anyhow::anyhow!(
|
anyhow::anyhow!(
|
||||||
"{} API key not set. Run `zeroclaw onboard` or set the appropriate env var.",
|
"{} API key not set. Run `zeroclaw onboard` or set the appropriate env var.",
|
||||||
self.name
|
self.name
|
||||||
|
|
@ -569,7 +574,7 @@ impl Provider for OpenAiCompatibleProvider {
|
||||||
|
|
||||||
let url = self.chat_completions_url();
|
let url = self.chat_completions_url();
|
||||||
let response = self
|
let response = self
|
||||||
.apply_auth_header(self.client.post(&url).json(&request), api_key)
|
.apply_auth_header(self.client.post(&url).json(&request), credential)
|
||||||
.send()
|
.send()
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
|
|
@ -584,7 +589,7 @@ impl Provider for OpenAiCompatibleProvider {
|
||||||
if let Some(user_msg) = last_user {
|
if let Some(user_msg) = last_user {
|
||||||
return self
|
return self
|
||||||
.chat_via_responses(
|
.chat_via_responses(
|
||||||
api_key,
|
credential,
|
||||||
system.map(|m| m.content.as_str()),
|
system.map(|m| m.content.as_str()),
|
||||||
&user_msg.content,
|
&user_msg.content,
|
||||||
model,
|
model,
|
||||||
|
|
@ -791,16 +796,20 @@ mod tests {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn creates_with_key() {
|
fn creates_with_key() {
|
||||||
let p = make_provider("venice", "https://api.venice.ai", Some("vn-key"));
|
let p = make_provider(
|
||||||
|
"venice",
|
||||||
|
"https://api.venice.ai",
|
||||||
|
Some("venice-test-credential"),
|
||||||
|
);
|
||||||
assert_eq!(p.name, "venice");
|
assert_eq!(p.name, "venice");
|
||||||
assert_eq!(p.base_url, "https://api.venice.ai");
|
assert_eq!(p.base_url, "https://api.venice.ai");
|
||||||
assert_eq!(p.api_key.as_deref(), Some("vn-key"));
|
assert_eq!(p.credential.as_deref(), Some("venice-test-credential"));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn creates_without_key() {
|
fn creates_without_key() {
|
||||||
let p = make_provider("test", "https://example.com", None);
|
let p = make_provider("test", "https://example.com", None);
|
||||||
assert!(p.api_key.is_none());
|
assert!(p.credential.is_none());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
@ -894,6 +903,7 @@ mod tests {
|
||||||
make_provider("Groq", "https://api.groq.com/openai", None),
|
make_provider("Groq", "https://api.groq.com/openai", None),
|
||||||
make_provider("Mistral", "https://api.mistral.ai", None),
|
make_provider("Mistral", "https://api.mistral.ai", None),
|
||||||
make_provider("xAI", "https://api.x.ai", None),
|
make_provider("xAI", "https://api.x.ai", None),
|
||||||
|
make_provider("Astrai", "https://as-trai.com/v1", None),
|
||||||
];
|
];
|
||||||
|
|
||||||
for p in providers {
|
for p in providers {
|
||||||
|
|
|
||||||
705
src/providers/copilot.rs
Normal file
705
src/providers/copilot.rs
Normal file
|
|
@ -0,0 +1,705 @@
|
||||||
|
//! GitHub Copilot provider with OAuth device-flow authentication.
|
||||||
|
//!
|
||||||
|
//! Authenticates via GitHub's device code flow (same as VS Code Copilot),
|
||||||
|
//! then exchanges the OAuth token for short-lived Copilot API keys.
|
||||||
|
//! Tokens are cached to disk and auto-refreshed.
|
||||||
|
//!
|
||||||
|
//! **Note:** This uses VS Code's OAuth client ID (`Iv1.b507a08c87ecfe98`) and
|
||||||
|
//! editor headers. This is the same approach used by LiteLLM, Codex CLI,
|
||||||
|
//! and other third-party Copilot integrations. The Copilot token endpoint is
|
||||||
|
//! private; there is no public OAuth scope or app registration for it.
|
||||||
|
//! GitHub could change or revoke this at any time, which would break all
|
||||||
|
//! third-party integrations simultaneously.
|
||||||
|
|
||||||
|
use crate::providers::traits::{
|
||||||
|
ChatMessage, ChatRequest as ProviderChatRequest, ChatResponse as ProviderChatResponse,
|
||||||
|
Provider, ToolCall as ProviderToolCall,
|
||||||
|
};
|
||||||
|
use crate::tools::ToolSpec;
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use reqwest::Client;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::path::{Path, PathBuf};
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::time::Duration;
|
||||||
|
use tokio::sync::Mutex;
|
||||||
|
use tracing::warn;
|
||||||
|
|
||||||
|
/// GitHub OAuth client ID for Copilot (VS Code extension).
|
||||||
|
const GITHUB_CLIENT_ID: &str = "Iv1.b507a08c87ecfe98";
|
||||||
|
const GITHUB_DEVICE_CODE_URL: &str = "https://github.com/login/device/code";
|
||||||
|
const GITHUB_ACCESS_TOKEN_URL: &str = "https://github.com/login/oauth/access_token";
|
||||||
|
const GITHUB_API_KEY_URL: &str = "https://api.github.com/copilot_internal/v2/token";
|
||||||
|
const DEFAULT_API: &str = "https://api.githubcopilot.com";
|
||||||
|
|
||||||
|
// ── Token types ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct DeviceCodeResponse {
|
||||||
|
device_code: String,
|
||||||
|
user_code: String,
|
||||||
|
verification_uri: String,
|
||||||
|
#[serde(default = "default_interval")]
|
||||||
|
interval: u64,
|
||||||
|
#[serde(default = "default_expires_in")]
|
||||||
|
expires_in: u64,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_interval() -> u64 {
|
||||||
|
5
|
||||||
|
}
|
||||||
|
|
||||||
|
fn default_expires_in() -> u64 {
|
||||||
|
900
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct AccessTokenResponse {
|
||||||
|
access_token: Option<String>,
|
||||||
|
error: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
struct ApiKeyInfo {
|
||||||
|
token: String,
|
||||||
|
expires_at: i64,
|
||||||
|
#[serde(default)]
|
||||||
|
endpoints: Option<ApiEndpoints>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
struct ApiEndpoints {
|
||||||
|
api: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
struct CachedApiKey {
|
||||||
|
token: String,
|
||||||
|
api_endpoint: String,
|
||||||
|
expires_at: i64,
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Chat completions types ───────────────────────────────────────
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
struct ApiChatRequest {
|
||||||
|
model: String,
|
||||||
|
messages: Vec<ApiMessage>,
|
||||||
|
temperature: f64,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
tools: Option<Vec<NativeToolSpec>>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
tool_choice: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
struct ApiMessage {
|
||||||
|
role: String,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
content: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
tool_call_id: Option<String>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
tool_calls: Option<Vec<NativeToolCall>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
struct NativeToolSpec {
|
||||||
|
#[serde(rename = "type")]
|
||||||
|
kind: String,
|
||||||
|
function: NativeToolFunctionSpec,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
struct NativeToolFunctionSpec {
|
||||||
|
name: String,
|
||||||
|
description: String,
|
||||||
|
parameters: serde_json::Value,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
struct NativeToolCall {
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
id: Option<String>,
|
||||||
|
#[serde(rename = "type", skip_serializing_if = "Option::is_none")]
|
||||||
|
kind: Option<String>,
|
||||||
|
function: NativeFunctionCall,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
struct NativeFunctionCall {
|
||||||
|
name: String,
|
||||||
|
arguments: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct ApiChatResponse {
|
||||||
|
choices: Vec<Choice>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct Choice {
|
||||||
|
message: ResponseMessage,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct ResponseMessage {
|
||||||
|
#[serde(default)]
|
||||||
|
content: Option<String>,
|
||||||
|
#[serde(default)]
|
||||||
|
tool_calls: Option<Vec<NativeToolCall>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
// ── Provider ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// GitHub Copilot provider with automatic OAuth and token refresh.
|
||||||
|
///
|
||||||
|
/// On first use, prompts the user to visit github.com/login/device.
|
||||||
|
/// Tokens are cached to `~/.config/zeroclaw/copilot/` and refreshed
|
||||||
|
/// automatically.
|
||||||
|
pub struct CopilotProvider {
|
||||||
|
github_token: Option<String>,
|
||||||
|
/// Mutex ensures only one caller refreshes tokens at a time,
|
||||||
|
/// preventing duplicate device flow prompts or redundant API calls.
|
||||||
|
refresh_lock: Arc<Mutex<Option<CachedApiKey>>>,
|
||||||
|
http: Client,
|
||||||
|
token_dir: PathBuf,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CopilotProvider {
|
||||||
|
pub fn new(github_token: Option<&str>) -> Self {
|
||||||
|
let token_dir = directories::ProjectDirs::from("", "", "zeroclaw")
|
||||||
|
.map(|dir| dir.config_dir().join("copilot"))
|
||||||
|
.unwrap_or_else(|| {
|
||||||
|
// Fall back to a user-specific temp directory to avoid
|
||||||
|
// shared-directory symlink attacks.
|
||||||
|
let user = std::env::var("USER")
|
||||||
|
.or_else(|_| std::env::var("USERNAME"))
|
||||||
|
.unwrap_or_else(|_| "unknown".to_string());
|
||||||
|
std::env::temp_dir().join(format!("zeroclaw-copilot-{user}"))
|
||||||
|
});
|
||||||
|
|
||||||
|
if let Err(err) = std::fs::create_dir_all(&token_dir) {
|
||||||
|
warn!(
|
||||||
|
"Failed to create Copilot token directory {:?}: {err}. Token caching is disabled.",
|
||||||
|
token_dir
|
||||||
|
);
|
||||||
|
} else {
|
||||||
|
#[cfg(unix)]
|
||||||
|
{
|
||||||
|
use std::os::unix::fs::PermissionsExt;
|
||||||
|
|
||||||
|
if let Err(err) =
|
||||||
|
std::fs::set_permissions(&token_dir, std::fs::Permissions::from_mode(0o700))
|
||||||
|
{
|
||||||
|
warn!(
|
||||||
|
"Failed to set Copilot token directory permissions on {:?}: {err}",
|
||||||
|
token_dir
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Self {
|
||||||
|
github_token: github_token
|
||||||
|
.filter(|token| !token.is_empty())
|
||||||
|
.map(String::from),
|
||||||
|
refresh_lock: Arc::new(Mutex::new(None)),
|
||||||
|
http: Client::builder()
|
||||||
|
.timeout(Duration::from_secs(120))
|
||||||
|
.connect_timeout(Duration::from_secs(10))
|
||||||
|
.build()
|
||||||
|
.unwrap_or_else(|_| Client::new()),
|
||||||
|
token_dir,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Required headers for Copilot API requests (editor identification).
|
||||||
|
const COPILOT_HEADERS: [(&str, &str); 4] = [
|
||||||
|
("Editor-Version", "vscode/1.85.1"),
|
||||||
|
("Editor-Plugin-Version", "copilot/1.155.0"),
|
||||||
|
("User-Agent", "GithubCopilot/1.155.0"),
|
||||||
|
("Accept", "application/json"),
|
||||||
|
];
|
||||||
|
|
||||||
|
fn convert_tools(tools: Option<&[ToolSpec]>) -> Option<Vec<NativeToolSpec>> {
|
||||||
|
tools.map(|items| {
|
||||||
|
items
|
||||||
|
.iter()
|
||||||
|
.map(|tool| NativeToolSpec {
|
||||||
|
kind: "function".to_string(),
|
||||||
|
function: NativeToolFunctionSpec {
|
||||||
|
name: tool.name.clone(),
|
||||||
|
description: tool.description.clone(),
|
||||||
|
parameters: tool.parameters.clone(),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
fn convert_messages(messages: &[ChatMessage]) -> Vec<ApiMessage> {
|
||||||
|
messages
|
||||||
|
.iter()
|
||||||
|
.map(|message| {
|
||||||
|
if message.role == "assistant" {
|
||||||
|
if let Ok(value) = serde_json::from_str::<serde_json::Value>(&message.content) {
|
||||||
|
if let Some(tool_calls_value) = value.get("tool_calls") {
|
||||||
|
if let Ok(parsed_calls) =
|
||||||
|
serde_json::from_value::<Vec<ProviderToolCall>>(tool_calls_value.clone())
|
||||||
|
{
|
||||||
|
let tool_calls = parsed_calls
|
||||||
|
.into_iter()
|
||||||
|
.map(|tool_call| NativeToolCall {
|
||||||
|
id: Some(tool_call.id),
|
||||||
|
kind: Some("function".to_string()),
|
||||||
|
function: NativeFunctionCall {
|
||||||
|
name: tool_call.name,
|
||||||
|
arguments: tool_call.arguments,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
let content = value
|
||||||
|
.get("content")
|
||||||
|
.and_then(serde_json::Value::as_str)
|
||||||
|
.map(ToString::to_string);
|
||||||
|
|
||||||
|
return ApiMessage {
|
||||||
|
role: "assistant".to_string(),
|
||||||
|
content,
|
||||||
|
tool_call_id: None,
|
||||||
|
tool_calls: Some(tool_calls),
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if message.role == "tool" {
|
||||||
|
if let Ok(value) = serde_json::from_str::<serde_json::Value>(&message.content) {
|
||||||
|
let tool_call_id = value
|
||||||
|
.get("tool_call_id")
|
||||||
|
.and_then(serde_json::Value::as_str)
|
||||||
|
.map(ToString::to_string);
|
||||||
|
let content = value
|
||||||
|
.get("content")
|
||||||
|
.and_then(serde_json::Value::as_str)
|
||||||
|
.map(ToString::to_string);
|
||||||
|
|
||||||
|
return ApiMessage {
|
||||||
|
role: "tool".to_string(),
|
||||||
|
content,
|
||||||
|
tool_call_id,
|
||||||
|
tool_calls: None,
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ApiMessage {
|
||||||
|
role: message.role.clone(),
|
||||||
|
content: Some(message.content.clone()),
|
||||||
|
tool_call_id: None,
|
||||||
|
tool_calls: None,
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Send a chat completions request with required Copilot headers.
|
||||||
|
async fn send_chat_request(
|
||||||
|
&self,
|
||||||
|
messages: Vec<ApiMessage>,
|
||||||
|
tools: Option<&[ToolSpec]>,
|
||||||
|
model: &str,
|
||||||
|
temperature: f64,
|
||||||
|
) -> anyhow::Result<ProviderChatResponse> {
|
||||||
|
let (token, endpoint) = self.get_api_key().await?;
|
||||||
|
let url = format!("{}/chat/completions", endpoint.trim_end_matches('/'));
|
||||||
|
|
||||||
|
let native_tools = Self::convert_tools(tools);
|
||||||
|
let request = ApiChatRequest {
|
||||||
|
model: model.to_string(),
|
||||||
|
messages,
|
||||||
|
temperature,
|
||||||
|
tool_choice: native_tools.as_ref().map(|_| "auto".to_string()),
|
||||||
|
tools: native_tools,
|
||||||
|
};
|
||||||
|
|
||||||
|
let mut req = self
|
||||||
|
.http
|
||||||
|
.post(&url)
|
||||||
|
.header("Authorization", format!("Bearer {token}"))
|
||||||
|
.json(&request);
|
||||||
|
|
||||||
|
for (header, value) in &Self::COPILOT_HEADERS {
|
||||||
|
req = req.header(*header, *value);
|
||||||
|
}
|
||||||
|
|
||||||
|
let response = req.send().await?;
|
||||||
|
|
||||||
|
if !response.status().is_success() {
|
||||||
|
return Err(super::api_error("GitHub Copilot", response).await);
|
||||||
|
}
|
||||||
|
|
||||||
|
let api_response: ApiChatResponse = response.json().await?;
|
||||||
|
let choice = api_response
|
||||||
|
.choices
|
||||||
|
.into_iter()
|
||||||
|
.next()
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("No response from GitHub Copilot"))?;
|
||||||
|
|
||||||
|
let tool_calls = choice
|
||||||
|
.message
|
||||||
|
.tool_calls
|
||||||
|
.unwrap_or_default()
|
||||||
|
.into_iter()
|
||||||
|
.map(|tool_call| ProviderToolCall {
|
||||||
|
id: tool_call
|
||||||
|
.id
|
||||||
|
.unwrap_or_else(|| uuid::Uuid::new_v4().to_string()),
|
||||||
|
name: tool_call.function.name,
|
||||||
|
arguments: tool_call.function.arguments,
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
Ok(ProviderChatResponse {
|
||||||
|
text: choice.message.content,
|
||||||
|
tool_calls,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get a valid Copilot API key, refreshing or re-authenticating as needed.
|
||||||
|
/// Uses a Mutex to ensure only one caller refreshes at a time.
|
||||||
|
async fn get_api_key(&self) -> anyhow::Result<(String, String)> {
|
||||||
|
let mut cached = self.refresh_lock.lock().await;
|
||||||
|
|
||||||
|
if let Some(cached_key) = cached.as_ref() {
|
||||||
|
if chrono::Utc::now().timestamp() + 120 < cached_key.expires_at {
|
||||||
|
return Ok((cached_key.token.clone(), cached_key.api_endpoint.clone()));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(info) = self.load_api_key_from_disk().await {
|
||||||
|
if chrono::Utc::now().timestamp() + 120 < info.expires_at {
|
||||||
|
let endpoint = info
|
||||||
|
.endpoints
|
||||||
|
.as_ref()
|
||||||
|
.and_then(|e| e.api.clone())
|
||||||
|
.unwrap_or_else(|| DEFAULT_API.to_string());
|
||||||
|
let token = info.token;
|
||||||
|
|
||||||
|
*cached = Some(CachedApiKey {
|
||||||
|
token: token.clone(),
|
||||||
|
api_endpoint: endpoint.clone(),
|
||||||
|
expires_at: info.expires_at,
|
||||||
|
});
|
||||||
|
return Ok((token, endpoint));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let access_token = self.get_github_access_token().await?;
|
||||||
|
let api_key_info = self.exchange_for_api_key(&access_token).await?;
|
||||||
|
self.save_api_key_to_disk(&api_key_info).await;
|
||||||
|
|
||||||
|
let endpoint = api_key_info
|
||||||
|
.endpoints
|
||||||
|
.as_ref()
|
||||||
|
.and_then(|e| e.api.clone())
|
||||||
|
.unwrap_or_else(|| DEFAULT_API.to_string());
|
||||||
|
|
||||||
|
*cached = Some(CachedApiKey {
|
||||||
|
token: api_key_info.token.clone(),
|
||||||
|
api_endpoint: endpoint.clone(),
|
||||||
|
expires_at: api_key_info.expires_at,
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok((api_key_info.token, endpoint))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Get a GitHub access token from config, cache, or device flow.
|
||||||
|
async fn get_github_access_token(&self) -> anyhow::Result<String> {
|
||||||
|
if let Some(token) = &self.github_token {
|
||||||
|
return Ok(token.clone());
|
||||||
|
}
|
||||||
|
|
||||||
|
let access_token_path = self.token_dir.join("access-token");
|
||||||
|
if let Ok(cached) = tokio::fs::read_to_string(&access_token_path).await {
|
||||||
|
let token = cached.trim();
|
||||||
|
if !token.is_empty() {
|
||||||
|
return Ok(token.to_string());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let token = self.device_code_login().await?;
|
||||||
|
write_file_secure(&access_token_path, &token).await;
|
||||||
|
Ok(token)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Run GitHub OAuth device code flow.
|
||||||
|
async fn device_code_login(&self) -> anyhow::Result<String> {
|
||||||
|
let response: DeviceCodeResponse = self
|
||||||
|
.http
|
||||||
|
.post(GITHUB_DEVICE_CODE_URL)
|
||||||
|
.header("Accept", "application/json")
|
||||||
|
.json(&serde_json::json!({
|
||||||
|
"client_id": GITHUB_CLIENT_ID,
|
||||||
|
"scope": "read:user"
|
||||||
|
}))
|
||||||
|
.send()
|
||||||
|
.await?
|
||||||
|
.error_for_status()?
|
||||||
|
.json()
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let mut poll_interval = Duration::from_secs(response.interval.max(5));
|
||||||
|
let expires_in = response.expires_in.max(1);
|
||||||
|
let expires_at = tokio::time::Instant::now() + Duration::from_secs(expires_in);
|
||||||
|
|
||||||
|
eprintln!(
|
||||||
|
"\nGitHub Copilot authentication is required.\n\
|
||||||
|
Visit: {}\n\
|
||||||
|
Code: {}\n\
|
||||||
|
Waiting for authorization...\n",
|
||||||
|
response.verification_uri, response.user_code
|
||||||
|
);
|
||||||
|
|
||||||
|
while tokio::time::Instant::now() < expires_at {
|
||||||
|
tokio::time::sleep(poll_interval).await;
|
||||||
|
|
||||||
|
let token_response: AccessTokenResponse = self
|
||||||
|
.http
|
||||||
|
.post(GITHUB_ACCESS_TOKEN_URL)
|
||||||
|
.header("Accept", "application/json")
|
||||||
|
.json(&serde_json::json!({
|
||||||
|
"client_id": GITHUB_CLIENT_ID,
|
||||||
|
"device_code": response.device_code,
|
||||||
|
"grant_type": "urn:ietf:params:oauth:grant-type:device_code"
|
||||||
|
}))
|
||||||
|
.send()
|
||||||
|
.await?
|
||||||
|
.json()
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
if let Some(token) = token_response.access_token {
|
||||||
|
eprintln!("Authentication succeeded.\n");
|
||||||
|
return Ok(token);
|
||||||
|
}
|
||||||
|
|
||||||
|
match token_response.error.as_deref() {
|
||||||
|
Some("slow_down") => {
|
||||||
|
poll_interval += Duration::from_secs(5);
|
||||||
|
}
|
||||||
|
Some("authorization_pending") | None => {}
|
||||||
|
Some("expired_token") => {
|
||||||
|
anyhow::bail!("GitHub device authorization expired")
|
||||||
|
}
|
||||||
|
Some(error) => anyhow::bail!("GitHub auth failed: {error}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
anyhow::bail!("Timed out waiting for GitHub authorization")
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Exchange a GitHub access token for a Copilot API key.
|
||||||
|
async fn exchange_for_api_key(&self, access_token: &str) -> anyhow::Result<ApiKeyInfo> {
|
||||||
|
let mut request = self.http.get(GITHUB_API_KEY_URL);
|
||||||
|
for (header, value) in &Self::COPILOT_HEADERS {
|
||||||
|
request = request.header(*header, *value);
|
||||||
|
}
|
||||||
|
request = request.header("Authorization", format!("token {access_token}"));
|
||||||
|
|
||||||
|
let response = request.send().await?;
|
||||||
|
|
||||||
|
if !response.status().is_success() {
|
||||||
|
let status = response.status();
|
||||||
|
let body = response.text().await.unwrap_or_default();
|
||||||
|
let sanitized = super::sanitize_api_error(&body);
|
||||||
|
|
||||||
|
if status.as_u16() == 401 || status.as_u16() == 403 {
|
||||||
|
let access_token_path = self.token_dir.join("access-token");
|
||||||
|
tokio::fs::remove_file(&access_token_path).await.ok();
|
||||||
|
}
|
||||||
|
|
||||||
|
anyhow::bail!(
|
||||||
|
"Failed to get Copilot API key ({status}): {sanitized}. \
|
||||||
|
Ensure your GitHub account has an active Copilot subscription."
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let info: ApiKeyInfo = response.json().await?;
|
||||||
|
Ok(info)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn load_api_key_from_disk(&self) -> Option<ApiKeyInfo> {
|
||||||
|
let path = self.token_dir.join("api-key.json");
|
||||||
|
let data = tokio::fs::read_to_string(&path).await.ok()?;
|
||||||
|
serde_json::from_str(&data).ok()
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn save_api_key_to_disk(&self, info: &ApiKeyInfo) {
|
||||||
|
let path = self.token_dir.join("api-key.json");
|
||||||
|
if let Ok(json) = serde_json::to_string_pretty(info) {
|
||||||
|
write_file_secure(&path, &json).await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Write a file with 0600 permissions (owner read/write only).
|
||||||
|
/// Uses `spawn_blocking` to avoid blocking the async runtime.
|
||||||
|
async fn write_file_secure(path: &Path, content: &str) {
|
||||||
|
let path = path.to_path_buf();
|
||||||
|
let content = content.to_string();
|
||||||
|
|
||||||
|
let result = tokio::task::spawn_blocking(move || {
|
||||||
|
#[cfg(unix)]
|
||||||
|
{
|
||||||
|
use std::io::Write;
|
||||||
|
use std::os::unix::fs::{OpenOptionsExt, PermissionsExt};
|
||||||
|
|
||||||
|
let mut file = std::fs::OpenOptions::new()
|
||||||
|
.write(true)
|
||||||
|
.create(true)
|
||||||
|
.truncate(true)
|
||||||
|
.mode(0o600)
|
||||||
|
.open(&path)?;
|
||||||
|
file.write_all(content.as_bytes())?;
|
||||||
|
|
||||||
|
std::fs::set_permissions(&path, std::fs::Permissions::from_mode(0o600))?;
|
||||||
|
Ok::<(), std::io::Error>(())
|
||||||
|
}
|
||||||
|
#[cfg(not(unix))]
|
||||||
|
{
|
||||||
|
std::fs::write(&path, &content)?;
|
||||||
|
Ok::<(), std::io::Error>(())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
|
||||||
|
match result {
|
||||||
|
Ok(Ok(())) => {}
|
||||||
|
Ok(Err(err)) => warn!("Failed to write secure file: {err}"),
|
||||||
|
Err(err) => warn!("Failed to spawn blocking write: {err}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Provider for CopilotProvider {
|
||||||
|
async fn chat_with_system(
|
||||||
|
&self,
|
||||||
|
system_prompt: Option<&str>,
|
||||||
|
message: &str,
|
||||||
|
model: &str,
|
||||||
|
temperature: f64,
|
||||||
|
) -> anyhow::Result<String> {
|
||||||
|
let mut messages = Vec::new();
|
||||||
|
if let Some(system) = system_prompt {
|
||||||
|
messages.push(ApiMessage {
|
||||||
|
role: "system".to_string(),
|
||||||
|
content: Some(system.to_string()),
|
||||||
|
tool_call_id: None,
|
||||||
|
tool_calls: None,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
messages.push(ApiMessage {
|
||||||
|
role: "user".to_string(),
|
||||||
|
content: Some(message.to_string()),
|
||||||
|
tool_call_id: None,
|
||||||
|
tool_calls: None,
|
||||||
|
});
|
||||||
|
|
||||||
|
let response = self
|
||||||
|
.send_chat_request(messages, None, model, temperature)
|
||||||
|
.await?;
|
||||||
|
Ok(response.text.unwrap_or_default())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn chat_with_history(
|
||||||
|
&self,
|
||||||
|
messages: &[ChatMessage],
|
||||||
|
model: &str,
|
||||||
|
temperature: f64,
|
||||||
|
) -> anyhow::Result<String> {
|
||||||
|
let response = self
|
||||||
|
.send_chat_request(Self::convert_messages(messages), None, model, temperature)
|
||||||
|
.await?;
|
||||||
|
Ok(response.text.unwrap_or_default())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn chat(
|
||||||
|
&self,
|
||||||
|
request: ProviderChatRequest<'_>,
|
||||||
|
model: &str,
|
||||||
|
temperature: f64,
|
||||||
|
) -> anyhow::Result<ProviderChatResponse> {
|
||||||
|
self.send_chat_request(
|
||||||
|
Self::convert_messages(request.messages),
|
||||||
|
request.tools,
|
||||||
|
model,
|
||||||
|
temperature,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
}
|
||||||
|
|
||||||
|
fn supports_native_tools(&self) -> bool {
|
||||||
|
true
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn warmup(&self) -> anyhow::Result<()> {
|
||||||
|
let _ = self.get_api_key().await?;
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn new_without_token() {
|
||||||
|
let provider = CopilotProvider::new(None);
|
||||||
|
assert!(provider.github_token.is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn new_with_token() {
|
||||||
|
let provider = CopilotProvider::new(Some("ghp_test"));
|
||||||
|
assert_eq!(provider.github_token.as_deref(), Some("ghp_test"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn empty_token_treated_as_none() {
|
||||||
|
let provider = CopilotProvider::new(Some(""));
|
||||||
|
assert!(provider.github_token.is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn cache_starts_empty() {
|
||||||
|
let provider = CopilotProvider::new(None);
|
||||||
|
let cached = provider.refresh_lock.lock().await;
|
||||||
|
assert!(cached.is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn copilot_headers_include_required_fields() {
|
||||||
|
let headers = CopilotProvider::COPILOT_HEADERS;
|
||||||
|
assert!(headers
|
||||||
|
.iter()
|
||||||
|
.any(|(header, _)| *header == "Editor-Version"));
|
||||||
|
assert!(headers
|
||||||
|
.iter()
|
||||||
|
.any(|(header, _)| *header == "Editor-Plugin-Version"));
|
||||||
|
assert!(headers.iter().any(|(header, _)| *header == "User-Agent"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn default_interval_and_expiry() {
|
||||||
|
assert_eq!(default_interval(), 5);
|
||||||
|
assert_eq!(default_expires_in(), 900);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn supports_native_tools() {
|
||||||
|
let provider = CopilotProvider::new(None);
|
||||||
|
assert!(provider.supports_native_tools());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
pub mod anthropic;
|
pub mod anthropic;
|
||||||
pub mod compatible;
|
pub mod compatible;
|
||||||
|
pub mod copilot;
|
||||||
pub mod gemini;
|
pub mod gemini;
|
||||||
pub mod ollama;
|
pub mod ollama;
|
||||||
pub mod openai;
|
pub mod openai;
|
||||||
|
|
@ -37,9 +38,18 @@ fn token_end(input: &str, from: usize) -> usize {
|
||||||
|
|
||||||
/// Scrub known secret-like token prefixes from provider error strings.
|
/// Scrub known secret-like token prefixes from provider error strings.
|
||||||
///
|
///
|
||||||
/// Redacts tokens with prefixes like `sk-`, `xoxb-`, and `xoxp-`.
|
/// Redacts tokens with prefixes like `sk-`, `xoxb-`, `xoxp-`, `ghp_`, `gho_`,
|
||||||
|
/// `ghu_`, and `github_pat_`.
|
||||||
pub fn scrub_secret_patterns(input: &str) -> String {
|
pub fn scrub_secret_patterns(input: &str) -> String {
|
||||||
const PREFIXES: [&str; 3] = ["sk-", "xoxb-", "xoxp-"];
|
const PREFIXES: [&str; 7] = [
|
||||||
|
"sk-",
|
||||||
|
"xoxb-",
|
||||||
|
"xoxp-",
|
||||||
|
"ghp_",
|
||||||
|
"gho_",
|
||||||
|
"ghu_",
|
||||||
|
"github_pat_",
|
||||||
|
];
|
||||||
|
|
||||||
let mut scrubbed = input.to_string();
|
let mut scrubbed = input.to_string();
|
||||||
|
|
||||||
|
|
@ -104,9 +114,12 @@ pub async fn api_error(provider: &str, response: reqwest::Response) -> anyhow::E
|
||||||
///
|
///
|
||||||
/// For Anthropic, the provider-specific env var is `ANTHROPIC_OAUTH_TOKEN` (for setup-tokens)
|
/// For Anthropic, the provider-specific env var is `ANTHROPIC_OAUTH_TOKEN` (for setup-tokens)
|
||||||
/// followed by `ANTHROPIC_API_KEY` (for regular API keys).
|
/// followed by `ANTHROPIC_API_KEY` (for regular API keys).
|
||||||
fn resolve_api_key(name: &str, api_key: Option<&str>) -> Option<String> {
|
fn resolve_provider_credential(name: &str, credential_override: Option<&str>) -> Option<String> {
|
||||||
if let Some(key) = api_key.map(str::trim).filter(|k| !k.is_empty()) {
|
if let Some(raw_override) = credential_override {
|
||||||
return Some(key.to_string());
|
let trimmed_override = raw_override.trim();
|
||||||
|
if !trimmed_override.is_empty() {
|
||||||
|
return Some(trimmed_override.to_owned());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let provider_env_candidates: Vec<&str> = match name {
|
let provider_env_candidates: Vec<&str> = match name {
|
||||||
|
|
@ -135,6 +148,7 @@ fn resolve_api_key(name: &str, api_key: Option<&str>) -> Option<String> {
|
||||||
"opencode" | "opencode-zen" => vec!["OPENCODE_API_KEY"],
|
"opencode" | "opencode-zen" => vec!["OPENCODE_API_KEY"],
|
||||||
"vercel" | "vercel-ai" => vec!["VERCEL_API_KEY"],
|
"vercel" | "vercel-ai" => vec!["VERCEL_API_KEY"],
|
||||||
"cloudflare" | "cloudflare-ai" => vec!["CLOUDFLARE_API_KEY"],
|
"cloudflare" | "cloudflare-ai" => vec!["CLOUDFLARE_API_KEY"],
|
||||||
|
"astrai" => vec!["ASTRAI_API_KEY"],
|
||||||
_ => vec![],
|
_ => vec![],
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
@ -182,19 +196,28 @@ fn parse_custom_provider_url(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Factory: create the right provider from config
|
/// Factory: create the right provider from config (without custom URL)
|
||||||
#[allow(clippy::too_many_lines)]
|
|
||||||
pub fn create_provider(name: &str, api_key: Option<&str>) -> anyhow::Result<Box<dyn Provider>> {
|
pub fn create_provider(name: &str, api_key: Option<&str>) -> anyhow::Result<Box<dyn Provider>> {
|
||||||
let resolved_key = resolve_api_key(name, api_key);
|
create_provider_with_url(name, api_key, None)
|
||||||
let key = resolved_key.as_deref();
|
}
|
||||||
|
|
||||||
|
/// Factory: create the right provider from config with optional custom base URL
|
||||||
|
#[allow(clippy::too_many_lines)]
|
||||||
|
pub fn create_provider_with_url(
|
||||||
|
name: &str,
|
||||||
|
api_key: Option<&str>,
|
||||||
|
api_url: Option<&str>,
|
||||||
|
) -> anyhow::Result<Box<dyn Provider>> {
|
||||||
|
let resolved_credential = resolve_provider_credential(name, api_key);
|
||||||
|
#[allow(clippy::option_as_ref_deref)]
|
||||||
|
let key = resolved_credential.as_ref().map(String::as_str);
|
||||||
match name {
|
match name {
|
||||||
// ── Primary providers (custom implementations) ───────
|
// ── Primary providers (custom implementations) ───────
|
||||||
"openrouter" => Ok(Box::new(openrouter::OpenRouterProvider::new(key))),
|
"openrouter" => Ok(Box::new(openrouter::OpenRouterProvider::new(key))),
|
||||||
"anthropic" => Ok(Box::new(anthropic::AnthropicProvider::new(key))),
|
"anthropic" => Ok(Box::new(anthropic::AnthropicProvider::new(key))),
|
||||||
"openai" => Ok(Box::new(openai::OpenAiProvider::new(key))),
|
"openai" => Ok(Box::new(openai::OpenAiProvider::new(key))),
|
||||||
// Ollama is a local service that doesn't use API keys.
|
// Ollama uses api_url for custom base URL (e.g. remote Ollama instance)
|
||||||
// The api_key parameter is ignored to avoid it being misinterpreted as a base_url.
|
"ollama" => Ok(Box::new(ollama::OllamaProvider::new(api_url))),
|
||||||
"ollama" => Ok(Box::new(ollama::OllamaProvider::new(None))),
|
|
||||||
"gemini" | "google" | "google-gemini" => {
|
"gemini" | "google" | "google-gemini" => {
|
||||||
Ok(Box::new(gemini::GeminiProvider::new(key)))
|
Ok(Box::new(gemini::GeminiProvider::new(key)))
|
||||||
}
|
}
|
||||||
|
|
@ -257,7 +280,7 @@ pub fn create_provider(name: &str, api_key: Option<&str>) -> anyhow::Result<Box<
|
||||||
"Groq", "https://api.groq.com/openai", key, AuthStyle::Bearer,
|
"Groq", "https://api.groq.com/openai", key, AuthStyle::Bearer,
|
||||||
))),
|
))),
|
||||||
"mistral" => Ok(Box::new(OpenAiCompatibleProvider::new(
|
"mistral" => Ok(Box::new(OpenAiCompatibleProvider::new(
|
||||||
"Mistral", "https://api.mistral.ai", key, AuthStyle::Bearer,
|
"Mistral", "https://api.mistral.ai/v1", key, AuthStyle::Bearer,
|
||||||
))),
|
))),
|
||||||
"xai" | "grok" => Ok(Box::new(OpenAiCompatibleProvider::new(
|
"xai" | "grok" => Ok(Box::new(OpenAiCompatibleProvider::new(
|
||||||
"xAI", "https://api.x.ai", key, AuthStyle::Bearer,
|
"xAI", "https://api.x.ai", key, AuthStyle::Bearer,
|
||||||
|
|
@ -277,11 +300,33 @@ pub fn create_provider(name: &str, api_key: Option<&str>) -> anyhow::Result<Box<
|
||||||
"cohere" => Ok(Box::new(OpenAiCompatibleProvider::new(
|
"cohere" => Ok(Box::new(OpenAiCompatibleProvider::new(
|
||||||
"Cohere", "https://api.cohere.com/compatibility", key, AuthStyle::Bearer,
|
"Cohere", "https://api.cohere.com/compatibility", key, AuthStyle::Bearer,
|
||||||
))),
|
))),
|
||||||
"copilot" | "github-copilot" => Ok(Box::new(OpenAiCompatibleProvider::new(
|
"copilot" | "github-copilot" => {
|
||||||
"GitHub Copilot", "https://api.githubcopilot.com", key, AuthStyle::Bearer,
|
Ok(Box::new(copilot::CopilotProvider::new(api_key)))
|
||||||
))),
|
},
|
||||||
"nvidia" | "nvidia-nim" | "build.nvidia.com" => Ok(Box::new(OpenAiCompatibleProvider::new(
|
"lmstudio" | "lm-studio" => {
|
||||||
"NVIDIA NIM", "https://integrate.api.nvidia.com/v1", key, AuthStyle::Bearer,
|
let lm_studio_key = api_key
|
||||||
|
.map(str::trim)
|
||||||
|
.filter(|value| !value.is_empty())
|
||||||
|
.unwrap_or("lm-studio");
|
||||||
|
Ok(Box::new(OpenAiCompatibleProvider::new(
|
||||||
|
"LM Studio",
|
||||||
|
"http://localhost:1234/v1",
|
||||||
|
Some(lm_studio_key),
|
||||||
|
AuthStyle::Bearer,
|
||||||
|
)))
|
||||||
|
}
|
||||||
|
"nvidia" | "nvidia-nim" | "build.nvidia.com" => Ok(Box::new(
|
||||||
|
OpenAiCompatibleProvider::new(
|
||||||
|
"NVIDIA NIM",
|
||||||
|
"https://integrate.api.nvidia.com/v1",
|
||||||
|
key,
|
||||||
|
AuthStyle::Bearer,
|
||||||
|
),
|
||||||
|
)),
|
||||||
|
|
||||||
|
// ── AI inference routers ─────────────────────────────
|
||||||
|
"astrai" => Ok(Box::new(OpenAiCompatibleProvider::new(
|
||||||
|
"Astrai", "https://as-trai.com/v1", key, AuthStyle::Bearer,
|
||||||
))),
|
))),
|
||||||
|
|
||||||
// ── Bring Your Own Provider (custom URL) ───────────
|
// ── Bring Your Own Provider (custom URL) ───────────
|
||||||
|
|
@ -326,13 +371,14 @@ pub fn create_provider(name: &str, api_key: Option<&str>) -> anyhow::Result<Box<
|
||||||
pub fn create_resilient_provider(
|
pub fn create_resilient_provider(
|
||||||
primary_name: &str,
|
primary_name: &str,
|
||||||
api_key: Option<&str>,
|
api_key: Option<&str>,
|
||||||
|
api_url: Option<&str>,
|
||||||
reliability: &crate::config::ReliabilityConfig,
|
reliability: &crate::config::ReliabilityConfig,
|
||||||
) -> anyhow::Result<Box<dyn Provider>> {
|
) -> anyhow::Result<Box<dyn Provider>> {
|
||||||
let mut providers: Vec<(String, Box<dyn Provider>)> = Vec::new();
|
let mut providers: Vec<(String, Box<dyn Provider>)> = Vec::new();
|
||||||
|
|
||||||
providers.push((
|
providers.push((
|
||||||
primary_name.to_string(),
|
primary_name.to_string(),
|
||||||
create_provider(primary_name, api_key)?,
|
create_provider_with_url(primary_name, api_key, api_url)?,
|
||||||
));
|
));
|
||||||
|
|
||||||
for fallback in &reliability.fallback_providers {
|
for fallback in &reliability.fallback_providers {
|
||||||
|
|
@ -340,21 +386,13 @@ pub fn create_resilient_provider(
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if api_key.is_some() && fallback != "ollama" {
|
// Fallback providers don't use the custom api_url (it's specific to primary)
|
||||||
tracing::warn!(
|
|
||||||
fallback_provider = fallback,
|
|
||||||
primary_provider = primary_name,
|
|
||||||
"Fallback provider will use the primary provider's API key — \
|
|
||||||
this will fail if the providers require different keys"
|
|
||||||
);
|
|
||||||
}
|
|
||||||
|
|
||||||
match create_provider(fallback, api_key) {
|
match create_provider(fallback, api_key) {
|
||||||
Ok(provider) => providers.push((fallback.clone(), provider)),
|
Ok(provider) => providers.push((fallback.clone(), provider)),
|
||||||
Err(e) => {
|
Err(_error) => {
|
||||||
tracing::warn!(
|
tracing::warn!(
|
||||||
fallback_provider = fallback,
|
fallback_provider = fallback,
|
||||||
"Ignoring invalid fallback provider: {e}"
|
"Ignoring invalid fallback provider during initialization"
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -377,12 +415,13 @@ pub fn create_resilient_provider(
|
||||||
pub fn create_routed_provider(
|
pub fn create_routed_provider(
|
||||||
primary_name: &str,
|
primary_name: &str,
|
||||||
api_key: Option<&str>,
|
api_key: Option<&str>,
|
||||||
|
api_url: Option<&str>,
|
||||||
reliability: &crate::config::ReliabilityConfig,
|
reliability: &crate::config::ReliabilityConfig,
|
||||||
model_routes: &[crate::config::ModelRouteConfig],
|
model_routes: &[crate::config::ModelRouteConfig],
|
||||||
default_model: &str,
|
default_model: &str,
|
||||||
) -> anyhow::Result<Box<dyn Provider>> {
|
) -> anyhow::Result<Box<dyn Provider>> {
|
||||||
if model_routes.is_empty() {
|
if model_routes.is_empty() {
|
||||||
return create_resilient_provider(primary_name, api_key, reliability);
|
return create_resilient_provider(primary_name, api_key, api_url, reliability);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Collect unique provider names needed
|
// Collect unique provider names needed
|
||||||
|
|
@ -396,12 +435,19 @@ pub fn create_routed_provider(
|
||||||
// Create each provider (with its own resilience wrapper)
|
// Create each provider (with its own resilience wrapper)
|
||||||
let mut providers: Vec<(String, Box<dyn Provider>)> = Vec::new();
|
let mut providers: Vec<(String, Box<dyn Provider>)> = Vec::new();
|
||||||
for name in &needed {
|
for name in &needed {
|
||||||
let key = model_routes
|
let routed_credential = model_routes
|
||||||
.iter()
|
.iter()
|
||||||
.find(|r| &r.provider == name)
|
.find(|r| &r.provider == name)
|
||||||
.and_then(|r| r.api_key.as_deref())
|
.and_then(|r| {
|
||||||
.or(api_key);
|
r.api_key.as_ref().and_then(|raw_key| {
|
||||||
match create_resilient_provider(name, key, reliability) {
|
let trimmed_key = raw_key.trim();
|
||||||
|
(!trimmed_key.is_empty()).then_some(trimmed_key)
|
||||||
|
})
|
||||||
|
});
|
||||||
|
let key = routed_credential.or(api_key);
|
||||||
|
// Only use api_url for the primary provider
|
||||||
|
let url = if name == primary_name { api_url } else { None };
|
||||||
|
match create_resilient_provider(name, key, url, reliability) {
|
||||||
Ok(provider) => providers.push((name.clone(), provider)),
|
Ok(provider) => providers.push((name.clone(), provider)),
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
if name == primary_name {
|
if name == primary_name {
|
||||||
|
|
@ -409,7 +455,7 @@ pub fn create_routed_provider(
|
||||||
}
|
}
|
||||||
tracing::warn!(
|
tracing::warn!(
|
||||||
provider = name.as_str(),
|
provider = name.as_str(),
|
||||||
"Ignoring routed provider that failed to create: {e}"
|
"Ignoring routed provider that failed to initialize"
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -441,27 +487,27 @@ mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn resolve_api_key_prefers_explicit_argument() {
|
fn resolve_provider_credential_prefers_explicit_argument() {
|
||||||
let resolved = resolve_api_key("openrouter", Some(" explicit-key "));
|
let resolved = resolve_provider_credential("openrouter", Some(" explicit-key "));
|
||||||
assert_eq!(resolved.as_deref(), Some("explicit-key"));
|
assert_eq!(resolved, Some("explicit-key".to_string()));
|
||||||
}
|
}
|
||||||
|
|
||||||
// ── Primary providers ────────────────────────────────────
|
// ── Primary providers ────────────────────────────────────
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn factory_openrouter() {
|
fn factory_openrouter() {
|
||||||
assert!(create_provider("openrouter", Some("sk-test")).is_ok());
|
assert!(create_provider("openrouter", Some("provider-test-credential")).is_ok());
|
||||||
assert!(create_provider("openrouter", None).is_ok());
|
assert!(create_provider("openrouter", None).is_ok());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn factory_anthropic() {
|
fn factory_anthropic() {
|
||||||
assert!(create_provider("anthropic", Some("sk-test")).is_ok());
|
assert!(create_provider("anthropic", Some("provider-test-credential")).is_ok());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn factory_openai() {
|
fn factory_openai() {
|
||||||
assert!(create_provider("openai", Some("sk-test")).is_ok());
|
assert!(create_provider("openai", Some("provider-test-credential")).is_ok());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
@ -556,6 +602,13 @@ mod tests {
|
||||||
assert!(create_provider("dashscope-us", Some("key")).is_ok());
|
assert!(create_provider("dashscope-us", Some("key")).is_ok());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn factory_lmstudio() {
|
||||||
|
assert!(create_provider("lmstudio", Some("key")).is_ok());
|
||||||
|
assert!(create_provider("lm-studio", Some("key")).is_ok());
|
||||||
|
assert!(create_provider("lmstudio", None).is_ok());
|
||||||
|
}
|
||||||
|
|
||||||
// ── Extended ecosystem ───────────────────────────────────
|
// ── Extended ecosystem ───────────────────────────────────
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
@ -614,6 +667,13 @@ mod tests {
|
||||||
assert!(create_provider("build.nvidia.com", Some("nvapi-test")).is_ok());
|
assert!(create_provider("build.nvidia.com", Some("nvapi-test")).is_ok());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ── AI inference routers ─────────────────────────────────
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn factory_astrai() {
|
||||||
|
assert!(create_provider("astrai", Some("sk-astrai-test")).is_ok());
|
||||||
|
}
|
||||||
|
|
||||||
// ── Custom / BYOP provider ─────────────────────────────
|
// ── Custom / BYOP provider ─────────────────────────────
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
@ -761,17 +821,33 @@ mod tests {
|
||||||
scheduler_retries: 2,
|
scheduler_retries: 2,
|
||||||
};
|
};
|
||||||
|
|
||||||
let provider = create_resilient_provider("openrouter", Some("sk-test"), &reliability);
|
let provider = create_resilient_provider(
|
||||||
|
"openrouter",
|
||||||
|
Some("provider-test-credential"),
|
||||||
|
None,
|
||||||
|
&reliability,
|
||||||
|
);
|
||||||
assert!(provider.is_ok());
|
assert!(provider.is_ok());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn resilient_provider_errors_for_invalid_primary() {
|
fn resilient_provider_errors_for_invalid_primary() {
|
||||||
let reliability = crate::config::ReliabilityConfig::default();
|
let reliability = crate::config::ReliabilityConfig::default();
|
||||||
let provider = create_resilient_provider("totally-invalid", Some("sk-test"), &reliability);
|
let provider = create_resilient_provider(
|
||||||
|
"totally-invalid",
|
||||||
|
Some("provider-test-credential"),
|
||||||
|
None,
|
||||||
|
&reliability,
|
||||||
|
);
|
||||||
assert!(provider.is_err());
|
assert!(provider.is_err());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn ollama_with_custom_url() {
|
||||||
|
let provider = create_provider_with_url("ollama", None, Some("http://10.100.2.32:11434"));
|
||||||
|
assert!(provider.is_ok());
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn factory_all_providers_create_successfully() {
|
fn factory_all_providers_create_successfully() {
|
||||||
let providers = [
|
let providers = [
|
||||||
|
|
@ -794,6 +870,7 @@ mod tests {
|
||||||
"qwen",
|
"qwen",
|
||||||
"qwen-intl",
|
"qwen-intl",
|
||||||
"qwen-us",
|
"qwen-us",
|
||||||
|
"lmstudio",
|
||||||
"groq",
|
"groq",
|
||||||
"mistral",
|
"mistral",
|
||||||
"xai",
|
"xai",
|
||||||
|
|
@ -888,7 +965,7 @@ mod tests {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn sanitize_preserves_unicode_boundaries() {
|
fn sanitize_preserves_unicode_boundaries() {
|
||||||
let input = format!("{} sk-abcdef123", "こんにちは".repeat(80));
|
let input = format!("{} sk-abcdef123", "hello🙂".repeat(80));
|
||||||
let result = sanitize_api_error(&input);
|
let result = sanitize_api_error(&input);
|
||||||
assert!(std::str::from_utf8(result.as_bytes()).is_ok());
|
assert!(std::str::from_utf8(result.as_bytes()).is_ok());
|
||||||
assert!(!result.contains("sk-abcdef123"));
|
assert!(!result.contains("sk-abcdef123"));
|
||||||
|
|
@ -900,4 +977,32 @@ mod tests {
|
||||||
let result = sanitize_api_error(input);
|
let result = sanitize_api_error(input);
|
||||||
assert_eq!(result, input);
|
assert_eq!(result, input);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn scrub_github_personal_access_token() {
|
||||||
|
let input = "auth failed with token ghp_abc123def456";
|
||||||
|
let result = scrub_secret_patterns(input);
|
||||||
|
assert_eq!(result, "auth failed with token [REDACTED]");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn scrub_github_oauth_token() {
|
||||||
|
let input = "Bearer gho_1234567890abcdef";
|
||||||
|
let result = scrub_secret_patterns(input);
|
||||||
|
assert_eq!(result, "Bearer [REDACTED]");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn scrub_github_user_token() {
|
||||||
|
let input = "token ghu_sessiontoken123";
|
||||||
|
let result = scrub_secret_patterns(input);
|
||||||
|
assert_eq!(result, "token [REDACTED]");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn scrub_github_fine_grained_pat() {
|
||||||
|
let input = "failed: github_pat_11AABBC_xyzzy789";
|
||||||
|
let result = scrub_secret_patterns(input);
|
||||||
|
assert_eq!(result, "failed: [REDACTED]");
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -8,6 +8,8 @@ pub struct OllamaProvider {
|
||||||
client: Client,
|
client: Client,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ─── Request Structures ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
#[derive(Debug, Serialize)]
|
#[derive(Debug, Serialize)]
|
||||||
struct ChatRequest {
|
struct ChatRequest {
|
||||||
model: String,
|
model: String,
|
||||||
|
|
@ -27,6 +29,8 @@ struct Options {
|
||||||
temperature: f64,
|
temperature: f64,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ─── Response Structures ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
struct ApiChatResponse {
|
struct ApiChatResponse {
|
||||||
message: ResponseMessage,
|
message: ResponseMessage,
|
||||||
|
|
@ -34,9 +38,30 @@ struct ApiChatResponse {
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
struct ResponseMessage {
|
struct ResponseMessage {
|
||||||
|
#[serde(default)]
|
||||||
content: String,
|
content: String,
|
||||||
|
#[serde(default)]
|
||||||
|
tool_calls: Vec<OllamaToolCall>,
|
||||||
|
/// Some models return a "thinking" field with internal reasoning
|
||||||
|
#[serde(default)]
|
||||||
|
thinking: Option<String>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct OllamaToolCall {
|
||||||
|
id: Option<String>,
|
||||||
|
function: OllamaFunction,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
struct OllamaFunction {
|
||||||
|
name: String,
|
||||||
|
#[serde(default)]
|
||||||
|
arguments: serde_json::Value,
|
||||||
|
}
|
||||||
|
|
||||||
|
// ─── Implementation ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
impl OllamaProvider {
|
impl OllamaProvider {
|
||||||
pub fn new(base_url: Option<&str>) -> Self {
|
pub fn new(base_url: Option<&str>) -> Self {
|
||||||
Self {
|
Self {
|
||||||
|
|
@ -45,12 +70,145 @@ impl OllamaProvider {
|
||||||
.trim_end_matches('/')
|
.trim_end_matches('/')
|
||||||
.to_string(),
|
.to_string(),
|
||||||
client: Client::builder()
|
client: Client::builder()
|
||||||
.timeout(std::time::Duration::from_secs(300)) // Ollama runs locally, may be slow
|
.timeout(std::time::Duration::from_secs(300))
|
||||||
.connect_timeout(std::time::Duration::from_secs(10))
|
.connect_timeout(std::time::Duration::from_secs(10))
|
||||||
.build()
|
.build()
|
||||||
.unwrap_or_else(|_| Client::new()),
|
.unwrap_or_else(|_| Client::new()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Send a request to Ollama and get the parsed response
|
||||||
|
async fn send_request(
|
||||||
|
&self,
|
||||||
|
messages: Vec<Message>,
|
||||||
|
model: &str,
|
||||||
|
temperature: f64,
|
||||||
|
) -> anyhow::Result<ApiChatResponse> {
|
||||||
|
let request = ChatRequest {
|
||||||
|
model: model.to_string(),
|
||||||
|
messages,
|
||||||
|
stream: false,
|
||||||
|
options: Options { temperature },
|
||||||
|
};
|
||||||
|
|
||||||
|
let url = format!("{}/api/chat", self.base_url);
|
||||||
|
|
||||||
|
tracing::debug!(
|
||||||
|
"Ollama request: url={} model={} message_count={} temperature={}",
|
||||||
|
url,
|
||||||
|
model,
|
||||||
|
request.messages.len(),
|
||||||
|
temperature
|
||||||
|
);
|
||||||
|
|
||||||
|
let response = self.client.post(&url).json(&request).send().await?;
|
||||||
|
let status = response.status();
|
||||||
|
tracing::debug!("Ollama response status: {}", status);
|
||||||
|
|
||||||
|
let body = response.bytes().await?;
|
||||||
|
tracing::debug!("Ollama response body length: {} bytes", body.len());
|
||||||
|
|
||||||
|
if !status.is_success() {
|
||||||
|
let raw = String::from_utf8_lossy(&body);
|
||||||
|
let sanitized = super::sanitize_api_error(&raw);
|
||||||
|
tracing::error!(
|
||||||
|
"Ollama error response: status={} body_excerpt={}",
|
||||||
|
status,
|
||||||
|
sanitized
|
||||||
|
);
|
||||||
|
anyhow::bail!(
|
||||||
|
"Ollama API error ({}): {}. Is Ollama running? (brew install ollama && ollama serve)",
|
||||||
|
status,
|
||||||
|
sanitized
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
let chat_response: ApiChatResponse = match serde_json::from_slice(&body) {
|
||||||
|
Ok(r) => r,
|
||||||
|
Err(e) => {
|
||||||
|
let raw = String::from_utf8_lossy(&body);
|
||||||
|
let sanitized = super::sanitize_api_error(&raw);
|
||||||
|
tracing::error!(
|
||||||
|
"Ollama response deserialization failed: {e}. body_excerpt={}",
|
||||||
|
sanitized
|
||||||
|
);
|
||||||
|
anyhow::bail!("Failed to parse Ollama response: {e}");
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(chat_response)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Convert Ollama tool calls to the JSON format expected by parse_tool_calls in loop_.rs
|
||||||
|
///
|
||||||
|
/// Handles quirky model behavior where tool calls are wrapped:
|
||||||
|
/// - `{"name": "tool_call", "arguments": {"name": "shell", "arguments": {...}}}`
|
||||||
|
/// - `{"name": "tool.shell", "arguments": {...}}`
|
||||||
|
fn format_tool_calls_for_loop(&self, tool_calls: &[OllamaToolCall]) -> String {
|
||||||
|
let formatted_calls: Vec<serde_json::Value> = tool_calls
|
||||||
|
.iter()
|
||||||
|
.map(|tc| {
|
||||||
|
let (tool_name, tool_args) = self.extract_tool_name_and_args(tc);
|
||||||
|
|
||||||
|
// Arguments must be a JSON string for parse_tool_calls compatibility
|
||||||
|
let args_str =
|
||||||
|
serde_json::to_string(&tool_args).unwrap_or_else(|_| "{}".to_string());
|
||||||
|
|
||||||
|
serde_json::json!({
|
||||||
|
"id": tc.id,
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": tool_name,
|
||||||
|
"arguments": args_str
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
serde_json::json!({
|
||||||
|
"content": "",
|
||||||
|
"tool_calls": formatted_calls
|
||||||
|
})
|
||||||
|
.to_string()
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Extract the actual tool name and arguments from potentially nested structures
|
||||||
|
fn extract_tool_name_and_args(&self, tc: &OllamaToolCall) -> (String, serde_json::Value) {
|
||||||
|
let name = &tc.function.name;
|
||||||
|
let args = &tc.function.arguments;
|
||||||
|
|
||||||
|
// Pattern 1: Nested tool_call wrapper (various malformed versions)
|
||||||
|
// {"name": "tool_call", "arguments": {"name": "shell", "arguments": {"command": "date"}}}
|
||||||
|
// {"name": "tool_call><json", "arguments": {"name": "shell", ...}}
|
||||||
|
// {"name": "tool.call", "arguments": {"name": "shell", ...}}
|
||||||
|
if name == "tool_call"
|
||||||
|
|| name == "tool.call"
|
||||||
|
|| name.starts_with("tool_call>")
|
||||||
|
|| name.starts_with("tool_call<")
|
||||||
|
{
|
||||||
|
if let Some(nested_name) = args.get("name").and_then(|v| v.as_str()) {
|
||||||
|
let nested_args = args
|
||||||
|
.get("arguments")
|
||||||
|
.cloned()
|
||||||
|
.unwrap_or(serde_json::json!({}));
|
||||||
|
tracing::debug!(
|
||||||
|
"Unwrapped nested tool call: {} -> {} with args {:?}",
|
||||||
|
name,
|
||||||
|
nested_name,
|
||||||
|
nested_args
|
||||||
|
);
|
||||||
|
return (nested_name.to_string(), nested_args);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pattern 2: Prefixed tool name (tool.shell, tool.file_read, etc.)
|
||||||
|
if let Some(stripped) = name.strip_prefix("tool.") {
|
||||||
|
return (stripped.to_string(), args.clone());
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pattern 3: Normal tool call
|
||||||
|
(name.clone(), args.clone())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
|
|
@ -76,27 +234,96 @@ impl Provider for OllamaProvider {
|
||||||
content: message.to_string(),
|
content: message.to_string(),
|
||||||
});
|
});
|
||||||
|
|
||||||
let request = ChatRequest {
|
let response = self.send_request(messages, model, temperature).await?;
|
||||||
model: model.to_string(),
|
|
||||||
messages,
|
|
||||||
stream: false,
|
|
||||||
options: Options { temperature },
|
|
||||||
};
|
|
||||||
|
|
||||||
let url = format!("{}/api/chat", self.base_url);
|
// If model returned tool calls, format them for loop_.rs's parse_tool_calls
|
||||||
|
if !response.message.tool_calls.is_empty() {
|
||||||
let response = self.client.post(&url).json(&request).send().await?;
|
tracing::debug!(
|
||||||
|
"Ollama returned {} tool call(s), formatting for loop parser",
|
||||||
if !response.status().is_success() {
|
response.message.tool_calls.len()
|
||||||
let err = super::api_error("Ollama", response).await;
|
);
|
||||||
anyhow::bail!("{err}. Is Ollama running? (brew install ollama && ollama serve)");
|
return Ok(self.format_tool_calls_for_loop(&response.message.tool_calls));
|
||||||
}
|
}
|
||||||
|
|
||||||
let chat_response: ApiChatResponse = response.json().await?;
|
// Plain text response
|
||||||
Ok(chat_response.message.content)
|
let content = response.message.content;
|
||||||
|
|
||||||
|
// Handle edge case: model returned only "thinking" with no content or tool calls
|
||||||
|
if content.is_empty() {
|
||||||
|
if let Some(thinking) = &response.message.thinking {
|
||||||
|
tracing::warn!(
|
||||||
|
"Ollama returned empty content with only thinking: '{}'. Model may have stopped prematurely.",
|
||||||
|
if thinking.len() > 100 { &thinking[..100] } else { thinking }
|
||||||
|
);
|
||||||
|
return Ok(format!(
|
||||||
|
"I was thinking about this: {}... but I didn't complete my response. Could you try asking again?",
|
||||||
|
if thinking.len() > 200 { &thinking[..200] } else { thinking }
|
||||||
|
));
|
||||||
|
}
|
||||||
|
tracing::warn!("Ollama returned empty content with no tool calls");
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(content)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn chat_with_history(
|
||||||
|
&self,
|
||||||
|
messages: &[crate::providers::ChatMessage],
|
||||||
|
model: &str,
|
||||||
|
temperature: f64,
|
||||||
|
) -> anyhow::Result<String> {
|
||||||
|
let api_messages: Vec<Message> = messages
|
||||||
|
.iter()
|
||||||
|
.map(|m| Message {
|
||||||
|
role: m.role.clone(),
|
||||||
|
content: m.content.clone(),
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
let response = self.send_request(api_messages, model, temperature).await?;
|
||||||
|
|
||||||
|
// If model returned tool calls, format them for loop_.rs's parse_tool_calls
|
||||||
|
if !response.message.tool_calls.is_empty() {
|
||||||
|
tracing::debug!(
|
||||||
|
"Ollama returned {} tool call(s), formatting for loop parser",
|
||||||
|
response.message.tool_calls.len()
|
||||||
|
);
|
||||||
|
return Ok(self.format_tool_calls_for_loop(&response.message.tool_calls));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Plain text response
|
||||||
|
let content = response.message.content;
|
||||||
|
|
||||||
|
// Handle edge case: model returned only "thinking" with no content or tool calls
|
||||||
|
// This is a model quirk - it stopped after reasoning without producing output
|
||||||
|
if content.is_empty() {
|
||||||
|
if let Some(thinking) = &response.message.thinking {
|
||||||
|
tracing::warn!(
|
||||||
|
"Ollama returned empty content with only thinking: '{}'. Model may have stopped prematurely.",
|
||||||
|
if thinking.len() > 100 { &thinking[..100] } else { thinking }
|
||||||
|
);
|
||||||
|
// Return a message indicating the model's thought process but no action
|
||||||
|
return Ok(format!(
|
||||||
|
"I was thinking about this: {}... but I didn't complete my response. Could you try asking again?",
|
||||||
|
if thinking.len() > 200 { &thinking[..200] } else { thinking }
|
||||||
|
));
|
||||||
|
}
|
||||||
|
tracing::warn!("Ollama returned empty content with no tool calls");
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(content)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn supports_native_tools(&self) -> bool {
|
||||||
|
// Return false since loop_.rs uses XML-style tool parsing via system prompt
|
||||||
|
// The model may return native tool_calls but we convert them to JSON format
|
||||||
|
// that parse_tool_calls() understands
|
||||||
|
false
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ─── Tests ────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
@ -125,46 +352,6 @@ mod tests {
|
||||||
assert_eq!(p.base_url, "");
|
assert_eq!(p.base_url, "");
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn request_serializes_with_system() {
|
|
||||||
let req = ChatRequest {
|
|
||||||
model: "llama3".to_string(),
|
|
||||||
messages: vec![
|
|
||||||
Message {
|
|
||||||
role: "system".to_string(),
|
|
||||||
content: "You are ZeroClaw".to_string(),
|
|
||||||
},
|
|
||||||
Message {
|
|
||||||
role: "user".to_string(),
|
|
||||||
content: "hello".to_string(),
|
|
||||||
},
|
|
||||||
],
|
|
||||||
stream: false,
|
|
||||||
options: Options { temperature: 0.7 },
|
|
||||||
};
|
|
||||||
let json = serde_json::to_string(&req).unwrap();
|
|
||||||
assert!(json.contains("\"stream\":false"));
|
|
||||||
assert!(json.contains("llama3"));
|
|
||||||
assert!(json.contains("system"));
|
|
||||||
assert!(json.contains("\"temperature\":0.7"));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
|
||||||
fn request_serializes_without_system() {
|
|
||||||
let req = ChatRequest {
|
|
||||||
model: "mistral".to_string(),
|
|
||||||
messages: vec![Message {
|
|
||||||
role: "user".to_string(),
|
|
||||||
content: "test".to_string(),
|
|
||||||
}],
|
|
||||||
stream: false,
|
|
||||||
options: Options { temperature: 0.0 },
|
|
||||||
};
|
|
||||||
let json = serde_json::to_string(&req).unwrap();
|
|
||||||
assert!(!json.contains("\"role\":\"system\""));
|
|
||||||
assert!(json.contains("mistral"));
|
|
||||||
}
|
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn response_deserializes() {
|
fn response_deserializes() {
|
||||||
let json = r#"{"message":{"role":"assistant","content":"Hello from Ollama!"}}"#;
|
let json = r#"{"message":{"role":"assistant","content":"Hello from Ollama!"}}"#;
|
||||||
|
|
@ -180,9 +367,98 @@ mod tests {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn response_with_multiline() {
|
fn response_with_missing_content_defaults_to_empty() {
|
||||||
let json = r#"{"message":{"role":"assistant","content":"line1\nline2\nline3"}}"#;
|
let json = r#"{"message":{"role":"assistant"}}"#;
|
||||||
let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
|
let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
|
||||||
assert!(resp.message.content.contains("line1"));
|
assert!(resp.message.content.is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn response_with_thinking_field_extracts_content() {
|
||||||
|
let json =
|
||||||
|
r#"{"message":{"role":"assistant","content":"hello","thinking":"internal reasoning"}}"#;
|
||||||
|
let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
|
||||||
|
assert_eq!(resp.message.content, "hello");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn response_with_tool_calls_parses_correctly() {
|
||||||
|
let json = r#"{"message":{"role":"assistant","content":"","tool_calls":[{"id":"call_123","function":{"name":"shell","arguments":{"command":"date"}}}]}}"#;
|
||||||
|
let resp: ApiChatResponse = serde_json::from_str(json).unwrap();
|
||||||
|
assert!(resp.message.content.is_empty());
|
||||||
|
assert_eq!(resp.message.tool_calls.len(), 1);
|
||||||
|
assert_eq!(resp.message.tool_calls[0].function.name, "shell");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn extract_tool_name_handles_nested_tool_call() {
|
||||||
|
let provider = OllamaProvider::new(None);
|
||||||
|
let tc = OllamaToolCall {
|
||||||
|
id: Some("call_123".into()),
|
||||||
|
function: OllamaFunction {
|
||||||
|
name: "tool_call".into(),
|
||||||
|
arguments: serde_json::json!({
|
||||||
|
"name": "shell",
|
||||||
|
"arguments": {"command": "date"}
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
};
|
||||||
|
let (name, args) = provider.extract_tool_name_and_args(&tc);
|
||||||
|
assert_eq!(name, "shell");
|
||||||
|
assert_eq!(args.get("command").unwrap(), "date");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn extract_tool_name_handles_prefixed_name() {
|
||||||
|
let provider = OllamaProvider::new(None);
|
||||||
|
let tc = OllamaToolCall {
|
||||||
|
id: Some("call_123".into()),
|
||||||
|
function: OllamaFunction {
|
||||||
|
name: "tool.shell".into(),
|
||||||
|
arguments: serde_json::json!({"command": "ls"}),
|
||||||
|
},
|
||||||
|
};
|
||||||
|
let (name, args) = provider.extract_tool_name_and_args(&tc);
|
||||||
|
assert_eq!(name, "shell");
|
||||||
|
assert_eq!(args.get("command").unwrap(), "ls");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn extract_tool_name_handles_normal_call() {
|
||||||
|
let provider = OllamaProvider::new(None);
|
||||||
|
let tc = OllamaToolCall {
|
||||||
|
id: Some("call_123".into()),
|
||||||
|
function: OllamaFunction {
|
||||||
|
name: "file_read".into(),
|
||||||
|
arguments: serde_json::json!({"path": "/tmp/test"}),
|
||||||
|
},
|
||||||
|
};
|
||||||
|
let (name, args) = provider.extract_tool_name_and_args(&tc);
|
||||||
|
assert_eq!(name, "file_read");
|
||||||
|
assert_eq!(args.get("path").unwrap(), "/tmp/test");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn format_tool_calls_produces_valid_json() {
|
||||||
|
let provider = OllamaProvider::new(None);
|
||||||
|
let tool_calls = vec![OllamaToolCall {
|
||||||
|
id: Some("call_abc".into()),
|
||||||
|
function: OllamaFunction {
|
||||||
|
name: "shell".into(),
|
||||||
|
arguments: serde_json::json!({"command": "date"}),
|
||||||
|
},
|
||||||
|
}];
|
||||||
|
|
||||||
|
let formatted = provider.format_tool_calls_for_loop(&tool_calls);
|
||||||
|
let parsed: serde_json::Value = serde_json::from_str(&formatted).unwrap();
|
||||||
|
|
||||||
|
assert!(parsed.get("tool_calls").is_some());
|
||||||
|
let calls = parsed.get("tool_calls").unwrap().as_array().unwrap();
|
||||||
|
assert_eq!(calls.len(), 1);
|
||||||
|
|
||||||
|
let func = calls[0].get("function").unwrap();
|
||||||
|
assert_eq!(func.get("name").unwrap(), "shell");
|
||||||
|
// arguments should be a string (JSON-encoded)
|
||||||
|
assert!(func.get("arguments").unwrap().is_string());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@ use reqwest::Client;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
pub struct OpenAiProvider {
|
pub struct OpenAiProvider {
|
||||||
api_key: Option<String>,
|
credential: Option<String>,
|
||||||
client: Client,
|
client: Client,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -110,9 +110,9 @@ struct NativeResponseMessage {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl OpenAiProvider {
|
impl OpenAiProvider {
|
||||||
pub fn new(api_key: Option<&str>) -> Self {
|
pub fn new(credential: Option<&str>) -> Self {
|
||||||
Self {
|
Self {
|
||||||
api_key: api_key.map(ToString::to_string),
|
credential: credential.map(ToString::to_string),
|
||||||
client: Client::builder()
|
client: Client::builder()
|
||||||
.timeout(std::time::Duration::from_secs(120))
|
.timeout(std::time::Duration::from_secs(120))
|
||||||
.connect_timeout(std::time::Duration::from_secs(10))
|
.connect_timeout(std::time::Duration::from_secs(10))
|
||||||
|
|
@ -232,7 +232,7 @@ impl Provider for OpenAiProvider {
|
||||||
model: &str,
|
model: &str,
|
||||||
temperature: f64,
|
temperature: f64,
|
||||||
) -> anyhow::Result<String> {
|
) -> anyhow::Result<String> {
|
||||||
let api_key = self.api_key.as_ref().ok_or_else(|| {
|
let credential = self.credential.as_ref().ok_or_else(|| {
|
||||||
anyhow::anyhow!("OpenAI API key not set. Set OPENAI_API_KEY or edit config.toml.")
|
anyhow::anyhow!("OpenAI API key not set. Set OPENAI_API_KEY or edit config.toml.")
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
|
|
@ -259,7 +259,7 @@ impl Provider for OpenAiProvider {
|
||||||
let response = self
|
let response = self
|
||||||
.client
|
.client
|
||||||
.post("https://api.openai.com/v1/chat/completions")
|
.post("https://api.openai.com/v1/chat/completions")
|
||||||
.header("Authorization", format!("Bearer {api_key}"))
|
.header("Authorization", format!("Bearer {credential}"))
|
||||||
.json(&request)
|
.json(&request)
|
||||||
.send()
|
.send()
|
||||||
.await?;
|
.await?;
|
||||||
|
|
@ -284,7 +284,7 @@ impl Provider for OpenAiProvider {
|
||||||
model: &str,
|
model: &str,
|
||||||
temperature: f64,
|
temperature: f64,
|
||||||
) -> anyhow::Result<ProviderChatResponse> {
|
) -> anyhow::Result<ProviderChatResponse> {
|
||||||
let api_key = self.api_key.as_ref().ok_or_else(|| {
|
let credential = self.credential.as_ref().ok_or_else(|| {
|
||||||
anyhow::anyhow!("OpenAI API key not set. Set OPENAI_API_KEY or edit config.toml.")
|
anyhow::anyhow!("OpenAI API key not set. Set OPENAI_API_KEY or edit config.toml.")
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
|
|
@ -300,7 +300,7 @@ impl Provider for OpenAiProvider {
|
||||||
let response = self
|
let response = self
|
||||||
.client
|
.client
|
||||||
.post("https://api.openai.com/v1/chat/completions")
|
.post("https://api.openai.com/v1/chat/completions")
|
||||||
.header("Authorization", format!("Bearer {api_key}"))
|
.header("Authorization", format!("Bearer {credential}"))
|
||||||
.json(&native_request)
|
.json(&native_request)
|
||||||
.send()
|
.send()
|
||||||
.await?;
|
.await?;
|
||||||
|
|
@ -330,20 +330,20 @@ mod tests {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn creates_with_key() {
|
fn creates_with_key() {
|
||||||
let p = OpenAiProvider::new(Some("sk-proj-abc123"));
|
let p = OpenAiProvider::new(Some("openai-test-credential"));
|
||||||
assert_eq!(p.api_key.as_deref(), Some("sk-proj-abc123"));
|
assert_eq!(p.credential.as_deref(), Some("openai-test-credential"));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn creates_without_key() {
|
fn creates_without_key() {
|
||||||
let p = OpenAiProvider::new(None);
|
let p = OpenAiProvider::new(None);
|
||||||
assert!(p.api_key.is_none());
|
assert!(p.credential.is_none());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn creates_with_empty_key() {
|
fn creates_with_empty_key() {
|
||||||
let p = OpenAiProvider::new(Some(""));
|
let p = OpenAiProvider::new(Some(""));
|
||||||
assert_eq!(p.api_key.as_deref(), Some(""));
|
assert_eq!(p.credential.as_deref(), Some(""));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
|
|
|
||||||
|
|
@ -8,7 +8,7 @@ use reqwest::Client;
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
pub struct OpenRouterProvider {
|
pub struct OpenRouterProvider {
|
||||||
api_key: Option<String>,
|
credential: Option<String>,
|
||||||
client: Client,
|
client: Client,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -110,9 +110,9 @@ struct NativeResponseMessage {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl OpenRouterProvider {
|
impl OpenRouterProvider {
|
||||||
pub fn new(api_key: Option<&str>) -> Self {
|
pub fn new(credential: Option<&str>) -> Self {
|
||||||
Self {
|
Self {
|
||||||
api_key: api_key.map(ToString::to_string),
|
credential: credential.map(ToString::to_string),
|
||||||
client: Client::builder()
|
client: Client::builder()
|
||||||
.timeout(std::time::Duration::from_secs(120))
|
.timeout(std::time::Duration::from_secs(120))
|
||||||
.connect_timeout(std::time::Duration::from_secs(10))
|
.connect_timeout(std::time::Duration::from_secs(10))
|
||||||
|
|
@ -232,10 +232,10 @@ impl Provider for OpenRouterProvider {
|
||||||
async fn warmup(&self) -> anyhow::Result<()> {
|
async fn warmup(&self) -> anyhow::Result<()> {
|
||||||
// Hit a lightweight endpoint to establish TLS + HTTP/2 connection pool.
|
// Hit a lightweight endpoint to establish TLS + HTTP/2 connection pool.
|
||||||
// This prevents the first real chat request from timing out on cold start.
|
// This prevents the first real chat request from timing out on cold start.
|
||||||
if let Some(api_key) = self.api_key.as_ref() {
|
if let Some(credential) = self.credential.as_ref() {
|
||||||
self.client
|
self.client
|
||||||
.get("https://openrouter.ai/api/v1/auth/key")
|
.get("https://openrouter.ai/api/v1/auth/key")
|
||||||
.header("Authorization", format!("Bearer {api_key}"))
|
.header("Authorization", format!("Bearer {credential}"))
|
||||||
.send()
|
.send()
|
||||||
.await?
|
.await?
|
||||||
.error_for_status()?;
|
.error_for_status()?;
|
||||||
|
|
@ -250,7 +250,7 @@ impl Provider for OpenRouterProvider {
|
||||||
model: &str,
|
model: &str,
|
||||||
temperature: f64,
|
temperature: f64,
|
||||||
) -> anyhow::Result<String> {
|
) -> anyhow::Result<String> {
|
||||||
let api_key = self.api_key.as_ref()
|
let credential = self.credential.as_ref()
|
||||||
.ok_or_else(|| anyhow::anyhow!("OpenRouter API key not set. Run `zeroclaw onboard` or set OPENROUTER_API_KEY env var."))?;
|
.ok_or_else(|| anyhow::anyhow!("OpenRouter API key not set. Run `zeroclaw onboard` or set OPENROUTER_API_KEY env var."))?;
|
||||||
|
|
||||||
let mut messages = Vec::new();
|
let mut messages = Vec::new();
|
||||||
|
|
@ -276,7 +276,7 @@ impl Provider for OpenRouterProvider {
|
||||||
let response = self
|
let response = self
|
||||||
.client
|
.client
|
||||||
.post("https://openrouter.ai/api/v1/chat/completions")
|
.post("https://openrouter.ai/api/v1/chat/completions")
|
||||||
.header("Authorization", format!("Bearer {api_key}"))
|
.header("Authorization", format!("Bearer {credential}"))
|
||||||
.header(
|
.header(
|
||||||
"HTTP-Referer",
|
"HTTP-Referer",
|
||||||
"https://github.com/theonlyhennygod/zeroclaw",
|
"https://github.com/theonlyhennygod/zeroclaw",
|
||||||
|
|
@ -306,7 +306,7 @@ impl Provider for OpenRouterProvider {
|
||||||
model: &str,
|
model: &str,
|
||||||
temperature: f64,
|
temperature: f64,
|
||||||
) -> anyhow::Result<String> {
|
) -> anyhow::Result<String> {
|
||||||
let api_key = self.api_key.as_ref()
|
let credential = self.credential.as_ref()
|
||||||
.ok_or_else(|| anyhow::anyhow!("OpenRouter API key not set. Run `zeroclaw onboard` or set OPENROUTER_API_KEY env var."))?;
|
.ok_or_else(|| anyhow::anyhow!("OpenRouter API key not set. Run `zeroclaw onboard` or set OPENROUTER_API_KEY env var."))?;
|
||||||
|
|
||||||
let api_messages: Vec<Message> = messages
|
let api_messages: Vec<Message> = messages
|
||||||
|
|
@ -326,7 +326,7 @@ impl Provider for OpenRouterProvider {
|
||||||
let response = self
|
let response = self
|
||||||
.client
|
.client
|
||||||
.post("https://openrouter.ai/api/v1/chat/completions")
|
.post("https://openrouter.ai/api/v1/chat/completions")
|
||||||
.header("Authorization", format!("Bearer {api_key}"))
|
.header("Authorization", format!("Bearer {credential}"))
|
||||||
.header(
|
.header(
|
||||||
"HTTP-Referer",
|
"HTTP-Referer",
|
||||||
"https://github.com/theonlyhennygod/zeroclaw",
|
"https://github.com/theonlyhennygod/zeroclaw",
|
||||||
|
|
@ -356,7 +356,7 @@ impl Provider for OpenRouterProvider {
|
||||||
model: &str,
|
model: &str,
|
||||||
temperature: f64,
|
temperature: f64,
|
||||||
) -> anyhow::Result<ProviderChatResponse> {
|
) -> anyhow::Result<ProviderChatResponse> {
|
||||||
let api_key = self.api_key.as_ref().ok_or_else(|| {
|
let credential = self.credential.as_ref().ok_or_else(|| {
|
||||||
anyhow::anyhow!(
|
anyhow::anyhow!(
|
||||||
"OpenRouter API key not set. Run `zeroclaw onboard` or set OPENROUTER_API_KEY env var."
|
"OpenRouter API key not set. Run `zeroclaw onboard` or set OPENROUTER_API_KEY env var."
|
||||||
)
|
)
|
||||||
|
|
@ -374,7 +374,7 @@ impl Provider for OpenRouterProvider {
|
||||||
let response = self
|
let response = self
|
||||||
.client
|
.client
|
||||||
.post("https://openrouter.ai/api/v1/chat/completions")
|
.post("https://openrouter.ai/api/v1/chat/completions")
|
||||||
.header("Authorization", format!("Bearer {api_key}"))
|
.header("Authorization", format!("Bearer {credential}"))
|
||||||
.header(
|
.header(
|
||||||
"HTTP-Referer",
|
"HTTP-Referer",
|
||||||
"https://github.com/theonlyhennygod/zeroclaw",
|
"https://github.com/theonlyhennygod/zeroclaw",
|
||||||
|
|
@ -409,7 +409,7 @@ impl Provider for OpenRouterProvider {
|
||||||
model: &str,
|
model: &str,
|
||||||
temperature: f64,
|
temperature: f64,
|
||||||
) -> anyhow::Result<ProviderChatResponse> {
|
) -> anyhow::Result<ProviderChatResponse> {
|
||||||
let api_key = self.api_key.as_ref().ok_or_else(|| {
|
let credential = self.credential.as_ref().ok_or_else(|| {
|
||||||
anyhow::anyhow!(
|
anyhow::anyhow!(
|
||||||
"OpenRouter API key not set. Run `zeroclaw onboard` or set OPENROUTER_API_KEY env var."
|
"OpenRouter API key not set. Run `zeroclaw onboard` or set OPENROUTER_API_KEY env var."
|
||||||
)
|
)
|
||||||
|
|
@ -462,7 +462,7 @@ impl Provider for OpenRouterProvider {
|
||||||
let response = self
|
let response = self
|
||||||
.client
|
.client
|
||||||
.post("https://openrouter.ai/api/v1/chat/completions")
|
.post("https://openrouter.ai/api/v1/chat/completions")
|
||||||
.header("Authorization", format!("Bearer {api_key}"))
|
.header("Authorization", format!("Bearer {credential}"))
|
||||||
.header(
|
.header(
|
||||||
"HTTP-Referer",
|
"HTTP-Referer",
|
||||||
"https://github.com/theonlyhennygod/zeroclaw",
|
"https://github.com/theonlyhennygod/zeroclaw",
|
||||||
|
|
@ -494,14 +494,17 @@ mod tests {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn creates_with_key() {
|
fn creates_with_key() {
|
||||||
let provider = OpenRouterProvider::new(Some("sk-or-123"));
|
let provider = OpenRouterProvider::new(Some("openrouter-test-credential"));
|
||||||
assert_eq!(provider.api_key.as_deref(), Some("sk-or-123"));
|
assert_eq!(
|
||||||
|
provider.credential.as_deref(),
|
||||||
|
Some("openrouter-test-credential")
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn creates_without_key() {
|
fn creates_without_key() {
|
||||||
let provider = OpenRouterProvider::new(None);
|
let provider = OpenRouterProvider::new(None);
|
||||||
assert!(provider.api_key.is_none());
|
assert!(provider.credential.is_none());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
|
|
|
||||||
|
|
@ -144,8 +144,8 @@ impl Provider for ReliableProvider {
|
||||||
async fn warmup(&self) -> anyhow::Result<()> {
|
async fn warmup(&self) -> anyhow::Result<()> {
|
||||||
for (name, provider) in &self.providers {
|
for (name, provider) in &self.providers {
|
||||||
tracing::info!(provider = name, "Warming up provider connection pool");
|
tracing::info!(provider = name, "Warming up provider connection pool");
|
||||||
if let Err(e) = provider.warmup().await {
|
if provider.warmup().await.is_err() {
|
||||||
tracing::warn!(provider = name, "Warmup failed (non-fatal): {e}");
|
tracing::warn!(provider = name, "Warmup failed (non-fatal)");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
|
|
@ -186,8 +186,15 @@ impl Provider for ReliableProvider {
|
||||||
let non_retryable = is_non_retryable(&e);
|
let non_retryable = is_non_retryable(&e);
|
||||||
let rate_limited = is_rate_limited(&e);
|
let rate_limited = is_rate_limited(&e);
|
||||||
|
|
||||||
|
let failure_reason = if rate_limited {
|
||||||
|
"rate_limited"
|
||||||
|
} else if non_retryable {
|
||||||
|
"non_retryable"
|
||||||
|
} else {
|
||||||
|
"retryable"
|
||||||
|
};
|
||||||
failures.push(format!(
|
failures.push(format!(
|
||||||
"{provider_name}/{current_model} attempt {}/{}: {e}",
|
"{provider_name}/{current_model} attempt {}/{}: {failure_reason}",
|
||||||
attempt + 1,
|
attempt + 1,
|
||||||
self.max_retries + 1
|
self.max_retries + 1
|
||||||
));
|
));
|
||||||
|
|
@ -284,8 +291,15 @@ impl Provider for ReliableProvider {
|
||||||
let non_retryable = is_non_retryable(&e);
|
let non_retryable = is_non_retryable(&e);
|
||||||
let rate_limited = is_rate_limited(&e);
|
let rate_limited = is_rate_limited(&e);
|
||||||
|
|
||||||
|
let failure_reason = if rate_limited {
|
||||||
|
"rate_limited"
|
||||||
|
} else if non_retryable {
|
||||||
|
"non_retryable"
|
||||||
|
} else {
|
||||||
|
"retryable"
|
||||||
|
};
|
||||||
failures.push(format!(
|
failures.push(format!(
|
||||||
"{provider_name}/{current_model} attempt {}/{}: {e}",
|
"{provider_name}/{current_model} attempt {}/{}: {failure_reason}",
|
||||||
attempt + 1,
|
attempt + 1,
|
||||||
self.max_retries + 1
|
self.max_retries + 1
|
||||||
));
|
));
|
||||||
|
|
|
||||||
|
|
@ -193,6 +193,13 @@ pub enum StreamError {
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
pub trait Provider: Send + Sync {
|
pub trait Provider: Send + Sync {
|
||||||
|
/// Query provider capabilities.
|
||||||
|
///
|
||||||
|
/// Default implementation returns minimal capabilities (no native tool calling).
|
||||||
|
/// Providers should override this to declare their actual capabilities.
|
||||||
|
fn capabilities(&self) -> ProviderCapabilities {
|
||||||
|
ProviderCapabilities::default()
|
||||||
|
}
|
||||||
/// Simple one-shot chat (single user message, no explicit system prompt).
|
/// Simple one-shot chat (single user message, no explicit system prompt).
|
||||||
///
|
///
|
||||||
/// This is the preferred API for non-agentic direct interactions.
|
/// This is the preferred API for non-agentic direct interactions.
|
||||||
|
|
@ -256,7 +263,7 @@ pub trait Provider: Send + Sync {
|
||||||
|
|
||||||
/// Whether provider supports native tool calls over API.
|
/// Whether provider supports native tool calls over API.
|
||||||
fn supports_native_tools(&self) -> bool {
|
fn supports_native_tools(&self) -> bool {
|
||||||
false
|
self.capabilities().native_tool_calling
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Warm up the HTTP connection pool (TLS handshake, DNS, HTTP/2 setup).
|
/// Warm up the HTTP connection pool (TLS handshake, DNS, HTTP/2 setup).
|
||||||
|
|
@ -336,6 +343,27 @@ pub trait Provider: Send + Sync {
|
||||||
mod tests {
|
mod tests {
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
|
struct CapabilityMockProvider;
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Provider for CapabilityMockProvider {
|
||||||
|
fn capabilities(&self) -> ProviderCapabilities {
|
||||||
|
ProviderCapabilities {
|
||||||
|
native_tool_calling: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn chat_with_system(
|
||||||
|
&self,
|
||||||
|
_system_prompt: Option<&str>,
|
||||||
|
_message: &str,
|
||||||
|
_model: &str,
|
||||||
|
_temperature: f64,
|
||||||
|
) -> anyhow::Result<String> {
|
||||||
|
Ok("ok".into())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn chat_message_constructors() {
|
fn chat_message_constructors() {
|
||||||
let sys = ChatMessage::system("Be helpful");
|
let sys = ChatMessage::system("Be helpful");
|
||||||
|
|
@ -398,4 +426,32 @@ mod tests {
|
||||||
let json = serde_json::to_string(&tool_result).unwrap();
|
let json = serde_json::to_string(&tool_result).unwrap();
|
||||||
assert!(json.contains("\"type\":\"ToolResults\""));
|
assert!(json.contains("\"type\":\"ToolResults\""));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn provider_capabilities_default() {
|
||||||
|
let caps = ProviderCapabilities::default();
|
||||||
|
assert!(!caps.native_tool_calling);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn provider_capabilities_equality() {
|
||||||
|
let caps1 = ProviderCapabilities {
|
||||||
|
native_tool_calling: true,
|
||||||
|
};
|
||||||
|
let caps2 = ProviderCapabilities {
|
||||||
|
native_tool_calling: true,
|
||||||
|
};
|
||||||
|
let caps3 = ProviderCapabilities {
|
||||||
|
native_tool_calling: false,
|
||||||
|
};
|
||||||
|
|
||||||
|
assert_eq!(caps1, caps2);
|
||||||
|
assert_ne!(caps1, caps3);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn supports_native_tools_reflects_capabilities_default_mapping() {
|
||||||
|
let provider = CapabilityMockProvider;
|
||||||
|
assert!(provider.supports_native_tools());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -81,14 +81,17 @@ mod tests {
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn bubblewrap_sandbox_name() {
|
fn bubblewrap_sandbox_name() {
|
||||||
assert_eq!(BubblewrapSandbox.name(), "bubblewrap");
|
let sandbox = BubblewrapSandbox;
|
||||||
|
assert_eq!(sandbox.name(), "bubblewrap");
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn bubblewrap_is_available_only_if_installed() {
|
fn bubblewrap_is_available_only_if_installed() {
|
||||||
// Result depends on whether bwrap is installed
|
// Result depends on whether bwrap is installed
|
||||||
let available = BubblewrapSandbox::is_available();
|
let sandbox = BubblewrapSandbox;
|
||||||
|
let _available = sandbox.is_available();
|
||||||
|
|
||||||
// Either way, the name should still work
|
// Either way, the name should still work
|
||||||
assert_eq!(BubblewrapSandbox.name(), "bubblewrap");
|
assert_eq!(sandbox.name(), "bubblewrap");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -184,7 +184,7 @@ fn generate_token() -> String {
|
||||||
use rand::RngCore;
|
use rand::RngCore;
|
||||||
let mut bytes = [0u8; 32];
|
let mut bytes = [0u8; 32];
|
||||||
rand::thread_rng().fill_bytes(&mut bytes);
|
rand::thread_rng().fill_bytes(&mut bytes);
|
||||||
format!("zc_{}", hex::encode(&bytes))
|
format!("zc_{}", hex::encode(bytes))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// SHA-256 hash a bearer token for storage. Returns lowercase hex.
|
/// SHA-256 hash a bearer token for storage. Returns lowercase hex.
|
||||||
|
|
|
||||||
|
|
@ -343,6 +343,7 @@ impl SecurityPolicy {
|
||||||
/// validates each sub-command against the allowlist
|
/// validates each sub-command against the allowlist
|
||||||
/// - Blocks single `&` background chaining (`&&` remains supported)
|
/// - Blocks single `&` background chaining (`&&` remains supported)
|
||||||
/// - Blocks output redirections (`>`, `>>`) that could write outside workspace
|
/// - Blocks output redirections (`>`, `>>`) that could write outside workspace
|
||||||
|
/// - Blocks dangerous arguments (e.g. `find -exec`, `git config`)
|
||||||
pub fn is_command_allowed(&self, command: &str) -> bool {
|
pub fn is_command_allowed(&self, command: &str) -> bool {
|
||||||
if self.autonomy == AutonomyLevel::ReadOnly {
|
if self.autonomy == AutonomyLevel::ReadOnly {
|
||||||
return false;
|
return false;
|
||||||
|
|
@ -350,7 +351,12 @@ impl SecurityPolicy {
|
||||||
|
|
||||||
// Block subshell/expansion operators — these allow hiding arbitrary
|
// Block subshell/expansion operators — these allow hiding arbitrary
|
||||||
// commands inside an allowed command (e.g. `echo $(rm -rf /)`)
|
// commands inside an allowed command (e.g. `echo $(rm -rf /)`)
|
||||||
if command.contains('`') || command.contains("$(") || command.contains("${") {
|
if command.contains('`')
|
||||||
|
|| command.contains("$(")
|
||||||
|
|| command.contains("${")
|
||||||
|
|| command.contains("<(")
|
||||||
|
|| command.contains(">(")
|
||||||
|
{
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -359,6 +365,15 @@ impl SecurityPolicy {
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Block `tee` — it can write to arbitrary files, bypassing the
|
||||||
|
// redirect check above (e.g. `echo secret | tee /etc/crontab`)
|
||||||
|
if command
|
||||||
|
.split_whitespace()
|
||||||
|
.any(|w| w == "tee" || w.ends_with("/tee"))
|
||||||
|
{
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
// Block background command chaining (`&`), which can hide extra
|
// Block background command chaining (`&`), which can hide extra
|
||||||
// sub-commands and outlive timeout expectations. Keep `&&` allowed.
|
// sub-commands and outlive timeout expectations. Keep `&&` allowed.
|
||||||
if contains_single_ampersand(command) {
|
if contains_single_ampersand(command) {
|
||||||
|
|
@ -384,13 +399,9 @@ impl SecurityPolicy {
|
||||||
// Strip leading env var assignments (e.g. FOO=bar cmd)
|
// Strip leading env var assignments (e.g. FOO=bar cmd)
|
||||||
let cmd_part = skip_env_assignments(segment);
|
let cmd_part = skip_env_assignments(segment);
|
||||||
|
|
||||||
let base_cmd = cmd_part
|
let mut words = cmd_part.split_whitespace();
|
||||||
.split_whitespace()
|
let base_raw = words.next().unwrap_or("");
|
||||||
.next()
|
let base_cmd = base_raw.rsplit('/').next().unwrap_or("");
|
||||||
.unwrap_or("")
|
|
||||||
.rsplit('/')
|
|
||||||
.next()
|
|
||||||
.unwrap_or("");
|
|
||||||
|
|
||||||
if base_cmd.is_empty() {
|
if base_cmd.is_empty() {
|
||||||
continue;
|
continue;
|
||||||
|
|
@ -403,6 +414,12 @@ impl SecurityPolicy {
|
||||||
{
|
{
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Validate arguments for the command
|
||||||
|
let args: Vec<String> = words.map(|w| w.to_ascii_lowercase()).collect();
|
||||||
|
if !self.is_args_safe(base_cmd, &args) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// At least one command must be present
|
// At least one command must be present
|
||||||
|
|
@ -414,6 +431,29 @@ impl SecurityPolicy {
|
||||||
has_cmd
|
has_cmd
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Check for dangerous arguments that allow sub-command execution.
|
||||||
|
fn is_args_safe(&self, base: &str, args: &[String]) -> bool {
|
||||||
|
let base = base.to_ascii_lowercase();
|
||||||
|
match base.as_str() {
|
||||||
|
"find" => {
|
||||||
|
// find -exec and find -ok allow arbitrary command execution
|
||||||
|
!args.iter().any(|arg| arg == "-exec" || arg == "-ok")
|
||||||
|
}
|
||||||
|
"git" => {
|
||||||
|
// git config, alias, and -c can be used to set dangerous options
|
||||||
|
// (e.g. git config core.editor "rm -rf /")
|
||||||
|
!args.iter().any(|arg| {
|
||||||
|
arg == "config"
|
||||||
|
|| arg.starts_with("config.")
|
||||||
|
|| arg == "alias"
|
||||||
|
|| arg.starts_with("alias.")
|
||||||
|
|| arg == "-c"
|
||||||
|
})
|
||||||
|
}
|
||||||
|
_ => true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Check if a file path is allowed (no path traversal, within workspace)
|
/// Check if a file path is allowed (no path traversal, within workspace)
|
||||||
pub fn is_path_allowed(&self, path: &str) -> bool {
|
pub fn is_path_allowed(&self, path: &str) -> bool {
|
||||||
// Block null bytes (can truncate paths in C-backed syscalls)
|
// Block null bytes (can truncate paths in C-backed syscalls)
|
||||||
|
|
@ -982,12 +1022,43 @@ mod tests {
|
||||||
assert!(!p.is_command_allowed("ls >> /tmp/exfil.txt"));
|
assert!(!p.is_command_allowed("ls >> /tmp/exfil.txt"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn command_argument_injection_blocked() {
|
||||||
|
let p = default_policy();
|
||||||
|
// find -exec is a common bypass
|
||||||
|
assert!(!p.is_command_allowed("find . -exec rm -rf {} +"));
|
||||||
|
assert!(!p.is_command_allowed("find / -ok cat {} \\;"));
|
||||||
|
// git config/alias can execute commands
|
||||||
|
assert!(!p.is_command_allowed("git config core.editor \"rm -rf /\""));
|
||||||
|
assert!(!p.is_command_allowed("git alias.st status"));
|
||||||
|
assert!(!p.is_command_allowed("git -c core.editor=calc.exe commit"));
|
||||||
|
// Legitimate commands should still work
|
||||||
|
assert!(p.is_command_allowed("find . -name '*.txt'"));
|
||||||
|
assert!(p.is_command_allowed("git status"));
|
||||||
|
assert!(p.is_command_allowed("git add ."));
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn command_injection_dollar_brace_blocked() {
|
fn command_injection_dollar_brace_blocked() {
|
||||||
let p = default_policy();
|
let p = default_policy();
|
||||||
assert!(!p.is_command_allowed("echo ${IFS}cat${IFS}/etc/passwd"));
|
assert!(!p.is_command_allowed("echo ${IFS}cat${IFS}/etc/passwd"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn command_injection_tee_blocked() {
|
||||||
|
let p = default_policy();
|
||||||
|
assert!(!p.is_command_allowed("echo secret | tee /etc/crontab"));
|
||||||
|
assert!(!p.is_command_allowed("ls | /usr/bin/tee outfile"));
|
||||||
|
assert!(!p.is_command_allowed("tee file.txt"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn command_injection_process_substitution_blocked() {
|
||||||
|
let p = default_policy();
|
||||||
|
assert!(!p.is_command_allowed("cat <(echo pwned)"));
|
||||||
|
assert!(!p.is_command_allowed("ls >(cat /etc/passwd)"));
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn command_env_var_prefix_with_allowed_cmd() {
|
fn command_env_var_prefix_with_allowed_cmd() {
|
||||||
let p = default_policy();
|
let p = default_policy();
|
||||||
|
|
|
||||||
|
|
@ -854,7 +854,6 @@ impl BrowserTool {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[allow(clippy::too_many_lines)]
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl Tool for BrowserTool {
|
impl Tool for BrowserTool {
|
||||||
fn name(&self) -> &str {
|
fn name(&self) -> &str {
|
||||||
|
|
@ -1031,165 +1030,21 @@ impl Tool for BrowserTool {
|
||||||
return self.execute_computer_use_action(action_str, &args).await;
|
return self.execute_computer_use_action(action_str, &args).await;
|
||||||
}
|
}
|
||||||
|
|
||||||
let action = match action_str {
|
if is_computer_use_only_action(action_str) {
|
||||||
"open" => {
|
return Ok(ToolResult {
|
||||||
let url = args
|
success: false,
|
||||||
.get("url")
|
output: String::new(),
|
||||||
.and_then(|v| v.as_str())
|
error: Some(unavailable_action_for_backend_error(action_str, backend)),
|
||||||
.ok_or_else(|| anyhow::anyhow!("Missing 'url' for open action"))?;
|
});
|
||||||
BrowserAction::Open { url: url.into() }
|
}
|
||||||
}
|
|
||||||
"snapshot" => BrowserAction::Snapshot {
|
let action = match parse_browser_action(action_str, &args) {
|
||||||
interactive_only: args
|
Ok(a) => a,
|
||||||
.get("interactive_only")
|
Err(e) => {
|
||||||
.and_then(serde_json::Value::as_bool)
|
|
||||||
.unwrap_or(true), // Default to interactive for AI
|
|
||||||
compact: args
|
|
||||||
.get("compact")
|
|
||||||
.and_then(serde_json::Value::as_bool)
|
|
||||||
.unwrap_or(true),
|
|
||||||
depth: args
|
|
||||||
.get("depth")
|
|
||||||
.and_then(serde_json::Value::as_u64)
|
|
||||||
.map(|d| u32::try_from(d).unwrap_or(u32::MAX)),
|
|
||||||
},
|
|
||||||
"click" => {
|
|
||||||
let selector = args
|
|
||||||
.get("selector")
|
|
||||||
.and_then(|v| v.as_str())
|
|
||||||
.ok_or_else(|| anyhow::anyhow!("Missing 'selector' for click"))?;
|
|
||||||
BrowserAction::Click {
|
|
||||||
selector: selector.into(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
"fill" => {
|
|
||||||
let selector = args
|
|
||||||
.get("selector")
|
|
||||||
.and_then(|v| v.as_str())
|
|
||||||
.ok_or_else(|| anyhow::anyhow!("Missing 'selector' for fill"))?;
|
|
||||||
let value = args
|
|
||||||
.get("value")
|
|
||||||
.and_then(|v| v.as_str())
|
|
||||||
.ok_or_else(|| anyhow::anyhow!("Missing 'value' for fill"))?;
|
|
||||||
BrowserAction::Fill {
|
|
||||||
selector: selector.into(),
|
|
||||||
value: value.into(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
"type" => {
|
|
||||||
let selector = args
|
|
||||||
.get("selector")
|
|
||||||
.and_then(|v| v.as_str())
|
|
||||||
.ok_or_else(|| anyhow::anyhow!("Missing 'selector' for type"))?;
|
|
||||||
let text = args
|
|
||||||
.get("text")
|
|
||||||
.and_then(|v| v.as_str())
|
|
||||||
.ok_or_else(|| anyhow::anyhow!("Missing 'text' for type"))?;
|
|
||||||
BrowserAction::Type {
|
|
||||||
selector: selector.into(),
|
|
||||||
text: text.into(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
"get_text" => {
|
|
||||||
let selector = args
|
|
||||||
.get("selector")
|
|
||||||
.and_then(|v| v.as_str())
|
|
||||||
.ok_or_else(|| anyhow::anyhow!("Missing 'selector' for get_text"))?;
|
|
||||||
BrowserAction::GetText {
|
|
||||||
selector: selector.into(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
"get_title" => BrowserAction::GetTitle,
|
|
||||||
"get_url" => BrowserAction::GetUrl,
|
|
||||||
"screenshot" => BrowserAction::Screenshot {
|
|
||||||
path: args.get("path").and_then(|v| v.as_str()).map(String::from),
|
|
||||||
full_page: args
|
|
||||||
.get("full_page")
|
|
||||||
.and_then(serde_json::Value::as_bool)
|
|
||||||
.unwrap_or(false),
|
|
||||||
},
|
|
||||||
"wait" => BrowserAction::Wait {
|
|
||||||
selector: args
|
|
||||||
.get("selector")
|
|
||||||
.and_then(|v| v.as_str())
|
|
||||||
.map(String::from),
|
|
||||||
ms: args.get("ms").and_then(serde_json::Value::as_u64),
|
|
||||||
text: args.get("text").and_then(|v| v.as_str()).map(String::from),
|
|
||||||
},
|
|
||||||
"press" => {
|
|
||||||
let key = args
|
|
||||||
.get("key")
|
|
||||||
.and_then(|v| v.as_str())
|
|
||||||
.ok_or_else(|| anyhow::anyhow!("Missing 'key' for press"))?;
|
|
||||||
BrowserAction::Press { key: key.into() }
|
|
||||||
}
|
|
||||||
"hover" => {
|
|
||||||
let selector = args
|
|
||||||
.get("selector")
|
|
||||||
.and_then(|v| v.as_str())
|
|
||||||
.ok_or_else(|| anyhow::anyhow!("Missing 'selector' for hover"))?;
|
|
||||||
BrowserAction::Hover {
|
|
||||||
selector: selector.into(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
"scroll" => {
|
|
||||||
let direction = args
|
|
||||||
.get("direction")
|
|
||||||
.and_then(|v| v.as_str())
|
|
||||||
.ok_or_else(|| anyhow::anyhow!("Missing 'direction' for scroll"))?;
|
|
||||||
BrowserAction::Scroll {
|
|
||||||
direction: direction.into(),
|
|
||||||
pixels: args
|
|
||||||
.get("pixels")
|
|
||||||
.and_then(serde_json::Value::as_u64)
|
|
||||||
.map(|p| u32::try_from(p).unwrap_or(u32::MAX)),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
"is_visible" => {
|
|
||||||
let selector = args
|
|
||||||
.get("selector")
|
|
||||||
.and_then(|v| v.as_str())
|
|
||||||
.ok_or_else(|| anyhow::anyhow!("Missing 'selector' for is_visible"))?;
|
|
||||||
BrowserAction::IsVisible {
|
|
||||||
selector: selector.into(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
"close" => BrowserAction::Close,
|
|
||||||
"find" => {
|
|
||||||
let by = args
|
|
||||||
.get("by")
|
|
||||||
.and_then(|v| v.as_str())
|
|
||||||
.ok_or_else(|| anyhow::anyhow!("Missing 'by' for find"))?;
|
|
||||||
let value = args
|
|
||||||
.get("value")
|
|
||||||
.and_then(|v| v.as_str())
|
|
||||||
.ok_or_else(|| anyhow::anyhow!("Missing 'value' for find"))?;
|
|
||||||
let action = args
|
|
||||||
.get("find_action")
|
|
||||||
.and_then(|v| v.as_str())
|
|
||||||
.ok_or_else(|| anyhow::anyhow!("Missing 'find_action' for find"))?;
|
|
||||||
BrowserAction::Find {
|
|
||||||
by: by.into(),
|
|
||||||
value: value.into(),
|
|
||||||
action: action.into(),
|
|
||||||
fill_value: args
|
|
||||||
.get("fill_value")
|
|
||||||
.and_then(|v| v.as_str())
|
|
||||||
.map(String::from),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
_ => {
|
|
||||||
return Ok(ToolResult {
|
return Ok(ToolResult {
|
||||||
success: false,
|
success: false,
|
||||||
output: String::new(),
|
output: String::new(),
|
||||||
error: Some(format!(
|
error: Some(e.to_string()),
|
||||||
"Action '{action_str}' is unavailable for backend '{}'",
|
|
||||||
match backend {
|
|
||||||
ResolvedBackend::AgentBrowser => "agent_browser",
|
|
||||||
ResolvedBackend::RustNative => "rust_native",
|
|
||||||
ResolvedBackend::ComputerUse => "computer_use",
|
|
||||||
}
|
|
||||||
)),
|
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
@ -1871,6 +1726,161 @@ mod native_backend {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ── Action parsing ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
/// Parse a JSON `args` object into a typed `BrowserAction`.
|
||||||
|
fn parse_browser_action(action_str: &str, args: &Value) -> anyhow::Result<BrowserAction> {
|
||||||
|
match action_str {
|
||||||
|
"open" => {
|
||||||
|
let url = args
|
||||||
|
.get("url")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("Missing 'url' for open action"))?;
|
||||||
|
Ok(BrowserAction::Open { url: url.into() })
|
||||||
|
}
|
||||||
|
"snapshot" => Ok(BrowserAction::Snapshot {
|
||||||
|
interactive_only: args
|
||||||
|
.get("interactive_only")
|
||||||
|
.and_then(serde_json::Value::as_bool)
|
||||||
|
.unwrap_or(true),
|
||||||
|
compact: args
|
||||||
|
.get("compact")
|
||||||
|
.and_then(serde_json::Value::as_bool)
|
||||||
|
.unwrap_or(true),
|
||||||
|
depth: args
|
||||||
|
.get("depth")
|
||||||
|
.and_then(serde_json::Value::as_u64)
|
||||||
|
.map(|d| u32::try_from(d).unwrap_or(u32::MAX)),
|
||||||
|
}),
|
||||||
|
"click" => {
|
||||||
|
let selector = args
|
||||||
|
.get("selector")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("Missing 'selector' for click"))?;
|
||||||
|
Ok(BrowserAction::Click {
|
||||||
|
selector: selector.into(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
"fill" => {
|
||||||
|
let selector = args
|
||||||
|
.get("selector")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("Missing 'selector' for fill"))?;
|
||||||
|
let value = args
|
||||||
|
.get("value")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("Missing 'value' for fill"))?;
|
||||||
|
Ok(BrowserAction::Fill {
|
||||||
|
selector: selector.into(),
|
||||||
|
value: value.into(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
"type" => {
|
||||||
|
let selector = args
|
||||||
|
.get("selector")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("Missing 'selector' for type"))?;
|
||||||
|
let text = args
|
||||||
|
.get("text")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("Missing 'text' for type"))?;
|
||||||
|
Ok(BrowserAction::Type {
|
||||||
|
selector: selector.into(),
|
||||||
|
text: text.into(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
"get_text" => {
|
||||||
|
let selector = args
|
||||||
|
.get("selector")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("Missing 'selector' for get_text"))?;
|
||||||
|
Ok(BrowserAction::GetText {
|
||||||
|
selector: selector.into(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
"get_title" => Ok(BrowserAction::GetTitle),
|
||||||
|
"get_url" => Ok(BrowserAction::GetUrl),
|
||||||
|
"screenshot" => Ok(BrowserAction::Screenshot {
|
||||||
|
path: args.get("path").and_then(|v| v.as_str()).map(String::from),
|
||||||
|
full_page: args
|
||||||
|
.get("full_page")
|
||||||
|
.and_then(serde_json::Value::as_bool)
|
||||||
|
.unwrap_or(false),
|
||||||
|
}),
|
||||||
|
"wait" => Ok(BrowserAction::Wait {
|
||||||
|
selector: args
|
||||||
|
.get("selector")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.map(String::from),
|
||||||
|
ms: args.get("ms").and_then(serde_json::Value::as_u64),
|
||||||
|
text: args.get("text").and_then(|v| v.as_str()).map(String::from),
|
||||||
|
}),
|
||||||
|
"press" => {
|
||||||
|
let key = args
|
||||||
|
.get("key")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("Missing 'key' for press"))?;
|
||||||
|
Ok(BrowserAction::Press { key: key.into() })
|
||||||
|
}
|
||||||
|
"hover" => {
|
||||||
|
let selector = args
|
||||||
|
.get("selector")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("Missing 'selector' for hover"))?;
|
||||||
|
Ok(BrowserAction::Hover {
|
||||||
|
selector: selector.into(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
"scroll" => {
|
||||||
|
let direction = args
|
||||||
|
.get("direction")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("Missing 'direction' for scroll"))?;
|
||||||
|
Ok(BrowserAction::Scroll {
|
||||||
|
direction: direction.into(),
|
||||||
|
pixels: args
|
||||||
|
.get("pixels")
|
||||||
|
.and_then(serde_json::Value::as_u64)
|
||||||
|
.map(|p| u32::try_from(p).unwrap_or(u32::MAX)),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
"is_visible" => {
|
||||||
|
let selector = args
|
||||||
|
.get("selector")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("Missing 'selector' for is_visible"))?;
|
||||||
|
Ok(BrowserAction::IsVisible {
|
||||||
|
selector: selector.into(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
"close" => Ok(BrowserAction::Close),
|
||||||
|
"find" => {
|
||||||
|
let by = args
|
||||||
|
.get("by")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("Missing 'by' for find"))?;
|
||||||
|
let value = args
|
||||||
|
.get("value")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("Missing 'value' for find"))?;
|
||||||
|
let action = args
|
||||||
|
.get("find_action")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("Missing 'find_action' for find"))?;
|
||||||
|
Ok(BrowserAction::Find {
|
||||||
|
by: by.into(),
|
||||||
|
value: value.into(),
|
||||||
|
action: action.into(),
|
||||||
|
fill_value: args
|
||||||
|
.get("fill_value")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.map(String::from),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
other => anyhow::bail!("Unsupported browser action: {other}"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// ── Helper functions ─────────────────────────────────────────────
|
// ── Helper functions ─────────────────────────────────────────────
|
||||||
|
|
||||||
fn is_supported_browser_action(action: &str) -> bool {
|
fn is_supported_browser_action(action: &str) -> bool {
|
||||||
|
|
@ -1901,6 +1911,28 @@ fn is_supported_browser_action(action: &str) -> bool {
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn is_computer_use_only_action(action: &str) -> bool {
|
||||||
|
matches!(
|
||||||
|
action,
|
||||||
|
"mouse_move" | "mouse_click" | "mouse_drag" | "key_type" | "key_press" | "screen_capture"
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn backend_name(backend: ResolvedBackend) -> &'static str {
|
||||||
|
match backend {
|
||||||
|
ResolvedBackend::AgentBrowser => "agent_browser",
|
||||||
|
ResolvedBackend::RustNative => "rust_native",
|
||||||
|
ResolvedBackend::ComputerUse => "computer_use",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn unavailable_action_for_backend_error(action: &str, backend: ResolvedBackend) -> String {
|
||||||
|
format!(
|
||||||
|
"Action '{action}' is unavailable for backend '{}'",
|
||||||
|
backend_name(backend)
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
fn normalize_domains(domains: Vec<String>) -> Vec<String> {
|
fn normalize_domains(domains: Vec<String>) -> Vec<String> {
|
||||||
domains
|
domains
|
||||||
.into_iter()
|
.into_iter()
|
||||||
|
|
@ -2342,4 +2374,28 @@ mod tests {
|
||||||
let tool = BrowserTool::new(security, vec![], None);
|
let tool = BrowserTool::new(security, vec![], None);
|
||||||
assert!(tool.validate_url("https://example.com").is_err());
|
assert!(tool.validate_url("https://example.com").is_err());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn computer_use_only_action_detection_is_correct() {
|
||||||
|
assert!(is_computer_use_only_action("mouse_move"));
|
||||||
|
assert!(is_computer_use_only_action("mouse_click"));
|
||||||
|
assert!(is_computer_use_only_action("mouse_drag"));
|
||||||
|
assert!(is_computer_use_only_action("key_type"));
|
||||||
|
assert!(is_computer_use_only_action("key_press"));
|
||||||
|
assert!(is_computer_use_only_action("screen_capture"));
|
||||||
|
assert!(!is_computer_use_only_action("open"));
|
||||||
|
assert!(!is_computer_use_only_action("snapshot"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn unavailable_action_error_preserves_backend_context() {
|
||||||
|
assert_eq!(
|
||||||
|
unavailable_action_for_backend_error("mouse_move", ResolvedBackend::AgentBrowser),
|
||||||
|
"Action 'mouse_move' is unavailable for backend 'agent_browser'"
|
||||||
|
);
|
||||||
|
assert_eq!(
|
||||||
|
unavailable_action_for_backend_error("mouse_move", ResolvedBackend::RustNative),
|
||||||
|
"Action 'mouse_move' is unavailable for backend 'rust_native'"
|
||||||
|
);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -112,12 +112,12 @@ impl ComposioTool {
|
||||||
action_name: &str,
|
action_name: &str,
|
||||||
params: serde_json::Value,
|
params: serde_json::Value,
|
||||||
entity_id: Option<&str>,
|
entity_id: Option<&str>,
|
||||||
connected_account_id: Option<&str>,
|
connected_account_ref: Option<&str>,
|
||||||
) -> anyhow::Result<serde_json::Value> {
|
) -> anyhow::Result<serde_json::Value> {
|
||||||
let tool_slug = normalize_tool_slug(action_name);
|
let tool_slug = normalize_tool_slug(action_name);
|
||||||
|
|
||||||
match self
|
match self
|
||||||
.execute_action_v3(&tool_slug, params.clone(), entity_id, connected_account_id)
|
.execute_action_v3(&tool_slug, params.clone(), entity_id, connected_account_ref)
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
Ok(result) => Ok(result),
|
Ok(result) => Ok(result),
|
||||||
|
|
@ -130,21 +130,17 @@ impl ComposioTool {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn execute_action_v3(
|
fn build_execute_action_v3_request(
|
||||||
&self,
|
|
||||||
tool_slug: &str,
|
tool_slug: &str,
|
||||||
params: serde_json::Value,
|
params: serde_json::Value,
|
||||||
entity_id: Option<&str>,
|
entity_id: Option<&str>,
|
||||||
connected_account_id: Option<&str>,
|
connected_account_ref: Option<&str>,
|
||||||
) -> anyhow::Result<serde_json::Value> {
|
) -> (String, serde_json::Value) {
|
||||||
let url = if let Some(connected_account_id) = connected_account_id
|
let url = format!("{COMPOSIO_API_BASE_V3}/tools/{tool_slug}/execute");
|
||||||
.map(str::trim)
|
let account_ref = connected_account_ref.and_then(|candidate| {
|
||||||
.filter(|id| !id.is_empty())
|
let trimmed_candidate = candidate.trim();
|
||||||
{
|
(!trimmed_candidate.is_empty()).then_some(trimmed_candidate)
|
||||||
format!("{COMPOSIO_API_BASE_V3}/tools/{tool_slug}/execute/{connected_account_id}")
|
});
|
||||||
} else {
|
|
||||||
format!("{COMPOSIO_API_BASE_V3}/tools/{tool_slug}/execute")
|
|
||||||
};
|
|
||||||
|
|
||||||
let mut body = json!({
|
let mut body = json!({
|
||||||
"arguments": params,
|
"arguments": params,
|
||||||
|
|
@ -153,6 +149,26 @@ impl ComposioTool {
|
||||||
if let Some(entity) = entity_id {
|
if let Some(entity) = entity_id {
|
||||||
body["user_id"] = json!(entity);
|
body["user_id"] = json!(entity);
|
||||||
}
|
}
|
||||||
|
if let Some(account_ref) = account_ref {
|
||||||
|
body["connected_account_id"] = json!(account_ref);
|
||||||
|
}
|
||||||
|
|
||||||
|
(url, body)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn execute_action_v3(
|
||||||
|
&self,
|
||||||
|
tool_slug: &str,
|
||||||
|
params: serde_json::Value,
|
||||||
|
entity_id: Option<&str>,
|
||||||
|
connected_account_ref: Option<&str>,
|
||||||
|
) -> anyhow::Result<serde_json::Value> {
|
||||||
|
let (url, body) = Self::build_execute_action_v3_request(
|
||||||
|
tool_slug,
|
||||||
|
params,
|
||||||
|
entity_id,
|
||||||
|
connected_account_ref,
|
||||||
|
);
|
||||||
|
|
||||||
let resp = self
|
let resp = self
|
||||||
.client
|
.client
|
||||||
|
|
@ -474,11 +490,11 @@ impl Tool for ComposioTool {
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
let params = args.get("params").cloned().unwrap_or(json!({}));
|
let params = args.get("params").cloned().unwrap_or(json!({}));
|
||||||
let connected_account_id =
|
let connected_account_ref =
|
||||||
args.get("connected_account_id").and_then(|v| v.as_str());
|
args.get("connected_account_id").and_then(|v| v.as_str());
|
||||||
|
|
||||||
match self
|
match self
|
||||||
.execute_action(action_name, params, Some(entity_id), connected_account_id)
|
.execute_action(action_name, params, Some(entity_id), connected_account_ref)
|
||||||
.await
|
.await
|
||||||
{
|
{
|
||||||
Ok(result) => {
|
Ok(result) => {
|
||||||
|
|
@ -594,9 +610,38 @@ async fn response_error(resp: reqwest::Response) -> String {
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(api_error) = extract_api_error_message(&body) {
|
if let Some(api_error) = extract_api_error_message(&body) {
|
||||||
format!("HTTP {}: {api_error}", status.as_u16())
|
return format!(
|
||||||
|
"HTTP {}: {}",
|
||||||
|
status.as_u16(),
|
||||||
|
sanitize_error_message(&api_error)
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
format!("HTTP {}", status.as_u16())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn sanitize_error_message(message: &str) -> String {
|
||||||
|
let mut sanitized = message.replace('\n', " ");
|
||||||
|
for marker in [
|
||||||
|
"connected_account_id",
|
||||||
|
"connectedAccountId",
|
||||||
|
"entity_id",
|
||||||
|
"entityId",
|
||||||
|
"user_id",
|
||||||
|
"userId",
|
||||||
|
] {
|
||||||
|
sanitized = sanitized.replace(marker, "[redacted]");
|
||||||
|
}
|
||||||
|
|
||||||
|
let max_chars = 240;
|
||||||
|
if sanitized.chars().count() <= max_chars {
|
||||||
|
sanitized
|
||||||
} else {
|
} else {
|
||||||
format!("HTTP {}: {body}", status.as_u16())
|
let mut end = max_chars;
|
||||||
|
while end > 0 && !sanitized.is_char_boundary(end) {
|
||||||
|
end -= 1;
|
||||||
|
}
|
||||||
|
format!("{}...", &sanitized[..end])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -948,4 +993,40 @@ mod tests {
|
||||||
fn composio_api_base_url_is_v3() {
|
fn composio_api_base_url_is_v3() {
|
||||||
assert_eq!(COMPOSIO_API_BASE_V3, "https://backend.composio.dev/api/v3");
|
assert_eq!(COMPOSIO_API_BASE_V3, "https://backend.composio.dev/api/v3");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn build_execute_action_v3_request_uses_fixed_endpoint_and_body_account_id() {
|
||||||
|
let (url, body) = ComposioTool::build_execute_action_v3_request(
|
||||||
|
"gmail-send-email",
|
||||||
|
json!({"to": "test@example.com"}),
|
||||||
|
Some("workspace-user"),
|
||||||
|
Some("account-42"),
|
||||||
|
);
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
url,
|
||||||
|
"https://backend.composio.dev/api/v3/tools/gmail-send-email/execute"
|
||||||
|
);
|
||||||
|
assert_eq!(body["arguments"]["to"], json!("test@example.com"));
|
||||||
|
assert_eq!(body["user_id"], json!("workspace-user"));
|
||||||
|
assert_eq!(body["connected_account_id"], json!("account-42"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn build_execute_action_v3_request_drops_blank_optional_fields() {
|
||||||
|
let (url, body) = ComposioTool::build_execute_action_v3_request(
|
||||||
|
"github-list-repos",
|
||||||
|
json!({}),
|
||||||
|
None,
|
||||||
|
Some(" "),
|
||||||
|
);
|
||||||
|
|
||||||
|
assert_eq!(
|
||||||
|
url,
|
||||||
|
"https://backend.composio.dev/api/v3/tools/github-list-repos/execute"
|
||||||
|
);
|
||||||
|
assert_eq!(body["arguments"], json!({}));
|
||||||
|
assert!(body.get("connected_account_id").is_none());
|
||||||
|
assert!(body.get("user_id").is_none());
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -16,8 +16,8 @@ const DELEGATE_TIMEOUT_SECS: u64 = 120;
|
||||||
/// summarization) to purpose-built sub-agents.
|
/// summarization) to purpose-built sub-agents.
|
||||||
pub struct DelegateTool {
|
pub struct DelegateTool {
|
||||||
agents: Arc<HashMap<String, DelegateAgentConfig>>,
|
agents: Arc<HashMap<String, DelegateAgentConfig>>,
|
||||||
/// Global API key fallback (from config.api_key)
|
/// Global credential fallback (from config.api_key)
|
||||||
fallback_api_key: Option<String>,
|
fallback_credential: Option<String>,
|
||||||
/// Depth at which this tool instance lives in the delegation chain.
|
/// Depth at which this tool instance lives in the delegation chain.
|
||||||
depth: u32,
|
depth: u32,
|
||||||
}
|
}
|
||||||
|
|
@ -25,11 +25,11 @@ pub struct DelegateTool {
|
||||||
impl DelegateTool {
|
impl DelegateTool {
|
||||||
pub fn new(
|
pub fn new(
|
||||||
agents: HashMap<String, DelegateAgentConfig>,
|
agents: HashMap<String, DelegateAgentConfig>,
|
||||||
fallback_api_key: Option<String>,
|
fallback_credential: Option<String>,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
agents: Arc::new(agents),
|
agents: Arc::new(agents),
|
||||||
fallback_api_key,
|
fallback_credential,
|
||||||
depth: 0,
|
depth: 0,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -39,12 +39,12 @@ impl DelegateTool {
|
||||||
/// their DelegateTool via this method with `depth: parent.depth + 1`.
|
/// their DelegateTool via this method with `depth: parent.depth + 1`.
|
||||||
pub fn with_depth(
|
pub fn with_depth(
|
||||||
agents: HashMap<String, DelegateAgentConfig>,
|
agents: HashMap<String, DelegateAgentConfig>,
|
||||||
fallback_api_key: Option<String>,
|
fallback_credential: Option<String>,
|
||||||
depth: u32,
|
depth: u32,
|
||||||
) -> Self {
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
agents: Arc::new(agents),
|
agents: Arc::new(agents),
|
||||||
fallback_api_key,
|
fallback_credential,
|
||||||
depth,
|
depth,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -165,13 +165,15 @@ impl Tool for DelegateTool {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create provider for this agent
|
// Create provider for this agent
|
||||||
let api_key = agent_config
|
let provider_credential_owned = agent_config
|
||||||
.api_key
|
.api_key
|
||||||
.as_deref()
|
.clone()
|
||||||
.or(self.fallback_api_key.as_deref());
|
.or_else(|| self.fallback_credential.clone());
|
||||||
|
#[allow(clippy::option_as_ref_deref)]
|
||||||
|
let provider_credential = provider_credential_owned.as_ref().map(String::as_str);
|
||||||
|
|
||||||
let provider: Box<dyn Provider> =
|
let provider: Box<dyn Provider> =
|
||||||
match providers::create_provider(&agent_config.provider, api_key) {
|
match providers::create_provider(&agent_config.provider, provider_credential) {
|
||||||
Ok(p) => p,
|
Ok(p) => p,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
return Ok(ToolResult {
|
return Ok(ToolResult {
|
||||||
|
|
@ -268,7 +270,7 @@ mod tests {
|
||||||
provider: "openrouter".to_string(),
|
provider: "openrouter".to_string(),
|
||||||
model: "anthropic/claude-sonnet-4-20250514".to_string(),
|
model: "anthropic/claude-sonnet-4-20250514".to_string(),
|
||||||
system_prompt: None,
|
system_prompt: None,
|
||||||
api_key: Some("sk-test".to_string()),
|
api_key: Some("delegate-test-credential".to_string()),
|
||||||
temperature: None,
|
temperature: None,
|
||||||
max_depth: 2,
|
max_depth: 2,
|
||||||
},
|
},
|
||||||
|
|
|
||||||
|
|
@ -28,13 +28,22 @@ impl GitOperationsTool {
|
||||||
if arg_lower.starts_with("--exec=")
|
if arg_lower.starts_with("--exec=")
|
||||||
|| arg_lower.starts_with("--upload-pack=")
|
|| arg_lower.starts_with("--upload-pack=")
|
||||||
|| arg_lower.starts_with("--receive-pack=")
|
|| arg_lower.starts_with("--receive-pack=")
|
||||||
|
|| arg_lower.starts_with("--pager=")
|
||||||
|
|| arg_lower.starts_with("--editor=")
|
||||||
|
|| arg_lower == "--no-verify"
|
||||||
|| arg_lower.contains("$(")
|
|| arg_lower.contains("$(")
|
||||||
|| arg_lower.contains('`')
|
|| arg_lower.contains('`')
|
||||||
|| arg.contains('|')
|
|| arg.contains('|')
|
||||||
|| arg.contains(';')
|
|| arg.contains(';')
|
||||||
|
|| arg.contains('>')
|
||||||
{
|
{
|
||||||
anyhow::bail!("Blocked potentially dangerous git argument: {arg}");
|
anyhow::bail!("Blocked potentially dangerous git argument: {arg}");
|
||||||
}
|
}
|
||||||
|
// Block `-c` config injection (exact match or `-c=...` prefix).
|
||||||
|
// This must not false-positive on `--cached` or `-cached`.
|
||||||
|
if arg_lower == "-c" || arg_lower.starts_with("-c=") {
|
||||||
|
anyhow::bail!("Blocked potentially dangerous git argument: {arg}");
|
||||||
|
}
|
||||||
result.push(arg.to_string());
|
result.push(arg.to_string());
|
||||||
}
|
}
|
||||||
Ok(result)
|
Ok(result)
|
||||||
|
|
@ -129,6 +138,9 @@ impl GitOperationsTool {
|
||||||
.and_then(|v| v.as_bool())
|
.and_then(|v| v.as_bool())
|
||||||
.unwrap_or(false);
|
.unwrap_or(false);
|
||||||
|
|
||||||
|
// Validate files argument against injection patterns
|
||||||
|
self.sanitize_git_args(files)?;
|
||||||
|
|
||||||
let mut git_args = vec!["diff", "--unified=3"];
|
let mut git_args = vec!["diff", "--unified=3"];
|
||||||
if cached {
|
if cached {
|
||||||
git_args.push("--cached");
|
git_args.push("--cached");
|
||||||
|
|
@ -267,6 +279,14 @@ impl GitOperationsTool {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn truncate_commit_message(message: &str) -> String {
|
||||||
|
if message.chars().count() > 2000 {
|
||||||
|
format!("{}...", message.chars().take(1997).collect::<String>())
|
||||||
|
} else {
|
||||||
|
message.to_string()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
async fn git_commit(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
async fn git_commit(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||||
let message = args
|
let message = args
|
||||||
.get("message")
|
.get("message")
|
||||||
|
|
@ -286,11 +306,7 @@ impl GitOperationsTool {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Limit message length
|
// Limit message length
|
||||||
let message = if sanitized.len() > 2000 {
|
let message = Self::truncate_commit_message(&sanitized);
|
||||||
format!("{}...", &sanitized[..1997])
|
|
||||||
} else {
|
|
||||||
sanitized
|
|
||||||
};
|
|
||||||
|
|
||||||
let output = self.run_git_command(&["commit", "-m", &message]).await;
|
let output = self.run_git_command(&["commit", "-m", &message]).await;
|
||||||
|
|
||||||
|
|
@ -314,6 +330,9 @@ impl GitOperationsTool {
|
||||||
.and_then(|v| v.as_str())
|
.and_then(|v| v.as_str())
|
||||||
.ok_or_else(|| anyhow::anyhow!("Missing 'paths' parameter"))?;
|
.ok_or_else(|| anyhow::anyhow!("Missing 'paths' parameter"))?;
|
||||||
|
|
||||||
|
// Validate paths against injection patterns
|
||||||
|
self.sanitize_git_args(paths)?;
|
||||||
|
|
||||||
let output = self.run_git_command(&["add", "--", paths]).await;
|
let output = self.run_git_command(&["add", "--", paths]).await;
|
||||||
|
|
||||||
match output {
|
match output {
|
||||||
|
|
@ -574,6 +593,52 @@ mod tests {
|
||||||
assert!(tool.sanitize_git_args("arg; rm file").is_err());
|
assert!(tool.sanitize_git_args("arg; rm file").is_err());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn sanitize_git_blocks_pager_editor_injection() {
|
||||||
|
let tmp = TempDir::new().unwrap();
|
||||||
|
let tool = test_tool(tmp.path());
|
||||||
|
|
||||||
|
assert!(tool.sanitize_git_args("--pager=less").is_err());
|
||||||
|
assert!(tool.sanitize_git_args("--editor=vim").is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn sanitize_git_blocks_config_injection() {
|
||||||
|
let tmp = TempDir::new().unwrap();
|
||||||
|
let tool = test_tool(tmp.path());
|
||||||
|
|
||||||
|
// Exact `-c` flag (config injection)
|
||||||
|
assert!(tool.sanitize_git_args("-c core.sshCommand=evil").is_err());
|
||||||
|
assert!(tool.sanitize_git_args("-c=core.pager=less").is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn sanitize_git_blocks_no_verify() {
|
||||||
|
let tmp = TempDir::new().unwrap();
|
||||||
|
let tool = test_tool(tmp.path());
|
||||||
|
|
||||||
|
assert!(tool.sanitize_git_args("--no-verify").is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn sanitize_git_blocks_redirect_in_args() {
|
||||||
|
let tmp = TempDir::new().unwrap();
|
||||||
|
let tool = test_tool(tmp.path());
|
||||||
|
|
||||||
|
assert!(tool.sanitize_git_args("file.txt > /tmp/out").is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn sanitize_git_cached_not_blocked() {
|
||||||
|
let tmp = TempDir::new().unwrap();
|
||||||
|
let tool = test_tool(tmp.path());
|
||||||
|
|
||||||
|
// --cached must NOT be blocked by the `-c` check
|
||||||
|
assert!(tool.sanitize_git_args("--cached").is_ok());
|
||||||
|
// Other safe flags starting with -c prefix
|
||||||
|
assert!(tool.sanitize_git_args("-cached").is_ok());
|
||||||
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
fn sanitize_git_allows_safe() {
|
fn sanitize_git_allows_safe() {
|
||||||
let tmp = TempDir::new().unwrap();
|
let tmp = TempDir::new().unwrap();
|
||||||
|
|
@ -583,6 +648,8 @@ mod tests {
|
||||||
assert!(tool.sanitize_git_args("main").is_ok());
|
assert!(tool.sanitize_git_args("main").is_ok());
|
||||||
assert!(tool.sanitize_git_args("feature/test-branch").is_ok());
|
assert!(tool.sanitize_git_args("feature/test-branch").is_ok());
|
||||||
assert!(tool.sanitize_git_args("--cached").is_ok());
|
assert!(tool.sanitize_git_args("--cached").is_ok());
|
||||||
|
assert!(tool.sanitize_git_args("src/main.rs").is_ok());
|
||||||
|
assert!(tool.sanitize_git_args(".").is_ok());
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
@ -691,4 +758,12 @@ mod tests {
|
||||||
.unwrap_or("")
|
.unwrap_or("")
|
||||||
.contains("Unknown operation"));
|
.contains("Unknown operation"));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn truncates_multibyte_commit_message_without_panicking() {
|
||||||
|
let long = "🦀".repeat(2500);
|
||||||
|
let truncated = GitOperationsTool::truncate_commit_message(&long);
|
||||||
|
|
||||||
|
assert_eq!(truncated.chars().count(), 2000);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -124,10 +124,11 @@ impl Tool for HardwareBoardInfoTool {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
output.push_str(&format!(
|
use std::fmt::Write;
|
||||||
"probe-rs attach failed: {}. Using static info.\n\n",
|
let _ = write!(
|
||||||
e
|
output,
|
||||||
));
|
"probe-rs attach failed: {e}. Using static info.\n\n"
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -135,13 +136,15 @@ impl Tool for HardwareBoardInfoTool {
|
||||||
if let Some(info) = self.static_info_for_board(board) {
|
if let Some(info) = self.static_info_for_board(board) {
|
||||||
output.push_str(&info);
|
output.push_str(&info);
|
||||||
if let Some(mem) = memory_map_static(board) {
|
if let Some(mem) = memory_map_static(board) {
|
||||||
output.push_str(&format!("\n\n**Memory map:**\n{}", mem));
|
use std::fmt::Write;
|
||||||
|
let _ = write!(output, "\n\n**Memory map:**\n{mem}");
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
output.push_str(&format!(
|
use std::fmt::Write;
|
||||||
"Board '{}' configured. No static info available.",
|
let _ = write!(
|
||||||
board
|
output,
|
||||||
));
|
"Board '{board}' configured. No static info available."
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(ToolResult {
|
Ok(ToolResult {
|
||||||
|
|
|
||||||
|
|
@ -122,14 +122,16 @@ impl Tool for HardwareMemoryMapTool {
|
||||||
|
|
||||||
if !probe_ok {
|
if !probe_ok {
|
||||||
if let Some(map) = self.static_map_for_board(board) {
|
if let Some(map) = self.static_map_for_board(board) {
|
||||||
output.push_str(&format!("**{}** (from datasheet):\n{}", board, map));
|
use std::fmt::Write;
|
||||||
|
let _ = write!(output, "**{board}** (from datasheet):\n{map}");
|
||||||
} else {
|
} else {
|
||||||
|
use std::fmt::Write;
|
||||||
let known: Vec<&str> = MEMORY_MAPS.iter().map(|(b, _)| *b).collect();
|
let known: Vec<&str> = MEMORY_MAPS.iter().map(|(b, _)| *b).collect();
|
||||||
output.push_str(&format!(
|
let _ = write!(
|
||||||
"No memory map for board '{}'. Known boards: {}",
|
output,
|
||||||
board,
|
"No memory map for board '{board}'. Known boards: {}",
|
||||||
known.join(", ")
|
known.join(", ")
|
||||||
));
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -94,14 +94,16 @@ impl Tool for HardwareMemoryReadTool {
|
||||||
.get("address")
|
.get("address")
|
||||||
.and_then(|v| v.as_str())
|
.and_then(|v| v.as_str())
|
||||||
.unwrap_or("0x20000000");
|
.unwrap_or("0x20000000");
|
||||||
let address = parse_hex_address(address_str).unwrap_or(NUCLEO_RAM_BASE);
|
let _address = parse_hex_address(address_str).unwrap_or(NUCLEO_RAM_BASE);
|
||||||
|
|
||||||
let length = args.get("length").and_then(|v| v.as_u64()).unwrap_or(128) as usize;
|
let requested_length = args.get("length").and_then(|v| v.as_u64()).unwrap_or(128);
|
||||||
let length = length.min(256).max(1);
|
let _length = usize::try_from(requested_length)
|
||||||
|
.unwrap_or(256)
|
||||||
|
.clamp(1, 256);
|
||||||
|
|
||||||
#[cfg(feature = "probe")]
|
#[cfg(feature = "probe")]
|
||||||
{
|
{
|
||||||
match probe_read_memory(chip.unwrap(), address, length) {
|
match probe_read_memory(chip.unwrap(), _address, _length) {
|
||||||
Ok(output) => {
|
Ok(output) => {
|
||||||
return Ok(ToolResult {
|
return Ok(ToolResult {
|
||||||
success: true,
|
success: true,
|
||||||
|
|
|
||||||
|
|
@ -749,4 +749,54 @@ mod tests {
|
||||||
let _ = HttpRequestTool::redact_headers_for_display(&headers);
|
let _ = HttpRequestTool::redact_headers_for_display(&headers);
|
||||||
assert_eq!(headers[0].1, "Bearer real-token");
|
assert_eq!(headers[0].1, "Bearer real-token");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ── SSRF: alternate IP notation bypass defense-in-depth ─────────
|
||||||
|
//
|
||||||
|
// Rust's IpAddr::parse() rejects non-standard notations (octal, hex,
|
||||||
|
// decimal integer, zero-padded). These tests document that property
|
||||||
|
// so regressions are caught if the parsing strategy ever changes.
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn ssrf_octal_loopback_not_parsed_as_ip() {
|
||||||
|
// 0177.0.0.1 is octal for 127.0.0.1 in some languages, but
|
||||||
|
// Rust's IpAddr rejects it — it falls through as a hostname.
|
||||||
|
assert!(!is_private_or_local_host("0177.0.0.1"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn ssrf_hex_loopback_not_parsed_as_ip() {
|
||||||
|
// 0x7f000001 is hex for 127.0.0.1 in some languages.
|
||||||
|
assert!(!is_private_or_local_host("0x7f000001"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn ssrf_decimal_loopback_not_parsed_as_ip() {
|
||||||
|
// 2130706433 is decimal for 127.0.0.1 in some languages.
|
||||||
|
assert!(!is_private_or_local_host("2130706433"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn ssrf_zero_padded_loopback_not_parsed_as_ip() {
|
||||||
|
// 127.000.000.001 uses zero-padded octets.
|
||||||
|
assert!(!is_private_or_local_host("127.000.000.001"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn ssrf_alternate_notations_rejected_by_validate_url() {
|
||||||
|
// Even if is_private_or_local_host doesn't flag these, they
|
||||||
|
// fail the allowlist because they're treated as hostnames.
|
||||||
|
let tool = test_tool(vec!["example.com"]);
|
||||||
|
for notation in [
|
||||||
|
"http://0177.0.0.1",
|
||||||
|
"http://0x7f000001",
|
||||||
|
"http://2130706433",
|
||||||
|
"http://127.000.000.001",
|
||||||
|
] {
|
||||||
|
let err = tool.validate_url(notation).unwrap_err().to_string();
|
||||||
|
assert!(
|
||||||
|
err.contains("allowed_domains"),
|
||||||
|
"Expected allowlist rejection for {notation}, got: {err}"
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -87,7 +87,7 @@ mod tests {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn forget_existing() {
|
async fn forget_existing() {
|
||||||
let (_tmp, mem) = test_mem();
|
let (_tmp, mem) = test_mem();
|
||||||
mem.store("temp", "temporary", MemoryCategory::Conversation)
|
mem.store("temp", "temporary", MemoryCategory::Conversation, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -55,7 +55,7 @@ impl Tool for MemoryRecallTool {
|
||||||
.and_then(serde_json::Value::as_u64)
|
.and_then(serde_json::Value::as_u64)
|
||||||
.map_or(5, |v| v as usize);
|
.map_or(5, |v| v as usize);
|
||||||
|
|
||||||
match self.memory.recall(query, limit).await {
|
match self.memory.recall(query, limit, None).await {
|
||||||
Ok(entries) if entries.is_empty() => Ok(ToolResult {
|
Ok(entries) if entries.is_empty() => Ok(ToolResult {
|
||||||
success: true,
|
success: true,
|
||||||
output: "No memories found matching that query.".into(),
|
output: "No memories found matching that query.".into(),
|
||||||
|
|
@ -112,10 +112,10 @@ mod tests {
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
async fn recall_finds_match() {
|
async fn recall_finds_match() {
|
||||||
let (_tmp, mem) = seeded_mem();
|
let (_tmp, mem) = seeded_mem();
|
||||||
mem.store("lang", "User prefers Rust", MemoryCategory::Core)
|
mem.store("lang", "User prefers Rust", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
mem.store("tz", "Timezone is EST", MemoryCategory::Core)
|
mem.store("tz", "Timezone is EST", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
|
|
@ -134,6 +134,7 @@ mod tests {
|
||||||
&format!("k{i}"),
|
&format!("k{i}"),
|
||||||
&format!("Rust fact {i}"),
|
&format!("Rust fact {i}"),
|
||||||
MemoryCategory::Core,
|
MemoryCategory::Core,
|
||||||
|
None,
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
|
||||||
|
|
@ -64,7 +64,7 @@ impl Tool for MemoryStoreTool {
|
||||||
_ => MemoryCategory::Core,
|
_ => MemoryCategory::Core,
|
||||||
};
|
};
|
||||||
|
|
||||||
match self.memory.store(key, content, category).await {
|
match self.memory.store(key, content, category, None).await {
|
||||||
Ok(()) => Ok(ToolResult {
|
Ok(()) => Ok(ToolResult {
|
||||||
success: true,
|
success: true,
|
||||||
output: format!("Stored memory: {key}"),
|
output: format!("Stored memory: {key}"),
|
||||||
|
|
|
||||||
|
|
@ -19,7 +19,9 @@ pub mod image_info;
|
||||||
pub mod memory_forget;
|
pub mod memory_forget;
|
||||||
pub mod memory_recall;
|
pub mod memory_recall;
|
||||||
pub mod memory_store;
|
pub mod memory_store;
|
||||||
|
pub mod pushover;
|
||||||
pub mod schedule;
|
pub mod schedule;
|
||||||
|
pub mod schema;
|
||||||
pub mod screenshot;
|
pub mod screenshot;
|
||||||
pub mod shell;
|
pub mod shell;
|
||||||
pub mod traits;
|
pub mod traits;
|
||||||
|
|
@ -45,7 +47,9 @@ pub use image_info::ImageInfoTool;
|
||||||
pub use memory_forget::MemoryForgetTool;
|
pub use memory_forget::MemoryForgetTool;
|
||||||
pub use memory_recall::MemoryRecallTool;
|
pub use memory_recall::MemoryRecallTool;
|
||||||
pub use memory_store::MemoryStoreTool;
|
pub use memory_store::MemoryStoreTool;
|
||||||
|
pub use pushover::PushoverTool;
|
||||||
pub use schedule::ScheduleTool;
|
pub use schedule::ScheduleTool;
|
||||||
|
pub use schema::{CleaningStrategy, SchemaCleanr};
|
||||||
pub use screenshot::ScreenshotTool;
|
pub use screenshot::ScreenshotTool;
|
||||||
pub use shell::ShellTool;
|
pub use shell::ShellTool;
|
||||||
pub use traits::Tool;
|
pub use traits::Tool;
|
||||||
|
|
@ -141,6 +145,10 @@ pub fn all_tools_with_runtime(
|
||||||
security.clone(),
|
security.clone(),
|
||||||
workspace_dir.to_path_buf(),
|
workspace_dir.to_path_buf(),
|
||||||
)),
|
)),
|
||||||
|
Box::new(PushoverTool::new(
|
||||||
|
security.clone(),
|
||||||
|
workspace_dir.to_path_buf(),
|
||||||
|
)),
|
||||||
];
|
];
|
||||||
|
|
||||||
if browser_config.enabled {
|
if browser_config.enabled {
|
||||||
|
|
@ -195,9 +203,13 @@ pub fn all_tools_with_runtime(
|
||||||
.iter()
|
.iter()
|
||||||
.map(|(name, cfg)| (name.clone(), cfg.clone()))
|
.map(|(name, cfg)| (name.clone(), cfg.clone()))
|
||||||
.collect();
|
.collect();
|
||||||
|
let delegate_fallback_credential = fallback_api_key.and_then(|value| {
|
||||||
|
let trimmed_value = value.trim();
|
||||||
|
(!trimmed_value.is_empty()).then(|| trimmed_value.to_owned())
|
||||||
|
});
|
||||||
tools.push(Box::new(DelegateTool::new(
|
tools.push(Box::new(DelegateTool::new(
|
||||||
delegate_agents,
|
delegate_agents,
|
||||||
fallback_api_key.map(String::from),
|
delegate_fallback_credential,
|
||||||
)));
|
)));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -261,6 +273,7 @@ mod tests {
|
||||||
let names: Vec<&str> = tools.iter().map(|t| t.name()).collect();
|
let names: Vec<&str> = tools.iter().map(|t| t.name()).collect();
|
||||||
assert!(!names.contains(&"browser_open"));
|
assert!(!names.contains(&"browser_open"));
|
||||||
assert!(names.contains(&"schedule"));
|
assert!(names.contains(&"schedule"));
|
||||||
|
assert!(names.contains(&"pushover"));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
@ -298,6 +311,7 @@ mod tests {
|
||||||
);
|
);
|
||||||
let names: Vec<&str> = tools.iter().map(|t| t.name()).collect();
|
let names: Vec<&str> = tools.iter().map(|t| t.name()).collect();
|
||||||
assert!(names.contains(&"browser_open"));
|
assert!(names.contains(&"browser_open"));
|
||||||
|
assert!(names.contains(&"pushover"));
|
||||||
}
|
}
|
||||||
|
|
||||||
#[test]
|
#[test]
|
||||||
|
|
@ -432,7 +446,7 @@ mod tests {
|
||||||
&http,
|
&http,
|
||||||
tmp.path(),
|
tmp.path(),
|
||||||
&agents,
|
&agents,
|
||||||
Some("sk-test"),
|
Some("delegate-test-credential"),
|
||||||
&cfg,
|
&cfg,
|
||||||
);
|
);
|
||||||
let names: Vec<&str> = tools.iter().map(|t| t.name()).collect();
|
let names: Vec<&str> = tools.iter().map(|t| t.name()).collect();
|
||||||
|
|
|
||||||
442
src/tools/pushover.rs
Normal file
442
src/tools/pushover.rs
Normal file
|
|
@ -0,0 +1,442 @@
|
||||||
|
use super::traits::{Tool, ToolResult};
|
||||||
|
use crate::security::SecurityPolicy;
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use reqwest::Client;
|
||||||
|
use serde_json::json;
|
||||||
|
use std::path::PathBuf;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
|
const PUSHOVER_API_URL: &str = "https://api.pushover.net/1/messages.json";
|
||||||
|
const PUSHOVER_REQUEST_TIMEOUT_SECS: u64 = 15;
|
||||||
|
|
||||||
|
pub struct PushoverTool {
|
||||||
|
client: Client,
|
||||||
|
security: Arc<SecurityPolicy>,
|
||||||
|
workspace_dir: PathBuf,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PushoverTool {
|
||||||
|
pub fn new(security: Arc<SecurityPolicy>, workspace_dir: PathBuf) -> Self {
|
||||||
|
let client = Client::builder()
|
||||||
|
.timeout(Duration::from_secs(PUSHOVER_REQUEST_TIMEOUT_SECS))
|
||||||
|
.build()
|
||||||
|
.unwrap_or_else(|_| Client::new());
|
||||||
|
|
||||||
|
Self {
|
||||||
|
client,
|
||||||
|
security,
|
||||||
|
workspace_dir,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parse_env_value(raw: &str) -> String {
|
||||||
|
let raw = raw.trim();
|
||||||
|
|
||||||
|
let unquoted = if raw.len() >= 2
|
||||||
|
&& ((raw.starts_with('"') && raw.ends_with('"'))
|
||||||
|
|| (raw.starts_with('\'') && raw.ends_with('\'')))
|
||||||
|
{
|
||||||
|
&raw[1..raw.len() - 1]
|
||||||
|
} else {
|
||||||
|
raw
|
||||||
|
};
|
||||||
|
|
||||||
|
// Keep support for inline comments in unquoted values:
|
||||||
|
// KEY=value # comment
|
||||||
|
unquoted.split_once(" #").map_or_else(
|
||||||
|
|| unquoted.trim().to_string(),
|
||||||
|
|(value, _)| value.trim().to_string(),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn get_credentials(&self) -> anyhow::Result<(String, String)> {
|
||||||
|
let env_path = self.workspace_dir.join(".env");
|
||||||
|
let content = std::fs::read_to_string(&env_path)
|
||||||
|
.map_err(|e| anyhow::anyhow!("Failed to read {}: {}", env_path.display(), e))?;
|
||||||
|
|
||||||
|
let mut token = None;
|
||||||
|
let mut user_key = None;
|
||||||
|
|
||||||
|
for line in content.lines() {
|
||||||
|
let line = line.trim();
|
||||||
|
if line.starts_with('#') || line.is_empty() {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
let line = line.strip_prefix("export ").map(str::trim).unwrap_or(line);
|
||||||
|
if let Some((key, value)) = line.split_once('=') {
|
||||||
|
let key = key.trim();
|
||||||
|
let value = Self::parse_env_value(value);
|
||||||
|
|
||||||
|
if key.eq_ignore_ascii_case("PUSHOVER_TOKEN") {
|
||||||
|
token = Some(value);
|
||||||
|
} else if key.eq_ignore_ascii_case("PUSHOVER_USER_KEY") {
|
||||||
|
user_key = Some(value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let token = token.ok_or_else(|| anyhow::anyhow!("PUSHOVER_TOKEN not found in .env"))?;
|
||||||
|
let user_key =
|
||||||
|
user_key.ok_or_else(|| anyhow::anyhow!("PUSHOVER_USER_KEY not found in .env"))?;
|
||||||
|
|
||||||
|
Ok((token, user_key))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl Tool for PushoverTool {
|
||||||
|
fn name(&self) -> &str {
|
||||||
|
"pushover"
|
||||||
|
}
|
||||||
|
|
||||||
|
fn description(&self) -> &str {
|
||||||
|
"Send a Pushover notification to your device. Requires PUSHOVER_TOKEN and PUSHOVER_USER_KEY in .env file."
|
||||||
|
}
|
||||||
|
|
||||||
|
fn parameters_schema(&self) -> serde_json::Value {
|
||||||
|
json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"message": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "The notification message to send"
|
||||||
|
},
|
||||||
|
"title": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Optional notification title"
|
||||||
|
},
|
||||||
|
"priority": {
|
||||||
|
"type": "integer",
|
||||||
|
"enum": [-2, -1, 0, 1, 2],
|
||||||
|
"description": "Message priority: -2 (lowest/silent), -1 (low/no sound), 0 (normal), 1 (high), 2 (emergency/repeating)"
|
||||||
|
},
|
||||||
|
"sound": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Notification sound override (e.g., 'pushover', 'bike', 'bugle', 'cashregister', etc.)"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"required": ["message"]
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
|
||||||
|
if !self.security.can_act() {
|
||||||
|
return Ok(ToolResult {
|
||||||
|
success: false,
|
||||||
|
output: String::new(),
|
||||||
|
error: Some("Action blocked: autonomy is read-only".into()),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
if !self.security.record_action() {
|
||||||
|
return Ok(ToolResult {
|
||||||
|
success: false,
|
||||||
|
output: String::new(),
|
||||||
|
error: Some("Action blocked: rate limit exceeded".into()),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
let message = args
|
||||||
|
.get("message")
|
||||||
|
.and_then(|v| v.as_str())
|
||||||
|
.map(str::trim)
|
||||||
|
.filter(|v| !v.is_empty())
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("Missing 'message' parameter"))?
|
||||||
|
.to_string();
|
||||||
|
|
||||||
|
let title = args.get("title").and_then(|v| v.as_str()).map(String::from);
|
||||||
|
|
||||||
|
let priority = match args.get("priority").and_then(|v| v.as_i64()) {
|
||||||
|
Some(value) if (-2..=2).contains(&value) => Some(value),
|
||||||
|
Some(value) => {
|
||||||
|
return Ok(ToolResult {
|
||||||
|
success: false,
|
||||||
|
output: String::new(),
|
||||||
|
error: Some(format!(
|
||||||
|
"Invalid 'priority': {value}. Expected integer in range -2..=2"
|
||||||
|
)),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
None => None,
|
||||||
|
};
|
||||||
|
|
||||||
|
let sound = args.get("sound").and_then(|v| v.as_str()).map(String::from);
|
||||||
|
|
||||||
|
let (token, user_key) = self.get_credentials()?;
|
||||||
|
|
||||||
|
let mut form = reqwest::multipart::Form::new()
|
||||||
|
.text("token", token)
|
||||||
|
.text("user", user_key)
|
||||||
|
.text("message", message);
|
||||||
|
|
||||||
|
if let Some(title) = title {
|
||||||
|
form = form.text("title", title);
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(priority) = priority {
|
||||||
|
form = form.text("priority", priority.to_string());
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(sound) = sound {
|
||||||
|
form = form.text("sound", sound);
|
||||||
|
}
|
||||||
|
|
||||||
|
let response = self
|
||||||
|
.client
|
||||||
|
.post(PUSHOVER_API_URL)
|
||||||
|
.multipart(form)
|
||||||
|
.send()
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let status = response.status();
|
||||||
|
let body = response.text().await.unwrap_or_default();
|
||||||
|
|
||||||
|
if !status.is_success() {
|
||||||
|
return Ok(ToolResult {
|
||||||
|
success: false,
|
||||||
|
output: body,
|
||||||
|
error: Some(format!("Pushover API returned status {}", status)),
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
let api_status = serde_json::from_str::<serde_json::Value>(&body)
|
||||||
|
.ok()
|
||||||
|
.and_then(|json| json.get("status").and_then(|value| value.as_i64()));
|
||||||
|
|
||||||
|
if api_status == Some(1) {
|
||||||
|
Ok(ToolResult {
|
||||||
|
success: true,
|
||||||
|
output: format!(
|
||||||
|
"Pushover notification sent successfully. Response: {}",
|
||||||
|
body
|
||||||
|
),
|
||||||
|
error: None,
|
||||||
|
})
|
||||||
|
} else {
|
||||||
|
Ok(ToolResult {
|
||||||
|
success: false,
|
||||||
|
output: body,
|
||||||
|
error: Some("Pushover API returned an application-level error".into()),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::security::AutonomyLevel;
|
||||||
|
use std::fs;
|
||||||
|
use tempfile::TempDir;
|
||||||
|
|
||||||
|
fn test_security(level: AutonomyLevel, max_actions_per_hour: u32) -> Arc<SecurityPolicy> {
|
||||||
|
Arc::new(SecurityPolicy {
|
||||||
|
autonomy: level,
|
||||||
|
max_actions_per_hour,
|
||||||
|
workspace_dir: std::env::temp_dir(),
|
||||||
|
..SecurityPolicy::default()
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn pushover_tool_name() {
|
||||||
|
let tool = PushoverTool::new(
|
||||||
|
test_security(AutonomyLevel::Full, 100),
|
||||||
|
PathBuf::from("/tmp"),
|
||||||
|
);
|
||||||
|
assert_eq!(tool.name(), "pushover");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn pushover_tool_description() {
|
||||||
|
let tool = PushoverTool::new(
|
||||||
|
test_security(AutonomyLevel::Full, 100),
|
||||||
|
PathBuf::from("/tmp"),
|
||||||
|
);
|
||||||
|
assert!(!tool.description().is_empty());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn pushover_tool_has_parameters_schema() {
|
||||||
|
let tool = PushoverTool::new(
|
||||||
|
test_security(AutonomyLevel::Full, 100),
|
||||||
|
PathBuf::from("/tmp"),
|
||||||
|
);
|
||||||
|
let schema = tool.parameters_schema();
|
||||||
|
assert_eq!(schema["type"], "object");
|
||||||
|
assert!(schema["properties"].get("message").is_some());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn pushover_tool_requires_message() {
|
||||||
|
let tool = PushoverTool::new(
|
||||||
|
test_security(AutonomyLevel::Full, 100),
|
||||||
|
PathBuf::from("/tmp"),
|
||||||
|
);
|
||||||
|
let schema = tool.parameters_schema();
|
||||||
|
let required = schema["required"].as_array().unwrap();
|
||||||
|
assert!(required.contains(&serde_json::Value::String("message".to_string())));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn credentials_parsed_from_env_file() {
|
||||||
|
let tmp = TempDir::new().unwrap();
|
||||||
|
let env_path = tmp.path().join(".env");
|
||||||
|
fs::write(
|
||||||
|
&env_path,
|
||||||
|
"PUSHOVER_TOKEN=testtoken123\nPUSHOVER_USER_KEY=userkey456\n",
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let tool = PushoverTool::new(
|
||||||
|
test_security(AutonomyLevel::Full, 100),
|
||||||
|
tmp.path().to_path_buf(),
|
||||||
|
);
|
||||||
|
let result = tool.get_credentials();
|
||||||
|
|
||||||
|
assert!(result.is_ok());
|
||||||
|
let (token, user_key) = result.unwrap();
|
||||||
|
assert_eq!(token, "testtoken123");
|
||||||
|
assert_eq!(user_key, "userkey456");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn credentials_fail_without_env_file() {
|
||||||
|
let tmp = TempDir::new().unwrap();
|
||||||
|
let tool = PushoverTool::new(
|
||||||
|
test_security(AutonomyLevel::Full, 100),
|
||||||
|
tmp.path().to_path_buf(),
|
||||||
|
);
|
||||||
|
let result = tool.get_credentials();
|
||||||
|
|
||||||
|
assert!(result.is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn credentials_fail_without_token() {
|
||||||
|
let tmp = TempDir::new().unwrap();
|
||||||
|
let env_path = tmp.path().join(".env");
|
||||||
|
fs::write(&env_path, "PUSHOVER_USER_KEY=userkey456\n").unwrap();
|
||||||
|
|
||||||
|
let tool = PushoverTool::new(
|
||||||
|
test_security(AutonomyLevel::Full, 100),
|
||||||
|
tmp.path().to_path_buf(),
|
||||||
|
);
|
||||||
|
let result = tool.get_credentials();
|
||||||
|
|
||||||
|
assert!(result.is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn credentials_fail_without_user_key() {
|
||||||
|
let tmp = TempDir::new().unwrap();
|
||||||
|
let env_path = tmp.path().join(".env");
|
||||||
|
fs::write(&env_path, "PUSHOVER_TOKEN=testtoken123\n").unwrap();
|
||||||
|
|
||||||
|
let tool = PushoverTool::new(
|
||||||
|
test_security(AutonomyLevel::Full, 100),
|
||||||
|
tmp.path().to_path_buf(),
|
||||||
|
);
|
||||||
|
let result = tool.get_credentials();
|
||||||
|
|
||||||
|
assert!(result.is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn credentials_ignore_comments() {
|
||||||
|
let tmp = TempDir::new().unwrap();
|
||||||
|
let env_path = tmp.path().join(".env");
|
||||||
|
fs::write(&env_path, "# This is a comment\nPUSHOVER_TOKEN=realtoken\n# Another comment\nPUSHOVER_USER_KEY=realuser\n").unwrap();
|
||||||
|
|
||||||
|
let tool = PushoverTool::new(
|
||||||
|
test_security(AutonomyLevel::Full, 100),
|
||||||
|
tmp.path().to_path_buf(),
|
||||||
|
);
|
||||||
|
let result = tool.get_credentials();
|
||||||
|
|
||||||
|
assert!(result.is_ok());
|
||||||
|
let (token, user_key) = result.unwrap();
|
||||||
|
assert_eq!(token, "realtoken");
|
||||||
|
assert_eq!(user_key, "realuser");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn pushover_tool_supports_priority() {
|
||||||
|
let tool = PushoverTool::new(
|
||||||
|
test_security(AutonomyLevel::Full, 100),
|
||||||
|
PathBuf::from("/tmp"),
|
||||||
|
);
|
||||||
|
let schema = tool.parameters_schema();
|
||||||
|
assert!(schema["properties"].get("priority").is_some());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn pushover_tool_supports_sound() {
|
||||||
|
let tool = PushoverTool::new(
|
||||||
|
test_security(AutonomyLevel::Full, 100),
|
||||||
|
PathBuf::from("/tmp"),
|
||||||
|
);
|
||||||
|
let schema = tool.parameters_schema();
|
||||||
|
assert!(schema["properties"].get("sound").is_some());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn credentials_support_export_and_quoted_values() {
|
||||||
|
let tmp = TempDir::new().unwrap();
|
||||||
|
let env_path = tmp.path().join(".env");
|
||||||
|
fs::write(
|
||||||
|
&env_path,
|
||||||
|
"export PUSHOVER_TOKEN=\"quotedtoken\"\nPUSHOVER_USER_KEY='quoteduser'\n",
|
||||||
|
)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let tool = PushoverTool::new(
|
||||||
|
test_security(AutonomyLevel::Full, 100),
|
||||||
|
tmp.path().to_path_buf(),
|
||||||
|
);
|
||||||
|
let result = tool.get_credentials();
|
||||||
|
|
||||||
|
assert!(result.is_ok());
|
||||||
|
let (token, user_key) = result.unwrap();
|
||||||
|
assert_eq!(token, "quotedtoken");
|
||||||
|
assert_eq!(user_key, "quoteduser");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn execute_blocks_readonly_mode() {
|
||||||
|
let tool = PushoverTool::new(
|
||||||
|
test_security(AutonomyLevel::ReadOnly, 100),
|
||||||
|
PathBuf::from("/tmp"),
|
||||||
|
);
|
||||||
|
|
||||||
|
let result = tool.execute(json!({"message": "hello"})).await.unwrap();
|
||||||
|
assert!(!result.success);
|
||||||
|
assert!(result.error.unwrap().contains("read-only"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn execute_blocks_rate_limit() {
|
||||||
|
let tool = PushoverTool::new(test_security(AutonomyLevel::Full, 0), PathBuf::from("/tmp"));
|
||||||
|
|
||||||
|
let result = tool.execute(json!({"message": "hello"})).await.unwrap();
|
||||||
|
assert!(!result.success);
|
||||||
|
assert!(result.error.unwrap().contains("rate limit"));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn execute_rejects_priority_out_of_range() {
|
||||||
|
let tool = PushoverTool::new(
|
||||||
|
test_security(AutonomyLevel::Full, 100),
|
||||||
|
PathBuf::from("/tmp"),
|
||||||
|
);
|
||||||
|
|
||||||
|
let result = tool
|
||||||
|
.execute(json!({"message": "hello", "priority": 5}))
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
assert!(!result.success);
|
||||||
|
assert!(result.error.unwrap().contains("-2..=2"));
|
||||||
|
}
|
||||||
|
}
|
||||||
838
src/tools/schema.rs
Normal file
838
src/tools/schema.rs
Normal file
|
|
@ -0,0 +1,838 @@
|
||||||
|
//! JSON Schema cleaning and validation for LLM tool-calling compatibility.
|
||||||
|
//!
|
||||||
|
//! Different providers support different subsets of JSON Schema. This module
|
||||||
|
//! normalizes tool schemas to improve cross-provider compatibility while
|
||||||
|
//! preserving semantic intent.
|
||||||
|
//!
|
||||||
|
//! ## What this module does
|
||||||
|
//!
|
||||||
|
//! 1. Removes unsupported keywords per provider strategy
|
||||||
|
//! 2. Resolves local `$ref` entries from `$defs` and `definitions`
|
||||||
|
//! 3. Flattens literal `anyOf` / `oneOf` unions into `enum`
|
||||||
|
//! 4. Strips nullable variants from unions and `type` arrays
|
||||||
|
//! 5. Converts `const` to single-value `enum`
|
||||||
|
//! 6. Detects circular references and stops recursion safely
|
||||||
|
//!
|
||||||
|
//! # Example
|
||||||
|
//!
|
||||||
|
//! ```rust
|
||||||
|
//! use serde_json::json;
|
||||||
|
//! use zeroclaw::tools::schema::SchemaCleanr;
|
||||||
|
//!
|
||||||
|
//! let dirty_schema = json!({
|
||||||
|
//! "type": "object",
|
||||||
|
//! "properties": {
|
||||||
|
//! "name": {
|
||||||
|
//! "type": "string",
|
||||||
|
//! "minLength": 1, // Gemini rejects this
|
||||||
|
//! "pattern": "^[a-z]+$" // Gemini rejects this
|
||||||
|
//! },
|
||||||
|
//! "age": {
|
||||||
|
//! "$ref": "#/$defs/Age" // Needs resolution
|
||||||
|
//! }
|
||||||
|
//! },
|
||||||
|
//! "$defs": {
|
||||||
|
//! "Age": {
|
||||||
|
//! "type": "integer",
|
||||||
|
//! "minimum": 0 // Gemini rejects this
|
||||||
|
//! }
|
||||||
|
//! }
|
||||||
|
//! });
|
||||||
|
//!
|
||||||
|
//! let cleaned = SchemaCleanr::clean_for_gemini(dirty_schema);
|
||||||
|
//!
|
||||||
|
//! // Result:
|
||||||
|
//! // {
|
||||||
|
//! // "type": "object",
|
||||||
|
//! // "properties": {
|
||||||
|
//! // "name": { "type": "string" },
|
||||||
|
//! // "age": { "type": "integer" }
|
||||||
|
//! // }
|
||||||
|
//! // }
|
||||||
|
//! ```
|
||||||
|
//!
|
||||||
|
use serde_json::{json, Map, Value};
|
||||||
|
use std::collections::{HashMap, HashSet};
|
||||||
|
|
||||||
|
/// Keywords that Gemini rejects for tool schemas.
|
||||||
|
pub const GEMINI_UNSUPPORTED_KEYWORDS: &[&str] = &[
|
||||||
|
// Schema composition
|
||||||
|
"$ref",
|
||||||
|
"$schema",
|
||||||
|
"$id",
|
||||||
|
"$defs",
|
||||||
|
"definitions",
|
||||||
|
// Property constraints
|
||||||
|
"additionalProperties",
|
||||||
|
"patternProperties",
|
||||||
|
// String constraints
|
||||||
|
"minLength",
|
||||||
|
"maxLength",
|
||||||
|
"pattern",
|
||||||
|
"format",
|
||||||
|
// Number constraints
|
||||||
|
"minimum",
|
||||||
|
"maximum",
|
||||||
|
"multipleOf",
|
||||||
|
// Array constraints
|
||||||
|
"minItems",
|
||||||
|
"maxItems",
|
||||||
|
"uniqueItems",
|
||||||
|
// Object constraints
|
||||||
|
"minProperties",
|
||||||
|
"maxProperties",
|
||||||
|
// Non-standard
|
||||||
|
"examples", // OpenAPI keyword, not JSON Schema
|
||||||
|
];
|
||||||
|
|
||||||
|
/// Keywords that should be preserved during cleaning (metadata).
|
||||||
|
const SCHEMA_META_KEYS: &[&str] = &["description", "title", "default"];
|
||||||
|
|
||||||
|
/// Schema cleaning strategies for different LLM providers.
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
|
pub enum CleaningStrategy {
|
||||||
|
/// Gemini (Google AI / Vertex AI) - Most restrictive
|
||||||
|
Gemini,
|
||||||
|
/// Anthropic Claude - Moderately permissive
|
||||||
|
Anthropic,
|
||||||
|
/// OpenAI GPT - Most permissive
|
||||||
|
OpenAI,
|
||||||
|
/// Conservative: Remove only universally unsupported keywords
|
||||||
|
Conservative,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl CleaningStrategy {
|
||||||
|
/// Get the list of unsupported keywords for this strategy.
|
||||||
|
pub fn unsupported_keywords(&self) -> &'static [&'static str] {
|
||||||
|
match self {
|
||||||
|
Self::Gemini => GEMINI_UNSUPPORTED_KEYWORDS,
|
||||||
|
Self::Anthropic => &["$ref", "$defs", "definitions"], // Anthropic doesn't resolve refs
|
||||||
|
Self::OpenAI => &[], // OpenAI is most permissive
|
||||||
|
Self::Conservative => &["$ref", "$defs", "definitions", "additionalProperties"],
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// JSON Schema cleaner optimized for LLM tool calling.
|
||||||
|
pub struct SchemaCleanr;
|
||||||
|
|
||||||
|
impl SchemaCleanr {
|
||||||
|
/// Clean schema for Gemini compatibility (strictest).
|
||||||
|
///
|
||||||
|
/// This is the most aggressive cleaning strategy, removing all keywords
|
||||||
|
/// that Gemini's API rejects.
|
||||||
|
pub fn clean_for_gemini(schema: Value) -> Value {
|
||||||
|
Self::clean(schema, CleaningStrategy::Gemini)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Clean schema for Anthropic compatibility.
|
||||||
|
pub fn clean_for_anthropic(schema: Value) -> Value {
|
||||||
|
Self::clean(schema, CleaningStrategy::Anthropic)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Clean schema for OpenAI compatibility (most permissive).
|
||||||
|
pub fn clean_for_openai(schema: Value) -> Value {
|
||||||
|
Self::clean(schema, CleaningStrategy::OpenAI)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Clean schema with specified strategy.
|
||||||
|
pub fn clean(schema: Value, strategy: CleaningStrategy) -> Value {
|
||||||
|
// Extract $defs for reference resolution
|
||||||
|
let defs = if let Some(obj) = schema.as_object() {
|
||||||
|
Self::extract_defs(obj)
|
||||||
|
} else {
|
||||||
|
HashMap::new()
|
||||||
|
};
|
||||||
|
|
||||||
|
Self::clean_with_defs(schema, &defs, strategy, &mut HashSet::new())
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Validate that a schema is suitable for LLM tool calling.
|
||||||
|
///
|
||||||
|
/// Returns an error if the schema is invalid or missing required fields.
|
||||||
|
pub fn validate(schema: &Value) -> anyhow::Result<()> {
|
||||||
|
let obj = schema
|
||||||
|
.as_object()
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("Schema must be an object"))?;
|
||||||
|
|
||||||
|
// Must have 'type' field
|
||||||
|
if !obj.contains_key("type") {
|
||||||
|
anyhow::bail!("Schema missing required 'type' field");
|
||||||
|
}
|
||||||
|
|
||||||
|
// If type is 'object', should have 'properties'
|
||||||
|
if let Some(Value::String(t)) = obj.get("type") {
|
||||||
|
if t == "object" && !obj.contains_key("properties") {
|
||||||
|
tracing::warn!("Object schema without 'properties' field may cause issues");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
// --------------------------------------------------------------------
|
||||||
|
// Internal implementation
|
||||||
|
// --------------------------------------------------------------------
|
||||||
|
|
||||||
|
/// Extract $defs and definitions into a flat map for reference resolution.
|
||||||
|
fn extract_defs(obj: &Map<String, Value>) -> HashMap<String, Value> {
|
||||||
|
let mut defs = HashMap::new();
|
||||||
|
|
||||||
|
// Extract from $defs (JSON Schema 2019-09+)
|
||||||
|
if let Some(Value::Object(defs_obj)) = obj.get("$defs") {
|
||||||
|
for (key, value) in defs_obj {
|
||||||
|
defs.insert(key.clone(), value.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract from definitions (JSON Schema draft-07)
|
||||||
|
if let Some(Value::Object(defs_obj)) = obj.get("definitions") {
|
||||||
|
for (key, value) in defs_obj {
|
||||||
|
defs.insert(key.clone(), value.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
defs
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Recursively clean a schema value.
|
||||||
|
fn clean_with_defs(
|
||||||
|
schema: Value,
|
||||||
|
defs: &HashMap<String, Value>,
|
||||||
|
strategy: CleaningStrategy,
|
||||||
|
ref_stack: &mut HashSet<String>,
|
||||||
|
) -> Value {
|
||||||
|
match schema {
|
||||||
|
Value::Object(obj) => Self::clean_object(obj, defs, strategy, ref_stack),
|
||||||
|
Value::Array(arr) => Value::Array(
|
||||||
|
arr.into_iter()
|
||||||
|
.map(|v| Self::clean_with_defs(v, defs, strategy, ref_stack))
|
||||||
|
.collect(),
|
||||||
|
),
|
||||||
|
other => other,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Clean an object schema.
|
||||||
|
fn clean_object(
|
||||||
|
obj: Map<String, Value>,
|
||||||
|
defs: &HashMap<String, Value>,
|
||||||
|
strategy: CleaningStrategy,
|
||||||
|
ref_stack: &mut HashSet<String>,
|
||||||
|
) -> Value {
|
||||||
|
// Handle $ref resolution
|
||||||
|
if let Some(Value::String(ref_value)) = obj.get("$ref") {
|
||||||
|
return Self::resolve_ref(ref_value, &obj, defs, strategy, ref_stack);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle anyOf/oneOf simplification
|
||||||
|
if obj.contains_key("anyOf") || obj.contains_key("oneOf") {
|
||||||
|
if let Some(simplified) = Self::try_simplify_union(&obj, defs, strategy, ref_stack) {
|
||||||
|
return simplified;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build cleaned object
|
||||||
|
let mut cleaned = Map::new();
|
||||||
|
let unsupported: HashSet<&str> = strategy.unsupported_keywords().iter().copied().collect();
|
||||||
|
let has_union = obj.contains_key("anyOf") || obj.contains_key("oneOf");
|
||||||
|
|
||||||
|
for (key, value) in obj {
|
||||||
|
// Skip unsupported keywords
|
||||||
|
if unsupported.contains(key.as_str()) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Special handling for specific keys
|
||||||
|
match key.as_str() {
|
||||||
|
// Convert const to enum
|
||||||
|
"const" => {
|
||||||
|
cleaned.insert("enum".to_string(), json!([value]));
|
||||||
|
}
|
||||||
|
// Skip type if we have anyOf/oneOf (they define the type)
|
||||||
|
"type" if has_union => {
|
||||||
|
// Skip
|
||||||
|
}
|
||||||
|
// Handle type arrays (remove null)
|
||||||
|
"type" if matches!(value, Value::Array(_)) => {
|
||||||
|
let cleaned_value = Self::clean_type_array(value);
|
||||||
|
cleaned.insert(key, cleaned_value);
|
||||||
|
}
|
||||||
|
// Recursively clean nested schemas
|
||||||
|
"properties" => {
|
||||||
|
let cleaned_value = Self::clean_properties(value, defs, strategy, ref_stack);
|
||||||
|
cleaned.insert(key, cleaned_value);
|
||||||
|
}
|
||||||
|
"items" => {
|
||||||
|
let cleaned_value = Self::clean_with_defs(value, defs, strategy, ref_stack);
|
||||||
|
cleaned.insert(key, cleaned_value);
|
||||||
|
}
|
||||||
|
"anyOf" | "oneOf" | "allOf" => {
|
||||||
|
let cleaned_value = Self::clean_union(value, defs, strategy, ref_stack);
|
||||||
|
cleaned.insert(key, cleaned_value);
|
||||||
|
}
|
||||||
|
// Keep all other keys, cleaning nested objects/arrays recursively.
|
||||||
|
_ => {
|
||||||
|
let cleaned_value = match value {
|
||||||
|
Value::Object(_) | Value::Array(_) => {
|
||||||
|
Self::clean_with_defs(value, defs, strategy, ref_stack)
|
||||||
|
}
|
||||||
|
other => other,
|
||||||
|
};
|
||||||
|
cleaned.insert(key, cleaned_value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Value::Object(cleaned)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Resolve a $ref to its definition.
|
||||||
|
fn resolve_ref(
|
||||||
|
ref_value: &str,
|
||||||
|
obj: &Map<String, Value>,
|
||||||
|
defs: &HashMap<String, Value>,
|
||||||
|
strategy: CleaningStrategy,
|
||||||
|
ref_stack: &mut HashSet<String>,
|
||||||
|
) -> Value {
|
||||||
|
// Prevent circular references
|
||||||
|
if ref_stack.contains(ref_value) {
|
||||||
|
tracing::warn!("Circular $ref detected: {}", ref_value);
|
||||||
|
return Self::preserve_meta(obj, Value::Object(Map::new()));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to resolve local ref (#/$defs/Name or #/definitions/Name)
|
||||||
|
if let Some(def_name) = Self::parse_local_ref(ref_value) {
|
||||||
|
if let Some(definition) = defs.get(def_name.as_str()) {
|
||||||
|
ref_stack.insert(ref_value.to_string());
|
||||||
|
let cleaned = Self::clean_with_defs(definition.clone(), defs, strategy, ref_stack);
|
||||||
|
ref_stack.remove(ref_value);
|
||||||
|
return Self::preserve_meta(obj, cleaned);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Can't resolve: return empty object with metadata
|
||||||
|
tracing::warn!("Cannot resolve $ref: {}", ref_value);
|
||||||
|
Self::preserve_meta(obj, Value::Object(Map::new()))
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Parse a local JSON Pointer ref (#/$defs/Name).
|
||||||
|
fn parse_local_ref(ref_value: &str) -> Option<String> {
|
||||||
|
ref_value
|
||||||
|
.strip_prefix("#/$defs/")
|
||||||
|
.or_else(|| ref_value.strip_prefix("#/definitions/"))
|
||||||
|
.map(Self::decode_json_pointer)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Decode JSON Pointer escaping (`~0` = `~`, `~1` = `/`).
|
||||||
|
fn decode_json_pointer(segment: &str) -> String {
|
||||||
|
if !segment.contains('~') {
|
||||||
|
return segment.to_string();
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut decoded = String::with_capacity(segment.len());
|
||||||
|
let mut chars = segment.chars().peekable();
|
||||||
|
|
||||||
|
while let Some(ch) = chars.next() {
|
||||||
|
if ch == '~' {
|
||||||
|
match chars.peek().copied() {
|
||||||
|
Some('0') => {
|
||||||
|
chars.next();
|
||||||
|
decoded.push('~');
|
||||||
|
}
|
||||||
|
Some('1') => {
|
||||||
|
chars.next();
|
||||||
|
decoded.push('/');
|
||||||
|
}
|
||||||
|
_ => decoded.push('~'),
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
decoded.push(ch);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
decoded
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Try to simplify anyOf/oneOf to a simpler form.
|
||||||
|
fn try_simplify_union(
|
||||||
|
obj: &Map<String, Value>,
|
||||||
|
defs: &HashMap<String, Value>,
|
||||||
|
strategy: CleaningStrategy,
|
||||||
|
ref_stack: &mut HashSet<String>,
|
||||||
|
) -> Option<Value> {
|
||||||
|
let union_key = if obj.contains_key("anyOf") {
|
||||||
|
"anyOf"
|
||||||
|
} else if obj.contains_key("oneOf") {
|
||||||
|
"oneOf"
|
||||||
|
} else {
|
||||||
|
return None;
|
||||||
|
};
|
||||||
|
|
||||||
|
let variants = obj.get(union_key)?.as_array()?;
|
||||||
|
|
||||||
|
// Clean all variants first
|
||||||
|
let cleaned_variants: Vec<Value> = variants
|
||||||
|
.iter()
|
||||||
|
.map(|v| Self::clean_with_defs(v.clone(), defs, strategy, ref_stack))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
// Strip null variants
|
||||||
|
let non_null: Vec<Value> = cleaned_variants
|
||||||
|
.into_iter()
|
||||||
|
.filter(|v| !Self::is_null_schema(v))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
// If only one variant remains after stripping nulls, return it
|
||||||
|
if non_null.len() == 1 {
|
||||||
|
return Some(Self::preserve_meta(obj, non_null[0].clone()));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Try to flatten to enum if all variants are literals
|
||||||
|
if let Some(enum_value) = Self::try_flatten_literal_union(&non_null) {
|
||||||
|
return Some(Self::preserve_meta(obj, enum_value));
|
||||||
|
}
|
||||||
|
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Check if a schema represents null type.
|
||||||
|
fn is_null_schema(value: &Value) -> bool {
|
||||||
|
if let Some(obj) = value.as_object() {
|
||||||
|
// { const: null }
|
||||||
|
if let Some(Value::Null) = obj.get("const") {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
// { enum: [null] }
|
||||||
|
if let Some(Value::Array(arr)) = obj.get("enum") {
|
||||||
|
if arr.len() == 1 && matches!(arr[0], Value::Null) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// { type: "null" }
|
||||||
|
if let Some(Value::String(t)) = obj.get("type") {
|
||||||
|
if t == "null" {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
false
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Try to flatten anyOf/oneOf with only literal values to enum.
|
||||||
|
///
|
||||||
|
/// Example: `anyOf: [{const: "a"}, {const: "b"}]` -> `{type: "string", enum: ["a", "b"]}`
|
||||||
|
fn try_flatten_literal_union(variants: &[Value]) -> Option<Value> {
|
||||||
|
if variants.is_empty() {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
let mut all_values = Vec::new();
|
||||||
|
let mut common_type: Option<String> = None;
|
||||||
|
|
||||||
|
for variant in variants {
|
||||||
|
let obj = variant.as_object()?;
|
||||||
|
|
||||||
|
// Extract literal value from const or single-item enum
|
||||||
|
let literal_value = if let Some(const_val) = obj.get("const") {
|
||||||
|
const_val.clone()
|
||||||
|
} else if let Some(Value::Array(arr)) = obj.get("enum") {
|
||||||
|
if arr.len() == 1 {
|
||||||
|
arr[0].clone()
|
||||||
|
} else {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return None;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Check type consistency
|
||||||
|
let variant_type = obj.get("type")?.as_str()?;
|
||||||
|
match &common_type {
|
||||||
|
None => common_type = Some(variant_type.to_string()),
|
||||||
|
Some(t) if t != variant_type => return None,
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
|
||||||
|
all_values.push(literal_value);
|
||||||
|
}
|
||||||
|
|
||||||
|
common_type.map(|t| {
|
||||||
|
json!({
|
||||||
|
"type": t,
|
||||||
|
"enum": all_values
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Clean type array, removing null.
|
||||||
|
fn clean_type_array(value: Value) -> Value {
|
||||||
|
if let Value::Array(types) = value {
|
||||||
|
let non_null: Vec<Value> = types
|
||||||
|
.into_iter()
|
||||||
|
.filter(|v| v.as_str() != Some("null"))
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
match non_null.len() {
|
||||||
|
0 => Value::String("null".to_string()),
|
||||||
|
1 => non_null
|
||||||
|
.into_iter()
|
||||||
|
.next()
|
||||||
|
.unwrap_or(Value::String("null".to_string())),
|
||||||
|
_ => Value::Array(non_null),
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Clean properties object.
|
||||||
|
fn clean_properties(
|
||||||
|
value: Value,
|
||||||
|
defs: &HashMap<String, Value>,
|
||||||
|
strategy: CleaningStrategy,
|
||||||
|
ref_stack: &mut HashSet<String>,
|
||||||
|
) -> Value {
|
||||||
|
if let Value::Object(props) = value {
|
||||||
|
let cleaned: Map<String, Value> = props
|
||||||
|
.into_iter()
|
||||||
|
.map(|(k, v)| (k, Self::clean_with_defs(v, defs, strategy, ref_stack)))
|
||||||
|
.collect();
|
||||||
|
Value::Object(cleaned)
|
||||||
|
} else {
|
||||||
|
value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Clean union (anyOf/oneOf/allOf).
|
||||||
|
fn clean_union(
|
||||||
|
value: Value,
|
||||||
|
defs: &HashMap<String, Value>,
|
||||||
|
strategy: CleaningStrategy,
|
||||||
|
ref_stack: &mut HashSet<String>,
|
||||||
|
) -> Value {
|
||||||
|
if let Value::Array(variants) = value {
|
||||||
|
let cleaned: Vec<Value> = variants
|
||||||
|
.into_iter()
|
||||||
|
.map(|v| Self::clean_with_defs(v, defs, strategy, ref_stack))
|
||||||
|
.collect();
|
||||||
|
Value::Array(cleaned)
|
||||||
|
} else {
|
||||||
|
value
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Preserve metadata (description, title, default) from source to target.
|
||||||
|
fn preserve_meta(source: &Map<String, Value>, mut target: Value) -> Value {
|
||||||
|
if let Value::Object(target_obj) = &mut target {
|
||||||
|
for &key in SCHEMA_META_KEYS {
|
||||||
|
if let Some(value) = source.get(key) {
|
||||||
|
target_obj.insert(key.to_string(), value.clone());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
target
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_remove_unsupported_keywords() {
|
||||||
|
let schema = json!({
|
||||||
|
"type": "string",
|
||||||
|
"minLength": 1,
|
||||||
|
"maxLength": 100,
|
||||||
|
"pattern": "^[a-z]+$",
|
||||||
|
"description": "A lowercase string"
|
||||||
|
});
|
||||||
|
|
||||||
|
let cleaned = SchemaCleanr::clean_for_gemini(schema);
|
||||||
|
|
||||||
|
assert_eq!(cleaned["type"], "string");
|
||||||
|
assert_eq!(cleaned["description"], "A lowercase string");
|
||||||
|
assert!(cleaned.get("minLength").is_none());
|
||||||
|
assert!(cleaned.get("maxLength").is_none());
|
||||||
|
assert!(cleaned.get("pattern").is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_resolve_ref() {
|
||||||
|
let schema = json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"age": {
|
||||||
|
"$ref": "#/$defs/Age"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"$defs": {
|
||||||
|
"Age": {
|
||||||
|
"type": "integer",
|
||||||
|
"minimum": 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let cleaned = SchemaCleanr::clean_for_gemini(schema);
|
||||||
|
|
||||||
|
assert_eq!(cleaned["properties"]["age"]["type"], "integer");
|
||||||
|
assert!(cleaned["properties"]["age"].get("minimum").is_none()); // Stripped by Gemini strategy
|
||||||
|
assert!(cleaned.get("$defs").is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_flatten_literal_union() {
|
||||||
|
let schema = json!({
|
||||||
|
"anyOf": [
|
||||||
|
{ "const": "admin", "type": "string" },
|
||||||
|
{ "const": "user", "type": "string" },
|
||||||
|
{ "const": "guest", "type": "string" }
|
||||||
|
]
|
||||||
|
});
|
||||||
|
|
||||||
|
let cleaned = SchemaCleanr::clean_for_gemini(schema);
|
||||||
|
|
||||||
|
assert_eq!(cleaned["type"], "string");
|
||||||
|
assert!(cleaned["enum"].is_array());
|
||||||
|
let enum_values = cleaned["enum"].as_array().unwrap();
|
||||||
|
assert_eq!(enum_values.len(), 3);
|
||||||
|
assert!(enum_values.contains(&json!("admin")));
|
||||||
|
assert!(enum_values.contains(&json!("user")));
|
||||||
|
assert!(enum_values.contains(&json!("guest")));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_strip_null_from_union() {
|
||||||
|
let schema = json!({
|
||||||
|
"oneOf": [
|
||||||
|
{ "type": "string" },
|
||||||
|
{ "type": "null" }
|
||||||
|
]
|
||||||
|
});
|
||||||
|
|
||||||
|
let cleaned = SchemaCleanr::clean_for_gemini(schema);
|
||||||
|
|
||||||
|
// Should simplify to just { type: "string" }
|
||||||
|
assert_eq!(cleaned["type"], "string");
|
||||||
|
assert!(cleaned.get("oneOf").is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_const_to_enum() {
|
||||||
|
let schema = json!({
|
||||||
|
"const": "fixed_value",
|
||||||
|
"description": "A constant"
|
||||||
|
});
|
||||||
|
|
||||||
|
let cleaned = SchemaCleanr::clean_for_gemini(schema);
|
||||||
|
|
||||||
|
assert_eq!(cleaned["enum"], json!(["fixed_value"]));
|
||||||
|
assert_eq!(cleaned["description"], "A constant");
|
||||||
|
assert!(cleaned.get("const").is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_preserve_metadata() {
|
||||||
|
let schema = json!({
|
||||||
|
"$ref": "#/$defs/Name",
|
||||||
|
"description": "User's name",
|
||||||
|
"title": "Name Field",
|
||||||
|
"default": "Anonymous",
|
||||||
|
"$defs": {
|
||||||
|
"Name": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let cleaned = SchemaCleanr::clean_for_gemini(schema);
|
||||||
|
|
||||||
|
assert_eq!(cleaned["type"], "string");
|
||||||
|
assert_eq!(cleaned["description"], "User's name");
|
||||||
|
assert_eq!(cleaned["title"], "Name Field");
|
||||||
|
assert_eq!(cleaned["default"], "Anonymous");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_circular_ref_prevention() {
|
||||||
|
let schema = json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"parent": {
|
||||||
|
"$ref": "#/$defs/Node"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"$defs": {
|
||||||
|
"Node": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"child": {
|
||||||
|
"$ref": "#/$defs/Node"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
// Should not panic on circular reference
|
||||||
|
let cleaned = SchemaCleanr::clean_for_gemini(schema);
|
||||||
|
|
||||||
|
assert_eq!(cleaned["properties"]["parent"]["type"], "object");
|
||||||
|
// Circular reference should be broken
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_validate_schema() {
|
||||||
|
let valid = json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"name": { "type": "string" }
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
assert!(SchemaCleanr::validate(&valid).is_ok());
|
||||||
|
|
||||||
|
let invalid = json!({
|
||||||
|
"properties": {
|
||||||
|
"name": { "type": "string" }
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
assert!(SchemaCleanr::validate(&invalid).is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_strategy_differences() {
|
||||||
|
let schema = json!({
|
||||||
|
"type": "string",
|
||||||
|
"minLength": 1,
|
||||||
|
"description": "A string field"
|
||||||
|
});
|
||||||
|
|
||||||
|
// Gemini: Most restrictive (removes minLength)
|
||||||
|
let gemini = SchemaCleanr::clean_for_gemini(schema.clone());
|
||||||
|
assert!(gemini.get("minLength").is_none());
|
||||||
|
assert_eq!(gemini["type"], "string");
|
||||||
|
assert_eq!(gemini["description"], "A string field");
|
||||||
|
|
||||||
|
// OpenAI: Most permissive (keeps minLength)
|
||||||
|
let openai = SchemaCleanr::clean_for_openai(schema.clone());
|
||||||
|
assert_eq!(openai["minLength"], 1); // OpenAI allows validation keywords
|
||||||
|
assert_eq!(openai["type"], "string");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_nested_properties() {
|
||||||
|
let schema = json!({
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"user": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"name": {
|
||||||
|
"type": "string",
|
||||||
|
"minLength": 1
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"additionalProperties": false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let cleaned = SchemaCleanr::clean_for_gemini(schema);
|
||||||
|
|
||||||
|
assert!(cleaned["properties"]["user"]["properties"]["name"]
|
||||||
|
.get("minLength")
|
||||||
|
.is_none());
|
||||||
|
assert!(cleaned["properties"]["user"]
|
||||||
|
.get("additionalProperties")
|
||||||
|
.is_none());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_type_array_null_removal() {
|
||||||
|
let schema = json!({
|
||||||
|
"type": ["string", "null"]
|
||||||
|
});
|
||||||
|
|
||||||
|
let cleaned = SchemaCleanr::clean_for_gemini(schema);
|
||||||
|
|
||||||
|
// Should simplify to just "string"
|
||||||
|
assert_eq!(cleaned["type"], "string");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_type_array_only_null_preserved() {
|
||||||
|
let schema = json!({
|
||||||
|
"type": ["null"]
|
||||||
|
});
|
||||||
|
|
||||||
|
let cleaned = SchemaCleanr::clean_for_gemini(schema);
|
||||||
|
|
||||||
|
assert_eq!(cleaned["type"], "null");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_ref_with_json_pointer_escape() {
|
||||||
|
let schema = json!({
|
||||||
|
"$ref": "#/$defs/Foo~1Bar",
|
||||||
|
"$defs": {
|
||||||
|
"Foo/Bar": {
|
||||||
|
"type": "string"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let cleaned = SchemaCleanr::clean_for_gemini(schema);
|
||||||
|
|
||||||
|
assert_eq!(cleaned["type"], "string");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_skip_type_when_non_simplifiable_union_exists() {
|
||||||
|
let schema = json!({
|
||||||
|
"type": "object",
|
||||||
|
"oneOf": [
|
||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"a": { "type": "string" }
|
||||||
|
}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "object",
|
||||||
|
"properties": {
|
||||||
|
"b": { "type": "number" }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
]
|
||||||
|
});
|
||||||
|
|
||||||
|
let cleaned = SchemaCleanr::clean_for_gemini(schema);
|
||||||
|
|
||||||
|
assert!(cleaned.get("type").is_none());
|
||||||
|
assert!(cleaned.get("oneOf").is_some());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_clean_nested_unknown_schema_keyword() {
|
||||||
|
let schema = json!({
|
||||||
|
"not": {
|
||||||
|
"$ref": "#/$defs/Age"
|
||||||
|
},
|
||||||
|
"$defs": {
|
||||||
|
"Age": {
|
||||||
|
"type": "integer",
|
||||||
|
"minimum": 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let cleaned = SchemaCleanr::clean_for_gemini(schema);
|
||||||
|
|
||||||
|
assert_eq!(cleaned["not"]["type"], "integer");
|
||||||
|
assert!(cleaned["not"].get("minimum").is_none());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -36,6 +36,7 @@ async fn compare_store_speed() {
|
||||||
&format!("key_{i}"),
|
&format!("key_{i}"),
|
||||||
&format!("Memory entry number {i} about Rust programming"),
|
&format!("Memory entry number {i} about Rust programming"),
|
||||||
MemoryCategory::Core,
|
MemoryCategory::Core,
|
||||||
|
None,
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
@ -49,6 +50,7 @@ async fn compare_store_speed() {
|
||||||
&format!("key_{i}"),
|
&format!("key_{i}"),
|
||||||
&format!("Memory entry number {i} about Rust programming"),
|
&format!("Memory entry number {i} about Rust programming"),
|
||||||
MemoryCategory::Core,
|
MemoryCategory::Core,
|
||||||
|
None,
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
@ -127,8 +129,8 @@ async fn compare_recall_quality() {
|
||||||
];
|
];
|
||||||
|
|
||||||
for (key, content, cat) in &entries {
|
for (key, content, cat) in &entries {
|
||||||
sq.store(key, content, cat.clone()).await.unwrap();
|
sq.store(key, content, cat.clone(), None).await.unwrap();
|
||||||
md.store(key, content, cat.clone()).await.unwrap();
|
md.store(key, content, cat.clone(), None).await.unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test queries and compare results
|
// Test queries and compare results
|
||||||
|
|
@ -145,8 +147,8 @@ async fn compare_recall_quality() {
|
||||||
println!("RECALL QUALITY (10 entries seeded):\n");
|
println!("RECALL QUALITY (10 entries seeded):\n");
|
||||||
|
|
||||||
for (query, desc) in &queries {
|
for (query, desc) in &queries {
|
||||||
let sq_results = sq.recall(query, 10).await.unwrap();
|
let sq_results = sq.recall(query, 10, None).await.unwrap();
|
||||||
let md_results = md.recall(query, 10).await.unwrap();
|
let md_results = md.recall(query, 10, None).await.unwrap();
|
||||||
|
|
||||||
println!(" Query: \"{query}\" — {desc}");
|
println!(" Query: \"{query}\" — {desc}");
|
||||||
println!(" SQLite: {} results", sq_results.len());
|
println!(" SQLite: {} results", sq_results.len());
|
||||||
|
|
@ -190,21 +192,21 @@ async fn compare_recall_speed() {
|
||||||
} else {
|
} else {
|
||||||
format!("TypeScript powers modern web apps, entry {i}")
|
format!("TypeScript powers modern web apps, entry {i}")
|
||||||
};
|
};
|
||||||
sq.store(&format!("e{i}"), &content, MemoryCategory::Core)
|
sq.store(&format!("e{i}"), &content, MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
md.store(&format!("e{i}"), &content, MemoryCategory::Daily)
|
md.store(&format!("e{i}"), &content, MemoryCategory::Daily, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Benchmark recall
|
// Benchmark recall
|
||||||
let start = Instant::now();
|
let start = Instant::now();
|
||||||
let sq_results = sq.recall("Rust systems", 10).await.unwrap();
|
let sq_results = sq.recall("Rust systems", 10, None).await.unwrap();
|
||||||
let sq_dur = start.elapsed();
|
let sq_dur = start.elapsed();
|
||||||
|
|
||||||
let start = Instant::now();
|
let start = Instant::now();
|
||||||
let md_results = md.recall("Rust systems", 10).await.unwrap();
|
let md_results = md.recall("Rust systems", 10, None).await.unwrap();
|
||||||
let md_dur = start.elapsed();
|
let md_dur = start.elapsed();
|
||||||
|
|
||||||
println!("\n============================================================");
|
println!("\n============================================================");
|
||||||
|
|
@ -227,15 +229,25 @@ async fn compare_persistence() {
|
||||||
// Store in both, then drop and re-open
|
// Store in both, then drop and re-open
|
||||||
{
|
{
|
||||||
let sq = sqlite_backend(tmp_sq.path());
|
let sq = sqlite_backend(tmp_sq.path());
|
||||||
sq.store("persist_test", "I should survive", MemoryCategory::Core)
|
sq.store(
|
||||||
.await
|
"persist_test",
|
||||||
.unwrap();
|
"I should survive",
|
||||||
|
MemoryCategory::Core,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
}
|
}
|
||||||
{
|
{
|
||||||
let md = markdown_backend(tmp_md.path());
|
let md = markdown_backend(tmp_md.path());
|
||||||
md.store("persist_test", "I should survive", MemoryCategory::Core)
|
md.store(
|
||||||
.await
|
"persist_test",
|
||||||
.unwrap();
|
"I should survive",
|
||||||
|
MemoryCategory::Core,
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Re-open
|
// Re-open
|
||||||
|
|
@ -282,17 +294,17 @@ async fn compare_upsert() {
|
||||||
let md = markdown_backend(tmp_md.path());
|
let md = markdown_backend(tmp_md.path());
|
||||||
|
|
||||||
// Store twice with same key, different content
|
// Store twice with same key, different content
|
||||||
sq.store("pref", "likes Rust", MemoryCategory::Core)
|
sq.store("pref", "likes Rust", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
sq.store("pref", "loves Rust", MemoryCategory::Core)
|
sq.store("pref", "loves Rust", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
md.store("pref", "likes Rust", MemoryCategory::Core)
|
md.store("pref", "likes Rust", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
md.store("pref", "loves Rust", MemoryCategory::Core)
|
md.store("pref", "loves Rust", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
|
|
@ -300,7 +312,7 @@ async fn compare_upsert() {
|
||||||
let md_count = md.count().await.unwrap();
|
let md_count = md.count().await.unwrap();
|
||||||
|
|
||||||
let sq_entry = sq.get("pref").await.unwrap();
|
let sq_entry = sq.get("pref").await.unwrap();
|
||||||
let md_results = md.recall("loves Rust", 5).await.unwrap();
|
let md_results = md.recall("loves Rust", 5, None).await.unwrap();
|
||||||
|
|
||||||
println!("\n============================================================");
|
println!("\n============================================================");
|
||||||
println!("UPSERT (store same key twice):");
|
println!("UPSERT (store same key twice):");
|
||||||
|
|
@ -328,10 +340,10 @@ async fn compare_forget() {
|
||||||
let sq = sqlite_backend(tmp_sq.path());
|
let sq = sqlite_backend(tmp_sq.path());
|
||||||
let md = markdown_backend(tmp_md.path());
|
let md = markdown_backend(tmp_md.path());
|
||||||
|
|
||||||
sq.store("secret", "API key: sk-1234", MemoryCategory::Core)
|
sq.store("secret", "API key: sk-1234", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
md.store("secret", "API key: sk-1234", MemoryCategory::Core)
|
md.store("secret", "API key: sk-1234", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
|
|
@ -372,37 +384,40 @@ async fn compare_category_filter() {
|
||||||
let md = markdown_backend(tmp_md.path());
|
let md = markdown_backend(tmp_md.path());
|
||||||
|
|
||||||
// Mix of categories
|
// Mix of categories
|
||||||
sq.store("a", "core fact 1", MemoryCategory::Core)
|
sq.store("a", "core fact 1", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
sq.store("b", "core fact 2", MemoryCategory::Core)
|
sq.store("b", "core fact 2", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
sq.store("c", "daily note", MemoryCategory::Daily)
|
sq.store("c", "daily note", MemoryCategory::Daily, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
sq.store("d", "convo msg", MemoryCategory::Conversation)
|
sq.store("d", "convo msg", MemoryCategory::Conversation, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
md.store("a", "core fact 1", MemoryCategory::Core)
|
md.store("a", "core fact 1", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
md.store("b", "core fact 2", MemoryCategory::Core)
|
md.store("b", "core fact 2", MemoryCategory::Core, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
md.store("c", "daily note", MemoryCategory::Daily)
|
md.store("c", "daily note", MemoryCategory::Daily, None)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let sq_core = sq.list(Some(&MemoryCategory::Core)).await.unwrap();
|
let sq_core = sq.list(Some(&MemoryCategory::Core), None).await.unwrap();
|
||||||
let sq_daily = sq.list(Some(&MemoryCategory::Daily)).await.unwrap();
|
let sq_daily = sq.list(Some(&MemoryCategory::Daily), None).await.unwrap();
|
||||||
let sq_conv = sq.list(Some(&MemoryCategory::Conversation)).await.unwrap();
|
let sq_conv = sq
|
||||||
let sq_all = sq.list(None).await.unwrap();
|
.list(Some(&MemoryCategory::Conversation), None)
|
||||||
|
.await
|
||||||
|
.unwrap();
|
||||||
|
let sq_all = sq.list(None, None).await.unwrap();
|
||||||
|
|
||||||
let md_core = md.list(Some(&MemoryCategory::Core)).await.unwrap();
|
let md_core = md.list(Some(&MemoryCategory::Core), None).await.unwrap();
|
||||||
let md_daily = md.list(Some(&MemoryCategory::Daily)).await.unwrap();
|
let md_daily = md.list(Some(&MemoryCategory::Daily), None).await.unwrap();
|
||||||
let md_all = md.list(None).await.unwrap();
|
let md_all = md.list(None, None).await.unwrap();
|
||||||
|
|
||||||
println!("\n============================================================");
|
println!("\n============================================================");
|
||||||
println!("CATEGORY FILTERING:");
|
println!("CATEGORY FILTERING:");
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue